diff --git a/tests/test_aws_opensearch_cli.py b/tests/test_aws_opensearch_cli.py new file mode 100644 index 000000000..6f11622cb --- /dev/null +++ b/tests/test_aws_opensearch_cli.py @@ -0,0 +1,61 @@ +import pytest +from pydantic import SecretStr + +from vectordb_bench.backend.clients.aws_opensearch.cli import optional_secret_str + + +class TestOptionalSecretStr: + """Test cases for the optional_secret_str helper function.""" + + def test_none_input_returns_none(self): + """Test that None input returns None.""" + result = optional_secret_str(None) + assert result is None + + def test_string_input_returns_secret_str(self): + """Test that string input returns SecretStr.""" + test_password = "my_secret_password" + result = optional_secret_str(test_password) + + assert isinstance(result, SecretStr) + assert result.get_secret_value() == test_password + + def test_empty_string_returns_secret_str(self): + """Test that empty string returns SecretStr with empty value.""" + result = optional_secret_str("") + + assert isinstance(result, SecretStr) + assert result.get_secret_value() == "" + + @pytest.mark.parametrize("test_input,expected_value", [ + ("password123", "password123"), + ("", ""), + ("special!@#$%^&*()chars", "special!@#$%^&*()chars"), + (" spaces ", " spaces "), + ("unicode_ñáéíóú", "unicode_ñáéíóú"), + ]) + def test_various_string_inputs(self, test_input, expected_value): + """Test various string inputs return SecretStr with correct values.""" + result = optional_secret_str(test_input) + + assert isinstance(result, SecretStr) + assert result.get_secret_value() == expected_value + + def test_none_vs_empty_string_difference(self): + """Test that None and empty string are handled differently.""" + none_result = optional_secret_str(None) + empty_result = optional_secret_str("") + + assert none_result is None + assert isinstance(empty_result, SecretStr) + assert empty_result.get_secret_value() == "" + + def test_return_type_annotations(self): + """Test that return types match the function signature.""" + # Test None case + none_result = optional_secret_str(None) + assert none_result is None + + # Test string case + string_result = optional_secret_str("test") + assert isinstance(string_result, SecretStr) \ No newline at end of file diff --git a/vectordb_bench/backend/clients/aws_opensearch/cli.py b/vectordb_bench/backend/clients/aws_opensearch/cli.py index a3ddb8712..5bc80a687 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/cli.py +++ b/vectordb_bench/backend/clients/aws_opensearch/cli.py @@ -17,11 +17,16 @@ log = logging.getLogger(__name__) +def optional_secret_str(value: str | None) -> SecretStr | None: + """Convert string to SecretStr, handling None gracefully.""" + return None if value is None else SecretStr(value) + + class AWSOpenSearchTypedDict(TypedDict): host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] port: Annotated[int, click.option("--port", type=int, default=80, help="Db Port")] - user: Annotated[str, click.option("--user", type=str, help="Db User")] - password: Annotated[str, click.option("--password", type=str, help="Db password")] + user: Annotated[str | None, click.option("--user", type=str, help="Db User")] + password: Annotated[str | None, click.option("--password", type=str, help="Db password")] number_of_shards: Annotated[ int, click.option("--number-of-shards", type=int, help="Number of primary shards for the index", default=1), @@ -188,7 +193,7 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]): host=parameters["host"], port=parameters["port"], user=parameters["user"], - password=SecretStr(parameters["password"]), + password=optional_secret_str(parameters["password"]), ), db_case_config=AWSOpenSearchIndexConfig( number_of_shards=parameters["number_of_shards"], diff --git a/vectordb_bench/backend/clients/aws_opensearch/config.py b/vectordb_bench/backend/clients/aws_opensearch/config.py index ff87f66bf..5ab63010d 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/config.py +++ b/vectordb_bench/backend/clients/aws_opensearch/config.py @@ -1,7 +1,7 @@ import logging from enum import Enum -from pydantic import BaseModel, SecretStr +from pydantic import BaseModel, SecretStr, validator from ..api import DBCaseConfig, DBConfig, MetricType @@ -11,13 +11,15 @@ class AWSOpenSearchConfig(DBConfig, BaseModel): host: str = "" port: int = 80 - user: str = "" - password: SecretStr = "" + user: str | None = None + password: SecretStr | None = None def to_dict(self) -> dict: use_ssl = self.port == 443 http_auth = ( - (self.user, self.password.get_secret_value()) if len(self.user) != 0 and len(self.password) != 0 else () + (self.user, self.password.get_secret_value()) + if self.user is not None and self.password is not None and len(self.user) != 0 and len(self.password) != 0 + else () ) return { "hosts": [{"host": self.host, "port": self.port}], @@ -30,6 +32,18 @@ def to_dict(self) -> dict: "timeout": 600, } + @validator("*") + def not_empty_field(cls, v: any, field: any): + if ( + field.name in cls.common_short_configs() + or field.name in cls.common_long_configs() + or field.name in ["user", "password", "host"] + ): + return v + if isinstance(v, str | SecretStr) and len(v) == 0: + raise ValueError("Empty string!") + return v + class AWSOS_Engine(Enum): faiss = "faiss"