|
1 | 1 | import traceback |
2 | | -import urllib.request |
3 | | -from enum import StrEnum |
4 | 2 | from pathlib import Path |
5 | 3 |
|
6 | 4 | import pytest |
7 | 5 |
|
8 | | -from tests.conftest import DEVICE |
9 | | -from tests.models.conftest import make_model_calculator_consistency_test |
| 6 | +from tests.conftest import DEVICE, DTYPE |
| 7 | +from tests.models.conftest import ( |
| 8 | + consistency_test_simstate_fixtures, |
| 9 | + make_model_calculator_consistency_test, |
| 10 | + make_validate_model_outputs_test, |
| 11 | +) |
10 | 12 |
|
11 | 13 |
|
12 | 14 | try: |
13 | 15 | from nequip.ase import NequIPCalculator |
| 16 | + from nequip.scripts.compile import main |
14 | 17 |
|
15 | 18 | from torch_sim.models.nequip_framework import ( |
16 | 19 | NequIPFrameworkModel, |
|
22 | 25 | ) |
23 | 26 |
|
24 | 27 |
|
25 | | -class NequIPUrls(StrEnum): |
26 | | - """Checkpoint download URLs for NequIP models.""" |
27 | | - |
28 | | - Si = "https://github.com/abhijeetgangan/pt_model_checkpoints/raw/refs/heads/main/nequip/Si.nequip.pth" |
29 | | - |
30 | | - |
31 | 28 | @pytest.fixture(scope="session") |
32 | | -def model_path_nequip(tmp_path_factory: pytest.TempPathFactory) -> Path: |
33 | | - tmp_path = tmp_path_factory.mktemp("nequip_checkpoints") |
34 | | - model_name = "Si.nequip.pth" |
35 | | - model_path = Path(tmp_path) / model_name |
36 | | - |
37 | | - if not model_path.is_file(): |
38 | | - urllib.request.urlretrieve(NequIPUrls.Si, model_path) # noqa: S310 |
| 29 | +def compiled_nequip_model_path(tmp_path_factory: pytest.TempPathFactory) -> Path: |
| 30 | + """Compile NequIP OAM-L model from nequip.net.""" |
| 31 | + tmp_path = tmp_path_factory.mktemp("nequip_compiled") |
| 32 | + output_model_name = "mir-group__NequIP-OAM-L__0.1.nequip.pt2" |
| 33 | + output_path = Path(tmp_path) / output_model_name |
| 34 | + |
| 35 | + main( |
| 36 | + args=[ |
| 37 | + "nequip.net:mir-group/NequIP-OAM-L:0.1", |
| 38 | + str(output_path), |
| 39 | + "--mode", |
| 40 | + "aotinductor", |
| 41 | + "--device", |
| 42 | + "cuda", |
| 43 | + "--target", |
| 44 | + "ase", |
| 45 | + ] |
| 46 | + ) |
39 | 47 |
|
40 | | - return model_path |
| 48 | + return output_path |
41 | 49 |
|
42 | 50 |
|
43 | 51 | @pytest.fixture |
44 | | -def nequip_model(model_path_nequip: Path) -> NequIPFrameworkModel: |
| 52 | +def nequip_model(compiled_nequip_model_path: Path) -> NequIPFrameworkModel: |
45 | 53 | """Create an NequIPModel wrapper for the pretrained model.""" |
46 | 54 | compiled_model, (r_max, type_names) = from_compiled_model( |
47 | | - model_path_nequip, device=DEVICE |
| 55 | + compiled_nequip_model_path, device=DEVICE |
48 | 56 | ) |
49 | 57 | return NequIPFrameworkModel( |
50 | 58 | model=compiled_model, |
@@ -74,11 +82,18 @@ def test_nequip_initialization(model_path_nequip: Path) -> None: |
74 | 82 | assert model._device == DEVICE # noqa: SLF001 |
75 | 83 |
|
76 | 84 |
|
77 | | -test_nequip_consistency = make_model_calculator_consistency_test( |
| 85 | +test_metatomic_consistency = make_model_calculator_consistency_test( |
78 | 86 | test_name="nequip", |
79 | 87 | model_fixture_name="nequip_model", |
80 | 88 | calculator_fixture_name="nequip_calculator", |
81 | | - sim_state_names=("si_sim_state", "rattled_si_sim_state"), |
| 89 | + sim_state_names=consistency_test_simstate_fixtures, |
| 90 | + energy_atol=5e-5, |
| 91 | + dtype=DTYPE, |
| 92 | + device=DEVICE, |
82 | 93 | ) |
83 | 94 |
|
84 | | -# TODO (AG): Test multi element models |
| 95 | +test_metatomic_model_outputs = make_validate_model_outputs_test( |
| 96 | + model_fixture_name="nequip_model", |
| 97 | + dtype=DTYPE, |
| 98 | + device=DEVICE, |
| 99 | +) |
0 commit comments