diff --git a/conkit/core/contactmap.py b/conkit/core/contactmap.py index db7db3fe..310f884f 100644 --- a/conkit/core/contactmap.py +++ b/conkit/core/contactmap.py @@ -738,13 +738,14 @@ def match_naive(self, other, add_false_negatives=False, match_other=False, inpla else: contact.false_positive = True - if not add_false_negatives: + if not add_false_negatives and not match_other: return contact_map1 for contact in contact_map2: if contact.id not in contact_map1_set: contact.false_negative = True - contact_map1.add(contact) + if add_false_negatives: + contact_map1.add(contact) return contact_map1 @@ -800,15 +801,16 @@ def match(self, other, add_false_negatives=False, match_other=False, remove_unma else: contact_map2 = other._inplace(False) - if contact_map1.empty and add_false_negatives: + if contact_map1.empty: for contact in contact_map2: contact.false_negative = True + if add_false_negatives: + contact_map1.add(contact) + return contact_map1 if contact_map2.empty: for contact in contact_map1: contact.false_positive = True - - if contact_map2.empty or contact_map1.empty: return contact_map1 # ================================================================ @@ -876,12 +878,12 @@ def match(self, other, add_false_negatives=False, match_other=False, remove_unma # ================================================================ # 3. Add false negatives # ================================================================ - if add_false_negatives: - for contactid in contact_map2.as_list(): - contactid = tuple(contactid) - if contactid not in contact_map1: + for contactid in contact_map2.as_list(): + contactid = tuple(contactid) + if contactid not in contact_map1: + contact_map2[contactid].status = ContactMatchState.false_negative + if add_false_negatives: contact = contact_map2[contactid].copy() - contact.false_negative = True contact_map1.add(contact) # ================================================================ diff --git a/conkit/core/tests/test_contactmap.py b/conkit/core/tests/test_contactmap.py index d2e40266..48896f2f 100644 --- a/conkit/core/tests/test_contactmap.py +++ b/conkit/core/tests/test_contactmap.py @@ -535,6 +535,7 @@ def test_match_4(self): contact_map1.match(contact_map2, match_other=True, remove_unmatched=True, inplace=True) self.assertEqual([TP, TP], [c.status for c in contact_map1]) + self.assertEqual([TP, TP, FN], [c.status for c in contact_map2]) self.assertEqual([2, 2, 3], [c.res1_altseq for c in contact_map2]) self.assertEqual([6, 7, 5], [c.res2_altseq for c in contact_map2]) @@ -891,6 +892,27 @@ def test_match_14(self): contact_map2.set_sequence_register() contact_map1.match(contact_map2, add_false_negatives=True, inplace=True, match_other=True) + self.assertEqual([FN, FN, FN, FN], [c.status for c in contact_map1]) + self.assertEqual([FN, FN, FN, FN], [c.status for c in contact_map2]) + self.assertListEqual([[1, 5], [1, 7], [2, 7], [3, 4]], contact_map2.as_list()) + + def test_match_15(self): + contact_map1 = ContactMap("foo") + contact_map1.sequence = Sequence("foo", "AICDEFGH") + contact_map1.set_sequence_register() + + contact_map2 = ContactMap("bar") + for i, params in enumerate([(1, 5, 1.0), (1, 7, 1.0), (2, 7, 1.0), (3, 4, 1.0)]): + contact = Contact(*params) + contact.res1_altseq = params[0] + contact.res2_altseq = params[1] + contact.status = TP + contact_map2.add(contact) + contact_map2.sequence = Sequence("bar", "AICDEFGH") + contact_map2.set_sequence_register() + + contact_map1.match(contact_map2, inplace=True, match_other=True) + self.assertEqual([], [c.status for c in contact_map1]) self.assertEqual([FN, FN, FN, FN], [c.status for c in contact_map2]) self.assertListEqual([[1, 5], [1, 7], [2, 7], [3, 4]], contact_map2.as_list()) @@ -946,9 +968,47 @@ def test_match_naive_3(self): contact_map2.set_sequence_register() contact_map1.match_naive(contact_map2, add_false_negatives=True, inplace=True, match_other=True) + self.assertEqual([FN, FN, FN, FN], [c.status for c in contact_map1]) self.assertEqual([FN, FN, FN, FN], [c.status for c in contact_map2]) self.assertListEqual([[1, 5], [1, 7], [2, 7], [3, 4]], contact_map2.as_list()) + def test_match_naive_4(self): + contact_map1 = ContactMap("foo") + for params in [(1, 5, 1.0), (1, 6, 1.0), (2, 7, 1.0), (3, 5, 1.0), (2, 8, 1.0)]: + contact = Contact(*params) + contact_map1.add(contact) + + contact_map2 = ContactMap("bar") + for i, params in enumerate([(1, 5, 1.0), (1, 7, 1.0), (2, 7, 1.0), (3, 4, 1.0)]): + contact = Contact(*params) + contact.res1_altseq = params[0] + contact.res2_altseq = params[1] + contact.status = TP + contact_map2.add(contact) + + contact_map1.match_naive(contact_map2, inplace=True, match_other=True) + self.assertEqual([TP, FP, TP, FP, FP], [c.status for c in contact_map1]) + self.assertEqual([TP, FN, TP, FN], [c.status for c in contact_map2]) + + def test_match_naive_5(self): + contact_map1 = ContactMap("foo") + contact_map1.sequence = Sequence("foo", "AICDEFGH") + contact_map1.set_sequence_register() + + contact_map2 = ContactMap("bar") + for i, params in enumerate([(1, 5, 1.0), (1, 7, 1.0), (2, 7, 1.0), (3, 4, 1.0)]): + contact = Contact(*params) + contact.res1_altseq = params[0] + contact.res2_altseq = params[1] + contact.status = TP + contact_map2.add(contact) + contact_map2.sequence = Sequence("bar", "AICDEFGH") + contact_map2.set_sequence_register() + + contact_map1.match_naive(contact_map2, inplace=True, match_other=True) + self.assertEqual([], [c.status for c in contact_map1]) + self.assertEqual([FN, FN, FN, FN], [c.status for c in contact_map2]) + def test_remove_neighbors_1(self): contact_map = ContactMap("test") for c in [Contact(1, 5, 1.0), Contact(3, 3, 0.4), Contact(2, 4, 0.1), Contact(5, 1, 0.2)]: