Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,63 @@ def test_integrate_double_nvt(
assert not torch.isnan(final_state.energy).any()


def test_integrate_double_nvt_multiple_temperatures(
ar_double_sim_state: SimState, lj_model: LennardJonesModel
) -> None:
"""Test NVT integration with LJ potential."""
n_steps = 5
_ = ts.integrate(
system=ar_double_sim_state,
model=lj_model,
integrator=ts.Integrator.nvt_langevin,
n_steps=n_steps,
temperature=[100.0, 200.0], # K
timestep=0.001, # ps
init_kwargs=dict(seed=481516),
)

batcher = ts.autobatching.BinningAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=ar_double_sim_state[0].n_atoms,
)
_ = ts.integrate(
system=ar_double_sim_state,
model=lj_model,
integrator=ts.Integrator.nvt_langevin,
n_steps=n_steps,
temperature=[100.0, 200.0], # K
timestep=0.001, # ps
autobatcher=batcher,
init_kwargs=dict(seed=481516),
)

# Temperature tensor with correct shape (n_steps, n_systems)
_ = ts.integrate(
system=ar_double_sim_state,
model=lj_model,
integrator=ts.Integrator.nvt_langevin,
n_steps=n_steps,
temperature=torch.tensor([100.0, 200.0])[None, :].repeat(n_steps, 1),
timestep=0.001, # ps
autobatcher=batcher,
init_kwargs=dict(seed=481516),
)

# Temperature tensor with incorrect shape (n_systems, n_steps)
with pytest.raises(ValueError, match="first dimension must be n_steps"):
_ = ts.integrate(
system=ar_double_sim_state,
model=lj_model,
integrator=ts.Integrator.nvt_langevin,
n_steps=n_steps,
temperature=torch.tensor([100.0, 200.0])[None, :].repeat(n_steps, 1).T, # K
timestep=0.001, # ps
autobatcher=batcher,
init_kwargs=dict(seed=481516),
)


