From b9dfda0c28e011fb11c7c3dc3486304a70bb79a4 Mon Sep 17 00:00:00 2001 From: Finlay Clark Date: Fri, 28 Nov 2025 17:14:07 +0000 Subject: [PATCH 1/4] Store regularization info in AttributeConfig and Trainable --- descent/tests/test_train.py | 163 ++++++++++++++++++++++++++++++++++++ descent/train.py | 102 +++++++++++++++++++--- 2 files changed, 252 insertions(+), 13 deletions(-) diff --git a/descent/tests/test_train.py b/descent/tests/test_train.py index 3884a55..69fc3d4 100644 --- a/descent/tests/test_train.py +++ b/descent/tests/test_train.py @@ -88,6 +88,23 @@ def test_validate_keys_limits(self): ): AttributeConfig(cols=["scale_14"], limits={"scale_15": (0.1, 0.2)}) + def test_validate_keys_regularize(self): + with pytest.raises( + pydantic.ValidationError, match="cannot regularize non-trainable parameters" + ): + AttributeConfig(cols=["scale_14"], regularize={"scale_15": 0.01}) + + def test_regularize_field(self): + config = AttributeConfig( + cols=["scale_14", "scale_15"], + regularize={"scale_14": 0.01, "scale_15": 0.001}, + ) + assert config.regularize == {"scale_14": 0.01, "scale_15": 0.001} + + def test_regularize_empty(self): + config = AttributeConfig(cols=["scale_14"]) + assert config.regularize == {} + class TestParameterConfig: def test_validate_include_exclude(self): @@ -268,3 +285,149 @@ def test_clamp(self, mock_ff, mock_parameter_configs, mock_attribute_configs): assert values.shape == expected_values.shape assert torch.allclose(values, expected_values) + + def test_regularized_idxs_no_regularization( + self, mock_ff, mock_parameter_configs, mock_attribute_configs + ): + trainable = Trainable( + mock_ff, + parameters=mock_parameter_configs, + attributes=mock_attribute_configs, + ) + + assert len(trainable.regularized_idxs) == 0 + assert len(trainable.regularization_weights) == 0 + + def test_regularized_idxs_with_parameter_regularization(self, mock_ff): + parameter_configs = { + "vdW": ParameterConfig( + cols=["epsilon", "sigma"], + regularize={"epsilon": 0.01, "sigma": 0.001}, + ), + } + attribute_configs = {} + + trainable = Trainable( + mock_ff, + parameters=parameter_configs, + attributes=attribute_configs, + ) + + # vdW has 2 parameters (2 rows), and we're regularizing both epsilon and sigma + # So we should have 4 regularized values total: 2 epsilons + 2 sigmas + assert len(trainable.regularized_idxs) == 4 + assert len(trainable.regularization_weights) == 4 + + # Check the weights match what we configured + # Interleaved: row 0 (eps, sig), row 1 (eps, sig) + expected_weights = torch.tensor( + [0.01, 0.001, 0.01, 0.001], dtype=trainable.regularization_weights.dtype + ) + assert torch.allclose(trainable.regularization_weights, expected_weights) + + def test_regularized_idxs_with_attribute_regularization(self, mock_ff): + parameter_configs = {} + attribute_configs = { + "vdW": AttributeConfig( + cols=["scale_14", "scale_15"], + regularize={"scale_14": 0.05}, + ) + } + + trainable = Trainable( + mock_ff, + parameters=parameter_configs, + attributes=attribute_configs, + ) + + # Only scale_14 should be regularized (1 attribute) + assert len(trainable.regularized_idxs) == 1 + assert len(trainable.regularization_weights) == 1 + + expected_weights = torch.tensor( + [0.05], dtype=trainable.regularization_weights.dtype + ) + assert torch.allclose(trainable.regularization_weights, expected_weights) + + def test_regularized_idxs_with_mixed_regularization(self, mock_ff): + parameter_configs = { + "vdW": ParameterConfig( + cols=["epsilon", "sigma"], + regularize={"epsilon": 0.02}, + include=[mock_ff.potentials_by_type["vdW"].parameter_keys[0]], + ), + } + attribute_configs = { + "vdW": AttributeConfig( + cols=["scale_14"], + regularize={"scale_14": 0.1}, + ) + } + + trainable = Trainable( + mock_ff, + parameters=parameter_configs, + attributes=attribute_configs, + ) + + # Only first vdW parameter row is included, with only epsilon regularized + # Plus scale_14 attribute + assert len(trainable.regularized_idxs) == 2 + assert len(trainable.regularization_weights) == 2 + + # First should be epsilon (0.02), second should be scale_14 (0.1) + expected_weights = torch.tensor( + [0.02, 0.1], dtype=trainable.regularization_weights.dtype + ) + assert torch.allclose(trainable.regularization_weights, expected_weights) + + def test_regularized_idxs_excluded_parameters(self, mock_ff): + parameter_configs = { + "Bonds": ParameterConfig( + cols=["k", "length"], + regularize={"k": 0.01, "length": 0.02}, + exclude=[mock_ff.potentials_by_type["Bonds"].parameter_keys[0]], + ), + } + attribute_configs = {} + + trainable = Trainable( + mock_ff, + parameters=parameter_configs, + attributes=attribute_configs, + ) + + # Only second bond parameter row should be included (first is excluded) + # Both k and length are regularized + assert len(trainable.regularized_idxs) == 2 + assert len(trainable.regularization_weights) == 2 + + expected_weights = torch.tensor( + [0.01, 0.02], dtype=trainable.regularization_weights.dtype + ) + assert torch.allclose(trainable.regularization_weights, expected_weights) + + def test_regularization_indices_match_unfrozen_values(self, mock_ff): + parameter_configs = { + "vdW": ParameterConfig( + cols=["epsilon"], + regularize={"epsilon": 0.01}, + ), + } + attribute_configs = {} + + trainable = Trainable( + mock_ff, + parameters=parameter_configs, + attributes=attribute_configs, + ) + + values = trainable.to_values() + + # Regularization indices should be valid indices into the unfrozen values + assert trainable.regularized_idxs.max() < len(values) + assert trainable.regularized_idxs.min() >= 0 + + # We should be able to index the values tensor with regularization indices + regularized_values = values[trainable.regularized_idxs] + assert len(regularized_values) == len(trainable.regularized_idxs) diff --git a/descent/train.py b/descent/train.py index 0aaea7e..5ad4292 100644 --- a/descent/train.py +++ b/descent/train.py @@ -63,9 +63,11 @@ def _convert_keys(value: typing.Any) -> typing.Any: return value value = [ - _PotentialKey(**v.dict()) - if isinstance(v, openff.interchange.models.PotentialKey) - else v + ( + _PotentialKey(**v.dict()) + if isinstance(v, openff.interchange.models.PotentialKey) + else v + ) for v in value ] return value @@ -94,6 +96,12 @@ class AttributeConfig(pydantic.BaseModel): "none indicates no constraint.", ) + regularize: dict[str, float] = pydantic.Field( + {}, + description="The regularization strength to apply to each parameter, e.g. " + "'k': 0.01, 'epsilon': 0.001. Parameters not listed are not regularized.", + ) + if pydantic.__version__.startswith("1."): @pydantic.root_validator @@ -102,11 +110,14 @@ def _validate_keys(cls, values): scales = values.get("scales") limits = values.get("limits") + regularize = values.get("regularize") if any(key not in cols for key in scales): raise ValueError("cannot scale non-trainable parameters") if any(key not in cols for key in limits): raise ValueError("cannot clamp non-trainable parameters") + if any(key not in cols for key in regularize): + raise ValueError("cannot regularize non-trainable parameters") return values @@ -122,6 +133,9 @@ def _validate_keys(self): if any(key not in self.cols for key in self.limits): raise ValueError("cannot clamp non-trainable parameters") + if any(key not in self.cols for key in self.regularize): + raise ValueError("cannot regularize non-trainable parameters") + return self @@ -208,13 +222,16 @@ def _prepare( clamp_lower = [] clamp_upper = [] + regularized_idxs = [] + regularization_weights = [] + for potential_type, potential in zip(potential_types, potentials, strict=True): potential_config = config[potential_type] potential_cols = getattr(potential, f"{attr[:-1]}_cols") - assert ( - len({*potential_config.cols} - {*potential_cols}) == 0 - ), f"unknown columns: {potential_cols}" + assert len({*potential_config.cols} - {*potential_cols}) == 0, ( + f"unknown columns: {potential_cols}" + ) potential_values = getattr(potential, attr).detach().clone() potential_values_flat = potential_values.flatten() @@ -242,13 +259,24 @@ def _prepare( key_to_row[key] for key in unfrozen_keys if key not in excluded_keys } - unfrozen_idxs.extend( - unfrozen_col_offset + col_idx + row_idx * potential_values.shape[-1] - for row_idx in range(n_rows) - if row_idx in unfrozen_rows - for col_idx, col in enumerate(potential_cols) - if col in potential_config.cols - ) + # Track unfrozen and regularized indices + for row_idx in range(n_rows): + if row_idx not in unfrozen_rows: + continue + for col_idx, col in enumerate(potential_cols): + if col not in potential_config.cols: + continue + + flat_idx = ( + unfrozen_col_offset + + col_idx + + row_idx * potential_values.shape[-1] + ) + unfrozen_idxs.append(flat_idx) + + if col in potential_config.regularize: + regularized_idxs.append(flat_idx) + regularization_weights.append(potential_config.regularize[col]) unfrozen_col_offset += len(potential_values_flat) @@ -290,6 +318,12 @@ def _prepare( smee.utils.tensor_like(scales, values), smee.utils.tensor_like(clamp_lower, values), smee.utils.tensor_like(clamp_upper, values), + torch.tensor(regularized_idxs), + ( + smee.utils.tensor_like(regularization_weights, values) + if regularization_weights + else smee.utils.tensor_like([], values) + ), ) def __init__( @@ -315,6 +349,8 @@ def __init__( param_scales, param_clamp_lower, param_clamp_upper, + param_regularized_idxs, + param_regularization_weights, ) = self._prepare(force_field, parameters, "parameters") ( self._attr_types, @@ -324,6 +360,8 @@ def __init__( attr_scales, attr_clamp_lower, attr_clamp_upper, + attr_regularized_idxs, + attr_regularization_weights, ) = self._prepare(force_field, attributes, "attributes") self._values = torch.cat([param_values, attr_values]) @@ -341,6 +379,34 @@ def __init__( self._unfrozen_idxs ] + # Store regularization information + all_regularized_idxs = torch.cat( + [param_regularized_idxs, attr_regularized_idxs + len(param_scales)] + ).long() + all_regularization_weights = torch.cat( + [param_regularization_weights, attr_regularization_weights] + ) + + # Map global indices to unfrozen indices + idx_mapping = {idx.item(): i for i, idx in enumerate(self._unfrozen_idxs)} + self._regularized_idxs = torch.tensor( + [ + idx_mapping[idx.item()] + for idx in all_regularized_idxs + if idx.item() in idx_mapping + ] + ).long() + regularization_weights = [ + all_regularization_weights[i] + for i, idx in enumerate(all_regularized_idxs) + if idx.item() in idx_mapping + ] + self._regularization_weights = ( + torch.stack(regularization_weights) + if regularization_weights + else torch.tensor([]) + ) + @torch.no_grad() def to_values(self) -> torch.Tensor: """Returns unfrozen parameter and attribute values as a flat tensor.""" @@ -381,3 +447,13 @@ def clamp(self, values_flat: torch.Tensor) -> torch.Tensor: return (values_flat / self._scales).clamp( min=self._clamp_lower, max=self._clamp_upper ) * self._scales + + @property + def regularized_idxs(self) -> torch.Tensor: + """The indices of parameters/attributes to regularize.""" + return self._regularized_idxs + + @property + def regularization_weights(self) -> torch.Tensor: + """The regularization weights for parameters/attributes to regularize.""" + return self._regularization_weights From feb75c851817b15b82f50ccfd85cb30811adaf71 Mon Sep 17 00:00:00 2001 From: Finlay Clark Date: Fri, 28 Nov 2025 17:16:08 +0000 Subject: [PATCH 2/4] Formatting --- descent/targets/dimers.py | 2 +- descent/targets/energy.py | 4 ++-- descent/targets/thermo.py | 2 +- descent/tests/targets/test_energy.py | 12 ++++++++---- descent/utils/loss.py | 2 +- 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/descent/targets/dimers.py b/descent/targets/dimers.py index a45ec82..f084dfd 100644 --- a/descent/targets/dimers.py +++ b/descent/targets/dimers.py @@ -409,7 +409,7 @@ def report( col: rmse_format for col in data_stats.columns if col.startswith("RMSE") } formatters_full = { - **{col: "html" for col in ["Dimer", "Energy [kcal/mol]"]}, + **dict.fromkeys(["Dimer", "Energy [kcal/mol]"], "html"), **{col: rmse_format for col in data_full.columns if col.startswith("RMSE")}, } diff --git a/descent/targets/energy.py b/descent/targets/energy.py index ff610d7..b9f89d1 100644 --- a/descent/targets/energy.py +++ b/descent/targets/energy.py @@ -116,8 +116,8 @@ def predict( ) coords = ( - coords_flat.reshape(len(energy_ref), -1, 3) - ).detach().requires_grad_(True) + (coords_flat.reshape(len(energy_ref), -1, 3)).detach().requires_grad_(True) + ) topology = topologies[smiles] energy_pred = smee.compute_energy(topology, force_field, coords) diff --git a/descent/targets/thermo.py b/descent/targets/thermo.py index 3cae352..0457146 100644 --- a/descent/targets/thermo.py +++ b/descent/targets/thermo.py @@ -703,7 +703,7 @@ def predict( verbose_rows.append( { - "type": f'{entry["type"]} [{entry["units"]}]', + "type": f"{entry['type']} [{entry['units']}]", "smiles_a": descent.utils.molecule.unmap_smiles(entry["smiles_a"]), "smiles_b": ( "" diff --git a/descent/tests/targets/test_energy.py b/descent/tests/targets/test_energy.py index cbd75c0..95dd2a8 100644 --- a/descent/tests/targets/test_energy.py +++ b/descent/tests/targets/test_energy.py @@ -78,7 +78,8 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): [9.0, 10.0, 11.0], [12.0, 13.0, 14.0], [15.0, 16.0, 17.0], - ], dtype=torch.float64 + ], + dtype=torch.float64, ) / math.sqrt(6.0 * 3.0), torch.tensor([7.899425506591797, -7.89942741394043]) / math.sqrt(2.0), @@ -90,7 +91,8 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): [0.0, -137.45770263671875, 0.0], [102.62999725341797, 68.72884368896484, 0.0], [-102.62999725341797, 68.72884368896484, 0.0], - ], dtype=torch.float64 + ], + dtype=torch.float64, ) / math.sqrt(6.0 * 3.0), ), @@ -106,7 +108,8 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): [9.0, 10.0, 11.0], [12.0, 13.0, 14.0], [15.0, 16.0, 17.0], - ], dtype=torch.float64 + ], + dtype=torch.float64, ), torch.tensor([0.0, -15.798852920532227]), -torch.tensor( @@ -117,7 +120,8 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): [0.0, -137.45770263671875, 0.0], [102.62999725341797, 68.72884368896484, 0.0], [-102.62999725341797, 68.72884368896484, 0.0], - ], dtype=torch.float64 + ], + dtype=torch.float64, ), ), ], diff --git a/descent/utils/loss.py b/descent/utils/loss.py index 5b019ab..bccc369 100644 --- a/descent/utils/loss.py +++ b/descent/utils/loss.py @@ -73,7 +73,7 @@ def combine_closures( A combined closure function. """ - weights = weights if weights is not None else {name: 1.0 for name in closures} + weights = weights if weights is not None else dict.fromkeys(closures, 1.0) if len(closures) == 0: raise NotImplementedError("At least one closure function is required.") From 17bad727f73cdf0c1864dd67d01e10c602641878 Mon Sep 17 00:00:00 2001 From: Finlay Clark Date: Thu, 11 Dec 2025 10:59:31 +0000 Subject: [PATCH 3/4] Clarify docstring --- descent/train.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/descent/train.py b/descent/train.py index 5ad4292..9fbee76 100644 --- a/descent/train.py +++ b/descent/train.py @@ -229,9 +229,9 @@ def _prepare( potential_config = config[potential_type] potential_cols = getattr(potential, f"{attr[:-1]}_cols") - assert len({*potential_config.cols} - {*potential_cols}) == 0, ( - f"unknown columns: {potential_cols}" - ) + assert ( + len({*potential_config.cols} - {*potential_cols}) == 0 + ), f"unknown columns: {potential_cols}" potential_values = getattr(potential, attr).detach().clone() potential_values_flat = potential_values.flatten() @@ -450,7 +450,8 @@ def clamp(self, values_flat: torch.Tensor) -> torch.Tensor: @property def regularized_idxs(self) -> torch.Tensor: - """The indices of parameters/attributes to regularize.""" + """The indices (within the tensor returned by to_values) + of parameters/attributes to regularize.""" return self._regularized_idxs @property From 2ffa43566eac6992199b5407216011001f981cd0 Mon Sep 17 00:00:00 2001 From: Finlay Clark Date: Thu, 11 Dec 2025 11:05:02 +0000 Subject: [PATCH 4/4] Make regularization tests more precise --- descent/tests/test_train.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/descent/tests/test_train.py b/descent/tests/test_train.py index 69fc3d4..1b445ee 100644 --- a/descent/tests/test_train.py +++ b/descent/tests/test_train.py @@ -315,8 +315,8 @@ def test_regularized_idxs_with_parameter_regularization(self, mock_ff): # vdW has 2 parameters (2 rows), and we're regularizing both epsilon and sigma # So we should have 4 regularized values total: 2 epsilons + 2 sigmas - assert len(trainable.regularized_idxs) == 4 - assert len(trainable.regularization_weights) == 4 + expected_idxs = torch.tensor([0, 1, 2, 3], dtype=torch.long) + assert torch.equal(trainable.regularized_idxs, expected_idxs) # Check the weights match what we configured # Interleaved: row 0 (eps, sig), row 1 (eps, sig) @@ -341,8 +341,8 @@ def test_regularized_idxs_with_attribute_regularization(self, mock_ff): ) # Only scale_14 should be regularized (1 attribute) - assert len(trainable.regularized_idxs) == 1 - assert len(trainable.regularization_weights) == 1 + expected_idxs = torch.tensor([0], dtype=torch.long) + assert torch.equal(trainable.regularized_idxs, expected_idxs) expected_weights = torch.tensor( [0.05], dtype=trainable.regularization_weights.dtype @@ -372,8 +372,8 @@ def test_regularized_idxs_with_mixed_regularization(self, mock_ff): # Only first vdW parameter row is included, with only epsilon regularized # Plus scale_14 attribute - assert len(trainable.regularized_idxs) == 2 - assert len(trainable.regularization_weights) == 2 + expected_idxs = torch.tensor([0, 2], dtype=torch.long) + assert torch.equal(trainable.regularized_idxs, expected_idxs) # First should be epsilon (0.02), second should be scale_14 (0.1) expected_weights = torch.tensor( @@ -399,8 +399,8 @@ def test_regularized_idxs_excluded_parameters(self, mock_ff): # Only second bond parameter row should be included (first is excluded) # Both k and length are regularized - assert len(trainable.regularized_idxs) == 2 - assert len(trainable.regularization_weights) == 2 + expected_idxs = torch.tensor([0, 1], dtype=torch.long) + assert torch.equal(trainable.regularized_idxs, expected_idxs) expected_weights = torch.tensor( [0.01, 0.02], dtype=trainable.regularization_weights.dtype