From ce9d985cd3970cfddea11b68a5a3f3a86718bbc6 Mon Sep 17 00:00:00 2001 From: Eli Shteinman <7198754@gmail.com> Date: Fri, 13 Feb 2026 10:46:42 +0200 Subject: [PATCH 1/3] feat: add CLI flags and tests for transport security configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add --disable-dns-rebinding-protection, --allowed-hosts, --allowed-origins CLI flags - Move transport security from module-level init to main() (after argparse) - Apply transport security only for SSE and streamable-http transports (not stdio) - Env vars (POSTGRES_MCP_*) override CLI flags when both are set - Add comprehensive test suite: 10 scenarios × 2 transports = 20 tests Co-Authored-By: Claude Opus 4.6 --- src/postgres_mcp/server.py | 57 ++++--- tests/unit/test_transport_security.py | 217 ++++++++++++++++++++++++++ 2 files changed, 251 insertions(+), 23 deletions(-) create mode 100644 tests/unit/test_transport_security.py diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index cc2b9a19..f644843b 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -35,29 +35,8 @@ from .sql import obfuscate_password from .top_queries import TopQueriesCalc -# Initialize FastMCP with transport security settings -# Configure to allow requests from localhost, 127.0.0.1, and Docker container names -# These can be customized via environment variables: -# - MCP_ENABLE_DNS_REBINDING_PROTECTION: Set to "false" to disable (default: "true") -# - MCP_ALLOWED_HOSTS: Comma-separated list of allowed hosts (default: localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*) -# - MCP_ALLOWED_ORIGINS: Comma-separated list of allowed origins (default: empty, allows any origin) - -dns_rebinding_protection = os.environ.get("MCP_ENABLE_DNS_REBINDING_PROTECTION", "true").lower() == "true" - -default_allowed_hosts = "localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*" -allowed_hosts_str = os.environ.get("MCP_ALLOWED_HOSTS", default_allowed_hosts) -allowed_hosts = [host.strip() for host in allowed_hosts_str.split(",") if host.strip()] - -allowed_origins_str = os.environ.get("MCP_ALLOWED_ORIGINS", "") -allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",") if origin.strip()] - -transport_security = TransportSecuritySettings( - enable_dns_rebinding_protection=dns_rebinding_protection, - allowed_hosts=allowed_hosts, - allowed_origins=allowed_origins, -) - -mcp = FastMCP("postgres-mcp", transport_security=transport_security) +# Initialize FastMCP with default settings +mcp = FastMCP("postgres-mcp") # Constants PG_STAT_STATEMENTS = "pg_stat_statements" @@ -618,6 +597,24 @@ async def main(): default=8000, help="Port for streamable HTTP server (default: 8000)", ) + parser.add_argument( + "--disable-dns-rebinding-protection", + action="store_true", + default=False, + help="Disable DNS rebinding protection (not recommended for production)", + ) + parser.add_argument( + "--allowed-hosts", + type=str, + default=None, + help="Comma-separated allowed Host header values for DNS rebinding protection (e.g. 'localhost:*,127.0.0.1:*')", + ) + parser.add_argument( + "--allowed-origins", + type=str, + default=None, + help="Comma-separated allowed Origin header values for DNS rebinding protection (e.g. 'http://localhost:*')", + ) args = parser.parse_args() @@ -678,6 +675,20 @@ async def main(): logger.warning("Signal handling not supported on Windows") pass + # Apply transport security settings (SSE and streamable-http only) + if args.transport in ("sse", "streamable-http"): + dns_env = os.environ.get("POSTGRES_MCP_DNS_REBINDING_PROTECTION") + protection_off = dns_env.lower() in ("false", "0", "no") if dns_env else args.disable_dns_rebinding_protection + hosts = os.environ.get("POSTGRES_MCP_ALLOWED_HOSTS", args.allowed_hosts) + origins = os.environ.get("POSTGRES_MCP_ALLOWED_ORIGINS", args.allowed_origins) + + if protection_off or hosts or origins: + mcp.settings.transport_security = TransportSecuritySettings( + enable_dns_rebinding_protection=not protection_off, + **{"allowed_hosts": [h.strip() for h in hosts.split(",") if h.strip()]} if hosts else {}, + **{"allowed_origins": [o.strip() for o in origins.split(",") if o.strip()]} if origins else {}, + ) + # Run the server with the selected transport (always async) if args.transport == "stdio": await mcp.run_stdio_async() diff --git a/tests/unit/test_transport_security.py b/tests/unit/test_transport_security.py new file mode 100644 index 00000000..d763ad4e --- /dev/null +++ b/tests/unit/test_transport_security.py @@ -0,0 +1,217 @@ +import sys +from unittest.mock import AsyncMock +from unittest.mock import patch + +import pytest + +_TRANSPORT_MOCK_MAP = { + "sse": "postgres_mcp.server.mcp.run_sse_async", + "streamable-http": "postgres_mcp.server.mcp.run_streamable_http_async", +} + + +@pytest.mark.parametrize("transport", ["sse", "streamable-http"]) +class TestTransportSecurityIntegration: + @pytest.fixture(autouse=True) + def _preserve_mcp_state(self): + from postgres_mcp.server import mcp + + original_argv = sys.argv + original_security = mcp.settings.transport_security + yield + sys.argv = original_argv + mcp.settings.transport_security = original_security + + @pytest.mark.asyncio + async def test_disable_dns_rebinding_via_cli_flag(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--disable-dns-rebinding-protection", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is False + + @pytest.mark.asyncio + async def test_disable_dns_rebinding_via_env(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"POSTGRES_MCP_DNS_REBINDING_PROTECTION": "false"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is False + + @pytest.mark.asyncio + async def test_allowed_hosts_via_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-hosts", + "localhost:*,127.0.0.1:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert "localhost:*" in mcp.settings.transport_security.allowed_hosts + assert "127.0.0.1:*" in mcp.settings.transport_security.allowed_hosts + + @pytest.mark.asyncio + async def test_allowed_hosts_env_overrides_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-hosts", + "cli-host:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"POSTGRES_MCP_ALLOWED_HOSTS": "env-host:*"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert "env-host:*" in mcp.settings.transport_security.allowed_hosts + assert "cli-host:*" not in mcp.settings.transport_security.allowed_hosts + + @pytest.mark.asyncio + async def test_allowed_origins_via_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-origins", + "http://localhost:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert "http://localhost:*" in mcp.settings.transport_security.allowed_origins + + @pytest.mark.asyncio + async def test_allowed_origins_env_overrides_cli(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--allowed-origins", + "http://cli-origin:*", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"POSTGRES_MCP_ALLOWED_ORIGINS": "http://env-origin:*"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert "http://env-origin:*" in mcp.settings.transport_security.allowed_origins + assert "http://cli-origin:*" not in mcp.settings.transport_security.allowed_origins + + @pytest.mark.asyncio + async def test_env_protection_true_overrides_cli_disable(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + "--disable-dns-rebinding-protection", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + patch.dict("os.environ", {"POSTGRES_MCP_DNS_REBINDING_PROTECTION": "true"}), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is True + + @pytest.mark.asyncio + async def test_default_defers_to_fastmcp(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + f"--transport={transport}", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is True + assert "localhost:*" in mcp.settings.transport_security.allowed_hosts + + @pytest.mark.asyncio + async def test_database_url_after_flags_not_consumed(self, transport: str): + from postgres_mcp.server import main + from postgres_mcp.server import mcp + + sys.argv = [ + "postgres_mcp", + f"--transport={transport}", + "--allowed-hosts", + "localhost:*,my-gateway:8080", + "--allowed-origins", + "http://localhost:*,http://my-gateway:*", + "postgresql://user:password@localhost/db", + ] + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), + ): + await main() + assert mcp.settings.transport_security is not None + assert "localhost:*" in mcp.settings.transport_security.allowed_hosts + assert "my-gateway:8080" in mcp.settings.transport_security.allowed_hosts From 6c5357b29c797cb22b05c4d5ac81d169b1d29d07 Mon Sep 17 00:00:00 2001 From: Eli Shteinman <7198754@gmail.com> Date: Fri, 13 Feb 2026 10:50:46 +0200 Subject: [PATCH 2/3] fix: use MCP_* env var prefix instead of POSTGRES_MCP_* Align with the shorter MCP_* naming convention used in the original PR. Co-Authored-By: Claude Opus 4.6 --- src/postgres_mcp/server.py | 6 +++--- tests/unit/test_transport_security.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index f644843b..3c559d6d 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -677,10 +677,10 @@ async def main(): # Apply transport security settings (SSE and streamable-http only) if args.transport in ("sse", "streamable-http"): - dns_env = os.environ.get("POSTGRES_MCP_DNS_REBINDING_PROTECTION") + dns_env = os.environ.get("MCP_DNS_REBINDING_PROTECTION") protection_off = dns_env.lower() in ("false", "0", "no") if dns_env else args.disable_dns_rebinding_protection - hosts = os.environ.get("POSTGRES_MCP_ALLOWED_HOSTS", args.allowed_hosts) - origins = os.environ.get("POSTGRES_MCP_ALLOWED_ORIGINS", args.allowed_origins) + hosts = os.environ.get("MCP_ALLOWED_HOSTS", args.allowed_hosts) + origins = os.environ.get("MCP_ALLOWED_ORIGINS", args.allowed_origins) if protection_off or hosts or origins: mcp.settings.transport_security = TransportSecuritySettings( diff --git a/tests/unit/test_transport_security.py b/tests/unit/test_transport_security.py index d763ad4e..945a3ab5 100644 --- a/tests/unit/test_transport_security.py +++ b/tests/unit/test_transport_security.py @@ -56,7 +56,7 @@ async def test_disable_dns_rebinding_via_env(self, transport: str): with ( patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), - patch.dict("os.environ", {"POSTGRES_MCP_DNS_REBINDING_PROTECTION": "false"}), + patch.dict("os.environ", {"MCP_DNS_REBINDING_PROTECTION": "false"}), ): await main() assert mcp.settings.transport_security is not None @@ -100,7 +100,7 @@ async def test_allowed_hosts_env_overrides_cli(self, transport: str): with ( patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), - patch.dict("os.environ", {"POSTGRES_MCP_ALLOWED_HOSTS": "env-host:*"}), + patch.dict("os.environ", {"MCP_ALLOWED_HOSTS": "env-host:*"}), ): await main() assert mcp.settings.transport_security is not None @@ -144,7 +144,7 @@ async def test_allowed_origins_env_overrides_cli(self, transport: str): with ( patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), - patch.dict("os.environ", {"POSTGRES_MCP_ALLOWED_ORIGINS": "http://env-origin:*"}), + patch.dict("os.environ", {"MCP_ALLOWED_ORIGINS": "http://env-origin:*"}), ): await main() assert mcp.settings.transport_security is not None @@ -166,7 +166,7 @@ async def test_env_protection_true_overrides_cli_disable(self, transport: str): with ( patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), - patch.dict("os.environ", {"POSTGRES_MCP_DNS_REBINDING_PROTECTION": "true"}), + patch.dict("os.environ", {"MCP_DNS_REBINDING_PROTECTION": "true"}), ): await main() assert mcp.settings.transport_security is not None From dd887e55282fa5670974d958a191af2f12c26cac Mon Sep 17 00:00:00 2001 From: Eli Shteinman <7198754@gmail.com> Date: Fri, 13 Feb 2026 11:00:40 +0200 Subject: [PATCH 3/3] fix: align env var names with original PR and address review feedback - Rename MCP_DNS_REBINDING_PROTECTION to MCP_ENABLE_DNS_REBINDING_PROTECTION - Add monkeypatch fixture to clear MCP_* env vars in tests - Remove coupling to FastMCP upstream defaults in test_default_defers_to_fastmcp - Update README with CLI flags documentation table Co-Authored-By: Claude Opus 4.6 --- README.md | 12 ++++++++---- src/postgres_mcp/server.py | 2 +- tests/unit/test_transport_security.py | 15 +++++++++++---- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index b6f91c7b..5939737c 100644 --- a/README.md +++ b/README.md @@ -231,11 +231,15 @@ To use restricted mode, replace `--access-mode=unrestricted` with `--access-mode Postgres MCP Pro includes DNS rebinding protection to secure the server against certain types of attacks. By default, the server allows connections from common local and Docker hostnames. -You can customize this behavior using environment variables: +Transport security applies only to network transports (`sse` and `streamable-http`), not `stdio`. -- **`MCP_ENABLE_DNS_REBINDING_PROTECTION`**: Controls whether DNS rebinding protection is enabled. Set to `false` to disable. Default: `true`. -- **`MCP_ALLOWED_HOSTS`**: Comma-separated list of allowed host patterns. Default: `localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*`. -- **`MCP_ALLOWED_ORIGINS`**: Comma-separated list of allowed origins. Default: empty (allows any origin). +You can customize this behavior using CLI flags or environment variables (env vars take precedence over CLI flags): + +| CLI Flag | Environment Variable | Description | Default | +|---|---|---|---| +| `--disable-dns-rebinding-protection` | `MCP_ENABLE_DNS_REBINDING_PROTECTION` | Enable/disable DNS rebinding protection | Enabled | +| `--allowed-hosts` | `MCP_ALLOWED_HOSTS` | Comma-separated allowed host patterns | `localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*` | +| `--allowed-origins` | `MCP_ALLOWED_ORIGINS` | Comma-separated allowed origins | Empty (allows any origin) | For example, to restrict allowed hosts in your configuration: diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 3c559d6d..ad407804 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -677,7 +677,7 @@ async def main(): # Apply transport security settings (SSE and streamable-http only) if args.transport in ("sse", "streamable-http"): - dns_env = os.environ.get("MCP_DNS_REBINDING_PROTECTION") + dns_env = os.environ.get("MCP_ENABLE_DNS_REBINDING_PROTECTION") protection_off = dns_env.lower() in ("false", "0", "no") if dns_env else args.disable_dns_rebinding_protection hosts = os.environ.get("MCP_ALLOWED_HOSTS", args.allowed_hosts) origins = os.environ.get("MCP_ALLOWED_ORIGINS", args.allowed_origins) diff --git a/tests/unit/test_transport_security.py b/tests/unit/test_transport_security.py index 945a3ab5..a54ae666 100644 --- a/tests/unit/test_transport_security.py +++ b/tests/unit/test_transport_security.py @@ -9,15 +9,23 @@ "streamable-http": "postgres_mcp.server.mcp.run_streamable_http_async", } +_MCP_ENV_KEYS = [ + "MCP_ENABLE_DNS_REBINDING_PROTECTION", + "MCP_ALLOWED_HOSTS", + "MCP_ALLOWED_ORIGINS", +] + @pytest.mark.parametrize("transport", ["sse", "streamable-http"]) class TestTransportSecurityIntegration: @pytest.fixture(autouse=True) - def _preserve_mcp_state(self): + def _preserve_mcp_state(self, monkeypatch: pytest.MonkeyPatch): from postgres_mcp.server import mcp original_argv = sys.argv original_security = mcp.settings.transport_security + for key in _MCP_ENV_KEYS: + monkeypatch.delenv(key, raising=False) yield sys.argv = original_argv mcp.settings.transport_security = original_security @@ -56,7 +64,7 @@ async def test_disable_dns_rebinding_via_env(self, transport: str): with ( patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), - patch.dict("os.environ", {"MCP_DNS_REBINDING_PROTECTION": "false"}), + patch.dict("os.environ", {"MCP_ENABLE_DNS_REBINDING_PROTECTION": "false"}), ): await main() assert mcp.settings.transport_security is not None @@ -166,7 +174,7 @@ async def test_env_protection_true_overrides_cli_disable(self, transport: str): with ( patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()), - patch.dict("os.environ", {"MCP_DNS_REBINDING_PROTECTION": "true"}), + patch.dict("os.environ", {"MCP_ENABLE_DNS_REBINDING_PROTECTION": "true"}), ): await main() assert mcp.settings.transport_security is not None @@ -190,7 +198,6 @@ async def test_default_defers_to_fastmcp(self, transport: str): await main() assert mcp.settings.transport_security is not None assert mcp.settings.transport_security.enable_dns_rebinding_protection is True - assert "localhost:*" in mcp.settings.transport_security.allowed_hosts @pytest.mark.asyncio async def test_database_url_after_flags_not_consumed(self, transport: str):