-
Notifications
You must be signed in to change notification settings - Fork 74
Implement lbfgs and bfgs #365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
See detailed comment here. |
orionarcher
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Going to leave the physics review to someone else, just one API comment
curtischong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just learned the algorithms today so I may have missed things. Great work, especially because the ase implementation isn't very documented.
| Returns: | ||
| Updated state | ||
| """ | ||
| eps = 1e-7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One idea is we can surface this eps up as a param to the function (maybe change it to update_hessian_eps) so people can choose to always update the approximated hessian for all systems. If we choose to do this, I think we should maybe increase the + 1e-30 epsilon (when we divide the BFGS update terms by a and b) to 1e-7 (similar to ase - or we do eps = 1e-8 if dtype == torch.float32 else 1e-16 - similar to our lbfgs implementation) so the blowup isn't as large when we divide by a number near 0
| sy = tsm.batched_vdot(s_new, y_new, state.system_idx) | ||
| bad_curv = sy <= curvature_eps | ||
|
|
||
| if bad_curv.any(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can replace this section with
if bad_curv.any():
# Clear history for systems with bad curves to preserve correctness
keep_mask = ~bad_curv[state.system_idx].unsqueeze(1) # [n_atoms, 1]
state.s_history = state.s_history * keep_mask
state.y_history = state.y_history * keep_mask
s_new = s_new * keep_mask # set to 0 for bad curves (no-op)
y_new = y_new * keep_mask
# Append and trim if needed
if state.s_history.shape[0] == 0:
s_hist = s_new.unsqueeze(0)
y_hist = y_new.unsqueeze(0)
else:
s_hist = torch.cat([state.s_history, s_new.unsqueeze(0)], dim=0)
y_hist = torch.cat([state.y_history, y_new.unsqueeze(0)], dim=0)
if s_hist.shape[0] > max_history:
s_hist = s_hist[-max_history:]
y_hist = y_hist[-max_history:]
so we only clear out the history for systems with bad curves.
| state: "LBFGSState", | ||
| model: "ModelInterface", | ||
| *, | ||
| max_history: int = 10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we can set this during the init. I don't see people changing this value throughout the optimization. Also, it allows us to preallocate the history tensors
| alpha = alphas[state.s_history.shape[0] - 1 - i] | ||
| # z <- z + s_i * (alpha - beta) | ||
| coeff = (alpha - beta)[state.system_idx].unsqueeze(-1) | ||
| z = z + coeff * s_i |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: change to z = z + s_i * coeff so it matches the comment
| sys_max = torch.zeros(state.n_systems, device=device, dtype=dtype) | ||
| sys_max.scatter_reduce_(0, state.system_idx, norms, reduce="amax", include_self=False) | ||
|
|
||
| # Scaling factors per system: <= 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider saying: "scale down step so atoms move at most max_step" or "# Scale step if it exceeds max_step" like you mention in bfgs
| direction = forces_new.unsqueeze(2) | ||
|
|
||
| # omega: (n_systems, D), V: (n_systems, D, D) | ||
| omega, V = torch.linalg.eigh(state.hessian) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ik that ase doesn't do this, but I'd appreciate a comment here saying that we do an eigendecomp (rather than simply updating the inverse Hessian directly) so we can get the eigenvalues and do torch.abs(omega) to ensure that the approximated hessian is always positive definite (and we're always going downwards - which is more important for atomistic systems). I was a bit confused when this didn't match the vanilla BFGS algorithm
|
Thanks for the comments. I will update those over the weekend. |
Summary
Checklist
Before a pull request can be merged, the following items must be checked: