Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions arc/family/arc_families_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python3
# encoding: utf-8

"""
This module contains unit tests for the kinetic families defined under arc.data.families.
"""

import unittest
import os

from arc.family.family import ReactionFamily, get_reaction_family_products, get_recipe_actions
from arc.imports import settings
from arc.reaction.reaction import ARCReaction
from arc.species.species import ARCSpecies

ARC_FAMILIES_PATH = settings['ARC_FAMILIES_PATH']


class TestCarbonylBasedHydrolysisReactionFamily(unittest.TestCase):
"""
Contains unit tests for the carbonyl-based hydrolysis reaction family.
"""

@classmethod
def setUpClass(cls):
"""Set up the test by defining the carbonyl-based hydrolysis reaction family."""
cls.family = ReactionFamily('carbonyl_based_hydrolysis')

def test_carbonyl_based_hydrolysis_reaction(self):
"""Test if carbonyl_based hydrolysis products are correctly generated."""
carbonyl = ARCSpecies(label='carbonyl', smiles='CC(=O)OC')
water = ARCSpecies(label='H2O', smiles='O')
acid = ARCSpecies(label='acid', smiles='CC(=O)O')
alcohol = ARCSpecies(label='alcohol', smiles='CO')
rxn = ARCReaction(r_species=[carbonyl, water], p_species=[acid, alcohol])
products = get_reaction_family_products(rxn)
product_smiles = [p.to_smiles() for p in products[0]['products']]
expected_product_smiles = ['CC(=O)O', 'CO']
self.assertEqual(product_smiles, expected_product_smiles)

def test_recipe_actions(self):
"""Test if the reaction recipe is applied correctly."""
groups_file_path = os.path.join(ARC_FAMILIES_PATH, 'carbonyl_based_hydrolysis.py')
with open(groups_file_path, 'r') as f:
groups_as_lines = f.readlines()
actions = get_recipe_actions(groups_as_lines)
expected_actions = [
['BREAK_BOND', '*1', 1, '*2'],
['BREAK_BOND', '*3', 1, '*4'],
['FORM_BOND', '*1', 1, '*4'],
['FORM_BOND', '*2', 1, '*3'],
]
self.assertEqual(actions, expected_actions)

def test_carbonyl_based_hydrolysis_withP(self):
"""Test if carbonyl-based hydrolysis products are correctly generated."""
carbonyl= ARCSpecies(label='carbonyl', smiles='CP(=O)(OC)O')
water = ARCSpecies(label='H2O', smiles='O')
acid = ARCSpecies(label='acid', smiles='CP(=O)(O)O')
alcohol = ARCSpecies(label='alcohol', smiles='CO')
rxn = ARCReaction(r_species=[carbonyl, water], p_species=[acid, alcohol])
products = get_reaction_family_products(rxn)
product_smiles = [p.to_smiles() for p in products[0]['products']]
expected_product_smiles = ['CP(=O)(O)O', 'CO']
self.assertEqual(product_smiles, expected_product_smiles)


class TestNitrileHydrolysisReactionFamily(unittest.TestCase):
"""
Contains unit tests for the nitrile hydrolysis reaction family.
"""

@classmethod
def setUpClass(cls):
"""Set up the test by defining the nitrile hydrolysis reaction family."""
cls.family = ReactionFamily('nitrile_hydrolysis')

def test_nitrile_hydrolysis_reaction(self):
"""Test if nitrile hydrolysis products are correctly generated."""
nitrile = ARCSpecies(label='nitrile', smiles='CC#N')
water = ARCSpecies(label='H2O', smiles='O')
acid = ARCSpecies(label='acid', smiles='CC(=N)O')
rxn = ARCReaction(r_species=[nitrile, water], p_species=[acid])
products = get_reaction_family_products(rxn)
product_smiles = [p.to_smiles() for p in products[0]['products']]
expected_product_smiles = ['CC(=N)O']
self.assertEqual(product_smiles, expected_product_smiles)

def test_recipe_actions(self):
"""Test if the reaction recipe is applied correctly for nitrile hydrolysis."""
groups_file_path = os.path.join(ARC_FAMILIES_PATH, 'nitrile_hydrolysis.py')
with open(groups_file_path, 'r') as f:
groups_as_lines = f.readlines()
actions = get_recipe_actions(groups_as_lines)
expected_actions =[
['CHANGE_BOND', '*1', -1, '*2'],
['BREAK_BOND', '*3', 1, '*4'],
['FORM_BOND', '*1', 1, '*4'],
['FORM_BOND', '*2', 1, '*3'],
]
self.assertEqual(actions, expected_actions)


class TestEtherHydrolysisReactionFamily(unittest.TestCase):
"""
Contains unit tests for the ether hydrolysis reaction family.
"""

@classmethod
def setUpClass(cls):
"""Set up the test by defining the ether hydrolysis reaction family."""
cls.family = ReactionFamily('ether_hydrolysis')

def test_ether_hydrolysis_reaction(self):
"""Test if ether hydrolysis products are correctly generated."""
ether = ARCSpecies(label='ether', smiles='CCOC')
water = ARCSpecies(label='H2O', smiles='O')
alcohol1 = ARCSpecies(label='alcohol1', smiles='CCO')
alcohol2 = ARCSpecies(label='alcohol2', smiles='CO')
rxn = ARCReaction(r_species=[ether, water], p_species=[alcohol1, alcohol2])
products = get_reaction_family_products(rxn)
product_smiles = [p.to_smiles() for p in products[0]['products']]
expected_product_smiles = ['CCO', 'CO']
self.assertEqual(product_smiles, expected_product_smiles)

def test_recipe_actions(self):
"""Test if the reaction recipe is applied correctly."""
groups_file_path = os.path.join(ARC_FAMILIES_PATH, 'ether_hydrolysis.py')
with open(groups_file_path, 'r') as f:
groups_as_lines = f.readlines()
actions = get_recipe_actions(groups_as_lines)
expected_actions = [
['BREAK_BOND', '*1', 1, '*2'],
['BREAK_BOND', '*3', 1, '*4'],
['FORM_BOND', '*1', 1, '*4'],
['FORM_BOND', '*2', 1, '*3'],
]
self.assertEqual(actions, expected_actions)


