Skip to content

Commit d0cbc79

Browse files
committed
fix: minimal test fix
1 parent 6a0f226 commit d0cbc79

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

tests/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@
1818
DTYPE = torch.float64
1919

2020

21+
# TODO: recreate these to minimize changes in #274 to be deleted in #264
22+
@pytest.fixture
23+
def device() -> torch.device:
24+
"""Fixture for torch.device."""
25+
return DEVICE
26+
27+
28+
@pytest.fixture
29+
def dtype() -> torch.dtype:
30+
"""Fixture for torch.dtype."""
31+
return DTYPE
32+
33+
2134
@pytest.fixture
2235
def lj_model() -> LennardJonesModel:
2336
"""Create a Lennard-Jones model with reasonable parameters for Ar."""

tests/test_optimizers_vs_ase.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import traceback
12
from typing import TYPE_CHECKING, Any
23

34
import pytest
@@ -8,14 +9,52 @@
89

910
import torch_sim as ts
1011
from torch_sim.io import atoms_to_state, state_to_atoms, state_to_structures
11-
from torch_sim.models.mace import MaceModel
12+
from torch_sim.models.mace import MaceModel, MaceUrls
1213
from torch_sim.optimizers import frechet_cell_fire, unit_cell_fire
1314

1415

1516
if TYPE_CHECKING:
1617
from mace.calculators import MACECalculator
1718

1819

20+
@pytest.fixture
21+
def ts_mace_mpa() -> MaceModel:
22+
"""Provides a MACE MP model instance for the optimizer tests."""
23+
try:
24+
from mace.calculators.foundations_models import mace_mp
25+
except ImportError:
26+
pytest.skip(
27+
f"MACE not installed: {traceback.format_exc()}", allow_module_level=True
28+
)
29+
30+
# Use float64 for potentially higher precision needed in optimization
31+
dtype = getattr(torch, dtype_str := "float64")
32+
raw_mace = mace_mp(
33+
model=MaceUrls.mace_mp_small, return_raw_model=True, default_dtype=dtype_str
34+
)
35+
return MaceModel(
36+
model=raw_mace,
37+
device=torch.device("cpu"),
38+
dtype=dtype,
39+
compute_forces=True,
40+
compute_stress=True,
41+
)
42+
43+
44+
@pytest.fixture
45+
def ase_mace_mpa() -> "MACECalculator":
46+
"""Provides an ASE MACECalculator instance using mace_mp."""
47+
try:
48+
from mace.calculators.foundations_models import mace_mp
49+
except ImportError:
50+
pytest.skip(
51+
f"MACE not installed: {traceback.format_exc()}", allow_module_level=True
52+
)
53+
54+
# Ensure dtype matches the one used in the torch-sim fixture (float64)
55+
return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64")
56+
57+
1958
def _compare_ase_and_ts_states(
2059
ts_current_system_state: ts.state.SimState,
2160
filtered_ase_atoms_for_run: Any,
@@ -77,7 +116,7 @@ def _compare_ase_and_ts_states(
77116

78117
def _run_and_compare_optimizers(
79118
initial_sim_state_fixture: ts.state.SimState,
80-
torchsim_mace_mpa: MaceModel,
119+
ts_mace_mpa: MaceModel,
81120
ase_mace_mpa: "MACECalculator",
82121
torch_sim_optimizer_type: str,
83122
ase_filter_class: Any,
@@ -89,7 +128,7 @@ def _run_and_compare_optimizers(
89128
"""Run and compare optimizations between torch-sim and ASE."""
90129
pytest.importorskip("mace")
91130
dtype = torch.float64
92-
device = torchsim_mace_mpa.device
131+
device = ts_mace_mpa.device
93132

94133
ts_current_system_state = initial_sim_state_fixture.clone()
95134

@@ -117,7 +156,7 @@ def _run_and_compare_optimizers(
117156
force_tol=force_tol, include_cell_forces=True
118157
)
119158

120-
results = torchsim_mace_mpa(ts_current_system_state)
159+
results = ts_mace_mpa(ts_current_system_state)
121160
ts_initial_system_state = ts_current_system_state.clone()
122161
ts_initial_system_state.forces = results["forces"]
123162
ts_initial_system_state.energy = results["energy"]
@@ -136,7 +175,7 @@ def _run_and_compare_optimizers(
136175
if steps_for_current_segment > 0:
137176
updated_ts_state = ts.optimize(
138177
system=ts_current_system_state,
139-
model=torchsim_mace_mpa,
178+
model=ts_mace_mpa,
140179
optimizer=optimizer_callable_for_ts_optimize,
141180
max_steps=steps_for_current_segment,
142181
convergence_fn=convergence_fn,
@@ -269,7 +308,7 @@ def test_optimizer_vs_ase_parametrized(
269308
force_tol: float,
270309
tolerances: dict[str, float],
271310
test_id_prefix: str,
272-
torchsim_mace_mpa: MaceModel,
311+
ts_mace_mpa: MaceModel,
273312
ase_mace_mpa: "MACECalculator",
274313
request: pytest.FixtureRequest,
275314
) -> None:
@@ -279,7 +318,7 @@ def test_optimizer_vs_ase_parametrized(
279318

280319
_run_and_compare_optimizers(
281320
initial_sim_state_fixture=initial_sim_state_fixture,
282-
torchsim_mace_mpa=torchsim_mace_mpa,
321+
ts_mace_mpa=ts_mace_mpa,
283322
ase_mace_mpa=ase_mace_mpa,
284323
torch_sim_optimizer_type=torch_sim_optimizer_type,
285324
ase_filter_class=ase_filter_class,

0 commit comments

Comments
 (0)