def test_integrate_double_nvt_with_reporter(
ar_double_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path
) -> None:
Expand Down
4 changes: 3 additions & 1 deletion torch_sim/autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,9 @@ def load_states(self, states: T | Sequence[T]) -> float:
self.index_to_scaler = dict(enumerate(self.memory_scalers))
self.index_bins = to_constant_volume_bins(
self.index_to_scaler, max_volume=self.max_memory_scaler
)
) # list[dict[original_index: int, memory_scale:float]]
# Convert to list of lists of indices
self.index_bins = [list(batch.keys()) for batch in self.index_bins]
self.batched_states = []
for index_bin in self.index_bins:
self.batched_states.append([self.state_slices[idx] for idx in index_bin])
Expand Down
100 changes: 88 additions & 12 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,76 @@ def _configure_batches_iterator(
return batches


def _normalize_temperature_tensor(
temperature: float | list | torch.Tensor, n_steps: int, initial_state: SimState
) -> torch.Tensor:
"""Turn the temperature into a tensor of shape (n_steps,) or (n_steps, n_systems).

Args:
temperature (float | int | list | torch.Tensor): Temperature input
n_steps (int): Number of integration steps
initial_state (SimState): Initial simulation state for dtype and device
Returns:
torch.Tensor: Normalized temperature tensor
"""
# ---- Step 1: Convert to tensor ----
if isinstance(temperature, (float, int)):
return torch.full(
(n_steps,),
float(temperature),
dtype=initial_state.dtype,
device=initial_state.device,
)

# Convert list or tensor input to tensor
if isinstance(temperature, list):
temps = torch.tensor(
temperature, dtype=initial_state.dtype, device=initial_state.device
)
elif isinstance(temperature, torch.Tensor):
temps = temperature.to(dtype=initial_state.dtype, device=initial_state.device)
else:
raise TypeError(
f"Invalid temperature type: {type(temperature).__name__}. "
"Must be float, int, list, or torch.Tensor."
)

# ---- Step 2: Determine how to broadcast ----
temps = torch.atleast_1d(temps)
if temps.ndim > 2:
raise ValueError(f"Temperature tensor must be 1D or 2D, got shape {temps.shape}.")

if temps.shape[0] == 1:
# A single value in a 1-element list/tensor
return temps.repeat(n_steps)

if initial_state.n_systems == n_steps:
warnings.warn(
"n_systems is equal to n_steps. Interpreting temperature array of length "
"n_systems as temperatures for each system, broadcasted over steps.",
stacklevel=2,
)

if temps.shape[0] == initial_state.n_systems:
if temps.ndim == 2:
raise ValueError(
"If temperature tensor is 2D, first dimension must be n_steps."
)
# Interpret as single-step multi-system temperatures → broadcast over steps
return temps.unsqueeze(0).expand(n_steps, -1) # (n_steps, n_systems)

if temps.shape[0] == n_steps:
return temps # already good: (n_steps,) or (n_steps, n_systems)

raise ValueError(
f"Temperature length ({temps.shape[0]}) must be either:\n"
f" - n_steps ({n_steps}), or\n"
f" - n_systems ({initial_state.n_systems}), or\n"
f" - 1 (scalar),\n"
f"but got {temps.shape[0]}."
)


def integrate[T: SimState]( # noqa: C901
system: StateLike,
model: ModelInterface,
Expand All @@ -123,7 +193,11 @@ def integrate[T: SimState]( # noqa: C901
(init_func, step_func) functions.
n_steps (int): Number of integration steps
temperature (float | ArrayLike): Temperature or array of temperatures for each
step
step or system:
Float: used for all steps and systems
1D array of length n_steps: used for each step
1D array of length n_systems: used for each system
2D array of shape (n_steps, n_systems): used for each step and system.
timestep (float): Integration time step
trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for
tracking trajectory. If a dict, will be passed to the TrajectoryReporter
Expand All @@ -140,18 +214,11 @@ def integrate[T: SimState]( # noqa: C901
T: Final state after integration
"""
unit_system = UnitSystem.metal
# create a list of temperatures
temps = (
[temperature] * n_steps
if isinstance(temperature, (float, int))
else list(temperature)
)
if len(temps) != n_steps:
raise ValueError(f"{len(temps)=:,}. It must equal n_steps = {n_steps=:,}")

initial_state: SimState = ts.initialize_state(system, model.device, model.dtype)
dtype, device = initial_state.dtype, initial_state.device
kTs = torch.tensor(temps, dtype=dtype, device=device) * unit_system.temperature
kTs = _normalize_temperature_tensor(temperature, n_steps, initial_state)
kTs = kTs * unit_system.temperature
dt = torch.tensor(timestep * unit_system.time, dtype=dtype, device=device)

# Handle both string names and direct function tuples
Expand Down Expand Up @@ -192,7 +259,12 @@ def integrate[T: SimState]( # noqa: C901
# Handle both BinningAutoBatcher and list of tuples
for state, system_indices in batch_iterator:
# Pass correct parameters based on integrator type
state = init_func(state=state, model=model, kT=kTs[0], dt=dt, **init_kwargs or {})
batch_kT = (
kTs[:, system_indices] if (system_indices and len(kTs.shape) == 2) else kTs
)
state = init_func(
state=state, model=model, kT=batch_kT[0], dt=dt, **init_kwargs or {}
)

# set up trajectory reporters
if autobatcher and trajectory_reporter is not None and og_filenames is not None:
Expand All @@ -204,7 +276,11 @@ def integrate[T: SimState]( # noqa: C901
# run the simulation
for step in range(1, n_steps + 1):
state = step_func(
state=state, model=model, dt=dt, kT=kTs[step - 1], **integrator_kwargs
state=state,
model=model,
dt=dt,
kT=batch_kT[step - 1],
**integrator_kwargs,
)

if trajectory_reporter:
Expand Down