Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions tests/test_aws_opensearch_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
from pydantic import SecretStr

from vectordb_bench.backend.clients.aws_opensearch.cli import optional_secret_str


class TestOptionalSecretStr:
"""Test cases for the optional_secret_str helper function."""

def test_none_input_returns_none(self):
"""Test that None input returns None."""
result = optional_secret_str(None)
assert result is None

def test_string_input_returns_secret_str(self):
"""Test that string input returns SecretStr."""
test_password = "my_secret_password"
result = optional_secret_str(test_password)

assert isinstance(result, SecretStr)
assert result.get_secret_value() == test_password

def test_empty_string_returns_secret_str(self):
"""Test that empty string returns SecretStr with empty value."""
result = optional_secret_str("")

assert isinstance(result, SecretStr)
assert result.get_secret_value() == ""

@pytest.mark.parametrize("test_input,expected_value", [
("password123", "password123"),
("", ""),
("special!@#$%^&*()chars", "special!@#$%^&*()chars"),
(" spaces ", " spaces "),
("unicode_ñáéíóú", "unicode_ñáéíóú"),
])
def test_various_string_inputs(self, test_input, expected_value):
"""Test various string inputs return SecretStr with correct values."""
result = optional_secret_str(test_input)

assert isinstance(result, SecretStr)
assert result.get_secret_value() == expected_value

def test_none_vs_empty_string_difference(self):
"""Test that None and empty string are handled differently."""
none_result = optional_secret_str(None)
empty_result = optional_secret_str("")

assert none_result is None
assert isinstance(empty_result, SecretStr)
assert empty_result.get_secret_value() == ""

def test_return_type_annotations(self):
"""Test that return types match the function signature."""
# Test None case
none_result = optional_secret_str(None)
assert none_result is None

# Test string case
string_result = optional_secret_str("test")
assert isinstance(string_result, SecretStr)
11 changes: 8 additions & 3 deletions vectordb_bench/backend/clients/aws_opensearch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
log = logging.getLogger(__name__)


def optional_secret_str(value: str | None) -> SecretStr | None:
"""Convert string to SecretStr, handling None gracefully."""
return None if value is None else SecretStr(value)


class AWSOpenSearchTypedDict(TypedDict):
host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
port: Annotated[int, click.option("--port", type=int, default=80, help="Db Port")]
user: Annotated[str, click.option("--user", type=str, help="Db User")]
password: Annotated[str, click.option("--password", type=str, help="Db password")]
user: Annotated[str | None, click.option("--user", type=str, help="Db User")]
password: Annotated[str | None, click.option("--password", type=str, help="Db password")]
number_of_shards: Annotated[
int,
click.option("--number-of-shards", type=int, help="Number of primary shards for the index", default=1),
Expand Down Expand Up @@ -188,7 +193,7 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
host=parameters["host"],
port=parameters["port"],
user=parameters["user"],
password=SecretStr(parameters["password"]),
password=optional_secret_str(parameters["password"]),
),
db_case_config=AWSOpenSearchIndexConfig(
number_of_shards=parameters["number_of_shards"],
Expand Down
22 changes: 18 additions & 4 deletions vectordb_bench/backend/clients/aws_opensearch/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from enum import Enum

from pydantic import BaseModel, SecretStr
from pydantic import BaseModel, SecretStr, validator

from ..api import DBCaseConfig, DBConfig, MetricType

Expand All @@ -11,13 +11,15 @@
class AWSOpenSearchConfig(DBConfig, BaseModel):
host: str = ""
port: int = 80
user: str = ""
password: SecretStr = ""
user: str | None = None
password: SecretStr | None = None

def to_dict(self) -> dict:
use_ssl = self.port == 443
http_auth = (
(self.user, self.password.get_secret_value()) if len(self.user) != 0 and len(self.password) != 0 else ()
(self.user, self.password.get_secret_value())
if self.user is not None and self.password is not None and len(self.user) != 0 and len(self.password) != 0
else ()
)
return {
"hosts": [{"host": self.host, "port": self.port}],
Expand All @@ -30,6 +32,18 @@ def to_dict(self) -> dict:
"timeout": 600,
}

@validator("*")
def not_empty_field(cls, v: any, field: any):
if (
field.name in cls.common_short_configs()
or field.name in cls.common_long_configs()
or field.name in ["user", "password", "host"]
):
return v
if isinstance(v, str | SecretStr) and len(v) == 0:
raise ValueError("Empty string!")
return v


class AWSOS_Engine(Enum):
faiss = "faiss"
Expand Down
Loading