diff --git a/src/templates.py b/src/templates.py index d31c65b..bdb06ad 100644 --- a/src/templates.py +++ b/src/templates.py @@ -12,27 +12,27 @@ def __init__(self, input: str, pos: int, message: str) -> None: class TemplateType: args: List['TemplateType'] + inner: 'TemplateType' - def __init__(self, name: str, args=None) -> None: + def __init__(self, name: str, args=[], inner=None) -> None: super().__init__() - if args is None: - args = [] self.name = name self.args = args + self.inner = inner @property def is_wildcard(self): return self.name == "*" def __repr__(self) -> str: - return '<{}: {!r} [{}]>'.format(self.__class__.__name__, self.name, ",".join(repr(x) for x in self.args)) + return '<{}: {!r} [{}]> -> {!r}'.format(self.__class__.__name__, self.name, ",".join(repr(x) for x in self.args), self.inner) def __str__(self) -> str: if len(self.args) <= 0: return self.name else: - return '{}<{}>'.format(self.name, ", ".join(str(x) for x in self.args)) + return '{}<{}>{}'.format(self.name, ", ".join(str(x) for x in self.args), str(self.inner) if self.inner is not None else '') def matches(self, other: 'TemplateType', matched_args: List[str] = None) -> bool: if self.is_wildcard: @@ -48,7 +48,13 @@ def matches(self, other: 'TemplateType', matched_args: List[str] = None) -> bool if not left.matches(right, matched_args): return False - return self.name == other.name + if self.name != other.name: + return False + + if (self.inner is None and other.inner is not None) or (self.inner is not None and other.inner is None): + return False + + return (self.inner is None and other.inner is None) or self.inner.matches(other.inner, matched_args) TEMPLATE_LIST_REGEX = re.compile("[<>,]") @@ -102,7 +108,12 @@ def _template_type_parse_runner(input: str, start: int) -> Tuple[TemplateType, i .format("" if arg_start >= len(input) else input[arg_start])) arg_start += 1 # Consume the '>' - return TemplateType(input[start:name_end], args), arg_start + if arg_start < len(input): + # consume remaining itemsas inner + arg_type, arg_end = _template_type_parse_runner(input, arg_start) + return TemplateType(input[start:name_end], args, arg_type), arg_end + else: + return TemplateType(input[start:name_end], args), arg_start def parse_template_type(input: str) -> TemplateType: diff --git a/test/test_templates.py b/test/test_templates.py index bace83e..cd02bb9 100644 --- a/test/test_templates.py +++ b/test/test_templates.py @@ -76,6 +76,44 @@ def test_nested_wildcard(self): self.assertTrue(vector_type.args[0].is_wildcard) self.assertEqual(0, len(vector_type.args[0].args)) + def test_simple_list_inner_non_template(self): + template_type = templates.parse_template_type("test::template_class::inner_class") + + self.assertEqual("test::template_class", template_type.name) + + self.assertEqual(1, len(template_type.args)) + + self.assertEqual("float", template_type.args[0].name) + + self.assertEqual(0, len(template_type.args[0].args)) + + self.assertNotEqual(None, template_type.inner) + + self.assertEqual("::inner_class", template_type.inner.name) + + self.assertEqual(0, len(template_type.inner.args)) + + def test_simple_list_inner_simple_list(self): + template_type = templates.parse_template_type("test::template_class::inner_class") + + self.assertEqual("test::template_class", template_type.name) + + self.assertEqual(1, len(template_type.args)) + + self.assertEqual("float", template_type.args[0].name) + + self.assertEqual(0, len(template_type.args[0].args)) + + self.assertNotEqual(None, template_type.inner) + + self.assertEqual("::inner_class", template_type.inner.name) + + self.assertEqual(1, len(template_type.inner.args)) + + self.assertEqual("int", template_type.inner.args[0].name) + + self.assertEqual(0, len(template_type.inner.args[0].args)) + def test_missing_closing_brance(self): with self.assertRaises(TemplateException): templates.parse_template_type("test::template_class<")