Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion descent/optim/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,14 +529,15 @@ def levenberg_marquardt(
loss_history = []
has_converged = False

best_x, best_loss = x.clone(), closure_prev[0]
best_x, best_loss = x.clone().detach(), closure_prev[0]

for step in range(config.max_steps):
loss_prev, gradient_prev, hessian_prev = closure_prev

dx, expected_improvement, damping_adjusted, damping_factor = _step(
gradient_prev, hessian_prev, trust_radius, config
)
_LOGGER.info(f"The current step is {dx.detach().cpu().numpy()}")

if config.mode.lower() == _HESSIAN_SEARCH:
dx, expected_improvement = _hessian_diagonal_search(
Expand All @@ -553,6 +554,7 @@ def levenberg_marquardt(
_LOGGER.info(f"{config.mode} step found (length {dx_norm:.4e})")

x_next = correct_fn(x + dx).requires_grad_(x.requires_grad)
_LOGGER.info(f"The next set of parameters to try {x_next.detach().cpu().numpy()}")

loss, gradient, hessian = closure_fn(x_next, True, True)
loss_delta = loss - loss_prev
Expand Down
71 changes: 41 additions & 30 deletions descent/targets/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,32 +318,25 @@ def _bulk_config(temperature: float, pressure: float) -> SimulationConfig:
pressure = pressure * openmm.unit.atmosphere

return SimulationConfig(
max_mols=256,
max_mols=1000,
gen_coords=smee.mm.GenerateCoordsConfig(),
equilibrate=[
smee.mm.MinimizationConfig(),
# short NVT equilibration simulation
smee.mm.SimulationConfig(
temperature=temperature,
pressure=None,
n_steps=50000,
timestep=1.0 * openmm.unit.femtosecond,
),
# short NPT equilibration simulation
smee.mm.SimulationConfig(
temperature=temperature,
pressure=pressure,
n_steps=50000,
timestep=1.0 * openmm.unit.femtosecond,
n_steps=100000,
timestep=2.0 * openmm.unit.femtosecond,
),
],
production=smee.mm.SimulationConfig(
temperature=temperature,
pressure=pressure,
n_steps=500000,
n_steps=1000000,
timestep=2.0 * openmm.unit.femtosecond,
),
production_frequency=1000,
production_frequency=2000,
)


Expand Down Expand Up @@ -758,23 +751,41 @@ def closure_fn(
compute_hessian: bool,
):
force_field = trainable.to_force_field(x)

y_ref, _, y_pred, _ = descent.targets.thermo.predict(
dataset,
force_field,
topologies,
pathlib.Path.cwd(),
None,
per_type_scales,
verbose,
)
loss, gradient, hessian = ((y_pred - y_ref) ** 2).sum(), None, None

if compute_hessian:
hessian = descent.utils.loss.approximate_hessian(x, y_pred)
if compute_gradient:
gradient = torch.autograd.grad(loss, x, retain_graph=True)[0].detach()

return loss.detach(), gradient, hessian
total_loss, grad, hess = torch.zeros(size=(1,), device=x.device.type), None, None
for i in range(len(dataset)):
y_ref, _, y_pred, _ = descent.targets.thermo.predict(
dataset.select(indices=[i]),
force_field,
topologies,
pathlib.Path.cwd(),
None,
per_type_scales,
verbose,
)
loss = (y_pred - y_ref) ** 2

if compute_hessian:
hessian = descent.utils.loss.approximate_hessian(x, y_pred).detach()
if hess is None:
hess = hessian
else:
hess += hessian
if compute_gradient:
gradient = torch.autograd.grad(loss, x, retain_graph=True)[0].detach()
if grad is None:
grad = gradient
else:
grad += gradient

total_loss += loss.detach()
# clear the graph
torch.cuda.empty_cache()

# if compute_gradient:
# grad = sum(grad[1:], grad[0]).detach()
# if compute_hessian:
# hess = sum(hess[1:], hess[0]).detach()
# total_loss = sum(total_loss[1:], total_loss[0]).detach()
return total_loss, grad, hess

return closure_fn
2 changes: 1 addition & 1 deletion descent/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def combined_closure_fn(

if verbose:
verbose_rows.append(
{"target": name, "loss": float(f"{local_loss:.5f}")}
{"target": name, "loss": float(f"{local_loss.item():.5f}")}
)

loss = sum(loss[1:], loss[0])
Expand Down