diff --git a/docs/usage.rst b/docs/usage.rst index a6b3e55..c84cc47 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -19,3 +19,20 @@ Or using a context manager: from quads_lib import QuadsApi with QuadsApi(username, password, base_url) as quads: hosts = quads.get_hosts() + +TLS Certificate Verification +---------------------------- + +By default, TLS certificate verification is disabled (``verify=False``) for backward compatibility. +You can control certificate verification using the ``verify`` parameter: + +.. code-block:: python + + # Disable verification (default) + quads = QuadsApi(username, password, base_url, verify=False) + + # Enable verification with default CA bundle + quads = QuadsApi(username, password, base_url, verify=True) + + # Use a custom CA bundle file + quads = QuadsApi(username, password, base_url, verify="/path/to/ca-bundle.pem") diff --git a/src/quads_lib/base.py b/src/quads_lib/base.py index be245e2..0f6ae35 100644 --- a/src/quads_lib/base.py +++ b/src/quads_lib/base.py @@ -1,5 +1,6 @@ from json import JSONDecodeError from typing import Optional +from typing import Union from urllib.parse import urljoin from requests import Session @@ -16,10 +17,29 @@ class QuadsBase: Base class for the Quads API """ - def __init__(self, username: str, password: str, base_url: str): + def __init__( + self, + username: str, + password: str, + base_url: str, + verify: Union[bool, str] = False, + ): + """ + Initialize QuadsBase. + + Args: + username: Username for QUADS authentication + password: Password for QUADS authentication + base_url: Base URL for the QUADS API + verify: Controls TLS certificate verification. Can be: + - False: Disable certificate verification (default, for backward compatibility) + - True: Enable verification using default CA bundle + - str: Path to a custom CA bundle file + """ self.username = username self.password = password self.base_url = urljoin(base_url, "api/v3/") + self.verify = verify self.session = Session() retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504]) self.session.mount("http://", HTTPAdapter(max_retries=retries)) @@ -40,7 +60,7 @@ def _make_request(self, method: str, endpoint: str, data: Optional[dict] = None) method, urljoin(self.base_url, endpoint), json=data, - verify=False, + verify=self.verify, ) if _response.status_code == 500: raise APIServerException("Check the flask server logs") diff --git a/src/quads_lib/quads.py b/src/quads_lib/quads.py index 0d22892..591801d 100644 --- a/src/quads_lib/quads.py +++ b/src/quads_lib/quads.py @@ -20,7 +20,7 @@ def register(self) -> dict: def login(self) -> dict: endpoint = urljoin(self.base_url, "login") - _response = self.session.post(endpoint, auth=self.auth, verify=False) + _response = self.session.post(endpoint, auth=self.auth, verify=self.verify) json_response = _response.json() if json_response.get("status_code") == 201: self.token = json_response.get("auth_token") diff --git a/tests/test_quads.py b/tests/test_quads.py index c910a25..cea5671 100644 --- a/tests/test_quads.py +++ b/tests/test_quads.py @@ -1782,9 +1782,15 @@ def test_create_self_assignment_limit_reached(self, mock_request, mock_print): class TestQuadsBase: + @pytest.fixture(autouse=True) + def setup(self): + self.username = "test_user" + self.password = "test_pass" + self.base_url = "http://test.com" + @pytest.fixture def quads_base(self): - return QuadsBase(username="test_user", password="test_pass", base_url="http://test.com") + return QuadsBase(username=self.username, password=self.password, base_url=self.base_url) def test_context_manager_enter(self, quads_base): quads_base.login = Mock() @@ -1800,3 +1806,92 @@ def test_context_manager_exit(self, quads_base): quads_base.logout.assert_called_once() quads_base.session.close.assert_called_once() + + @patch("requests.Session.request") + def test_verify_default_false(self, mock_request): + """Test that verify defaults to False for backward compatibility.""" + api = QuadsApi(self.username, self.password, self.base_url) + expected_response = {"hosts": []} + mock_response = Mock() + mock_response.json.return_value = expected_response + mock_request.return_value = mock_response + + api.get_hosts() + + mock_request.assert_called_once() + # Check that verify=False is passed + assert mock_request.call_args[1]["verify"] is False + + def test_verify_default_false_login(self): + """Test that verify defaults to False in login method.""" + api = QuadsApi(self.username, self.password, self.base_url) + expected_response = {"status_code": 201, "auth_token": "token"} + mock_response = Mock() + mock_response.json.return_value = expected_response + api.session.post = Mock(return_value=mock_response) + + api.login() + + api.session.post.assert_called_once() + # Check that verify=False is passed + assert api.session.post.call_args[1]["verify"] is False + + @patch("requests.Session.request") + def test_verify_true(self, mock_request): + """Test that verify=True is passed correctly.""" + api = QuadsApi(self.username, self.password, self.base_url, verify=True) + expected_response = {"hosts": []} + mock_response = Mock() + mock_response.json.return_value = expected_response + mock_request.return_value = mock_response + + api.get_hosts() + + mock_request.assert_called_once() + # Check that verify=True is passed + assert mock_request.call_args[1]["verify"] is True + + def test_verify_true_login(self): + """Test that verify=True is passed correctly in login method.""" + api = QuadsApi(self.username, self.password, self.base_url, verify=True) + expected_response = {"status_code": 201, "auth_token": "token"} + mock_response = Mock() + mock_response.json.return_value = expected_response + api.session.post = Mock(return_value=mock_response) + + api.login() + + api.session.post.assert_called_once() + # Check that verify=True is passed + assert api.session.post.call_args[1]["verify"] is True + + @patch("requests.Session.request") + def test_verify_custom_ca_bundle(self, mock_request): + """Test that a custom CA bundle path is passed correctly.""" + custom_ca = "/path/to/ca-bundle.pem" + api = QuadsApi(self.username, self.password, self.base_url, verify=custom_ca) + expected_response = {"hosts": []} + mock_response = Mock() + mock_response.json.return_value = expected_response + mock_request.return_value = mock_response + + api.get_hosts() + + mock_request.assert_called_once() + # Check that custom CA bundle path is passed + assert mock_request.call_args[1]["verify"] == custom_ca + + def test_verify_custom_ca_bundle_login(self): + """Test that a custom CA bundle path is passed correctly in login method.""" + custom_ca = "/path/to/ca-bundle.pem" + api = QuadsApi(self.username, self.password, self.base_url, verify=custom_ca) + expected_response = {"status_code": 201, "auth_token": "token"} + mock_response = Mock() + mock_response.json.return_value = expected_response + api.session.post = Mock(return_value=mock_response) + + api.login() + + api.session.post.assert_called_once() + # Check that custom CA bundle path is passed + assert api.session.post.call_args[1]["verify"] == custom_ca