if __name__ == '__main__':
unittest.main()
24 changes: 20 additions & 4 deletions arc/family/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,11 +588,12 @@ def get_all_families(rmg_family_set: Union[List[str], str] = 'default',
rmg_families.extend(list(families))
else:
rmg_families = list(family_sets[rmg_family_set]) \
if isinstance(rmg_family_set, str) and rmg_family_set in family_sets else rmg_family_set
if isinstance(rmg_family_set, str) and rmg_family_set in family_sets else [rmg_family_set]
if consider_arc_families:
arc_families = [os.path.splitext(family)[0] for family in os.listdir(ARC_FAMILIES_PATH)]
rmg_families = [rmg_families] if isinstance(rmg_families, str) else rmg_families
arc_families = [arc_families] if isinstance(arc_families, str) else arc_families
for family in os.listdir(ARC_FAMILIES_PATH):
if family.startswith('.') or family.startswith('_'):
continue
arc_families.append(os.path.splitext(family)[0])
return rmg_families + arc_families if rmg_families is not None else arc_families


Expand Down Expand Up @@ -862,3 +863,18 @@ def isomorphic_products(rxn: 'ARCReaction',
"""
p_species = rxn.get_reactants_and_products(return_copies=True)[1]
return check_product_isomorphism(products, p_species)

def check_family_name(family: str
) -> bool:
"""
Check whether the family name is defined.

Args:
family (str): The family name.

Returns:
bool: Whether the family is defined.
"""
if not isinstance(family, str) and family is not None:
raise TypeError("Family name must be a string or None.")
return family in get_all_families() or family is None
15 changes: 14 additions & 1 deletion arc/family/family_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
get_rmg_recommended_family_sets,
is_own_reverse,
is_reversible,
check_family_name
)
from arc.molecule import Group, Molecule
from arc.molecule.resonance import generate_resonance_structures_safely
Expand Down Expand Up @@ -701,7 +702,9 @@ def test_get_all_families(self):
self.assertIn('intra_OH_migration', families)
families = get_all_families(consider_rmg_families=False)
self.assertIsInstance(families, list)
self.assertIn('hydrolysis', families)
self.assertIn('carbonyl_based_hydrolysis', families)
self.assertIn('ether_hydrolysis', families)
self.assertIn('nitrile_hydrolysis', families)
families = get_all_families(rmg_family_set=['H_Abstraction'])
self.assertEqual(families, ['H_Abstraction'])

Expand Down Expand Up @@ -1059,6 +1062,16 @@ def test_get_isomorphic_subgraph(self):
)
self.assertEqual(isomorphic_subgraph, {0: '*3', 4: '*1', 7: '*2'})

def test_check_family_name(self):
"""Test check family name function"""
self.assertTrue(check_family_name('H_Abstraction'))
self.assertTrue(check_family_name('ether_hydrolysis'))
self.assertFalse(check_family_name('etherhydrolysis'))
self.assertFalse(check_family_name('amine_hydrolysis'))
self.assertTrue(check_family_name(None))
with self.assertRaises(TypeError):
check_family_name(123)


if __name__ == '__main__':
unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))
21 changes: 13 additions & 8 deletions arc/job/adapters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
default_job_settings, global_ess_settings, rotor_scan_resolution = \
settings['default_job_settings'], settings['global_ess_settings'], settings['rotor_scan_resolution']


ts_adapters_by_rmg_family = {'1+2_Cycloaddition': ['kinbot'],
'1,2_Insertion_CO': ['kinbot'],
'1,2_Insertion_carbene': ['kinbot'],
Expand All @@ -43,6 +42,9 @@
'Cyclopentadiene_scission': ['gcn', 'xtb_gsm'],
'Diels_alder_addition': ['kinbot'],
'H_Abstraction': ['heuristics', 'autotst'],
'carbonyl_based_hydrolysis': ['heuristics'],
'ether_hydrolysis': ['heuristics'],
'nitrile_hydrolysis': ['heuristics'],
'HO2_Elimination_from_PeroxyRadical': ['kinbot'],
'Intra_2+2_cycloaddition_Cd': ['gcn', 'xtb_gsm'],
'Intra_5_membered_conjugated_C=C_C=C_addition': ['gcn', 'xtb_gsm'],
Expand Down Expand Up @@ -117,7 +119,7 @@ def _initialize_adapter(obj: 'JobAdapter',
times_rerun: int = 0,
torsions: Optional[List[List[int]]] = None,
tsg: Optional[int] = None,
xyz: Optional[Union[dict,List[dict]]] = None,
xyz: Optional[Union[dict, List[dict]]] = None,
):
"""
A common Job adapter initializer function.
Expand Down Expand Up @@ -161,7 +163,7 @@ def _initialize_adapter(obj: 'JobAdapter',
obj.job_num = job_num
obj.job_server_name = job_server_name
obj.job_status = job_status \
or ['initializing', {'status': 'initializing', 'keywords': list(), 'error': '', 'line': ''}]
or ['initializing', {'status': 'initializing', 'keywords': list(), 'error': '', 'line': ''}]
obj.job_type = job_type if isinstance(job_type, str) else job_type[0] # always a string
obj.job_types = job_type if isinstance(job_type, list) else [job_type] # always a list
# When restarting ARC and re-setting the jobs, ``level`` is a string, convert it to a Level object instance
Expand Down Expand Up @@ -211,7 +213,7 @@ def _initialize_adapter(obj: 'JobAdapter',
obj.is_ts = obj.species[0].is_ts
obj.species_label = list()
for spc in obj.species:
obj.charge.append(spc.charge)
obj.charge.append(spc.charge)
obj.multiplicity.append(spc.multiplicity)
obj.species_label.append(spc.label)
elif obj.reactions is not None:
Expand Down Expand Up @@ -286,9 +288,9 @@ def is_species_restricted(obj: 'JobAdapter',
bool: Whether to run as restricted (``True``) or not (``False``).
"""

if obj.level.method_type in ['force_field','composite','semiempirical']:
if obj.level.method_type in ['force_field', 'composite', 'semiempirical']:
return True

multiplicity = obj.multiplicity if species is None else species.multiplicity
number_of_radicals = obj.species[0].number_of_radicals if species is None else species.number_of_radicals
species_label = obj.species[0].label if species is None else species.label
Expand Down Expand Up @@ -322,7 +324,8 @@ def check_argument_consistency(obj: 'JobAdapter'):
raise NotImplementedError(f'The {obj.job_adapter} job adapter does not support ESS scans.')
if obj.job_type == 'scan' and divmod(360, obj.scan_res)[1]:
raise ValueError(f'Got an illegal rotor scan resolution of {obj.scan_res}.')
if obj.job_type == 'scan' and ((not obj.species[0].rotors_dict or obj.rotor_index is None) and obj.torsions is None):
if obj.job_type == 'scan' and (
(not obj.species[0].rotors_dict or obj.rotor_index is None) and obj.torsions is None):
# If this is a scan job type and species.rotors_dict is empty (e.g., via pipe), then torsions must be set up.
raise ValueError('Either torsions or a species rotors_dict along with a rotor_index argument '
'must be specified for an ESS scan job.')
Expand Down Expand Up @@ -406,7 +409,7 @@ def update_input_dict_with_args(args: dict,
else:
if 'keywords' not in input_dict.keys():
input_dict['keywords'] = ''
# Check if input_dict['keywords'] already contains a value
# Check if input_dict['keywords'] already contains a value
if input_dict['keywords']:
input_dict['keywords'] += f' {value}'
else:
Expand Down Expand Up @@ -444,6 +447,7 @@ def update_input_dict_with_args(args: dict,

return input_dict


def input_dict_strip(input_dict: dict) -> dict:
"""
Strip all values in the input dict of leading and trailing whitespace.
Expand Down Expand Up @@ -536,6 +540,7 @@ def which(command: Union[str, list],
else:
return ans


def combine_parameters(input_dict: dict, terms: list) -> Tuple[dict, List]:
"""
Extract and combine specific parameters from a dictionary's string values based on a list of terms.
Expand Down
Loading
Loading