Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def report(
col: rmse_format for col in data_stats.columns if col.startswith("RMSE")
}
formatters_full = {
**{col: "html" for col in ["Dimer", "Energy [kcal/mol]"]},
**dict.fromkeys(["Dimer", "Energy [kcal/mol]"], "html"),
**{col: rmse_format for col in data_full.columns if col.startswith("RMSE")},
}

Expand Down
4 changes: 2 additions & 2 deletions descent/targets/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def predict(
)

coords = (
coords_flat.reshape(len(energy_ref), -1, 3)
).detach().requires_grad_(True)
(coords_flat.reshape(len(energy_ref), -1, 3)).detach().requires_grad_(True)
)
topology = topologies[smiles]

energy_pred = smee.compute_energy(topology, force_field, coords)
Expand Down
2 changes: 1 addition & 1 deletion descent/targets/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def predict(

verbose_rows.append(
{
"type": f'{entry["type"]} [{entry["units"]}]',
"type": f"{entry['type']} [{entry['units']}]",
"smiles_a": descent.utils.molecule.unmap_smiles(entry["smiles_a"]),
"smiles_b": (
""
Expand Down
12 changes: 8 additions & 4 deletions descent/tests/targets/test_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
[9.0, 10.0, 11.0],
[12.0, 13.0, 14.0],
[15.0, 16.0, 17.0],
], dtype=torch.float64
],
dtype=torch.float64,
)
/ math.sqrt(6.0 * 3.0),
torch.tensor([7.899425506591797, -7.89942741394043]) / math.sqrt(2.0),
Expand All @@ -90,7 +91,8 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
[0.0, -137.45770263671875, 0.0],
[102.62999725341797, 68.72884368896484, 0.0],
[-102.62999725341797, 68.72884368896484, 0.0],
], dtype=torch.float64
],
dtype=torch.float64,
)
/ math.sqrt(6.0 * 3.0),
),
Expand All @@ -106,7 +108,8 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
[9.0, 10.0, 11.0],
[12.0, 13.0, 14.0],
[15.0, 16.0, 17.0],
], dtype=torch.float64
],
dtype=torch.float64,
),
torch.tensor([0.0, -15.798852920532227]),
-torch.tensor(
Expand All @@ -117,7 +120,8 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
[0.0, -137.45770263671875, 0.0],
[102.62999725341797, 68.72884368896484, 0.0],
[-102.62999725341797, 68.72884368896484, 0.0],
], dtype=torch.float64
],
dtype=torch.float64,
),
),
],
Expand Down
163 changes: 163 additions & 0 deletions descent/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ def test_validate_keys_limits(self):
):
AttributeConfig(cols=["scale_14"], limits={"scale_15": (0.1, 0.2)})

def test_validate_keys_regularize(self):
with pytest.raises(
pydantic.ValidationError, match="cannot regularize non-trainable parameters"
):
AttributeConfig(cols=["scale_14"], regularize={"scale_15": 0.01})

def test_regularize_field(self):
config = AttributeConfig(
cols=["scale_14", "scale_15"],
regularize={"scale_14": 0.01, "scale_15": 0.001},
)
assert config.regularize == {"scale_14": 0.01, "scale_15": 0.001}

def test_regularize_empty(self):
config = AttributeConfig(cols=["scale_14"])
assert config.regularize == {}


class TestParameterConfig:
def test_validate_include_exclude(self):
Expand Down Expand Up @@ -268,3 +285,149 @@ def test_clamp(self, mock_ff, mock_parameter_configs, mock_attribute_configs):

assert values.shape == expected_values.shape
assert torch.allclose(values, expected_values)

def test_regularized_idxs_no_regularization(
self, mock_ff, mock_parameter_configs, mock_attribute_configs
):
trainable = Trainable(
mock_ff,
parameters=mock_parameter_configs,
attributes=mock_attribute_configs,
)

assert len(trainable.regularized_idxs) == 0
assert len(trainable.regularization_weights) == 0

def test_regularized_idxs_with_parameter_regularization(self, mock_ff):
parameter_configs = {
"vdW": ParameterConfig(
cols=["epsilon", "sigma"],
regularize={"epsilon": 0.01, "sigma": 0.001},
),
}
attribute_configs = {}

trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
)

# vdW has 2 parameters (2 rows), and we're regularizing both epsilon and sigma
# So we should have 4 regularized values total: 2 epsilons + 2 sigmas
expected_idxs = torch.tensor([0, 1, 2, 3], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)

# Check the weights match what we configured
# Interleaved: row 0 (eps, sig), row 1 (eps, sig)
expected_weights = torch.tensor(
[0.01, 0.001, 0.01, 0.001], dtype=trainable.regularization_weights.dtype
)
assert torch.allclose(trainable.regularization_weights, expected_weights)

def test_regularized_idxs_with_attribute_regularization(self, mock_ff):
parameter_configs = {}
attribute_configs = {
"vdW": AttributeConfig(
cols=["scale_14", "scale_15"],
regularize={"scale_14": 0.05},
)
}

trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
)

# Only scale_14 should be regularized (1 attribute)
expected_idxs = torch.tensor([0], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)

expected_weights = torch.tensor(
[0.05], dtype=trainable.regularization_weights.dtype
)
assert torch.allclose(trainable.regularization_weights, expected_weights)

def test_regularized_idxs_with_mixed_regularization(self, mock_ff):
parameter_configs = {
"vdW": ParameterConfig(
cols=["epsilon", "sigma"],
regularize={"epsilon": 0.02},
include=[mock_ff.potentials_by_type["vdW"].parameter_keys[0]],
),
}
attribute_configs = {
"vdW": AttributeConfig(
cols=["scale_14"],
regularize={"scale_14": 0.1},
)
}

trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
)

# Only first vdW parameter row is included, with only epsilon regularized
# Plus scale_14 attribute
expected_idxs = torch.tensor([0, 2], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)

# First should be epsilon (0.02), second should be scale_14 (0.1)
expected_weights = torch.tensor(
[0.02, 0.1], dtype=trainable.regularization_weights.dtype
)
assert torch.allclose(trainable.regularization_weights, expected_weights)

def test_regularized_idxs_excluded_parameters(self, mock_ff):
parameter_configs = {
"Bonds": ParameterConfig(
cols=["k", "length"],
regularize={"k": 0.01, "length": 0.02},
exclude=[mock_ff.potentials_by_type["Bonds"].parameter_keys[0]],
),
}
attribute_configs = {}

trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
)

# Only second bond parameter row should be included (first is excluded)
# Both k and length are regularized
expected_idxs = torch.tensor([0, 1], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)

expected_weights = torch.tensor(
[0.01, 0.02], dtype=trainable.regularization_weights.dtype
)
assert torch.allclose(trainable.regularization_weights, expected_weights)

def test_regularization_indices_match_unfrozen_values(self, mock_ff):
parameter_configs = {
"vdW": ParameterConfig(
cols=["epsilon"],
regularize={"epsilon": 0.01},
),
}
attribute_configs = {}

trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
)

values = trainable.to_values()

# Regularization indices should be valid indices into the unfrozen values
assert trainable.regularized_idxs.max() < len(values)
assert trainable.regularized_idxs.min() >= 0

# We should be able to index the values tensor with regularization indices
regularized_values = values[trainable.regularized_idxs]
assert len(regularized_values) == len(trainable.regularized_idxs)
Loading