diff --git a/tests/test_runners.py b/tests/test_runners.py index c9b39f4b..7968a15d 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -108,6 +108,63 @@ def test_integrate_double_nvt( assert not torch.isnan(final_state.energy).any() +def test_integrate_double_nvt_multiple_temperatures( + ar_double_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """Test NVT integration with LJ potential.""" + n_steps = 5 + _ = ts.integrate( + system=ar_double_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=n_steps, + temperature=[100.0, 200.0], # K + timestep=0.001, # ps + init_kwargs=dict(seed=481516), + ) + + batcher = ts.autobatching.BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=ar_double_sim_state[0].n_atoms, + ) + _ = ts.integrate( + system=ar_double_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=n_steps, + temperature=[100.0, 200.0], # K + timestep=0.001, # ps + autobatcher=batcher, + init_kwargs=dict(seed=481516), + ) + + # Temperature tensor with correct shape (n_steps, n_systems) + _ = ts.integrate( + system=ar_double_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=n_steps, + temperature=torch.tensor([100.0, 200.0])[None, :].repeat(n_steps, 1), + timestep=0.001, # ps + autobatcher=batcher, + init_kwargs=dict(seed=481516), + ) + + # Temperature tensor with incorrect shape (n_systems, n_steps) + with pytest.raises(ValueError, match="first dimension must be n_steps"): + _ = ts.integrate( + system=ar_double_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=n_steps, + temperature=torch.tensor([100.0, 200.0])[None, :].repeat(n_steps, 1).T, # K + timestep=0.001, # ps + autobatcher=batcher, + init_kwargs=dict(seed=481516), + ) + + def test_integrate_double_nvt_with_reporter( ar_double_sim_state: SimState, lj_model: LennardJonesModel, tmp_path: Path ) -> None: diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 65bbbe12..cd6fca01 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -578,7 +578,9 @@ def load_states(self, states: T | Sequence[T]) -> float: self.index_to_scaler = dict(enumerate(self.memory_scalers)) self.index_bins = to_constant_volume_bins( self.index_to_scaler, max_volume=self.max_memory_scaler - ) + ) # list[dict[original_index: int, memory_scale:float]] + # Convert to list of lists of indices + self.index_bins = [list(batch.keys()) for batch in self.index_bins] self.batched_states = [] for index_bin in self.index_bins: self.batched_states.append([self.state_slices[idx] for idx in index_bin]) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index b7fe73cb..8b39f55f 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -100,6 +100,76 @@ def _configure_batches_iterator( return batches +def _normalize_temperature_tensor( + temperature: float | list | torch.Tensor, n_steps: int, initial_state: SimState +) -> torch.Tensor: + """Turn the temperature into a tensor of shape (n_steps,) or (n_steps, n_systems). + + Args: + temperature (float | int | list | torch.Tensor): Temperature input + n_steps (int): Number of integration steps + initial_state (SimState): Initial simulation state for dtype and device + Returns: + torch.Tensor: Normalized temperature tensor + """ + # ---- Step 1: Convert to tensor ---- + if isinstance(temperature, (float, int)): + return torch.full( + (n_steps,), + float(temperature), + dtype=initial_state.dtype, + device=initial_state.device, + ) + + # Convert list or tensor input to tensor + if isinstance(temperature, list): + temps = torch.tensor( + temperature, dtype=initial_state.dtype, device=initial_state.device + ) + elif isinstance(temperature, torch.Tensor): + temps = temperature.to(dtype=initial_state.dtype, device=initial_state.device) + else: + raise TypeError( + f"Invalid temperature type: {type(temperature).__name__}. " + "Must be float, int, list, or torch.Tensor." + ) + + # ---- Step 2: Determine how to broadcast ---- + temps = torch.atleast_1d(temps) + if temps.ndim > 2: + raise ValueError(f"Temperature tensor must be 1D or 2D, got shape {temps.shape}.") + + if temps.shape[0] == 1: + # A single value in a 1-element list/tensor + return temps.repeat(n_steps) + + if initial_state.n_systems == n_steps: + warnings.warn( + "n_systems is equal to n_steps. Interpreting temperature array of length " + "n_systems as temperatures for each system, broadcasted over steps.", + stacklevel=2, + ) + + if temps.shape[0] == initial_state.n_systems: + if temps.ndim == 2: + raise ValueError( + "If temperature tensor is 2D, first dimension must be n_steps." + ) + # Interpret as single-step multi-system temperatures → broadcast over steps + return temps.unsqueeze(0).expand(n_steps, -1) # (n_steps, n_systems) + + if temps.shape[0] == n_steps: + return temps # already good: (n_steps,) or (n_steps, n_systems) + + raise ValueError( + f"Temperature length ({temps.shape[0]}) must be either:\n" + f" - n_steps ({n_steps}), or\n" + f" - n_systems ({initial_state.n_systems}), or\n" + f" - 1 (scalar),\n" + f"but got {temps.shape[0]}." + ) + + def integrate[T: SimState]( # noqa: C901 system: StateLike, model: ModelInterface, @@ -123,7 +193,11 @@ def integrate[T: SimState]( # noqa: C901 (init_func, step_func) functions. n_steps (int): Number of integration steps temperature (float | ArrayLike): Temperature or array of temperatures for each - step + step or system: + Float: used for all steps and systems + 1D array of length n_steps: used for each step + 1D array of length n_systems: used for each system + 2D array of shape (n_steps, n_systems): used for each step and system. timestep (float): Integration time step trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for tracking trajectory. If a dict, will be passed to the TrajectoryReporter @@ -140,18 +214,11 @@ def integrate[T: SimState]( # noqa: C901 T: Final state after integration """ unit_system = UnitSystem.metal - # create a list of temperatures - temps = ( - [temperature] * n_steps - if isinstance(temperature, (float, int)) - else list(temperature) - ) - if len(temps) != n_steps: - raise ValueError(f"{len(temps)=:,}. It must equal n_steps = {n_steps=:,}") initial_state: SimState = ts.initialize_state(system, model.device, model.dtype) dtype, device = initial_state.dtype, initial_state.device - kTs = torch.tensor(temps, dtype=dtype, device=device) * unit_system.temperature + kTs = _normalize_temperature_tensor(temperature, n_steps, initial_state) + kTs = kTs * unit_system.temperature dt = torch.tensor(timestep * unit_system.time, dtype=dtype, device=device) # Handle both string names and direct function tuples @@ -192,7 +259,12 @@ def integrate[T: SimState]( # noqa: C901 # Handle both BinningAutoBatcher and list of tuples for state, system_indices in batch_iterator: # Pass correct parameters based on integrator type - state = init_func(state=state, model=model, kT=kTs[0], dt=dt, **init_kwargs or {}) + batch_kT = ( + kTs[:, system_indices] if (system_indices and len(kTs.shape) == 2) else kTs + ) + state = init_func( + state=state, model=model, kT=batch_kT[0], dt=dt, **init_kwargs or {} + ) # set up trajectory reporters if autobatcher and trajectory_reporter is not None and og_filenames is not None: @@ -204,7 +276,11 @@ def integrate[T: SimState]( # noqa: C901 # run the simulation for step in range(1, n_steps + 1): state = step_func( - state=state, model=model, dt=dt, kT=kTs[step - 1], **integrator_kwargs + state=state, + model=model, + dt=dt, + kT=batch_kT[step - 1], + **integrator_kwargs, ) if trajectory_reporter: