From d8e1c544e6126da0d979f12e923da78a60f7e580 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sun, 8 Jun 2025 01:02:02 +0000 Subject: [PATCH 01/28] refactor: turn pod into a subpackage --- src/orcabridge/pod/__init__.py | 7 +++++++ src/orcabridge/{ => pod}/pod.py | 0 2 files changed, 7 insertions(+) create mode 100644 src/orcabridge/pod/__init__.py rename src/orcabridge/{ => pod}/pod.py (100%) diff --git a/src/orcabridge/pod/__init__.py b/src/orcabridge/pod/__init__.py new file mode 100644 index 0000000..b58e438 --- /dev/null +++ b/src/orcabridge/pod/__init__.py @@ -0,0 +1,7 @@ +from .pod import Pod, FunctionPod, function_pod + +__all__ = [ + "Pod", + "FunctionPod", + "function_pod", +] diff --git a/src/orcabridge/pod.py b/src/orcabridge/pod/pod.py similarity index 100% rename from src/orcabridge/pod.py rename to src/orcabridge/pod/pod.py From d5bd32a1598683c17df659bbdf24008cc41de9a0 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 10 Jun 2025 01:36:24 +0000 Subject: [PATCH 02/28] chore: ignore notebook starting with underscore --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 251b5d3..a5279a0 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,10 @@ notebooks/**/*.parquet notebooks/**/*.pkl notebooks/**/*.db + +# Ignore any notebook that starts with an underscore +notebooks/**/_*.ipynb + # Ignore vscode settings .vscode/ From 6f516e20f1f15c109608fd114f335013b5b023cd Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Tue, 10 Jun 2025 01:36:59 +0000 Subject: [PATCH 03/28] refactor: turn types into a subpackage --- src/orcabridge/types.py | 37 ------------------- src/orcabridge/types/__init__.py | 63 ++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 37 deletions(-) delete mode 100644 src/orcabridge/types.py create mode 100644 src/orcabridge/types/__init__.py diff --git a/src/orcabridge/types.py b/src/orcabridge/types.py deleted file mode 100644 index 51a0284..0000000 --- a/src/orcabridge/types.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -from collections.abc import Collection, Mapping -from typing import Protocol - -from typing_extensions import TypeAlias - -# Convenience alias for anything pathlike -PathLike = str | os.PathLike - -# an (optional) string or a collection of (optional) string values -# Note that TagValue can be nested, allowing for an arbitrary depth of nested lists -TagValue: TypeAlias = str | None | Collection["TagValue"] - - -# the top level tag is a mapping from string keys to values that can be a string or -# an arbitrary depth of nested list of strings or None -Tag: TypeAlias = Mapping[str, TagValue] - - -# a pathset is a path or an arbitrary depth of nested list of paths -PathSet: TypeAlias = PathLike | Collection[PathLike | None] - -# a packet is a mapping from string keys to pathsets -Packet: TypeAlias = Mapping[str, PathSet] - -# a batch is a tuple of a tag and a list of packets -Batch: TypeAlias = tuple[Tag, Collection[Packet]] - - -class PodFunction(Protocol): - """ - A function suitable to be used in a FunctionPod. - It takes one or more named arguments, each corresponding to a path to a file or directory, - and returns a path or a list of paths - """ - - def __call__(self, **kwargs: PathSet) -> None | PathSet | list[PathSet]: ... diff --git a/src/orcabridge/types/__init__.py b/src/orcabridge/types/__init__.py new file mode 100644 index 0000000..dbf0101 --- /dev/null +++ b/src/orcabridge/types/__init__.py @@ -0,0 +1,63 @@ +# src/orcabridge/types.py +import os +from collections.abc import Collection, Mapping +from pathlib import Path +from typing import Any, Protocol +from typing_extensions import TypeAlias + +import polars as pl + +# Convenience alias for anything pathlike +PathLike = str | os.PathLike + +# an (optional) string or a collection of (optional) string values +# Note that TagValue can be nested, allowing for an arbitrary depth of nested lists +TagValue: TypeAlias = str | None | Collection["TagValue"] + +# the top level tag is a mapping from string keys to values that can be a string or +# an arbitrary depth of nested list of strings or None +Tag: TypeAlias = Mapping[str, TagValue] + +# a pathset is a path or an arbitrary depth of nested list of paths +PathSet: TypeAlias = PathLike | Collection[PathLike | None] + +# Simple data types that we support (with clear Polars correspondence) +SimpleDataValue: TypeAlias = str | int | float | bool | bytes + +# Extended data types that can be stored in packets +# Either the original PathSet or one of our supported simple data types +DataValue: TypeAlias = PathSet | SimpleDataValue + +# Data type specifications - only support Python types and Polars types for simplicity +DataType: TypeAlias = ( + type[str] + | type[int] + | type[float] + | type[bool] + | type[bytes] + | type[Path] + | type[list] # this needs to be validated specifically at runtime + | type[pl.DataType] +) + + +# a packet is a mapping from string keys to data values +Packet: TypeAlias = Mapping[str, DataValue] + +# a batch is a tuple of a tag and a list of packets +Batch: TypeAlias = tuple[Tag, Collection[Packet]] + +# Type specification for function inputs/outputs +TypeSpec: TypeAlias = dict[str, DataType] + + +class PodFunction(Protocol): + """ + A function suitable to be used in a FunctionPod. + It takes one or more named arguments, each corresponding to either: + - A path to a file or directory (PathSet) - for backward compatibility + - A simple data value (str, int, float, bool, bytes, Path) + and returns either None, a single value, or a list of values + """ + + def __call__(self, **kwargs: DataValue) -> None | DataValue | list[DataValue]: ... From 0f191beaafa826369354bc255ccc7577bbd0cbbb Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 13 Jun 2025 02:51:22 +0000 Subject: [PATCH 04/28] feat: add type extraction from function --- src/orcabridge/types/__init__.py | 26 +- src/orcabridge/types/core.py | 49 ++ src/orcabridge/types/default.py | 23 + src/orcabridge/types/handlers.py | 114 ++++ src/orcabridge/types/inference.py | 173 +++++ src/orcabridge/types/registry.py | 618 ++++++++++++++++++ tests/test_types/__init__.py | 1 + tests/test_types/test_inference/__init__.py | 1 + .../test_extract_function_data_types.py | 391 +++++++++++ 9 files changed, 1377 insertions(+), 19 deletions(-) create mode 100644 src/orcabridge/types/core.py create mode 100644 src/orcabridge/types/default.py create mode 100644 src/orcabridge/types/handlers.py create mode 100644 src/orcabridge/types/inference.py create mode 100644 src/orcabridge/types/registry.py create mode 100644 tests/test_types/__init__.py create mode 100644 tests/test_types/test_inference/__init__.py create mode 100644 tests/test_types/test_inference/test_extract_function_data_types.py diff --git a/src/orcabridge/types/__init__.py b/src/orcabridge/types/__init__.py index dbf0101..0ec194d 100644 --- a/src/orcabridge/types/__init__.py +++ b/src/orcabridge/types/__init__.py @@ -5,7 +5,8 @@ from typing import Any, Protocol from typing_extensions import TypeAlias -import polars as pl + +SUPPORTED_PYTHON_TYPES = (str, int, float, bool, bytes) # Convenience alias for anything pathlike PathLike = str | os.PathLike @@ -22,23 +23,13 @@ PathSet: TypeAlias = PathLike | Collection[PathLike | None] # Simple data types that we support (with clear Polars correspondence) -SimpleDataValue: TypeAlias = str | int | float | bool | bytes +SupportedNativePythonData: TypeAlias = str | int | float | bool | bytes -# Extended data types that can be stored in packets -# Either the original PathSet or one of our supported simple data types -DataValue: TypeAlias = PathSet | SimpleDataValue +ExtendedSupportedPythonData: TypeAlias = SupportedNativePythonData | PathLike -# Data type specifications - only support Python types and Polars types for simplicity -DataType: TypeAlias = ( - type[str] - | type[int] - | type[float] - | type[bool] - | type[bytes] - | type[Path] - | type[list] # this needs to be validated specifically at runtime - | type[pl.DataType] -) +# Extended data values that can be stored in packets +# Either the original PathSet or one of our supported simple data types +DataValue: TypeAlias = PathSet | SupportedNativePythonData | Collection["DataValue"] # a packet is a mapping from string keys to data values @@ -47,9 +38,6 @@ # a batch is a tuple of a tag and a list of packets Batch: TypeAlias = tuple[Tag, Collection[Packet]] -# Type specification for function inputs/outputs -TypeSpec: TypeAlias = dict[str, DataType] - class PodFunction(Protocol): """ diff --git a/src/orcabridge/types/core.py b/src/orcabridge/types/core.py new file mode 100644 index 0000000..389338a --- /dev/null +++ b/src/orcabridge/types/core.py @@ -0,0 +1,49 @@ +from typing import Protocol, Any +import pyarrow as pa +from dataclasses import dataclass + + +# TODO: reconsider the need for this dataclass as its information is superfluous +# to the registration of the handler into the registry. +@dataclass +class TypeInfo: + python_type: type + arrow_type: pa.DataType + semantic_type: str # name under which the type is registered + + +class TypeHandler(Protocol): + """Protocol for handling conversion between Python types and underlying Arrow + data types used for storage. + + The handler itself IS the definition of a semantic type. The semantic type + name/identifier is provided by the registerer when registering the handler. + + TypeHandlers should clearly communicate what Python types they can handle, + and focus purely on conversion logic. + """ + + def supported_types(self) -> type | tuple[type, ...]: + """Return the Python type(s) this handler can process. + + Returns: + Single Type or tuple of Types this handler supports + + Examples: + - PathHandler: return Path + - NumericHandler: return (int, float) + - CollectionHandler: return (list, tuple, set) + """ + ... + + def to_storage_type(self) -> pa.DataType: + """Return the Arrow DataType instance for schema definition.""" + ... + + def to_storage_value(self, value: Any) -> Any: + """Convert Python value to Arrow-compatible storage representation.""" + ... + + def from_storage_value(self, value: Any) -> Any: + """Convert storage representation back to Python object.""" + ... diff --git a/src/orcabridge/types/default.py b/src/orcabridge/types/default.py new file mode 100644 index 0000000..701b987 --- /dev/null +++ b/src/orcabridge/types/default.py @@ -0,0 +1,23 @@ +from .registry import TypeRegistry +from .handlers import ( + PathHandler, + UUIDHandler, + SimpleMappingHandler, + DateTimeHandler, +) +import pyarrow as pa + +# Create default registry and register handlers +default_registry = TypeRegistry() + +# Register with semantic names - registry extracts supported types automatically +default_registry.register("path", PathHandler()) +default_registry.register("uuid", UUIDHandler()) +default_registry.register("int", SimpleMappingHandler(int, pa.int64())) +default_registry.register("float", SimpleMappingHandler(float, pa.float64())) +default_registry.register("bool", SimpleMappingHandler(bool, pa.bool_())) +default_registry.register("str", SimpleMappingHandler(str, pa.string())) +default_registry.register("bytes", SimpleMappingHandler(bytes, pa.binary())) +default_registry.register( + "datetime", DateTimeHandler() +) # Registers for datetime, date, time diff --git a/src/orcabridge/types/handlers.py b/src/orcabridge/types/handlers.py new file mode 100644 index 0000000..0dcc97a --- /dev/null +++ b/src/orcabridge/types/handlers.py @@ -0,0 +1,114 @@ +from typing import Any +import pyarrow as pa +from pathlib import Path +from uuid import UUID +from decimal import Decimal +from datetime import datetime, date, time + + +class PathHandler: + """Handler for pathlib.Path objects, stored as strings.""" + + def supported_types(self) -> type: + return Path + + def to_storage_type(self) -> pa.DataType: + return pa.string() + + def to_storage_value(self, value: Path) -> str: + return str(value) + + def from_storage_value(self, value: str) -> Path | None: + return Path(value) if value else None + + +class UUIDHandler: + """Handler for UUID objects, stored as strings.""" + + def supported_types(self) -> type: + return UUID + + def to_storage_type(self) -> pa.DataType: + return pa.string() + + def to_storage_value(self, value: UUID) -> str: + return str(value) + + def from_storage_value(self, value: str) -> UUID | None: + return UUID(value) if value else None + + +class DecimalHandler: + """Handler for Decimal objects, stored as strings.""" + + def supported_types(self) -> type: + return Decimal + + def to_storage_type(self) -> pa.DataType: + return pa.string() + + def to_storage_value(self, value: Decimal) -> str: + return str(value) + + def from_storage_value(self, value: str) -> Decimal | None: + return Decimal(value) if value else None + + +class SimpleMappingHandler: + """Handler for basic types that map directly to Arrow.""" + + def __init__(self, python_type: type, arrow_type: pa.DataType): + self._python_type = python_type + self._arrow_type = arrow_type + + def supported_types(self) -> type: + return self._python_type + + def to_storage_type(self) -> pa.DataType: + return self._arrow_type + + def to_storage_value(self, value: Any) -> Any: + return value # Direct mapping + + def from_storage_value(self, value: Any) -> Any: + return value # Direct mapping + + +class DirectArrowHandler: + """Handler for types that map directly to Arrow without conversion.""" + + def __init__(self, arrow_type: pa.DataType): + self._arrow_type = arrow_type + + def supported_types(self) -> type: + return self._arrow_type + + def to_storage_type(self) -> pa.DataType: + return self._arrow_type + + def to_storage_value(self, value: Any) -> Any: + return value # Direct mapping + + def from_storage_value(self, value: Any) -> Any: + return value # Direct mapping + + +class DateTimeHandler: + """Handler for datetime objects.""" + + def supported_types(self) -> tuple[type, ...]: + return (datetime, date, time) # Handles multiple related types + + def to_storage_type(self) -> pa.DataType: + return pa.timestamp("us") # Store everything as timestamp + + def to_storage_value(self, value: datetime | date | time) -> Any: + if isinstance(value, datetime): + return value + elif isinstance(value, date): + return datetime.combine(value, time.min) + elif isinstance(value, time): + return datetime.combine(date.today(), value) + + def from_storage_value(self, value: datetime) -> datetime: + return value # Could add logic to restore original type if needed diff --git a/src/orcabridge/types/inference.py b/src/orcabridge/types/inference.py new file mode 100644 index 0000000..9ef3d74 --- /dev/null +++ b/src/orcabridge/types/inference.py @@ -0,0 +1,173 @@ +# Library of functions for inferring types for FunctionPod input and output parameters. + + +from collections.abc import Callable, Collection, Sequence +from typing import get_origin, get_args, TypeAlias +import inspect +import logging + +logger = logging.getLogger(__name__) +DataType: TypeAlias = type +TypeSpec: TypeAlias = dict[str, DataType] # Mapping of parameter names to their types + + +def extract_function_data_types( + func: Callable, + output_keys: Collection[str], + input_types: TypeSpec | None = None, + output_types: TypeSpec | Sequence[type] | None = None, +) -> tuple[TypeSpec, TypeSpec]: + """ + Extract input and output data types from a function signature. + + This function analyzes a function's signature to determine the types of its parameters + and return values. It combines information from type annotations, user-provided type + specifications, and return key mappings to produce complete type specifications. + + Args: + func: The function to analyze for type information. + output_keys: Collection of string keys that will be used to map the function's + return values. For functions returning a single value, provide a single key. + For functions returning multiple values (tuple/list), provide keys matching + the number of return items. + input_types: Optional mapping of parameter names to their types. If provided, + these types override any type annotations in the function signature for the + specified parameters. If a parameter is not in this mapping and has no + annotation, an error is raised. + output_types: Optional type specification for return values. Can be either: + - A dict mapping output keys to types (TypeSpec) + - A sequence of types that will be mapped to output_keys in order + These types override any inferred types from the function's return annotation. + + Returns: + A tuple containing: + - input_types_dict: Mapping of parameter names to their inferred/specified types + - output_types_dict: Mapping of output keys to their inferred/specified types + + Raises: + ValueError: In various scenarios: + - Parameter has no type annotation and is not in input_types + - Function has return annotation but no output_keys specified + - Function has explicit None return but non-empty output_keys provided + - Multiple output_keys specified but return annotation is not a sequence type + - Return annotation is a sequence type but doesn't specify item types + - Number of types in return annotation doesn't match number of output_keys + - Output types sequence length doesn't match output_keys length + - Output key not specified in output_types and has no type annotation + + Examples: + >>> def add(x: int, y: int) -> int: + ... return x + y + >>> input_types, output_types = extract_function_data_types(add, ['result']) + >>> input_types + {'x': , 'y': } + >>> output_types + {'result': } + + >>> def process(data: str) -> tuple[int, str]: + ... return len(data), data.upper() + >>> input_types, output_types = extract_function_data_types( + ... process, ['length', 'upper_data'] + ... ) + >>> input_types + {'data': } + >>> output_types + {'length': , 'upper_data': } + + >>> def legacy_func(x, y): # No annotations + ... return x + y + >>> input_types, output_types = extract_function_data_types( + ... legacy_func, ['sum'], + ... input_types={'x': int, 'y': int}, + ... output_types={'sum': int} + ... ) + >>> input_types + {'x': , 'y': } + >>> output_types + {'sum': } + + >>> def multi_return(data: list) -> tuple[int, float, str]: + ... return len(data), sum(data), str(data) + >>> input_types, output_types = extract_function_data_types( + ... multi_return, ['count', 'total', 'repr'], + ... output_types=[int, float, str] # Override with sequence + ... ) + >>> output_types + {'count': , 'total': , 'repr': } + """ + verified_output_types: TypeSpec = {} + if output_types is not None: + if isinstance(output_types, dict): + verified_output_types = output_types + elif isinstance(output_types, Sequence): + # If output_types is a collection, convert it to a dict with keys from return_keys + if len(output_types) != len(output_keys): + raise ValueError( + f"Output types collection length {len(output_types)} does not match return keys length {len(output_keys)}." + ) + verified_output_types = {k: v for k, v in zip(output_keys, output_types)} + + signature = inspect.signature(func) + + param_info: TypeSpec = {} + for name, param in signature.parameters.items(): + if input_types and name in input_types: + param_info[name] = input_types[name] + else: + # check if the parameter has annotation + if param.annotation is not inspect.Signature.empty: + param_info[name] = param.annotation + else: + raise ValueError( + f"Parameter '{name}' has no type annotation and is not specified in input_types." + ) + + return_annot = signature.return_annotation + inferred_output_types: TypeSpec = {} + if return_annot is not inspect.Signature.empty and return_annot is not None: + output_item_types = [] + if len(output_keys) == 0: + raise ValueError( + "Function has a return type annotation, but no return keys were specified." + ) + elif len(output_keys) == 1: + # if only one return key, the entire annotation is inferred as the return type + output_item_types = [return_annot] + elif (get_origin(return_annot) or return_annot) in (tuple, list, Sequence): + if get_origin(return_annot) is None: + # right type was specified but did not specified the type of items + raise ValueError( + f"Function return type annotation {return_annot} is a Sequence type but does not specify item types." + ) + output_item_types = get_args(return_annot) + if len(output_item_types) != len(output_keys): + raise ValueError( + f"Function return type annotation {return_annot} has {len(output_item_types)} items, " + f"but output_keys has {len(output_keys)} items." + ) + else: + raise ValueError( + f"Multiple return keys were specified but return type annotation {return_annot} is not a sequence type (list, tuple, Collection)." + ) + for key, type_annot in zip(output_keys, output_item_types): + inferred_output_types[key] = type_annot + elif return_annot is None: + if len(output_keys) != 0: + raise ValueError( + f"Function provides explicit return type annotation as None, but return keys of length {len(output_keys)} were specified." + ) + else: + inferred_output_types = {k: inspect.Signature.empty for k in output_keys} + + # TODO: simplify the handling here -- technically all keys should already be in return_types + for key in output_keys: + if key in verified_output_types: + inferred_output_types[key] = verified_output_types[key] + elif ( + key not in inferred_output_types + or inferred_output_types[key] is inspect.Signature.empty + ): + raise ValueError( + f"Type for return item '{key}' is not specified in output_types and has no type annotation in function signature." + ) + return param_info, inferred_output_types diff --git a/src/orcabridge/types/registry.py b/src/orcabridge/types/registry.py new file mode 100644 index 0000000..b2c433f --- /dev/null +++ b/src/orcabridge/types/registry.py @@ -0,0 +1,618 @@ +from collections.abc import Callable +from typing import Any +import pyarrow as pa +from orcabridge.types import Packet +from .core import TypeHandler, TypeInfo + + +class TypeRegistry: + """Registry that manages type handlers with semantic type names.""" + + def __init__(self): + self._handlers: dict[ + type, tuple[TypeHandler, str] + ] = {} # Type -> (Handler, semantic_name) + self._semantic_handlers: dict[str, TypeHandler] = {} # semantic_name -> Handler + + def register( + self, + semantic_name: str, + handler: TypeHandler, + explicit_types: type | tuple[type, ...] | None = None, + override: bool = False, + ): + """Register a handler with a semantic type name. + + Args: + semantic_name: Identifier for this semantic type (e.g., 'path', 'uuid') + handler: The type handler instance + explicit_types: Optional override of types to register for (if different from handler's supported_types) + override: If True, allow overriding existing registration for the same semantic name and Python type(s) + """ + # Determine which types to register for + if explicit_types is not None: + types_to_register = ( + explicit_types + if isinstance(explicit_types, tuple) + else (explicit_types,) + ) + else: + supported = handler.supported_types() + types_to_register = ( + supported if isinstance(supported, tuple) else (supported,) + ) + + # Register handler for each type + for python_type in types_to_register: + if python_type in self._handlers and not override: + existing_semantic = self._handlers[python_type][1] + # TODO: handle overlapping registration more gracefully + raise ValueError( + f"Type {python_type} already registered with semantic type '{existing_semantic}'" + ) + + self._handlers[python_type] = (handler, semantic_name) + + # Register by semantic name + if semantic_name in self._semantic_handlers and not override: + raise ValueError(f"Semantic type '{semantic_name}' already registered") + + self._semantic_handlers[semantic_name] = handler + + def get_handler(self, python_type: type) -> TypeHandler | None: + """Get handler for a Python type.""" + handler_info = self._handlers.get(python_type) + return handler_info[0] if handler_info else None + + def get_semantic_name(self, python_type: type) -> str | None: + """Get semantic name for a Python type.""" + handler_info = self._handlers.get(python_type) + return handler_info[1] if handler_info else None + + def get_handler_by_semantic_name(self, semantic_name: str) -> TypeHandler | None: + """Get handler by semantic name.""" + return self._semantic_handlers.get(semantic_name) + + # TODO: reconsider the need for this method + def extract_type_info(self, python_type: type) -> TypeInfo: + """Extract TypeInfo for a Python type.""" + handler_info = self._handlers.get(python_type) + if not handler_info: + raise ValueError(f"Unsupported Python type: {python_type}") + + handler, semantic_name = handler_info + return TypeInfo( + python_type=python_type, + arrow_type=handler.to_storage_type(), + semantic_type=semantic_name, + ) + + def __contains__(self, python_type: type) -> bool: + """Check if a Python type is registered.""" + return python_type in self._handlers + + +def create_packet_converters( + packet_type_info: dict[str, type], registry: TypeRegistry +) -> tuple[ + Callable[[Packet], dict[str, Any]], + Callable[[dict[str, Any]], Packet], +]: + """Create optimized conversion functions for a specific packet type. + + Pre-looks up all handlers to avoid repeated registry lookups during conversion. + + Args: + type_info: Dictionary mapping parameter names to their Python types + registry: TypeRegistry containing handlers for type conversions + + Returns: + Tuple of (to_storage_converter, from_storage_converter) functions + + Raises: + ValueError: If any type in type_info is not supported by the registry + + Example: + type_info = { + 'file_path': Path, + 'threshold': float, + 'user_id': UUID + } + + to_storage, from_storage = create_packet_converters(type_info, registry) + + # Fast conversion (no registry lookups) + storage_packet = to_storage(original_packet) + restored_packet = from_storage(storage_packet) + """ + + # Pre-lookup all handlers and validate they exist + handlers: dict[str, TypeHandler] = {} + expected_types: dict[str, type] = {} + + for key, python_type in packet_type_info.items(): + handler = registry.get_handler(python_type) + if handler is None: + raise ValueError( + f"No handler registered for type {python_type} (key: '{key}')" + ) + + handlers[key] = handler + expected_types[key] = python_type + + def to_storage_converter(packet: Packet) -> dict[str, Any]: + """Convert packet to storage representation. + + Args: + packet: Dictionary mapping parameter names to Python values + + Returns: + Dictionary with same keys but values converted to storage format + + Raises: + KeyError: If packet keys don't match the expected type_info keys + TypeError: If value type doesn't match expected type + ValueError: If conversion fails + """ + # Validate packet keys + packet_keys = set(packet.keys()) + expected_keys = set(expected_types.keys()) + + if packet_keys != expected_keys: + missing_in_packet = expected_keys - packet_keys + extra_in_packet = packet_keys - expected_keys + + error_parts = [] + if missing_in_packet: + error_parts.append(f"Missing keys: {missing_in_packet}") + if extra_in_packet: + error_parts.append(f"Extra keys: {extra_in_packet}") + + raise KeyError( + f"Packet keys don't match expected keys. {'; '.join(error_parts)}" + ) + + # Convert each value + storage_packet = {} + + for key, value in packet.items(): + expected_type = expected_types[key] + handler = handlers[key] + + # Handle None values + if value is None: + storage_packet[key] = None + continue + + # Validate value type + if not isinstance(value, expected_type): + raise TypeError( + f"Value for '{key}' is {type(value).__name__}, expected {expected_type.__name__}" + ) + + # Convert to storage representation + try: + storage_value = handler.to_storage_value(value) + storage_packet[key] = storage_value + except Exception as e: + raise ValueError( + f"Failed to convert '{key}' of type {expected_type}: {e}" + ) from e + + return storage_packet + + def from_storage_converter(storage_packet: dict[str, Any]) -> Packet: + """Convert storage packet back to Python values. + + Args: + storage_packet: Dictionary with values in storage format + + Returns: + Dictionary with same keys but values converted back to Python types + + Raises: + KeyError: If storage_packet keys don't match the expected type_info keys + ValueError: If conversion fails + """ + # Validate storage packet keys + packet_keys = set(storage_packet.keys()) + expected_keys = set(expected_types.keys()) + + if packet_keys != expected_keys: + missing_in_packet = expected_keys - packet_keys + extra_in_packet = packet_keys - expected_keys + + error_parts = [] + if missing_in_packet: + error_parts.append(f"Missing keys: {missing_in_packet}") + if extra_in_packet: + error_parts.append(f"Extra keys: {extra_in_packet}") + + raise KeyError( + f"Storage packet keys don't match expected keys. {'; '.join(error_parts)}" + ) + + # Convert each value back + python_packet = {} + + for key, storage_value in storage_packet.items(): + handler = handlers[key] + + # Handle None values + if storage_value is None: + python_packet[key] = None + continue + + # Convert from storage representation + try: + python_value = handler.from_storage_value(storage_value) + python_packet[key] = python_value + except Exception as e: + raise ValueError(f"Failed to convert '{key}' from storage: {e}") from e + + return python_packet + + return to_storage_converter, from_storage_converter + + +def convert_packet_to_storage( + packet: Packet, type_info: dict[str, type], registry: TypeRegistry +) -> Packet: + """Convert a packet to its storage representation using the provided type info. + + Args: + packet: The original packet to convert + type_info: Dictionary mapping parameter names to their Python types + registry: TypeRegistry containing handlers for type conversions + + Returns: + Converted packet in storage format + """ + to_storage, _ = create_packet_converters(type_info, registry) + return to_storage(packet) + + +def convert_storage_to_packet( + storage_packet: dict[str, Any], type_info: dict[str, type], registry: TypeRegistry +) -> Packet | None: + pass + + +class PacketConverter: + """ + Convenience class for converting packets between storage and Python formats. + """ + + def __init__(self, packet_type_info: dict[str, type], registry: TypeRegistry): + """Initialize the packet converter with type info and registry.""" + self._to_storage, self._from_storage = create_packet_converters( + packet_type_info, registry + ) + self.packet_type_info = packet_type_info + + def to_storage(self, packet: Packet) -> dict[str, Any]: + """Convert packet to storage representation.""" + return self._to_storage(packet) + + def from_storage(self, storage_packet: dict[str, Any]) -> Packet: + """Convert storage packet back to Python values.""" + return self._from_storage(storage_packet) + + +def convert_packet_to_arrow_table( + packet: dict[str, Any], type_info: dict[str, type], registry: TypeRegistry +) -> pa.Table: + """Convert a single packet to a PyArrow Table with one row. + + Args: + packet: Dictionary mapping parameter names to Python values + type_info: Dictionary mapping parameter names to their Python types + registry: TypeRegistry containing handlers for type conversions + + Returns: + PyArrow Table with the packet data as a single row + """ + # Get the converter functions + to_storage, _ = create_packet_converters(type_info, registry) + + # Convert packet to storage format + storage_packet = to_storage(packet) + + # Create schema + schema_fields = [] + for key, python_type in type_info.items(): + type_info_obj = registry.extract_type_info(python_type) + schema_fields.append(pa.field(key, type_info_obj.arrow_type)) + + schema = pa.schema(schema_fields) + + # Convert storage packet to arrays (single element each) + arrays = [] + for field in schema: + field_name = field.name + value = storage_packet[field_name] + + # Create single-element array + array = pa.array([value], type=field.type) + arrays.append(array) + + # Create table + return pa.Table.from_arrays(arrays, schema=schema) + + +def convert_packets_to_arrow_table( + packets: list[dict[str, Any]], type_info: dict[str, type], registry: TypeRegistry +) -> pa.Table: + """Convert multiple packets to a PyArrow Table. + + Args: + packets: List of packets (dictionaries) + type_info: Dictionary mapping parameter names to their Python types + registry: TypeRegistry containing handlers for type conversions + + Returns: + PyArrow Table with all packet data as rows + """ + if not packets: + # Return empty table with correct schema + schema_fields = [] + for key, python_type in type_info.items(): + type_info_obj = registry.extract_type_info(python_type) + schema_fields.append(pa.field(key, type_info_obj.arrow_type)) + schema = pa.schema(schema_fields) + return pa.Table.from_arrays([], schema=schema) + + # Get the converter functions (reuse for all packets) + to_storage, _ = create_packet_converters(type_info, registry) + + # Convert all packets to storage format + storage_packets = [to_storage(packet) for packet in packets] + + # Create schema + schema_fields = [] + for key, python_type in type_info.items(): + type_info_obj = registry.extract_type_info(python_type) + schema_fields.append(pa.field(key, type_info_obj.arrow_type)) + + schema = pa.schema(schema_fields) + + # Group values by column + column_data = {} + for field in schema: + field_name = field.name + column_data[field_name] = [packet[field_name] for packet in storage_packets] + + # Create arrays for each column + arrays = [] + for field in schema: + field_name = field.name + values = column_data[field_name] + array = pa.array(values, type=field.type) + arrays.append(array) + + # Create table + return pa.Table.from_arrays(arrays, schema=schema) + + +def convert_packet_to_arrow_table_with_field_metadata( + packet: Packet, type_info: dict[str, type], registry: TypeRegistry +) -> pa.Table: + """Convert packet to Arrow table with semantic type stored as field metadata.""" + + # Get converter + to_storage, _ = create_packet_converters(type_info, registry) + storage_packet = to_storage(packet) + + # Create schema fields with metadata + schema_fields = [] + for key, python_type in type_info.items(): + type_info_obj = registry.extract_type_info(python_type) + + # Create field with semantic type metadata + field_metadata = {} + if type_info_obj.semantic_type: + field_metadata["semantic_type"] = type_info_obj.semantic_type + + field = pa.field(key, type_info_obj.arrow_type, metadata=field_metadata) + schema_fields.append(field) + + schema = pa.schema(schema_fields) + + # Create arrays + arrays = [] + for field in schema: + value = storage_packet[field.name] + array = pa.array([value], type=field.type) + arrays.append(array) + + return pa.Table.from_arrays(arrays, schema=schema) + + +def convert_packets_to_arrow_table_with_field_metadata( + packets: list[Packet], type_info: dict[str, type], registry: TypeRegistry +) -> pa.Table: + """Convert multiple packets to Arrow table with field metadata.""" + + if not packets: + return _create_empty_table_with_field_metadata(type_info, registry) + + # Get converter + to_storage, _ = create_packet_converters(type_info, registry) + storage_packets = [to_storage(packet) for packet in packets] + + # Create schema with field metadata + schema = _create_schema_with_field_metadata(type_info, registry) + + # Group values by column + column_data = {} + for field in schema: + field_name = field.name + column_data[field_name] = [packet[field_name] for packet in storage_packets] + + # Create arrays + arrays = [] + for field in schema: + values = column_data[field.name] + array = pa.array(values, type=field.type) + arrays.append(array) + + return pa.Table.from_arrays(arrays, schema=schema) + + +def _create_schema_with_field_metadata( + type_info: dict[str, type], registry: TypeRegistry +) -> pa.Schema: + """Helper to create schema with field-level semantic type metadata.""" + schema_fields = [] + + for key, python_type in type_info.items(): + type_info_obj = registry.extract_type_info(python_type) + + # Create field metadata + field_metadata = {} + if type_info_obj.semantic_type: + field_metadata["semantic_type"] = type_info_obj.semantic_type + + field = pa.field(key, type_info_obj.arrow_type, metadata=field_metadata) + schema_fields.append(field) + + return pa.schema(schema_fields) + + +def _create_empty_table_with_field_metadata( + type_info: dict[str, type], registry: TypeRegistry +) -> pa.Table: + """Helper to create empty table with correct schema and field metadata.""" + schema = _create_schema_with_field_metadata(type_info, registry) + arrays = [pa.array([], type=field.type) for field in schema] + return pa.Table.from_arrays(arrays, schema=schema) + + +def extract_field_semantic_types(table: pa.Table) -> dict[str, str | None]: + """Extract semantic type from each field's metadata.""" + semantic_types = {} + + for field in table.schema: + if field.metadata and b"semantic_type" in field.metadata: + semantic_type = field.metadata[b"semantic_type"].decode("utf-8") + semantic_types[field.name] = semantic_type + else: + semantic_types[field.name] = None + + return semantic_types + + +def convert_arrow_table_to_packets_with_field_metadata( + table: pa.Table, registry: TypeRegistry +) -> list[Packet]: + """Convert Arrow table back to packets using field metadata.""" + + # Extract semantic types from field metadata + field_semantic_types = extract_field_semantic_types(table) + + # Reconstruct type_info from field metadata + type_info = {} + for field in table.schema: + field_name = field.name + semantic_type = field_semantic_types.get(field_name) + + if semantic_type: + # Get handler by semantic type + handler = registry.get_handler_by_semantic_name(semantic_type) + if handler: + python_type = handler.supported_types() + if isinstance(python_type, tuple): + python_type = python_type[0] # Take first if multiple + type_info[field_name] = python_type + else: + # Fallback to basic type inference + type_info[field_name] = _infer_python_type_from_arrow(field.type) + else: + # No semantic type metadata - infer from Arrow type + type_info[field_name] = _infer_python_type_from_arrow(field.type) + + # Convert using reconstructed type info + _, from_storage = create_packet_converters(type_info, registry) + storage_packets = table.to_pylist() + + return [from_storage(packet) for packet in storage_packets] + + +def _infer_python_type_from_arrow(arrow_type: pa.DataType) -> type: + """Infer Python type from Arrow type as fallback.""" + if arrow_type == pa.int64(): + return int + elif arrow_type == pa.float64(): + return float + elif arrow_type == pa.string(): + return str + elif arrow_type == pa.bool_(): + return bool + elif arrow_type == pa.binary(): + return bytes + else: + return str # Safe fallback + + +# TODO: move these functions to util +def escape_with_postfix(field: str, postfix=None, separator="_") -> str: + """ + Escape the field string by doubling separators and optionally append a postfix. + This function takes a field string and escapes any occurrences of the separator + by doubling them, then optionally appends a postfix with a separator prefix. + + Args: + field (str): The input string containing to be escaped. + postfix (str, optional): An optional postfix to append to the escaped string. + If None, no postfix is added. Defaults to None. + separator (str, optional): The separator character to escape and use for + prefixing the postfix. Defaults to "_". + Returns: + str: The escaped string with optional postfix. Returns empty string if + fields is provided but postfix is None. + Examples: + >>> escape_with_postfix("field1_field2", "suffix") + 'field1__field2_suffix' + >>> escape_with_postfix("name_age_city", "backup", "_") + 'name__age__city_backup' + >>> escape_with_postfix("data-info", "temp", "-") + 'data--info-temp' + >>> escape_with_postfix("simple", None) + 'simple' + >>> escape_with_postfix("no_separators", "end") + 'no__separators_end' + """ + + return field.replace(separator, separator * 2) + (f"_{postfix}" if postfix else "") + + +def unescape_with_postfix(field: str, separator="_") -> tuple[str, str | None]: + """ + Unescape a string by converting double separators back to single separators and extract postfix metadata. + This function reverses the escaping process where single separators were doubled to avoid + conflicts with metadata delimiters. It splits the input on double separators, then extracts + any postfix metadata from the last part. + + Args: + field (str): The escaped string containing doubled separators and optional postfix metadata + separator (str, optional): The separator character used for escaping. Defaults to "_" + Returns: + tuple[str, str | None]: A tuple containing: + - The unescaped string with single separators restored + - The postfix metadata if present, None otherwise + Examples: + >>> unescape_with_postfix("field1__field2__field3") + ('field1_field2_field3', None) + >>> unescape_with_postfix("field1__field2_metadata") + ('field1_field2', 'metadata') + >>> unescape_with_postfix("simple") + ('simple', None) + >>> unescape_with_postfix("field1--field2", separator="-") + ('field1-field2', None) + >>> unescape_with_postfix("field1--field2-meta", separator="-") + ('field1-field2', 'meta') + """ + + parts = field.split(separator * 2) + parts[-1], *meta = parts[-1].split("_", 1) + return separator.join(parts), meta[0] if meta else None diff --git a/tests/test_types/__init__.py b/tests/test_types/__init__.py new file mode 100644 index 0000000..aa691b1 --- /dev/null +++ b/tests/test_types/__init__.py @@ -0,0 +1 @@ +# Test package for orcabridge types module diff --git a/tests/test_types/test_inference/__init__.py b/tests/test_types/test_inference/__init__.py new file mode 100644 index 0000000..45e6baf --- /dev/null +++ b/tests/test_types/test_inference/__init__.py @@ -0,0 +1 @@ +# Test package for orcabridge types inference module diff --git a/tests/test_types/test_inference/test_extract_function_data_types.py b/tests/test_types/test_inference/test_extract_function_data_types.py new file mode 100644 index 0000000..c3426b6 --- /dev/null +++ b/tests/test_types/test_inference/test_extract_function_data_types.py @@ -0,0 +1,391 @@ +""" +Unit tests for the extract_function_data_types function. + +This module tests the function type extraction functionality, covering: +- Type inference from function annotations +- User-provided type overrides +- Various return type scenarios (single, multiple, None) +- Error conditions and edge cases +""" + +import pytest +from collections.abc import Collection + +from orcabridge.types.inference import extract_function_data_types + + +class TestExtractFunctionDataTypes: + """Test cases for extract_function_data_types function.""" + + def test_simple_annotated_function(self): + """Test function with simple type annotations.""" + + def add(x: int, y: int) -> int: + return x + y + + input_types, output_types = extract_function_data_types(add, ["result"]) + + assert input_types == {"x": int, "y": int} + assert output_types == {"result": int} + + def test_multiple_return_values_tuple(self): + """Test function returning multiple values as tuple.""" + + def process(data: str) -> tuple[int, str]: + return len(data), data.upper() + + input_types, output_types = extract_function_data_types( + process, ["length", "upper_data"] + ) + + assert input_types == {"data": str} + assert output_types == {"length": int, "upper_data": str} + + def test_multiple_return_values_list(self): + """Test function returning multiple values as list.""" + + def split_data(data: str) -> tuple[str, str]: + word1, *words = data.split() + if len(words) < 1: + word2 = "" + else: + word2 = words[0] + return word1, word2 + + # Note: This tests the case where we have multiple output keys + # but the return type is list[str] (homogeneous) + input_types, output_types = extract_function_data_types( + split_data, ["first_word", "second_word"] + ) + + assert input_types == {"data": str} + assert output_types == {"first_word": str, "second_word": str} + + def test_no_return_annotation_multiple_keys(self): + """Test function with no return annotation and multiple output keys.""" + + def mystery_func(x: int): + return x, str(x) + + with pytest.raises( + ValueError, + match="Type for return item 'number' is not specified in output_types", + ): + input_types, output_types = extract_function_data_types( + mystery_func, + ["number", "text"], + ) + + def test_input_types_override(self): + """Test overriding parameter types with input_types.""" + + def legacy_func(x, y) -> int: # No annotations + return x + y + + input_types, output_types = extract_function_data_types( + legacy_func, ["sum"], input_types={"x": int, "y": int} + ) + + assert input_types == {"x": int, "y": int} + assert output_types == {"sum": int} + + def test_partial_input_types_override(self): + """Test partial override where some params have annotations.""" + + def mixed_func(x: int, y) -> int: # One annotated, one not + return x + y + + input_types, output_types = extract_function_data_types( + mixed_func, ["sum"], input_types={"y": float} + ) + + assert input_types == {"x": int, "y": float} + assert output_types == {"sum": int} + + def test_output_types_dict_override(self): + """Test overriding output types with dict.""" + + def mystery_func(x: int) -> str: + return str(x) + + input_types, output_types = extract_function_data_types( + mystery_func, ["result"], output_types={"result": float} + ) + + assert input_types == {"x": int} + assert output_types == {"result": float} + + def test_output_types_sequence_override(self): + """Test overriding output types with sequence.""" + + def multi_return(data: list) -> tuple[int, float, str]: + return len(data), sum(data), str(data) + + input_types, output_types = extract_function_data_types( + multi_return, ["count", "total", "repr"], output_types=[int, float, str] + ) + + assert input_types == {"data": list} + assert output_types == {"count": int, "total": float, "repr": str} + + def test_complex_types(self): + """Test function with complex type annotations.""" + + def complex_func(x: str | None, y: int | float) -> tuple[bool, list[str]]: + return bool(x), [x] if x else [] + + input_types, output_types = extract_function_data_types( + complex_func, ["is_valid", "items"] + ) + + assert input_types == {"x": str | None, "y": int | float} + assert output_types == {"is_valid": bool, "items": list[str]} + + def test_none_return_annotation(self): + """Test function with explicit None return annotation.""" + + def side_effect_func(x: int) -> None: + print(x) + + input_types, output_types = extract_function_data_types(side_effect_func, []) + + assert input_types == {"x": int} + assert output_types == {} + + def test_empty_parameters(self): + """Test function with no parameters.""" + + def get_constant() -> int: + return 42 + + input_types, output_types = extract_function_data_types(get_constant, ["value"]) + + assert input_types == {} + assert output_types == {"value": int} + + # Error condition tests + + def test_missing_parameter_annotation_error(self): + """Test error when parameter has no annotation and not in input_types.""" + + def bad_func(x, y: int): + return x + y + + with pytest.raises(ValueError, match="Parameter 'x' has no type annotation"): + extract_function_data_types(bad_func, ["result"]) + + def test_return_annotation_but_no_output_keys_error(self): + """Test error when function has return annotation but no output keys.""" + + def func_with_return(x: int) -> str: + return str(x) + + with pytest.raises( + ValueError, + match="Function has a return type annotation, but no return keys were specified", + ): + extract_function_data_types(func_with_return, []) + + def test_none_return_with_output_keys_error(self): + """Test error when function returns None but output keys provided.""" + + def side_effect_func(x: int) -> None: + print(x) + + with pytest.raises( + ValueError, + match="Function provides explicit return type annotation as None", + ): + extract_function_data_types(side_effect_func, ["result"]) + + def test_single_return_multiple_keys_error(self): + """Test error when single return type but multiple output keys.""" + + def single_return(x: int) -> str: + return str(x) + + with pytest.raises( + ValueError, + match="Multiple return keys were specified but return type annotation .* is not a sequence type", + ): + extract_function_data_types(single_return, ["first", "second"]) + + def test_unparameterized_sequence_type_error(self): + """Test error when return type is sequence but not parameterized.""" + + def bad_return(x: int) -> tuple: # tuple without types + return x, str(x) + + with pytest.raises( + ValueError, match="is a Sequence type but does not specify item types" + ): + extract_function_data_types(bad_return, ["number", "text"]) + + def test_mismatched_return_types_count_error(self): + """Test error when return type count doesn't match output keys count.""" + + def three_returns(x: int) -> tuple[int, str, float]: + return x, str(x), float(x) + + with pytest.raises( + ValueError, match="has 3 items, but output_keys has 2 items" + ): + extract_function_data_types(three_returns, ["first", "second"]) + + def test_mismatched_output_types_sequence_length_error(self): + """Test error when output_types sequence length doesn't match output_keys.""" + + def func(x: int) -> tuple[int, str]: + return x, str(x) + + with pytest.raises( + ValueError, + match="Output types collection length .* does not match return keys length", + ): + extract_function_data_types( + func, + ["first", "second"], + output_types=[int, str, float], # Wrong length + ) + + def test_missing_output_type_specification_error(self): + """Test error when output key not specified and no annotation.""" + + def no_return_annotation(x: int): + return x, str(x) + + with pytest.raises( + ValueError, + match="Type for return item 'first' is not specified in output_types", + ): + extract_function_data_types(no_return_annotation, ["first", "second"]) + + # Edge cases + + def test_callable_with_args_kwargs(self): + """Test function with *args and **kwargs.""" + + def flexible_func(x: int, *args: str, **kwargs: float) -> bool: + return True + + input_types, output_types = extract_function_data_types( + flexible_func, ["success"] + ) + + assert "x" in input_types + assert "args" in input_types + assert "kwargs" in input_types + assert input_types["x"] is int + assert output_types == {"success": bool} + + def test_mixed_override_scenarios(self): + """Test complex scenario with both input and output overrides.""" + + def complex_func(a, b: str) -> tuple[int, str]: + return len(b), b.upper() + + input_types, output_types = extract_function_data_types( + complex_func, + ["length", "upper"], + input_types={"a": float}, + output_types={"length": int}, # Override only one output + ) + + assert input_types == {"a": float, "b": str} + assert output_types == {"length": int, "upper": str} + + def test_generic_types(self): + """Test function with generic type annotations.""" + + def generic_func(data: list[int]) -> dict[str, int]: + return {str(i): i for i in data} + + input_types, output_types = extract_function_data_types( + generic_func, ["mapping"] + ) + + assert input_types == {"data": list[int]} + assert output_types == {"mapping": dict[str, int]} + + def test_sequence_return_type_inference(self): + """Test that sequence types are properly handled in return annotations.""" + + def list_func( + x: int, + ) -> tuple[str, int]: # This should work for multiple outputs + return str(x), x + + # This tests the sequence detection logic + input_types, output_types = extract_function_data_types( + list_func, ["text", "number"] + ) + + assert input_types == {"x": int} + assert output_types == {"text": str, "number": int} + + def test_collection_return_type_inference(self): + """Test Collection type in return annotation.""" + + def collection_func(x: int) -> Collection[str]: + return [str(x)] + + # Single output key with Collection type + input_types, output_types = extract_function_data_types( + collection_func, ["result"] + ) + + assert input_types == {"x": int} + assert output_types == {"result": Collection[str]} + + +class TestTypeSpecHandling: + """Test TypeSpec and type handling edge cases.""" + + def test_empty_function(self): + """Test function with no parameters and no return.""" + + def empty_func(): + pass + + input_types, output_types = extract_function_data_types(empty_func, []) + + assert input_types == {} + assert output_types == {} + + def test_preserve_annotation_objects(self): + """Test that complex annotation objects are preserved.""" + from typing import TypeVar, Generic + + T = TypeVar("T") + + class Container(Generic[T]): + pass + + def generic_container_func(x: Container[int]) -> Container[str]: + return Container() + + input_types, output_types = extract_function_data_types( + generic_container_func, ["result"] + ) + + assert input_types == {"x": Container[int]} + assert output_types == {"result": Container[str]} + + def test_output_types_dict_partial_override(self): + """Test partial override with output_types dict.""" + + def three_output_func() -> tuple[int, str, float]: + return 1, "hello", 3.14 + + input_types, output_types = extract_function_data_types( + three_output_func, + ["num", "text", "decimal"], + output_types={"text": bytes}, # Override only middle one + ) + + assert input_types == {} + assert output_types == { + "num": int, + "text": bytes, # Overridden + "decimal": float, + } From b306a746497847b8ae30670cc02ae3335e8421e1 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 13 Jun 2025 03:52:53 +0000 Subject: [PATCH 05/28] feat: add typespec compatibility checks --- pyproject.toml | 1 + src/orcabridge/types/inference.py | 40 +++++++++++++++++++++++++++++++ src/orcabridge/types/registry.py | 21 ++++++---------- uv.lock | 17 ++++++++++--- 4 files changed, 62 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fe1e914..0dfedeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "pyyaml>=6.0.2", "pyarrow>=20.0.0", "polars>=1.30.0", + "beartype>=0.21.0", ] readme = "README.md" requires-python = ">=3.10" diff --git a/src/orcabridge/types/inference.py b/src/orcabridge/types/inference.py index 9ef3d74..09ea633 100644 --- a/src/orcabridge/types/inference.py +++ b/src/orcabridge/types/inference.py @@ -5,12 +5,52 @@ from typing import get_origin, get_args, TypeAlias import inspect import logging +from beartype.door import is_bearable, is_subhint + logger = logging.getLogger(__name__) DataType: TypeAlias = type TypeSpec: TypeAlias = dict[str, DataType] # Mapping of parameter names to their types +def verify_against_typespec(packet: dict, typespec: TypeSpec) -> bool: + """Verify that the dictionary's types match the expected types in the typespec.""" + # verify that packet contains no keys not in typespec + if set(packet.keys()) - set(typespec.keys()): + logger.warning( + f"Packet contains keys not in typespec: {set(packet.keys()) - set(typespec.keys())}. " + ) + return False + for key, type_info in typespec.items(): + if key not in packet: + logger.warning( + f"Key '{key}' not found in packet. Assuming None but this behavior may change in the future" + ) + if not is_bearable(packet.get(key), type_info): + logger.warning( + f"Type mismatch for key '{key}': expected {type_info}, got {packet.get(key)}." + ) + return False + return True + + +# TODO: is_subhint does not handle invariance properly +# so when working with mutable types, we have to make sure to perform deep copy +def check_typespec_compatibility( + incoming_types: TypeSpec, receiving_types: TypeSpec +) -> bool: + for key, type_info in incoming_types.items(): + if key not in receiving_types: + logger.warning(f"Key '{key}' not found in parameter types.") + return False + if not is_subhint(type_info, receiving_types[key]): + logger.warning( + f"Type mismatch for key '{key}': expected {receiving_types[key]}, got {type_info}." + ) + return False + return True + + def extract_function_data_types( func: Callable, output_keys: Collection[str], diff --git a/src/orcabridge/types/registry.py b/src/orcabridge/types/registry.py index b2c433f..e7b4553 100644 --- a/src/orcabridge/types/registry.py +++ b/src/orcabridge/types/registry.py @@ -73,25 +73,18 @@ def get_handler_by_semantic_name(self, semantic_name: str) -> TypeHandler | None """Get handler by semantic name.""" return self._semantic_handlers.get(semantic_name) - # TODO: reconsider the need for this method - def extract_type_info(self, python_type: type) -> TypeInfo: - """Extract TypeInfo for a Python type.""" - handler_info = self._handlers.get(python_type) - if not handler_info: - raise ValueError(f"Unsupported Python type: {python_type}") - - handler, semantic_name = handler_info - return TypeInfo( - python_type=python_type, - arrow_type=handler.to_storage_type(), - semantic_type=semantic_name, - ) - def __contains__(self, python_type: type) -> bool: """Check if a Python type is registered.""" return python_type in self._handlers +def is_packet_supported( + packet_type_info: dict[str, type], registry: TypeRegistry +) -> bool: + """Check if all types in the packet are supported by the registry.""" + return all(python_type in registry for python_type in packet_type_info.values()) + + def create_packet_converters( packet_type_info: dict[str, type], registry: TypeRegistry ) -> tuple[ diff --git a/uv.lock b/uv.lock index 26e122a..a886dbe 100644 --- a/uv.lock +++ b/uv.lock @@ -117,6 +117,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233, upload-time = "2024-11-06T16:41:37.9Z" }, ] +[[package]] +name = "beartype" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/f9/21e5a9c731e14f08addd53c71fea2e70794e009de5b98e6a2c3d2f3015d6/beartype-0.21.0.tar.gz", hash = "sha256:f9a5078f5ce87261c2d22851d19b050b64f6a805439e8793aecf01ce660d3244", size = 1437066, upload-time = "2025-05-22T05:09:27.116Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/31/87045d1c66ee10a52486c9d2047bc69f00f2689f69401bb1e998afb4b205/beartype-0.21.0-py3-none-any.whl", hash = "sha256:b6a1bd56c72f31b0a496a36cc55df6e2f475db166ad07fa4acc7e74f4c7f34c0", size = 1191340, upload-time = "2025-05-22T05:09:24.606Z" }, +] + [[package]] name = "cachetools" version = "5.5.2" @@ -1185,6 +1194,7 @@ wheels = [ name = "orcabridge" source = { editable = "." } dependencies = [ + { name = "beartype" }, { name = "matplotlib" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -1216,6 +1226,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "beartype", specifier = ">=0.21.0" }, { name = "matplotlib", specifier = ">=3.10.3" }, { name = "networkx" }, { name = "pandas", specifier = ">=2.2.3" }, @@ -2126,11 +2137,11 @@ wheels = [ [[package]] name = "typing-extensions" -version = "4.13.2" +version = "4.14.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f6/37/23083fcd6e35492953e8d2aaaa68b860eb422b34627b13f2ce3eb6106061/typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef", size = 106967, upload-time = "2025-04-10T14:19:05.416Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d1/bc/51647cd02527e87d05cb083ccc402f93e441606ff1f01739a62c8ad09ba5/typing_extensions-4.14.0.tar.gz", hash = "sha256:8676b788e32f02ab42d9e7c61324048ae4c6d844a399eebace3d4979d75ceef4", size = 107423, upload-time = "2025-06-02T14:52:11.399Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/54/b1ae86c0973cc6f0210b53d508ca3641fb6d0c56823f288d108bc7ab3cc8/typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c", size = 45806, upload-time = "2025-04-10T14:19:03.967Z" }, + { url = "https://files.pythonhosted.org/packages/69/e0/552843e0d356fbb5256d21449fa957fa4eff3bbc135a74a691ee70c7c5da/typing_extensions-4.14.0-py3-none-any.whl", hash = "sha256:a1514509136dd0b477638fc68d6a91497af5076466ad0fa6c338e44e359944af", size = 43839, upload-time = "2025-06-02T14:52:10.026Z" }, ] [[package]] From c8fb500958633d54d5cde434b0a7e21c6c0bacc8 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 13 Jun 2025 05:44:09 +0000 Subject: [PATCH 06/28] wip: refactor: clean up logic around function and object hashing --- src/orcabridge/hashing/files.py | 108 ------------------ src/orcabridge/hashing/function.py | 2 - .../hashing/function_info_extractors.py | 55 +++++++++ src/orcabridge/hashing/object_hashers.py | 20 ++++ src/orcabridge/hashing/types.py | 23 +++- 5 files changed, 97 insertions(+), 111 deletions(-) delete mode 100644 src/orcabridge/hashing/files.py delete mode 100644 src/orcabridge/hashing/function.py create mode 100644 src/orcabridge/hashing/function_info_extractors.py create mode 100644 src/orcabridge/hashing/object_hashers.py diff --git a/src/orcabridge/hashing/files.py b/src/orcabridge/hashing/files.py deleted file mode 100644 index 3a70b9d..0000000 --- a/src/orcabridge/hashing/files.py +++ /dev/null @@ -1,108 +0,0 @@ -import threading -from typing import Optional - -from orcabridge.hashing.core import hash_file, hash_packet, hash_pathset -from orcabridge.hashing.types import FileHasher, StringCacher -from orcabridge.types import Packet, PathLike, PathSet - - -# Completely unnecessary to inherit from FileHasher, but this -# allows for type checking based on ininstance -class DefaultFileHasher(FileHasher): - """Default implementation for file hashing.""" - - def __init__( - self, - algorithm: str = "sha256", - buffer_size: int = 65536, - char_count: int | None = 32, - ): - self.algorithm = algorithm - self.buffer_size = buffer_size - self.char_count = char_count - - def hash_file(self, file_path: PathLike) -> str: - return hash_file( - file_path, algorithm=self.algorithm, buffer_size=self.buffer_size - ) - - def hash_pathset(self, pathset: PathSet) -> str: - return hash_pathset( - pathset, - algorithm=self.algorithm, - buffer_size=self.buffer_size, - char_count=self.char_count, - file_hasher=self.hash_file, - ) - - def hash_packet(self, packet: Packet) -> str: - return hash_packet( - packet, - algorithm=self.algorithm, - buffer_size=self.buffer_size, - char_count=self.char_count, - pathset_hasher=self.hash_pathset, - ) - - -class InMemoryCacher(StringCacher): - """Thread-safe in-memory LRU cache.""" - - def __init__(self, max_size: int | None = 1000): - self.max_size = max_size - self._cache = {} - self._access_order = [] - self._lock = threading.RLock() - - def get_cached(self, cache_key: str) -> Optional[str]: - with self._lock: - if cache_key in self._cache: - self._access_order.remove(cache_key) - self._access_order.append(cache_key) - return self._cache[cache_key] - return None - - def set_cached(self, cache_key: str, value: str) -> None: - with self._lock: - if cache_key in self._cache: - self._access_order.remove(cache_key) - elif self.max_size is not None and len(self._cache) >= self.max_size: - oldest = self._access_order.pop(0) - del self._cache[oldest] - self._cache[cache_key] = value - self._access_order.append(cache_key) - - def clear_cache(self) -> None: - with self._lock: - self._cache.clear() - self._access_order.clear() - - -class CachedFileHasher(FileHasher): - """FileHasher with caching capabilities.""" - - def __init__( - self, - file_hasher: FileHasher, - string_cacher: StringCacher, - cache_file=True, - cache_pathset=False, - cache_packet=False, - ): - self.file_hasher = file_hasher - self.string_cacher = string_cacher - self.cache_file = cache_file - self.cache_pathset = cache_pathset - self.cache_packet = cache_packet - - def hash_file(self, file_path: PathLike) -> str: - cache_key = f"file:{file_path}" - if self.cache_file: - cached_value = self.string_cacher.get_cached(cache_key) - if cached_value is not None: - return cached_value - value = self.file_hasher.hash_file(file_path) - if self.cache_file: - # Store the hash in the cache - self.string_cacher.set_cached(cache_key, value) - return value diff --git a/src/orcabridge/hashing/function.py b/src/orcabridge/hashing/function.py deleted file mode 100644 index 501f385..0000000 --- a/src/orcabridge/hashing/function.py +++ /dev/null @@ -1,2 +0,0 @@ -# Provides functions for hashing of a Python function - diff --git a/src/orcabridge/hashing/function_info_extractors.py b/src/orcabridge/hashing/function_info_extractors.py new file mode 100644 index 0000000..74be127 --- /dev/null +++ b/src/orcabridge/hashing/function_info_extractors.py @@ -0,0 +1,55 @@ +from .types import FunctionInfoExtractor +from collections.abc import Callable +from typing import Any, Literal + + +class FunctionNameExtractor: + """ + Extractor that only uses the function name for information extraction. + """ + + def extract_function_info(self, func: Callable[..., Any]) -> dict[str, Any]: + """ + Extracts information from the function based on its name. + """ + if not callable(func): + raise TypeError("Provided object is not callable") + + # Use the function's name as the hash + function_name = func.__name__ if hasattr(func, "__name__") else str(func) + return {"name": function_name} + + +class FunctionSignatureExtractor: + """ + Extractor that uses the function signature for information extraction. + """ + + def extract_function_info(self, func: Callable[..., Any]) -> dict[str, Any]: + """ + Extracts information from the function based on its signature. + """ + if not callable(func): + raise TypeError("Provided object is not callable") + + # Use the function's signature as the hash + function_signature = str(func.__code__) + return {"signature": function_signature} + + +class FunctionInfoExtractorFactory: + """Factory for creating various extractor combinations.""" + + @staticmethod + def create_function_info_extractor( + strategy: Literal["name", "signature"] = "signature", + ) -> FunctionInfoExtractor: + """Create a basic composite extractor.""" + if strategy == "name": + return FunctionNameExtractor() + elif strategy == "signature": + return FunctionSignatureExtractor() + else: + raise ValueError( + f"Unknown strategy: {strategy}. Use 'name' or 'signature'." + ) diff --git a/src/orcabridge/hashing/object_hashers.py b/src/orcabridge/hashing/object_hashers.py new file mode 100644 index 0000000..c83a6cf --- /dev/null +++ b/src/orcabridge/hashing/object_hashers.py @@ -0,0 +1,20 @@ +from .types import FunctionInfoExtractor +from .core import hash_object + + +class DefaultObjectHasher: + """ + Default object hasher that returns the string representation of the object. + """ + + def __init__(self, function_info_extractor: FunctionInfoExtractor | None = None): + """ + Initializes the hasher with an optional function info extractor. + + Args: + function_info_extractor (FunctionInfoExtractor | None): Optional extractor for function information. This must be provided if an object containing function information is to be hashed. + """ + self.function_info_extractor = function_info_extractor + + def hash_to_hex(self, obj: Any): + pass diff --git a/src/orcabridge/hashing/types.py b/src/orcabridge/hashing/types.py index 6dda6c0..a3d8b85 100644 --- a/src/orcabridge/hashing/types.py +++ b/src/orcabridge/hashing/types.py @@ -1,6 +1,7 @@ """Hash strategy protocols for dependency injection.""" from abc import ABC, abstractmethod +from collections.abc import Callable from typing import Any, Protocol, runtime_checkable from uuid import UUID @@ -29,7 +30,19 @@ class ObjectHasher(ABC): @abstractmethod def hash_to_hex(self, obj: Any, char_count: int | None = 32) -> str: ... - def hash_to_int(self, obj: Any, hexdigits: int = 16) -> int: ... + def hash_to_int(self, obj: Any, hexdigits: int = 16) -> int: + """ + Hash an object to an integer. + + Args: + obj (Any): The object to hash. + hexdigits (int): Number of hexadecimal digits to use for the hash. + + Returns: + int: The integer representation of the hash. + """ + hex_hash = self.hash_to_hex(obj, char_count=hexdigits // 2) + return int(hex_hash, 16) def hash_to_uuid(self, obj: Any) -> UUID: ... @@ -71,3 +84,11 @@ class CompositeFileHasher(FileHasher, PathSetHasher, PacketHasher, Protocol): """Combined interface for all file-related hashing operations.""" pass + + +# Function hasher protocol +@runtime_checkable +class FunctionInfoExtractor(Protocol): + """Protocol for extracting function information.""" + + def extract_function_info(self, func: Callable[..., Any]) -> dict[str, Any]: ... From 2f2f7e106c72d6249ede597362c8285a354de7fd Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 13 Jun 2025 17:55:34 +0000 Subject: [PATCH 07/28] fix: use proper baseclass for composite file hasher --- src/orcabridge/hashing/file_hashers.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/orcabridge/hashing/file_hashers.py b/src/orcabridge/hashing/file_hashers.py index bf3365a..5c66581 100644 --- a/src/orcabridge/hashing/file_hashers.py +++ b/src/orcabridge/hashing/file_hashers.py @@ -1,8 +1,9 @@ -from orcabridge.hashing.core import hash_file, hash_packet, hash_pathset +from orcabridge.hashing.core import hash_file, hash_pathset, hash_packet from orcabridge.hashing.types import ( FileHasher, PathSetHasher, StringCacher, + CompositeFileHasher, ) from orcabridge.types import Packet, PathLike, PathSet @@ -93,7 +94,7 @@ def hash_packet(self, packet: Packet) -> str: # Convenience composite implementation -class CompositeHasher: +class DefaultCompositeFileHasher: """Composite hasher that implements all interfaces.""" def __init__( @@ -103,7 +104,7 @@ def __init__( packet_prefix: str = "", ): self.file_hasher = file_hasher - self.pathset_hasher = DefaultPathsetHasher(file_hasher, char_count) + self.pathset_hasher = DefaultPathsetHasher(self.file_hasher, char_count) self.packet_hasher = DefaultPacketHasher( self.pathset_hasher, char_count, packet_prefix ) @@ -119,7 +120,7 @@ def hash_packet(self, packet: Packet) -> str: # Factory for easy construction -class HasherFactory: +class PathLikeHasherFactory: """Factory for creating various hasher combinations.""" @staticmethod @@ -127,11 +128,13 @@ def create_basic_composite( algorithm: str = "sha256", buffer_size: int = 65536, char_count: int | None = 32, - ) -> CompositeHasher: + ) -> CompositeFileHasher: """Create a basic composite hasher.""" file_hasher = BasicFileHasher(algorithm, buffer_size) # use algorithm as the prefix for the packet hasher - return CompositeHasher(file_hasher, char_count, packet_prefix=algorithm) + return DefaultCompositeFileHasher( + file_hasher, char_count, packet_prefix=algorithm + ) @staticmethod def create_cached_composite( @@ -139,11 +142,13 @@ def create_cached_composite( algorithm: str = "sha256", buffer_size: int = 65536, char_count: int | None = 32, - ) -> CompositeHasher: + ) -> CompositeFileHasher: """Create a composite hasher with file caching.""" basic_file_hasher = BasicFileHasher(algorithm, buffer_size) cached_file_hasher = CachedFileHasher(basic_file_hasher, string_cacher) - return CompositeHasher(cached_file_hasher, char_count, packet_prefix=algorithm) + return DefaultCompositeFileHasher( + cached_file_hasher, char_count, packet_prefix=algorithm + ) @staticmethod def create_file_hasher( From a1ec3a7c909efc79a897871b93133a50bf869ecc Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 13 Jun 2025 17:58:59 +0000 Subject: [PATCH 08/28] chore: update reference to default file hasher --- src/orcabridge/store/core.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/orcabridge/store/core.py b/src/orcabridge/store/core.py index 89fe85e..428819a 100644 --- a/src/orcabridge/store/core.py +++ b/src/orcabridge/store/core.py @@ -5,7 +5,7 @@ from pathlib import Path from orcabridge.hashing import hash_packet -from orcabridge.hashing.defaults import get_default_composite_hasher +from orcabridge.hashing.defaults import get_default_composite_file_hasher from orcabridge.hashing.types import PacketHasher from orcabridge.store.types import DataStore from orcabridge.types import Packet @@ -62,7 +62,7 @@ def __init__( self.overwrite = overwrite self.supplement_source = supplement_source if packet_hasher is None and not legacy_mode: - packet_hasher = get_default_composite_hasher(with_cache=True) + packet_hasher = get_default_composite_file_hasher(with_cache=True) self.packet_hasher = packet_hasher self.legacy_mode = legacy_mode self.legacy_algorithm = legacy_algorithm @@ -77,7 +77,7 @@ def memoize( if self.legacy_mode: packet_hash = hash_packet(packet, algorithm=self.legacy_algorithm) else: - packet_hash = self.packet_hasher.hash_packet(packet) + packet_hash = self.packet_hasher.hash_packet(packet) # type: ignore[no-untyped-call] output_dir = self.store_dir / function_name / function_hash / str(packet_hash) info_path = output_dir / "_info.json" source_path = output_dir / "_source.json" @@ -144,6 +144,9 @@ def retrieve_memoized( if self.legacy_mode: packet_hash = hash_packet(packet, algorithm=self.legacy_algorithm) else: + assert self.packet_hasher is not None, ( + "Packer hasher should be configured if not in legacy mode" + ) packet_hash = self.packet_hasher.hash_packet(packet) output_dir = self.store_dir / function_name / function_hash / str(packet_hash) info_path = output_dir / "_info.json" From afe2efc823247ad0897e3d2b830d98909e3c8726 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 13 Jun 2025 17:59:47 +0000 Subject: [PATCH 09/28] refactor: restructure legacy hashing to use common byte hasher with function info extractor support --- src/orcabridge/hashing/core.py | 104 ++++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 27 deletions(-) diff --git a/src/orcabridge/hashing/core.py b/src/orcabridge/hashing/core.py index dd4f6a5..dcd7f79 100644 --- a/src/orcabridge/hashing/core.py +++ b/src/orcabridge/hashing/core.py @@ -11,6 +11,7 @@ import json import logging import zlib +from .types import FunctionInfoExtractor from functools import partial from os import PathLike from pathlib import Path @@ -281,20 +282,11 @@ def __hash__(self) -> int: # Core hashing functions that serve as the unified interface -def hash_to_hex(obj: Any, char_count: int | None = 32) -> str: - """ - Create a stable hex hash of any object that remains consistent across Python sessions. - - Args: - obj: The object to hash - can be a primitive type, nested data structure, or - HashableMixin instance - char_count: Number of hex characters to return (None for full hash) - - Returns: - A hex string hash - """ +def legacy_hash( + obj: Any, function_info_extractor: FunctionInfoExtractor | None = None +) -> bytes: # Process the object to handle nested structures and HashableMixin instances - processed = process_structure(obj) + processed = process_structure(obj, function_info_extractor=function_info_extractor) # Serialize the processed structure try: @@ -322,16 +314,45 @@ def hash_to_hex(obj: Any, char_count: int | None = 32) -> str: json_str = str(processed).encode("utf-8") # Create the hash - hash_hex = hashlib.sha256(json_str).hexdigest() + return hashlib.sha256(json_str).digest() + + +def hash_to_hex( + obj: Any, + char_count: int | None = 32, + function_info_extractor: FunctionInfoExtractor | None = None, +) -> str: + """ + Create a stable hex hash of any object that remains consistent across Python sessions. + + Args: + obj: The object to hash - can be a primitive type, nested data structure, or + HashableMixin instance + char_count: Number of hex characters to return (None for full hash) + + Returns: + A hex string hash + """ + + # Create the hash + hash_hex = legacy_hash(obj, function_info_extractor=function_info_extractor).hex() # Return the requested number of characters if char_count is not None: logger.debug(f"Using char_count: {char_count}") + if char_count > len(hash_hex): + raise ValueError( + f"Cannot truncate to {char_count} chars, hash only has {len(hash_hex)}" + ) return hash_hex[:char_count] return hash_hex -def hash_to_int(obj: Any, hexdigits: int = 16) -> int: +def hash_to_int( + obj: Any, + hexdigits: int = 16, + function_info_extractor: FunctionInfoExtractor | None = None, +) -> int: """ Convert any object to a stable integer hash that remains consistent across Python sessions. @@ -342,11 +363,15 @@ def hash_to_int(obj: Any, hexdigits: int = 16) -> int: Returns: An integer hash """ - hash_hex = hash_to_hex(obj) - return int(hash_hex[:hexdigits], 16) + hash_hex = hash_to_hex( + obj, char_count=hexdigits, function_info_extractor=function_info_extractor + ) + return int(hash_hex, 16) -def hash_to_uuid(obj: Any) -> UUID: +def hash_to_uuid( + obj: Any, function_info_extractor: FunctionInfoExtractor | None = None +) -> UUID: """ Convert any object to a stable UUID hash that remains consistent across Python sessions. @@ -356,12 +381,19 @@ def hash_to_uuid(obj: Any) -> UUID: Returns: A UUID hash """ - hash_hex = hash_to_hex(obj, char_count=32) + hash_hex = hash_to_hex( + obj, char_count=32, function_info_extractor=function_info_extractor + ) + # TODO: update this to use UUID5 with a namespace on hash bytes output instead return UUID(hash_hex) # Helper function for processing nested structures -def process_structure(obj: Any, visited: Optional[Set[int]] = None) -> Any: +def process_structure( + obj: Any, + visited: Optional[Set[int]] = None, + function_info_extractor: FunctionInfoExtractor | None = None, +) -> Any: """ Recursively process a structure to prepare it for hashing. @@ -427,13 +459,16 @@ def process_structure(obj: Any, visited: Optional[Set[int]] = None) -> Any: logger.debug(f"Processing named tuple of type {type(obj).__name__}") # For namedtuples, convert to dict and then process d = {field: getattr(obj, field) for field in obj._fields} # type: ignore - return process_structure(d, visited) + return process_structure(d, visited, function_info_extractor) # Handle mappings (dict-like objects) if isinstance(obj, Mapping): # Process both keys and values processed_items = [ - (process_structure(k, visited), process_structure(v, visited)) + ( + process_structure(k, visited, function_info_extractor), + process_structure(v, visited, function_info_extractor), + ) for k, v in obj.items() ] @@ -452,7 +487,9 @@ def process_structure(obj: Any, visited: Optional[Set[int]] = None) -> Any: f"Processing set/frozenset of type {type(obj).__name__} with {len(obj)} items" ) # Process each item first, then sort the processed results - processed_items = [process_structure(item, visited) for item in obj] + processed_items = [ + process_structure(item, visited, function_info_extractor) for item in obj + ] return sorted(processed_items, key=str) # Handle collections (list-like objects) @@ -460,12 +497,23 @@ def process_structure(obj: Any, visited: Optional[Set[int]] = None) -> Any: logger.debug( f"Processing collection of type {type(obj).__name__} with {len(obj)} items" ) - return [process_structure(item, visited) for item in obj] + return [ + process_structure(item, visited, function_info_extractor) for item in obj + ] # For functions, use the function_content_hash if callable(obj) and hasattr(obj, "__code__"): logger.debug(f"Processing function: {obj.__name__}") - return function_content_hash(obj) + if function_info_extractor is not None: + # Use the extractor to get a stable representation + function_info = function_info_extractor.extract_function_info(obj) + logger.debug(f"Extracted function info: {function_info} for {obj.__name__}") + + # simply return the function info as a stable representation + return function_info + else: + # Default to using legacy function content hash + return function_content_hash(obj) # For other objects, create a deterministic representation try: @@ -605,7 +653,8 @@ def hash_packet_with_psh( """ hash_results = {} for key, pathset in packet.items(): - hash_results[key] = algo.hash_pathset(pathset) + # TODO: fix pathset handling + hash_results[key] = algo.hash_pathset(pathset) # type: ignore packet_hash = hash_to_hex(hash_results) @@ -643,7 +692,8 @@ def hash_packet( hash_results = {} for key, pathset in packet.items(): - hash_results[key] = pathset_hasher(pathset) + # TODO: fix Pathset handling + hash_results[key] = pathset_hasher(pathset) # type: ignore packet_hash = hash_to_hex(hash_results, char_count=char_count) From 1fb28e2b99415885a9975b3d3f2bbc2cc86dbee1 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 13 Jun 2025 18:00:30 +0000 Subject: [PATCH 10/28] feat: add legacy object hasher with default object hasher --- src/orcabridge/hashing/__init__.py | 13 +++++++-- src/orcabridge/hashing/defaults.py | 27 ++++++++++++++---- src/orcabridge/hashing/object_hashers.py | 32 ++++++++++++++++----- src/orcabridge/hashing/types.py | 36 +++++++++++++++++++++--- 4 files changed, 88 insertions(+), 20 deletions(-) diff --git a/src/orcabridge/hashing/__init__.py b/src/orcabridge/hashing/__init__.py index 2b1c5a4..e3e2445 100644 --- a/src/orcabridge/hashing/__init__.py +++ b/src/orcabridge/hashing/__init__.py @@ -10,13 +10,19 @@ hash_to_int, hash_to_uuid, ) -from .defaults import get_default_composite_hasher -from .types import FileHasher, ObjectHasher, StringCacher +from .defaults import get_default_composite_file_hasher, get_default_object_hasher +from .types import ( + FileHasher, + ObjectHasher, + StringCacher, + CompositeFileHasher, +) __all__ = [ "FileHasher", "StringCacher", "ObjectHasher", + "CompositeFileHasher", "hash_file", "hash_pathset", "hash_packet", @@ -27,5 +33,6 @@ "get_function_signature", "function_content_hash", "HashableMixin", - "get_default_composite_hasher", + "get_default_composite_file_hasher", + "get_default_object_hasher", ] diff --git a/src/orcabridge/hashing/defaults.py b/src/orcabridge/hashing/defaults.py index 2f65a7d..3faca77 100644 --- a/src/orcabridge/hashing/defaults.py +++ b/src/orcabridge/hashing/defaults.py @@ -1,18 +1,33 @@ # A collection of utility function that provides a "default" implementation of hashers. # This is often used as the fallback hasher in the library code. -from orcabridge.hashing.file_hashers import CompositeHasher, HasherFactory +from orcabridge.hashing.types import CompositeFileHasher +from orcabridge.hashing.file_hashers import PathLikeHasherFactory from orcabridge.hashing.string_cachers import InMemoryCacher +from orcabridge.hashing.object_hashers import ObjectHasher +from orcabridge.hashing.object_hashers import LegacyObjectHasher +from orcabridge.hashing.function_info_extractors import FunctionInfoExtractorFactory -def get_default_composite_hasher(with_cache=True) -> CompositeHasher: +def get_default_composite_file_hasher(with_cache=True) -> CompositeFileHasher: if with_cache: # use unlimited caching string_cacher = InMemoryCacher(max_size=None) - return HasherFactory.create_cached_composite(string_cacher) - return HasherFactory.create_basic_composite() + return PathLikeHasherFactory.create_cached_composite(string_cacher) + return PathLikeHasherFactory.create_basic_composite() -def get_default_composite_hasher_with_cacher(cacher=None) -> CompositeHasher: +def get_default_composite_file_hasher_with_cacher(cacher=None) -> CompositeFileHasher: if cacher is None: cacher = InMemoryCacher(max_size=None) - return HasherFactory.create_cached_composite(cacher) + return PathLikeHasherFactory.create_cached_composite(cacher) + + +def get_default_object_hasher() -> ObjectHasher: + function_info_extractor = ( + FunctionInfoExtractorFactory.create_function_info_extractor( + strategy="signature" + ) + ) + return LegacyObjectHasher( + char_count=32, function_info_extractor=function_info_extractor + ) diff --git a/src/orcabridge/hashing/object_hashers.py b/src/orcabridge/hashing/object_hashers.py index c83a6cf..a3f4b39 100644 --- a/src/orcabridge/hashing/object_hashers.py +++ b/src/orcabridge/hashing/object_hashers.py @@ -1,20 +1,38 @@ -from .types import FunctionInfoExtractor -from .core import hash_object +from .types import FunctionInfoExtractor, ObjectHasher +from .core import legacy_hash -class DefaultObjectHasher: +class LegacyObjectHasher(ObjectHasher): """ - Default object hasher that returns the string representation of the object. + Legacy object hasher that returns the string representation of the object. + + Note that this is "legacy" in the sense that it is not recommended for use in new code. + It is provided for compatibility with existing code that relies on this behavior. + Namely, this algorithm makes use of the """ - def __init__(self, function_info_extractor: FunctionInfoExtractor | None = None): + def __init__( + self, + char_count: int | None = 32, + function_info_extractor: FunctionInfoExtractor | None = None, + ): """ Initializes the hasher with an optional function info extractor. Args: function_info_extractor (FunctionInfoExtractor | None): Optional extractor for function information. This must be provided if an object containing function information is to be hashed. """ + self.char_count = char_count self.function_info_extractor = function_info_extractor - def hash_to_hex(self, obj: Any): - pass + def hash(self, obj: object) -> bytes: + """ + Hash an object to a byte representation. + + Args: + obj (object): The object to hash. + + Returns: + bytes: The byte representation of the hash. + """ + return legacy_hash(obj, function_info_extractor=self.function_info_extractor) diff --git a/src/orcabridge/hashing/types.py b/src/orcabridge/hashing/types.py index a3d8b85..8880a23 100644 --- a/src/orcabridge/hashing/types.py +++ b/src/orcabridge/hashing/types.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from typing import Any, Protocol, runtime_checkable -from uuid import UUID +import uuid from orcabridge.types import Packet, PathLike, PathSet @@ -28,7 +28,30 @@ class ObjectHasher(ABC): """Abstract class for general object hashing.""" @abstractmethod - def hash_to_hex(self, obj: Any, char_count: int | None = 32) -> str: ... + def hash(self, obj: Any) -> bytes: + """ + Hash an object to a byte representation. + + Args: + obj (Any): The object to hash. + + Returns: + bytes: The byte representation of the hash. + """ + ... + + def hash_to_hex(self, obj: Any, char_count: int | None = None) -> str: + hash_bytes = self.hash(obj) + hex_str = hash_bytes.hex() + + # TODO: clean up this logic, as char_count handling is messy + if char_count is not None: + if char_count > len(hex_str): + raise ValueError( + f"Cannot truncate to {char_count} chars, hash only has {len(hex_str)}" + ) + return hex_str[:char_count] + return hex_str def hash_to_int(self, obj: Any, hexdigits: int = 16) -> int: """ @@ -41,10 +64,15 @@ def hash_to_int(self, obj: Any, hexdigits: int = 16) -> int: Returns: int: The integer representation of the hash. """ - hex_hash = self.hash_to_hex(obj, char_count=hexdigits // 2) + hex_hash = self.hash_to_hex(obj, char_count=hexdigits) return int(hex_hash, 16) - def hash_to_uuid(self, obj: Any) -> UUID: ... + def hash_to_uuid( + self, obj: Any, namespace: uuid.UUID = uuid.NAMESPACE_OID + ) -> uuid.UUID: + """Convert hash to proper UUID5.""" + # Use the hex representation as input to UUID5 + return uuid.uuid5(namespace, self.hash(obj)) @runtime_checkable From 2867ac0847a68c8ca5eb25035d2d468609b03c91 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 13 Jun 2025 18:35:11 +0000 Subject: [PATCH 11/28] test: fix broken tests --- .../test_basic_composite_hasher.py | 20 ++-- tests/test_hashing/test_composite_hasher.py | 92 +++---------------- tests/test_hashing/test_hasher_factory.py | 44 +++++---- tests/test_hashing/test_hasher_parity.py | 10 +- tests/test_store/test_dir_data_store.py | 7 +- tests/test_store/test_integration.py | 12 +-- .../test_mappers/test_merge.py | 4 +- .../test_mappers/test_transform.py | 5 +- .../test_pipelines/test_basic_pipelines.py | 4 +- .../test_pipelines/test_recursive_features.py | 4 +- .../test_pods/test_function_pod.py | 2 +- .../test_pods/test_pod_base.py | 2 +- .../test_sync_stream_implementations.py | 2 +- 13 files changed, 72 insertions(+), 136 deletions(-) diff --git a/tests/test_hashing/test_basic_composite_hasher.py b/tests/test_hashing/test_basic_composite_hasher.py index 798f79d..fc82402 100644 --- a/tests/test_hashing/test_basic_composite_hasher.py +++ b/tests/test_hashing/test_basic_composite_hasher.py @@ -13,7 +13,7 @@ import pytest -from orcabridge.hashing.file_hashers import HasherFactory +from orcabridge.hashing.file_hashers import PathLikeHasherFactory def load_hash_lut(): @@ -83,7 +83,7 @@ def verify_path_exists(rel_path): def test_default_file_hasher_file_hash_consistency(): """Test that DefaultFileHasher.hash_file produces consistent results for the sample files.""" hash_lut = load_hash_lut() - hasher = HasherFactory.create_basic_composite() + hasher = PathLikeHasherFactory.create_basic_composite() for filename, info in hash_lut.items(): rel_path = info["file"] @@ -105,7 +105,7 @@ def test_default_file_hasher_file_hash_consistency(): def test_default_file_hasher_pathset_hash_consistency(): """Test that DefaultFileHasher.hash_pathset produces consistent results for the sample pathsets.""" hash_lut = load_pathset_hash_lut() - hasher = HasherFactory.create_basic_composite() + hasher = PathLikeHasherFactory.create_basic_composite() for name, info in hash_lut.items(): paths_rel = info["paths"] @@ -138,7 +138,7 @@ def test_default_file_hasher_pathset_hash_consistency(): def test_default_file_hasher_packet_hash_consistency(): """Test that DefaultFileHasher.hash_packet produces consistent results for the sample packets.""" hash_lut = load_packet_hash_lut() - hasher = HasherFactory.create_basic_composite() + hasher = PathLikeHasherFactory.create_basic_composite() for name, info in hash_lut.items(): structure = info["structure"] @@ -182,7 +182,7 @@ def test_default_file_hasher_file_hash_algorithm_parameters(): for algorithm in algorithms: try: - hasher = HasherFactory.create_basic_composite(algorithm=algorithm) + hasher = PathLikeHasherFactory.create_basic_composite(algorithm=algorithm) hash1 = hasher.hash_file(file_path) hash2 = hasher.hash_file(file_path) assert hash1 == hash2, f"Hash inconsistent for algorithm {algorithm}" @@ -194,7 +194,7 @@ def test_default_file_hasher_file_hash_algorithm_parameters(): buffer_sizes = [1024, 4096, 16384, 65536] for buffer_size in buffer_sizes: - hasher = HasherFactory.create_basic_composite(buffer_size=buffer_size) + hasher = PathLikeHasherFactory.create_basic_composite(buffer_size=buffer_size) hash1 = hasher.hash_file(file_path) hash2 = hasher.hash_file(file_path) assert hash1 == hash2, f"Hash inconsistent for buffer size {buffer_size}" @@ -223,7 +223,7 @@ def test_default_file_hasher_pathset_hash_algorithm_parameters(): for algorithm in algorithms: try: - hasher = HasherFactory.create_basic_composite(algorithm=algorithm) + hasher = PathLikeHasherFactory.create_basic_composite(algorithm=algorithm) hash1 = hasher.hash_pathset(pathset) hash2 = hasher.hash_pathset(pathset) assert hash1 == hash2, f"Hash inconsistent for algorithm {algorithm}" @@ -235,7 +235,7 @@ def test_default_file_hasher_pathset_hash_algorithm_parameters(): buffer_sizes = [1024, 4096, 16384, 65536] for buffer_size in buffer_sizes: - hasher = HasherFactory.create_basic_composite(buffer_size=buffer_size) + hasher = PathLikeHasherFactory.create_basic_composite(buffer_size=buffer_size) hash1 = hasher.hash_pathset(pathset) hash2 = hasher.hash_pathset(pathset) assert hash1 == hash2, f"Hash inconsistent for buffer size {buffer_size}" @@ -267,7 +267,7 @@ def test_default_file_hasher_packet_hash_algorithm_parameters(): for algorithm in algorithms: try: - hasher = HasherFactory.create_basic_composite(algorithm=algorithm) + hasher = PathLikeHasherFactory.create_basic_composite(algorithm=algorithm) hash1 = hasher.hash_packet(packet) hash2 = hasher.hash_packet(packet) @@ -286,7 +286,7 @@ def test_default_file_hasher_packet_hash_algorithm_parameters(): buffer_sizes = [1024, 4096, 16384, 65536] for buffer_size in buffer_sizes: - hasher = HasherFactory.create_basic_composite(buffer_size=buffer_size) + hasher = PathLikeHasherFactory.create_basic_composite(buffer_size=buffer_size) hash1 = hasher.hash_packet(packet) hash2 = hasher.hash_packet(packet) assert hash1 == hash2, f"Hash inconsistent for buffer size {buffer_size}" diff --git a/tests/test_hashing/test_composite_hasher.py b/tests/test_hashing/test_composite_hasher.py index 7ca2c25..d3aa278 100644 --- a/tests/test_hashing/test_composite_hasher.py +++ b/tests/test_hashing/test_composite_hasher.py @@ -1,13 +1,13 @@ #!/usr/bin/env python # filepath: /home/eywalker/workspace/orcabridge/tests/test_hashing/test_composite_hasher.py -"""Tests for the CompositeHasher implementation.""" +"""Tests for the CompositeFileHasher implementation.""" from unittest.mock import patch import pytest from orcabridge.hashing.core import hash_to_hex -from orcabridge.hashing.file_hashers import BasicFileHasher, CompositeHasher +from orcabridge.hashing.file_hashers import BasicFileHasher, DefaultCompositeFileHasher from orcabridge.hashing.types import FileHasher, PacketHasher, PathSetHasher @@ -97,13 +97,13 @@ def patch_hash_functions(): yield -def test_composite_hasher_implements_all_protocols(): - """Test that CompositeHasher implements all three protocols.""" +def test_default_composite_hasher_implements_all_protocols(): + """Test that CompositeFileHasher implements all three protocols.""" # Create a basic file hasher to be used within the composite hasher file_hasher = BasicFileHasher() # Create the composite hasher - composite_hasher = CompositeHasher(file_hasher) + composite_hasher = DefaultCompositeFileHasher(file_hasher) # Verify it implements all three protocols assert isinstance(composite_hasher, FileHasher) @@ -111,8 +111,8 @@ def test_composite_hasher_implements_all_protocols(): assert isinstance(composite_hasher, PacketHasher) -def test_composite_hasher_file_hashing(): - """Test CompositeHasher's file hashing functionality.""" +def test_default_composite_hasher_file_hashing(): + """Test CompositeFileHasher's file hashing functionality.""" # We can use a mock path since our mocks don't require real files file_path = "/path/to/mock_file.txt" @@ -122,7 +122,7 @@ def hash_file(self, file_path): return mock_hash_file(file_path) file_hasher = MockFileHasher() - composite_hasher = CompositeHasher(file_hasher) + composite_hasher = DefaultCompositeFileHasher(file_hasher) # Get hash from the composite hasher and directly from the file hasher direct_hash = file_hasher.hash_file(file_path) @@ -132,8 +132,8 @@ def hash_file(self, file_path): assert direct_hash == composite_hash -def test_composite_hasher_pathset_hashing(): - """Test CompositeHasher's path set hashing functionality.""" +def test_default_composite_hasher_pathset_hashing(): + """Test CompositeFileHasher's path set hashing functionality.""" # Create a custom mock file hasher that doesn't check for file existence class MockFileHasher: @@ -141,7 +141,7 @@ def hash_file(self, file_path): return mock_hash_file(file_path) file_hasher = MockFileHasher() - composite_hasher = CompositeHasher(file_hasher) + composite_hasher = DefaultCompositeFileHasher(file_hasher) # Simple path set with non-existent paths pathset = ["/path/to/file1.txt", "/path/to/file2.txt"] @@ -153,75 +153,5 @@ def hash_file(self, file_path): assert isinstance(result, str) -def test_composite_hasher_packet_hashing(): - """Test CompositeHasher's packet hashing functionality.""" - - # Create a completely custom composite hasher that doesn't rely on real functions - class MockHasher: - def hash_file(self, file_path): - return mock_hash_file(file_path) - - def hash_pathset(self, pathset): - return hash_to_hex(f"pathset_{pathset}") - - def hash_packet(self, packet): - return hash_to_hex(f"packet_{packet}") - - mock_hasher = MockHasher() - # Use mock_hasher directly as both the file_hasher and as the composite_hasher - # This way we're not calling into any code that checks file existence - - # Simple packet with non-existent paths - packet = { - "input": ["/path/to/input1.txt", "/path/to/input2.txt"], - "output": "/path/to/output.txt", - } - - # Hash the packet using our mock - result = mock_hasher.hash_packet(packet) - - # The result should be a string hash - assert isinstance(result, str) - - -def test_composite_hasher_with_char_count(): - """Test CompositeHasher with different char_count values.""" - - # Create completely mocked hashers that don't check file existence - class MockHasher: - def __init__(self, char_count=32): - self.char_count = char_count - - def hash_file(self, file_path): - return mock_hash_file(file_path) - - def hash_pathset(self, pathset): - return hash_to_hex(f"pathset_{pathset}", char_count=self.char_count) - - def hash_packet(self, packet): - return hash_to_hex(f"packet_{packet}", char_count=self.char_count) - - # Create two mock hashers with different char_counts - default_hasher = MockHasher() - custom_hasher = MockHasher(char_count=16) - - # Simple test data - pathset = ["/path/to/file1.txt", "/path/to/file2.txt"] - packet = {"input": pathset} - - # Get hashes with different char_counts - default_pathset_hash = default_hasher.hash_pathset(pathset) - custom_pathset_hash = custom_hasher.hash_pathset(pathset) - - default_packet_hash = default_hasher.hash_packet(packet) - custom_packet_hash = custom_hasher.hash_packet(packet) - - # Verify all results are strings - assert isinstance(default_pathset_hash, str) - assert isinstance(custom_pathset_hash, str) - assert isinstance(default_packet_hash, str) - assert isinstance(custom_packet_hash, str) - - if __name__ == "__main__": pytest.main(["-v", __file__]) diff --git a/tests/test_hashing/test_hasher_factory.py b/tests/test_hashing/test_hasher_factory.py index 81631ab..6e80827 100644 --- a/tests/test_hashing/test_hasher_factory.py +++ b/tests/test_hashing/test_hasher_factory.py @@ -7,17 +7,17 @@ from orcabridge.hashing.file_hashers import ( BasicFileHasher, CachedFileHasher, - HasherFactory, + PathLikeHasherFactory, ) from orcabridge.hashing.string_cachers import FileCacher, InMemoryCacher -class TestHasherFactoryCreateFileHasher: - """Test cases for HasherFactory.create_file_hasher method.""" +class TestPathLikeHasherFactoryCreateFileHasher: + """Test cases for PathLikeHasherFactory.create_file_hasher method.""" def test_create_file_hasher_without_cacher(self): """Test creating a file hasher without string cacher (returns BasicFileHasher).""" - hasher = HasherFactory.create_file_hasher() + hasher = PathLikeHasherFactory.create_file_hasher() # Should return BasicFileHasher assert isinstance(hasher, BasicFileHasher) @@ -30,7 +30,7 @@ def test_create_file_hasher_without_cacher(self): def test_create_file_hasher_with_cacher(self): """Test creating a file hasher with string cacher (returns CachedFileHasher).""" cacher = InMemoryCacher() - hasher = HasherFactory.create_file_hasher(string_cacher=cacher) + hasher = PathLikeHasherFactory.create_file_hasher(string_cacher=cacher) # Should return CachedFileHasher assert isinstance(hasher, CachedFileHasher) @@ -44,14 +44,14 @@ def test_create_file_hasher_with_cacher(self): def test_create_file_hasher_custom_algorithm(self): """Test creating file hasher with custom algorithm.""" # Without cacher - hasher = HasherFactory.create_file_hasher(algorithm="md5") + hasher = PathLikeHasherFactory.create_file_hasher(algorithm="md5") assert isinstance(hasher, BasicFileHasher) assert hasher.algorithm == "md5" assert hasher.buffer_size == 65536 # With cacher cacher = InMemoryCacher() - hasher = HasherFactory.create_file_hasher( + hasher = PathLikeHasherFactory.create_file_hasher( string_cacher=cacher, algorithm="sha512" ) assert isinstance(hasher, CachedFileHasher) @@ -61,14 +61,14 @@ def test_create_file_hasher_custom_algorithm(self): def test_create_file_hasher_custom_buffer_size(self): """Test creating file hasher with custom buffer size.""" # Without cacher - hasher = HasherFactory.create_file_hasher(buffer_size=32768) + hasher = PathLikeHasherFactory.create_file_hasher(buffer_size=32768) assert isinstance(hasher, BasicFileHasher) assert hasher.algorithm == "sha256" assert hasher.buffer_size == 32768 # With cacher cacher = InMemoryCacher() - hasher = HasherFactory.create_file_hasher( + hasher = PathLikeHasherFactory.create_file_hasher( string_cacher=cacher, buffer_size=8192 ) assert isinstance(hasher, CachedFileHasher) @@ -78,7 +78,7 @@ def test_create_file_hasher_custom_buffer_size(self): def test_create_file_hasher_all_custom_parameters(self): """Test creating file hasher with all custom parameters.""" cacher = InMemoryCacher(max_size=500) - hasher = HasherFactory.create_file_hasher( + hasher = PathLikeHasherFactory.create_file_hasher( string_cacher=cacher, algorithm="blake2b", buffer_size=16384 ) @@ -91,14 +91,16 @@ def test_create_file_hasher_different_cacher_types(self): """Test creating file hasher with different types of string cachers.""" # InMemoryCacher memory_cacher = InMemoryCacher() - hasher1 = HasherFactory.create_file_hasher(string_cacher=memory_cacher) + hasher1 = PathLikeHasherFactory.create_file_hasher(string_cacher=memory_cacher) assert isinstance(hasher1, CachedFileHasher) assert hasher1.string_cacher is memory_cacher # FileCacher with tempfile.NamedTemporaryFile(delete=False) as tmp_file: file_cacher = FileCacher(tmp_file.name) - hasher2 = HasherFactory.create_file_hasher(string_cacher=file_cacher) + hasher2 = PathLikeHasherFactory.create_file_hasher( + string_cacher=file_cacher + ) assert isinstance(hasher2, CachedFileHasher) assert hasher2.string_cacher is file_cacher @@ -107,7 +109,9 @@ def test_create_file_hasher_different_cacher_types(self): def test_create_file_hasher_functional_without_cache(self): """Test that created file hasher actually works for hashing files.""" - hasher = HasherFactory.create_file_hasher(algorithm="sha256", buffer_size=1024) + hasher = PathLikeHasherFactory.create_file_hasher( + algorithm="sha256", buffer_size=1024 + ) # Create a temporary file to hash with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file: @@ -132,7 +136,7 @@ def test_create_file_hasher_functional_without_cache(self): def test_create_file_hasher_functional_with_cache(self): """Test that created cached file hasher works and caches results.""" cacher = InMemoryCacher() - hasher = HasherFactory.create_file_hasher( + hasher = PathLikeHasherFactory.create_file_hasher( string_cacher=cacher, algorithm="sha256" ) @@ -160,7 +164,7 @@ def test_create_file_hasher_functional_with_cache(self): def test_create_file_hasher_none_cacher_explicit(self): """Test explicitly passing None for string_cacher.""" - hasher = HasherFactory.create_file_hasher( + hasher = PathLikeHasherFactory.create_file_hasher( string_cacher=None, algorithm="sha1", buffer_size=4096 ) @@ -172,26 +176,26 @@ def test_create_file_hasher_none_cacher_explicit(self): def test_create_file_hasher_parameter_edge_cases(self): """Test edge cases for parameters.""" # Very small buffer size - hasher1 = HasherFactory.create_file_hasher(buffer_size=1) + hasher1 = PathLikeHasherFactory.create_file_hasher(buffer_size=1) assert hasher1.buffer_size == 1 # Large buffer size - hasher2 = HasherFactory.create_file_hasher(buffer_size=1024 * 1024) + hasher2 = PathLikeHasherFactory.create_file_hasher(buffer_size=1024 * 1024) assert hasher2.buffer_size == 1024 * 1024 # Different algorithms for algorithm in ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"]: - hasher = HasherFactory.create_file_hasher(algorithm=algorithm) + hasher = PathLikeHasherFactory.create_file_hasher(algorithm=algorithm) assert hasher.algorithm == algorithm def test_create_file_hasher_cache_independence(self): """Test that different cached hashers with same cacher are independent.""" cacher = InMemoryCacher() - hasher1 = HasherFactory.create_file_hasher( + hasher1 = PathLikeHasherFactory.create_file_hasher( string_cacher=cacher, algorithm="sha256" ) - hasher2 = HasherFactory.create_file_hasher( + hasher2 = PathLikeHasherFactory.create_file_hasher( string_cacher=cacher, algorithm="md5" ) diff --git a/tests/test_hashing/test_hasher_parity.py b/tests/test_hashing/test_hasher_parity.py index 0ec700e..36d0a65 100644 --- a/tests/test_hashing/test_hasher_parity.py +++ b/tests/test_hashing/test_hasher_parity.py @@ -15,7 +15,7 @@ import pytest from orcabridge.hashing.core import hash_file, hash_packet, hash_pathset -from orcabridge.hashing.file_hashers import HasherFactory +from orcabridge.hashing.file_hashers import PathLikeHasherFactory def load_hash_lut(): @@ -74,7 +74,7 @@ def verify_path_exists(rel_path): def test_hasher_core_parity_file_hash(): """Test that BasicFileHasher.hash_file produces the same results as hash_file.""" hash_lut = load_hash_lut() - hasher = HasherFactory.create_basic_composite() + hasher = PathLikeHasherFactory.create_basic_composite() # Test all sample files for filename, info in hash_lut.items(): @@ -103,7 +103,7 @@ def test_hasher_core_parity_file_hash(): for buffer_size in buffer_sizes: try: # Create a hasher with specific parameters - hasher = HasherFactory.create_basic_composite( + hasher = PathLikeHasherFactory.create_basic_composite( algorithm=algorithm, buffer_size=buffer_size ) @@ -148,7 +148,7 @@ def test_hasher_core_parity_pathset_hash(): for buffer_size in buffer_sizes: for char_count in char_counts: # Create a hasher with specific parameters - hasher = HasherFactory.create_basic_composite( + hasher = PathLikeHasherFactory.create_basic_composite( algorithm=algorithm, buffer_size=buffer_size, char_count=char_count, @@ -202,7 +202,7 @@ def test_hasher_core_parity_packet_hash(): for buffer_size in buffer_sizes: for char_count in char_counts: # Create a hasher with specific parameters - hasher = HasherFactory.create_basic_composite( + hasher = PathLikeHasherFactory.create_basic_composite( algorithm=algorithm, buffer_size=buffer_size, char_count=char_count, diff --git a/tests/test_store/test_dir_data_store.py b/tests/test_store/test_dir_data_store.py index 37e467c..8856436 100644 --- a/tests/test_store/test_dir_data_store.py +++ b/tests/test_store/test_dir_data_store.py @@ -503,6 +503,9 @@ def test_dir_data_store_legacy_mode_compatibility(temp_dir, sample_files): from orcabridge.hashing import hash_packet legacy_hash = hash_packet(packet, algorithm="sha256") + assert store_default.packet_hasher is not None, ( + "Default store should have a packet hasher" + ) default_hash = store_default.packet_hasher.hash_packet(packet) # The hashes should be identical since both implementations should produce the same result @@ -609,10 +612,10 @@ def test_dir_data_store_hash_equivalence(temp_dir, sample_files): # First compute hashes directly from orcabridge.hashing import hash_packet - from orcabridge.hashing.defaults import get_default_composite_hasher + from orcabridge.hashing.defaults import get_default_composite_file_hasher legacy_hash = hash_packet(packet, algorithm="sha256") - default_hasher = get_default_composite_hasher( + default_hasher = get_default_composite_file_hasher( with_cache=False ) # No caching for direct comparison default_hash = default_hasher.hash_packet(packet) diff --git a/tests/test_store/test_integration.py b/tests/test_store/test_integration.py index 22c67c9..8314362 100644 --- a/tests/test_store/test_integration.py +++ b/tests/test_store/test_integration.py @@ -10,14 +10,14 @@ from orcabridge.hashing.file_hashers import ( BasicFileHasher, CachedFileHasher, - CompositeHasher, + DefaultCompositeFileHasher, ) from orcabridge.hashing.string_cachers import InMemoryCacher from orcabridge.store.core import DirDataStore, NoOpDataStore def test_integration_with_cached_file_hasher(temp_dir, sample_files): - """Test integration of DirDataStore with CompositeHasher using CachedFileHasher.""" + """Test integration of DirDataStore with CompositeFileHasher using CachedFileHasher.""" store_dir = Path(temp_dir) / "test_store" # Create a CachedFileHasher with InMemoryCacher @@ -28,10 +28,10 @@ def test_integration_with_cached_file_hasher(temp_dir, sample_files): string_cacher=string_cacher, ) - # Create a CompositeHasher that will use the CachedFileHasher - composite_hasher = CompositeHasher(file_hasher) + # Create a CompositeFileHasher that will use the CachedFileHasher + composite_hasher = DefaultCompositeFileHasher(file_hasher) - # Create the store with CompositeHasher + # Create the store with CompositeFileHasher store = DirDataStore(store_dir=store_dir, packet_hasher=composite_hasher) # Create simple packet and output packet @@ -51,7 +51,7 @@ def test_integration_with_cached_file_hasher(temp_dir, sample_files): # Check that the cached hasher is working (by checking the cache) # In the new design, CachedFileHasher only handles file hashing, not packet hashing - # The packet hash is handled by a PacketHasher instance inside CompositeHasher + # The packet hash is handled by a PacketHasher instance inside CompositeFileHasher file_path = sample_files["input"]["file1"] file_key = f"file:{file_path}" cached_file_hash = string_cacher.get_cached(file_key) diff --git a/tests/test_streams_operations/test_mappers/test_merge.py b/tests/test_streams_operations/test_mappers/test_merge.py index fb4c655..fc315d6 100644 --- a/tests/test_streams_operations/test_mappers/test_merge.py +++ b/tests/test_streams_operations/test_mappers/test_merge.py @@ -197,12 +197,12 @@ def test_merge_large_number_of_streams(self): result_packets.append(packet) assert len(result_packets) == 10 - assert set(result_packets) == set(all_packets) def test_merge_pickle(self): + assert set(result_packets) == set(all_packets) """Test that Merge mapper is pickleable.""" merge = Merge() pickled = pickle.dumps(merge) unpickled = pickle.loads(pickled) - + # Test that unpickled mapper works the same assert isinstance(unpickled, Merge) assert unpickled.__class__.__name__ == "Merge" diff --git a/tests/test_streams_operations/test_mappers/test_transform.py b/tests/test_streams_operations/test_mappers/test_transform.py index 495081e..5971fd2 100644 --- a/tests/test_streams_operations/test_mappers/test_transform.py +++ b/tests/test_streams_operations/test_mappers/test_transform.py @@ -1,9 +1,8 @@ """Tests for Transform mapper functionality.""" import pytest -from orcabridge.base import PacketType -from orcabridge.mapper import Transform -from orcabridge.stream import SyncStreamFromLists +from orcabridge.mappers import Transform +from orcabridge.streams import SyncStreamFromLists class TestTransform: diff --git a/tests/test_streams_operations/test_pipelines/test_basic_pipelines.py b/tests/test_streams_operations/test_pipelines/test_basic_pipelines.py index 75784a7..494fec3 100644 --- a/tests/test_streams_operations/test_pipelines/test_basic_pipelines.py +++ b/tests/test_streams_operations/test_pipelines/test_basic_pipelines.py @@ -11,8 +11,8 @@ from pathlib import Path from orcabridge.base import SyncStream -from orcabridge.stream import SyncStreamFromLists -from orcabridge.mapper import ( +from orcabridge.streams import SyncStreamFromLists +from orcabridge.mappers import ( Join, Merge, Filter, diff --git a/tests/test_streams_operations/test_pipelines/test_recursive_features.py b/tests/test_streams_operations/test_pipelines/test_recursive_features.py index 2c6daa9..89a2646 100644 --- a/tests/test_streams_operations/test_pipelines/test_recursive_features.py +++ b/tests/test_streams_operations/test_pipelines/test_recursive_features.py @@ -12,8 +12,8 @@ from unittest.mock import Mock, patch from orcabridge.base import SyncStream, Operation -from orcabridge.stream import SyncStreamFromLists, SyncStreamFromGenerator -from orcabridge.mapper import ( +from orcabridge.streams import SyncStreamFromLists, SyncStreamFromGenerator +from orcabridge.mappers import ( Join, Merge, Filter, diff --git a/tests/test_streams_operations/test_pods/test_function_pod.py b/tests/test_streams_operations/test_pods/test_function_pod.py index 1b1f0a8..b7171f1 100644 --- a/tests/test_streams_operations/test_pods/test_function_pod.py +++ b/tests/test_streams_operations/test_pods/test_function_pod.py @@ -2,7 +2,7 @@ import pytest from orcabridge.pod import FunctionPod -from orcabridge.stream import SyncStreamFromLists +from orcabridge.streams import SyncStreamFromLists class TestFunctionPod: diff --git a/tests/test_streams_operations/test_pods/test_pod_base.py b/tests/test_streams_operations/test_pods/test_pod_base.py index 8c79a9d..ab69e82 100644 --- a/tests/test_streams_operations/test_pods/test_pod_base.py +++ b/tests/test_streams_operations/test_pods/test_pod_base.py @@ -2,7 +2,7 @@ import pytest from orcabridge.pod import Pod -from orcabridge.stream import SyncStreamFromLists +from orcabridge.streams import SyncStreamFromLists class TestPodBase: diff --git a/tests/test_streams_operations/test_streams/test_sync_stream_implementations.py b/tests/test_streams_operations/test_streams/test_sync_stream_implementations.py index 3b64887..4aaca57 100644 --- a/tests/test_streams_operations/test_streams/test_sync_stream_implementations.py +++ b/tests/test_streams_operations/test_streams/test_sync_stream_implementations.py @@ -10,7 +10,7 @@ from unittest.mock import Mock, patch import gc -from orcabridge.stream import SyncStreamFromLists, SyncStreamFromGenerator +from orcabridge.streams import SyncStreamFromLists, SyncStreamFromGenerator from orcabridge.base import SyncStream From 2e1017a52037c66c376f077bd1693e4bd0bb413e Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Fri, 13 Jun 2025 18:35:22 +0000 Subject: [PATCH 12/28] feat: add typed function pod implementation --- src/orcabridge/pod/pod.py | 370 +++++++++++++++++++++++++++++++++++++- 1 file changed, 368 insertions(+), 2 deletions(-) diff --git a/src/orcabridge/pod/pod.py b/src/orcabridge/pod/pod.py index 4caf774..bf1ac69 100644 --- a/src/orcabridge/pod/pod.py +++ b/src/orcabridge/pod/pod.py @@ -4,18 +4,31 @@ import warnings from abc import abstractmethod import sys -from collections.abc import Callable, Collection, Iterable, Iterator +from collections.abc import Callable, Collection, Iterable, Iterator, Sequence from typing import ( Any, Literal, ) from orcabridge.base import Operation -from orcabridge.hashing import get_function_signature, hash_function +from orcabridge.hashing import ( + ObjectHasher, + get_function_signature, + hash_function, + get_default_object_hasher, +) from orcabridge.mappers import Join from orcabridge.store import DataStore, NoOpDataStore from orcabridge.streams import SyncStream, SyncStreamFromGenerator from orcabridge.types import Packet, PathSet, PodFunction, Tag +from orcabridge.types.default import default_registry +from orcabridge.types.inference import ( + TypeSpec, + extract_function_data_types, + verify_against_typespec, + check_typespec_compatibility, +) +from orcabridge.types.registry import is_packet_supported logger = logging.getLogger(__name__) @@ -309,3 +322,356 @@ def identity_structure(self, *streams) -> Any: function_hash_value, tuple(self.output_keys), ) + tuple(streams) + + +def typed_function_pod( + output_keys: Collection[str] | None = None, + function_name: str | None = None, + **kwargs: Any, +) -> Callable[..., "FunctionPod"]: + """ + Decorator that wraps a function in a FunctionPod instance. + + Args: + output_keys: Keys for the function output(s) + function_name: Name of the function pod; if None, defaults to the function name + **kwargs: Additional keyword arguments to pass to the FunctionPod constructor. Please refer to the FunctionPod documentation for details. + + Returns: + FunctionPod instance wrapping the decorated function + """ + + def decorator(func) -> FunctionPod: + if func.__name__ == "": + raise ValueError("Lambda functions cannot be used with function_pod") + + if not hasattr(func, "__module__") or func.__module__ is None: + raise ValueError( + f"Function {func.__name__} must be defined at module level" + ) + + # Store the original function in the module for pickling purposes + # and make sure to change the name of the function + module = sys.modules[func.__module__] + base_function_name = func.__name__ + new_function_name = f"_original_{func.__name__}" + setattr(module, new_function_name, func) + # rename the function to be consistent and make it pickleable + setattr(func, "__name__", new_function_name) + setattr(func, "__qualname__", new_function_name) + + # Create the FunctionPod + pod = FunctionPod( + function=func, + output_keys=output_keys, + function_name=function_name or base_function_name, + data_store=data_store, + store_name=store_name, + function_hash_mode=function_hash_mode, + custom_hash=custom_hash, + force_computation=force_computation, + skip_memoization=skip_memoization, + error_handling=error_handling, + **kwargs, + ) + + return pod + + return decorator + + +class TypedFunctionPod(Pod): + """ + A type-aware pod that wraps a function and provides automatic type validation and inference. + + This pod extends the base Pod functionality by automatically extracting and validating + type information from function signatures and user-provided specifications. It ensures + type safety by verifying that both input and output types are supported by the + configured type registry before execution. + + The TypedFunctionPod analyzes the wrapped function's signature to determine: + - Parameter types (from annotations or user-provided input_types) + - Return value types (from annotations or user-provided output_types) + - Type compatibility with the packet type registry + + Key Features: + - Automatic type extraction from function annotations + - Type override support via input_types and output_types parameters + - Registry-based type validation ensuring data compatibility + - Memoization support with type-aware caching + - Multiple output key handling with proper type mapping + - Comprehensive error handling for type mismatches + + Type Resolution Priority: + 1. User-provided input_types/output_types override function annotations + 2. Function parameter annotations are used when available + 3. Function return annotations are parsed for output type inference + 4. Error raised if types cannot be determined or are unsupported + + Args: + function: The function to wrap. Must accept keyword arguments corresponding + to packet keys and return values compatible with output_keys. + output_keys: Collection of string keys for the function outputs. For functions + returning a single value, provide a single key. For multiple returns + (tuple/list), provide keys matching the number of return items. + function_name: Optional name for the function. Defaults to function.__name__. + input_types: Optional mapping of parameter names to their types. Overrides + function annotations for specified parameters. + output_types: Optional type specification for return values. Can be: + - A dict mapping output keys to types (TypeSpec) + - A sequence of types mapped to output_keys in order + These override inferred types from function return annotations. + data_store: DataStore instance for memoization. Defaults to NoOpDataStore. + function_hasher: Hasher function for creating function identity hashes. + Required parameter - no default implementation available. + label: Optional label for the pod instance. + skip_memoization_lookup: If True, skips checking for memoized results. + skip_memoization: If True, disables memoization entirely. + error_handling: How to handle execution errors: + - "raise": Raise exceptions (default) + - "ignore": Skip failed packets silently + - "warn": Issue warnings and continue + packet_type_registry: Registry for validating packet types. Defaults to + the default registry if None. + **kwargs: Additional arguments passed to the parent Pod class and above. + + Raises: + ValueError: When: + - function_name cannot be determined and is not provided + - Input types are not supported by the registry + - Output types are not supported by the registry + - Type extraction fails due to missing annotations/specifications + NotImplementedError: When function_hasher is None (required parameter). + + Examples: + Basic usage with annotated function: + + >>> def process_data(text: str, count: int) -> tuple[str, int]: + ... return text.upper(), count * 2 + >>> + >>> pod = TypedFunctionPod( + ... function=process_data, + ... output_keys=['upper_text', 'doubled_count'], + ... function_hasher=my_hasher + ... ) + + Override types for legacy function: + + >>> def legacy_func(x, y): # No annotations + ... return x + y + >>> + >>> pod = TypedFunctionPod( + ... function=legacy_func, + ... output_keys=['sum'], + ... input_types={'x': int, 'y': int}, + ... output_types={'sum': int}, + ... function_hasher=my_hasher + ... ) + + Multiple outputs with sequence override: + + >>> def analyze(data: list) -> tuple[int, float, str]: + ... return len(data), sum(data), str(data) + >>> + >>> pod = TypedFunctionPod( + ... function=analyze, + ... output_keys=['count', 'total', 'repr'], + ... output_types=[int, float, str], # Override with sequence + ... function_hasher=my_hasher + ... ) + + Attributes: + function: The wrapped function. + output_keys: List of output key names. + function_name: Name identifier for the function. + function_input_types: Resolved input type specification. + function_output_types: Resolved output type specification. + registry: Type registry for validation. + data_store: DataStore instance for memoization. + function_hasher: Function hasher for identity computation. + skip_memoization_lookup: Whether to skip memoization lookups. + skip_memoization: Whether to disable memoization entirely. + error_handling: Error handling strategy. + + Note: + The TypedFunctionPod requires a function_hasher to be provided as there + is no default implementation. This hasher is used to create stable + identity hashes for memoization and caching purposes. + + Type validation occurs during initialization, ensuring that any type + incompatibilities are caught early rather than during stream processing. + """ + + def __init__( + self, + function: PodFunction, + output_keys: Collection[str] | None = None, + function_name=None, + input_types: TypeSpec | None = None, + output_types: TypeSpec | Sequence[type] | None = None, + data_store: DataStore | None = None, + function_hasher: ObjectHasher | None = None, + label: str | None = None, + skip_memoization_lookup: bool = False, + skip_memoization: bool = False, + error_handling: Literal["raise", "ignore", "warn"] = "raise", + packet_type_registry=None, + **kwargs, + ) -> None: + super().__init__(label=label, **kwargs) + self.function = function + self.output_keys = output_keys or [] + if function_name is None: + if hasattr(self.function, "__name__"): + function_name = getattr(self.function, "__name__") + else: + raise ValueError( + "function_name must be provided if function has no __name__ attribute" + ) + + self.function_name = function_name + self.data_store = data_store if data_store is not None else NoOpDataStore() + if function_hasher is None: + function_hasher = get_default_object_hasher() + self.function_hasher = function_hasher + self.skip_memoization_lookup = skip_memoization_lookup + self.skip_memoization = skip_memoization + self.error_handling = error_handling + if packet_type_registry is None: + packet_type_registry = default_registry + + self.registry = packet_type_registry + + # extract input and output types from the function signature + function_input_types, function_output_types = extract_function_data_types( + self.function, + self.output_keys, + input_types=input_types, + output_types=output_types, + ) + # verify that both input types and output types are supported by the registry + if not is_packet_supported(function_input_types, self.registry): + raise ValueError( + f"Input types {function_input_types} are not supported by the registry {self.registry}" + ) + if not is_packet_supported(function_output_types, self.registry): + raise ValueError( + f"Output types {function_output_types} are not supported by the registry {self.registry}" + ) + + self.function_input_types = function_input_types + self.function_output_types = function_output_types + + # TODO: prepare a separate str and repr methods + def __repr__(self) -> str: + func_sig = get_function_signature(self.function) + return f"FunctionPod:{func_sig} ⇒ {self.output_keys}" + + def keys( + self, *streams: SyncStream + ) -> tuple[Collection[str] | None, Collection[str] | None]: + stream = self.process_stream(*streams) + tag_keys, _ = stream[0].keys() + return tag_keys, tuple(self.output_keys) + + def is_memoized(self, packet: Packet) -> bool: + return self.retrieve_memoized(packet) is not None + + def retrieve_memoized(self, packet: Packet) -> Packet | None: + """ + Retrieve a memoized packet from the data store. + Returns None if no memoized packet is found. + """ + return self.data_store.retrieve_memoized( + self.function_name, + self.content_hash(char_count=16), + packet, + ) + + def memoize( + self, + packet: Packet, + output_packet: Packet, + ) -> Packet: + """ + Memoize the output packet in the data store. + Returns the memoized packet. + """ + return self.data_store.memoize( + self.function_name, + self.content_hash(char_count=16), # identity of this function pod + packet, + output_packet, + ) + + def forward(self, *streams: SyncStream) -> SyncStream: + # if multiple streams are provided, join them + if len(streams) > 1: + raise ValueError("Multiple streams should be joined before calling forward") + if len(streams) == 0: + raise ValueError("No streams provided to forward") + stream = streams[0] + + def generator() -> Iterator[tuple[Tag, Packet]]: + n_computed = 0 + for tag, packet in stream: + output_values: list["PathSet"] = [] + try: + if not self.skip_memoization_lookup: + memoized_packet = self.retrieve_memoized(packet) + else: + memoized_packet = None + if memoized_packet is not None: + logger.info("Memoized packet found, skipping computation") + yield tag, memoized_packet + continue + values = self.function(**packet) + + if len(self.output_keys) == 0: + output_values = [] + elif len(self.output_keys) == 1: + output_values = [values] # type: ignore + elif isinstance(values, Iterable): + output_values = list(values) # type: ignore + elif len(self.output_keys) > 1: + raise ValueError( + "Values returned by function must be a pathlike or a sequence of pathlikes" + ) + + if len(output_values) != len(self.output_keys): + raise ValueError( + f"Number of output keys {len(self.output_keys)}:{self.output_keys} does not match number of values returned by function {len(output_values)}" + ) + except Exception as e: + logger.error(f"Error processing packet {packet}: {e}") + if self.error_handling == "raise": + raise e + elif self.error_handling == "ignore": + continue + elif self.error_handling == "warn": + warnings.warn(f"Error processing packet {packet}: {e}") + continue + + output_packet: Packet = { + k: v for k, v in zip(self.output_keys, output_values) + } + + if not self.skip_memoization: + # output packet may be modified by the memoization process + # e.g. if the output is a file, the path may be changed + output_packet = self.memoize(packet, output_packet) # type: ignore + + n_computed += 1 + logger.info(f"Computed item {n_computed}") + yield tag, output_packet + + return SyncStreamFromGenerator(generator) + + def identity_structure(self, *streams) -> Any: + function_hash_value = self.function_hasher.hash_to_hex(self.function) + + return ( + self.__class__.__name__, + function_hash_value, + tuple(self.output_keys), + ) + tuple(streams) From f01de04811879efb860de952b494bafa3e083ffd Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sun, 15 Jun 2025 04:38:42 +0000 Subject: [PATCH 13/28] refactor: rename pod to core --- src/orcabridge/hashing/types.py | 5 ++++ src/orcabridge/pod/__init__.py | 4 ++- src/orcabridge/pod/{pod.py => core.py} | 34 +++++++++----------------- 3 files changed, 19 insertions(+), 24 deletions(-) rename src/orcabridge/pod/{pod.py => core.py} (96%) diff --git a/src/orcabridge/hashing/types.py b/src/orcabridge/hashing/types.py index 8880a23..ddefe1f 100644 --- a/src/orcabridge/hashing/types.py +++ b/src/orcabridge/hashing/types.py @@ -90,6 +90,11 @@ class PathSetHasher(Protocol): def hash_pathset(self, pathset: PathSet) -> str: ... +@runtime_checkable +class SemanticHasher(Protocol): + pass + + @runtime_checkable class PacketHasher(Protocol): """Protocol for hashing packets (collections of pathsets).""" diff --git a/src/orcabridge/pod/__init__.py b/src/orcabridge/pod/__init__.py index b58e438..8567c2a 100644 --- a/src/orcabridge/pod/__init__.py +++ b/src/orcabridge/pod/__init__.py @@ -1,7 +1,9 @@ -from .pod import Pod, FunctionPod, function_pod +from .core import Pod, FunctionPod, function_pod, TypedFunctionPod, typed_function_pod __all__ = [ "Pod", "FunctionPod", "function_pod", + "TypedFunctionPod", + "typed_function_pod", ] diff --git a/src/orcabridge/pod/pod.py b/src/orcabridge/pod/core.py similarity index 96% rename from src/orcabridge/pod/pod.py rename to src/orcabridge/pod/core.py index bf1ac69..ea9e92d 100644 --- a/src/orcabridge/pod/pod.py +++ b/src/orcabridge/pod/core.py @@ -9,6 +9,7 @@ Any, Literal, ) +from orcabridge.types.registry import PacketConverter from orcabridge.base import Operation from orcabridge.hashing import ( @@ -124,9 +125,6 @@ def __call__(self, *streams: SyncStream, **kwargs) -> SyncStream: return super().__call__(*stream, **kwargs) -# TODO: reimplement the memoization as dependency injection - - class FunctionPod(Pod): """ A pod that wraps a function and allows it to be used as an operation in a stream. @@ -328,7 +326,7 @@ def typed_function_pod( output_keys: Collection[str] | None = None, function_name: str | None = None, **kwargs: Any, -) -> Callable[..., "FunctionPod"]: +) -> Callable[..., "TypedFunctionPod"]: """ Decorator that wraps a function in a FunctionPod instance. @@ -341,7 +339,7 @@ def typed_function_pod( FunctionPod instance wrapping the decorated function """ - def decorator(func) -> FunctionPod: + def decorator(func) -> TypedFunctionPod: if func.__name__ == "": raise ValueError("Lambda functions cannot be used with function_pod") @@ -361,17 +359,10 @@ def decorator(func) -> FunctionPod: setattr(func, "__qualname__", new_function_name) # Create the FunctionPod - pod = FunctionPod( + pod = TypedFunctionPod( function=func, output_keys=output_keys, function_name=function_name or base_function_name, - data_store=data_store, - store_name=store_name, - function_hash_mode=function_hash_mode, - custom_hash=custom_hash, - force_computation=force_computation, - skip_memoization=skip_memoization, - error_handling=error_handling, **kwargs, ) @@ -549,19 +540,16 @@ def __init__( input_types=input_types, output_types=output_types, ) - # verify that both input types and output types are supported by the registry - if not is_packet_supported(function_input_types, self.registry): - raise ValueError( - f"Input types {function_input_types} are not supported by the registry {self.registry}" - ) - if not is_packet_supported(function_output_types, self.registry): - raise ValueError( - f"Output types {function_output_types} are not supported by the registry {self.registry}" - ) self.function_input_types = function_input_types self.function_output_types = function_output_types + # TODO: include explicit check of support during PacketConverter creation + self.input_converter = PacketConverter(self.function_input_types, self.registry) + self.output_converter = PacketConverter( + self.function_output_types, self.registry + ) + # TODO: prepare a separate str and repr methods def __repr__(self) -> str: func_sig = get_function_signature(self.function) @@ -585,7 +573,7 @@ def retrieve_memoized(self, packet: Packet) -> Packet | None: return self.data_store.retrieve_memoized( self.function_name, self.content_hash(char_count=16), - packet, + self.input_converter.to_arrow_table(packet), ) def memoize( From 5ef1d871a57af853aba5875d40797b019721668b Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sun, 15 Jun 2025 04:39:17 +0000 Subject: [PATCH 14/28] refactor: cleanup type handler interface --- src/orcabridge/types/core.py | 18 ++++++++---- src/orcabridge/types/handlers.py | 48 +++++++++++++++---------------- src/orcabridge/types/inference.py | 3 +- 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/orcabridge/types/core.py b/src/orcabridge/types/core.py index 389338a..be7c3d0 100644 --- a/src/orcabridge/types/core.py +++ b/src/orcabridge/types/core.py @@ -1,4 +1,4 @@ -from typing import Protocol, Any +from typing import Protocol, Any, TypeAlias import pyarrow as pa from dataclasses import dataclass @@ -9,7 +9,13 @@ class TypeInfo: python_type: type arrow_type: pa.DataType - semantic_type: str # name under which the type is registered + semantic_type: str | None # name under which the type is registered + handler: "TypeHandler" + + +DataType: TypeAlias = type + +TypeSpec: TypeAlias = dict[str, DataType] # Mapping of parameter names to their types class TypeHandler(Protocol): @@ -23,7 +29,7 @@ class TypeHandler(Protocol): and focus purely on conversion logic. """ - def supported_types(self) -> type | tuple[type, ...]: + def python_types(self) -> type | tuple[type, ...]: """Return the Python type(s) this handler can process. Returns: @@ -36,14 +42,14 @@ def supported_types(self) -> type | tuple[type, ...]: """ ... - def to_storage_type(self) -> pa.DataType: + def storage_type(self) -> pa.DataType: """Return the Arrow DataType instance for schema definition.""" ... - def to_storage_value(self, value: Any) -> Any: + def python_to_storage(self, value: Any) -> Any: """Convert Python value to Arrow-compatible storage representation.""" ... - def from_storage_value(self, value: Any) -> Any: + def storage_to_python(self, value: Any) -> Any: """Convert storage representation back to Python object.""" ... diff --git a/src/orcabridge/types/handlers.py b/src/orcabridge/types/handlers.py index 0dcc97a..ecbdfba 100644 --- a/src/orcabridge/types/handlers.py +++ b/src/orcabridge/types/handlers.py @@ -9,48 +9,48 @@ class PathHandler: """Handler for pathlib.Path objects, stored as strings.""" - def supported_types(self) -> type: + def python_types(self) -> type: return Path - def to_storage_type(self) -> pa.DataType: + def storage_type(self) -> pa.DataType: return pa.string() - def to_storage_value(self, value: Path) -> str: + def python_to_storage(self, value: Path) -> str: return str(value) - def from_storage_value(self, value: str) -> Path | None: + def storage_to_python(self, value: str) -> Path | None: return Path(value) if value else None class UUIDHandler: """Handler for UUID objects, stored as strings.""" - def supported_types(self) -> type: + def python_types(self) -> type: return UUID - def to_storage_type(self) -> pa.DataType: + def storage_type(self) -> pa.DataType: return pa.string() - def to_storage_value(self, value: UUID) -> str: + def python_to_storage(self, value: UUID) -> str: return str(value) - def from_storage_value(self, value: str) -> UUID | None: + def storage_to_python(self, value: str) -> UUID | None: return UUID(value) if value else None class DecimalHandler: """Handler for Decimal objects, stored as strings.""" - def supported_types(self) -> type: + def python_types(self) -> type: return Decimal - def to_storage_type(self) -> pa.DataType: + def storage_type(self) -> pa.DataType: return pa.string() - def to_storage_value(self, value: Decimal) -> str: + def python_to_storage(self, value: Decimal) -> str: return str(value) - def from_storage_value(self, value: str) -> Decimal | None: + def storage_to_python(self, value: str) -> Decimal | None: return Decimal(value) if value else None @@ -61,16 +61,16 @@ def __init__(self, python_type: type, arrow_type: pa.DataType): self._python_type = python_type self._arrow_type = arrow_type - def supported_types(self) -> type: + def python_types(self) -> type: return self._python_type - def to_storage_type(self) -> pa.DataType: + def storage_type(self) -> pa.DataType: return self._arrow_type - def to_storage_value(self, value: Any) -> Any: + def python_to_storage(self, value: Any) -> Any: return value # Direct mapping - def from_storage_value(self, value: Any) -> Any: + def storage_to_python(self, value: Any) -> Any: return value # Direct mapping @@ -80,29 +80,29 @@ class DirectArrowHandler: def __init__(self, arrow_type: pa.DataType): self._arrow_type = arrow_type - def supported_types(self) -> type: + def python_types(self) -> type: return self._arrow_type - def to_storage_type(self) -> pa.DataType: + def storage_type(self) -> pa.DataType: return self._arrow_type - def to_storage_value(self, value: Any) -> Any: + def python_to_storage(self, value: Any) -> Any: return value # Direct mapping - def from_storage_value(self, value: Any) -> Any: + def storage_to_python(self, value: Any) -> Any: return value # Direct mapping class DateTimeHandler: """Handler for datetime objects.""" - def supported_types(self) -> tuple[type, ...]: + def python_types(self) -> tuple[type, ...]: return (datetime, date, time) # Handles multiple related types - def to_storage_type(self) -> pa.DataType: + def storage_type(self) -> pa.DataType: return pa.timestamp("us") # Store everything as timestamp - def to_storage_value(self, value: datetime | date | time) -> Any: + def python_to_storage(self, value: datetime | date | time) -> Any: if isinstance(value, datetime): return value elif isinstance(value, date): @@ -110,5 +110,5 @@ def to_storage_value(self, value: datetime | date | time) -> Any: elif isinstance(value, time): return datetime.combine(date.today(), value) - def from_storage_value(self, value: datetime) -> datetime: + def storage_to_python(self, value: datetime) -> datetime: return value # Could add logic to restore original type if needed diff --git a/src/orcabridge/types/inference.py b/src/orcabridge/types/inference.py index 09ea633..72a54de 100644 --- a/src/orcabridge/types/inference.py +++ b/src/orcabridge/types/inference.py @@ -3,14 +3,13 @@ from collections.abc import Callable, Collection, Sequence from typing import get_origin, get_args, TypeAlias +from .core import TypeSpec import inspect import logging from beartype.door import is_bearable, is_subhint logger = logging.getLogger(__name__) -DataType: TypeAlias = type -TypeSpec: TypeAlias = dict[str, DataType] # Mapping of parameter names to their types def verify_against_typespec(packet: dict, typespec: TypeSpec) -> bool: From c9a1d3e501bfe0aac020d1994ed31ea4a602e54a Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sun, 15 Jun 2025 04:39:58 +0000 Subject: [PATCH 15/28] refactor: place utils and new data types --- src/orcabridge/store/types.py | 22 ++++++++++++ src/orcabridge/types/__init__.py | 2 +- src/orcabridge/types/utils.py | 62 ++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 src/orcabridge/types/utils.py diff --git a/src/orcabridge/store/types.py b/src/orcabridge/store/types.py index b32aede..32092a0 100644 --- a/src/orcabridge/store/types.py +++ b/src/orcabridge/store/types.py @@ -1,6 +1,7 @@ from typing import Protocol, runtime_checkable from orcabridge.types import Packet +import pyarrow as pa @runtime_checkable @@ -22,3 +23,24 @@ def memoize( def retrieve_memoized( self, function_name: str, function_hash: str, packet: Packet ) -> Packet | None: ... + + +@runtime_checkable +class ArrowBasedDataStore(Protocol): + """ + Protocol for data stores that can memoize and retrieve packets. + This is used to define the interface for data stores like DirDataStore. + """ + + def __init__(self, *args, **kwargs) -> None: ... + def memoize( + self, + function_name: str, + function_hash: str, + packet: pa.Table, + output_packet: pa.Table, + ) -> pa.Table: ... + + def retrieve_memoized( + self, function_name: str, function_hash: str, packet: Packet + ) -> Packet | None: ... diff --git a/src/orcabridge/types/__init__.py b/src/orcabridge/types/__init__.py index 0ec194d..f82cfc9 100644 --- a/src/orcabridge/types/__init__.py +++ b/src/orcabridge/types/__init__.py @@ -33,7 +33,7 @@ # a packet is a mapping from string keys to data values -Packet: TypeAlias = Mapping[str, DataValue] +Packet: TypeAlias = dict[str, DataValue] # a batch is a tuple of a tag and a list of packets Batch: TypeAlias = tuple[Tag, Collection[Packet]] diff --git a/src/orcabridge/types/utils.py b/src/orcabridge/types/utils.py new file mode 100644 index 0000000..5393492 --- /dev/null +++ b/src/orcabridge/types/utils.py @@ -0,0 +1,62 @@ +# TODO: move these functions to util +def escape_with_postfix(field: str, postfix=None, separator="_") -> str: + """ + Escape the field string by doubling separators and optionally append a postfix. + This function takes a field string and escapes any occurrences of the separator + by doubling them, then optionally appends a postfix with a separator prefix. + + Args: + field (str): The input string containing to be escaped. + postfix (str, optional): An optional postfix to append to the escaped string. + If None, no postfix is added. Defaults to None. + separator (str, optional): The separator character to escape and use for + prefixing the postfix. Defaults to "_". + Returns: + str: The escaped string with optional postfix. Returns empty string if + fields is provided but postfix is None. + Examples: + >>> escape_with_postfix("field1_field2", "suffix") + 'field1__field2_suffix' + >>> escape_with_postfix("name_age_city", "backup", "_") + 'name__age__city_backup' + >>> escape_with_postfix("data-info", "temp", "-") + 'data--info-temp' + >>> escape_with_postfix("simple", None) + 'simple' + >>> escape_with_postfix("no_separators", "end") + 'no__separators_end' + """ + + return field.replace(separator, separator * 2) + (f"_{postfix}" if postfix else "") + + +def unescape_with_postfix(field: str, separator="_") -> tuple[str, str | None]: + """ + Unescape a string by converting double separators back to single separators and extract postfix metadata. + This function reverses the escaping process where single separators were doubled to avoid + conflicts with metadata delimiters. It splits the input on double separators, then extracts + any postfix metadata from the last part. + + Args: + field (str): The escaped string containing doubled separators and optional postfix metadata + separator (str, optional): The separator character used for escaping. Defaults to "_" + Returns: + tuple[str, str | None]: A tuple containing: + - The unescaped string with single separators restored + - The postfix metadata if present, None otherwise + Examples: + >>> unescape_with_postfix("field1__field2__field3") + ('field1_field2_field3', None) + >>> unescape_with_postfix("field1__field2_metadata") + ('field1_field2', 'metadata') + >>> unescape_with_postfix("simple") + ('simple', None) + >>> unescape_with_postfix("field1--field2", separator="-") + ('field1-field2', None) + >>> unescape_with_postfix("field1--field2-meta", separator="-") + ('field1-field2', 'meta') + """ + + parts = field.split(separator * 2) + parts[-1], *meta = parts[-1].split("_", 1) + return separator.join(parts), meta[0] if meta else None From 337fca1d389bbc9bd25112d62a9d04a90d4d68d9 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Sun, 15 Jun 2025 04:40:41 +0000 Subject: [PATCH 16/28] feat: add working implementation of packet converter to arrow --- src/orcabridge/types/default.py | 5 - src/orcabridge/types/registry.py | 678 ++++++++++++------------------- 2 files changed, 253 insertions(+), 430 deletions(-) diff --git a/src/orcabridge/types/default.py b/src/orcabridge/types/default.py index 701b987..d41e577 100644 --- a/src/orcabridge/types/default.py +++ b/src/orcabridge/types/default.py @@ -13,11 +13,6 @@ # Register with semantic names - registry extracts supported types automatically default_registry.register("path", PathHandler()) default_registry.register("uuid", UUIDHandler()) -default_registry.register("int", SimpleMappingHandler(int, pa.int64())) -default_registry.register("float", SimpleMappingHandler(float, pa.float64())) -default_registry.register("bool", SimpleMappingHandler(bool, pa.bool_())) -default_registry.register("str", SimpleMappingHandler(str, pa.string())) -default_registry.register("bytes", SimpleMappingHandler(bytes, pa.binary())) default_registry.register( "datetime", DateTimeHandler() ) # Registers for datetime, date, time diff --git a/src/orcabridge/types/registry.py b/src/orcabridge/types/registry.py index e7b4553..2870b1c 100644 --- a/src/orcabridge/types/registry.py +++ b/src/orcabridge/types/registry.py @@ -1,8 +1,21 @@ -from collections.abc import Callable +from collections.abc import Callable, Collection, Sequence +import logging +from optparse import Values from typing import Any import pyarrow as pa from orcabridge.types import Packet -from .core import TypeHandler, TypeInfo +from .core import TypeHandler, TypeInfo, TypeSpec + +# This mapping is expected to be stable +# Be sure to test this assumption holds true +DEFAULT_ARROW_TYPE_LUT = { + int: pa.int64(), + float: pa.float64(), + str: pa.string(), + bool: pa.bool_(), +} + +logger = logging.getLogger(__name__) class TypeRegistry: @@ -37,7 +50,7 @@ def register( else (explicit_types,) ) else: - supported = handler.supported_types() + supported = handler.python_types() types_to_register = ( supported if isinstance(supported, tuple) else (supported,) ) @@ -69,6 +82,19 @@ def get_semantic_name(self, python_type: type) -> str | None: handler_info = self._handlers.get(python_type) return handler_info[1] if handler_info else None + def get_type_info(self, python_type: type) -> TypeInfo | None: + """Get TypeInfo for a Python type.""" + handler = self.get_handler(python_type) + if handler is None: + return None + semantic_name = self.get_semantic_name(python_type) + return TypeInfo( + python_type=python_type, + arrow_type=handler.storage_type(), + semantic_type=semantic_name, + handler=handler, + ) + def get_handler_by_semantic_name(self, semantic_name: str) -> TypeHandler | None: """Get handler by semantic name.""" return self._semantic_handlers.get(semantic_name) @@ -78,62 +104,39 @@ def __contains__(self, python_type: type) -> bool: return python_type in self._handlers -def is_packet_supported( - packet_type_info: dict[str, type], registry: TypeRegistry -) -> bool: - """Check if all types in the packet are supported by the registry.""" - return all(python_type in registry for python_type in packet_type_info.values()) - - -def create_packet_converters( - packet_type_info: dict[str, type], registry: TypeRegistry -) -> tuple[ - Callable[[Packet], dict[str, Any]], - Callable[[dict[str, Any]], Packet], -]: - """Create optimized conversion functions for a specific packet type. - - Pre-looks up all handlers to avoid repeated registry lookups during conversion. - - Args: - type_info: Dictionary mapping parameter names to their Python types - registry: TypeRegistry containing handlers for type conversions - - Returns: - Tuple of (to_storage_converter, from_storage_converter) functions - - Raises: - ValueError: If any type in type_info is not supported by the registry +class PacketConverter: + def __init__(self, python_type_spec: TypeSpec, registry: TypeRegistry): + self.python_type_spec = python_type_spec + self.registry = registry - Example: - type_info = { - 'file_path': Path, - 'threshold': float, - 'user_id': UUID - } + # Lookup handlers and type info for fast access + self.handlers: dict[str, TypeHandler] = {} + self.storage_type_info: dict[str, TypeInfo] = {} - to_storage, from_storage = create_packet_converters(type_info, registry) + self.expected_key_set = set(python_type_spec.keys()) - # Fast conversion (no registry lookups) - storage_packet = to_storage(original_packet) - restored_packet = from_storage(storage_packet) - """ + # prepare the corresponding arrow table schema with metadata + self.keys_with_handlers, self.schema = create_schema_from_python_type_info( + python_type_spec, registry + ) - # Pre-lookup all handlers and validate they exist - handlers: dict[str, TypeHandler] = {} - expected_types: dict[str, type] = {} + self.semantic_type_lut = get_metadata_from_schema(self.schema, b"semantic_type") - for key, python_type in packet_type_info.items(): - handler = registry.get_handler(python_type) - if handler is None: - raise ValueError( - f"No handler registered for type {python_type} (key: '{key}')" - ) + def _check_key_consistency(self, keys): + """Check if the provided keys match the expected keys.""" + keys_set = set(keys) + if keys_set != self.expected_key_set: + missing_keys = self.expected_key_set - keys_set + extra_keys = keys_set - self.expected_key_set + error_parts = [] + if missing_keys: + error_parts.append(f"Missing keys: {missing_keys}") + if extra_keys: + error_parts.append(f"Extra keys: {extra_keys}") - handlers[key] = handler - expected_types[key] = python_type + raise KeyError(f"Keys don't match expected keys. {'; '.join(error_parts)}") - def to_storage_converter(packet: Packet) -> dict[str, Any]: + def _to_storage_packet(self, packet: Packet) -> dict[str, Any]: """Convert packet to storage representation. Args: @@ -149,259 +152,250 @@ def to_storage_converter(packet: Packet) -> dict[str, Any]: """ # Validate packet keys packet_keys = set(packet.keys()) - expected_keys = set(expected_types.keys()) - - if packet_keys != expected_keys: - missing_in_packet = expected_keys - packet_keys - extra_in_packet = packet_keys - expected_keys - - error_parts = [] - if missing_in_packet: - error_parts.append(f"Missing keys: {missing_in_packet}") - if extra_in_packet: - error_parts.append(f"Extra keys: {extra_in_packet}") - raise KeyError( - f"Packet keys don't match expected keys. {'; '.join(error_parts)}" - ) + self._check_key_consistency(packet_keys) # Convert each value - storage_packet = {} - - for key, value in packet.items(): - expected_type = expected_types[key] - handler = handlers[key] - - # Handle None values - if value is None: - storage_packet[key] = None - continue - - # Validate value type - if not isinstance(value, expected_type): - raise TypeError( - f"Value for '{key}' is {type(value).__name__}, expected {expected_type.__name__}" - ) + storage_packet: dict[str, Any] = ( + packet.copy() + ) # Start with a copy of the packet - # Convert to storage representation + for key, handler in self.keys_with_handlers: try: - storage_value = handler.to_storage_value(value) - storage_packet[key] = storage_value + storage_packet[key] = handler.python_to_storage(storage_packet[key]) except Exception as e: - raise ValueError( - f"Failed to convert '{key}' of type {expected_type}: {e}" - ) from e + raise ValueError(f"Failed to convert value for '{key}': {e}") from e return storage_packet - def from_storage_converter(storage_packet: dict[str, Any]) -> Packet: - """Convert storage packet back to Python values. + def _from_storage_packet(self, storage_packet: dict[str, Any]) -> Packet: + """Convert storage packet back to Python packet. Args: storage_packet: Dictionary with values in storage format Returns: - Dictionary with same keys but values converted back to Python types + Packet with values converted back to Python types Raises: - KeyError: If storage_packet keys don't match the expected type_info keys + KeyError: If storage packet keys don't match the expected type_info keys + TypeError: If value type doesn't match expected type ValueError: If conversion fails """ # Validate storage packet keys - packet_keys = set(storage_packet.keys()) - expected_keys = set(expected_types.keys()) + storage_keys = set(storage_packet.keys()) - if packet_keys != expected_keys: - missing_in_packet = expected_keys - packet_keys - extra_in_packet = packet_keys - expected_keys + self._check_key_consistency(storage_keys) - error_parts = [] - if missing_in_packet: - error_parts.append(f"Missing keys: {missing_in_packet}") - if extra_in_packet: - error_parts.append(f"Extra keys: {extra_in_packet}") - - raise KeyError( - f"Storage packet keys don't match expected keys. {'; '.join(error_parts)}" - ) + # Convert each value back to Python type + packet: Packet = storage_packet.copy() - # Convert each value back - python_packet = {} - - for key, storage_value in storage_packet.items(): - handler = handlers[key] - - # Handle None values - if storage_value is None: - python_packet[key] = None - continue - - # Convert from storage representation + for key, handler in self.keys_with_handlers: try: - python_value = handler.from_storage_value(storage_value) - python_packet[key] = python_value + packet[key] = handler.storage_to_python(storage_packet[key]) except Exception as e: - raise ValueError(f"Failed to convert '{key}' from storage: {e}") from e - - return python_packet + raise ValueError(f"Failed to convert value for '{key}': {e}") from e - return to_storage_converter, from_storage_converter + return packet + def to_arrow_table(self, packet: Packet | Sequence[Packet]) -> pa.Table: + """Convert packet to PyArrow Table with field metadata. -def convert_packet_to_storage( - packet: Packet, type_info: dict[str, type], registry: TypeRegistry -) -> Packet: - """Convert a packet to its storage representation using the provided type info. + Args: + packet: Dictionary mapping parameter names to Python values - Args: - packet: The original packet to convert - type_info: Dictionary mapping parameter names to their Python types - registry: TypeRegistry containing handlers for type conversions + Returns: + PyArrow Table with the packet data as a single row + """ + # Convert packet to storage format + if not isinstance(packet, Sequence): + packets = [packet] + else: + packets = packet - Returns: - Converted packet in storage format - """ - to_storage, _ = create_packet_converters(type_info, registry) - return to_storage(packet) + storage_packets = [self._to_storage_packet(p) for p in packets] + # Create arrays + arrays = [] + for field in self.schema: + values = [p[field.name] for p in storage_packets] + array = pa.array(values, type=field.type) + arrays.append(array) -def convert_storage_to_packet( - storage_packet: dict[str, Any], type_info: dict[str, type], registry: TypeRegistry -) -> Packet | None: - pass + return pa.Table.from_arrays(arrays, schema=self.schema) + def from_arrow_table( + self, table: pa.Table, verify_semantic_equivalence: bool = True + ) -> list[Packet]: + """Convert Arrow table to packet with field metadata. -class PacketConverter: - """ - Convenience class for converting packets between storage and Python formats. - """ - - def __init__(self, packet_type_info: dict[str, type], registry: TypeRegistry): - """Initialize the packet converter with type info and registry.""" - self._to_storage, self._from_storage = create_packet_converters( - packet_type_info, registry - ) - self.packet_type_info = packet_type_info + Args: + table: PyArrow Table with metadata - def to_storage(self, packet: Packet) -> dict[str, Any]: - """Convert packet to storage representation.""" - return self._to_storage(packet) + Returns: + List of packets converted from the Arrow table + """ + # Check for consistency in the semantic type mapping: + semantic_type_info = get_metadata_from_schema(table.schema, b"semantic_type") + + if semantic_type_info != self.semantic_type_lut: + if not verify_semantic_equivalence: + logger.warning( + "Arrow table semantic types do not match expected type registry. " + f"Expected: {self.semantic_type_lut}, got: {semantic_type_info}" + ) + else: + raise ValueError( + "Arrow table semantic types do not match expected type registry. " + f"Expected: {self.semantic_type_lut}, got: {semantic_type_info}" + ) - def from_storage(self, storage_packet: dict[str, Any]) -> Packet: - """Convert storage packet back to Python values.""" - return self._from_storage(storage_packet) + # Create packets from the Arrow table + # TODO: make this more efficient + storage_packets: list[Packet] = arrow_to_dicts(table) # type: ignore + if not self.keys_with_handlers: + # no special handling required + return storage_packets + return [self._from_storage_packet(packet) for packet in storage_packets] -def convert_packet_to_arrow_table( - packet: dict[str, Any], type_info: dict[str, type], registry: TypeRegistry -) -> pa.Table: - """Convert a single packet to a PyArrow Table with one row. +def arrow_to_dicts(table: pa.Table) -> list[dict[str, Any]]: + """ + Convert Arrow table to dictionary or list of dictionaries. + By default returns a list of dictionaries (one per row) with column names as keys. + If `collapse_singleton` is True, return a single dictionary for single-row tables. Args: - packet: Dictionary mapping parameter names to Python values - type_info: Dictionary mapping parameter names to their Python types - registry: TypeRegistry containing handlers for type conversions - + table: PyArrow Table to convert + collapse_singleton: If True, return a single dictionary for single-row tables. Defaults to False. Returns: - PyArrow Table with the packet data as a single row + A dictionary if singleton and collapse_singleton=True. Otherwise, list of dictionaries for multi-row tables. """ - # Get the converter functions - to_storage, _ = create_packet_converters(type_info, registry) + if len(table) == 0: + return [] - # Convert packet to storage format - storage_packet = to_storage(packet) + # Multiple rows: return list of dicts (one per row) + return [ + {col_name: table.column(col_name)[i].as_py() for col_name in table.column_names} + for i in range(len(table)) + ] - # Create schema - schema_fields = [] - for key, python_type in type_info.items(): - type_info_obj = registry.extract_type_info(python_type) - schema_fields.append(pa.field(key, type_info_obj.arrow_type)) - - schema = pa.schema(schema_fields) - # Convert storage packet to arrays (single element each) - arrays = [] +def get_metadata_from_schema( + schema: pa.Schema, metadata_field: bytes +) -> dict[str, str]: + """ + Extract metadata from Arrow schema fields. Metadata value will be utf-8 decoded. + Args: + schema: PyArrow Schema to extract metadata from + metadata_field: Metadata field to extract (e.g., b'semantic_type') + Returns: + Dictionary mapping field names to their metadata values + """ + metadata = {} for field in schema: - field_name = field.name - value = storage_packet[field_name] + if field.metadata and metadata_field in field.metadata: + metadata[field.name] = field.metadata[metadata_field].decode("utf-8") + return metadata + + +def create_schema_from_python_type_info( + python_type_spec: TypeSpec, + registry: TypeRegistry, + arrow_type_lut: dict[type, pa.DataType] | None = None, +) -> tuple[list[tuple[str, TypeHandler]], pa.Schema]: + if arrow_type_lut is None: + arrow_type_lut = DEFAULT_ARROW_TYPE_LUT + keys_with_handlers: list[tuple[str, TypeHandler]] = [] + schema_fields = [] + for key, python_type in python_type_spec.items(): + type_info = registry.get_type_info(python_type) - # Create single-element array - array = pa.array([value], type=field.type) - arrays.append(array) + field_metadata = {} + if type_info and type_info.semantic_type: + field_metadata["semantic_type"] = type_info.semantic_type + keys_with_handlers.append((key, type_info.handler)) + arrow_type = type_info.arrow_type + else: + arrow_type = arrow_type_lut.get(python_type) + if arrow_type is None: + raise ValueError( + f"Direct support for Python type {python_type} is not provided. Register a handler to work with {python_type}" + ) - # Create table - return pa.Table.from_arrays(arrays, schema=schema) + schema_fields.append(pa.field(key, arrow_type, metadata=field_metadata)) + return keys_with_handlers, pa.schema(schema_fields) -def convert_packets_to_arrow_table( - packets: list[dict[str, Any]], type_info: dict[str, type], registry: TypeRegistry -) -> pa.Table: - """Convert multiple packets to a PyArrow Table. +def arrow_table_to_packets( + table: pa.Table, + registry: TypeRegistry, +) -> list[Packet]: + """Convert Arrow table to packet with field metadata. Args: - packets: List of packets (dictionaries) - type_info: Dictionary mapping parameter names to their Python types - registry: TypeRegistry containing handlers for type conversions + packet: Dictionary mapping parameter names to Python values Returns: - PyArrow Table with all packet data as rows + PyArrow Table with the packet data as a single row """ - if not packets: - # Return empty table with correct schema - schema_fields = [] - for key, python_type in type_info.items(): - type_info_obj = registry.extract_type_info(python_type) - schema_fields.append(pa.field(key, type_info_obj.arrow_type)) - schema = pa.schema(schema_fields) - return pa.Table.from_arrays([], schema=schema) - - # Get the converter functions (reuse for all packets) - to_storage, _ = create_packet_converters(type_info, registry) - - # Convert all packets to storage format - storage_packets = [to_storage(packet) for packet in packets] - - # Create schema - schema_fields = [] - for key, python_type in type_info.items(): - type_info_obj = registry.extract_type_info(python_type) - schema_fields.append(pa.field(key, type_info_obj.arrow_type)) + packets: list[Packet] = [] - schema = pa.schema(schema_fields) + # prepare converter for each field - # Group values by column - column_data = {} - for field in schema: - field_name = field.name - column_data[field_name] = [packet[field_name] for packet in storage_packets] + def no_op(x) -> Any: + return x - # Create arrays for each column - arrays = [] - for field in schema: - field_name = field.name - values = column_data[field_name] - array = pa.array(values, type=field.type) - arrays.append(array) + converter_lut = {} + for field in table.schema: + if field.metadata and b"semantic_type" in field.metadata: + semantic_type = field.metadata[b"semantic_type"].decode("utf-8") + if semantic_type: + handler = registry.get_handler_by_semantic_name(semantic_type) + if handler is None: + raise ValueError( + f"No handler registered for semantic type '{semantic_type}'" + ) + converter_lut[field.name] = handler.storage_to_python + + # Create packets from the Arrow table + # TODO: make this more efficient + for row in range(table.num_rows): + packet: Packet = {} + for field in table.schema: + value = table.column(field.name)[row].as_py() + packet[field.name] = converter_lut.get(field.name, no_op)(value) + packets.append(packet) + + return packets - # Create table - return pa.Table.from_arrays(arrays, schema=schema) +def is_packet_supported( + python_type_info: TypeSpec, registry: TypeRegistry, type_lut: dict | None = None +) -> bool: + """Check if all types in the packet are supported by the registry or known to the default lut.""" + if type_lut is None: + type_lut = {} + return all( + python_type in registry or python_type in type_lut + for python_type in python_type_info.values() + ) -def convert_packet_to_arrow_table_with_field_metadata( - packet: Packet, type_info: dict[str, type], registry: TypeRegistry -) -> pa.Table: - """Convert packet to Arrow table with semantic type stored as field metadata.""" - # Get converter - to_storage, _ = create_packet_converters(type_info, registry) - storage_packet = to_storage(packet) +def create_arrow_table_with_meta( + storage_packet: dict[str, Any], type_info: dict[str, TypeInfo] +): + """Create an Arrow table with metadata from a storage packet. - # Create schema fields with metadata - schema_fields = [] - for key, python_type in type_info.items(): - type_info_obj = registry.extract_type_info(python_type) + Args: + storage_packet: Dictionary with values in storage format + type_info: Dictionary mapping parameter names to TypeInfo objects - # Create field with semantic type metadata + Returns: + PyArrow Table with metadata + """ + schema_fields = [] + for key, type_info_obj in type_info.items(): field_metadata = {} if type_info_obj.semantic_type: field_metadata["semantic_type"] = type_info_obj.semantic_type @@ -411,7 +405,6 @@ def convert_packet_to_arrow_table_with_field_metadata( schema = pa.schema(schema_fields) - # Create arrays arrays = [] for field in schema: value = storage_packet[field.name] @@ -421,191 +414,26 @@ def convert_packet_to_arrow_table_with_field_metadata( return pa.Table.from_arrays(arrays, schema=schema) -def convert_packets_to_arrow_table_with_field_metadata( - packets: list[Packet], type_info: dict[str, type], registry: TypeRegistry -) -> pa.Table: - """Convert multiple packets to Arrow table with field metadata.""" - - if not packets: - return _create_empty_table_with_field_metadata(type_info, registry) - - # Get converter - to_storage, _ = create_packet_converters(type_info, registry) - storage_packets = [to_storage(packet) for packet in packets] - - # Create schema with field metadata - schema = _create_schema_with_field_metadata(type_info, registry) - - # Group values by column - column_data = {} - for field in schema: - field_name = field.name - column_data[field_name] = [packet[field_name] for packet in storage_packets] - - # Create arrays - arrays = [] - for field in schema: - values = column_data[field.name] - array = pa.array(values, type=field.type) - arrays.append(array) - - return pa.Table.from_arrays(arrays, schema=schema) - - -def _create_schema_with_field_metadata( - type_info: dict[str, type], registry: TypeRegistry -) -> pa.Schema: - """Helper to create schema with field-level semantic type metadata.""" - schema_fields = [] - - for key, python_type in type_info.items(): - type_info_obj = registry.extract_type_info(python_type) - - # Create field metadata - field_metadata = {} - if type_info_obj.semantic_type: - field_metadata["semantic_type"] = type_info_obj.semantic_type - - field = pa.field(key, type_info_obj.arrow_type, metadata=field_metadata) - schema_fields.append(field) - - return pa.schema(schema_fields) - - -def _create_empty_table_with_field_metadata( - type_info: dict[str, type], registry: TypeRegistry -) -> pa.Table: - """Helper to create empty table with correct schema and field metadata.""" - schema = _create_schema_with_field_metadata(type_info, registry) - arrays = [pa.array([], type=field.type) for field in schema] - return pa.Table.from_arrays(arrays, schema=schema) - - -def extract_field_semantic_types(table: pa.Table) -> dict[str, str | None]: - """Extract semantic type from each field's metadata.""" - semantic_types = {} - - for field in table.schema: - if field.metadata and b"semantic_type" in field.metadata: - semantic_type = field.metadata[b"semantic_type"].decode("utf-8") - semantic_types[field.name] = semantic_type - else: - semantic_types[field.name] = None - - return semantic_types - - -def convert_arrow_table_to_packets_with_field_metadata( - table: pa.Table, registry: TypeRegistry -) -> list[Packet]: - """Convert Arrow table back to packets using field metadata.""" - - # Extract semantic types from field metadata - field_semantic_types = extract_field_semantic_types(table) - - # Reconstruct type_info from field metadata - type_info = {} - for field in table.schema: - field_name = field.name - semantic_type = field_semantic_types.get(field_name) - - if semantic_type: - # Get handler by semantic type - handler = registry.get_handler_by_semantic_name(semantic_type) - if handler: - python_type = handler.supported_types() - if isinstance(python_type, tuple): - python_type = python_type[0] # Take first if multiple - type_info[field_name] = python_type - else: - # Fallback to basic type inference - type_info[field_name] = _infer_python_type_from_arrow(field.type) - else: - # No semantic type metadata - infer from Arrow type - type_info[field_name] = _infer_python_type_from_arrow(field.type) - - # Convert using reconstructed type info - _, from_storage = create_packet_converters(type_info, registry) - storage_packets = table.to_pylist() - - return [from_storage(packet) for packet in storage_packets] - - -def _infer_python_type_from_arrow(arrow_type: pa.DataType) -> type: - """Infer Python type from Arrow type as fallback.""" - if arrow_type == pa.int64(): - return int - elif arrow_type == pa.float64(): - return float - elif arrow_type == pa.string(): - return str - elif arrow_type == pa.bool_(): - return bool - elif arrow_type == pa.binary(): - return bytes - else: - return str # Safe fallback - - -# TODO: move these functions to util -def escape_with_postfix(field: str, postfix=None, separator="_") -> str: - """ - Escape the field string by doubling separators and optionally append a postfix. - This function takes a field string and escapes any occurrences of the separator - by doubling them, then optionally appends a postfix with a separator prefix. +def retrieve_storage_packet_from_arrow_with_meta( + arrow_table: pa.Table, +) -> dict[str, Any]: + """Retrieve storage packet from Arrow table with metadata. Args: - field (str): The input string containing to be escaped. - postfix (str, optional): An optional postfix to append to the escaped string. - If None, no postfix is added. Defaults to None. - separator (str, optional): The separator character to escape and use for - prefixing the postfix. Defaults to "_". - Returns: - str: The escaped string with optional postfix. Returns empty string if - fields is provided but postfix is None. - Examples: - >>> escape_with_postfix("field1_field2", "suffix") - 'field1__field2_suffix' - >>> escape_with_postfix("name_age_city", "backup", "_") - 'name__age__city_backup' - >>> escape_with_postfix("data-info", "temp", "-") - 'data--info-temp' - >>> escape_with_postfix("simple", None) - 'simple' - >>> escape_with_postfix("no_separators", "end") - 'no__separators_end' - """ - - return field.replace(separator, separator * 2) + (f"_{postfix}" if postfix else "") + arrow_table: PyArrow Table with metadata - -def unescape_with_postfix(field: str, separator="_") -> tuple[str, str | None]: - """ - Unescape a string by converting double separators back to single separators and extract postfix metadata. - This function reverses the escaping process where single separators were doubled to avoid - conflicts with metadata delimiters. It splits the input on double separators, then extracts - any postfix metadata from the last part. - - Args: - field (str): The escaped string containing doubled separators and optional postfix metadata - separator (str, optional): The separator character used for escaping. Defaults to "_" Returns: - tuple[str, str | None]: A tuple containing: - - The unescaped string with single separators restored - - The postfix metadata if present, None otherwise - Examples: - >>> unescape_with_postfix("field1__field2__field3") - ('field1_field2_field3', None) - >>> unescape_with_postfix("field1__field2_metadata") - ('field1_field2', 'metadata') - >>> unescape_with_postfix("simple") - ('simple', None) - >>> unescape_with_postfix("field1--field2", separator="-") - ('field1-field2', None) - >>> unescape_with_postfix("field1--field2-meta", separator="-") - ('field1-field2', 'meta') + Dictionary representing the storage packet """ + storage_packet = {} + for field in arrow_table.schema: + # Extract value from Arrow array + array = arrow_table.column(field.name) + if array.num_chunks > 0: + value = array.chunk(0).as_py()[0] # Get first value + else: + value = None # Handle empty arrays + + storage_packet[field.name] = value - parts = field.split(separator * 2) - parts[-1], *meta = parts[-1].split("_", 1) - return separator.join(parts), meta[0] if meta else None + return storage_packet From 69f71ee2f241813ddd18a6e13a0ff3d4f532c9aa Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 16 Jun 2025 09:23:20 +0000 Subject: [PATCH 17/28] feat: add functional parquet arrow data store --- src/orcabridge/store/__init__.py | 4 +- src/orcabridge/store/arrow_data_stores.py | 1008 +++++++++++++++++++++ src/orcabridge/store/types.py | 32 +- src/orcabridge/types/__init__.py | 3 +- 4 files changed, 1035 insertions(+), 12 deletions(-) create mode 100644 src/orcabridge/store/arrow_data_stores.py diff --git a/src/orcabridge/store/__init__.py b/src/orcabridge/store/__init__.py index 66a68df..f573c4d 100644 --- a/src/orcabridge/store/__init__.py +++ b/src/orcabridge/store/__init__.py @@ -1,8 +1,10 @@ -from .core import DataStore, DirDataStore, NoOpDataStore +from .types import DataStore, ArrowDataStore +from .core import DirDataStore, NoOpDataStore from .safe_dir_data_store import SafeDirDataStore __all__ = [ "DataStore", + "ArrowDataStore", "DirDataStore", "SafeDirDataStore", "NoOpDataStore", diff --git a/src/orcabridge/store/arrow_data_stores.py b/src/orcabridge/store/arrow_data_stores.py new file mode 100644 index 0000000..475e506 --- /dev/null +++ b/src/orcabridge/store/arrow_data_stores.py @@ -0,0 +1,1008 @@ +import pyarrow as pa +import pyarrow.parquet as pq +import pyarrow.dataset as ds +import polars as pl +import os +import json +import threading +import time +from pathlib import Path +from typing import Any, cast +from dataclasses import dataclass +from datetime import datetime, timedelta +import logging +from collections import defaultdict + + +# Module-level logger +logger = logging.getLogger(__name__) + + +@dataclass +class RecordMetadata: + """Metadata for a stored record.""" + + source_name: str + source_id: str + entry_id: str + created_at: datetime + updated_at: datetime + schema_hash: str + parquet_path: str | None = None # Path to the specific partition + + +class SourceCache: + """Cache for a specific source_name/source_id combination.""" + + def __init__( + self, + source_name: str, + source_id: str, + base_path: Path, + partition_prefix_length: int = 2, + ): + self.source_name = source_name + self.source_id = source_id + self.base_path = base_path + self.source_dir = base_path / source_name / source_id + self.partition_prefix_length = partition_prefix_length + + # In-memory data - only for this source + self._memory_table: pl.DataFrame | None = None + self._loaded = False + self._dirty = False + self._last_access = datetime.now() + + # Track which entries are in memory vs on disk + self._memory_entries: set[str] = set() + self._disk_entries: set[str] = set() + + # Track which partitions are dirty (need to be rewritten) + self._dirty_partitions: set[str] = set() + + self._lock = threading.RLock() + + def _get_partition_key(self, entry_id: str) -> str: + """Get the partition key for an entry_id.""" + if len(entry_id) < self.partition_prefix_length: + return entry_id.ljust(self.partition_prefix_length, "0") + return entry_id[: self.partition_prefix_length] + + def _get_partition_path(self, entry_id: str) -> Path: + """Get the partition directory for an entry_id.""" + partition_key = self._get_partition_key(entry_id) + # Use prefix_ instead of entry_id= to avoid Hive partitioning issues + return self.source_dir / f"prefix_{partition_key}" + + def _get_partition_parquet_path(self, entry_id: str) -> Path: + """Get the Parquet file path for a partition.""" + partition_dir = self._get_partition_path(entry_id) + partition_key = self._get_partition_key(entry_id) + return partition_dir / f"partition_{partition_key}.parquet" + + def _load_from_disk_lazy(self) -> None: + """Lazily load data from disk only when first accessed.""" + if self._loaded: + return + + with self._lock: + if self._loaded: # Double-check after acquiring lock + return + + logger.debug(f"Lazy loading {self.source_name}/{self.source_id}") + + all_tables = [] + + if self.source_dir.exists(): + # Scan all partition directories + for partition_dir in self.source_dir.iterdir(): + if not partition_dir.is_dir() or not ( + partition_dir.name.startswith("entry_id=") + or partition_dir.name.startswith("prefix_") + ): + continue + + # Load the partition Parquet file (one per partition) + if partition_dir.name.startswith("entry_id="): + partition_key = partition_dir.name.split("=")[1] + else: # prefix_XX format + partition_key = partition_dir.name.split("_")[1] + + parquet_file = partition_dir / f"partition_{partition_key}.parquet" + + if parquet_file.exists(): + try: + table = pq.read_table(parquet_file) + if len(table) > 0: + polars_df = pl.from_arrow(table) + all_tables.append(polars_df) + + logger.debug( + f"Loaded partition {parquet_file}: {len(table)} rows, {len(table.columns)} columns" + ) + logger.debug(f" Columns: {table.column_names}") + + # Track disk entries from this partition + if "__entry_id" in table.column_names: + entry_ids = set( + table.column("__entry_id").to_pylist() + ) + self._disk_entries.update(entry_ids) + + except Exception as e: + logger.error(f"Failed to load {parquet_file}: {e}") + + # Combine all tables + if all_tables: + self._memory_table = pl.concat(all_tables) + self._memory_entries = self._disk_entries.copy() + logger.debug( + f"Combined loaded data: {len(self._memory_table)} rows, {len(self._memory_table.columns)} columns" + ) + logger.debug(f" Final columns: {self._memory_table.columns}") + + self._loaded = True + self._last_access = datetime.now() + + def add_entry( + self, + entry_id: str, + table_with_metadata: pa.Table, + allow_overwrite: bool = False, + ) -> None: + """Add an entry to this source cache.""" + with self._lock: + self._load_from_disk_lazy() # Ensure we're loaded + + # Check if entry already exists + entry_exists = ( + entry_id in self._memory_entries or entry_id in self._disk_entries + ) + + if entry_exists and not allow_overwrite: + raise ValueError( + f"Entry {entry_id} already exists in {self.source_name}/{self.source_id}" + ) + + # We know this returns DataFrame since we're passing a Table + polars_table = cast(pl.DataFrame, pl.from_arrow(table_with_metadata)) + + if self._memory_table is None: + self._memory_table = polars_table + else: + # Remove existing entry if it exists (for overwrite case) + if entry_id in self._memory_entries: + mask = self._memory_table["__entry_id"] != entry_id + self._memory_table = self._memory_table.filter(mask) + logger.debug(f"Removed existing entry {entry_id} for overwrite") + + # Debug schema mismatch + existing_cols = self._memory_table.columns + new_cols = polars_table.columns + + if len(existing_cols) != len(new_cols): + logger.error(f"Schema mismatch for entry {entry_id}:") + logger.error( + f" Existing columns ({len(existing_cols)}): {existing_cols}" + ) + logger.error(f" New columns ({len(new_cols)}): {new_cols}") + logger.error( + f" Missing in new: {set(existing_cols) - set(new_cols)}" + ) + logger.error( + f" Extra in new: {set(new_cols) - set(existing_cols)}" + ) + + raise ValueError( + f"Schema mismatch: existing table has {len(existing_cols)} columns, " + f"new table has {len(new_cols)} columns" + ) + + # Ensure column order matches + if existing_cols != new_cols: + logger.debug(f"Reordering columns to match existing schema") + polars_table = polars_table.select(existing_cols) + + # Add new entry + self._memory_table = pl.concat([self._memory_table, polars_table]) + + self._memory_entries.add(entry_id) + self._dirty = True + + # Mark the partition as dirty + partition_key = self._get_partition_key(entry_id) + self._dirty_partitions.add(partition_key) + + self._last_access = datetime.now() + + if entry_exists: + logger.info(f"Overwrote existing entry {entry_id}") + else: + logger.debug(f"Added new entry {entry_id}") + + def get_entry(self, entry_id: str) -> pa.Table | None: + """Get a specific entry.""" + with self._lock: + self._load_from_disk_lazy() + + if self._memory_table is None: + return None + + mask = self._memory_table["__entry_id"] == entry_id + filtered = self._memory_table.filter(mask) + + if len(filtered) == 0: + return None + + self._last_access = datetime.now() + return filtered.to_arrow() + + def get_all_entries(self) -> pa.Table | None: + """Get all entries for this source.""" + with self._lock: + self._load_from_disk_lazy() + + if self._memory_table is None: + return None + + self._last_access = datetime.now() + return self._memory_table.to_arrow() + + def get_all_entries_as_polars(self) -> pl.LazyFrame | None: + """Get all entries as a Polars LazyFrame.""" + with self._lock: + self._load_from_disk_lazy() + + if self._memory_table is None: + return None + + self._last_access = datetime.now() + return self._memory_table.lazy() + + def sync_to_disk(self) -> None: + """Sync dirty partitions to disk using efficient Parquet files.""" + with self._lock: + if not self._dirty or self._memory_table is None: + return + + logger.debug(f"Syncing {self.source_name}/{self.source_id} to disk") + + # Only sync dirty partitions + for partition_key in self._dirty_partitions: + try: + # Get all entries for this partition + partition_mask = ( + self._memory_table["__entry_id"].str.slice( + 0, self.partition_prefix_length + ) + == partition_key + ) + partition_data = self._memory_table.filter(partition_mask) + + if len(partition_data) == 0: + continue + + logger.debug(f"Syncing partition {partition_key}:") + logger.debug(f" Rows: {len(partition_data)}") + logger.debug(f" Columns: {partition_data.columns}") + logger.debug( + f" Sample __entry_id values: {partition_data['__entry_id'].head(3).to_list()}" + ) + + # Ensure partition directory exists + partition_dir = self.source_dir / f"prefix_{partition_key}" + partition_dir.mkdir(parents=True, exist_ok=True) + + # Write entire partition to single Parquet file + partition_path = ( + partition_dir / f"partition_{partition_key}.parquet" + ) + arrow_table = partition_data.to_arrow() + + logger.debug( + f" Arrow table columns before write: {arrow_table.column_names}" + ) + logger.debug(f" Arrow table shape: {arrow_table.shape}") + + pq.write_table(arrow_table, partition_path) + + # Verify what was written + verification_table = pq.read_table(partition_path) + logger.debug( + f" Verification - columns after write: {verification_table.column_names}" + ) + logger.debug(f" Verification - shape: {verification_table.shape}") + + entry_count = len(set(partition_data["__entry_id"].to_list())) + logger.debug( + f"Wrote partition {partition_key} with {entry_count} entries ({len(partition_data)} rows)" + ) + + except Exception as e: + logger.error(f"Failed to write partition {partition_key}: {e}") + import traceback + + logger.error(f"Traceback: {traceback.format_exc()}") + + # Clear dirty markers + self._dirty_partitions.clear() + self._dirty = False + + def is_loaded(self) -> bool: + """Check if this cache is loaded in memory.""" + return self._loaded + + def get_last_access(self) -> datetime: + """Get the last access time.""" + return self._last_access + + def unload(self) -> None: + """Unload from memory (after syncing if dirty).""" + with self._lock: + if self._dirty: + self.sync_to_disk() + + self._memory_table = None + self._loaded = False + self._memory_entries.clear() + # Keep _disk_entries for reference + + def entry_exists(self, entry_id: str) -> bool: + """Check if an entry exists (in memory or on disk).""" + with self._lock: + self._load_from_disk_lazy() + return entry_id in self._memory_entries or entry_id in self._disk_entries + + def list_entries(self) -> set[str]: + """List all entry IDs in this source.""" + with self._lock: + self._load_from_disk_lazy() + return self._memory_entries | self._disk_entries + + def get_stats(self) -> dict[str, Any]: + """Get statistics for this cache.""" + with self._lock: + return { + "source_name": self.source_name, + "source_id": self.source_id, + "loaded": self._loaded, + "dirty": self._dirty, + "memory_entries": len(self._memory_entries), + "disk_entries": len(self._disk_entries), + "memory_rows": len(self._memory_table) + if self._memory_table is not None + else 0, + "last_access": self._last_access.isoformat(), + } + + +class ParquetArrowDataStore: + """ + Lazy-loading, append-only Arrow data store with entry_id partitioning. + + Features: + - Lazy loading: Only loads source data when first accessed + - Separate memory management per source_name/source_id + - Entry_id partitioning: Multiple entries per Parquet file based on prefix + - Configurable duplicate entry_id handling (error or overwrite) + - Automatic cache eviction for memory management + - Single-row constraint: Each record must contain exactly one row + """ + + def __init__( + self, + base_path: str | Path, + sync_interval_seconds: int = 300, # 5 minutes default + auto_sync: bool = True, + max_loaded_sources: int = 100, + cache_eviction_hours: int = 2, + duplicate_entry_behavior: str = "error", + partition_prefix_length: int = 2, + ): + """ + Initialize the ParquetArrowDataStore. + + Args: + base_path: Directory path for storing Parquet files + sync_interval_seconds: How often to sync dirty caches to disk + auto_sync: Whether to automatically sync on a timer + max_loaded_sources: Maximum number of source caches to keep in memory + cache_eviction_hours: Hours of inactivity before evicting from memory + duplicate_entry_behavior: How to handle duplicate entry_ids: + - 'error': Raise ValueError when entry_id already exists + - 'overwrite': Replace existing entry with new data + partition_prefix_length: Number of characters from entry_id to use for partitioning (default 2) + """ + self.base_path = Path(base_path) + self.base_path.mkdir(parents=True, exist_ok=True) + self.sync_interval = sync_interval_seconds + self.auto_sync = auto_sync + self.max_loaded_sources = max_loaded_sources + self.cache_eviction_hours = cache_eviction_hours + self.partition_prefix_length = max( + 1, min(8, partition_prefix_length) + ) # Clamp between 1-8 + + # Validate duplicate behavior + if duplicate_entry_behavior not in ["error", "overwrite"]: + raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") + self.duplicate_entry_behavior = duplicate_entry_behavior + + # Cache management + self._source_caches: dict[str, SourceCache] = {} # key: "source_name:source_id" + self._global_lock = threading.RLock() + + # Record metadata (always in memory for fast lookups) + self._record_metadata: dict[str, RecordMetadata] = {} + self._load_metadata_index() + + # Sync management + self._sync_timer: threading.Timer | None = None + self._shutdown = False + + # Start auto-sync and cleanup if enabled + if self.auto_sync: + self._start_sync_timer() + + logger.info(f"Initialized lazy ParquetArrowDataStore at {base_path}") + + def _get_source_key(self, source_name: str, source_id: str) -> str: + """Generate key for source cache.""" + return f"{source_name}:{source_id}" + + def _get_record_key(self, source_name: str, source_id: str, entry_id: str) -> str: + """Generate unique key for a record.""" + return f"{source_name}:{source_id}:{entry_id}" + + def _load_metadata_index(self) -> None: + """Load metadata index from disk (lightweight - just file paths and timestamps).""" + logger.info("Loading metadata index...") + + if not self.base_path.exists(): + return + + for source_name_dir in self.base_path.iterdir(): + if not source_name_dir.is_dir(): + continue + + source_name = source_name_dir.name + + for source_id_dir in source_name_dir.iterdir(): + if not source_id_dir.is_dir(): + continue + + source_id = source_id_dir.name + + # Scan partition directories for parquet files + for partition_dir in source_id_dir.iterdir(): + if not partition_dir.is_dir() or not ( + partition_dir.name.startswith("entry_id=") + or partition_dir.name.startswith("prefix_") + ): + continue + + for parquet_file in partition_dir.glob("partition_*.parquet"): + try: + # Read the parquet file to extract entry IDs + table = pq.read_table(parquet_file) + if "__entry_id" in table.column_names: + entry_ids = set(table.column("__entry_id").to_pylist()) + + # Get file stats + stat = parquet_file.stat() + created_at = datetime.fromtimestamp(stat.st_ctime) + updated_at = datetime.fromtimestamp(stat.st_mtime) + + for entry_id in entry_ids: + record_key = self._get_record_key( + source_name, source_id, entry_id + ) + self._record_metadata[record_key] = RecordMetadata( + source_name=source_name, + source_id=source_id, + entry_id=entry_id, + created_at=created_at, + updated_at=updated_at, + schema_hash="unknown", # Will be computed if needed + parquet_path=str(parquet_file), + ) + except Exception as e: + logger.error( + f"Failed to read metadata from {parquet_file}: {e}" + ) + + logger.info(f"Loaded metadata for {len(self._record_metadata)} records") + + def _get_or_create_source_cache( + self, source_name: str, source_id: str + ) -> SourceCache: + """Get or create a source cache, handling eviction if needed.""" + source_key = self._get_source_key(source_name, source_id) + + with self._global_lock: + if source_key not in self._source_caches: + # Check if we need to evict old caches + if len(self._source_caches) >= self.max_loaded_sources: + self._evict_old_caches() + + # Create new cache with partition configuration + self._source_caches[source_key] = SourceCache( + source_name, source_id, self.base_path, self.partition_prefix_length + ) + logger.debug(f"Created cache for {source_key}") + + return self._source_caches[source_key] + + def _evict_old_caches(self) -> None: + """Evict old caches based on last access time.""" + cutoff_time = datetime.now() - timedelta(hours=self.cache_eviction_hours) + + to_evict = [] + for source_key, cache in self._source_caches.items(): + if cache.get_last_access() < cutoff_time: + to_evict.append(source_key) + + for source_key in to_evict: + cache = self._source_caches.pop(source_key) + cache.unload() # This will sync if dirty + logger.debug(f"Evicted cache for {source_key}") + + def _compute_schema_hash(self, table: pa.Table) -> str: + """Compute a hash of the table schema.""" + import hashlib + + schema_str = str(table.schema) + return hashlib.sha256(schema_str.encode()).hexdigest()[:16] + + def _add_system_columns( + self, table: pa.Table, metadata: RecordMetadata + ) -> pa.Table: + """Add system columns to track record metadata.""" + # Keep all system columns for self-describing data + system_columns = [ + ("__source_name", pa.array([metadata.source_name] * len(table))), + ("__source_id", pa.array([metadata.source_id] * len(table))), + ("__entry_id", pa.array([metadata.entry_id] * len(table))), + ("__created_at", pa.array([metadata.created_at] * len(table))), + ("__updated_at", pa.array([metadata.updated_at] * len(table))), + ("__schema_hash", pa.array([metadata.schema_hash] * len(table))), + ] + + # Combine user columns + system columns in consistent order + new_columns = list(table.columns) + [col[1] for col in system_columns] + new_names = table.column_names + [col[0] for col in system_columns] + + result = pa.table(new_columns, names=new_names) + logger.debug( + f"Added system columns: {len(table.columns)} -> {len(result.columns)} columns" + ) + return result + + def _remove_system_columns(self, table: pa.Table) -> pa.Table: + """Remove system columns to get original user data.""" + system_cols = [ + "__source_name", + "__source_id", + "__entry_id", + "__created_at", + "__updated_at", + "__schema_hash", + ] + user_columns = [name for name in table.column_names if name not in system_cols] + return table.select(user_columns) + + def add_record( + self, source_name: str, source_id: str, entry_id: str, arrow_data: pa.Table + ) -> pa.Table: + """ + Add or update a record (append-only operation). + + Args: + source_name: Name of the data source + source_id: ID of the specific dataset within the source + entry_id: Unique identifier for this record (typically 32-char hash) + arrow_data: The Arrow table data to store (MUST contain exactly 1 row) + + Returns: + The original arrow_data table + + Raises: + ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' + ValueError: If arrow_data contains more than 1 row + ValueError: If arrow_data schema doesn't match existing data for this source + """ + # CRITICAL: Enforce single-row constraint + if len(arrow_data) != 1: + raise ValueError( + f"Each record must contain exactly 1 row, got {len(arrow_data)} rows. " + f"This constraint ensures that for each source_name/source_id combination, " + f"there is only one valid entry per entry_id." + ) + + # Validate entry_id format (assuming 8+ char identifier) + if not entry_id or len(entry_id) < 8: + raise ValueError( + f"entry_id must be at least 8 characters long, got: '{entry_id}'" + ) + + # Check if this source already has data and validate schema compatibility + cache = self._get_or_create_source_cache(source_name, source_id) + + # Load existing data to check schema compatibility + cache._load_from_disk_lazy() + + if cache._memory_table is not None: + # Extract user columns from existing data (remove system columns) + existing_arrow = cache._memory_table.to_arrow() + existing_user_data = self._remove_system_columns(existing_arrow) + + # Check if schemas match + existing_schema = existing_user_data.schema + new_schema = arrow_data.schema + + if not existing_schema.equals(new_schema): + existing_cols = existing_user_data.column_names + new_cols = arrow_data.column_names + + logger.error(f"Schema mismatch for {source_name}/{source_id}:") + logger.error(f" Existing user columns: {existing_cols}") + logger.error(f" New user columns: {new_cols}") + logger.error(f" Missing in new: {set(existing_cols) - set(new_cols)}") + logger.error(f" Extra in new: {set(new_cols) - set(existing_cols)}") + + raise ValueError( + f"Schema mismatch for {source_name}/{source_id}. " + f"Existing data has columns {existing_cols}, " + f"but new data has columns {new_cols}. " + f"All records in a source must have the same schema." + ) + + now = datetime.now() + record_key = self._get_record_key(source_name, source_id, entry_id) + + # Check for existing entry + existing_metadata = self._record_metadata.get(record_key) + entry_exists = existing_metadata is not None + + if entry_exists and self.duplicate_entry_behavior == "error": + raise ValueError( + f"Entry '{entry_id}' already exists in {source_name}/{source_id}. " + f"Use duplicate_entry_behavior='overwrite' to allow updates." + ) + + # Create/update metadata + schema_hash = self._compute_schema_hash(arrow_data) + metadata = RecordMetadata( + source_name=source_name, + source_id=source_id, + entry_id=entry_id, + created_at=existing_metadata.created_at if existing_metadata else now, + updated_at=now, + schema_hash=schema_hash, + ) + + # Add system columns + table_with_metadata = self._add_system_columns(arrow_data, metadata) + + # Get or create source cache and add entry + allow_overwrite = self.duplicate_entry_behavior == "overwrite" + + try: + cache.add_entry(entry_id, table_with_metadata, allow_overwrite) + except ValueError as e: + # Re-raise with more context + raise ValueError(f"Failed to add record: {e}") + + # Update metadata + self._record_metadata[record_key] = metadata + + action = "Updated" if entry_exists else "Added" + logger.info(f"{action} record {record_key} with {len(arrow_data)} rows") + return arrow_data + + def get_record( + self, source_name: str, source_id: str, entry_id: str + ) -> pa.Table | None: + """Retrieve a specific record.""" + record_key = self._get_record_key(source_name, source_id, entry_id) + + if record_key not in self._record_metadata: + return None + + cache = self._get_or_create_source_cache(source_name, source_id) + table = cache.get_entry(entry_id) + + if table is None: + return None + + return self._remove_system_columns(table) + + def get_all_records(self, source_name: str, source_id: str) -> pa.Table | None: + """Retrieve all records for a given source as a single Arrow table.""" + cache = self._get_or_create_source_cache(source_name, source_id) + table = cache.get_all_entries() + + if table is None: + return None + + return self._remove_system_columns(table) + + def get_all_records_as_polars( + self, source_name: str, source_id: str + ) -> pl.LazyFrame | None: + """Retrieve all records for a given source as a Polars LazyFrame.""" + cache = self._get_or_create_source_cache(source_name, source_id) + lazy_frame = cache.get_all_entries_as_polars() + + if lazy_frame is None: + return None + + # Remove system columns + system_cols = ["__entry_id", "__created_at", "__updated_at", "__schema_hash"] + user_columns = [col for col in lazy_frame.columns if col not in system_cols] + + return lazy_frame.select(user_columns) + + def _sync_all_dirty_caches(self) -> None: + """Sync all dirty caches to disk.""" + with self._global_lock: + dirty_count = 0 + for cache in self._source_caches.values(): + if cache._dirty: + cache.sync_to_disk() + dirty_count += 1 + + if dirty_count > 0: + logger.info(f"Synced {dirty_count} dirty caches to disk") + + def _start_sync_timer(self) -> None: + """Start the automatic sync timer.""" + if self._shutdown: + return + + self._sync_timer = threading.Timer( + self.sync_interval, self._sync_and_reschedule + ) + self._sync_timer.daemon = True + self._sync_timer.start() + + def _sync_and_reschedule(self) -> None: + """Sync dirty caches and reschedule.""" + try: + self._sync_all_dirty_caches() + self._evict_old_caches() + except Exception as e: + logger.error(f"Auto-sync failed: {e}") + finally: + if not self._shutdown: + self._start_sync_timer() + + def force_sync(self) -> None: + """Manually trigger a sync of all dirty caches.""" + self._sync_all_dirty_caches() + + def entry_exists(self, source_name: str, source_id: str, entry_id: str) -> bool: + """Check if a specific entry exists.""" + record_key = self._get_record_key(source_name, source_id, entry_id) + + # Check metadata first (fast) + if record_key in self._record_metadata: + return True + + # If not in metadata, check if source cache knows about it + source_key = self._get_source_key(source_name, source_id) + if source_key in self._source_caches: + cache = self._source_caches[source_key] + return cache.entry_exists(entry_id) + + # Not loaded and not in metadata - doesn't exist + return False + + def list_entries(self, source_name: str, source_id: str) -> set[str]: + """List all entry IDs for a specific source.""" + cache = self._get_or_create_source_cache(source_name, source_id) + return cache.list_entries() + + def list_sources(self) -> set[tuple[str, str]]: + """List all (source_name, source_id) combinations.""" + sources = set() + + # From metadata + for metadata in self._record_metadata.values(): + sources.add((metadata.source_name, metadata.source_id)) + + return sources + + def get_stats(self) -> dict[str, Any]: + """Get comprehensive statistics about the data store.""" + with self._global_lock: + loaded_caches = len(self._source_caches) + dirty_caches = sum( + 1 for cache in self._source_caches.values() if cache._dirty + ) + + cache_stats = [cache.get_stats() for cache in self._source_caches.values()] + + return { + "total_records": len(self._record_metadata), + "loaded_source_caches": loaded_caches, + "dirty_caches": dirty_caches, + "max_loaded_sources": self.max_loaded_sources, + "sync_interval": self.sync_interval, + "auto_sync": self.auto_sync, + "cache_eviction_hours": self.cache_eviction_hours, + "base_path": str(self.base_path), + "duplicate_entry_behavior": self.duplicate_entry_behavior, + "partition_prefix_length": self.partition_prefix_length, + "cache_details": cache_stats, + } + + def shutdown(self) -> None: + """Shutdown the data store, ensuring all data is synced.""" + logger.info("Shutting down ParquetArrowDataStore...") + self._shutdown = True + + if self._sync_timer: + self._sync_timer.cancel() + + # Final sync of all caches + self._sync_all_dirty_caches() + + logger.info("Shutdown complete") + + def __del__(self): + """Ensure cleanup on destruction.""" + if not self._shutdown: + self.shutdown() + + +# Example usage and testing +def demo_single_row_constraint(): + """Demonstrate the single-row constraint in the ParquetArrowDataStore.""" + import tempfile + import random + from datetime import timedelta + + def create_single_row_record(entry_id: str, value: float | None = None) -> pa.Table: + """Create a single-row Arrow table.""" + if value is None: + value = random.uniform(0, 100) + + return pa.table( + { + "entry_id": [entry_id], + "timestamp": [datetime.now()], + "value": [value], + "category": [random.choice(["A", "B", "C"])], + } + ) + + def create_multi_row_record(entry_id: str, num_rows: int = 3) -> pa.Table: + """Create a multi-row Arrow table (should be rejected).""" + return pa.table( + { + "entry_id": [entry_id] * num_rows, + "timestamp": [ + datetime.now() + timedelta(seconds=i) for i in range(num_rows) + ], + "value": [random.uniform(0, 100) for _ in range(num_rows)], + "category": [random.choice(["A", "B", "C"]) for _ in range(num_rows)], + } + ) + + print("Testing Single-Row Constraint...") + + with tempfile.TemporaryDirectory() as temp_dir: + store = ParquetArrowDataStore( + base_path=temp_dir, + sync_interval_seconds=10, + auto_sync=False, # Manual sync for testing + duplicate_entry_behavior="overwrite", + ) + + try: + print("\n=== Testing Valid Single-Row Records ===") + + # Test 1: Add valid single-row records + valid_entries = [ + "entry_001_abcdef1234567890abcdef1234567890", + "entry_002_abcdef1234567890abcdef1234567890", + "entry_003_abcdef1234567890abcdef1234567890", + ] + + for i, entry_id in enumerate(valid_entries): + data = create_single_row_record(entry_id, value=100.0 + i) + result = store.add_record("experiments", "dataset_A", entry_id, data) + print( + f"✓ Added single-row record {entry_id[:16]}... (value: {100.0 + i})" + ) + + print(f"\nTotal records stored: {len(store._record_metadata)}") + + print("\n=== Testing Invalid Multi-Row Records ===") + + # Test 2: Try to add multi-row record (should fail) + invalid_entry = "entry_004_abcdef1234567890abcdef1234567890" + try: + invalid_data = create_multi_row_record(invalid_entry, num_rows=3) + store.add_record( + "experiments", "dataset_A", invalid_entry, invalid_data + ) + print("✗ ERROR: Multi-row record was accepted!") + except ValueError as e: + print(f"✓ Correctly rejected multi-row record: {str(e)[:80]}...") + + # Test 3: Try to add empty record (should fail) + empty_entry = "entry_005_abcdef1234567890abcdef1234567890" + try: + empty_data = pa.table({"col1": pa.array([], type=pa.int64())}) + store.add_record("experiments", "dataset_A", empty_entry, empty_data) + print("✗ ERROR: Empty record was accepted!") + except ValueError as e: + print(f"✓ Correctly rejected empty record: {str(e)[:80]}...") + + print("\n=== Testing Retrieval ===") + + # Test 4: Retrieve records + retrieved = store.get_record("experiments", "dataset_A", valid_entries[0]) + if retrieved and len(retrieved) == 1: + print(f"✓ Retrieved single record: {len(retrieved)} row") + print(f" Value: {retrieved.column('value')[0].as_py()}") + else: + print("✗ Failed to retrieve record or wrong size") + + # Test 5: Get all records + all_records = store.get_all_records("experiments", "dataset_A") + if all_records: + print(f"✓ Retrieved all records: {len(all_records)} rows total") + unique_entries = len(set(all_records.column("entry_id").to_pylist())) + print(f" Unique entries: {unique_entries}") + + # Verify each entry appears exactly once + entry_counts = {} + for entry_id in all_records.column("entry_id").to_pylist(): + entry_counts[entry_id] = entry_counts.get(entry_id, 0) + 1 + + all_single = all(count == 1 for count in entry_counts.values()) + if all_single: + print( + "✓ Each entry appears exactly once (single-row constraint maintained)" + ) + else: + print("✗ Some entries appear multiple times!") + + print("\n=== Testing Overwrite Behavior ===") + + # Test 6: Overwrite existing single-row record + overwrite_data = create_single_row_record(valid_entries[0], value=999.0) + store.add_record( + "experiments", "dataset_A", valid_entries[0], overwrite_data + ) + print(f"✓ Overwrote existing record") + + # Verify overwrite + updated_record = store.get_record( + "experiments", "dataset_A", valid_entries[0] + ) + if updated_record and updated_record.column("value")[0].as_py() == 999.0: + print( + f"✓ Overwrite successful: new value = {updated_record.column('value')[0].as_py()}" + ) + + # Sync and show final stats + store.force_sync() + stats = store.get_stats() + print(f"\n=== Final Statistics ===") + print(f"Total records: {stats['total_records']}") + print(f"Loaded caches: {stats['loaded_source_caches']}") + print(f"Dirty caches: {stats['dirty_caches']}") + + finally: + store.shutdown() + + print("\n✓ Single-row constraint testing completed successfully!") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + demo_single_row_constraint() diff --git a/src/orcabridge/store/types.py b/src/orcabridge/store/types.py index 32092a0..912a2d1 100644 --- a/src/orcabridge/store/types.py +++ b/src/orcabridge/store/types.py @@ -1,7 +1,8 @@ from typing import Protocol, runtime_checkable -from orcabridge.types import Packet +from orcabridge.types import Tag, Packet import pyarrow as pa +import polars as pl @runtime_checkable @@ -26,21 +27,32 @@ def retrieve_memoized( @runtime_checkable -class ArrowBasedDataStore(Protocol): +class ArrowDataStore(Protocol): """ Protocol for data stores that can memoize and retrieve packets. This is used to define the interface for data stores like DirDataStore. """ def __init__(self, *args, **kwargs) -> None: ... - def memoize( + + def add_record( self, - function_name: str, - function_hash: str, - packet: pa.Table, - output_packet: pa.Table, + source_name: str, + source_id: str, + entry_id: str, + arrow_data: pa.Table, ) -> pa.Table: ... - def retrieve_memoized( - self, function_name: str, function_hash: str, packet: Packet - ) -> Packet | None: ... + def get_record( + self, source_name: str, source_id: str, entry_id: str + ) -> pa.Table | None: ... + + def get_all_records(self, source_name: str, source_id: str) -> pa.Table | None: + """Retrieve all records for a given source as a single table.""" + ... + + def get_all_records_as_polars( + self, source_name: str, source_id: str + ) -> pl.LazyFrame | None: + """Retrieve all records for a given source as a single Polars DataFrame.""" + ... diff --git a/src/orcabridge/types/__init__.py b/src/orcabridge/types/__init__.py index f82cfc9..08c8ae4 100644 --- a/src/orcabridge/types/__init__.py +++ b/src/orcabridge/types/__init__.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Protocol from typing_extensions import TypeAlias +from .core import TypeSpec, TypeHandler SUPPORTED_PYTHON_TYPES = (str, int, float, bool, bytes) @@ -17,7 +18,7 @@ # the top level tag is a mapping from string keys to values that can be a string or # an arbitrary depth of nested list of strings or None -Tag: TypeAlias = Mapping[str, TagValue] +Tag: TypeAlias = dict[str, TagValue] # a pathset is a path or an arbitrary depth of nested list of paths PathSet: TypeAlias = PathLike | Collection[PathLike | None] From 09ca3632c5241ff567e137795265ff027ee7c7be Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 16 Jun 2025 09:23:54 +0000 Subject: [PATCH 18/28] feat: add working implementation of semantic arrow hasher --- src/orcabridge/hashing/__init__.py | 6 + src/orcabridge/hashing/defaults.py | 10 + .../hashing/function_info_extractors.py | 65 ++++- .../hashing/semantic_arrow_hasher.py | 263 ++++++++++++++++++ 4 files changed, 330 insertions(+), 14 deletions(-) create mode 100644 src/orcabridge/hashing/semantic_arrow_hasher.py diff --git a/src/orcabridge/hashing/__init__.py b/src/orcabridge/hashing/__init__.py index e3e2445..a91b7f4 100644 --- a/src/orcabridge/hashing/__init__.py +++ b/src/orcabridge/hashing/__init__.py @@ -13,16 +13,22 @@ from .defaults import get_default_composite_file_hasher, get_default_object_hasher from .types import ( FileHasher, + PacketHasher, + ArrowPacketHasher, ObjectHasher, StringCacher, + FunctionInfoExtractor, CompositeFileHasher, ) __all__ = [ "FileHasher", + "PacketHasher", + "ArrowPacketHasher", "StringCacher", "ObjectHasher", "CompositeFileHasher", + "FunctionInfoExtractor", "hash_file", "hash_pathset", "hash_packet", diff --git a/src/orcabridge/hashing/defaults.py b/src/orcabridge/hashing/defaults.py index 3faca77..6f9abf3 100644 --- a/src/orcabridge/hashing/defaults.py +++ b/src/orcabridge/hashing/defaults.py @@ -6,6 +6,7 @@ from orcabridge.hashing.object_hashers import ObjectHasher from orcabridge.hashing.object_hashers import LegacyObjectHasher from orcabridge.hashing.function_info_extractors import FunctionInfoExtractorFactory +from orcabridge.hashing.semantic_arrow_hasher import SemanticArrowHasher, PathHasher def get_default_composite_file_hasher(with_cache=True) -> CompositeFileHasher: @@ -31,3 +32,12 @@ def get_default_object_hasher() -> ObjectHasher: return LegacyObjectHasher( char_count=32, function_info_extractor=function_info_extractor ) + + +def get_default_semantic_arrow_hasher( + chunk_size: int = 8192, handle_missing: str = "error" +) -> SemanticArrowHasher: + hasher = SemanticArrowHasher(chunk_size=chunk_size, handle_missing=handle_missing) + # register semantic hasher for Path + hasher.register_semantic_hasher("Path", PathHasher()) + return hasher diff --git a/src/orcabridge/hashing/function_info_extractors.py b/src/orcabridge/hashing/function_info_extractors.py index 74be127..4f9bb58 100644 --- a/src/orcabridge/hashing/function_info_extractors.py +++ b/src/orcabridge/hashing/function_info_extractors.py @@ -1,6 +1,8 @@ from .types import FunctionInfoExtractor from collections.abc import Callable from typing import Any, Literal +from orcabridge.types import TypeSpec +import inspect class FunctionNameExtractor: @@ -8,15 +10,16 @@ class FunctionNameExtractor: Extractor that only uses the function name for information extraction. """ - def extract_function_info(self, func: Callable[..., Any]) -> dict[str, Any]: - """ - Extracts information from the function based on its name. - """ + def extract_function_info( + self, + func: Callable[..., Any], + function_name: str | None = None, + input_types: TypeSpec | None = None, + output_types: TypeSpec | None = None, + ) -> dict[str, Any]: if not callable(func): raise TypeError("Provided object is not callable") - - # Use the function's name as the hash - function_name = func.__name__ if hasattr(func, "__name__") else str(func) + function_name = function_name or getattr(func, "__name__", str(func)) return {"name": function_name} @@ -25,16 +28,50 @@ class FunctionSignatureExtractor: Extractor that uses the function signature for information extraction. """ - def extract_function_info(self, func: Callable[..., Any]) -> dict[str, Any]: - """ - Extracts information from the function based on its signature. - """ + def __init__(self, include_module: bool = True, include_defaults: bool = True): + self.include_module = include_module + self.include_defaults = include_defaults + + # FIXME: Fix this implementation!! + # BUG: Currently this is not using the input_types and output_types parameters + def extract_function_info( + self, + func: Callable[..., Any], + function_name: str | None = None, + input_types: TypeSpec | None = None, + output_types: TypeSpec | None = None, + ) -> dict[str, Any]: if not callable(func): raise TypeError("Provided object is not callable") - # Use the function's signature as the hash - function_signature = str(func.__code__) - return {"signature": function_signature} + sig = inspect.signature(func) + + # Build the signature string + parts = {} + + # Add module if requested + if self.include_module and hasattr(func, "__module__"): + parts["module"] = func.__module__ + + # Add function name + parts["name"] = function_name or func.__name__ + + # Add parameters + param_strs = [] + for name, param in sig.parameters.items(): + param_str = str(param) + if not self.include_defaults and "=" in param_str: + param_str = param_str.split("=")[0].strip() + + param_strs.append(param_str) + + parts["params"] = ", ".join(param_strs) + + # Add return annotation if present + if sig.return_annotation is not inspect.Signature.empty: + parts["returns"] = sig.return_annotation + + return parts class FunctionInfoExtractorFactory: diff --git a/src/orcabridge/hashing/semantic_arrow_hasher.py b/src/orcabridge/hashing/semantic_arrow_hasher.py new file mode 100644 index 0000000..311c823 --- /dev/null +++ b/src/orcabridge/hashing/semantic_arrow_hasher.py @@ -0,0 +1,263 @@ +import hashlib +import os +from typing import Any, Protocol +from abc import ABC, abstractmethod +import pyarrow as pa +import pyarrow.ipc as ipc +from io import BytesIO + + +class SemanticTypeHasher(Protocol): + """Abstract base class for semantic type-specific hashers.""" + + @abstractmethod + def hash_column(self, column: pa.Array) -> bytes: + """Hash a column with this semantic type and return the hash bytes.""" + pass + + +class PathHasher(SemanticTypeHasher): + """Hasher for Path semantic type columns - hashes file contents.""" + + def __init__(self, chunk_size: int = 8192, handle_missing: str = "error"): + """ + Initialize PathHasher. + + Args: + chunk_size: Size of chunks to read files in bytes + handle_missing: How to handle missing files ('error', 'skip', 'null_hash') + """ + self.chunk_size = chunk_size + self.handle_missing = handle_missing + + def _hash_file_content(self, file_path: str) -> str: + """Hash the content of a single file and return hex string.""" + import os + + try: + if not os.path.exists(file_path): + if self.handle_missing == "error": + raise FileNotFoundError(f"File not found: {file_path}") + elif self.handle_missing == "skip": + return hashlib.sha256(b"").hexdigest() + elif self.handle_missing == "null_hash": + return hashlib.sha256(b"").hexdigest() + + hasher = hashlib.sha256() + + # Read file in chunks to handle large files efficiently + with open(file_path, "rb") as f: + while chunk := f.read(self.chunk_size): + hasher.update(chunk) + + return hasher.hexdigest() + + except (IOError, OSError, PermissionError) as e: + if self.handle_missing == "error": + raise IOError(f"Cannot read file {file_path}: {e}") + else: # skip or null_hash + error_msg = f"" + return hashlib.sha256(error_msg.encode("utf-8")).hexdigest() + + def hash_column(self, column: pa.Array) -> pa.Array: + """ + Replace path column with file content hashes. + Returns a new array where each path is replaced with its file content hash. + """ + + # Convert to python list for processing + paths = column.to_pylist() + + # Hash each file's content individually + content_hashes = [] + for path in paths: + if path is not None: + # Normalize path for consistency + normalized_path = os.path.normpath(str(path)) + file_content_hash = self._hash_file_content(normalized_path) + content_hashes.append(file_content_hash) + else: + content_hashes.append(None) # Preserve nulls + + # Return new array with content hashes instead of paths + return pa.array(content_hashes) + + +class SemanticArrowHasher: + """ + Stable hasher for Arrow tables with semantic type support. + + This hasher: + 1. Processes columns with special semantic types using dedicated hashers + 2. Sorts columns by name for deterministic ordering + 3. Uses Arrow IPC format for stable serialization + 4. Computes final hash of the processed packet + """ + + def __init__(self, chunk_size: int = 8192, handle_missing: str = "error"): + """ + Initialize SemanticArrowHasher. + + Args: + chunk_size: Size of chunks to read files in bytes + handle_missing: How to handle missing files ('error', 'skip', 'null_hash') + """ + self.chunk_size = chunk_size + self.handle_missing = handle_missing + self.semantic_type_hashers: dict[str, SemanticTypeHasher] = {} + + def register_semantic_hasher(self, semantic_type: str, hasher: SemanticTypeHasher): + """Register a custom hasher for a semantic type.""" + self.semantic_type_hashers[semantic_type] = hasher + + def _get_semantic_type(self, field: pa.Field) -> str | None: + """Extract semantic_type from field metadata.""" + if field.metadata is None: + return None + + metadata = field.metadata + if b"semantic_type" in metadata: + return metadata[b"semantic_type"].decode("utf-8") + elif "semantic_type" in metadata: + return metadata["semantic_type"] + + return None + + def _create_hash_column( + self, original_column: pa.Array, hash_bytes: bytes, original_field: pa.Field + ) -> tuple[pa.Array, pa.Field]: + """Create a new column containing the hash bytes.""" + # Create array of hash bytes (one hash value repeated for each row) + hash_value = hash_bytes.hex() # Convert to hex string for readability + hash_array = pa.array([hash_value] * len(original_column)) + + # Create new field with modified metadata + new_metadata = dict(original_field.metadata) if original_field.metadata else {} + new_metadata["original_semantic_type"] = new_metadata.get( + "semantic_type", "unknown" + ) + new_metadata["semantic_type"] = "hash" + new_metadata["hash_algorithm"] = "sha256" + + new_field = pa.field( + original_field.name, + pa.string(), # Hash stored as string + nullable=original_field.nullable, + metadata=new_metadata, + ) + + return hash_array, new_field + + def _process_table_columns(self, table: pa.Table) -> pa.Table: + # TODO: add copy of table-level metadata to the new table + """Process table columns, replacing semantic type columns with their hashes.""" + new_columns = [] + new_fields = [] + + for i, field in enumerate(table.schema): + column = table.column(i) + semantic_type = self._get_semantic_type(field) + + if semantic_type in self.semantic_type_hashers: + # Hash the column using the appropriate semantic hasher + hasher = self.semantic_type_hashers[semantic_type] + hash_bytes = hasher.hash_column(column) + + # Replace column with hash + hash_column, hash_field = self._create_hash_column( + column, hash_bytes, field + ) + new_columns.append(hash_column) + new_fields.append(hash_field) + else: + # Keep original column + new_columns.append(column) + new_fields.append(field) + + # Create new table with processed columns + new_schema = pa.schema(new_fields) + return pa.table(new_columns, schema=new_schema) + + def _sort_table_columns(self, table: pa.Table) -> pa.Table: + """Sort table columns by field name for deterministic ordering.""" + # Get column indices sorted by field name + sorted_indices = sorted( + range(len(table.schema)), key=lambda i: table.schema.field(i).name + ) + + # Reorder columns + sorted_columns = [table.column(i) for i in sorted_indices] + sorted_fields = [table.schema.field(i) for i in sorted_indices] + + sorted_schema = pa.schema(sorted_fields) + return pa.table(sorted_columns, schema=sorted_schema) + + def _serialize_table_ipc(self, table: pa.Table) -> bytes: + """Serialize table using Arrow IPC format for stable binary representation.""" + buffer = BytesIO() + + # Use IPC stream format for deterministic serialization + with ipc.new_stream(buffer, table.schema) as writer: + writer.write_table(table) + + return buffer.getvalue() + + def hash_table(self, table: pa.Table, algorithm: str = "sha256") -> str: + """ + Compute stable hash of Arrow table. + + Args: + table: Arrow table to hash + algorithm: Hash algorithm to use ('sha256', 'md5', etc.) + + Returns: + Hex string of the computed hash + """ + # Step 1: Process columns with semantic types + processed_table = self._process_table_columns(table) + + # Step 2: Sort columns by name for deterministic ordering + sorted_table = self._sort_table_columns(processed_table) + + # Step 3: Serialize using Arrow IPC format + serialized_bytes = self._serialize_table_ipc(sorted_table) + + # Step 4: Compute final hash + hasher = hashlib.new(algorithm) + hasher.update(serialized_bytes) + + return hasher.hexdigest() + + def hash_table_with_metadata( + self, table: pa.Table, algorithm: str = "sha256" + ) -> dict[str, Any]: + """ + Compute hash with additional metadata about the process. + + Returns: + Dictionary containing hash, metadata, and processing info + """ + processed_columns = [] + + # Track processing steps + for i, field in enumerate(table.schema): + semantic_type = self._get_semantic_type(field) + column_info = { + "name": field.name, + "original_type": str(field.type), + "semantic_type": semantic_type, + "processed": semantic_type in self.semantic_type_hashers, + } + processed_columns.append(column_info) + + # Compute hash + table_hash = self.hash_table(table, algorithm) + + return { + "hash": table_hash, + "algorithm": algorithm, + "num_rows": len(table), + "num_columns": len(table.schema), + "processed_columns": processed_columns, + "column_order": [field.name for field in table.schema], + } From 2f548b997e6c13388defec3597cf4dfe57ce0267 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 16 Jun 2025 09:24:52 +0000 Subject: [PATCH 19/28] refactor: update function info extractor --- src/orcabridge/hashing/types.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/orcabridge/hashing/types.py b/src/orcabridge/hashing/types.py index ddefe1f..4822433 100644 --- a/src/orcabridge/hashing/types.py +++ b/src/orcabridge/hashing/types.py @@ -5,7 +5,9 @@ from typing import Any, Protocol, runtime_checkable import uuid -from orcabridge.types import Packet, PathLike, PathSet +from orcabridge.types import Packet, PathLike, PathSet, TypeSpec + +import pyarrow as pa @runtime_checkable @@ -97,11 +99,17 @@ class SemanticHasher(Protocol): @runtime_checkable class PacketHasher(Protocol): - """Protocol for hashing packets (collections of pathsets).""" + """Protocol for hashing packets.""" def hash_packet(self, packet: Packet) -> str: ... +class ArrowPacketHasher: + """Protocol for hashing arrow packets.""" + + def hash_arrow_packet(self, packet: pa.Table) -> str: ... + + @runtime_checkable class StringCacher(Protocol): """Protocol for caching string key value pairs.""" @@ -124,4 +132,10 @@ class CompositeFileHasher(FileHasher, PathSetHasher, PacketHasher, Protocol): class FunctionInfoExtractor(Protocol): """Protocol for extracting function information.""" - def extract_function_info(self, func: Callable[..., Any]) -> dict[str, Any]: ... + def extract_function_info( + self, + func: Callable[..., Any], + function_name: str | None = None, + input_types: TypeSpec | None = None, + output_types: TypeSpec | None = None, + ) -> dict[str, Any]: ... From 65bc3fc03857753d3fe7236067277af24adaf68d Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Mon, 16 Jun 2025 09:25:18 +0000 Subject: [PATCH 20/28] refactor: update identity structure for function pod --- src/orcabridge/pod/core.py | 329 ++++++++++++++++++++++++++----------- 1 file changed, 236 insertions(+), 93 deletions(-) diff --git a/src/orcabridge/pod/core.py b/src/orcabridge/pod/core.py index ea9e92d..0ee8f39 100644 --- a/src/orcabridge/pod/core.py +++ b/src/orcabridge/pod/core.py @@ -3,6 +3,7 @@ import pickle import warnings from abc import abstractmethod +import pyarrow as pa import sys from collections.abc import Callable, Collection, Iterable, Iterator, Sequence from typing import ( @@ -14,12 +15,14 @@ from orcabridge.base import Operation from orcabridge.hashing import ( ObjectHasher, + ArrowPacketHasher, + FunctionInfoExtractor, get_function_signature, hash_function, get_default_object_hasher, ) from orcabridge.mappers import Join -from orcabridge.store import DataStore, NoOpDataStore +from orcabridge.store import DataStore, ArrowDataStore, NoOpDataStore from orcabridge.streams import SyncStream, SyncStreamFromGenerator from orcabridge.types import Packet, PathSet, PodFunction, Tag from orcabridge.types.default import default_registry @@ -105,6 +108,12 @@ class Pod(Operation): the pods act as pure functions which is a necessary condition to guarantee reproducibility. """ + def __init__( + self, error_handling: Literal["raise", "ignore", "warn"] = "raise", **kwargs + ): + super().__init__(**kwargs) + self.error_handling = error_handling + def process_stream(self, *streams: SyncStream) -> list[SyncStream]: """ Prepare the incoming streams for execution in the pod. This default implementation @@ -124,6 +133,37 @@ def __call__(self, *streams: SyncStream, **kwargs) -> SyncStream: stream = self.process_stream(*streams) return super().__call__(*stream, **kwargs) + def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet]: ... + + def forward(self, *streams: SyncStream) -> SyncStream: + # if multiple streams are provided, join them + if len(streams) > 1: + raise ValueError("Multiple streams should be joined before calling forward") + if len(streams) == 0: + raise ValueError("No streams provided to forward") + stream = streams[0] + + def generator() -> Iterator[tuple[Tag, Packet]]: + n_computed = 0 + for tag, packet in stream: + try: + tag, output_packet = self.call(tag, packet) + n_computed += 1 + logger.info(f"Computed item {n_computed}") + yield tag, output_packet + + except Exception as e: + logger.error(f"Error processing packet {packet}: {e}") + if self.error_handling == "raise": + raise e + elif self.error_handling == "ignore": + continue + elif self.error_handling == "warn": + warnings.warn(f"Error processing packet {packet}: {e}") + continue + + return SyncStreamFromGenerator(generator) + class FunctionPod(Pod): """ @@ -495,23 +535,23 @@ class TypedFunctionPod(Pod): def __init__( self, - function: PodFunction, - output_keys: Collection[str] | None = None, + function: Callable[..., Any], + output_keys: str | Collection[str] | None = None, function_name=None, input_types: TypeSpec | None = None, output_types: TypeSpec | Sequence[type] | None = None, - data_store: DataStore | None = None, - function_hasher: ObjectHasher | None = None, label: str | None = None, - skip_memoization_lookup: bool = False, - skip_memoization: bool = False, - error_handling: Literal["raise", "ignore", "warn"] = "raise", packet_type_registry=None, + function_info_extractor: FunctionInfoExtractor | None = None, **kwargs, ) -> None: super().__init__(label=label, **kwargs) self.function = function - self.output_keys = output_keys or [] + if output_keys is None: + output_keys = [] + if isinstance(output_keys, str): + output_keys = [output_keys] + self.output_keys = output_keys if function_name is None: if hasattr(self.function, "__name__"): function_name = getattr(self.function, "__name__") @@ -519,19 +559,13 @@ def __init__( raise ValueError( "function_name must be provided if function has no __name__ attribute" ) - self.function_name = function_name - self.data_store = data_store if data_store is not None else NoOpDataStore() - if function_hasher is None: - function_hasher = get_default_object_hasher() - self.function_hasher = function_hasher - self.skip_memoization_lookup = skip_memoization_lookup - self.skip_memoization = skip_memoization - self.error_handling = error_handling + if packet_type_registry is None: packet_type_registry = default_registry self.registry = packet_type_registry + self.function_info_extractor = function_info_extractor # extract input and output types from the function signature function_input_types, function_output_types = extract_function_data_types( @@ -562,19 +596,154 @@ def keys( tag_keys, _ = stream[0].keys() return tag_keys, tuple(self.output_keys) + def call(self, tag, packet) -> tuple[Tag, Packet]: + output_values: list["PathSet"] = [] + + values = self.function(**packet) + + if len(self.output_keys) == 0: + output_values = [] + elif len(self.output_keys) == 1: + output_values = [values] # type: ignore + elif isinstance(values, Iterable): + output_values = list(values) # type: ignore + elif len(self.output_keys) > 1: + raise ValueError( + "Values returned by function must be a pathlike or a sequence of pathlikes" + ) + + if len(output_values) != len(self.output_keys): + raise ValueError( + f"Number of output keys {len(self.output_keys)}:{self.output_keys} does not match number of values returned by function {len(output_values)}" + ) + + output_packet: Packet = {k: v for k, v in zip(self.output_keys, output_values)} + return tag, output_packet + + def identity_structure(self, *streams) -> Any: + # construct identity structure for the function + # if function_info_extractor is available, use that but substitute the function_name + if self.function_info_extractor is not None: + function_info = self.function_info_extractor.extract_function_info( + self.function, + function_name=self.function_name, + input_types=self.function_input_types, + output_types=self.function_output_types, + ) + else: + # use basic information only + function_info = { + "name": self.function_name, + "input_types": self.function_input_types, + "output_types": self.function_output_types, + } + function_info["output_keys"] = tuple(self.output_keys) + + return ( + self.__class__.__name__, + function_info, + ) + tuple(streams) + + +class CachedFunctionPod(Pod): + def __init__( + self, + function_pod: TypedFunctionPod, + object_hasher: ObjectHasher, + packet_hasher: ArrowPacketHasher, + result_store: ArrowDataStore, + tag_store: ArrowDataStore | None = None, + label: str | None = None, + skip_memoization_lookup: bool = False, + skip_memoization: bool = False, + skip_tag_record: bool = False, + error_handling: Literal["raise", "ignore", "warn"] = "raise", + **kwargs, + ) -> None: + super().__init__(label=label, error_handling=error_handling, **kwargs) + self.function_pod = function_pod + + self.object_hasher = object_hasher + self.packet_hasher = packet_hasher + self.result_store = result_store + self.tag_store = tag_store + + self.skip_memoization_lookup = skip_memoization_lookup + self.skip_memoization = skip_memoization + self.skip_tag_record = skip_tag_record + + # TODO: consider making this dynamic + self.function_pod_hash = self.object_hasher.hash_to_hex(self.function_pod) + + def get_packet_key(self, packet: Packet) -> str: + return self.packet_hasher.hash_arrow_packet( + self.function_pod.input_converter.to_arrow_table(packet) + ) + + # TODO: prepare a separate str and repr methods + def __repr__(self) -> str: + return f"Cached:{self.function_pod}" + + def keys( + self, *streams: SyncStream + ) -> tuple[Collection[str] | None, Collection[str] | None]: + return self.function_pod.keys(*streams) + def is_memoized(self, packet: Packet) -> bool: return self.retrieve_memoized(packet) is not None + def add_tag_record(self, tag: Tag, packet: Packet) -> Tag: + """ + Record the tag for the packet in the record store. + This is used to keep track of the tags associated with memoized packets. + """ + + return self._add_tag_record_with_packet_key(tag, self.get_packet_key(packet)) + + def _add_tag_record_with_packet_key(self, tag: Tag, packet_key: str) -> Tag: + if self.tag_store is None: + raise ValueError("Recording of tag requires tag_store but none provided") + + tag = tag.copy() # ensure we don't modify the original tag + tag["__packet_key"] = packet_key + + # convert tag to arrow table + table = pa.Table.from_pylist([tag]) + + entry_hash = self.packet_hasher.hash_arrow_packet(table) + + # TODO: add error handling + self.tag_store.add_record( + self.function_pod.function_name, self.function_pod_hash, entry_hash, table + ) + + return tag + def retrieve_memoized(self, packet: Packet) -> Packet | None: """ Retrieve a memoized packet from the data store. Returns None if no memoized packet is found. """ - return self.data_store.retrieve_memoized( - self.function_name, - self.content_hash(char_count=16), - self.input_converter.to_arrow_table(packet), + return self._retrieve_memoized_by_hash(self.get_packet_key(packet)) + + def _retrieve_memoized_by_hash(self, packet_hash: str) -> Packet | None: + """ + Retrieve a memoized result packet from the data store, looking up by hash + Returns None if no memoized packet is found. + """ + arrow_table = self.result_store.get_record( + self.function_pod.function_name, + self.function_pod_hash, + packet_hash, ) + if arrow_table is None: + return None + packets = self.function_pod.output_converter.from_arrow_table(arrow_table) + # since memoizing single packet, it should only contain one packet + assert len(packets) == 1, ( + f"Memoizing single packet return {len(packets)} packets!" + ) + return packets[0] def memoize( self, @@ -585,81 +754,55 @@ def memoize( Memoize the output packet in the data store. Returns the memoized packet. """ - return self.data_store.memoize( - self.function_name, - self.content_hash(char_count=16), # identity of this function pod - packet, - output_packet, - ) + return self._memoize_by_hash(self.get_packet_key(packet), output_packet) - def forward(self, *streams: SyncStream) -> SyncStream: - # if multiple streams are provided, join them - if len(streams) > 1: - raise ValueError("Multiple streams should be joined before calling forward") - if len(streams) == 0: - raise ValueError("No streams provided to forward") - stream = streams[0] - - def generator() -> Iterator[tuple[Tag, Packet]]: - n_computed = 0 - for tag, packet in stream: - output_values: list["PathSet"] = [] - try: - if not self.skip_memoization_lookup: - memoized_packet = self.retrieve_memoized(packet) - else: - memoized_packet = None - if memoized_packet is not None: - logger.info("Memoized packet found, skipping computation") - yield tag, memoized_packet - continue - values = self.function(**packet) - - if len(self.output_keys) == 0: - output_values = [] - elif len(self.output_keys) == 1: - output_values = [values] # type: ignore - elif isinstance(values, Iterable): - output_values = list(values) # type: ignore - elif len(self.output_keys) > 1: - raise ValueError( - "Values returned by function must be a pathlike or a sequence of pathlikes" - ) - - if len(output_values) != len(self.output_keys): - raise ValueError( - f"Number of output keys {len(self.output_keys)}:{self.output_keys} does not match number of values returned by function {len(output_values)}" - ) - except Exception as e: - logger.error(f"Error processing packet {packet}: {e}") - if self.error_handling == "raise": - raise e - elif self.error_handling == "ignore": - continue - elif self.error_handling == "warn": - warnings.warn(f"Error processing packet {packet}: {e}") - continue - - output_packet: Packet = { - k: v for k, v in zip(self.output_keys, output_values) - } + def _memoize_by_hash(self, packet_hash: str, output_packet: Packet) -> Packet: + """ + Memoize the output packet in the data store, looking up by hash. + Returns the memoized packet. + """ + packets = self.function_pod.output_converter.from_arrow_table( + self.result_store.add_record( + self.function_pod.function_name, + self.function_pod_hash, + packet_hash, + self.function_pod.output_converter.to_arrow_table(output_packet), + ) + ) + # since memoizing single packet, it should only contain one packet + assert len(packets) == 1, ( + f"Memoizing single packet return {len(packets)} packets!" + ) + return packets[0] + + def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet]: + packet_key = "" + if ( + not self.skip_tag_record + or not self.skip_memoization_lookup + or not self.skip_memoization + ): + packet_key = self.get_packet_key(packet) + + if not self.skip_tag_record and self.tag_store is not None: + self._add_tag_record_with_packet_key(tag, packet_key) + + if not self.skip_memoization_lookup: + memoized_packet = self._retrieve_memoized_by_hash(packet_key) + else: + memoized_packet = None + if memoized_packet is not None: + logger.info("Memoized packet found, skipping computation") + return tag, memoized_packet - if not self.skip_memoization: - # output packet may be modified by the memoization process - # e.g. if the output is a file, the path may be changed - output_packet = self.memoize(packet, output_packet) # type: ignore + tag, output_packet = self.function_pod.call(tag, packet) - n_computed += 1 - logger.info(f"Computed item {n_computed}") - yield tag, output_packet + if not self.skip_memoization: + # output packet may be modified by the memoization process + # e.g. if the output is a file, the path may be changed + output_packet = self.memoize(packet, output_packet) # type: ignore - return SyncStreamFromGenerator(generator) + return tag, output_packet def identity_structure(self, *streams) -> Any: - function_hash_value = self.function_hasher.hash_to_hex(self.function) - - return ( - self.__class__.__name__, - function_hash_value, - tuple(self.output_keys), - ) + tuple(streams) + return self.function_pod.identity_structure(*streams) From 241c7706eb3c8a20ee65012632c8d4ce653875aa Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 18 Jun 2025 00:12:38 +0000 Subject: [PATCH 21/28] refactor: move core modules into subpackage core --- src/orcabridge/__init__.py | 14 +- src/orcabridge/core/__init__.py | 0 src/orcabridge/{ => core}/base.py | 379 +++++++++----- .../{mappers.py => core/operators.py} | 466 +++++++++++------- src/orcabridge/{ => core}/sources.py | 42 +- src/orcabridge/{ => core}/streams.py | 20 +- src/orcabridge/{ => core}/tracker.py | 6 +- src/orcabridge/pod/core.py | 36 +- src/orcabridge/{ => store}/file.py | 0 src/orcabridge/utils/stream_utils.py | 57 ++- 10 files changed, 651 insertions(+), 369 deletions(-) create mode 100644 src/orcabridge/core/__init__.py rename src/orcabridge/{ => core}/base.py (52%) rename src/orcabridge/{mappers.py => core/operators.py} (77%) rename src/orcabridge/{ => core}/sources.py (95%) rename src/orcabridge/{ => core}/streams.py (84%) rename src/orcabridge/{ => core}/tracker.py (93%) rename src/orcabridge/{ => store}/file.py (100%) diff --git a/src/orcabridge/__init__.py b/src/orcabridge/__init__.py index 6da00a9..12ba536 100644 --- a/src/orcabridge/__init__.py +++ b/src/orcabridge/__init__.py @@ -1,9 +1,11 @@ -from . import hashing, mappers, pod, sources, store, streams -from .mappers import Join, MapPackets, MapTags, packet, tag +from .core import operators, sources, streams +from .core.streams import SyncStreamFromLists, SyncStreamFromGenerator +from . import hashing, pod, store +from .core.operators import Join, MapPackets, MapTags, packet, tag from .pod import FunctionPod, function_pod -from .sources import GlobSource +from .core.sources import GlobSource from .store import DirDataStore, SafeDirDataStore -from .pipeline import GraphTracker +from .pipeline.pipeline import GraphTracker DEFAULT_TRACKER = GraphTracker() DEFAULT_TRACKER.activate() @@ -13,8 +15,7 @@ "hashing", "store", "pod", - "dir_data_store", - "mappers", + "operators", "streams", "sources", "MapTags", @@ -29,4 +30,5 @@ "SafeDirDataStore", "DEFAULT_TRACKER", "SyncStreamFromLists", + "SyncStreamFromGenerator", ] diff --git a/src/orcabridge/core/__init__.py b/src/orcabridge/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/orcabridge/base.py b/src/orcabridge/core/base.py similarity index 52% rename from src/orcabridge/base.py rename to src/orcabridge/core/base.py index 6a73048..7c025e6 100644 --- a/src/orcabridge/base.py +++ b/src/orcabridge/core/base.py @@ -2,42 +2,39 @@ import threading from abc import ABC, abstractmethod from collections.abc import Callable, Collection, Iterator -from typing import Any +from typing import Any, TypeVar, Hashable + from orcabridge.hashing import HashableMixin -from orcabridge.types import Packet, Tag +from orcabridge.types import Packet, Tag, TypeSpec +from orcabridge.utils.stream_utils import get_typespec + +import logging + +logger = logging.getLogger(__name__) -class Operation(ABC, HashableMixin): + +class Kernel(ABC, HashableMixin): """ - Operation defines a generic operation that can be performed on a stream of data. - It is a base class for all operations that can be performed on a collection of streams + Kernel defines the fundamental unit of computation that can be performed on zero, one or more streams of data. + It is the base class for all computations and transformations that can be performed on a collection of streams (including an empty collection). - The operation is defined as a callable that takes a collection of streams as input - and returns a new stream as output. - Each invocation of the operation is assigned a unique ID. The corresponding invocation - information is stored as Invocation object and attached to the output stream. + A kernel is defined as a callable that takes a (possibly empty) collection of streams as the input + and returns a new stream as output (note that output stream is always singular). + Each "invocation" of the kernel on a collection of streams is assigned a unique ID. + The corresponding invocation information is stored as Invocation object and attached to the output stream + for computational graph tracking. """ def __init__(self, label: str | None = None, **kwargs) -> None: super().__init__(**kwargs) self._label = label - def keys( - self, *streams: "SyncStream" - ) -> tuple[Collection[str] | None, Collection[str] | None]: - """ - Returns the keys of the operation. - The first list contains the keys of the tags, and the second list contains the keys of the packets. - The keys are returned if it is feasible to do so, otherwise a tuple - (None, None) is returned to signify that the keys are not known. - """ - return None, None - @property def label(self) -> str: """ - Returns a human-readable label for this operation. + Returns a human-readable label for this kernel. Default implementation returns the provided label or class name if no label was provided. """ if self._label: @@ -45,35 +42,36 @@ def label(self) -> str: return self.__class__.__name__ @label.setter - def label(self, value: str) -> None: - self._label = value - - def identity_structure(self, *streams: "SyncStream") -> Any: - # Default implementation of identity_structure for the operation only - # concerns the operation class and the streams if present. Subclasses of - # Operations should override this method to provide a more meaningful - # representation of the operation. - return (self.__class__.__name__,) + tuple(streams) + def label(self, label: str) -> None: + self._label = label def __call__(self, *streams: "SyncStream", **kwargs) -> "SyncStream": - # trigger call on source if passed as stream - + # Special handling of Source: trigger call on source if passed as stream normalized_streams = [ stream() if isinstance(stream, Source) else stream for stream in streams ] + output_stream = self.forward(*normalized_streams, **kwargs) # create an invocation instance invocation = Invocation(self, normalized_streams) - # label the output_stream with the invocation information + # label the output_stream with the invocation that produced the stream output_stream.invocation = invocation - # register the invocation with active trackers + # register the invocation to all active trackers active_trackers = Tracker.get_active_trackers() for tracker in active_trackers: tracker.record(invocation) return output_stream + @abstractmethod + def forward(self, *streams: "SyncStream") -> "SyncStream": + """ + Trigger the main computation of the kernel on a collection of streams. + This method is called when the kernel is invoked with a collection of streams. + Subclasses should override this method to provide the kernel with its unique behavior + """ + def __repr__(self): return self.__class__.__name__ @@ -82,27 +80,86 @@ def __str__(self): return f"{self.__class__.__name__}({self._label})" return self.__class__.__name__ + def identity_structure(self, *streams: "SyncStream") -> Any: + # Default implementation of identity_structure for the kernel only + # concerns the kernel class and the streams if present. Subclasses of + # Kernels should override this method to provide a more meaningful + # representation of the kernel. Note that kernel must provide the notion + # of identity under possibly two distinct contexts: + # 1) identity of the kernel in itself when invoked without any stream + # 2) identity of the specific invocation of the kernel with a collection of streams + # While the latter technically corresponds to the identity of the invocation and not + # the kernel, only kernel can provide meaningful information as to the uniqueness of + # the invocation as only kernel would know if / how the input stream(s) alter the identity + # of the invocation. For example, if the kernel corresponds to an commutative computation + # and therefore kernel K(x, y) == K(y, x), then the identity structure must reflect the + # equivalence of the two by returning the same identity structure for both invocations. + # This can be achieved, for example, by returning a set over the streams instead of a tuple. + logger.warning( + f"Identity structure not implemented for {self.__class__.__name__}" + ) + return (self.__class__.__name__,) + tuple(streams) + + def keys( + self, *streams: "SyncStream", trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + """ + Returns the keys of the kernel output. + The first list contains the keys of the tags, and the second list contains the keys of the packets. + If trigger_run is False (default), the keys are returned only if it is feasible to do so without triggering + the chain of computations. If trigger_run is True, underlying computation may get triggered if doing so + would allow for the keys to be determined. Returns None for either part of the keys cannot be inferred. + + This should be overridden by the subclass if subclass can provide smarter inference based on the specific + implementation of the subclass and input streams. + """ + if not trigger_run: + return None, None + + # resolve to actually executing the stream to fetch the first element + tag, packet = next(iter(self(*streams))) + return tuple(tag.keys()), tuple(packet.keys()) + + def types( + self, *streams: "SyncStream", trigger_run: bool = False + ) -> tuple[TypeSpec | None, TypeSpec | None]: + """ + Returns the tag and packet typespec of the kernel output. + Each typespec consists of mapping from field name to Python type. + If trigger_run is False (default), the typespec info is returned only if it is feasible to do so without triggering + the chain of computation. If trigger_run is True, underlying computation may get triggered if doing so + would allow for the typespec to be determined. Returns None for either part of the typespec cannot be inferred. + """ + if not trigger_run: + return None, None + + tag, packet = next(iter(self(*streams))) + return get_typespec(tag), get_typespec(packet) + def claims_unique_tags( - self, *streams: "SyncStream", trigger_run: bool = True - ) -> bool: + self, *streams: "SyncStream", trigger_run: bool = False + ) -> bool | None: """ - Returns True if the operation claims that it has unique tags, False otherwise. - This method is useful for checking if the operation can be used as a source - for other operations that require unique tags. + Returns True if the kernel claims that it has unique tags, False otherwise. + False indicates that it can be inferred that the kernel does not have unique tags + based on the input streams and the kernel's implementation. None indicates that + whether it is unique or not cannot be determined with certainty. + If trigger_run is True, the kernel may trigger the computation to verify + the uniqueness of tags. If trigger_run is False, the kernel will return + None if it cannot determine the uniqueness of tags without triggering the computation. + This method is useful for checking if the kernel can be used as a source + for other kernels that require unique tags. Subclasses should override this method if it can provide reasonable check/guarantee - of unique tags. The default implementation returns False, meaning that the operation - does not claim to have unique tags. + of unique tags. The default implementation returns False, meaning that the kernel + does not claim to have unique tags, even if turns out to be unique. """ - return False - - @abstractmethod - def forward(self, *streams: "SyncStream") -> "SyncStream": ... + return None class Tracker(ABC): """ - A tracker is a class that can track the invocations of operations. Only "active" trackers - participate in tracking and its `record` method gets called on each invocation of an operation. + A tracker is a class that can track the invocations of kernels. Only "active" trackers + participate in tracking and its `record` method gets called on each invocation of a kernel. Multiple trackers can be active at any time. """ @@ -146,37 +203,32 @@ def record(self, invocation: "Invocation") -> None: ... # This is NOT an abstract class, but rather a concrete class that -# represents an invocation of an operation on a collection of streams. +# represents an invocation of a kernel on a collection of streams. class Invocation(HashableMixin): """ - This class represents an invocation of an operation on a collection of streams. - It contains the operation and the streams that were used in the invocation. + This class represents an invocation of a kernel on a collection of streams. + It contains the kernel and the streams that were used in the invocation. Note that the collection of streams may be empty, in which case the invocation - likely corresponds to a source operation. + likely corresponds to a source kernel. """ def __init__( self, - operation: Operation, - # TODO: technically this should be Stream to stay consistent with Stream interface + kernel: Kernel, + # TODO: technically this should be Stream to stay consistent with Stream interface. Update to Stream when AsyncStream is implemented streams: Collection["SyncStream"], ) -> None: - self.operation = operation + self.kernel = kernel self.streams = streams def __hash__(self) -> int: return super().__hash__() def __repr__(self) -> str: - return f"Invocation({self.operation}, ID:{hash(self)})" - - def keys(self) -> tuple[Collection[str] | None, Collection[str] | None]: - return self.operation.keys(*self.streams) + return f"Invocation(kernel={self.kernel}, streams={self.streams})" - def identity_structure(self) -> int: - # Identity of an invocation is entirely dependend on - # the operation's identity structure upon invocation - return self.operation.identity_structure(*self.streams) + def __str__(self) -> str: + return f"Invocation[ID:{self.__hash__()}]({self.kernel}, {self.streams})" def __eq__(self, other: Any) -> bool: if not isinstance(other, Invocation): @@ -187,22 +239,36 @@ def __lt__(self, other: Any) -> bool: if not isinstance(other, Invocation): return NotImplemented - if self.operation == other.operation: + if self.kernel == other.kernel: return hash(self) < hash(other) - # otherwise, order by the operation - return hash(self.operation) < hash(other.operation) + # otherwise, order by the kernel + return hash(self.kernel) < hash(other.kernel) + + # Pass-through implementations: these methods are implemented by "passing-through" the methods logic, + # simply invoking the corresopnding methods on the underlying kernel with the input streams - def claims_unique_tags(self, trigger_run: bool = True) -> bool: + def claims_unique_tags(self, trigger_run: bool = True) -> bool | None: """ Returns True if the invocation has unique tags, False otherwise. This method is useful for checking if the invocation can be used as a source - for other operations that require unique tags. None is returned if the + for other kernels that require unique tags. None is returned if the uniqueness of tags cannot be determined. - Note that uniqueness is best thought of as a "claim" by the operation + Note that uniqueness is best thought of as a "claim" by the kernel that it has unique tags. The actual uniqueness can only be verified by iterating over the streams and checking the tags. """ - return self.operation.claims_unique_tags(*self.streams, trigger_run=trigger_run) + return self.kernel.claims_unique_tags(*self.streams, trigger_run=trigger_run) + + def keys(self) -> tuple[Collection[str] | None, Collection[str] | None]: + return self.kernel.keys(*self.streams) + + def types(self) -> tuple[TypeSpec | None, TypeSpec | None]: + return self.kernel.types(*self.streams) + + def identity_structure(self) -> int: + # Identity of an invocation is entirely determined by the + # the kernel's identity structure upon invocation + return self.kernel.identity_structure(*self.streams) class Stream(ABC, HashableMixin): @@ -210,8 +276,8 @@ class Stream(ABC, HashableMixin): A stream is a collection of tagged-packets that are generated by an operation. The stream is iterable and can be used to access the packets in the stream. - A stream has propery `invocation` that is an instance of Invocation that generated the stream. - This may be None if the stream is not generated by an operation. + A stream has property `invocation` that is an instance of Invocation that generated the stream. + This may be None if the stream is not generated by a kernel (i.e. directly instantiated by a user). """ def __init__(self, label: str | None = None, **kwargs) -> None: @@ -219,16 +285,6 @@ def __init__(self, label: str | None = None, **kwargs) -> None: self._invocation: Invocation | None = None self._label = label - def identity_structure(self) -> Any: - """ - Identity structure of a stream is deferred to the identity structure - of the associated invocation, if present. - A bare stream without invocation has no well-defined identity structure. - """ - if self.invocation is not None: - return self.invocation.identity_structure() - return super().identity_structure() - @property def label(self) -> str: """ @@ -240,11 +296,20 @@ def label(self) -> str: if self._label is None: if self.invocation is not None: # use the invocation operation label - return self.invocation.operation.label + return self.invocation.kernel.label else: return self.__class__.__name__ return self._label + @label.setter + def label(self, label: str) -> None: + """ + Sets a human-readable label for this stream. + """ + if not isinstance(label, str): + raise TypeError("label must be a string") + self._label = label + @property def invocation(self) -> Invocation | None: return self._invocation @@ -255,7 +320,37 @@ def invocation(self, value: Invocation) -> None: raise TypeError("invocation field must be an instance of Invocation") self._invocation = value - def keys(self) -> tuple[Collection[str] | None, Collection[str] | None]: + @abstractmethod + def __iter__(self) -> Iterator[tuple[Tag, Packet]]: + raise NotImplementedError("Subclasses must implement __iter__ method") + + def flow(self) -> Collection[tuple[Tag, Packet]]: + """ + Flow everything through the stream, returning the entire collection of + (Tag, Packet) as a collection. This will tigger any upstream computation of the stream. + """ + return list(self) + + # --------------------- Recursive methods --------------------------- + # These methods form a step in the multi-class recursive invocation that follows the pattern of + # Stream -> Invocation -> Kernel -> Stream ... -> Invocation -> Kernel + # Most of the method logic would be found in Kernel's implementation of the method with + # Stream and Invocation simply serving as recursive steps + + def identity_structure(self) -> Any: + """ + Identity structure of a stream is deferred to the identity structure + of the associated invocation, if present. + A bare stream without invocation has no well-defined identity structure. + Specialized stream subclasses should override this method to provide more meaningful identity structure + """ + if self.invocation is not None: + return self.invocation.identity_structure() + return super().identity_structure() + + def keys( + self, *, trigger_run=False + ) -> tuple[Collection[str] | None, Collection[str] | None]: """ Returns the keys of the stream. The first list contains the keys of the tags, and the second list contains the keys of the packets. @@ -275,7 +370,28 @@ def keys(self) -> tuple[Collection[str] | None, Collection[str] | None]: tag, packet = next(iter(self)) return list(tag.keys()), list(packet.keys()) - def claims_unique_tags(self) -> bool: + def types(self, *, trigger_run=False) -> tuple[TypeSpec | None, TypeSpec | None]: + """ + Returns the keys of the stream. + The first list contains the keys of the tags, and the second list contains the keys of the packets. + The keys are returned on based-effort basis, and this invocation may trigger the + upstream computation of the stream. + Furthermore, the keys are not guaranteed to be identical across all packets in the stream. + This method is useful for inferring the keys of the stream without having to iterate + over the entire stream. + """ + tag_types, packet_types = None, None + if self.invocation is not None: + # if the stream is generated by an operation, use the keys from the invocation + tag_types, packet_types = self.invocation.types() + if not trigger_run or (tag_types is not None and packet_types is not None): + return tag_types, packet_types + # otherwise, use the keys from the first packet in the stream + # note that this may be computationally expensive + tag, packet = next(iter(self)) + return tag_types or get_typespec(tag), packet_types or get_typespec(packet) + + def claims_unique_tags(self, *, trigger_run=False) -> bool | None: """ Returns True if the stream has unique tags, False otherwise. This method is useful for checking if the stream can be used as a source @@ -285,19 +401,8 @@ def claims_unique_tags(self) -> bool: the information about unique tags. """ if self.invocation is not None: - return self.invocation.claims_unique_tags() - return False - - @abstractmethod - def __iter__(self) -> Iterator[tuple[Tag, Packet]]: - raise NotImplementedError("Subclasses must implement __iter__ method") - - def flow(self) -> Collection[tuple[Tag, Packet]]: - """ - Flow everything through the stream, returning the entire collection of - (Tag, Packet) as a collection. This will tigger any upstream computation of the stream. - """ - return list(self) + return self.invocation.claims_unique_tags(trigger_run=trigger_run) + return None class SyncStream(Stream): @@ -307,29 +412,6 @@ class SyncStream(Stream): will have to wait for the stream to finish before proceeding. """ - def claims_unique_tags(self, *, trigger_run=True) -> bool: - """ - For synchronous streams, if the stream is generated by an operation, the invocation - is consulted first to see if the uniqueness of tags can be determined without iterating over the stream. - If uniqueness cannot be determined from the invocation and if trigger_run is True, uniqueness is checked - by iterating over all elements and verifying uniqueness. - Consequently, this may trigger upstream computations and can be expensive. - If trigger_run is False, the method will return None if the uniqueness cannot be determined. - Since this consults the invocation, the resulting value is ultimately a claim and not a guarantee - of uniqueness. If guarantee of uniquess is required, then use has_unique_tags method - """ - result = super().claims_unique_tags() - if result is not None or not trigger_run: - return result - - # If the uniqueness cannot be determined from the invocation, iterate over the stream - unique_tags = set() - for idx, (tag, packet) in enumerate(self): - if tag in unique_tags: - return False - unique_tags.add(tag) - return True - def head(self, n: int = 5) -> None: """ Print the first n elements of the stream. @@ -359,7 +441,7 @@ def __rshift__( are returned in a new stream. """ # TODO: remove just in time import - from .mappers import MapPackets + from .operators import MapPackets if isinstance(transformer, dict): return MapPackets(transformer)(self) @@ -376,29 +458,52 @@ def __mul__(self, other: "SyncStream") -> "SyncStream": Returns a new stream that is the result joining with the other stream """ # TODO: remove just in time import - from .mappers import Join + from .operators import Join if not isinstance(other, SyncStream): raise TypeError("other must be a SyncStream") return Join()(self, other) + def claims_unique_tags(self, *, trigger_run=False) -> bool | None: + """ + For synchronous streams, if the stream is generated by an operation, the invocation + is consulted first to see if the uniqueness of tags can be determined without iterating over the stream. + If uniqueness cannot be determined from the invocation and if trigger_run is True, uniqueness is checked + by iterating over all elements and verifying uniqueness. + Consequently, this may trigger upstream computations and can be expensive. + If trigger_run is False, the method will return None if the uniqueness cannot be determined. + Since this consults the invocation, the resulting value is ultimately a claim and not a guarantee + of uniqueness. If guarantee of uniquess is required, then use has_unique_tags method + """ + result = super().claims_unique_tags(trigger_run=trigger_run) + if not trigger_run or result is not None: + return result -class Mapper(Operation): + # If the uniqueness cannot be determined from the invocation, iterate over the stream + unique_tags = set() + for tag, _ in self: + if tag in unique_tags: + return False + unique_tags.add(tag) + return True + + +class Operator(Kernel): """ A Mapper is an operation that does NOT generate new file content. - It is used to control the flow of data in the pipeline without modifying or creating new data (file). + It is used to control the flow of data in the pipeline without modifying or creating data content. """ -class Source(Operation, SyncStream): +class Source(Kernel, SyncStream): """ A base class for all sources in the system. A source can be seen as a special - type of Operation that takes no input and produces a stream of packets. - For convenience, the source itself is also a stream and thus can be used - as an input to other operations directly. - However, note that Source is still best thought of as an Operation that + type of kernel that takes no input and produces a stream of packets. + For convenience, the source itself can act as a stream and thus can be used + as an input to other kernels directly. + However, note that a source is still best thought of as a kernel that produces a stream of packets, rather than a stream itself. On almost all occasions, - Source acts as an Operation. + a source acts as a kernel. """ def __init__(self, label: str | None = None, **kwargs) -> None: @@ -406,4 +511,18 @@ def __init__(self, label: str | None = None, **kwargs) -> None: self._invocation = None def __iter__(self) -> Iterator[tuple[Tag, Packet]]: + """ + Simple iter method that allows for Source object to act as a stream. + """ yield from self() + + # TODO: consider adding stream-like behavior for determining keys and types + def keys( + self, *streams: "SyncStream", trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + return Kernel.keys(self, *streams, trigger_run=trigger_run) + + def types( + self, *streams: "SyncStream", trigger_run: bool = False + ) -> tuple[TypeSpec | None, TypeSpec | None]: + return Kernel.types(self, *streams, trigger_run=trigger_run) diff --git a/src/orcabridge/mappers.py b/src/orcabridge/core/operators.py similarity index 77% rename from src/orcabridge/mappers.py rename to src/orcabridge/core/operators.py index 4ced7ee..04a2795 100644 --- a/src/orcabridge/mappers.py +++ b/src/orcabridge/core/operators.py @@ -4,21 +4,22 @@ from typing import Any -from orcabridge.base import Mapper, SyncStream +from orcabridge.core.base import Operator, SyncStream from orcabridge.hashing import function_content_hash, hash_function -from orcabridge.streams import SyncStreamFromGenerator +from orcabridge.core.streams import SyncStreamFromGenerator from orcabridge.utils.stream_utils import ( batch_packet, batch_tags, check_packet_compatibility, join_tags, + fill_missing, + merge_typespecs, ) -from orcabridge.utils.stream_utils import fill_missing -from .types import Packet, Tag +from orcabridge.types import Packet, Tag, TypeSpec -class Repeat(Mapper): +class Repeat(Operator): """ A Mapper that repeats the packets in the stream a specified number of times. The repeat count is the number of times to repeat each packet. @@ -32,12 +33,28 @@ def __init__(self, repeat_count: int) -> None: raise ValueError("repeat_count must be non-negative") self.repeat_count = repeat_count + def forward(self, *streams: SyncStream) -> SyncStream: + if len(streams) != 1: + raise ValueError("Repeat operation requires exactly one stream") + + stream = streams[0] + + def generator() -> Iterator[tuple[Tag, Packet]]: + for tag, packet in stream: + for _ in range(self.repeat_count): + yield tag, packet + + return SyncStreamFromGenerator(generator) + + def __repr__(self) -> str: + return f"Repeat(count={self.repeat_count})" + def identity_structure(self, *streams) -> tuple[str, int, set[SyncStream]]: # Join does not depend on the order of the streams -- convert it onto a set return (self.__class__.__name__, self.repeat_count, set(streams)) def keys( - self, *streams: SyncStream + self, *streams: SyncStream, trigger_run=False ) -> tuple[Collection[str] | None, Collection[str] | None]: """ Repeat does not alter the keys of the stream. @@ -46,27 +63,23 @@ def keys( raise ValueError("Repeat operation requires exactly one stream") stream = streams[0] - return stream.keys() + return stream.keys(trigger_run=trigger_run) - def forward(self, *streams: SyncStream) -> SyncStream: + def types( + self, *streams: SyncStream, trigger_run=False + ) -> tuple[TypeSpec | None, TypeSpec | None]: + """ + Repeat does not alter the types of the stream. + """ if len(streams) != 1: raise ValueError("Repeat operation requires exactly one stream") stream = streams[0] - - def generator() -> Iterator[tuple[Tag, Packet]]: - for tag, packet in stream: - for _ in range(self.repeat_count): - yield tag, packet - - return SyncStreamFromGenerator(generator) - - def __repr__(self) -> str: - return f"Repeat(count={self.repeat_count})" + return stream.types(trigger_run=trigger_run) def claims_unique_tags( - self, *streams: SyncStream, trigger_run: bool = True - ) -> bool: + self, *streams: SyncStream, trigger_run: bool = False + ) -> bool | None: if len(streams) != 1: raise ValueError( "Repeat operation only supports operating on a single input stream" @@ -78,13 +91,28 @@ def claims_unique_tags( ) -class Merge(Mapper): +class Merge(Operator): + def forward(self, *streams: SyncStream) -> SyncStream: + tag_keys, packet_keys = self.keys(*streams) + + def generator() -> Iterator[tuple[Tag, Packet]]: + for tag, packet in chain(*streams): + # fill missing keys with None + tag = fill_missing(tag, tag_keys) + packet = fill_missing(packet, packet_keys) + yield tag, packet + + return SyncStreamFromGenerator(generator) + + def __repr__(self) -> str: + return "Merge()" + def identity_structure(self, *streams): # Merge does not depend on the order of the streams -- convert it onto a set return (self.__class__.__name__, set(streams)) def keys( - self, *streams: SyncStream + self, *streams: SyncStream, trigger_run: bool = False ) -> tuple[Collection[str] | None, Collection[str] | None]: """ Merge does not alter the keys of the stream. @@ -96,7 +124,7 @@ def keys( merged_packet_keys = set() for stream in streams: - tag_keys, packet_keys = stream.keys() + tag_keys, packet_keys = stream.keys(trigger_run=trigger_run) if tag_keys is not None: merged_tag_keys.update(set(tag_keys)) if packet_keys is not None: @@ -104,28 +132,41 @@ def keys( return list(merged_tag_keys), list(merged_packet_keys) - def forward(self, *streams: SyncStream) -> SyncStream: - tag_keys, packet_keys = self.keys(*streams) + def types( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[TypeSpec | None, TypeSpec | None]: + """ + Merge does not alter the types of the stream. + """ + if len(streams) < 2: + raise ValueError("Merge operation requires at least two streams") - def generator() -> Iterator[tuple[Tag, Packet]]: - for tag, packet in chain(*streams): - # fill missing keys with None - tag = fill_missing(tag, tag_keys) - packet = fill_missing(packet, packet_keys) - yield tag, packet + merged_tag_types: TypeSpec | None = {} + merged_packet_types: TypeSpec | None = {} - return SyncStreamFromGenerator(generator) + for stream in streams: + if merged_tag_types is None and merged_packet_types is None: + break + tag_types, packet_types = stream.types(trigger_run=trigger_run) + if merged_tag_types is not None and tag_types is not None: + merged_tag_types.update(tag_types) + else: + merged_tag_types = None + if merged_tag_types is not None and packet_types is not None: + merged_packet_types.update(packet_types) + else: + merged_tag_types = None - def __repr__(self) -> str: - return "Merge()" + return merged_tag_types, merged_packet_types def claims_unique_tags( self, *streams: SyncStream, trigger_run: bool = True - ) -> bool: + ) -> bool | None: """ Merge operation can only claim unique tags if all input streams have unique tags AND the tag keys are not identical across all streams. """ + # TODO: update implementation if len(streams) < 2: raise ValueError("Merge operation requires at least two streams") # Check if all streams have unique tags @@ -146,14 +187,16 @@ def claims_unique_tags( return True -class Join(Mapper): +class Join(Operator): def identity_structure(self, *streams): # Join does not depend on the order of the streams -- convert it onto a set return (self.__class__.__name__, set(streams)) - def keys(self, *streams: SyncStream) -> tuple[Collection[str], Collection[str]]: + def types( + self, *streams: SyncStream, trigger_run=False + ) -> tuple[TypeSpec | None, TypeSpec | None]: """ - Returns the keys of the operation. + Returns the types of the operation. The first list contains the keys of the tags, and the second list contains the keys of the packets. The keys are returned if it is feasible to do so, otherwise a tuple (None, None) is returned to signify that the keys are not known. @@ -162,15 +205,14 @@ def keys(self, *streams: SyncStream) -> tuple[Collection[str], Collection[str]]: raise ValueError("Join operation requires exactly two streams") left_stream, right_stream = streams - left_tag_keys, left_packet_keys = left_stream.keys() - right_tag_keys, right_packet_keys = right_stream.keys() + left_tag_types, left_packet_types = left_stream.types(trigger_run=False) + right_tag_types, right_packet_types = right_stream.types(trigger_run=False) - joined_tag_keys = list(set(left_tag_keys or []) | set(right_tag_keys or [])) - joined_packet_keys = list( - set(left_packet_keys or []) | set(right_packet_keys or []) - ) + # TODO: do error handling when merge fails + joined_tag_types = merge_typespecs(left_tag_types, right_tag_types) + joined_packet_types = merge_typespecs(left_packet_types, right_packet_types) - return joined_tag_keys, joined_packet_keys + return joined_tag_types, joined_packet_types def forward(self, *streams: SyncStream) -> SyncStream: """ @@ -182,7 +224,7 @@ def forward(self, *streams: SyncStream) -> SyncStream: left_stream, right_stream = streams - def generator(): + def generator() -> Iterator[tuple[Tag, Packet]]: for left_tag, left_packet in left_stream: for right_tag, right_packet in right_stream: if (joined_tag := join_tags(left_tag, right_tag)) is not None: @@ -198,37 +240,10 @@ def __repr__(self) -> str: return "Join()" -class FirstMatch(Mapper): - def identity_structure(self, *streams: SyncStream) -> tuple[str, set[SyncStream]]: - # Join does not depend on the order of the streams -- convert it onto a set - return (self.__class__.__name__, set(streams)) - - def keys( - self, *streams: SyncStream - ) -> tuple[Collection[str] | None, Collection[str] | None]: - """ - Returns the keys of the operation. - The first list contains the keys of the tags, and the second list contains the keys of the packets. - The keys are returned if it is feasible to do so, otherwise a tuple - (None, None) is returned to signify that the keys are not known. - """ - if len(streams) != 2: - raise ValueError("FirstMatch operation requires exactly two streams") - - left_stream, right_stream = streams - left_tag_keys, left_packet_keys = left_stream.keys() - right_tag_keys, right_packet_keys = right_stream.keys() - - joined_tag_keys = list(set(left_tag_keys or []) | set(right_tag_keys or [])) - joined_packet_keys = list( - set(left_packet_keys or []) | set(right_packet_keys or []) - ) - - return joined_tag_keys, joined_packet_keys - +class FirstMatch(Operator): def forward(self, *streams: SyncStream) -> SyncStream: """ - Joins two streams together based on their tags. + Joins two streams together based on their tags, returning at most one match for each tag. The resulting stream will contain all the tags from both streams. """ if len(streams) != 2: @@ -265,8 +280,71 @@ def generator(): def __repr__(self) -> str: return "MatchUpToN()" + def identity_structure(self, *streams: SyncStream) -> tuple[str, set[SyncStream]]: + # Join does not depend on the order of the streams -- convert it onto a set + return (self.__class__.__name__, set(streams)) + + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + """ + Returns the keys of the operation. + The first list contains the keys of the tags, and the second list contains the keys of the packets. + The keys are returned if it is feasible to do so, otherwise a tuple + (None, None) is returned to signify that the keys are not known. + """ + if len(streams) != 2: + raise ValueError("FirstMatch operation requires exactly two streams") + + left_stream, right_stream = streams + left_tag_keys, left_packet_keys = left_stream.keys(trigger_run=trigger_run) + right_tag_keys, right_packet_keys = right_stream.keys(trigger_run=trigger_run) + + # if any of the components return None -> resolve to default operation + if ( + left_tag_keys is None + or right_tag_keys is None + or left_packet_keys is None + or right_packet_keys is None + ): + return super().keys(*streams, trigger_run=trigger_run) + + joined_tag_keys = list(set(left_tag_keys) | set(right_tag_keys)) + joined_packet_keys = list(set(left_packet_keys) | set(right_packet_keys)) -class MapPackets(Mapper): + return joined_tag_keys, joined_packet_keys + + def types( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[TypeSpec | None, TypeSpec | None]: + """ + Returns the typespecs of tag and packet. + """ + if len(streams) != 2: + raise ValueError("FirstMatch operation requires exactly two streams") + + left_stream, right_stream = streams + left_tag_types, left_packet_types = left_stream.types(trigger_run=trigger_run) + right_tag_types, right_packet_types = right_stream.types( + trigger_run=trigger_run + ) + + # if any of the components return None -> resolve to default operation + if ( + left_tag_types is None + or right_tag_types is None + or left_packet_types is None + or right_packet_types is None + ): + return super().types(*streams, trigger_run=trigger_run) + + joined_tag_types = merge_typespecs(left_tag_types, right_tag_types) + joined_packet_types = merge_typespecs(left_packet_types, right_packet_types) + + return joined_tag_types, joined_packet_types + + +class MapPackets(Operator): """ A Mapper that maps the keys of the packet in the stream to new keys. The mapping is done using a dictionary that maps old keys to new keys. @@ -279,8 +357,37 @@ def __init__(self, key_map: dict[str, str], drop_unmapped: bool = True) -> None: self.key_map = key_map self.drop_unmapped = drop_unmapped + def forward(self, *streams: SyncStream) -> SyncStream: + if len(streams) != 1: + raise ValueError("MapPackets operation requires exactly one stream") + + stream = streams[0] + + def generator(): + for tag, packet in stream: + if self.drop_unmapped: + packet = { + v: packet[k] for k, v in self.key_map.items() if k in packet + } + else: + packet = {self.key_map.get(k, k): v for k, v in packet.items()} + yield tag, packet + + return SyncStreamFromGenerator(generator) + + def __repr__(self) -> str: + map_repr = ", ".join([f"{k} ⇒ {v}" for k, v in self.key_map.items()]) + return f"packets({map_repr})" + + def identity_structure(self, *streams): + return ( + self.__class__.__name__, + self.key_map, + self.drop_unmapped, + ) + tuple(streams) + def keys( - self, *streams: SyncStream + self, *streams: SyncStream, trigger_run: bool = False ) -> tuple[Collection[str] | None, Collection[str] | None]: """ Returns the keys of the operation. @@ -291,9 +398,9 @@ def keys( raise ValueError("MapPackets operation requires exactly one stream") stream = streams[0] - tag_keys, packet_keys = stream.keys() + tag_keys, packet_keys = stream.keys(trigger_run=trigger_run) if tag_keys is None or packet_keys is None: - return None, None + return super().keys(trigger_run=trigger_run) if self.drop_unmapped: # If drop_unmapped is True, we only keep the keys that are in the mapping @@ -305,37 +412,36 @@ def keys( return tag_keys, mapped_packet_keys - def forward(self, *streams: SyncStream) -> SyncStream: + def types( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[TypeSpec | None, TypeSpec | None]: + """ + Returns the types of the operation. + The first list contains the types of the tags, and the second list contains the types of the packets. + The types are inferred based on the first (tag, packet) pair in the stream. + """ if len(streams) != 1: raise ValueError("MapPackets operation requires exactly one stream") stream = streams[0] + tag_types, packet_types = stream.types(trigger_run=trigger_run) + if tag_types is None or packet_types is None: + return super().types(trigger_run=trigger_run) - def generator(): - for tag, packet in stream: - if self.drop_unmapped: - packet = { - v: packet[k] for k, v in self.key_map.items() if k in packet - } - else: - packet = {self.key_map.get(k, k): v for k, v in packet.items()} - yield tag, packet - - return SyncStreamFromGenerator(generator) - - def __repr__(self) -> str: - map_repr = ", ".join([f"{k} ⇒ {v}" for k, v in self.key_map.items()]) - return f"packets({map_repr})" + if self.drop_unmapped: + # If drop_unmapped is True, we only keep the keys that are in the mapping + mapped_packet_types = { + self.key_map[k]: v for k, v in packet_types.items() if k in self.key_map + } + else: + mapped_packet_types = { + self.key_map.get(k, k): v for k, v in packet_types.items() + } - def identity_structure(self, *streams): - return ( - self.__class__.__name__, - self.key_map, - self.drop_unmapped, - ) + tuple(streams) + return tag_types, mapped_packet_types -class DefaultTag(Mapper): +class DefaultTag(Operator): """ A Mapper that adds a default tag to the packets in the stream. The default tag is added to all packets in the stream. If the @@ -346,22 +452,6 @@ def __init__(self, default_tag: Tag) -> None: super().__init__() self.default_tag = default_tag - def keys( - self, *streams: SyncStream - ) -> tuple[Collection[str] | None, Collection[str] | None]: - """ - Returns the keys of the operation. - The first list contains the keys of the tags, and the second list contains the keys of the packets. - The keys are inferred based on the first (tag, packet) pair in the stream. - """ - if len(streams) != 1: - raise ValueError("DefaultTag operation requires exactly one stream") - - stream = streams[0] - tag_keys, packet_keys = stream.keys() - tag_keys = list(set(tag_keys or []) | set(self.default_tag.keys())) - return tag_keys, packet_keys - def forward(self, *streams: SyncStream) -> SyncStream: if len(streams) != 1: raise ValueError("DefaultTag operation requires exactly one stream") @@ -377,22 +467,8 @@ def generator() -> Iterator[tuple[Tag, Packet]]: def __repr__(self) -> str: return f"DefaultTag({self.default_tag})" - -class MapTags(Mapper): - """ - A Mapper that maps the tags of the packet in the stream to new tags. Packet remains unchanged. - The mapping is done using a dictionary that maps old tags to new tags. - If a tag is not in the mapping, it will be dropped from the element unless - drop_unmapped=False, in which case unmapped tags will be retained. - """ - - def __init__(self, key_map: dict[str, str], drop_unmapped: bool = True) -> None: - super().__init__() - self.key_map = key_map - self.drop_unmapped = drop_unmapped - def keys( - self, *streams: SyncStream + self, *streams: SyncStream, trigger_run: bool = False ) -> tuple[Collection[str] | None, Collection[str] | None]: """ Returns the keys of the operation. @@ -400,20 +476,28 @@ def keys( The keys are inferred based on the first (tag, packet) pair in the stream. """ if len(streams) != 1: - raise ValueError("MapTags operation requires exactly one stream") + raise ValueError("DefaultTag operation requires exactly one stream") stream = streams[0] - tag_keys, packet_keys = stream.keys() + tag_keys, packet_keys = stream.keys(trigger_run=trigger_run) if tag_keys is None or packet_keys is None: - return None, None + return super().keys(trigger_run=trigger_run) + tag_keys = list(set(tag_keys) | set(self.default_tag.keys())) + return tag_keys, packet_keys - if self.drop_unmapped: - # If drop_unmapped is True, we only keep the keys that are in the mapping - mapped_tag_keys = [self.key_map[k] for k in tag_keys if k in self.key_map] - else: - mapped_tag_keys = [self.key_map.get(k, k) for k in tag_keys] - return mapped_tag_keys, packet_keys +class MapTags(Operator): + """ + A Mapper that maps the tags of the packet in the stream to new tags. Packet remains unchanged. + The mapping is done using a dictionary that maps old tags to new tags. + If a tag is not in the mapping, it will be dropped from the element unless + drop_unmapped=False, in which case unmapped tags will be retained. + """ + + def __init__(self, key_map: dict[str, str], drop_unmapped: bool = True) -> None: + super().__init__() + self.key_map = key_map + self.drop_unmapped = drop_unmapped def forward(self, *streams: SyncStream) -> SyncStream: if len(streams) != 1: @@ -442,8 +526,32 @@ def identity_structure(self, *streams): self.drop_unmapped, ) + tuple(streams) + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + """ + Returns the keys of the operation. + The first list contains the keys of the tags, and the second list contains the keys of the packets. + The keys are inferred based on the first (tag, packet) pair in the stream. + """ + if len(streams) != 1: + raise ValueError("MapTags operation requires exactly one stream") + + stream = streams[0] + tag_keys, packet_keys = stream.keys(trigger_run=trigger_run) + if tag_keys is None or packet_keys is None: + return super().keys(trigger_run=trigger_run) + + if self.drop_unmapped: + # If drop_unmapped is True, we only keep the keys that are in the mapping + mapped_tag_keys = [self.key_map[k] for k in tag_keys if k in self.key_map] + else: + mapped_tag_keys = [self.key_map.get(k, k) for k in tag_keys] -class Filter(Mapper): + return mapped_tag_keys, packet_keys + + +class Filter(Operator): """ A Mapper that filters the packets in the stream based on a predicate function. Predicate function should take two arguments: the tag and the packet, both as dictionaries. @@ -454,18 +562,6 @@ def __init__(self, predicate: Callable[[Tag, Packet], bool]): super().__init__() self.predicate = predicate - def keys( - self, *streams: SyncStream - ) -> tuple[Collection[str] | None, Collection[str] | None]: - """ - Filter does not alter the keys of the stream. - """ - if len(streams) != 1: - raise ValueError("Filter operation requires exactly one stream") - - stream = streams[0] - return stream.keys() - def forward(self, *streams: SyncStream) -> SyncStream: if len(streams) != 1: raise ValueError("Filter operation requires exactly one stream") @@ -488,8 +584,20 @@ def identity_structure(self, *streams): function_content_hash(self.predicate), ) + tuple(streams) + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + """ + Filter does not alter the keys of the stream. + """ + if len(streams) != 1: + raise ValueError("Filter operation requires exactly one stream") + + stream = streams[0] + return stream.keys(trigger_run=trigger_run) + -class Transform(Mapper): +class Transform(Operator): """ A Mapper that transforms the packets in the stream based on a transformation function. The transformation function should take two arguments: the tag and the packet, both as dictionaries. @@ -522,7 +630,7 @@ def identity_structure(self, *streams): ) + tuple(streams) -class Batch(Mapper): +class Batch(Operator): """ A Mapper that batches the packets in the stream based on a batch size. The batch size is the number of packets to include in each batch. @@ -543,18 +651,6 @@ def __init__( self.tag_processor = tag_processor self.drop_last = drop_last - def keys( - self, *streams: SyncStream - ) -> tuple[Collection[str] | None, Collection[str] | None]: - """ - Batch does not alter the keys of the stream. - """ - if len(streams) != 1: - raise ValueError("Batch operation requires exactly one stream") - - stream = streams[0] - return stream.keys() - def forward(self, *streams: SyncStream) -> SyncStream: if len(streams) != 1: raise ValueError("Batch operation requires exactly one stream") @@ -590,8 +686,20 @@ def identity_structure(self, *streams): self.drop_last, ) + tuple(streams) + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + """ + Batch does not alter the keys of the stream. + """ + if len(streams) != 1: + raise ValueError("Batch operation requires exactly one stream") + + stream = streams[0] + return stream.keys(trigger_run=trigger_run) -class GroupBy(Mapper): + +class GroupBy(Operator): def __init__( self, group_keys: Collection[str] | None = None, @@ -604,12 +712,6 @@ def __init__( self.reduce_keys = reduce_keys self.selection_function = selection_function - def identity_structure(self, *streams: SyncStream) -> Any: - struct = (self.__class__.__name__, self.group_keys, self.reduce_keys) - if self.selection_function is not None: - struct += (hash_function(self.selection_function),) - return struct + tuple(streams) - def forward(self, *streams: SyncStream) -> SyncStream: if len(streams) != 1: raise ValueError("GroupBy operation requires exactly one stream") @@ -640,7 +742,7 @@ def generator() -> Iterator[tuple[Tag, Packet]]: # create a new tag that combines the group keys # if reduce_keys is True, we only keep the group keys as a singular value - new_tag = {} + new_tag: Tag = {} if self.reduce_keys: new_tag = {k: key[i] for i, k in enumerate(group_keys)} remaining_keys = set(stream_keys) - set(group_keys) @@ -651,15 +753,21 @@ def generator() -> Iterator[tuple[Tag, Packet]]: if k not in new_tag: new_tag[k] = [t.get(k, None) for t, _ in packets] # combine all packets into a single packet - combined_packet = { + combined_packet: Packet = { k: [p.get(k, None) for _, p in packets] for k in packet_keys } yield new_tag, combined_packet return SyncStreamFromGenerator(generator) + def identity_structure(self, *streams: SyncStream) -> Any: + struct = (self.__class__.__name__, self.group_keys, self.reduce_keys) + if self.selection_function is not None: + struct += (hash_function(self.selection_function),) + return struct + tuple(streams) + -class CacheStream(Mapper): +class CacheStream(Operator): """ A Mapper that caches the packets in the stream, thus avoiding upstream recomputation. The cache is filled the first time the stream is iterated over. diff --git a/src/orcabridge/sources.py b/src/orcabridge/core/sources.py similarity index 95% rename from src/orcabridge/sources.py rename to src/orcabridge/core/sources.py index 71758bd..235372c 100644 --- a/src/orcabridge/sources.py +++ b/src/orcabridge/core/sources.py @@ -3,9 +3,9 @@ from pathlib import Path from typing import Any, Literal -from orcabridge.base import Source +from orcabridge.core.base import Source from orcabridge.hashing import hash_function -from orcabridge.streams import SyncStream, SyncStreamFromGenerator +from orcabridge.core.streams import SyncStream, SyncStreamFromGenerator from orcabridge.types import Packet, Tag @@ -69,24 +69,6 @@ def __init__( self.tag_function: Callable[[PathLike], Tag] = tag_function self.tag_function_hash_mode = tag_function_hash_mode - def keys( - self, *streams: SyncStream - ) -> tuple[Collection[str] | None, Collection[str] | None]: - """ - Returns the keys of the stream. The keys are the names of the packets - in the stream. The keys are used to identify the packets in the stream. - If expected_keys are provided, they will be used instead of the default keys. - """ - if len(streams) != 0: - raise ValueError( - "GlobSource does not support forwarding streams. " - "It generates its own stream from the file system." - ) - - if self.expected_tag_keys is not None: - return tuple(self.expected_tag_keys), (self.name,) - return super().keys() - def forward(self, *streams: SyncStream) -> SyncStream: if len(streams) != 0: raise ValueError( @@ -126,9 +108,27 @@ def identity_structure(self, *streams) -> Any: tag_function_hash, ) + tuple(streams) + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + """ + Returns the keys of the stream. The keys are the names of the packets + in the stream. The keys are used to identify the packets in the stream. + If expected_keys are provided, they will be used instead of the default keys. + """ + if len(streams) != 0: + raise ValueError( + "GlobSource does not support forwarding streams. " + "It generates its own stream from the file system." + ) + + if self.expected_tag_keys is not None: + return tuple(self.expected_tag_keys), (self.name,) + return super().keys(trigger_run=trigger_run) + def claims_unique_tags( self, *streams: "SyncStream", trigger_run: bool = True - ) -> bool: + ) -> bool | None: if len(streams) != 0: raise ValueError( "GlobSource does not support forwarding streams. " diff --git a/src/orcabridge/streams.py b/src/orcabridge/core/streams.py similarity index 84% rename from src/orcabridge/streams.py rename to src/orcabridge/core/streams.py index 03100c7..4f4f3c3 100644 --- a/src/orcabridge/streams.py +++ b/src/orcabridge/core/streams.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Collection, Iterator -from orcabridge.base import SyncStream +from orcabridge.core.base import SyncStream from orcabridge.types import Packet, Tag @@ -31,9 +31,11 @@ def __init__( "Either tags and packets or paired must be provided to SyncStreamFromLists" ) - def keys(self) -> tuple[Collection[str] | None, Collection[str] | None]: + def keys( + self, *, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: if self.tag_keys is None or self.packet_keys is None: - return super().keys() + return super().keys(trigger_run=trigger_run) # If the keys are already set, return them return self.tag_keys.copy(), self.packet_keys.copy() @@ -58,11 +60,13 @@ def __init__( self.packet_keys = packet_keys self.generator_factory = generator_factory - def keys(self) -> tuple[Collection[str] | None, Collection[str] | None]: + def __iter__(self) -> Iterator[tuple[Tag, Packet]]: + yield from self.generator_factory() + + def keys( + self, *, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: if self.tag_keys is None or self.packet_keys is None: - return super().keys() + return super().keys(trigger_run=trigger_run) # If the keys are already set, return them return self.tag_keys.copy(), self.packet_keys.copy() - - def __iter__(self) -> Iterator[tuple[Tag, Packet]]: - yield from self.generator_factory() diff --git a/src/orcabridge/tracker.py b/src/orcabridge/core/tracker.py similarity index 93% rename from src/orcabridge/tracker.py rename to src/orcabridge/core/tracker.py index e8224a2..f30984d 100644 --- a/src/orcabridge/tracker.py +++ b/src/orcabridge/core/tracker.py @@ -1,4 +1,4 @@ -from orcabridge.base import Invocation, Operation, Tracker +from orcabridge.core.base import Invocation, Kernel, Tracker import networkx as nx import matplotlib.pyplot as plt @@ -13,14 +13,14 @@ class GraphTracker(Tracker): def __init__(self) -> None: super().__init__() - self.invocation_lut: dict[Operation, list[Invocation]] = {} + self.invocation_lut: dict[Kernel, list[Invocation]] = {} def record(self, invocation: Invocation) -> None: invocation_list = self.invocation_lut.setdefault(invocation.operation, []) if invocation not in invocation_list: invocation_list.append(invocation) - def reset(self) -> dict[Operation, list[Invocation]]: + def reset(self) -> dict[Kernel, list[Invocation]]: """ Reset the tracker and return the recorded invocations. """ diff --git a/src/orcabridge/pod/core.py b/src/orcabridge/pod/core.py index 0ee8f39..396b4d9 100644 --- a/src/orcabridge/pod/core.py +++ b/src/orcabridge/pod/core.py @@ -12,7 +12,7 @@ ) from orcabridge.types.registry import PacketConverter -from orcabridge.base import Operation +from orcabridge.core.base import Kernel from orcabridge.hashing import ( ObjectHasher, ArrowPacketHasher, @@ -21,13 +21,13 @@ hash_function, get_default_object_hasher, ) -from orcabridge.mappers import Join +from orcabridge.core.operators import Join from orcabridge.store import DataStore, ArrowDataStore, NoOpDataStore -from orcabridge.streams import SyncStream, SyncStreamFromGenerator -from orcabridge.types import Packet, PathSet, PodFunction, Tag +from orcabridge.core.streams import SyncStream, SyncStreamFromGenerator +from orcabridge.types import Packet, PathSet, PodFunction, Tag, TypeSpec + from orcabridge.types.default import default_registry from orcabridge.types.inference import ( - TypeSpec, extract_function_data_types, verify_against_typespec, check_typespec_compatibility, @@ -100,7 +100,7 @@ def decorator(func) -> FunctionPod: return decorator -class Pod(Operation): +class Pod(Kernel): """ An (abstract) base class for all pods. A pod can be seen as a special type of operation that only operates on the packet content without reading tags. Consequently, no operation @@ -217,10 +217,10 @@ def __repr__(self) -> str: return f"FunctionPod:{func_sig} ⇒ {self.output_keys}" def keys( - self, *streams: SyncStream + self, *streams: SyncStream, trigger_run: bool = False ) -> tuple[Collection[str] | None, Collection[str] | None]: stream = self.process_stream(*streams) - tag_keys, _ = stream[0].keys() + tag_keys, _ = stream[0].keys(trigger_run=trigger_run) return tag_keys, tuple(self.output_keys) def is_memoized(self, packet: Packet) -> bool: @@ -589,13 +589,6 @@ def __repr__(self) -> str: func_sig = get_function_signature(self.function) return f"FunctionPod:{func_sig} ⇒ {self.output_keys}" - def keys( - self, *streams: SyncStream - ) -> tuple[Collection[str] | None, Collection[str] | None]: - stream = self.process_stream(*streams) - tag_keys, _ = stream[0].keys() - return tag_keys, tuple(self.output_keys) - def call(self, tag, packet) -> tuple[Tag, Packet]: output_values: list["PathSet"] = [] @@ -644,6 +637,13 @@ def identity_structure(self, *streams) -> Any: function_info, ) + tuple(streams) + def keys( + self, *streams: SyncStream, trigger_run: bool = False + ) -> tuple[Collection[str] | None, Collection[str] | None]: + stream = self.process_stream(*streams) + tag_keys, _ = stream[0].keys(trigger_run=trigger_run) + return tag_keys, tuple(self.output_keys) + class CachedFunctionPod(Pod): def __init__( @@ -685,9 +685,9 @@ def __repr__(self) -> str: return f"Cached:{self.function_pod}" def keys( - self, *streams: SyncStream + self, *streams: SyncStream, trigger_run: bool = False ) -> tuple[Collection[str] | None, Collection[str] | None]: - return self.function_pod.keys(*streams) + return self.function_pod.keys(*streams, trigger_run=trigger_run) def is_memoized(self, packet: Packet) -> bool: return self.retrieve_memoized(packet) is not None @@ -704,7 +704,7 @@ def _add_tag_record_with_packet_key(self, tag: Tag, packet_key: str) -> Tag: if self.tag_store is None: raise ValueError("Recording of tag requires tag_store but none provided") - tag = tag.copy() # ensure we don't modify the original tag + tag = dict(tag) # ensure we don't modify the original tag tag["__packet_key"] = packet_key # convert tag to arrow table diff --git a/src/orcabridge/file.py b/src/orcabridge/store/file.py similarity index 100% rename from src/orcabridge/file.py rename to src/orcabridge/store/file.py diff --git a/src/orcabridge/utils/stream_utils.py b/src/orcabridge/utils/stream_utils.py index 611e94e..a762b06 100644 --- a/src/orcabridge/utils/stream_utils.py +++ b/src/orcabridge/utils/stream_utils.py @@ -3,14 +3,62 @@ """ from collections.abc import Collection, Mapping -from typing import TypeVar +from typing import TypeVar, Hashable, Any -from orcabridge.types import Packet, Tag +from orcabridge.types import Packet, Tag, TypeSpec -K = TypeVar("K") + +K = TypeVar("K", bound=Hashable) V = TypeVar("V") +def get_typespec(dict: Mapping) -> TypeSpec: + """ + Returns a TypeSpec for the given dictionary. + The TypeSpec is a mapping from field name to Python type. + """ + return {key: type(value) for key, value in dict.items()} + + +def get_compatible_type(type1: Any, type2: Any) -> Any: + if type1 is type2: + return type1 + if issubclass(type1, type2): + return type2 + if issubclass(type2, type1): + return type1 + raise TypeError(f"Types {type1} and {type2} are not compatible") + + +def merge_dicts(left: dict[K, V], right: dict[K, V]) -> dict[K, V]: + merged = left.copy() + for key, right_value in right.items(): + if key in merged: + if merged[key] != right_value: + raise ValueError( + f"Conflicting values for key '{key}': {merged[key]} vs {right_value}" + ) + else: + merged[key] = right_value + return merged + + +def merge_typespecs(left: TypeSpec | None, right: TypeSpec | None) -> TypeSpec | None: + if left is None: + return right + if right is None: + return left + # Merge the two TypeSpecs but raise an error if conflicts in types are found + merged = dict(left) + for key, right_type in right.items(): + merged[key] = ( + get_compatible_type(merged[key], right_type) + if key in merged + else right_type + ) + return merged + + def common_elements(*values) -> Collection[str]: """ Returns the common keys between all lists of values. The identified common elements are @@ -26,10 +74,11 @@ def common_elements(*values) -> Collection[str]: return common_keys -def join_tags(tag1: Mapping[K, V], tag2: Mapping[K, V]) -> Mapping[K, V] | None: +def join_tags(tag1: Mapping[K, V], tag2: Mapping[K, V]) -> dict[K, V] | None: """ Joins two tags together. If the tags have the same key, the value must be the same or None will be returned. """ + # create a dict copy of tag1 joined_tag = dict(tag1) for k, v in tag2.items(): if k in joined_tag and joined_tag[k] != v: From 136e9a1645c133c04b8ce949831dfc636d4e8b94 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 18 Jun 2025 01:20:55 +0000 Subject: [PATCH 22/28] refactor: use mapping instead of dict --- src/orcabridge/types/__init__.py | 4 ++-- src/orcabridge/types/core.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/orcabridge/types/__init__.py b/src/orcabridge/types/__init__.py index 08c8ae4..f372259 100644 --- a/src/orcabridge/types/__init__.py +++ b/src/orcabridge/types/__init__.py @@ -18,7 +18,7 @@ # the top level tag is a mapping from string keys to values that can be a string or # an arbitrary depth of nested list of strings or None -Tag: TypeAlias = dict[str, TagValue] +Tag: TypeAlias = Mapping[str, TagValue] # a pathset is a path or an arbitrary depth of nested list of paths PathSet: TypeAlias = PathLike | Collection[PathLike | None] @@ -34,7 +34,7 @@ # a packet is a mapping from string keys to data values -Packet: TypeAlias = dict[str, DataValue] +Packet: TypeAlias = Mapping[str, DataValue] # a batch is a tuple of a tag and a list of packets Batch: TypeAlias = tuple[Tag, Collection[Packet]] diff --git a/src/orcabridge/types/core.py b/src/orcabridge/types/core.py index be7c3d0..5822f87 100644 --- a/src/orcabridge/types/core.py +++ b/src/orcabridge/types/core.py @@ -1,4 +1,4 @@ -from typing import Protocol, Any, TypeAlias +from typing import Protocol, Any, TypeAlias, Mapping import pyarrow as pa from dataclasses import dataclass @@ -15,7 +15,9 @@ class TypeInfo: DataType: TypeAlias = type -TypeSpec: TypeAlias = dict[str, DataType] # Mapping of parameter names to their types +TypeSpec: TypeAlias = Mapping[ + str, DataType +] # Mapping of parameter names to their types class TypeHandler(Protocol): From d5e7dc4e4a168f45e1f253f11589714acb7ee4df Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 18 Jun 2025 01:21:10 +0000 Subject: [PATCH 23/28] feat: add draft parquet arrow dataset --- src/orcabridge/store/arrow_data_stores.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/orcabridge/store/arrow_data_stores.py b/src/orcabridge/store/arrow_data_stores.py index 475e506..b3f1c8e 100644 --- a/src/orcabridge/store/arrow_data_stores.py +++ b/src/orcabridge/store/arrow_data_stores.py @@ -11,7 +11,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta import logging -from collections import defaultdict # Module-level logger From 83bce8d886e4755135ad3a221f80106664367b00 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 18 Jun 2025 08:05:27 +0000 Subject: [PATCH 24/28] wip: typed pod with storage --- src/orcabridge/core/tracker.py | 16 +- src/orcabridge/hashing/__init__.py | 11 +- src/orcabridge/hashing/defaults.py | 6 +- .../hashing/semantic_arrow_hasher.py | 6 + src/orcabridge/hashing/types.py | 5 +- src/orcabridge/pipeline/pipeline.py | 730 ++++++++++++++++++ src/orcabridge/pod/core.py | 80 +- src/orcabridge/store/arrow_data_stores.py | 204 ++++- src/orcabridge/store/types.py | 26 + 9 files changed, 1028 insertions(+), 56 deletions(-) create mode 100644 src/orcabridge/pipeline/pipeline.py diff --git a/src/orcabridge/core/tracker.py b/src/orcabridge/core/tracker.py index f30984d..b36e095 100644 --- a/src/orcabridge/core/tracker.py +++ b/src/orcabridge/core/tracker.py @@ -16,7 +16,7 @@ def __init__(self) -> None: self.invocation_lut: dict[Kernel, list[Invocation]] = {} def record(self, invocation: Invocation) -> None: - invocation_list = self.invocation_lut.setdefault(invocation.operation, []) + invocation_list = self.invocation_lut.setdefault(invocation.kernel, []) if invocation not in invocation_list: invocation_list.append(invocation) @@ -30,16 +30,16 @@ def reset(self) -> dict[Kernel, list[Invocation]]: def generate_namemap(self) -> dict[Invocation, str]: namemap = {} - for operation, invocations in self.invocation_lut.items(): - # if only one entry present, use the operation name alone - if operation.label is not None: - node_label = operation.label + for kernel, invocations in self.invocation_lut.items(): + # if only one entry present, use the kernel name alone + if kernel.label is not None: + node_label = kernel.label else: - node_label = str(operation) + node_label = str(kernel) if len(invocations) == 1: namemap[invocations[0]] = node_label continue - # if multiple entries, use the operation name and index + # if multiple entries, use the kernel name and index for idx, invocation in enumerate(invocations): namemap[invocation] = f"{node_label}_{idx}" return namemap @@ -48,7 +48,7 @@ def generate_graph(self): G = nx.DiGraph() # Add edges for each invocation - for operation, invocations in self.invocation_lut.items(): + for kernel, invocations in self.invocation_lut.items(): for invocation in invocations: for upstream in invocation.streams: # if upstream.invocation is not in the graph, add it diff --git a/src/orcabridge/hashing/__init__.py b/src/orcabridge/hashing/__init__.py index a91b7f4..98a15da 100644 --- a/src/orcabridge/hashing/__init__.py +++ b/src/orcabridge/hashing/__init__.py @@ -10,11 +10,15 @@ hash_to_int, hash_to_uuid, ) -from .defaults import get_default_composite_file_hasher, get_default_object_hasher +from .defaults import ( + get_default_composite_file_hasher, + get_default_object_hasher, + get_default_arrow_hasher, +) from .types import ( FileHasher, PacketHasher, - ArrowPacketHasher, + ArrowHasher, ObjectHasher, StringCacher, FunctionInfoExtractor, @@ -24,7 +28,7 @@ __all__ = [ "FileHasher", "PacketHasher", - "ArrowPacketHasher", + "ArrowHasher", "StringCacher", "ObjectHasher", "CompositeFileHasher", @@ -41,4 +45,5 @@ "HashableMixin", "get_default_composite_file_hasher", "get_default_object_hasher", + "get_default_arrow_hasher", ] diff --git a/src/orcabridge/hashing/defaults.py b/src/orcabridge/hashing/defaults.py index 6f9abf3..5a5d587 100644 --- a/src/orcabridge/hashing/defaults.py +++ b/src/orcabridge/hashing/defaults.py @@ -1,6 +1,6 @@ # A collection of utility function that provides a "default" implementation of hashers. # This is often used as the fallback hasher in the library code. -from orcabridge.hashing.types import CompositeFileHasher +from orcabridge.hashing.types import CompositeFileHasher, ArrowHasher from orcabridge.hashing.file_hashers import PathLikeHasherFactory from orcabridge.hashing.string_cachers import InMemoryCacher from orcabridge.hashing.object_hashers import ObjectHasher @@ -34,9 +34,9 @@ def get_default_object_hasher() -> ObjectHasher: ) -def get_default_semantic_arrow_hasher( +def get_default_arrow_hasher( chunk_size: int = 8192, handle_missing: str = "error" -) -> SemanticArrowHasher: +) -> ArrowHasher: hasher = SemanticArrowHasher(chunk_size=chunk_size, handle_missing=handle_missing) # register semantic hasher for Path hasher.register_semantic_hasher("Path", PathHasher()) diff --git a/src/orcabridge/hashing/semantic_arrow_hasher.py b/src/orcabridge/hashing/semantic_arrow_hasher.py index 311c823..f3682ed 100644 --- a/src/orcabridge/hashing/semantic_arrow_hasher.py +++ b/src/orcabridge/hashing/semantic_arrow_hasher.py @@ -5,6 +5,7 @@ import pyarrow as pa import pyarrow.ipc as ipc from io import BytesIO +import polars as pl class SemanticTypeHasher(Protocol): @@ -213,12 +214,17 @@ def hash_table(self, table: pa.Table, algorithm: str = "sha256") -> str: Returns: Hex string of the computed hash """ + # Step 1: Process columns with semantic types processed_table = self._process_table_columns(table) # Step 2: Sort columns by name for deterministic ordering sorted_table = self._sort_table_columns(processed_table) + # normalize all string to large strings by passing through polars + # TODO: consider cleaner approach in the future + sorted_table = pl.DataFrame(sorted_table).to_arrow() + # Step 3: Serialize using Arrow IPC format serialized_bytes = self._serialize_table_ipc(sorted_table) diff --git a/src/orcabridge/hashing/types.py b/src/orcabridge/hashing/types.py index 4822433..b986941 100644 --- a/src/orcabridge/hashing/types.py +++ b/src/orcabridge/hashing/types.py @@ -104,10 +104,11 @@ class PacketHasher(Protocol): def hash_packet(self, packet: Packet) -> str: ... -class ArrowPacketHasher: +@runtime_checkable +class ArrowHasher(Protocol): """Protocol for hashing arrow packets.""" - def hash_arrow_packet(self, packet: pa.Table) -> str: ... + def hash_table(self, table: pa.Table) -> str: ... @runtime_checkable diff --git a/src/orcabridge/pipeline/pipeline.py b/src/orcabridge/pipeline/pipeline.py new file mode 100644 index 0000000..f203bdb --- /dev/null +++ b/src/orcabridge/pipeline/pipeline.py @@ -0,0 +1,730 @@ +import json +import logging +import pickle +import sys +import time +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Protocol, runtime_checkable + +import networkx as nx +import pandas as pd + +from orcabridge.core.base import Invocation, Kernel +from orcabridge.hashing import hash_to_hex +from orcabridge.core.tracker import GraphTracker + +logger = logging.getLogger(__name__) + + +class SerializationError(Exception): + """Raised when pipeline cannot be serialized""" + + pass + + +class Pipeline(GraphTracker): + """ + Enhanced pipeline that tracks operations and provides queryable views. + Replaces the old Tracker with better persistence and view capabilities. + """ + + def __init__(self, name: str | None = None): + super().__init__() + self.name = name or f"pipeline_{id(self)}" + self._view_registry: dict[str, "PipelineView"] = {} + self._cache_dir = Path(".pipeline_cache") / self.name + self._cache_dir.mkdir(parents=True, exist_ok=True) + + # Core Pipeline Operations + def save(self, path: Path | str) -> None: + """Save complete pipeline state - named functions only""" + path = Path(path) + + # Validate serializability first + self._validate_serializable() + + state = { + "name": self.name, + "invocation_lut": self.invocation_lut, + "metadata": { + "created_at": time.time(), + "python_version": sys.version_info[:2], + "orcabridge_version": "0.1.0", # You can make this dynamic + }, + } + + # Atomic write + temp_path = path.with_suffix(".tmp") + try: + with open(temp_path, "wb") as f: + pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) + temp_path.replace(path) + logger.info(f"Pipeline '{self.name}' saved to {path}") + except Exception: + if temp_path.exists(): + temp_path.unlink() + raise + + @classmethod + def load(cls, path: Path | str) -> "Pipeline": + """Load complete pipeline state""" + path = Path(path) + + with open(path, "rb") as f: + state = pickle.load(f) + + pipeline = cls(state["name"]) + pipeline.invocation_lut = state["invocation_lut"] + + logger.info(f"Pipeline '{pipeline.name}' loaded from {path}") + return pipeline + + def _validate_serializable(self) -> None: + """Ensure pipeline contains only serializable operations""" + issues = [] + + for operation, invocations in self.invocation_lut.items(): + # Check for lambda functions + if hasattr(operation, "function"): + func = getattr(operation, "function", None) + if func and hasattr(func, "__name__") and func.__name__ == "": + issues.append(f"Lambda function in {operation.__class__.__name__}") + + # Test actual serializability + try: + pickle.dumps(operation) + except Exception as e: + issues.append(f"Non-serializable operation {operation}: {e}") + + if issues: + raise SerializationError( + "Pipeline contains non-serializable elements:\n" + + "\n".join(f" - {issue}" for issue in issues) + + "\n\nOnly named functions are supported for serialization." + ) + + # View Management + def as_view( + self, renderer: "ViewRenderer", view_id: str | None = None, **kwargs + ) -> "PipelineView": + """Get a view of this pipeline using the specified renderer""" + view_id = ( + view_id + or f"{renderer.__class__.__name__.lower()}_{len(self._view_registry)}" + ) + + if view_id not in self._view_registry: + self._view_registry[view_id] = renderer.create_view( + self, view_id=view_id, **kwargs + ) + return self._view_registry[view_id] + + def as_dataframe(self, view_id: str = "default", **kwargs) -> "PandasPipelineView": + """Convenience method for pandas DataFrame view""" + return self.as_view(PandasViewRenderer(), view_id=view_id, **kwargs) + + def as_graph(self) -> nx.DiGraph: + """Get the computation graph""" + return self.generate_graph() + + # Combined save/load with views + def save_with_views(self, base_path: Path | str) -> dict[str, Path]: + """Save pipeline and all its views together""" + base_path = Path(base_path) + base_path.mkdir(parents=True, exist_ok=True) + + saved_files = {} + + # Save pipeline itself + pipeline_path = base_path / "pipeline.pkl" + self.save(pipeline_path) + saved_files["pipeline"] = pipeline_path + + # Save all views + for view_id, view in self._view_registry.items(): + view_path = base_path / f"view_{view_id}.pkl" + view.save(view_path, include_pipeline=False) + saved_files[f"view_{view_id}"] = view_path + + # Save manifest + manifest = { + "pipeline_file": "pipeline.pkl", + "views": { + view_id: f"view_{view_id}.pkl" for view_id in self._view_registry.keys() + }, + "created_at": time.time(), + "pipeline_name": self.name, + } + + manifest_path = base_path / "manifest.json" + with open(manifest_path, "w") as f: + json.dump(manifest, f, indent=2) + saved_files["manifest"] = manifest_path + + return saved_files + + @classmethod + def load_with_views( + cls, base_path: Path | str + ) -> tuple["Pipeline", dict[str, "PipelineView"]]: + """Load pipeline and all its views""" + base_path = Path(base_path) + + # Load manifest + manifest_path = base_path / "manifest.json" + with open(manifest_path, "r") as f: + manifest = json.load(f) + + # Load pipeline + pipeline_path = base_path / manifest["pipeline_file"] + pipeline = cls.load(pipeline_path) + + # Load views with appropriate renderers + renderers = { + "PandasPipelineView": PandasViewRenderer(), + "DataJointPipelineView": DataJointViewRenderer(None), # Would need schema + } + + views = {} + for view_id, view_file in manifest["views"].items(): + view_path = base_path / view_file + + # Load view data to determine type + with open(view_path, "rb") as f: + view_data = pickle.load(f) + + # Find appropriate renderer + view_type = view_data.get("view_type", "PandasPipelineView") + if view_type in renderers and renderers[view_type].can_load_view(view_data): + # Load with appropriate view class + if view_type == "PandasPipelineView": + view = PandasPipelineView.load(view_path, pipeline) + else: + view = DataJointPipelineView.load(view_path, pipeline) + else: + # Default to pandas view + view = PandasPipelineView.load(view_path, pipeline) + + views[view_id] = view + pipeline._view_registry[view_id] = view + + return pipeline, views + + def get_stats(self) -> dict[str, Any]: + """Get pipeline statistics""" + total_operations = len(self.invocation_lut) + total_invocations = sum(len(invs) for invs in self.invocation_lut.values()) + + operation_types = {} + for operation in self.invocation_lut.keys(): + op_type = operation.__class__.__name__ + operation_types[op_type] = operation_types.get(op_type, 0) + 1 + + return { + "name": self.name, + "total_operations": total_operations, + "total_invocations": total_invocations, + "operation_types": operation_types, + "views": list(self._view_registry.keys()), + } + + +# View Renderer Protocol +@runtime_checkable +class ViewRenderer(Protocol): + """Protocol for all view renderers - uses structural typing""" + + def create_view( + self, pipeline: "Pipeline", view_id: str, **kwargs + ) -> "PipelineView": + """Create a view for the given pipeline""" + ... + + def can_load_view(self, view_data: dict[str, Any]) -> bool: + """Check if this renderer can load the given view data""" + ... + + +class PandasViewRenderer: + """Renderer for pandas DataFrame views""" + + def create_view( + self, pipeline: "Pipeline", view_id: str, **kwargs + ) -> "PandasPipelineView": + return PandasPipelineView(pipeline, view_id=view_id, **kwargs) + + def can_load_view(self, view_data: dict[str, Any]) -> bool: + return view_data.get("view_type") == "PandasPipelineView" + + +class DataJointViewRenderer: + """Renderer for DataJoint views""" + + def __init__(self, schema): + self.schema = schema + + def create_view( + self, pipeline: "Pipeline", view_id: str, **kwargs + ) -> "DataJointPipelineView": + return DataJointPipelineView(pipeline, self.schema, view_id=view_id, **kwargs) + + def can_load_view(self, view_data: dict[str, Any]) -> bool: + return view_data.get("view_type") == "DataJointPipelineView" + + +# Base class for all views +class PipelineView(ABC): + """Base class for all pipeline views""" + + def __init__(self, pipeline: Pipeline, view_id: str): + self.pipeline = pipeline + self.view_id = view_id + self._cache_dir = pipeline._cache_dir / "views" + self._cache_dir.mkdir(parents=True, exist_ok=True) + + @abstractmethod + def save(self, path: Path | str, include_pipeline: bool = True) -> None: + """Save the view""" + pass + + @classmethod + @abstractmethod + def load(cls, path: Path | str, pipeline: Pipeline | None = None) -> "PipelineView": + """Load the view""" + pass + + def _compute_pipeline_hash(self) -> str: + """Compute hash of current pipeline state for validation""" + pipeline_state = [] + for operation, invocations in self.pipeline.invocation_lut.items(): + for invocation in invocations: + pipeline_state.append(invocation.content_hash()) + return hash_to_hex(sorted(pipeline_state)) + + +# Pandas DataFrame-like view +class PandasPipelineView(PipelineView): + """ + Provides a pandas DataFrame-like interface to pipeline metadata. + Focuses on tag information for querying and filtering. + """ + + def __init__( + self, + pipeline: Pipeline, + view_id: str = "pandas_view", + max_records: int = 10000, + sample_size: int = 100, + ): + super().__init__(pipeline, view_id) + self.max_records = max_records + self.sample_size = sample_size + self._cached_data: pd.DataFrame | None = None + self._build_options = {"max_records": max_records, "sample_size": sample_size} + self._hash_to_data_map: dict[str, Any] = {} + + @property + def df(self) -> pd.DataFrame: + """Access the underlying DataFrame, building if necessary""" + if self._cached_data is None: + # Try to load from cache first + cache_path = self._cache_dir / f"{self.view_id}.pkl" + if cache_path.exists(): + try: + loaded_view = self.load(cache_path, self.pipeline) + if self._is_cache_valid(loaded_view): + self._cached_data = loaded_view._cached_data + self._hash_to_data_map = loaded_view._hash_to_data_map + logger.info(f"Loaded view '{self.view_id}' from cache") + return self._cached_data + except Exception as e: + logger.warning(f"Failed to load cached view: {e}") + + # Build from scratch + logger.info(f"Building view '{self.view_id}' from pipeline") + self._cached_data = self._build_metadata() + + # Auto-save after building + try: + self.save(cache_path, include_pipeline=False) + except Exception as e: + logger.warning(f"Failed to cache view: {e}") + + return self._cached_data + + def _build_metadata(self) -> pd.DataFrame: + """Build the metadata DataFrame from pipeline operations""" + metadata_records = [] + total_records = 0 + + for operation, invocations in self.pipeline.invocation_lut.items(): + if total_records >= self.max_records: + logger.warning(f"Hit max_records limit ({self.max_records})") + break + + for invocation in invocations: + try: + # Get sample of outputs, not all + records = self._extract_metadata_from_invocation( + invocation, operation + ) + for record in records: + metadata_records.append(record) + total_records += 1 + if total_records >= self.max_records: + break + + if total_records >= self.max_records: + break + + except Exception as e: + logger.warning(f"Skipping {operation.__class__.__name__}: {e}") + # Create placeholder record + placeholder = self._create_placeholder_record(invocation, operation) + metadata_records.append(placeholder) + total_records += 1 + + if not metadata_records: + # Return empty DataFrame with basic structure + return pd.DataFrame( + columns=[ + "operation_name", + "operation_hash", + "invocation_id", + "created_at", + "packet_keys", + ] + ) + + return pd.DataFrame(metadata_records) + + def _extract_metadata_from_invocation( + self, invocation: Invocation, operation: Kernel + ) -> list[dict[str, Any]]: + """Extract metadata records from a single invocation""" + records = [] + + # Try to get sample outputs from the invocation + try: + # This is tricky - we need to reconstruct the output stream + # For now, we'll create a basic record from what we know + base_record = { + "operation_name": operation.label or operation.__class__.__name__, + "operation_hash": invocation.content_hash(), + "invocation_id": hash(invocation), + "created_at": time.time(), + "operation_type": operation.__class__.__name__, + } + + # Try to get tag and packet info from the operation + try: + tag_keys, packet_keys = invocation.keys() + base_record.update( + { + "tag_keys": list(tag_keys) if tag_keys else [], + "packet_keys": list(packet_keys) if packet_keys else [], + } + ) + except Exception: + base_record.update( + { + "tag_keys": [], + "packet_keys": [], + } + ) + + records.append(base_record) + + except Exception as e: + logger.debug(f"Could not extract detailed metadata from {operation}: {e}") + records.append(self._create_placeholder_record(invocation, operation)) + + return records + + def _create_placeholder_record( + self, invocation: Invocation, operation: Kernel + ) -> dict[str, Any]: + """Create a placeholder record when extraction fails""" + return { + "operation_name": operation.label or operation.__class__.__name__, + "operation_hash": invocation.content_hash(), + "invocation_id": hash(invocation), + "created_at": time.time(), + "operation_type": operation.__class__.__name__, + "tag_keys": [], + "packet_keys": [], + "is_placeholder": True, + } + + # DataFrame-like interface + def __getitem__(self, condition) -> "FilteredPipelineView": + """Enable pandas-like filtering: view[condition]""" + df = self.df + if isinstance(condition, pd.Series): + filtered_df = df[condition] + elif callable(condition): + filtered_df = df[condition(df)] + else: + filtered_df = df[condition] + + return FilteredPipelineView(self.pipeline, filtered_df, self._hash_to_data_map) + + def query(self, expr: str) -> "FilteredPipelineView": + """SQL-like querying: view.query('operation_name == "MyOperation"')""" + df = self.df + filtered_df = df.query(expr) + return FilteredPipelineView(self.pipeline, filtered_df, self._hash_to_data_map) + + def groupby(self, *args, **kwargs) -> "GroupedPipelineView": + """Group operations similar to pandas groupby""" + df = self.df + grouped = df.groupby(*args, **kwargs) + return GroupedPipelineView(self.pipeline, grouped, self._hash_to_data_map) + + def head(self, n: int = 5) -> pd.DataFrame: + """Return first n rows""" + return self.df.head(n) + + def info(self) -> None: + """Display DataFrame info""" + return self.df.info() + + def describe(self) -> pd.DataFrame: + """Generate descriptive statistics""" + return self.df.describe() + + # Persistence methods + def save(self, path: Path | str, include_pipeline: bool = True) -> None: + """Save view, optionally with complete pipeline state""" + path = Path(path) + + # Build the view data if not cached + df = self.df + + view_data = { + "view_id": self.view_id, + "view_type": self.__class__.__name__, + "dataframe": df, + "build_options": self._build_options, + "hash_to_data_map": self._hash_to_data_map, + "created_at": time.time(), + "pipeline_hash": self._compute_pipeline_hash(), + } + + if include_pipeline: + view_data["pipeline_state"] = { + "name": self.pipeline.name, + "invocation_lut": self.pipeline.invocation_lut, + } + view_data["has_pipeline"] = True + else: + view_data["pipeline_name"] = self.pipeline.name + view_data["has_pipeline"] = False + + with open(path, "wb") as f: + pickle.dump(view_data, f, protocol=pickle.HIGHEST_PROTOCOL) + + @classmethod + def load( + cls, path: Path | str, pipeline: Pipeline | None = None + ) -> "PandasPipelineView": + """Load view, reconstructing pipeline if needed""" + with open(path, "rb") as f: + view_data = pickle.load(f) + + # Handle pipeline reconstruction + if view_data["has_pipeline"]: + pipeline = Pipeline(view_data["pipeline_state"]["name"]) + pipeline.invocation_lut = view_data["pipeline_state"]["invocation_lut"] + elif pipeline is None: + raise ValueError( + "View was saved without pipeline state. " + "You must provide a pipeline parameter." + ) + + # Reconstruct view + build_options = view_data.get("build_options", {}) + view = cls( + pipeline, + view_id=view_data["view_id"], + max_records=build_options.get("max_records", 10000), + sample_size=build_options.get("sample_size", 100), + ) + view._cached_data = view_data["dataframe"] + view._hash_to_data_map = view_data.get("hash_to_data_map", {}) + + return view + + def _is_cache_valid(self, cached_view: "PandasPipelineView") -> bool: + """Check if cached view is still valid""" + try: + cached_hash = getattr(cached_view, "_pipeline_hash", None) + current_hash = self._compute_pipeline_hash() + return cached_hash == current_hash + except Exception: + return False + + def invalidate(self) -> None: + """Force re-rendering on next access""" + self._cached_data = None + cache_path = self._cache_dir / f"{self.view_id}.pkl" + if cache_path.exists(): + cache_path.unlink() + + +class FilteredPipelineView: + """Represents a filtered subset of pipeline metadata""" + + def __init__( + self, pipeline: Pipeline, filtered_df: pd.DataFrame, data_map: dict[str, Any] + ): + self.pipeline = pipeline + self.df = filtered_df + self._data_map = data_map + + def __getitem__(self, condition): + """Further filtering""" + further_filtered = self.df[condition] + return FilteredPipelineView(self.pipeline, further_filtered, self._data_map) + + def query(self, expr: str): + """Apply additional query""" + further_filtered = self.df.query(expr) + return FilteredPipelineView(self.pipeline, further_filtered, self._data_map) + + def to_pandas(self) -> pd.DataFrame: + """Convert to regular pandas DataFrame""" + return self.df.copy() + + def head(self, n: int = 5) -> pd.DataFrame: + """Return first n rows""" + return self.df.head(n) + + def __len__(self) -> int: + return len(self.df) + + def __repr__(self) -> str: + return f"FilteredPipelineView({len(self.df)} records)" + + +class GroupedPipelineView: + """Represents grouped pipeline metadata""" + + def __init__(self, pipeline: Pipeline, grouped_df, data_map: dict[str, Any]): + self.pipeline = pipeline + self.grouped = grouped_df + self._data_map = data_map + + def apply(self, func): + """Apply function to each group""" + return self.grouped.apply(func) + + def agg(self, *args, **kwargs): + """Aggregate groups""" + return self.grouped.agg(*args, **kwargs) + + def size(self): + """Get group sizes""" + return self.grouped.size() + + def get_group(self, name): + """Get specific group""" + group_df = self.grouped.get_group(name) + return FilteredPipelineView(self.pipeline, group_df, self._data_map) + + +# Basic DataJoint View (simplified implementation) +class DataJointPipelineView(PipelineView): + """ + Basic DataJoint view - creates tables for pipeline operations + This is a simplified version - you can expand based on your existing DJ code + """ + + def __init__(self, pipeline: Pipeline, schema, view_id: str = "dj_view"): + super().__init__(pipeline, view_id) + self.schema = schema + self._tables = {} + + def save(self, path: Path | str, include_pipeline: bool = True) -> None: + """Save DataJoint view metadata""" + view_data = { + "view_id": self.view_id, + "view_type": self.__class__.__name__, + "schema_database": self.schema.database, + "table_names": list(self._tables.keys()), + "created_at": time.time(), + } + + if include_pipeline: + view_data["pipeline_state"] = { + "name": self.pipeline.name, + "invocation_lut": self.pipeline.invocation_lut, + } + view_data["has_pipeline"] = True + + with open(path, "wb") as f: + pickle.dump(view_data, f) + + @classmethod + def load( + cls, path: Path | str, pipeline: Pipeline | None = None + ) -> "DataJointPipelineView": + """Load DataJoint view""" + with open(path, "rb") as f: + view_data = pickle.load(f) + + # This would need actual DataJoint schema reconstruction + # For now, return a basic instance + if pipeline is None: + raise ValueError("Pipeline required for DataJoint view loading") + + # You'd need to reconstruct the schema here + view = cls(pipeline, None, view_id=view_data["view_id"]) # schema=None for now + return view + + def generate_tables(self): + """Generate DataJoint tables from pipeline - placeholder implementation""" + # This would use your existing DataJoint generation logic + # from your dj/tracker.py file + pass + + +# Utility functions +def validate_pipeline_serializability(pipeline: Pipeline) -> None: + """Helper to check if pipeline can be saved""" + try: + pipeline._validate_serializable() + print("✅ Pipeline is ready for serialization") + + # Additional performance warnings + stats = pipeline.get_stats() + if stats["total_invocations"] > 1000: + print( + f"⚠️ Large pipeline ({stats['total_invocations']} invocations) - views may be slow to build" + ) + + except SerializationError as e: + print("❌ Pipeline cannot be serialized:") + print(str(e)) + print("\n💡 Convert lambda functions to named functions:") + print(" lambda x: x > 0.8 → def filter_func(x): return x > 0.8") + + +def create_example_pipeline() -> Pipeline: + """Create an example pipeline for testing""" + from orcabridge import GlobSource, function_pod + + @function_pod + def example_function(input_file): + return f"processed_{input_file}" + + pipeline = Pipeline("example") + + with pipeline: + # This would need actual operations to be meaningful + # source = GlobSource('data', './test_data', '*.txt')() + # results = source >> example_function + pass + + return pipeline diff --git a/src/orcabridge/pod/core.py b/src/orcabridge/pod/core.py index 396b4d9..ffca708 100644 --- a/src/orcabridge/pod/core.py +++ b/src/orcabridge/pod/core.py @@ -10,16 +10,18 @@ Any, Literal, ) + from orcabridge.types.registry import PacketConverter from orcabridge.core.base import Kernel from orcabridge.hashing import ( ObjectHasher, - ArrowPacketHasher, + ArrowHasher, FunctionInfoExtractor, get_function_signature, hash_function, get_default_object_hasher, + get_default_arrow_hasher, ) from orcabridge.core.operators import Join from orcabridge.store import DataStore, ArrowDataStore, NoOpDataStore @@ -33,6 +35,7 @@ check_typespec_compatibility, ) from orcabridge.types.registry import is_packet_supported +import polars as pl logger = logging.getLogger(__name__) @@ -363,10 +366,15 @@ def identity_structure(self, *streams) -> Any: def typed_function_pod( - output_keys: Collection[str] | None = None, + output_keys: str | Collection[str] | None = None, function_name: str | None = None, - **kwargs: Any, -) -> Callable[..., "TypedFunctionPod"]: + label: str | None = None, + result_store: ArrowDataStore | None = None, + tag_store: ArrowDataStore | None = None, + object_hasher: ObjectHasher | None = None, + arrow_hasher: ArrowHasher | None = None, + **kwargs, +) -> Callable[..., "TypedFunctionPod | CachedFunctionPod"]: """ Decorator that wraps a function in a FunctionPod instance. @@ -379,7 +387,7 @@ def typed_function_pod( FunctionPod instance wrapping the decorated function """ - def decorator(func) -> TypedFunctionPod: + def decorator(func) -> TypedFunctionPod | CachedFunctionPod: if func.__name__ == "": raise ValueError("Lambda functions cannot be used with function_pod") @@ -398,14 +406,28 @@ def decorator(func) -> TypedFunctionPod: setattr(func, "__name__", new_function_name) setattr(func, "__qualname__", new_function_name) - # Create the FunctionPod + # Create a simple typed function pod pod = TypedFunctionPod( function=func, output_keys=output_keys, function_name=function_name or base_function_name, + label=label, **kwargs, ) + if result_store is not None: + pod = CachedFunctionPod( + function_pod=pod, + object_hasher=object_hasher + if object_hasher is not None + else get_default_object_hasher(), + arrow_hasher=arrow_hasher + if arrow_hasher is not None + else get_default_arrow_hasher(), + result_store=result_store, + tag_store=tag_store, + ) + return pod return decorator @@ -650,7 +672,7 @@ def __init__( self, function_pod: TypedFunctionPod, object_hasher: ObjectHasher, - packet_hasher: ArrowPacketHasher, + arrow_hasher: ArrowHasher, result_store: ArrowDataStore, tag_store: ArrowDataStore | None = None, label: str | None = None, @@ -664,7 +686,7 @@ def __init__( self.function_pod = function_pod self.object_hasher = object_hasher - self.packet_hasher = packet_hasher + self.arrow_hasher = arrow_hasher self.result_store = result_store self.tag_store = tag_store @@ -676,7 +698,7 @@ def __init__( self.function_pod_hash = self.object_hasher.hash_to_hex(self.function_pod) def get_packet_key(self, packet: Packet) -> str: - return self.packet_hasher.hash_arrow_packet( + return self.arrow_hasher.hash_table( self.function_pod.input_converter.to_arrow_table(packet) ) @@ -710,12 +732,20 @@ def _add_tag_record_with_packet_key(self, tag: Tag, packet_key: str) -> Tag: # convert tag to arrow table table = pa.Table.from_pylist([tag]) - entry_hash = self.packet_hasher.hash_arrow_packet(table) + entry_hash = self.arrow_hasher.hash_table(table) # TODO: add error handling - self.tag_store.add_record( - self.function_pod.function_name, self.function_pod_hash, entry_hash, table + # check if record already exists: + retrieved_table = self.tag_store.get_record( + self.function_pod.function_name, self.function_pod_hash, entry_hash ) + if retrieved_table is None: + self.tag_store.add_record( + self.function_pod.function_name, + self.function_pod_hash, + entry_hash, + table, + ) return tag @@ -804,5 +834,31 @@ def call(self, tag: Tag, packet: Packet) -> tuple[Tag, Packet]: return tag, output_packet + def get_all_entries_with_tags(self) -> pl.LazyFrame | None: + """ + Retrieve all entries from the tag store with their associated tags. + Returns a DataFrame with columns for tag and packet key. + """ + if self.tag_store is None: + raise ValueError("Tag store is not set, cannot retrieve entries") + + tag_records = self.tag_store.get_all_records_as_polars( + self.function_pod.function_name, self.function_pod_hash + ) + if tag_records is None: + return None + result_packets = self.result_store.get_records_by_ids_as_polars( + self.function_pod.function_name, + self.function_pod_hash, + tag_records.collect()["__packet_key"], + preserve_input_order=True, + ) + if result_packets is None: + return None + + return pl.concat([tag_records, result_packets], how="horizontal").drop( + ["__packet_key"] + ) + def identity_structure(self, *streams) -> Any: return self.function_pod.identity_structure(*streams) diff --git a/src/orcabridge/store/arrow_data_stores.py b/src/orcabridge/store/arrow_data_stores.py index b3f1c8e..2a2cc50 100644 --- a/src/orcabridge/store/arrow_data_stores.py +++ b/src/orcabridge/store/arrow_data_stores.py @@ -1,17 +1,13 @@ import pyarrow as pa import pyarrow.parquet as pq -import pyarrow.dataset as ds import polars as pl -import os -import json import threading -import time from pathlib import Path from typing import Any, cast from dataclasses import dataclass from datetime import datetime, timedelta import logging - +from orcabridge.store.types import DuplicateError # Module-level logger logger = logging.getLogger(__name__) @@ -199,7 +195,7 @@ def add_entry( # Ensure column order matches if existing_cols != new_cols: - logger.debug(f"Reordering columns to match existing schema") + logger.debug("Reordering columns to match existing schema") polars_table = polars_table.select(existing_cols) # Add new entry @@ -388,6 +384,15 @@ class ParquetArrowDataStore: - Single-row constraint: Each record must contain exactly one row """ + _system_columns = [ + "__source_name", + "__source_id", + "__entry_id", + "__created_at", + "__updated_at", + "__schema_hash", + ] + def __init__( self, base_path: str | Path, @@ -558,13 +563,28 @@ def _add_system_columns( ) -> pa.Table: """Add system columns to track record metadata.""" # Keep all system columns for self-describing data + # Use large_string for all string columns + large_string_type = pa.large_string() + system_columns = [ - ("__source_name", pa.array([metadata.source_name] * len(table))), - ("__source_id", pa.array([metadata.source_id] * len(table))), - ("__entry_id", pa.array([metadata.entry_id] * len(table))), + ( + "__source_name", + pa.array([metadata.source_name] * len(table), type=large_string_type), + ), + ( + "__source_id", + pa.array([metadata.source_id] * len(table), type=large_string_type), + ), + ( + "__entry_id", + pa.array([metadata.entry_id] * len(table), type=large_string_type), + ), ("__created_at", pa.array([metadata.created_at] * len(table))), ("__updated_at", pa.array([metadata.updated_at] * len(table))), - ("__schema_hash", pa.array([metadata.schema_hash] * len(table))), + ( + "__schema_hash", + pa.array([metadata.schema_hash] * len(table), type=large_string_type), + ), ] # Combine user columns + system columns in consistent order @@ -579,16 +599,7 @@ def _add_system_columns( def _remove_system_columns(self, table: pa.Table) -> pa.Table: """Remove system columns to get original user data.""" - system_cols = [ - "__source_name", - "__source_id", - "__entry_id", - "__created_at", - "__updated_at", - "__schema_hash", - ] - user_columns = [name for name in table.column_names if name not in system_cols] - return table.select(user_columns) + return table.drop(self._system_columns) def add_record( self, source_name: str, source_id: str, entry_id: str, arrow_data: pa.Table @@ -610,6 +621,9 @@ def add_record( ValueError: If arrow_data contains more than 1 row ValueError: If arrow_data schema doesn't match existing data for this source """ + # normalize arrow_data to conform to polars string. TODO: consider a clearner approach + arrow_data = pl.DataFrame(arrow_data).to_arrow() + # CRITICAL: Enforce single-row constraint if len(arrow_data) != 1: raise ValueError( @@ -664,7 +678,7 @@ def add_record( entry_exists = existing_metadata is not None if entry_exists and self.duplicate_entry_behavior == "error": - raise ValueError( + raise DuplicateError( f"Entry '{entry_id}' already exists in {source_name}/{source_id}. " f"Use duplicate_entry_behavior='overwrite' to allow updates." ) @@ -716,7 +730,9 @@ def get_record( return self._remove_system_columns(table) - def get_all_records(self, source_name: str, source_id: str) -> pa.Table | None: + def get_all_records( + self, source_name: str, source_id: str, _keep_system_columns: bool = False + ) -> pa.Table | None: """Retrieve all records for a given source as a single Arrow table.""" cache = self._get_or_create_source_cache(source_name, source_id) table = cache.get_all_entries() @@ -724,10 +740,12 @@ def get_all_records(self, source_name: str, source_id: str) -> pa.Table | None: if table is None: return None + if _keep_system_columns: + return table return self._remove_system_columns(table) def get_all_records_as_polars( - self, source_name: str, source_id: str + self, source_name: str, source_id: str, _keep_system_columns: bool = False ) -> pl.LazyFrame | None: """Retrieve all records for a given source as a Polars LazyFrame.""" cache = self._get_or_create_source_cache(source_name, source_id) @@ -736,11 +754,141 @@ def get_all_records_as_polars( if lazy_frame is None: return None - # Remove system columns - system_cols = ["__entry_id", "__created_at", "__updated_at", "__schema_hash"] - user_columns = [col for col in lazy_frame.columns if col not in system_cols] + if _keep_system_columns: + return lazy_frame + + return lazy_frame.drop(self._system_columns) + + def get_records_by_ids( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pa.Table | None: + """ + Retrieve multiple records by their entry_ids as a single Arrow table. + + Args: + source_name: Name of the data source + source_id: ID of the specific dataset within the source + entry_ids: Entry IDs to retrieve. Can be: + - list[str]: List of entry ID strings + - pl.Series: Polars Series containing entry IDs + - pa.Array: PyArrow Array containing entry IDs + add_entry_id_column: Control entry ID column inclusion: + - False: Don't include entry ID column (default) + - True: Include entry ID column as "__entry_id" + - str: Include entry ID column with custom name + preserve_input_order: If True, return results in the same order as input entry_ids, + with null rows for missing entries. If False, return in storage order. + + Returns: + Arrow table containing all found records, or None if no records found + When preserve_input_order=True, table length equals input length + When preserve_input_order=False, records are in storage order + """ + # Get Polars result using the Polars method + polars_result = self.get_records_by_ids_as_polars( + source_name, source_id, entry_ids, add_entry_id_column, preserve_input_order + ) + + if polars_result is None: + return None + + # Convert to Arrow table + return polars_result.collect().to_arrow() + + def get_records_by_ids_as_polars( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pl.LazyFrame | None: + """ + Retrieve multiple records by their entry_ids as a Polars LazyFrame. + + Args: + source_name: Name of the data source + source_id: ID of the specific dataset within the source + entry_ids: Entry IDs to retrieve. Can be: + - list[str]: List of entry ID strings + - pl.Series: Polars Series containing entry IDs + - pa.Array: PyArrow Array containing entry IDs + add_entry_id_column: Control entry ID column inclusion: + - False: Don't include entry ID column (default) + - True: Include entry ID column as "__entry_id" + - str: Include entry ID column with custom name + preserve_input_order: If True, return results in the same order as input entry_ids, + with null rows for missing entries. If False, return in storage order. + + Returns: + Polars LazyFrame containing all found records, or None if no records found + When preserve_input_order=True, frame length equals input length + When preserve_input_order=False, records are in storage order (existing behavior) + """ + # Convert input to Polars Series + if isinstance(entry_ids, list): + if not entry_ids: + return None + entry_ids_series = pl.Series("entry_id", entry_ids) + elif isinstance(entry_ids, pl.Series): + if len(entry_ids) == 0: + return None + entry_ids_series = entry_ids + elif isinstance(entry_ids, pa.Array): + if len(entry_ids) == 0: + return None + entry_ids_series = pl.Series( + "entry_id", entry_ids + ) # Direct from Arrow array + else: + raise TypeError( + f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" + ) + + cache = self._get_or_create_source_cache(source_name, source_id) + lazy_frame = cache.get_all_entries_as_polars() + + if lazy_frame is None: + return None + + # Define system columns that are always excluded (except optionally __entry_id) + system_cols = [ + "__source_name", + "__source_id", + "__created_at", + "__updated_at", + "__schema_hash", + ] + + # Add __entry_id to system columns if we don't want it in the result + if add_entry_id_column is False: + system_cols.append("__entry_id") + + # Handle input order preservation vs filtering + if preserve_input_order: + # Create ordered DataFrame with input IDs and join to preserve order with nulls + ordered_df = pl.DataFrame({"__entry_id": entry_ids_series}).lazy() + # Join with all data to get results in input order with nulls for missing + result_frame = ordered_df.join(lazy_frame, on="__entry_id", how="left") + else: + # Standard filtering approach for storage order -- should be faster in general + result_frame = lazy_frame.filter( + pl.col("__entry_id").is_in(entry_ids_series) + ) + + # Apply column selection (same for both paths) + result_frame = result_frame.drop(system_cols) + + # Rename __entry_id column if custom name provided + if isinstance(add_entry_id_column, str): + result_frame = result_frame.rename({"__entry_id": add_entry_id_column}) - return lazy_frame.select(user_columns) + return result_frame def _sync_all_dirty_caches(self) -> None: """Sync all dirty caches to disk.""" @@ -911,7 +1059,7 @@ def create_multi_row_record(entry_id: str, num_rows: int = 3) -> pa.Table: for i, entry_id in enumerate(valid_entries): data = create_single_row_record(entry_id, value=100.0 + i) - result = store.add_record("experiments", "dataset_A", entry_id, data) + store.add_record("experiments", "dataset_A", entry_id, data) print( f"✓ Added single-row record {entry_id[:16]}... (value: {100.0 + i})" ) diff --git a/src/orcabridge/store/types.py b/src/orcabridge/store/types.py index 912a2d1..444149d 100644 --- a/src/orcabridge/store/types.py +++ b/src/orcabridge/store/types.py @@ -5,6 +5,10 @@ import polars as pl +class DuplicateError(ValueError): + pass + + @runtime_checkable class DataStore(Protocol): """ @@ -56,3 +60,25 @@ def get_all_records_as_polars( ) -> pl.LazyFrame | None: """Retrieve all records for a given source as a single Polars DataFrame.""" ... + + def get_records_by_ids( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pa.Table | None: + """Retrieve records by entry IDs as a single table.""" + ... + + def get_records_by_ids_as_polars( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pl.LazyFrame | None: + """Retrieve records by entry IDs as a single Polars DataFrame.""" + ... From 12679b7bf489ead2df756e4d37733bd9cfe2758a Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 18 Jun 2025 20:26:24 +0000 Subject: [PATCH 25/28] perf: improved import time by lazyloading nonessential dependencies --- src/orcabridge/__init__.py | 2 +- src/orcabridge/core/tracker.py | 5 ++-- src/orcabridge/hashing/string_cachers.py | 31 +++++++++++------------- src/orcabridge/pod/core.py | 4 +++ src/orcabridge/types/inference.py | 4 ++- 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/orcabridge/__init__.py b/src/orcabridge/__init__.py index 12ba536..a84492a 100644 --- a/src/orcabridge/__init__.py +++ b/src/orcabridge/__init__.py @@ -5,7 +5,7 @@ from .pod import FunctionPod, function_pod from .core.sources import GlobSource from .store import DirDataStore, SafeDirDataStore -from .pipeline.pipeline import GraphTracker +from .core.tracker import GraphTracker DEFAULT_TRACKER = GraphTracker() DEFAULT_TRACKER.activate() diff --git a/src/orcabridge/core/tracker.py b/src/orcabridge/core/tracker.py index b36e095..6e3afa9 100644 --- a/src/orcabridge/core/tracker.py +++ b/src/orcabridge/core/tracker.py @@ -1,6 +1,4 @@ from orcabridge.core.base import Invocation, Kernel, Tracker -import networkx as nx -import matplotlib.pyplot as plt class GraphTracker(Tracker): @@ -45,6 +43,7 @@ def generate_namemap(self) -> dict[Invocation, str]: return namemap def generate_graph(self): + import networkx as nx G = nx.DiGraph() # Add edges for each invocation @@ -59,6 +58,8 @@ def generate_graph(self): return G def draw_graph(self): + import networkx as nx + import matplotlib.pyplot as plt G = self.generate_graph() labels = self.generate_namemap() diff --git a/src/orcabridge/hashing/string_cachers.py b/src/orcabridge/hashing/string_cachers.py index 817aa44..6ae69a6 100644 --- a/src/orcabridge/hashing/string_cachers.py +++ b/src/orcabridge/hashing/string_cachers.py @@ -4,22 +4,14 @@ import sqlite3 import threading from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TYPE_CHECKING from orcabridge.hashing.types import StringCacher logger = logging.getLogger(__name__) -REDIS_AVAILABLE = False if TYPE_CHECKING: import redis -else: - try: - import redis - - REDIS_AVAILABLE = True - except ImportError: - redis = None class TransferCacher(StringCacher): @@ -612,9 +604,14 @@ def __init__( password: Redis password (used if connection is None) socket_timeout: Socket timeout in seconds """ - if not REDIS_AVAILABLE: - raise ImportError("redis package is required for RedisCacher") - + # TODO: cleanup the redis use pattern + try: + import redis + self._redis_module = redis + except ImportError as e: + raise ImportError( + "Redis module not found. Please install the 'redis' package." + ) from e self.key_prefix = key_prefix self._connection_failed = False self._lock = threading.RLock() @@ -623,7 +620,7 @@ def __init__( if connection is not None: self.redis = connection else: - self.redis = redis.Redis( + self.redis = self._redis_module.Redis( host=host, port=port, db=db, @@ -657,7 +654,7 @@ def _test_connection(self) -> None: f"Redis connection established successfully with prefix '{self.key_prefix}'" ) - except (redis.RedisError, redis.ConnectionError) as e: + except (self._redis_module.RedisError, self._redis_module.ConnectionError) as e: logging.error(f"Failed to establish Redis connection: {e}") raise RuntimeError(f"Redis connection test failed: {e}") @@ -689,7 +686,7 @@ def get_cached(self, cache_key: str) -> str | None: return str(result) - except (redis.RedisError, redis.ConnectionError) as e: + except (self._redis_module.RedisError, self._redis_module.ConnectionError) as e: self._handle_redis_error("get", e) return None @@ -707,7 +704,7 @@ def set_cached(self, cache_key: str, value: str) -> None: self.redis.set(self._get_prefixed_key(cache_key), value) - except (redis.RedisError, redis.ConnectionError) as e: + except (self._redis_module.RedisError, self._redis_module.ConnectionError) as e: self._handle_redis_error("set", e) def clear_cache(self) -> None: @@ -721,7 +718,7 @@ def clear_cache(self) -> None: if keys: self.redis.delete(*list(keys)) # type: ignore[arg-type] - except (redis.RedisError, redis.ConnectionError) as e: + except (self._redis_module.RedisError, self._redis_module.ConnectionError) as e: self._handle_redis_error("clear", e) def is_connected(self) -> bool: diff --git a/src/orcabridge/pod/core.py b/src/orcabridge/pod/core.py index ffca708..364bc90 100644 --- a/src/orcabridge/pod/core.py +++ b/src/orcabridge/pod/core.py @@ -754,6 +754,7 @@ def retrieve_memoized(self, packet: Packet) -> Packet | None: Retrieve a memoized packet from the data store. Returns None if no memoized packet is found. """ + logger.info("Retrieving memoized packet") return self._retrieve_memoized_by_hash(self.get_packet_key(packet)) def _retrieve_memoized_by_hash(self, packet_hash: str) -> Packet | None: @@ -761,6 +762,7 @@ def _retrieve_memoized_by_hash(self, packet_hash: str) -> Packet | None: Retrieve a memoized result packet from the data store, looking up by hash Returns None if no memoized packet is found. """ + logger.info(f"Retrieving memoized packet with hash {packet_hash}") arrow_table = self.result_store.get_record( self.function_pod.function_name, self.function_pod_hash, @@ -784,6 +786,7 @@ def memoize( Memoize the output packet in the data store. Returns the memoized packet. """ + logger.info("Memoizing packet") return self._memoize_by_hash(self.get_packet_key(packet), output_packet) def _memoize_by_hash(self, packet_hash: str, output_packet: Packet) -> Packet: @@ -791,6 +794,7 @@ def _memoize_by_hash(self, packet_hash: str, output_packet: Packet) -> Packet: Memoize the output packet in the data store, looking up by hash. Returns the memoized packet. """ + logger.info(f"Memoizing packet with hash {packet_hash}") packets = self.function_pod.output_converter.from_arrow_table( self.result_store.add_record( self.function_pod.function_name, diff --git a/src/orcabridge/types/inference.py b/src/orcabridge/types/inference.py index 72a54de..2f18f39 100644 --- a/src/orcabridge/types/inference.py +++ b/src/orcabridge/types/inference.py @@ -6,7 +6,7 @@ from .core import TypeSpec import inspect import logging -from beartype.door import is_bearable, is_subhint + logger = logging.getLogger(__name__) @@ -14,6 +14,7 @@ def verify_against_typespec(packet: dict, typespec: TypeSpec) -> bool: """Verify that the dictionary's types match the expected types in the typespec.""" + from beartype.door import is_bearable # verify that packet contains no keys not in typespec if set(packet.keys()) - set(typespec.keys()): logger.warning( @@ -38,6 +39,7 @@ def verify_against_typespec(packet: dict, typespec: TypeSpec) -> bool: def check_typespec_compatibility( incoming_types: TypeSpec, receiving_types: TypeSpec ) -> bool: + from beartype.door import is_subhint for key, type_info in incoming_types.items(): if key not in receiving_types: logger.warning(f"Key '{key}' not found in parameter types.") From 727701047e10d941c5f1c8a026d6858987a3b1eb Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 18 Jun 2025 20:26:36 +0000 Subject: [PATCH 26/28] feat: add mock and in memory arrow data store --- src/orcabridge/store/arrow_data_stores.py | 509 ++++++++++++++++++++++ 1 file changed, 509 insertions(+) diff --git a/src/orcabridge/store/arrow_data_stores.py b/src/orcabridge/store/arrow_data_stores.py index 2a2cc50..1e866d5 100644 --- a/src/orcabridge/store/arrow_data_stores.py +++ b/src/orcabridge/store/arrow_data_stores.py @@ -13,6 +13,515 @@ logger = logging.getLogger(__name__) +class MockArrowDataStore: + """ + Mock Arrow data store for testing purposes. + This class simulates the behavior of ParquetArrowDataStore without actually saving anything. + It is useful for unit tests where you want to avoid filesystem dependencies. + """ + + def __init__(self): + logger.info("Initialized MockArrowDataStore") + + def add_record(self, + source_name: str, + source_id: str, + entry_id: str, + arrow_data: pa.Table) -> pa.Table: + """Add a record to the mock store.""" + return arrow_data + + def get_record(self, source_name: str, + source_id: str, + entry_id: str) -> pa.Table | None: + """Get a specific record.""" + return None + + def get_all_records(self, source_name: str, source_id: str) -> pa.Table | None: + """Retrieve all records for a given source as a single table.""" + return None + + def get_all_records_as_polars( + self, source_name: str, source_id: str + ) -> pl.LazyFrame | None: + """Retrieve all records for a given source as a single Polars LazyFrame.""" + return None + + def get_records_by_ids( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pa.Table | None: + """ + Retrieve records by entry IDs as a single table. + + Args: + source_name: Name of the data source + source_id: ID of the specific dataset within the source + entry_ids: Entry IDs to retrieve. Can be: + - list[str]: List of entry ID strings + - pl.Series: Polars Series containing entry IDs + - pa.Array: PyArrow Array containing entry IDs + add_entry_id_column: Control entry ID column inclusion: + - False: Don't include entry ID column (default) + - True: Include entry ID column as "__entry_id" + - str: Include entry ID column with custom name + preserve_input_order: If True, return results in the same order as input entry_ids, + with null rows for missing entries. If False, return in storage order. + + Returns: + Arrow table containing all found records, or None if no records found + """ + return None + + def get_records_by_ids_as_polars( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pl.LazyFrame | None: + return None + + + + +class InMemoryArrowDataStore: + """ + In-memory Arrow data store for testing purposes. + This class simulates the behavior of ParquetArrowDataStore without actual file I/O. + It is useful for unit tests where you want to avoid filesystem dependencies. + + Uses dict of dict of Arrow tables for efficient storage and retrieval. + """ + + def __init__(self, duplicate_entry_behavior: str = "error"): + """ + Initialize the InMemoryArrowDataStore. + + Args: + duplicate_entry_behavior: How to handle duplicate entry_ids: + - 'error': Raise ValueError when entry_id already exists + - 'overwrite': Replace existing entry with new data + """ + # Validate duplicate behavior + if duplicate_entry_behavior not in ["error", "overwrite"]: + raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") + self.duplicate_entry_behavior = duplicate_entry_behavior + + # Store Arrow tables: {source_key: {entry_id: arrow_table}} + self._in_memory_store: dict[str, dict[str, pa.Table]] = {} + logger.info(f"Initialized InMemoryArrowDataStore with duplicate_entry_behavior='{duplicate_entry_behavior}'") + + def _get_source_key(self, source_name: str, source_id: str) -> str: + """Generate key for source storage.""" + return f"{source_name}:{source_id}" + + def add_record( + self, + source_name: str, + source_id: str, + entry_id: str, + arrow_data: pa.Table, + ) -> pa.Table: + """ + Add a record to the in-memory store. + + Args: + source_name: Name of the data source + source_id: ID of the specific dataset within the source + entry_id: Unique identifier for this record + arrow_data: The Arrow table data to store + + Returns: + The original arrow_data table + + Raises: + ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' + """ + source_key = self._get_source_key(source_name, source_id) + + # Initialize source if it doesn't exist + if source_key not in self._in_memory_store: + self._in_memory_store[source_key] = {} + + local_data = self._in_memory_store[source_key] + + # Check for duplicate entry + if entry_id in local_data and self.duplicate_entry_behavior == "error": + raise ValueError( + f"Entry '{entry_id}' already exists in {source_name}/{source_id}. " + f"Use duplicate_entry_behavior='overwrite' to allow updates." + ) + + # Store the record + local_data[entry_id] = arrow_data + + action = "Updated" if entry_id in local_data else "Added" + logger.debug(f"{action} record {entry_id} in {source_key}") + return arrow_data + + def get_record( + self, source_name: str, source_id: str, entry_id: str + ) -> pa.Table | None: + """Get a specific record.""" + source_key = self._get_source_key(source_name, source_id) + local_data = self._in_memory_store.get(source_key, {}) + return local_data.get(entry_id) + + def get_all_records(self, source_name: str, source_id: str) -> pa.Table | None: + """Retrieve all records for a given source as a single table.""" + source_key = self._get_source_key(source_name, source_id) + local_data = self._in_memory_store.get(source_key, {}) + + if not local_data: + return None + + tables_with_keys = [] + for key, table in local_data.items(): + # Add entry_id column to each table + key_array = pa.array([key] * len(table), type=pa.string()) + table_with_key = table.add_column(0, "__entry_id", key_array) + tables_with_keys.append(table_with_key) + + # Concatenate all tables + if tables_with_keys: + return pa.concat_tables(tables_with_keys) + return None + + def get_all_records_as_polars( + self, source_name: str, source_id: str + ) -> pl.LazyFrame | None: + """Retrieve all records for a given source as a single Polars LazyFrame.""" + all_records = self.get_all_records(source_name, source_id) + if all_records is None: + return None + return pl.LazyFrame(all_records) + + def get_records_by_ids( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pa.Table | None: + """ + Retrieve records by entry IDs as a single table. + + Args: + source_name: Name of the data source + source_id: ID of the specific dataset within the source + entry_ids: Entry IDs to retrieve. Can be: + - list[str]: List of entry ID strings + - pl.Series: Polars Series containing entry IDs + - pa.Array: PyArrow Array containing entry IDs + add_entry_id_column: Control entry ID column inclusion: + - False: Don't include entry ID column (default) + - True: Include entry ID column as "__entry_id" + - str: Include entry ID column with custom name + preserve_input_order: If True, return results in the same order as input entry_ids, + with null rows for missing entries. If False, return in storage order. + + Returns: + Arrow table containing all found records, or None if no records found + """ + # Convert input to list of strings for consistency + if isinstance(entry_ids, list): + if not entry_ids: + return None + entry_ids_list = entry_ids + elif isinstance(entry_ids, pl.Series): + if len(entry_ids) == 0: + return None + entry_ids_list = entry_ids.to_list() + elif isinstance(entry_ids, pa.Array): + if len(entry_ids) == 0: + return None + entry_ids_list = entry_ids.to_pylist() + else: + raise TypeError( + f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" + ) + + source_key = self._get_source_key(source_name, source_id) + local_data = self._in_memory_store.get(source_key, {}) + + if not local_data: + return None + + # Collect matching tables + found_tables = [] + found_entry_ids = [] + + if preserve_input_order: + # Preserve input order, include nulls for missing entries + first_table_schema = None + + for entry_id in entry_ids_list: + if entry_id in local_data: + table = local_data[entry_id] + # Add entry_id column + key_array = pa.array([entry_id] * len(table), type=pa.string()) + table_with_key = table.add_column(0, "__entry_id", key_array) + found_tables.append(table_with_key) + found_entry_ids.append(entry_id) + + # Store schema for creating null rows + if first_table_schema is None: + first_table_schema = table_with_key.schema + else: + # Create a null row with the same schema as other tables + if first_table_schema is not None: + # Create null row + null_data = {} + for field in first_table_schema: + if field.name == "__entry_id": + null_data[field.name] = pa.array([entry_id], type=field.type) + else: + # Create null array with proper type + null_array = pa.array([None], type=field.type) + null_data[field.name] = null_array + + null_table = pa.table(null_data, schema=first_table_schema) + found_tables.append(null_table) + found_entry_ids.append(entry_id) + else: + # Storage order (faster) - only include existing entries + for entry_id in entry_ids_list: + if entry_id in local_data: + table = local_data[entry_id] + # Add entry_id column + key_array = pa.array([entry_id] * len(table), type=pa.string()) + table_with_key = table.add_column(0, "__entry_id", key_array) + found_tables.append(table_with_key) + found_entry_ids.append(entry_id) + + if not found_tables: + return None + + # Concatenate all found tables + if len(found_tables) == 1: + combined_table = found_tables[0] + else: + combined_table = pa.concat_tables(found_tables) + + # Handle entry_id column based on add_entry_id_column parameter + if add_entry_id_column is False: + # Remove the __entry_id column + column_names = combined_table.column_names + if "__entry_id" in column_names: + indices_to_keep = [i for i, name in enumerate(column_names) if name != "__entry_id"] + combined_table = combined_table.select(indices_to_keep) + elif isinstance(add_entry_id_column, str): + # Rename __entry_id to custom name + schema = combined_table.schema + new_names = [add_entry_id_column if name == "__entry_id" else name for name in schema.names] + combined_table = combined_table.rename_columns(new_names) + # If add_entry_id_column is True, keep __entry_id as is + + return combined_table + + def get_records_by_ids_as_polars( + self, + source_name: str, + source_id: str, + entry_ids: list[str] | pl.Series | pa.Array, + add_entry_id_column: bool | str = False, + preserve_input_order: bool = False, + ) -> pl.LazyFrame | None: + """ + Retrieve records by entry IDs as a single Polars LazyFrame. + + Args: + source_name: Name of the data source + source_id: ID of the specific dataset within the source + entry_ids: Entry IDs to retrieve. Can be: + - list[str]: List of entry ID strings + - pl.Series: Polars Series containing entry IDs + - pa.Array: PyArrow Array containing entry IDs + add_entry_id_column: Control entry ID column inclusion: + - False: Don't include entry ID column (default) + - True: Include entry ID column as "__entry_id" + - str: Include entry ID column with custom name + preserve_input_order: If True, return results in the same order as input entry_ids, + with null rows for missing entries. If False, return in storage order. + + Returns: + Polars LazyFrame containing all found records, or None if no records found + """ + # Get Arrow result and convert to Polars + arrow_result = self.get_records_by_ids( + source_name, source_id, entry_ids, add_entry_id_column, preserve_input_order + ) + + if arrow_result is None: + return None + + # Convert to Polars LazyFrame + return pl.LazyFrame(arrow_result) + + def save_to_parquet(self, base_path: str | Path) -> None: + """ + Save all data to Parquet files in a directory structure. + + Directory structure: base_path/source_name/source_id/data.parquet + + Args: + base_path: Base directory path where to save the Parquet files + """ + base_path = Path(base_path) + base_path.mkdir(parents=True, exist_ok=True) + + saved_count = 0 + + for source_key, local_data in self._in_memory_store.items(): + if not local_data: + continue + + # Parse source_name and source_id from the key + if ":" not in source_key: + logger.warning(f"Invalid source key format: {source_key}, skipping") + continue + + source_name, source_id = source_key.split(":", 1) + + # Create directory structure + source_dir = base_path / source_name / source_id + source_dir.mkdir(parents=True, exist_ok=True) + + # Combine all tables for this source with entry_id column + tables_with_keys = [] + for entry_id, table in local_data.items(): + # Add entry_id column to each table + key_array = pa.array([entry_id] * len(table), type=pa.string()) + table_with_key = table.add_column(0, "__entry_id", key_array) + tables_with_keys.append(table_with_key) + + # Concatenate all tables + if tables_with_keys: + combined_table = pa.concat_tables(tables_with_keys) + + # Save as Parquet file + parquet_path = source_dir / "data.parquet" + import pyarrow.parquet as pq + pq.write_table(combined_table, parquet_path) + + saved_count += 1 + logger.debug(f"Saved {len(combined_table)} records for {source_key} to {parquet_path}") + + logger.info(f"Saved {saved_count} sources to Parquet files in {base_path}") + + def load_from_parquet(self, base_path: str | Path) -> None: + """ + Load data from Parquet files with the expected directory structure. + + Expected structure: base_path/source_name/source_id/data.parquet + + Args: + base_path: Base directory path containing the Parquet files + """ + base_path = Path(base_path) + + if not base_path.exists(): + logger.warning(f"Base path {base_path} does not exist") + return + + # Clear existing data + self._in_memory_store.clear() + + loaded_count = 0 + + # Traverse directory structure: source_name/source_id/ + for source_name_dir in base_path.iterdir(): + if not source_name_dir.is_dir(): + continue + + source_name = source_name_dir.name + + for source_id_dir in source_name_dir.iterdir(): + if not source_id_dir.is_dir(): + continue + + source_id = source_id_dir.name + source_key = self._get_source_key(source_name, source_id) + + # Look for Parquet files in this directory + parquet_files = list(source_id_dir.glob("*.parquet")) + + if not parquet_files: + logger.debug(f"No Parquet files found in {source_id_dir}") + continue + + # Load all Parquet files and combine them + all_records = [] + + for parquet_file in parquet_files: + try: + import pyarrow.parquet as pq + table = pq.read_table(parquet_file) + + # Validate that __entry_id column exists + if "__entry_id" not in table.column_names: + logger.warning(f"Parquet file {parquet_file} missing __entry_id column, skipping") + continue + + all_records.append(table) + logger.debug(f"Loaded {len(table)} records from {parquet_file}") + + except Exception as e: + logger.error(f"Failed to load Parquet file {parquet_file}: {e}") + continue + + # Process all records for this source + if all_records: + # Combine all tables + if len(all_records) == 1: + combined_table = all_records[0] + else: + combined_table = pa.concat_tables(all_records) + + # Split back into individual records by entry_id + local_data = {} + entry_ids = combined_table.column("__entry_id").to_pylist() + + # Group records by entry_id + entry_id_groups = {} + for i, entry_id in enumerate(entry_ids): + if entry_id not in entry_id_groups: + entry_id_groups[entry_id] = [] + entry_id_groups[entry_id].append(i) + + # Extract each entry_id's records + for entry_id, indices in entry_id_groups.items(): + # Take rows for this entry_id and remove __entry_id column + entry_table = combined_table.take(indices) + + # Remove __entry_id column + column_names = entry_table.column_names + if "__entry_id" in column_names: + indices_to_keep = [i for i, name in enumerate(column_names) if name != "__entry_id"] + entry_table = entry_table.select(indices_to_keep) + + local_data[entry_id] = entry_table + + self._in_memory_store[source_key] = local_data + loaded_count += 1 + + record_count = len(combined_table) + unique_entries = len(entry_id_groups) + logger.debug(f"Loaded {record_count} records ({unique_entries} unique entries) for {source_key}") + + logger.info(f"Loaded {loaded_count} sources from Parquet files in {base_path}") + + # Log summary of loaded data + total_records = sum(len(local_data) for local_data in self._in_memory_store.values()) + logger.info(f"Total records loaded: {total_records}") + @dataclass class RecordMetadata: """Metadata for a stored record.""" From 0ce5aa3d94aa823d9c107e0f22efa51ccdb9ea75 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 18 Jun 2025 21:06:49 +0000 Subject: [PATCH 27/28] test: update redis test to work with new lazy loading --- src/orcabridge/hashing/string_cachers.py | 20 ++-- .../test_string_cacher/test_redis_cacher.py | 105 +++++++----------- 2 files changed, 52 insertions(+), 73 deletions(-) diff --git a/src/orcabridge/hashing/string_cachers.py b/src/orcabridge/hashing/string_cachers.py index 6ae69a6..0db7af4 100644 --- a/src/orcabridge/hashing/string_cachers.py +++ b/src/orcabridge/hashing/string_cachers.py @@ -13,6 +13,14 @@ if TYPE_CHECKING: import redis +def _get_redis(): + """Lazy import for Redis to avoid circular dependencies.""" + try: + import redis + return redis + except ImportError as e: + return None + class TransferCacher(StringCacher): """ @@ -605,13 +613,9 @@ def __init__( socket_timeout: Socket timeout in seconds """ # TODO: cleanup the redis use pattern - try: - import redis - self._redis_module = redis - except ImportError as e: - raise ImportError( - "Redis module not found. Please install the 'redis' package." - ) from e + self._redis_module = _get_redis() + if self._redis_module is None: + raise ImportError("Could not import Redis module. redis package is required for RedisCacher") self.key_prefix = key_prefix self._connection_failed = False self._lock = threading.RLock() @@ -648,7 +652,7 @@ def _test_connection(self) -> None: self.redis.delete(test_key) if result != "test": - raise redis.RedisError("Failed to verify key access") + raise self._redis_module.RedisError("Failed to verify key access") logging.info( f"Redis connection established successfully with prefix '{self.key_prefix}'" diff --git a/tests/test_hashing/test_string_cacher/test_redis_cacher.py b/tests/test_hashing/test_string_cacher/test_redis_cacher.py index ac04b82..060fb61 100644 --- a/tests/test_hashing/test_string_cacher/test_redis_cacher.py +++ b/tests/test_hashing/test_string_cacher/test_redis_cacher.py @@ -1,12 +1,15 @@ """Tests for RedisCacher using mocked Redis.""" -from typing import cast -from unittest.mock import patch +from typing import cast, TYPE_CHECKING +from unittest.mock import patch, MagicMock import pytest from orcabridge.hashing.string_cachers import RedisCacher +if TYPE_CHECKING: + import redis + # Mock Redis exceptions class MockRedisError(Exception): @@ -65,15 +68,25 @@ def keys(self, pattern): return [key for key in self.data.keys() if key.startswith(prefix)] return [key for key in self.data.keys() if key == pattern] +class MockRedisModule: + ConnectionError = MockConnectionError + RedisError = MockRedisError + Redis = MagicMock(return_value=MockRedis()) # Simple one-liner! + + + +def mock_get_redis(): + return MockRedisModule + +def mock_no_redis(): + return None + + class TestRedisCacher: """Test cases for RedisCacher with mocked Redis.""" - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) - @patch("orcabridge.hashing.string_cachers.redis.RedisError", MockRedisError) - @patch( - "orcabridge.hashing.string_cachers.redis.ConnectionError", MockConnectionError - ) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_basic_operations(self): """Test basic get/set/clear operations.""" mock_redis = MockRedis() @@ -100,7 +113,7 @@ def test_basic_operations(self): assert cacher.get_cached("key1") is None assert cacher.get_cached("key2") is None - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_key_prefixing(self): """Test that keys are properly prefixed.""" mock_redis = MockRedis() @@ -115,7 +128,7 @@ def test_key_prefixing(self): # But retrieval should work without prefix assert cacher.get_cached("key1") == "value1" - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_connection_initialization_success(self): """Test successful connection initialization.""" mock_redis = MockRedis() @@ -130,11 +143,7 @@ def test_connection_initialization_success(self): assert mock_redis.ping_called assert cacher.is_connected() - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) - @patch("orcabridge.hashing.string_cachers.redis.RedisError", MockRedisError) - @patch( - "orcabridge.hashing.string_cachers.redis.ConnectionError", MockConnectionError - ) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_connection_initialization_failure(self): """Test connection initialization failure.""" mock_redis = MockRedis(fail_connection=True) @@ -142,17 +151,15 @@ def test_connection_initialization_failure(self): with pytest.raises(RuntimeError, match="Redis connection test failed"): RedisCacher(connection=mock_redis, key_prefix="test:") - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) - @patch("orcabridge.hashing.string_cachers.redis.Redis") - def test_new_connection_creation(self, mock_redis_class): + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) + def test_new_connection_creation(self): """Test creation of new Redis connection when none provided.""" - mock_instance = MockRedis() - mock_redis_class.return_value = mock_instance - cacher = RedisCacher(host="localhost", port=6379, db=0, key_prefix="test:") # Verify Redis was called with correct parameters - mock_redis_class.assert_called_once_with( + # Get the mock module to verify calls + mock_module = mock_get_redis() + mock_module.Redis.assert_called_with( host="localhost", port=6379, db=0, @@ -164,11 +171,7 @@ def test_new_connection_creation(self, mock_redis_class): assert cacher.is_connected() - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) - @patch("orcabridge.hashing.string_cachers.redis.RedisError", MockRedisError) - @patch( - "orcabridge.hashing.string_cachers.redis.ConnectionError", MockConnectionError - ) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_graceful_failure_on_operations(self): """Test graceful failure when Redis operations fail during use.""" mock_redis = MockRedis() @@ -190,11 +193,7 @@ def test_graceful_failure_on_operations(self): mock_log.assert_called_once() assert "Redis get failed" in str(mock_log.call_args) - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) - @patch("orcabridge.hashing.string_cachers.redis.RedisError", MockRedisError) - @patch( - "orcabridge.hashing.string_cachers.redis.ConnectionError", MockConnectionError - ) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_set_failure_handling(self): """Test handling of set operation failures.""" mock_redis = MockRedis() @@ -209,11 +208,7 @@ def test_set_failure_handling(self): assert "Redis set failed" in str(mock_log.call_args) assert not cacher.is_connected() - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) - @patch("orcabridge.hashing.string_cachers.redis.RedisError", MockRedisError) - @patch( - "orcabridge.hashing.string_cachers.redis.ConnectionError", MockConnectionError - ) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_clear_cache_failure_handling(self): """Test handling of clear cache operation failures.""" mock_redis = MockRedis() @@ -231,7 +226,7 @@ def test_clear_cache_failure_handling(self): assert "Redis clear failed" in str(mock_log.call_args) assert not cacher.is_connected() - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_clear_cache_with_pattern_matching(self): """Test that clear_cache only removes keys with the correct prefix.""" mock_redis = MockRedis() @@ -249,11 +244,7 @@ def test_clear_cache_with_pattern_matching(self): assert "test:key2" not in mock_redis.data assert "other:key1" in mock_redis.data # Should remain - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) - @patch("orcabridge.hashing.string_cachers.redis.RedisError", MockRedisError) - @patch( - "orcabridge.hashing.string_cachers.redis.ConnectionError", MockConnectionError - ) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_connection_reset(self): """Test connection reset functionality.""" mock_redis = MockRedis() @@ -274,11 +265,7 @@ def test_connection_reset(self): # Check that the reset message was logged (it should be the last call) mock_log.assert_called_with("Redis connection successfully reset") - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) - @patch("orcabridge.hashing.string_cachers.redis.RedisError", MockRedisError) - @patch( - "orcabridge.hashing.string_cachers.redis.ConnectionError", MockConnectionError - ) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_connection_reset_failure(self): """Test connection reset failure handling.""" mock_redis = MockRedis() @@ -300,11 +287,7 @@ def test_connection_reset_failure(self): "Failed to reset Redis connection: Redis connection test failed: Connection failed" ) - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) - @patch("orcabridge.hashing.string_cachers.redis.RedisError", MockRedisError) - @patch( - "orcabridge.hashing.string_cachers.redis.ConnectionError", MockConnectionError - ) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_error_logging_only_once(self): """Test that errors are only logged once per failure.""" mock_redis = MockRedis() @@ -322,7 +305,7 @@ def test_error_logging_only_once(self): # Should only log the first error assert mock_log.call_count == 1 - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_default_key_prefix(self): """Test default key prefix behavior.""" mock_redis = MockRedis() @@ -337,15 +320,11 @@ def test_default_key_prefix(self): def test_redis_not_available(self): """Test behavior when redis package is not available.""" - with patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", False): + with patch("orcabridge.hashing.string_cachers._get_redis", mock_no_redis): with pytest.raises(ImportError, match="redis package is required"): RedisCacher() - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) - @patch("orcabridge.hashing.string_cachers.redis.RedisError", MockRedisError) - @patch( - "orcabridge.hashing.string_cachers.redis.ConnectionError", MockConnectionError - ) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_connection_test_key_access_failure(self): """Test failure when connection test can't create/access test key.""" @@ -361,7 +340,7 @@ def get(self, key): with pytest.raises(RuntimeError, match="Redis connection test failed"): RedisCacher(connection=mock_redis, key_prefix="test:") - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_thread_safety(self): """Test thread safety of Redis operations.""" import threading @@ -419,11 +398,7 @@ def worker(thread_id: int): expected = f"thread{thread_id}_value{i}" assert result == expected - @patch("orcabridge.hashing.string_cachers.REDIS_AVAILABLE", True) - @patch("orcabridge.hashing.string_cachers.redis.RedisError", MockRedisError) - @patch( - "orcabridge.hashing.string_cachers.redis.ConnectionError", MockConnectionError - ) + @patch("orcabridge.hashing.string_cachers._get_redis", mock_get_redis) def test_operations_after_connection_failure(self): """Test that operations return None/do nothing after connection failure.""" mock_redis = MockRedis() From e8fc7a9fc0ff81a2118eb5292a33c32c5e581626 Mon Sep 17 00:00:00 2001 From: "Edgar Y. Walker" Date: Wed, 18 Jun 2025 21:07:09 +0000 Subject: [PATCH 28/28] refactor: remote nonfunctional tests for stream operations --- tests/test_streams_operations/__init__.py | 0 tests/test_streams_operations/conftest.py | 204 ------ .../test_mappers/__init__.py | 0 .../test_mappers/test_batch.py | 290 -------- .../test_mappers/test_cache_stream.py | 299 -------- .../test_mappers/test_default_tag.py | 260 ------- .../test_mappers/test_filter.py | 325 --------- .../test_mappers/test_first_match.py | 244 ------- .../test_mappers/test_group_by.py | 298 -------- .../test_mappers/test_join.py | 198 ------ .../test_mappers/test_map_packets.py | 273 -------- .../test_mappers/test_map_tags.py | 330 --------- .../test_mappers/test_merge.py | 208 ------ .../test_mappers/test_repeat.py | 186 ----- .../test_mappers/test_transform.py | 363 ---------- .../test_mappers/test_utility_functions.py | 248 ------- .../test_pipelines/__init__.py | 0 .../test_pipelines/test_basic_pipelines.py | 542 --------------- .../test_pipelines/test_recursive_features.py | 637 ------------------ .../test_pods/__init__.py | 0 .../test_pods/test_function_pod.py | 305 --------- .../test_pods/test_function_pod_datastore.py | 403 ----------- .../test_pods/test_pod_base.py | 274 -------- .../test_sources/__init__.py | 0 .../test_sources/test_glob_source.py | 325 --------- .../test_streams/__init__.py | 0 .../test_streams/test_base_classes.py | 514 -------------- .../test_sync_stream_implementations.py | 578 ---------------- 28 files changed, 7304 deletions(-) delete mode 100644 tests/test_streams_operations/__init__.py delete mode 100644 tests/test_streams_operations/conftest.py delete mode 100644 tests/test_streams_operations/test_mappers/__init__.py delete mode 100644 tests/test_streams_operations/test_mappers/test_batch.py delete mode 100644 tests/test_streams_operations/test_mappers/test_cache_stream.py delete mode 100644 tests/test_streams_operations/test_mappers/test_default_tag.py delete mode 100644 tests/test_streams_operations/test_mappers/test_filter.py delete mode 100644 tests/test_streams_operations/test_mappers/test_first_match.py delete mode 100644 tests/test_streams_operations/test_mappers/test_group_by.py delete mode 100644 tests/test_streams_operations/test_mappers/test_join.py delete mode 100644 tests/test_streams_operations/test_mappers/test_map_packets.py delete mode 100644 tests/test_streams_operations/test_mappers/test_map_tags.py delete mode 100644 tests/test_streams_operations/test_mappers/test_merge.py delete mode 100644 tests/test_streams_operations/test_mappers/test_repeat.py delete mode 100644 tests/test_streams_operations/test_mappers/test_transform.py delete mode 100644 tests/test_streams_operations/test_mappers/test_utility_functions.py delete mode 100644 tests/test_streams_operations/test_pipelines/__init__.py delete mode 100644 tests/test_streams_operations/test_pipelines/test_basic_pipelines.py delete mode 100644 tests/test_streams_operations/test_pipelines/test_recursive_features.py delete mode 100644 tests/test_streams_operations/test_pods/__init__.py delete mode 100644 tests/test_streams_operations/test_pods/test_function_pod.py delete mode 100644 tests/test_streams_operations/test_pods/test_function_pod_datastore.py delete mode 100644 tests/test_streams_operations/test_pods/test_pod_base.py delete mode 100644 tests/test_streams_operations/test_sources/__init__.py delete mode 100644 tests/test_streams_operations/test_sources/test_glob_source.py delete mode 100644 tests/test_streams_operations/test_streams/__init__.py delete mode 100644 tests/test_streams_operations/test_streams/test_base_classes.py delete mode 100644 tests/test_streams_operations/test_streams/test_sync_stream_implementations.py diff --git a/tests/test_streams_operations/__init__.py b/tests/test_streams_operations/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_streams_operations/conftest.py b/tests/test_streams_operations/conftest.py deleted file mode 100644 index b6420a3..0000000 --- a/tests/test_streams_operations/conftest.py +++ /dev/null @@ -1,204 +0,0 @@ -""" -Shared fixtures for streams and operations testing. -""" - -import tempfile -import json -import numpy as np -from pathlib import Path -from typing import Any, Iterator -import pytest - -from orcabridge.types import Tag, Packet -from orcabridge.streams import SyncStreamFromLists -from orcabridge.store import DirDataStore - - -@pytest.fixture -def temp_dir(): - """Create a temporary directory for testing.""" - with tempfile.TemporaryDirectory() as tmpdir: - yield Path(tmpdir) - - -@pytest.fixture -def sample_tags(): - """Sample tags for testing.""" - return [ - {"file_name": "day1", "session": "morning"}, - {"file_name": "day2", "session": "afternoon"}, - {"file_name": "day3", "session": "evening"}, - ] - - -@pytest.fixture -def sample_packets(): - """Sample packets for testing.""" - return [ - {"txt_file": "data/day1.txt", "metadata": "meta1.json"}, - {"txt_file": "data/day2.txt", "metadata": "meta2.json"}, - {"txt_file": "data/day3.txt", "metadata": "meta3.json"}, - ] - - -@pytest.fixture -def sample_stream(sample_tags, sample_packets): - """Create a sample stream from tags and packets.""" - return SyncStreamFromLists( - tags=sample_tags, - packets=sample_packets, - tag_keys=["file_name", "session"], - packet_keys=["txt_file", "metadata"], - ) - - -@pytest.fixture -def empty_stream() -> SyncStreamFromLists: - """Create an empty stream.""" - return SyncStreamFromLists(paired=[]) - - -@pytest.fixture -def single_item_stream() -> SyncStreamFromLists: - """Create a stream with a single item.""" - return SyncStreamFromLists(tags=[{"name": "single"}], packets=[{"data": "value"}]) - - -@pytest.fixture -def test_files(temp_dir) -> dict[str, Any]: - """Create test files for source testing.""" - # Create text files - txt_dir = temp_dir / "txt_files" - txt_dir.mkdir() - - txt_files = [] - for i, day in enumerate(["day1", "day2", "day3"], 1): - txt_file = txt_dir / f"{day}.txt" - txt_file.write_text(f"Content for {day}\n" * (i * 5)) - txt_files.append(txt_file) - - # Create binary files with numpy arrays - bin_dir = temp_dir / "bin_files" - bin_dir.mkdir() - - bin_files = [] - for i, session in enumerate(["session_day1", "session_day2"], 1): - bin_file = bin_dir / f"{session}.bin" - data = np.random.rand(10 * i).astype(np.float64) - bin_file.write_bytes(data.tobytes()) - bin_files.append(bin_file) - - # Create json files - json_dir = temp_dir / "json_files" - json_dir.mkdir() - - json_files = [] - for i, info in enumerate(["info_day1", "info_day2"], 1): - json_file = json_dir / f"{info}.json" - data = {"lines": i * 5, "day": f"day{i}", "processed": False} - json_file.write_text(json.dumps(data)) - json_files.append(json_file) - - return { - "txt_dir": txt_dir, - "txt_files": txt_files, - "bin_dir": bin_dir, - "bin_files": bin_files, - "json_dir": json_dir, - "json_files": json_files, - } - - -@pytest.fixture -def data_store(temp_dir) -> DirDataStore: - """Create a test data store.""" - store_dir = temp_dir / "data_store" - return DirDataStore(store_dir=store_dir) - - -# Sample functions for FunctionPod testing - - -def sample_function_no_output(input_file: str) -> None: - """Sample function that takes input but returns nothing.""" - pass - - -def sample_function_single_output(input_file: str) -> str: - """Sample function that returns a single output.""" - return str(Path(input_file).with_suffix(".processed")) - - -def sample_function_multiple_outputs(input_file: str) -> tuple[str, str]: - """Sample function that returns multiple outputs.""" - base = Path(input_file).stem - return f"{base}_output1.txt", f"{base}_output2.txt" - - -def sample_function_with_error(input_file: str) -> str: - """Sample function that raises an error.""" - raise ValueError("Intentional error for testing") - - -def count_lines_function(txt_file: str) -> int: - """Function that counts lines in a text file.""" - with open(txt_file, "r") as f: - return len(f.readlines()) - - -def compute_stats_function(bin_file: str, temp_dir: str | None = None) -> str: - """Function that computes statistics on binary data.""" - import tempfile - - with open(bin_file, "rb") as f: - data = np.frombuffer(f.read(), dtype=np.float64) - - stats = { - "mean": float(np.mean(data)), - "std": float(np.std(data)), - "min": float(np.min(data)), - "max": float(np.max(data)), - "count": len(data), - } - - if temp_dir is None: - output_file = Path(tempfile.mkdtemp()) / "stats.json" - else: - output_file = Path(temp_dir) / "stats.json" - - with open(output_file, "w") as f: - json.dump(stats, f) - - return str(output_file) - - -# Predicate functions for Filter testing - - -def filter_by_session_morning(tag: Tag, packet: Packet) -> bool: - """Filter predicate that keeps only morning sessions.""" - return tag.get("session") == "morning" - - -def filter_by_filename_pattern(tag: Tag, packet: Packet) -> bool: - """Filter predicate that keeps files matching a pattern.""" - return "day1" in tag.get("file_name", "") # type: ignore - - -# Transform functions - - -def transform_add_prefix(tag: Tag, packet: Packet) -> tuple[Tag, Packet]: - """Transform that adds prefix to file_name tag.""" - new_tag = tag.copy() - if "file_name" in new_tag: - new_tag["file_name"] = f"prefix_{new_tag['file_name']}" - return new_tag, packet - - -def transform_rename_keys(tag: Tag, packet: Packet) -> tuple[Tag, Packet]: - """Transform that renames packet keys.""" - new_packet = packet.copy() - if "txt_file" in new_packet: - new_packet["content"] = new_packet.pop("txt_file") - return tag, new_packet diff --git a/tests/test_streams_operations/test_mappers/__init__.py b/tests/test_streams_operations/test_mappers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_streams_operations/test_mappers/test_batch.py b/tests/test_streams_operations/test_mappers/test_batch.py deleted file mode 100644 index b30701e..0000000 --- a/tests/test_streams_operations/test_mappers/test_batch.py +++ /dev/null @@ -1,290 +0,0 @@ -"""Tests for Batch mapper functionality.""" - -import pytest -from orcabridge.mappers import Batch -from orcabridge.streams import SyncStreamFromLists - - -class TestBatch: - """Test cases for Batch mapper.""" - - def test_batch_basic(self, sample_tags, sample_packets): - """Test basic batch functionality.""" - stream = SyncStreamFromLists(sample_tags, sample_packets) - batch = Batch(2, drop_last=False) - batched_stream = batch(stream) - - result = list(batched_stream) - - # Should have 2 batches: [packet1, packet2] and [packet3] - assert len(result) == 2 - - batch1_tag, batch1_packet = result[0] - batch2_tag, batch2_packet = result[1] - - # First batch should have 2 items - assert len(batch1_packet["txt_file"]) == 2 - for k, v in batch1_packet.items(): - assert v == [p[k] for p in sample_packets[:2]] - - assert len(batch2_packet["txt_file"]) == 1 - for k, v in batch2_packet.items(): - assert v == [p[k] for p in sample_packets[2:]] - - def test_batch_exact_division(self): - """Test batch when stream length divides evenly by batch size.""" - packets = [1, 2, 3, 4, 5, 6] - tags = ["a", "b", "c", "d", "e", "f"] - - stream = SyncStreamFromLists(packets, tags) - batch = Batch(3) - batched_stream = batch(stream) - - result = list(batched_stream) - - # Should have exactly 2 batches - assert len(result) == 2 - - batch1_packet, _ = result[0] - batch2_packet, _ = result[1] - - assert len(batch1_packet) == 3 - assert len(batch2_packet) == 3 - assert list(batch1_packet) == [1, 2, 3] - assert list(batch2_packet) == [4, 5, 6] - - def test_batch_size_one(self, sample_packets, sample_tags): - """Test batch with size 1.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - batch = Batch(1) - batched_stream = batch(stream) - - result = list(batched_stream) - - # Should have same number of batches as original packets - assert len(result) == len(sample_packets) - - for i, (batch_packet, batch_tag) in enumerate(result): - assert len(batch_packet) == 1 - assert list(batch_packet) == [sample_packets[i]] - - def test_batch_larger_than_stream(self, sample_packets, sample_tags): - """Test batch size larger than stream.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - batch = Batch(10) # Larger than sample_packets length - batched_stream = batch(stream) - - result = list(batched_stream) - - # Should have exactly 1 batch with all packets - assert len(result) == 1 - - batch_packet, batch_tag = result[0] - assert len(batch_packet) == len(sample_packets) - assert list(batch_packet) == sample_packets - - def test_batch_empty_stream(self): - """Test batch with empty stream.""" - empty_stream = SyncStreamFromLists([], []) - batch = Batch(3) - batched_stream = batch(empty_stream) - - result = list(batched_stream) - assert len(result) == 0 - - def test_batch_preserves_packet_types(self): - """Test that batch preserves different packet types.""" - packets = [PacketType("data1"), {"key": "value"}, [1, 2, 3], 42, "string"] - tags = ["type1", "type2", "type3", "type4", "type5"] - - stream = SyncStreamFromLists(packets, tags) - batch = Batch(2) - batched_stream = batch(stream) - - result = list(batched_stream) - - # Should have 3 batches: [2, 2, 1] - assert len(result) == 3 - - # Check first batch - batch1_packet, _ = result[0] - batch1_list = list(batch1_packet) - assert batch1_list[0] == PacketType("data1") - assert batch1_list[1] == {"key": "value"} - - # Check second batch - batch2_packet, _ = result[1] - batch2_list = list(batch2_packet) - assert batch2_list[0] == [1, 2, 3] - assert batch2_list[1] == 42 - - # Check third batch - batch3_packet, _ = result[2] - batch3_list = list(batch3_packet) - assert batch3_list[0] == "string" - - def test_batch_tag_handling(self, sample_packets, sample_tags): - """Test how batch handles tags.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - batch = Batch(2) - batched_stream = batch(stream) - - result = list(batched_stream) - - # Each batch should have some representation of the constituent tags - for batch_packet, batch_tag in result: - assert batch_tag is not None - # The exact format depends on implementation - - def test_batch_maintains_order(self): - """Test that batch maintains packet order within batches.""" - packets = [f"packet_{i}" for i in range(10)] - tags = [f"tag_{i}" for i in range(10)] - - stream = SyncStreamFromLists(packets, tags) - batch = Batch(3) - batched_stream = batch(stream) - - result = list(batched_stream) - - # Should have 4 batches: [3, 3, 3, 1] - assert len(result) == 4 - - # Check order within each batch - all_packets = [] - for batch_packet, _ in result: - all_packets.extend(list(batch_packet)) - - assert all_packets == packets - - def test_batch_large_stream(self): - """Test batch with large stream.""" - packets = [f"packet_{i}" for i in range(1000)] - tags = [f"tag_{i}" for i in range(1000)] - - stream = SyncStreamFromLists(packets, tags) - batch = Batch(50) - batched_stream = batch(stream) - - result = list(batched_stream) - - # Should have exactly 20 batches of 50 each - assert len(result) == 20 - - for i, (batch_packet, _) in enumerate(result): - assert len(batch_packet) == 50 - expected_packets = packets[i * 50 : (i + 1) * 50] - assert list(batch_packet) == expected_packets - - def test_batch_invalid_size(self): - """Test batch with invalid size.""" - with pytest.raises(ValueError): - Batch(0) - - with pytest.raises(ValueError): - Batch(-1) - - with pytest.raises(TypeError): - Batch(3.5) - - with pytest.raises(TypeError): - Batch("3") - - def test_batch_chaining(self, sample_packets, sample_tags): - """Test chaining batch operations.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - - # First batch: size 2 - batch1 = Batch(2) - stream1 = batch1(stream) - - # Second batch: size 1 (batch the batches) - batch2 = Batch(1) - stream2 = batch2(stream1) - - result = list(stream2) - - # Each item should be a batch containing a single batch - for batch_packet, _ in result: - assert len(batch_packet) == 1 - # The contained item should itself be a batch - - def test_batch_with_generator_stream(self): - """Test batch with generator-based stream.""" - - def packet_generator(): - for i in range(7): - yield f"packet_{i}", f"tag_{i}" - - from orcabridge.stream import SyncStreamFromGenerator - - stream = SyncStreamFromGenerator(packet_generator()) - - batch = Batch(3) - batched_stream = batch(stream) - - result = list(batched_stream) - - # Should have 3 batches: [3, 3, 1] - assert len(result) == 3 - - batch1_packet, _ = result[0] - batch2_packet, _ = result[1] - batch3_packet, _ = result[2] - - assert len(batch1_packet) == 3 - assert len(batch2_packet) == 3 - assert len(batch3_packet) == 1 - - def test_batch_memory_efficiency(self): - """Test that batch doesn't consume excessive memory.""" - # Create a large stream - packets = [f"packet_{i}" for i in range(10000)] - tags = [f"tag_{i}" for i in range(10000)] - - stream = SyncStreamFromLists(packets, tags) - batch = Batch(100) - batched_stream = batch(stream) - - # Process one batch at a time to test memory efficiency - batch_count = 0 - for batch_packet, _ in batched_stream: - batch_count += 1 - assert len(batch_packet) <= 100 - if batch_count == 50: # Stop early to avoid processing everything - break - - assert batch_count == 50 - - def test_batch_with_none_packets(self): - """Test batch with None packets.""" - packets = [1, None, 3, None, 5, None] - tags = ["num1", "null1", "num3", "null2", "num5", "null3"] - - stream = SyncStreamFromLists(packets, tags) - batch = Batch(2) - batched_stream = batch(stream) - - result = list(batched_stream) - - assert len(result) == 3 - - # Check that None values are preserved - all_packets = [] - for batch_packet, _ in result: - all_packets.extend(list(batch_packet)) - - assert all_packets == packets - - def test_batch_pickle(self): - """Test that Batch mapper is pickleable.""" - import pickle - from orcabridge.mappers import Batch - - batch = Batch(batch_size=3) - pickled = pickle.dumps(batch) - unpickled = pickle.loads(pickled) - - # Test that unpickled mapper works the same - assert isinstance(unpickled, Batch) - assert unpickled.batch_size == batch.batch_size diff --git a/tests/test_streams_operations/test_mappers/test_cache_stream.py b/tests/test_streams_operations/test_mappers/test_cache_stream.py deleted file mode 100644 index feefb61..0000000 --- a/tests/test_streams_operations/test_mappers/test_cache_stream.py +++ /dev/null @@ -1,299 +0,0 @@ -""" -Test module for CacheStream mapper. - -This module tests the CacheStream mapper functionality, which provides -caching capabilities to avoid upstream recomputation by storing stream data -in memory after the first iteration. -""" - -import pytest -from unittest.mock import Mock - -from orcabridge.base import SyncStream -from orcabridge.mapper import CacheStream -from orcabridge.stream import SyncStreamFromLists - - -@pytest.fixture -def cache_mapper(): - """Create a CacheStream mapper instance.""" - return CacheStream() - - -@pytest.fixture -def sample_stream_data(): - """Sample stream data for testing.""" - return [ - ({"id": 1}, {"value": 10}), - ({"id": 2}, {"value": 20}), - ({"id": 3}, {"value": 30}), - ] - - -@pytest.fixture -def sample_stream(sample_stream_data): - """Create a sample stream.""" - tags, packets = zip(*sample_stream_data) - return SyncStreamFromLists(list(tags), list(packets)) - - -class TestCacheStream: - """Test cases for CacheStream mapper.""" - - def test_cache_initialization(self, cache_mapper): - """Test that CacheStream initializes with empty cache.""" - assert cache_mapper.cache == [] - assert cache_mapper.is_cached is False - - def test_repr(self, cache_mapper): - """Test CacheStream string representation.""" - assert repr(cache_mapper) == "CacheStream(active:False)" - - # After caching - cache_mapper.is_cached = True - assert repr(cache_mapper) == "CacheStream(active:True)" - - def test_first_iteration_caches_data(self, cache_mapper, sample_stream): - """Test that first iteration through stream caches the data.""" - cached_stream = cache_mapper(sample_stream) - - # Initially not cached - assert not cache_mapper.is_cached - assert len(cache_mapper.cache) == 0 - - # Iterate through stream - result = list(cached_stream) - - # After iteration, should be cached - assert cache_mapper.is_cached - assert len(cache_mapper.cache) == 3 - assert cache_mapper.cache == [ - ({"id": 1}, {"value": 10}), - ({"id": 2}, {"value": 20}), - ({"id": 3}, {"value": 30}), - ] - - # Result should match original stream - assert result == [ - ({"id": 1}, {"value": 10}), - ({"id": 2}, {"value": 20}), - ({"id": 3}, {"value": 30}), - ] - - def test_subsequent_iterations_use_cache(self, cache_mapper, sample_stream): - """Test that subsequent iterations use cached data.""" - cached_stream = cache_mapper(sample_stream) - - # First iteration - first_result = list(cached_stream) - assert cache_mapper.is_cached - - # Create new stream from same mapper (simulates reuse) - second_cached_stream = cache_mapper() # No input streams for cached version - second_result = list(second_cached_stream) - - # Results should be identical - assert first_result == second_result - assert second_result == [ - ({"id": 1}, {"value": 10}), - ({"id": 2}, {"value": 20}), - ({"id": 3}, {"value": 30}), - ] - - def test_clear_cache(self, cache_mapper, sample_stream): - """Test cache clearing functionality.""" - cached_stream = cache_mapper(sample_stream) - - # Cache some data - list(cached_stream) - assert cache_mapper.is_cached - assert len(cache_mapper.cache) > 0 - - # Clear cache - cache_mapper.clear_cache() - assert not cache_mapper.is_cached - assert len(cache_mapper.cache) == 0 - - def test_multiple_streams_error_when_not_cached(self, cache_mapper, sample_stream): - """Test that providing multiple streams raises error when not cached.""" - stream2 = SyncStreamFromLists([{"id": 4}], [{"value": 40}]) - - with pytest.raises( - ValueError, match="CacheStream operation requires exactly one stream" - ): - cache_mapper(sample_stream, stream2) - - def test_no_streams_when_cached(self, cache_mapper, sample_stream): - """Test that cached stream can be called without input streams.""" - # First, cache some data - cached_stream = cache_mapper(sample_stream) - list(cached_stream) # This caches the data - - # Now call without streams (should use cache) - cached_only_stream = cache_mapper() - result = list(cached_only_stream) - - assert result == [ - ({"id": 1}, {"value": 10}), - ({"id": 2}, {"value": 20}), - ({"id": 3}, {"value": 30}), - ] - - def test_empty_stream_caching(self, cache_mapper): - """Test caching behavior with empty stream.""" - empty_stream = SyncStreamFromLists([], []) - cached_stream = cache_mapper(empty_stream) - - result = list(cached_stream) - - assert result == [] - assert cache_mapper.is_cached - assert cache_mapper.cache == [] - - def test_identity_structure(self, cache_mapper, sample_stream): - """Test that CacheStream has unique identity structure.""" - # CacheStream should return None for identity structure - # to treat every instance as different - assert cache_mapper.identity_structure(sample_stream) is None - - def test_avoids_upstream_recomputation(self, cache_mapper): - """Test that CacheStream avoids upstream recomputation.""" - # Create a mock stream that tracks how many times it's iterated - iteration_count = {"count": 0} - - def counting_generator(): - iteration_count["count"] += 1 - yield ({"id": 1}, {"value": 10}) - yield ({"id": 2}, {"value": 20}) - - mock_stream = Mock(spec=SyncStream) - mock_stream.__iter__ = counting_generator - - cached_stream = cache_mapper(mock_stream) - - # First iteration should call upstream - list(cached_stream) - assert iteration_count["count"] == 1 - - # Second iteration should use cache (not call upstream) - second_cached_stream = cache_mapper() - list(second_cached_stream) - assert iteration_count["count"] == 1 # Should still be 1 - - def test_cache_with_different_data_types(self, cache_mapper): - """Test caching with various data types.""" - complex_data = [ - ({"id": 1, "type": "string"}, {"data": "hello", "numbers": [1, 2, 3]}), - ({"id": 2, "type": "dict"}, {"data": {"nested": True}, "numbers": None}), - ({"id": 3, "type": "boolean"}, {"data": True, "numbers": 42}), - ] - - tags, packets = zip(*complex_data) - stream = SyncStreamFromLists(list(tags), list(packets)) - cached_stream = cache_mapper(stream) - - result = list(cached_stream) - - assert result == complex_data - assert cache_mapper.is_cached - assert cache_mapper.cache == complex_data - - def test_multiple_cache_instances(self, sample_stream): - """Test that different CacheStream instances have separate caches.""" - cache1 = CacheStream() - cache2 = CacheStream() - - # Cache in first instance - cached_stream1 = cache1(sample_stream) - list(cached_stream1) - - # Second instance should not be cached - assert cache1.is_cached - assert not cache2.is_cached - assert len(cache1.cache) == 3 - assert len(cache2.cache) == 0 - - def test_keys_method(self, cache_mapper, sample_stream): - """Test that CacheStream passes through keys correctly.""" - # CacheStream should inherit keys from input stream - tag_keys, packet_keys = cache_mapper.keys(sample_stream) - original_tag_keys, original_packet_keys = sample_stream.keys() - - assert tag_keys == original_tag_keys - assert packet_keys == original_packet_keys - - def test_chaining_with_cache(self, cache_mapper, sample_stream): - """Test chaining CacheStream with other operations.""" - from orcabridge.mapper import Filter - - # Chain cache with filter - filter_mapper = Filter(lambda tag, packet: tag["id"] > 1) - - # Cache first, then filter - cached_stream = cache_mapper(sample_stream) - filtered_stream = filter_mapper(cached_stream) - - result = list(filtered_stream) - - assert len(result) == 2 # Should have filtered out id=1 - assert result == [ - ({"id": 2}, {"value": 20}), - ({"id": 3}, {"value": 30}), - ] - - # Cache should still be populated with original data - assert cache_mapper.is_cached - assert len(cache_mapper.cache) == 3 - - def test_cache_persistence_across_multiple_outputs( - self, cache_mapper, sample_stream - ): - """Test that cache persists when creating multiple output streams.""" - # First stream - stream1 = cache_mapper(sample_stream) - result1 = list(stream1) - - # Second stream from same cache - stream2 = cache_mapper() - result2 = list(stream2) - - # Third stream from same cache - stream3 = cache_mapper() - result3 = list(stream3) - - # All results should be identical - assert result1 == result2 == result3 - assert len(result1) == 3 - - def test_error_handling_during_caching(self, cache_mapper): - """Test error handling when upstream stream raises exception.""" - - def error_generator(): - yield ({"id": 1}, {"value": 10}) - raise ValueError("Upstream error") - - mock_stream = Mock(spec=SyncStream) - mock_stream.__iter__ = error_generator - - cached_stream = cache_mapper(mock_stream) - - # Should propagate the error and not cache partial data - with pytest.raises(ValueError, match="Upstream error"): - list(cached_stream) - - # Cache should remain empty after error - assert not cache_mapper.is_cached - assert len(cache_mapper.cache) == 0 - - def test_cache_stream_pickle(self): - """Test that CacheStream mapper is pickleable.""" - import pickle - from orcabridge.mappers import CacheStream - - cache_stream = CacheStream() - pickled = pickle.dumps(cache_stream) - unpickled = pickle.loads(pickled) - - # Test that unpickled mapper works the same - assert isinstance(unpickled, CacheStream) - assert unpickled.__class__.__name__ == "CacheStream" diff --git a/tests/test_streams_operations/test_mappers/test_default_tag.py b/tests/test_streams_operations/test_mappers/test_default_tag.py deleted file mode 100644 index 281002b..0000000 --- a/tests/test_streams_operations/test_mappers/test_default_tag.py +++ /dev/null @@ -1,260 +0,0 @@ -"""Tests for DefaultTag mapper functionality.""" - -import pytest -from orcabridge.base import PacketType -from orcabridge.mapper import DefaultTag -from orcabridge.stream import SyncStreamFromLists - - -class TestDefaultTag: - """Test cases for DefaultTag mapper.""" - - def test_default_tag_basic(self, sample_packets): - """Test basic default tag functionality.""" - tags = ["existing1", None, "existing2"] - - stream = SyncStreamFromLists(sample_packets, tags) - default_tag = DefaultTag("default_value") - result_stream = default_tag(stream) - - result = list(result_stream) - - expected_tags = ["existing1", "default_value", "existing2"] - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == sample_packets - assert actual_tags == expected_tags - - def test_default_tag_all_none(self, sample_packets): - """Test default tag when all tags are None.""" - tags = [None, None, None] - - stream = SyncStreamFromLists(sample_packets, tags) - default_tag = DefaultTag("fallback") - result_stream = default_tag(stream) - - result = list(result_stream) - - expected_tags = ["fallback", "fallback", "fallback"] - actual_tags = [tag for _, tag in result] - - assert actual_tags == expected_tags - - def test_default_tag_no_none(self, sample_packets, sample_tags): - """Test default tag when no tags are None.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - default_tag = DefaultTag("unused_default") - result_stream = default_tag(stream) - - result = list(result_stream) - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - # Should remain unchanged - assert actual_packets == sample_packets - assert actual_tags == sample_tags - - def test_default_tag_empty_stream(self): - """Test default tag with empty stream.""" - empty_stream = SyncStreamFromLists([], []) - default_tag = DefaultTag("default") - result_stream = default_tag(empty_stream) - - result = list(result_stream) - assert len(result) == 0 - - def test_default_tag_different_types(self): - """Test default tag with different default value types.""" - packets = ["data1", "data2", "data3"] - tags = [None, "existing", None] - - # Test with string default - stream1 = SyncStreamFromLists(packets, tags) - default_tag1 = DefaultTag("string_default") - result1 = list(default_tag1(stream1)) - - expected_tags1 = ["string_default", "existing", "string_default"] - actual_tags1 = [tag for _, tag in result1] - assert actual_tags1 == expected_tags1 - - # Test with numeric default - stream2 = SyncStreamFromLists(packets, tags) - default_tag2 = DefaultTag(42) - result2 = list(default_tag2(stream2)) - - expected_tags2 = [42, "existing", 42] - actual_tags2 = [tag for _, tag in result2] - assert actual_tags2 == expected_tags2 - - def test_default_tag_empty_string_vs_none(self): - """Test default tag distinguishes between empty string and None.""" - packets = ["data1", "data2", "data3"] - tags = [None, "", None] # Empty string vs None - - stream = SyncStreamFromLists(packets, tags) - default_tag = DefaultTag("default") - result_stream = default_tag(stream) - - result = list(result_stream) - - # Empty string should be preserved, None should be replaced - expected_tags = ["default", "", "default"] - actual_tags = [tag for _, tag in result] - - assert actual_tags == expected_tags - - def test_default_tag_preserves_packets(self): - """Test that default tag preserves all packet types.""" - packets = [PacketType("data1"), {"key": "value"}, [1, 2, 3], 42, "string"] - tags = [None, None, "existing", None, None] - - stream = SyncStreamFromLists(packets, tags) - default_tag = DefaultTag("default") - result_stream = default_tag(stream) - - result = list(result_stream) - - actual_packets = [packet for packet, _ in result] - expected_tags = ["default", "default", "existing", "default", "default"] - actual_tags = [tag for _, tag in result] - - assert actual_packets == packets - assert actual_tags == expected_tags - - def test_default_tag_with_complex_default(self): - """Test default tag with complex default value.""" - packets = ["data1", "data2"] - tags = [None, "existing"] - - default_value = {"type": "default", "timestamp": 12345} - - stream = SyncStreamFromLists(packets, tags) - default_tag = DefaultTag(default_value) - result_stream = default_tag(stream) - - result = list(result_stream) - - expected_tags = [default_value, "existing"] - actual_tags = [tag for _, tag in result] - - assert actual_tags == expected_tags - assert actual_tags[0] is default_value # Should be the same object - - def test_default_tag_chaining(self, sample_packets): - """Test chaining multiple default tag operations.""" - tags = [None, "middle", None] - - stream = SyncStreamFromLists(sample_packets, tags) - - # First default tag - default_tag1 = DefaultTag("first_default") - stream1 = default_tag1(stream) - - # Create new stream with some None tags again - intermediate_result = list(stream1) - new_tags = [ - None if tag == "first_default" else tag for _, tag in intermediate_result - ] - new_packets = [packet for packet, _ in intermediate_result] - - stream2 = SyncStreamFromLists(new_packets, new_tags) - default_tag2 = DefaultTag("second_default") - stream3 = default_tag2(stream2) - - final_result = list(stream3) - - # The "middle" tag should be preserved - actual_tags = [tag for _, tag in final_result] - assert "middle" in actual_tags - assert "second_default" in actual_tags - - def test_default_tag_maintains_order(self): - """Test that default tag maintains packet order.""" - packets = [f"packet_{i}" for i in range(10)] - tags = [None if i % 2 == 0 else f"tag_{i}" for i in range(10)] - - stream = SyncStreamFromLists(packets, tags) - default_tag = DefaultTag("even_default") - result_stream = default_tag(stream) - - result = list(result_stream) - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == packets - - # Check that even indices got default tags, odd indices kept original - for i in range(10): - if i % 2 == 0: - assert actual_tags[i] == "even_default" - else: - assert actual_tags[i] == f"tag_{i}" - - def test_default_tag_with_callable_default(self): - """Test default tag with callable default (if supported).""" - packets = ["data1", "data2", "data3"] - tags = [None, "existing", None] - - # Simple callable that returns a counter - class DefaultGenerator: - def __init__(self): - self.count = 0 - - def __call__(self): - self.count += 1 - return f"default_{self.count}" - - # If the implementation supports callable defaults - try: - default_gen = DefaultGenerator() - stream = SyncStreamFromLists(packets, tags) - default_tag = DefaultTag(default_gen) - result_stream = default_tag(stream) - - result = list(result_stream) - actual_tags = [tag for _, tag in result] - - # This would only work if DefaultTag supports callable defaults - # Otherwise this test should be skipped or modified - assert "existing" in actual_tags - except (TypeError, AttributeError): - # If callable defaults are not supported, that's fine - pass - - def test_default_tag_large_stream(self): - """Test default tag with large stream.""" - packets = [f"packet_{i}" for i in range(1000)] - tags = [None if i % 3 == 0 else f"tag_{i}" for i in range(1000)] - - stream = SyncStreamFromLists(packets, tags) - default_tag = DefaultTag("bulk_default") - result_stream = default_tag(stream) - - result = list(result_stream) - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert len(actual_packets) == 1000 - assert len(actual_tags) == 1000 - - # Check that every third tag was replaced - for i in range(1000): - if i % 3 == 0: - assert actual_tags[i] == "bulk_default" - else: - assert actual_tags[i] == f"tag_{i}" def test_default_tag_pickle(self): - """Test that DefaultTag mapper is pickleable.""" - import pickle - from orcabridge.mappers import DefaultTag - - default_tag = DefaultTag({"default": "test"}) - pickled = pickle.dumps(default_tag) - unpickled = pickle.loads(pickled) - - # Test that unpickled mapper works the same - assert isinstance(unpickled, DefaultTag) - assert unpickled.default_tag == default_tag.default_tag diff --git a/tests/test_streams_operations/test_mappers/test_filter.py b/tests/test_streams_operations/test_mappers/test_filter.py deleted file mode 100644 index b16049d..0000000 --- a/tests/test_streams_operations/test_mappers/test_filter.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Tests for Filter mapper functionality.""" - -import pytest -from orcabridge.base import PacketType -from orcabridge.mapper import Filter -from orcabridge.stream import SyncStreamFromLists - - -class TestFilter: - """Test cases for Filter mapper.""" - - def test_filter_basic(self, simple_predicate): - """Test basic filter functionality.""" - packets = [1, 2, 3, 4, 5, 6] - tags = ["odd", "even", "odd", "even", "odd", "even"] - - stream = SyncStreamFromLists(packets, tags) - filter_mapper = Filter(simple_predicate) - filtered_stream = filter_mapper(stream) - - result = list(filtered_stream) - - # Should keep only even numbers - expected_packets = [2, 4, 6] - expected_tags = ["even", "even", "even"] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_filter_none_match(self, sample_packets, sample_tags): - """Test filter when no packets match.""" - - def never_matches(packet, tag): - return False - - stream = SyncStreamFromLists(sample_packets, sample_tags) - filter_mapper = Filter(never_matches) - filtered_stream = filter_mapper(stream) - - result = list(filtered_stream) - assert len(result) == 0 - - def test_filter_all_match(self, sample_packets, sample_tags): - """Test filter when all packets match.""" - - def always_matches(packet, tag): - return True - - stream = SyncStreamFromLists(sample_packets, sample_tags) - filter_mapper = Filter(always_matches) - filtered_stream = filter_mapper(stream) - - result = list(filtered_stream) - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == sample_packets - assert actual_tags == sample_tags - - def test_filter_empty_stream(self, simple_predicate): - """Test filter with empty stream.""" - empty_stream = SyncStreamFromLists([], []) - filter_mapper = Filter(simple_predicate) - filtered_stream = filter_mapper(empty_stream) - - result = list(filtered_stream) - assert len(result) == 0 - - def test_filter_string_predicate(self): - """Test filter with string-based predicate.""" - packets = ["apple", "banana", "cherry", "date", "elderberry"] - tags = ["fruit1", "fruit2", "fruit3", "fruit4", "fruit5"] - - def starts_with_vowel(packet, tag): - return isinstance(packet, str) and packet[0].lower() in "aeiou" - - stream = SyncStreamFromLists(packets, tags) - filter_mapper = Filter(starts_with_vowel) - filtered_stream = filter_mapper(stream) - - result = list(filtered_stream) - - expected_packets = ["apple", "elderberry"] - expected_tags = ["fruit1", "fruit5"] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_filter_tag_based_predicate(self): - """Test filter using tag information.""" - packets = [10, 20, 30, 40, 50] - tags = ["small", "medium", "large", "huge", "enormous"] - - def tag_length_filter(packet, tag): - return len(tag) <= 5 - - stream = SyncStreamFromLists(packets, tags) - filter_mapper = Filter(tag_length_filter) - filtered_stream = filter_mapper(stream) - - result = list(filtered_stream) - - expected_packets = [10, 40] # "small" and "huge" have <= 5 chars - expected_tags = ["small", "huge"] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_filter_complex_predicate(self): - """Test filter with complex predicate.""" - packets = [ - {"value": 5, "type": "A", "active": True}, - {"value": 15, "type": "B", "active": False}, - {"value": 25, "type": "A", "active": True}, - {"value": 35, "type": "C", "active": True}, - {"value": 45, "type": "A", "active": False}, - ] - tags = ["item1", "item2", "item3", "item4", "item5"] - - def complex_predicate(packet, tag): - return ( - isinstance(packet, dict) - and packet.get("type") == "A" - and packet.get("active", False) - and packet.get("value", 0) > 10 - ) - - stream = SyncStreamFromLists(packets, tags) - filter_mapper = Filter(complex_predicate) - filtered_stream = filter_mapper(stream) - - result = list(filtered_stream) - - # Only the third item matches all conditions - expected_packets = [{"value": 25, "type": "A", "active": True}] - expected_tags = ["item3"] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_filter_with_none_packets(self): - """Test filter with None packets.""" - packets = [None, "data", None, "more_data", None] - tags = ["empty1", "full1", "empty2", "full2", "empty3"] - - def not_none(packet, tag): - return packet is not None - - stream = SyncStreamFromLists(packets, tags) - filter_mapper = Filter(not_none) - filtered_stream = filter_mapper(stream) - - result = list(filtered_stream) - - expected_packets = ["data", "more_data"] - expected_tags = ["full1", "full2"] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_filter_preserves_packet_types(self): - """Test that filter preserves packet types.""" - packets = [PacketType("data1"), [1, 2, 3], {"key": "value"}, "string", 42] - tags = ["type1", "type2", "type3", "type4", "type5"] - - def is_container(packet, tag): - return isinstance(packet, (list, dict)) - - stream = SyncStreamFromLists(packets, tags) - filter_mapper = Filter(is_container) - filtered_stream = filter_mapper(stream) - - result = list(filtered_stream) - - expected_packets = [[1, 2, 3], {"key": "value"}] - expected_tags = ["type2", "type3"] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - assert isinstance(actual_packets[0], list) - assert isinstance(actual_packets[1], dict) - - def test_filter_maintains_order(self): - """Test that filter maintains packet order.""" - packets = [f"packet_{i}" for i in range(20)] - tags = [f"tag_{i}" for i in range(20)] - - def keep_even_indices(packet, tag): - # Extract index from packet name - index = int(packet.split("_")[1]) - return index % 2 == 0 - - stream = SyncStreamFromLists(packets, tags) - filter_mapper = Filter(keep_even_indices) - filtered_stream = filter_mapper(stream) - - result = list(filtered_stream) - - expected_packets = [f"packet_{i}" for i in range(0, 20, 2)] - expected_tags = [f"tag_{i}" for i in range(0, 20, 2)] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_filter_predicate_exception(self, sample_packets, sample_tags): - """Test filter when predicate raises exception.""" - - def error_predicate(packet, tag): - if packet == sample_packets[1]: # Error on second packet - raise ValueError("Predicate error") - return True - - stream = SyncStreamFromLists(sample_packets, sample_tags) - filter_mapper = Filter(error_predicate) - filtered_stream = filter_mapper(stream) - - # Should propagate the exception - with pytest.raises(ValueError): - list(filtered_stream) - - def test_filter_with_lambda(self): - """Test filter with lambda predicate.""" - packets = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - tags = [f"num_{i}" for i in packets] - - stream = SyncStreamFromLists(packets, tags) - filter_mapper = Filter(lambda p, t: p % 3 == 0) - filtered_stream = filter_mapper(stream) - - result = list(filtered_stream) - - expected_packets = [3, 6, 9] - actual_packets = [packet for packet, _ in result] - - assert actual_packets == expected_packets - - def test_filter_chaining(self): - """Test chaining multiple filter operations.""" - packets = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - tags = [f"num_{i}" for i in packets] - - stream = SyncStreamFromLists(packets, tags) - - # First filter: keep even numbers - filter1 = Filter(lambda p, t: p % 2 == 0) - stream1 = filter1(stream) - - # Second filter: keep numbers > 4 - filter2 = Filter(lambda p, t: p > 4) - stream2 = filter2(stream1) - - result = list(stream2) - - expected_packets = [6, 8, 10] # Even numbers > 4 - actual_packets = [packet for packet, _ in result] - - assert actual_packets == expected_packets - - def test_filter_with_generator_stream(self): - """Test filter with generator-based stream.""" - - def packet_generator(): - for i in range(20): - yield i, f"tag_{i}" - - from orcabridge.stream import SyncStreamFromGenerator - - stream = SyncStreamFromGenerator(packet_generator()) - - def is_prime(packet, tag): - if packet < 2: - return False - for i in range(2, int(packet**0.5) + 1): - if packet % i == 0: - return False - return True - - filter_mapper = Filter(is_prime) - filtered_stream = filter_mapper(stream) - - result = list(filtered_stream) - - # Prime numbers under 20: 2, 3, 5, 7, 11, 13, 17, 19 - expected_packets = [2, 3, 5, 7, 11, 13, 17, 19] - actual_packets = [packet for packet, _ in result] - - assert actual_packets == expected_packets - - def test_filter_pickle(self): - """Test that Filter mapper is pickleable.""" - import pickle - from orcabridge.mappers import Filter - - def is_even(tag, packet): - return packet % 2 == 0 - - filter_mapper = Filter(is_even) - pickled = pickle.dumps(filter_mapper) - unpickled = pickle.loads(pickled) - - # Test that unpickled mapper works the same - assert isinstance(unpickled, Filter) - assert unpickled.__class__.__name__ == "Filter" diff --git a/tests/test_streams_operations/test_mappers/test_first_match.py b/tests/test_streams_operations/test_mappers/test_first_match.py deleted file mode 100644 index b282ebc..0000000 --- a/tests/test_streams_operations/test_mappers/test_first_match.py +++ /dev/null @@ -1,244 +0,0 @@ -"""Tests for FirstMatch mapper functionality.""" - -import pytest -from orcabridge.base import PacketType -from orcabridge.mapper import FirstMatch -from orcabridge.stream import SyncStreamFromLists - - -class TestFirstMatch: - """Test cases for FirstMatch mapper.""" - - def test_first_match_basic(self, simple_predicate): - """Test basic first match functionality.""" - packets = [1, 2, 3, 4, 5] - tags = ["odd", "even", "odd", "even", "odd"] - - stream = SyncStreamFromLists(packets, tags) - first_match = FirstMatch(simple_predicate) - result_stream = first_match(stream) - - result = list(result_stream) - - # Should find the first packet that matches the predicate - assert len(result) == 1 - packet, tag = result[0] - assert packet == 2 # First even number - assert tag == "even" - - def test_first_match_no_match(self, sample_packets, sample_tags): - """Test first match when no packet matches.""" - - def never_matches(packet, tag): - return False - - stream = SyncStreamFromLists(sample_packets, sample_tags) - first_match = FirstMatch(never_matches) - result_stream = first_match(stream) - - result = list(result_stream) - assert len(result) == 0 - - def test_first_match_all_match(self, sample_packets, sample_tags): - """Test first match when all packets match.""" - - def always_matches(packet, tag): - return True - - stream = SyncStreamFromLists(sample_packets, sample_tags) - first_match = FirstMatch(always_matches) - result_stream = first_match(stream) - - result = list(result_stream) - - # Should return only the first packet - assert len(result) == 1 - packet, tag = result[0] - assert packet == sample_packets[0] - assert tag == sample_tags[0] - - def test_first_match_empty_stream(self, simple_predicate): - """Test first match with empty stream.""" - empty_stream = SyncStreamFromLists([], []) - first_match = FirstMatch(simple_predicate) - result_stream = first_match(empty_stream) - - result = list(result_stream) - assert len(result) == 0 - - def test_first_match_string_predicate(self): - """Test first match with string-based predicate.""" - packets = ["apple", "banana", "cherry", "date"] - tags = ["fruit1", "fruit2", "fruit3", "fruit4"] - - def starts_with_c(packet, tag): - return isinstance(packet, str) and packet.startswith("c") - - stream = SyncStreamFromLists(packets, tags) - first_match = FirstMatch(starts_with_c) - result_stream = first_match(stream) - - result = list(result_stream) - assert len(result) == 1 - packet, tag = result[0] - assert packet == "cherry" - assert tag == "fruit3" - - def test_first_match_tag_based_predicate(self): - """Test first match using tag information.""" - packets = [10, 20, 30, 40] - tags = ["small", "medium", "large", "huge"] - - def tag_contains_e(packet, tag): - return "e" in tag - - stream = SyncStreamFromLists(packets, tags) - first_match = FirstMatch(tag_contains_e) - result_stream = first_match(stream) - - result = list(result_stream) - assert len(result) == 1 - packet, tag = result[0] - assert packet == 20 # "medium" contains 'e' - assert tag == "medium" - - def test_first_match_complex_predicate(self): - """Test first match with complex predicate.""" - packets = [ - {"value": 5, "type": "A"}, - {"value": 15, "type": "B"}, - {"value": 25, "type": "A"}, - {"value": 35, "type": "C"}, - ] - tags = ["item1", "item2", "item3", "item4"] - - def complex_predicate(packet, tag): - return ( - isinstance(packet, dict) - and packet.get("value", 0) > 10 - and packet.get("type") == "A" - ) - - stream = SyncStreamFromLists(packets, tags) - first_match = FirstMatch(complex_predicate) - result_stream = first_match(stream) - - result = list(result_stream) - assert len(result) == 1 - packet, tag = result[0] - assert packet == {"value": 25, "type": "A"} - assert tag == "item3" - - def test_first_match_with_none_packets(self): - """Test first match with None packets.""" - packets = [None, "data", None, "more_data"] - tags = ["empty1", "full1", "empty2", "full2"] - - def not_none(packet, tag): - return packet is not None - - stream = SyncStreamFromLists(packets, tags) - first_match = FirstMatch(not_none) - result_stream = first_match(stream) - - result = list(result_stream) - assert len(result) == 1 - packet, tag = result[0] - assert packet == "data" - assert tag == "full1" - - def test_first_match_preserves_packet_types(self): - """Test that first match preserves packet types.""" - packets = [PacketType("data1"), [1, 2, 3], {"key": "value"}, 42] - tags = ["str", "list", "dict", "int"] - - def is_list(packet, tag): - return isinstance(packet, list) - - stream = SyncStreamFromLists(packets, tags) - first_match = FirstMatch(is_list) - result_stream = first_match(stream) - - result = list(result_stream) - assert len(result) == 1 - packet, tag = result[0] - assert packet == [1, 2, 3] - assert tag == "list" - assert isinstance(packet, list) - - def test_first_match_predicate_exception(self, sample_packets, sample_tags): - """Test first match when predicate raises exception.""" - - def error_predicate(packet, tag): - if packet == sample_packets[1]: # Error on second packet - raise ValueError("Predicate error") - return packet == sample_packets[2] # Match third packet - - stream = SyncStreamFromLists(sample_packets, sample_tags) - first_match = FirstMatch(error_predicate) - result_stream = first_match(stream) - - # The behavior here depends on implementation - # It might propagate the exception or skip the problematic packet - with pytest.raises(ValueError): - list(result_stream) - - def test_first_match_with_generator_stream(self): - """Test first match with generator-based stream.""" - - def packet_generator(): - for i in range(10): - yield f"packet_{i}", f"tag_{i}" - - from orcabridge.stream import SyncStreamFromGenerator - - stream = SyncStreamFromGenerator(packet_generator()) - - def find_packet_5(packet, tag): - return packet == "packet_5" - - first_match = FirstMatch(find_packet_5) - result_stream = first_match(stream) - - result = list(result_stream) - assert len(result) == 1 - packet, tag = result[0] - assert packet == "packet_5" - assert tag == "tag_5" - - def test_first_match_early_termination(self): - """Test that first match terminates early and doesn't process remaining packets.""" - processed_packets = [] - - def tracking_predicate(packet, tag): - processed_packets.append(packet) - return packet == "target" - - packets = ["a", "b", "target", "c", "d"] - tags = ["tag1", "tag2", "tag3", "tag4", "tag5"] - - stream = SyncStreamFromLists(packets, tags) - first_match = FirstMatch(tracking_predicate) - result_stream = first_match(stream) - - result = list(result_stream) - - # Should have found the target - assert len(result) == 1 - assert result[0][0] == "target" - - # Should have stopped processing after finding the target - assert processed_packets == ["a", "b", "target"] - - def test_first_match_pickle(self): - """Test that FirstMatch mapper is pickleable.""" - import pickle - from orcabridge.mappers import FirstMatch - - first_match = FirstMatch() - pickled = pickle.dumps(first_match) - unpickled = pickle.loads(pickled) - - # Test that unpickled mapper works the same - assert isinstance(unpickled, FirstMatch) - assert unpickled.__class__.__name__ == "FirstMatch" diff --git a/tests/test_streams_operations/test_mappers/test_group_by.py b/tests/test_streams_operations/test_mappers/test_group_by.py deleted file mode 100644 index 1594498..0000000 --- a/tests/test_streams_operations/test_mappers/test_group_by.py +++ /dev/null @@ -1,298 +0,0 @@ -"""Tests for GroupBy mapper functionality.""" - -import pytest -import pickle -from orcabridge.mappers import GroupBy -from orcabridge.streams import SyncStreamFromLists - - -class TestGroupBy: - """Test cases for GroupBy mapper.""" - - def test_group_by_basic(self): - """Test basic groupby functionality.""" - tags = [ - {"category": "A", "id": "1"}, - {"category": "B", "id": "2"}, - {"category": "A", "id": "3"}, - {"category": "B", "id": "4"}, - ] - packets = [ - {"value": "data/item1.txt", "name": "metadata/item1.json"}, - {"value": "data/item2.txt", "name": "metadata/item2.json"}, - {"value": "data/item3.txt", "name": "metadata/item3.json"}, - {"value": "data/item4.txt", "name": "metadata/item4.json"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - group_by = GroupBy(group_keys=["category"]) - grouped_stream = group_by(stream) - - results = list(grouped_stream) - - # Should have 2 groups (A and B) - assert len(results) == 2 - - # Check that all groups are present - categories_found = [] - for tag, _ in results: - categories_found.extend(tag["category"]) - categories = set(categories_found) - assert categories == {"A", "B"} - - # Check grouped data structure - # With reduce_keys=False (default), everything should be lists including group keys - for tag, packet in results: - if tag["category"] == ["A", "A"]: # Group key is also a list - assert tag["id"] == ["1", "3"] # IDs for category A - assert packet["value"] == [ - "data/item1.txt", - "data/item3.txt", - ] # Values for category A - assert packet["name"] == ["metadata/item1.json", "metadata/item3.json"] - elif tag["category"] == ["B", "B"]: # Group key is also a list - assert tag["id"] == ["2", "4"] # IDs for category B - assert packet["value"] == [ - "data/item2.txt", - "data/item4.txt", - ] # Values for category B - assert packet["name"] == ["metadata/item2.json", "metadata/item4.json"] - - def test_group_by_reduce_keys(self): - """Test groupby with reduce_keys=True.""" - tags = [ - {"category": "A", "id": "1", "extra": "x1"}, - {"category": "A", "id": "2", "extra": "x2"}, - {"category": "B", "id": "3", "extra": "x3"}, - ] - packets = [ - {"value": "data/item1.txt"}, - {"value": "data/item2.txt"}, - {"value": "data/item3.txt"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - group_by = GroupBy(group_keys=["category"], reduce_keys=True) - grouped_stream = group_by(stream) - - results = list(grouped_stream) - - for tag, packet in results: - if tag["category"] == "A": - # With reduce_keys=True, group keys become singular values - assert tag["category"] == "A" - # Non-group keys become lists - assert tag["id"] == ["1", "2"] - assert tag["extra"] == ["x1", "x2"] - elif tag["category"] == "B": - assert tag["category"] == "B" - assert tag["id"] == ["3"] - assert tag["extra"] == ["x3"] - - def test_group_by_no_group_keys(self): - """Test groupby without specifying group_keys (uses all tag keys).""" - tags = [ - {"category": "A", "id": "1"}, - {"category": "A", "id": "1"}, # Duplicate - {"category": "B", "id": "2"}, - ] - packets = [ - {"value": "data/item1.txt"}, - {"value": "data/item2.txt"}, - {"value": "data/item3.txt"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - group_by = GroupBy() # No group_keys specified - grouped_stream = group_by(stream) - - results = list(grouped_stream) - - # Should group by all tag keys (category, id) - assert len(results) == 2 # (A,1) and (B,2) - - # Extract group keys, accounting for lists in the results - group_keys = set() - for tag, _ in results: - # When reduce_keys=False, all values are lists - category_list = tag["category"] - id_list = tag["id"] - # Since this groups by exact matches, each group should have same values - # We'll take the first value from each list to represent the group - group_keys.add((category_list[0], id_list[0])) - assert group_keys == {("A", "1"), ("B", "2")} - - def test_group_by_with_selection_function(self): - """Test groupby with selection function.""" - tags = [ - {"category": "A", "priority": "1"}, - {"category": "A", "priority": "2"}, - {"category": "A", "priority": "3"}, - ] - packets = [ - {"value": "data/item1.txt"}, - {"value": "data/item2.txt"}, - {"value": "data/item3.txt"}, - ] - - # Selection function that only keeps items with priority >= 2 - def select_high_priority(grouped_items): - return [int(tag["priority"]) >= 2 for tag, packet in grouped_items] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - group_by = GroupBy( - group_keys=["category"], selection_function=select_high_priority - ) - grouped_stream = group_by(stream) - - results = list(grouped_stream) - - assert len(results) == 1 - tag, packet = results[0] - - # Should only have priority 2 and 3 items - assert tag["priority"] == ["2", "3"] - assert packet["value"] == ["data/item2.txt", "data/item3.txt"] - - def test_group_by_empty_stream(self): - """Test groupby with empty stream.""" - stream = SyncStreamFromLists( - tags=[], packets=[], tag_keys=["category", "id"], packet_keys=["value"] - ) - group_by = GroupBy(group_keys=["category"]) - grouped_stream = group_by(stream) - - results = list(grouped_stream) - assert len(results) == 0 - - def test_group_by_single_item(self): - """Test groupby with single item.""" - tags = [{"category": "A", "id": "1"}] - packets = [{"value": "data/item1.txt"}] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - group_by = GroupBy(group_keys=["category"]) - grouped_stream = group_by(stream) - - results = list(grouped_stream) - - assert len(results) == 1 - tag, packet = results[0] - assert tag["category"] == [ - "A" - ] # With reduce_keys=False, even single values become lists - assert tag["id"] == ["1"] - assert packet["value"] == ["data/item1.txt"] - - def test_group_by_missing_group_keys(self): - """Test groupby when some items don't have the group keys.""" - tags = [ - {"category": "A", "id": "1"}, - {"id": "2"}, # Missing category - {"category": "A", "id": "3"}, - ] - packets = [ - {"value": "data/item1.txt"}, - {"value": "data/item2.txt"}, - {"value": "data/item3.txt"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - group_by = GroupBy(group_keys=["category"]) - grouped_stream = group_by(stream) - - results = list(grouped_stream) - - # Should have 2 groups: category="A" and category=None - assert len(results) == 2 - - categories = set() - for tag, _ in results: - # When reduce_keys=False, all values are lists - category_list = tag.get("category", [None]) - if category_list and category_list != [None]: - categories.add(category_list[0]) - else: - categories.add(None) - assert categories == {"A", None} - - def test_group_by_selection_function_filters_all(self): - """Test groupby where selection function filters out all items.""" - tags = [ - {"category": "A", "priority": "1"}, - {"category": "A", "priority": "2"}, - ] - packets = [ - {"value": "data/item1.txt"}, - {"value": "data/item2.txt"}, - ] - - # Selection function that filters out everything - def select_none(grouped_items): - return [False] * len(grouped_items) - - stream = SyncStreamFromLists(tags=tags, packets=packets) - group_by = GroupBy(group_keys=["category"], selection_function=select_none) - grouped_stream = group_by(stream) - - results = list(grouped_stream) - - # Should have no results since everything was filtered out - assert len(results) == 0 - - def test_group_by_multiple_streams_error(self): - """Test that GroupBy raises error with multiple streams.""" - stream1 = SyncStreamFromLists(tags=[{"a": "1"}], packets=[{"b": "file.txt"}]) - stream2 = SyncStreamFromLists(tags=[{"c": "3"}], packets=[{"d": "file2.txt"}]) - - group_by = GroupBy(group_keys=["a"]) - - with pytest.raises(ValueError, match="exactly one stream"): - list(group_by(stream1, stream2)) - - def test_group_by_pickle(self): - """Test that GroupBy mapper is pickleable.""" - # Test basic GroupBy - group_by = GroupBy(group_keys=["category"]) - pickled = pickle.dumps(group_by) - unpickled = pickle.loads(pickled) - - assert unpickled.group_keys == group_by.group_keys - assert unpickled.reduce_keys == group_by.reduce_keys - assert unpickled.selection_function == group_by.selection_function - - # Test with reduce_keys - group_by_reduce = GroupBy(group_keys=["category"], reduce_keys=True) - pickled_reduce = pickle.dumps(group_by_reduce) - unpickled_reduce = pickle.loads(pickled_reduce) - - assert unpickled_reduce.group_keys == group_by_reduce.group_keys - assert unpickled_reduce.reduce_keys == group_by_reduce.reduce_keys - - def test_group_by_identity_structure(self): - """Test GroupBy identity_structure method.""" - stream = SyncStreamFromLists(tags=[{"a": "1"}], packets=[{"b": "file.txt"}]) - - # Test without selection function - group_by1 = GroupBy(group_keys=["category"]) - structure1 = group_by1.identity_structure(stream) - assert structure1[0] == "GroupBy" - assert structure1[1] == ["category"] - assert not structure1[2] # reduce_keys - - # Test with reduce_keys - group_by2 = GroupBy(group_keys=["category"], reduce_keys=True) - structure2 = group_by2.identity_structure(stream) - assert structure2[2] # reduce_keys - - # Different group_keys should have different structures - group_by3 = GroupBy(group_keys=["other"]) - structure3 = group_by3.identity_structure(stream) - assert structure1 != structure3 - - def test_group_by_repr(self): - """Test GroupBy string representation.""" - group_by = GroupBy(group_keys=["category"], reduce_keys=True) - repr_str = repr(group_by) - # Should contain class name and key parameters - assert "GroupBy" in repr_str diff --git a/tests/test_streams_operations/test_mappers/test_join.py b/tests/test_streams_operations/test_mappers/test_join.py deleted file mode 100644 index 7b60571..0000000 --- a/tests/test_streams_operations/test_mappers/test_join.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Tests for Join mapper functionality.""" - -import pytest -import pickle -from orcabridge.mappers import Join -from orcabridge.streams import SyncStreamFromLists - - -class TestJoin: - """Test cases for Join mapper.""" - - def test_join_basic(self, sample_packets, sample_tags): - """Test basic join functionality.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - join = Join() - joined_stream = join(stream) - - # Join should collect all packets into a single packet - packets = list(joined_stream) - - assert len(packets) == 1 - joined_packet, joined_tag = packets[0] - - # The joined packet should contain all original packets - assert len(joined_packet) == len(sample_packets) - assert list(joined_packet) == sample_packets - - def test_join_empty_stream(self): - """Test join with empty stream.""" - empty_stream = SyncStreamFromLists([], []) - join = Join() - joined_stream = join(empty_stream) - - packets = list(joined_stream) - - assert len(packets) == 1 - joined_packet, _ = packets[0] - assert len(joined_packet) == 0 - assert list(joined_packet) == [] - - def test_join_single_packet(self): - """Test join with single packet stream.""" - packets = ["single_packet"] - tags = ["single_tag"] - stream = SyncStreamFromLists(packets, tags) - - join = Join() - joined_stream = join(stream) - - result = list(joined_stream) - assert len(result) == 1 - - joined_packet, joined_tag = result[0] - assert len(joined_packet) == 1 - assert list(joined_packet) == ["single_packet"] - - def test_join_preserves_packet_types(self): - """Test that join preserves different packet types.""" - packets = [PacketType("data1"), {"key": "value"}, [1, 2, 3], 42, "string"] - tags = ["type1", "type2", "type3", "type4", "type5"] - - stream = SyncStreamFromLists(packets, tags) - join = Join() - joined_stream = join(stream) - - result = list(joined_stream) - assert len(result) == 1 - - joined_packet, _ = result[0] - assert len(joined_packet) == 5 - - joined_list = list(joined_packet) - assert joined_list[0] == PacketType("data1") - assert joined_list[1] == {"key": "value"} - assert joined_list[2] == [1, 2, 3] - assert joined_list[3] == 42 - assert joined_list[4] == "string" - - def test_join_maintains_order(self): - """Test that join maintains packet order.""" - packets = [f"packet_{i}" for i in range(10)] - tags = [f"tag_{i}" for i in range(10)] - - stream = SyncStreamFromLists(packets, tags) - join = Join() - joined_stream = join(stream) - - result = list(joined_stream) - joined_packet, _ = result[0] - - assert list(joined_packet) == packets - - def test_join_tag_handling(self, sample_packets, sample_tags): - """Test how join handles tags.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - join = Join() - joined_stream = join(stream) - - result = list(joined_stream) - _, joined_tag = result[0] - - # The joined tag should be a collection of original tags - # (implementation-specific behavior) - assert joined_tag is not None - - def test_join_large_stream(self): - """Test join with large stream.""" - packets = [f"packet_{i}" for i in range(1000)] - tags = [f"tag_{i}" for i in range(1000)] - - stream = SyncStreamFromLists(packets, tags) - join = Join() - joined_stream = join(stream) - - result = list(joined_stream) - assert len(result) == 1 - - joined_packet, _ = result[0] - assert len(joined_packet) == 1000 - assert list(joined_packet) == packets - - def test_join_nested_structures(self): - """Test join with nested data structures.""" - packets = [{"nested": {"data": 1}}, [1, [2, 3], 4], ((1, 2), (3, 4))] - tags = ["dict", "list", "tuple"] - - stream = SyncStreamFromLists(packets, tags) - join = Join() - joined_stream = join(stream) - - result = list(joined_stream) - joined_packet, _ = result[0] - - joined_list = list(joined_packet) - assert joined_list[0] == {"nested": {"data": 1}} - assert joined_list[1] == [1, [2, 3], 4] - assert joined_list[2] == ((1, 2), (3, 4)) - - def test_join_with_none_packets(self): - """Test join with None packets.""" - packets = ["data1", None, "data2", None] - tags = ["tag1", "tag2", "tag3", "tag4"] - - stream = SyncStreamFromLists(packets, tags) - join = Join() - joined_stream = join(stream) - - result = list(joined_stream) - joined_packet, _ = result[0] - - joined_list = list(joined_packet) - assert joined_list == ["data1", None, "data2", None] - - def test_join_chaining(self, sample_packets, sample_tags): - """Test chaining join operations.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - - # First join - join1 = Join() - joined_stream1 = join1(stream) - - # Second join (should join the already joined result) - join2 = Join() - joined_stream2 = join2(joined_stream1) - - result = list(joined_stream2) - assert len(result) == 1 - - # The result should be a packet containing one element (the previous join result) - final_packet, _ = result[0] - assert len(final_packet) == 1 - - def test_join_memory_efficiency(self): - """Test that join doesn't consume excessive memory for large streams.""" - # This is more of a performance test, but we can check basic functionality - packets = [f"packet_{i}" for i in range(10000)] - tags = [f"tag_{i}" for i in range(10000)] - - stream = SyncStreamFromLists(packets, tags) - join = Join() - joined_stream = join(stream) - - # Just verify it completes without issues - result = list(joined_stream) - assert len(result) == 1 - - joined_packet, _ = result[0] - assert len(joined_packet) == 10000 - - def test_join_pickle(self): - """Test that Join mapper is pickleable.""" - join = Join() - pickled = pickle.dumps(join) - unpickled = pickle.loads(pickled) - - # Test that unpickled mapper works the same - assert isinstance(unpickled, Join) - assert unpickled.__class__.__name__ == "Join" diff --git a/tests/test_streams_operations/test_mappers/test_map_packets.py b/tests/test_streams_operations/test_mappers/test_map_packets.py deleted file mode 100644 index da278de..0000000 --- a/tests/test_streams_operations/test_mappers/test_map_packets.py +++ /dev/null @@ -1,273 +0,0 @@ -"""Tests for MapPackets mapper functionality.""" - -import pytest -from orcabridge.base import PacketType -from orcabridge.mapper import MapPackets -from orcabridge.stream import SyncStreamFromLists - - -class TestMapPackets: - """Test cases for MapPackets mapper.""" - - def test_map_packets_basic(self, sample_packets, sample_tags): - """Test basic map packets functionality.""" - - def add_suffix(packet): - return f"{packet}_mapped" - - stream = SyncStreamFromLists(sample_packets, sample_tags) - map_packets = MapPackets(add_suffix) - mapped_stream = map_packets(stream) - - result_packets = [] - result_tags = [] - for packet, tag in mapped_stream: - result_packets.append(packet) - result_tags.append(tag) - - # Packets should be transformed, tags unchanged - expected_packets = [f"{p}_mapped" for p in sample_packets] - assert result_packets == expected_packets - assert result_tags == sample_tags - - def test_map_packets_numeric_transformation(self): - """Test map packets with numeric transformation.""" - packets = [1, 2, 3, 4, 5] - tags = ["num1", "num2", "num3", "num4", "num5"] - - def square(packet): - return packet**2 - - stream = SyncStreamFromLists(packets, tags) - map_packets = MapPackets(square) - mapped_stream = map_packets(stream) - - result = list(mapped_stream) - - expected_packets = [1, 4, 9, 16, 25] - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == tags - - def test_map_packets_type_conversion(self): - """Test map packets with type conversion.""" - packets = ["1", "2", "3", "4"] - tags = ["str1", "str2", "str3", "str4"] - - def str_to_int(packet): - return int(packet) - - stream = SyncStreamFromLists(packets, tags) - map_packets = MapPackets(str_to_int) - mapped_stream = map_packets(stream) - - result = list(mapped_stream) - - expected_packets = [1, 2, 3, 4] - actual_packets = [packet for packet, _ in result] - - assert actual_packets == expected_packets - assert all(isinstance(p, int) for p in actual_packets) - - def test_map_packets_complex_transformation(self): - """Test map packets with complex data transformation.""" - packets = [ - {"name": "alice", "age": 25}, - {"name": "bob", "age": 30}, - {"name": "charlie", "age": 35}, - ] - tags = ["person1", "person2", "person3"] - - def create_description(packet): - return f"{packet['name']} is {packet['age']} years old" - - stream = SyncStreamFromLists(packets, tags) - map_packets = MapPackets(create_description) - mapped_stream = map_packets(stream) - - result = list(mapped_stream) - - expected_packets = [ - "alice is 25 years old", - "bob is 30 years old", - "charlie is 35 years old", - ] - actual_packets = [packet for packet, _ in result] - - assert actual_packets == expected_packets - - def test_map_packets_identity_function(self, sample_packets, sample_tags): - """Test map packets with identity function.""" - - def identity(packet): - return packet - - stream = SyncStreamFromLists(sample_packets, sample_tags) - map_packets = MapPackets(identity) - mapped_stream = map_packets(stream) - - result = list(mapped_stream) - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == sample_packets - assert actual_tags == sample_tags - - def test_map_packets_empty_stream(self): - """Test map packets with empty stream.""" - - def dummy_transform(packet): - return packet * 2 - - empty_stream = SyncStreamFromLists([], []) - map_packets = MapPackets(dummy_transform) - mapped_stream = map_packets(empty_stream) - - result = list(mapped_stream) - assert len(result) == 0 - - def test_map_packets_with_none_values(self): - """Test map packets with None values.""" - packets = [1, None, 3, None, 5] - tags = ["num1", "null1", "num3", "null2", "num5"] - - def handle_none(packet): - return 0 if packet is None else packet * 2 - - stream = SyncStreamFromLists(packets, tags) - map_packets = MapPackets(handle_none) - mapped_stream = map_packets(stream) - - result = list(mapped_stream) - - expected_packets = [2, 0, 6, 0, 10] - actual_packets = [packet for packet, _ in result] - - assert actual_packets == expected_packets - - def test_map_packets_exception_handling(self): - """Test map packets when transformation function raises exception.""" - packets = [1, 2, "invalid", 4] - tags = ["num1", "num2", "str1", "num4"] - - def divide_by_packet(packet): - return 10 / packet # Will fail on "invalid" - - stream = SyncStreamFromLists(packets, tags) - map_packets = MapPackets(divide_by_packet) - mapped_stream = map_packets(stream) - - # Should raise exception when processing "invalid" - with pytest.raises(TypeError): - list(mapped_stream) - - def test_map_packets_preserves_order(self): - """Test that map packets preserves packet order.""" - packets = [f"packet_{i}" for i in range(100)] - tags = [f"tag_{i}" for i in range(100)] - - def add_prefix(packet): - return f"mapped_{packet}" - - stream = SyncStreamFromLists(packets, tags) - map_packets = MapPackets(add_prefix) - mapped_stream = map_packets(stream) - - result = list(mapped_stream) - - expected_packets = [f"mapped_packet_{i}" for i in range(100)] - actual_packets = [packet for packet, _ in result] - - assert actual_packets == expected_packets - - def test_map_packets_with_lambda(self, sample_packets, sample_tags): - """Test map packets with lambda function.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - map_packets = MapPackets(lambda x: f"λ({x})") - mapped_stream = map_packets(stream) - - result = list(mapped_stream) - - expected_packets = [f"λ({p})" for p in sample_packets] - actual_packets = [packet for packet, _ in result] - - assert actual_packets == expected_packets - - def test_map_packets_chaining(self, sample_packets, sample_tags): - """Test chaining multiple map packets operations.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - - # First transformation - map1 = MapPackets(lambda x: f"first_{x}") - stream1 = map1(stream) - - # Second transformation - map2 = MapPackets(lambda x: f"second_{x}") - stream2 = map2(stream1) - - result = list(stream2) - - expected_packets = [f"second_first_{p}" for p in sample_packets] - actual_packets = [packet for packet, _ in result] - - assert actual_packets == expected_packets - - def test_map_packets_with_packet_type(self): - """Test map packets with PacketType objects.""" - packets = [PacketType("data1"), PacketType("data2")] - tags = ["type1", "type2"] - - def extract_data(packet): - return packet.data if hasattr(packet, "data") else str(packet) - - stream = SyncStreamFromLists(packets, tags) - map_packets = MapPackets(extract_data) - mapped_stream = map_packets(stream) - - result = list(mapped_stream) - actual_packets = [packet for packet, _ in result] - - # Should extract string representation or data - assert len(actual_packets) == 2 - assert all(isinstance(p, str) for p in actual_packets) - - def test_map_packets_stateful_transformation(self): - """Test map packets with stateful transformation.""" - packets = [1, 2, 3, 4, 5] - tags = ["n1", "n2", "n3", "n4", "n5"] - - class Counter: - def __init__(self): - self.count = 0 - - def transform(self, packet): - self.count += 1 - return (packet, self.count) - - counter = Counter() - stream = SyncStreamFromLists(packets, tags) - map_packets = MapPackets(counter.transform) - mapped_stream = map_packets(stream) - - result = list(mapped_stream) - - expected_packets = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)] - actual_packets = [packet for packet, _ in result] - - assert actual_packets == expected_packets def test_map_packets_pickle(self): - """Test that MapPackets mapper is pickleable.""" - import pickle - from orcabridge.mappers import MapPackets - - # MapPackets takes a key mapping, not a transformation function - key_map = {"old_key": "new_key", "data": "value"} - map_packets = MapPackets(key_map) - pickled = pickle.dumps(map_packets) - unpickled = pickle.loads(pickled) - - # Test that unpickled mapper works the same - assert isinstance(unpickled, MapPackets) - assert unpickled.key_map == map_packets.key_map diff --git a/tests/test_streams_operations/test_mappers/test_map_tags.py b/tests/test_streams_operations/test_mappers/test_map_tags.py deleted file mode 100644 index a8e185a..0000000 --- a/tests/test_streams_operations/test_mappers/test_map_tags.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Tests for MapTags mapper functionality.""" - -import pytest -from orcabridge.base import PacketType -from orcabridge.mapper import MapTags -from orcabridge.stream import SyncStreamFromLists - - -class TestMapTags: - """Test cases for MapTags mapper.""" - - def test_map_tags_basic(self, sample_packets, sample_tags): - """Test basic map tags functionality.""" - - def add_prefix(tag): - return f"mapped_{tag}" - - stream = SyncStreamFromLists(sample_packets, sample_tags) - map_tags = MapTags(add_prefix) - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - expected_tags = [f"mapped_{t}" for t in sample_tags] - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - # Packets should be unchanged, tags transformed - assert actual_packets == sample_packets - assert actual_tags == expected_tags - - def test_map_tags_type_conversion(self, sample_packets): - """Test map tags with type conversion.""" - tags = ["1", "2", "3"] - - def str_to_int(tag): - return int(tag) - - stream = SyncStreamFromLists(sample_packets, tags) - map_tags = MapTags(str_to_int) - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - expected_tags = [1, 2, 3] - actual_tags = [tag for _, tag in result] - - assert actual_tags == expected_tags - assert all(isinstance(t, int) for t in actual_tags) - - def test_map_tags_complex_transformation(self): - """Test map tags with complex transformation.""" - packets = ["data1", "data2", "data3"] - tags = [ - {"type": "string", "length": 5}, - {"type": "string", "length": 5}, - {"type": "string", "length": 5}, - ] - - def extract_type(tag): - if isinstance(tag, dict): - return tag.get("type", "unknown") - return str(tag) - - stream = SyncStreamFromLists(packets, tags) - map_tags = MapTags(extract_type) - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - expected_tags = ["string", "string", "string"] - actual_tags = [tag for _, tag in result] - - assert actual_tags == expected_tags - - def test_map_tags_identity_function(self, sample_packets, sample_tags): - """Test map tags with identity function.""" - - def identity(tag): - return tag - - stream = SyncStreamFromLists(sample_packets, sample_tags) - map_tags = MapTags(identity) - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == sample_packets - assert actual_tags == sample_tags - - def test_map_tags_empty_stream(self): - """Test map tags with empty stream.""" - - def dummy_transform(tag): - return f"transformed_{tag}" - - empty_stream = SyncStreamFromLists([], []) - map_tags = MapTags(dummy_transform) - mapped_stream = map_tags(empty_stream) - - result = list(mapped_stream) - assert len(result) == 0 - - def test_map_tags_with_none_values(self, sample_packets): - """Test map tags with None values.""" - tags = ["tag1", None, "tag3"] - - def handle_none(tag): - return "NULL_TAG" if tag is None else tag.upper() - - stream = SyncStreamFromLists(sample_packets, tags) - map_tags = MapTags(handle_none) - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - expected_tags = ["TAG1", "NULL_TAG", "TAG3"] - actual_tags = [tag for _, tag in result] - - assert actual_tags == expected_tags - - def test_map_tags_exception_handling(self, sample_packets): - """Test map tags when transformation function raises exception.""" - tags = ["valid", "also_valid", 123] # 123 will cause error in upper() - - def to_upper(tag): - return tag.upper() # Will fail on integer - - stream = SyncStreamFromLists(sample_packets, tags) - map_tags = MapTags(to_upper) - mapped_stream = map_tags(stream) - - # Should raise exception when processing integer tag - with pytest.raises(AttributeError): - list(mapped_stream) - - def test_map_tags_preserves_packets(self): - """Test that map tags preserves all packet types.""" - packets = [PacketType("data1"), {"key": "value"}, [1, 2, 3], 42, "string"] - tags = ["type1", "type2", "type3", "type4", "type5"] - - def add_suffix(tag): - return f"{tag}_processed" - - stream = SyncStreamFromLists(packets, tags) - map_tags = MapTags(add_suffix) - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - actual_packets = [packet for packet, _ in result] - expected_tags = [f"{t}_processed" for t in tags] - actual_tags = [tag for _, tag in result] - - assert actual_packets == packets - assert actual_tags == expected_tags - - def test_map_tags_maintains_order(self): - """Test that map tags maintains packet order.""" - packets = [f"packet_{i}" for i in range(100)] - tags = [f"tag_{i}" for i in range(100)] - - def reverse_tag(tag): - return tag[::-1] # Reverse the string - - stream = SyncStreamFromLists(packets, tags) - map_tags = MapTags(reverse_tag) - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - expected_tags = [f"{i}_gat" for i in range(100)] # "tag_i" reversed - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == packets - assert actual_tags == expected_tags - - def test_map_tags_with_lambda(self, sample_packets, sample_tags): - """Test map tags with lambda function.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - map_tags = MapTags(lambda t: f"λ({t})") - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - expected_tags = [f"λ({t})" for t in sample_tags] - actual_tags = [tag for _, tag in result] - - assert actual_tags == expected_tags - - def test_map_tags_chaining(self, sample_packets, sample_tags): - """Test chaining multiple map tags operations.""" - stream = SyncStreamFromLists(sample_packets, sample_tags) - - # First transformation - map1 = MapTags(lambda t: f"first_{t}") - stream1 = map1(stream) - - # Second transformation - map2 = MapTags(lambda t: f"second_{t}") - stream2 = map2(stream1) - - result = list(stream2) - - expected_tags = [f"second_first_{t}" for t in sample_tags] - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == sample_packets - assert actual_tags == expected_tags - - def test_map_tags_stateful_transformation(self): - """Test map tags with stateful transformation.""" - packets = ["a", "b", "c", "d", "e"] - tags = ["tag1", "tag2", "tag3", "tag4", "tag5"] - - class TagCounter: - def __init__(self): - self.count = 0 - - def transform(self, tag): - self.count += 1 - return f"{tag}_#{self.count}" - - counter = TagCounter() - stream = SyncStreamFromLists(packets, tags) - map_tags = MapTags(counter.transform) - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - expected_tags = ["tag1_#1", "tag2_#2", "tag3_#3", "tag4_#4", "tag5_#5"] - actual_tags = [tag for _, tag in result] - - assert actual_tags == expected_tags - - def test_map_tags_with_complex_types(self): - """Test map tags with complex tag types.""" - packets = ["data1", "data2", "data3"] - tags = [ - {"id": 1, "category": "A"}, - {"id": 2, "category": "B"}, - {"id": 3, "category": "A"}, - ] - - def extract_category(tag): - if isinstance(tag, dict): - return f"cat_{tag.get('category', 'unknown')}" - return str(tag) - - stream = SyncStreamFromLists(packets, tags) - map_tags = MapTags(extract_category) - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - expected_tags = ["cat_A", "cat_B", "cat_A"] - actual_tags = [tag for _, tag in result] - - assert actual_tags == expected_tags - - def test_map_tags_preserves_tag_references(self): - """Test that map tags doesn't break tag references when not needed.""" - packets = ["data1", "data2"] - shared_tag = {"shared": "reference"} - tags = [shared_tag, shared_tag] - - def conditional_transform(tag): - # Only transform if it's a string - if isinstance(tag, str): - return f"transformed_{tag}" - return tag # Keep dict unchanged - - stream = SyncStreamFromLists(packets, tags) - map_tags = MapTags(conditional_transform) - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - actual_tags = [tag for _, tag in result] - - # Both tags should still reference the same object - assert actual_tags[0] is shared_tag - assert actual_tags[1] is shared_tag - assert actual_tags[0] is actual_tags[1] - - def test_map_tags_large_stream(self): - """Test map tags with large stream.""" - packets = [f"packet_{i}" for i in range(1000)] - tags = [f"tag_{i}" for i in range(1000)] - - def add_hash(tag): - return f"{tag}_{hash(tag) % 1000}" - - stream = SyncStreamFromLists(packets, tags) - map_tags = MapTags(add_hash) - mapped_stream = map_tags(stream) - - result = list(mapped_stream) - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert len(actual_packets) == 1000 - assert len(actual_tags) == 1000 - assert actual_packets == packets - - # All tags should have been transformed - assert all( - "_" in tag and tag != f"tag_{i}" for i, tag in enumerate(actual_tags) - ) - - def test_map_tags_pickle(self): - """Test that MapTags mapper is pickleable.""" - import pickle - from orcabridge.mappers import MapTags - - # MapTags takes a key mapping, not a transformation function - key_map = {"old_tag": "new_tag", "category": "type"} - map_tags = MapTags(key_map) - pickled = pickle.dumps(map_tags) - unpickled = pickle.loads(pickled) - - # Test that unpickled mapper works the same - assert isinstance(unpickled, MapTags) - assert unpickled.key_map == map_tags.key_map diff --git a/tests/test_streams_operations/test_mappers/test_merge.py b/tests/test_streams_operations/test_mappers/test_merge.py deleted file mode 100644 index fc315d6..0000000 --- a/tests/test_streams_operations/test_mappers/test_merge.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Tests for Merge mapper functionality.""" - -import pickle -import pytest -from orcabridge.base import PacketType -from orcabridge.mappers import Merge -from orcabridge.streams import SyncStreamFromLists - - -class TestMerge: - """Test cases for Merge mapper.""" - - def test_merge_two_streams(self, sample_packets, sample_tags): - """Test merging two streams.""" - # Create two streams - stream1 = SyncStreamFromLists(sample_packets[:2], sample_tags[:2]) - stream2 = SyncStreamFromLists(sample_packets[2:], sample_tags[2:]) - - merge = Merge() - merged_stream = merge(stream1, stream2) - - packets = [] - tags = [] - for packet, tag in merged_stream: - packets.append(packet) - tags.append(tag) - - # Should contain all packets from both streams - assert len(packets) == 3 - assert set(packets) == set(sample_packets) - assert set(tags) == set(sample_tags) - - def test_merge_multiple_streams(self, sample_packets, sample_tags): - """Test merging multiple streams.""" - # Create three streams with one packet each - streams = [] - for i in range(3): - stream = SyncStreamFromLists([sample_packets[i]], [sample_tags[i]]) - streams.append(stream) - - merge = Merge() - merged_stream = merge(*streams) - - packets = [] - tags = [] - for packet, tag in merged_stream: - packets.append(packet) - tags.append(tag) - - assert len(packets) == 3 - assert set(packets) == set(sample_packets) - assert set(tags) == set(sample_tags) - - def test_merge_empty_streams(self): - """Test merging with empty streams.""" - empty1 = SyncStreamFromLists([], []) - empty2 = SyncStreamFromLists([], []) - - merge = Merge() - merged_stream = merge(empty1, empty2) - - packets = list(merged_stream) - assert len(packets) == 0 - - def test_merge_one_empty_one_full(self, sample_stream): - """Test merging empty stream with full stream.""" - empty_stream = SyncStreamFromLists([], []) - - merge = Merge() - merged_stream = merge(sample_stream, empty_stream) - - packets = list(merged_stream) - original_packets = list(sample_stream) - - assert len(packets) == len(original_packets) - # Order might be different, so check sets - assert set(packets) == set(original_packets) - - def test_merge_different_lengths(self): - """Test merging streams of different lengths.""" - packets1 = ["a", "b"] - tags1 = ["tag1", "tag2"] - packets2 = ["c", "d", "e", "f"] - tags2 = ["tag3", "tag4", "tag5", "tag6"] - - stream1 = SyncStreamFromLists(packets1, tags1) - stream2 = SyncStreamFromLists(packets2, tags2) - - merge = Merge() - merged_stream = merge(stream1, stream2) - - packets = [] - tags = [] - for packet, tag in merged_stream: - packets.append(packet) - tags.append(tag) - - assert len(packets) == 6 - assert set(packets) == set(packets1 + packets2) - assert set(tags) == set(tags1 + tags2) - - def test_merge_single_stream(self, sample_stream): - """Test merge with single stream.""" - merge = Merge() - merged_stream = merge(sample_stream) - - packets = list(merged_stream) - original_packets = list(sample_stream) - - assert packets == original_packets - - def test_merge_preserves_packet_types(self): - """Test that merge preserves different packet types.""" - packets1 = [PacketType("data1"), {"key1": "value1"}] - tags1 = ["str1", "dict1"] - packets2 = [[1, 2], 42] - tags2 = ["list1", "int1"] - - stream1 = SyncStreamFromLists(packets1, tags1) - stream2 = SyncStreamFromLists(packets2, tags2) - - merge = Merge() - merged_stream = merge(stream1, stream2) - - result_packets = [] - for packet, _ in merged_stream: - result_packets.append(packet) - - assert len(result_packets) == 4 - assert set(result_packets) == set(packets1 + packets2) - - def test_merge_order_independence(self, sample_packets, sample_tags): - """Test that merge order doesn't affect final result set.""" - stream1 = SyncStreamFromLists(sample_packets[:2], sample_tags[:2]) - stream2 = SyncStreamFromLists(sample_packets[2:], sample_tags[2:]) - - merge = Merge() - - # Merge in one order - merged1 = merge(stream1, stream2) - packets1 = set(p for p, _ in merged1) - - # Merge in reverse order (need to recreate streams) - stream1_new = SyncStreamFromLists(sample_packets[:2], sample_tags[:2]) - stream2_new = SyncStreamFromLists(sample_packets[2:], sample_tags[2:]) - merged2 = merge(stream2_new, stream1_new) - packets2 = set(p for p, _ in merged2) - - assert packets1 == packets2 - - def test_merge_with_duplicate_packets(self): - """Test merging streams with duplicate packets.""" - packets1 = ["a", "b"] - tags1 = ["tag1", "tag2"] - packets2 = ["a", "c"] # "a" appears in both streams - tags2 = ["tag3", "tag4"] - - stream1 = SyncStreamFromLists(packets1, tags1) - stream2 = SyncStreamFromLists(packets2, tags2) - - merge = Merge() - merged_stream = merge(stream1, stream2) - - packets = [] - for packet, _ in merged_stream: - packets.append(packet) - - # Should include duplicates - assert len(packets) == 4 - assert packets.count("a") == 2 - assert "b" in packets - assert "c" in packets - - def test_merge_no_streams_error(self): - """Test that merge with no streams raises an error.""" - merge = Merge() - - with pytest.raises(TypeError): - merge() - - def test_merge_large_number_of_streams(self): - """Test merging a large number of streams.""" - streams = [] - all_packets = [] - - for i in range(10): - packets = [f"packet_{i}"] - tags = [f"tag_{i}"] - streams.append(SyncStreamFromLists(packets, tags)) - all_packets.extend(packets) - - merge = Merge() - merged_stream = merge(*streams) - - result_packets = [] - for packet, _ in merged_stream: - result_packets.append(packet) - - assert len(result_packets) == 10 - assert set(result_packets) == set(all_packets) - """Test that Merge mapper is pickleable.""" - merge = Merge() - pickled = pickle.dumps(merge) - unpickled = pickle.loads(pickled) - - # Test that unpickled mapper works the same - assert isinstance(unpickled, Merge) - assert unpickled.__class__.__name__ == "Merge" diff --git a/tests/test_streams_operations/test_mappers/test_repeat.py b/tests/test_streams_operations/test_mappers/test_repeat.py deleted file mode 100644 index b8a4a98..0000000 --- a/tests/test_streams_operations/test_mappers/test_repeat.py +++ /dev/null @@ -1,186 +0,0 @@ -"""Tests for Repeat mapper functionality.""" - -import pytest -import pickle -from orcabridge.mappers import Repeat - - -class TestRepeat: - """Test cases for Repeat mapper.""" - - def test_repeat_basic(self, sample_stream): - """Test basic repeat functionality.""" - repeat = Repeat(3) - repeated_stream = repeat(sample_stream) - - packets = list(repeated_stream) - - # Should have 3 times the original packets - assert len(packets) == 9 # 3 original * 3 repeats - - # Check that each packet appears 3 times consecutively - original_packets = list(sample_stream) - expected_packets = [] - for packet in original_packets: - expected_packets.extend([packet] * 3) - - assert packets == expected_packets - - def test_repeat_zero(self, sample_stream): - """Test repeat with count 0.""" - repeat = Repeat(0) - repeated_stream = repeat(sample_stream) - - packets = list(repeated_stream) - assert len(packets) == 0 - - def test_repeat_one(self, sample_stream): - """Test repeat with count 1.""" - repeat = Repeat(1) - repeated_stream = repeat(sample_stream) - - packets = list(repeated_stream) - original_packets = list(sample_stream) - - assert packets == original_packets - - def test_repeat_with_tags(self, sample_packets, sample_tags): - """Test repeat preserves tags correctly.""" - from orcabridge.streams import SyncStreamFromLists - - stream = SyncStreamFromLists(tags=sample_tags, packets=sample_packets) - repeat = Repeat(2) - repeated_stream = repeat(stream) - - packets = [] - tags = [] - for tag, packet in repeated_stream: - packets.append(packet) - tags.append(tag) - - # Each packet should appear twice with its corresponding tag - assert len(packets) == 6 # 3 original * 2 repeats - assert len(tags) == 6 - - # Check pattern: [p1,p1,p2,p2,p3,p3] with [t1,t1,t2,t2,t3,t3] - expected_packets = [] - expected_tags = [] - for p, t in zip(sample_packets, sample_tags): - expected_packets.extend([p, p]) - expected_tags.extend([t, t]) - - assert packets == expected_packets - assert tags == expected_tags - - def test_repeat_with_empty_stream(self): - """Test repeat with empty stream.""" - from orcabridge.streams import SyncStreamFromLists - - empty_stream = SyncStreamFromLists(tags=[], packets=[]) - repeat = Repeat(5) - repeated_stream = repeat(empty_stream) - - packets = list(repeated_stream) - assert len(packets) == 0 - - def test_repeat_large_count(self, sample_stream): - """Test repeat with large count.""" - repeat = Repeat(100) - repeated_stream = repeat(sample_stream) - - packets = list(repeated_stream) - assert len(packets) == 300 # 3 original * 100 repeats - - def test_repeat_negative_count(self): - """Test repeat with negative count raises error.""" - with pytest.raises(ValueError): - Repeat(-1) - - def test_repeat_non_integer_count(self): - """Test repeat with non-integer count.""" - with pytest.raises(TypeError): - Repeat(3.5) - - with pytest.raises(TypeError): - Repeat("3") - - def test_repeat_preserves_packet_types(self, sample_stream): - """Test that repeat preserves different packet types.""" - # Create stream with mixed packet types - from orcabridge.streams import SyncStreamFromLists - - packets = [ - {"data": "data1"}, - {"key": "value"}, - {"items": ["a", "b", "c"]}, - {"number": "42"}, - ] - tags = [{"type": "str"}, {"type": "dict"}, {"type": "list"}, {"type": "int"}] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - repeat = Repeat(2) - repeated_stream = repeat(stream) - - result_packets = [] - for tag, packet in repeated_stream: - result_packets.append(packet) - - expected = [ - {"data": "data1"}, - {"data": "data1"}, - {"key": "value"}, - {"key": "value"}, - {"items": ["a", "b", "c"]}, - {"items": ["a", "b", "c"]}, - {"number": "42"}, - {"number": "42"}, - ] - - assert result_packets == expected - - def test_repeat_chaining(self, sample_stream): - """Test chaining multiple repeat operations.""" - repeat1 = Repeat(2) - repeat2 = Repeat(3) - - # Apply first repeat - stream1 = repeat1(sample_stream) - # Apply second repeat - stream2 = repeat2(stream1) - - packets = list(stream2) - - # Should have 3 original * 2 * 3 = 18 packets - assert len(packets) == 18 - - # Each original packet should appear 6 times consecutively - original_packets = list(sample_stream) - expected = [] - for packet in original_packets: - expected.extend([packet] * 6) - - assert packets == expected - - def test_repeat_pickle(self): - """Test that Repeat mapper is pickleable.""" - repeat = Repeat(5) - - # Test pickle/unpickle - pickled = pickle.dumps(repeat) - unpickled = pickle.loads(pickled) - - # Verify the unpickled mapper has the same properties - assert unpickled.repeat_count == repeat.repeat_count - - # Test that the unpickled mapper works correctly - from orcabridge.streams import SyncStreamFromLists - - tags = [{"id": "1"}, {"id": "2"}] - packets = [{"data": "file1.txt"}, {"data": "file2.txt"}] - stream = SyncStreamFromLists(tags=tags, packets=packets) - - original_results = list(repeat(stream)) - unpickled_results = list(unpickled(stream)) - - assert original_results == unpickled_results - assert len(original_results) == 10 # 2 * 5 repeats diff --git a/tests/test_streams_operations/test_mappers/test_transform.py b/tests/test_streams_operations/test_mappers/test_transform.py deleted file mode 100644 index 5971fd2..0000000 --- a/tests/test_streams_operations/test_mappers/test_transform.py +++ /dev/null @@ -1,363 +0,0 @@ -"""Tests for Transform mapper functionality.""" - -import pytest -from orcabridge.mappers import Transform -from orcabridge.streams import SyncStreamFromLists - - -class TestTransform: - """Test cases for Transform mapper.""" - - def test_transform_basic(self, simple_transform): - """Test basic transform functionality.""" - packets = ["hello", "world", "test"] - tags = ["greeting", "noun", "action"] - - stream = SyncStreamFromLists(packets, tags) - transform_mapper = Transform(simple_transform) - transformed_stream = transform_mapper(stream) - - result = list(transformed_stream) - - expected_packets = ["HELLO", "WORLD", "TEST"] - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == tags # Tags should be preserved - - def test_transform_with_tag_modification(self): - """Test transform that modifies both packet and tag.""" - packets = [1, 2, 3, 4, 5] - tags = ["num1", "num2", "num3", "num4", "num5"] - - def double_and_prefix_tag(packet, tag): - return packet * 2, f"doubled_{tag}" - - stream = SyncStreamFromLists(packets, tags) - transform_mapper = Transform(double_and_prefix_tag) - transformed_stream = transform_mapper(stream) - - result = list(transformed_stream) - - expected_packets = [2, 4, 6, 8, 10] - expected_tags = [ - "doubled_num1", - "doubled_num2", - "doubled_num3", - "doubled_num4", - "doubled_num5", - ] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_transform_packet_only(self, sample_packets, sample_tags): - """Test transform that only modifies packets.""" - - def add_prefix(packet, tag): - return f"transformed_{packet}", tag - - stream = SyncStreamFromLists(sample_packets, sample_tags) - transform_mapper = Transform(add_prefix) - transformed_stream = transform_mapper(stream) - - result = list(transformed_stream) - - expected_packets = [f"transformed_{p}" for p in sample_packets] - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == sample_tags - - def test_transform_tag_only(self, sample_packets, sample_tags): - """Test transform that only modifies tags.""" - - def add_tag_suffix(packet, tag): - return packet, f"{tag}_processed" - - stream = SyncStreamFromLists(sample_packets, sample_tags) - transform_mapper = Transform(add_tag_suffix) - transformed_stream = transform_mapper(stream) - - result = list(transformed_stream) - - expected_tags = [f"{t}_processed" for t in sample_tags] - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == sample_packets - assert actual_tags == expected_tags - - def test_transform_empty_stream(self): - """Test transform with empty stream.""" - - def dummy_transform(packet, tag): - return packet, tag - - empty_stream = SyncStreamFromLists([], []) - transform_mapper = Transform(dummy_transform) - transformed_stream = transform_mapper(empty_stream) - - result = list(transformed_stream) - assert len(result) == 0 - - def test_transform_type_conversion(self): - """Test transform with type conversion.""" - packets = ["1", "2", "3", "4", "5"] - tags = ["str1", "str2", "str3", "str4", "str5"] - - def str_to_int_with_tag(packet, tag): - return int(packet), f"int_{tag}" - - stream = SyncStreamFromLists(packets, tags) - transform_mapper = Transform(str_to_int_with_tag) - transformed_stream = transform_mapper(stream) - - result = list(transformed_stream) - - expected_packets = [1, 2, 3, 4, 5] - expected_tags = ["int_str1", "int_str2", "int_str3", "int_str4", "int_str5"] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - assert all(isinstance(p, int) for p in actual_packets) - - def test_transform_complex_data(self): - """Test transform with complex data structures.""" - packets = [ - {"name": "alice", "age": 25}, - {"name": "bob", "age": 30}, - {"name": "charlie", "age": 35}, - ] - tags = ["person1", "person2", "person3"] - - def enrich_person_data(packet, tag): - enriched = packet.copy() - enriched["category"] = "adult" if packet["age"] >= 30 else "young" - return enriched, f"enriched_{tag}" - - stream = SyncStreamFromLists(packets, tags) - transform_mapper = Transform(enrich_person_data) - transformed_stream = transform_mapper(stream) - - result = list(transformed_stream) - - expected_packets = [ - {"name": "alice", "age": 25, "category": "young"}, - {"name": "bob", "age": 30, "category": "adult"}, - {"name": "charlie", "age": 35, "category": "adult"}, - ] - expected_tags = ["enriched_person1", "enriched_person2", "enriched_person3"] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_transform_with_none_values(self): - """Test transform with None values.""" - packets = [1, None, 3, None, 5] - tags = ["num1", "null1", "num3", "null2", "num5"] - - def handle_none_transform(packet, tag): - if packet is None: - return "MISSING", f"missing_{tag}" - else: - return packet * 2, f"doubled_{tag}" - - stream = SyncStreamFromLists(packets, tags) - transform_mapper = Transform(handle_none_transform) - transformed_stream = transform_mapper(stream) - - result = list(transformed_stream) - - expected_packets = [2, "MISSING", 6, "MISSING", 10] - expected_tags = [ - "doubled_num1", - "missing_null1", - "doubled_num3", - "missing_null2", - "doubled_num5", - ] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_transform_preserves_order(self): - """Test that transform preserves packet order.""" - packets = [f"packet_{i}" for i in range(100)] - tags = [f"tag_{i}" for i in range(100)] - - def add_index(packet, tag): - index = int(packet.split("_")[1]) - return f"indexed_{index}_{packet}", f"indexed_{tag}" - - stream = SyncStreamFromLists(packets, tags) - transform_mapper = Transform(add_index) - transformed_stream = transform_mapper(stream) - - result = list(transformed_stream) - - expected_packets = [f"indexed_{i}_packet_{i}" for i in range(100)] - expected_tags = [f"indexed_tag_{i}" for i in range(100)] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_transform_exception_handling(self): - """Test transform when transformation function raises exception.""" - packets = [1, 2, "invalid", 4] - tags = ["num1", "num2", "str1", "num4"] - - def divide_transform(packet, tag): - return 10 / packet, f"divided_{tag}" # Will fail on "invalid" - - stream = SyncStreamFromLists(packets, tags) - transform_mapper = Transform(divide_transform) - transformed_stream = transform_mapper(stream) - - # Should raise exception when processing "invalid" - with pytest.raises(TypeError): - list(transformed_stream) - - def test_transform_with_lambda(self): - """Test transform with lambda function.""" - packets = [1, 2, 3, 4, 5] - tags = ["a", "b", "c", "d", "e"] - - stream = SyncStreamFromLists(packets, tags) - transform_mapper = Transform(lambda p, t: (p**2, t.upper())) - transformed_stream = transform_mapper(stream) - - result = list(transformed_stream) - - expected_packets = [1, 4, 9, 16, 25] - expected_tags = ["A", "B", "C", "D", "E"] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_transform_chaining(self): - """Test chaining multiple transform operations.""" - packets = [1, 2, 3, 4, 5] - tags = ["num1", "num2", "num3", "num4", "num5"] - - stream = SyncStreamFromLists(packets, tags) - - # First transformation: double the packet - transform1 = Transform(lambda p, t: (p * 2, f"doubled_{t}")) - stream1 = transform1(stream) - - # Second transformation: add 10 to packet - transform2 = Transform(lambda p, t: (p + 10, f"added_{t}")) - stream2 = transform2(stream1) - - result = list(stream2) - - expected_packets = [12, 14, 16, 18, 20] # (original * 2) + 10 - expected_tags = [ - "added_doubled_num1", - "added_doubled_num2", - "added_doubled_num3", - "added_doubled_num4", - "added_doubled_num5", - ] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_transform_with_packet_type(self): - """Test transform with PacketType objects.""" - packets = [PacketType("data1"), PacketType("data2")] - tags = ["type1", "type2"] - - def extract_and_modify(packet, tag): - data = str(packet) # Convert to string - return f"extracted_{data}", f"processed_{tag}" - - stream = SyncStreamFromLists(packets, tags) - transform_mapper = Transform(extract_and_modify) - transformed_stream = transform_mapper(stream) - - result = list(transformed_stream) - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert len(actual_packets) == 2 - assert all("extracted_" in p for p in actual_packets) - assert actual_tags == ["processed_type1", "processed_type2"] - - def test_transform_stateful(self): - """Test transform with stateful transformation.""" - packets = [1, 2, 3, 4, 5] - tags = ["n1", "n2", "n3", "n4", "n5"] - - class StatefulTransform: - def __init__(self): - self.counter = 0 - - def transform(self, packet, tag): - self.counter += 1 - return (packet + self.counter, f"{tag}_step_{self.counter}") - - stateful = StatefulTransform() - stream = SyncStreamFromLists(packets, tags) - transform_mapper = Transform(stateful.transform) - transformed_stream = transform_mapper(stream) - - result = list(transformed_stream) - - expected_packets = [2, 4, 6, 8, 10] # packet + step_number - expected_tags = [ - "n1_step_1", - "n2_step_2", - "n3_step_3", - "n4_step_4", - "n5_step_5", - ] - - actual_packets = [packet for packet, _ in result] - actual_tags = [tag for _, tag in result] - - assert actual_packets == expected_packets - assert actual_tags == expected_tags - - def test_transform_pickle(self): - """Test that Transform mapper is pickleable.""" - import pickle - from orcabridge.mappers import Transform - - def add_prefix(tag, packet): - new_tag = {**tag, "prefix": "test"} - new_packet = {**packet, "processed": True} - return new_tag, new_packet - - transform = Transform(add_prefix) - pickled = pickle.dumps(transform) - unpickled = pickle.loads(pickled) - - # Test that unpickled mapper works the same - assert isinstance(unpickled, Transform) - assert unpickled.__class__.__name__ == "Transform" diff --git a/tests/test_streams_operations/test_mappers/test_utility_functions.py b/tests/test_streams_operations/test_mappers/test_utility_functions.py deleted file mode 100644 index 9cae09e..0000000 --- a/tests/test_streams_operations/test_mappers/test_utility_functions.py +++ /dev/null @@ -1,248 +0,0 @@ -"""Tests for utility functions tag() and packet().""" - -from orcabridge.mappers import tag, packet -from orcabridge.streams import SyncStreamFromLists - - -class TestUtilityFunctions: - """Test cases for tag() and packet() utility functions.""" - - def test_tag_function_basic(self): - """Test basic tag() function functionality.""" - tags = [ - {"old_key": "value1", "other": "data1"}, - {"old_key": "value2", "other": "data2"}, - ] - packets = [ - {"data": "packet1"}, - {"data": "packet2"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - tag_mapper = tag({"old_key": "new_key"}) - transformed_stream = tag_mapper(stream) - - results = list(transformed_stream) - - assert len(results) == 2 - for (result_tag, result_packet), original_packet in zip(results, packets): - # Tag should be transformed - assert "new_key" in result_tag - assert "old_key" not in result_tag # old key dropped by default - assert result_tag["new_key"] in ["value1", "value2"] - - # Packet should be unchanged - assert result_packet == original_packet - - def test_tag_function_keep_unmapped(self): - """Test tag() function with drop_unmapped=False.""" - tags = [ - {"old_key": "value1", "keep_this": "data1"}, - ] - packets = [ - {"data": "packet1"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - tag_mapper = tag({"old_key": "new_key"}, drop_unmapped=False) - transformed_stream = tag_mapper(stream) - - results = list(transformed_stream) - - assert len(results) == 1 - result_tag, result_packet = results[0] - - # Should have both mapped and unmapped keys - assert result_tag["new_key"] == "value1" - assert result_tag["keep_this"] == "data1" - - def test_packet_function_basic(self): - """Test basic packet() function functionality.""" - tags = [ - {"tag_data": "tag1"}, - {"tag_data": "tag2"}, - ] - packets = [ - {"old_key": "value1", "other": "data1"}, - {"old_key": "value2", "other": "data2"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - packet_mapper = packet({"old_key": "new_key"}) - transformed_stream = packet_mapper(stream) - - results = list(transformed_stream) - - assert len(results) == 2 - for (result_tag, result_packet), original_tag in zip(results, tags): - # Tag should be unchanged - assert result_tag == original_tag - - # Packet should be transformed - assert "new_key" in result_packet - assert "old_key" not in result_packet # old key dropped by default - assert result_packet["new_key"] in ["value1", "value2"] - - def test_packet_function_keep_unmapped(self): - """Test packet() function with drop_unmapped=False.""" - tags = [ - {"tag_data": "tag1"}, - ] - packets = [ - {"old_key": "value1", "keep_this": "data1"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - packet_mapper = packet({"old_key": "new_key"}, drop_unmapped=False) - transformed_stream = packet_mapper(stream) - - results = list(transformed_stream) - - assert len(results) == 1 - result_tag, result_packet = results[0] - - # Should have both mapped and unmapped keys - assert result_packet["new_key"] == "value1" - assert result_packet["keep_this"] == "data1" - - def test_tag_function_empty_mapping(self): - """Test tag() function with empty mapping.""" - tags = [ - {"key1": "value1", "key2": "value2"}, - ] - packets = [ - {"data": "packet1"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - tag_mapper = tag({}) # Empty mapping - transformed_stream = tag_mapper(stream) - - results = list(transformed_stream) - - assert len(results) == 1 - result_tag, result_packet = results[0] - - # With empty mapping and drop_unmapped=True (default), all keys should be dropped - assert result_tag == {} - assert result_packet == packets[0] # Packet unchanged - - def test_packet_function_empty_mapping(self): - """Test packet() function with empty mapping.""" - tags = [ - {"tag_data": "tag1"}, - ] - packets = [ - {"key1": "value1", "key2": "value2"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - packet_mapper = packet({}) # Empty mapping - transformed_stream = packet_mapper(stream) - - results = list(transformed_stream) - - assert len(results) == 1 - result_tag, result_packet = results[0] - - # With empty mapping and drop_unmapped=True (default), all keys should be dropped - assert result_tag == tags[0] # Tag unchanged - assert result_packet == {} - - def test_tag_function_chaining(self): - """Test chaining multiple tag() transformations.""" - tags = [ - {"a": "value1", "b": "value2", "c": "value3"}, - ] - packets = [ - {"data": "packet1"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - - # Chain transformations - tag_mapper1 = tag({"a": "new_a"}, drop_unmapped=False) - tag_mapper2 = tag({"b": "new_b"}, drop_unmapped=False) - - transformed_stream = tag_mapper2(tag_mapper1(stream)) - - results = list(transformed_stream) - - assert len(results) == 1 - result_tag, result_packet = results[0] - - # Should have transformations from both mappers - assert result_tag["new_a"] == "value1" - assert result_tag["new_b"] == "value2" - assert result_tag["c"] == "value3" # Unchanged - - def test_packet_function_chaining(self): - """Test chaining multiple packet() transformations.""" - tags = [ - {"tag_data": "tag1"}, - ] - packets = [ - {"a": "value1", "b": "value2", "c": "value3"}, - ] - - stream = SyncStreamFromLists(tags=tags, packets=packets) - - # Chain transformations - packet_mapper1 = packet({"a": "new_a"}, drop_unmapped=False) - packet_mapper2 = packet({"b": "new_b"}, drop_unmapped=False) - - transformed_stream = packet_mapper2(packet_mapper1(stream)) - - results = list(transformed_stream) - - assert len(results) == 1 - result_tag, result_packet = results[0] - - # Should have transformations from both mappers - assert result_packet["new_a"] == "value1" - assert result_packet["new_b"] == "value2" - assert result_packet["c"] == "value3" # Unchanged - - def test_utility_functions_pickle(self): - """Test that utility functions tag() and packet() are pickleable.""" - import pickle - - # Test tag() function - tag_mapper = tag({"old_key": "new_key"}) - pickled_tag = pickle.dumps(tag_mapper) - unpickled_tag = pickle.loads(pickled_tag) - - # Test that unpickled tag mapper works - assert callable(unpickled_tag) - - # Test packet() function - packet_mapper = packet({"old_key": "new_key"}) - pickled_packet = pickle.dumps(packet_mapper) - unpickled_packet = pickle.loads(pickled_packet) - - # Test that unpickled packet mapper works - assert callable(unpickled_packet) - - def test_utility_functions_with_complex_streams(self, sample_stream): - """Test utility functions with complex streams from fixtures.""" - # Test tag() with sample stream - tag_mapper = tag({"file_name": "filename"}, drop_unmapped=False) - transformed_stream = tag_mapper(sample_stream) - - results = list(transformed_stream) - - for result_tag, _ in results: - assert "filename" in result_tag - assert result_tag["filename"] in ["day1", "day2", "day3"] - assert "session" in result_tag # Kept because drop_unmapped=False - - # Test packet() with sample stream - packet_mapper = packet({"txt_file": "text_file"}, drop_unmapped=False) - transformed_stream = packet_mapper(sample_stream) - - results = list(transformed_stream) - - for _, result_packet in results: - assert "text_file" in result_packet - assert "data" in result_packet["text_file"] - assert "metadata" in result_packet # Kept because drop_unmapped=False diff --git a/tests/test_streams_operations/test_pipelines/__init__.py b/tests/test_streams_operations/test_pipelines/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_streams_operations/test_pipelines/test_basic_pipelines.py b/tests/test_streams_operations/test_pipelines/test_basic_pipelines.py deleted file mode 100644 index 494fec3..0000000 --- a/tests/test_streams_operations/test_pipelines/test_basic_pipelines.py +++ /dev/null @@ -1,542 +0,0 @@ -""" -Test module for basic pipeline operations. - -This module tests fundamental pipeline construction and execution, -including chaining operations, combining multiple streams, and -basic data flow patterns as demonstrated in the notebooks. -""" - -import pytest -import tempfile -from pathlib import Path - -from orcabridge.base import SyncStream -from orcabridge.streams import SyncStreamFromLists -from orcabridge.mappers import ( - Join, - Merge, - Filter, - Transform, - MapPackets, - MapTags, - Repeat, - DefaultTag, - Batch, - FirstMatch, -) -from orcabridge.sources import GlobSource -from orcabridge.pod import FunctionPod - - -@pytest.fixture -def temp_files(): - """Create temporary files for testing.""" - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create test files - files = {} - for i in range(1, 4): - file_path = temp_path / f"test_{i}.txt" - content = f"Content of file {i}\nLine 2 of file {i}" - with open(file_path, "w") as f: - f.write(content) - files[f"test_{i}.txt"] = file_path - - yield temp_path, files - - -@pytest.fixture -def sample_user_data(): - """Sample user data for pipeline testing.""" - return [ - ({"user_id": 1, "session": "a"}, {"name": "Alice", "age": 25, "score": 85}), - ({"user_id": 2, "session": "a"}, {"name": "Bob", "age": 30, "score": 92}), - ({"user_id": 3, "session": "b"}, {"name": "Charlie", "age": 28, "score": 78}), - ({"user_id": 1, "session": "b"}, {"name": "Alice", "age": 25, "score": 88}), - ] - - -@pytest.fixture -def sample_metadata(): - """Sample metadata for joining.""" - return [ - ({"user_id": 1}, {"department": "Engineering", "level": "Senior"}), - ({"user_id": 2}, {"department": "Marketing", "level": "Junior"}), - ({"user_id": 3}, {"department": "Engineering", "level": "Mid"}), - ] - - -class TestBasicPipelineConstruction: - """Test basic pipeline construction patterns.""" - - def test_simple_linear_pipeline(self, sample_user_data): - """Test simple linear pipeline with chained operations.""" - tags, packets = zip(*sample_user_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Build pipeline: filter -> transform -> map packets - pipeline = ( - source_stream - >> Filter(lambda tag, packet: packet["age"] >= 28) - >> Transform( - lambda tag, packet: (tag, {**packet, "category": "experienced"}) - ) - >> MapPackets({"name": "full_name", "score": "performance"}) - ) - - result = list(pipeline) - - # Should have filtered out users under 28 - assert len(result) == 3 - - # Check transformations applied - for tag, packet in result: - assert packet["age"] >= 28 - assert packet["category"] == "experienced" - assert "full_name" in packet - assert "performance" in packet - assert "name" not in packet # Should be mapped - assert "score" not in packet # Should be mapped - - def test_pipeline_with_join(self, sample_user_data, sample_metadata): - """Test pipeline with join operation.""" - # Create streams - user_tags, user_packets = zip(*sample_user_data) - meta_tags, meta_packets = zip(*sample_metadata) - - user_stream = SyncStreamFromLists(list(user_tags), list(user_packets)) - meta_stream = SyncStreamFromLists(list(meta_tags), list(meta_packets)) - - # Join streams on user_id - joined = Join()(user_stream, meta_stream) - result = list(joined) - - # Should have joined records where user_id matches - assert len(result) >= 2 # At least Alice and Bob should match - - # Check that joined data has both user and metadata info - for tag, packet in result: - assert "user_id" in tag - assert "name" in packet # From user data - assert "department" in packet # From metadata - - def test_pipeline_with_merge(self, sample_user_data): - """Test pipeline with merge operation.""" - tags, packets = zip(*sample_user_data) - - # Split data into two streams - stream1 = SyncStreamFromLists(list(tags[:2]), list(packets[:2])) - stream2 = SyncStreamFromLists(list(tags[2:]), list(packets[2:])) - - # Merge streams - merged = Merge()(stream1, stream2) - result = list(merged) - - # Should have all items from both streams - assert len(result) == 4 - - # Order might be different but all data should be present - result_user_ids = [tag["user_id"] for tag, packet in result] - expected_user_ids = [tag["user_id"] for tag, packet in sample_user_data] - assert sorted(result_user_ids) == sorted(expected_user_ids) - - def test_pipeline_with_batch_processing(self, sample_user_data): - """Test pipeline with batch processing.""" - tags, packets = zip(*sample_user_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Create batches of size 2 - batched = Batch(batch_size=2)(source_stream) - result = list(batched) - - # Should have 2 batches (4 items / 2 per batch) - assert len(result) == 2 - - # Each result should be a batch - for tag, packet in result: - assert isinstance(packet, list) - assert len(packet) == 2 - # Tag should be batch representation of individual tags - assert isinstance(tag, dict) - - def test_pipeline_with_repeat_operation(self, sample_user_data): - """Test pipeline with repeat operation.""" - tags, packets = zip(*sample_user_data) - source_stream = SyncStreamFromLists( - list(tags[:2]), list(packets[:2]) - ) # Use first 2 items - - # Repeat each item 3 times - repeated = Repeat(repeat_count=3)(source_stream) - result = list(repeated) - - # Should have 6 items total (2 original * 3 repeats) - assert len(result) == 6 - - # Check that items are correctly repeated - assert result[0] == result[1] == result[2] # First item repeated - assert result[3] == result[4] == result[5] # Second item repeated - - def test_complex_multi_stage_pipeline(self, sample_user_data, sample_metadata): - """Test complex pipeline with multiple stages and branches.""" - # Create source streams - user_tags, user_packets = zip(*sample_user_data) - meta_tags, meta_packets = zip(*sample_metadata) - - user_stream = SyncStreamFromLists(list(user_tags), list(user_packets)) - meta_stream = SyncStreamFromLists(list(meta_tags), list(meta_packets)) - - # Complex pipeline: - # 1. Add default tags to user stream - # 2. Join with metadata - # 3. Filter by age and score - # 4. Transform and map fields - pipeline = ( - DefaultTag({"source": "user_system"})(user_stream) - * meta_stream # Join operation - >> Filter(lambda tag, packet: packet["age"] >= 25 and packet["score"] >= 80) - >> Transform( - lambda tag, packet: ( - {**tag, "processed": True}, - {**packet, "grade": "A" if packet["score"] >= 90 else "B"}, - ) - ) - >> MapPackets({"name": "employee_name", "department": "dept"}) - ) - - result = list(pipeline) - - # Verify complex transformations - for tag, packet in result: - assert tag["source"] == "user_system" - assert tag["processed"] is True - assert packet["age"] >= 25 - assert packet["score"] >= 80 - assert packet["grade"] in ["A", "B"] - assert "employee_name" in packet - assert "dept" in packet - - def test_pipeline_error_propagation(self, sample_user_data): - """Test that errors propagate correctly through pipeline.""" - tags, packets = zip(*sample_user_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Create pipeline with operation that will fail - def failing_transform(tag, packet): - if packet["age"] > 29: - raise ValueError("Age too high!") - return tag, packet - - pipeline = source_stream >> Transform(failing_transform) - - # Should propagate the error - with pytest.raises(ValueError, match="Age too high!"): - list(pipeline) - - def test_pipeline_with_empty_stream(self): - """Test pipeline behavior with empty streams.""" - empty_stream = SyncStreamFromLists([], []) - - # Apply operations to empty stream - pipeline = ( - empty_stream - >> Filter(lambda tag, packet: True) - >> Transform(lambda tag, packet: (tag, {**packet, "processed": True})) - ) - - result = list(pipeline) - assert result == [] - - def test_pipeline_with_first_match(self, sample_user_data, sample_metadata): - """Test pipeline with FirstMatch operation.""" - user_tags, user_packets = zip(*sample_user_data) - meta_tags, meta_packets = zip(*sample_metadata) - - user_stream = SyncStreamFromLists(list(user_tags), list(user_packets)) - meta_stream = SyncStreamFromLists(list(meta_tags), list(meta_packets)) - - # Use FirstMatch instead of Join - matched = FirstMatch()(user_stream, meta_stream) - result = list(matched) - - # FirstMatch should consume items from both streams - assert len(result) <= len(sample_user_data) - - # Each result should have matched data - for tag, packet in result: - assert "user_id" in tag - assert "name" in packet or "department" in packet - - -class TestPipelineDataFlow: - """Test data flow patterns in pipelines.""" - - def test_data_preservation_through_pipeline(self, sample_user_data): - """Test that data is correctly preserved through transformations.""" - tags, packets = zip(*sample_user_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Track original data - original_user_ids = [tag["user_id"] for tag, packet in sample_user_data] - original_names = [packet["name"] for tag, packet in sample_user_data] - - # Pipeline that shouldn't lose data - pipeline = ( - source_stream - >> MapTags({"user_id": "id"}) # Rename tag field - >> MapPackets({"name": "username"}) # Rename packet field - ) - - result = list(pipeline) - - # Check data preservation - result_ids = [tag["id"] for tag, packet in result] - result_names = [packet["username"] for tag, packet in result] - - assert sorted(result_ids) == sorted(original_user_ids) - assert sorted(result_names) == sorted(original_names) - - def test_data_aggregation_pipeline(self, sample_user_data): - """Test pipeline that aggregates data.""" - tags, packets = zip(*sample_user_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Aggregate by session - def aggregate_by_session(tag, packet): - return {"session": tag["session"]}, { - "users": [packet["name"]], - "avg_score": packet["score"], - "count": 1, - } - - # Transform and then batch by session (simplified aggregation) - pipeline = source_stream >> Transform(aggregate_by_session) - - result = list(pipeline) - - # Should have transformed all items - assert len(result) == len(sample_user_data) - - # Check session-based grouping - sessions = [tag["session"] for tag, packet in result] - assert "a" in sessions - assert "b" in sessions - - def test_conditional_processing_pipeline(self, sample_user_data): - """Test pipeline with conditional processing branches.""" - tags, packets = zip(*sample_user_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Split into high and low performers - high_performers = ( - source_stream - >> Filter(lambda tag, packet: packet["score"] >= 85) - >> Transform( - lambda tag, packet: ( - {**tag, "category": "high"}, - {**packet, "bonus": packet["score"] * 0.1}, - ) - ) - ) - - low_performers = ( - source_stream - >> Filter(lambda tag, packet: packet["score"] < 85) - >> Transform( - lambda tag, packet: ( - {**tag, "category": "low"}, - {**packet, "training": True}, - ) - ) - ) - - # Merge results - combined = Merge()(high_performers, low_performers) - result = list(combined) - - # Check that all items are categorized - categories = [tag["category"] for tag, packet in result] - assert "high" in categories - assert "low" in categories - - # Check conditional processing - for tag, packet in result: - if tag["category"] == "high": - assert "bonus" in packet - assert packet["score"] >= 85 - else: - assert "training" in packet - assert packet["score"] < 85 - - -class TestPipelineWithSources: - """Test pipelines starting from sources.""" - - def test_pipeline_from_glob_source(self, temp_files): - """Test pipeline starting from GlobSource.""" - temp_dir, files = temp_files - - # Create source - source = GlobSource(str(temp_dir / "*.txt")) - - # Build pipeline - pipeline = ( - source - >> Transform( - lambda tag, packet: ( - {**tag, "processed": True}, - {**packet, "line_count": len(packet["content"].split("\n"))}, - ) - ) - >> Filter(lambda tag, packet: packet["line_count"] >= 2) - ) - - result = list(pipeline) - - # Should have all files (each has 2 lines) - assert len(result) == 3 - - # Check processing - for tag, packet in result: - assert tag["processed"] is True - assert packet["line_count"] == 2 - assert "path" in tag - - def test_pipeline_with_function_pod(self, sample_user_data): - """Test pipeline with FunctionPod processing.""" - tags, packets = zip(*sample_user_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Create processing function - def enrich_user_data(tag, packet): - """Add computed fields to user data.""" - return tag, { - **packet, - "age_group": "young" if packet["age"] < 30 else "mature", - "performance": "excellent" if packet["score"] >= 90 else "good", - } - - # Create pod - processor = FunctionPod(enrich_user_data) - - # Build pipeline - pipeline = ( - source_stream - >> processor - >> Filter(lambda tag, packet: packet["performance"] == "excellent") - ) - - result = list(pipeline) - - # Check processing - for tag, packet in result: - assert packet["performance"] == "excellent" - assert packet["age_group"] in ["young", "mature"] - assert packet["score"] >= 90 - - -class TestPipelineOptimization: - """Test pipeline optimization and efficiency.""" - - def test_pipeline_lazy_evaluation(self, sample_user_data): - """Test that pipeline operations are lazily evaluated.""" - call_log = [] - - def logging_transform(tag, packet): - call_log.append(f"processing_{tag['user_id']}") - return tag, packet - - tags, packets = zip(*sample_user_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Build pipeline but don't execute - pipeline = ( - source_stream - >> Transform(logging_transform) - >> Filter(lambda tag, packet: packet["age"] >= 28) - ) - - # No processing should have happened yet - assert call_log == [] - - # Start consuming pipeline - iterator = iter(pipeline) - next(iterator) - - # Now some processing should have happened - assert len(call_log) >= 1 - - def test_pipeline_memory_efficiency(self): - """Test pipeline memory efficiency with large data.""" - - def large_data_generator(): - for i in range(1000): - yield ({"id": i}, {"value": i * 2, "data": f"item_{i}"}) - - # Create pipeline that processes large stream - from orcabridge.stream import SyncStreamFromGenerator - - source = SyncStreamFromGenerator(large_data_generator) - pipeline = ( - source - >> Filter(lambda tag, packet: tag["id"] % 10 == 0) # Keep every 10th item - >> Transform(lambda tag, packet: (tag, {**packet, "filtered": True})) - ) - - # Process in chunks - count = 0 - for tag, packet in pipeline: - assert packet["filtered"] is True - assert tag["id"] % 10 == 0 - count += 1 - - if count >= 10: # Don't process all items - break - - assert count == 10 - - def test_pipeline_error_recovery(self, sample_user_data): - """Test pipeline behavior with partial errors.""" - tags, packets = zip(*sample_user_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - def sometimes_failing_transform(tag, packet): - if packet["name"] == "Bob": # Fail for Bob - raise ValueError("Bob processing failed") - return tag, {**packet, "processed": True} - - # This pipeline will fail partway through - pipeline = source_stream >> Transform(sometimes_failing_transform) - - # Should fail when reaching Bob - with pytest.raises(ValueError, match="Bob processing failed"): - list(pipeline) - - def test_pipeline_reusability(self, sample_user_data): - """Test that pipeline components can be reused.""" - # Create reusable operations - age_filter = Filter(lambda tag, packet: packet["age"] >= 28) - score_transform = Transform( - lambda tag, packet: ( - tag, - {**packet, "grade": "A" if packet["score"] >= 90 else "B"}, - ) - ) - - tags, packets = zip(*sample_user_data) - stream1 = SyncStreamFromLists(list(tags[:2]), list(packets[:2])) - stream2 = SyncStreamFromLists(list(tags[2:]), list(packets[2:])) - - # Apply same operations to different streams - pipeline1 = stream1 >> age_filter >> score_transform - pipeline2 = stream2 >> age_filter >> score_transform - - result1 = list(pipeline1) - result2 = list(pipeline2) - - # Both should work independently - for tag, packet in result1 + result2: - if len([tag, packet]) > 0: # If any results - assert packet["age"] >= 28 - assert packet["grade"] in ["A", "B"] diff --git a/tests/test_streams_operations/test_pipelines/test_recursive_features.py b/tests/test_streams_operations/test_pipelines/test_recursive_features.py deleted file mode 100644 index 89a2646..0000000 --- a/tests/test_streams_operations/test_pipelines/test_recursive_features.py +++ /dev/null @@ -1,637 +0,0 @@ -""" -Test module for recursive features and advanced pipeline patterns. - -This module tests advanced orcabridge features including recursive stream -operations, label chaining, length operations, source invocation patterns, -and complex pipeline compositions as demonstrated in the notebooks. -""" - -import pytest -import tempfile -from pathlib import Path -from unittest.mock import Mock, patch - -from orcabridge.base import SyncStream, Operation -from orcabridge.streams import SyncStreamFromLists, SyncStreamFromGenerator -from orcabridge.mappers import ( - Join, - Merge, - Filter, - Transform, - MapPackets, - MapTags, - Repeat, - DefaultTag, - Batch, - CacheStream, -) -from orcabridge.sources import GlobSource -from orcabridge.pod import FunctionPod - - -@pytest.fixture -def hierarchical_data(): - """Hierarchical data for testing recursive operations.""" - return [ - ( - {"level": 1, "parent": None, "id": "root"}, - {"name": "Root", "children": ["a", "b"]}, - ), - ( - {"level": 2, "parent": "root", "id": "a"}, - {"name": "Node A", "children": ["a1", "a2"]}, - ), - ( - {"level": 2, "parent": "root", "id": "b"}, - {"name": "Node B", "children": ["b1"]}, - ), - ({"level": 3, "parent": "a", "id": "a1"}, {"name": "Leaf A1", "children": []}), - ({"level": 3, "parent": "a", "id": "a2"}, {"name": "Leaf A2", "children": []}), - ({"level": 3, "parent": "b", "id": "b1"}, {"name": "Leaf B1", "children": []}), - ] - - -@pytest.fixture -def temp_nested_files(): - """Create nested file structure for testing.""" - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create nested directory structure - (temp_path / "level1").mkdir() - (temp_path / "level1" / "level2").mkdir() - - files = {} - - # Root level files - for i in range(3): - file_path = temp_path / f"root_{i}.txt" - with open(file_path, "w") as f: - f.write(f"Root file {i}") - files[f"root_{i}"] = file_path - - # Level 1 files - for i in range(2): - file_path = temp_path / "level1" / f"l1_{i}.txt" - with open(file_path, "w") as f: - f.write(f"Level 1 file {i}") - files[f"l1_{i}"] = file_path - - # Level 2 files - file_path = temp_path / "level1" / "level2" / "l2_0.txt" - with open(file_path, "w") as f: - f.write("Level 2 file") - files["l2_0"] = file_path - - yield temp_path, files - - -class TestRecursiveStreamOperations: - """Test recursive and self-referential stream operations.""" - - def test_recursive_stream_processing(self, hierarchical_data): - """Test recursive processing of hierarchical data.""" - tags, packets = zip(*hierarchical_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - def process_level(stream, max_level=3): - """Recursively process each level.""" - - def level_processor(tag, packet): - level = tag["level"] - if level < max_level: - # Add processing marker - return tag, {**packet, f"processed_level_{level}": True} - else: - # Leaf nodes get different processing - return tag, {**packet, "is_leaf": True} - - return Transform(level_processor)(stream) - - # Apply recursive processing - processed = process_level(source_stream) - result = list(processed) - - # Check that different levels are processed differently - for tag, packet in result: - level = tag["level"] - if level < 3: - assert f"processed_level_{level}" in packet - else: - assert packet["is_leaf"] is True - - def test_recursive_stream_expansion(self, hierarchical_data): - """Test recursive expansion of stream data.""" - # Start with root nodes only - root_data = [item for item in hierarchical_data if item[0]["parent"] is None] - tags, packets = zip(*root_data) - root_stream = SyncStreamFromLists(list(tags), list(packets)) - - def expand_children(tag, packet): - """Generate child nodes for each parent.""" - children = packet.get("children", []) - for child_id in children: - # Find child data from hierarchical_data - for h_tag, h_packet in hierarchical_data: - if h_tag["id"] == child_id: - yield h_tag, h_packet - break - - # Create expanding pod - expander = FunctionPod(expand_children) - expanded = expander(root_stream) - result = list(expanded) - - # Should have expanded to include all children - assert len(result) >= 2 # At least the immediate children - - def test_recursive_filtering_cascade(self, hierarchical_data): - """Test recursive filtering that cascades through levels.""" - tags, packets = zip(*hierarchical_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Create a cascade of filters for each level - level_1_filter = Filter(lambda tag, packet: tag["level"] == 1) - level_2_filter = Filter(lambda tag, packet: tag["level"] <= 2) - level_3_filter = Filter(lambda tag, packet: tag["level"] <= 3) - - # Apply filters recursively - def recursive_filter(stream, current_level=1): - if current_level == 1: - filtered = level_1_filter(stream) - elif current_level == 2: - filtered = level_2_filter(stream) - else: - filtered = level_3_filter(stream) - - return filtered - - # Test each level - level_1_result = list(recursive_filter(source_stream, 1)) - level_2_result = list(recursive_filter(source_stream, 2)) - level_3_result = list(recursive_filter(source_stream, 3)) - - assert len(level_1_result) == 1 # Only root - assert len(level_2_result) == 3 # Root + level 2 nodes - assert len(level_3_result) == 6 # All nodes - - def test_self_referential_stream_operations(self, hierarchical_data): - """Test operations that reference the stream itself.""" - tags, packets = zip(*hierarchical_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Cache the stream for self-reference - cache = CacheStream() - cached_stream = cache(source_stream) - - # Consume the cache - list(cached_stream) - - # Now create operations that reference the cached data - def find_parent_info(tag, packet): - parent_id = tag.get("parent") - if parent_id: - # Look up parent in cached stream - for cached_tag, cached_packet in cache.cache: - if cached_tag["id"] == parent_id: - return tag, { - **packet, - "parent_name": cached_packet["name"], - "parent_level": cached_tag["level"], - } - return tag, {**packet, "parent_name": None, "parent_level": None} - - # Apply parent lookup - enriched = Transform(find_parent_info)(cached_stream) - result = list(enriched) - - # Check parent information was added - for tag, packet in result: - if tag["parent"] is not None: - assert packet["parent_name"] is not None - assert packet["parent_level"] is not None - - -class TestLabelAndLengthOperations: - """Test label manipulation and length operations.""" - - def test_label_chaining_operations(self, hierarchical_data): - """Test chaining operations with label tracking.""" - tags, packets = zip(*hierarchical_data) - source_stream = SyncStreamFromLists( - list(tags), list(packets), label="hierarchical_source" - ) - - # Create labeled operations - filter_op = Filter(lambda tag, packet: tag["level"] <= 2) - transform_op = Transform( - lambda tag, packet: (tag, {**packet, "processed": True}) - ) - - # Apply operations and track labels - filtered = filter_op(source_stream) - assert filtered.label.startswith("Filter_") - - transformed = transform_op(filtered) - assert transformed.label.startswith("Transform_") - - # Check that invocation chain is maintained - result = list(transformed) - assert len(result) == 3 # Root + 2 level-2 nodes - - def test_stream_length_operations(self, hierarchical_data): - """Test operations that depend on stream length.""" - tags, packets = zip(*hierarchical_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - def length_dependent_transform(tag, packet): - # This would need to know stream length - # For simulation, we'll use a mock length - stream_length = 6 # Known length of hierarchical_data - return tag, { - **packet, - "relative_position": tag["level"] / 3, # Relative to max level - "is_majority_level": tag["level"] == 3, # Most nodes are level 3 - } - - processed = Transform(length_dependent_transform)(source_stream) - result = list(processed) - - # Check length-dependent calculations - for tag, packet in result: - assert "relative_position" in packet - assert "is_majority_level" in packet - if tag["level"] == 3: - assert packet["is_majority_level"] is True - - def test_dynamic_label_generation(self, hierarchical_data): - """Test dynamic label generation based on stream content.""" - tags, packets = zip(*hierarchical_data) - - # Create streams with content-based labels - def create_labeled_stream(data, label_func): - stream_tags, stream_packets = zip(*data) - label = label_func(data) - return SyncStreamFromLists( - list(stream_tags), list(stream_packets), label=label - ) - - # Different labeling strategies - level_1_data = [item for item in hierarchical_data if item[0]["level"] == 1] - level_2_data = [item for item in hierarchical_data if item[0]["level"] == 2] - level_3_data = [item for item in hierarchical_data if item[0]["level"] == 3] - - stream_1 = create_labeled_stream( - level_1_data, lambda data: f"level_1_stream_{len(data)}_items" - ) - stream_2 = create_labeled_stream( - level_2_data, lambda data: f"level_2_stream_{len(data)}_items" - ) - stream_3 = create_labeled_stream( - level_3_data, lambda data: f"level_3_stream_{len(data)}_items" - ) - - assert stream_1.label == "level_1_stream_1_items" - assert stream_2.label == "level_2_stream_2_items" - assert stream_3.label == "level_3_stream_3_items" - - -class TestSourceInvocationPatterns: - """Test advanced source invocation and composition patterns.""" - - def test_multiple_source_composition(self, temp_nested_files): - """Test composing multiple sources with different patterns.""" - temp_path, files = temp_nested_files - - # Create different sources for different levels - root_source = GlobSource(str(temp_path / "*.txt"), label="root_files") - level1_source = GlobSource( - str(temp_path / "level1" / "*.txt"), label="level1_files" - ) - level2_source = GlobSource( - str(temp_path / "level1" / "level2" / "*.txt"), label="level2_files" - ) - - # Compose sources - all_sources = Merge()(root_source, level1_source, level2_source) - result = list(all_sources) - - # Should have files from all levels - assert len(result) >= 6 # 3 root + 2 level1 + 1 level2 - - # Check that files from different levels are included - paths = [tag["path"] for tag, packet in result] - assert any("root_" in str(path) for path in paths) - assert any("l1_" in str(path) for path in paths) - assert any("l2_" in str(path) for path in paths) - - def test_conditional_source_invocation(self, temp_nested_files): - """Test conditional source invocation based on data content.""" - temp_path, files = temp_nested_files - - def conditional_source_factory(condition): - """Create source based on condition.""" - if condition == "root": - return GlobSource(str(temp_path / "*.txt")) - elif condition == "nested": - return GlobSource(str(temp_path / "**" / "*.txt")) - else: - return SyncStreamFromLists([], []) # Empty stream - - # Test different conditions - root_stream = conditional_source_factory("root") - nested_stream = conditional_source_factory("nested") - empty_stream = conditional_source_factory("other") - - root_result = list(root_stream) - nested_result = list(nested_stream) - empty_result = list(empty_stream) - - assert len(root_result) == 3 # Only root files - assert len(nested_result) >= 6 # All files recursively - assert len(empty_result) == 0 - - def test_recursive_source_generation(self, temp_nested_files): - """Test recursive generation of sources.""" - temp_path, files = temp_nested_files - - def recursive_file_processor(tag, packet): - """Process file and potentially generate more sources.""" - file_path = Path(tag["path"]) - - # If this is a directory-like file, yield info about subdirectories - if "level1" in str(file_path.parent): - # This file is in level1, so it knows about level2 - yield tag, {**packet, "has_subdirs": True, "subdir_count": 1} - else: - yield tag, {**packet, "has_subdirs": False, "subdir_count": 0} - - # Start with root source - root_source = GlobSource(str(temp_path / "*.txt")) - - # Apply recursive processing - processor = FunctionPod(recursive_file_processor) - processed = processor(root_source) - result = list(processed) - - # Check recursive information - for tag, packet in result: - assert "has_subdirs" in packet - assert "subdir_count" in packet - - def test_source_caching_and_reuse(self, temp_nested_files): - """Test caching and reusing source results.""" - temp_path, files = temp_nested_files - - # Create cached source - source = GlobSource(str(temp_path / "*.txt")) - cache = CacheStream() - cached_source = cache(source) - - # First consumption - result1 = list(cached_source) - - # Verify caching worked - assert cache.is_cached - assert len(cache.cache) == 3 - - # Create new operations using cached source - filter_op = Filter(lambda tag, packet: "root_1" in str(tag["path"])) - transform_op = Transform(lambda tag, packet: (tag, {**packet, "reused": True})) - - # Apply operations to cached source - filtered = filter_op(cache()) # Use cached version - transformed = transform_op(cache()) # Use cached version again - - filter_result = list(filtered) - transform_result = list(transformed) - - # Both should work independently using cached data - assert len(filter_result) == 1 # Only root_1 file - assert len(transform_result) == 3 # All files with reused flag - - for tag, packet in transform_result: - assert packet["reused"] is True - - -class TestComplexPipelinePatterns: - """Test complex pipeline patterns and compositions.""" - - def test_branching_and_merging_pipeline(self, hierarchical_data): - """Test pipeline that branches and merges back together.""" - tags, packets = zip(*hierarchical_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Create branches for different processing paths - branch_a = ( - source_stream - >> Filter(lambda tag, packet: tag["level"] <= 2) - >> Transform( - lambda tag, packet: (tag, {**packet, "branch": "A", "priority": "high"}) - ) - ) - - branch_b = ( - source_stream - >> Filter(lambda tag, packet: tag["level"] == 3) - >> Transform( - lambda tag, packet: (tag, {**packet, "branch": "B", "priority": "low"}) - ) - ) - - # Merge branches back together - merged = Merge()(branch_a, branch_b) - result = list(merged) - - # Should have all original items but with branch processing - assert len(result) == 6 - - # Check branch assignments - branches = [packet["branch"] for tag, packet in result] - assert "A" in branches - assert "B" in branches - - def test_multi_level_pipeline_composition(self, hierarchical_data): - """Test multi-level pipeline composition with nested operations.""" - tags, packets = zip(*hierarchical_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Level 1: Basic filtering and transformation - level1_pipeline = ( - source_stream - >> Filter(lambda tag, packet: len(packet["name"]) > 5) - >> Transform( - lambda tag, packet: (tag, {**packet, "level1_processed": True}) - ) - ) - - # Level 2: Advanced processing based on level 1 - level2_pipeline = ( - level1_pipeline - >> MapTags({"level": "hierarchy_level", "id": "node_id"}) - >> MapPackets({"name": "node_name", "children": "child_nodes"}) - ) - - # Level 3: Final aggregation and summary - level3_pipeline = level2_pipeline >> Transform( - lambda tag, packet: ( - tag, - { - **packet, - "final_processed": True, - "child_count": len(packet["child_nodes"]), - "has_children": len(packet["child_nodes"]) > 0, - }, - ) - ) - - result = list(level3_pipeline) - - # Check multi-level processing - for tag, packet in result: - assert packet["level1_processed"] is True - assert packet["final_processed"] is True - assert "hierarchy_level" in tag - assert "node_name" in packet - assert "child_count" in packet - - def test_pipeline_with_feedback_loop(self, hierarchical_data): - """Test pipeline pattern that simulates feedback loops.""" - tags, packets = zip(*hierarchical_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - # Create a cache to simulate feedback - feedback_cache = CacheStream() - - # First pass: process and cache - first_pass = ( - source_stream - >> Transform(lambda tag, packet: (tag, {**packet, "pass": 1})) - >> feedback_cache - ) - - # Consume first pass to populate cache - first_result = list(first_pass) - - # Second pass: use cached data for enrichment - def enrich_with_feedback(tag, packet): - # Use cached data to enrich current item - related_items = [] - for cached_tag, cached_packet in feedback_cache.cache: - if ( - cached_tag["level"] == tag["level"] - and cached_tag["id"] != tag["id"] - ): - related_items.append(cached_packet["name"]) - - return tag, { - **packet, - "pass": 2, - "related_items": related_items, - "relation_count": len(related_items), - } - - second_pass = Transform(enrich_with_feedback)(feedback_cache()) - second_result = list(second_pass) - - # Check feedback enrichment - for tag, packet in second_result: - assert packet["pass"] == 2 - assert "related_items" in packet - assert "relation_count" in packet - - def test_pipeline_error_handling_and_recovery(self, hierarchical_data): - """Test pipeline error handling and recovery patterns.""" - tags, packets = zip(*hierarchical_data) - source_stream = SyncStreamFromLists(list(tags), list(packets)) - - def potentially_failing_operation(tag, packet): - # Fail on specific condition - if tag["id"] == "a1": # Fail on specific node - raise ValueError("Processing failed for a1") - return tag, {**packet, "processed": True} - - # Create error-tolerant pipeline - def error_tolerant_transform(tag, packet): - try: - return potentially_failing_operation(tag, packet) - except ValueError: - # Recovery: mark as failed but continue - return tag, {**packet, "processed": False, "error": True} - - pipeline = Transform(error_tolerant_transform)(source_stream) - result = list(pipeline) - - # Should have processed all items despite error - assert len(result) == 6 - - # Check error handling - failed_items = [ - item for tag, packet in result for item in [packet] if packet.get("error") - ] - successful_items = [ - item - for tag, packet in result - for item in [packet] - if packet.get("processed") - ] - - assert len(failed_items) == 1 # One failed item - assert len(successful_items) == 5 # Five successful items - - def test_dynamic_pipeline_construction(self, hierarchical_data): - """Test dynamic construction of pipelines based on data characteristics.""" - tags, packets = zip(*hierarchical_data) - - def build_dynamic_pipeline(data): - """Build pipeline based on data characteristics.""" - # Analyze data - levels = set(tag["level"] for tag, packet in data) - max_level = max(levels) - has_children = any(len(packet["children"]) > 0 for tag, packet in data) - - # Build pipeline dynamically - base_stream = SyncStreamFromLists( - [tag for tag, packet in data], [packet for tag, packet in data] - ) - - operations = [base_stream] - - # Add level-specific processing - if max_level > 2: - operations.append( - Transform( - lambda tag, packet: (tag, {**packet, "is_deep_hierarchy": True}) - ) - ) - - # Add child processing if needed - if has_children: - operations.append( - Transform( - lambda tag, packet: ( - tag, - { - **packet, - "child_info": f"has_{len(packet['children'])}_children", - }, - ) - ) - ) - - # Chain operations - pipeline = operations[0] - for op in operations[1:]: - if isinstance(op, Transform): - pipeline = op(pipeline) - - return pipeline - - # Build and execute dynamic pipeline - dynamic_pipeline = build_dynamic_pipeline(hierarchical_data) - result = list(dynamic_pipeline) - - # Check dynamic processing - for tag, packet in result: - assert "is_deep_hierarchy" in packet # Should be added due to max_level > 2 - assert "child_info" in packet # Should be added due to has_children diff --git a/tests/test_streams_operations/test_pods/__init__.py b/tests/test_streams_operations/test_pods/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_streams_operations/test_pods/test_function_pod.py b/tests/test_streams_operations/test_pods/test_function_pod.py deleted file mode 100644 index b7171f1..0000000 --- a/tests/test_streams_operations/test_pods/test_function_pod.py +++ /dev/null @@ -1,305 +0,0 @@ -"""Tests for FunctionPod functionality.""" - -import pytest -from orcabridge.pod import FunctionPod -from orcabridge.streams import SyncStreamFromLists - - -class TestFunctionPod: - """Test cases for FunctionPod.""" - - def test_function_pod_no_output(self, sample_stream, func_no_output): - """Test function pod with function that has no output.""" - pod = FunctionPod(func_no_output) - result_stream = pod(sample_stream) - - result = list(result_stream) - - # Should produce no output - assert len(result) == 0 - - def test_function_pod_single_output(self, sample_stream, func_single_output): - """Test function pod with function that has single output.""" - pod = FunctionPod(func_single_output) - result_stream = pod(sample_stream) - - result = list(result_stream) - - # Should produce one output per input - original_packets = list(sample_stream) - assert len(result) == len(original_packets) - - for i, (packet, tag) in enumerate(result): - expected_packet = f"processed_{original_packets[i][0]}" - assert packet == expected_packet - - def test_function_pod_multiple_outputs(self, sample_stream, func_multiple_outputs): - """Test function pod with function that has multiple outputs.""" - pod = FunctionPod(func_multiple_outputs) - result_stream = pod(sample_stream) - - result = list(result_stream) - - # Should produce two outputs per input - original_packets = list(sample_stream) - assert len(result) == len(original_packets) * 2 - - # Check that we get pairs of outputs - for i in range(0, len(result), 2): - original_idx = i // 2 - original_packet = original_packets[original_idx][0] - - # First output should be the packet itself - assert result[i][0] == original_packet - # Second output should be uppercased - assert result[i + 1][0] == str(original_packet).upper() - - def test_function_pod_error_function(self, sample_stream, func_with_error): - """Test function pod with function that raises error.""" - pod = FunctionPod(func_with_error) - result_stream = pod(sample_stream) - - # Should raise error when processing - with pytest.raises(ValueError, match="Function error"): - list(result_stream) - - def test_function_pod_with_datastore(self, func_single_output, data_store): - """Test function pod with datastore integration.""" - - # Create a function that uses the datastore - def datastore_function(inputs, datastore): - packet, tag = inputs[0] - # Store and retrieve from datastore - datastore["processed_count"] = datastore.get("processed_count", 0) + 1 - return f"item_{datastore['processed_count']}_{packet}" - - pod = FunctionPod(datastore_function, datastore=data_store) - - packets = ["a", "b", "c"] - tags = ["tag1", "tag2", "tag3"] - stream = SyncStreamFromLists(packets, tags) - - result_stream = pod(stream) - result = list(result_stream) - - # Should use datastore to track processing - expected = [("item_1_a", "tag1"), ("item_2_b", "tag2"), ("item_3_c", "tag3")] - assert result == expected - assert data_store["processed_count"] == 3 - - def test_function_pod_different_input_counts(self): - """Test function pod with functions expecting different input counts.""" - - # Function expecting 1 input - def single_input_func(inputs): - packet, tag = inputs[0] - return f"single_{packet}" - - # Function expecting 2 inputs - def double_input_func(inputs): - if len(inputs) < 2: - return None # Not enough inputs - packet1, tag1 = inputs[0] - packet2, tag2 = inputs[1] - return f"combined_{packet1}_{packet2}" - - packets = ["a", "b", "c", "d"] - tags = ["t1", "t2", "t3", "t4"] - stream = SyncStreamFromLists(packets, tags) - - # Test single input function - pod1 = FunctionPod(single_input_func) - result1 = list(pod1(stream)) - - assert len(result1) == 4 - assert result1[0][0] == "single_a" - assert result1[1][0] == "single_b" - - # Test double input function (if supported) - # This behavior depends on FunctionPod implementation - try: - pod2 = FunctionPod(double_input_func, input_count=2) - stream2 = SyncStreamFromLists(packets, tags) - result2 = list(pod2(stream2)) - - # Should produce fewer outputs since it needs 2 inputs per call - assert len(result2) <= len(packets) - - except (TypeError, AttributeError): - # FunctionPod might not support configurable input counts - pass - - def test_function_pod_with_none_outputs(self, sample_stream): - """Test function pod with function that sometimes returns None.""" - - def conditional_function(inputs): - packet, tag = inputs[0] - # Only process strings - if isinstance(packet, str): - return f"processed_{packet}" - return None # Skip non-strings - - # Mix of string and non-string packets - packets = ["hello", 42, "world", None, "test"] - tags = ["str1", "int1", "str2", "null1", "str3"] - stream = SyncStreamFromLists(packets, tags) - - pod = FunctionPod(conditional_function) - result_stream = pod(stream) - result = list(result_stream) - - # Should only process string packets - string_packets = [p for p in packets if isinstance(p, str)] - assert len(result) == len(string_packets) - - for packet, _ in result: - assert packet.startswith("processed_") - - def test_function_pod_stateful_function(self, data_store): - """Test function pod with stateful function using datastore.""" - - def stateful_function(inputs, datastore): - packet, tag = inputs[0] - - # Keep running total - if "total" not in datastore: - datastore["total"] = 0 - if "count" not in datastore: - datastore["count"] = 0 - - if isinstance(packet, (int, float)): - datastore["total"] += packet - datastore["count"] += 1 - avg = datastore["total"] / datastore["count"] - return f"avg_so_far_{avg:.2f}" - - return None - - packets = [10, 20, 30, 40] - tags = ["n1", "n2", "n3", "n4"] - stream = SyncStreamFromLists(packets, tags) - - pod = FunctionPod(stateful_function, datastore=data_store) - result_stream = pod(stream) - result = list(result_stream) - - # Should produce running averages - assert len(result) == 4 - assert result[0][0] == "avg_so_far_10.00" # 10/1 - assert result[1][0] == "avg_so_far_15.00" # (10+20)/2 - assert result[2][0] == "avg_so_far_20.00" # (10+20+30)/3 - assert result[3][0] == "avg_so_far_25.00" # (10+20+30+40)/4 - - def test_function_pod_generator_output(self, sample_stream): - """Test function pod with function that yields multiple outputs.""" - - def generator_function(inputs): - packet, tag = inputs[0] - # Yield multiple outputs for each input - for i in range(3): - yield f"{packet}_part_{i}" - - pod = FunctionPod(generator_function) - result_stream = pod(sample_stream) - result = list(result_stream) - - # Should produce 3 outputs per input - original_packets = list(sample_stream) - assert len(result) == len(original_packets) * 3 - - # Check pattern of outputs - for i, (packet, tag) in enumerate(result): - original_idx = i // 3 - part_idx = i % 3 - original_packet = original_packets[original_idx][0] - expected_packet = f"{original_packet}_part_{part_idx}" - assert packet == expected_packet - - def test_function_pod_complex_data_transformation(self): - """Test function pod with complex data transformation.""" - - def json_processor(inputs): - packet, tag = inputs[0] - - if isinstance(packet, dict): - # Extract all values and create separate outputs - for key, value in packet.items(): - yield f"{key}={value}" - else: - yield f"non_dict_{packet}" - - packets = [ - {"name": "Alice", "age": 30}, - "simple_string", - {"x": 1, "y": 2, "z": 3}, - ] - tags = ["person", "text", "coordinates"] - stream = SyncStreamFromLists(packets, tags) - - pod = FunctionPod(json_processor) - result_stream = pod(stream) - result = list(result_stream) - - # Should extract dict entries - result_packets = [packet for packet, _ in result] - - assert "name=Alice" in result_packets - assert "age=30" in result_packets - assert "non_dict_simple_string" in result_packets - assert "x=1" in result_packets - assert "y=2" in result_packets - assert "z=3" in result_packets - - def test_function_pod_empty_stream(self, func_single_output): - """Test function pod with empty stream.""" - empty_stream = SyncStreamFromLists([], []) - pod = FunctionPod(func_single_output) - result_stream = pod(empty_stream) - - result = list(result_stream) - assert len(result) == 0 - - def test_function_pod_large_stream(self, func_single_output): - """Test function pod with large stream.""" - packets = [f"packet_{i}" for i in range(1000)] - tags = [f"tag_{i}" for i in range(1000)] - stream = SyncStreamFromLists(packets, tags) - - pod = FunctionPod(func_single_output) - result_stream = pod(stream) - - # Process stream lazily to test memory efficiency - count = 0 - for packet, tag in result_stream: - count += 1 - if count == 100: # Stop early - break - - assert count == 100 - - def test_function_pod_chaining(self, func_single_output): - """Test chaining function pods.""" - - def second_processor(inputs): - packet, tag = inputs[0] - return f"second_{packet}" - - packets = ["a", "b", "c"] - tags = ["t1", "t2", "t3"] - stream = SyncStreamFromLists(packets, tags) - - # Chain two function pods - pod1 = FunctionPod(func_single_output) - pod2 = FunctionPod(second_processor) - - intermediate_stream = pod1(stream) - final_stream = pod2(intermediate_stream) - result = list(final_stream) - - # Should apply both transformations - expected = [ - ("second_processed_a", "t1"), - ("second_processed_b", "t2"), - ("second_processed_c", "t3"), - ] - assert result == expected diff --git a/tests/test_streams_operations/test_pods/test_function_pod_datastore.py b/tests/test_streams_operations/test_pods/test_function_pod_datastore.py deleted file mode 100644 index e3a2fa4..0000000 --- a/tests/test_streams_operations/test_pods/test_function_pod_datastore.py +++ /dev/null @@ -1,403 +0,0 @@ -""" -Test module for FunctionPod datastore integration. - -This module tests FunctionPod functionality when working with datastore operations, -including storage, retrieval, and state management across pod invocations. -""" - -import pytest -import tempfile -import os -from pathlib import Path - -from orcabridge.pod import FunctionPod -from orcabridge.stream import SyncStreamFromLists - - -@pytest.fixture -def temp_datastore(): - """Create a temporary datastore directory.""" - with tempfile.TemporaryDirectory() as temp_dir: - yield Path(temp_dir) - - -@pytest.fixture -def sample_stream_data(): - """Sample stream data for testing.""" - return [ - ({"file_id": 1}, {"content": "Hello World", "metadata": {"type": "text"}}), - ({"file_id": 2}, {"content": "Python Code", "metadata": {"type": "code"}}), - ( - {"file_id": 3}, - {"content": "Data Analysis", "metadata": {"type": "analysis"}}, - ), - ] - - -@pytest.fixture -def sample_stream(sample_stream_data): - """Create a sample stream.""" - tags, packets = zip(*sample_stream_data) - return SyncStreamFromLists(list(tags), list(packets)) - - -class TestFunctionPodDatastore: - """Test cases for FunctionPod datastore integration.""" - - def test_datastore_saving_function(self, temp_datastore, sample_stream): - """Test FunctionPod with function that saves data to datastore.""" - - def save_to_datastore(tag, packet, datastore): - """Save packet content to datastore.""" - file_id = tag["file_id"] - content = packet["content"] - - # Create file path - file_path = datastore / f"file_{file_id}.txt" - - # Save content - with open(file_path, "w") as f: - f.write(content) - - # Return tag and packet with file path - return tag, {**packet, "saved_path": str(file_path)} - - # Create pod with datastore - pod = FunctionPod(save_to_datastore, datastore=temp_datastore) - - # Process stream - result_stream = pod(sample_stream) - result = list(result_stream) - - # Check results - assert len(result) == 3 - - # Verify files were created - for i, (tag, packet) in enumerate(result, 1): - expected_path = temp_datastore / f"file_{i}.txt" - assert expected_path.exists() - - # Verify content - with open(expected_path, "r") as f: - saved_content = f.read() - - original_content = sample_stream_data[i - 1][1]["content"] - assert saved_content == original_content - - # Verify packet contains path - assert "saved_path" in packet - assert packet["saved_path"] == str(expected_path) - - def test_datastore_loading_function(self, temp_datastore): - """Test FunctionPod with function that loads data from datastore.""" - - # First, create some test files - test_files = { - "file1.txt": "Content of file 1", - "file2.txt": "Content of file 2", - "file3.txt": "Content of file 3", - } - - for filename, content in test_files.items(): - file_path = temp_datastore / filename - with open(file_path, "w") as f: - f.write(content) - - def load_from_datastore(tag, packet, datastore): - """Load content from datastore based on filename in packet.""" - filename = packet["filename"] - file_path = datastore / filename - - if file_path.exists(): - with open(file_path, "r") as f: - content = f.read() - return tag, {**packet, "content": content, "loaded": True} - else: - return tag, {**packet, "content": None, "loaded": False} - - # Create input stream with filenames - tags = [{"request_id": i} for i in range(1, 4)] - packets = [{"filename": f"file{i}.txt"} for i in range(1, 4)] - input_stream = SyncStreamFromLists(tags, packets) - - # Create pod with datastore - pod = FunctionPod(load_from_datastore, datastore=temp_datastore) - - # Process stream - result_stream = pod(input_stream) - result = list(result_stream) - - # Check results - assert len(result) == 3 - - for i, (tag, packet) in enumerate(result): - assert packet["loaded"] is True - assert packet["content"] == f"Content of file {i + 1}" - assert packet["filename"] == f"file{i + 1}.txt" - - def test_datastore_with_stateful_operations(self, temp_datastore): - """Test FunctionPod with stateful operations using datastore.""" - - def stateful_counter(tag, packet, datastore): - """Maintain a counter in datastore across invocations.""" - counter_file = datastore / "counter.txt" - - # Read current counter value - if counter_file.exists(): - with open(counter_file, "r") as f: - count = int(f.read().strip()) - else: - count = 0 - - # Increment counter - count += 1 - - # Save new counter value - with open(counter_file, "w") as f: - f.write(str(count)) - - return tag, {**packet, "sequence_number": count} - - # Create multiple input streams to test state persistence - tags1 = [{"batch": 1, "item": i} for i in range(3)] - packets1 = [{"data": f"item_{i}"} for i in range(3)] - stream1 = SyncStreamFromLists(tags1, packets1) - - tags2 = [{"batch": 2, "item": i} for i in range(2)] - packets2 = [{"data": f"item_{i}"} for i in range(2)] - stream2 = SyncStreamFromLists(tags2, packets2) - - # Create pod with datastore - pod = FunctionPod(stateful_counter, datastore=temp_datastore) - - # Process first stream - result1 = list(pod(stream1)) - - # Process second stream (should continue counting) - result2 = list(pod(stream2)) - - # Check that counter state persisted across streams - expected_sequences1 = [1, 2, 3] - expected_sequences2 = [4, 5] - - for i, (tag, packet) in enumerate(result1): - assert packet["sequence_number"] == expected_sequences1[i] - - for i, (tag, packet) in enumerate(result2): - assert packet["sequence_number"] == expected_sequences2[i] - - def test_datastore_error_handling(self, temp_datastore): - """Test error handling when datastore operations fail.""" - - def failing_datastore_operation(tag, packet, datastore): - """Function that tries to access non-existent file.""" - nonexistent_file = datastore / "nonexistent.txt" - - # This should raise an exception - with open(nonexistent_file, "r") as f: - content = f.read() - - return tag, {**packet, "content": content} - - tags = [{"id": 1}] - packets = [{"data": "test"}] - stream = SyncStreamFromLists(tags, packets) - - pod = FunctionPod(failing_datastore_operation, datastore=temp_datastore) - result_stream = pod(stream) - - # Should propagate the file not found error - with pytest.raises(FileNotFoundError): - list(result_stream) - - def test_datastore_with_subdirectories(self, temp_datastore): - """Test FunctionPod with datastore operations using subdirectories.""" - - def organize_by_type(tag, packet, datastore): - """Organize files by type in subdirectories.""" - file_type = packet["type"] - content = packet["content"] - file_id = tag["id"] - - # Create subdirectory - type_dir = datastore / file_type - type_dir.mkdir(exist_ok=True) - - # Save file in subdirectory - file_path = type_dir / f"{file_id}.txt" - with open(file_path, "w") as f: - f.write(content) - - return tag, {**packet, "organized_path": str(file_path)} - - # Create input with different types - tags = [{"id": f"file_{i}"} for i in range(4)] - packets = [ - {"type": "documents", "content": "Document content 1"}, - {"type": "images", "content": "Image metadata 1"}, - {"type": "documents", "content": "Document content 2"}, - {"type": "code", "content": "Python code"}, - ] - stream = SyncStreamFromLists(tags, packets) - - pod = FunctionPod(organize_by_type, datastore=temp_datastore) - result = list(pod(stream)) - - # Check that subdirectories were created - assert (temp_datastore / "documents").exists() - assert (temp_datastore / "images").exists() - assert (temp_datastore / "code").exists() - - # Check that files were saved in correct subdirectories - assert (temp_datastore / "documents" / "file_0.txt").exists() - assert (temp_datastore / "images" / "file_1.txt").exists() - assert (temp_datastore / "documents" / "file_2.txt").exists() - assert (temp_datastore / "code" / "file_3.txt").exists() - - def test_datastore_without_datastore_param(self): - """Test that function without datastore parameter works normally.""" - - def simple_function(tag, packet): - """Function that doesn't use datastore.""" - return tag, {**packet, "processed": True} - - # This should work even though we don't provide datastore - pod = FunctionPod(simple_function) - - tags = [{"id": 1}] - packets = [{"data": "test"}] - stream = SyncStreamFromLists(tags, packets) - - result = list(pod(stream)) - assert len(result) == 1 - assert result[0][1]["processed"] is True - - def test_datastore_metadata_operations(self, temp_datastore): - """Test FunctionPod with metadata tracking in datastore.""" - - def track_processing_metadata(tag, packet, datastore): - """Track processing metadata for each item.""" - import time - import json - - item_id = tag["id"] - processing_time = time.time() - - # Create metadata entry - metadata = { - "item_id": item_id, - "processed_at": processing_time, - "original_data": packet["data"], - "processing_status": "completed", - } - - # Save metadata - metadata_file = datastore / f"metadata_{item_id}.json" - with open(metadata_file, "w") as f: - json.dump(metadata, f) - - return tag, {**packet, "metadata_file": str(metadata_file)} - - tags = [{"id": f"item_{i}"} for i in range(3)] - packets = [{"data": f"data_{i}"} for i in range(3)] - stream = SyncStreamFromLists(tags, packets) - - pod = FunctionPod(track_processing_metadata, datastore=temp_datastore) - result = list(pod(stream)) - - # Check that metadata files were created - for i in range(3): - metadata_file = temp_datastore / f"metadata_item_{i}.json" - assert metadata_file.exists() - - # Verify metadata content - import json - - with open(metadata_file, "r") as f: - metadata = json.load(f) - - assert metadata["item_id"] == f"item_{i}" - assert metadata["original_data"] == f"data_{i}" - assert metadata["processing_status"] == "completed" - assert "processed_at" in metadata - - def test_datastore_with_generator_function(self, temp_datastore): - """Test FunctionPod with generator function that uses datastore.""" - - def split_and_save(tag, packet, datastore): - """Split content and save each part separately.""" - content = packet["content"] - parts = content.split() - base_id = tag["id"] - - for i, part in enumerate(parts): - part_id = f"{base_id}_part_{i}" - - # Save part to datastore - part_file = datastore / f"{part_id}.txt" - with open(part_file, "w") as f: - f.write(part) - - # Yield new tag-packet pair - new_tag = {**tag, "part_id": part_id, "part_index": i} - new_packet = {"part_content": part, "saved_to": str(part_file)} - yield new_tag, new_packet - - tags = [{"id": "doc1"}] - packets = [{"content": "Hello World Python Programming"}] - stream = SyncStreamFromLists(tags, packets) - - pod = FunctionPod(split_and_save, datastore=temp_datastore) - result = list(pod(stream)) - - # Should have 4 parts - assert len(result) == 4 - - expected_parts = ["Hello", "World", "Python", "Programming"] - for i, (tag, packet) in enumerate(result): - assert tag["part_index"] == i - assert packet["part_content"] == expected_parts[i] - - # Check that file was saved - saved_file = Path(packet["saved_to"]) - assert saved_file.exists() - - with open(saved_file, "r") as f: - saved_content = f.read() - assert saved_content == expected_parts[i] - - def test_datastore_path_validation(self, temp_datastore): - """Test that datastore path is properly validated and accessible.""" - - def check_datastore_access(tag, packet, datastore): - """Function that checks datastore accessibility.""" - # Check if datastore is a Path object - assert isinstance(datastore, Path) - - # Check if datastore directory exists and is writable - assert datastore.exists() - assert datastore.is_dir() - - # Test writing and reading - test_file = datastore / "access_test.txt" - with open(test_file, "w") as f: - f.write("test") - - with open(test_file, "r") as f: - content = f.read() - - assert content == "test" - - # Clean up - test_file.unlink() - - return tag, {**packet, "datastore_accessible": True} - - tags = [{"id": 1}] - packets = [{"data": "test"}] - stream = SyncStreamFromLists(tags, packets) - - pod = FunctionPod(check_datastore_access, datastore=temp_datastore) - result = list(pod(stream)) - - assert result[0][1]["datastore_accessible"] is True diff --git a/tests/test_streams_operations/test_pods/test_pod_base.py b/tests/test_streams_operations/test_pods/test_pod_base.py deleted file mode 100644 index ab69e82..0000000 --- a/tests/test_streams_operations/test_pods/test_pod_base.py +++ /dev/null @@ -1,274 +0,0 @@ -"""Tests for base Pod functionality.""" - -import pytest -from orcabridge.pod import Pod -from orcabridge.streams import SyncStreamFromLists - - -class TestPodBase: - """Test cases for base Pod class.""" - - def test_pod_creation(self): - """Test basic pod creation.""" - pod = Pod() - assert pod is not None - - def test_pod_call_interface(self, sample_stream): - """Test that pod implements callable interface.""" - pod = Pod() - - # Base Pod should be callable, but might not do anything useful - # This tests the interface exists - try: - result_stream = pod(sample_stream) - # If it succeeds, result should be a stream - assert hasattr(result_stream, "__iter__") - except NotImplementedError: - # Base Pod might not implement __call__ - pass - - def test_pod_with_empty_stream(self): - """Test pod with empty stream.""" - empty_stream = SyncStreamFromLists([], []) - pod = Pod() - - try: - result_stream = pod(empty_stream) - result = list(result_stream) - # If implemented, should handle empty stream - assert isinstance(result, list) - except NotImplementedError: - # Base Pod might not implement functionality - pass - - def test_pod_inheritance(self): - """Test that Pod can be inherited.""" - - class CustomPod(Pod): - def __call__(self, stream): - # Simple pass-through implementation - for packet, tag in stream: - yield packet, tag - - custom_pod = CustomPod() - packets = ["data1", "data2", "data3"] - tags = ["tag1", "tag2", "tag3"] - - stream = SyncStreamFromLists(packets, tags) - result_stream = custom_pod(stream) - result = list(result_stream) - - expected = list(zip(packets, tags)) - assert result == expected - - def test_pod_chaining(self): - """Test chaining pods together.""" - - class AddPrefixPod(Pod): - def __init__(self, prefix): - self.prefix = prefix - - def __call__(self, stream): - for packet, tag in stream: - yield f"{self.prefix}_{packet}", tag - - class AddSuffixPod(Pod): - def __init__(self, suffix): - self.suffix = suffix - - def __call__(self, stream): - for packet, tag in stream: - yield f"{packet}_{self.suffix}", tag - - packets = ["data1", "data2"] - tags = ["tag1", "tag2"] - stream = SyncStreamFromLists(packets, tags) - - # Chain two pods - prefix_pod = AddPrefixPod("PRE") - suffix_pod = AddSuffixPod("SUF") - - intermediate_stream = prefix_pod(stream) - final_stream = suffix_pod(intermediate_stream) - - result = list(final_stream) - - expected = [("PRE_data1_SUF", "tag1"), ("PRE_data2_SUF", "tag2")] - assert result == expected - - def test_pod_error_handling(self): - """Test pod error handling.""" - - class ErrorPod(Pod): - def __call__(self, stream): - for i, (packet, tag) in enumerate(stream): - if i == 1: # Error on second item - raise ValueError("Test error") - yield packet, tag - - packets = ["data1", "data2", "data3"] - tags = ["tag1", "tag2", "tag3"] - stream = SyncStreamFromLists(packets, tags) - - error_pod = ErrorPod() - result_stream = error_pod(stream) - - # Should raise error when processing second item - with pytest.raises(ValueError, match="Test error"): - list(result_stream) - - def test_pod_stateful_processing(self): - """Test pod with stateful processing.""" - - class CounterPod(Pod): - def __init__(self): - self.count = 0 - - def __call__(self, stream): - for packet, tag in stream: - self.count += 1 - yield (packet, self.count), tag - - packets = ["a", "b", "c"] - tags = ["t1", "t2", "t3"] - stream = SyncStreamFromLists(packets, tags) - - counter_pod = CounterPod() - result_stream = counter_pod(stream) - result = list(result_stream) - - expected = [(("a", 1), "t1"), (("b", 2), "t2"), (("c", 3), "t3")] - assert result == expected - - def test_pod_multiple_outputs_per_input(self): - """Test pod that produces multiple outputs per input.""" - - class DuplicatorPod(Pod): - def __call__(self, stream): - for packet, tag in stream: - yield f"{packet}_copy1", f"{tag}_1" - yield f"{packet}_copy2", f"{tag}_2" - - packets = ["data1", "data2"] - tags = ["tag1", "tag2"] - stream = SyncStreamFromLists(packets, tags) - - duplicator_pod = DuplicatorPod() - result_stream = duplicator_pod(stream) - result = list(result_stream) - - expected = [ - ("data1_copy1", "tag1_1"), - ("data1_copy2", "tag1_2"), - ("data2_copy1", "tag2_1"), - ("data2_copy2", "tag2_2"), - ] - assert result == expected - - def test_pod_filtering(self): - """Test pod that filters items.""" - - class FilterPod(Pod): - def __init__(self, predicate): - self.predicate = predicate - - def __call__(self, stream): - for packet, tag in stream: - if self.predicate(packet, tag): - yield packet, tag - - packets = [1, 2, 3, 4, 5] - tags = ["odd", "even", "odd", "even", "odd"] - stream = SyncStreamFromLists(packets, tags) - - # Filter for even numbers - def is_even(packet, tag): - return packet % 2 == 0 - - filter_pod = FilterPod(is_even) - result_stream = filter_pod(stream) - result = list(result_stream) - - expected = [(2, "even"), (4, "even")] - assert result == expected - - def test_pod_transformation(self): - """Test pod that transforms data.""" - - class TransformPod(Pod): - def __init__(self, transform_func): - self.transform_func = transform_func - - def __call__(self, stream): - for packet, tag in stream: - new_packet, new_tag = self.transform_func(packet, tag) - yield new_packet, new_tag - - packets = ["hello", "world"] - tags = ["greeting", "noun"] - stream = SyncStreamFromLists(packets, tags) - - def uppercase_transform(packet, tag): - return packet.upper(), tag.upper() - - transform_pod = TransformPod(uppercase_transform) - result_stream = transform_pod(stream) - result = list(result_stream) - - expected = [("HELLO", "GREETING"), ("WORLD", "NOUN")] - assert result == expected - - def test_pod_aggregation(self): - """Test pod that aggregates data.""" - - class SumPod(Pod): - def __call__(self, stream): - total = 0 - count = 0 - for packet, tag in stream: - if isinstance(packet, (int, float)): - total += packet - count += 1 - - if count > 0: - yield total, f"sum_of_{count}_items" - - packets = [1, 2, 3, 4, 5] - tags = ["n1", "n2", "n3", "n4", "n5"] - stream = SyncStreamFromLists(packets, tags) - - sum_pod = SumPod() - result_stream = sum_pod(stream) - result = list(result_stream) - - expected = [(15, "sum_of_5_items")] - assert result == expected - - def test_pod_with_complex_data(self): - """Test pod with complex data structures.""" - - class ExtractorPod(Pod): - def __call__(self, stream): - for packet, tag in stream: - if isinstance(packet, dict): - for key, value in packet.items(): - yield value, f"{tag}_{key}" - else: - yield packet, tag - - packets = [{"a": 1, "b": 2}, "simple_string", {"x": 10, "y": 20, "z": 30}] - tags = ["dict1", "str1", "dict2"] - stream = SyncStreamFromLists(packets, tags) - - extractor_pod = ExtractorPod() - result_stream = extractor_pod(stream) - result = list(result_stream) - - # Should extract dict values as separate items - assert len(result) == 6 # 2 + 1 + 3 - assert (1, "dict1_a") in result - assert (2, "dict1_b") in result - assert ("simple_string", "str1") in result - assert (10, "dict2_x") in result - assert (20, "dict2_y") in result - assert (30, "dict2_z") in result diff --git a/tests/test_streams_operations/test_sources/__init__.py b/tests/test_streams_operations/test_sources/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_streams_operations/test_sources/test_glob_source.py b/tests/test_streams_operations/test_sources/test_glob_source.py deleted file mode 100644 index 62875a0..0000000 --- a/tests/test_streams_operations/test_sources/test_glob_source.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Tests for GlobSource functionality.""" - -import pytest -import os -from pathlib import Path -from orcabridge.sources import GlobSource - - -class TestGlobSource: - """Test cases for GlobSource.""" - - def test_glob_source_basic(self, test_files, temp_dir): - """Test basic glob source functionality.""" - # Create a glob pattern for txt files - pattern = os.path.join(temp_dir, "*.txt") - - source = GlobSource(pattern) - stream = source() - - result = list(stream) - - # Should find all txt files - txt_files = [f for f in test_files if f.endswith(".txt")] - assert len(result) == len(txt_files) - - # Check that all found files are actual files - for file_content, file_path in result: - assert os.path.isfile(file_path) - assert file_path.endswith(".txt") - assert isinstance(file_content, str) # Text content - - def test_glob_source_specific_pattern(self, test_files, temp_dir): - """Test glob source with specific pattern.""" - # Look for files starting with "file1" - pattern = os.path.join(temp_dir, "file1*") - - source = GlobSource(pattern) - stream = source() - - result = list(stream) - - # Should find only file1.txt - assert len(result) == 1 - file_content, file_path = result[0] - assert "file1.txt" in file_path - assert file_content == "Content of file 1" - - def test_glob_source_binary_files(self, test_files, temp_dir): - """Test glob source with binary files.""" - # Look for binary files - pattern = os.path.join(temp_dir, "*.bin") - - source = GlobSource(pattern) - stream = source() - - result = list(stream) - - # Should find all binary files - bin_files = [f for f in test_files if f.endswith(".bin")] - assert len(result) == len(bin_files) - - for file_content, file_path in result: - assert file_path.endswith(".bin") - assert isinstance(file_content, bytes) # Binary content - - def test_glob_source_json_files(self, test_files, temp_dir): - """Test glob source with JSON files.""" - pattern = os.path.join(temp_dir, "*.json") - - source = GlobSource(pattern) - stream = source() - - result = list(stream) - - # Should find all JSON files - json_files = [f for f in test_files if f.endswith(".json")] - assert len(result) == len(json_files) - - for file_content, file_path in result: - assert file_path.endswith(".json") - # Content should be the raw JSON string - assert '"key"' in file_content - - def test_glob_source_no_matches(self, temp_dir): - """Test glob source when pattern matches no files.""" - pattern = os.path.join(temp_dir, "*.nonexistent") - - source = GlobSource(pattern) - stream = source() - - result = list(stream) - assert len(result) == 0 - - def test_glob_source_recursive_pattern(self, temp_dir): - """Test glob source with recursive pattern.""" - # Create subdirectory with files - subdir = os.path.join(temp_dir, "subdir") - os.makedirs(subdir, exist_ok=True) - - sub_file = os.path.join(subdir, "sub_file.txt") - with open(sub_file, "w") as f: - f.write("Subdirectory content") - - # Use recursive pattern - pattern = os.path.join(temp_dir, "**", "*.txt") - - source = GlobSource(pattern) - stream = source() - - result = list(stream) - - # Should find files in both root and subdirectory - txt_files = [file_path for _, file_path in result] - - # Check that we found files in subdirectory - sub_files = [f for f in txt_files if "subdir" in f] - assert len(sub_files) > 0 - - # Verify content of subdirectory file - sub_result = [ - (content, path) for content, path in result if "sub_file.txt" in path - ] - assert len(sub_result) == 1 - assert sub_result[0][0] == "Subdirectory content" - - def test_glob_source_absolute_vs_relative_paths(self, test_files, temp_dir): - """Test glob source with both absolute and relative paths.""" - # Test with absolute path - abs_pattern = os.path.join(os.path.abspath(temp_dir), "*.txt") - abs_source = GlobSource(abs_pattern) - abs_stream = abs_source() - abs_result = list(abs_stream) - - # Test with relative path (if possible) - current_dir = os.getcwd() - try: - os.chdir(temp_dir) - rel_pattern = "*.txt" - rel_source = GlobSource(rel_pattern) - rel_stream = rel_source() - rel_result = list(rel_stream) - - # Should find the same number of files - assert len(abs_result) == len(rel_result) - - finally: - os.chdir(current_dir) - - def test_glob_source_empty_directory(self, temp_dir): - """Test glob source in empty directory.""" - empty_dir = os.path.join(temp_dir, "empty_subdir") - os.makedirs(empty_dir, exist_ok=True) - - pattern = os.path.join(empty_dir, "*") - - source = GlobSource(pattern) - stream = source() - - result = list(stream) - assert len(result) == 0 - - def test_glob_source_large_directory(self, temp_dir): - """Test glob source with many files.""" - # Create many files - for i in range(50): - file_path = os.path.join(temp_dir, f"bulk_file_{i:03d}.txt") - with open(file_path, "w") as f: - f.write(f"Content of bulk file {i}") - - pattern = os.path.join(temp_dir, "bulk_file_*.txt") - - source = GlobSource(pattern) - stream = source() - - result = list(stream) - - assert len(result) == 50 - - # Check that files are properly ordered (if implementation sorts) - file_paths = [file_path for _, file_path in result] - for i, file_path in enumerate(file_paths): - if "bulk_file_000.txt" in file_path: - # Found the first file, check content - content = [content for content, path in result if path == file_path][0] - assert "Content of bulk file 0" in content - - def test_glob_source_special_characters_in_filenames(self, temp_dir): - """Test glob source with special characters in filenames.""" - # Create files with special characters - special_files = [ - "file with spaces.txt", - "file-with-dashes.txt", - "file_with_underscores.txt", - "file.with.dots.txt", - ] - - for filename in special_files: - file_path = os.path.join(temp_dir, filename) - with open(file_path, "w") as f: - f.write(f"Content of {filename}") - - pattern = os.path.join(temp_dir, "file*.txt") - - source = GlobSource(pattern) - stream = source() - - result = list(stream) - - # Should find all special files plus any existing test files - found_files = [os.path.basename(file_path) for _, file_path in result] - - for special_file in special_files: - assert special_file in found_files - - def test_glob_source_mixed_file_types(self, test_files, temp_dir): - """Test glob source that matches multiple file types.""" - # Pattern that matches both txt and json files - pattern = os.path.join(temp_dir, "file*") - - source = GlobSource(pattern) - stream = source() - - result = list(stream) - - # Should find both text and json files - file_extensions = [os.path.splitext(file_path)[1] for _, file_path in result] - - assert ".txt" in file_extensions - assert ".json" in file_extensions - - def test_glob_source_case_sensitivity(self, temp_dir): - """Test glob source case sensitivity.""" - # Create files with different cases - files = ["Test.TXT", "test.txt", "TEST.txt"] - - for filename in files: - file_path = os.path.join(temp_dir, filename) - with open(file_path, "w") as f: - f.write(f"Content of {filename}") - - # Test exact case match - pattern = os.path.join(temp_dir, "test.txt") - source = GlobSource(pattern) - stream = source() - result = list(stream) - - # Should find at least the exact match - found_files = [os.path.basename(file_path) for _, file_path in result] - assert "test.txt" in found_files - - def test_glob_source_symlinks(self, temp_dir): - """Test glob source with symbolic links (if supported).""" - # Create a regular file - original_file = os.path.join(temp_dir, "original.txt") - with open(original_file, "w") as f: - f.write("Original content") - - try: - # Create a symbolic link - link_file = os.path.join(temp_dir, "link.txt") - os.symlink(original_file, link_file) - - pattern = os.path.join(temp_dir, "*.txt") - source = GlobSource(pattern) - stream = source() - result = list(stream) - - # Should find both original and link - file_paths = [file_path for _, file_path in result] - original_found = any("original.txt" in path for path in file_paths) - link_found = any("link.txt" in path for path in file_paths) - - assert original_found - # Link behavior depends on implementation - - except (OSError, NotImplementedError): - # Symlinks not supported on this system - pass - - def test_glob_source_error_handling(self, temp_dir): - """Test glob source error handling.""" - # Test with invalid pattern - invalid_pattern = "/nonexistent/path/*.txt" - - source = GlobSource(invalid_pattern) - stream = source() - - # Should handle gracefully (empty result or specific error) - try: - result = list(stream) - # If no error, should be empty - assert len(result) == 0 - except (OSError, FileNotFoundError): - # Expected error for invalid path - pass - - def test_glob_source_file_permissions(self, temp_dir): - """Test glob source with files of different permissions.""" - # Create a file and try to change permissions - restricted_file = os.path.join(temp_dir, "restricted.txt") - with open(restricted_file, "w") as f: - f.write("Restricted content") - - try: - # Try to make file unreadable - os.chmod(restricted_file, 0o000) - - pattern = os.path.join(temp_dir, "restricted.txt") - source = GlobSource(pattern) - stream = source() - - # Should handle permission errors gracefully - try: - result = list(stream) - # If successful, content might be empty or error - except PermissionError: - # Expected for restricted files - pass - - finally: - # Restore permissions for cleanup - try: - os.chmod(restricted_file, 0o644) - except: - pass diff --git a/tests/test_streams_operations/test_streams/__init__.py b/tests/test_streams_operations/test_streams/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_streams_operations/test_streams/test_base_classes.py b/tests/test_streams_operations/test_streams/test_base_classes.py deleted file mode 100644 index 7eecab3..0000000 --- a/tests/test_streams_operations/test_streams/test_base_classes.py +++ /dev/null @@ -1,514 +0,0 @@ -""" -Test module for base Stream and SyncStream classes. - -This module tests the fundamental stream functionality including -iteration, flow operations, labeling, key management, and invocation tracking. -""" - -from collections.abc import Collection, Iterator -import pytest -from unittest.mock import Mock, MagicMock -from abc import ABC - -from orcabridge.base import Stream, SyncStream, Operation, Invocation -from orcabridge.mappers import Join -from orcabridge.streams import SyncStreamFromLists, SyncStreamFromGenerator -from orcabridge.types import Tag, Packet - - -class ConcreteStream(Stream): - """Concrete Stream implementation for testing.""" - - def __init__(self, data: Collection[tuple[Tag, Packet]], label=None): - super().__init__(label=label) - self.data = data - - def __iter__(self): - return iter(self.data) - - -class ConcreteSyncStream(SyncStream): - """Concrete SyncStream implementation for testing.""" - - def __init__(self, data: Collection[tuple[Tag, Packet]], label=None): - super().__init__(label=label) - self.data = data - - def __iter__(self) -> Iterator[tuple[Tag, Packet]]: - return iter(self.data) - - -@pytest.fixture -def sample_stream_data(): - """Sample stream data for testing.""" - return [ - ({"id": 1, "type": "text"}, {"content": "Hello", "size": 5}), - ({"id": 2, "type": "text"}, {"content": "World", "size": 5}), - ({"id": 3, "type": "number"}, {"value": 42, "unit": "count"}), - ] - - -@pytest.fixture -def sample_tags_packets(sample_stream_data): - """Extract tags and packets from sample data.""" - tags, packets = zip(*sample_stream_data) - return list(tags), list(packets) - - -class TestStreamBase: - """Test cases for base Stream class.""" - - def test_stream_labels(self, sample_stream_data): - """Test Stream initialization with and without label.""" - - # Without label - stream = ConcreteStream(sample_stream_data) - assert stream.label == "ConcreteStream", ( - f"Label should default to class name {stream.__class__.__name__} but got {stream.label}" - ) - assert stream.invocation is None - - # With label - labeled_stream = ConcreteStream(sample_stream_data, label="test_stream") - assert labeled_stream.label == "test_stream" - - def test_stream_iteration(self, sample_stream_data): - """Test that Stream can be iterated over.""" - stream = ConcreteStream(sample_stream_data) - - result = list(stream) - assert result == sample_stream_data - - # Test multiple iterations - result2 = list(stream) - assert result2 == sample_stream_data - - def test_stream_flow(self, sample_stream_data): - """Test Stream.flow() method.""" - stream = ConcreteStream(sample_stream_data) - - flowed = stream.flow() - assert flowed == sample_stream_data - assert isinstance(flowed, list) - - def test_stream_identity_structure(self, sample_stream_data): - """Test Stream identity structure.""" - stream = ConcreteStream(sample_stream_data) - - # Default identity structure for uninvoked stream should be None - identity = stream.identity_structure() - # TODO: consider alternative behavior for identity structure for streams - assert identity is None - - def test_stream_keys_default(self, sample_stream_data): - """Test Stream keys method default behavior.""" - stream = ConcreteStream(sample_stream_data) - - tag_keys, packet_keys = stream.keys() - # Default implementation will be based on the first sample from the stream - assert tag_keys is not None and set(tag_keys) == set(["id", "type"]) - - assert packet_keys is not None and set(packet_keys) == set(["content", "size"]) - - def test_stream_repr(self, sample_stream_data): - """Test Stream string representation.""" - stream = ConcreteStream(sample_stream_data, label="test_stream") - - repr_str = repr(stream) - assert "ConcreteStream" in repr_str - assert "test_stream" in repr_str - - -class TestSyncStreamBase: - """Test cases for SyncStream base class.""" - - def test_syncstream_initialization(self, sample_stream_data): - """Test SyncStream initialization.""" - sync_stream = ConcreteSyncStream(sample_stream_data) - - assert isinstance(sync_stream, Stream) - assert isinstance(sync_stream, SyncStream) - - def test_syncstream_rshift_operator_dict(self, sample_stream_data): - """Test SyncStream >> operator with dictionary mapping.""" - sync_stream = SyncStreamFromLists(paired=sample_stream_data) - - # Test with dictionary (should use MapPackets) - mapping = {"content": "text", "size": "length"} - mapped_stream = sync_stream >> mapping - - assert isinstance(mapped_stream, SyncStream) - result = list(mapped_stream) - - # Check that mapping was applied - for (tag, packet), (ref_tag, ref_packet) in zip(result, sample_stream_data): - if "content" in ref_packet: - assert "text" in packet - assert packet["text"] == ref_packet["content"] - if "size" in ref_packet: - assert "length" in packet - assert packet["length"] == ref_packet["size"] - - def test_syncstream_rshift_operator_callable(self, sample_stream_data): - """Test SyncStream >> operator with callable transformer.""" - sync_stream = SyncStreamFromLists(paired=sample_stream_data) - - def add_processed_flag(stream): - """Add processed flag to all packets.""" - - def generator(): - for tag, packet in stream: - yield tag, {**packet, "processed": True} - - return SyncStreamFromGenerator(generator) - - transformed = sync_stream >> add_processed_flag - result = list(transformed) - - # Check that all packets have processed flag - for _, packet in result: - assert packet["processed"] is True - - def test_syncstream_mul_operator(self, sample_tags_packets): - """Test SyncStream * operator for joining streams.""" - tags1, packets1 = sample_tags_packets - stream1 = SyncStreamFromLists(tags1[:2], packets1[:2]) - - tags2 = [{"id": 1, "category": "A"}, {"id": 2, "category": "B"}] - packets2 = [{"priority": "high"}, {"priority": "low"}] - stream2 = SyncStreamFromLists(tags2, packets2) - - # Test join operation - joined = stream1 * stream2 - - assert joined.invocation is not None and isinstance( - joined.invocation.operation, Join - ), ( - f"* operator should be resulting from an Join object invocation but got {type(joined)}" - ) - result = list(joined) - - # Should have joined results where tags match - assert len(result) >= 0 # Exact count depends on tag matching logic - - def test_syncstream_mul_operator_type_error(self, sample_tags_packets): - """Test SyncStream * operator with invalid type.""" - tags, packets = sample_tags_packets - sync_stream = SyncStreamFromLists(tags, packets) - - with pytest.raises(TypeError, match="other must be a SyncStream"): - sync_stream * "not_a_stream" # type: ignore - - def test_syncstream_rshift_invalid_type(self, sample_tags_packets): - """Test SyncStream >> operator with invalid transformer type.""" - tags, packets = sample_tags_packets - sync_stream = SyncStreamFromLists(tags, packets) - - # Should handle non-dict, non-callable gracefully or raise appropriate error - with pytest.raises((TypeError, AttributeError)): - sync_stream >> 123 # type: ignore - - def test_syncstream_chaining_operations(self, sample_tags_packets): - """Test chaining multiple SyncStream operations.""" - tags, packets = sample_tags_packets - sync_stream = SyncStreamFromLists(tags, packets) - - # Chain multiple transformations - def add_flag(stream): - def generator(): - for tag, packet in stream: - yield tag, {**packet, "chained": True} - - return SyncStreamFromGenerator(generator) - - def add_counter(stream): - def generator(): - for i, (tag, packet) in enumerate(stream): - yield tag, {**packet, "counter": i} - - return SyncStreamFromGenerator(generator) - - result_stream = sync_stream >> add_flag >> add_counter - result = list(result_stream) - - # Check that both transformations were applied - for i, (tag, packet) in enumerate(result): - assert packet["chained"] is True - assert packet["counter"] == i - - -class TestSyncStreamFromLists: - """Test cases for SyncStreamFromLists implementation.""" - - def test_creation_from_lists(self, sample_tags_packets): - """Test SyncStreamFromLists creation.""" - tags, packets = sample_tags_packets - stream = SyncStreamFromLists(tags, packets) - - assert isinstance(stream, SyncStream) - result = list(stream) - - expected = list(zip(tags, packets)) - assert result == expected - - def test_creation_with_mismatched_lengths(self): - """Test SyncStreamFromLists with mismatched tag/packet lengths.""" - tags = [{"id": "1"}, {"id": "2"}] - packets = [{"data": "a"}] # One less packet - - # If strict (default), should raise a ValueError - with pytest.raises(ValueError): - stream = SyncStreamFromLists(tags, packets, strict=True) - - # If not strict, should handle gracefully and create based on the shortest length - stream = SyncStreamFromLists(tags, packets, strict=False) - result = list(stream) - - assert len(result) == 1 - assert result[0] == ({"id": "1"}, {"data": "a"}) - - def test_empty_lists(self): - """Test SyncStreamFromLists with empty lists.""" - stream = SyncStreamFromLists([], []) - result = list(stream) - - assert result == [] - - def test_keys_inference(self, sample_tags_packets): - """Test key inference from tag and packet data.""" - tags, packets = sample_tags_packets - stream = SyncStreamFromLists(tags, packets) - - tag_keys, packet_keys = stream.keys() - - # Should infer keys from the first element - expected_tag_keys = set() - expected_packet_keys = set() - - if tags: - expected_tag_keys.update(tags[0].keys()) - if packets: - expected_packet_keys.update(packets[0].keys()) - - assert tag_keys is not None and set(tag_keys) == expected_tag_keys - assert packet_keys is not None and set(packet_keys) == expected_packet_keys - - def test_multiple_iterations(self, sample_tags_packets): - """Test that SyncStreamFromLists can be iterated multiple times.""" - tags, packets = sample_tags_packets - stream = SyncStreamFromLists(tags, packets) - - result1 = list(stream) - result2 = list(stream) - - assert result1 == result2 - assert len(result1) == len(tags) - - -class TestSyncStreamFromGenerator: - """Test cases for SyncStreamFromGenerator implementation.""" - - def test_creation_from_generator(self, sample_stream_data): - """Test SyncStreamFromGenerator creation.""" - - def generator(): - for item in sample_stream_data: - yield item - - stream = SyncStreamFromGenerator(generator) - assert isinstance(stream, SyncStream) - - result = list(stream) - assert result == sample_stream_data - - def test_generator_multiple_iterations(self, sample_stream_data): - """Test that generator-based streams can be iterated multiple times""" - - def generator(): - for item in sample_stream_data: - yield item - - stream = SyncStreamFromGenerator(generator) - - # First iteration should work - result1 = list(stream) - assert result1 == sample_stream_data - - # Second iteration should work (new iterator instance) - result2 = list(stream) - assert result2 == sample_stream_data - - def test_empty_generator(self): - """Test SyncStreamFromGenerator with empty generator.""" - - def empty_generator(): - return - yield # This line is never reached - - stream = SyncStreamFromGenerator(empty_generator) - result = list(stream) - - assert result == [] - - def test_generator_with_exception(self): - """Test SyncStreamFromGenerator with generator that raises exception.""" - - def failing_generator(): - yield ({"id": "1"}, {"data": "ok"}) - raise ValueError("Generator failed") - - stream = SyncStreamFromGenerator(failing_generator) - - # Should propagate the exception - with pytest.raises(ValueError, match="Generator failed"): - list(stream) - - def test_lazy_evaluation(self): - """Test that SyncStreamFromGenerator is lazily evaluated.""" - call_count = {"count": 0} - - def counting_generator(): - call_count["count"] += 1 - yield ({"id": "1"}, {"data": "test"}) - - stream = SyncStreamFromGenerator(counting_generator) - - # Generator should not be called until iteration starts - assert call_count["count"] == 0 - - # Start iteration - iterator = iter(stream) - next(iterator) - - # Now generator should have been called - assert call_count["count"] == 1 - - def test_inferred_keys_with_generator(self): - """Test key inference with generator streams.""" - - def sample_generator(): - yield ({"id": "1", "type": "A"}, {"value": "10", "name": "test"}) - yield ({"id": "2", "type": "B"}, {"value": "20", "size": "5"}) - - stream = SyncStreamFromGenerator(sample_generator) - - # Keys should be inferred from generated data - tag_keys, packet_keys = stream.keys() - - # Note: This depends on implementation - may need to consume stream - # to infer keys, or may return None - if tag_keys is not None: - assert "id" in tag_keys - assert "type" in tag_keys - - if packet_keys is not None: - assert "value" in packet_keys - - def test_specified_keys_with_generator(self): - """Test key inference with generator streams.""" - - def sample_generator(): - yield ({"id": "1", "type": "A"}, {"value": "10", "name": "test"}) - yield ({"id": "2", "type": "B"}, {"value": "20", "size": "5"}) - - # Specify keys explicitly -- it need not match the actual content - stream = SyncStreamFromGenerator( - sample_generator, tag_keys=["id"], packet_keys=["group"] - ) - - # Keys should be based on what was specified at the construction - tag_keys, packet_keys = stream.keys() - - # Note: This depends on implementation - may need to consume stream - # to infer keys, or may return None - if tag_keys is not None: - assert "id" in tag_keys - assert "type" not in tag_keys - - if packet_keys is not None: - assert "value" not in packet_keys - assert "group" in packet_keys - - -class TestStreamIntegration: - """Integration tests for stream functionality.""" - - def test_stream_composition(self, sample_tags_packets): - """Test composing different stream types.""" - tags, packets = sample_tags_packets - - # Create streams from different sources - list_stream = SyncStreamFromLists(tags[:2], packets[:2]) - - def gen_func(): - yield tags[2], packets[2] - - gen_stream = SyncStreamFromGenerator(gen_func) - - # Both should work similarly - list_result = list(list_stream) - gen_result = list(gen_stream) - - assert len(list_result) == 2 - assert len(gen_result) == 1 - - # Combine results - all_data = list_result + gen_result - assert len(all_data) == 3 - - def test_stream_with_complex_data(self): - """Test streams with complex nested data.""" - complex_tags = [ - {"id": 1, "metadata": {"type": "nested", "level": 1}}, - {"id": 2, "metadata": {"type": "nested", "level": 2}}, - ] - complex_packets = [ - {"data": {"values": [1, 2, 3], "config": {"enabled": True}}}, - {"data": {"values": [4, 5, 6], "config": {"enabled": False}}}, - ] - - stream = SyncStreamFromLists(complex_tags, complex_packets) - result = list(stream) - - assert len(result) == 2 - - # Verify complex data is preserved - tag, packet = result[0] - assert tag["metadata"]["type"] == "nested" - assert packet["data"]["values"] == [1, 2, 3] - assert packet["data"]["config"]["enabled"] is True - - def test_stream_memory_efficiency(self): - """Test that generator streams don't consume excessive memory.""" - - def large_generator(): - for i in range(1000): - yield ({"id": i}, {"value": i * 2}) - - stream = SyncStreamFromGenerator(large_generator) - - # Process in chunks to test memory efficiency - count = 0 - for tag, packet in stream: - count += 1 - if count > 10: # Just test first few items - break - - assert count == 11 # Processed 11 items - - def test_stream_error_propagation(self, sample_tags_packets): - """Test that errors in stream data are properly propagated.""" - tags, packets = sample_tags_packets - - # Create stream with invalid data - invalid_tags = tags + [None] # Add invalid tag - invalid_packets = packets + [{"data": "valid"}] - - stream = SyncStreamFromLists(invalid_tags, invalid_packets) - - # Should handle None tags gracefully or raise appropriate error - result = list(stream) - - # The None tag should be included as-is - assert len(result) == 4 - assert result[-1] == (None, {"data": "valid"}) diff --git a/tests/test_streams_operations/test_streams/test_sync_stream_implementations.py b/tests/test_streams_operations/test_streams/test_sync_stream_implementations.py deleted file mode 100644 index 4aaca57..0000000 --- a/tests/test_streams_operations/test_streams/test_sync_stream_implementations.py +++ /dev/null @@ -1,578 +0,0 @@ -""" -Test module for SyncStream concrete implementations. - -This module tests the specific implementations of SyncStream including -SyncStreamFromLists and SyncStreamFromGenerator, focusing on their unique -behaviors, performance characteristics, and edge cases. -""" - -import pytest -from unittest.mock import Mock, patch -import gc - -from orcabridge.streams import SyncStreamFromLists, SyncStreamFromGenerator -from orcabridge.base import SyncStream - - -@pytest.fixture -def sample_data(): - """Sample data for stream testing.""" - return [ - ({"id": 1, "type": "doc"}, {"content": "Hello", "size": 5}), - ({"id": 2, "type": "doc"}, {"content": "World", "size": 5}), - ({"id": 3, "type": "img"}, {"pixels": 1920 * 1080, "format": "png"}), - ] - - -@pytest.fixture -def sample_tags_packets(sample_data): - """Extract tags and packets separately.""" - tags, packets = zip(*sample_data) - return list(tags), list(packets) - - -class TestSyncStreamFromLists: - """Comprehensive tests for SyncStreamFromLists implementation.""" - - def test_basic_creation_and_iteration(self, sample_tags_packets): - """Test basic creation and iteration functionality.""" - tags, packets = sample_tags_packets - stream = SyncStreamFromLists(tags, packets) - - # Test basic properties - assert isinstance(stream, SyncStream) - - # Test iteration - result = list(stream) - expected = list(zip(tags, packets)) - assert result == expected - - def test_creation_with_empty_lists(self): - """Test creation with empty tag and packet lists.""" - stream = SyncStreamFromLists([], []) - - result = list(stream) - assert result == [] - - # Test keys with empty stream - tag_keys, packet_keys = stream.keys() - assert tag_keys == [] - assert packet_keys == [] - - def test_creation_with_single_item(self): - """Test creation with single tag-packet pair.""" - tags = [{"id": 1}] - packets = [{"data": "test"}] - stream = SyncStreamFromLists(tags, packets) - - result = list(stream) - assert result == [({"id": 1}, {"data": "test"})] - - def test_mismatched_list_lengths(self): - """Test behavior with different length tag and packet lists.""" - tags = [{"id": 1}, {"id": 2}, {"id": 3}] - packets = [{"data": "a"}, {"data": "b"}] # Shorter list - - stream = SyncStreamFromLists(tags, packets) - result = list(stream) - - # Should zip to shortest length - assert len(result) == 2 - assert result == [ - ({"id": 1}, {"data": "a"}), - ({"id": 2}, {"data": "b"}), - ] - - def test_keys_inference_comprehensive(self): - """Test comprehensive key inference from data.""" - tags = [ - {"id": 1, "type": "A", "category": "test"}, - {"id": 2, "type": "B"}, # Missing category - {"id": 3, "category": "prod", "extra": "value"}, # Missing type, has extra - ] - packets = [ - {"data": "hello", "size": 5, "meta": {"info": "test"}}, - {"data": "world", "count": 10}, # Missing size, meta; has count - {"size": 3, "format": "json"}, # Missing data; has format - ] - - stream = SyncStreamFromLists(tags, packets) - tag_keys, packet_keys = stream.keys() - - # Should include all keys found across all items - expected_tag_keys = {"id", "type", "category", "extra"} - expected_packet_keys = {"data", "size", "meta", "count", "format"} - - assert set(tag_keys) == expected_tag_keys - assert set(packet_keys) == expected_packet_keys - - def test_multiple_iterations_consistency(self, sample_tags_packets): - """Test that multiple iterations return consistent results.""" - tags, packets = sample_tags_packets - stream = SyncStreamFromLists(tags, packets) - - # Multiple iterations should be identical - result1 = list(stream) - result2 = list(stream) - result3 = list(stream) - - assert result1 == result2 == result3 - assert len(result1) == len(tags) - - def test_iteration_with_generators_as_input(self): - """Test creation with generator inputs (should work since converted to lists).""" - - def tag_gen(): - for i in range(3): - yield {"id": i} - - def packet_gen(): - for i in range(3): - yield {"value": i * 10} - - # Should accept generators and convert them - stream = SyncStreamFromLists(list(tag_gen()), list(packet_gen())) - result = list(stream) - - assert len(result) == 3 - assert result[0] == ({"id": 0}, {"value": 0}) - assert result[1] == ({"id": 1}, {"value": 10}) - assert result[2] == ({"id": 2}, {"value": 20}) - - def test_memory_efficiency_large_lists(self): - """Test memory efficiency with large lists.""" - # Create large but not excessive lists - size = 1000 - tags = [{"id": i} for i in range(size)] - packets = [{"value": i * 2} for i in range(size)] - - stream = SyncStreamFromLists(tags, packets) - - # Should be able to iterate without memory issues - count = 0 - for tag, packet in stream: - count += 1 - assert tag["id"] == packet["value"] // 2 - - assert count == size - - def test_data_types_preservation(self): - """Test that various data types are preserved correctly.""" - tags = [ - {"int": 42, "float": 3.14, "str": "hello"}, - {"bool": True, "none": None, "list": [1, 2, 3]}, - {"dict": {"nested": "value"}, "tuple": (1, 2)}, - ] - packets = [ - {"complex": 1 + 2j, "bytes": b"binary", "set": {1, 2, 3}}, - {"lambda": lambda x: x * 2}, # Function objects - {"custom": {"deep": {"nesting": {"value": 123}}}}, - ] - - stream = SyncStreamFromLists(tags, packets) - result = list(stream) - - # Verify data type preservation - assert result[0][0]["int"] == 42 - assert result[0][0]["float"] == 3.14 - assert result[0][1]["complex"] == 1 + 2j - assert result[0][1]["bytes"] == b"binary" - - assert result[1][0]["bool"] is True - assert result[1][0]["none"] is None - assert callable(result[1][1]["lambda"]) - - assert result[2][0]["dict"]["nested"] == "value" - assert result[2][1]["custom"]["deep"]["nesting"]["value"] == 123 - - def test_mutable_data_safety(self): - """Test that mutable data doesn't cause unexpected sharing.""" - shared_dict = {"shared": "value"} - tags = [{"ref": shared_dict}, {"ref": shared_dict}] - packets = [{"data": "a"}, {"data": "b"}] - - stream = SyncStreamFromLists(tags, packets) - result = list(stream) - - # Modify the shared dict - shared_dict["shared"] = "modified" - - # The stream results should reflect the change (references preserved) - assert result[0][0]["ref"]["shared"] == "modified" - assert result[1][0]["ref"]["shared"] == "modified" - - def test_label_and_metadata(self, sample_tags_packets): - """Test stream labeling and metadata handling.""" - tags, packets = sample_tags_packets - - # Test with custom label - stream = SyncStreamFromLists(tags, packets, label="test_stream") - assert stream.label == "test_stream" - - # Test default label generation - stream_auto = SyncStreamFromLists(tags, packets) - assert "SyncStreamFromLists_" in stream_auto.label - - -class TestSyncStreamFromGenerator: - """Comprehensive tests for SyncStreamFromGenerator implementation.""" - - def test_basic_creation_and_iteration(self, sample_data): - """Test basic creation and iteration functionality.""" - - def generator(): - for item in sample_data: - yield item - - stream = SyncStreamFromGenerator(generator) - assert isinstance(stream, SyncStream) - - result = list(stream) - assert result == sample_data - - def test_empty_generator(self): - """Test with generator that yields nothing.""" - - def empty_gen(): - return - yield # Never reached - - stream = SyncStreamFromGenerator(empty_gen) - result = list(stream) - assert result == [] - - def test_single_item_generator(self): - """Test with generator that yields single item.""" - - def single_gen(): - yield ({"id": 1}, {"data": "test"}) - - stream = SyncStreamFromGenerator(single_gen) - result = list(stream) - assert result == [({"id": 1}, {"data": "test"})] - - def test_generator_exhaustion(self, sample_data): - """Test that generators are exhausted after iteration.""" - - def generator(): - for item in sample_data: - yield item - - stream = SyncStreamFromGenerator(generator) - - # First iteration consumes generator - result1 = list(stream) - assert result1 == sample_data - - # Second iteration gets empty results (generator exhausted) - result2 = list(stream) - assert result2 == [] - - def test_lazy_evaluation(self): - """Test that generator evaluation is lazy.""" - call_log = [] - - def tracking_generator(): - call_log.append("generator_started") - for i in range(3): - call_log.append(f"yielding_{i}") - yield ({"id": i}, {"value": i * 10}) - call_log.append("generator_finished") - - stream = SyncStreamFromGenerator(tracking_generator) - - # Generator should not have started yet - assert call_log == [] - - # Start iteration but don't consume everything - iterator = iter(stream) - next(iterator) - - # Should have started and yielded first item - assert "generator_started" in call_log - assert "yielding_0" in call_log - assert "yielding_1" not in call_log - - def test_generator_with_exception(self): - """Test generator that raises exception during iteration.""" - - def failing_generator(): - yield ({"id": 1}, {"data": "ok"}) - yield ({"id": 2}, {"data": "ok"}) - raise ValueError("Something went wrong") - yield ({"id": 3}, {"data": "never_reached"}) - - stream = SyncStreamFromGenerator(failing_generator) - - # Should propagate exception - with pytest.raises(ValueError, match="Something went wrong"): - list(stream) - - def test_generator_partial_consumption(self, sample_data): - """Test partial consumption of generator.""" - - def generator(): - for item in sample_data: - yield item - - stream = SyncStreamFromGenerator(generator) - - # Consume only part of the stream - iterator = iter(stream) - first_item = next(iterator) - second_item = next(iterator) - - assert first_item == sample_data[0] - assert second_item == sample_data[1] - - # Rest of generator should still be available - remaining = list(iterator) - assert remaining == sample_data[2:] - - def test_generator_with_infinite_sequence(self): - """Test generator with infinite sequence (partial consumption).""" - - def infinite_generator(): - i = 0 - while True: - yield ({"id": i}, {"value": i * i}) - i += 1 - - stream = SyncStreamFromGenerator(infinite_generator) - - # Consume just first few items - iterator = iter(stream) - results = [] - for _ in range(5): - results.append(next(iterator)) - - assert len(results) == 5 - assert results[0] == ({"id": 0}, {"value": 0}) - assert results[4] == ({"id": 4}, {"value": 16}) - - def test_generator_with_complex_logic(self): - """Test generator with complex internal logic.""" - - def complex_generator(): - # Generator with state and complex logic - state = {"count": 0, "filter_odd": True} - - for i in range(10): - state["count"] += 1 - - if state["filter_odd"] and i % 2 == 1: - continue # Skip odd numbers initially - - if i == 6: # Change behavior mid-stream - state["filter_odd"] = False - - yield ({"id": i, "count": state["count"]}, {"value": i * 2}) - - stream = SyncStreamFromGenerator(complex_generator) - result = list(stream) - - # Should have skipped odds initially, then included them - ids = [item[0]["id"] for item in result] - assert 0 in ids and 2 in ids and 4 in ids and 6 in ids # Evens - assert 1 not in ids and 3 not in ids and 5 not in ids # Early odds skipped - assert 7 in ids and 8 in ids and 9 in ids # Later odds included - - def test_keys_inference_limitation(self): - """Test that key inference may be limited for generators.""" - - def generator(): - yield ({"id": 1, "type": "A"}, {"data": "hello", "size": 5}) - yield ({"id": 2, "type": "B"}, {"data": "world", "count": 10}) - - stream = SyncStreamFromGenerator(generator) - - # Keys might not be available without consuming stream - tag_keys, packet_keys = stream.keys() - - # Implementation-dependent: might be None or inferred - if tag_keys is not None: - assert isinstance(tag_keys, (list, tuple, set)) - if packet_keys is not None: - assert isinstance(packet_keys, (list, tuple, set)) - - def test_memory_efficiency(self): - """Test memory efficiency of generator streams.""" - - def memory_efficient_generator(): - # Generate large number of items without storing them all - for i in range(10000): - yield ({"id": i}, {"value": i * 2}) - - stream = SyncStreamFromGenerator(memory_efficient_generator) - - # Process in chunks to verify memory efficiency - count = 0 - for tag, packet in stream: - count += 1 - assert tag["id"] == packet["value"] // 2 - - if count >= 100: # Don't process all 10k items in test - break - - assert count == 100 - - def test_generator_function_vs_generator_object(self, sample_data): - """Test creation with generator function vs generator object.""" - - def gen_function(): - for item in sample_data: - yield item - - # Test with generator function (should work) - stream1 = SyncStreamFromGenerator(gen_function) - result1 = list(stream1) - - # Test with generator object (should work) - gen_object = gen_function() - stream2 = SyncStreamFromGenerator(lambda: gen_object) - result2 = list(stream2) - - assert result1 == sample_data - assert result2 == sample_data - - def test_label_and_metadata(self, sample_data): - """Test stream labeling and metadata handling.""" - - def generator(): - for item in sample_data: - yield item - - # Test with custom label - stream = SyncStreamFromGenerator(generator, label="test_gen_stream") - assert stream.label == "test_gen_stream" - - # Test default label generation - stream_auto = SyncStreamFromGenerator(generator) - assert "SyncStreamFromGenerator_" in stream_auto.label - - -class TestStreamImplementationComparison: - """Tests comparing different stream implementations.""" - - def test_equivalent_output(self, sample_data): - """Test that both implementations produce equivalent output for same data.""" - tags, packets = zip(*sample_data) - - # Create streams from same data using different implementations - list_stream = SyncStreamFromLists(list(tags), list(packets)) - - def generator(): - for item in sample_data: - yield item - - gen_stream = SyncStreamFromGenerator(generator) - - # Results should be identical - list_result = list(list_stream) - gen_result = list(gen_stream) - - assert list_result == gen_result == sample_data - - def test_multiple_iteration_behavior(self, sample_data): - """Test different behavior in multiple iterations.""" - tags, packets = zip(*sample_data) - - list_stream = SyncStreamFromLists(list(tags), list(packets)) - - def generator(): - for item in sample_data: - yield item - - gen_stream = SyncStreamFromGenerator(generator) - - # List stream should support multiple iterations - list_result1 = list(list_stream) - list_result2 = list(list_stream) - assert list_result1 == list_result2 - - # Generator stream should only work once - gen_result1 = list(gen_stream) - gen_result2 = list(gen_stream) - assert gen_result1 == sample_data - assert gen_result2 == [] # Exhausted - - def test_performance_characteristics(self): - """Test performance characteristics of different implementations.""" - import time - - size = 1000 - tags = [{"id": i} for i in range(size)] - packets = [{"value": i * 2} for i in range(size)] - - # Time list-based stream creation and consumption - start = time.time() - list_stream = SyncStreamFromLists(tags, packets) - list_result = list(list_stream) - list_time = time.time() - start - - # Time generator-based stream creation and consumption - def generator(): - for tag, packet in zip(tags, packets): - yield tag, packet - - start = time.time() - gen_stream = SyncStreamFromGenerator(generator) - gen_result = list(gen_stream) - gen_time = time.time() - start - - # Results should be equivalent - assert list_result == gen_result - - # Both should complete in reasonable time (implementation dependent) - assert list_time < 1.0 # Should be fast - assert gen_time < 1.0 # Should be fast - - def test_error_handling_consistency(self): - """Test that error handling is consistent between implementations.""" - - def failing_generator(): - yield ({"id": 1}, {"data": "ok"}) - raise RuntimeError("Generator error") - - # Generator stream should propagate error - gen_stream = SyncStreamFromGenerator(failing_generator) - with pytest.raises(RuntimeError, match="Generator error"): - list(gen_stream) - - # List stream with problematic data - tags = [{"id": 1}, None] # None tag might cause issues - packets = [{"data": "ok"}, {"data": "also_ok"}] - - list_stream = SyncStreamFromLists(tags, packets) - result = list(list_stream) # Should handle None gracefully - - assert len(result) == 2 - assert result[1] == (None, {"data": "also_ok"}) - - def test_integration_with_operations(self, sample_data): - """Test that both stream types work equivalently with operations.""" - from orcabridge.mapper import Filter - - tags, packets = zip(*sample_data) - - # Create equivalent streams - list_stream = SyncStreamFromLists(list(tags), list(packets)) - - def generator(): - for item in sample_data: - yield item - - gen_stream = SyncStreamFromGenerator(generator) - - # Apply same operation to both - filter_op = Filter(lambda tag, packet: tag["id"] > 1) - - filtered_list = filter_op(list_stream) - filtered_gen = filter_op(gen_stream) - - list_result = list(filtered_list) - gen_result = list(filtered_gen) - - # Results should be equivalent - assert list_result == gen_result - assert len(list_result) == 2 # Should have filtered out id=1