Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion mcpgateway/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,17 @@ def validate_database(self) -> None:
validation_max_description_length: int = 8192 # 8KB
validation_max_template_length: int = 65536 # 64KB
validation_max_content_length: int = 1048576 # 1MB
validation_max_json_depth: int = 10
validation_max_json_depth: int = Field(
default=int(os.getenv("VALIDATION_MAX_JSON_DEPTH", "30")),
description=(
"Maximum allowed JSON nesting depth for tool/resource schemas. "
"Increased from 10 to 30 for compatibility with deeply nested schemas "
"like Notion MCP (issue #1542). Override with VALIDATION_MAX_JSON_DEPTH "
"environment variable. Minimum: 1, Maximum: 100"
),
ge=1,
le=100,
)
validation_max_url_length: int = 2048
validation_max_rpc_param_size: int = 262144 # 256KB

Expand Down
94 changes: 88 additions & 6 deletions mcpgateway/services/gateway_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from pydantic import ValidationError
from sqlalchemy import and_, or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -272,6 +273,10 @@ class GatewayConnectionError(GatewayError):
"""


class OAuthToolValidationError(GatewayConnectionError):
"""Raised when tool validation fails during OAuth-driven fetch."""


class GatewayService: # pylint: disable=too-many-instance-attributes
"""Service for managing federated gateways.

Expand Down Expand Up @@ -1110,6 +1115,10 @@ async def fetch_tools_after_oauth(self, db: Session, gateway_id: str, app_user_e

return {"capabilities": capabilities, "tools": tools, "resources": resources, "prompts": prompts}

except GatewayConnectionError as gce:
# Surface validation or depth-related failures directly to the user
logger.error(f"GatewayConnectionError during OAuth fetch for {gateway_id}: {gce}")
raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(gce)}")
except Exception as e:
logger.error(f"Failed to fetch tools after OAuth for gateway {gateway_id}: {e}")
raise GatewayConnectionError(f"Failed to fetch tools after OAuth: {str(e)}")
Expand Down Expand Up @@ -3514,6 +3523,73 @@ async def _publish_event(self, event: Dict[str, Any]) -> None:
"""
await self._event_service.publish_event(event)

def _validate_tools(self, tools: list[dict[str, Any]], context: str = "default") -> list[ToolCreate]:
"""Validate tools individually with richer logging and error aggregation.

Args:
tools: list of tool dicts
context: caller context, e.g. "oauth" to tailor errors/messages

Returns:
list[ToolCreate]: List of successfully validated tools

Raises:
OAuthToolValidationError: If all tools fail validation in OAuth context
GatewayConnectionError: If all tools fail validation in default context
"""
valid_tools: list[ToolCreate] = []
validation_errors: list[str] = []

for i, tool_dict in enumerate(tools):
tool_name = tool_dict.get("name", f"unknown_tool_{i}")
try:
logger.debug(f"Validating tool: {tool_name}")
validated_tool = ToolCreate.model_validate(tool_dict)
valid_tools.append(validated_tool)
logger.debug(f"Tool '{tool_name}' validated successfully")
except ValidationError as e:
error_msg = f"Validation failed for tool '{tool_name}': {e.errors()}"
logger.error(error_msg)
logger.debug(f"Failed tool schema: {tool_dict}")
validation_errors.append(error_msg)
except ValueError as e:
if "JSON structure exceeds maximum depth" in str(e):
error_msg = (
f"Tool '{tool_name}' schema too deeply nested. "
f"Current depth limit: {settings.validation_max_json_depth}"
)
logger.error(error_msg)
logger.warning("Consider increasing VALIDATION_MAX_JSON_DEPTH environment variable")
else:
error_msg = f"ValueError for tool '{tool_name}': {str(e)}"
logger.error(error_msg)
validation_errors.append(error_msg)
except Exception as e: # pragma: no cover - defensive
error_msg = f"Unexpected error validating tool '{tool_name}': {type(e).__name__}: {str(e)}"
logger.error(error_msg, exc_info=True)
validation_errors.append(error_msg)

if validation_errors:
logger.warning(
f"Tool validation completed with {len(validation_errors)} error(s). "
f"Successfully validated {len(valid_tools)} tool(s)."
)
for err in validation_errors[:3]:
logger.debug(f"Validation error: {err}")

if not valid_tools and validation_errors:
if context == "oauth":
raise OAuthToolValidationError(
f"OAuth tool fetch failed: all {len(tools)} tools failed validation. "
f"First error: {validation_errors[0][:200]}"
)
raise GatewayConnectionError(
f"Failed to fetch tools: All {len(tools)} tools failed validation. "
f"First error: {validation_errors[0][:200]}"
)

return valid_tools

async def _connect_to_sse_server_without_validation(self, server_url: str, authentication: Optional[Dict[str, str]] = None):
"""Connect to an MCP server running with SSE transport, skipping URL validation.

Expand Down Expand Up @@ -3541,9 +3617,11 @@ async def _connect_to_sse_server_without_validation(self, server_url: str, authe

response = await session.list_tools()
tools = response.tools
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
tools = [
tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools
]

tools = [ToolCreate.model_validate(tool) for tool in tools]
tools = self._validate_tools(tools, context="oauth")
if tools:
logger.info(f"Fetched {len(tools)} tools from gateway")
# Fetch resources if supported
Expand Down Expand Up @@ -3684,9 +3762,11 @@ def get_httpx_client_factory(

response = await session.list_tools()
tools = response.tools
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
tools = [
tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools
]

tools = [ToolCreate.model_validate(tool) for tool in tools]
tools = self._validate_tools(tools)
if tools:
logger.info(f"Fetched {len(tools)} tools from gateway")
# Fetch resources if supported
Expand Down Expand Up @@ -3824,9 +3904,11 @@ def get_httpx_client_factory(

response = await session.list_tools()
tools = response.tools
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
tools = [
tool.model_dump(by_alias=True, exclude_none=True, exclude_unset=True) for tool in tools
]

tools = [ToolCreate.model_validate(tool) for tool in tools]
tools = self._validate_tools(tools)
for tool in tools:
tool.request_type = "STREAMABLEHTTP"
if tools:
Expand Down
Loading