Skip to content

Commit 6a0f226

Browse files
committed
fea: make test match other models using foundation nequip
1 parent ff71965 commit 6a0f226

File tree

1 file changed

+38
-23
lines changed

1 file changed

+38
-23
lines changed

tests/models/test_nequip_framework.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import traceback
2-
import urllib.request
3-
from enum import StrEnum
42
from pathlib import Path
53

64
import pytest
75

8-
from tests.conftest import DEVICE
9-
from tests.models.conftest import make_model_calculator_consistency_test
6+
from tests.conftest import DEVICE, DTYPE
7+
from tests.models.conftest import (
8+
consistency_test_simstate_fixtures,
9+
make_model_calculator_consistency_test,
10+
make_validate_model_outputs_test,
11+
)
1012

1113

1214
try:
1315
from nequip.ase import NequIPCalculator
16+
from nequip.scripts.compile import main
1417

1518
from torch_sim.models.nequip_framework import (
1619
NequIPFrameworkModel,
@@ -22,29 +25,34 @@
2225
)
2326

2427

25-
class NequIPUrls(StrEnum):
26-
"""Checkpoint download URLs for NequIP models."""
27-
28-
Si = "https://github.com/abhijeetgangan/pt_model_checkpoints/raw/refs/heads/main/nequip/Si.nequip.pth"
29-
30-
3128
@pytest.fixture(scope="session")
32-
def model_path_nequip(tmp_path_factory: pytest.TempPathFactory) -> Path:
33-
tmp_path = tmp_path_factory.mktemp("nequip_checkpoints")
34-
model_name = "Si.nequip.pth"
35-
model_path = Path(tmp_path) / model_name
36-
37-
if not model_path.is_file():
38-
urllib.request.urlretrieve(NequIPUrls.Si, model_path) # noqa: S310
29+
def compiled_nequip_model_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
30+
"""Compile NequIP OAM-L model from nequip.net."""
31+
tmp_path = tmp_path_factory.mktemp("nequip_compiled")
32+
output_model_name = "mir-group__NequIP-OAM-L__0.1.nequip.pt2"
33+
output_path = Path(tmp_path) / output_model_name
34+
35+
main(
36+
args=[
37+
"nequip.net:mir-group/NequIP-OAM-L:0.1",
38+
str(output_path),
39+
"--mode",
40+
"aotinductor",
41+
"--device",
42+
"cuda",
43+
"--target",
44+
"ase",
45+
]
46+
)
3947

40-
return model_path
48+
return output_path
4149

4250

4351
@pytest.fixture
44-
def nequip_model(model_path_nequip: Path) -> NequIPFrameworkModel:
52+
def nequip_model(compiled_nequip_model_path: Path) -> NequIPFrameworkModel:
4553
"""Create an NequIPModel wrapper for the pretrained model."""
4654
compiled_model, (r_max, type_names) = from_compiled_model(
47-
model_path_nequip, device=DEVICE
55+
compiled_nequip_model_path, device=DEVICE
4856
)
4957
return NequIPFrameworkModel(
5058
model=compiled_model,
@@ -74,11 +82,18 @@ def test_nequip_initialization(model_path_nequip: Path) -> None:
7482
assert model._device == DEVICE # noqa: SLF001
7583

7684

77-
test_nequip_consistency = make_model_calculator_consistency_test(
85+
test_metatomic_consistency = make_model_calculator_consistency_test(
7886
test_name="nequip",
7987
model_fixture_name="nequip_model",
8088
calculator_fixture_name="nequip_calculator",
81-
sim_state_names=("si_sim_state", "rattled_si_sim_state"),
89+
sim_state_names=consistency_test_simstate_fixtures,
90+
energy_atol=5e-5,
91+
dtype=DTYPE,
92+
device=DEVICE,
8293
)
8394

84-
# TODO (AG): Test multi element models
95+
test_metatomic_model_outputs = make_validate_model_outputs_test(
96+
model_fixture_name="nequip_model",
97+
dtype=DTYPE,
98+
device=DEVICE,
99+
)

0 commit comments

Comments
 (0)