From de7c35c9f400dcb3b99bb95a38cc11964ec18d0d Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 8 Sep 2025 20:54:31 -0500 Subject: [PATCH 1/5] Cleaned colang directory --- .../colang/v1_0/lang/comd_parser.py | 25 +++++++------- .../colang/v1_0/lang/coyml_parser.py | 33 ++++++++++++------- nemoguardrails/colang/v1_0/lang/utils.py | 5 +++ nemoguardrails/colang/v1_0/runtime/runtime.py | 10 +++++- nemoguardrails/colang/v1_0/runtime/sliding.py | 5 ++- nemoguardrails/colang/v2_x/lang/colang_ast.py | 6 +++- nemoguardrails/colang/v2_x/lang/expansion.py | 22 +++++++++---- .../colang/v2_x/lang/transformer.py | 10 +++--- nemoguardrails/colang/v2_x/lang/utils.py | 2 +- nemoguardrails/colang/v2_x/runtime/eval.py | 13 ++++---- nemoguardrails/colang/v2_x/runtime/flows.py | 15 ++++++++- nemoguardrails/logging/processing_log.py | 6 ++-- 12 files changed, 106 insertions(+), 46 deletions(-) diff --git a/nemoguardrails/colang/v1_0/lang/comd_parser.py b/nemoguardrails/colang/v1_0/lang/comd_parser.py index 25ce5f289..1f6d6709d 100644 --- a/nemoguardrails/colang/v1_0/lang/comd_parser.py +++ b/nemoguardrails/colang/v1_0/lang/comd_parser.py @@ -360,21 +360,22 @@ def parse_md_file(file_name, content=None): continue # Make sure we have the type of the symbol in the name of the symbol - sym = _get_typed_symbol_name(sym, symbol_type) + if sym is not None: + sym = _get_typed_symbol_name(sym, symbol_type) - # For objects, we translate the "string" type to "kb:Object:prop|partial" - param_type = _get_param_type(parts[1]) - if symbol_type == "object" and param_type in ["string", "text"]: - object_name = split_max(sym, ":", 1)[1] - param_type = f"kb:{object_name}:{parts[0]}|partial" + # For objects, we translate the "string" type to "kb:Object:prop|partial" + param_type = _get_param_type(parts[1]) + if symbol_type == "object" and param_type in ["string", "text"]: + object_name = split_max(sym, ":", 1)[1] + param_type = f"kb:{object_name}:{parts[0]}|partial" - # TODO: figure out a cleaner way to deal with this - # For the "type:time" type, we transform it into "lookup:time" - if param_type == "type:time": - param_type = "lookup:time" + # TODO: figure out a cleaner way to deal with this + # For the "type:time" type, we transform it into "lookup:time" + if param_type == "type:time": + param_type = "lookup:time" - result["mappings"].append((f"{sym}:{parts[0]}", param_type)) - symbol_params.append(parts[0]) + result["mappings"].append((f"{sym}:{parts[0]}", param_type)) + symbol_params.append(parts[0]) elif line.startswith("-") or line.startswith("*"): if sym is None: diff --git a/nemoguardrails/colang/v1_0/lang/coyml_parser.py b/nemoguardrails/colang/v1_0/lang/coyml_parser.py index 93c036886..acc9893bb 100644 --- a/nemoguardrails/colang/v1_0/lang/coyml_parser.py +++ b/nemoguardrails/colang/v1_0/lang/coyml_parser.py @@ -420,14 +420,20 @@ def _extract_elements(items: List) -> List[dict]: # for `if` flow elements, we have to go recursively if element["_type"] == "if": if_element = element - then_elements = _extract_elements(if_element["then"]) - else_elements = _extract_elements(if_element["else"]) + then_items = ( + if_element["then"] if isinstance(if_element["then"], list) else [] + ) + else_items = ( + if_element["else"] if isinstance(if_element["else"], list) else [] + ) + then_elements = _extract_elements(then_items) + else_elements = _extract_elements(else_items) # Remove the raw info del if_element["then"] del if_element["else"] - if_element["_next_else"] = len(then_elements) + 1 + if_element["_next_else"] = str(len(then_elements) + 1) # Add the "if" elements.append(if_element) @@ -437,8 +443,10 @@ def _extract_elements(items: List) -> List[dict]: # if we have "else" elements, we need to adjust also add a jump if len(else_elements) > 0: - elements.append({"_type": "jump", "_next": len(else_elements) + 1}) - if_element["_next_else"] += 1 + elements.append( + {"_type": "jump", "_next": str(len(else_elements) + 1)} + ) + if_element["_next_else"] = str(int(if_element["_next_else"]) + 1) # Add the "else" elements elements.extend(else_elements) @@ -446,21 +454,24 @@ def _extract_elements(items: List) -> List[dict]: # WHILE elif element["_type"] == "while": while_element = element - do_elements = _extract_elements(while_element["do"]) + do_items = ( + while_element["do"] if isinstance(while_element["do"], list) else [] + ) + do_elements = _extract_elements(do_items) n = len(do_elements) # Remove the raw info del while_element["do"] # On break we have to skip n elements and 1 jump, hence we go to n+2 - while_element["_next_on_break"] = n + 2 + while_element["_next_on_break"] = str(n + 2) # We need to compute the jumps on break and on continue for each element for j in range(n): # however, we make sure we don't override an inner loop if "_next_on_break" not in do_elements[j]: - do_elements[j]["_next_on_break"] = n + 1 - j - do_elements[j]["_next_on_continue"] = -1 * j - 1 + do_elements[j]["_next_on_break"] = str(n + 1 - j) + do_elements[j]["_next_on_continue"] = str(-1 * j - 1) # Add the "while" elements.append(while_element) @@ -500,7 +511,7 @@ def _extract_elements(items: List) -> List[dict]: branch_element = { "_type": "branch", # these are the relative positions to the current position - "branch_heads": [], + "branch_heads": [], # type: ignore } branch_element_pos = len(elements) elements.append(branch_element) @@ -520,7 +531,7 @@ def _extract_elements(items: List) -> List[dict]: branch_element["_source_mapping"] = branch_path[0]["_source_mapping"] # Create the jump element - jump_element = {"_type": "jump", "_next": 1} + jump_element = {"_type": "jump", "_next": 1} # type: ignore # We compute how far we need to jump based on the remaining branches j = branch_idx + 1 diff --git a/nemoguardrails/colang/v1_0/lang/utils.py b/nemoguardrails/colang/v1_0/lang/utils.py index 69631ee16..5a37be11e 100644 --- a/nemoguardrails/colang/v1_0/lang/utils.py +++ b/nemoguardrails/colang/v1_0/lang/utils.py @@ -86,11 +86,14 @@ def get_numbered_lines(content: str): current_comment = None multiline_string = False current_string = None + multiline_indentation = 0 while i < len(raw_lines): raw_line = raw_lines[i].strip() # handle multiline string if multiline_string: + if current_string is None: + current_string = "" current_string += "\n" + raw_line if raw_line.endswith('"'): multiline_string = False @@ -139,6 +142,8 @@ def get_numbered_lines(content: str): continue if multiline_comment: + if current_comment is None: + current_comment = "" if raw_line.endswith('"""'): current_comment += "\n" + raw_line[0:-3] multiline_comment = False diff --git a/nemoguardrails/colang/v1_0/runtime/runtime.py b/nemoguardrails/colang/v1_0/runtime/runtime.py index 69023d05e..27081510a 100644 --- a/nemoguardrails/colang/v1_0/runtime/runtime.py +++ b/nemoguardrails/colang/v1_0/runtime/runtime.py @@ -659,7 +659,8 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]: if isinstance(result, ActionResult): return_value = result.return_value return_events = result.events - context_updates.update(result.context_updates) + if result.context_updates is not None: + context_updates.update(result.context_updates) # If we have an action result key, we also record the update. if action_result_key: @@ -730,10 +731,17 @@ async def _get_action_resp( ) except Exception as e: log.info(f"Exception {e} while making request to {action_name}") + if not isinstance(result, dict): + result = {"value": result} return result, status except Exception as e: log.info(f"Failed to get response from {action_name} due to exception {e}") + + # Ensure result is a dict as expected by the return type + if not isinstance(result, dict): + result = {"value": result} + return result, status async def _process_start_flow(self, events: List[dict], processing_log: List[dict]) -> List[dict]: diff --git a/nemoguardrails/colang/v1_0/runtime/sliding.py b/nemoguardrails/colang/v1_0/runtime/sliding.py index b1e4513ab..54b4a265c 100644 --- a/nemoguardrails/colang/v1_0/runtime/sliding.py +++ b/nemoguardrails/colang/v1_0/runtime/sliding.py @@ -14,7 +14,10 @@ # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from nemoguardrails.colang.v1_0.runtime.flows import FlowConfig, State from nemoguardrails.colang.v1_0.runtime.eval import eval_expression diff --git a/nemoguardrails/colang/v2_x/lang/colang_ast.py b/nemoguardrails/colang/v2_x/lang/colang_ast.py index 8e54cfe17..9c2d1c518 100644 --- a/nemoguardrails/colang/v2_x/lang/colang_ast.py +++ b/nemoguardrails/colang/v2_x/lang/colang_ast.py @@ -77,13 +77,17 @@ def get(self, key, default_value=None): def __eq__(self, other): if isinstance(other, self.__class__): - return self.__hash__() == other.__hash__() + return self.hash() == other.hash() return NotImplemented def hash(self): """Return the hash for the current object.""" return hash(_make_hashable(self)) + def __hash__(self): + """Return the hash for the current object.""" + return self.hash() + ElementType = Union[Element, dict] diff --git a/nemoguardrails/colang/v2_x/lang/expansion.py b/nemoguardrails/colang/v2_x/lang/expansion.py index 5162b584c..64e864235 100644 --- a/nemoguardrails/colang/v2_x/lang/expansion.py +++ b/nemoguardrails/colang/v2_x/lang/expansion.py @@ -106,7 +106,12 @@ def expand_elements( if e.args[0]: error = e.args[0] - if hasattr(element, "_source") and element._source: + if ( + not isinstance(element, dict) + and hasattr(element, "_source") + and element._source is not None + and hasattr(element._source, "line") + ): # TODO: Resolve source line to Colang file level raise ColangSyntaxError(error + f" on source line {element._source.line}") else: @@ -413,10 +418,15 @@ def _expand_match_element( for idx, element in enumerate(and_group["elements"]): new_elements.append(event_label_elements[idx]) + # Ensure element is valid for SpecOp + if isinstance(element, (dict, Spec)): + spec_element: Union[dict, Spec] = element + else: + spec_element = {} new_elements.append( SpecOp( op="match", - spec=element, + spec=spec_element, ) ) new_elements.append(goto_end_element) @@ -433,8 +443,8 @@ def _expand_match_element( else: # Multiple and-groups combined by or - fork_uid: str = new_var_uuid() - fork_element = ForkHead(fork_uid=fork_uid) + or_fork_uid: str = new_var_uuid() + fork_element = ForkHead(fork_uid=or_fork_uid) group_label_elements: List[Label] = [] failure_label_name = f"failure_label_{new_var_uuid()}" failure_label_element = Label(name=failure_label_name) @@ -463,12 +473,12 @@ def _expand_match_element( new_elements.append(failure_label_element) new_elements.append(WaitForHeads(number=len(or_group))) - new_elements.append(MergeHeads(fork_uid=fork_uid)) + new_elements.append(MergeHeads(fork_uid=or_fork_uid)) new_elements.append(CatchPatternFailure(label=None)) new_elements.append(Abort()) new_elements.append(end_label_element) - new_elements.append(MergeHeads(fork_uid=fork_uid)) + new_elements.append(MergeHeads(fork_uid=or_fork_uid)) new_elements.append(CatchPatternFailure(label=None)) else: diff --git a/nemoguardrails/colang/v2_x/lang/transformer.py b/nemoguardrails/colang/v2_x/lang/transformer.py index 74d77613f..8e75499c1 100644 --- a/nemoguardrails/colang/v2_x/lang/transformer.py +++ b/nemoguardrails/colang/v2_x/lang/transformer.py @@ -168,15 +168,17 @@ def _flow_def(self, children: dict, meta: Meta) -> Flow: assert member_name_el["_type"] == "var_name" member_name = member_name_el["elements"][0][1:] - member_def = FlowReturnMemberDef(name=member_name) + return_member_def_obj = FlowReturnMemberDef(name=member_name) # If we have a default value, we also use that if len(return_member_def["elements"]) == 2: default_value_el = return_member_def["elements"][1] assert default_value_el["_type"] == "expr" - member_def.default_value_expr = default_value_el["elements"][0] + return_member_def_obj.default_value_expr = default_value_el[ + "elements" + ][0] - return_member_defs.append(member_def) + return_member_defs.append(return_member_def_obj) elements[0:0] = [ SpecOp( @@ -546,7 +548,7 @@ def _non_var_spec_and(self, children: list, meta: Meta) -> dict: val["_source"] = self.__source(meta) return val - def __default__(self, data, children: list, meta: Meta) -> dict: + def __default__(self, data, children: list, meta: Meta) -> Any: """Default function that is called if there is no attribute matching ``data`` Can be overridden. Defaults to creating diff --git a/nemoguardrails/colang/v2_x/lang/utils.py b/nemoguardrails/colang/v2_x/lang/utils.py index 122024b6e..5bda97551 100644 --- a/nemoguardrails/colang/v2_x/lang/utils.py +++ b/nemoguardrails/colang/v2_x/lang/utils.py @@ -18,7 +18,7 @@ def dataclass_to_dict(obj: Any) -> Any: - if is_dataclass(obj): + if is_dataclass(obj) and not isinstance(obj, type): return {k: dataclass_to_dict(v) for k, v in asdict(obj).items()} elif isinstance(obj, list): return [dataclass_to_dict(v) for v in obj] diff --git a/nemoguardrails/colang/v2_x/runtime/eval.py b/nemoguardrails/colang/v2_x/runtime/eval.py index 47b3659e7..83d93cda2 100644 --- a/nemoguardrails/colang/v2_x/runtime/eval.py +++ b/nemoguardrails/colang/v2_x/runtime/eval.py @@ -202,7 +202,8 @@ def _regex_findall(pattern: str, string: str) -> List[str]: def _pretty_str(data: Any) -> str: if isinstance(data, (dict, list, set)): string = json.dumps(data, indent=4) - return SimplifyFormatter().format(string) + # SimplifyFormatter.format() accepts string as well as LogRecord + return str(SimplifyFormatter().format(string)) # type: ignore return str(data) @@ -245,27 +246,27 @@ def _get_type(val: Any) -> str: def _less_than_operator(v_ref: Any) -> ComparisonExpression: """Create less then comparison expression.""" - return ComparisonExpression(lambda val, v_ref=v_ref: val < v_ref, v_ref) + return ComparisonExpression(lambda val: val < v_ref, v_ref) def _equal_or_less_than_operator(v_ref: Any) -> ComparisonExpression: """Create equal or less than comparison expression.""" - return ComparisonExpression(lambda val, val_ref=v_ref: val <= val_ref, v_ref) + return ComparisonExpression(lambda val: val <= v_ref, v_ref) def _greater_than_operator(v_ref: Any) -> ComparisonExpression: """Create less then comparison expression.""" - return ComparisonExpression(lambda val, val_ref=v_ref: val > val_ref, v_ref) + return ComparisonExpression(lambda val: val > v_ref, v_ref) def _equal_or_greater_than_operator(v_ref: Any) -> ComparisonExpression: """Create equal or less than comparison expression.""" - return ComparisonExpression(lambda val, val_ref=v_ref: val >= val_ref, v_ref) + return ComparisonExpression(lambda val: val >= v_ref, v_ref) def _not_equal_to_operator(v_ref: Any) -> ComparisonExpression: """Create a not equal comparison expression.""" - return ComparisonExpression(lambda val, val_ref=v_ref: val != val_ref, v_ref) + return ComparisonExpression(lambda val: val != v_ref, v_ref) def _flows_info(state: State, flow_instance_uid: Optional[str] = None) -> dict: diff --git a/nemoguardrails/colang/v2_x/runtime/flows.py b/nemoguardrails/colang/v2_x/runtime/flows.py index 7cadcfd93..9923e61a2 100644 --- a/nemoguardrails/colang/v2_x/runtime/flows.py +++ b/nemoguardrails/colang/v2_x/runtime/flows.py @@ -23,7 +23,20 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Callable, Deque, Dict, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Deque, + Dict, + List, + Optional, + Tuple, + Union, +) + +if TYPE_CHECKING: + from nemoguardrails.rails.llm.config import RailsConfig from dataclasses_json import dataclass_json diff --git a/nemoguardrails/logging/processing_log.py b/nemoguardrails/logging/processing_log.py index 54def219a..983ff0e7a 100644 --- a/nemoguardrails/logging/processing_log.py +++ b/nemoguardrails/logging/processing_log.py @@ -14,7 +14,7 @@ # limitations under the License. import contextvars -from typing import List +from typing import List, Optional from nemoguardrails.rails.llm.options import ( ActivatedRail, @@ -23,7 +23,9 @@ ) # The processing log for the current async stack -processing_log_var = contextvars.ContextVar("processing_log", default=None) +processing_log_var: contextvars.ContextVar[ + Optional[List[dict]] +] = contextvars.ContextVar("processing_log", default=None) def compute_generation_log(processing_log: List[dict]) -> GenerationLog: From 2299504d8353186883bf8b13f9e787883325ef3e Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 13:35:34 -0500 Subject: [PATCH 2/5] Revert "Dummy commit to set up the chore/type-clean-guardrails PR and branch" This reverts commit 71d00f083fb59bda34c82b82eea85602c1710265. --- nemoguardrails/actions/llm/generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 938f0d3d0..4a3f2c1a4 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -130,7 +130,7 @@ async def init(self): self._init_flows_index(), ) - def _extract_user_message_example(self, flow: Flow) -> None: + def _extract_user_message_example(self, flow: Flow): """Heuristic to extract user message examples from a flow.""" elements = [item for item in flow.elements if item["_type"] != "doc_string_stmt" and item["_type"] != "stmt"] if len(elements) != 2: From f1045f58d6100c9bb280647afe0170f864dfed44 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Fri, 24 Oct 2025 12:02:26 -0500 Subject: [PATCH 3/5] Final cleanups --- .../colang/v1_0/lang/comd_parser.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/nemoguardrails/colang/v1_0/lang/comd_parser.py b/nemoguardrails/colang/v1_0/lang/comd_parser.py index 1f6d6709d..dd735a219 100644 --- a/nemoguardrails/colang/v1_0/lang/comd_parser.py +++ b/nemoguardrails/colang/v1_0/lang/comd_parser.py @@ -360,22 +360,24 @@ def parse_md_file(file_name, content=None): continue # Make sure we have the type of the symbol in the name of the symbol - if sym is not None: - sym = _get_typed_symbol_name(sym, symbol_type) + if not sym or not isinstance(sym, str): + raise ValueError(f"sym must be a non-empty string, found {sym}") + + sym = _get_typed_symbol_name(sym, symbol_type) - # For objects, we translate the "string" type to "kb:Object:prop|partial" - param_type = _get_param_type(parts[1]) - if symbol_type == "object" and param_type in ["string", "text"]: - object_name = split_max(sym, ":", 1)[1] - param_type = f"kb:{object_name}:{parts[0]}|partial" + # For objects, we translate the "string" type to "kb:Object:prop|partial" + param_type = _get_param_type(parts[1]) + if symbol_type == "object" and param_type in ["string", "text"]: + object_name = split_max(sym, ":", 1)[1] + param_type = f"kb:{object_name}:{parts[0]}|partial" - # TODO: figure out a cleaner way to deal with this - # For the "type:time" type, we transform it into "lookup:time" - if param_type == "type:time": - param_type = "lookup:time" + # TODO: figure out a cleaner way to deal with this + # For the "type:time" type, we transform it into "lookup:time" + if param_type == "type:time": + param_type = "lookup:time" - result["mappings"].append((f"{sym}:{parts[0]}", param_type)) - symbol_params.append(parts[0]) + result["mappings"].append((f"{sym}:{parts[0]}", param_type)) + symbol_params.append(parts[0]) elif line.startswith("-") or line.startswith("*"): if sym is None: From 84484f560f08a072db233a60de6966332437dce7 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Fri, 12 Dec 2025 10:57:54 +0000 Subject: [PATCH 4/5] Format with ruff --- nemoguardrails/colang/v1_0/lang/coyml_parser.py | 16 ++++------------ nemoguardrails/colang/v2_x/lang/transformer.py | 4 +--- nemoguardrails/logging/processing_log.py | 6 +++--- 3 files changed, 8 insertions(+), 18 deletions(-) diff --git a/nemoguardrails/colang/v1_0/lang/coyml_parser.py b/nemoguardrails/colang/v1_0/lang/coyml_parser.py index acc9893bb..c6977c1b4 100644 --- a/nemoguardrails/colang/v1_0/lang/coyml_parser.py +++ b/nemoguardrails/colang/v1_0/lang/coyml_parser.py @@ -420,12 +420,8 @@ def _extract_elements(items: List) -> List[dict]: # for `if` flow elements, we have to go recursively if element["_type"] == "if": if_element = element - then_items = ( - if_element["then"] if isinstance(if_element["then"], list) else [] - ) - else_items = ( - if_element["else"] if isinstance(if_element["else"], list) else [] - ) + then_items = if_element["then"] if isinstance(if_element["then"], list) else [] + else_items = if_element["else"] if isinstance(if_element["else"], list) else [] then_elements = _extract_elements(then_items) else_elements = _extract_elements(else_items) @@ -443,9 +439,7 @@ def _extract_elements(items: List) -> List[dict]: # if we have "else" elements, we need to adjust also add a jump if len(else_elements) > 0: - elements.append( - {"_type": "jump", "_next": str(len(else_elements) + 1)} - ) + elements.append({"_type": "jump", "_next": str(len(else_elements) + 1)}) if_element["_next_else"] = str(int(if_element["_next_else"]) + 1) # Add the "else" elements @@ -454,9 +448,7 @@ def _extract_elements(items: List) -> List[dict]: # WHILE elif element["_type"] == "while": while_element = element - do_items = ( - while_element["do"] if isinstance(while_element["do"], list) else [] - ) + do_items = while_element["do"] if isinstance(while_element["do"], list) else [] do_elements = _extract_elements(do_items) n = len(do_elements) diff --git a/nemoguardrails/colang/v2_x/lang/transformer.py b/nemoguardrails/colang/v2_x/lang/transformer.py index 8e75499c1..2c1968fec 100644 --- a/nemoguardrails/colang/v2_x/lang/transformer.py +++ b/nemoguardrails/colang/v2_x/lang/transformer.py @@ -174,9 +174,7 @@ def _flow_def(self, children: dict, meta: Meta) -> Flow: if len(return_member_def["elements"]) == 2: default_value_el = return_member_def["elements"][1] assert default_value_el["_type"] == "expr" - return_member_def_obj.default_value_expr = default_value_el[ - "elements" - ][0] + return_member_def_obj.default_value_expr = default_value_el["elements"][0] return_member_defs.append(return_member_def_obj) diff --git a/nemoguardrails/logging/processing_log.py b/nemoguardrails/logging/processing_log.py index 983ff0e7a..b23505110 100644 --- a/nemoguardrails/logging/processing_log.py +++ b/nemoguardrails/logging/processing_log.py @@ -23,9 +23,9 @@ ) # The processing log for the current async stack -processing_log_var: contextvars.ContextVar[ - Optional[List[dict]] -] = contextvars.ContextVar("processing_log", default=None) +processing_log_var: contextvars.ContextVar[Optional[List[dict]]] = contextvars.ContextVar( + "processing_log", default=None +) def compute_generation_log(processing_log: List[dict]) -> GenerationLog: From 1a0b696503bc9a3ce84874d6feb8736c3a49c7fa Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Fri, 12 Dec 2025 13:21:22 +0000 Subject: [PATCH 5/5] Fixed type errors --- nemoguardrails/colang/runtime.py | 13 +++--- .../colang/v1_0/lang/colang_parser.py | 43 ++++++++++++------- .../colang/v1_0/lang/coyml_parser.py | 14 +++--- nemoguardrails/colang/v1_0/lang/utils.py | 4 +- nemoguardrails/colang/v1_0/runtime/flows.py | 30 +++++++++---- nemoguardrails/colang/v1_0/runtime/runtime.py | 26 +++++++---- nemoguardrails/colang/v2_x/lang/parser.py | 6 ++- nemoguardrails/colang/v2_x/runtime/runtime.py | 32 ++++++++------ .../colang/v2_x/runtime/serialization.py | 5 ++- .../colang/v2_x/runtime/statemachine.py | 33 +++++++++++--- pyproject.toml | 1 + 11 files changed, 140 insertions(+), 67 deletions(-) diff --git a/nemoguardrails/colang/runtime.py b/nemoguardrails/colang/runtime.py index 377155cd7..1765922f3 100644 --- a/nemoguardrails/colang/runtime.py +++ b/nemoguardrails/colang/runtime.py @@ -32,28 +32,31 @@ def __init__(self, config: RailsConfig, verbose: bool = False): self.verbose = verbose # Register the actions with the dispatcher. + imported_paths = config.imported_paths if config.imported_paths else {} self.action_dispatcher = ActionDispatcher( config_path=config.config_path, - import_paths=list(config.imported_paths.values()), + import_paths=list(imported_paths.values()), ) if hasattr(self, "_run_output_rails_in_parallel_streaming"): self.action_dispatcher.register_action( - self._run_output_rails_in_parallel_streaming, + getattr(self, "_run_output_rails_in_parallel_streaming"), name="run_output_rails_in_parallel_streaming", ) if hasattr(self, "_run_flows_in_parallel"): - self.action_dispatcher.register_action(self._run_flows_in_parallel, name="run_flows_in_parallel") + self.action_dispatcher.register_action( + getattr(self, "_run_flows_in_parallel"), name="run_flows_in_parallel" + ) if hasattr(self, "_run_input_rails_in_parallel"): self.action_dispatcher.register_action( - self._run_input_rails_in_parallel, name="run_input_rails_in_parallel" + getattr(self, "_run_input_rails_in_parallel"), name="run_input_rails_in_parallel" ) if hasattr(self, "_run_output_rails_in_parallel"): self.action_dispatcher.register_action( - self._run_output_rails_in_parallel, name="run_output_rails_in_parallel" + getattr(self, "_run_output_rails_in_parallel"), name="run_output_rails_in_parallel" ) # The list of additional parameters that can be passed to the actions. diff --git a/nemoguardrails/colang/v1_0/lang/colang_parser.py b/nemoguardrails/colang/v1_0/lang/colang_parser.py index 255776276..4be068ac7 100644 --- a/nemoguardrails/colang/v1_0/lang/colang_parser.py +++ b/nemoguardrails/colang/v1_0/lang/colang_parser.py @@ -295,6 +295,7 @@ def _create_namespace(self, namespace): # Now, append the new one self.current_namespaces.append(namespace) self.current_namespace = ".".join(self.current_namespaces) + assert self.next_line is not None, "next_line must not be None when creating namespace" self.current_indentation = self.next_line["indentation"] self.current_indentations.append(self.next_line["indentation"]) @@ -318,7 +319,7 @@ def _include_source_mappings(self): # Include the source mapping information if required if self.include_source_mapping: if self.current_element and "_source_mapping" not in self.current_element: - self.current_element["_source_mapping"] = { + self.current_element["_source_mapping"] = { # type: ignore[assignment] "filename": self.filename, "line_number": self.current_line["number"], "line_text": self.current_line["text"], @@ -771,6 +772,10 @@ def _process_define(self): # Finally, we parse the markdown content self._extract_markdown() + def _insert_topic_flow_definition(self) -> None: + """Insert a topic flow definition. Currently not implemented.""" + raise NotImplementedError("Topic flow definitions are not yet implemented") + def _extract_indentation_levels(self): """Helper to extract the indentation levels higher than the current line.""" indentations = [] @@ -910,13 +915,14 @@ def _extract_params(self, param_lines: Optional[List] = None): yaml_value = {"$0": yaml_value} # self.current_element.update(yaml_value) - for k in yaml_value.keys(): - # if the key tarts with $, we remove it - param_name = k - if param_name[0] == "$": - param_name = param_name[1:] + if self.current_element is not None and isinstance(self.current_element, dict): + for k in yaml_value.keys(): + # if the key tarts with $, we remove it + param_name = k + if param_name[0] == "$": + param_name = param_name[1:] - self.current_element[param_name] = yaml_value[k] + self.current_element[param_name] = yaml_value[k] # type: ignore[assignment] def _is_test_flow(self): """Returns true if the current flow is a test one. @@ -956,6 +962,7 @@ def _is_sample_flow(self): def _parse_when(self): # TODO: deal with "when" after "else when" + assert self.next_line is not None, "Expected next line after 'when' statement." assert self.next_line["indentation"] > self.current_line["indentation"], ( "Expected indented block after 'when' statement." ) @@ -1280,6 +1287,7 @@ def _parse_bot(self): # Finally, decide what to include in the element if utterance_id is None: + assert utterance_text is not None, "utterance_text must not be None when utterance_id is None" self.current_element["bot"] = { "_type": "element", "text": utterance_text[1:-1], @@ -1298,11 +1306,12 @@ def _parse_bot(self): # If there was a bot message with a snippet, we also add an expect # TODO: can this be handled better? try: - if "snippet" in self.current_element["bot"]: + bot_element = self.current_element["bot"] + if isinstance(bot_element, dict) and "snippet" in bot_element: self.branches[-1]["elements"].append( { "expect": "snippet", - "snippet": self.current_element["bot"]["snippet"], + "snippet": bot_element["snippet"], } ) # noinspection PyBroadException @@ -1374,7 +1383,7 @@ def _parse_do(self): if return_vars: return_vars = get_stripped_tokens(return_vars.split(",")) return_vars = [_var[1:] if _var[0] == "$" else _var for _var in return_vars] - self.current_element["_return_vars"] = return_vars + self.current_element["_return_vars"] = return_vars # type: ignore[assignment] # Add to current branch self.branches[-1]["elements"].append(self.current_element) @@ -1477,6 +1486,7 @@ def _parse_if_branch(self, if_condition): self.current_element = {"if": if_condition, "then": []} self.branches[-1]["elements"].append(self.current_element) + assert self.next_line is not None, "next_line must not be None when parsing if branch" self.ifs.append( { "element": self.current_element, @@ -1519,6 +1529,7 @@ def _parse_while(self): self.current_element = {"while": while_condition, "do": []} self.branches[-1]["elements"].append(self.current_element) + assert self.next_line is not None, "next_line must not be None when parsing while" # Add a new branch for the then part self.branches.append( { @@ -1533,6 +1544,7 @@ def _parse_any(self): } self.branches[-1]["elements"].append(self.current_element) + assert self.next_line is not None, "next_line must not be None when parsing any" # Add a new branch for the then part self.branches.append( { @@ -1562,6 +1574,7 @@ def _parse_infer(self): } self.branches[-1]["elements"].append(self.current_element) + assert self.next_line is not None, "next_line must not be None when parsing infer" # Add a new branch for the then part self.branches.append( { @@ -1600,7 +1613,7 @@ def _parse_return(self): } if return_values: - self.current_element["_return_values"] = return_values + self.current_element["_return_values"] = return_values # type: ignore[assignment] self.branches[-1]["elements"].append(self.current_element) @@ -1697,15 +1710,15 @@ def parse(self): exception = Exception(error) # Decorate the exception with where the parsing failed - exception.filename = self.filename - exception.line = self.current_line["number"] - exception.error = str(ex) + exception.filename = self.filename # type: ignore[attr-defined] + exception.line = self.current_line["number"] # type: ignore[attr-defined] + exception.error = str(ex) # type: ignore[attr-defined] raise exception self.current_line_idx += 1 - result = {"flows": self.flows} + result: dict = {"flows": self.flows} if self.imports: result["imports"] = self.imports diff --git a/nemoguardrails/colang/v1_0/lang/coyml_parser.py b/nemoguardrails/colang/v1_0/lang/coyml_parser.py index c6977c1b4..2f61271b6 100644 --- a/nemoguardrails/colang/v1_0/lang/coyml_parser.py +++ b/nemoguardrails/colang/v1_0/lang/coyml_parser.py @@ -241,7 +241,7 @@ def _dict_to_element(d): elif d_type in ["break"]: element = {"_type": "break"} elif d_type in ["return"]: - element = {"_type": "jump", "_next": "-1", "_absolute": True} + element = {"_type": "jump", "_next": -1, "_absolute": True} # Include the return values information if "_return_values" in d: @@ -429,7 +429,7 @@ def _extract_elements(items: List) -> List[dict]: del if_element["then"] del if_element["else"] - if_element["_next_else"] = str(len(then_elements) + 1) + if_element["_next_else"] = len(then_elements) + 1 # type: ignore[arg-type] # Add the "if" elements.append(if_element) @@ -439,8 +439,8 @@ def _extract_elements(items: List) -> List[dict]: # if we have "else" elements, we need to adjust also add a jump if len(else_elements) > 0: - elements.append({"_type": "jump", "_next": str(len(else_elements) + 1)}) - if_element["_next_else"] = str(int(if_element["_next_else"]) + 1) + elements.append({"_type": "jump", "_next": len(else_elements) + 1}) # type: ignore[dict-item] + if_element["_next_else"] += 1 # type: ignore[arg-type, operator] # Add the "else" elements elements.extend(else_elements) @@ -456,14 +456,14 @@ def _extract_elements(items: List) -> List[dict]: del while_element["do"] # On break we have to skip n elements and 1 jump, hence we go to n+2 - while_element["_next_on_break"] = str(n + 2) + while_element["_next_on_break"] = n + 2 # type: ignore[arg-type] # We need to compute the jumps on break and on continue for each element for j in range(n): # however, we make sure we don't override an inner loop if "_next_on_break" not in do_elements[j]: - do_elements[j]["_next_on_break"] = str(n + 1 - j) - do_elements[j]["_next_on_continue"] = str(-1 * j - 1) + do_elements[j]["_next_on_break"] = n + 1 - j + do_elements[j]["_next_on_continue"] = -1 * j - 1 # Add the "while" elements.append(while_element) diff --git a/nemoguardrails/colang/v1_0/lang/utils.py b/nemoguardrails/colang/v1_0/lang/utils.py index 5a37be11e..8d40821b7 100644 --- a/nemoguardrails/colang/v1_0/lang/utils.py +++ b/nemoguardrails/colang/v1_0/lang/utils.py @@ -408,7 +408,7 @@ def parse_package_name(text): return package_name -def string_hash(s): +def string_hash(s: str) -> str: """A simple string hash with an equivalent implementation in javascript. module.exports.string_hash = function(s){ @@ -426,7 +426,7 @@ def string_hash(s): """ _hash = 0 if len(s) == 0: - return 0 + return "0" for i in range(len(s)): _char = ord(s[i]) _hash = ((_hash << 5) - _hash) + _char diff --git a/nemoguardrails/colang/v1_0/runtime/flows.py b/nemoguardrails/colang/v1_0/runtime/flows.py index c341cf879..26f864365 100644 --- a/nemoguardrails/colang/v1_0/runtime/flows.py +++ b/nemoguardrails/colang/v1_0/runtime/flows.py @@ -18,12 +18,15 @@ from dataclasses import dataclass, field from enum import Enum from time import time -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from nemoguardrails.colang.v1_0.runtime.eval import eval_expression from nemoguardrails.colang.v1_0.runtime.sliding import slide from nemoguardrails.utils import new_event_dict, new_uuid +if TYPE_CHECKING: + from nemoguardrails.rails.llm.config import RailsConfig + @dataclass class FlowConfig: @@ -91,7 +94,7 @@ class FlowState: status: FlowStatus = FlowStatus.ACTIVE # The UID of the flows that interrupted this one - interrupted_by = None + interrupted_by: Optional[str] = None @dataclass @@ -240,11 +243,11 @@ def _call_subflow(new_state: State, flow_state: FlowState) -> Optional[FlowState Optional[FlowState]: The state of the subflow, if applicable. """ flow_config = new_state.flow_configs[flow_state.flow_id] - subflow_id = flow_config.elements[flow_state.head]["flow_name"] + subflow_id: str = flow_config.elements[flow_state.head]["flow_name"] # Basic support for referring a subflow using a variable if subflow_id.startswith("$"): - subflow_id = eval_expression(subflow_id, new_state.context) + subflow_id = str(eval_expression(subflow_id, new_state.context)) # parameter support if _flow_id_has_params(subflow_id): @@ -300,7 +303,10 @@ def _slide_with_subflows(state: State, flow_state: FlowState) -> Optional[int]: should_continue = True while should_continue: should_continue = False - flow_state.head = slide(state, flow_config, flow_state.head) + slide_result = slide(state, flow_config, flow_state.head) + if slide_result is None: + break + flow_state.head = slide_result # We check if we reached a point where we need to call a subflow if flow_state.head >= 0: @@ -441,7 +447,10 @@ def compute_next_state(state: State, event: dict) -> State: continue # We try to slide first, just in case a flow starts with sliding logic - start_head = slide(new_state, flow_config, 0) + slide_result = slide(new_state, flow_config, 0) + if slide_result is None: + continue + start_head: int = slide_result # If the first element matches the current event, # or, if the flow is explicitly started by a `start_flow` event, @@ -488,7 +497,12 @@ def compute_next_state(state: State, event: dict) -> State: # If we have aborted flows, and the current flow is an extension, when we interrupt them. # We are only interested when the extension flow actually decided, not just started. - if decision_flow_config and decision_flow_config.is_extension and decision_flow_state.head > 1: + if ( + decision_flow_config + and decision_flow_state + and decision_flow_config.is_extension + and decision_flow_state.head > 1 + ): for flow_state in new_state.flow_states: if flow_state.status == FlowStatus.ABORTED and state.flow_configs[flow_state.flow_id].is_interruptible: flow_state.status = FlowStatus.INTERRUPTED @@ -665,7 +679,7 @@ def compute_context(history: List[dict]): Returns: dict: The computed context. """ - context = { + context: dict = { "last_user_message": None, "last_bot_message": None, } diff --git a/nemoguardrails/colang/v1_0/runtime/runtime.py b/nemoguardrails/colang/v1_0/runtime/runtime.py index 27081510a..704e050f6 100644 --- a/nemoguardrails/colang/v1_0/runtime/runtime.py +++ b/nemoguardrails/colang/v1_0/runtime/runtime.py @@ -439,10 +439,15 @@ def filter_and_append(logs, target_log): async def _run_input_rails_in_parallel(self, flows: List[str], events: List[dict]) -> ActionResult: """Run the input rails in parallel.""" - pre_events = [(await create_event({"_type": "StartInputRail", "flow_id": flow})).events[0] for flow in flows] - post_events = [ - (await create_event({"_type": "InputRailFinished", "flow_id": flow})).events[0] for flow in flows - ] + pre_events = [] + post_events = [] + for flow in flows: + pre_result = await create_event({"_type": "StartInputRail", "flow_id": flow}) + post_result = await create_event({"_type": "InputRailFinished", "flow_id": flow}) + if pre_result.events: + pre_events.append(pre_result.events[0]) + if post_result.events: + post_events.append(post_result.events[0]) return await self._run_flows_in_parallel( flows=flows, events=events, pre_events=pre_events, post_events=post_events @@ -450,10 +455,15 @@ async def _run_input_rails_in_parallel(self, flows: List[str], events: List[dict async def _run_output_rails_in_parallel(self, flows: List[str], events: List[dict]) -> ActionResult: """Run the output rails in parallel.""" - pre_events = [(await create_event({"_type": "StartOutputRail", "flow_id": flow})).events[0] for flow in flows] - post_events = [ - (await create_event({"_type": "OutputRailFinished", "flow_id": flow})).events[0] for flow in flows - ] + pre_events = [] + post_events = [] + for flow in flows: + pre_result = await create_event({"_type": "StartOutputRail", "flow_id": flow}) + post_result = await create_event({"_type": "OutputRailFinished", "flow_id": flow}) + if pre_result.events: + pre_events.append(pre_result.events[0]) + if post_result.events: + post_events.append(post_result.events[0]) return await self._run_flows_in_parallel( flows=flows, events=events, pre_events=pre_events, post_events=post_events diff --git a/nemoguardrails/colang/v2_x/lang/parser.py b/nemoguardrails/colang/v2_x/lang/parser.py index 270288414..a7941cbd1 100644 --- a/nemoguardrails/colang/v2_x/lang/parser.py +++ b/nemoguardrails/colang/v2_x/lang/parser.py @@ -17,8 +17,10 @@ import os import re import textwrap +from typing import Any import yaml +from lark import Tree from nemoguardrails.colang.v2_x.lang.colang_ast import Flow, Import from nemoguardrails.colang.v2_x.lang.grammar.load import load_lark_parser @@ -38,7 +40,7 @@ def __init__(self, include_source_mapping: bool = False): # Initialize the Lark Parser self._lark_parser = load_lark_parser(self.grammar_path) - def get_parsing_tree(self, content: str) -> dict: + def get_parsing_tree(self, content: str) -> Tree[Any]: """Helper to get only the parsing tree. Args: @@ -133,7 +135,7 @@ def parse_content(self, content: str, print_tokens: bool = False, print_parsing_ import_el: Import = element if import_el.path: result["import_paths"].append(import_el.path) - else: + elif import_el.package: # If we have a package name, we need to translate it to a path result["import_paths"].append(os.path.join(*import_el.package.split("."))) diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index 6980714bc..0aecfbaae 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -24,7 +24,7 @@ from nemoguardrails.actions.actions import ActionResult from nemoguardrails.colang import parse_colang_file from nemoguardrails.colang.runtime import Runtime -from nemoguardrails.colang.v2_x.lang.colang_ast import Decorator, Flow +from nemoguardrails.colang.v2_x.lang.colang_ast import Decorator, ElementType, Flow from nemoguardrails.colang.v2_x.lang.utils import format_colang_parsing_error_message from nemoguardrails.colang.v2_x.runtime.errors import ( ColangRuntimeError, @@ -138,9 +138,13 @@ async def _remove_flows_action(self, state: "State", **args: dict) -> None: def _init_flow_configs(self) -> None: """Initializes the flow configs based on the config.""" - self.flow_configs = create_flow_configs_from_flow_list(self.config.flows) + # Type assertion: config.flows contains Flow objects at runtime + from nemoguardrails.colang.v2_x.lang.colang_ast import Flow - async def generate_events(self, events: List[dict]) -> List[dict]: + flows = [f for f in self.config.flows if isinstance(f, Flow)] + self.flow_configs = create_flow_configs_from_flow_list(flows) + + async def generate_events(self, events: List[dict], processing_log: Optional[List[dict]] = None) -> List[dict]: raise NotImplementedError("Stateless API not supported for Colang 2.x, yet.") @staticmethod @@ -167,7 +171,7 @@ async def _process_start_action( action_name: str, action_params: dict, context: dict, - events: List[dict], + events: List[Union[dict, Event]], state: "State", ) -> Tuple[Any, List[dict], dict]: """Starts the specified action, waits for it to finish and posts back the result.""" @@ -268,7 +272,9 @@ async def _get_action_resp( # Call the Actions Server if it is available. # But not for system actions, those should still run locally. if action_meta.get("is_system_action", False) or self.config.actions_server_url is None: - result, status = await self.action_dispatcher.execute_action(action_name, kwargs) + action_result, action_status = await self.action_dispatcher.execute_action(action_name, kwargs) + result = action_result if action_result is not None else result + status = action_status else: url = urljoin(self.config.actions_server_url, "/v1/actions/run") # action server execute action path data = {"action_name": action_name, "action_parameters": kwargs} @@ -280,11 +286,9 @@ async def _get_action_resp( f"Got status code {resp.status} while getting response from {action_name}" ) - resp = await resp.json() - result, status = ( - resp.get("result", result), - resp.get("status", status), - ) + resp_json = await resp.json() + result = resp_json.get("result") or result + status = resp_json.get("status") or status except Exception as e: log.info("Exception %s while making request to %s", e, action_name) return result, status @@ -342,6 +346,8 @@ async def _get_async_actions_finished_events(self, main_flow_uid: str) -> Tuple[ "Local action finished with an exception!", exc_info=True, ) + self.async_actions[main_flow_uid].remove(finished_task) + continue self.async_actions[main_flow_uid].remove(finished_task) @@ -387,7 +393,7 @@ async def process_events( """ output_events = [] - input_events: List[Union[dict, InternalEvent]] = events.copy() + input_events: List[Union[dict, InternalEvent]] = list(events) local_running_actions: List[asyncio.Task[dict]] = [] if state is None or state == {}: @@ -668,9 +674,11 @@ def create_flow_configs_from_flow_list(flows: List[Flow]) -> Dict[str, FlowConfi ]: raise ColangSyntaxError(f"Flow '{flow.name}' starts with a keyword!") + # Cast elements to ElementType list for type compatibility + elements: List[ElementType] = list(flow.elements) config = FlowConfig( id=flow.name, - elements=flow.elements, + elements=elements, decorators=convert_decorator_list_to_dictionary(flow.decorators), parameters=flow.parameters, return_members=flow.return_members, diff --git a/nemoguardrails/colang/v2_x/runtime/serialization.py b/nemoguardrails/colang/v2_x/runtime/serialization.py index e924ab435..322299ff9 100644 --- a/nemoguardrails/colang/v2_x/runtime/serialization.py +++ b/nemoguardrails/colang/v2_x/runtime/serialization.py @@ -211,7 +211,10 @@ def state_to_json(state: State, indent: bool = False): def json_to_state(s: str) -> State: """Helper to decode a State object from a JSON string.""" data = json.loads(s) - state = decode_from_dict(data, refs={}) + decoded = decode_from_dict(data, refs={}) + if not isinstance(decoded, State): + raise ValueError(f"Expected State object, got {type(decoded)}") + state: State = decoded # Redo the callbacks. for flow_uid, flow_state in state.flow_states.items(): diff --git a/nemoguardrails/colang/v2_x/runtime/statemachine.py b/nemoguardrails/colang/v2_x/runtime/statemachine.py index 838081b6a..f2f6f35d1 100644 --- a/nemoguardrails/colang/v2_x/runtime/statemachine.py +++ b/nemoguardrails/colang/v2_x/runtime/statemachine.py @@ -29,6 +29,7 @@ Break, CatchPatternFailure, Continue, + Element, ElementType, EndScope, ForkHead, @@ -87,13 +88,17 @@ def initialize_state(state: State) -> None: state.flow_states = dict() + current_flow_config: Optional[FlowConfig] = None try: # TODO: Think about where to put this for flow_config in state.flow_configs.values(): + current_flow_config = flow_config initialize_flow(state, flow_config) except Exception as e: - if e.args[0]: - raise ColangSyntaxError(e.args[0] + f" in flow `{flow_config.id}` ({flow_config.source_file})") + if e.args[0] and current_flow_config: + raise ColangSyntaxError( + e.args[0] + f" in flow `{current_flow_config.id}` ({current_flow_config.source_file})" + ) else: raise ColangSyntaxError() from e @@ -119,7 +124,7 @@ def initialize_flow(state: State, flow_config: FlowConfig) -> None: # Extract all the label elements for idx, element in enumerate(flow_config.elements): if isinstance(element, Label): - flow_config.element_labels.update({element["name"]: idx}) + flow_config.element_labels.update({element.name: idx}) def create_flow_instance( @@ -872,7 +877,7 @@ def _advance_head_front(state: State, heads: List[FlowHead]) -> List[FlowHead]: # In case there were any runtime error the flow will be aborted (fail) source_line = "unknown" element = flow_config.elements[head.position] - if hasattr(element, "_source") and element._source: + if isinstance(element, Element) and element._source: source_line = str(element._source.line) log.warning( "Flow '%s' failed on line %s (%s) due to Colang runtime exception: %s", @@ -1955,6 +1960,8 @@ def get_event_name_from_element(state: State, flow_state: FlowState, element: Sp if element_spec.members is None: raise ColangValueError("Missing event attributes!") event_name = member["name"] + if event_name is None: + raise ColangValueError("Event name is required") event = obj.get_event(event_name, {}) return event.name else: @@ -1967,6 +1974,8 @@ def get_event_name_from_element(state: State, flow_state: FlowState, element: Sp flow_config = state.flow_configs[element_spec.name] temp_flow_state = create_flow_instance(flow_config, "", "", {}) flow_event_name = element_spec.members[0]["name"] + if flow_event_name is None: + raise ColangValueError("Flow event name is required") flow_event: InternalEvent = temp_flow_state.get_event(flow_event_name, {}) del flow_event.arguments["source_flow_instance_uid"] del flow_event.arguments["flow_instance_uid"] @@ -1976,6 +1985,8 @@ def get_event_name_from_element(state: State, flow_state: FlowState, element: Sp assert element_spec.name action = Action(element_spec.name, {}, flow_state.flow_id) event_name = element_spec.members[0]["name"] + if event_name is None: + raise ColangValueError("Action event name is required") action_event: ActionEvent = action.get_event(event_name, {}) return action_event.name else: @@ -2029,7 +2040,9 @@ def get_event_from_element(state: State, flow_state: FlowState, element: SpecOp) if element_spec.members is None: raise ColangValueError("Missing event attributes!") event_name = member["name"] - event_arguments = member["arguments"] + if event_name is None: + raise ColangValueError("Event name is required") + event_arguments = member["arguments"] or {} event_arguments = _evaluate_arguments(event_arguments, _get_eval_context(state, flow_state)) event = obj.get_event(event_name, event_arguments) @@ -2051,8 +2064,12 @@ def get_event_from_element(state: State, flow_state: FlowState, element: SpecOp) flow_config = state.flow_configs[element_spec.name] temp_flow_state = create_flow_instance(flow_config, "", "", {}) flow_event_name = element_spec.members[0]["name"] + if flow_event_name is None: + raise ColangValueError("Flow event name is required") flow_event_arguments = element_spec.arguments - flow_event_arguments.update(element_spec.members[0]["arguments"]) + member_arguments = element_spec.members[0]["arguments"] + if member_arguments: + flow_event_arguments.update(member_arguments) flow_event_arguments = _evaluate_arguments(flow_event_arguments, _get_eval_context(state, flow_state)) flow_event: InternalEvent = temp_flow_state.get_event(flow_event_name, flow_event_arguments) del flow_event.arguments["source_flow_instance_uid"] @@ -2067,7 +2084,9 @@ def get_event_from_element(state: State, flow_state: FlowState, element: SpecOp) action = Action(element_spec.name, action_arguments, flow_state.flow_id) # TODO: refactor the following repetition of code (see above) event_name = element_spec.members[0]["name"] - event_arguments = element_spec.members[0]["arguments"] + if event_name is None: + raise ColangValueError("Action event name is required") + event_arguments = element_spec.members[0]["arguments"] or {} event_arguments = _evaluate_arguments(event_arguments, _get_eval_context(state, flow_state)) action_event: ActionEvent = action.get_event(event_name, event_arguments) if element["op"] == "match": diff --git a/pyproject.toml b/pyproject.toml index ef6bf9218..2069f8ac7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,7 @@ include = [ "nemoguardrails/server/**", "tests/test_callbacks.py", "nemoguardrails/benchmark/**", + "nemoguardrails/colang/**", ] exclude = [ "nemoguardrails/llm/providers/trtllm/**",