Skip to content

Commit de7c35c

Browse files
committed
Cleaned colang directory
1 parent 5c0c461 commit de7c35c

File tree

12 files changed

+106
-46
lines changed

12 files changed

+106
-46
lines changed

nemoguardrails/colang/v1_0/lang/comd_parser.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -360,21 +360,22 @@ def parse_md_file(file_name, content=None):
360360
continue
361361

362362
# Make sure we have the type of the symbol in the name of the symbol
363-
sym = _get_typed_symbol_name(sym, symbol_type)
363+
if sym is not None:
364+
sym = _get_typed_symbol_name(sym, symbol_type)
364365

365-
# For objects, we translate the "string" type to "kb:Object:prop|partial"
366-
param_type = _get_param_type(parts[1])
367-
if symbol_type == "object" and param_type in ["string", "text"]:
368-
object_name = split_max(sym, ":", 1)[1]
369-
param_type = f"kb:{object_name}:{parts[0]}|partial"
366+
# For objects, we translate the "string" type to "kb:Object:prop|partial"
367+
param_type = _get_param_type(parts[1])
368+
if symbol_type == "object" and param_type in ["string", "text"]:
369+
object_name = split_max(sym, ":", 1)[1]
370+
param_type = f"kb:{object_name}:{parts[0]}|partial"
370371

371-
# TODO: figure out a cleaner way to deal with this
372-
# For the "type:time" type, we transform it into "lookup:time"
373-
if param_type == "type:time":
374-
param_type = "lookup:time"
372+
# TODO: figure out a cleaner way to deal with this
373+
# For the "type:time" type, we transform it into "lookup:time"
374+
if param_type == "type:time":
375+
param_type = "lookup:time"
375376

376-
result["mappings"].append((f"{sym}:{parts[0]}", param_type))
377-
symbol_params.append(parts[0])
377+
result["mappings"].append((f"{sym}:{parts[0]}", param_type))
378+
symbol_params.append(parts[0])
378379

379380
elif line.startswith("-") or line.startswith("*"):
380381
if sym is None:

nemoguardrails/colang/v1_0/lang/coyml_parser.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -420,14 +420,20 @@ def _extract_elements(items: List) -> List[dict]:
420420
# for `if` flow elements, we have to go recursively
421421
if element["_type"] == "if":
422422
if_element = element
423-
then_elements = _extract_elements(if_element["then"])
424-
else_elements = _extract_elements(if_element["else"])
423+
then_items = (
424+
if_element["then"] if isinstance(if_element["then"], list) else []
425+
)
426+
else_items = (
427+
if_element["else"] if isinstance(if_element["else"], list) else []
428+
)
429+
then_elements = _extract_elements(then_items)
430+
else_elements = _extract_elements(else_items)
425431

426432
# Remove the raw info
427433
del if_element["then"]
428434
del if_element["else"]
429435

430-
if_element["_next_else"] = len(then_elements) + 1
436+
if_element["_next_else"] = str(len(then_elements) + 1)
431437

432438
# Add the "if"
433439
elements.append(if_element)
@@ -437,30 +443,35 @@ def _extract_elements(items: List) -> List[dict]:
437443

438444
# if we have "else" elements, we need to adjust also add a jump
439445
if len(else_elements) > 0:
440-
elements.append({"_type": "jump", "_next": len(else_elements) + 1})
441-
if_element["_next_else"] += 1
446+
elements.append(
447+
{"_type": "jump", "_next": str(len(else_elements) + 1)}
448+
)
449+
if_element["_next_else"] = str(int(if_element["_next_else"]) + 1)
442450

443451
# Add the "else" elements
444452
elements.extend(else_elements)
445453

446454
# WHILE
447455
elif element["_type"] == "while":
448456
while_element = element
449-
do_elements = _extract_elements(while_element["do"])
457+
do_items = (
458+
while_element["do"] if isinstance(while_element["do"], list) else []
459+
)
460+
do_elements = _extract_elements(do_items)
450461
n = len(do_elements)
451462

452463
# Remove the raw info
453464
del while_element["do"]
454465

455466
# On break we have to skip n elements and 1 jump, hence we go to n+2
456-
while_element["_next_on_break"] = n + 2
467+
while_element["_next_on_break"] = str(n + 2)
457468

458469
# We need to compute the jumps on break and on continue for each element
459470
for j in range(n):
460471
# however, we make sure we don't override an inner loop
461472
if "_next_on_break" not in do_elements[j]:
462-
do_elements[j]["_next_on_break"] = n + 1 - j
463-
do_elements[j]["_next_on_continue"] = -1 * j - 1
473+
do_elements[j]["_next_on_break"] = str(n + 1 - j)
474+
do_elements[j]["_next_on_continue"] = str(-1 * j - 1)
464475

