Skip to content

Conversation

@abhijeetgangan
Copy link
Collaborator

Summary

  • Add L-BFGS and BFGS optimization schemes.

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
  • Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

@abhijeetgangan
Copy link
Collaborator Author

See detailed comment here.

Copy link
Collaborator

@orionarcher orionarcher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going to leave the physics review to someone else, just one API comment

Copy link
Collaborator

@curtischong curtischong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just learned the algorithms today so I may have missed things. Great work, especially because the ase implementation isn't very documented.

Returns:
Updated state
"""
eps = 1e-7
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea is we can surface this eps up as a param to the function (maybe change it to update_hessian_eps) so people can choose to always update the approximated hessian for all systems. If we choose to do this, I think we should maybe increase the + 1e-30 epsilon (when we divide the BFGS update terms by a and b) to 1e-7 (similar to ase - or we do eps = 1e-8 if dtype == torch.float32 else 1e-16 - similar to our lbfgs implementation) so the blowup isn't as large when we divide by a number near 0

sy = tsm.batched_vdot(s_new, y_new, state.system_idx)
bad_curv = sy <= curvature_eps

if bad_curv.any():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can replace this section with

if bad_curv.any():
    # Clear history for systems with bad curves to preserve correctness
    keep_mask = ~bad_curv[state.system_idx].unsqueeze(1)  # [n_atoms, 1]
    state.s_history = state.s_history * keep_mask
    state.y_history = state.y_history * keep_mask

    s_new = s_new * keep_mask # set to 0 for bad curves (no-op)
    y_new = y_new * keep_mask

# Append and trim if needed
if state.s_history.shape[0] == 0:
    s_hist = s_new.unsqueeze(0)
    y_hist = y_new.unsqueeze(0)
else:
    s_hist = torch.cat([state.s_history, s_new.unsqueeze(0)], dim=0)
    y_hist = torch.cat([state.y_history, y_new.unsqueeze(0)], dim=0)
if s_hist.shape[0] > max_history:
    s_hist = s_hist[-max_history:]
    y_hist = y_hist[-max_history:]

so we only clear out the history for systems with bad curves.

state: "LBFGSState",
model: "ModelInterface",
*,
max_history: int = 10,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we can set this during the init. I don't see people changing this value throughout the optimization. Also, it allows us to preallocate the history tensors

alpha = alphas[state.s_history.shape[0] - 1 - i]
# z <- z + s_i * (alpha - beta)
coeff = (alpha - beta)[state.system_idx].unsqueeze(-1)
z = z + coeff * s_i
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change to z = z + s_i * coeff so it matches the comment

sys_max = torch.zeros(state.n_systems, device=device, dtype=dtype)
sys_max.scatter_reduce_(0, state.system_idx, norms, reduce="amax", include_self=False)

# Scaling factors per system: <= 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider saying: "scale down step so atoms move at most max_step" or "# Scale step if it exceeds max_step" like you mention in bfgs

direction = forces_new.unsqueeze(2)

# omega: (n_systems, D), V: (n_systems, D, D)
omega, V = torch.linalg.eigh(state.hessian)
Copy link
Collaborator

@curtischong curtischong Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ik that ase doesn't do this, but I'd appreciate a comment here saying that we do an eigendecomp (rather than simply updating the inverse Hessian directly) so we can get the eigenvalues and do torch.abs(omega) to ensure that the approximated hessian is always positive definite (and we're always going downwards - which is more important for atomistic systems). I was a bit confused when this didn't match the vanilla BFGS algorithm

@abhijeetgangan
Copy link
Collaborator Author

Thanks for the comments. I will update those over the weekend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants