From 3dc8b3b2de95a95b7ac61351a6e8e6767fba0104 Mon Sep 17 00:00:00 2001 From: Ming-Jer Lee Date: Wed, 4 Feb 2026 19:41:11 -0800 Subject: [PATCH 1/3] fix: Address mechanical review findings (security, performance, code hygiene) - Fix Jinja2 SSTI vulnerability using SandboxedEnvironment (Item 1) - Remove unused cloudpickle dependency (Item 2) - Replace print() with structured logging, strip emoji (Item 3) - Replace list.pop(0) with deque.popleft() for O(1) BFS (Item 4) - Add edge adjacency indices for O(1) graph lookups (Item 5) - Narrow except Exception to specific types at 17 locations (Item 6) --- pyproject.toml | 1 - src/clgraph/__init__.py | 4 +- src/clgraph/agent.py | 7 +- src/clgraph/column.py | 26 +++++-- src/clgraph/execution.py | 71 +++++++------------ src/clgraph/lineage_builder.py | 13 ++-- src/clgraph/mcp/server.py | 4 ++ src/clgraph/models.py | 15 ++-- src/clgraph/multi_query.py | 30 ++++---- src/clgraph/pipeline.py | 109 +++++++++++++++-------------- src/clgraph/table.py | 2 +- src/clgraph/tools/sql.py | 6 +- src/clgraph/visualizations.py | 7 +- tests/test_validation_framework.py | 16 ++--- uv.lock | 11 --- 15 files changed, 164 insertions(+), 158 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2474913..ecf2f9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ dependencies = [ "sqlglot>=28.0.0", "graphviz>=0.20.0", "jinja2>=3.0.0", - "cloudpickle>=3.1.2", ] [project.optional-dependencies] diff --git a/src/clgraph/__init__.py b/src/clgraph/__init__.py index 29c3b07..197ab28 100644 --- a/src/clgraph/__init__.py +++ b/src/clgraph/__init__.py @@ -5,10 +5,10 @@ """ try: - from importlib.metadata import version + from importlib.metadata import PackageNotFoundError, version __version__ = version("clgraph") -except Exception: +except PackageNotFoundError: # Fallback when package metadata is not available (e.g., mounted as volume) __version__ = "dev" diff --git a/src/clgraph/agent.py b/src/clgraph/agent.py index 3744cc3..bb944d9 100644 --- a/src/clgraph/agent.py +++ b/src/clgraph/agent.py @@ -20,6 +20,7 @@ print(result.data["sql"]) """ +import logging import re from dataclasses import dataclass from enum import Enum @@ -30,6 +31,8 @@ if TYPE_CHECKING: from .pipeline import Pipeline +logger = logging.getLogger(__name__) + class QuestionType(Enum): """Types of questions the agent can handle.""" @@ -199,8 +202,7 @@ def query(self, question: str) -> AgentResult: # Classify question question_type = self._classify_question(question) - if self.verbose: - print(f"Question type: {question_type}") + logger.debug("Question type: %s", question_type) # Route to appropriate handler try: @@ -225,6 +227,7 @@ def query(self, question: str) -> AgentResult: else: return self._handle_general(question) except Exception as e: + logger.error("Agent query failed: %s", e, exc_info=True) return AgentResult( answer=f"Sorry, I encountered an error: {e}", question_type=question_type, diff --git a/src/clgraph/column.py b/src/clgraph/column.py index a584e83..c3d09fa 100644 --- a/src/clgraph/column.py +++ b/src/clgraph/column.py @@ -6,8 +6,9 @@ """ import logging +from collections import deque from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Set +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set from .models import ( ColumnEdge, @@ -64,8 +65,9 @@ def generate_description(column: ColumnNode, llm: Any, pipeline: "Pipeline"): column.description = response.content.strip() column.description_source = DescriptionSource.GENERATED - except Exception: - # Fallback to simple rule-based description + except (ImportError, ValueError, AttributeError, RuntimeError) as e: + # Fallback to simple rule-based description if LLM fails + logger.debug("LLM description generation failed: %s", e) _generate_fallback_description(column) @@ -233,6 +235,11 @@ class PipelineLineageGraph: edges: List[ColumnEdge] = field(default_factory=list) issues: List[ValidationIssue] = field(default_factory=list) # Validation issues + # Adjacency indices: full_name -> list of edges + _outgoing_index: Dict[str, List[ColumnEdge]] = field(default_factory=dict, repr=False) + _incoming_index: Dict[str, List[ColumnEdge]] = field(default_factory=dict, repr=False) + _column_deps_cache: Optional[Dict[str, Set[str]]] = field(default=None, repr=False) + def add_column(self, column: ColumnNode) -> ColumnNode: """Add a column node to the graph""" self.columns[column.full_name] = column @@ -241,6 +248,9 @@ def add_column(self, column: ColumnNode) -> ColumnNode: def add_edge(self, edge: ColumnEdge): """Add a lineage edge""" self.edges.append(edge) + self._outgoing_index.setdefault(edge.from_node.full_name, []).append(edge) + self._incoming_index.setdefault(edge.to_node.full_name, []).append(edge) + self._column_deps_cache = None # Invalidate cache def add_issue(self, issue: ValidationIssue): """Add a validation issue and log it""" @@ -263,9 +273,14 @@ def _build_column_dependencies(self) -> Dict[str, Set[str]]: Build dependency map: column_full_name -> set of column_full_names it depends on. This is the column-level equivalent of TableDependencyGraph._build_table_dependencies. + Returns cached result when available; invalidated by add_edge(). + Returns: Dict mapping column full_name to set of upstream column full_names """ + if self._column_deps_cache is not None: + return self._column_deps_cache + deps: Dict[str, Set[str]] = {} for full_name in self.columns: @@ -278,6 +293,7 @@ def _build_column_dependencies(self) -> Dict[str, Set[str]]: if to_name in deps: deps[to_name].add(from_name) + self._column_deps_cache = deps return deps def get_upstream(self, full_name: str) -> List[ColumnNode]: @@ -398,10 +414,10 @@ def to_simplified(self) -> "PipelineLineageGraph": for col_name, col in table_columns.items(): # BFS backward to find all reachable table columns visited: Set[str] = set() - queue = [col_name] + queue = deque([col_name]) while queue: - current = queue.pop(0) + current = queue.popleft() if current in visited: continue visited.add(current) diff --git a/src/clgraph/execution.py b/src/clgraph/execution.py index a897b8a..52aba83 100644 --- a/src/clgraph/execution.py +++ b/src/clgraph/execution.py @@ -6,10 +6,13 @@ """ import asyncio +import logging import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from .pipeline import Pipeline @@ -118,9 +121,7 @@ def execute_sql(sql: str): result = executor.run(execute_sql, max_workers=4) print(f"Completed {len(result['completed'])} queries") """ - if verbose: - print(f"šŸš€ Starting pipeline execution ({len(self.table_graph.queries)} queries)") - print() + logger.info("Starting pipeline execution (%d queries)", len(self.table_graph.queries)) # Track completed queries completed = set() @@ -132,8 +133,7 @@ def execute_sql(sql: str): # Execute level by level for level_num, level_queries in enumerate(levels, 1): - if verbose: - print(f"šŸ“Š Level {level_num}: {len(level_queries)} queries") + logger.info("Level %d: %d queries", level_num, len(level_queries)) # Execute queries in this level concurrently with ThreadPoolExecutor(max_workers=max_workers) as pool: @@ -151,31 +151,21 @@ def execute_sql(sql: str): try: future.result() completed.add(query_id) - - if verbose: - print(f" āœ… {query_id}") + logger.info("Completed: %s", query_id) except Exception as e: failed.append((query_id, str(e))) - - if verbose: - print(f" āŒ {query_id}: {e}") - - if verbose: - print() + logger.debug("Query %s execution failed", query_id, exc_info=True) + logger.warning("Failed: %s: %s", query_id, e) elapsed = time.time() - start_time # Summary - if verbose: - print("=" * 60) - print(f"āœ… Pipeline completed in {elapsed:.2f}s") - print(f" Successful: {len(completed)}") - print(f" Failed: {len(failed)}") - if failed: - print("\nāš ļø Failed queries:") - for query_id, error in failed: - print(f" - {query_id}: {error}") - print("=" * 60) + logger.info("Pipeline completed in %.2fs", elapsed) + logger.info("Successful: %d", len(completed)) + logger.info("Failed: %d", len(failed)) + if failed: + for query_id, error in failed: + logger.warning("Failed query - %s: %s", query_id, error) return { "completed": list(completed), @@ -214,9 +204,7 @@ async def execute_sql(sql: str): result = await executor.async_run(execute_sql, max_workers=4) print(f"Completed {len(result['completed'])} queries") """ - if verbose: - print(f"šŸš€ Starting async pipeline execution ({len(self.table_graph.queries)} queries)") - print() + logger.info("Starting async pipeline execution (%d queries)", len(self.table_graph.queries)) # Track completed queries completed = set() @@ -231,8 +219,7 @@ async def execute_sql(sql: str): # Execute level by level for level_num, level_queries in enumerate(levels, 1): - if verbose: - print(f"šŸ“Š Level {level_num}: {len(level_queries)} queries") + logger.info("Level %d: %d queries", level_num, len(level_queries)) async def execute_with_semaphore(query_id: str, sql: str): """Execute query with semaphore for concurrency control""" @@ -240,12 +227,11 @@ async def execute_with_semaphore(query_id: str, sql: str): try: await executor(sql) completed.add(query_id) - if verbose: - print(f" āœ… {query_id}") + logger.info("Completed: %s", query_id) except Exception as e: failed.append((query_id, str(e))) - if verbose: - print(f" āŒ {query_id}: {e}") + logger.debug("Async query %s execution failed", query_id, exc_info=True) + logger.warning("Failed: %s: %s", query_id, e) # Execute queries in this level concurrently tasks = [] @@ -257,22 +243,15 @@ async def execute_with_semaphore(query_id: str, sql: str): # Wait for all tasks in this level to complete await asyncio.gather(*tasks) - if verbose: - print() - elapsed = time.time() - start_time # Summary - if verbose: - print("=" * 60) - print(f"āœ… Pipeline completed in {elapsed:.2f}s") - print(f" Successful: {len(completed)}") - print(f" Failed: {len(failed)}") - if failed: - print("\nāš ļø Failed queries:") - for query_id, error in failed: - print(f" - {query_id}: {error}") - print("=" * 60) + logger.info("Pipeline completed in %.2fs", elapsed) + logger.info("Successful: %d", len(completed)) + logger.info("Failed: %d", len(failed)) + if failed: + for query_id, error in failed: + logger.warning("Failed query - %s: %s", query_id, error) return { "completed": list(completed), diff --git a/src/clgraph/lineage_builder.py b/src/clgraph/lineage_builder.py index 086f6cc..d3a5b58 100644 --- a/src/clgraph/lineage_builder.py +++ b/src/clgraph/lineage_builder.py @@ -5,6 +5,7 @@ Includes SQLColumnTracer wrapper for backward compatibility. """ +from collections import deque from typing import Any, Dict, List, Optional, Set, Tuple, TypedDict import sqlglot @@ -558,7 +559,7 @@ def _qualify_sql_with_schema( # Return the qualified SQL return qualified.sql(dialect=dialect) - except Exception: + except (sqlglot.errors.SqlglotError, KeyError, ValueError, TypeError): # If qualification fails, return original SQL # The lineage builder will handle unqualified columns as before return sql_query @@ -1505,7 +1506,7 @@ def _extract_columns_from_expr( ) col_name = col.name result.append((table_ref, col_name)) - except Exception: + except (sqlglot.errors.SqlglotError, ValueError, TypeError): # If parsing fails, try simple extraction for "table.column" format if "." in expr_str: parts = expr_str.split(".") @@ -3219,10 +3220,10 @@ def get_forward_lineage(self, input_columns: List[str]) -> Dict[str, Any]: # BFS forward from each start node for start_node in start_nodes: visited = set() - queue = [(start_node, [start_node.full_name], [])] + queue = deque([(start_node, [start_node.full_name], [])]) while queue: - current, path, transformations = queue.pop(0) + current, path, transformations = queue.popleft() if current.full_name in visited: continue @@ -3290,10 +3291,10 @@ def get_backward_lineage(self, output_columns: List[str]) -> BackwardLineageResu # BFS backward from each start node for start_node in start_nodes: visited = set() - queue = [(start_node, [start_node.full_name], [])] + queue = deque([(start_node, [start_node.full_name], [])]) while queue: - current, path, transformations = queue.pop(0) + current, path, transformations = queue.popleft() if current.full_name in visited: continue diff --git a/src/clgraph/mcp/server.py b/src/clgraph/mcp/server.py index 226c61c..d7379f0 100644 --- a/src/clgraph/mcp/server.py +++ b/src/clgraph/mcp/server.py @@ -8,6 +8,7 @@ import argparse import asyncio import json +import logging from typing import Any, Dict, List from ..pipeline import Pipeline @@ -25,6 +26,8 @@ MCP_AVAILABLE = False Server = None # type: ignore +logger = logging.getLogger(__name__) + def _convert_param_type(param_type: ParameterType) -> str: """Convert ParameterType to JSON Schema type.""" @@ -146,6 +149,7 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: return [TextContent(type="text", text=json.dumps(content, indent=2, default=str))] except Exception as e: + logger.error("MCP tool execution failed: %s", e, exc_info=True) error_content = { "success": False, "error": f"Tool execution failed: {str(e)}", diff --git a/src/clgraph/models.py b/src/clgraph/models.py index 7b84552..1d70d25 100644 --- a/src/clgraph/models.py +++ b/src/clgraph/models.py @@ -9,6 +9,7 @@ """ import logging +from collections import deque from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple @@ -653,6 +654,10 @@ class ColumnLineageGraph: warnings: List[str] = field(default_factory=list) # Legacy validation warnings (deprecated) issues: List[ValidationIssue] = field(default_factory=list) # Structured validation issues + # Adjacency indices: full_name -> list of edges + _outgoing_index: Dict[str, List[ColumnEdge]] = field(default_factory=dict, repr=False) + _incoming_index: Dict[str, List[ColumnEdge]] = field(default_factory=dict, repr=False) + def add_node(self, node: ColumnNode): """Add a column node to the graph""" self.nodes[node.full_name] = node @@ -668,6 +673,8 @@ def add_edge(self, edge: ColumnEdge): # Add edge if not duplicate if edge not in self.edges: self.edges.append(edge) + self._outgoing_index.setdefault(edge.from_node.full_name, []).append(edge) + self._incoming_index.setdefault(edge.to_node.full_name, []).append(edge) def add_warning(self, warning: str): """Add a validation warning (deprecated - use add_issue instead)""" @@ -700,11 +707,11 @@ def get_output_nodes(self) -> List[ColumnNode]: def get_edges_from(self, node: ColumnNode) -> List[ColumnEdge]: """Get all edges originating from a node""" - return [e for e in self.edges if e.from_node == node] + return self._outgoing_index.get(node.full_name, []) def get_edges_to(self, node: ColumnNode) -> List[ColumnEdge]: """Get all edges pointing to a node""" - return [e for e in self.edges if e.to_node == node] + return self._incoming_index.get(node.full_name, []) def to_simplified(self) -> "ColumnLineageGraph": """ @@ -738,10 +745,10 @@ def to_simplified(self) -> "ColumnLineageGraph": for output_node in output_nodes: # BFS/DFS backward to find all reachable input nodes visited: Set[str] = set() - queue = [output_node.full_name] + queue = deque([output_node.full_name]) while queue: - current = queue.pop(0) + current = queue.popleft() if current in visited: continue visited.add(current) diff --git a/src/clgraph/multi_query.py b/src/clgraph/multi_query.py index 7bf6e3f..8e7bde2 100644 --- a/src/clgraph/multi_query.py +++ b/src/clgraph/multi_query.py @@ -93,23 +93,27 @@ def resolve(self, context: Dict): def _resolve_template(self, template: str, context: Dict) -> str: """Resolve a single template with context""" - # Try Jinja2 resolution + # Try Jinja2 resolution using sandboxed environment try: - from jinja2 import Template as JinjaTemplate # type: ignore[import-untyped] + from jinja2.sandbox import SandboxedEnvironment # type: ignore[import-untyped] - jinja_template = JinjaTemplate(template) + env = SandboxedEnvironment() + jinja_template = env.from_string(template) return jinja_template.render(**context) + except ImportError: + return self._resolve_fstring_template(template, context) except Exception: - # If Jinja2 fails, try f-string style - try: - # Simple variable substitution - for key, value in context.items(): - template = template.replace(f"{{{key}}}", str(value)) - template = template.replace(f"{{{{ {key} }}}}", str(value)) - return template - except Exception: - # If all fails, return original - return template + return self._resolve_fstring_template(template, context) + + def _resolve_fstring_template(self, template: str, context: Dict) -> str: + """Resolve template using simple f-string style substitution""" + try: + for key, value in context.items(): + template = template.replace(f"{{{key}}}", str(value)) + template = template.replace(f"{{{{ {key} }}}}", str(value)) + return template + except (KeyError, ValueError): + return template class MultiQueryParser: diff --git a/src/clgraph/pipeline.py b/src/clgraph/pipeline.py index 629157c..71f825d 100644 --- a/src/clgraph/pipeline.py +++ b/src/clgraph/pipeline.py @@ -9,9 +9,12 @@ - Airflow DAG generation """ +import logging +from collections import deque from datetime import datetime from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union +import sqlglot.errors from sqlglot import exp from .column import ( @@ -33,6 +36,8 @@ ) from .table import TableDependencyGraph +logger = logging.getLogger(__name__) + class PipelineLineageBuilder: """ @@ -106,13 +111,11 @@ def build(self, pipeline_or_graph) -> "Pipeline": self._add_query_edges(pipeline, query, query_lineage) else: # No SELECT to analyze (e.g., UPDATE without SELECT) - print(f"Info: Skipping lineage for {query_id} (no SELECT statement)") - except Exception as e: - # If lineage fails, skip this query - print(f"Warning: Failed to build lineage for {query_id}: {e}") - import traceback - - traceback.print_exc() + logger.info("Skipping lineage for %s (no SELECT statement)", query_id) + except (sqlglot.errors.SqlglotError, KeyError, ValueError, TypeError) as e: + # If lineage fails due to SQL parsing or data issues, skip this query + logger.warning("Failed to build lineage for %s: %s", query_id, e) + logger.debug("Traceback for %s lineage failure", query_id, exc_info=True) continue # Step 3: Add cross-query edges @@ -854,6 +857,14 @@ def edges(self) -> List[ColumnEdge]: """Access edges through column_graph for backward compatibility""" return self.column_graph.edges + def _get_incoming_edges(self, full_name: str) -> List[ColumnEdge]: + """Get incoming edges for a column using adjacency index.""" + return self.column_graph._incoming_index.get(full_name, []) + + def _get_outgoing_edges(self, full_name: str) -> List[ColumnEdge]: + """Get outgoing edges for a column using adjacency index.""" + return self.column_graph._outgoing_index.get(full_name, []) + def get_column( self, table_name: str, column_name: str, query_id: Optional[str] = None ) -> Optional[ColumnNode]: @@ -1114,7 +1125,7 @@ def _generate_query_id(sql: str, dialect: str, id_counts: Dict[str, int]) -> str id_counts[base_id] += 1 return f"{base_id}_{id_counts[base_id]}" - except Exception: + except (sqlglot.errors.SqlglotError, KeyError, AttributeError): # Fallback if parsing fails base_id = "query" if base_id not in id_counts: @@ -1335,17 +1346,17 @@ def trace_column_backward(self, table_name: str, column_name: str) -> List[Colum # BFS backward through edges visited = set() - queue = list(start_columns) + queue = deque(start_columns) sources = [] while queue: - current = queue.pop(0) + current = queue.popleft() if current.full_name in visited: continue visited.add(current.full_name) # Find incoming edges - incoming = [e for e in self.edges if e.to_node.full_name == current.full_name] + incoming = self._get_incoming_edges(current.full_name) if not incoming: # No incoming edges = source column @@ -1402,12 +1413,12 @@ def trace_column_backward_full( # BFS backward through edges, collecting all nodes and edges visited = set() - queue = list(start_columns) + queue = deque(start_columns) all_nodes = [] all_edges = [] while queue: - current = queue.pop(0) + current = queue.popleft() if current.full_name in visited: continue visited.add(current.full_name) @@ -1415,7 +1426,7 @@ def trace_column_backward_full( # Optionally skip CTE columns if not include_ctes and current.layer == "cte": # Still need to traverse through CTEs to find real tables - incoming = [e for e in self.edges if e.to_node.full_name == current.full_name] + incoming = self._get_incoming_edges(current.full_name) for edge in incoming: queue.append(edge.from_node) continue @@ -1423,7 +1434,7 @@ def trace_column_backward_full( all_nodes.append(current) # Find incoming edges - incoming = [e for e in self.edges if e.to_node.full_name == current.full_name] + incoming = self._get_incoming_edges(current.full_name) for edge in incoming: all_edges.append(edge) @@ -1489,17 +1500,17 @@ def trace_column_forward(self, table_name: str, column_name: str) -> List[Column # BFS forward through edges visited = set() - queue = list(start_columns) + queue = deque(start_columns) descendants = [] while queue: - current = queue.pop(0) + current = queue.popleft() if current.full_name in visited: continue visited.add(current.full_name) # Find outgoing edges - outgoing = [e for e in self.edges if e.from_node.full_name == current.full_name] + outgoing = self._get_outgoing_edges(current.full_name) if not outgoing: # No outgoing edges = final column @@ -1551,12 +1562,12 @@ def trace_column_forward_full( # BFS forward through edges, collecting all nodes and edges visited = set() - queue = list(start_columns) + queue = deque(start_columns) all_nodes = [] all_edges = [] while queue: - current = queue.pop(0) + current = queue.popleft() if current.full_name in visited: continue visited.add(current.full_name) @@ -1564,7 +1575,7 @@ def trace_column_forward_full( # Optionally skip CTE columns if not include_ctes and current.layer == "cte": # Still need to traverse through CTEs to find real tables - outgoing = [e for e in self.edges if e.from_node.full_name == current.full_name] + outgoing = self._get_outgoing_edges(current.full_name) for edge in outgoing: queue.append(edge.to_node) continue @@ -1572,7 +1583,7 @@ def trace_column_forward_full( all_nodes.append(current) # Find outgoing edges - outgoing = [e for e in self.edges if e.from_node.full_name == current.full_name] + outgoing = self._get_outgoing_edges(current.full_name) for edge in outgoing: all_edges.append(edge) @@ -1641,11 +1652,11 @@ def get_lineage_path( to_full_names = {col.full_name for col in to_columns} # BFS with path tracking, starting from all matching source columns - queue = [(col, []) for col in from_columns] + queue = deque((col, []) for col in from_columns) visited = set() while queue: - current, path = queue.pop(0) + current, path = queue.popleft() if current.full_name in visited: continue visited.add(current.full_name) @@ -1654,9 +1665,8 @@ def get_lineage_path( return path # Find outgoing edges - for edge in self.edges: - if edge.from_node.full_name == current.full_name: - queue.append((edge.to_node, path + [edge])) + for edge in self._get_outgoing_edges(current.full_name): + queue.append((edge.to_node, path + [edge])) return [] # No path found @@ -1688,18 +1698,16 @@ def generate_all_descriptions(self, batch_size: int = 10, verbose: bool = True): ): columns_to_process.append(col) - if verbose: - print(f"šŸ“Š Generating descriptions for {len(columns_to_process)} columns...") + logger.info("Generating descriptions for %d columns...", len(columns_to_process)) # Process columns for i, col in enumerate(columns_to_process): - if verbose and (i + 1) % batch_size == 0: - print(f" Processed {i + 1}/{len(columns_to_process)} columns...") + if (i + 1) % batch_size == 0: + logger.info("Processed %d/%d columns...", i + 1, len(columns_to_process)) generate_description(col, self.llm, self) - if verbose: - print(f"āœ… Done! Generated {len(columns_to_process)} descriptions") + logger.info("Done! Generated %d descriptions", len(columns_to_process)) def propagate_all_metadata(self, verbose: bool = True): """ @@ -1722,11 +1730,10 @@ def propagate_all_metadata(self, verbose: bool = True): # This handles metadata set via SQL comments on output columns output_columns = [col for col in self.columns.values() if col.layer == "output"] - if verbose: - print( - f"šŸ“Š Pass 1: Propagating metadata backward from " - f"{len(output_columns)} output columns..." - ) + logger.info( + "Pass 1: Propagating metadata backward from %d output columns...", + len(output_columns), + ) for col in output_columns: propagate_metadata_backward(col, self) @@ -1744,17 +1751,16 @@ def propagate_all_metadata(self, verbose: bool = True): if col.table_name == target_table and col.is_computed(): columns_to_process.append(col) - if verbose: - print( - f"šŸ“Š Pass 2: Propagating metadata forward for {len(columns_to_process)} columns..." - ) + logger.info( + "Pass 2: Propagating metadata forward for %d columns...", + len(columns_to_process), + ) # Process columns for col in columns_to_process: propagate_metadata(col, self) - if verbose: - print(f"āœ… Done! Propagated metadata for {len(columns_to_process)} columns") + logger.info("Done! Propagated metadata for %d columns", len(columns_to_process)) def get_pii_columns(self) -> List[ColumnNode]: """ @@ -2021,10 +2027,10 @@ def split(self, sinks: List) -> List["Pipeline"]: # Find all queries needed for this sink visited = set() - queue = [sink_table] + queue = deque([sink_table]) while queue: - current_table = queue.pop(0) + current_table = queue.popleft() if current_table in visited: continue visited.add(current_table) @@ -2762,7 +2768,7 @@ def print_issues(self, severity: Optional[str | IssueSeverity] = None): issues = self.get_issues(severity=severity) if severity else self.get_all_issues() if not issues: - print("āœ… No validation issues found!") + logger.info("No validation issues found") return # Group by severity @@ -2772,18 +2778,15 @@ def print_issues(self, severity: Optional[str | IssueSeverity] = None): for issue in issues: by_severity[issue.severity.value].append(issue) - # Print by severity (errors first, then warnings, then info) + # Log by severity (errors first, then warnings, then info) for sev in ["error", "warning", "info"]: if sev not in by_severity: continue issues_list = by_severity[sev] - icon = {"error": "āŒ", "warning": "āš ļø", "info": "ā„¹ļø"}[sev] - print(f"\n{icon} {sev.upper()} ({len(issues_list)})") - print("=" * 80) - + logger.info("%s (%d)", sev.upper(), len(issues_list)) for issue in issues_list: - print(f"\n{issue}") + logger.info("%s", issue) __all__ = [ diff --git a/src/clgraph/table.py b/src/clgraph/table.py index 214d48f..d06350b 100644 --- a/src/clgraph/table.py +++ b/src/clgraph/table.py @@ -86,7 +86,7 @@ def generate_description(self, llm, lineage_graph): response = chain.invoke({}) self.description = response.content.strip() - except Exception: + except (ImportError, ValueError, AttributeError, RuntimeError): # Fallback to simple rule-based description self._generate_fallback_description() diff --git a/src/clgraph/tools/sql.py b/src/clgraph/tools/sql.py index 8bc0792..d46d620 100644 --- a/src/clgraph/tools/sql.py +++ b/src/clgraph/tools/sql.py @@ -142,7 +142,7 @@ def run( return self._generate_two_stage(question, include_explanation) else: return self._generate_direct(question, include_explanation) - except Exception as e: + except (ImportError, ValueError, AttributeError, RuntimeError) as e: return ToolResult.error_result(f"SQL generation failed: {e}") def _generate_direct(self, question: str, include_explanation: bool) -> ToolResult: @@ -267,7 +267,7 @@ def _select_tables(self, question: str, builder: ContextBuilder) -> List[str]: valid = [t for t in selected if t in self.pipeline.table_graph.tables] if valid: return valid - except Exception: + except (ImportError, ValueError, AttributeError, RuntimeError): pass # Fallback to keyword selection @@ -429,7 +429,7 @@ def run(self, sql: str, detail_level: str = "normal") -> ToolResult: }, message=explanation[:200] + "..." if len(explanation) > 200 else explanation, ) - except Exception as e: + except (ImportError, ValueError, AttributeError, RuntimeError) as e: return ToolResult.error_result(f"Failed to explain query: {e}") def _extract_tables(self, sql: str) -> List[str]: diff --git a/src/clgraph/visualizations.py b/src/clgraph/visualizations.py index f56f95d..67c8edf 100644 --- a/src/clgraph/visualizations.py +++ b/src/clgraph/visualizations.py @@ -5,6 +5,7 @@ No business logic - just presentation layer. """ +from collections import deque from typing import TYPE_CHECKING, List, Tuple, Union import graphviz @@ -622,11 +623,11 @@ def visualize_column_path( # Traverse backward to find all dependencies using BFS visited = set() - to_visit = [target_node] + to_visit = deque([target_node]) relevant_nodes = [] while to_visit: - node = to_visit.pop(0) + node = to_visit.popleft() if node.full_name in visited: continue @@ -808,7 +809,7 @@ def visualize_table_dependencies_with_levels( for level_num, level_queries in enumerate(levels, 1): for qid in level_queries: query_to_level[qid] = level_num - except Exception: + except (RuntimeError, KeyError, AttributeError): query_to_level = {} # Get source and final tables diff --git a/tests/test_validation_framework.py b/tests/test_validation_framework.py index 1cd8726..fca53b5 100644 --- a/tests/test_validation_framework.py +++ b/tests/test_validation_framework.py @@ -358,7 +358,7 @@ def test_has_warnings(self): class TestValidationReporting: """Test validation reporting and output formatting.""" - def test_print_issues_no_crash(self, capsys): + def test_print_issues_no_crash(self, caplog): """Test that print_issues() doesn't crash.""" queries = [ ( @@ -372,12 +372,12 @@ def test_print_issues_no_crash(self, capsys): pipeline = Pipeline(queries, dialect="bigquery") # Should not crash - pipeline.print_issues() + with caplog.at_level("INFO", logger="clgraph.pipeline"): + pipeline.print_issues() - captured = capsys.readouterr() - assert "ERROR" in captured.out or "āŒ" in captured.out + assert "ERROR" in caplog.text - def test_print_issues_severity_filter(self, capsys): + def test_print_issues_severity_filter(self, caplog): """Test printing issues filtered by severity.""" queries = [ ("bad", "CREATE TABLE r AS SELECT * FROM t1, t2"), @@ -386,11 +386,11 @@ def test_print_issues_severity_filter(self, capsys): pipeline = Pipeline(queries, dialect="bigquery") # Print only errors - pipeline.print_issues(severity=IssueSeverity.ERROR) + with caplog.at_level("INFO", logger="clgraph.pipeline"): + pipeline.print_issues(severity=IssueSeverity.ERROR) - captured = capsys.readouterr() # Should show ERROR but not INFO - assert "ERROR" in captured.out or "āŒ" in captured.out + assert "ERROR" in caplog.text class TestValidationIntegration: diff --git a/uv.lock b/uv.lock index eacb6d2..8c0851f 100644 --- a/uv.lock +++ b/uv.lock @@ -743,7 +743,6 @@ name = "clgraph" version = "0.0.3" source = { editable = "." } dependencies = [ - { name = "cloudpickle" }, { name = "graphviz" }, { name = "jinja2" }, { name = "sqlglot" }, @@ -805,7 +804,6 @@ requires-dist = [ { name = "apache-airflow", marker = "extra == 'airflow'", specifier = ">=2.7.0,<3.0.0" }, { name = "build", marker = "extra == 'build'", specifier = ">=1.0.0" }, { name = "clgraph", extras = ["llm", "mcp", "airflow"], marker = "extra == 'all'" }, - { name = "cloudpickle", specifier = ">=3.1.2" }, { name = "duckdb", marker = "extra == 'dev'", specifier = ">=0.9.0" }, { name = "graphviz", specifier = ">=0.20.0" }, { name = "graphviz", marker = "extra == 'examples'", specifier = ">=0.20.0" }, @@ -861,15 +859,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7a/7e/c08007d3fb2bbefb430437a3573373590abedc03566b785d7d6763b22480/clickclick-20.10.2-py2.py3-none-any.whl", hash = "sha256:c8f33e6d9ec83f68416dd2136a7950125bd256ec39ccc9a85c6e280a16be2bb5", size = 7368, upload-time = "2020-10-03T13:36:49.842Z" }, ] -[[package]] -name = "cloudpickle" -version = "3.1.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, -] - [[package]] name = "colorama" version = "0.4.6" From 2d7b8531130b1db16970812af52ccdc0810f57b2 Mon Sep 17 00:00:00 2001 From: Ming-Jer Lee Date: Wed, 4 Feb 2026 19:57:36 -0800 Subject: [PATCH 2/3] fix: Resolve LSP violation in BaseTool.run() signature Add *args to BaseTool.run() so subclasses can define specific parameters without violating the Liskov Substitution Principle. --- src/clgraph/tools/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/clgraph/tools/base.py b/src/clgraph/tools/base.py index 9d8622b..638bda1 100644 --- a/src/clgraph/tools/base.py +++ b/src/clgraph/tools/base.py @@ -175,7 +175,7 @@ def parameters(self) -> Dict[str, ParameterSpec]: pass @abstractmethod - def run(self, **kwargs) -> ToolResult: + def run(self, *args, **kwargs) -> ToolResult: """ Execute the tool with given parameters. From 1efee1791845f27109ff43eb8c1a424d083a32cb Mon Sep 17 00:00:00 2001 From: Ming-Jer Lee Date: Wed, 4 Feb 2026 20:13:54 -0800 Subject: [PATCH 3/3] fix: Resolve all ty type checker diagnostics - agent.py: Use lambda for sorted key to avoid list[Sized] inference - lineage_builder.py: Add type: ignore for union-attr, ensure str types - mcp/server.py: Add type: ignore for optional mcp imports, assert Server availability, annotate schema dict - multi_query.py: Remove unused type: ignore comment - orchestrators: Add type: ignore for optional dagster/prefect imports - query_parser.py: Use str() to ensure alias types after guard blocks - tools/governance.py: Suppress mixed-type dict append with type: ignore - visualizations.py: Replace hasattr with isinstance for type narrowing --- src/clgraph/agent.py | 6 ++++-- src/clgraph/lineage_builder.py | 8 ++++---- src/clgraph/mcp/server.py | 13 +++++++------ src/clgraph/multi_query.py | 2 +- src/clgraph/orchestrators/dagster.py | 4 ++-- src/clgraph/orchestrators/prefect.py | 6 ++++-- src/clgraph/query_parser.py | 6 ++++++ src/clgraph/tools/governance.py | 17 ++++++++--------- src/clgraph/visualizations.py | 10 ++++------ 9 files changed, 40 insertions(+), 32 deletions(-) diff --git a/src/clgraph/agent.py b/src/clgraph/agent.py index bb944d9..c959205 100644 --- a/src/clgraph/agent.py +++ b/src/clgraph/agent.py @@ -270,7 +270,9 @@ def _extract_table_column(self, question: str) -> tuple: """Extract table and column references from question.""" # First, try to match against known table names (handles schema.table.column) # Sort by length descending to match longer table names first - known_tables = sorted(self.pipeline.table_graph.tables.keys(), key=len, reverse=True) + known_tables = sorted( + self.pipeline.table_graph.tables.keys(), key=lambda t: len(t), reverse=True + ) for table_name in known_tables: # Check for table.column pattern @@ -567,7 +569,7 @@ def _handle_sql_explain(self, question: str) -> AgentResult: result = self.registry.run("explain_query", sql=sql) return AgentResult( - answer=result.message if result.success else result.error, + answer=(result.message if result.success else result.error) or "", question_type=QuestionType.SQL_EXPLAIN, tool_used="explain_query", tool_result=result, diff --git a/src/clgraph/lineage_builder.py b/src/clgraph/lineage_builder.py index d3a5b58..2baca54 100644 --- a/src/clgraph/lineage_builder.py +++ b/src/clgraph/lineage_builder.py @@ -983,7 +983,7 @@ def _create_lateral_correlation_edges(self, unit: QueryUnit): For each correlated column (reference to outer table), create an edge showing the correlation relationship. """ - lateral_alias = unit.name + lateral_alias = unit.name or "" for correlated_col in unit.correlated_columns: # Parse table.column format @@ -2651,7 +2651,7 @@ def _parse_aggregate_spec(self, ast_node: Optional[exp.Expression]) -> Optional[ # Extract ORDER BY within aggregate (fallback for standard syntax) if not order_by and hasattr(agg_func, "order") and agg_func.order: - for order_expr in agg_func.order.expressions: + for order_expr in agg_func.order.expressions: # type: ignore[union-attr] col_name = "" direction = "asc" nulls = None @@ -2724,7 +2724,7 @@ def _get_aggregate_func_name(self, node: exp.Expression) -> str: elif isinstance(node, exp.Max): return "MAX" elif hasattr(node, "sql_name"): - return node.sql_name().upper() + return node.sql_name().upper() # type: ignore[union-attr] elif hasattr(node, "name") and node.name: return node.name.upper() return "AGGREGATE" @@ -3097,7 +3097,7 @@ def _validate_qualified_columns_in_joins( category=IssueCategory.UNQUALIFIED_COLUMN, message=( f"Unqualified column '{col_name}' in expression for '{output_col_name}'. " - f"With multiple tables ({', '.join(available_tables)}), " + f"With multiple tables ({', '.join(str(t) for t in available_tables)}), " f"the source table is ambiguous." ), query_id=self.query_id, diff --git a/src/clgraph/mcp/server.py b/src/clgraph/mcp/server.py index d7379f0..b06055f 100644 --- a/src/clgraph/mcp/server.py +++ b/src/clgraph/mcp/server.py @@ -17,14 +17,14 @@ # MCP SDK imports - these are optional dependencies try: - from mcp.server import Server - from mcp.server.stdio import stdio_server - from mcp.types import Resource, TextContent, Tool + from mcp.server import Server # type: ignore[unresolved-import] + from mcp.server.stdio import stdio_server # type: ignore[unresolved-import] + from mcp.types import Resource, TextContent, Tool # type: ignore[unresolved-import] MCP_AVAILABLE = True except ImportError: MCP_AVAILABLE = False - Server = None # type: ignore + Server = None logger = logging.getLogger(__name__) @@ -107,6 +107,7 @@ def create_mcp_server( registry.register_all(BASIC_TOOLS) # Create server + assert Server is not None, "MCP SDK should be available at this point" server = Server("clgraph-lineage") # Register tool list handler @@ -206,7 +207,7 @@ async def read_resource(uri: str) -> str: def _get_full_schema(pipeline: Pipeline) -> str: """Get full pipeline schema as JSON.""" - schema = { + schema: Dict[str, Any] = { "dialect": pipeline.dialect, "tables": {}, } @@ -321,7 +322,7 @@ async def run_mcp_server_async( server = create_mcp_server(pipeline, llm, include_llm_tools) async with stdio_server() as (read_stream, write_stream): - await server.run(read_stream, write_stream) + await server.run(read_stream, write_stream) # type: ignore[possibly-missing-attribute] def run_mcp_server( diff --git a/src/clgraph/multi_query.py b/src/clgraph/multi_query.py index 8e7bde2..071170b 100644 --- a/src/clgraph/multi_query.py +++ b/src/clgraph/multi_query.py @@ -95,7 +95,7 @@ def _resolve_template(self, template: str, context: Dict) -> str: """Resolve a single template with context""" # Try Jinja2 resolution using sandboxed environment try: - from jinja2.sandbox import SandboxedEnvironment # type: ignore[import-untyped] + from jinja2.sandbox import SandboxedEnvironment env = SandboxedEnvironment() jinja_template = env.from_string(template) diff --git a/src/clgraph/orchestrators/dagster.py b/src/clgraph/orchestrators/dagster.py index 545a9a1..41a8833 100644 --- a/src/clgraph/orchestrators/dagster.py +++ b/src/clgraph/orchestrators/dagster.py @@ -89,7 +89,7 @@ def to_assets( - Deployment: Drop the definitions.py file in your Dagster workspace """ try: - import dagster as dg + import dagster as dg # type: ignore[unresolved-import] except ImportError as e: raise ImportError( "Dagster is required for asset generation. " @@ -226,7 +226,7 @@ def to_job( - Deployment: Drop the definitions.py file in your Dagster workspace """ try: - import dagster as dg + import dagster as dg # type: ignore[unresolved-import] except ImportError as e: raise ImportError( "Dagster is required for job generation. " diff --git a/src/clgraph/orchestrators/prefect.py b/src/clgraph/orchestrators/prefect.py index d75f4ab..366cb75 100644 --- a/src/clgraph/orchestrators/prefect.py +++ b/src/clgraph/orchestrators/prefect.py @@ -89,7 +89,7 @@ def to_flow( - Use to_deployment() for scheduled execution """ try: - from prefect import flow, task + from prefect import flow, task # type: ignore[unresolved-import] except ImportError as e: raise ImportError( "Prefect is required for flow generation. " @@ -212,7 +212,9 @@ def to_deployment( - Use work_pool_name to specify execution environment """ try: - from prefect import flow as flow_decorator # noqa: F401 + from prefect import ( # type: ignore[unresolved-import] + flow as flow_decorator, # noqa: F401 + ) except ImportError as e: raise ImportError( "Prefect is required for deployment generation. " diff --git a/src/clgraph/query_parser.py b/src/clgraph/query_parser.py index 543a29b..ab463af 100644 --- a/src/clgraph/query_parser.py +++ b/src/clgraph/query_parser.py @@ -863,6 +863,9 @@ def process_lateral_flatten(lateral_node: exp.Lateral, parent_unit: QueryUnit): flatten_alias = f"_flatten_{self.subquery_counter}" self.subquery_counter += 1 + # Ensure str type for type checker (flatten_alias is guaranteed non-empty) + flatten_alias = str(flatten_alias) + # Store FLATTEN info parent_unit.unnest_sources[flatten_alias] = { "source_table": source_table, @@ -914,6 +917,9 @@ def process_lateral_subquery( lateral_alias = f"_lateral_{self.subquery_counter}" self.subquery_counter += 1 + # Ensure str type for type checker (lateral_alias is guaranteed non-empty) + lateral_alias = str(lateral_alias) + # Find all column references in the subquery correlated_columns: List[str] = [] for col in subquery.find_all(exp.Column): diff --git a/src/clgraph/tools/governance.py b/src/clgraph/tools/governance.py index 5372ed3..2916f14 100644 --- a/src/clgraph/tools/governance.py +++ b/src/clgraph/tools/governance.py @@ -59,15 +59,14 @@ def run(self, table: Optional[str] = None, include_lineage: bool = False) -> Too key = f"{impact.table_name}.{impact.column_name}" if key not in derived_pii: derived_pii.add(key) - pii_columns.append( - { - "table": impact.table_name, - "column": impact.column_name, - "description": impact.description, - "owner": impact.owner, - "derived_from_pii": True, - } - ) + derived_entry = { + "table": impact.table_name, + "column": impact.column_name, + "description": impact.description, + "owner": impact.owner, + "derived_from_pii": True, + } + pii_columns.append(derived_entry) # type: ignore[invalid-argument-type] if not pii_columns: msg = "No PII columns found" diff --git a/src/clgraph/visualizations.py b/src/clgraph/visualizations.py index 67c8edf..e00b8e3 100644 --- a/src/clgraph/visualizations.py +++ b/src/clgraph/visualizations.py @@ -606,12 +606,11 @@ def visualize_column_path( # Determine graph type and get appropriate accessors # PipelineLineageGraph uses .columns, ColumnLineageGraph uses .nodes - if hasattr(graph, "columns"): + if isinstance(graph, ColumnLineageGraph): + nodes_dict = graph.nodes + else: # PipelineLineageGraph nodes_dict = graph.columns - else: - # ColumnLineageGraph - nodes_dict = graph.nodes # Find target node if target_column not in nodes_dict: @@ -635,8 +634,7 @@ def visualize_column_path( relevant_nodes.append(node) # Get incoming edges - different methods for different graph types - if hasattr(graph, "get_edges_to"): - # ColumnLineageGraph has get_edges_to method + if isinstance(graph, ColumnLineageGraph): incoming = graph.get_edges_to(node) else: # PipelineLineageGraph - filter edges manually