diff --git a/README.md b/README.md index b6f91c7b..5939737c 100644 --- a/README.md +++ b/README.md @@ -231,11 +231,15 @@ To use restricted mode, replace `--access-mode=unrestricted` with `--access-mode Postgres MCP Pro includes DNS rebinding protection to secure the server against certain types of attacks. By default, the server allows connections from common local and Docker hostnames. -You can customize this behavior using environment variables: +Transport security applies only to network transports (`sse` and `streamable-http`), not `stdio`. -- **`MCP_ENABLE_DNS_REBINDING_PROTECTION`**: Controls whether DNS rebinding protection is enabled. Set to `false` to disable. Default: `true`. -- **`MCP_ALLOWED_HOSTS`**: Comma-separated list of allowed host patterns. Default: `localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*`. -- **`MCP_ALLOWED_ORIGINS`**: Comma-separated list of allowed origins. Default: empty (allows any origin). +You can customize this behavior using CLI flags or environment variables (env vars take precedence over CLI flags): + +| CLI Flag | Environment Variable | Description | Default | +|---|---|---|---| +| `--disable-dns-rebinding-protection` | `MCP_ENABLE_DNS_REBINDING_PROTECTION` | Enable/disable DNS rebinding protection | Enabled | +| `--allowed-hosts` | `MCP_ALLOWED_HOSTS` | Comma-separated allowed host patterns | `localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*` | +| `--allowed-origins` | `MCP_ALLOWED_ORIGINS` | Comma-separated allowed origins | Empty (allows any origin) | For example, to restrict allowed hosts in your configuration: diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index cc2b9a19..ad407804 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -35,29 +35,8 @@ from .sql import obfuscate_password from .top_queries import TopQueriesCalc -# Initialize FastMCP with transport security settings -# Configure to allow requests from localhost, 127.0.0.1, and Docker container names -# These can be customized via environment variables: -# - MCP_ENABLE_DNS_REBINDING_PROTECTION: Set to "false" to disable (default: "true") -# - MCP_ALLOWED_HOSTS: Comma-separated list of allowed hosts (default: localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*) -# - MCP_ALLOWED_ORIGINS: Comma-separated list of allowed origins (default: empty, allows any origin) - -dns_rebinding_protection = os.environ.get("MCP_ENABLE_DNS_REBINDING_PROTECTION", "true").lower() == "true" - -default_allowed_hosts = "localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*" -allowed_hosts_str = os.environ.get("MCP_ALLOWED_HOSTS", default_allowed_hosts) -allowed_hosts = [host.strip() for host in allowed_hosts_str.split(",") if host.strip()] - -allowed_origins_str = os.environ.get("MCP_ALLOWED_ORIGINS", "") -allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",") if origin.strip()] - -transport_security = TransportSecuritySettings( - enable_dns_rebinding_protection=dns_rebinding_protection, - allowed_hosts=allowed_hosts, - allowed_origins=allowed_origins, -) - -mcp = FastMCP("postgres-mcp", transport_security=transport_security) +# Initialize FastMCP with default settings +mcp = FastMCP("postgres-mcp") # Constants PG_STAT_STATEMENTS = "pg_stat_statements" @@ -618,6 +597,24 @@ async def main(): default=8000, help="Port for streamable HTTP server (default: 8000)", ) + parser.add_argument( + "--disable-dns-rebinding-protection", + action="store_true", + default=False, + help="Disable DNS rebinding protection (not recommended for production)", + ) + parser.add_argument( + "--allowed-hosts", + type=str, + default=None, + help="Comma-separated allowed Host header values for DNS rebinding protection (e.g. 'localhost:*,127.0.0.1:*')", + ) + parser.add_argument( + "--allowed-origins", + type=str, + default=None, + help="Comma-separated allowed Origin header values for DNS rebinding protection (e.g. 'http://localhost:*')", + ) args = parser.parse_args() @@ -678,6 +675,20 @@ async def main(): logger.warning("Signal handling not supported on Windows") pass + # Apply transport security settings (SSE and streamable-http only) + if args.transport in ("sse", "streamable-http"): + dns_env = os.environ.get("MCP_ENABLE_DNS_REBINDING_PROTECTION") + protection_off = dns_env.lower() in ("false", "0", "no") if dns_env else args.disable_dns_rebinding_protection + hosts = os.environ.get("MCP_ALLOWED_HOSTS", args.allowed_hosts) + origins = os.environ.get("MCP_ALLOWED_ORIGINS", args.allowed_origins) + + if protection_off or hosts or origins: + mcp.settings.transport_security = TransportSecuritySettings( + enable_dns_rebinding_protection=not protection_off, + **{"allowed_hosts": [h.strip() for h in hosts.split(",") if h.strip()]} if hosts else {}, + **{"allowed_origins": [o.strip() for o in origins.split(",") if o.strip()]} if origins else {}, + ) + # Run the server with the selected transport (always async) if args.transport == "stdio": await mcp.run_stdio_async() diff --git a/tests/unit/test_transport_security.py b/tests/unit/test_transport_security.py new file mode 100644 index 00000000..a54ae666 --- /dev/null +++ b/tests/unit/test_transport_security.py @@ -0,0 +1,224 @@ +import sys +from unittest.mock import AsyncMock +from unittest.mock import patch + +import pytest + +_TRANSPORT_MOCK_MAP = { + "sse": "postgres_mcp.server.mcp.run_sse_async", + "streamable-http": "postgres_mcp.server.mcp.run_streamable_http_async", +} + +_MCP_ENV_KEYS = [ + "MCP_ENABLE_DNS_REBINDING_PROTECTION", + "MCP_ALLOWED_HOSTS", + "MCP_ALLOWED_ORIGINS", +] + + +@pytest.mark.parametrize("transport", ["sse", "streamable-http"]) +class TestTransportSecurityIntegration: + @pytest.fixture(autouse=True) + def _preserve_mcp_state(self, monkeypatch: pytest.MonkeyPatch): + from postgres_mcp.server import mcp + + original_argv = sys.argv + original_security = mcp.settings.transport_security + for key in _MCP_ENV_KEYS: + monkeypatch.delenv(key, raising=False) + yield + sys.argv = original_argv + mcp.settings.transport_security = original_security + + @pytest.mark.asyncio + async def test_disable_dns_rebinding_via_cli_flag(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--disable-dns-rebinding-protection", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is False + + @pytest.mark.asyncio + async def test_disable_dns_rebinding_via_env(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"MCP_ENABLE_DNS_REBINDING_PROTECTION": "false"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is False + + @pytest.mark.asyncio + async def test_allowed_hosts_via_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-hosts", + "localhost:*,127.0.0.1:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert "localhost:*" in mcp.settings.transport_security.allowed_hosts + assert "127.0.0.1:*" in mcp.settings.transport_security.allowed_hosts + + @pytest.mark.asyncio + async def test_allowed_hosts_env_overrides_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-hosts", + "cli-host:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"MCP_ALLOWED_HOSTS": "env-host:*"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert "env-host:*" in mcp.settings.transport_security.allowed_hosts + assert "cli-host:*" not in mcp.settings.transport_security.allowed_hosts + + @pytest.mark.asyncio + async def test_allowed_origins_via_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-origins", + "http://localhost:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert "http://localhost:*" in mcp.settings.transport_security.allowed_origins + + @pytest.mark.asyncio + async def test_allowed_origins_env_overrides_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-origins", + "http://cli-origin:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"MCP_ALLOWED_ORIGINS": "http://env-origin:*"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert "http://env-origin:*" in mcp.settings.transport_security.allowed_origins + assert "http://cli-origin:*" not in mcp.settings.transport_security.allowed_origins + + @pytest.mark.asyncio + async def test_env_protection_true_overrides_cli_disable(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--disable-dns-rebinding-protection", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"MCP_ENABLE_DNS_REBINDING_PROTECTION": "true"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is True + + @pytest.mark.asyncio + async def test_default_defers_to_fastmcp(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is True + + @pytest.mark.asyncio + async def test_database_url_after_flags_not_consumed(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + f"--transport={transport}", + "--allowed-hosts", + "localhost:*,my-gateway:8080", + "--allowed-origins", + "http://localhost:*,http://my-gateway:*", + "postgresql://user:password@localhost/db", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert "localhost:*" in mcp.settings.transport_security.allowed_hosts + assert "my-gateway:8080" in mcp.settings.transport_security.allowed_hosts