|
56 | 56 | from mcp import ClientSession |
57 | 57 | from mcp.client.sse import sse_client |
58 | 58 | from mcp.client.streamable_http import streamablehttp_client |
| 59 | +from pydantic import ValidationError |
59 | 60 | from sqlalchemy import and_, or_, select |
60 | 61 | from sqlalchemy.exc import IntegrityError |
61 | 62 | from sqlalchemy.orm import Session |
@@ -272,6 +273,10 @@ class GatewayConnectionError(GatewayError): |
272 | 273 | """ |
273 | 274 |
|
274 | 275 |
|
| 276 | +class OAuthToolValidationError(GatewayConnectionError): |
| 277 | + """Raised when tool validation fails during OAuth-driven fetch.""" |
| 278 | + |
| 279 | + |
275 | 280 | class GatewayService: # pylint: disable=too-many-instance-attributes |
276 | 281 | """Service for managing federated gateways. |
277 | 282 |
|
@@ -1110,6 +1115,10 @@ async def fetch_tools_after_oauth(self, db: Session, gateway_id: str, app_user_e |
1110 | 1115 |
|
1111 | 1116 | return {"capabilities": capabilities, "tools": tools, "resources": resources, "prompts": prompts} |
1112 | 1117 |
|
| 1118 | + except GatewayConnectionError as gce: |
| 1119 | + # Surface validation or depth-related failures directly to the user |
| 1120 | + logger.error(f"GatewayConnectionError during OAuth fetch for {gateway_id}: {gce}") |
| 1121 | + raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(gce)}") |
1113 | 1122 | except Exception as e: |
1114 | 1123 | logger.error(f"Failed to fetch tools after OAuth for gateway {gateway_id}: {e}") |
1115 | 1124 | raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(e)}") |
@@ -3514,6 +3523,73 @@ async def _publish_event(self, event: Dict[str, Any]) -> None: |
3514 | 3523 | """ |
3515 | 3524 | await self._event_service.publish_event(event) |
3516 | 3525 |
|
| 3526 | + def _validate_tools(self, tools: list[dict[str, Any]], context: str = "default") -> list[ToolCreate]: |
| 3527 | + """Validate tools individually with richer logging and error aggregation. |
| 3528 | +
|
| 3529 | + Args: |
| 3530 | + tools: list of tool dicts |
| 3531 | + context: caller context, e.g. "oauth" to tailor errors/messages |
| 3532 | +
|
| 3533 | + Returns: |
| 3534 | + list[ToolCreate]: List of successfully validated tools |
| 3535 | +
|
| 3536 | + Raises: |
| 3537 | + OAuthToolValidationError: If all tools fail validation in OAuth context |
| 3538 | + GatewayConnectionError: If all tools fail validation in default context |
| 3539 | + """ |
| 3540 | + valid_tools: list[ToolCreate] = [] |
| 3541 | + validation_errors: list[str] = [] |
| 3542 | + |
| 3543 | + for i, tool_dict in enumerate(tools): |
| 3544 | + tool_name = tool_dict.get("name", f"unknown_tool_{i}") |
| 3545 | + try: |
| 3546 | + logger.debug(f"Validating tool: {tool_name}") |
| 3547 | + validated_tool = ToolCreate.model_validate(tool_dict) |
| 3548 | + valid_tools.append(validated_tool) |
| 3549 | + logger.debug(f"Tool '{tool_name}' validated successfully") |
| 3550 | + except ValidationError as e: |
| 3551 | + error_msg = f"Validation failed for tool '{tool_name}': {e.errors()}" |
| 3552 | + logger.error(error_msg) |
| 3553 | + logger.debug(f"Failed tool schema: {tool_dict}") |
| 3554 | + validation_errors.append(error_msg) |
| 3555 | + except ValueError as e: |
| 3556 | + if "JSON structure exceeds maximum depth" in str(e): |
| 3557 | + error_msg = ( |
| 3558 | + f"Tool '{tool_name}' schema too deeply nested. " |
| 3559 | + f"Current depth limit: {settings.validation_max_json_depth}" |
| 3560 | + ) |
| 3561 | + logger.error(error_msg) |
| 3562 | + logger.warning("Consider increasing VALIDATION_MAX_JSON_DEPTH environment variable") |
| 3563 | + else: |
| 3564 | + error_msg = f"ValueError for tool '{tool_name}': {str(e)}" |
| 3565 | + logger.error(error_msg) |
| 3566 | + validation_errors.append(error_msg) |
| 3567 | + except Exception as e: # pragma: no cover - defensive |
| 3568 | + error_msg = f"Unexpected error validating tool '{tool_name}': {type(e).__name__}: {str(e)}" |
| 3569 | + logger.error(error_msg, exc_info=True) |
| 3570 | + validation_errors.append(error_msg) |
| 3571 | + |
| 3572 | + if validation_errors: |
| 3573 | + logger.warning( |
| 3574 | + f"Tool validation completed with {len(validation_errors)} error(s). " |
| 3575 | + f"Successfully validated {len(valid_tools)} tool(s)." |
| 3576 | + ) |
| 3577 | + for err in validation_errors[:3]: |
| 3578 | + logger.debug(f"Validation error: {err}") |
| 3579 | + |
| 3580 | + if not valid_tools and validation_errors: |
| 3581 | + if context == "oauth": |
| 3582 | + raise OAuthToolValidationError( |
| 3583 | + f"OAuth tool fetch failed: all {len(tools)} tools failed validation. " |
| 3584 | + f"First error: {validation_errors[0][:200]}" |
| 3585 | + ) |
| 3586 | + raise GatewayConnectionError( |
| 3587 | + f"Failed to fetch tools: All {len(tools)} tools failed validation. " |
| 3588 | + f"First error: {validation_errors[0][:200]}" |
| 3589 | + ) |
| 3590 | + |
| 3591 | + return valid_tools |
| 3592 | + |
3517 | 3593 | async def _connect_to_sse_server_without_validation(self, server_url: str, authentication: Optional[Dict[str, str]] = None): |
3518 | 3594 | """Connect to an MCP server running with SSE transport, skipping URL validation. |
3519 | 3595 |
|
@@ -3541,9 +3617,11 @@ async def _connect_to_sse_server_without_validation(self, server_url: str, authe |
3541 | 3617 |
|
3542 | 3618 | response = await session.list_tools() |
3543 | 3619 | tools = response.tools |
3544 | | - tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] |
| 3620 | + tools = [ |
| 3621 | + tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools |
| 3622 | + ] |
3545 | 3623 |
|
3546 | | - tools = [ToolCreate.model_validate(tool) for tool in tools] |
| 3624 | + tools = self._validate_tools(tools, context="oauth") |
3547 | 3625 | if tools: |
3548 | 3626 | logger.info(f"Fetched {len(tools)} tools from gateway") |
3549 | 3627 | # Fetch resources if supported |
@@ -3684,9 +3762,11 @@ def get_httpx_client_factory( |
3684 | 3762 |
|
3685 | 3763 | response = await session.list_tools() |
3686 | 3764 | tools = response.tools |
3687 | | - tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] |
| 3765 | + tools = [ |
| 3766 | + tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools |
| 3767 | + ] |
3688 | 3768 |
|
3689 | | - tools = [ToolCreate.model_validate(tool) for tool in tools] |
| 3769 | + tools = self._validate_tools(tools) |
3690 | 3770 | if tools: |
3691 | 3771 | logger.info(f"Fetched {len(tools)} tools from gateway") |
3692 | 3772 | # Fetch resources if supported |
@@ -3824,9 +3904,11 @@ def get_httpx_client_factory( |
3824 | 3904 |
|
3825 | 3905 | response = await session.list_tools() |
3826 | 3906 | tools = response.tools |
3827 | | - tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools] |
| 3907 | + tools = [ |
| 3908 | + tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools |
| 3909 | + ] |
3828 | 3910 |
|
3829 | | - tools = [ToolCreate.model_validate(tool) for tool in tools] |
| 3911 | + tools = self._validate_tools(tools) |
3830 | 3912 | for tool in tools: |
3831 | 3913 | tool.request_type = "STREAMABLEHTTP" |
3832 | 3914 | if tools: |
|
0 commit comments