Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,15 @@ To use restricted mode, replace `--access-mode=unrestricted` with `--access-mode

Postgres MCP Pro includes DNS rebinding protection to secure the server against certain types of attacks.
By default, the server allows connections from common local and Docker hostnames.
You can customize this behavior using environment variables:
Transport security applies only to network transports (`sse` and `streamable-http`), not `stdio`.

- **`MCP_ENABLE_DNS_REBINDING_PROTECTION`**: Controls whether DNS rebinding protection is enabled. Set to `false` to disable. Default: `true`.
- **`MCP_ALLOWED_HOSTS`**: Comma-separated list of allowed host patterns. Default: `localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*`.
- **`MCP_ALLOWED_ORIGINS`**: Comma-separated list of allowed origins. Default: empty (allows any origin).
You can customize this behavior using CLI flags or environment variables (env vars take precedence over CLI flags):

| CLI Flag | Environment Variable | Description | Default |
|---|---|---|---|
| `--disable-dns-rebinding-protection` | `MCP_ENABLE_DNS_REBINDING_PROTECTION` | Enable/disable DNS rebinding protection | Enabled |
| `--allowed-hosts` | `MCP_ALLOWED_HOSTS` | Comma-separated allowed host patterns | `localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*` |
| `--allowed-origins` | `MCP_ALLOWED_ORIGINS` | Comma-separated allowed origins | Empty (allows any origin) |

For example, to restrict allowed hosts in your configuration:

Expand Down
57 changes: 34 additions & 23 deletions src/postgres_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,8 @@
from .sql import obfuscate_password
from .top_queries import TopQueriesCalc

# Initialize FastMCP with transport security settings
# Configure to allow requests from localhost, 127.0.0.1, and Docker container names
# These can be customized via environment variables:
# - MCP_ENABLE_DNS_REBINDING_PROTECTION: Set to "false" to disable (default: "true")
# - MCP_ALLOWED_HOSTS: Comma-separated list of allowed hosts (default: localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*)
# - MCP_ALLOWED_ORIGINS: Comma-separated list of allowed origins (default: empty, allows any origin)

dns_rebinding_protection = os.environ.get("MCP_ENABLE_DNS_REBINDING_PROTECTION", "true").lower() == "true"

default_allowed_hosts = "localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*"
allowed_hosts_str = os.environ.get("MCP_ALLOWED_HOSTS", default_allowed_hosts)
allowed_hosts = [host.strip() for host in allowed_hosts_str.split(",") if host.strip()]

allowed_origins_str = os.environ.get("MCP_ALLOWED_ORIGINS", "")
allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",") if origin.strip()]

transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=dns_rebinding_protection,
allowed_hosts=allowed_hosts,
allowed_origins=allowed_origins,
)

mcp = FastMCP("postgres-mcp", transport_security=transport_security)
# Initialize FastMCP with default settings
mcp = FastMCP("postgres-mcp")

# Constants
PG_STAT_STATEMENTS = "pg_stat_statements"
Expand Down Expand Up @@ -618,6 +597,24 @@ async def main():
default=8000,
help="Port for streamable HTTP server (default: 8000)",
)
parser.add_argument(
"--disable-dns-rebinding-protection",
action="store_true",
default=False,
help="Disable DNS rebinding protection (not recommended for production)",
)
parser.add_argument(
"--allowed-hosts",
type=str,
default=None,
help="Comma-separated allowed Host header values for DNS rebinding protection (e.g. 'localhost:*,127.0.0.1:*')",
)
parser.add_argument(
"--allowed-origins",
type=str,
default=None,
help="Comma-separated allowed Origin header values for DNS rebinding protection (e.g. 'http://localhost:*')",
)

args = parser.parse_args()

Expand Down Expand Up @@ -678,6 +675,20 @@ async def main():
logger.warning("Signal handling not supported on Windows")
pass

# Apply transport security settings (SSE and streamable-http only)
if args.transport in ("sse", "streamable-http"):
dns_env = os.environ.get("MCP_ENABLE_DNS_REBINDING_PROTECTION")
protection_off = dns_env.lower() in ("false", "0", "no") if dns_env else args.disable_dns_rebinding_protection
hosts = os.environ.get("MCP_ALLOWED_HOSTS", args.allowed_hosts)
origins = os.environ.get("MCP_ALLOWED_ORIGINS", args.allowed_origins)

if protection_off or hosts or origins:
mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=not protection_off,
**{"allowed_hosts": [h.strip() for h in hosts.split(",") if h.strip()]} if hosts else {},
**{"allowed_origins": [o.strip() for o in origins.split(",") if o.strip()]} if origins else {},
)

# Run the server with the selected transport (always async)
if args.transport == "stdio":
await mcp.run_stdio_async()
Expand Down
224 changes: 224 additions & 0 deletions tests/unit/test_transport_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import sys
from unittest.mock import AsyncMock
from unittest.mock import patch

import pytest

_TRANSPORT_MOCK_MAP = {
"sse": "postgres_mcp.server.mcp.run_sse_async",
"streamable-http": "postgres_mcp.server.mcp.run_streamable_http_async",
}

_MCP_ENV_KEYS = [
"MCP_ENABLE_DNS_REBINDING_PROTECTION",
"MCP_ALLOWED_HOSTS",
"MCP_ALLOWED_ORIGINS",
]


@pytest.mark.parametrize("transport", ["sse", "streamable-http"])
class TestTransportSecurityIntegration:
@pytest.fixture(autouse=True)
def _preserve_mcp_state(self, monkeypatch: pytest.MonkeyPatch):
from postgres_mcp.server import mcp

original_argv = sys.argv
original_security = mcp.settings.transport_security
for key in _MCP_ENV_KEYS:
monkeypatch.delenv(key, raising=False)
yield
sys.argv = original_argv
mcp.settings.transport_security = original_security

@pytest.mark.asyncio
async def test_disable_dns_rebinding_via_cli_flag(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--disable-dns-rebinding-protection",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
):
await main()
assert mcp.settings.transport_security is not None
assert mcp.settings.transport_security.enable_dns_rebinding_protection is False

@pytest.mark.asyncio
async def test_disable_dns_rebinding_via_env(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
patch.dict("os.environ", {"MCP_ENABLE_DNS_REBINDING_PROTECTION": "false"}),
):
await main()
assert mcp.settings.transport_security is not None
assert mcp.settings.transport_security.enable_dns_rebinding_protection is False

@pytest.mark.asyncio
async def test_allowed_hosts_via_cli(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--allowed-hosts",
"localhost:*,127.0.0.1:*",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
):
await main()
assert mcp.settings.transport_security is not None
assert "localhost:*" in mcp.settings.transport_security.allowed_hosts
assert "127.0.0.1:*" in mcp.settings.transport_security.allowed_hosts

@pytest.mark.asyncio
async def test_allowed_hosts_env_overrides_cli(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--allowed-hosts",
"cli-host:*",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
patch.dict("os.environ", {"MCP_ALLOWED_HOSTS": "env-host:*"}),
):
await main()
assert mcp.settings.transport_security is not None
assert "env-host:*" in mcp.settings.transport_security.allowed_hosts
assert "cli-host:*" not in mcp.settings.transport_security.allowed_hosts

@pytest.mark.asyncio
async def test_allowed_origins_via_cli(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--allowed-origins",
"http://localhost:*",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
):
await main()
assert mcp.settings.transport_security is not None
assert "http://localhost:*" in mcp.settings.transport_security.allowed_origins

@pytest.mark.asyncio
async def test_allowed_origins_env_overrides_cli(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--allowed-origins",
"http://cli-origin:*",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
patch.dict("os.environ", {"MCP_ALLOWED_ORIGINS": "http://env-origin:*"}),
):
await main()
assert mcp.settings.transport_security is not None
assert "http://env-origin:*" in mcp.settings.transport_security.allowed_origins
assert "http://cli-origin:*" not in mcp.settings.transport_security.allowed_origins

@pytest.mark.asyncio
async def test_env_protection_true_overrides_cli_disable(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--disable-dns-rebinding-protection",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
patch.dict("os.environ", {"MCP_ENABLE_DNS_REBINDING_PROTECTION": "true"}),
):
await main()
assert mcp.settings.transport_security is not None
assert mcp.settings.transport_security.enable_dns_rebinding_protection is True

@pytest.mark.asyncio
async def test_default_defers_to_fastmcp(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
):
await main()
assert mcp.settings.transport_security is not None
assert mcp.settings.transport_security.enable_dns_rebinding_protection is True

@pytest.mark.asyncio
async def test_database_url_after_flags_not_consumed(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
f"--transport={transport}",
"--allowed-hosts",
"localhost:*,my-gateway:8080",
"--allowed-origins",
"http://localhost:*,http://my-gateway:*",
"postgresql://user:password@localhost/db",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
):
await main()
assert mcp.settings.transport_security is not None
assert "localhost:*" in mcp.settings.transport_security.allowed_hosts
assert "my-gateway:8080" in mcp.settings.transport_security.allowed_hosts