465476
# Add the "while"
466477
elements.append(while_element)
@@ -500,7 +511,7 @@ def _extract_elements(items: List) -> List[dict]:
500511
branch_element = {
501512
"_type": "branch",
502513
# these are the relative positions to the current position
503-
"branch_heads": [],
514+
"branch_heads": [], # type: ignore
504515
}
505516
branch_element_pos = len(elements)
506517
elements.append(branch_element)
@@ -520,7 +531,7 @@ def _extract_elements(items: List) -> List[dict]:
520531
branch_element["_source_mapping"] = branch_path[0]["_source_mapping"]
521532

522533
# Create the jump element
523-
jump_element = {"_type": "jump", "_next": 1}
534+
jump_element = {"_type": "jump", "_next": 1} # type: ignore
524535

525536
# We compute how far we need to jump based on the remaining branches
526537
j = branch_idx + 1

nemoguardrails/colang/v1_0/lang/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,14 @@ def get_numbered_lines(content: str):
8686
current_comment = None
8787
multiline_string = False
8888
current_string = None
89+
multiline_indentation = 0
8990
while i < len(raw_lines):
9091
raw_line = raw_lines[i].strip()
9192

9293
# handle multiline string
9394
if multiline_string:
95+
if current_string is None:
96+
current_string = ""
9497
current_string += "\n" + raw_line
9598
if raw_line.endswith('"'):
9699
multiline_string = False
@@ -139,6 +142,8 @@ def get_numbered_lines(content: str):
139142
continue
140143

141144
if multiline_comment:
145+
if current_comment is None:
146+
current_comment = ""
142147
if raw_line.endswith('"""'):
143148
current_comment += "\n" + raw_line[0:-3]
144149
multiline_comment = False

nemoguardrails/colang/v1_0/runtime/runtime.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,8 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]:
659659
if isinstance(result, ActionResult):
660660
return_value = result.return_value
661661
return_events = result.events
662-
context_updates.update(result.context_updates)
662+
if result.context_updates is not None:
663+
context_updates.update(result.context_updates)
663664

664665
# If we have an action result key, we also record the update.
665666
if action_result_key:
@@ -730,10 +731,17 @@ async def _get_action_resp(
730731
)
731732
except Exception as e:
732733
log.info(f"Exception {e} while making request to {action_name}")
734+
if not isinstance(result, dict):
735+
result = {"value": result}
733736
return result, status
734737

735738
except Exception as e:
736739
log.info(f"Failed to get response from {action_name} due to exception {e}")
740+
741+
# Ensure result is a dict as expected by the return type
742+
if not isinstance(result, dict):
743+
result = {"value": result}
744+
737745
return result, status
738746

739747
async def _process_start_flow(self, events: List[dict], processing_log: List[dict]) -> List[dict]:

nemoguardrails/colang/v1_0/runtime/sliding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
# limitations under the License.
1515

1616
import logging
17-
from typing import Optional
17+
from typing import TYPE_CHECKING, Optional
18+
19+
if TYPE_CHECKING:
20+
from nemoguardrails.colang.v1_0.runtime.flows import FlowConfig, State
1821

1922
from nemoguardrails.colang.v1_0.runtime.eval import eval_expression
2023

nemoguardrails/colang/v2_x/lang/colang_ast.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,17 @@ def get(self, key, default_value=None):
7777

7878
def __eq__(self, other):
7979
if isinstance(other, self.__class__):
80-
return self.__hash__() == other.__hash__()
80+
return self.hash() == other.hash()
8181
return NotImplemented
8282

8383
def hash(self):
8484
"""Return the hash for the current object."""
8585
return hash(_make_hashable(self))
8686

87+
def __hash__(self):
88+
"""Return the hash for the current object."""
89+
return self.hash()
90+
8791

8892
ElementType = Union[Element, dict]
8993

