|
56 | 56 | from mcp import ClientSession |
57 | 57 | from mcp.client.sse import sse_client |
58 | 58 | from mcp.client.streamable_http import streamablehttp_client |
| 59 | +from pydantic import ValidationError |
59 | 60 | from sqlalchemy import and_, or_, select |
60 | 61 | from sqlalchemy.exc import IntegrityError |
61 | 62 | from sqlalchemy.orm import Session |
@@ -266,6 +267,10 @@ class GatewayConnectionError(GatewayError): |
266 | 267 | """ |
267 | 268 |
|
268 | 269 |
|
| 270 | +class OAuthToolValidationError(GatewayConnectionError): |
| 271 | + """Raised when tool validation fails during OAuth-driven fetch.""" |
| 272 | + |
| 273 | + |
269 | 274 | class GatewayService: # pylint: disable=too-many-instance-attributes |
270 | 275 | """Service for managing federated gateways. |
271 | 276 |
|
@@ -986,6 +991,10 @@ async def fetch_tools_after_oauth(self, db: Session, gateway_id: str, app_user_e |
986 | 991 |
|
987 | 992 | return {"capabilities": capabilities, "tools": tools, "resources": resources, "prompts": prompts} |
988 | 993 |
|
| 994 | + except GatewayConnectionError as gce: |
| 995 | + # Surface validation or depth-related failures directly to the user |
| 996 | + logger.error(f"GatewayConnectionError during OAuth fetch for {gateway_id}: {gce}") |
| 997 | + raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(gce)}") |
989 | 998 | except Exception as e: |
990 | 999 | logger.error(f"Failed to fetch tools after OAuth for gateway {gateway_id}: {e}") |
991 | 1000 | raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(e)}") |
@@ -3146,6 +3155,66 @@ async def _publish_event(self, event: Dict[str, Any]) -> None: |
3146 | 3155 | """ |
3147 | 3156 | await self._event_service.publish_event(event) |
3148 | 3157 |
|
| 3158 | + def _validate_tools(self, tools: list[dict[str, Any]], context: str = "default") -> list[ToolCreate]: |
| 3159 | + """Validate tools individually with richer logging and error aggregation. |
| 3160 | +
|
| 3161 | + Args: |
| 3162 | + tools: list of tool dicts |
| 3163 | + context: caller context, e.g. "oauth" to tailor errors/messages |
| 3164 | + """ |
| 3165 | + valid_tools: list[ToolCreate] = [] |
| 3166 | + validation_errors: list[str] = [] |
| 3167 | + |
| 3168 | + for i, tool_dict in enumerate(tools): |
| 3169 | + tool_name = tool_dict.get("name", f"unknown_tool_{i}") |
| 3170 | + try: |
| 3171 | + logger.debug(f"Validating tool: {tool_name}") |
| 3172 | + validated_tool = ToolCreate.model_validate(tool_dict) |
| 3173 | + valid_tools.append(validated_tool) |
| 3174 | + logger.debug(f"Tool '{tool_name}' validated successfully") |
| 3175 | + except ValidationError as e: |
| 3176 | + error_msg = f"Validation failed for tool '{tool_name}': {e.errors()}" |
| 3177 | + logger.error(error_msg) |
| 3178 | + logger.debug(f"Failed tool schema: {tool_dict}") |
| 3179 | + validation_errors.append(error_msg) |
| 3180 | + except ValueError as e: |
| 3181 | + if "JSON structure exceeds maximum depth" in str(e): |
| 3182 | + error_msg = ( |
| 3183 | + f"Tool '{tool_name}' schema too deeply nested. " |
| 3184 | + f"Current depth limit: {settings.validation_max_json_depth}" |
| 3185 | + ) |
| 3186 | + logger.error(error_msg) |
| 3187 | + logger.warning("Consider increasing VALIDATION_MAX_JSON_DEPTH environment variable") |
| 3188 | + else: |
| 3189 | + error_msg = f"ValueError for tool '{tool_name}': {str(e)}" |
| 3190 | + logger.error(error_msg) |
| 3191 | + validation_errors.append(error_msg) |
| 3192 | + except Exception as e: # pragma: no cover - defensive |
| 3193 | + error_msg = f"Unexpected error validating tool '{tool_name}': {type(e).__name__}: {str(e)}" |
| 3194 | + logger.error(error_msg, exc_info=True) |
| 3195 | + validation_errors.append(error_msg) |
| 3196 | + |
| 3197 | + if validation_errors: |
| 3198 | + logger.warning( |
| 3199 | + f"Tool validation completed with {len(validation_errors)} error(s). " |
| 3200 | + f"Successfully validated {len(valid_tools)} tool(s)." |
| 3201 | + ) |
| 3202 | + for err in validation_errors[:3]: |
| 3203 | + logger.debug(f"Validation error: {err}") |
| 3204 | + |
| 3205 | + if not valid_tools and validation_errors: |
| 3206 | + if context == "oauth": |
| 3207 | + raise OAuthToolValidationError( |
| 3208 | + f"OAuth tool fetch failed: all {len(tools)} tools failed validation. " |
| 3209 | + f"First error: {validation_errors[0][:200]}" |
| 3210 | + ) |
| 3211 | + raise GatewayConnectionError( |
| 3212 | + f"Failed to fetch tools: All {len(tools)} tools failed validation. " |
| 3213 | + f"First error: {validation_errors[0][:200]}" |
| 3214 | + ) |
| 3215 | + |
| 3216 | + return valid_tools |
| 3217 | + |
3149 | 3218 | async def _connect_to_sse_server_without_validation(self, server_url: str, authentication: Optional[Dict[str, str]] = None): |
3150 | 3219 | """Connect to an MCP server running with SSE transport, skipping URL validation. |
3151 | 3220 |
|
@@ -3173,9 +3242,11 @@ async def _connect_to_sse_server_without_validation(self, server_url: str, authe |
3173 | 3242 |
|
3174 | 3243 | response = await session.list_tools() |
3175 | 3244 | tools = response.tools |
3176 | | - tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] |
| 3245 | + tools = [ |
| 3246 | + tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools |
| 3247 | + ] |
3177 | 3248 |
|
3178 | | - tools = [ToolCreate.model_validate(tool) for tool in tools] |
| 3249 | + tools = self._validate_tools(tools, context="oauth") |
3179 | 3250 | if tools: |
3180 | 3251 | logger.info(f"Fetched {len(tools)} tools from gateway") |
3181 | 3252 | # Fetch resources if supported |
@@ -3316,9 +3387,11 @@ def get_httpx_client_factory( |
3316 | 3387 |
|
3317 | 3388 | response = await session.list_tools() |
3318 | 3389 | tools = response.tools |
3319 | | - tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] |
| 3390 | + tools = [ |
| 3391 | + tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools |
| 3392 | + ] |
3320 | 3393 |
|
3321 | | - tools = [ToolCreate.model_validate(tool) for tool in tools] |
| 3394 | + tools = self._validate_tools(tools) |
3322 | 3395 | if tools: |
3323 | 3396 | logger.info(f"Fetched {len(tools)} tools from gateway") |
3324 | 3397 | # Fetch resources if supported |
@@ -3456,9 +3529,11 @@ def get_httpx_client_factory( |
3456 | 3529 |
|
3457 | 3530 | response = await session.list_tools() |
3458 | 3531 | tools = response.tools |
3459 | | - tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] |
| 3532 | + tools = [ |
| 3533 | + tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools |
| 3534 | + ] |
3460 | 3535 |
|
3461 | | - tools = [ToolCreate.model_validate(tool) for tool in tools] |
| 3536 | + tools = self._validate_tools(tools) |
3462 | 3537 | for tool in tools: |
3463 | 3538 | tool.request_type = "STREAMABLEHTTP" |
3464 | 3539 | if tools: |
|
0 commit comments