Skip to content

Commit a317235

Browse files
Refactor internal structure to separate dataclasses from converter (#68)
* Refactor internal structure to separate dataclasses from converter * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add more tests --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 21708b0 commit a317235

File tree

11 files changed

+136
-118
lines changed

11 files changed

+136
-118
lines changed

pyiron_dataclasses/v1/converter.py

Lines changed: 24 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
1-
from typing import Callable, Union
1+
from typing import Callable
22

33
from pint import UnitRegistry
44

5-
from pyiron_dataclasses.v1.atomistic import (
5+
from pyiron_dataclasses.v1.jobs.atomistic import (
66
Cell,
77
GenericInput,
88
GenericOutput,
99
Structure,
1010
Units,
1111
)
12-
from pyiron_dataclasses.v1.dft import (
12+
from pyiron_dataclasses.v1.jobs.dft import (
1313
ChargeDensity,
1414
DensityOfStates,
1515
ElectronicStructure,
1616
OutputGenericDFT,
1717
)
18-
from pyiron_dataclasses.v1.job import (
18+
from pyiron_dataclasses.v1.jobs.generic import (
1919
Executable,
2020
GenericDict,
2121
Interactive,
2222
Server,
2323
)
24-
from pyiron_dataclasses.v1.lammps import (
24+
from pyiron_dataclasses.v1.jobs.lammps import (
2525
LammpsInput,
2626
LammpsInputFiles,
2727
LammpsJob,
2828
LammpsOutput,
2929
LammpsPotential,
3030
)
31-
from pyiron_dataclasses.v1.murn import (
31+
from pyiron_dataclasses.v1.jobs.murn import (
3232
MurnaghanInput,
3333
MurnaghanJob,
3434
MurnaghanOutput,
3535
)
36-
from pyiron_dataclasses.v1.sphinx import (
36+
from pyiron_dataclasses.v1.jobs.sphinx import (
3737
BornOppenheimer,
3838
PawPot,
3939
ScfDiag,
@@ -58,14 +58,19 @@
5858
SphinxStructure,
5959
SphinxWaves,
6060
)
61-
from pyiron_dataclasses.v1.vasp import (
61+
from pyiron_dataclasses.v1.jobs.vasp import (
6262
OutCar,
6363
PotCar,
6464
VaspInput,
6565
VaspJob,
6666
VaspOutput,
6767
VaspResources,
6868
)
69+
from pyiron_dataclasses.v1.shared import (
70+
convert_datacontainer_to_dictionary,
71+
convert_generic_parameters_to_dictionary,
72+
convert_generic_parameters_to_string,
73+
)
6974

7075

7176
def get_dataclass(job_dict: dict) -> Callable:
@@ -80,13 +85,13 @@ def get_dataclass(job_dict: dict) -> Callable:
8085

8186
def _convert_sphinx_job_dict(job_dict: dict) -> SphinxJob:
8287
ureg = UnitRegistry()
83-
sphinx_input_parameter_dict = _convert_datacontainer_to_dictionary(
88+
sphinx_input_parameter_dict = convert_datacontainer_to_dictionary(
8489
data_container_dict=job_dict["input"]["parameters"]
8590
)
86-
generic_input_dict = _convert_generic_parameters_to_dictionary(
91+
generic_input_dict = convert_generic_parameters_to_dictionary(
8792
generic_parameter_dict=job_dict["input"]["generic"],
8893
)
89-
output_dict = _convert_datacontainer_to_dictionary(
94+
output_dict = convert_datacontainer_to_dictionary(
9095
data_container_dict=job_dict["output"]["generic"]
9196
)
9297
if "ricQN" in sphinx_input_parameter_dict["sphinx"]["main"]:
@@ -478,7 +483,7 @@ def _convert_sphinx_job_dict(job_dict: dict) -> SphinxJob:
478483

479484
def _convert_lammps_job_dict(job_dict: dict) -> LammpsJob:
480485
ureg = UnitRegistry()
481-
generic_input_dict = _convert_generic_parameters_to_dictionary(
486+
generic_input_dict = convert_generic_parameters_to_dictionary(
482487
generic_parameter_dict=job_dict["input"]["generic"],
483488
)
484489
return LammpsJob(
@@ -556,10 +561,10 @@ def _convert_lammps_job_dict(job_dict: dict) -> LammpsJob:
556561
species=job_dict["input"]["potential_inp"]["potential"]["Species"],
557562
),
558563
input_files=LammpsInputFiles(
559-
control_inp=_convert_generic_parameters_to_string(
564+
control_inp=convert_generic_parameters_to_string(
560565
generic_parameter_dict=job_dict["input"]["control_inp"]
561566
),
562-
potential_inp=_convert_generic_parameters_to_string(
567+
potential_inp=convert_generic_parameters_to_string(
563568
generic_parameter_dict=job_dict["input"]["potential_inp"]
564569
),
565570
),
@@ -622,7 +627,7 @@ def _convert_lammps_job_dict(job_dict: dict) -> LammpsJob:
622627

623628
def _convert_vasp_job_dict(job_dict):
624629
ureg = UnitRegistry()
625-
generic_input_dict = _convert_generic_parameters_to_dictionary(
630+
generic_input_dict = convert_generic_parameters_to_dictionary(
626631
generic_parameter_dict=job_dict["input"]["generic"],
627632
)
628633
return VaspJob(
@@ -704,14 +709,14 @@ def _convert_vasp_job_dict(job_dict):
704709
fix_spin_constraint=generic_input_dict.get("fix_spin_constraint", None),
705710
max_iter=generic_input_dict.get("max_iter", None),
706711
),
707-
incar=_convert_generic_parameters_to_string(
712+
incar=convert_generic_parameters_to_string(
708713
generic_parameter_dict=job_dict["input"]["incar"]
709714
),
710-
kpoints=_convert_generic_parameters_to_string(
715+
kpoints=convert_generic_parameters_to_string(
711716
generic_parameter_dict=job_dict["input"]["kpoints"]
712717
),
713718
potcar=PotCar(
714-
xc=_convert_generic_parameters_to_dictionary(
719+
xc=convert_generic_parameters_to_dictionary(
715720
generic_parameter_dict=job_dict["input"]["potcar"]
716721
)["xc"]
717722
),
@@ -927,7 +932,7 @@ def _convert_vasp_job_dict(job_dict):
927932

928933

929934
def _convert_murnaghan_job_dict(job_dict):
930-
input_dict = _convert_generic_parameters_to_dictionary(
935+
input_dict = convert_generic_parameters_to_dictionary(
931936
generic_parameter_dict=job_dict["input"]["parameters"]
932937
)
933938
return MurnaghanJob(
@@ -967,91 +972,3 @@ def _convert_murnaghan_job_dict(job_dict):
967972
structure=job_dict["output"]["structure"],
968973
),
969974
)
970-
971-
972-
def _convert_generic_parameters_to_string(generic_parameter_dict: dict) -> str:
973-
output_str = ""
974-
for p, v in zip(
975-
generic_parameter_dict["data_dict"]["Parameter"],
976-
generic_parameter_dict["data_dict"]["Value"],
977-
):
978-
output_str += p.replace("___", " ") + " " + str(v) + "\n"
979-
return output_str[:-1]
980-
981-
982-
def _convert_generic_parameters_to_dictionary(generic_parameter_dict: dict) -> dict:
983-
return {
984-
p: v
985-
for p, v in zip(
986-
generic_parameter_dict["data_dict"]["Parameter"],
987-
generic_parameter_dict["data_dict"]["Value"],
988-
)
989-
}
990-
991-
992-
def _filter_dict(input_dict: dict, remove_keys_lst: list) -> dict:
993-
def recursive_filter(input_value: dict, remove_keys_lst: list) -> dict:
994-
if isinstance(input_value, dict):
995-
return _filter_dict(input_dict=input_value, remove_keys_lst=remove_keys_lst)
996-
else:
997-
return input_value
998-
999-
return {
1000-
k: recursive_filter(input_value=v, remove_keys_lst=remove_keys_lst)
1001-
for k, v in input_dict.items()
1002-
if k not in remove_keys_lst
1003-
}
1004-
1005-
1006-
def _sort_dictionary_from_datacontainer(input_dict: dict) -> Union[dict, list]:
1007-
def recursive_sort(input_value: dict) -> Union[dict, list]:
1008-
if isinstance(input_value, dict):
1009-
return _sort_dictionary_from_datacontainer(input_dict=input_value)
1010-
else:
1011-
return input_value
1012-
1013-
ind_dict, content_dict = {}, {}
1014-
content_lst_flag = False
1015-
for k, v in input_dict.items():
1016-
if "__index_" in k:
1017-
key, ind = k.split("__index_")
1018-
if key == "":
1019-
content_lst_flag = True
1020-
ind_dict[int(ind)] = recursive_sort(input_value=v)
1021-
else:
1022-
ind_dict[int(ind)] = key
1023-
content_dict[key] = recursive_sort(input_value=v)
1024-
else:
1025-
content_dict[k] = recursive_sort(input_value=v)
1026-
if content_lst_flag:
1027-
return [ind_dict[ind] for ind in sorted(list(ind_dict.keys()))]
1028-
elif len(ind_dict) == len(content_dict):
1029-
return {
1030-
ind_dict[ind]: content_dict[ind_dict[ind]]
1031-
for ind in sorted(list(ind_dict.keys()))
1032-
}
1033-
elif len(ind_dict) == 0:
1034-
return content_dict
1035-
else:
1036-
raise KeyError(ind_dict, content_dict)
1037-
1038-
1039-
def _convert_datacontainer_to_dictionary(data_container_dict: dict) -> dict:
1040-
output_dict = _sort_dictionary_from_datacontainer(
1041-
input_dict=_filter_dict(
1042-
input_dict=data_container_dict,
1043-
remove_keys_lst=[
1044-
"NAME",
1045-
"TYPE",
1046-
"OBJECT",
1047-
"DICT_VERSION",
1048-
"HDF_VERSION",
1049-
"READ_ONLY",
1050-
"VERSION",
1051-
],
1052-
)
1053-
)
1054-
if isinstance(output_dict, dict):
1055-
return output_dict
1056-
else:
1057-
raise TypeError("datacontainer was not converted to a dictionary.")

pyiron_dataclasses/v1/jobs/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55

6-
from pyiron_dataclasses.v1.dft import OutputGenericDFT
6+
from pyiron_dataclasses.v1.jobs.dft import OutputGenericDFT
77

88

99
@dataclass
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from dataclasses import dataclass
22
from typing import List, Optional
33

4-
from pyiron_dataclasses.v1.atomistic import (
4+
from pyiron_dataclasses.v1.jobs.atomistic import (
55
GenericInput,
66
GenericOutput,
77
Structure,
88
)
9-
from pyiron_dataclasses.v1.job import (
9+
from pyiron_dataclasses.v1.jobs.generic import (
1010
Executable,
1111
GenericDict,
1212
Interactive,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import numpy as np
55

6-
from pyiron_dataclasses.v1.atomistic import Structure
7-
from pyiron_dataclasses.v1.job import Executable, Server
6+
from pyiron_dataclasses.v1.jobs.atomistic import Structure
7+
from pyiron_dataclasses.v1.jobs.generic import Server
88

99

1010
@dataclass
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33

44
import numpy as np
55

6-
from pyiron_dataclasses.v1.atomistic import (
6+
from pyiron_dataclasses.v1.jobs.atomistic import (
77
GenericInput,
88
GenericOutput,
99
Structure,
1010
)
11-
from pyiron_dataclasses.v1.dft import (
11+
from pyiron_dataclasses.v1.jobs.dft import (
1212
ChargeDensity,
1313
ElectronicStructure,
1414
)
15-
from pyiron_dataclasses.v1.job import (
15+
from pyiron_dataclasses.v1.jobs.generic import (
1616
Executable,
1717
GenericDict,
1818
Interactive,
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22

33
import numpy as np
44

5-
from pyiron_dataclasses.v1.atomistic import (
5+
from pyiron_dataclasses.v1.jobs.atomistic import (
66
GenericInput,
77
GenericOutput,
88
Structure,
99
)
10-
from pyiron_dataclasses.v1.dft import (
10+
from pyiron_dataclasses.v1.jobs.dft import (
1111
ChargeDensity,
1212
ElectronicStructure,
1313
)
14-
from pyiron_dataclasses.v1.job import (
14+
from pyiron_dataclasses.v1.jobs.generic import (
1515
Executable,
1616
GenericDict,
1717
Interactive,

pyiron_dataclasses/v1/shared.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from typing import Union
2+
3+
4+
def convert_generic_parameters_to_string(generic_parameter_dict: dict) -> str:
5+
output_str = ""
6+
for p, v in zip(
7+
generic_parameter_dict["data_dict"]["Parameter"],
8+
generic_parameter_dict["data_dict"]["Value"],
9+
):
10+
output_str += p.replace("___", " ") + " " + str(v) + "\n"
11+
return output_str[:-1]
12+
13+
14+
def convert_generic_parameters_to_dictionary(generic_parameter_dict: dict) -> dict:
15+
return {
16+
p: v
17+
for p, v in zip(
18+
generic_parameter_dict["data_dict"]["Parameter"],
19+
generic_parameter_dict["data_dict"]["Value"],
20+
)
21+
}
22+
23+
24+
def convert_datacontainer_to_dictionary(data_container_dict: dict) -> dict:
25+
output_dict = _sort_dictionary_from_datacontainer(
26+
input_dict=_filter_dict(
27+
input_dict=data_container_dict,
28+
remove_keys_lst=[
29+
"NAME",
30+
"TYPE",
31+
"OBJECT",
32+
"DICT_VERSION",
33+
"HDF_VERSION",
34+
"READ_ONLY",
35+
"VERSION",
36+
],
37+
)
38+
)
39+
if isinstance(output_dict, dict):
40+
return output_dict
41+
else:
42+
raise TypeError("datacontainer was not converted to a dictionary.")
43+
44+
45+
def _filter_dict(input_dict: dict, remove_keys_lst: list) -> dict:
46+
def recursive_filter(input_value: dict, remove_keys_lst: list) -> dict:
47+
if isinstance(input_value, dict):
48+
return _filter_dict(input_dict=input_value, remove_keys_lst=remove_keys_lst)
49+
else:
50+
return input_value
51+
52+
return {
53+
k: recursive_filter(input_value=v, remove_keys_lst=remove_keys_lst)
54+
for k, v in input_dict.items()
55+
if k not in remove_keys_lst
56+
}
57+
58+
59+
def _sort_dictionary_from_datacontainer(input_dict: dict) -> Union[dict, list]:
60+
def recursive_sort(input_value: dict) -> Union[dict, list]:
61+
if isinstance(input_value, dict):
62+
return _sort_dictionary_from_datacontainer(input_dict=input_value)
63+
else:
64+
return input_value
65+
66+
ind_dict, content_dict = {}, {}
67+
content_lst_flag = False
68+
for k, v in input_dict.items():
69+
if "__index_" in k:
70+
key, ind = k.split("__index_")
71+
if key == "":
72+
content_lst_flag = True
73+
ind_dict[int(ind)] = recursive_sort(input_value=v)
74+
else:
75+
ind_dict[int(ind)] = key
76+
content_dict[key] = recursive_sort(input_value=v)
77+
else:
78+
content_dict[k] = recursive_sort(input_value=v)
79+
if content_lst_flag:
80+
return [ind_dict[ind] for ind in sorted(list(ind_dict.keys()))]
81+
elif len(ind_dict) == len(content_dict):
82+
return {
83+
ind_dict[ind]: content_dict[ind_dict[ind]]
84+
for ind in sorted(list(ind_dict.keys()))
85+
}
86+
elif len(ind_dict) == 0:
87+
return content_dict
88+
else:
89+
raise KeyError(ind_dict, content_dict)

0 commit comments

Comments
 (0)