diff --git a/examples/scripts/2_Structural_optimization/2.8_MACE_LBFGS.py b/examples/scripts/2_Structural_optimization/2.8_MACE_LBFGS.py new file mode 100644 index 00000000..74c983d0 --- /dev/null +++ b/examples/scripts/2_Structural_optimization/2.8_MACE_LBFGS.py @@ -0,0 +1,79 @@ +"""Batched MACE L-BFGS optimizer with ASE comparison.""" + +# /// script +# dependencies = ["mace-torch>=0.3.12"] +# /// +import os + +import numpy as np +import torch +from ase.build import bulk +from ase.optimize import LBFGS as ASE_LBFGS +from mace.calculators.foundations_models import mace_mp + +import torch_sim as ts +from torch_sim.models.mace import MaceModel, MaceUrls + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +dtype = torch.float32 + +loaded_model = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype=str(dtype).removeprefix("torch."), + device=str(device), +) + +SMOKE_TEST = os.getenv("CI") is not None +N_steps = 10 if SMOKE_TEST else 200 + +rng = np.random.default_rng(seed=0) + +si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) +si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape) + +cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2)) +cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape) + +fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2)) +fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape) + +atoms_list = [si_dc, cu_dc, fe_dc] + +model = MaceModel( + model=loaded_model, + device=device, + compute_forces=True, + compute_stress=True, + dtype=dtype, + enable_cueq=False, +) + +# torch-sim batched L-BFGS +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) +initial_results = model(state) +state = ts.lbfgs_init(state=state, model=model, alpha=70.0, step_size=1.0) + +for _ in range(N_steps): + state = ts.lbfgs_step(state=state, model=model, max_history=100) + +ts_final = [e.item() for e in state.energy] + +# ASE L-BFGS comparison +ase_calc = mace_mp( + model=MaceUrls.mace_mpa_medium, + default_dtype=str(dtype).removeprefix("torch."), + device=str(device), +) +ase_final = [] +for atoms in atoms_list: + atoms.calc = ase_calc + optimizer = ASE_LBFGS(atoms, logfile=None) + optimizer.run(fmax=0.01, steps=N_steps) + ase_final.append(atoms.get_potential_energy()) + +# Results +print(f"Initial energies: {[f'{e.item():.4f}' for e in initial_results['energy']]}") +print(f"torch-sim final: {[f'{e:.4f}' for e in ts_final]}") +print(f"ASE final: {[f'{e:.4f}' for e in ase_final]}") diff --git a/examples/scripts/2_Structural_optimization/2.9_MACE_BFGS.py b/examples/scripts/2_Structural_optimization/2.9_MACE_BFGS.py new file mode 100644 index 00000000..138ced4b --- /dev/null +++ b/examples/scripts/2_Structural_optimization/2.9_MACE_BFGS.py @@ -0,0 +1,79 @@ +"""Batched MACE BFGS optimizer with ASE comparison.""" + +# /// script +# dependencies = ["mace-torch>=0.3.12"] +# /// +import os + +import numpy as np +import torch +from ase.build import bulk +from ase.optimize import BFGS as ASE_BFGS +from mace.calculators.foundations_models import mace_mp + +import torch_sim as ts +from torch_sim.models.mace import MaceModel, MaceUrls + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +dtype = torch.float32 + +loaded_model = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype=str(dtype).removeprefix("torch."), + device=str(device), +) + +SMOKE_TEST = os.getenv("CI") is not None +N_steps = 10 if SMOKE_TEST else 200 + +rng = np.random.default_rng(seed=0) + +si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2)) +si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape) + +cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2)) +cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape) + +fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2)) +fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape) + +atoms_list = [si_dc, cu_dc, fe_dc] + +model = MaceModel( + model=loaded_model, + device=device, + compute_forces=True, + compute_stress=True, + dtype=dtype, + enable_cueq=False, +) + +# torch-sim batched BFGS +state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype) +initial_results = model(state) +state = ts.bfgs_init(state=state, model=model, alpha=70.0) + +for _ in range(N_steps): + state = ts.bfgs_step(state=state, model=model) + +ts_final = [e.item() for e in state.energy] + +# ASE BFGS comparison +ase_calc = mace_mp( + model=MaceUrls.mace_mpa_medium, + default_dtype=str(dtype).removeprefix("torch."), + device=str(device), +) +ase_final = [] +for atoms in atoms_list: + atoms.calc = ase_calc + optimizer = ASE_BFGS(atoms, logfile=None, alpha=70.0) + optimizer.run(fmax=0.01, steps=N_steps) + ase_final.append(atoms.get_potential_energy()) + +# Results +print(f"Initial energies: {[f'{e.item():.4f}' for e in initial_results['energy']]}") +print(f"torch-sim final: {[f'{e:.4f}' for e in ts_final]}") +print(f"ASE final: {[f'{e:.4f}' for e in ase_final]}") diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index f632cbfa..ece43c54 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -52,13 +52,19 @@ from torch_sim.monte_carlo import SwapMCState, swap_mc_init, swap_mc_step from torch_sim.optimizers import ( OPTIM_REGISTRY, + BFGSState, FireState, + LBFGSState, Optimizer, OptimState, + bfgs_init, + bfgs_step, fire_init, fire_step, gradient_descent_init, gradient_descent_step, + lbfgs_init, + lbfgs_step, ) from torch_sim.optimizers.cell_filters import ( CELL_FILTER_REGISTRY, diff --git a/torch_sim/optimizers/__init__.py b/torch_sim/optimizers/__init__.py index 850cfcac..3a8d4faf 100644 --- a/torch_sim/optimizers/__init__.py +++ b/torch_sim/optimizers/__init__.py @@ -10,13 +10,20 @@ from enum import StrEnum from typing import Any, Final, Literal, get_args +from torch_sim.optimizers.bfgs import bfgs_init, bfgs_step from torch_sim.optimizers.cell_filters import CellFireState, CellOptimState # noqa: F401 from torch_sim.optimizers.fire import fire_init, fire_step from torch_sim.optimizers.gradient_descent import ( gradient_descent_init, gradient_descent_step, ) -from torch_sim.optimizers.state import FireState, OptimState # noqa: F401 +from torch_sim.optimizers.lbfgs import lbfgs_init, lbfgs_step +from torch_sim.optimizers.state import ( # noqa: F401 + BFGSState, + FireState, + LBFGSState, + OptimState, +) FireFlavor = Literal["vv_fire", "ase_fire"] @@ -28,9 +35,13 @@ class Optimizer(StrEnum): gradient_descent = "gradient_descent" fire = "fire" + lbfgs = "lbfgs" + bfgs = "bfgs" OPTIM_REGISTRY: Final[dict[Optimizer, tuple[Callable[..., Any], Callable[..., Any]]]] = { Optimizer.gradient_descent: (gradient_descent_init, gradient_descent_step), Optimizer.fire: (fire_init, fire_step), + Optimizer.lbfgs: (lbfgs_init, lbfgs_step), + Optimizer.bfgs: (bfgs_init, bfgs_step), } diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py new file mode 100644 index 00000000..bb7021c5 --- /dev/null +++ b/torch_sim/optimizers/bfgs.py @@ -0,0 +1,282 @@ +"""BFGS (Broyden-Fletcher-Goldfarb-Shanno) optimizer implementation. + +This module provides a batched BFGS optimizer that maintains the full Hessian +matrix for each system. This is suitable for systems with a small to moderate +number of atoms, where the $O(N^2)$ memory cost is acceptable. + +The implementation handles batches of systems with different numbers of atoms +by padding vectors to the maximum number of atoms in the batch. The Hessian +matrices are similarly padded to shape (n_systems, 3*max_atoms, 3*max_atoms). +""" + +from typing import TYPE_CHECKING + +import torch + +from torch_sim.state import SimState +from torch_sim.typing import StateDict + + +if TYPE_CHECKING: + from torch_sim.models.interface import ModelInterface + from torch_sim.optimizers import BFGSState + + +def _get_atom_indices_per_system( + system_idx: torch.Tensor, n_systems: int +) -> torch.Tensor: + """Compute the index of each atom within its system. + + Assumes atoms are grouped contiguously by system. + + Args: + system_idx: Tensor of system indices [n_atoms] + n_systems: Number of systems + + Returns: + Tensor of [0, 1, 2, ..., 0, 1, ...] [n_atoms] + """ + # We assume contiguous atoms for each system, which is standard in SimState + counts = torch.bincount(system_idx, minlength=n_systems) + # Create ranges [0...n-1] for each system and concatenate + indices = [torch.arange(c, device=system_idx.device) for c in counts] + return torch.cat(indices) + + +def _pad_to_dense( + flat_tensor: torch.Tensor, + system_idx: torch.Tensor, + atom_idx_in_system: torch.Tensor, + n_systems: int, + max_atoms: int, +) -> torch.Tensor: + """Convert a packed tensor to a padded dense tensor. + + Args: + flat_tensor: [n_atoms, D] + system_idx: [n_atoms] + atom_idx_in_system: [n_atoms] + n_systems: int + max_atoms: int + + Returns: + dense_tensor: [n_systems, max_atoms, D] + """ + D = flat_tensor.shape[1] + dense = torch.zeros( + (n_systems, max_atoms, D), dtype=flat_tensor.dtype, device=flat_tensor.device + ) + dense[system_idx, atom_idx_in_system] = flat_tensor + return dense + + +def bfgs_init( + state: SimState | StateDict, + model: "ModelInterface", + *, + max_step: float = 0.2, + alpha: float = 70.0, +) -> "BFGSState": + """Create an initial BFGSState. + + Initializes the Hessian as Identity * alpha. + + Args: + state: Input state + model: Model + max_step: Maximum step size (Angstrom) + alpha: Initial Hessian stiffness (eV/A^2) + + Returns: + BFGSState + """ + from torch_sim.optimizers import BFGSState + + tensor_args = {"device": model.device, "dtype": model.dtype} + + if not isinstance(state, SimState): + state = SimState(**state) + + n_systems = state.n_systems + + counts = state.n_atoms_per_system + max_atoms = int(counts.max().item()) if len(counts) > 0 else 0 + atom_idx = _get_atom_indices_per_system(state.system_idx, n_systems) + + model_output = model(state) + energy = model_output["energy"] + forces = model_output["forces"] + stress = model_output["stress"] + + # shape: (n_systems, 3*max_atoms, 3*max_atoms) + dim = 3 * max_atoms + hessian = torch.eye(dim, **tensor_args).unsqueeze(0).repeat(n_systems, 1, 1) * alpha + + alpha_t = torch.full((n_systems,), alpha, **tensor_args) + max_step_t = torch.full((n_systems,), max_step, **tensor_args) + n_iter = torch.zeros((n_systems,), device=model.device, dtype=torch.int32) + + return BFGSState( + positions=state.positions.clone(), + masses=state.masses.clone(), + cell=state.cell.clone(), + atomic_numbers=state.atomic_numbers.clone(), + forces=forces, + energy=energy, + stress=stress, + hessian=hessian, + prev_forces=forces.clone(), + prev_positions=state.positions.clone(), + alpha=alpha_t, + max_step=max_step_t, + n_iter=n_iter, + atom_idx_in_system=atom_idx, + max_atoms=max_atoms, + # passed to __post_init__ + system_idx=state.system_idx.clone(), + pbc=state.pbc, + ) + + +def bfgs_step( + state: "BFGSState", + model: "ModelInterface", +) -> "BFGSState": + """Perform one BFGS optimization step. + + Updates the Hessian estimate and moves atoms. + + Args: + state: Current optimization state + model: Calculator model + + Returns: + Updated state + """ + eps = 1e-7 + + # Pack flat tensors into dense batched tensors + # shape: (n_systems, max_atoms * 3) + pos_new = _pad_to_dense( + state.positions, + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + state.max_atoms, + ).reshape(state.n_systems, -1) + + forces_new = _pad_to_dense( + state.forces, + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + state.max_atoms, + ).reshape(state.n_systems, -1) + + pos_old = _pad_to_dense( + state.prev_positions, + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + state.max_atoms, + ).reshape(state.n_systems, -1) + + forces_old = _pad_to_dense( + state.prev_forces, + state.system_idx, + state.atom_idx_in_system, + state.n_systems, + state.max_atoms, + ).reshape(state.n_systems, -1) + + # Calculate displacements and force changes + # dpos: (n_systems, max_atoms * 3) + dpos = pos_new - pos_old + dforces = -(forces_new - forces_old) + + # Identify systems with significant movement + max_disp = torch.max(torch.abs(dpos), dim=1).values + update_mask = max_disp >= eps + + # Update Hessian for active systems + if update_mask.any(): + idx = update_mask + H = state.hessian[idx] + + # shape: (n_active, D, 1) + dp = dpos[idx].unsqueeze(2) + df = dforces[idx].unsqueeze(2) # noqa: PD901 + + # shape: (n_active, 1) + a = torch.bmm(dp.transpose(1, 2), df).squeeze(2) + + # shape: (n_active, D, 1) + dg = torch.bmm(H, dp) + + # shape: (n_active, 1) + b = torch.bmm(dp.transpose(1, 2), dg).squeeze(2) + + # Rank-2 update + # shape: (n_active, D, D) + term1 = torch.bmm(df, df.transpose(1, 2)) / (a.unsqueeze(2) + 1e-30) + term2 = torch.bmm(dg, dg.transpose(1, 2)) / (b.unsqueeze(2) + 1e-30) + + state.hessian[idx] = H - term1 - term2 + + # Calculate step direction using eigendecomposition + # gradient: (n_systems, D, 1) + # Step p = H^-1 * F + direction = forces_new.unsqueeze(2) + + # omega: (n_systems, D), V: (n_systems, D, D) + omega, V = torch.linalg.eigh(state.hessian) + + # shape: (n_systems, 1, D) + abs_omega = torch.abs(omega).unsqueeze(1) + abs_omega = torch.where(abs_omega < 1e-30, torch.ones_like(abs_omega), abs_omega) + + # Project direction onto eigenvectors and scale + # shape: (n_systems, D, 1) + vt_g = torch.bmm(V.transpose(1, 2), direction) + scaled = vt_g / abs_omega.transpose(1, 2) + + # Transform back to original basis + # shape: (n_systems, D) + step_dense = torch.bmm(V, scaled).squeeze(2) + + # Scale step if it exceeds max_step + # step_atoms: (n_systems, max_atoms, 3) + step_atoms = step_dense.view(state.n_systems, state.max_atoms, 3) + # atom_norms: (n_systems, max_atoms) + atom_norms = torch.norm(step_atoms, dim=2) + + # max_disp_per_sys: (n_systems,) + max_disp_per_sys = torch.max(atom_norms, dim=1).values + + scale = torch.ones_like(max_disp_per_sys) + needs_scale = max_disp_per_sys > state.max_step + scale[needs_scale] = state.max_step[needs_scale] / ( + max_disp_per_sys[needs_scale] + 1e-30 + ) + + # shape: (n_systems, D) + step_dense = step_dense * scale.unsqueeze(1) + + # Unpack dense step back to flat valid atoms + flat_step = step_dense.view(state.n_systems, state.max_atoms, 3)[ + state.system_idx, state.atom_idx_in_system + ] + + new_positions = state.positions + flat_step + + state.prev_positions = state.positions.clone() + state.prev_forces = state.forces.clone() + state.positions = new_positions + + model_output = model(state) + state.forces = model_output["forces"] + state.energy = model_output["energy"] + state.stress = model_output["stress"] + state.n_iter += 1 + + return state diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py new file mode 100644 index 00000000..6c6408d3 --- /dev/null +++ b/torch_sim/optimizers/lbfgs.py @@ -0,0 +1,286 @@ +"""L-BFGS (Limited-memory BFGS) optimizer implementation. + +This module provides a batched L-BFGS optimizer for atomic structure relaxation. +L-BFGS is a quasi-Newton method that approximates the inverse Hessian using +a limited history of position and gradient differences, making it memory-efficient +for large systems while achieving superlinear convergence near the minimum. +""" + +from typing import TYPE_CHECKING + +import torch + +import torch_sim.math as tsm +from torch_sim.state import SimState +from torch_sim.typing import StateDict + + +if TYPE_CHECKING: + from torch_sim.models.interface import ModelInterface + from torch_sim.optimizers import LBFGSState + + +def lbfgs_init( + state: SimState | StateDict, + model: "ModelInterface", + *, + step_size: float = 0.1, + alpha: float | None = None, +) -> "LBFGSState": + r"""Create an initial LBFGSState from a SimState or state dict. + + Initializes forces/energy, clears the (s, y) memory, and broadcasts the + fixed step size to all systems. + + Args: + state: Input state as SimState object or state parameter dict + model: Model that computes energies, forces, and optionally stress + step_size: Fixed per-system step length (damping factor). + If using ASE mode (fixed alpha), set this to 1.0 (or your damping). + If using dynamic mode (default), 0.1 is a safe starting point. + alpha: Initial inverse Hessian stiffness guess (ASE parameter). + If provided (e.g. 70.0), fixes H0 = 1/alpha for all steps (ASE-style). + If None (default), H0 is updated dynamically (Standard L-BFGS). + + Returns: + LBFGSState with initialized optimization tensors + + Notes: + The optimizer supports two modes of operation: + 1. **Standard L-BFGS (default)**: Set `alpha=None`. The inverse Hessian + diagonal $H_0$ is updated dynamically at each step using the scaling + $\gamma_k = (s^T y) / (y^T y)$. This is the standard behavior described + by Nocedal & Wright. + 2. **ASE Compatibility Mode**: Set `alpha` (e.g. 70.0) and `step_size=1.0`. + The inverse Hessian diagonal is fixed at $H_0 = 1/\alpha$ throughout the + optimization, and the step is scaled by `step_size` (damping). + This matches `ase.optimize.LBFGS(alpha=70.0, damping=1.0)`. + """ + from torch_sim.optimizers import LBFGSState + + tensor_args = {"device": model.device, "dtype": model.dtype} + + if not isinstance(state, SimState): + state = SimState(**state) + + n_systems = state.n_systems + + # Get initial forces and energy from model + model_output = model(state) + energy = model_output["energy"] + forces = model_output["forces"] + stress = model_output["stress"] + + # Initialize empty history tensors + # History shape: [max_history, n_atoms, 3] but we start with 0 entries + s_history = torch.zeros((0, state.n_atoms, 3), **tensor_args) + y_history = torch.zeros((0, state.n_atoms, 3), **tensor_args) + + # Alpha tensor: 0.0 means dynamic, >0 means fixed + alpha_val = 0.0 if alpha is None else alpha + alpha_tensor = torch.full((n_systems,), alpha_val, **tensor_args) + + return LBFGSState( + # Copy SimState attributes + positions=state.positions.clone(), + masses=state.masses.clone(), + cell=state.cell.clone(), + atomic_numbers=state.atomic_numbers.clone(), + system_idx=state.system_idx.clone(), + pbc=state.pbc, + # Optimization state + forces=forces, + energy=energy, + stress=stress, + # L-BFGS specific state + prev_forces=forces.clone(), + prev_positions=state.positions.clone(), + s_history=s_history, + y_history=y_history, + step_size=torch.full((n_systems,), step_size, **tensor_args), + alpha=alpha_tensor, + n_iter=torch.zeros((n_systems,), device=model.device, dtype=torch.int32), + ) + + +def lbfgs_step( # noqa: PLR0915 + state: "LBFGSState", + model: "ModelInterface", + *, + max_history: int = 10, + max_step: float = 0.2, + curvature_eps: float = 1e-12, +) -> "LBFGSState": + r"""Advance one L-BFGS iteration using the two-loop recursion. + + Computes the search direction via the two-loop recursion, applies a + fixed step with optional per-system capping, evaluates new forces and + energy, and updates the limited-memory history with a curvature check. + + Algorithm (per system s): + 1) Evaluate gradient g_k = ∇E(x_k) = -f(x_k) + 2) Perform L-BFGS two-loop recursion using up to `max_history` pairs + (s_i, y_i) to compute d_k = -H_k g_k + 3) Fixed step update with optional per-system step capping by `max_step` + 4) Curvature check and history update: accept (s_k, y_k) if ⟨y_k, s_k⟩ > ε + + Args: + state: Current L-BFGS optimization state + model: Model that computes energies, forces, and optionally stress + max_history: Number of (s, y) pairs retained for the two-loop recursion. + max_step: If set, caps the maximum per-atom displacement per iteration. + curvature_eps: Threshold for the curvature ⟨y, s⟩ used to accept new + history pairs. + + Returns: + Updated LBFGSState after one optimization step + + Notes: + - If `state.alpha > 0` (ASE mode), the initial inverse Hessian estimate is + fixed at $H_0 = 1/\alpha$. + - Otherwise (Standard mode), $H_0$ varies at each step based on the + curvature of the most recent history pair. + + References: + - Nocedal & Wright, Numerical Optimization (L-BFGS two-loop recursion). + """ + device, dtype = model.device, model.dtype + eps = 1e-8 if dtype == torch.float32 else 1e-16 + + # Current gradient + g = -state.forces + + # Two-loop recursion to compute search direction d = -H_k g_k + q = g.clone() + alphas: list[torch.Tensor] = [] # per-history, shape [n_systems] + + # First loop (from newest to oldest) + for i in range(state.s_history.shape[0] - 1, -1, -1): + s_i = state.s_history[i] + y_i = state.y_history[i] + + ys = tsm.batched_vdot(y_i, s_i, state.system_idx) # y^T s per system + rho = torch.where( + ys.abs() > curvature_eps, + 1.0 / (ys + eps), + torch.zeros_like(ys), + ) + sq = tsm.batched_vdot(s_i, q, state.system_idx) + alpha = rho * sq + alphas.append(alpha) + + # q <- q - alpha * y_i (broadcast per system to atoms) + alpha_atom = alpha[state.system_idx].unsqueeze(-1) + q = q - alpha_atom * y_i + + # Initial H0 scaling: gamma = (s^T y)/(y^T y) using the last pair + # Dynamic gamma (Standard L-BFGS) + if state.s_history.shape[0] > 0: + s_last = state.s_history[-1] + y_last = state.y_history[-1] + sy = tsm.batched_vdot(s_last, y_last, state.system_idx) + yy = tsm.batched_vdot(y_last, y_last, state.system_idx) + gamma_dynamic = torch.where( + yy.abs() > curvature_eps, + sy / (yy + eps), + torch.ones_like(yy), + ) + else: + gamma_dynamic = torch.ones((state.n_systems,), device=device, dtype=dtype) + + # Fixed gamma (ASE style: 1/alpha) + # If state.alpha > 0, use that. Else use dynamic. + is_fixed = state.alpha > 1e-6 + gamma_fixed = 1.0 / (state.alpha + eps) + gamma = torch.where(is_fixed, gamma_fixed, gamma_dynamic) + + z = gamma[state.system_idx].unsqueeze(-1) * q + + # Second loop (from oldest to newest) + for i in range(state.s_history.shape[0]): + s_i = state.s_history[i] + y_i = state.y_history[i] + + ys = tsm.batched_vdot(y_i, s_i, state.system_idx) + rho = torch.where( + ys.abs() > curvature_eps, + 1.0 / (ys + eps), + torch.zeros_like(ys), + ) + yz = tsm.batched_vdot(y_i, z, state.system_idx) + beta = rho * yz + + alpha = alphas[state.s_history.shape[0] - 1 - i] + # z <- z + s_i * (alpha - beta) + coeff = (alpha - beta)[state.system_idx].unsqueeze(-1) + z = z + coeff * s_i + + d = -z # search direction + + # Optional per-system max step cap + # Compute per-atom step with current step_size + t_atoms = state.step_size[state.system_idx].unsqueeze(-1) + step = t_atoms * d + + # Per-atom norms + norms = torch.linalg.norm(step, dim=1) + + # Per-system max norm + sys_max = torch.zeros(state.n_systems, device=device, dtype=dtype) + sys_max.scatter_reduce_(0, state.system_idx, norms, reduce="amax", include_self=False) + + # Scaling factors per system: <= 1.0 + scale = torch.where( + sys_max > max_step, + max_step / (sys_max + eps), + torch.ones_like(sys_max), + ) + scale_atoms = scale[state.system_idx].unsqueeze(-1) + step = scale_atoms * step + + # Update positions + new_positions = state.positions + step + + # Evaluate new forces/energy + state.positions = new_positions + model_output = model(state) + new_forces = model_output["forces"] + new_energy = model_output["energy"] + new_stress = model_output["stress"] + + # Build new (s, y) + s_new = state.positions - state.prev_positions + y_new = -new_forces - (-state.prev_forces) # g_new - g_prev = -(f_new - f_prev) + + # Curvature check per system; if bad, clear history (conservative) + sy = tsm.batched_vdot(s_new, y_new, state.system_idx) + bad_curv = sy <= curvature_eps + + if bad_curv.any(): + # Clear entire history to preserve correctness + s_hist = torch.zeros((0, state.n_atoms, 3), device=device, dtype=dtype) + y_hist = torch.zeros((0, state.n_atoms, 3), device=device, dtype=dtype) + else: + # Append and trim if needed + if state.s_history.shape[0] == 0: + s_hist = s_new.unsqueeze(0) + y_hist = y_new.unsqueeze(0) + else: + s_hist = torch.cat([state.s_history, s_new.unsqueeze(0)], dim=0) + y_hist = torch.cat([state.y_history, y_new.unsqueeze(0)], dim=0) + if s_hist.shape[0] > max_history: + s_hist = s_hist[-max_history:] + y_hist = y_hist[-max_history:] + + # Update state + state.forces = new_forces + state.energy = new_energy + state.stress = new_stress + + state.prev_forces = new_forces.clone() + state.prev_positions = state.positions.clone() + state.s_history = s_hist + state.y_history = y_hist + state.n_iter = state.n_iter + 1 + + return state diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index 2ab530db..92ace560 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -40,4 +40,84 @@ class FireState(OptimState): _system_attributes = OptimState._system_attributes | {"dt", "alpha", "n_pos"} # noqa: SLF001 +@dataclass(kw_only=True) +class BFGSState(OptimState): + """State for batched BFGS optimization. + + Stores the state needed to run a batched BFGS optimizer that maintains + an approximate Hessian or inverse Hessian. + + Attributes: + hessian: Hessian matrix [n_systems, 3*max_atoms, 3*max_atoms] + prev_forces: Previous-step forces [n_atoms, 3] + prev_positions: Previous-step positions [n_atoms, 3] + alpha: Initial Hessian scale [n_systems] + max_step: Maximum step size [n_systems] + n_iter: Per-system iteration counter [n_systems] (int32) + atom_idx_in_system: Index of each atom within its system [n_atoms] + max_atoms: Maximum number of atoms in any system (int) + """ + + hessian: torch.Tensor + prev_forces: torch.Tensor + prev_positions: torch.Tensor + alpha: torch.Tensor + max_step: torch.Tensor + n_iter: torch.Tensor + atom_idx_in_system: torch.Tensor + max_atoms: int + + _atom_attributes = OptimState._atom_attributes | { # noqa: SLF001 + "prev_forces", + "prev_positions", + "atom_idx_in_system", + } + _system_attributes = OptimState._system_attributes | { # noqa: SLF001 + "hessian", + "alpha", + "max_step", + "n_iter", + "max_atoms", + } + + +@dataclass(kw_only=True) +class LBFGSState(OptimState): + """State for batched L-BFGS minimization (no line search). + + Stores the state needed to run a batched Limited-memory BFGS optimizer that + uses a fixed step size and the classical two-loop recursion to compute + approximate inverse-Hessian-vector products. All tensors are batched across + systems via `system_idx`. + + Attributes: + prev_forces: Previous-step forces [n_atoms, 3] + prev_positions: Previous-step positions [n_atoms, 3] + s_history: Displacement history [h, n_atoms, 3] + y_history: Gradient-diff history [h, n_atoms, 3] + step_size: Per-system fixed step size [n_systems] + n_iter: Per-system iteration counter [n_systems] (int32) + """ + + prev_forces: torch.Tensor + prev_positions: torch.Tensor + s_history: torch.Tensor + y_history: torch.Tensor + step_size: torch.Tensor + alpha: torch.Tensor + n_iter: torch.Tensor + + _atom_attributes = OptimState._atom_attributes | { # noqa: SLF001 + "prev_forces", + "prev_positions", + } + _system_attributes = OptimState._system_attributes | { # noqa: SLF001 + "s_history", + "y_history", + "step_size", + "alpha", + "n_iter", + } + + # there's no GradientDescentState, it's the same as OptimState