From 4157a20cec17069de2da1a46d1ca6f25ef890a70 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 18 Sep 2025 13:57:07 +0000 Subject: [PATCH 1/8] fix:orb squeeze incorrect energy shape --- torch_sim/models/orb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index fd65b23f..132f6d5c 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -416,7 +416,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: if not model_has_direct_heads and prop == "stress": continue _property = "energy" if prop == "free_energy" else prop - results[prop] = predictions[_property].squeeze() + results[prop] = predictions[_property] if self.conservative: results["forces"] = results[self.model.grad_forces_name] From 0cbdbc61559ead5690ab34fbabe6e7ad37bddee4 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 27 Nov 2025 15:59:32 +0100 Subject: [PATCH 2/8] feat/different temperatures per systems ts.integrate --- tests/test_runners.py | 15 ++++++++ torch_sim/runners.py | 83 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 87 insertions(+), 11 deletions(-) diff --git a/tests/test_runners.py b/tests/test_runners.py index c9b39f4b..71d0d151 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -108,6 +108,21 @@ def test_integrate_double_nvt( assert not torch.isnan(final_state.energy).any() +def test_integrate_double_nvt_multiple_temperatures( + ar_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """Test NVT integration with LJ potential.""" + _ = ts.integrate( + system=ar_double_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=10, + temperature=[100.0, 200.0], # K + timestep=0.001, # ps + init_kwargs=dict(seed=481516), + ) + + def test_integrate_double_nvt_with_reporter( ar_double_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path ) -> None: diff --git a/torch_sim/runners.py b/torch_sim/runners.py index b7fe73cb..d3c8c278 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -100,6 +100,67 @@ def _configure_batches_iterator( return batches +def _normalize_temperature_tensor( + temperature: float | list | torch.Tensor, n_steps: int, initial_state: SimState +) -> torch.Tensor: + """Turn the temperature into a tensor of shape (n_steps,) or (n_steps, n_systems). + + Args: + temperature (float | int | list | torch.Tensor): Temperature input + n_steps (int): Number of integration steps + initial_state (SimState): Initial simulation state for dtype and device + Returns: + torch.Tensor: Normalized temperature tensor + """ + # ---- Step 1: Convert to tensor ---- + if isinstance(temperature, (float, int)): + return torch.full( + (n_steps,), + float(temperature), + dtype=initial_state.dtype, + device=initial_state.device, + ) + + # Convert list or tensor input to tensor + if isinstance(temperature, list): + temps = torch.tensor( + temperature, dtype=initial_state.dtype, device=initial_state.device + ) + elif isinstance(temperature, torch.Tensor): + temps = temperature.to(dtype=initial_state.dtype, device=initial_state.device) + else: + raise TypeError( + f"Invalid temperature type: {type(temperature).__name__}. " + "Must be float, int, list, or torch.Tensor." + ) + + # ---- Step 2: Determine how to broadcast ---- + temps = torch.atleast_1d(temps) + if temps.ndim > 2: + raise ValueError(f"Temperature tensor must be 1D or 2D, got shape {temps.shape}.") + + if temps.shape[0] == 1: + # A single value in a 1-element list/tensor + return temps.repeat(n_steps) + + # This assumes that in case n_systems == n_steps, the user wants to apply + # different temperatures per system, not per step. + if temps.shape[0] == initial_state.n_systems: + # Interpret as single-step multi-system temperatures → broadcast over steps + return temps.unsqueeze(0).expand(n_steps, -1) # (n_steps, n_systems) + + if temps.shape[0] == n_steps: + return temps # already good: (n_steps,) or (n_steps, n_systems) + + raise ValueError( + f"Temperature length ({temps.shape[0]}) must be either:\n" + f" - n_steps ({n_steps}), or\n" + f" - n_systems ({initial_state.n_systems}), or\n" + f" - 1 (scalar),\n" + f"but got {temps.shape[0]}." + ) + + def integrate[T: SimState]( # noqa: C901 system: StateLike, model: ModelInterface, @@ -140,18 +201,11 @@ def integrate[T: SimState]( # noqa: C901 T: Final state after integration """ unit_system = UnitSystem.metal - # create a list of temperatures - temps = ( - [temperature] * n_steps - if isinstance(temperature, (float, int)) - else list(temperature) - ) - if len(temps) != n_steps: - raise ValueError(f"{len(temps)=:,}. It must equal n_steps = {n_steps=:,}") initial_state: SimState = ts.initialize_state(system, model.device, model.dtype) dtype, device = initial_state.dtype, initial_state.device - kTs = torch.tensor(temps, dtype=dtype, device=device) * unit_system.temperature + kTs = _normalize_temperature_tensor(temperature, n_steps, initial_state) + kTs = kTs * unit_system.temperature dt = torch.tensor(timestep * unit_system.time, dtype=dtype, device=device) # Handle both string names and direct function tuples @@ -192,7 +246,10 @@ def integrate[T: SimState]( # noqa: C901 # Handle both BinningAutoBatcher and list of tuples for state, system_indices in batch_iterator: # Pass correct parameters based on integrator type - state = init_func(state=state, model=model, kT=kTs[0], dt=dt, **init_kwargs or {}) + batch_kT = kTs[:, system_indices] if (system_indices and kTs.shape == 2) else kTs + state = init_func( + state=state, model=model, kT=batch_kT[0], dt=dt, **init_kwargs or {} + ) # set up trajectory reporters if autobatcher and trajectory_reporter is not None and og_filenames is not None: @@ -204,7 +261,11 @@ def integrate[T: SimState]( # noqa: C901 # run the simulation for step in range(1, n_steps + 1): state = step_func( - state=state, model=model, dt=dt, kT=kTs[step - 1], **integrator_kwargs + state=state, + model=model, + dt=dt, + kT=batch_kT[step - 1], + **integrator_kwargs, ) if trajectory_reporter: From fbb219de78a0b76c625da07ef78061ced9f86788 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 27 Nov 2025 16:02:59 +0100 Subject: [PATCH 3/8] docs --- torch_sim/runners.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index d3c8c278..3f874580 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -184,7 +184,11 @@ def integrate[T: SimState]( # noqa: C901 (init_func, step_func) functions. n_steps (int): Number of integration steps temperature (float | ArrayLike): Temperature or array of temperatures for each - step + step or system: + Float: used for all steps and systems + 1D array of length n_steps: used for each step + 1D array of length n_systems: used for each system + 2D array of shape (n_steps, n_systems): used for each step and system. timestep (float): Integration time step trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for tracking trajectory. If a dict, will be passed to the TrajectoryReporter From a4994e1448f6e648a15e10fb312315229064d948 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Fri, 28 Nov 2025 12:01:44 +0100 Subject: [PATCH 4/8] fix if condition --- torch_sim/runners.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 3f874580..135c5f6e 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -250,7 +250,9 @@ def integrate[T: SimState]( # noqa: C901 # Handle both BinningAutoBatcher and list of tuples for state, system_indices in batch_iterator: # Pass correct parameters based on integrator type - batch_kT = kTs[:, system_indices] if (system_indices and kTs.shape == 2) else kTs + batch_kT = ( + kTs[:, system_indices] if (system_indices and len(kTs.shape) == 2) else kTs + ) state = init_func( state=state, model=model, kT=batch_kT[0], dt=dt, **init_kwargs or {} ) From a431db4434e8af40a54d88c5e6b80e2b8c2be331 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Fri, 28 Nov 2025 12:02:04 +0100 Subject: [PATCH 5/8] modify index bins in BinningAutoBatcher to only keep indices --- torch_sim/autobatching.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 65bbbe12..cd6fca01 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -578,7 +578,9 @@ def load_states(self, states: T | Sequence[T]) -> float: self.index_to_scaler = dict(enumerate(self.memory_scalers)) self.index_bins = to_constant_volume_bins( self.index_to_scaler, max_volume=self.max_memory_scaler - ) + ) # list[dict[original_index: int, memory_scale:float]] + # Convert to list of lists of indices + self.index_bins = [list(batch.keys()) for batch in self.index_bins] self.batched_states = [] for index_bin in self.index_bins: self.batched_states.append([self.state_slices[idx] for idx in index_bin]) From 6dcb883049321e0aa7a534f6f5eb71d6e41b5b03 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Fri, 28 Nov 2025 12:11:39 +0100 Subject: [PATCH 6/8] test with binning autobatcher --- tests/test_runners.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_runners.py b/tests/test_runners.py index 71d0d151..f165f6f1 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -122,6 +122,22 @@ def test_integrate_double_nvt_multiple_temperatures( init_kwargs=dict(seed=481516), ) + batcher = ts.autobatching.BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=ar_double_sim_state[0].n_atoms, + ) + _ = ts.integrate( + system=ar_double_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=10, + temperature=[100.0, 200.0], # K + timestep=0.001, # ps + autobatcher=batcher, + init_kwargs=dict(seed=481516), + ) + def test_integrate_double_nvt_with_reporter( ar_double_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path From bd9507a4e14f0d7ba60eb818153aa7b476f2fb51 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Fri, 28 Nov 2025 12:16:32 +0100 Subject: [PATCH 7/8] add warning for a rare edge case --- torch_sim/runners.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 135c5f6e..5357e1dd 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -143,8 +143,13 @@ def _normalize_temperature_tensor( # A single value in a 1-element list/tensor return temps.repeat(n_steps) - # This assumes that in case n_systems == n_steps, the user wants to apply - # different temperatures per system, not per step. + if initial_state.n_systems == n_steps: + warnings.warn( + "n_systems is equal to n_steps. Interpreting temperature array of length " + "n_systems as temperatures for each system, broadcasted over steps.", + stacklevel=2, + ) + if temps.shape[0] == initial_state.n_systems: # Interpret as single-step multi-system temperatures → broadcast over steps return temps.unsqueeze(0).expand(n_steps, -1) # (n_steps, n_systems) From a314cdef088bc3c8ced9bf2f9aba7da6a1bc111d Mon Sep 17 00:00:00 2001 From: thomasloux Date: Fri, 28 Nov 2025 23:00:43 +0100 Subject: [PATCH 8/8] add test and add error in case temps is 2D to enforce (n_systems, n_steps) --- tests/test_runners.py | 30 ++++++++++++++++++++++++++++-- torch_sim/runners.py | 4 ++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/test_runners.py b/tests/test_runners.py index f165f6f1..7968a15d 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -112,11 +112,12 @@ def test_integrate_double_nvt_multiple_temperatures( ar_double_sim_state: SimState, lj_model: LennardJonesModel ) -> None: """Test NVT integration with LJ potential.""" + n_steps = 5 _ = ts.integrate( system=ar_double_sim_state, model=lj_model, integrator=ts.Integrator.nvt_langevin, - n_steps=10, + n_steps=n_steps, temperature=[100.0, 200.0], # K timestep=0.001, # ps init_kwargs=dict(seed=481516), @@ -131,13 +132,38 @@ def test_integrate_double_nvt_multiple_temperatures( system=ar_double_sim_state, model=lj_model, integrator=ts.Integrator.nvt_langevin, - n_steps=10, + n_steps=n_steps, temperature=[100.0, 200.0], # K timestep=0.001, # ps autobatcher=batcher, init_kwargs=dict(seed=481516), ) + # Temperature tensor with correct shape (n_steps, n_systems) + _ = ts.integrate( + system=ar_double_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=n_steps, + temperature=torch.tensor([100.0, 200.0])[None, :].repeat(n_steps, 1), + timestep=0.001, # ps + autobatcher=batcher, + init_kwargs=dict(seed=481516), + ) + + # Temperature tensor with incorrect shape (n_systems, n_steps) + with pytest.raises(ValueError, match="first dimension must be n_steps"): + _ = ts.integrate( + system=ar_double_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=n_steps, + temperature=torch.tensor([100.0, 200.0])[None, :].repeat(n_steps, 1).T, # K + timestep=0.001, # ps + autobatcher=batcher, + init_kwargs=dict(seed=481516), + ) + def test_integrate_double_nvt_with_reporter( ar_double_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 5357e1dd..8b39f55f 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -151,6 +151,10 @@ def _normalize_temperature_tensor( ) if temps.shape[0] == initial_state.n_systems: + if temps.ndim == 2: + raise ValueError( + "If temperature tensor is 2D, first dimension must be n_steps." + ) # Interpret as single-step multi-system temperatures → broadcast over steps return temps.unsqueeze(0).expand(n_steps, -1) # (n_steps, n_systems)