diff --git a/.env.example b/.env.example index e7399ed..f2db461 100644 --- a/.env.example +++ b/.env.example @@ -1,27 +1,66 @@ -# OpenAI API Configuration -OPENAI_API_KEY= +# GitHub Configuration (required) +APP_NAME_GITHUB=your_app_name +APP_CLIENT_ID_GITHUB=your_client_id +APP_CLIENT_SECRET_GITHUB=your_client_secret -# LangChain Configuration -LANGCHAIN_TRACING_V2=true +PRIVATE_KEY_BASE64_GITHUB=your_private_key_base64 +WEBHOOK_SECRET_GITHUB=your_webhook_secret + +# AI Provider Selection +AI_PROVIDER=openai # Options: openai, bedrock, vertex_ai + +# Common AI Settings (defaults for all agents) +AI_MAX_TOKENS=4096 +AI_TEMPERATURE=0.1 + +# OpenAI Configuration (when AI_PROVIDER=openai) +OPENAI_API_KEY=your_openai_api_key_here +OPENAI_MODEL=gpt-4.1-mini # Optional, defaults to gpt-4.1-mini + +# AWS Bedrock Configuration (when AI_PROVIDER=bedrock) +# BEDROCK_REGION=us-east-1 +# BEDROCK_MODEL_ID=anthropic.claude-3-sonnet-20240229-v1:0 +# AWS_ACCESS_KEY_ID=your_aws_access_key # Optional, can use AWS profile instead +# AWS_SECRET_ACCESS_KEY=your_aws_secret_key # Optional, can use AWS profile instead +# AWS_PROFILE=your_aws_profile # Optional, alternative to access keys + +# Google Vertex AI Configuration (when AI_PROVIDER=vertex_ai) +# GCP_PROJECT_ID=your-gcp-project-id +# GCP_LOCATION=us-central1 +# VERTEX_AI_MODEL=gemini-pro # Options: gemini-pro, gemini-1.5-pro, claude-3-opus@20240229, etc. +# GCP_SERVICE_ACCOUNT_KEY_BASE64=your_base64_encoded_service_account_key # Optional, can use ADC instead + +# Engine Agent Configuration +AI_ENGINE_MAX_TOKENS=8000 # Default: 8000 +AI_ENGINE_TEMPERATURE=0.1 + +# Feasibility Agent Configuration +AI_FEASIBILITY_MAX_TOKENS=4096 +AI_FEASIBILITY_TEMPERATURE=0.1 + +# Acknowledgment Agent Configuration +AI_ACKNOWLEDGMENT_MAX_TOKENS=2000 +AI_ACKNOWLEDGMENT_TEMPERATURE=0.1 + +# LangSmith Configuration +LANGCHAIN_TRACING_V2=false LANGCHAIN_ENDPOINT=https://api.smith.langchain.com LANGCHAIN_API_KEY= -LANGCHAIN_PROJECT= - -# AWS Configuration -AWS_ACCESS_KEY_ID= -AWS_SECRET_ACCESS_KEY= +LANGCHAIN_PROJECT=watchflow-dev -# Application Configuration -ENVIRONMENT=development +# CORS Configuration CORS_HEADERS=["*"] -CORS_ORIGINS='["http://localhost:3000", "http://127.0.0.1:3000"]' +CORS_ORIGINS=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5500", "https://warestack.github.io", "https://watchflow.dev"] + +# Repository Configuration +REPO_CONFIG_BASE_PATH=.watchflow +REPO_CONFIG_RULES_FILE=rules.yaml -# GitHub OAuth Configuration -APP_NAME_GITHUB= -CLIENT_ID_GITHUB= -CLIENT_SECRET_GITHUB= -PRIVATE_KEY_BASE64_GITHUB= -REDIRECT_URI_GITHUB=http://localhost:3000 +# Logging Configuration +LOG_LEVEL=INFO +LOG_FORMAT=%(asctime)s - %(name)s - %(levelname)s - %(message)s +LOG_FILE_PATH= -# GitHub Webhook Configuration -WEBHOOK_SECRET_GITHUB= +# Development Settings +DEBUG=false +ENVIRONMENT=development diff --git a/env.example b/env.example deleted file mode 100644 index 0e05de0..0000000 --- a/env.example +++ /dev/null @@ -1,80 +0,0 @@ -# ============================================================================= -# WATCHFLOW CONFIGURATION - AI Provider Abstraction Example -# ============================================================================= - -# GitHub Configuration (required) -APP_NAME_GITHUB=your_app_name -APP_CLIENT_ID_GITHUB=your_client_id -APP_CLIENT_SECRET_GITHUB=your_client_secret -PRIVATE_KEY_BASE64_GITHUB=your_private_key_base64 -REDIRECT_URI_GITHUB=your_redirect_uri -WEBHOOK_SECRET_GITHUB=your_webhook_secret - -# ============================================================================= -# AI PROVIDER CONFIGURATION (NEW - PR #18) -# ============================================================================= - -# AI Provider Selection -AI_PROVIDER=openai # Options: openai, bedrock, garden - -# Common AI Settings (defaults for all agents) -AI_MODEL=gpt-4.1-mini -AI_MAX_TOKENS=4096 -AI_TEMPERATURE=0.1 - -# OpenAI Configuration (when AI_PROVIDER=openai) -OPENAI_API_KEY=your_openai_api_key_here - -# AWS Bedrock Configuration (when AI_PROVIDER=bedrock) -# BEDROCK_REGION=us-east-1 -# BEDROCK_MODEL_ID=anthropic.claude-3-sonnet-20240229-v1:0 -# AWS_ACCESS_KEY_ID=your_aws_access_key -# AWS_SECRET_ACCESS_KEY=your_aws_secret_key - -# GCP Model Garden Configuration (when AI_PROVIDER=garden) -# GCP_PROJECT_ID=your-gcp-project-id -# GCP_LOCATION=us-central1 -# GOOGLE_APPLICATION_CREDENTIALS=/path/to/service-account-key.json - -# ============================================================================= -# PER-AGENT AI CONFIGURATION (NEW - PR #18 Enhancement) -# ============================================================================= - -# Engine Agent Configuration -AI_ENGINE_MAX_TOKENS=2000 -AI_ENGINE_TEMPERATURE=0.1 - -# Feasibility Agent Configuration -AI_FEASIBILITY_MAX_TOKENS=4096 -AI_FEASIBILITY_TEMPERATURE=0.1 - -# Acknowledgment Agent Configuration -AI_ACKNOWLEDGMENT_MAX_TOKENS=2000 -AI_ACKNOWLEDGMENT_TEMPERATURE=0.1 - -# ============================================================================= -# EXISTING CONFIGURATION -# ============================================================================= - -# LangSmith Configuration -LANGCHAIN_TRACING_V2=false -LANGCHAIN_ENDPOINT=https://api.smith.langchain.com -LANGCHAIN_API_KEY= -LANGCHAIN_PROJECT=watchflow-dev - -# CORS Configuration -CORS_HEADERS=["*"] -CORS_ORIGINS=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5500", "https://warestack.github.io", "https://watchflow.dev"] - -# Repository Configuration -REPO_CONFIG_BASE_PATH=.watchflow -REPO_CONFIG_RULES_FILE=rules.yaml - -# Logging Configuration -LOG_LEVEL=INFO -LOG_FORMAT=%(asctime)s - %(name)s - %(levelname)s - %(message)s -LOG_FILE_PATH= - -# Development Settings -DEBUG=false -ENVIRONMENT=development \ No newline at end of file diff --git a/src/agents/acknowledgment_agent/agent.py b/src/agents/acknowledgment_agent/agent.py index bee7984..ec862aa 100644 --- a/src/agents/acknowledgment_agent/agent.py +++ b/src/agents/acknowledgment_agent/agent.py @@ -11,7 +11,7 @@ from src.agents.acknowledgment_agent.models import AcknowledgmentContext, AcknowledgmentEvaluation from src.agents.acknowledgment_agent.prompts import create_evaluation_prompt, get_system_prompt from src.agents.base import AgentResult, BaseAgent -from src.core.ai import get_chat_model +from src.providers import get_chat_model logger = logging.getLogger(__name__) diff --git a/src/agents/base.py b/src/agents/base.py index 5857aed..3e15735 100644 --- a/src/agents/base.py +++ b/src/agents/base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from typing import Any, TypeVar -from src.core.ai import get_chat_model +from src.providers import get_chat_model logger = logging.getLogger(__name__) diff --git a/src/agents/engine_agent/models.py b/src/agents/engine_agent/models.py index 16af0a8..910189b 100644 --- a/src/agents/engine_agent/models.py +++ b/src/agents/engine_agent/models.py @@ -41,7 +41,11 @@ class LLMEvaluationResponse(BaseModel): is_violated: bool = Field(description="Whether the rule is violated") message: str = Field(description="Explanation of the violation or why the rule passed") - details: dict[str, Any] = Field(description="Detailed reasoning and metadata", default_factory=dict) + details: dict[str, Any] = Field( + description="Detailed reasoning and metadata", + default_factory=dict, + json_schema_extra={"additionalProperties": False}, + ) how_to_fix: str | None = Field(description="Specific instructions on how to fix the violation", default=None) diff --git a/src/agents/engine_agent/nodes.py b/src/agents/engine_agent/nodes.py index d74da04..3abf542 100644 --- a/src/agents/engine_agent/nodes.py +++ b/src/agents/engine_agent/nodes.py @@ -24,7 +24,7 @@ create_validation_strategy_prompt, get_llm_evaluation_system_prompt, ) -from src.core.ai import get_chat_model +from src.providers import get_chat_model from src.rules.validators import VALIDATOR_REGISTRY logger = logging.getLogger(__name__) diff --git a/src/agents/feasibility_agent/nodes.py b/src/agents/feasibility_agent/nodes.py index fd21271..0270cc2 100644 --- a/src/agents/feasibility_agent/nodes.py +++ b/src/agents/feasibility_agent/nodes.py @@ -6,7 +6,7 @@ from src.agents.feasibility_agent.models import FeasibilityAnalysis, FeasibilityState, YamlGeneration from src.agents.feasibility_agent.prompts import RULE_FEASIBILITY_PROMPT, YAML_GENERATION_PROMPT -from src.core.ai import get_chat_model +from src.providers import get_chat_model logger = logging.getLogger(__name__) diff --git a/src/core/ai.py b/src/core/ai.py deleted file mode 100644 index cd3b754..0000000 --- a/src/core/ai.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Provider-agnostic AI chat model factory. - -This module provides a simple interface to the AI provider system. -For complex provider logic, see src.core.ai_providers and src.integrations. -""" - -from __future__ import annotations - -from src.core.ai_providers.factory import get_chat_model as _get_chat_model - - -def get_chat_model( - *, - provider: str | None = None, - model: str | None = None, - max_tokens: int | None = None, - temperature: float | None = None, - agent: str | None = None, - **kwargs, -): - """ - Return a chat model client based on configuration. - - Args: - provider: AI provider name (openai, bedrock, vertex_ai) - model: Model name/ID - max_tokens: Override max tokens (takes precedence over agent config) - temperature: Override temperature (takes precedence over agent config) - agent: Agent name for per-agent configuration ('engine_agent', 'feasibility_agent', 'acknowledgment_agent') - **kwargs: Additional provider-specific parameters - - Providers: - - "openai": uses OpenAI API - - "bedrock": uses AWS Bedrock (supports both standard and Anthropic inference profiles) - - "vertex_ai": uses GCP Vertex AI - """ - return _get_chat_model( - provider=provider, model=model, max_tokens=max_tokens, temperature=temperature, agent=agent, **kwargs - ) diff --git a/src/core/ai_providers/__init__.py b/src/core/ai_providers/__init__.py deleted file mode 100644 index 65c467e..0000000 --- a/src/core/ai_providers/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -AI Providers module for managing different AI service providers. - -This module provides a unified interface for accessing various AI providers -including OpenAI, AWS Bedrock, and GCP Model Garden. -""" - -from .base import BaseAIProvider -from .bedrock_provider import BedrockProvider -from .factory import get_ai_provider -from .garden_provider import GardenProvider -from .openai_provider import OpenAIProvider - -__all__ = [ - "BaseAIProvider", - "OpenAIProvider", - "BedrockProvider", - "GardenProvider", - "get_ai_provider", -] diff --git a/src/core/ai_providers/bedrock_provider.py b/src/core/ai_providers/bedrock_provider.py deleted file mode 100644 index 4e227bc..0000000 --- a/src/core/ai_providers/bedrock_provider.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -AWS Bedrock AI Provider implementation. - -This provider handles both standard Bedrock models and Anthropic models -requiring inference profiles. -""" - -from typing import Any - -from src.integrations.aws_bedrock import get_bedrock_client - -from .base import BaseAIProvider - - -class BedrockProvider(BaseAIProvider): - """AWS Bedrock AI Provider with hybrid client support.""" - - def get_chat_model(self) -> Any: - """Get Bedrock chat model using appropriate client.""" - # Get the appropriate Bedrock client (uses config directly) - client = get_bedrock_client() - - return client - - def supports_structured_output(self) -> bool: - """Bedrock supports structured output.""" - return True - - def get_provider_name(self) -> str: - """Get provider name.""" - return "bedrock" - - def get_model_info(self) -> dict[str, Any]: - """Get enhanced model information.""" - info = super().get_model_info() - model_id = self.kwargs.get("model_id", self.model) - info.update( - { - "model_id": model_id, - "supports_inference_profiles": model_id.startswith("anthropic."), - } - ) - return info diff --git a/src/core/ai_providers/factory.py b/src/core/ai_providers/factory.py deleted file mode 100644 index b644b36..0000000 --- a/src/core/ai_providers/factory.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -AI Provider Factory. - -This module provides a factory function to create the appropriate -AI provider based on configuration. -""" - -from typing import Any - -from src.core.config import config - -from .base import BaseAIProvider -from .bedrock_provider import BedrockProvider -from .garden_provider import GardenProvider -from .openai_provider import OpenAIProvider - - -def get_ai_provider( - provider: str | None = None, - model: str | None = None, - max_tokens: int | None = None, - temperature: float | None = None, - agent: str | None = None, - **kwargs, -) -> BaseAIProvider: - """ - Get the appropriate AI provider based on configuration. - - Args: - provider: AI provider name (openai, bedrock, vertex_ai) - model: Model name/ID - max_tokens: Maximum tokens to generate - temperature: Sampling temperature - agent: Agent name for per-agent configuration - **kwargs: Additional provider-specific parameters - - Returns: - Configured AI provider instance - """ - # Use config defaults if not provided - provider = provider or config.ai.provider or "openai" - - # Get model with fallbacks handled by config - if not model: - model = config.ai.get_model_for_provider(provider) - - # Determine tokens and temperature with precedence: explicit params > agent config > global config - if max_tokens is not None: - tokens = max_tokens - else: - tokens = config.ai.get_max_tokens_for_agent(agent) - - if temperature is not None: - temp = temperature - else: - temp = config.ai.get_temperature_for_agent(agent) - - # Create provider-specific parameters - provider_kwargs = kwargs.copy() - - if provider.lower() == "openai": - provider_kwargs.update( - { - "api_key": config.ai.api_key, - } - ) - return OpenAIProvider(model=model, max_tokens=tokens, temperature=temp, **provider_kwargs) - - elif provider.lower() == "bedrock": - return BedrockProvider( - model=model, - max_tokens=tokens, - temperature=temp, - ) - - elif provider.lower() in ["garden", "model_garden", "gcp"]: - return GardenProvider( - model=model, - max_tokens=tokens, - temperature=temp, - ) - - else: - raise ValueError(f"Unsupported AI provider: {provider}") - - -def get_chat_model( - provider: str | None = None, - model: str | None = None, - max_tokens: int | None = None, - temperature: float | None = None, - agent: str | None = None, - **kwargs, -) -> Any: - """ - Get a chat model instance using the appropriate provider. - - This is a convenience function that creates a provider and returns its chat model. - """ - provider_instance = get_ai_provider( - provider=provider, model=model, max_tokens=max_tokens, temperature=temperature, agent=agent, **kwargs - ) - - return provider_instance.get_chat_model() diff --git a/src/core/ai_providers/garden_provider.py b/src/core/ai_providers/garden_provider.py deleted file mode 100644 index c9dd617..0000000 --- a/src/core/ai_providers/garden_provider.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -GCP Model Garden Provider implementation. -""" - -from typing import Any - -from src.integrations.gcp_garden import get_garden_client - -from .base import BaseAIProvider - - -class GardenProvider(BaseAIProvider): - """GCP Model Garden Provider.""" - - def get_chat_model(self) -> Any: - """Get Model Garden chat model.""" - return get_garden_client() - - def supports_structured_output(self) -> bool: - """Model Garden supports structured output.""" - return True - - def get_provider_name(self) -> str: - """Get provider name.""" - return "garden" diff --git a/src/core/config.py b/src/core/config.py index 505d1b8..2d38720 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -36,7 +36,7 @@ class AIConfig: # Provider-specific model fields openai_model: str | None = None bedrock_model_id: str | None = None - model_garden_model: str | None = None + vertex_ai_model: str | None = None # Optional provider-specific fields # AWS Bedrock bedrock_region: str | None = None @@ -60,9 +60,9 @@ def get_model_for_provider(self, provider: str) -> str: return self.openai_model or "gpt-4.1-mini" elif provider == "bedrock": return self.bedrock_model_id or "anthropic.claude-3-sonnet-20240229-v1:0" - elif provider in ["garden", "model_garden", "gcp"]: - # Support both Gemini and Claude models in Model Garden - return self.model_garden_model or "gemini-pro" + elif provider in ["vertex_ai", "garden", "model_garden", "gcp", "vertex", "vertexai"]: + # Support both Gemini and Claude models in Vertex AI + return self.vertex_ai_model or "gemini-pro" else: return "gpt-4.1-mini" # Ultimate fallback @@ -142,7 +142,7 @@ def __init__(self): # Provider-specific model fields openai_model=os.getenv("OPENAI_MODEL"), bedrock_model_id=os.getenv("BEDROCK_MODEL_ID"), - model_garden_model=os.getenv("MODEL_GARDEN_MODEL"), + vertex_ai_model=os.getenv("VERTEX_AI_MODEL") or os.getenv("MODEL_GARDEN_MODEL"), # Support legacy name # AWS Bedrock configuration bedrock_region=os.getenv("BEDROCK_REGION"), aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), @@ -238,12 +238,12 @@ def validate(self) -> bool: errors.append("OPENAI_API_KEY is required for OpenAI provider") if self.ai.provider == "bedrock": # Bedrock credentials are read from AWS environment/IMDS; encourage region/model hints - if not self.ai.bedrock_model_id and not self.ai.model: - errors.append("BEDROCK_MODEL_ID or AI_MODEL is required for Bedrock provider") - if self.ai.provider in {"garden", "vertex", "vertexai", "model_garden"}: - # Vertex typically uses ADC; project/location optional but recommended - if not self.ai.model: - errors.append("AI_MODEL is required for GCP Garden/Vertex provider") + if not self.ai.bedrock_model_id: + errors.append("BEDROCK_MODEL_ID is required for Bedrock provider") + if self.ai.provider in {"vertex_ai", "garden", "vertex", "vertexai", "model_garden", "gcp"}: + # Vertex AI typically uses ADC; project/location optional but recommended + if not self.ai.vertex_ai_model: + errors.append("VERTEX_AI_MODEL is required for Google Vertex AI provider") if errors: raise ValueError(f"Configuration errors: {', '.join(errors)}") diff --git a/src/integrations/gcp_garden.py b/src/integrations/gcp_garden.py index 0804c30..c9d5c72 100644 --- a/src/integrations/gcp_garden.py +++ b/src/integrations/gcp_garden.py @@ -1,8 +1,8 @@ """ -GCP Vertex AI integration for AI model access. +GCP Model Garden integration for AI model access. -This module handles Google Cloud Platform Vertex AI API interactions -for AI model access through Model Garden. +This module handles Google Cloud Platform Model Garden API interactions +for AI model access, supporting both Google (Gemini) and third-party (Claude) models. """ from __future__ import annotations @@ -16,7 +16,7 @@ def get_garden_client() -> Any: """ Get GCP Model Garden client for accessing both Google and third-party models. - + Returns: Model Garden client instance """ @@ -27,48 +27,46 @@ def get_garden_client() -> Any: def get_model_garden_client() -> Any: """ Get GCP Model Garden client for accessing both Google and third-party models. - + This client provides access to models from various providers through Google's Model Garden marketplace, including: - Google models: gemini-1.0-pro, gemini-1.5-pro, gemini-2.0-flash-exp - Third-party models: Claude, Llama, etc. (when available) - + Returns: Model Garden client instance """ # Get GCP credentials from config project_id = config.ai.gcp_project - location = config.ai.gcp_location or 'us-central1' + location = config.ai.gcp_location or "us-central1" service_account_key_base64 = config.ai.gcp_service_account_key_base64 - model = config.ai.get_model_for_provider('garden') - + model = config.ai.get_model_for_provider("garden") + if not project_id: - raise ValueError( - "GCP project ID required for Model Garden. Set GCP_PROJECT_ID in config" - ) + raise ValueError("GCP project ID required for Model Garden. Set GCP_PROJECT_ID in config") # Handle base64 encoded service account key if service_account_key_base64: import base64 import tempfile - + try: # Decode the base64 key - key_data = base64.b64decode(service_account_key_base64).decode('utf-8') - + key_data = base64.b64decode(service_account_key_base64).decode("utf-8") + # Create a temporary file with the key - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: f.write(key_data) credentials_path = f.name - + # Set the environment variable for Google Cloud to use - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = credentials_path - + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_path + except Exception as e: raise ValueError(f"Failed to decode GCP service account key: {e}") from e # Check if it's a Claude model - if 'claude' in model.lower(): + if "claude" in model.lower(): return get_claude_model_garden_client(project_id, location, model) else: return get_gemini_model_garden_client(project_id, location, model) @@ -77,12 +75,15 @@ def get_model_garden_client() -> Any: def get_claude_model_garden_client(project_id: str, location: str, model: str) -> Any: """ Get Claude model via GCP Model Garden using Anthropic Vertex SDK. - + + Note: The AnthropicVertex SDK is used for Claude models in Model Garden, + even though the provider is called "garden" in our configuration. + Args: project_id: GCP project ID location: GCP location/region model: Model name (e.g., claude-3-opus@20240229) - + Returns: Claude client instance """ @@ -94,9 +95,9 @@ def get_claude_model_garden_client(project_id: str, location: str, model: str) - "Install with: pip install 'anthropic[vertex]'" ) from e - # Create Anthropic Vertex client + # Create Anthropic Vertex client (this is the SDK class name for Model Garden) client = AnthropicVertex(region=location, project_id=project_id) - + # Wrap it to match LangChain interface return ClaudeModelGardenWrapper(client, model) @@ -104,12 +105,15 @@ def get_claude_model_garden_client(project_id: str, location: str, model: str) - def get_gemini_model_garden_client(project_id: str, location: str, model: str) -> Any: """ Get Gemini model via GCP Model Garden using LangChain. - + + Note: ChatVertexAI is the LangChain class name for Model Garden models, + even though the provider is called "garden" in our configuration. + Args: project_id: GCP project ID location: GCP location/region model: Model name (e.g., gemini-pro) - + Returns: Gemini client instance """ @@ -123,7 +127,7 @@ def get_gemini_model_garden_client(project_id: str, location: str, model: str) - # Try multiple Gemini model names in order of preference model_candidates = [model, "gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash"] - + for candidate_model in model_candidates: try: return ChatVertexAI( @@ -136,7 +140,7 @@ def get_gemini_model_garden_client(project_id: str, location: str, model: str) - continue # Try next model else: raise # Re-raise if it's not a model not found error - + # If all models fail, raise an error raise RuntimeError( f"None of the Gemini models are available in your GCP project. " @@ -149,45 +153,44 @@ class ClaudeModelGardenWrapper: """ Wrapper for Claude Model Garden client to match LangChain interface. """ - + def __init__(self, client, model: str): self.client = client self.model = model - + async def ainvoke(self, messages, **kwargs): """Async invoke method.""" # Convert LangChain messages to Anthropic format anthropic_messages = [] for msg in messages: - if hasattr(msg, 'content'): + if hasattr(msg, "content"): content = msg.content role = "user" if msg.type == "human" else "assistant" else: content = str(msg) role = "user" - - anthropic_messages.append({ - "role": role, - "content": content - }) - + + anthropic_messages.append({"role": role, "content": content}) + # Call Claude API response = self.client.messages.create( model=self.model, messages=anthropic_messages, - max_tokens=kwargs.get('max_tokens', 4096), - temperature=kwargs.get('temperature', 0.1), + max_tokens=kwargs.get("max_tokens", 4096), + temperature=kwargs.get("temperature", 0.1), ) - + # Convert response to LangChain format from langchain_core.messages import AIMessage + return AIMessage(content=response.content[0].text) - + def invoke(self, messages, **kwargs): """Sync invoke method.""" import asyncio + return asyncio.run(self.ainvoke(messages, **kwargs)) - + def with_structured_output(self, schema, **kwargs): """Structured output method.""" # For now, return self and handle structured output in ainvoke diff --git a/src/providers/__init__.py b/src/providers/__init__.py new file mode 100644 index 0000000..9b7d5cd --- /dev/null +++ b/src/providers/__init__.py @@ -0,0 +1,17 @@ +""" +AI Provider package for managing different AI service providers. + +This package provides a unified interface for accessing various AI providers +including OpenAI, AWS Bedrock, and Google Vertex AI. + +The main entry point is the factory functions: +- get_provider() - Get a provider instance +- get_chat_model() - Get a ready-to-use chat model +""" + +from src.providers.factory import get_chat_model, get_provider + +__all__ = [ + "get_provider", + "get_chat_model", +] diff --git a/src/core/ai_providers/base.py b/src/providers/base_provider.py similarity index 92% rename from src/core/ai_providers/base.py rename to src/providers/base_provider.py index 0646705..66ed3eb 100644 --- a/src/core/ai_providers/base.py +++ b/src/providers/base_provider.py @@ -1,7 +1,7 @@ """ Base AI Provider interface. -This module defines the base interface that all AI providers must implement. +This module defines the abstract base class that all AI providers must implement. """ from abc import ABC, abstractmethod diff --git a/src/providers/bedrock_provider.py b/src/providers/bedrock_provider.py new file mode 100644 index 0000000..1a12bf6 --- /dev/null +++ b/src/providers/bedrock_provider.py @@ -0,0 +1,244 @@ +""" +AWS Bedrock AI Provider implementation. + +This provider handles AWS Bedrock API interactions, including both +standard langchain-aws clients and the Anthropic Bedrock client +for inference profile support. All integration logic is consolidated here. +""" + +from __future__ import annotations + +import os +from typing import Any + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatGeneration, ChatResult + +from src.core.config import config +from src.providers.base_provider import BaseAIProvider + + +class BedrockProvider(BaseAIProvider): + """AWS Bedrock AI Provider with hybrid client support.""" + + def get_chat_model(self) -> Any: + """Get Bedrock chat model using appropriate client.""" + model_id = self.model + + # Check if this is already an inference profile ID + if model_id.startswith("us.") or model_id.startswith("global.") or model_id.startswith("arn:"): + return self._get_anthropic_inference_profile_client(model_id) + + # Try to find an inference profile for this model + inference_profile = self._find_inference_profile(model_id) + if inference_profile: + return self._get_anthropic_inference_profile_client(inference_profile) + + # Fallback to direct model access + if self._is_anthropic_model(model_id): + # For Anthropic models, try standard client first (supports structured output) + try: + return self._get_standard_bedrock_client() + except Exception: + # If standard client fails, fall back to Anthropic client + client = self._get_anthropic_bedrock_client() + return self._wrap_anthropic_client(client, model_id) + else: + # Use standard client for other models + return self._get_standard_bedrock_client() + + def supports_structured_output(self) -> bool: + """Bedrock supports structured output.""" + return True + + def get_provider_name(self) -> str: + """Get provider name.""" + return "bedrock" + + def _get_anthropic_bedrock_client(self) -> Any: + """Get Anthropic Bedrock client for models requiring inference profiles.""" + try: + from anthropic import AnthropicBedrock + except ImportError as e: + raise RuntimeError( + "Anthropic Bedrock client requires 'anthropic' package. Install with: pip install anthropic" + ) from e + + aws_access_key = config.ai.aws_access_key_id + aws_secret_key = config.ai.aws_secret_access_key + aws_region = config.ai.bedrock_region or "us-east-1" + aws_profile = config.ai.aws_profile + + if aws_profile: + os.environ["AWS_PROFILE"] = aws_profile + + client_kwargs = { + "aws_region": aws_region, + "aws_profile": aws_profile, + } + + if aws_access_key and aws_secret_key: + client_kwargs.update( + { + "aws_access_key": aws_access_key, + "aws_secret_key": aws_secret_key, + } + ) + + return AnthropicBedrock(**client_kwargs) + + def _get_standard_bedrock_client(self) -> Any: + """Get standard langchain-aws Bedrock client for on-demand models.""" + try: + from langchain_aws import ChatBedrock + except ImportError as e: + raise RuntimeError( + "Standard Bedrock client requires 'langchain-aws' package. Install with: pip install langchain-aws" + ) from e + + aws_access_key = config.ai.aws_access_key_id + aws_secret_key = config.ai.aws_secret_access_key + aws_region = config.ai.bedrock_region or "us-east-1" + aws_profile = config.ai.aws_profile + + if aws_profile: + os.environ["AWS_PROFILE"] = aws_profile + + client_kwargs = { + "model_id": self.model, + "region_name": aws_region, + } + + if self.model.startswith("arn:") or self.model.startswith("us.") or self.model.startswith("global."): + if "anthropic" in self.model.lower(): + client_kwargs["provider"] = "anthropic" + elif "amazon" in self.model.lower(): + client_kwargs["provider"] = "amazon" + elif "meta" in self.model.lower(): + client_kwargs["provider"] = "meta" + + if aws_access_key and aws_secret_key: + client_kwargs.update( + { + "aws_access_key_id": aws_access_key, + "aws_secret_access_key": aws_secret_key, + } + ) + + return ChatBedrock(**client_kwargs) + + def _is_anthropic_model(self, model_id: str) -> bool: + """Check if a model ID is an Anthropic model.""" + return model_id.startswith("anthropic.") + + def _find_inference_profile(self, model_id: str) -> str | None: + """Find an inference profile that contains the specified model.""" + try: + import boto3 + + aws_region = config.ai.bedrock_region or "us-east-1" + aws_access_key = config.ai.aws_access_key_id + aws_secret_key = config.ai.aws_secret_access_key + + client_kwargs = {"region_name": aws_region} + if aws_access_key and aws_secret_key: + client_kwargs.update({"aws_access_key_id": aws_access_key, "aws_secret_access_key": aws_secret_key}) + + bedrock = boto3.client("bedrock", **client_kwargs) + response = bedrock.list_inference_profiles() + profiles = response.get("inferenceProfiles", []) + + for profile in profiles: + profile_name = profile.get("name", "") + profile_arn = profile.get("arn", "") + + if any(keyword in profile_name.lower() for keyword in ["claude", "anthropic", "general", "default"]): + if "anthropic" in model_id.lower() or "claude" in model_id.lower(): + return profile_arn + elif any(keyword in profile_name.lower() for keyword in ["amazon", "titan", "nova"]): + if "amazon" in model_id.lower() or "titan" in model_id.lower() or "nova" in model_id.lower(): + return profile_arn + elif any(keyword in profile_name.lower() for keyword in ["meta", "llama"]): + if "meta" in model_id.lower() or "llama" in model_id.lower(): + return profile_arn + + return None + except Exception: + return None + + def _get_anthropic_inference_profile_client(self, inference_profile_id: str) -> Any: + """Get Anthropic client configured for inference profile models.""" + client = self._get_anthropic_bedrock_client() + return self._wrap_anthropic_client(client, inference_profile_id) + + def _wrap_anthropic_client(self, client: Any, model_id: str) -> Any: + """Wrap Anthropic Bedrock client to be langchain-compatible.""" + + class AnthropicBedrockWrapper(BaseChatModel): + """Wrapper for Anthropic Bedrock client to be langchain-compatible.""" + + anthropic_client: Any + model_id: str + max_tokens: int + temperature: float + + def __init__(self, anthropic_client: Any, model_id: str, max_tokens: int, temperature: float): + super().__init__( + anthropic_client=anthropic_client, + model_id=model_id, + max_tokens=max_tokens, + temperature=temperature, + ) + + @property + def _llm_type(self) -> str: + return "anthropic_bedrock" + + def with_structured_output(self, output_model: Any) -> Any: + """Add structured output support.""" + return self + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: Any | None = None, + ) -> ChatResult: + """Generate a response using the Anthropic client.""" + anthropic_messages = [] + for msg in messages: + if msg.type == "human": + role = "user" + elif msg.type == "ai": + role = "assistant" + elif msg.type == "system": + role = "user" + else: + role = "user" + + anthropic_messages.append({"role": role, "content": msg.content}) + + response = self.anthropic_client.messages.create( + model=self.model_id, + max_tokens=self.max_tokens, + temperature=self.temperature, + messages=anthropic_messages, + ) + + content = response.content[0].text if response.content else "" + message = BaseMessage(content=content, type="assistant") + generation = ChatGeneration(message=message) + + return ChatResult(generations=[generation]) + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: Any | None = None, + ) -> ChatResult: + """Async generate using the Anthropic client.""" + return self._generate(messages, stop, run_manager) + + return AnthropicBedrockWrapper(client, model_id, self.max_tokens, self.temperature) diff --git a/src/providers/factory.py b/src/providers/factory.py new file mode 100644 index 0000000..ebb745e --- /dev/null +++ b/src/providers/factory.py @@ -0,0 +1,135 @@ +""" +AI Provider Factory. + +This module provides factory functions to create the appropriate +AI provider based on configuration using a simple mapping approach. +""" + +from typing import Any + +from src.core.config import config +from src.providers.base_provider import BaseAIProvider +from src.providers.bedrock_provider import BedrockProvider +from src.providers.openai_provider import OpenAIProvider +from src.providers.vertex_ai_provider import VertexAIProvider + +# Provider mapping - canonical names to provider classes +PROVIDER_MAP: dict[str, type[BaseAIProvider]] = { + "openai": OpenAIProvider, + "bedrock": BedrockProvider, + "vertex_ai": VertexAIProvider, + # Legacy aliases for backward compatibility + "garden": VertexAIProvider, + "model_garden": VertexAIProvider, + "gcp": VertexAIProvider, + "vertex": VertexAIProvider, + "vertexai": VertexAIProvider, +} + + +def get_provider( + provider: str | None = None, + model: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + agent: str | None = None, + **kwargs: Any, +) -> BaseAIProvider: + """ + Get the appropriate AI provider based on configuration. + + Args: + provider: AI provider name (openai, bedrock, vertex_ai) + model: Model name/ID + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + agent: Agent name for per-agent configuration + **kwargs: Additional provider-specific parameters + + Returns: + Configured AI provider instance + + Raises: + ValueError: If provider is not supported + """ + # Use config defaults if not provided + provider_name = provider or config.ai.provider or "openai" + provider_name = provider_name.lower() + + # Get provider class from mapping + provider_class = PROVIDER_MAP.get(provider_name) + if not provider_class: + supported = ", ".join( + sorted(set(PROVIDER_MAP.keys()) - {"garden", "model_garden", "gcp", "vertex", "vertexai"}) + ) + raise ValueError(f"Unsupported AI provider: {provider_name}. Supported: {supported}") + + # Get model with fallbacks handled by config + if not model: + # Normalize provider name for config lookup (use canonical name) + canonical_provider = ( + "vertex_ai" if provider_name in ["garden", "model_garden", "gcp", "vertex", "vertexai"] else provider_name + ) + model = config.ai.get_model_for_provider(canonical_provider) + + # Determine tokens and temperature with precedence: explicit params > agent config > global config + if max_tokens is not None: + tokens = max_tokens + else: + tokens = config.ai.get_max_tokens_for_agent(agent) + + if temperature is not None: + temp = temperature + else: + temp = config.ai.get_temperature_for_agent(agent) + + # Prepare provider-specific kwargs + provider_kwargs = kwargs.copy() + + # Add provider-specific config + if provider_class == OpenAIProvider: + provider_kwargs["api_key"] = config.ai.api_key + + # Instantiate provider + return provider_class( + model=model, + max_tokens=tokens, + temperature=temp, + **provider_kwargs, + ) + + +def get_chat_model( + provider: str | None = None, + model: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + agent: str | None = None, + **kwargs: Any, +) -> Any: + """ + Get a chat model instance using the appropriate provider. + + This is a convenience function that creates a provider and returns its chat model. + + Args: + provider: AI provider name (openai, bedrock, vertex_ai) + model: Model name/ID + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + agent: Agent name for per-agent configuration + **kwargs: Additional provider-specific parameters + + Returns: + Ready-to-use chat model instance + """ + provider_instance = get_provider( + provider=provider, + model=model, + max_tokens=max_tokens, + temperature=temperature, + agent=agent, + **kwargs, + ) + + return provider_instance.get_chat_model() diff --git a/src/core/ai_providers/openai_provider.py b/src/providers/openai_provider.py similarity index 64% rename from src/core/ai_providers/openai_provider.py rename to src/providers/openai_provider.py index 2505048..369d49a 100644 --- a/src/core/ai_providers/openai_provider.py +++ b/src/providers/openai_provider.py @@ -1,10 +1,12 @@ """ OpenAI AI Provider implementation. + +This provider handles OpenAI API integration directly. """ from typing import Any -from .base import BaseAIProvider +from src.providers.base_provider import BaseAIProvider class OpenAIProvider(BaseAIProvider): @@ -19,7 +21,13 @@ def get_chat_model(self) -> Any: "OpenAI provider requires 'langchain-openai' package. Install with: pip install langchain-openai" ) from e - return ChatOpenAI(model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, **self.kwargs) + return ChatOpenAI( + model=self.model, + max_tokens=self.max_tokens, + temperature=self.temperature, + api_key=self.kwargs.get("api_key"), + **{k: v for k, v in self.kwargs.items() if k != "api_key"}, + ) def supports_structured_output(self) -> bool: """OpenAI supports structured output.""" diff --git a/src/providers/vertex_ai_provider.py b/src/providers/vertex_ai_provider.py new file mode 100644 index 0000000..77f6fd1 --- /dev/null +++ b/src/providers/vertex_ai_provider.py @@ -0,0 +1,142 @@ +""" +Google Vertex AI Provider implementation. + +This provider handles Google Cloud Platform Vertex AI (Model Garden) API interactions +for AI model access, supporting both Google (Gemini) and third-party (Claude) models. +All integration logic is consolidated here. +""" + +from __future__ import annotations + +import base64 +import os +import tempfile +from typing import Any + +from src.core.config import config +from src.providers.base_provider import BaseAIProvider + + +class VertexAIProvider(BaseAIProvider): + """Google Vertex AI Provider (Model Garden).""" + + def get_chat_model(self) -> Any: + """Get Vertex AI chat model.""" + project_id = config.ai.gcp_project + location = config.ai.gcp_location or "us-central1" + service_account_key_base64 = config.ai.gcp_service_account_key_base64 + + if not project_id: + raise ValueError("GCP project ID required for Vertex AI. Set GCP_PROJECT_ID in config") + + # Handle base64 encoded service account key + if service_account_key_base64: + try: + key_data = base64.b64decode(service_account_key_base64).decode("utf-8") + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write(key_data) + credentials_path = f.name + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_path + except Exception as e: + raise ValueError(f"Failed to decode GCP service account key: {e}") from e + + # Check if it's a Claude model + if "claude" in self.model.lower(): + return self._get_claude_client(project_id, location, self.model) + else: + return self._get_gemini_client(project_id, location, self.model) + + def supports_structured_output(self) -> bool: + """Vertex AI supports structured output.""" + return True + + def get_provider_name(self) -> str: + """Get provider name.""" + return "vertex_ai" + + def _get_claude_client(self, project_id: str, location: str, model: str) -> Any: + """Get Claude model via Vertex AI using Anthropic Vertex SDK.""" + try: + from anthropic import AnthropicVertex + except ImportError as e: + raise RuntimeError( + "Claude Vertex AI client requires 'anthropic[vertex]' package. " + "Install with: pip install 'anthropic[vertex]'" + ) from e + + client = AnthropicVertex(region=location, project_id=project_id) + return self._ClaudeVertexWrapper(client, model) + + def _get_gemini_client(self, project_id: str, location: str, model: str) -> Any: + """Get Gemini model via Vertex AI using LangChain.""" + try: + from langchain_google_vertexai import ChatVertexAI + except ImportError as e: + raise RuntimeError( + "Gemini Vertex AI client requires 'langchain-google-vertexai' package. " + "Install with: pip install langchain-google-vertexai" + ) from e + + # Try multiple Gemini model names in order of preference + model_candidates = [model, "gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash"] + + for candidate_model in model_candidates: + try: + return ChatVertexAI( + model=candidate_model, + project=project_id, + location=location, + ) + except Exception as e: + if "not found" in str(e).lower() or "404" in str(e): + continue + else: + raise + + raise RuntimeError( + f"None of the Gemini models are available in your GCP project. " + f"Tried: {', '.join(model_candidates)}. " + f"Please check your GCP project configuration and model access." + ) + + class _ClaudeVertexWrapper: + """Wrapper for Claude Vertex AI client to match LangChain interface.""" + + def __init__(self, client: Any, model: str): + self.client = client + self.model = model + + async def ainvoke(self, messages: list[Any], **kwargs: Any) -> Any: + """Async invoke method.""" + from langchain_core.messages import AIMessage + + anthropic_messages = [] + for msg in messages: + if hasattr(msg, "content"): + content = msg.content + role = "user" if msg.type == "human" else "assistant" + else: + content = str(msg) + role = "user" + + anthropic_messages.append({"role": role, "content": content}) + + response = self.client.messages.create( + model=self.model, + messages=anthropic_messages, + max_tokens=kwargs.get("max_tokens", 4096), + temperature=kwargs.get("temperature", 0.1), + ) + + return AIMessage(content=response.content[0].text) + + def invoke(self, messages: list[Any], **kwargs: Any) -> Any: + """Sync invoke method.""" + import asyncio + + return asyncio.run(self.ainvoke(messages, **kwargs)) + + def with_structured_output(self, schema: Any, **kwargs: Any) -> Any: + """Structured output method.""" + self._output_schema = schema + return self diff --git a/tests/unit/test_rule_engine_agent.py b/tests/unit/test_rule_engine_agent.py index 02dac37..d4cd66d 100644 --- a/tests/unit/test_rule_engine_agent.py +++ b/tests/unit/test_rule_engine_agent.py @@ -584,8 +584,8 @@ class TestDynamicHowToFix: """Test dynamic how-to-fix message generation.""" @pytest.mark.asyncio - @patch("src.agents.engine_agent.nodes.ChatOpenAI") - async def test_dynamic_how_to_fix_generation(self, mock_chat_openai): + @patch("src.agents.engine_agent.nodes.get_chat_model") + async def test_dynamic_how_to_fix_generation(self, mock_get_chat_model): """Test dynamic how-to-fix message generation.""" from src.agents.engine_agent.nodes import _generate_dynamic_how_to_fix @@ -596,7 +596,7 @@ async def test_dynamic_how_to_fix_generation(self, mock_chat_openai): how_to_fix="Add the 'security' and 'review' labels to this pull request" ) mock_llm.with_structured_output.return_value = mock_structured_llm - mock_chat_openai.return_value = mock_llm + mock_get_chat_model.return_value = mock_llm # Test data rule_desc = RuleDescription( @@ -615,13 +615,13 @@ async def test_dynamic_how_to_fix_generation(self, mock_chat_openai): assert "labels" in result @pytest.mark.asyncio - @patch("src.agents.engine_agent.nodes.ChatOpenAI") - async def test_dynamic_how_to_fix_fallback(self, mock_chat_openai): + @patch("src.agents.engine_agent.nodes.get_chat_model") + async def test_dynamic_how_to_fix_fallback(self, mock_get_chat_model): """Test dynamic how-to-fix message generation with fallback.""" from src.agents.engine_agent.nodes import _generate_dynamic_how_to_fix # Mock LLM error - mock_chat_openai.side_effect = Exception("LLM error") + mock_get_chat_model.side_effect = Exception("LLM error") # Test data rule_desc = RuleDescription(