nemoguardrails/colang/v2_x/lang/expansion.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,12 @@ def expand_elements(
106106
if e.args[0]:
107107
error = e.args[0]
108108

109-
if hasattr(element, "_source") and element._source:
109+
if (
110+
not isinstance(element, dict)
111+
and hasattr(element, "_source")
112+
and element._source is not None
113+
and hasattr(element._source, "line")
114+
):
110115
# TODO: Resolve source line to Colang file level
111116
raise ColangSyntaxError(error + f" on source line {element._source.line}")
112117
else:
@@ -413,10 +418,15 @@ def _expand_match_element(
413418

414419
for idx, element in enumerate(and_group["elements"]):
415420
new_elements.append(event_label_elements[idx])
421+
# Ensure element is valid for SpecOp
422+
if isinstance(element, (dict, Spec)):
423+
spec_element: Union[dict, Spec] = element
424+
else:
425+
spec_element = {}
416426
new_elements.append(
417427
SpecOp(
418428
op="match",
419-
spec=element,
429+
spec=spec_element,
420430
)
421431
)
422432
new_elements.append(goto_end_element)
@@ -433,8 +443,8 @@ def _expand_match_element(
433443

434444
else:
435445
# Multiple and-groups combined by or
436-
fork_uid: str = new_var_uuid()
437-
fork_element = ForkHead(fork_uid=fork_uid)
446+
or_fork_uid: str = new_var_uuid()
447+
fork_element = ForkHead(fork_uid=or_fork_uid)
438448
group_label_elements: List[Label] = []
439449
failure_label_name = f"failure_label_{new_var_uuid()}"
440450
failure_label_element = Label(name=failure_label_name)
@@ -463,12 +473,12 @@ def _expand_match_element(
463473

464474
new_elements.append(failure_label_element)
465475
new_elements.append(WaitForHeads(number=len(or_group)))
466-
new_elements.append(MergeHeads(fork_uid=fork_uid))
476+
new_elements.append(MergeHeads(fork_uid=or_fork_uid))
467477
new_elements.append(CatchPatternFailure(label=None))
468478
new_elements.append(Abort())
469479

470480
new_elements.append(end_label_element)
471-
new_elements.append(MergeHeads(fork_uid=fork_uid))
481+
new_elements.append(MergeHeads(fork_uid=or_fork_uid))
472482
new_elements.append(CatchPatternFailure(label=None))
473483

474484
else:

nemoguardrails/colang/v2_x/lang/transformer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,17 @@ def _flow_def(self, children: dict, meta: Meta) -> Flow:
168168

169169
assert member_name_el["_type"] == "var_name"
170170
member_name = member_name_el["elements"][0][1:]
171-
member_def = FlowReturnMemberDef(name=member_name)
171+
return_member_def_obj = FlowReturnMemberDef(name=member_name)
172172

173173
# If we have a default value, we also use that
174174
if len(return_member_def["elements"]) == 2:
175175
default_value_el = return_member_def["elements"][1]
176176
assert default_value_el["_type"] == "expr"
177-
member_def.default_value_expr = default_value_el["elements"][0]
177+
return_member_def_obj.default_value_expr = default_value_el[
178+
"elements"
179+
][0]
178180

179-
return_member_defs.append(member_def)
181+
return_member_defs.append(return_member_def_obj)
180182

181183
elements[0:0] = [
182184
SpecOp(
@@ -546,7 +548,7 @@ def _non_var_spec_and(self, children: list, meta: Meta) -> dict:
546548
val["_source"] = self.__source(meta)
547549
return val
548550

549-
def __default__(self, data, children: list, meta: Meta) -> dict:
551+
def __default__(self, data, children: list, meta: Meta) -> Any:
550552
"""Default function that is called if there is no attribute matching ``data``
551553
552554
Can be overridden. Defaults to creating

nemoguardrails/colang/v2_x/lang/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
def dataclass_to_dict(obj: Any) -> Any:
21-
if is_dataclass(obj):
21+
if is_dataclass(obj) and not isinstance(obj, type):
2222
return {k: dataclass_to_dict(v) for k, v in asdict(obj).items()}
2323
elif isinstance(obj, list):
2424
return [dataclass_to_dict(v) for v in obj]

nemoguardrails/colang/v2_x/runtime/eval.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def _regex_findall(pattern: str, string: str) -> List[str]:
202202
def _pretty_str(data: Any) -> str:
203203
if isinstance(data, (dict, list, set)):
204204
string = json.dumps(data, indent=4)
205-
return SimplifyFormatter().format(string)
205+
# SimplifyFormatter.format() accepts string as well as LogRecord
206+
return str(SimplifyFormatter().format(string)) # type: ignore
206207
return str(data)
207208

208209

@@ -245,27 +246,27 @@ def _get_type(val: Any) -> str:
245246

246247
def _less_than_operator(v_ref: Any) -> ComparisonExpression:
247248
"""Create less then comparison expression."""
248-
return ComparisonExpression(lambda val, v_ref=v_ref: val < v_ref, v_ref)
249+
return ComparisonExpression(lambda val: val < v_ref, v_ref)
249250

250251

251252
def _equal_or_less_than_operator(v_ref: Any) -> ComparisonExpression:
252253
"""Create equal or less than comparison expression."""
253-
return ComparisonExpression(lambda val, val_ref=v_ref: val <= val_ref, v_ref)
254+
return ComparisonExpression(lambda val: val <= v_ref, v_ref)
254255

255256

256257
def _greater_than_operator(v_ref: Any) -> ComparisonExpression:
257258
"""Create less then comparison expression."""
258-
return ComparisonExpression(lambda val, val_ref=v_ref: val > val_ref, v_ref)
259+
return ComparisonExpression(lambda val: val > v_ref, v_ref)
259260

260261

261262
def _equal_or_greater_than_operator(v_ref: Any) -> ComparisonExpression:
262263
"""Create equal or less than comparison expression."""
263-
return ComparisonExpression(lambda val, val_ref=v_ref: val >= val_ref, v_ref)
264+
return ComparisonExpression(lambda val: val >= v_ref, v_ref)
264265

265266

266267
def _not_equal_to_operator(v_ref: Any) -> ComparisonExpression:
267268
"""Create a not equal comparison expression."""
268-
return ComparisonExpression(lambda val, val_ref=v_ref: val != val_ref, v_ref)
269+
return ComparisonExpression(lambda val: val != v_ref, v_ref)
269270

270271

271272
def _flows_info(state: State, flow_instance_uid: Optional[str] = None) -> dict:

0 commit comments

Comments
 (0)