From 4157a20cec17069de2da1a46d1ca6f25ef890a70 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 18 Sep 2025 13:57:07 +0000 Subject: [PATCH 01/43] fix:orb squeeze incorrect energy shape --- torch_sim/models/orb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index fd65b23f..132f6d5c 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -416,7 +416,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: if not model_has_direct_heads and prop == "stress": continue _property = "energy" if prop == "free_energy" else prop - results[prop] = predictions[_property].squeeze() + results[prop] = predictions[_property] if self.conservative: results["forces"] = results[self.model.grad_forces_name] From 38c6138826f41190a5369e4060097183da2e5876 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Fri, 17 Oct 2025 19:54:46 +0200 Subject: [PATCH 02/43] First draft constraints --- tests/test_constraints.py | 77 +++++++ tests/test_state.py | 2 +- torch_sim/__init__.py | 1 + torch_sim/constraints.py | 406 +++++++++++++++++++++++++++++++++++ torch_sim/integrators/md.py | 17 +- torch_sim/integrators/npt.py | 4 +- torch_sim/integrators/nvt.py | 3 +- torch_sim/monte_carlo.py | 2 +- torch_sim/optimizers/fire.py | 1 + torch_sim/runners.py | 2 +- torch_sim/state.py | 44 +++- torch_sim/transforms.py | 64 ++++++ 12 files changed, 610 insertions(+), 13 deletions(-) create mode 100644 tests/test_constraints.py create mode 100644 torch_sim/constraints.py diff --git a/tests/test_constraints.py b/tests/test_constraints.py new file mode 100644 index 00000000..c919ade8 --- /dev/null +++ b/tests/test_constraints.py @@ -0,0 +1,77 @@ +import torch + +import torch_sim as ts +from tests.conftest import DTYPE +from torch_sim.constraints import FixAtoms, FixCom +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.transforms import unwrap_positions +from torch_sim.units import MetalUnits + + +def test_fix_com_nvt_langevin( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + n_steps = 1000 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + dofs_before = ar_double_sim_state.calc_dof() + ar_double_sim_state.constraints = [FixCom()] + assert torch.allclose(ar_double_sim_state.calc_dof(), dofs_before - 3) + + state = ts.nvt_langevin_init( + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 + ) + positions = [] + system_masses = torch.zeros((state.n_systems, 1), dtype=DTYPE).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 1), + state.masses.unsqueeze(-1), + ) + for _step in range(n_steps): + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) + positions.append(state.positions.clone()) + traj_positions = torch.stack(positions) + + unwrapped_positions = unwrap_positions( + traj_positions, ar_double_sim_state.cell, state.system_idx + ) + coms = torch.zeros((n_steps, state.n_systems, 3), dtype=DTYPE).scatter_add_( + 1, + state.system_idx[None, :, None].expand(n_steps, -1, 3), + state.masses.unsqueeze(-1) * unwrapped_positions, + ) + coms /= system_masses + coms_drift = coms - coms[0] + assert torch.allclose(coms_drift, torch.zeros_like(coms_drift), atol=1e-4) + + +def test_fix_atoms_nvt_langevin( + ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel +): + n_steps = 1000 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + dofs_before = ar_double_sim_state.calc_dof() + ar_double_sim_state.constraints = [ + FixAtoms(indices=torch.tensor([0, 1], dtype=torch.long)) + ] + assert torch.allclose( + ar_double_sim_state.calc_dof(), dofs_before - torch.tensor([6, 0]) + ) + state = ts.nvt_langevin_init( + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 + ) + positions = [] + for _step in range(n_steps): + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) + positions.append(state.positions.clone()) + traj_positions = torch.stack(positions) + + unwrapped_positions = unwrap_positions( + traj_positions, ar_double_sim_state.cell, state.system_idx + ) + diff_positions = unwrapped_positions - unwrapped_positions[0] + assert torch.max(diff_positions[:, :2]) < 1e-8 + assert torch.max(diff_positions[:, 2:]) > 1e-2 diff --git a/tests/test_state.py b/tests/test_state.py index 1657b382..95a04b8b 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -30,7 +30,7 @@ def test_get_attrs_for_scope(si_sim_state: SimState) -> None: per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) assert set(per_system_attrs) == {"cell"} global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) - assert set(global_attrs) == {"pbc"} + assert set(global_attrs) == {"pbc", "constraints"} def test_all_attributes_must_be_specified_in_scopes() -> None: diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index c86d732e..2f7ba4ab 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -7,6 +7,7 @@ import torch_sim as ts from torch_sim import ( autobatching, + constraints, elastic, io, math, diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py new file mode 100644 index 00000000..6d43b564 --- /dev/null +++ b/torch_sim/constraints.py @@ -0,0 +1,406 @@ +"""Constraints for molecular dynamics simulations. + +This module implements constraints inspired by ASE's constraint system, +adapted for the torch-sim framework with support for batched operations +and PyTorch tensors. + +The constraints affect degrees of freedom counting and modify forces, momenta, +and positions during MD simulations. +""" + +from __future__ import annotations + +import warnings +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import torch + + +if TYPE_CHECKING: + from torch_sim.state import SimState + + +class FixConstraint(ABC): + """Base class for all constraints in torch-sim. + + This is the abstract base class that all constraints must inherit from. + It defines the interface that constraints must implement to work with + the torch-sim MD system. + """ + + @abstractmethod + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Get the number of degrees of freedom removed by this constraint. + + Args: + state: The simulation state + + Returns: + Number of degrees of freedom removed by this constraint + """ + + @abstractmethod + def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: + """Adjust positions to satisfy the constraint. + + This method should modify new_positions in-place to ensure the + constraint is satisfied. + + Args: + state: Current simulation state + new_positions: Proposed new positions to be adjusted + """ + + def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: + """Adjust momenta to satisfy the constraint. + + This method should modify momenta in-place to ensure the constraint + is satisfied. By default, it calls adjust_forces with the momenta. + + Args: + state: Current simulation state + momenta: Momenta to be adjusted + """ + # Default implementation: treat momenta like forces + self.adjust_forces(state, momenta) + + @abstractmethod + def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: + """Adjust forces to satisfy the constraint. + + This method should modify forces in-place to ensure the constraint + is satisfied. + + Args: + state: Current simulation state + forces: Forces to be adjusted + """ + + def copy(self) -> FixConstraint: + """Create a copy of this constraint. + + Returns: + A new instance of this constraint with the same parameters + """ + return type(self)(**self.__dict__) + + def todict(self) -> dict[str, Any]: + """Convert constraint to dictionary representation. + + Returns: + Dictionary representation of the constraint + """ + return {"name": self.__class__.__name__, "kwargs": self.__dict__.copy()} + + +class IndexedConstraint(FixConstraint): + """Base class for constraints that act on specific atom indices. + + This class provides common functionality for constraints that operate + on a subset of atoms, identified by their indices. + """ + + def __init__(self, indices: torch.Tensor | list[int] | None = None) -> None: + """Initialize indexed constraint. + + Args: + indices: Indices of atoms to constrain. Can be a tensor or list of integers. + + Raises: + ValueError: If both indices and mask are provided, or if indices have + wrong shape/type + """ + if indices is None: + # Empty constraint + self.index = torch.empty(0, dtype=torch.long) + return + + # Convert to tensor if needed + if not isinstance(indices, torch.Tensor): + indices = torch.tensor(indices) + + # Ensure we have the right shape and type + indices = torch.atleast_1d(indices) + if indices.ndim > 1: + raise ValueError( + "indices has wrong number of dimensions. " + f"Got {indices.ndim}, expected ndim <= 1" + ) + + if indices.dtype == torch.bool: + # Convert boolean mask to indices + indices = torch.where(indices)[0] + elif len(indices) == 0: + indices = torch.empty(0, dtype=torch.long) + elif torch.is_floating_point(indices): + raise ValueError( + f"Indices must be integers or boolean mask, not dtype={indices.dtype}" + ) + + # Check for duplicates + if len(torch.unique(indices)) < len(indices): + raise ValueError( + "The indices array contains duplicates. " + "Perhaps you want to specify a mask instead, but " + "forgot the mask= keyword." + ) + + self.index = indices.long() + + def get_indices(self) -> torch.Tensor: + """Get the constrained atom indices. + + Returns: + Tensor of atom indices affected by this constraint + """ + return self.index.clone() + + +class FixAtoms(IndexedConstraint): + """Constraint that fixes specified atoms in place. + + This constraint prevents the specified atoms from moving by: + - Resetting their positions to original values + - Setting their forces to zero + - Removing 3 degrees of freedom per fixed atom + + Examples: + Fix atoms with indices [0, 1, 2]: + >>> constraint = FixAtoms(indices=[0, 1, 2]) + + Fix atoms using a boolean mask: + >>> mask = torch.tensor([True, True, True, False, False]) + >>> constraint = FixAtoms(mask=mask) + """ + + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Get number of removed degrees of freedom. + + Each fixed atom removes 3 degrees of freedom (x, y, z motion). + + Args: + state: Simulation state + + Returns: + Number of degrees of freedom removed (3 * number of fixed atoms) + """ + fixed_atoms_system_idx = torch.bincount( + state.system_idx[self.index], minlength=state.n_systems + ) + return 3 * fixed_atoms_system_idx + + def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: + """Reset positions of fixed atoms to their current values. + + Args: + state: Current simulation state + new_positions: Proposed positions to be adjusted in-place + """ + new_positions[self.index] = state.positions[self.index] + + def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: # noqa: ARG002 + """Set forces on fixed atoms to zero. + + Args: + state: Current simulation state + forces: Forces to be adjusted in-place + """ + forces[self.index] = 0.0 + + def __repr__(self) -> str: + """String representation of the constraint.""" + if len(self.index) <= 10: + indices_str = self.index.tolist() + else: + indices_str = f"{self.index[:5].tolist()}...{self.index[-5:].tolist()}" + return f"FixAtoms(indices={indices_str})" + + def todict(self) -> dict[str, Any]: + """Convert to dictionary representation. + + Returns: + Dictionary representation of the constraint + """ + return {"name": "FixAtoms", "kwargs": {"indices": self.index.tolist()}} + + +class FixCom(FixConstraint): + """Constraint that fixes the center of mass of all atoms per system. + + This constraint prevents the center of mass from moving by: + - Adjusting positions to maintain center of mass position + - Removing center of mass velocity from momenta + - Adjusting forces to remove net force + - Removing 3 degrees of freedom (center of mass translation) + + The constraint is applied to all atoms in the system. + """ + + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Get number of removed degrees of freedom. + + Fixing center of mass removes 3 degrees of freedom (x, y, z translation). + + Args: + state: Simulation state + + Returns: + Always returns 3 (center of mass translation degrees of freedom) + """ + # if self.index.numel() == 0: + # return 3 * torch.ones(state.n_systems, dtype=torch.long) + # removed_dof = torch.zeros(state.n_systems, dtype=torch.long) + # removed_dof[self.index] = 1 + # return 3 * removed_dof + return 3 * torch.ones(state.n_systems, dtype=torch.long) + + def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: + """Adjust positions to maintain center of mass position. + + Args: + state: Current simulation state + new_positions: Proposed positions to be adjusted in-place + """ + dtype = state.positions.dtype + system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + 0, state.system_idx, state.masses + ) + self.coms = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + state.masses.unsqueeze(-1) * state.positions, + ) + self.coms /= system_mass.unsqueeze(-1) + + new_com = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + state.masses.unsqueeze(-1) * state.positions, + ) + new_com /= system_mass.unsqueeze(-1) + displacement = -new_com + self.coms + new_positions += displacement[state.system_idx] + + def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: + """Remove center of mass velocity from momenta. + + Args: + state: Current simulation state + momenta: Momenta to be adjusted in-place + """ + # Compute center of mass momenta + dtype = momenta.dtype + com_momenta = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + momenta, + ) + system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + 0, state.system_idx, state.masses + ) + velocity_com = com_momenta / system_mass.unsqueeze(-1) + momenta -= velocity_com[state.system_idx] * state.masses.unsqueeze(-1) + + def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: + """Remove net force to prevent center of mass acceleration. + + This implements the constraint from Eq. (3) and (7) in + https://doi.org/10.1021/jp9722824 + + Args: + state: Current simulation state + forces: Forces to be adjusted in-place + """ + dtype = state.forces.dtype + system_square_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + 0, state.system_idx, torch.square(state.masses) + ) + lmd = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + forces * state.masses.unsqueeze(-1), + ) + lmd /= system_square_mass.unsqueeze(-1) + forces -= lmd[state.system_idx] * state.masses.unsqueeze(-1) + + def __repr__(self) -> str: + """String representation of the constraint.""" + return "FixCom()" + + def todict(self) -> dict[str, Any]: + """Convert to dictionary representation. + + Returns: + Dictionary representation of the constraint + """ + return {"name": "FixCom", "kwargs": {}} + + +def count_degrees_of_freedom( + state: SimState, constraints: list[FixConstraint] | None = None +) -> int: + """Count the total degrees of freedom in a system with constraints. + + This function calculates the total number of degrees of freedom by starting + with the unconstrained count (n_atoms * 3) and subtracting the degrees of + freedom removed by each constraint. + + Args: + state: Simulation state + constraints: List of active constraints (optional) + + Returns: + Total number of degrees of freedom + """ + # Start with unconstrained DOF + total_dof = state.n_atoms * 3 + + # Subtract DOF removed by constraints + if constraints is not None: + for constraint in constraints: + total_dof -= constraint.get_removed_dof(state) + + return max(0, total_dof) # Ensure non-negative + + +# WIP +def warn_if_overlapping_constraints(constraints: list[FixConstraint]) -> None: + """Issue warnings if constraints might overlap in problematic ways. + + This function checks for potential issues like multiple constraints + acting on the same atoms, which could lead to unexpected behavior. + + Args: + constraints: List of constraints to check + """ + indexed_constraints = [] + has_com_constraint = False + + for constraint in constraints: + if isinstance(constraint, IndexedConstraint): + indexed_constraints.append(constraint) + elif isinstance(constraint, FixCom): + has_com_constraint = True + + # Check for overlapping atom indices + if len(indexed_constraints) > 1: + all_indices = torch.cat([c.index for c in indexed_constraints]) + unique_indices = torch.unique(all_indices) + if len(unique_indices) < len(all_indices): + warnings.warn( + "Multiple constraints are acting on the same atoms. " + "This may lead to unexpected behavior.", + UserWarning, + stacklevel=2, + ) + + # Warn about COM constraint with fixed atoms + if has_com_constraint and indexed_constraints: + warnings.warn( + "Using FixCom together with other constraints may lead to " + "unexpected behavior. The center of mass constraint is applied " + "to all atoms, including those that may be constrained by other means.", + UserWarning, + stacklevel=2, + ) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 4d8f209a..2673e23e 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -10,7 +10,7 @@ from torch_sim.state import SimState -@dataclass +@dataclass(kw_only=True) class MDState(SimState): """State information for molecular dynamics simulations. @@ -55,6 +55,13 @@ def velocities(self) -> torch.Tensor: """ return self.momenta / self.masses.unsqueeze(-1) + def set_momenta(self, new_momenta: torch.Tensor) -> None: + """Set new momenta, applying any constraints as needed.""" + if self.constraints is not None: + for constraint in self.constraints: + constraint.adjust_momenta(self, new_momenta) + self.momenta = new_momenta + def calculate_momenta( positions: torch.Tensor, @@ -133,7 +140,7 @@ def momentum_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """ new_momenta = state.momenta + state.forces * dt - state.momenta = new_momenta + state.set_momenta(new_momenta) return state @@ -153,14 +160,14 @@ def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """ new_positions = state.positions + state.velocities * dt + state.set_positions(new_positions) if state.pbc: # Split positions and cells by system new_positions = transforms.pbc_wrap_batched( - new_positions, state.cell, state.system_idx + state.positions, state.cell, state.system_idx ) - - state.positions = new_positions + state.positions = new_positions # no constraints applied return state diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 1ab4e7c3..c8464a8e 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -19,7 +19,7 @@ from torch_sim.typing import StateDict -@dataclass +@dataclass(kw_only=True) class NPTLangevinState(SimState): """State information for an NPT system with Langevin dynamics. @@ -755,7 +755,7 @@ def npt_langevin_step( return _npt_langevin_velocity_step(state, forces, dt, kT, alpha) -@dataclass +@dataclass(kw_only=True) class NPTNoseHooverState(MDState): """State information for an NPT system with Nose-Hoover chain thermostats. diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index b4cb41ce..90baa0c1 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -129,6 +129,7 @@ def nvt_langevin_init( pbc=state.pbc, system_idx=state.system_idx, atomic_numbers=state.atomic_numbers, + constraints=state.constraints, ) @@ -196,7 +197,7 @@ def nvt_langevin_step( return momentum_step(state, dt / 2) -@dataclass +@dataclass(kw_only=True) class NVTNoseHooverState(MDState): """State information for an NVT system with a Nose-Hoover chain thermostat. diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 81fe2ab2..f75890a6 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -24,7 +24,7 @@ from torch_sim.state import SimState -@dataclass +@dataclass(kw_only=True) class SwapMCState(SimState): """State for Monte Carlo simulations with swap moves. diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 0a689432..898e58d2 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -80,6 +80,7 @@ def fire_init( "cell": state.cell.clone(), "atomic_numbers": state.atomic_numbers.clone(), "system_idx": state.system_idx.clone(), + "constraints": state.constraints, "pbc": state.pbc, # Optimization state "forces": forces, diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 220fb12a..a0efa071 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -571,7 +571,7 @@ def static( properties=properties, ) - @dataclass + @dataclass(kw_only=True) class StaticState(SimState): energy: torch.Tensor forces: torch.Tensor diff --git a/torch_sim/state.py b/torch_sim/state.py index a04fa5d3..6cd7d38a 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -9,7 +9,7 @@ import typing from collections import defaultdict from collections.abc import Generator, Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self import torch @@ -22,6 +22,7 @@ from ase import Atoms from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure +from torch_sim.constraints import FixConstraint @dataclass(init=False) @@ -51,6 +52,8 @@ class SimState: atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) system_idx (torch.Tensor): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. + constraints (list["FixConstraint"] | None): List of constraints applied to the + system. Constraints affect degrees of freedom and modify positions. Properties: wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary @@ -83,6 +86,7 @@ class SimState: pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor system_idx: torch.Tensor + constraints: list[FixConstraint] | None = field(default_factory=lambda: None) _atom_attributes: ClassVar[set[str]] = { "positions", @@ -91,7 +95,7 @@ class SimState: "system_idx", } _system_attributes: ClassVar[set[str]] = {"cell"} - _global_attributes: ClassVar[set[str]] = {"pbc"} + _global_attributes: ClassVar[set[str]] = {"pbc", "constraints"} def __init__( self, @@ -101,6 +105,7 @@ def __init__( pbc: bool, # noqa: FBT001 atomic_numbers: torch.Tensor, system_idx: torch.Tensor | None = None, + constraints: list["FixConstraint"] | None = None, ) -> None: """Initialize the SimState and validate the arguments. @@ -113,12 +118,15 @@ def __init__( system_idx (torch.Tensor | None): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. If not provided, it is initialized to zeros. + constraints (list["FixConstraint"] | None): List of constraints applied to the + system. If None, no constraints are applied. """ self.positions = positions self.masses = masses self.cell = cell self.pbc = pbc self.atomic_numbers = atomic_numbers + self.constraints = constraints # Validate and process the state after initialization. # data validation and fill system_idx @@ -234,6 +242,38 @@ def row_vector_cell(self, value: torch.Tensor) -> None: """ self.cell = value.mT + def set_positions(self, new_positions: torch.Tensor) -> None: + """Set the positions and apply constraints if they exist. + + Args: + new_positions: New positions tensor with shape (n_atoms, 3) + """ + # Apply constraints if they exist + if self.constraints is not None: + for constraint in self.constraints: + constraint.adjust_positions(self, new_positions) + self.positions = new_positions + + def calc_dof(self) -> torch.Tensor: + """Calculate degrees of freedom accounting for constraints. + + Returns: + torch.Tensor: Number of degrees of freedom per system, with shape + (n_systems,). Each system starts with 3 * n_atoms_per_system degrees + of freedom, minus any degrees removed by constraints. + """ + # Start with unconstrained DOF: 3 degrees per atom + dof_per_system = 3 * self.n_atoms_per_system + + # Subtract DOF removed by constraints + if self.constraints is not None: + for constraint in self.constraints: + removed_dof = constraint.get_removed_dof(self) + dof_per_system -= removed_dof + + # Ensure non-negative DOF + return torch.clamp(dof_per_system, min=0) + def clone(self) -> Self: """Create a deep copy of the SimState. diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 1b2c416b..1f905048 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -1157,3 +1157,67 @@ def safe_mask( """ masked = torch.where(mask, operand, torch.zeros_like(operand)) return torch.where(mask, fn(masked), torch.full_like(operand, placeholder)) + + +def unwrap_positions( + pos: torch.Tensor, box: torch.Tensor, system_idx: torch.Tensor +) -> torch.Tensor: + """Vectorized unwrapping for multiple systems without explicit loops. + + Parameters + ---------- + pos : (T, N_tot, 3) + Wrapped cartesian positions for all systems concatenated. + box : (n_systems, 3, 3) or (T, n_systems, 3, 3) + Box matrices, constant or time-dependent. + system_idx : (N_tot,) + For each atom, which system it belongs to (0..n_systems-1). + + Returns: + ------- + unwrapped_pos : (T, N_tot, 3) + Unwrapped cartesian positions. + """ + # -- Constant boxes per system + if box.ndim == 3: + inv_box = torch.inverse(box) # (n_systems, 3, 3) + + # Map each atom to its system's box + inv_box_atoms = inv_box[system_idx] # (N, 3, 3) + box_atoms = box[system_idx] # (N, 3, 3) + + # Compute fractional coordinates + frac = torch.einsum("tni,nij->tnj", pos, inv_box_atoms) + + # Fractional displacements and unwrap + dfrac = frac[1:] - frac[:-1] + dfrac -= torch.round(dfrac) + + # Back to Cartesian + dcart = torch.einsum("tni,nij->tnj", dfrac, box_atoms) + + # -- Time-dependent boxes per system + elif box.ndim == 4: + inv_box = torch.inverse(box) # (T, n_systems, 3, 3) + + # Gather each atom's box per frame efficiently + inv_box_atoms = inv_box[:, system_idx] # (T, N, 3, 3) + box_atoms = box[:, system_idx] # (T, N, 3, 3) + + # Compute fractional coordinates per frame + frac = torch.einsum("tni,tnij->tnj", pos, inv_box_atoms) + + dfrac = frac[1:] - frac[:-1] + dfrac -= torch.round(dfrac) + + dcart = torch.einsum("tni,tnij->tnj", dfrac, box_atoms[:-1]) + + else: + raise ValueError("box must have shape (n_systems,3,3) or (T,n_systems,3,3)") + + # Cumulative reconstruction + unwrapped = torch.empty_like(pos) + unwrapped[0] = pos[0] + unwrapped[1:] = torch.cumsum(dcart, dim=0) + unwrapped[0] + + return unwrapped From 6eb3d785a95f03571661bdde30d113f203ac1459 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 21 Oct 2025 14:46:48 +0200 Subject: [PATCH 03/43] change base class name for constraint --- torch_sim/constraints.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 6d43b564..5631baa9 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -21,7 +21,7 @@ from torch_sim.state import SimState -class FixConstraint(ABC): +class Constraint(ABC): """Base class for all constraints in torch-sim. This is the abstract base class that all constraints must inherit from. @@ -77,7 +77,7 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: forces: Forces to be adjusted """ - def copy(self) -> FixConstraint: + def copy(self) -> Constraint: """Create a copy of this constraint. Returns: @@ -94,7 +94,7 @@ def todict(self) -> dict[str, Any]: return {"name": self.__class__.__name__, "kwargs": self.__dict__.copy()} -class IndexedConstraint(FixConstraint): +class IndexedConstraint(Constraint): """Base class for constraints that act on specific atom indices. This class provides common functionality for constraints that operate @@ -225,7 +225,7 @@ def todict(self) -> dict[str, Any]: return {"name": "FixAtoms", "kwargs": {"indices": self.index.tolist()}} -class FixCom(FixConstraint): +class FixCom(Constraint): """Constraint that fixes the center of mass of all atoms per system. This constraint prevents the center of mass from moving by: @@ -338,7 +338,7 @@ def todict(self) -> dict[str, Any]: def count_degrees_of_freedom( - state: SimState, constraints: list[FixConstraint] | None = None + state: SimState, constraints: list[Constraint] | None = None ) -> int: """Count the total degrees of freedom in a system with constraints. @@ -365,7 +365,7 @@ def count_degrees_of_freedom( # WIP -def warn_if_overlapping_constraints(constraints: list[FixConstraint]) -> None: +def warn_if_overlapping_constraints(constraints: list[Constraint]) -> None: """Issue warnings if constraints might overlap in problematic ways. This function checks for potential issues like multiple constraints From c630f39177b1d5f80e6db8c8713e33db442bbea0 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 21 Oct 2025 14:49:08 +0200 Subject: [PATCH 04/43] remove useless methods --- torch_sim/constraints.py | 34 +--------------------------------- 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 5631baa9..965d00e5 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -12,7 +12,7 @@ import warnings from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import torch @@ -77,22 +77,6 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: forces: Forces to be adjusted """ - def copy(self) -> Constraint: - """Create a copy of this constraint. - - Returns: - A new instance of this constraint with the same parameters - """ - return type(self)(**self.__dict__) - - def todict(self) -> dict[str, Any]: - """Convert constraint to dictionary representation. - - Returns: - Dictionary representation of the constraint - """ - return {"name": self.__class__.__name__, "kwargs": self.__dict__.copy()} - class IndexedConstraint(Constraint): """Base class for constraints that act on specific atom indices. @@ -216,14 +200,6 @@ def __repr__(self) -> str: indices_str = f"{self.index[:5].tolist()}...{self.index[-5:].tolist()}" return f"FixAtoms(indices={indices_str})" - def todict(self) -> dict[str, Any]: - """Convert to dictionary representation. - - Returns: - Dictionary representation of the constraint - """ - return {"name": "FixAtoms", "kwargs": {"indices": self.index.tolist()}} - class FixCom(Constraint): """Constraint that fixes the center of mass of all atoms per system. @@ -328,14 +304,6 @@ def __repr__(self) -> str: """String representation of the constraint.""" return "FixCom()" - def todict(self) -> dict[str, Any]: - """Convert to dictionary representation. - - Returns: - Dictionary representation of the constraint - """ - return {"name": "FixCom", "kwargs": {}} - def count_degrees_of_freedom( state: SimState, constraints: list[Constraint] | None = None From f5459b9cea1b09038a9422722197bfacdf199a7b Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 21 Oct 2025 15:30:44 +0200 Subject: [PATCH 05/43] change redundant definition --- examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 3 +-- examples/tutorials/hybrid_swap_tutorial.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 331d69a1..17e3c37e 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -67,9 +67,8 @@ class HybridSwapMCState(ts.SwapMCState, MDState): last_swap: Last swap attempted """ - last_permutation: torch.Tensor _atom_attributes = ( - ts.SwapMCState._atom_attributes | MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ts.SwapMCState._atom_attributes | MDState._atom_attributes # noqa: SLF001 ) _system_attributes = ( ts.SwapMCState._system_attributes | MDState._system_attributes # noqa: SLF001 diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index c41a7f0b..80bc627f 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -100,9 +100,11 @@ class HybridSwapMCState(SwapMCState, MDState): from MDState. """ - last_permutation: torch.Tensor _atom_attributes = ( - MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ts.SwapMCState._atom_attributes | MDState._atom_attributes # noqa: SLF001 + ) + _system_attributes = ( + ts.SwapMCState._system_attributes | MDState._system_attributes # noqa: SLF001 ) From 6b2710e3cdbeec23db1f35d8da09d6fdbe8dad2a Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 23 Oct 2025 10:13:14 +0200 Subject: [PATCH 06/43] constraint to optimizer, compatibility with state manipulation --- tests/test_constraints.py | 312 ++++++++++++++++++++++- torch_sim/constraints.py | 152 ++++++++--- torch_sim/integrators/md.py | 13 +- torch_sim/integrators/npt.py | 1 + torch_sim/integrators/nve.py | 1 + torch_sim/integrators/nvt.py | 1 + torch_sim/monte_carlo.py | 1 + torch_sim/optimizers/fire.py | 23 +- torch_sim/optimizers/gradient_descent.py | 6 +- torch_sim/optimizers/state.py | 37 ++- torch_sim/state.py | 155 ++++++++++- torch_sim/transforms.py | 27 ++ 12 files changed, 650 insertions(+), 79 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index c919ade8..ade7e6e8 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -1,17 +1,80 @@ +from typing import get_args + +import pytest import torch import torch_sim as ts from tests.conftest import DTYPE from torch_sim.constraints import FixAtoms, FixCom +from torch_sim.models.interface import ModelInterface from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.transforms import unwrap_positions +from torch_sim.optimizers import FireFlavor +from torch_sim.transforms import get_centers_of_mass, unwrap_positions from torch_sim.units import MetalUnits +def test_fix_com(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel): + """Test adjustment of positions and momenta with FixCom constraint.""" + ar_supercell_sim_state.add_constraints([FixCom()]) + initial_positions = ar_supercell_sim_state.positions.clone() + ar_supercell_sim_state.set_positions(initial_positions + 0.5) + assert torch.allclose(ar_supercell_sim_state.positions, initial_positions, atol=1e-8) + + ar_supercell_mdstate = ts.nve_init( + state=ar_supercell_sim_state, + model=lj_model, + kT=torch.tensor(10.0, dtype=DTYPE), + seed=42, + ) + ar_supercell_mdstate.set_momenta(torch.randn_like(ar_supercell_mdstate.momenta) * 0.1) + assert torch.allclose( + ar_supercell_mdstate.momenta.mean(dim=0), torch.zeros(3, dtype=DTYPE), atol=1e-8 + ) + + +def test_fix_atoms(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel): + """Test adjustment of positions and momenta with FixAtoms constraint.""" + indices_to_fix = torch.tensor([0, 5, 10], dtype=torch.long) + ar_supercell_sim_state.add_constraints([FixAtoms(indices=indices_to_fix)]) + initial_positions = ar_supercell_sim_state.positions.clone() + # displacement = torch.randn_like(ar_supercell_sim_state.positions) * 0.5 + displacement = 0.5 + ar_supercell_sim_state.set_positions(initial_positions + displacement) + assert torch.allclose( + ar_supercell_sim_state.positions[indices_to_fix], + initial_positions[indices_to_fix], + atol=1e-8, + ) + # Check that other positions have changed + unfixed_indices = torch.tensor( + [i for i in range(ar_supercell_sim_state.n_atoms) if i not in indices_to_fix], + dtype=torch.long, + ) + assert not torch.allclose( + ar_supercell_sim_state.positions[unfixed_indices], + initial_positions[unfixed_indices], + atol=1e-8, + ) + + ar_supercell_mdstate = ts.nve_init( + state=ar_supercell_sim_state, + model=lj_model, + kT=torch.tensor(10.0, dtype=DTYPE), + seed=42, + ) + ar_supercell_mdstate.set_momenta(torch.randn_like(ar_supercell_mdstate.momenta) * 0.1) + assert torch.allclose( + ar_supercell_mdstate.momenta[indices_to_fix], + torch.zeros_like(ar_supercell_mdstate.momenta[indices_to_fix]), + atol=1e-8, + ) + + def test_fix_com_nvt_langevin( ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel ): - n_steps = 1000 + """Test FixCom constraint in NVT Langevin dynamics.""" + n_steps = 200 dt = torch.tensor(0.001, dtype=DTYPE) kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature @@ -33,23 +96,24 @@ def test_fix_com_nvt_langevin( positions.append(state.positions.clone()) traj_positions = torch.stack(positions) - unwrapped_positions = unwrap_positions( - traj_positions, ar_double_sim_state.cell, state.system_idx - ) + # unwrapped_positions = unwrap_positions( + # traj_positions, ar_double_sim_state.cell, state.system_idx + # ) coms = torch.zeros((n_steps, state.n_systems, 3), dtype=DTYPE).scatter_add_( 1, state.system_idx[None, :, None].expand(n_steps, -1, 3), - state.masses.unsqueeze(-1) * unwrapped_positions, + state.masses.unsqueeze(-1) * traj_positions, ) coms /= system_masses coms_drift = coms - coms[0] - assert torch.allclose(coms_drift, torch.zeros_like(coms_drift), atol=1e-4) + assert torch.allclose(coms_drift, torch.zeros_like(coms_drift), atol=1e-8) def test_fix_atoms_nvt_langevin( ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel ): - n_steps = 1000 + """Test FixAtoms constraint in NVT Langevin dynamics.""" + n_steps = 200 dt = torch.tensor(0.001, dtype=DTYPE) kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature @@ -74,4 +138,234 @@ def test_fix_atoms_nvt_langevin( ) diff_positions = unwrapped_positions - unwrapped_positions[0] assert torch.max(diff_positions[:, :2]) < 1e-8 - assert torch.max(diff_positions[:, 2:]) > 1e-2 + assert torch.max(diff_positions[:, 2:]) > 1e-3 + + +def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): + """Test that constraints are properly propagated during state manipulation.""" + # Set up constraints on the original state + ar_double_sim_state.add_constraints( + [FixAtoms(indices=torch.tensor([0, 1])), FixCom()] + ) + + # Extract individual systems from the double system state + first_system = ar_double_sim_state[0] + second_system = ar_double_sim_state[1] + concatenated_state = ts.concatenate_states( + [first_system, first_system, second_system] + ) + + # Verify constraint propagation to subsystems + assert len(first_system.constraints) == 2 + assert len(second_system.constraints) == 2 + assert len(concatenated_state.constraints) == 2 + + # Verify FixAtoms constraint indices are correctly mapped + assert torch.all(first_system.constraints[0].indices == torch.tensor([0, 1])) + assert torch.all(second_system.constraints[0].indices == torch.tensor([])) + assert torch.all( + concatenated_state.constraints[0].indices == torch.tensor([0, 1, 32, 33]) + ) + + # Verify FixCom constraint system masks + assert torch.all( + concatenated_state.constraints[1].system_idx == torch.tensor([0, 1, 2]) + ) + + # Test constraint propagation after splitting concatenated state + split_systems = concatenated_state.split() + assert len(split_systems[0].constraints) == 2 + assert torch.all(split_systems[0].constraints[0].indices == torch.tensor([0, 1])) + assert torch.all(split_systems[1].constraints[0].indices == torch.tensor([0, 1])) + assert torch.all( + split_systems[2].constraints[0].indices == torch.tensor([], dtype=torch.long) + ) + + # Test constraint manipulation with different configurations + ar_double_sim_state.constraints = [] + ar_double_sim_state.add_constraints([FixCom()]) + isolated_system = ar_double_sim_state[0] + assert torch.all( + isolated_system.constraints[0].system_idx == torch.tensor([0], dtype=torch.long) + ) + + # Test concatenation with mixed constraint states + isolated_system.constraints = [] + mixed_concatenated_state = ts.concatenate_states( + [isolated_system, ar_double_sim_state, isolated_system] + ) + assert torch.all( + mixed_concatenated_state.constraints[0].system_idx == torch.tensor([1, 2]) + ) + + +def test_fix_com_gradient_descent_optimization( + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface +) -> None: + """Test FixCom constraint in Gradient Descent optimization.""" + # Add some random displacement to positions + perturbed_positions = ( + ar_supercell_sim_state.positions + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + ar_supercell_sim_state.positions = perturbed_positions + initial_state = ar_supercell_sim_state + ar_supercell_sim_state.add_constraints(FixCom()) + + initial_coms = get_centers_of_mass( + positions=initial_state.positions, + masses=initial_state.masses, + system_idx=initial_state.system_idx, + n_systems=initial_state.n_systems, + ) + + # Initialize Gradient Descent optimizer + state = ts.gradient_descent_init( + state=ar_supercell_sim_state, model=lj_model, lr=0.01 + ) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + while abs(energies[-2] - energies[-1]) > 1e-6: + state = ts.gradient_descent_step(state=state, model=lj_model, pos_lr=0.01) + energies.append(state.energy.item()) + + final_coms = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=initial_state.n_systems, + ) + + assert torch.allclose(final_coms, initial_coms, atol=1e-4) + assert not torch.allclose(state.positions, initial_state.positions) + + +def test_fix_atoms_gradient_descent_optimization( + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface +) -> None: + """Test FixAtoms constraint in Gradient Descent optimization.""" + # Add some random displacement to positions + perturbed_positions = ( + ar_supercell_sim_state.positions + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + ar_supercell_sim_state.positions = perturbed_positions + initial_state = ar_supercell_sim_state + initial_state.add_constraints(FixAtoms(indices=[0])) + initial_position = initial_state.positions[0].clone() + + # Initialize Gradient Descent optimizer + state = ts.gradient_descent_init( + state=ar_supercell_sim_state, model=lj_model, lr=0.01 + ) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + while abs(energies[-2] - energies[-1]) > 1e-6: + state = ts.gradient_descent_step(state=state, model=lj_model, pos_lr=0.01) + energies.append(state.energy.item()) + + final_position = state.positions[0] + + assert torch.allclose(final_position, initial_position, atol=1e-5) + assert not torch.allclose(state.positions, initial_state.positions) + + +@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) +def test_test_atoms_fire_optimization( + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, fire_flavor: FireFlavor +) -> None: + """Test FixAtoms constraint in FIRE optimization.""" + # Add some random displacement to positions + # Create a fresh copy for each test run to avoid interference + + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = ts.SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + indices = torch.tensor([0, 2], dtype=torch.long) + current_sim_state.add_constraints(FixAtoms(indices=indices)) + + # Initialize FIRE optimizer + state = ts.fire_init( + current_sim_state, lj_model, fire_flavor=fire_flavor, dt_start=0.1 + ) + initial_position = state.positions[indices].clone() + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + max_steps = 1000 # Add max step to prevent infinite loop + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.fire_step(state=state, model=lj_model, dt_max=0.3) + energies.append(state.energy.item()) + steps_taken += 1 + + final_position = state.positions[indices] + + assert torch.allclose(final_position, initial_position, atol=1e-5) + + +@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) +def test_fix_com_fire_optimization( + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, fire_flavor: FireFlavor +) -> None: + """Test FixCom constraint in FIRE optimization.""" + # Add some random displacement to positions + # Create a fresh copy for each test run to avoid interference + + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = ts.SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + current_sim_state.add_constraints(FixCom()) + + # Initialize FIRE optimizer + state = ts.fire_init( + current_sim_state, lj_model, fire_flavor=fire_flavor, dt_start=0.1 + ) + initial_com = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=state.n_systems, + ) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + max_steps = 1000 # Add max step to prevent infinite loop + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.fire_step(state=state, model=lj_model, dt_max=0.3) + energies.append(state.energy.item()) + steps_taken += 1 + + final_com = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=state.n_systems, + ) + + assert torch.allclose(final_com, initial_com, atol=1e-4) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 965d00e5..5431fd8c 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -78,7 +78,7 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: """ -class IndexedConstraint(Constraint): +class AtomIndexedConstraint(Constraint): """Base class for constraints that act on specific atom indices. This class provides common functionality for constraints that operate @@ -97,7 +97,7 @@ def __init__(self, indices: torch.Tensor | list[int] | None = None) -> None: """ if indices is None: # Empty constraint - self.index = torch.empty(0, dtype=torch.long) + self.indices = torch.empty(0, dtype=torch.long) return # Convert to tensor if needed @@ -130,7 +130,7 @@ def __init__(self, indices: torch.Tensor | list[int] | None = None) -> None: "forgot the mask= keyword." ) - self.index = indices.long() + self.indices = indices.long() def get_indices(self) -> torch.Tensor: """Get the constrained atom indices. @@ -138,10 +138,45 @@ def get_indices(self) -> torch.Tensor: Returns: Tensor of atom indices affected by this constraint """ - return self.index.clone() + return self.indices.clone() -class FixAtoms(IndexedConstraint): +class SystemConstraint(Constraint): + """Base class for constraints that act on specific system indices. + + This class provides common functionality for constraints that operate + on a subset of systems, identified by their indices. + """ + + def __init__(self, system_idx: torch.Tensor | list[int] | None = None) -> None: + """Initialize indexed constraint. + + Args: + system_idx: Indices of systems to constrain. Can be a tensor or + list of integers. + + Raises: + ValueError: If both indices and mask are provided, or if indices have + wrong shape/type + """ + if system_idx is None: + # Empty constraint + self.system_idx = slice(None) # All systems + return + + # Convert to tensor if needed + system_idx = torch.as_tensor(system_idx) + + # Ensure we have the right shape and type + system_idx = torch.atleast_1d(system_idx) + if system_idx.ndim > 1: + raise ValueError( + "system_idx has wrong number of dimensions. " + f"Got {system_idx.ndim}, expected ndim <= 1" + ) + + +class FixAtoms(AtomIndexedConstraint): """Constraint that fixes specified atoms in place. This constraint prevents the specified atoms from moving by: @@ -170,7 +205,7 @@ def get_removed_dof(self, state: SimState) -> torch.Tensor: Number of degrees of freedom removed (3 * number of fixed atoms) """ fixed_atoms_system_idx = torch.bincount( - state.system_idx[self.index], minlength=state.n_systems + state.system_idx[self.indices], minlength=state.n_systems ) return 3 * fixed_atoms_system_idx @@ -181,7 +216,7 @@ def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None state: Current simulation state new_positions: Proposed positions to be adjusted in-place """ - new_positions[self.index] = state.positions[self.index] + new_positions[self.indices] = state.positions[self.indices] def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: # noqa: ARG002 """Set forces on fixed atoms to zero. @@ -190,18 +225,18 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: # noqa: state: Current simulation state forces: Forces to be adjusted in-place """ - forces[self.index] = 0.0 + forces[self.indices] = 0.0 def __repr__(self) -> str: """String representation of the constraint.""" - if len(self.index) <= 10: - indices_str = self.index.tolist() + if len(self.indices) <= 10: + indices_str = self.indices.tolist() else: - indices_str = f"{self.index[:5].tolist()}...{self.index[-5:].tolist()}" + indices_str = f"{self.indices[:5].tolist()}...{self.indices[-5:].tolist()}" return f"FixAtoms(indices={indices_str})" -class FixCom(Constraint): +class FixCom(SystemConstraint): """Constraint that fixes the center of mass of all atoms per system. This constraint prevents the center of mass from moving by: @@ -224,11 +259,10 @@ def get_removed_dof(self, state: SimState) -> torch.Tensor: Returns: Always returns 3 (center of mass translation degrees of freedom) """ - # if self.index.numel() == 0: - # return 3 * torch.ones(state.n_systems, dtype=torch.long) - # removed_dof = torch.zeros(state.n_systems, dtype=torch.long) - # removed_dof[self.index] = 1 - # return 3 * removed_dof + if self.system_idx != slice(None): + affected_systems = torch.zeros(state.n_systems, dtype=torch.long) + affected_systems[self.system_idx] = 1 + return 3 * affected_systems return 3 * torch.ones(state.n_systems, dtype=torch.long) def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: @@ -239,23 +273,35 @@ def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None new_positions: Proposed positions to be adjusted in-place """ dtype = state.positions.dtype - system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( - 0, state.system_idx, state.masses + n_systems = ( + state.n_systems if self.system_idx == slice(None) else len(self.system_idx) ) - self.coms = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( - 0, - state.system_idx.unsqueeze(-1).expand(-1, 3), - state.masses.unsqueeze(-1) * state.positions, + index_to_consider = ( + torch.isin(state.system_idx, self.system_idx) + if self.system_idx != slice(None) + else torch.ones(state.n_atoms, dtype=torch.bool) + ) + system_mass = torch.zeros(n_systems, dtype=dtype).scatter_add_( + 0, state.system_idx[index_to_consider], state.masses[index_to_consider] ) - self.coms /= system_mass.unsqueeze(-1) + if not hasattr(self, "coms"): + self.coms = torch.zeros((n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx[index_to_consider].unsqueeze(-1).expand(-1, 3), + state.masses[index_to_consider].unsqueeze(-1) + * state.positions[index_to_consider], + ) + self.coms /= system_mass.unsqueeze(-1) - new_com = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + new_com = torch.zeros((n_systems, 3), dtype=dtype).scatter_add_( 0, - state.system_idx.unsqueeze(-1).expand(-1, 3), - state.masses.unsqueeze(-1) * state.positions, + state.system_idx[index_to_consider].unsqueeze(-1).expand(-1, 3), + state.masses[index_to_consider].unsqueeze(-1) + * new_positions[index_to_consider], ) new_com /= system_mass.unsqueeze(-1) - displacement = -new_com + self.coms + displacement = torch.zeros(state.n_systems, 3, dtype=dtype) + displacement[self.system_idx] = -new_com + self.coms new_positions += displacement[state.system_idx] def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: @@ -267,16 +313,26 @@ def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: """ # Compute center of mass momenta dtype = momenta.dtype - com_momenta = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + n_systems = ( + state.n_systems if self.system_idx == slice(None) else len(self.system_idx) + ) + index_to_consider = ( + torch.isin(state.system_idx, self.system_idx) + if self.system_idx != slice(None) + else torch.ones(state.n_atoms, dtype=torch.bool) + ) + com_momenta = torch.zeros((n_systems, 3), dtype=dtype).scatter_add_( 0, - state.system_idx.unsqueeze(-1).expand(-1, 3), - momenta, + state.system_idx[index_to_consider].unsqueeze(-1).expand(-1, 3), + momenta[index_to_consider], ) - system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( - 0, state.system_idx, state.masses + system_mass = torch.zeros(n_systems, dtype=dtype).scatter_add_( + 0, state.system_idx[index_to_consider], state.masses[index_to_consider] ) velocity_com = com_momenta / system_mass.unsqueeze(-1) - momenta -= velocity_com[state.system_idx] * state.masses.unsqueeze(-1) + velocity_change = torch.zeros(state.n_systems, 3, dtype=dtype) + velocity_change[self.system_idx] = velocity_com + momenta -= velocity_change[state.system_idx] * state.masses.unsqueeze(-1) def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: """Remove net force to prevent center of mass acceleration. @@ -288,17 +344,29 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: state: Current simulation state forces: Forces to be adjusted in-place """ - dtype = state.forces.dtype - system_square_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( - 0, state.system_idx, torch.square(state.masses) + dtype = state.positions.dtype + n_systems = ( + state.n_systems if self.system_idx == slice(None) else len(self.system_idx) + ) + index_to_consider = ( + torch.isin(state.system_idx, self.system_idx) + if self.system_idx != slice(None) + else torch.ones(state.n_atoms, dtype=torch.bool) + ) + system_square_mass = torch.zeros(n_systems, dtype=dtype).scatter_add_( + 0, + state.system_idx[index_to_consider], + torch.square(state.masses[index_to_consider]), ) - lmd = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + lmd = torch.zeros((n_systems, 3), dtype=dtype).scatter_add_( 0, - state.system_idx.unsqueeze(-1).expand(-1, 3), - forces * state.masses.unsqueeze(-1), + state.system_idx[index_to_consider].unsqueeze(-1).expand(-1, 3), + forces[index_to_consider] * state.masses[index_to_consider].unsqueeze(-1), ) lmd /= system_square_mass.unsqueeze(-1) - forces -= lmd[state.system_idx] * state.masses.unsqueeze(-1) + forces_change = torch.zeros(state.n_systems, 3, dtype=dtype) + forces_change[self.system_idx] = lmd + forces -= forces_change[state.system_idx] * state.masses.unsqueeze(-1) def __repr__(self) -> str: """String representation of the constraint.""" @@ -346,7 +414,7 @@ def warn_if_overlapping_constraints(constraints: list[Constraint]) -> None: has_com_constraint = False for constraint in constraints: - if isinstance(constraint, IndexedConstraint): + if isinstance(constraint, AtomIndexedConstraint): indexed_constraints.append(constraint) elif isinstance(constraint, FixCom): has_com_constraint = True diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 2673e23e..6161cfcd 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -5,7 +5,6 @@ import torch -from torch_sim import transforms from torch_sim.models.interface import ModelInterface from torch_sim.state import SimState @@ -162,12 +161,12 @@ def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: new_positions = state.positions + state.velocities * dt state.set_positions(new_positions) - if state.pbc: - # Split positions and cells by system - new_positions = transforms.pbc_wrap_batched( - state.positions, state.cell, state.system_idx - ) - state.positions = new_positions # no constraints applied + # if state.pbc: + # # Split positions and cells by system + # new_positions = transforms.pbc_wrap_batched( + # state.positions, state.cell, state.system_idx + # ) + # state.positions = new_positions # no constraints applied return state diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index c8464a8e..8c633f06 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -643,6 +643,7 @@ def npt_langevin_init( cell_positions=cell_positions, cell_velocities=cell_velocities, cell_masses=cell_masses, + constraints=state.constraints, ) diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index d3773b3c..532add73 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -67,6 +67,7 @@ def nve_init( pbc=state.pbc, system_idx=state.system_idx, atomic_numbers=state.atomic_numbers, + constraints=state.constraints, ) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 90baa0c1..c75cc8d7 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -328,6 +328,7 @@ def nvt_nose_hoover_init( system_idx=state.system_idx, chain=chain_fns.initialize(total_dof, KE, kT), _chain_fns=chain_fns, # Store the chain functions + constraints=state.constraints, ) diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index f75890a6..c17bf6ae 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -214,6 +214,7 @@ def swap_mc_init( system_idx=state.system_idx, energy=model_output["energy"], last_permutation=torch.arange(state.n_atoms, device=state.device), + constraints=state.constraints, ) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 898e58d2..110c6bbc 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -212,13 +212,15 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) # Position update - state.positions = state.positions + atom_wise_dt * state.velocities + # state.positions = state.positions + atom_wise_dt * state.velocities + state.set_positions(state.positions + atom_wise_dt * state.velocities) # Cell position updates are handled in the velocity update step above # Get new forces and energy model_output = model(state) - state.forces = model_output["forces"] + # state.forces = model_output["forces"] + state.set_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] @@ -420,7 +422,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 cur_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell ) - state.positions = ( + state.set_positions( torch.linalg.solve( cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) ).squeeze(-1) @@ -455,16 +457,19 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 new_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell ) - state.positions = torch.bmm( - state.positions.unsqueeze(1), - new_deform_grad[state.system_idx].transpose(-2, -1), - ).squeeze(1) + state.set_positions( + torch.bmm( + state.positions.unsqueeze(1), + new_deform_grad[state.system_idx].transpose(-2, -1), + ).squeeze(1) + ) else: - state.positions = state.positions + dr_atom + state.set_positions(state.positions + dr_atom) # Get new forces, energy, and stress model_output = model(state) - state.forces = model_output["forces"] + # state.forces = model_output["forces"] + state.set_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index bfdfcf3f..d100bfaf 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -61,6 +61,7 @@ def gradient_descent_init( "pbc": state.pbc, "atomic_numbers": state.atomic_numbers, "system_idx": state.system_idx, + "constraints": state.constraints, } if cell_filter is not None: # Create cell optimization state @@ -107,7 +108,8 @@ def gradient_descent_step( atom_lr = pos_lr[state.system_idx].unsqueeze(-1) # Update atomic positions - state.positions = state.positions + atom_lr * state.forces + # state.positions = state.positions + atom_lr * state.forces + state.set_positions(state.positions + atom_lr * state.forces) # Update cell if using cell optimization if isinstance(state, CellOptimState): @@ -117,7 +119,7 @@ def gradient_descent_step( # Get updated forces, energy, and stress model_output = model(state) - state.forces = model_output["forces"] + state.set_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index 2ab530db..b65455b7 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -7,7 +7,7 @@ from torch_sim.state import SimState -@dataclass(kw_only=True) +@dataclass(kw_only=True, init=False) class OptimState(SimState): """Unified state class for optimization algorithms. @@ -23,6 +23,41 @@ class OptimState(SimState): _atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001 _system_attributes = SimState._system_attributes | {"energy", "stress"} # noqa: SLF001 + def set_forces(self, new_forces: torch.Tensor) -> None: + """Set new forces in the optimization state.""" + if self.constraints is not None: + for constraint in self.constraints: + constraint.adjust_forces(self, new_forces) + self.forces = new_forces + + def __init__( + self, + *, + positions: torch.Tensor, + forces: torch.Tensor, + energy: torch.Tensor, + stress: torch.Tensor | None = None, + masses: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + atomic_numbers: torch.Tensor, + system_idx: torch.Tensor, + constraints: list | None = None, + ) -> None: + """Initialize optimization state.""" + super().__init__( + positions=positions, + masses=masses, + cell=cell, + pbc=pbc, + atomic_numbers=atomic_numbers, + system_idx=system_idx, + constraints=constraints, + ) + self.energy = energy + self.set_forces(forces) + self.stress = stress + @dataclass(kw_only=True) class FireState(OptimState): diff --git a/torch_sim/state.py b/torch_sim/state.py index 6cd7d38a..5fdff5ad 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -22,7 +22,8 @@ from ase import Atoms from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure -from torch_sim.constraints import FixConstraint + +from torch_sim.constraints import AtomIndexedConstraint, Constraint, SystemConstraint @dataclass(init=False) @@ -52,7 +53,7 @@ class SimState: atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) system_idx (torch.Tensor): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. - constraints (list["FixConstraint"] | None): List of constraints applied to the + constraints (list["Constraint"] | None): List of constraints applied to the system. Constraints affect degrees of freedom and modify positions. Properties: @@ -86,7 +87,7 @@ class SimState: pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor system_idx: torch.Tensor - constraints: list[FixConstraint] | None = field(default_factory=lambda: None) + constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 _atom_attributes: ClassVar[set[str]] = { "positions", @@ -105,7 +106,7 @@ def __init__( pbc: bool, # noqa: FBT001 atomic_numbers: torch.Tensor, system_idx: torch.Tensor | None = None, - constraints: list["FixConstraint"] | None = None, + constraints: list["Constraint"] | None = None, ) -> None: """Initialize the SimState and validate the arguments. @@ -118,7 +119,7 @@ def __init__( system_idx (torch.Tensor | None): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. If not provided, it is initialized to zeros. - constraints (list["FixConstraint"] | None): List of constraints applied to the + constraints (list["Constraint"] | None): List of constraints applied to the system. If None, no constraints are applied. """ self.positions = positions @@ -126,7 +127,7 @@ def __init__( self.cell = cell self.pbc = pbc self.atomic_numbers = atomic_numbers - self.constraints = constraints + self.constraints = constraints if constraints is not None else [] # Validate and process the state after initialization. # data validation and fill system_idx @@ -254,6 +255,23 @@ def set_positions(self, new_positions: torch.Tensor) -> None: constraint.adjust_positions(self, new_positions) self.positions = new_positions + def add_constraints(self, constraints: list[Constraint] | Constraint) -> None: + """Set the constraints for the SimState. + + Args: + constraints (list["Constraint"] | None): List of constraints to apply. + If None, no constraints are applied. + """ + # check it is a list + if isinstance(constraints, Constraint): + constraints = [constraints] + for constraint in constraints: + # if constraint.system_idx exists + if hasattr(constraint, "system_idx") and constraint.system_idx == slice(None): + constraint.system_idx = torch.arange(self.n_systems, device=self.device) + + self.constraints += constraints + def calc_dof(self) -> torch.Tensor: """Calculate degrees of freedom accounting for constraints. @@ -671,8 +689,16 @@ def _filter_attrs_by_mask( Returns: dict: Filtered attributes with appropriate handling for each scope """ + # atoms_mask = torch.isin(state.system_idx, torch.nonzero(system_mask).squeeze()) # Copy global attributes directly filtered_attrs = dict(get_attrs_for_scope(state, "global")) + filtered_attrs["constraints"] = copy.deepcopy(filtered_attrs.get("constraints", [])) + + new_n_atoms_per_system = state.n_atoms_per_system[system_mask] + cum_sum_atoms = torch.cumsum(new_n_atoms_per_system, dim=0) + cum_sum_atoms = torch.cat( + (torch.tensor([0], device=cum_sum_atoms.device), cum_sum_atoms) + ) # Filter per-atom attributes for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): @@ -695,6 +721,39 @@ def _filter_attrs_by_mask( dtype=attr_value.dtype, ) filtered_attrs[attr_name] = new_system_idxs + + # take into account constraints that are AtomIndexedConstraint + for constraint in filtered_attrs.get("constraints", []): + if isinstance(constraint, AtomIndexedConstraint): + constraint.indices = torch.tensor( + [ + i + - cum_sum_atoms[ + system_idx_map[ + old_system_indices[state.system_idx[i]].item() + ] + ] + for i in constraint.indices + if atom_mask[i] + ], + device=old_system_indices.device, + dtype=constraint.indices.dtype, + ) + elif isinstance(constraint, SystemConstraint) and isinstance( + constraint.system_idx, torch.Tensor + ): + # print(constraint.system_idx, system_mask) + # constraint.system_idx = constraint.system_idx[system_mask] + constraint.system_idx = torch.tensor( + [ + system_idx_map[idx.item()] + for idx in constraint.system_idx + if system_mask[idx] + ], + device=constraint.system_idx.device, + dtype=constraint.system_idx.dtype, + ) + else: filtered_attrs[attr_name] = attr_value[atom_mask] @@ -721,7 +780,7 @@ def _split_state[T: SimState](state: T) -> list[T]: list[SimState]: A list of SimState objects, each containing a single system """ - system_sizes = torch.bincount(state.system_idx).tolist() + system_sizes = state.n_atoms_per_system.tolist() split_per_atom = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): @@ -740,6 +799,7 @@ def _split_state[T: SimState](state: T) -> list[T]: # Create a state for each system states: list[T] = [] n_systems = len(system_sizes) + n_atoms_cumsum = 0 for sys_idx in range(n_systems): system_attrs = { # Create a system tensor with all zeros for this system @@ -759,7 +819,30 @@ def _split_state[T: SimState](state: T) -> list[T]: # Add the global attributes **global_attrs, } + system_attrs["constraints"] = copy.deepcopy(system_attrs.get("constraints", [])) + for constraint in system_attrs.get("constraints", []): + if isinstance(constraint, SystemConstraint): + # Update system_mask to only include this system + constraint.system_idx = ( + torch.tensor([0], device=state.device, dtype=torch.int64) + if sys_idx in constraint.system_idx + else torch.tensor([], device=state.device, dtype=torch.int64) + ) + elif isinstance(constraint, AtomIndexedConstraint): + # Update atom_indices to only include atoms from this system + atom_start = n_atoms_cumsum + atom_end = n_atoms_cumsum + system_sizes[sys_idx] + constraint.indices = torch.tensor( + [ + idx - atom_start + for idx in constraint.indices + if atom_start <= idx < atom_end + ], + device=state.device, + dtype=torch.int64, + ) states.append(type(state)(**system_attrs)) # type: ignore[invalid-argument-type] + n_atoms_cumsum += system_sizes[sys_idx] return states @@ -848,7 +931,7 @@ def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor return type(state)(**filtered_attrs) # type: ignore[invalid-return-type] -def concatenate_states[T: SimState]( # noqa: C901 +def concatenate_states[T: SimState]( # noqa: C901, PLR0915 states: Sequence[T], device: torch.device | None = None ) -> T: """Concatenate a list of SimStates into a single SimState. @@ -885,12 +968,16 @@ def concatenate_states[T: SimState]( # noqa: C901 # Initialize result with global properties from first state concatenated = dict(get_attrs_for_scope(first_state, "global")) + del concatenated["constraints"] # will handle constraints separately # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) new_system_indices = [] system_offset = 0 + n_atoms_offset = 0 + + constraints = {} # Process all states in a single pass for state in states: @@ -913,6 +1000,55 @@ def concatenate_states[T: SimState]( # noqa: C901 num_systems = state.n_systems new_indices = state.system_idx + system_offset new_system_indices.append(new_indices) + + if state.constraints is not None: + for constraint in state.constraints: + constraint_name = type(constraint).__name__ + if isinstance( + constraint, SystemConstraint + ) and constraint.system_idx == slice(None): + constraint.system_idx = torch.arange( + num_systems, device=target_device + ) + if constraint_name not in constraints: + # if it's IndexedConstraint then we need to adjust the indices + if isinstance(constraint, AtomIndexedConstraint): + new_constraint = copy.deepcopy(constraint) + new_constraint.indices = torch.empty( + 0, dtype=torch.long, device=target_device + ) + constraints[constraint_name] = new_constraint + elif isinstance(constraint, SystemConstraint): + new_constraint = copy.deepcopy(constraint) + new_constraint.system_idx = torch.empty( + 0, dtype=torch.long, device=target_device + ) + constraints[constraint_name] = new_constraint + else: + raise NotImplementedError( + f"Concatenation of constraint type " + f"{type(constraint)} is not implemented" + ) + # need to adjust the indices for IndexedConstraint + if isinstance(constraint, AtomIndexedConstraint): + new_constraint = constraints[constraint_name] + new_constraint.indices = torch.concat( + ( + new_constraint.indices, + constraint.indices + n_atoms_offset, + ) + ) + elif isinstance(constraint, SystemConstraint): + new_constraint = constraints[constraint_name] + new_constraint.system_idx = torch.concat( + ( + new_constraint.system_idx, + constraint.system_idx + system_offset, + ) + ) + constraints[constraint_name] = new_constraint + + n_atoms_offset += state.n_atoms system_offset += num_systems # Concatenate collected tensors @@ -930,8 +1066,9 @@ def concatenate_states[T: SimState]( # noqa: C901 # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) + constraints = list(constraints.values()) # Create a new instance of the same class - return state_class(**concatenated) + return state_class(**concatenated, constraints=constraints) def initialize_state( diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 1f905048..d13a615d 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -1221,3 +1221,30 @@ def unwrap_positions( unwrapped[1:] = torch.cumsum(dcart, dim=0) + unwrapped[0] return unwrapped + + +def get_centers_of_mass( + positions: torch.Tensor, + masses: torch.Tensor, + system_idx: torch.Tensor, + n_systems: int, +) -> torch.Tensor: + """Compute the centers of mass for each structure in the simulation state.s. + + Args: + positions (torch.Tensor): Atomic positions of shape (N, 3). + masses (torch.Tensor): Atomic masses of shape (N,). + system_idx (torch.Tensor): System indices for each atom of shape (N,). + n_systems (int): Total number of systems. + + Returns: + torch.Tensor: A tensor of shape (n_structures, 3) containing + the center of mass coordinates for each structure. + """ + coms = torch.zeros((n_systems, 3), dtype=positions.dtype).scatter_add_( + 0, + system_idx.unsqueeze(-1).expand(-1, 3), + masses.unsqueeze(-1) * positions, + ) + coms /= masses.unsqueeze(-1).sum(dim=0) + return coms From 7d6306916cb83f8c5a53ca8831687b234b3ef330 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Thu, 23 Oct 2025 10:37:12 +0200 Subject: [PATCH 07/43] test temperature, adapt calc_kt for reduced degrees of freedom --- tests/test_constraints.py | 61 ++++++++++++++++++++++----------------- torch_sim/quantities.py | 6 +++- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index ade7e6e8..2c7745e5 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -70,30 +70,36 @@ def test_fix_atoms(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesMo ) -def test_fix_com_nvt_langevin( - ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel -): +def test_fix_com_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJonesModel): """Test FixCom constraint in NVT Langevin dynamics.""" - n_steps = 200 + n_steps = 1000 dt = torch.tensor(0.001, dtype=DTYPE) kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature - dofs_before = ar_double_sim_state.calc_dof() - ar_double_sim_state.constraints = [FixCom()] - assert torch.allclose(ar_double_sim_state.calc_dof(), dofs_before - 3) + dofs_before = cu_sim_state.calc_dof() + cu_sim_state.constraints = [FixCom()] + assert torch.allclose(cu_sim_state.calc_dof(), dofs_before - 3) - state = ts.nvt_langevin_init( - state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 - ) + state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT, seed=42) positions = [] system_masses = torch.zeros((state.n_systems, 1), dtype=DTYPE).scatter_add_( 0, state.system_idx.unsqueeze(-1).expand(-1, 1), state.masses.unsqueeze(-1), ) + temperatures = [] for _step in range(n_steps): state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) positions.append(state.positions.clone()) + temp = ts.calc_kT( + masses=state.masses, + momenta=state.momenta, + system_idx=state.system_idx, + dof_per_system=state.calc_dof(), + ) + temperatures.append(temp / MetalUnits.temperature) + temperatures = torch.stack(temperatures) + traj_positions = torch.stack(positions) # unwrapped_positions = unwrap_positions( @@ -106,39 +112,42 @@ def test_fix_com_nvt_langevin( ) coms /= system_masses coms_drift = coms - coms[0] - assert torch.allclose(coms_drift, torch.zeros_like(coms_drift), atol=1e-8) + assert torch.allclose(coms_drift, torch.zeros_like(coms_drift), atol=1e-6) + assert (torch.mean(temperatures[len(temperatures) // 2 :]) - 300) / 300 < 0.30 -def test_fix_atoms_nvt_langevin( - ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel -): +def test_fix_atoms_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJonesModel): """Test FixAtoms constraint in NVT Langevin dynamics.""" - n_steps = 200 + n_steps = 1000 dt = torch.tensor(0.001, dtype=DTYPE) kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature - dofs_before = ar_double_sim_state.calc_dof() - ar_double_sim_state.constraints = [ - FixAtoms(indices=torch.tensor([0, 1], dtype=torch.long)) - ] - assert torch.allclose( - ar_double_sim_state.calc_dof(), dofs_before - torch.tensor([6, 0]) - ) - state = ts.nvt_langevin_init( - state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 - ) + dofs_before = cu_sim_state.calc_dof() + cu_sim_state.constraints = [FixAtoms(indices=torch.tensor([0, 1], dtype=torch.long))] + assert torch.allclose(cu_sim_state.calc_dof(), dofs_before - torch.tensor([6])) + state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT, seed=42) positions = [] + temperatures = [] for _step in range(n_steps): state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) positions.append(state.positions.clone()) + temp = ts.calc_kT( + masses=state.masses, + momenta=state.momenta, + system_idx=state.system_idx, + dof_per_system=state.calc_dof(), + ) + temperatures.append(temp / MetalUnits.temperature) + temperatures = torch.stack(temperatures) traj_positions = torch.stack(positions) unwrapped_positions = unwrap_positions( - traj_positions, ar_double_sim_state.cell, state.system_idx + traj_positions, cu_sim_state.cell, state.system_idx ) diff_positions = unwrapped_positions - unwrapped_positions[0] assert torch.max(diff_positions[:, :2]) < 1e-8 assert torch.max(diff_positions[:, 2:]) > 1e-3 + assert (torch.mean(temperatures[len(temperatures) // 2 :]) - 300) / 300 < 0.30 def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index bcb824c0..a3953764 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -19,6 +19,7 @@ def calc_kT( # noqa: N802 momenta: torch.Tensor | None = None, velocities: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, + dof_per_system: torch.Tensor | None = None, ) -> torch.Tensor: """Calculate temperature in energy units from momenta/velocities and masses. @@ -28,6 +29,8 @@ def calc_kT( # noqa: N802 velocities (torch.Tensor | None): Particle velocities, shape (n_particles, n_dim) system_idx (torch.Tensor | None): Optional tensor indicating system membership of each particle + dof_per_system (torch.Tensor | None): Optional tensor indicating + degrees of freedom per system Returns: torch.Tensor: Scalar temperature value @@ -53,7 +56,8 @@ def calc_kT( # noqa: N802 # Count degrees of freedom per system system_sizes = torch.bincount(system_idx) - dof_per_system = system_sizes * squared_term.shape[-1] # multiply by n_dimensions + if dof_per_system is None: + dof_per_system = system_sizes * squared_term.shape[-1] # multiply by n_dimensions # Calculate temperature per system system_sums = torch.segment_reduce( From ad4fa0a367fb51bb231131ec0891d8c66e8316b3 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 9 Nov 2025 19:30:35 -0800 Subject: [PATCH 08/43] fix typo + unreleased changelog entry --- CHANGELOG.md | 10 ++++++++++ torch_sim/constraints.py | 3 +-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b7393f1..98ce9879 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,16 @@ # Changelog +## Unreleased + +### 🎉 New Features +* Constraints support for molecular dynamics and optimization by @thomasloux in [#294](https://github.com/TorchSim/torch-sim/pull/294) + - Added `FixAtoms` constraint to fix specific atoms in place + - Added `FixCom` constraint to prevent center of mass drift + - Constraints automatically adjust degrees of freedom for accurate temperature calculations + - Full support across all integrators (NVE, NVT, NPT) and optimizers (FIRE, Gradient Descent) + - Constraints preserved during state manipulation (slicing, splitting, concatenation) + ## v0.4.0 Thank you to everyone who contributed to this release! This release includes significant API improvements and breaking changes. @janosh led a major API redesign to improve usability. @stefanbringuier added heat flux calculations. @curtischong continued improving type safety across the codebase. @CompRhys, @orionarcher, @WillEngler, and @thomasloux all made valuable contributions. 🚀 diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 5431fd8c..4ebb839e 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -400,7 +400,6 @@ def count_degrees_of_freedom( return max(0, total_dof) # Ensure non-negative -# WIP def warn_if_overlapping_constraints(constraints: list[Constraint]) -> None: """Issue warnings if constraints might overlap in problematic ways. @@ -421,7 +420,7 @@ def warn_if_overlapping_constraints(constraints: list[Constraint]) -> None: # Check for overlapping atom indices if len(indexed_constraints) > 1: - all_indices = torch.cat([c.index for c in indexed_constraints]) + all_indices = torch.cat([c.indices for c in indexed_constraints]) unique_indices = torch.unique(all_indices) if len(unique_indices) < len(all_indices): warnings.warn( From 8beb9d93c4ca1b3e3b21045093f84d7c3c2b72fe Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 9 Nov 2025 19:51:56 -0800 Subject: [PATCH 09/43] renamed validate_constraints now called in SimState.add_constraints and checks atom indices exist in state if provided + that all constrained atoms belong to same system --- torch_sim/constraints.py | 48 ++++++++++++++++++++++++++++++++++------ torch_sim/state.py | 17 +++++++++++++- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 4ebb839e..d6564722 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -400,21 +400,55 @@ def count_degrees_of_freedom( return max(0, total_dof) # Ensure non-negative -def warn_if_overlapping_constraints(constraints: list[Constraint]) -> None: - """Issue warnings if constraints might overlap in problematic ways. +def validate_constraints( # noqa: C901 + constraints: list[Constraint], state: SimState | None = None +) -> None: + """Validate constraints for potential issues and incompatibilities. - This function checks for potential issues like multiple constraints - acting on the same atoms, which could lead to unexpected behavior. + This function checks for: + 1. Overlapping atom indices across multiple constraints + 2. AtomIndexedConstraints spanning multiple systems (requires state) + 3. Mixing FixCom with other constraints (warning only) Args: - constraints: List of constraints to check + constraints: List of constraints to validate + state: Optional SimState for validating atom indices belong to same system + + Raises: + ValueError: If constraints are invalid or span multiple systems + + Warns: + UserWarning: If constraints may lead to unexpected behavior """ + if not constraints: + return + indexed_constraints = [] has_com_constraint = False for constraint in constraints: if isinstance(constraint, AtomIndexedConstraint): indexed_constraints.append(constraint) + + # Validate that atom indices exist in state if provided + if state is not None and len(constraint.indices) > 0: + if constraint.indices.max() >= state.n_atoms: + raise ValueError( + f"Constraint {type(constraint).__name__} has indices up to " + f"{constraint.indices.max()}, but state only has {state.n_atoms} " + "atoms" + ) + + # Check that all constrained atoms belong to same system + constrained_system_indices = state.system_idx[constraint.indices] + unique_systems = torch.unique(constrained_system_indices) + if len(unique_systems) > 1: + raise ValueError( + f"Constraint {type(constraint).__name__} acts on atoms from " + f"multiple systems {unique_systems.tolist()}. Each constraint " + f"must operate within a single system." + ) + elif isinstance(constraint, FixCom): has_com_constraint = True @@ -427,7 +461,7 @@ def warn_if_overlapping_constraints(constraints: list[Constraint]) -> None: "Multiple constraints are acting on the same atoms. " "This may lead to unexpected behavior.", UserWarning, - stacklevel=2, + stacklevel=3, ) # Warn about COM constraint with fixed atoms @@ -437,5 +471,5 @@ def warn_if_overlapping_constraints(constraints: list[Constraint]) -> None: "unexpected behavior. The center of mass constraint is applied " "to all atoms, including those that may be constrained by other means.", UserWarning, - stacklevel=2, + stacklevel=3, ) diff --git a/torch_sim/state.py b/torch_sim/state.py index 57db8165..e47abf91 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -23,7 +23,12 @@ from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure -from torch_sim.constraints import AtomIndexedConstraint, Constraint, SystemConstraint +from torch_sim.constraints import ( + AtomIndexedConstraint, + Constraint, + SystemConstraint, + validate_constraints, +) @dataclass @@ -140,6 +145,9 @@ def __post_init__(self) -> None: if not torch.all(counts == torch.bincount(initial_system_idx)): raise ValueError("System indices must be unique consecutive integers") + if self.constraints: + validate_constraints(self.constraints, state=self) + if self.cell.ndim != 3 and initial_system_idx is None: self.cell = self.cell.unsqueeze(0) @@ -249,6 +257,9 @@ def add_constraints(self, constraints: list[Constraint] | Constraint) -> None: Args: constraints (list["Constraint"] | None): List of constraints to apply. If None, no constraints are applied. + + Raises: + ValueError: If constraints are invalid or span multiple systems """ # check it is a list if isinstance(constraints, Constraint): @@ -258,6 +269,10 @@ def add_constraints(self, constraints: list[Constraint] | Constraint) -> None: if hasattr(constraint, "system_idx") and constraint.system_idx == slice(None): constraint.system_idx = torch.arange(self.n_systems, device=self.device) + # Validate new constraints before adding + all_constraints = self.constraints + constraints + validate_constraints(all_constraints, state=self) + self.constraints += constraints def get_number_of_degrees_of_freedom(self) -> torch.Tensor: From c577e1d503b953c7427d5dbaaaf7c87f000fd3a6 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 9 Nov 2025 19:52:39 -0800 Subject: [PATCH 10/43] tests for constraint validation warnings and errors - check that constraints work correctly with non-periodic boundaries and in batched states --- tests/test_constraints.py | 317 +++++++++++++++++++++++++++++++++++++- 1 file changed, 309 insertions(+), 8 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 2c7745e5..382fb9e8 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -5,7 +5,7 @@ import torch_sim as ts from tests.conftest import DTYPE -from torch_sim.constraints import FixAtoms, FixCom +from torch_sim.constraints import Constraint, FixAtoms, FixCom, validate_constraints from torch_sim.models.interface import ModelInterface from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.optimizers import FireFlavor @@ -20,15 +20,17 @@ def test_fix_com(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesMode ar_supercell_sim_state.set_positions(initial_positions + 0.5) assert torch.allclose(ar_supercell_sim_state.positions, initial_positions, atol=1e-8) - ar_supercell_mdstate = ts.nve_init( + ar_supercell_md_state = ts.nve_init( state=ar_supercell_sim_state, model=lj_model, kT=torch.tensor(10.0, dtype=DTYPE), seed=42, ) - ar_supercell_mdstate.set_momenta(torch.randn_like(ar_supercell_mdstate.momenta) * 0.1) + ar_supercell_md_state.set_momenta( + torch.randn_like(ar_supercell_md_state.momenta) * 0.1 + ) assert torch.allclose( - ar_supercell_mdstate.momenta.mean(dim=0), torch.zeros(3, dtype=DTYPE), atol=1e-8 + ar_supercell_md_state.momenta.mean(dim=0), torch.zeros(3, dtype=DTYPE), atol=1e-8 ) @@ -56,16 +58,18 @@ def test_fix_atoms(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesMo atol=1e-8, ) - ar_supercell_mdstate = ts.nve_init( + ar_supercell_md_state = ts.nve_init( state=ar_supercell_sim_state, model=lj_model, kT=torch.tensor(10.0, dtype=DTYPE), seed=42, ) - ar_supercell_mdstate.set_momenta(torch.randn_like(ar_supercell_mdstate.momenta) * 0.1) + ar_supercell_md_state.set_momenta( + torch.randn_like(ar_supercell_md_state.momenta) * 0.1 + ) assert torch.allclose( - ar_supercell_mdstate.momenta[indices_to_fix], - torch.zeros_like(ar_supercell_mdstate.momenta[indices_to_fix]), + ar_supercell_md_state.momenta[indices_to_fix], + torch.zeros_like(ar_supercell_md_state.momenta[indices_to_fix]), atol=1e-8, ) @@ -378,3 +382,300 @@ def test_fix_com_fire_optimization( ) assert torch.allclose(final_com, initial_com, atol=1e-4) + + +def test_fix_atoms_validation() -> None: + """Test FixAtoms construction and validation.""" + # Boolean mask conversion + mask = torch.zeros(10, dtype=torch.bool) + mask[:3] = True + assert torch.all(FixAtoms(indices=mask).indices == torch.tensor([0, 1, 2])) + + # Invalid indices + with pytest.raises(ValueError, match="Indices must be integers"): + FixAtoms(indices=torch.tensor([0.5, 1.5])) + with pytest.raises(ValueError, match="duplicates"): + FixAtoms(indices=torch.tensor([0, 1, 1])) + with pytest.raises(ValueError, match="wrong number of dimensions"): + FixAtoms(indices=torch.tensor([[0, 1]])) + + +def test_constraint_validation_warnings() -> None: + """Test validation warnings for constraint conflicts.""" + with pytest.warns(UserWarning, match="Multiple constraints.*same atoms"): + validate_constraints([FixAtoms(indices=[0, 1, 2]), FixAtoms(indices=[2, 3, 4])]) + with pytest.warns(UserWarning, match="FixCom together with other constraints"): + validate_constraints([FixCom(), FixAtoms(indices=[0, 1])]) + + +def test_constraint_validation_errors( + cu_sim_state: ts.SimState, + ar_double_sim_state: ts.SimState, + ar_supercell_sim_state: ts.SimState, +) -> None: + """Test validation errors for invalid constraints.""" + # Out of bounds + with pytest.raises(ValueError, match="has indices up to.*only has.*atoms"): + cu_sim_state.add_constraints(FixAtoms(indices=[0, 1, 100])) + + # Spanning multiple systems + with pytest.raises(ValueError, match="acts on atoms from multiple systems"): + ar_double_sim_state.add_constraints(FixAtoms(indices=[0, 32])) + + # Validation in __post_init__ + with pytest.raises(ValueError, match="duplicates"): + ts.SimState( + positions=ar_supercell_sim_state.positions.clone(), + masses=ar_supercell_sim_state.masses, + cell=ar_supercell_sim_state.cell, + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers, + system_idx=ar_supercell_sim_state.system_idx, + constraints=[FixAtoms(indices=[0, 0, 1])], + ) + + +@pytest.mark.parametrize( + ("integrator", "constraint", "n_steps"), + [ + ("nve", FixAtoms(indices=[0, 1]), 100), + ("nvt_nose_hoover", FixCom(), 200), + ("npt_langevin", FixAtoms(indices=[0, 3]), 200), + ("npt_nose_hoover", FixCom(), 200), + ], +) +def test_integrators_with_constraints( + cu_sim_state: ts.SimState, + lj_model: LennardJonesModel, + integrator: str, + constraint: Constraint, + n_steps: int, +) -> None: + """Test all integrators respect constraints.""" + cu_sim_state.add_constraints(constraint) + kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature + + # Store initial state + if isinstance(constraint, FixAtoms): + initial = cu_sim_state.positions[constraint.indices].clone() + else: + initial = get_centers_of_mass( + cu_sim_state.positions, + cu_sim_state.masses, + cu_sim_state.system_idx, + cu_sim_state.n_systems, + ) + + # Run integration + if integrator == "nve": + state = ts.nve_init(cu_sim_state, lj_model, kT=kT, seed=42) + for _ in range(n_steps): + state = ts.nve_step(state, lj_model, dt=torch.tensor(0.001, dtype=DTYPE)) + elif integrator == "nvt_nose_hoover": + state = ts.nvt_nose_hoover_init(cu_sim_state, lj_model, kT=kT) + for _ in range(n_steps): + state = ts.nvt_nose_hoover_step( + state, lj_model, dt=torch.tensor(0.001, dtype=DTYPE), kT=kT + ) + elif integrator == "npt_langevin": + state = ts.npt_langevin_init(cu_sim_state, lj_model, kT=kT, seed=42) + for _ in range(n_steps): + state = ts.npt_langevin_step( + state, + lj_model, + dt=torch.tensor(0.001, dtype=DTYPE), + kT=kT, + external_pressure=torch.tensor(0.0, dtype=DTYPE), + ) + else: # npt_nose_hoover + state = ts.npt_nose_hoover_init(cu_sim_state, lj_model, kT=kT) + for _ in range(n_steps): + state = ts.npt_nose_hoover_step( + state, + lj_model, + dt=torch.tensor(0.001, dtype=DTYPE), + kT=kT, + external_pressure=torch.tensor(0.0, dtype=DTYPE), + ) + + # Verify constraint held + if isinstance(constraint, FixAtoms): + assert torch.allclose(state.positions[constraint.indices], initial, atol=1e-6) + else: + final = get_centers_of_mass( + state.positions, state.masses, state.system_idx, state.n_systems + ) + assert torch.allclose(final, initial, atol=1e-5) + + +def test_multiple_constraints_and_dof( + cu_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + """Test multiple constraints together with correct DOF calculation.""" + # Test DOF calculation + n = cu_sim_state.n_atoms + assert torch.all(cu_sim_state.calc_dof() == 3 * n) + cu_sim_state.add_constraints(FixAtoms(indices=[0])) + assert torch.all(cu_sim_state.calc_dof() == 3 * n - 3) + cu_sim_state.add_constraints(FixCom()) + assert torch.all(cu_sim_state.calc_dof() == 3 * n - 6) + + # Verify both constraints hold during dynamics + initial_pos = cu_sim_state.positions[0].clone() + initial_com = get_centers_of_mass( + cu_sim_state.positions, + cu_sim_state.masses, + cu_sim_state.system_idx, + cu_sim_state.n_systems, + ) + state = ts.nvt_langevin_init( + cu_sim_state, + lj_model, + kT=torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature, + seed=42, + ) + for _ in range(200): + state = ts.nvt_langevin_step( + state, + lj_model, + dt=torch.tensor(0.001, dtype=DTYPE), + kT=torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature, + ) + assert torch.allclose(state.positions[0], initial_pos, atol=1e-6) + final_com = get_centers_of_mass( + state.positions, state.masses, state.system_idx, state.n_systems + ) + assert torch.allclose(final_com, initial_com, atol=1e-5) + + +@pytest.mark.parametrize( + ("cell_filter", "fire_flavor"), + [ + ("unit_cell", "ase_fire"), + ("frechet_cell", "ase_fire"), + ("frechet_cell", "vv_fire"), + ], +) +def test_cell_optimization_with_constraints( + ar_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, + cell_filter: str, + fire_flavor: FireFlavor, +) -> None: + """Test cell filters work with constraints.""" + ar_supercell_sim_state.positions += ( + torch.randn_like(ar_supercell_sim_state.positions) * 0.05 + ) + ar_supercell_sim_state.add_constraints(FixAtoms(indices=[0, 1])) + state = ts.fire_init( + ar_supercell_sim_state, lj_model, cell_filter=cell_filter, fire_flavor=fire_flavor + ) + for _ in range(50): + state = ts.fire_step(state, lj_model, dt_max=0.1) + if state.forces.abs().max() < 0.05: + break + assert len(state.constraints) > 0 + + +def test_batched_constraints(ar_double_sim_state: ts.SimState) -> None: + """Test system-specific constraints in batched states.""" + s1, s2 = ar_double_sim_state.split() + s1.add_constraints(FixAtoms(indices=[0, 1])) + s2.add_constraints(FixCom()) + combined = ts.concatenate_states([s1, s2]) + assert len(combined.constraints) == 2 + assert isinstance(combined.constraints[0], FixAtoms) + assert torch.all(combined.constraints[0].indices == torch.tensor([0, 1])) + assert isinstance(combined.constraints[1], FixCom) + assert torch.all(combined.constraints[1].system_idx == torch.tensor([1])) + + +def test_constraints_with_non_pbc(lj_model: LennardJonesModel) -> None: + """Test constraints work with non-periodic boundaries.""" + state = ts.SimState( + positions=torch.tensor( + [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]], + dtype=DTYPE, + ), + masses=torch.ones(4, dtype=DTYPE) * 39.948, + cell=torch.eye(3, dtype=DTYPE) * 10.0, + pbc=False, + atomic_numbers=torch.full((4,), 18, dtype=torch.long), + system_idx=torch.zeros(4, dtype=torch.long), + ) + state.add_constraints(FixCom()) + initial = get_centers_of_mass( + state.positions, state.masses, state.system_idx, state.n_systems + ) + md_state = ts.nve_init(state, lj_model, kT=torch.tensor(100.0, dtype=DTYPE), seed=42) + for _ in range(100): + md_state = ts.nve_step(md_state, lj_model, dt=torch.tensor(0.001, dtype=DTYPE)) + final = get_centers_of_mass( + md_state.positions, md_state.masses, md_state.system_idx, md_state.n_systems + ) + assert torch.allclose(final, initial, atol=1e-5) + + +def test_high_level_api_with_constraints( + cu_sim_state: ts.SimState, + ar_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test high-level integrate() and optimize() APIs with constraints.""" + # Test integrate() + cu_sim_state.add_constraints(FixCom()) + initial_com = get_centers_of_mass( + cu_sim_state.positions, + cu_sim_state.masses, + cu_sim_state.system_idx, + cu_sim_state.n_systems, + ) + final = ts.integrate( + cu_sim_state, + lj_model, + integrator="nvt_langevin", + n_steps=100, + temperature=300.0, + timestep=0.001, + ) + final_com = get_centers_of_mass( + final.positions, final.masses, final.system_idx, final.n_systems + ) + assert torch.allclose(final_com, initial_com, atol=1e-5) + + # Test optimize() + ar_supercell_sim_state.positions += ( + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + ar_supercell_sim_state.add_constraints(FixAtoms(indices=[0, 1, 2])) + initial_pos = ar_supercell_sim_state.positions[[0, 1, 2]].clone() + final = ts.optimize(ar_supercell_sim_state, lj_model, optimizer="fire", max_steps=500) + assert torch.allclose(final.positions[[0, 1, 2]], initial_pos, atol=1e-5) + + +def test_temperature_with_constrained_dof( + cu_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + """Test temperature calculation uses constrained DOF.""" + target = 300.0 + cu_sim_state.add_constraints([FixAtoms(indices=[0, 1]), FixCom()]) + state = ts.nvt_langevin_init( + cu_sim_state, + lj_model, + kT=torch.tensor(target, dtype=DTYPE) * MetalUnits.temperature, + seed=42, + ) + temps = [] + for _ in range(1000): + state = ts.nvt_langevin_step( + state, + lj_model, + dt=torch.tensor(0.001, dtype=DTYPE), + kT=torch.tensor(target, dtype=DTYPE) * MetalUnits.temperature, + ) + temp = ts.calc_kT( + state.masses, state.momenta, state.system_idx, dof_per_system=state.calc_dof() + ) + temps.append(temp / MetalUnits.temperature) + avg = torch.mean(torch.stack(temps)[500:]) + assert abs(avg - target) / target < 0.30 From 9cfe52bfc60c7c2bacc27b334053b5ddf84ee494 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 10 Nov 2025 17:19:06 +0100 Subject: [PATCH 11/43] refactor to use getter setter and _constraints --- tests/test_constraints.py | 44 +++++++++++++++++++-------------------- torch_sim/state.py | 28 ++++++++++++++++++------- 2 files changed, 41 insertions(+), 31 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 382fb9e8..446fd170 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -15,7 +15,7 @@ def test_fix_com(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel): """Test adjustment of positions and momenta with FixCom constraint.""" - ar_supercell_sim_state.add_constraints([FixCom()]) + ar_supercell_sim_state.constraints = [FixCom()] initial_positions = ar_supercell_sim_state.positions.clone() ar_supercell_sim_state.set_positions(initial_positions + 0.5) assert torch.allclose(ar_supercell_sim_state.positions, initial_positions, atol=1e-8) @@ -37,7 +37,7 @@ def test_fix_com(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesMode def test_fix_atoms(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel): """Test adjustment of positions and momenta with FixAtoms constraint.""" indices_to_fix = torch.tensor([0, 5, 10], dtype=torch.long) - ar_supercell_sim_state.add_constraints([FixAtoms(indices=indices_to_fix)]) + ar_supercell_sim_state.constraints = [FixAtoms(indices=indices_to_fix)] initial_positions = ar_supercell_sim_state.positions.clone() # displacement = torch.randn_like(ar_supercell_sim_state.positions) * 0.5 displacement = 0.5 @@ -157,9 +157,7 @@ def test_fix_atoms_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJone def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): """Test that constraints are properly propagated during state manipulation.""" # Set up constraints on the original state - ar_double_sim_state.add_constraints( - [FixAtoms(indices=torch.tensor([0, 1])), FixCom()] - ) + ar_double_sim_state.constraints = [FixAtoms(indices=torch.tensor([0, 1])), FixCom()] # Extract individual systems from the double system state first_system = ar_double_sim_state[0] @@ -196,7 +194,7 @@ def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): # Test constraint manipulation with different configurations ar_double_sim_state.constraints = [] - ar_double_sim_state.add_constraints([FixCom()]) + ar_double_sim_state.constraints = [FixCom()] isolated_system = ar_double_sim_state[0] assert torch.all( isolated_system.constraints[0].system_idx == torch.tensor([0], dtype=torch.long) @@ -224,7 +222,7 @@ def test_fix_com_gradient_descent_optimization( ar_supercell_sim_state.positions = perturbed_positions initial_state = ar_supercell_sim_state - ar_supercell_sim_state.add_constraints(FixCom()) + ar_supercell_sim_state.constraints = [FixCom()] initial_coms = get_centers_of_mass( positions=initial_state.positions, @@ -267,7 +265,7 @@ def test_fix_atoms_gradient_descent_optimization( ar_supercell_sim_state.positions = perturbed_positions initial_state = ar_supercell_sim_state - initial_state.add_constraints(FixAtoms(indices=[0])) + initial_state.constraints = [FixAtoms(indices=[0])] initial_position = initial_state.positions[0].clone() # Initialize Gradient Descent optimizer @@ -309,7 +307,7 @@ def test_test_atoms_fire_optimization( system_idx=ar_supercell_sim_state.system_idx.clone(), ) indices = torch.tensor([0, 2], dtype=torch.long) - current_sim_state.add_constraints(FixAtoms(indices=indices)) + current_sim_state.constraints = [FixAtoms(indices=indices)] # Initialize FIRE optimizer state = ts.fire_init( @@ -352,7 +350,7 @@ def test_fix_com_fire_optimization( atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), system_idx=ar_supercell_sim_state.system_idx.clone(), ) - current_sim_state.add_constraints(FixCom()) + current_sim_state.constraints = [FixCom()] # Initialize FIRE optimizer state = ts.fire_init( @@ -415,12 +413,12 @@ def test_constraint_validation_errors( ) -> None: """Test validation errors for invalid constraints.""" # Out of bounds - with pytest.raises(ValueError, match="has indices up to.*only has.*atoms"): - cu_sim_state.add_constraints(FixAtoms(indices=[0, 1, 100])) + with pytest.raises(ValueError, match=r"has indices up to.*only has.*atoms"): + cu_sim_state.constraints = [FixAtoms(indices=[0, 1, 100])] # Spanning multiple systems with pytest.raises(ValueError, match="acts on atoms from multiple systems"): - ar_double_sim_state.add_constraints(FixAtoms(indices=[0, 32])) + ar_double_sim_state.constraints = [FixAtoms(indices=[0, 32])] # Validation in __post_init__ with pytest.raises(ValueError, match="duplicates"): @@ -452,7 +450,7 @@ def test_integrators_with_constraints( n_steps: int, ) -> None: """Test all integrators respect constraints.""" - cu_sim_state.add_constraints(constraint) + cu_sim_state.constraints = [constraint] kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature # Store initial state @@ -515,9 +513,9 @@ def test_multiple_constraints_and_dof( # Test DOF calculation n = cu_sim_state.n_atoms assert torch.all(cu_sim_state.calc_dof() == 3 * n) - cu_sim_state.add_constraints(FixAtoms(indices=[0])) + cu_sim_state.constraints = [FixAtoms(indices=[0])] assert torch.all(cu_sim_state.calc_dof() == 3 * n - 3) - cu_sim_state.add_constraints(FixCom()) + cu_sim_state.constraints = [FixCom()] assert torch.all(cu_sim_state.calc_dof() == 3 * n - 6) # Verify both constraints hold during dynamics @@ -566,7 +564,7 @@ def test_cell_optimization_with_constraints( ar_supercell_sim_state.positions += ( torch.randn_like(ar_supercell_sim_state.positions) * 0.05 ) - ar_supercell_sim_state.add_constraints(FixAtoms(indices=[0, 1])) + ar_supercell_sim_state.constraints = [FixAtoms(indices=[0, 1])] state = ts.fire_init( ar_supercell_sim_state, lj_model, cell_filter=cell_filter, fire_flavor=fire_flavor ) @@ -580,8 +578,8 @@ def test_cell_optimization_with_constraints( def test_batched_constraints(ar_double_sim_state: ts.SimState) -> None: """Test system-specific constraints in batched states.""" s1, s2 = ar_double_sim_state.split() - s1.add_constraints(FixAtoms(indices=[0, 1])) - s2.add_constraints(FixCom()) + s1.constraints = [FixAtoms(indices=[0, 1])] + s2.constraints = [FixCom()] combined = ts.concatenate_states([s1, s2]) assert len(combined.constraints) == 2 assert isinstance(combined.constraints[0], FixAtoms) @@ -603,7 +601,7 @@ def test_constraints_with_non_pbc(lj_model: LennardJonesModel) -> None: atomic_numbers=torch.full((4,), 18, dtype=torch.long), system_idx=torch.zeros(4, dtype=torch.long), ) - state.add_constraints(FixCom()) + state.constraints = [FixCom()] initial = get_centers_of_mass( state.positions, state.masses, state.system_idx, state.n_systems ) @@ -623,7 +621,7 @@ def test_high_level_api_with_constraints( ) -> None: """Test high-level integrate() and optimize() APIs with constraints.""" # Test integrate() - cu_sim_state.add_constraints(FixCom()) + cu_sim_state.constraints = [FixCom()] initial_com = get_centers_of_mass( cu_sim_state.positions, cu_sim_state.masses, @@ -647,7 +645,7 @@ def test_high_level_api_with_constraints( ar_supercell_sim_state.positions += ( torch.randn_like(ar_supercell_sim_state.positions) * 0.1 ) - ar_supercell_sim_state.add_constraints(FixAtoms(indices=[0, 1, 2])) + ar_supercell_sim_state.constraints = [FixAtoms(indices=[0, 1, 2])] initial_pos = ar_supercell_sim_state.positions[[0, 1, 2]].clone() final = ts.optimize(ar_supercell_sim_state, lj_model, optimizer="fire", max_steps=500) assert torch.allclose(final.positions[[0, 1, 2]], initial_pos, atol=1e-5) @@ -658,7 +656,7 @@ def test_temperature_with_constrained_dof( ) -> None: """Test temperature calculation uses constrained DOF.""" target = 300.0 - cu_sim_state.add_constraints([FixAtoms(indices=[0, 1]), FixCom()]) + cu_sim_state.constraints = [FixAtoms(indices=[0, 1]), FixCom()] state = ts.nvt_langevin_init( cu_sim_state, lj_model, diff --git a/torch_sim/state.py b/torch_sim/state.py index e47abf91..7cbc21e3 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -92,7 +92,7 @@ class SimState: pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor system_idx: torch.Tensor | None = field(default=None) - constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 + _constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 if TYPE_CHECKING: @@ -108,7 +108,7 @@ def system_idx(self) -> torch.Tensor: "system_idx", } _system_attributes: ClassVar[set[str]] = {"cell"} - _global_attributes: ClassVar[set[str]] = {"pbc", "constraints"} + _global_attributes: ClassVar[set[str]] = {"pbc"} def __post_init__(self) -> None: """Initialize the SimState and validate the arguments.""" @@ -251,7 +251,17 @@ def set_positions(self, new_positions: torch.Tensor) -> None: constraint.adjust_positions(self, new_positions) self.positions = new_positions - def add_constraints(self, constraints: list[Constraint] | Constraint) -> None: + @property + def constraints(self) -> list[Constraint]: + """Get the constraints for the SimState. + + Returns: + list["Constraint"]: List of constraints applied to the system. + """ + return self._constraints + + @constraints.setter + def constraints(self, constraints: list[Constraint] | Constraint) -> None: """Set the constraints for the SimState. Args: @@ -265,15 +275,17 @@ def add_constraints(self, constraints: list[Constraint] | Constraint) -> None: if isinstance(constraints, Constraint): constraints = [constraints] for constraint in constraints: - # if constraint.system_idx exists - if hasattr(constraint, "system_idx") and constraint.system_idx == slice(None): + if ( + isinstance(constraint, SystemConstraint) + and constraint.initialized is False + ): constraint.system_idx = torch.arange(self.n_systems, device=self.device) + constraint.initialized = True # Validate new constraints before adding - all_constraints = self.constraints + constraints - validate_constraints(all_constraints, state=self) + validate_constraints(constraints, state=self) - self.constraints += constraints + self._constraints = constraints def get_number_of_degrees_of_freedom(self) -> torch.Tensor: """Calculate degrees of freedom accounting for constraints. From be30d45a169e52058fb52d181e1e582aa7334c68 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 10 Nov 2025 17:19:39 +0100 Subject: [PATCH 12/43] remove edge case slice(None) --- torch_sim/constraints.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index d6564722..d72696ee 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -159,9 +159,11 @@ def __init__(self, system_idx: torch.Tensor | list[int] | None = None) -> None: ValueError: If both indices and mask are provided, or if indices have wrong shape/type """ + self.initialized = True if system_idx is None: # Empty constraint - self.system_idx = slice(None) # All systems + self.system_idx = [] + self.initialized = True return # Convert to tensor if needed From 33d6025ee1ba721ff4e0fa9b81adf32725dc1242 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 10 Nov 2025 18:57:50 +0100 Subject: [PATCH 13/43] new API (remove slice(None) and _constraint as private var --- tests/test_constraints.py | 71 +++--- torch_sim/constraints.py | 33 ++- torch_sim/integrators/md.py | 25 +- torch_sim/integrators/npt.py | 31 +-- torch_sim/integrators/nve.py | 2 +- torch_sim/integrators/nvt.py | 17 +- torch_sim/models/einstein.py | 285 +++++++++++++++++++++++ torch_sim/optimizers/fire.py | 2 +- torch_sim/optimizers/gradient_descent.py | 2 +- torch_sim/optimizers/state.py | 24 +- torch_sim/runners.py | 1 - torch_sim/state.py | 123 +++++----- 12 files changed, 438 insertions(+), 178 deletions(-) create mode 100644 torch_sim/models/einstein.py diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 446fd170..22723c93 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -80,9 +80,11 @@ def test_fix_com_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJonesM dt = torch.tensor(0.001, dtype=DTYPE) kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature - dofs_before = cu_sim_state.calc_dof() + dofs_before = cu_sim_state.get_number_of_degrees_of_freedom() cu_sim_state.constraints = [FixCom()] - assert torch.allclose(cu_sim_state.calc_dof(), dofs_before - 3) + assert torch.allclose( + cu_sim_state.get_number_of_degrees_of_freedom(), dofs_before - 3 + ) state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT, seed=42) positions = [] @@ -99,7 +101,7 @@ def test_fix_com_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJonesM masses=state.masses, momenta=state.momenta, system_idx=state.system_idx, - dof_per_system=state.calc_dof(), + dof_per_system=state.get_number_of_degrees_of_freedom(), ) temperatures.append(temp / MetalUnits.temperature) temperatures = torch.stack(temperatures) @@ -126,9 +128,11 @@ def test_fix_atoms_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJone dt = torch.tensor(0.001, dtype=DTYPE) kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature - dofs_before = cu_sim_state.calc_dof() + dofs_before = cu_sim_state.get_number_of_degrees_of_freedom() cu_sim_state.constraints = [FixAtoms(indices=torch.tensor([0, 1], dtype=torch.long))] - assert torch.allclose(cu_sim_state.calc_dof(), dofs_before - torch.tensor([6])) + assert torch.allclose( + cu_sim_state.get_number_of_degrees_of_freedom(), dofs_before - torch.tensor([6]) + ) state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT, seed=42) positions = [] temperatures = [] @@ -139,7 +143,7 @@ def test_fix_atoms_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJone masses=state.masses, momenta=state.momenta, system_idx=state.system_idx, - dof_per_system=state.calc_dof(), + dof_per_system=state.get_number_of_degrees_of_freedom(), ) temperatures.append(temp / MetalUnits.temperature) temperatures = torch.stack(temperatures) @@ -188,9 +192,7 @@ def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): assert len(split_systems[0].constraints) == 2 assert torch.all(split_systems[0].constraints[0].indices == torch.tensor([0, 1])) assert torch.all(split_systems[1].constraints[0].indices == torch.tensor([0, 1])) - assert torch.all( - split_systems[2].constraints[0].indices == torch.tensor([], dtype=torch.long) - ) + assert len(split_systems[2].constraints) == 1 # Test constraint manipulation with different configurations ar_double_sim_state.constraints = [] @@ -408,7 +410,6 @@ def test_constraint_validation_warnings() -> None: def test_constraint_validation_errors( cu_sim_state: ts.SimState, - ar_double_sim_state: ts.SimState, ar_supercell_sim_state: ts.SimState, ) -> None: """Test validation errors for invalid constraints.""" @@ -416,10 +417,6 @@ def test_constraint_validation_errors( with pytest.raises(ValueError, match=r"has indices up to.*only has.*atoms"): cu_sim_state.constraints = [FixAtoms(indices=[0, 1, 100])] - # Spanning multiple systems - with pytest.raises(ValueError, match="acts on atoms from multiple systems"): - ar_double_sim_state.constraints = [FixAtoms(indices=[0, 32])] - # Validation in __post_init__ with pytest.raises(ValueError, match="duplicates"): ts.SimState( @@ -429,7 +426,7 @@ def test_constraint_validation_errors( pbc=ar_supercell_sim_state.pbc, atomic_numbers=ar_supercell_sim_state.atomic_numbers, system_idx=ar_supercell_sim_state.system_idx, - constraints=[FixAtoms(indices=[0, 0, 1])], + _constraints=[FixAtoms(indices=[0, 0, 1])], ) @@ -452,6 +449,7 @@ def test_integrators_with_constraints( """Test all integrators respect constraints.""" cu_sim_state.constraints = [constraint] kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature + dt = torch.tensor(0.001, dtype=DTYPE) # Store initial state if isinstance(constraint, FixAtoms): @@ -468,25 +466,23 @@ def test_integrators_with_constraints( if integrator == "nve": state = ts.nve_init(cu_sim_state, lj_model, kT=kT, seed=42) for _ in range(n_steps): - state = ts.nve_step(state, lj_model, dt=torch.tensor(0.001, dtype=DTYPE)) + state = ts.nve_step(state, lj_model, dt=dt) elif integrator == "nvt_nose_hoover": - state = ts.nvt_nose_hoover_init(cu_sim_state, lj_model, kT=kT) + state = ts.nvt_nose_hoover_init(cu_sim_state, lj_model, kT=kT, dt=dt) for _ in range(n_steps): - state = ts.nvt_nose_hoover_step( - state, lj_model, dt=torch.tensor(0.001, dtype=DTYPE), kT=kT - ) + state = ts.nvt_nose_hoover_step(state, lj_model, dt=dt, kT=kT) elif integrator == "npt_langevin": - state = ts.npt_langevin_init(cu_sim_state, lj_model, kT=kT, seed=42) + state = ts.npt_langevin_init(cu_sim_state, lj_model, kT=kT, seed=42, dt=dt) for _ in range(n_steps): state = ts.npt_langevin_step( state, lj_model, - dt=torch.tensor(0.001, dtype=DTYPE), + dt=dt, kT=kT, external_pressure=torch.tensor(0.0, dtype=DTYPE), ) else: # npt_nose_hoover - state = ts.npt_nose_hoover_init(cu_sim_state, lj_model, kT=kT) + state = ts.npt_nose_hoover_init(cu_sim_state, lj_model, kT=kT, dt=dt) for _ in range(n_steps): state = ts.npt_nose_hoover_step( state, @@ -512,11 +508,11 @@ def test_multiple_constraints_and_dof( """Test multiple constraints together with correct DOF calculation.""" # Test DOF calculation n = cu_sim_state.n_atoms - assert torch.all(cu_sim_state.calc_dof() == 3 * n) + assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n) cu_sim_state.constraints = [FixAtoms(indices=[0])] - assert torch.all(cu_sim_state.calc_dof() == 3 * n - 3) - cu_sim_state.constraints = [FixCom()] - assert torch.all(cu_sim_state.calc_dof() == 3 * n - 6) + assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n - 3) + cu_sim_state.constraints = [FixCom(), FixAtoms(indices=[0])] + assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n - 6) # Verify both constraints hold during dynamics initial_pos = cu_sim_state.positions[0].clone() @@ -549,9 +545,9 @@ def test_multiple_constraints_and_dof( @pytest.mark.parametrize( ("cell_filter", "fire_flavor"), [ - ("unit_cell", "ase_fire"), - ("frechet_cell", "ase_fire"), - ("frechet_cell", "vv_fire"), + (ts.CellFilter.unit, "ase_fire"), + (ts.CellFilter.frechet, "ase_fire"), + (ts.CellFilter.frechet, "vv_fire"), ], ) def test_cell_optimization_with_constraints( @@ -596,7 +592,7 @@ def test_constraints_with_non_pbc(lj_model: LennardJonesModel) -> None: dtype=DTYPE, ), masses=torch.ones(4, dtype=DTYPE) * 39.948, - cell=torch.eye(3, dtype=DTYPE) * 10.0, + cell=torch.eye(3, dtype=DTYPE).unsqueeze(0) * 10.0, pbc=False, atomic_numbers=torch.full((4,), 18, dtype=torch.long), system_idx=torch.zeros(4, dtype=torch.long), @@ -631,11 +627,12 @@ def test_high_level_api_with_constraints( final = ts.integrate( cu_sim_state, lj_model, - integrator="nvt_langevin", - n_steps=100, + integrator=ts.Integrator.nvt_langevin, + n_steps=1, temperature=300.0, timestep=0.001, ) + final_com = get_centers_of_mass( final.positions, final.masses, final.system_idx, final.n_systems ) @@ -647,7 +644,9 @@ def test_high_level_api_with_constraints( ) ar_supercell_sim_state.constraints = [FixAtoms(indices=[0, 1, 2])] initial_pos = ar_supercell_sim_state.positions[[0, 1, 2]].clone() - final = ts.optimize(ar_supercell_sim_state, lj_model, optimizer="fire", max_steps=500) + final = ts.optimize( + ar_supercell_sim_state, lj_model, optimizer=ts.Optimizer.fire, max_steps=500 + ) assert torch.allclose(final.positions[[0, 1, 2]], initial_pos, atol=1e-5) @@ -671,9 +670,7 @@ def test_temperature_with_constrained_dof( dt=torch.tensor(0.001, dtype=DTYPE), kT=torch.tensor(target, dtype=DTYPE) * MetalUnits.temperature, ) - temp = ts.calc_kT( - state.masses, state.momenta, state.system_idx, dof_per_system=state.calc_dof() - ) + temp = state.calc_kT() temps.append(temp / MetalUnits.temperature) avg = torch.mean(torch.stack(temps)[500:]) assert abs(avg - target) / target < 0.30 diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index d72696ee..1f40ab19 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -162,8 +162,8 @@ def __init__(self, system_idx: torch.Tensor | list[int] | None = None) -> None: self.initialized = True if system_idx is None: # Empty constraint - self.system_idx = [] - self.initialized = True + self.system_idx = torch.empty(0, dtype=torch.long) + self.initialized = False return # Convert to tensor if needed @@ -402,7 +402,7 @@ def count_degrees_of_freedom( return max(0, total_dof) # Ensure non-negative -def validate_constraints( # noqa: C901 +def validate_constraints( constraints: list[Constraint], state: SimState | None = None ) -> None: """Validate constraints for potential issues and incompatibilities. @@ -433,23 +433,16 @@ def validate_constraints( # noqa: C901 indexed_constraints.append(constraint) # Validate that atom indices exist in state if provided - if state is not None and len(constraint.indices) > 0: - if constraint.indices.max() >= state.n_atoms: - raise ValueError( - f"Constraint {type(constraint).__name__} has indices up to " - f"{constraint.indices.max()}, but state only has {state.n_atoms} " - "atoms" - ) - - # Check that all constrained atoms belong to same system - constrained_system_indices = state.system_idx[constraint.indices] - unique_systems = torch.unique(constrained_system_indices) - if len(unique_systems) > 1: - raise ValueError( - f"Constraint {type(constraint).__name__} acts on atoms from " - f"multiple systems {unique_systems.tolist()}. Each constraint " - f"must operate within a single system." - ) + if ( + (state is not None) + and (len(constraint.indices) > 0) + and (constraint.indices.max() >= state.n_atoms) + ): + raise ValueError( + f"Constraint {type(constraint).__name__} has indices up to " + f"{constraint.indices.max()}, but state only has {state.n_atoms} " + "atoms" + ) elif isinstance(constraint, FixCom): has_com_constraint = True diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index e83573c8..2de04ed5 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -6,7 +6,7 @@ import torch from torch_sim.models.interface import ModelInterface -from torch_sim.quantities import calc_temperature +from torch_sim.quantities import calc_kT from torch_sim.state import SimState from torch_sim.units import MetalUnits @@ -58,9 +58,8 @@ def velocities(self) -> torch.Tensor: def set_momenta(self, new_momenta: torch.Tensor) -> None: """Set new momenta, applying any constraints as needed.""" - if self.constraints is not None: - for constraint in self.constraints: - constraint.adjust_momenta(self, new_momenta) + for constraint in self.constraints: + constraint.adjust_momenta(self, new_momenta) self.momenta = new_momenta def calc_temperature( @@ -74,12 +73,19 @@ def calc_temperature( Returns: torch.Tensor: Calculated temperature """ - return calc_temperature( + return self.calc_kT() / units.temperature + + def calc_kT(self) -> torch.Tensor: # noqa: N802 + """Calculate kT from momenta, masses, and system indices. + + Returns: + torch.Tensor: Calculated kT in energy units + """ + return calc_kT( masses=self.masses, momenta=self.momenta, system_idx=self.system_idx, dof_per_system=self.get_number_of_degrees_of_freedom(), - units=units, ) @@ -181,13 +187,6 @@ def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """ new_positions = state.positions + state.velocities * dt state.set_positions(new_positions) - - # if state.pbc: - # # Split positions and cells by system - # new_positions = transforms.pbc_wrap_batched( - # state.positions, state.cell, state.system_idx - # ) - # state.positions = new_positions # no constraints applied return state diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 1ba8ed3d..e5337269 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -364,14 +364,7 @@ def _npt_langevin_position_step( ) # Update positions with all contributions - state.positions = c_1 + c_2.unsqueeze(-1) * c_3 - - # Apply periodic boundary conditions if needed - if state.pbc: - state.positions = ts.transforms.pbc_wrap_batched( - state.positions, state.cell, state.system_idx - ) - + state.set_positions(c_1 + c_2.unsqueeze(-1) * c_3) return state @@ -435,7 +428,8 @@ def _npt_langevin_velocity_step( # Update momenta (velocities * masses) with all contributions new_velocities = c_1 + c_2 + c_3 - state.momenta = new_velocities * state.masses.unsqueeze(-1) + # Apply constraints. Is it correct to apply constraints here? + state.set_momenta(new_velocities * state.masses.unsqueeze(-1)) return state @@ -625,7 +619,7 @@ def npt_langevin_init( cell_velocities=cell_velocities, cell_masses=cell_masses, cell_alpha=cell_alpha, - constraints=state.constraints, + _constraints=state.constraints, ) @@ -1028,14 +1022,7 @@ def _npt_nose_hoover_exp_iL1( # noqa: N802 state.positions * (torch.exp(x_expanded) - 1) + dt * velocities * torch.exp(x_2_expanded) * sinh_expanded ) - new_positions = state.positions + new_positions - - # Apply periodic boundary conditions if needed - if state.pbc: - return ts.transforms.pbc_wrap_batched( - new_positions, state.current_cell, state.system_idx - ) - return new_positions + return state.positions + new_positions def _npt_nose_hoover_exp_iL2( # noqa: N802 @@ -1245,7 +1232,7 @@ def _npt_nose_hoover_inner_step( # Update particle positions and forces positions = _npt_nose_hoover_exp_iL1(state, state.velocities, cell_velocities, dt) - state.positions = positions + state.set_positions(positions) state.cell = cell model_output = model(state) @@ -1266,8 +1253,8 @@ def _npt_nose_hoover_inner_step( cell_momentum = cell_momentum + dt_2 * cell_force_val.unsqueeze(-1) # Return updated state - state.positions = positions - state.momenta = momenta + state.set_positions(positions) + state.set_momenta(momenta) state.forces = model_output["forces"] state.energy = model_output["energy"] state.cell_position = cell_position @@ -1431,7 +1418,7 @@ def npt_nose_hoover_init( thermostat=thermostat_fns.initialize(dof_per_system, KE_thermostat, kT), barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, - constraints=state.constraints, + _constraints=state.constraints, ) diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index 532add73..b4db4e6c 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -67,7 +67,7 @@ def nve_init( pbc=state.pbc, system_idx=state.system_idx, atomic_numbers=state.atomic_numbers, - constraints=state.constraints, + _constraints=state.constraints, ) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index ef52adc8..080e6d28 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -72,7 +72,7 @@ def _ou_step( c1.unsqueeze(-1) * state.momenta + c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise ) - state.momenta = new_momenta + state.set_momenta(new_momenta) return state @@ -118,7 +118,6 @@ def nvt_langevin_init( "momenta", calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return MDState( positions=state.positions, momenta=momenta, @@ -129,13 +128,13 @@ def nvt_langevin_init( pbc=state.pbc, system_idx=state.system_idx, atomic_numbers=state.atomic_numbers, - constraints=state.constraints, + _constraints=state.constraints, ) def nvt_langevin_step( - model: ModelInterface, state: MDState, + model: ModelInterface, *, dt: float | torch.Tensor, kT: float | torch.Tensor, @@ -248,8 +247,8 @@ def get_number_of_degrees_of_freedom(self) -> torch.Tensor: def nvt_nose_hoover_init( - model: ModelInterface, state: SimState | StateDict, + model: ModelInterface, *, kT: torch.Tensor, dt: torch.Tensor, @@ -329,13 +328,13 @@ def nvt_nose_hoover_init( system_idx=state.system_idx, chain=chain_fns.initialize(dof_per_system, KE, kT), _chain_fns=chain_fns, # Store the chain functions - constraints=state.constraints, + _constraints=state.constraints, ) def nvt_nose_hoover_step( - model: ModelInterface, state: NVTNoseHooverState, + model: ModelInterface, *, dt: torch.Tensor, kT: torch.Tensor, @@ -374,7 +373,7 @@ def nvt_nose_hoover_step( # First half-step of chain evolution momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx) - state.momenta = momenta + state.set_momenta(momenta) # Full velocity Verlet step state = velocity_verlet(state=state, dt=dt, model=model) @@ -387,7 +386,7 @@ def nvt_nose_hoover_step( # Second half-step of chain evolution momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx) - state.momenta = momenta + state.set_momenta(momenta) state.chain = chain return state diff --git a/torch_sim/models/einstein.py b/torch_sim/models/einstein.py new file mode 100644 index 00000000..af452f38 --- /dev/null +++ b/torch_sim/models/einstein.py @@ -0,0 +1,285 @@ +"""Einstein model where each atom is treated as an independent 3D harmonic oscillator. + +Contrary to other models, the model energies depend on an absolute reference position, +so the model can only be used on systems that the model was initialized with. +As a analytical model, it can provide its Helmholtz free energy and can also generate +samples from the Boltzmann distribution at a given temperature. +""" + +import torch + +import torch_sim as ts +from torch_sim import SimState, units +from torch_sim.models.interface import ModelInterface + + +class EinsteinModel(ModelInterface): + """Einstein model where each atom is treated as an independent 3D harmonic oscillator. + Each atom has its own frequency. + + For this model: + E = sum_i 0.5 * k_i * (x_i - x0_i)^2 + F = -k_i * (x_i - x0_i) + k_i = m_i * omega_i^2 + + For best results, frequencies should be in the range of typical phonon frequencies. + They can be set for each atom type individually following energy balance from + a NVT simulation. From equipartition theorem: + = 3/2 k_B T + => omega = sqrt(3 k_B T / m ) + """ + + def __init__( + self, + equilibrium_position: torch.Tensor, # shape [N, 3] + frequencies: torch.Tensor, # shape [N] + system_idx: torch.Tensor | None = None, # shape [N] or None + masses: torch.Tensor | None = None, # shape [N] or None + reference_energy: float = 0.0, # reference energy value + *, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, + compute_forces: bool = True, + compute_stress: bool = False, + ) -> None: + """Initialize the Einstein model. + + Args: + equilibrium_position: Tensor of shape [N, 3] with equilibrium positions. + frequencies: Tensor of shape [N] with frequencies for each atom + (same frequency in all 3 directions). + system_idx: Optional tensor of shape [N] with system indices for each atom. + If None, all atoms are assumed to belong to the same system. + masses: Optional tensor of shape [N] with masses for each atom. + If None, all masses are set to 1. + reference_energy: Reference energy value to add to the computed energy. + device: Device to use for the model (default: CPU). + dtype: Data type for the model (default: torch.float32). + compute_forces: Whether to compute forces in the model. + compute_stress: Whether to compute stress in the model. + + """ + super().__init__() + self._device = device or torch.device("cpu") + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + + equilibrium_position = torch.as_tensor( + equilibrium_position, device=self._device, dtype=self._dtype + ) + frequencies = torch.as_tensor( + frequencies, device=self._device, dtype=self._dtype + ) # [N, 3] + + if frequencies.shape[0] != equilibrium_position.shape[0]: + raise ValueError("frequencies shape must match equilibrium_position shape") + if frequencies.min() < 0: + raise ValueError("frequencies must be non-negative") + if frequencies.ndim == 0: + frequencies = frequencies.unsqueeze(0) + if frequencies.ndim != 1: + raise ValueError("frequencies must be a 1D tensor") + + if masses is None: + masses = torch.ones( + equilibrium_position.shape[0], dtype=self._dtype, device=self._device + ) + else: + masses = masses.to(self._device, self._dtype) + + if system_idx is not None: + system_idx = system_idx.to(self._device) + else: + system_idx = torch.zeros( + equilibrium_position.shape[0], dtype=torch.long, device=self._device + ) + + self.register_buffer("system_idx", system_idx.to(self._device)) + self.register_buffer("masses", masses) # [N] + self.register_buffer("x0", equilibrium_position) # [N, 3] + self.register_buffer("frequencies", frequencies) # [N] + self.register_buffer( + "reference_energy", + torch.tensor(reference_energy, dtype=self._dtype, device=self._device), + ) + + @classmethod + def from_atom_and_frequencies( + cls, + atom: SimState, + frequencies: torch.Tensor | float, + *, + reference_energy: float = 0.0, + compute_forces: bool = True, + compute_stress: bool = False, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, + ) -> "EinsteinModel": + """Create an EinsteinModel from an ASE Atoms object and frequencies. + + Args: + atom: ASE Atoms object containing the reference structure. + frequencies: Tensor of shape [N] with frequencies for each atom + (same frequency in all 3 directions) or a scalar. + reference_energy: Reference energy value. + compute_forces: Whether to compute forces in the model. + compute_stress: Whether to compute stress in the model. + device: Device to use for the model (default: CPU). + dtype: Data type for the model (default: torch.float32). + + Returns: + EinsteinModel: An instance of the EinsteinModel. + """ + # Get equilibrium positions from the atoms object + equilibrium_position = atom.positions.clone().to(dtype=dtype, device=device) + + frequencies = torch.as_tensor(frequencies, dtype=dtype, device=device) + if frequencies.ndim == 0: + frequencies = frequencies.repeat(atom.positions.shape[0]) + if frequencies.shape[0] != atom.positions.shape[0]: + raise ValueError( + "frequencies must be a scalar or a tensor of shape [N] " + "where N is the number of atoms" + ) + + # Create and return an instance of EinsteinModel + return cls( + equilibrium_position=equilibrium_position, + frequencies=frequencies, + masses=atom.masses, + system_idx=atom.system_idx, + reference_energy=reference_energy, + compute_forces=compute_forces, + compute_stress=compute_stress, + device=device, + dtype=dtype, + ) + + def get_spring_constants(self) -> torch.Tensor: + """Get the spring constants for each atom in the Einstein model. + + Returns: + Tensor of shape [N] with spring constants k_i = m_i * omega_i^2 + for each atom. + """ + return self.masses * (self.frequencies**2) # [N] + + def forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: + """Calculate energies and forces for the Einstein model. + + Args: + state: SimState or StateDict containing positions, cell, etc. + + Returns: + Dictionary containing energy, forces + """ + pos = state.positions.to(self._dtype) # [N, 3] + cell = state.cell.to(self._dtype) + + if cell.ndim == 2: + cell = cell.unsqueeze(0) # [1, 3, 3] + + # Get model parameters + x0 = torch.as_tensor(self.x0, dtype=self._dtype, device=self._device) + frequencies = torch.as_tensor( + self.frequencies, dtype=self._dtype, device=self._device + ) + masses = torch.as_tensor(self.masses, dtype=self._dtype, device=self._device) + + # Calculate displacements using periodic boundary conditions + if cell.shape[0] == 1: + disp = ts.transforms.minimum_image_displacement( + dr=pos - x0, cell=cell[0], pbc=state.pbc + ) + else: + disp = ts.transforms.minimum_image_displacement_batched( + pos - x0, cell, system_idx=state.system_idx, pbc=state.pbc + ) + + # Spring constants: k = m * omega^2 + spring_constants = masses * (frequencies**2) # [N] + + # Energy: E = 0.5 * k * x^2 + energies_per_mode = 0.5 * spring_constants * ((disp**2).sum(dim=1)) # [N] + total_energy = torch.zeros( + state.n_systems, dtype=self._dtype, device=self._device + ) + total_energy.scatter_add_(0, state.system_idx, energies_per_mode) + total_energy += self.reference_energy + + # Forces: F = -k * x + forces = -spring_constants.unsqueeze(-1) * disp # [N, 3] + + results = { + "energy": total_energy, + "forces": forces, + } + # Stress is not implemented for this model + if self._compute_stress: + results["stress"] = torch.zeros( + (state.n_systems, 3, 3), dtype=self._dtype, device=self._device + ) + + return results + + def get_free_energy(self, temperature: float) -> dict[str, torch.Tensor]: + """Compute free energy at a given temperature using Einstein model. + + Args: + temperature: Temperature in Kelvin. + + Returns: + Dictionary containing heat capacity, entropy, and free energy. + """ + # Boltzmann constant in eV/K + kB = units.BaseConstant.k_B / units.UnitConversion.eV_to_J + T = temperature + # Reduced Planck constant in eV*s + hbar = units.BaseConstant.h_planck / (2 * units.pi * units.UnitConversion.eV_to_J) + + frequencies_tensor = ( + torch.as_tensor(self.frequencies).clone() + * torch.as_tensor( + units.UnitConversion.eV_to_J / units.BaseConstant.amu + ).sqrt() + / units.UnitConversion.Ang_to_met + ) # Convert to rad/s + free_energy_per_atom = ( + -3 * kB * T * torch.log(kB * T / (hbar * frequencies_tensor)) + ) + + n_systems = self.system_idx.max().item() + 1 + free_energy_per_system = torch.zeros( + n_systems, dtype=self._dtype, device=self._device + ) + free_energy_per_system.scatter_add_(0, self.system_idx, free_energy_per_atom) + + return {"free_energy": free_energy_per_system} + + def sample(self, state: SimState, temperature: float) -> SimState: + """Generate samples from the Einstein model at a given temperature. + + Args: + state: Initial simulation state to sample from. + temperature: Temperature in Kelvin. + + Returns: + SimState containing sampled positions and velocities. + + The Boltzmann distribution for a harmonic oscillator leads to Gaussian + distributions + for both positions and velocities. + """ + N = self.x0.shape[0] + kB = units.BaseConstant.k_B / units.UnitConversion.eV_to_J + beta = 1.0 / (kB * temperature) # Inverse temperature in 1/eV + + # Sample positions from a normal distribution around equilibrium positions + stddev = torch.sqrt(1.0 / (self.masses * (self.frequencies**2) * beta)).unsqueeze( + -1 + ) + sampled_positions = self.x0 + torch.randn(N, 3, device=self._device) * stddev + state = state.clone() + state.positions = sampled_positions + return state diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 110c6bbc..d1c2454d 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -80,7 +80,7 @@ def fire_init( "cell": state.cell.clone(), "atomic_numbers": state.atomic_numbers.clone(), "system_idx": state.system_idx.clone(), - "constraints": state.constraints, + "_constraints": state.constraints, "pbc": state.pbc, # Optimization state "forces": forces, diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index d100bfaf..246aee1a 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -61,7 +61,7 @@ def gradient_descent_init( "pbc": state.pbc, "atomic_numbers": state.atomic_numbers, "system_idx": state.system_idx, - "constraints": state.constraints, + "_constraints": state.constraints, } if cell_filter is not None: # Create cell optimization state diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index b65455b7..8f8885d0 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -25,9 +25,8 @@ class OptimState(SimState): def set_forces(self, new_forces: torch.Tensor) -> None: """Set new forces in the optimization state.""" - if self.constraints is not None: - for constraint in self.constraints: - constraint.adjust_forces(self, new_forces) + for constraint in self.constraints: + constraint.adjust_forces(self, new_forces) self.forces = new_forces def __init__( @@ -42,21 +41,20 @@ def __init__( pbc: torch.Tensor, atomic_numbers: torch.Tensor, system_idx: torch.Tensor, - constraints: list | None = None, + _constraints: list | None = None, ) -> None: """Initialize optimization state.""" - super().__init__( - positions=positions, - masses=masses, - cell=cell, - pbc=pbc, - atomic_numbers=atomic_numbers, - system_idx=system_idx, - constraints=constraints, - ) + self.positions = positions + self.masses = masses + self.cell = cell + self.pbc = pbc + self.atomic_numbers = atomic_numbers + self.system_idx = system_idx + self._constraints = _constraints self.energy = energy self.set_forces(forces) self.stress = stress + SimState.__post_init__(self) @dataclass(kw_only=True) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index b2059aac..43c33db3 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -165,7 +165,6 @@ def integrate[T: SimState]( # noqa: C901 f"integrator must be key from Integrator or a tuple of " f"(init_func, step_func), got {type(integrator)}" ) - # batch_iterator will be a list if autobatcher is False batch_iterator = _configure_batches_iterator( initial_state, model, autobatcher=autobatcher diff --git a/torch_sim/state.py b/torch_sim/state.py index 7cbc21e3..d6228f3a 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -246,9 +246,8 @@ def set_positions(self, new_positions: torch.Tensor) -> None: new_positions: New positions tensor with shape (n_atoms, 3) """ # Apply constraints if they exist - if self.constraints is not None: - for constraint in self.constraints: - constraint.adjust_positions(self, new_positions) + for constraint in self.constraints: + constraint.adjust_positions(self, new_positions) self.positions = new_positions @property @@ -275,9 +274,8 @@ def constraints(self, constraints: list[Constraint] | Constraint) -> None: if isinstance(constraints, Constraint): constraints = [constraints] for constraint in constraints: - if ( - isinstance(constraint, SystemConstraint) - and constraint.initialized is False + if (isinstance(constraint, SystemConstraint)) and ( + not constraint.initialized ): constraint.system_idx = torch.arange(self.n_systems, device=self.device) constraint.initialized = True @@ -322,6 +320,7 @@ def clone(self) -> Self: attrs[attr_name] = attr_value.clone() else: attrs[attr_name] = copy.deepcopy(attr_value) + attrs["_constraints"] = copy.deepcopy(self.constraints) return type(self)(**attrs) @@ -671,7 +670,7 @@ def _state_to_device[T: SimState]( attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) - return type(state)(**attrs) # type: ignore[invalid-return-type] + return state def get_attrs_for_scope( @@ -722,7 +721,7 @@ def _filter_attrs_by_mask( # atoms_mask = torch.isin(state.system_idx, torch.nonzero(system_mask).squeeze()) # Copy global attributes directly filtered_attrs = dict(get_attrs_for_scope(state, "global")) - filtered_attrs["constraints"] = copy.deepcopy(filtered_attrs.get("constraints", [])) + filtered_attrs["_constraints"] = copy.deepcopy(state.constraints) new_n_atoms_per_system = state.n_atoms_per_system[system_mask] cum_sum_atoms = torch.cumsum(new_n_atoms_per_system, dim=0) @@ -753,7 +752,7 @@ def _filter_attrs_by_mask( filtered_attrs[attr_name] = new_system_idxs # take into account constraints that are AtomIndexedConstraint - for constraint in filtered_attrs.get("constraints", []): + for constraint in filtered_attrs.get("_constraints", []): if isinstance(constraint, AtomIndexedConstraint): constraint.indices = torch.tensor( [ @@ -769,11 +768,7 @@ def _filter_attrs_by_mask( device=old_system_indices.device, dtype=constraint.indices.dtype, ) - elif isinstance(constraint, SystemConstraint) and isinstance( - constraint.system_idx, torch.Tensor - ): - # print(constraint.system_idx, system_mask) - # constraint.system_idx = constraint.system_idx[system_mask] + elif isinstance(constraint, SystemConstraint): constraint.system_idx = torch.tensor( [ system_idx_map[idx.item()] @@ -849,8 +844,8 @@ def _split_state[T: SimState](state: T) -> list[T]: # Add the global attributes **global_attrs, } - system_attrs["constraints"] = copy.deepcopy(system_attrs.get("constraints", [])) - for constraint in system_attrs.get("constraints", []): + new_constraints = copy.deepcopy(state.constraints) + for constraint in new_constraints: if isinstance(constraint, SystemConstraint): # Update system_mask to only include this system constraint.system_idx = ( @@ -871,6 +866,22 @@ def _split_state[T: SimState](state: T) -> list[T]: device=state.device, dtype=torch.int64, ) + # Remove empty constraints + new_constraints = [ + constraint + for constraint in new_constraints + if ( + ( + isinstance(constraint, SystemConstraint) + and len(constraint.system_idx) > 0 + ) + or ( + isinstance(constraint, AtomIndexedConstraint) + and len(constraint.indices) > 0 + ) + ) + ] + system_attrs["_constraints"] = new_constraints states.append(type(state)(**system_attrs)) # type: ignore[invalid-argument-type] n_atoms_cumsum += system_sizes[sys_idx] @@ -998,7 +1009,6 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Initialize result with global properties from first state concatenated = dict(get_attrs_for_scope(first_state, "global")) - del concatenated["constraints"] # will handle constraints separately # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) @@ -1031,52 +1041,45 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 new_indices = state.system_idx + system_offset new_system_indices.append(new_indices) - if state.constraints is not None: - for constraint in state.constraints: - constraint_name = type(constraint).__name__ - if isinstance( - constraint, SystemConstraint - ) and constraint.system_idx == slice(None): - constraint.system_idx = torch.arange( - num_systems, device=target_device - ) - if constraint_name not in constraints: - # if it's IndexedConstraint then we need to adjust the indices - if isinstance(constraint, AtomIndexedConstraint): - new_constraint = copy.deepcopy(constraint) - new_constraint.indices = torch.empty( - 0, dtype=torch.long, device=target_device - ) - constraints[constraint_name] = new_constraint - elif isinstance(constraint, SystemConstraint): - new_constraint = copy.deepcopy(constraint) - new_constraint.system_idx = torch.empty( - 0, dtype=torch.long, device=target_device - ) - constraints[constraint_name] = new_constraint - else: - raise NotImplementedError( - f"Concatenation of constraint type " - f"{type(constraint)} is not implemented" - ) - # need to adjust the indices for IndexedConstraint + for constraint in state.constraints: + constraint_name = type(constraint).__name__ + if constraint_name not in constraints: + # if it's IndexedConstraint then we need to adjust the indices if isinstance(constraint, AtomIndexedConstraint): - new_constraint = constraints[constraint_name] - new_constraint.indices = torch.concat( - ( - new_constraint.indices, - constraint.indices + n_atoms_offset, - ) + new_constraint = copy.deepcopy(constraint) + new_constraint.indices = torch.empty( + 0, dtype=torch.long, device=target_device ) + constraints[constraint_name] = new_constraint elif isinstance(constraint, SystemConstraint): - new_constraint = constraints[constraint_name] - new_constraint.system_idx = torch.concat( - ( - new_constraint.system_idx, - constraint.system_idx + system_offset, - ) + new_constraint = copy.deepcopy(constraint) + new_constraint.system_idx = torch.empty( + 0, dtype=torch.long, device=target_device + ) + constraints[constraint_name] = new_constraint + else: + raise NotImplementedError( + f"Concatenation of constraint type " + f"{type(constraint)} is not implemented" + ) + # need to adjust the indices for IndexedConstraint + if isinstance(constraint, AtomIndexedConstraint): + new_constraint = constraints[constraint_name] + new_constraint.indices = torch.concat( + ( + new_constraint.indices, + constraint.indices + n_atoms_offset, ) - constraints[constraint_name] = new_constraint + ) + elif isinstance(constraint, SystemConstraint): + new_constraint = constraints[constraint_name] + new_constraint.system_idx = torch.concat( + ( + new_constraint.system_idx, + constraint.system_idx + system_offset, + ) + ) + constraints[constraint_name] = new_constraint n_atoms_offset += state.n_atoms system_offset += num_systems @@ -1098,7 +1101,7 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 constraints = list(constraints.values()) # Create a new instance of the same class - return state_class(**concatenated, constraints=constraints) + return state_class(**concatenated, _constraints=constraints) def initialize_state( From 399fbfdf2adef97ebed26592bb589b7261591de4 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 10 Nov 2025 18:58:15 +0100 Subject: [PATCH 14/43] correct get_centers_of_mass --- torch_sim/transforms.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index b14be48b..48bb550f 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -1248,5 +1248,8 @@ def get_centers_of_mass( system_idx.unsqueeze(-1).expand(-1, 3), masses.unsqueeze(-1) * positions, ) - coms /= masses.unsqueeze(-1).sum(dim=0) + system_masses = torch.zeros((n_systems,), dtype=positions.dtype).scatter_add_( + 0, system_idx, masses + ) + coms /= system_masses.unsqueeze(-1) return coms From b31ba8064ad2e1ddac2fda12fa15193194671e9e Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 10 Nov 2025 19:07:53 +0100 Subject: [PATCH 15/43] add warnings for npt dynamics --- torch_sim/integrators/npt.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index e5337269..bad47d78 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1,5 +1,6 @@ """Implementations of NPT integrators.""" +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -600,6 +601,16 @@ def npt_langevin_init( ) cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau + if state.constraints: + # warn if constraints are present + warnings.warn( + "Constraints are present in the system. " + "Make sure they are compatible with NPT Langevin dynamics." + "We recommend not using constraints with NPT dynamics for now.", + UserWarning, + stacklevel=3, + ) + # Create the initial state return NPTLangevinState( positions=state.positions, @@ -1399,6 +1410,16 @@ def npt_nose_hoover_init( forces = model_output["forces"] energy = model_output["energy"] + if state.constraints: + # warn if constraints are present + warnings.warn( + "Constraints are present in the system. " + "Make sure they are compatible with NPT Nosé Hoover dynamics." + "We recommend not using constraints with NPT dynamics for now.", + UserWarning, + stacklevel=3, + ) + # Create initial state return NPTNoseHooverState( positions=state.positions, From 14839775202c7dc67b1dd3e6dc5dc202115a2e21 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 18 Nov 2025 16:56:18 -0500 Subject: [PATCH 16/43] simplify state updating in _filter_attrs_by_mask --- torch_sim/constraints.py | 71 +++++++++++++++++++++++++++++++++++++++- torch_sim/state.py | 38 +++++---------------- 2 files changed, 78 insertions(+), 31 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 1f40ab19..12c7dfd1 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -178,6 +178,71 @@ def __init__(self, system_idx: torch.Tensor | list[int] | None = None) -> None: ) +def update_constraint( + constraint: AtomIndexedConstraint | SystemConstraint, + atom_mask: torch.Tensor, + system_mask: torch.Tensor, +) -> Constraint: + """Update a constraint to account for atom and system masks. + + Args: + constraint: Constraint to update + atom_mask: Boolean mask for atoms to keep + system_mask: Boolean mask for systems to keep + """ + # n_atoms_per_system = system_idx.bincount() + + def update_indices(idx: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + cumsum_atom_mask = torch.cumsum(~mask, dim=0) + new_indices = idx - cumsum_atom_mask[idx] + mask_indices = torch.where(mask)[0] + drop_indices = ~torch.isin(idx, mask_indices) + return new_indices[~drop_indices] + + if isinstance(constraint, AtomIndexedConstraint): + constraint.indices = update_indices(constraint.indices, atom_mask) + + elif isinstance(constraint, SystemConstraint): + constraint.system_idx = update_indices(constraint.system_idx, system_mask) + + else: + raise NotImplementedError( + f"Constraint type {type(constraint)} is not implemented" + ) + + return constraint + + +# def split_constraint( +# constraint: AtomIndexedConstraint | SystemConstraint, +# system_idx: torch.Tensor, +# ) -> list[AtomIndexedConstraint | SystemConstraint]: +# """Split a constraint into a list of constraints.""" +# n_atoms_per_system = system_idx.bincount() + +# # just atom indexed for now +# for sys_idx in range(max(system_idx) + 1): + + +# return [constraint] + + +# def merge_constraint_into_list( +# constraints: list[AtomIndexedConstraint | SystemConstraint], +# constraint: AtomIndexedConstraint | SystemConstraint, +# ) -> list[Constraint]: +# """Merge a constraint into a list of constraints. + +# Args: +# constraints: List of constraints +# constraint: Constraint to merge + +# Returns: +# List of constraints +# """ +# return constraints + + class FixAtoms(AtomIndexedConstraint): """Constraint that fixes specified atoms in place. @@ -220,7 +285,11 @@ def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None """ new_positions[self.indices] = state.positions[self.indices] - def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: # noqa: ARG002 + def adjust_forces( + self, + state: SimState, # noqa: ARG002 + forces: torch.Tensor, + ) -> None: """Set forces on fixed atoms to zero. Args: diff --git a/torch_sim/state.py b/torch_sim/state.py index d6228f3a..175984a1 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -15,6 +15,7 @@ import torch import torch_sim as ts +from torch_sim.constraints import update_constraint from torch_sim.typing import StateLike @@ -184,7 +185,7 @@ def n_atoms(self) -> int: @property def n_atoms_per_system(self) -> torch.Tensor: - """Number of atoms per system.""" + """Number of atoms per system. Length is n_systems.""" return ( self.system_idx.bincount() if self.system_idx is not None @@ -721,7 +722,12 @@ def _filter_attrs_by_mask( # atoms_mask = torch.isin(state.system_idx, torch.nonzero(system_mask).squeeze()) # Copy global attributes directly filtered_attrs = dict(get_attrs_for_scope(state, "global")) - filtered_attrs["_constraints"] = copy.deepcopy(state.constraints) + + # take into account constraints that are AtomIndexedConstraint + filtered_attrs["_constraints"] = [ + update_constraint(constraint, atom_mask, system_mask) + for constraint in copy.deepcopy(state.constraints) + ] new_n_atoms_per_system = state.n_atoms_per_system[system_mask] cum_sum_atoms = torch.cumsum(new_n_atoms_per_system, dim=0) @@ -751,34 +757,6 @@ def _filter_attrs_by_mask( ) filtered_attrs[attr_name] = new_system_idxs - # take into account constraints that are AtomIndexedConstraint - for constraint in filtered_attrs.get("_constraints", []): - if isinstance(constraint, AtomIndexedConstraint): - constraint.indices = torch.tensor( - [ - i - - cum_sum_atoms[ - system_idx_map[ - old_system_indices[state.system_idx[i]].item() - ] - ] - for i in constraint.indices - if atom_mask[i] - ], - device=old_system_indices.device, - dtype=constraint.indices.dtype, - ) - elif isinstance(constraint, SystemConstraint): - constraint.system_idx = torch.tensor( - [ - system_idx_map[idx.item()] - for idx in constraint.system_idx - if system_mask[idx] - ], - device=constraint.system_idx.device, - dtype=constraint.system_idx.dtype, - ) - else: filtered_attrs[attr_name] = attr_value[atom_mask] From 3c267eb642f1bf4b8024eb32c2919aac0cb6df43 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 19 Nov 2025 13:36:09 -0500 Subject: [PATCH 17/43] simplify _split_state with select_sub_constraint function --- torch_sim/constraints.py | 43 +++++++++++++++++++++++++++--------- torch_sim/state.py | 47 ++++++++-------------------------------- 2 files changed, 42 insertions(+), 48 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 12c7dfd1..b32a5def 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -213,18 +213,41 @@ def update_indices(idx: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: return constraint -# def split_constraint( -# constraint: AtomIndexedConstraint | SystemConstraint, -# system_idx: torch.Tensor, -# ) -> list[AtomIndexedConstraint | SystemConstraint]: -# """Split a constraint into a list of constraints.""" -# n_atoms_per_system = system_idx.bincount() - -# # just atom indexed for now -# for sys_idx in range(max(system_idx) + 1): +def select_sub_constraint( + constraint: AtomIndexedConstraint | SystemConstraint, + atom_idx: torch.Tensor, + sys_idx: int, +) -> AtomIndexedConstraint | SystemConstraint | None: + """Select a constraint for a given atom and system index. + Args: + constraint: Constraint to select + atom_idx: Atom indices for a single system + sys_idx: System index for a single system + """ + """Split a constraint into a list of constraints.""" -# return [constraint] + # TODO: finish this + # TODO: we can probably eliminate the for loop and make this a split + # out constraint function that just bumps out a single constraint + # then embed this function in the split_state function + if isinstance(constraint, AtomIndexedConstraint): + mask = torch.isin(constraint.indices, atom_idx) + masked_indices = constraint.indices[mask] + new_atom_idx = masked_indices - atom_idx.min() + if len(new_atom_idx) == 0: + return None + return type(constraint)(new_atom_idx) + + if isinstance(constraint, SystemConstraint): + mask = torch.isin(constraint.system_idx, sys_idx) + masked_system_idx = constraint.system_idx[mask] + new_system_idx = masked_system_idx - sys_idx + if len(new_system_idx) == 0: + return None + return type(constraint)(new_system_idx) + + raise NotImplementedError(f"Constraint type {type(constraint)} is not implemented") # def merge_constraint_into_list( diff --git a/torch_sim/state.py b/torch_sim/state.py index 175984a1..ff21b25d 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -15,7 +15,7 @@ import torch import torch_sim as ts -from torch_sim.constraints import update_constraint +from torch_sim.constraints import select_sub_constraint, update_constraint from torch_sim.typing import StateLike @@ -802,7 +802,8 @@ def _split_state[T: SimState](state: T) -> list[T]: # Create a state for each system states: list[T] = [] n_systems = len(system_sizes) - n_atoms_cumsum = 0 + zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64) + cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0))) for sys_idx in range(n_systems): system_attrs = { # Create a system tensor with all zeros for this system @@ -822,46 +823,16 @@ def _split_state[T: SimState](state: T) -> list[T]: # Add the global attributes **global_attrs, } - new_constraints = copy.deepcopy(state.constraints) - for constraint in new_constraints: - if isinstance(constraint, SystemConstraint): - # Update system_mask to only include this system - constraint.system_idx = ( - torch.tensor([0], device=state.device, dtype=torch.int64) - if sys_idx in constraint.system_idx - else torch.tensor([], device=state.device, dtype=torch.int64) - ) - elif isinstance(constraint, AtomIndexedConstraint): - # Update atom_indices to only include atoms from this system - atom_start = n_atoms_cumsum - atom_end = n_atoms_cumsum + system_sizes[sys_idx] - constraint.indices = torch.tensor( - [ - idx - atom_start - for idx in constraint.indices - if atom_start <= idx < atom_end - ], - device=state.device, - dtype=torch.int64, - ) - # Remove empty constraints + + atom_idx = torch.arange(cumsum_atoms[sys_idx], cumsum_atoms[sys_idx + 1]) new_constraints = [ - constraint - for constraint in new_constraints - if ( - ( - isinstance(constraint, SystemConstraint) - and len(constraint.system_idx) > 0 - ) - or ( - isinstance(constraint, AtomIndexedConstraint) - and len(constraint.indices) > 0 - ) - ) + new_constraint + for constraint in state.constraints + if (new_constraint := select_sub_constraint(constraint, atom_idx, sys_idx)) ] + system_attrs["_constraints"] = new_constraints states.append(type(state)(**system_attrs)) # type: ignore[invalid-argument-type] - n_atoms_cumsum += system_sizes[sys_idx] return states From 06400e120781bc60951df8935c0b39f6f3651b90 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Wed, 19 Nov 2025 15:46:06 -0500 Subject: [PATCH 18/43] make constraint handling more modular with methods, merge states currently broken --- torch_sim/constraints.py | 198 ++++++++++++++++++++++++--------------- torch_sim/state.py | 68 +++----------- 2 files changed, 134 insertions(+), 132 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index b32a5def..9755b952 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -77,6 +77,37 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: forces: Forces to be adjusted """ + @abstractmethod + def update_constraint( + self, atom_mask: torch.Tensor, system_mask: torch.Tensor + ) -> Constraint: + """Update the constraint to account for atom and system masks. + + Args: + atom_mask: Boolean mask for atoms to keep + system_mask: Boolean mask for systems to keep + """ + + @abstractmethod + def select_sub_constraint(self, atom_idx: torch.Tensor, sys_idx: int) -> Constraint: + """Select a constraint for a given atom and system index. + + Args: + atom_idx: Atom indices for a single system + sys_idx: System index for a single system + + Returns: + Constraint for the given atom and system index + """ + + +def _mask_constraint_indices(idx: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + cumsum_atom_mask = torch.cumsum(~mask, dim=0) + new_indices = idx - cumsum_atom_mask[idx] + mask_indices = torch.where(mask)[0] + drop_indices = ~torch.isin(idx, mask_indices) + return new_indices[~drop_indices] + class AtomIndexedConstraint(Constraint): """Base class for constraints that act on specific atom indices. @@ -140,6 +171,38 @@ def get_indices(self) -> torch.Tensor: """ return self.indices.clone() + def update_constraint( + self, + atom_mask: torch.Tensor, + system_mask: torch.Tensor, # noqa: ARG002 + ) -> Constraint: + """Update the constraint to account for atom and system masks. + + Args: + atom_mask: Boolean mask for atoms to keep + system_mask: Boolean mask for systems to keep + """ + self.indices = _mask_constraint_indices(self.indices, atom_mask) + return self + + def select_sub_constraint( + self, + atom_idx: torch.Tensor, + sys_idx: int, # noqa: ARG002 + ) -> Constraint: + """Select a constraint for a given atom and system index. + + Args: + atom_idx: Atom indices for a single system + sys_idx: System index for a single system + """ + mask = torch.isin(self.indices, atom_idx) + masked_indices = self.indices[mask] + new_atom_idx = masked_indices - atom_idx.min() + if len(new_atom_idx) == 0: + return None + return type(self)(new_atom_idx) + class SystemConstraint(Constraint): """Base class for constraints that act on specific system indices. @@ -177,93 +240,76 @@ def __init__(self, system_idx: torch.Tensor | list[int] | None = None) -> None: f"Got {system_idx.ndim}, expected ndim <= 1" ) + def update_constraint( + self, + atom_mask: torch.Tensor, # noqa: ARG002 + system_mask: torch.Tensor, + ) -> Constraint: + """Update the constraint to account for atom and system masks. -def update_constraint( - constraint: AtomIndexedConstraint | SystemConstraint, - atom_mask: torch.Tensor, - system_mask: torch.Tensor, -) -> Constraint: - """Update a constraint to account for atom and system masks. - - Args: - constraint: Constraint to update - atom_mask: Boolean mask for atoms to keep - system_mask: Boolean mask for systems to keep - """ - # n_atoms_per_system = system_idx.bincount() - - def update_indices(idx: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - cumsum_atom_mask = torch.cumsum(~mask, dim=0) - new_indices = idx - cumsum_atom_mask[idx] - mask_indices = torch.where(mask)[0] - drop_indices = ~torch.isin(idx, mask_indices) - return new_indices[~drop_indices] - - if isinstance(constraint, AtomIndexedConstraint): - constraint.indices = update_indices(constraint.indices, atom_mask) - - elif isinstance(constraint, SystemConstraint): - constraint.system_idx = update_indices(constraint.system_idx, system_mask) - - else: - raise NotImplementedError( - f"Constraint type {type(constraint)} is not implemented" - ) - - return constraint - - -def select_sub_constraint( - constraint: AtomIndexedConstraint | SystemConstraint, - atom_idx: torch.Tensor, - sys_idx: int, -) -> AtomIndexedConstraint | SystemConstraint | None: - """Select a constraint for a given atom and system index. + Args: + atom_mask: Boolean mask for atoms to keep + system_mask: Boolean mask for systems to keep + """ + self.system_idx = _mask_constraint_indices(self.system_idx, system_mask) + return self - Args: - constraint: Constraint to select - atom_idx: Atom indices for a single system - sys_idx: System index for a single system - """ - """Split a constraint into a list of constraints.""" - - # TODO: finish this - # TODO: we can probably eliminate the for loop and make this a split - # out constraint function that just bumps out a single constraint - # then embed this function in the split_state function - if isinstance(constraint, AtomIndexedConstraint): - mask = torch.isin(constraint.indices, atom_idx) - masked_indices = constraint.indices[mask] - new_atom_idx = masked_indices - atom_idx.min() - if len(new_atom_idx) == 0: - return None - return type(constraint)(new_atom_idx) + def select_sub_constraint( + self, + atom_idx: torch.Tensor, # noqa: ARG002 + sys_idx: int, + ) -> Constraint: + """Select a constraint for a given atom and system index. - if isinstance(constraint, SystemConstraint): - mask = torch.isin(constraint.system_idx, sys_idx) - masked_system_idx = constraint.system_idx[mask] + Args: + atom_idx: Atom indices for a single system + sys_idx: System index for a single system + """ + mask = torch.isin(self.system_idx, sys_idx) + masked_system_idx = self.system_idx[mask] new_system_idx = masked_system_idx - sys_idx if len(new_system_idx) == 0: return None - return type(constraint)(new_system_idx) + return type(self)(new_system_idx) - raise NotImplementedError(f"Constraint type {type(constraint)} is not implemented") +def merge_constraints( + constraint_lists: list[list[AtomIndexedConstraint | SystemConstraint]], + num_atoms_per_state: torch.Tensor, +) -> list[Constraint]: + """Merge constraints from multiple systems into a single list of constraints. -# def merge_constraint_into_list( -# constraints: list[AtomIndexedConstraint | SystemConstraint], -# constraint: AtomIndexedConstraint | SystemConstraint, -# ) -> list[Constraint]: -# """Merge a constraint into a list of constraints. + Args: + constraint_lists: List of lists of constraints + num_atoms_per_state: Number of atoms per system -# Args: -# constraints: List of constraints -# constraint: Constraint to merge + Returns: + List of merged constraints + """ + from collections import defaultdict + + cumsum_atoms = torch.cumsum(num_atoms_per_state, dim=0) - num_atoms_per_state[0] + + # aggregate updated constraint indices by constraint type + constraint_indices: dict[type[Constraint], list[torch.Tensor]] = defaultdict(list) + for i, constraint_list in enumerate(constraint_lists): + for constraint in constraint_list: + if isinstance(constraint, AtomIndexedConstraint): + idxs = constraint.indices + offset = cumsum_atoms[i] + elif isinstance(constraint, SystemConstraint): + idxs = constraint.system_idx + offset = i + else: + raise NotImplementedError( + f"Constraint type {type(constraint)} is not implemented" + ) + constraint_indices[type(constraint)].append(idxs + offset) -# Returns: -# List of constraints -# """ -# return constraints + return [ + constraint_type(torch.cat(idxs)) + for constraint_type, idxs in constraint_indices.items() + ] class FixAtoms(AtomIndexedConstraint): diff --git a/torch_sim/state.py b/torch_sim/state.py index ff21b25d..9ce49121 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -15,7 +15,6 @@ import torch import torch_sim as ts -from torch_sim.constraints import select_sub_constraint, update_constraint from torch_sim.typing import StateLike @@ -25,9 +24,9 @@ from pymatgen.core import Structure from torch_sim.constraints import ( - AtomIndexedConstraint, Constraint, SystemConstraint, + merge_constraints, validate_constraints, ) @@ -725,16 +724,10 @@ def _filter_attrs_by_mask( # take into account constraints that are AtomIndexedConstraint filtered_attrs["_constraints"] = [ - update_constraint(constraint, atom_mask, system_mask) + constraint.update_constraint(atom_mask, system_mask) for constraint in copy.deepcopy(state.constraints) ] - new_n_atoms_per_system = state.n_atoms_per_system[system_mask] - cum_sum_atoms = torch.cumsum(new_n_atoms_per_system, dim=0) - cum_sum_atoms = torch.cat( - (torch.tensor([0], device=cum_sum_atoms.device), cum_sum_atoms) - ) - # Filter per-atom attributes for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name == "system_idx": @@ -828,7 +821,7 @@ def _split_state[T: SimState](state: T) -> list[T]: new_constraints = [ new_constraint for constraint in state.constraints - if (new_constraint := select_sub_constraint(constraint, atom_idx, sys_idx)) + if (new_constraint := constraint.select_sub_constraint(atom_idx, sys_idx)) ] system_attrs["_constraints"] = new_constraints @@ -921,7 +914,7 @@ def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor return type(state)(**filtered_attrs) # type: ignore[invalid-return-type] -def concatenate_states[T: SimState]( # noqa: C901, PLR0915 +def concatenate_states[T: SimState]( # noqa: C901 states: Sequence[T], device: torch.device | None = None ) -> T: """Concatenate a list of SimStates into a single SimState. @@ -964,9 +957,7 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 per_system_tensors = defaultdict(list) new_system_indices = [] system_offset = 0 - n_atoms_offset = 0 - - constraints = {} + num_atoms_per_state = [] # Process all states in a single pass for state in states: @@ -989,48 +980,8 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 num_systems = state.n_systems new_indices = state.system_idx + system_offset new_system_indices.append(new_indices) + num_atoms_per_state.append(state.n_atoms) - for constraint in state.constraints: - constraint_name = type(constraint).__name__ - if constraint_name not in constraints: - # if it's IndexedConstraint then we need to adjust the indices - if isinstance(constraint, AtomIndexedConstraint): - new_constraint = copy.deepcopy(constraint) - new_constraint.indices = torch.empty( - 0, dtype=torch.long, device=target_device - ) - constraints[constraint_name] = new_constraint - elif isinstance(constraint, SystemConstraint): - new_constraint = copy.deepcopy(constraint) - new_constraint.system_idx = torch.empty( - 0, dtype=torch.long, device=target_device - ) - constraints[constraint_name] = new_constraint - else: - raise NotImplementedError( - f"Concatenation of constraint type " - f"{type(constraint)} is not implemented" - ) - # need to adjust the indices for IndexedConstraint - if isinstance(constraint, AtomIndexedConstraint): - new_constraint = constraints[constraint_name] - new_constraint.indices = torch.concat( - ( - new_constraint.indices, - constraint.indices + n_atoms_offset, - ) - ) - elif isinstance(constraint, SystemConstraint): - new_constraint = constraints[constraint_name] - new_constraint.system_idx = torch.concat( - ( - new_constraint.system_idx, - constraint.system_idx + system_offset, - ) - ) - constraints[constraint_name] = new_constraint - - n_atoms_offset += state.n_atoms system_offset += num_systems # Concatenate collected tensors @@ -1048,7 +999,12 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) - constraints = list(constraints.values()) + constraint_lists = [state.constraints for state in states] + + constraints = merge_constraints( + constraint_lists, torch.tensor(num_atoms_per_state, device=target_device) + ) + # Create a new instance of the same class return state_class(**concatenated, _constraints=constraints) From 4081973344dafb1310a04550604ddfa21d7103a5 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 21 Nov 2025 11:36:17 -0500 Subject: [PATCH 19/43] No longer allow initializing FixCom() or FixAtoms() with empty arguments --- tests/test_constraints.py | 48 ++++++++++++++++++++++++--------------- torch_sim/constraints.py | 19 ++++------------ torch_sim/state.py | 15 ++---------- 3 files changed, 36 insertions(+), 46 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 22723c93..bbcf8f04 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -15,7 +15,7 @@ def test_fix_com(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel): """Test adjustment of positions and momenta with FixCom constraint.""" - ar_supercell_sim_state.constraints = [FixCom()] + ar_supercell_sim_state.constraints = [FixCom([0])] initial_positions = ar_supercell_sim_state.positions.clone() ar_supercell_sim_state.set_positions(initial_positions + 0.5) assert torch.allclose(ar_supercell_sim_state.positions, initial_positions, atol=1e-8) @@ -30,7 +30,9 @@ def test_fix_com(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesMode torch.randn_like(ar_supercell_md_state.momenta) * 0.1 ) assert torch.allclose( - ar_supercell_md_state.momenta.mean(dim=0), torch.zeros(3, dtype=DTYPE), atol=1e-8 + ar_supercell_md_state.momenta.mean(dim=0), + torch.zeros(3, dtype=DTYPE), + atol=1e-8, ) @@ -81,7 +83,7 @@ def test_fix_com_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJonesM kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature dofs_before = cu_sim_state.get_number_of_degrees_of_freedom() - cu_sim_state.constraints = [FixCom()] + cu_sim_state.constraints = [FixCom([0])] assert torch.allclose( cu_sim_state.get_number_of_degrees_of_freedom(), dofs_before - 3 ) @@ -161,7 +163,10 @@ def test_fix_atoms_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJone def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): """Test that constraints are properly propagated during state manipulation.""" # Set up constraints on the original state - ar_double_sim_state.constraints = [FixAtoms(indices=torch.tensor([0, 1])), FixCom()] + ar_double_sim_state.constraints = [ + FixAtoms(indices=torch.tensor([0, 1])), + FixCom([0, 1]), + ] # Extract individual systems from the double system state first_system = ar_double_sim_state[0] @@ -196,7 +201,7 @@ def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): # Test constraint manipulation with different configurations ar_double_sim_state.constraints = [] - ar_double_sim_state.constraints = [FixCom()] + ar_double_sim_state.constraints = [FixCom([0, 1])] isolated_system = ar_double_sim_state[0] assert torch.all( isolated_system.constraints[0].system_idx == torch.tensor([0], dtype=torch.long) @@ -224,7 +229,7 @@ def test_fix_com_gradient_descent_optimization( ar_supercell_sim_state.positions = perturbed_positions initial_state = ar_supercell_sim_state - ar_supercell_sim_state.constraints = [FixCom()] + ar_supercell_sim_state.constraints = [FixCom([0])] initial_coms = get_centers_of_mass( positions=initial_state.positions, @@ -289,7 +294,9 @@ def test_fix_atoms_gradient_descent_optimization( @pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) def test_test_atoms_fire_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, fire_flavor: FireFlavor + ar_supercell_sim_state: ts.SimState, + lj_model: ModelInterface, + fire_flavor: FireFlavor, ) -> None: """Test FixAtoms constraint in FIRE optimization.""" # Add some random displacement to positions @@ -333,7 +340,9 @@ def test_test_atoms_fire_optimization( @pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) def test_fix_com_fire_optimization( - ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface, fire_flavor: FireFlavor + ar_supercell_sim_state: ts.SimState, + lj_model: ModelInterface, + fire_flavor: FireFlavor, ) -> None: """Test FixCom constraint in FIRE optimization.""" # Add some random displacement to positions @@ -352,7 +361,7 @@ def test_fix_com_fire_optimization( atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), system_idx=ar_supercell_sim_state.system_idx.clone(), ) - current_sim_state.constraints = [FixCom()] + current_sim_state.constraints = [FixCom([0])] # Initialize FIRE optimizer state = ts.fire_init( @@ -405,7 +414,7 @@ def test_constraint_validation_warnings() -> None: with pytest.warns(UserWarning, match="Multiple constraints.*same atoms"): validate_constraints([FixAtoms(indices=[0, 1, 2]), FixAtoms(indices=[2, 3, 4])]) with pytest.warns(UserWarning, match="FixCom together with other constraints"): - validate_constraints([FixCom(), FixAtoms(indices=[0, 1])]) + validate_constraints([FixCom([0]), FixAtoms(indices=[0, 1])]) def test_constraint_validation_errors( @@ -434,9 +443,9 @@ def test_constraint_validation_errors( ("integrator", "constraint", "n_steps"), [ ("nve", FixAtoms(indices=[0, 1]), 100), - ("nvt_nose_hoover", FixCom(), 200), + ("nvt_nose_hoover", FixCom([0]), 200), ("npt_langevin", FixAtoms(indices=[0, 3]), 200), - ("npt_nose_hoover", FixCom(), 200), + ("npt_nose_hoover", FixCom([0]), 200), ], ) def test_integrators_with_constraints( @@ -511,7 +520,7 @@ def test_multiple_constraints_and_dof( assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n) cu_sim_state.constraints = [FixAtoms(indices=[0])] assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n - 3) - cu_sim_state.constraints = [FixCom(), FixAtoms(indices=[0])] + cu_sim_state.constraints = [FixCom([0]), FixAtoms(indices=[0])] assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n - 6) # Verify both constraints hold during dynamics @@ -562,7 +571,10 @@ def test_cell_optimization_with_constraints( ) ar_supercell_sim_state.constraints = [FixAtoms(indices=[0, 1])] state = ts.fire_init( - ar_supercell_sim_state, lj_model, cell_filter=cell_filter, fire_flavor=fire_flavor + ar_supercell_sim_state, + lj_model, + cell_filter=cell_filter, + fire_flavor=fire_flavor, ) for _ in range(50): state = ts.fire_step(state, lj_model, dt_max=0.1) @@ -575,7 +587,7 @@ def test_batched_constraints(ar_double_sim_state: ts.SimState) -> None: """Test system-specific constraints in batched states.""" s1, s2 = ar_double_sim_state.split() s1.constraints = [FixAtoms(indices=[0, 1])] - s2.constraints = [FixCom()] + s2.constraints = [FixCom([0])] combined = ts.concatenate_states([s1, s2]) assert len(combined.constraints) == 2 assert isinstance(combined.constraints[0], FixAtoms) @@ -597,7 +609,7 @@ def test_constraints_with_non_pbc(lj_model: LennardJonesModel) -> None: atomic_numbers=torch.full((4,), 18, dtype=torch.long), system_idx=torch.zeros(4, dtype=torch.long), ) - state.constraints = [FixCom()] + state.constraints = [FixCom([0])] initial = get_centers_of_mass( state.positions, state.masses, state.system_idx, state.n_systems ) @@ -617,7 +629,7 @@ def test_high_level_api_with_constraints( ) -> None: """Test high-level integrate() and optimize() APIs with constraints.""" # Test integrate() - cu_sim_state.constraints = [FixCom()] + cu_sim_state.constraints = [FixCom([0])] initial_com = get_centers_of_mass( cu_sim_state.positions, cu_sim_state.masses, @@ -655,7 +667,7 @@ def test_temperature_with_constrained_dof( ) -> None: """Test temperature calculation uses constrained DOF.""" target = 300.0 - cu_sim_state.constraints = [FixAtoms(indices=[0, 1]), FixCom()] + cu_sim_state.constraints = [FixAtoms(indices=[0, 1]), FixCom([0])] state = ts.nvt_langevin_init( cu_sim_state, lj_model, diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 9755b952..65e7aea3 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -116,7 +116,7 @@ class AtomIndexedConstraint(Constraint): on a subset of atoms, identified by their indices. """ - def __init__(self, indices: torch.Tensor | list[int] | None = None) -> None: + def __init__(self, indices: torch.Tensor | list[int]) -> None: """Initialize indexed constraint. Args: @@ -126,11 +126,6 @@ def __init__(self, indices: torch.Tensor | list[int] | None = None) -> None: ValueError: If both indices and mask are provided, or if indices have wrong shape/type """ - if indices is None: - # Empty constraint - self.indices = torch.empty(0, dtype=torch.long) - return - # Convert to tensor if needed if not isinstance(indices, torch.Tensor): indices = torch.tensor(indices) @@ -211,7 +206,7 @@ class SystemConstraint(Constraint): on a subset of systems, identified by their indices. """ - def __init__(self, system_idx: torch.Tensor | list[int] | None = None) -> None: + def __init__(self, system_idx: torch.Tensor | list[int]) -> None: """Initialize indexed constraint. Args: @@ -222,13 +217,6 @@ def __init__(self, system_idx: torch.Tensor | list[int] | None = None) -> None: ValueError: If both indices and mask are provided, or if indices have wrong shape/type """ - self.initialized = True - if system_idx is None: - # Empty constraint - self.system_idx = torch.empty(0, dtype=torch.long) - self.initialized = False - return - # Convert to tensor if needed system_idx = torch.as_tensor(system_idx) @@ -239,6 +227,7 @@ def __init__(self, system_idx: torch.Tensor | list[int] | None = None) -> None: "system_idx has wrong number of dimensions. " f"Got {system_idx.ndim}, expected ndim <= 1" ) + self.system_idx = system_idx def update_constraint( self, @@ -510,7 +499,7 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: def __repr__(self) -> str: """String representation of the constraint.""" - return "FixCom()" + return f"FixCom(system_idx={self.system_idx})" def count_degrees_of_freedom( diff --git a/torch_sim/state.py b/torch_sim/state.py index 9ce49121..2b7255ad 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -23,12 +23,7 @@ from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure -from torch_sim.constraints import ( - Constraint, - SystemConstraint, - merge_constraints, - validate_constraints, -) +from torch_sim.constraints import Constraint, merge_constraints, validate_constraints @dataclass @@ -273,12 +268,6 @@ def constraints(self, constraints: list[Constraint] | Constraint) -> None: # check it is a list if isinstance(constraints, Constraint): constraints = [constraints] - for constraint in constraints: - if (isinstance(constraint, SystemConstraint)) and ( - not constraint.initialized - ): - constraint.system_idx = torch.arange(self.n_systems, device=self.device) - constraint.initialized = True # Validate new constraints before adding validate_constraints(constraints, state=self) @@ -999,8 +988,8 @@ def concatenate_states[T: SimState]( # noqa: C901 # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) + # Merge constraints constraint_lists = [state.constraints for state in states] - constraints = merge_constraints( constraint_lists, torch.tensor(num_atoms_per_state, device=target_device) ) From 35749c30a28c7bd40f93c830991b469dd803ea6c Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 21 Nov 2025 11:44:30 -0500 Subject: [PATCH 20/43] vibe code and verify some tests --- tests/test_constraints.py | 156 +++++++++++++++++++++++++++++++++++++- 1 file changed, 155 insertions(+), 1 deletion(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index bbcf8f04..24708693 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -5,7 +5,13 @@ import torch_sim as ts from tests.conftest import DTYPE -from torch_sim.constraints import Constraint, FixAtoms, FixCom, validate_constraints +from torch_sim.constraints import ( + Constraint, + FixAtoms, + FixCom, + merge_constraints, + validate_constraints, +) from torch_sim.models.interface import ModelInterface from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.optimizers import FireFlavor @@ -686,3 +692,151 @@ def test_temperature_with_constrained_dof( temps.append(temp / MetalUnits.temperature) avg = torch.mean(torch.stack(temps)[500:]) assert abs(avg - target) / target < 0.30 + + +def test_system_constraint_update_and_select() -> None: + """Test update_constraint and select_sub_constraint for SystemConstraint.""" + # Create a FixCom constraint for systems 0, 1, 2 + constraint = FixCom([0, 1, 2]) + + # Test update_constraint with system_mask + # Keep systems 0 and 2 (drop system 1) + atom_mask = torch.ones(10, dtype=torch.bool) + system_mask = torch.tensor([True, False, True], dtype=torch.bool) + updated_constraint = constraint.update_constraint(atom_mask, system_mask) + + # System indices should be renumbered: [0, 2] -> [0, 1] + assert torch.all(updated_constraint.system_idx == torch.tensor([0, 1])) + + # Test select_sub_constraint + # Select system 1 from the original constraint + constraint = FixCom([0, 1, 2]) + atom_idx = torch.arange(5, 10) # Atoms for a specific system + sys_idx = 1 + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx) + + # Should return a constraint with system_idx = [0] (renumbered from 1) + assert sub_constraint is not None + assert torch.all(sub_constraint.system_idx == torch.tensor([0])) + + # Test when system is not in constraint + constraint = FixCom([0, 2]) + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx=1) + assert sub_constraint is None + + +def test_atom_indexed_constraint_update_and_select() -> None: + """Test update_constraint and select_sub_constraint for AtomIndexedConstraint.""" + # Create a FixAtoms constraint for atoms 0, 1, 5, 8 + constraint = FixAtoms(indices=[0, 1, 5, 8]) + + # Test update_constraint with atom_mask + # Keep atoms 0, 1, 2, 3, 5, 6, 7, 8 (drop atoms 4) + atom_mask = torch.tensor( + [True, True, True, True, False, True, True, True, True], dtype=torch.bool + ) + system_mask = torch.ones(2, dtype=torch.bool) + updated_constraint = constraint.update_constraint(atom_mask, system_mask) + + # Atom indices should be renumbered: + # Original: [0, 1, 5, 8] + # After dropping atom 4: [0, 1, 4, 7] (indices shift down by 1 after index 4) + assert torch.all(updated_constraint.indices == torch.tensor([0, 1, 4, 7])) + + # Test select_sub_constraint + # Select atoms that belong to a specific system + constraint = FixAtoms(indices=[0, 1, 5, 8]) + atom_idx = torch.tensor([0, 1, 2, 3, 4]) # Atoms for first system + sys_idx = 0 + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx) + + # Should return a constraint with only atoms 0, 1 (within atom_idx range) + # Renumbered to start from 0 + assert sub_constraint is not None + assert torch.all(sub_constraint.indices == torch.tensor([0, 1])) + + # Test with different atom range + constraint = FixAtoms(indices=[0, 1, 5, 8]) + atom_idx = torch.tensor([5, 6, 7, 8, 9]) # Atoms for second system + sys_idx = 1 + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx) + + # Should return a constraint with atoms 5, 8 renumbered to [0, 3] + assert sub_constraint is not None + assert torch.all(sub_constraint.indices == torch.tensor([0, 3])) + + # Test when no atoms in range + constraint = FixAtoms(indices=[0, 1]) + atom_idx = torch.tensor([5, 6, 7, 8]) + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx=1) + assert sub_constraint is None + + +def test_merge_constraints(ar_double_sim_state: ts.SimState) -> None: + """Test merge_constraints combines constraints from multiple systems.""" + # Split the double system state + s1, s2 = ar_double_sim_state.split() + n_atoms_s1 = s1.n_atoms + n_atoms_s2 = s2.n_atoms + + # Create constraints for each system + # System 1: Fix atoms 0, 1 and fix COM for system 0 + s1_constraints = [ + FixAtoms(indices=[0, 1]), + FixCom([0]), + ] + + # System 2: Fix atoms 2, 3 and fix COM for system 0 + s2_constraints = [ + FixAtoms(indices=[2, 3]), + FixCom([0]), + ] + + # Merge constraints + constraint_lists = [s1_constraints, s2_constraints] + num_atoms_per_state = torch.tensor([n_atoms_s1, n_atoms_s2]) + merged_constraints = merge_constraints(constraint_lists, num_atoms_per_state) + + # Should have 2 constraints: one FixAtoms and one FixCom + assert len(merged_constraints) == 2 + + # Find FixAtoms and FixCom in merged list + fix_atoms = None + fix_com = None + for constraint in merged_constraints: + if isinstance(constraint, FixAtoms): + fix_atoms = constraint + elif isinstance(constraint, FixCom): + fix_com = constraint + + assert fix_atoms is not None + assert fix_com is not None + + # FixAtoms should have indices [0, 1] from s1 and [2+n_atoms_s1, 3+n_atoms_s1] from s2 + expected_atom_indices = torch.tensor([0, 1, 2 + n_atoms_s1, 3 + n_atoms_s1]) + assert torch.all(fix_atoms.indices == expected_atom_indices) + + # FixCom should have system_idx [0, 1] (one for each original system) + expected_system_indices = torch.tensor([0, 1]) + assert torch.all(fix_com.system_idx == expected_system_indices) + + # Test with three systems + s3 = s1.clone() + s3_constraints = [FixAtoms(indices=[0])] + constraint_lists = [s1_constraints, s2_constraints, s3_constraints] + num_atoms_per_state = torch.tensor([n_atoms_s1, n_atoms_s2, s3.n_atoms]) + merged_constraints = merge_constraints(constraint_lists, num_atoms_per_state) + + # Find FixAtoms + fix_atoms = None + for constraint in merged_constraints: + if isinstance(constraint, FixAtoms): + fix_atoms = constraint + break + + assert fix_atoms is not None + # Should include atoms from all three systems with proper offsets + expected_atom_indices = torch.tensor( + [0, 1, 2 + n_atoms_s1, 3 + n_atoms_s1, 0 + n_atoms_s1 + n_atoms_s2] + ) + assert torch.all(fix_atoms.indices == expected_atom_indices) From 0688bfe94d060fa4bbee0c672ba2dbc3599c135a Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 15:39:39 +0100 Subject: [PATCH 21/43] rename update_constraint to select_constraint, remove None Constraint for select, Update test --- tests/test_constraints.py | 25 +++++++++++------------ torch_sim/constraints.py | 43 ++++++++++++++++++++------------------- torch_sim/state.py | 8 +++++++- 3 files changed, 41 insertions(+), 35 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 24708693..1d3f2204 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -170,25 +170,24 @@ def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): """Test that constraints are properly propagated during state manipulation.""" # Set up constraints on the original state ar_double_sim_state.constraints = [ - FixAtoms(indices=torch.tensor([0, 1])), + FixAtoms(indices=torch.tensor([0, 1])), # Only applied to first system FixCom([0, 1]), ] # Extract individual systems from the double system state - first_system = ar_double_sim_state[0] - second_system = ar_double_sim_state[1] + first_system = ar_double_sim_state[0] # FixAtoms + FixCom + second_system = ar_double_sim_state[1] # FixCom only concatenated_state = ts.concatenate_states( [first_system, first_system, second_system] ) # Verify constraint propagation to subsystems assert len(first_system.constraints) == 2 - assert len(second_system.constraints) == 2 + assert len(second_system.constraints) == 1 assert len(concatenated_state.constraints) == 2 # Verify FixAtoms constraint indices are correctly mapped assert torch.all(first_system.constraints[0].indices == torch.tensor([0, 1])) - assert torch.all(second_system.constraints[0].indices == torch.tensor([])) assert torch.all( concatenated_state.constraints[0].indices == torch.tensor([0, 1, 32, 33]) ) @@ -673,7 +672,7 @@ def test_temperature_with_constrained_dof( ) -> None: """Test temperature calculation uses constrained DOF.""" target = 300.0 - cu_sim_state.constraints = [FixAtoms(indices=[0, 1]), FixCom([0])] + cu_sim_state.constraints = [FixAtoms(indices=[0, 1, 2])] state = ts.nvt_langevin_init( cu_sim_state, lj_model, @@ -681,7 +680,7 @@ def test_temperature_with_constrained_dof( seed=42, ) temps = [] - for _ in range(1000): + for _ in range(4000): state = ts.nvt_langevin_step( state, lj_model, @@ -695,15 +694,15 @@ def test_temperature_with_constrained_dof( def test_system_constraint_update_and_select() -> None: - """Test update_constraint and select_sub_constraint for SystemConstraint.""" + """Test select_constraint and select_sub_constraint for SystemConstraint.""" # Create a FixCom constraint for systems 0, 1, 2 constraint = FixCom([0, 1, 2]) - # Test update_constraint with system_mask + # Test select_constraint with system_mask # Keep systems 0 and 2 (drop system 1) atom_mask = torch.ones(10, dtype=torch.bool) system_mask = torch.tensor([True, False, True], dtype=torch.bool) - updated_constraint = constraint.update_constraint(atom_mask, system_mask) + updated_constraint = constraint.select_constraint(atom_mask, system_mask) # System indices should be renumbered: [0, 2] -> [0, 1] assert torch.all(updated_constraint.system_idx == torch.tensor([0, 1])) @@ -726,17 +725,17 @@ def test_system_constraint_update_and_select() -> None: def test_atom_indexed_constraint_update_and_select() -> None: - """Test update_constraint and select_sub_constraint for AtomIndexedConstraint.""" + """Test select_constraint and select_sub_constraint for AtomIndexedConstraint.""" # Create a FixAtoms constraint for atoms 0, 1, 5, 8 constraint = FixAtoms(indices=[0, 1, 5, 8]) - # Test update_constraint with atom_mask + # Test select_constraint with atom_mask # Keep atoms 0, 1, 2, 3, 5, 6, 7, 8 (drop atoms 4) atom_mask = torch.tensor( [True, True, True, True, False, True, True, True, True], dtype=torch.bool ) system_mask = torch.ones(2, dtype=torch.bool) - updated_constraint = constraint.update_constraint(atom_mask, system_mask) + updated_constraint = constraint.select_constraint(atom_mask, system_mask) # Atom indices should be renumbered: # Original: [0, 1, 5, 8] diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 65e7aea3..8ddc57fd 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -12,7 +12,7 @@ import warnings from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self import torch @@ -78,9 +78,9 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: """ @abstractmethod - def update_constraint( + def select_constraint( self, atom_mask: torch.Tensor, system_mask: torch.Tensor - ) -> Constraint: + ) -> None | Self: """Update the constraint to account for atom and system masks. Args: @@ -89,7 +89,7 @@ def update_constraint( """ @abstractmethod - def select_sub_constraint(self, atom_idx: torch.Tensor, sys_idx: int) -> Constraint: + def select_sub_constraint(self, atom_idx: torch.Tensor, sys_idx: int) -> None | Self: """Select a constraint for a given atom and system index. Args: @@ -166,25 +166,28 @@ def get_indices(self) -> torch.Tensor: """ return self.indices.clone() - def update_constraint( + def select_constraint( self, atom_mask: torch.Tensor, system_mask: torch.Tensor, # noqa: ARG002 - ) -> Constraint: + ) -> None | Self: """Update the constraint to account for atom and system masks. Args: atom_mask: Boolean mask for atoms to keep system_mask: Boolean mask for systems to keep """ - self.indices = _mask_constraint_indices(self.indices, atom_mask) - return self + indices = self.indices.clone() + indices = _mask_constraint_indices(indices, atom_mask) + if len(indices) == 0: + return None + return type(self)(indices) def select_sub_constraint( self, atom_idx: torch.Tensor, sys_idx: int, # noqa: ARG002 - ) -> Constraint: + ) -> None | Self: """Select a constraint for a given atom and system index. Args: @@ -227,39 +230,37 @@ def __init__(self, system_idx: torch.Tensor | list[int]) -> None: "system_idx has wrong number of dimensions. " f"Got {system_idx.ndim}, expected ndim <= 1" ) - self.system_idx = system_idx + self.system_idx: torch.Tensor = system_idx - def update_constraint( + def select_constraint( self, atom_mask: torch.Tensor, # noqa: ARG002 system_mask: torch.Tensor, - ) -> Constraint: + ) -> None | Self: """Update the constraint to account for atom and system masks. Args: atom_mask: Boolean mask for atoms to keep system_mask: Boolean mask for systems to keep """ - self.system_idx = _mask_constraint_indices(self.system_idx, system_mask) - return self + system_idx = self.system_idx.clone() + system_idx = _mask_constraint_indices(system_idx, system_mask) + if len(system_idx) == 0: + return None + return type(self)(system_idx) def select_sub_constraint( self, atom_idx: torch.Tensor, # noqa: ARG002 sys_idx: int, - ) -> Constraint: + ) -> None | Self: """Select a constraint for a given atom and system index. Args: atom_idx: Atom indices for a single system sys_idx: System index for a single system """ - mask = torch.isin(self.system_idx, sys_idx) - masked_system_idx = self.system_idx[mask] - new_system_idx = masked_system_idx - sys_idx - if len(new_system_idx) == 0: - return None - return type(self)(new_system_idx) + return type(self)(torch.tensor([0])) if sys_idx in self.system_idx else None def merge_constraints( diff --git a/torch_sim/state.py b/torch_sim/state.py index 2b7255ad..c573d322 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -713,9 +713,15 @@ def _filter_attrs_by_mask( # take into account constraints that are AtomIndexedConstraint filtered_attrs["_constraints"] = [ - constraint.update_constraint(atom_mask, system_mask) + constraint.select_constraint(atom_mask, system_mask) for constraint in copy.deepcopy(state.constraints) ] + # Remove any None constraints resulting from selection + filtered_attrs["_constraints"] = [ + constraint + for constraint in filtered_attrs["_constraints"] + if constraint is not None + ] # Filter per-atom attributes for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): From be161e31aafd6917a6a2ddc352cee3745b0a0ce1 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 15:52:18 +0100 Subject: [PATCH 22/43] change to _constraint name --- torch_sim/monte_carlo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index c17bf6ae..3e0ef47d 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -214,7 +214,7 @@ def swap_mc_init( system_idx=state.system_idx, energy=model_output["energy"], last_permutation=torch.arange(state.n_atoms, device=state.device), - constraints=state.constraints, + _constraints=state.constraints, ) From 6afab52f8c76d685b592e9733747b7c81eee4e57 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 15:53:21 +0100 Subject: [PATCH 23/43] revert to previous return as it actually also change the device/dtype (taken from variables) --- torch_sim/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index c573d322..9c4c022e 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -659,7 +659,7 @@ def _state_to_device[T: SimState]( attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) - return state + return type(state)(**attrs) def get_attrs_for_scope( From 4aa144790171e0dd3125e20fa17ee3b22f9048f5 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 16:01:20 +0100 Subject: [PATCH 24/43] use post_init to enforce constraint on forces --- torch_sim/optimizers/state.py | 33 +++++---------------------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index 8f8885d0..bd652857 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -7,7 +7,7 @@ from torch_sim.state import SimState -@dataclass(kw_only=True, init=False) +@dataclass(kw_only=True) class OptimState(SimState): """Unified state class for optimization algorithms. @@ -25,36 +25,13 @@ class OptimState(SimState): def set_forces(self, new_forces: torch.Tensor) -> None: """Set new forces in the optimization state.""" - for constraint in self.constraints: + for constraint in self._constraints: constraint.adjust_forces(self, new_forces) self.forces = new_forces - def __init__( - self, - *, - positions: torch.Tensor, - forces: torch.Tensor, - energy: torch.Tensor, - stress: torch.Tensor | None = None, - masses: torch.Tensor, - cell: torch.Tensor, - pbc: torch.Tensor, - atomic_numbers: torch.Tensor, - system_idx: torch.Tensor, - _constraints: list | None = None, - ) -> None: - """Initialize optimization state.""" - self.positions = positions - self.masses = masses - self.cell = cell - self.pbc = pbc - self.atomic_numbers = atomic_numbers - self.system_idx = system_idx - self._constraints = _constraints - self.energy = energy - self.set_forces(forces) - self.stress = stress - SimState.__post_init__(self) + def __post_init__(self) -> None: + """Post-initialization to ensure SimState setup.""" + self.set_forces(self.forces) @dataclass(kw_only=True) From e61e452f08842d23fe2eacb301661c4a616263fa Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 16:01:35 +0100 Subject: [PATCH 25/43] constraint is not a global_attrs anymore --- tests/test_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_state.py b/tests/test_state.py index b3332aeb..1b293668 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -30,7 +30,7 @@ def test_get_attrs_for_scope(si_sim_state: SimState) -> None: per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) assert set(per_system_attrs) == {"cell"} global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) - assert set(global_attrs) == {"pbc", "constraints"} + assert set(global_attrs) == {"pbc"} def test_all_attributes_must_be_specified_in_scopes() -> None: From 8144ed63f4ae8d238d42d7b1c5d3764808311c0f Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 16:09:19 +0100 Subject: [PATCH 26/43] increase slightly steps to test FixCom --- tests/test_constraints.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 1d3f2204..58b5854a 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -645,11 +645,10 @@ def test_high_level_api_with_constraints( cu_sim_state, lj_model, integrator=ts.Integrator.nvt_langevin, - n_steps=1, + n_steps=50, temperature=300.0, timestep=0.001, ) - final_com = get_centers_of_mass( final.positions, final.masses, final.system_idx, final.n_systems ) From 940827b34384498b64d701e47c556c8f771ff85a Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 16:09:41 +0100 Subject: [PATCH 27/43] add _constraint to attributes so that it's kept when cloning simstate --- torch_sim/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 9c4c022e..4b571eca 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -204,6 +204,7 @@ def attributes(self) -> dict[str, torch.Tensor]: for attr in self._atom_attributes | self._system_attributes | self._global_attributes + | {"_constraints"} } @property @@ -309,7 +310,6 @@ def clone(self) -> Self: attrs[attr_name] = attr_value.clone() else: attrs[attr_name] = copy.deepcopy(attr_value) - attrs["_constraints"] = copy.deepcopy(self.constraints) return type(self)(**attrs) From 1cbd0b00b842f1ff4cdd766b67a313fb680d09dd Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 16:24:14 +0100 Subject: [PATCH 28/43] compute com for all and only subselect depending on system_idx, remove slice(None) and add FixCom.coms --- torch_sim/constraints.py | 84 ++++++++++++++-------------------------- 1 file changed, 30 insertions(+), 54 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 8ddc57fd..b763f044 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -378,6 +378,8 @@ class FixCom(SystemConstraint): The constraint is applied to all atoms in the system. """ + coms: torch.Tensor | None = None + def get_removed_dof(self, state: SimState) -> torch.Tensor: """Get number of removed degrees of freedom. @@ -389,11 +391,9 @@ def get_removed_dof(self, state: SimState) -> torch.Tensor: Returns: Always returns 3 (center of mass translation degrees of freedom) """ - if self.system_idx != slice(None): - affected_systems = torch.zeros(state.n_systems, dtype=torch.long) - affected_systems[self.system_idx] = 1 - return 3 * affected_systems - return 3 * torch.ones(state.n_systems, dtype=torch.long) + affected_systems = torch.zeros(state.n_systems, dtype=torch.long) + affected_systems[self.system_idx] = 1 + return 3 * affected_systems def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: """Adjust positions to maintain center of mass position. @@ -403,35 +403,27 @@ def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None new_positions: Proposed positions to be adjusted in-place """ dtype = state.positions.dtype - n_systems = ( - state.n_systems if self.system_idx == slice(None) else len(self.system_idx) - ) - index_to_consider = ( - torch.isin(state.system_idx, self.system_idx) - if self.system_idx != slice(None) - else torch.ones(state.n_atoms, dtype=torch.bool) + system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + 0, state.system_idx, state.masses ) - system_mass = torch.zeros(n_systems, dtype=dtype).scatter_add_( - 0, state.system_idx[index_to_consider], state.masses[index_to_consider] - ) - if not hasattr(self, "coms"): - self.coms = torch.zeros((n_systems, 3), dtype=dtype).scatter_add_( + if self.coms is None: + self.coms = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( 0, - state.system_idx[index_to_consider].unsqueeze(-1).expand(-1, 3), - state.masses[index_to_consider].unsqueeze(-1) - * state.positions[index_to_consider], + state.system_idx.unsqueeze(-1).expand(-1, 3), + state.masses.unsqueeze(-1) * state.positions, ) self.coms /= system_mass.unsqueeze(-1) - new_com = torch.zeros((n_systems, 3), dtype=dtype).scatter_add_( + new_com = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( 0, - state.system_idx[index_to_consider].unsqueeze(-1).expand(-1, 3), - state.masses[index_to_consider].unsqueeze(-1) - * new_positions[index_to_consider], + state.system_idx.unsqueeze(-1).expand(-1, 3), + state.masses.unsqueeze(-1) * new_positions, ) new_com /= system_mass.unsqueeze(-1) displacement = torch.zeros(state.n_systems, 3, dtype=dtype) - displacement[self.system_idx] = -new_com + self.coms + displacement[self.system_idx] = ( + -new_com[self.system_idx] + self.coms[self.system_idx] + ) new_positions += displacement[state.system_idx] def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: @@ -443,25 +435,17 @@ def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: """ # Compute center of mass momenta dtype = momenta.dtype - n_systems = ( - state.n_systems if self.system_idx == slice(None) else len(self.system_idx) - ) - index_to_consider = ( - torch.isin(state.system_idx, self.system_idx) - if self.system_idx != slice(None) - else torch.ones(state.n_atoms, dtype=torch.bool) - ) - com_momenta = torch.zeros((n_systems, 3), dtype=dtype).scatter_add_( + com_momenta = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( 0, - state.system_idx[index_to_consider].unsqueeze(-1).expand(-1, 3), - momenta[index_to_consider], + state.system_idx.unsqueeze(-1).expand(-1, 3), + momenta, ) - system_mass = torch.zeros(n_systems, dtype=dtype).scatter_add_( - 0, state.system_idx[index_to_consider], state.masses[index_to_consider] + system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + 0, state.system_idx, state.masses ) velocity_com = com_momenta / system_mass.unsqueeze(-1) velocity_change = torch.zeros(state.n_systems, 3, dtype=dtype) - velocity_change[self.system_idx] = velocity_com + velocity_change[self.system_idx] = velocity_com[self.system_idx] momenta -= velocity_change[state.system_idx] * state.masses.unsqueeze(-1) def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: @@ -475,27 +459,19 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: forces: Forces to be adjusted in-place """ dtype = state.positions.dtype - n_systems = ( - state.n_systems if self.system_idx == slice(None) else len(self.system_idx) - ) - index_to_consider = ( - torch.isin(state.system_idx, self.system_idx) - if self.system_idx != slice(None) - else torch.ones(state.n_atoms, dtype=torch.bool) - ) - system_square_mass = torch.zeros(n_systems, dtype=dtype).scatter_add_( + system_square_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( 0, - state.system_idx[index_to_consider], - torch.square(state.masses[index_to_consider]), + state.system_idx, + torch.square(state.masses), ) - lmd = torch.zeros((n_systems, 3), dtype=dtype).scatter_add_( + lmd = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( 0, - state.system_idx[index_to_consider].unsqueeze(-1).expand(-1, 3), - forces[index_to_consider] * state.masses[index_to_consider].unsqueeze(-1), + state.system_idx.unsqueeze(-1).expand(-1, 3), + forces * state.masses.unsqueeze(-1), ) lmd /= system_square_mass.unsqueeze(-1) forces_change = torch.zeros(state.n_systems, 3, dtype=dtype) - forces_change[self.system_idx] = lmd + forces_change[self.system_idx] = lmd[self.system_idx] forces -= forces_change[state.system_idx] * state.masses.unsqueeze(-1) def __repr__(self) -> str: From 6e098957d3acb705ae3279ab069cc428dc89282c Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 17:13:31 +0100 Subject: [PATCH 29/43] remove comments --- torch_sim/optimizers/fire.py | 3 --- torch_sim/optimizers/gradient_descent.py | 1 - 2 files changed, 4 deletions(-) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index d1c2454d..9163fa71 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -212,14 +212,12 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) # Position update - # state.positions = state.positions + atom_wise_dt * state.velocities state.set_positions(state.positions + atom_wise_dt * state.velocities) # Cell position updates are handled in the velocity update step above # Get new forces and energy model_output = model(state) - # state.forces = model_output["forces"] state.set_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: @@ -468,7 +466,6 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 # Get new forces, energy, and stress model_output = model(state) - # state.forces = model_output["forces"] state.set_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index 246aee1a..4a563b86 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -108,7 +108,6 @@ def gradient_descent_step( atom_lr = pos_lr[state.system_idx].unsqueeze(-1) # Update atomic positions - # state.positions = state.positions + atom_lr * state.forces state.set_positions(state.positions + atom_lr * state.forces) # Update cell if using cell optimization From 7d8890f3517f5c754b8b0d72b5df38cc4e6e62fe Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 17:13:43 +0100 Subject: [PATCH 30/43] remove comment and raise if dof is negative --- torch_sim/state.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 4b571eca..5426d67d 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -293,7 +293,9 @@ def get_number_of_degrees_of_freedom(self) -> torch.Tensor: dof_per_system -= removed_dof # Ensure non-negative DOF - return torch.clamp(dof_per_system, min=0) + if (dof_per_system <= 0).any(): + raise ValueError("Degrees of freedom cannot be zero or negative") + return dof_per_system def clone(self) -> Self: """Create a deep copy of the SimState. @@ -707,7 +709,6 @@ def _filter_attrs_by_mask( Returns: dict: Filtered attributes with appropriate handling for each scope """ - # atoms_mask = torch.isin(state.system_idx, torch.nonzero(system_mask).squeeze()) # Copy global attributes directly filtered_attrs = dict(get_attrs_for_scope(state, "global")) From be55c9b877aeb76de1d29010d24e80af19af2834 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 17:14:32 +0100 Subject: [PATCH 31/43] remove unwrap_pos and add dummy state to test for validate_constraints --- tests/test_constraints.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 58b5854a..c80d15ec 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -15,7 +15,7 @@ from torch_sim.models.interface import ModelInterface from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.optimizers import FireFlavor -from torch_sim.transforms import get_centers_of_mass, unwrap_positions +from torch_sim.transforms import get_centers_of_mass from torch_sim.units import MetalUnits @@ -116,9 +116,6 @@ def test_fix_com_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJonesM traj_positions = torch.stack(positions) - # unwrapped_positions = unwrap_positions( - # traj_positions, ar_double_sim_state.cell, state.system_idx - # ) coms = torch.zeros((n_steps, state.n_systems, 3), dtype=DTYPE).scatter_add_( 1, state.system_idx[None, :, None].expand(n_steps, -1, 3), @@ -157,10 +154,7 @@ def test_fix_atoms_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJone temperatures = torch.stack(temperatures) traj_positions = torch.stack(positions) - unwrapped_positions = unwrap_positions( - traj_positions, cu_sim_state.cell, state.system_idx - ) - diff_positions = unwrapped_positions - unwrapped_positions[0] + diff_positions = traj_positions - traj_positions[0] assert torch.max(diff_positions[:, :2]) < 1e-8 assert torch.max(diff_positions[:, 2:]) > 1e-3 assert (torch.mean(temperatures[len(temperatures) // 2 :]) - 300) / 300 < 0.30 @@ -414,12 +408,15 @@ def test_fix_atoms_validation() -> None: FixAtoms(indices=torch.tensor([[0, 1]])) -def test_constraint_validation_warnings() -> None: +def test_constraint_validation_warnings(ar_double_sim_state: ts.SimState) -> None: """Test validation warnings for constraint conflicts.""" with pytest.warns(UserWarning, match="Multiple constraints.*same atoms"): - validate_constraints([FixAtoms(indices=[0, 1, 2]), FixAtoms(indices=[2, 3, 4])]) + validate_constraints( + [FixAtoms(indices=[0, 1, 2]), FixAtoms(indices=[2, 3, 4])], + ar_double_sim_state, + ) with pytest.warns(UserWarning, match="FixCom together with other constraints"): - validate_constraints([FixCom([0]), FixAtoms(indices=[0, 1])]) + validate_constraints([FixCom([0]), FixAtoms(indices=[0, 1])], ar_double_sim_state) def test_constraint_validation_errors( From 33c6e9208f2e23434411596760f73e62eb576a1a Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 17:14:50 +0100 Subject: [PATCH 32/43] ruff happy, simplify function --- torch_sim/constraints.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index b763f044..b22dacb9 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -506,9 +506,19 @@ def count_degrees_of_freedom( return max(0, total_dof) # Ensure non-negative -def validate_constraints( - constraints: list[Constraint], state: SimState | None = None +def check_no_index_out_of_bounds( + indices: torch.Tensor, max_state_indices: int, constraint_name: str ) -> None: + """Check that constraint indices are within bounds of the state.""" + if (len(indices) > 0) and (indices.max() >= max_state_indices): + raise ValueError( + f"Constraint {constraint_name} has indices up to " + f"{indices.max()}, but state only has {max_state_indices} " + "atoms" + ) + + +def validate_constraints(constraints: list[Constraint], state: SimState) -> None: """Validate constraints for potential issues and incompatibilities. This function checks for: @@ -518,7 +528,7 @@ def validate_constraints( Args: constraints: List of constraints to validate - state: Optional SimState for validating atom indices belong to same system + state: SimState to check against Raises: ValueError: If constraints are invalid or span multiple systems @@ -537,18 +547,15 @@ def validate_constraints( indexed_constraints.append(constraint) # Validate that atom indices exist in state if provided - if ( - (state is not None) - and (len(constraint.indices) > 0) - and (constraint.indices.max() >= state.n_atoms) - ): - raise ValueError( - f"Constraint {type(constraint).__name__} has indices up to " - f"{constraint.indices.max()}, but state only has {state.n_atoms} " - "atoms" - ) + check_no_index_out_of_bounds( + constraint.indices, state.n_atoms, type(constraint).__name__ + ) + elif isinstance(constraint, SystemConstraint): + check_no_index_out_of_bounds( + constraint.system_idx, state.n_systems, type(constraint).__name__ + ) - elif isinstance(constraint, FixCom): + if isinstance(constraint, FixCom): has_com_constraint = True # Check for overlapping atom indices From d99a1a76600c3691773e8b42a4c1535b19a19e25 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 17:14:58 +0100 Subject: [PATCH 33/43] test for unwrap_positions --- tests/test_transforms.py | 64 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 067abad8..6d07ce3d 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -7,6 +7,8 @@ import torch_sim as ts import torch_sim.transforms as ft from tests.conftest import DEVICE, DTYPE +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.units import MetalUnits def test_inverse_box_scalar() -> None: @@ -1289,3 +1291,65 @@ def test_build_linked_cell_neighborhood_basic() -> None: # Verify that there are neighbors from both batches assert torch.any(system_mapping == 0) assert torch.any(system_mapping == 1) + + +def test_unwrap_positions(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): + n_steps = 50 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + # Same cell + state = ts.nvt_langevin_init( + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 + ) + state.positions = ft.pbc_wrap_batched(state.positions, state.cell, state.system_idx) + positions = [state.positions.detach().clone()] + for _step in range(n_steps): + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) + positions.append(state.positions.detach().clone()) + + positions = torch.stack(positions) + wrapped_positions = torch.stack( + [ + ft.pbc_wrap_batched(positions, state.cell, state.system_idx) + for positions in positions + ] + ) + unwrapped_positions = ft.unwrap_positions( + wrapped_positions, + state.cell, + state.system_idx, + ) + assert torch.allclose(unwrapped_positions, positions, atol=1e-5) + + # Different cell + state = ts.npt_langevin_init( + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42, dt=dt + ) + state.positions = ft.pbc_wrap_batched(state.positions, state.cell, state.system_idx) + positions = [state.positions.detach().clone()] + cells = [state.cell.detach().clone()] + for _step in range(n_steps): + state = ts.npt_langevin_step( + model=lj_model, + state=state, + dt=dt, + kT=kT, + external_pressure=torch.tensor(0.0, dtype=DTYPE, device=DEVICE), + ) + positions.append(state.positions.detach().clone()) + cells.append(state.cell.detach().clone()) + + positions = torch.stack(positions) + wrapped_positions = torch.stack( + [ + ft.pbc_wrap_batched(positions, cell, state.system_idx) + for positions, cell in zip(positions, cells, strict=True) + ] + ) + unwrapped_positions = ft.unwrap_positions( + wrapped_positions, + state.cell, + state.system_idx, + ) + assert torch.allclose(unwrapped_positions, positions, atol=1e-5) From eb26975e33227c953a19ceb1e91ded661f2a3bdb Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 17:20:37 +0100 Subject: [PATCH 34/43] silence ruff --- torch_sim/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index a08d68ba..3c5feace 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -112,7 +112,7 @@ def pbc(self) -> torch.Tensor: _system_attributes: ClassVar[set[str]] = {"cell"} _global_attributes: ClassVar[set[str]] = {"pbc"} - def __post_init__(self) -> None: + def __post_init__(self) -> None: # noqa: C901 """Initialize the SimState and validate the arguments.""" # Check that positions, masses and atomic numbers have compatible shapes shapes = [ From c15a0127c8a30ed7f5b4a5f3aa6261b955230942 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 17:29:28 +0100 Subject: [PATCH 35/43] modify args names --- torch_sim/transforms.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index aa4d9265..28acb977 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -1178,15 +1178,15 @@ def safe_mask( def unwrap_positions( - pos: torch.Tensor, box: torch.Tensor, system_idx: torch.Tensor + positions: torch.Tensor, cells: torch.Tensor, system_idx: torch.Tensor ) -> torch.Tensor: """Vectorized unwrapping for multiple systems without explicit loops. Parameters ---------- - pos : (T, N_tot, 3) + positions : (T, N_tot, 3) Wrapped cartesian positions for all systems concatenated. - box : (n_systems, 3, 3) or (T, n_systems, 3, 3) + cells : (n_systems, 3, 3) or (T, n_systems, 3, 3) Box matrices, constant or time-dependent. system_idx : (N_tot,) For each atom, which system it belongs to (0..n_systems-1). @@ -1197,15 +1197,15 @@ def unwrap_positions( Unwrapped cartesian positions. """ # -- Constant boxes per system - if box.ndim == 3: - inv_box = torch.inverse(box) # (n_systems, 3, 3) + if cells.ndim == 3: + inv_box = torch.inverse(cells) # (n_systems, 3, 3) # Map each atom to its system's box inv_box_atoms = inv_box[system_idx] # (N, 3, 3) - box_atoms = box[system_idx] # (N, 3, 3) + box_atoms = cells[system_idx] # (N, 3, 3) # Compute fractional coordinates - frac = torch.einsum("tni,nij->tnj", pos, inv_box_atoms) + frac = torch.einsum("tni,nij->tnj", positions, inv_box_atoms) # Fractional displacements and unwrap dfrac = frac[1:] - frac[:-1] @@ -1215,15 +1215,15 @@ def unwrap_positions( dcart = torch.einsum("tni,nij->tnj", dfrac, box_atoms) # -- Time-dependent boxes per system - elif box.ndim == 4: - inv_box = torch.inverse(box) # (T, n_systems, 3, 3) + elif cells.ndim == 4: + inv_box = torch.inverse(cells) # (T, n_systems, 3, 3) # Gather each atom's box per frame efficiently inv_box_atoms = inv_box[:, system_idx] # (T, N, 3, 3) - box_atoms = box[:, system_idx] # (T, N, 3, 3) + box_atoms = cells[:, system_idx] # (T, N, 3, 3) # Compute fractional coordinates per frame - frac = torch.einsum("tni,tnij->tnj", pos, inv_box_atoms) + frac = torch.einsum("tni,tnij->tnj", positions, inv_box_atoms) dfrac = frac[1:] - frac[:-1] dfrac -= torch.round(dfrac) @@ -1234,8 +1234,8 @@ def unwrap_positions( raise ValueError("box must have shape (n_systems,3,3) or (T,n_systems,3,3)") # Cumulative reconstruction - unwrapped = torch.empty_like(pos) - unwrapped[0] = pos[0] + unwrapped = torch.empty_like(positions) + unwrapped[0] = positions[0] unwrapped[1:] = torch.cumsum(dcart, dim=0) + unwrapped[0] return unwrapped From 87644fa0fbe644bda64e6795e938073b31f5fa37 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 17:29:45 +0100 Subject: [PATCH 36/43] reduce precision for test_unwrap --- tests/test_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index bd9b9f81..fe9e8c3f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1332,7 +1332,7 @@ def test_unwrap_positions(ar_double_sim_state: ts.SimState, lj_model: LennardJon state.cell, state.system_idx, ) - assert torch.allclose(unwrapped_positions, positions, atol=1e-5) + assert torch.allclose(unwrapped_positions, positions, atol=1e-4) # Different cell state = ts.npt_langevin_init( @@ -1364,4 +1364,4 @@ def test_unwrap_positions(ar_double_sim_state: ts.SimState, lj_model: LennardJon state.cell, state.system_idx, ) - assert torch.allclose(unwrapped_positions, positions, atol=1e-5) + assert torch.allclose(unwrapped_positions, positions, atol=1e-4) From 95857b03d78caeae4428a4ee4dd3a834872da3d8 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 17:29:55 +0100 Subject: [PATCH 37/43] updates names --- tests/test_constraints.py | 28 +++++++++++++-------------- torch_sim/constraints.py | 40 +++++++++++++++++++-------------------- torch_sim/state.py | 2 +- 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index c80d15ec..e0fd3d64 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -181,9 +181,9 @@ def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): assert len(concatenated_state.constraints) == 2 # Verify FixAtoms constraint indices are correctly mapped - assert torch.all(first_system.constraints[0].indices == torch.tensor([0, 1])) + assert torch.all(first_system.constraints[0].atom_idx == torch.tensor([0, 1])) assert torch.all( - concatenated_state.constraints[0].indices == torch.tensor([0, 1, 32, 33]) + concatenated_state.constraints[0].atom_idx == torch.tensor([0, 1, 32, 33]) ) # Verify FixCom constraint system masks @@ -194,8 +194,8 @@ def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): # Test constraint propagation after splitting concatenated state split_systems = concatenated_state.split() assert len(split_systems[0].constraints) == 2 - assert torch.all(split_systems[0].constraints[0].indices == torch.tensor([0, 1])) - assert torch.all(split_systems[1].constraints[0].indices == torch.tensor([0, 1])) + assert torch.all(split_systems[0].constraints[0].atom_idx == torch.tensor([0, 1])) + assert torch.all(split_systems[1].constraints[0].atom_idx == torch.tensor([0, 1])) assert len(split_systems[2].constraints) == 1 # Test constraint manipulation with different configurations @@ -397,7 +397,7 @@ def test_fix_atoms_validation() -> None: # Boolean mask conversion mask = torch.zeros(10, dtype=torch.bool) mask[:3] = True - assert torch.all(FixAtoms(indices=mask).indices == torch.tensor([0, 1, 2])) + assert torch.all(FixAtoms(indices=mask).atom_idx == torch.tensor([0, 1, 2])) # Invalid indices with pytest.raises(ValueError, match="Indices must be integers"): @@ -464,7 +464,7 @@ def test_integrators_with_constraints( # Store initial state if isinstance(constraint, FixAtoms): - initial = cu_sim_state.positions[constraint.indices].clone() + initial = cu_sim_state.positions[constraint.atom_idx].clone() else: initial = get_centers_of_mass( cu_sim_state.positions, @@ -505,7 +505,7 @@ def test_integrators_with_constraints( # Verify constraint held if isinstance(constraint, FixAtoms): - assert torch.allclose(state.positions[constraint.indices], initial, atol=1e-6) + assert torch.allclose(state.positions[constraint.atom_idx], initial, atol=1e-6) else: final = get_centers_of_mass( state.positions, state.masses, state.system_idx, state.n_systems @@ -593,7 +593,7 @@ def test_batched_constraints(ar_double_sim_state: ts.SimState) -> None: combined = ts.concatenate_states([s1, s2]) assert len(combined.constraints) == 2 assert isinstance(combined.constraints[0], FixAtoms) - assert torch.all(combined.constraints[0].indices == torch.tensor([0, 1])) + assert torch.all(combined.constraints[0].atom_idx == torch.tensor([0, 1])) assert isinstance(combined.constraints[1], FixCom) assert torch.all(combined.constraints[1].system_idx == torch.tensor([1])) @@ -721,7 +721,7 @@ def test_system_constraint_update_and_select() -> None: def test_atom_indexed_constraint_update_and_select() -> None: - """Test select_constraint and select_sub_constraint for AtomIndexedConstraint.""" + """Test select_constraint and select_sub_constraint for AtomConstraint.""" # Create a FixAtoms constraint for atoms 0, 1, 5, 8 constraint = FixAtoms(indices=[0, 1, 5, 8]) @@ -736,7 +736,7 @@ def test_atom_indexed_constraint_update_and_select() -> None: # Atom indices should be renumbered: # Original: [0, 1, 5, 8] # After dropping atom 4: [0, 1, 4, 7] (indices shift down by 1 after index 4) - assert torch.all(updated_constraint.indices == torch.tensor([0, 1, 4, 7])) + assert torch.all(updated_constraint.atom_idx == torch.tensor([0, 1, 4, 7])) # Test select_sub_constraint # Select atoms that belong to a specific system @@ -748,7 +748,7 @@ def test_atom_indexed_constraint_update_and_select() -> None: # Should return a constraint with only atoms 0, 1 (within atom_idx range) # Renumbered to start from 0 assert sub_constraint is not None - assert torch.all(sub_constraint.indices == torch.tensor([0, 1])) + assert torch.all(sub_constraint.atom_idx == torch.tensor([0, 1])) # Test with different atom range constraint = FixAtoms(indices=[0, 1, 5, 8]) @@ -758,7 +758,7 @@ def test_atom_indexed_constraint_update_and_select() -> None: # Should return a constraint with atoms 5, 8 renumbered to [0, 3] assert sub_constraint is not None - assert torch.all(sub_constraint.indices == torch.tensor([0, 3])) + assert torch.all(sub_constraint.atom_idx == torch.tensor([0, 3])) # Test when no atoms in range constraint = FixAtoms(indices=[0, 1]) @@ -809,7 +809,7 @@ def test_merge_constraints(ar_double_sim_state: ts.SimState) -> None: # FixAtoms should have indices [0, 1] from s1 and [2+n_atoms_s1, 3+n_atoms_s1] from s2 expected_atom_indices = torch.tensor([0, 1, 2 + n_atoms_s1, 3 + n_atoms_s1]) - assert torch.all(fix_atoms.indices == expected_atom_indices) + assert torch.all(fix_atoms.atom_idx == expected_atom_indices) # FixCom should have system_idx [0, 1] (one for each original system) expected_system_indices = torch.tensor([0, 1]) @@ -834,4 +834,4 @@ def test_merge_constraints(ar_double_sim_state: ts.SimState) -> None: expected_atom_indices = torch.tensor( [0, 1, 2 + n_atoms_s1, 3 + n_atoms_s1, 0 + n_atoms_s1 + n_atoms_s2] ) - assert torch.all(fix_atoms.indices == expected_atom_indices) + assert torch.all(fix_atoms.atom_idx == expected_atom_indices) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index b22dacb9..f6b33f63 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -109,7 +109,7 @@ def _mask_constraint_indices(idx: torch.Tensor, mask: torch.Tensor) -> torch.Ten return new_indices[~drop_indices] -class AtomIndexedConstraint(Constraint): +class AtomConstraint(Constraint): """Base class for constraints that act on specific atom indices. This class provides common functionality for constraints that operate @@ -156,7 +156,7 @@ def __init__(self, indices: torch.Tensor | list[int]) -> None: "forgot the mask= keyword." ) - self.indices = indices.long() + self.atom_idx = indices.long() def get_indices(self) -> torch.Tensor: """Get the constrained atom indices. @@ -164,7 +164,7 @@ def get_indices(self) -> torch.Tensor: Returns: Tensor of atom indices affected by this constraint """ - return self.indices.clone() + return self.atom_idx.clone() def select_constraint( self, @@ -177,7 +177,7 @@ def select_constraint( atom_mask: Boolean mask for atoms to keep system_mask: Boolean mask for systems to keep """ - indices = self.indices.clone() + indices = self.atom_idx.clone() indices = _mask_constraint_indices(indices, atom_mask) if len(indices) == 0: return None @@ -194,8 +194,8 @@ def select_sub_constraint( atom_idx: Atom indices for a single system sys_idx: System index for a single system """ - mask = torch.isin(self.indices, atom_idx) - masked_indices = self.indices[mask] + mask = torch.isin(self.atom_idx, atom_idx) + masked_indices = self.atom_idx[mask] new_atom_idx = masked_indices - atom_idx.min() if len(new_atom_idx) == 0: return None @@ -264,7 +264,7 @@ def select_sub_constraint( def merge_constraints( - constraint_lists: list[list[AtomIndexedConstraint | SystemConstraint]], + constraint_lists: list[list[AtomConstraint | SystemConstraint]], num_atoms_per_state: torch.Tensor, ) -> list[Constraint]: """Merge constraints from multiple systems into a single list of constraints. @@ -284,8 +284,8 @@ def merge_constraints( constraint_indices: dict[type[Constraint], list[torch.Tensor]] = defaultdict(list) for i, constraint_list in enumerate(constraint_lists): for constraint in constraint_list: - if isinstance(constraint, AtomIndexedConstraint): - idxs = constraint.indices + if isinstance(constraint, AtomConstraint): + idxs = constraint.atom_idx offset = cumsum_atoms[i] elif isinstance(constraint, SystemConstraint): idxs = constraint.system_idx @@ -302,7 +302,7 @@ def merge_constraints( ] -class FixAtoms(AtomIndexedConstraint): +class FixAtoms(AtomConstraint): """Constraint that fixes specified atoms in place. This constraint prevents the specified atoms from moving by: @@ -331,7 +331,7 @@ def get_removed_dof(self, state: SimState) -> torch.Tensor: Number of degrees of freedom removed (3 * number of fixed atoms) """ fixed_atoms_system_idx = torch.bincount( - state.system_idx[self.indices], minlength=state.n_systems + state.system_idx[self.atom_idx], minlength=state.n_systems ) return 3 * fixed_atoms_system_idx @@ -342,7 +342,7 @@ def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None state: Current simulation state new_positions: Proposed positions to be adjusted in-place """ - new_positions[self.indices] = state.positions[self.indices] + new_positions[self.atom_idx] = state.positions[self.atom_idx] def adjust_forces( self, @@ -355,14 +355,14 @@ def adjust_forces( state: Current simulation state forces: Forces to be adjusted in-place """ - forces[self.indices] = 0.0 + forces[self.atom_idx] = 0.0 def __repr__(self) -> str: """String representation of the constraint.""" - if len(self.indices) <= 10: - indices_str = self.indices.tolist() + if len(self.atom_idx) <= 10: + indices_str = self.atom_idx.tolist() else: - indices_str = f"{self.indices[:5].tolist()}...{self.indices[-5:].tolist()}" + indices_str = f"{self.atom_idx[:5].tolist()}...{self.atom_idx[-5:].tolist()}" return f"FixAtoms(indices={indices_str})" @@ -523,7 +523,7 @@ def validate_constraints(constraints: list[Constraint], state: SimState) -> None This function checks for: 1. Overlapping atom indices across multiple constraints - 2. AtomIndexedConstraints spanning multiple systems (requires state) + 2. AtomConstraints spanning multiple systems (requires state) 3. Mixing FixCom with other constraints (warning only) Args: @@ -543,12 +543,12 @@ def validate_constraints(constraints: list[Constraint], state: SimState) -> None has_com_constraint = False for constraint in constraints: - if isinstance(constraint, AtomIndexedConstraint): + if isinstance(constraint, AtomConstraint): indexed_constraints.append(constraint) # Validate that atom indices exist in state if provided check_no_index_out_of_bounds( - constraint.indices, state.n_atoms, type(constraint).__name__ + constraint.atom_idx, state.n_atoms, type(constraint).__name__ ) elif isinstance(constraint, SystemConstraint): check_no_index_out_of_bounds( @@ -560,7 +560,7 @@ def validate_constraints(constraints: list[Constraint], state: SimState) -> None # Check for overlapping atom indices if len(indexed_constraints) > 1: - all_indices = torch.cat([c.indices for c in indexed_constraints]) + all_indices = torch.cat([c.atom_idx for c in indexed_constraints]) unique_indices = torch.unique(all_indices) if len(unique_indices) < len(all_indices): warnings.warn( diff --git a/torch_sim/state.py b/torch_sim/state.py index 3c5feace..1d14b1ee 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -724,7 +724,7 @@ def _filter_attrs_by_mask( # Copy global attributes directly filtered_attrs = dict(get_attrs_for_scope(state, "global")) - # take into account constraints that are AtomIndexedConstraint + # take into account constraints that are AtomConstraint filtered_attrs["_constraints"] = [ constraint.select_constraint(atom_mask, system_mask) for constraint in copy.deepcopy(state.constraints) From 7022df256b6e99d5de29eaf482847efb94c8143e Mon Sep 17 00:00:00 2001 From: thomasloux Date: Mon, 24 Nov 2025 17:31:17 +0100 Subject: [PATCH 38/43] remove einsteinModel (not for this PR) --- torch_sim/models/einstein.py | 285 ----------------------------------- 1 file changed, 285 deletions(-) delete mode 100644 torch_sim/models/einstein.py diff --git a/torch_sim/models/einstein.py b/torch_sim/models/einstein.py deleted file mode 100644 index af452f38..00000000 --- a/torch_sim/models/einstein.py +++ /dev/null @@ -1,285 +0,0 @@ -"""Einstein model where each atom is treated as an independent 3D harmonic oscillator. - -Contrary to other models, the model energies depend on an absolute reference position, -so the model can only be used on systems that the model was initialized with. -As a analytical model, it can provide its Helmholtz free energy and can also generate -samples from the Boltzmann distribution at a given temperature. -""" - -import torch - -import torch_sim as ts -from torch_sim import SimState, units -from torch_sim.models.interface import ModelInterface - - -class EinsteinModel(ModelInterface): - """Einstein model where each atom is treated as an independent 3D harmonic oscillator. - Each atom has its own frequency. - - For this model: - E = sum_i 0.5 * k_i * (x_i - x0_i)^2 - F = -k_i * (x_i - x0_i) - k_i = m_i * omega_i^2 - - For best results, frequencies should be in the range of typical phonon frequencies. - They can be set for each atom type individually following energy balance from - a NVT simulation. From equipartition theorem: - = 3/2 k_B T - => omega = sqrt(3 k_B T / m ) - """ - - def __init__( - self, - equilibrium_position: torch.Tensor, # shape [N, 3] - frequencies: torch.Tensor, # shape [N] - system_idx: torch.Tensor | None = None, # shape [N] or None - masses: torch.Tensor | None = None, # shape [N] or None - reference_energy: float = 0.0, # reference energy value - *, - device: torch.device | None = None, - dtype: torch.dtype = torch.float32, - compute_forces: bool = True, - compute_stress: bool = False, - ) -> None: - """Initialize the Einstein model. - - Args: - equilibrium_position: Tensor of shape [N, 3] with equilibrium positions. - frequencies: Tensor of shape [N] with frequencies for each atom - (same frequency in all 3 directions). - system_idx: Optional tensor of shape [N] with system indices for each atom. - If None, all atoms are assumed to belong to the same system. - masses: Optional tensor of shape [N] with masses for each atom. - If None, all masses are set to 1. - reference_energy: Reference energy value to add to the computed energy. - device: Device to use for the model (default: CPU). - dtype: Data type for the model (default: torch.float32). - compute_forces: Whether to compute forces in the model. - compute_stress: Whether to compute stress in the model. - - """ - super().__init__() - self._device = device or torch.device("cpu") - self._dtype = dtype - self._compute_forces = compute_forces - self._compute_stress = compute_stress - - equilibrium_position = torch.as_tensor( - equilibrium_position, device=self._device, dtype=self._dtype - ) - frequencies = torch.as_tensor( - frequencies, device=self._device, dtype=self._dtype - ) # [N, 3] - - if frequencies.shape[0] != equilibrium_position.shape[0]: - raise ValueError("frequencies shape must match equilibrium_position shape") - if frequencies.min() < 0: - raise ValueError("frequencies must be non-negative") - if frequencies.ndim == 0: - frequencies = frequencies.unsqueeze(0) - if frequencies.ndim != 1: - raise ValueError("frequencies must be a 1D tensor") - - if masses is None: - masses = torch.ones( - equilibrium_position.shape[0], dtype=self._dtype, device=self._device - ) - else: - masses = masses.to(self._device, self._dtype) - - if system_idx is not None: - system_idx = system_idx.to(self._device) - else: - system_idx = torch.zeros( - equilibrium_position.shape[0], dtype=torch.long, device=self._device - ) - - self.register_buffer("system_idx", system_idx.to(self._device)) - self.register_buffer("masses", masses) # [N] - self.register_buffer("x0", equilibrium_position) # [N, 3] - self.register_buffer("frequencies", frequencies) # [N] - self.register_buffer( - "reference_energy", - torch.tensor(reference_energy, dtype=self._dtype, device=self._device), - ) - - @classmethod - def from_atom_and_frequencies( - cls, - atom: SimState, - frequencies: torch.Tensor | float, - *, - reference_energy: float = 0.0, - compute_forces: bool = True, - compute_stress: bool = False, - device: torch.device | None = None, - dtype: torch.dtype = torch.float32, - ) -> "EinsteinModel": - """Create an EinsteinModel from an ASE Atoms object and frequencies. - - Args: - atom: ASE Atoms object containing the reference structure. - frequencies: Tensor of shape [N] with frequencies for each atom - (same frequency in all 3 directions) or a scalar. - reference_energy: Reference energy value. - compute_forces: Whether to compute forces in the model. - compute_stress: Whether to compute stress in the model. - device: Device to use for the model (default: CPU). - dtype: Data type for the model (default: torch.float32). - - Returns: - EinsteinModel: An instance of the EinsteinModel. - """ - # Get equilibrium positions from the atoms object - equilibrium_position = atom.positions.clone().to(dtype=dtype, device=device) - - frequencies = torch.as_tensor(frequencies, dtype=dtype, device=device) - if frequencies.ndim == 0: - frequencies = frequencies.repeat(atom.positions.shape[0]) - if frequencies.shape[0] != atom.positions.shape[0]: - raise ValueError( - "frequencies must be a scalar or a tensor of shape [N] " - "where N is the number of atoms" - ) - - # Create and return an instance of EinsteinModel - return cls( - equilibrium_position=equilibrium_position, - frequencies=frequencies, - masses=atom.masses, - system_idx=atom.system_idx, - reference_energy=reference_energy, - compute_forces=compute_forces, - compute_stress=compute_stress, - device=device, - dtype=dtype, - ) - - def get_spring_constants(self) -> torch.Tensor: - """Get the spring constants for each atom in the Einstein model. - - Returns: - Tensor of shape [N] with spring constants k_i = m_i * omega_i^2 - for each atom. - """ - return self.masses * (self.frequencies**2) # [N] - - def forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: - """Calculate energies and forces for the Einstein model. - - Args: - state: SimState or StateDict containing positions, cell, etc. - - Returns: - Dictionary containing energy, forces - """ - pos = state.positions.to(self._dtype) # [N, 3] - cell = state.cell.to(self._dtype) - - if cell.ndim == 2: - cell = cell.unsqueeze(0) # [1, 3, 3] - - # Get model parameters - x0 = torch.as_tensor(self.x0, dtype=self._dtype, device=self._device) - frequencies = torch.as_tensor( - self.frequencies, dtype=self._dtype, device=self._device - ) - masses = torch.as_tensor(self.masses, dtype=self._dtype, device=self._device) - - # Calculate displacements using periodic boundary conditions - if cell.shape[0] == 1: - disp = ts.transforms.minimum_image_displacement( - dr=pos - x0, cell=cell[0], pbc=state.pbc - ) - else: - disp = ts.transforms.minimum_image_displacement_batched( - pos - x0, cell, system_idx=state.system_idx, pbc=state.pbc - ) - - # Spring constants: k = m * omega^2 - spring_constants = masses * (frequencies**2) # [N] - - # Energy: E = 0.5 * k * x^2 - energies_per_mode = 0.5 * spring_constants * ((disp**2).sum(dim=1)) # [N] - total_energy = torch.zeros( - state.n_systems, dtype=self._dtype, device=self._device - ) - total_energy.scatter_add_(0, state.system_idx, energies_per_mode) - total_energy += self.reference_energy - - # Forces: F = -k * x - forces = -spring_constants.unsqueeze(-1) * disp # [N, 3] - - results = { - "energy": total_energy, - "forces": forces, - } - # Stress is not implemented for this model - if self._compute_stress: - results["stress"] = torch.zeros( - (state.n_systems, 3, 3), dtype=self._dtype, device=self._device - ) - - return results - - def get_free_energy(self, temperature: float) -> dict[str, torch.Tensor]: - """Compute free energy at a given temperature using Einstein model. - - Args: - temperature: Temperature in Kelvin. - - Returns: - Dictionary containing heat capacity, entropy, and free energy. - """ - # Boltzmann constant in eV/K - kB = units.BaseConstant.k_B / units.UnitConversion.eV_to_J - T = temperature - # Reduced Planck constant in eV*s - hbar = units.BaseConstant.h_planck / (2 * units.pi * units.UnitConversion.eV_to_J) - - frequencies_tensor = ( - torch.as_tensor(self.frequencies).clone() - * torch.as_tensor( - units.UnitConversion.eV_to_J / units.BaseConstant.amu - ).sqrt() - / units.UnitConversion.Ang_to_met - ) # Convert to rad/s - free_energy_per_atom = ( - -3 * kB * T * torch.log(kB * T / (hbar * frequencies_tensor)) - ) - - n_systems = self.system_idx.max().item() + 1 - free_energy_per_system = torch.zeros( - n_systems, dtype=self._dtype, device=self._device - ) - free_energy_per_system.scatter_add_(0, self.system_idx, free_energy_per_atom) - - return {"free_energy": free_energy_per_system} - - def sample(self, state: SimState, temperature: float) -> SimState: - """Generate samples from the Einstein model at a given temperature. - - Args: - state: Initial simulation state to sample from. - temperature: Temperature in Kelvin. - - Returns: - SimState containing sampled positions and velocities. - - The Boltzmann distribution for a harmonic oscillator leads to Gaussian - distributions - for both positions and velocities. - """ - N = self.x0.shape[0] - kB = units.BaseConstant.k_B / units.UnitConversion.eV_to_J - beta = 1.0 / (kB * temperature) # Inverse temperature in 1/eV - - # Sample positions from a normal distribution around equilibrium positions - stddev = torch.sqrt(1.0 / (self.masses * (self.frequencies**2) * beta)).unsqueeze( - -1 - ) - sampled_positions = self.x0 + torch.randn(N, 3, device=self._device) * stddev - state = state.clone() - state.positions = sampled_positions - return state From 07624f0eca04a0c104f070a332138e67007d4540 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Wed, 26 Nov 2025 10:35:22 +0100 Subject: [PATCH 39/43] rename var and add mask --- torch_sim/constraints.py | 74 ++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index f6b33f63..e68dd897 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -116,47 +116,44 @@ class AtomConstraint(Constraint): on a subset of atoms, identified by their indices. """ - def __init__(self, indices: torch.Tensor | list[int]) -> None: + def __init__( + self, + atom_idx: torch.Tensor | list[int] | None = None, + atom_mask: torch.Tensor | list[int] | None = None, + ) -> None: """Initialize indexed constraint. Args: - indices: Indices of atoms to constrain. Can be a tensor or list of integers. + atom_idx: Indices of atoms to constrain. Can be a tensor or list of integers. + atom_mask: Boolean mask for atoms to constrain. Raises: ValueError: If both indices and mask are provided, or if indices have wrong shape/type """ + if atom_idx is not None and atom_mask is not None: + raise ValueError("Provide either atom_idx or atom_mask, not both.") + if atom_mask is not None: + atom_mask = torch.as_tensor(atom_mask) + atom_idx = torch.where(atom_mask)[0] + # Convert to tensor if needed - if not isinstance(indices, torch.Tensor): - indices = torch.tensor(indices) + atom_idx = torch.as_tensor(atom_idx) # Ensure we have the right shape and type - indices = torch.atleast_1d(indices) - if indices.ndim > 1: + atom_idx = torch.atleast_1d(atom_idx) + if atom_idx.ndim > 1: raise ValueError( - "indices has wrong number of dimensions. " - f"Got {indices.ndim}, expected ndim <= 1" + "atom_idx has wrong number of dimensions. " + f"Got {atom_idx.ndim}, expected ndim <= 1" ) - if indices.dtype == torch.bool: - # Convert boolean mask to indices - indices = torch.where(indices)[0] - elif len(indices) == 0: - indices = torch.empty(0, dtype=torch.long) - elif torch.is_floating_point(indices): + if torch.is_floating_point(atom_idx): raise ValueError( - f"Indices must be integers or boolean mask, not dtype={indices.dtype}" + f"Indices must be integers or boolean mask, not dtype={atom_idx.dtype}" ) - # Check for duplicates - if len(torch.unique(indices)) < len(indices): - raise ValueError( - "The indices array contains duplicates. " - "Perhaps you want to specify a mask instead, but " - "forgot the mask= keyword." - ) - - self.atom_idx = indices.long() + self.atom_idx = atom_idx.long() def get_indices(self) -> torch.Tensor: """Get the constrained atom indices. @@ -209,17 +206,28 @@ class SystemConstraint(Constraint): on a subset of systems, identified by their indices. """ - def __init__(self, system_idx: torch.Tensor | list[int]) -> None: + def __init__( + self, + system_idx: torch.Tensor | list[int] | None = None, + system_mask: torch.Tensor | list[int] | None = None, + ) -> None: """Initialize indexed constraint. Args: - system_idx: Indices of systems to constrain. Can be a tensor or - list of integers. + system_idx: Indices of systems to constrain. + Can be a tensor or list of integers. + system_mask: Boolean mask for systems to constrain. Raises: ValueError: If both indices and mask are provided, or if indices have - wrong shape/type + wrong shape/type """ + if system_idx is not None and system_mask is not None: + raise ValueError("Provide either system_idx or system_mask, not both.") + if system_mask is not None: + system_idx = torch.as_tensor(system_idx) + system_idx = torch.where(system_mask)[0] + # Convert to tensor if needed system_idx = torch.as_tensor(system_idx) @@ -230,7 +238,13 @@ def __init__(self, system_idx: torch.Tensor | list[int]) -> None: "system_idx has wrong number of dimensions. " f"Got {system_idx.ndim}, expected ndim <= 1" ) - self.system_idx: torch.Tensor = system_idx + + if torch.is_floating_point(system_idx): + raise ValueError( + f"Indices must be integers or boolean mask, not dtype={system_idx.dtype}" + ) + + self.system_idx = system_idx.long() def select_constraint( self, @@ -312,7 +326,7 @@ class FixAtoms(AtomConstraint): Examples: Fix atoms with indices [0, 1, 2]: - >>> constraint = FixAtoms(indices=[0, 1, 2]) + >>> constraint = FixAtoms(atom_idx=[0, 1, 2]) Fix atoms using a boolean mask: >>> mask = torch.tensor([True, True, True, False, False]) From 0940919412b35511fa15db36456da3b6f5931b32 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Wed, 26 Nov 2025 10:47:52 +0100 Subject: [PATCH 40/43] remove comment now that a warning is set up for NPT MD with constraints --- torch_sim/integrators/npt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index bad47d78..388a4427 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -429,7 +429,7 @@ def _npt_langevin_velocity_step( # Update momenta (velocities * masses) with all contributions new_velocities = c_1 + c_2 + c_3 - # Apply constraints. Is it correct to apply constraints here? + # Apply constraints. state.set_momenta(new_velocities * state.masses.unsqueeze(-1)) return state From b49e3091e80ff71b9d2afdab563798101646fd76 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Wed, 26 Nov 2025 17:33:55 +0100 Subject: [PATCH 41/43] Add duplicate error in FixAtoms (subclass of AtomConstraint will handle duplicate differently) and SystemConstraint --- torch_sim/constraints.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index e68dd897..01cfa540 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -239,6 +239,10 @@ def __init__( f"Got {system_idx.ndim}, expected ndim <= 1" ) + # Check for duplicates + if len(system_idx) != len(torch.unique(system_idx)): + raise ValueError("Duplicate system indices found in SystemConstraint.") + if torch.is_floating_point(system_idx): raise ValueError( f"Indices must be integers or boolean mask, not dtype={system_idx.dtype}" @@ -333,6 +337,17 @@ class FixAtoms(AtomConstraint): >>> constraint = FixAtoms(mask=mask) """ + def __init__( + self, + atom_idx: torch.Tensor | list[int] | None = None, + atom_mask: torch.Tensor | list[int] | None = None, + ) -> None: + """Initialize FixAtoms constraint and check for duplicate indices.""" + super().__init__(atom_idx=atom_idx, atom_mask=atom_mask) + # Check duplicates + if len(self.atom_idx) != len(torch.unique(self.atom_idx)): + raise ValueError("Duplicate atom indices found in FixAtoms constraint.") + def get_removed_dof(self, state: SimState) -> torch.Tensor: """Get number of removed degrees of freedom. From 65fd0cfd486764952c6f965100abca41fe52fe67 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Wed, 26 Nov 2025 17:34:09 +0100 Subject: [PATCH 42/43] rename args FixAtoms tests --- tests/test_constraints.py | 62 ++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index e0fd3d64..7704bb8e 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -45,7 +45,7 @@ def test_fix_com(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesMode def test_fix_atoms(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel): """Test adjustment of positions and momenta with FixAtoms constraint.""" indices_to_fix = torch.tensor([0, 5, 10], dtype=torch.long) - ar_supercell_sim_state.constraints = [FixAtoms(indices=indices_to_fix)] + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=indices_to_fix)] initial_positions = ar_supercell_sim_state.positions.clone() # displacement = torch.randn_like(ar_supercell_sim_state.positions) * 0.5 displacement = 0.5 @@ -134,7 +134,7 @@ def test_fix_atoms_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJone kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature dofs_before = cu_sim_state.get_number_of_degrees_of_freedom() - cu_sim_state.constraints = [FixAtoms(indices=torch.tensor([0, 1], dtype=torch.long))] + cu_sim_state.constraints = [FixAtoms(atom_idx=torch.tensor([0, 1], dtype=torch.long))] assert torch.allclose( cu_sim_state.get_number_of_degrees_of_freedom(), dofs_before - torch.tensor([6]) ) @@ -164,7 +164,7 @@ def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): """Test that constraints are properly propagated during state manipulation.""" # Set up constraints on the original state ar_double_sim_state.constraints = [ - FixAtoms(indices=torch.tensor([0, 1])), # Only applied to first system + FixAtoms(atom_idx=torch.tensor([0, 1])), # Only applied to first system FixCom([0, 1]), ] @@ -271,7 +271,7 @@ def test_fix_atoms_gradient_descent_optimization( ar_supercell_sim_state.positions = perturbed_positions initial_state = ar_supercell_sim_state - initial_state.constraints = [FixAtoms(indices=[0])] + initial_state.constraints = [FixAtoms(atom_idx=[0])] initial_position = initial_state.positions[0].clone() # Initialize Gradient Descent optimizer @@ -315,7 +315,7 @@ def test_test_atoms_fire_optimization( system_idx=ar_supercell_sim_state.system_idx.clone(), ) indices = torch.tensor([0, 2], dtype=torch.long) - current_sim_state.constraints = [FixAtoms(indices=indices)] + current_sim_state.constraints = [FixAtoms(atom_idx=indices)] # Initialize FIRE optimizer state = ts.fire_init( @@ -397,26 +397,28 @@ def test_fix_atoms_validation() -> None: # Boolean mask conversion mask = torch.zeros(10, dtype=torch.bool) mask[:3] = True - assert torch.all(FixAtoms(indices=mask).atom_idx == torch.tensor([0, 1, 2])) + assert torch.all(FixAtoms(atom_mask=mask).atom_idx == torch.tensor([0, 1, 2])) # Invalid indices with pytest.raises(ValueError, match="Indices must be integers"): - FixAtoms(indices=torch.tensor([0.5, 1.5])) - with pytest.raises(ValueError, match="duplicates"): - FixAtoms(indices=torch.tensor([0, 1, 1])) + FixAtoms(atom_idx=torch.tensor([0.5, 1.5])) + with pytest.raises(ValueError, match="Duplicate"): + FixAtoms(atom_idx=torch.tensor([0, 1, 1])) with pytest.raises(ValueError, match="wrong number of dimensions"): - FixAtoms(indices=torch.tensor([[0, 1]])) + FixAtoms(atom_idx=torch.tensor([[0, 1]])) def test_constraint_validation_warnings(ar_double_sim_state: ts.SimState) -> None: """Test validation warnings for constraint conflicts.""" with pytest.warns(UserWarning, match="Multiple constraints.*same atoms"): validate_constraints( - [FixAtoms(indices=[0, 1, 2]), FixAtoms(indices=[2, 3, 4])], + [FixAtoms(atom_idx=[0, 1, 2]), FixAtoms(atom_idx=[2, 3, 4])], ar_double_sim_state, ) with pytest.warns(UserWarning, match="FixCom together with other constraints"): - validate_constraints([FixCom([0]), FixAtoms(indices=[0, 1])], ar_double_sim_state) + validate_constraints( + [FixCom([0]), FixAtoms(atom_idx=[0, 1])], ar_double_sim_state + ) def test_constraint_validation_errors( @@ -426,10 +428,10 @@ def test_constraint_validation_errors( """Test validation errors for invalid constraints.""" # Out of bounds with pytest.raises(ValueError, match=r"has indices up to.*only has.*atoms"): - cu_sim_state.constraints = [FixAtoms(indices=[0, 1, 100])] + cu_sim_state.constraints = [FixAtoms(atom_idx=[0, 1, 100])] # Validation in __post_init__ - with pytest.raises(ValueError, match="duplicates"): + with pytest.raises(ValueError, match="Duplicate"): ts.SimState( positions=ar_supercell_sim_state.positions.clone(), masses=ar_supercell_sim_state.masses, @@ -437,16 +439,16 @@ def test_constraint_validation_errors( pbc=ar_supercell_sim_state.pbc, atomic_numbers=ar_supercell_sim_state.atomic_numbers, system_idx=ar_supercell_sim_state.system_idx, - _constraints=[FixAtoms(indices=[0, 0, 1])], + _constraints=[FixAtoms(atom_idx=[0, 0, 1])], ) @pytest.mark.parametrize( ("integrator", "constraint", "n_steps"), [ - ("nve", FixAtoms(indices=[0, 1]), 100), + ("nve", FixAtoms(atom_idx=[0, 1]), 100), ("nvt_nose_hoover", FixCom([0]), 200), - ("npt_langevin", FixAtoms(indices=[0, 3]), 200), + ("npt_langevin", FixAtoms(atom_idx=[0, 3]), 200), ("npt_nose_hoover", FixCom([0]), 200), ], ) @@ -520,9 +522,9 @@ def test_multiple_constraints_and_dof( # Test DOF calculation n = cu_sim_state.n_atoms assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n) - cu_sim_state.constraints = [FixAtoms(indices=[0])] + cu_sim_state.constraints = [FixAtoms(atom_idx=[0])] assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n - 3) - cu_sim_state.constraints = [FixCom([0]), FixAtoms(indices=[0])] + cu_sim_state.constraints = [FixCom([0]), FixAtoms(atom_idx=[0])] assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n - 6) # Verify both constraints hold during dynamics @@ -571,7 +573,7 @@ def test_cell_optimization_with_constraints( ar_supercell_sim_state.positions += ( torch.randn_like(ar_supercell_sim_state.positions) * 0.05 ) - ar_supercell_sim_state.constraints = [FixAtoms(indices=[0, 1])] + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=[0, 1])] state = ts.fire_init( ar_supercell_sim_state, lj_model, @@ -588,7 +590,7 @@ def test_cell_optimization_with_constraints( def test_batched_constraints(ar_double_sim_state: ts.SimState) -> None: """Test system-specific constraints in batched states.""" s1, s2 = ar_double_sim_state.split() - s1.constraints = [FixAtoms(indices=[0, 1])] + s1.constraints = [FixAtoms(atom_idx=[0, 1])] s2.constraints = [FixCom([0])] combined = ts.concatenate_states([s1, s2]) assert len(combined.constraints) == 2 @@ -655,7 +657,7 @@ def test_high_level_api_with_constraints( ar_supercell_sim_state.positions += ( torch.randn_like(ar_supercell_sim_state.positions) * 0.1 ) - ar_supercell_sim_state.constraints = [FixAtoms(indices=[0, 1, 2])] + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=[0, 1, 2])] initial_pos = ar_supercell_sim_state.positions[[0, 1, 2]].clone() final = ts.optimize( ar_supercell_sim_state, lj_model, optimizer=ts.Optimizer.fire, max_steps=500 @@ -668,7 +670,7 @@ def test_temperature_with_constrained_dof( ) -> None: """Test temperature calculation uses constrained DOF.""" target = 300.0 - cu_sim_state.constraints = [FixAtoms(indices=[0, 1, 2])] + cu_sim_state.constraints = [FixAtoms(atom_idx=[0, 1, 2])] state = ts.nvt_langevin_init( cu_sim_state, lj_model, @@ -723,7 +725,7 @@ def test_system_constraint_update_and_select() -> None: def test_atom_indexed_constraint_update_and_select() -> None: """Test select_constraint and select_sub_constraint for AtomConstraint.""" # Create a FixAtoms constraint for atoms 0, 1, 5, 8 - constraint = FixAtoms(indices=[0, 1, 5, 8]) + constraint = FixAtoms(atom_idx=[0, 1, 5, 8]) # Test select_constraint with atom_mask # Keep atoms 0, 1, 2, 3, 5, 6, 7, 8 (drop atoms 4) @@ -740,7 +742,7 @@ def test_atom_indexed_constraint_update_and_select() -> None: # Test select_sub_constraint # Select atoms that belong to a specific system - constraint = FixAtoms(indices=[0, 1, 5, 8]) + constraint = FixAtoms(atom_idx=[0, 1, 5, 8]) atom_idx = torch.tensor([0, 1, 2, 3, 4]) # Atoms for first system sys_idx = 0 sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx) @@ -751,7 +753,7 @@ def test_atom_indexed_constraint_update_and_select() -> None: assert torch.all(sub_constraint.atom_idx == torch.tensor([0, 1])) # Test with different atom range - constraint = FixAtoms(indices=[0, 1, 5, 8]) + constraint = FixAtoms(atom_idx=[0, 1, 5, 8]) atom_idx = torch.tensor([5, 6, 7, 8, 9]) # Atoms for second system sys_idx = 1 sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx) @@ -761,7 +763,7 @@ def test_atom_indexed_constraint_update_and_select() -> None: assert torch.all(sub_constraint.atom_idx == torch.tensor([0, 3])) # Test when no atoms in range - constraint = FixAtoms(indices=[0, 1]) + constraint = FixAtoms(atom_idx=[0, 1]) atom_idx = torch.tensor([5, 6, 7, 8]) sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx=1) assert sub_constraint is None @@ -777,13 +779,13 @@ def test_merge_constraints(ar_double_sim_state: ts.SimState) -> None: # Create constraints for each system # System 1: Fix atoms 0, 1 and fix COM for system 0 s1_constraints = [ - FixAtoms(indices=[0, 1]), + FixAtoms(atom_idx=[0, 1]), FixCom([0]), ] # System 2: Fix atoms 2, 3 and fix COM for system 0 s2_constraints = [ - FixAtoms(indices=[2, 3]), + FixAtoms(atom_idx=[2, 3]), FixCom([0]), ] @@ -817,7 +819,7 @@ def test_merge_constraints(ar_double_sim_state: ts.SimState) -> None: # Test with three systems s3 = s1.clone() - s3_constraints = [FixAtoms(indices=[0])] + s3_constraints = [FixAtoms(atom_idx=[0])] constraint_lists = [s1_constraints, s2_constraints, s3_constraints] num_atoms_per_state = torch.tensor([n_atoms_s1, n_atoms_s2, s3.n_atoms]) merged_constraints = merge_constraints(constraint_lists, num_atoms_per_state) From fee207fc45311e5c1bcafda98ba28ae2bf987ea9 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Wed, 26 Nov 2025 17:36:48 +0100 Subject: [PATCH 43/43] system_idx for constraint must be dim 1 --- torch_sim/constraints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 01cfa540..d352d539 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -142,7 +142,7 @@ def __init__( # Ensure we have the right shape and type atom_idx = torch.atleast_1d(atom_idx) - if atom_idx.ndim > 1: + if atom_idx.ndim != 1: raise ValueError( "atom_idx has wrong number of dimensions. " f"Got {atom_idx.ndim}, expected ndim <= 1" @@ -233,7 +233,7 @@ def __init__( # Ensure we have the right shape and type system_idx = torch.atleast_1d(system_idx) - if system_idx.ndim > 1: + if system_idx.ndim != 1: raise ValueError( "system_idx has wrong number of dimensions. " f"Got {system_idx.ndim}, expected ndim <= 1"