1+ import traceback
12from typing import TYPE_CHECKING , Any
23
34import pytest
89
910import torch_sim as ts
1011from torch_sim .io import atoms_to_state , state_to_atoms , state_to_structures
11- from torch_sim .models .mace import MaceModel
12+ from torch_sim .models .mace import MaceModel , MaceUrls
1213from torch_sim .optimizers import frechet_cell_fire , unit_cell_fire
1314
1415
1516if TYPE_CHECKING :
1617 from mace .calculators import MACECalculator
1718
1819
20+ @pytest .fixture
21+ def ts_mace_mpa () -> MaceModel :
22+ """Provides a MACE MP model instance for the optimizer tests."""
23+ try :
24+ from mace .calculators .foundations_models import mace_mp
25+ except ImportError :
26+ pytest .skip (
27+ f"MACE not installed: { traceback .format_exc ()} " , allow_module_level = True
28+ )
29+
30+ # Use float64 for potentially higher precision needed in optimization
31+ dtype = getattr (torch , dtype_str := "float64" )
32+ raw_mace = mace_mp (
33+ model = MaceUrls .mace_mp_small , return_raw_model = True , default_dtype = dtype_str
34+ )
35+ return MaceModel (
36+ model = raw_mace ,
37+ device = torch .device ("cpu" ),
38+ dtype = dtype ,
39+ compute_forces = True ,
40+ compute_stress = True ,
41+ )
42+
43+
44+ @pytest .fixture
45+ def ase_mace_mpa () -> "MACECalculator" :
46+ """Provides an ASE MACECalculator instance using mace_mp."""
47+ try :
48+ from mace .calculators .foundations_models import mace_mp
49+ except ImportError :
50+ pytest .skip (
51+ f"MACE not installed: { traceback .format_exc ()} " , allow_module_level = True
52+ )
53+
54+ # Ensure dtype matches the one used in the torch-sim fixture (float64)
55+ return mace_mp (model = MaceUrls .mace_mp_small , default_dtype = "float64" )
56+
57+
1958def _compare_ase_and_ts_states (
2059 ts_current_system_state : ts .state .SimState ,
2160 filtered_ase_atoms_for_run : Any ,
@@ -77,7 +116,7 @@ def _compare_ase_and_ts_states(
77116
78117def _run_and_compare_optimizers (
79118 initial_sim_state_fixture : ts .state .SimState ,
80- torchsim_mace_mpa : MaceModel ,
119+ ts_mace_mpa : MaceModel ,
81120 ase_mace_mpa : "MACECalculator" ,
82121 torch_sim_optimizer_type : str ,
83122 ase_filter_class : Any ,
@@ -89,7 +128,7 @@ def _run_and_compare_optimizers(
89128 """Run and compare optimizations between torch-sim and ASE."""
90129 pytest .importorskip ("mace" )
91130 dtype = torch .float64
92- device = torchsim_mace_mpa .device
131+ device = ts_mace_mpa .device
93132
94133 ts_current_system_state = initial_sim_state_fixture .clone ()
95134
@@ -117,7 +156,7 @@ def _run_and_compare_optimizers(
117156 force_tol = force_tol , include_cell_forces = True
118157 )
119158
120- results = torchsim_mace_mpa (ts_current_system_state )
159+ results = ts_mace_mpa (ts_current_system_state )
121160 ts_initial_system_state = ts_current_system_state .clone ()
122161 ts_initial_system_state .forces = results ["forces" ]
123162 ts_initial_system_state .energy = results ["energy" ]
@@ -136,7 +175,7 @@ def _run_and_compare_optimizers(
136175 if steps_for_current_segment > 0 :
137176 updated_ts_state = ts .optimize (
138177 system = ts_current_system_state ,
139- model = torchsim_mace_mpa ,
178+ model = ts_mace_mpa ,
140179 optimizer = optimizer_callable_for_ts_optimize ,
141180 max_steps = steps_for_current_segment ,
142181 convergence_fn = convergence_fn ,
@@ -269,7 +308,7 @@ def test_optimizer_vs_ase_parametrized(
269308 force_tol : float ,
270309 tolerances : dict [str , float ],
271310 test_id_prefix : str ,
272- torchsim_mace_mpa : MaceModel ,
311+ ts_mace_mpa : MaceModel ,
273312 ase_mace_mpa : "MACECalculator" ,
274313 request : pytest .FixtureRequest ,
275314) -> None :
@@ -279,7 +318,7 @@ def test_optimizer_vs_ase_parametrized(
279318
280319 _run_and_compare_optimizers (
281320 initial_sim_state_fixture = initial_sim_state_fixture ,
282- torchsim_mace_mpa = torchsim_mace_mpa ,
321+ ts_mace_mpa = ts_mace_mpa ,
283322 ase_mace_mpa = ase_mace_mpa ,
284323 torch_sim_optimizer_type = torch_sim_optimizer_type ,
285324 ase_filter_class = ase_filter_class ,
0 commit comments