diff --git a/kinde_fastapi/framework/fastapi_framework.py b/kinde_fastapi/framework/fastapi_framework.py index 8dfe0668..04fb09d0 100644 --- a/kinde_fastapi/framework/fastapi_framework.py +++ b/kinde_fastapi/framework/fastapi_framework.py @@ -55,7 +55,6 @@ def start(self) -> None: """ Start the framework. This method initializes any necessary FastAPI components and registers Kinde routes. - This method initializes any necessary FastAPI components and registers Kinde routes. """ if not self._initialized: # Add framework middleware @@ -131,7 +130,15 @@ async def get_current_user(): @self.app.get("/login") async def login(request: Request): """Redirect to Kinde login page.""" - url=await self._oauth.login() + # Build login options from query parameters + login_options = {} + + # Check for invitation_code in query parameters + invitation_code = request.query_params.get('invitation_code') + if invitation_code: + login_options['invitation_code'] = invitation_code + + url = await self._oauth.login(login_options) self._logger.warning(f"[Login] Session is: {request.session}") return RedirectResponse(url=url) diff --git a/kinde_flask/framework/flask_framework.py b/kinde_flask/framework/flask_framework.py index 2b2b9246..a0ce4cdb 100644 --- a/kinde_flask/framework/flask_framework.py +++ b/kinde_flask/framework/flask_framework.py @@ -127,7 +127,16 @@ def _register_kinde_routes(self) -> None: def login(): """Redirect to Kinde login page.""" loop = asyncio.get_event_loop() - login_url = loop.run_until_complete(self._oauth.login()) + + # Build login options from query parameters + login_options = {} + + # Check for invitation_code in query parameters + invitation_code = request.args.get('invitation_code') + if invitation_code: + login_options['invitation_code'] = invitation_code + + login_url = loop.run_until_complete(self._oauth.login(login_options)) return redirect(login_url) # Callback route diff --git a/kinde_sdk/auth/login_options.py b/kinde_sdk/auth/login_options.py index 7ca43da4..3c87ec06 100644 --- a/kinde_sdk/auth/login_options.py +++ b/kinde_sdk/auth/login_options.py @@ -27,6 +27,10 @@ class LoginOptions: PLAN_INTEREST = "plan_interest" PRICING_TABLE_KEY = "pricing_table_key" + # Invitation parameters + INVITATION_CODE = "invitation_code" + IS_INVITATION = "is_invitation" + # Additional parameters container AUTH_PARAMS = "auth_params" SUPPORT_RE_AUTH = "supports_reauth" \ No newline at end of file diff --git a/kinde_sdk/auth/oauth.py b/kinde_sdk/auth/oauth.py index 1c22eb99..174a438b 100644 --- a/kinde_sdk/auth/oauth.py +++ b/kinde_sdk/auth/oauth.py @@ -280,6 +280,9 @@ async def generate_auth_url( # Registration params LoginOptions.PLAN_INTEREST: "plan_interest", LoginOptions.PRICING_TABLE_KEY: "pricing_table_key", + # Invitation params + LoginOptions.INVITATION_CODE: "invitation_code", + LoginOptions.IS_INVITATION: "is_invitation", # Re-authentication support LoginOptions.SUPPORT_RE_AUTH: "supports_reauth", } @@ -307,10 +310,19 @@ async def generate_auth_url( # Handle boolean parameters if option_key == LoginOptions.IS_CREATE_ORG or option_key == LoginOptions.HAS_SUCCESS_PAGE: search_params[param_key] = "true" if login_options[option_key] else "false" + elif option_key == LoginOptions.IS_INVITATION: + # Only add is_invitation if it's truthy + if login_options[option_key]: + search_params[param_key] = "true" else: # Use string representation for query params search_params[param_key] = str(login_options[option_key]) + # Handle invitation code: automatically set is_invitation to "true" when invitation_code is present + if LoginOptions.INVITATION_CODE in login_options and login_options[LoginOptions.INVITATION_CODE]: + if LoginOptions.IS_INVITATION not in login_options or not login_options[LoginOptions.IS_INVITATION]: + search_params["is_invitation"] = "true" + # Add additional auth parameters if LoginOptions.AUTH_PARAMS in login_options and isinstance(login_options[LoginOptions.AUTH_PARAMS], dict): for key, value in login_options[LoginOptions.AUTH_PARAMS].items(): diff --git a/testv2/testv2_auth/test_invitation_code.py b/testv2/testv2_auth/test_invitation_code.py new file mode 100644 index 00000000..86215a80 --- /dev/null +++ b/testv2/testv2_auth/test_invitation_code.py @@ -0,0 +1,204 @@ +import unittest +import asyncio +from unittest.mock import patch, MagicMock +from urllib.parse import urlparse, parse_qs + +from kinde_sdk.auth.oauth import OAuth +from kinde_sdk.auth.enums import IssuerRouteTypes +from kinde_sdk.auth.login_options import LoginOptions + + +def run_async(coro): + """Helper function to run async tests""" + return asyncio.run(coro) + + +class TestInvitationCode(unittest.TestCase): + """Tests for invitation code support in OAuth login flow.""" + + @patch("requests.get") + def setUp(self, mock_get): + """Set up test fixtures.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "authorization_endpoint": "https://example.com/oauth2/auth", + "token_endpoint": "https://example.com/oauth2/token", + "end_session_endpoint": "https://example.com/logout", + "userinfo_endpoint": "https://example.com/oauth2/userinfo", + } + mock_get.return_value = mock_response + + self.oauth = OAuth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="http://localhost:8000/callback", + host="https://test.kinde.com", + ) + self.mock_storage = MagicMock() + self.oauth._session_manager = MagicMock() + self.oauth._session_manager.storage_manager = self.mock_storage + self.oauth.auth_url = "https://example.com/oauth2/auth" + + # -- LoginOptions constants -- + + def test_login_options_has_invitation_code_constant(self): + """INVITATION_CODE constant is defined on LoginOptions.""" + self.assertEqual(LoginOptions.INVITATION_CODE, "invitation_code") + + def test_login_options_has_is_invitation_constant(self): + """IS_INVITATION constant is defined on LoginOptions.""" + self.assertEqual(LoginOptions.IS_INVITATION, "is_invitation") + + # -- generate_auth_url: invitation_code -- + + def test_invitation_code_appears_in_auth_url(self): + """invitation_code is forwarded as a query parameter.""" + result = run_async( + self.oauth.generate_auth_url( + login_options={LoginOptions.INVITATION_CODE: "abc123"} + ) + ) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["invitation_code"], ["abc123"]) + + def test_invitation_code_auto_sets_is_invitation(self): + """is_invitation is automatically set to 'true' when invitation_code is present.""" + result = run_async( + self.oauth.generate_auth_url( + login_options={LoginOptions.INVITATION_CODE: "abc123"} + ) + ) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["is_invitation"], ["true"]) + + def test_invitation_code_with_explicit_is_invitation_true(self): + """Explicit is_invitation=True is honoured alongside invitation_code.""" + result = run_async( + self.oauth.generate_auth_url( + login_options={ + LoginOptions.INVITATION_CODE: "abc123", + LoginOptions.IS_INVITATION: True, + } + ) + ) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["invitation_code"], ["abc123"]) + self.assertEqual(params["is_invitation"], ["true"]) + + def test_invitation_code_with_explicit_is_invitation_false(self): + """When invitation_code is present but is_invitation is explicitly False, + the auto-set logic still adds is_invitation='true'.""" + result = run_async( + self.oauth.generate_auth_url( + login_options={ + LoginOptions.INVITATION_CODE: "abc123", + LoginOptions.IS_INVITATION: False, + } + ) + ) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["is_invitation"], ["true"]) + + def test_is_invitation_alone_without_invitation_code(self): + """is_invitation=True can be set independently of invitation_code.""" + result = run_async( + self.oauth.generate_auth_url( + login_options={LoginOptions.IS_INVITATION: True} + ) + ) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["is_invitation"], ["true"]) + self.assertNotIn("invitation_code", params) + + def test_is_invitation_false_alone_not_in_url(self): + """is_invitation=False without invitation_code produces no is_invitation param.""" + result = run_async( + self.oauth.generate_auth_url( + login_options={LoginOptions.IS_INVITATION: False} + ) + ) + params = parse_qs(urlparse(result["url"]).query) + self.assertNotIn("is_invitation", params) + + def test_no_invitation_params_by_default(self): + """Neither invitation_code nor is_invitation appear when not requested.""" + result = run_async(self.oauth.generate_auth_url(login_options={})) + params = parse_qs(urlparse(result["url"]).query) + self.assertNotIn("invitation_code", params) + self.assertNotIn("is_invitation", params) + + def test_invitation_code_empty_string_not_set(self): + """An empty invitation_code does not add is_invitation.""" + result = run_async( + self.oauth.generate_auth_url( + login_options={LoginOptions.INVITATION_CODE: ""} + ) + ) + params = parse_qs(urlparse(result["url"]).query) + self.assertNotIn("is_invitation", params) + + def test_invitation_code_none_not_set(self): + """invitation_code=None is ignored.""" + result = run_async( + self.oauth.generate_auth_url( + login_options={LoginOptions.INVITATION_CODE: None} + ) + ) + params = parse_qs(urlparse(result["url"]).query) + self.assertNotIn("invitation_code", params) + self.assertNotIn("is_invitation", params) + + # -- login() wrapper -- + + def test_login_passes_invitation_code(self): + """login() forwards invitation_code to the generated URL.""" + url = run_async( + self.oauth.login( + login_options={LoginOptions.INVITATION_CODE: "inv_xyz"} + ) + ) + params = parse_qs(urlparse(url).query) + self.assertEqual(params["invitation_code"], ["inv_xyz"]) + self.assertEqual(params["is_invitation"], ["true"]) + + def test_login_without_invitation_code(self): + """login() without invitation options produces no invitation params.""" + url = run_async(self.oauth.login()) + params = parse_qs(urlparse(url).query) + self.assertNotIn("invitation_code", params) + self.assertNotIn("is_invitation", params) + + # -- register() wrapper -- + + def test_register_passes_invitation_code(self): + """register() forwards invitation_code to the generated URL.""" + url = run_async( + self.oauth.register( + login_options={LoginOptions.INVITATION_CODE: "inv_reg"} + ) + ) + params = parse_qs(urlparse(url).query) + self.assertEqual(params["invitation_code"], ["inv_reg"]) + self.assertEqual(params["is_invitation"], ["true"]) + + # -- Coexistence with other params -- + + def test_invitation_code_with_org_code(self): + """invitation_code works alongside org_code.""" + result = run_async( + self.oauth.generate_auth_url( + login_options={ + LoginOptions.INVITATION_CODE: "abc123", + LoginOptions.ORG_CODE: "org_456", + } + ) + ) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["invitation_code"], ["abc123"]) + self.assertEqual(params["is_invitation"], ["true"]) + self.assertEqual(params["org_code"], ["org_456"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/testv2/testv2_auth/test_oauth.py b/testv2/testv2_auth/test_oauth.py index 3c84bce6..72483a07 100644 --- a/testv2/testv2_auth/test_oauth.py +++ b/testv2/testv2_auth/test_oauth.py @@ -131,6 +131,110 @@ def mock_get_side_effect(key): self.oauth._framework = self.mock_framework + # -- Invitation code tests -- + + def test_generate_auth_url_with_invitation_code(self): + """invitation_code is included as a query parameter in the auth URL.""" + result = run_async(self.oauth.generate_auth_url( + login_options={LoginOptions.INVITATION_CODE: "abc123"} + )) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["invitation_code"], ["abc123"]) + + def test_generate_auth_url_invitation_code_auto_sets_is_invitation(self): + """is_invitation is automatically set to 'true' when invitation_code is present.""" + result = run_async(self.oauth.generate_auth_url( + login_options={LoginOptions.INVITATION_CODE: "abc123"} + )) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["is_invitation"], ["true"]) + + def test_generate_auth_url_invitation_code_with_explicit_is_invitation(self): + """Explicit is_invitation=True is honoured alongside invitation_code.""" + result = run_async(self.oauth.generate_auth_url( + login_options={ + LoginOptions.INVITATION_CODE: "abc123", + LoginOptions.IS_INVITATION: True, + } + )) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["invitation_code"], ["abc123"]) + self.assertEqual(params["is_invitation"], ["true"]) + + def test_generate_auth_url_invitation_code_overrides_false_is_invitation(self): + """invitation_code forces is_invitation='true' even when explicitly set to False.""" + result = run_async(self.oauth.generate_auth_url( + login_options={ + LoginOptions.INVITATION_CODE: "abc123", + LoginOptions.IS_INVITATION: False, + } + )) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["is_invitation"], ["true"]) + + def test_generate_auth_url_is_invitation_alone(self): + """is_invitation=True works independently of invitation_code.""" + result = run_async(self.oauth.generate_auth_url( + login_options={LoginOptions.IS_INVITATION: True} + )) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["is_invitation"], ["true"]) + self.assertNotIn("invitation_code", params) + + def test_generate_auth_url_no_invitation_params_by_default(self): + """Neither invitation_code nor is_invitation appear when not specified.""" + result = run_async(self.oauth.generate_auth_url(login_options={})) + params = parse_qs(urlparse(result["url"]).query) + self.assertNotIn("invitation_code", params) + self.assertNotIn("is_invitation", params) + + def test_generate_auth_url_empty_invitation_code_ignored(self): + """An empty invitation_code does not add is_invitation.""" + result = run_async(self.oauth.generate_auth_url( + login_options={LoginOptions.INVITATION_CODE: ""} + )) + params = parse_qs(urlparse(result["url"]).query) + self.assertNotIn("is_invitation", params) + + def test_generate_auth_url_none_invitation_code_ignored(self): + """invitation_code=None is ignored entirely.""" + result = run_async(self.oauth.generate_auth_url( + login_options={LoginOptions.INVITATION_CODE: None} + )) + params = parse_qs(urlparse(result["url"]).query) + self.assertNotIn("invitation_code", params) + self.assertNotIn("is_invitation", params) + + def test_login_passes_invitation_code(self): + """login() forwards invitation_code to the generated auth URL.""" + url = run_async(self.oauth.login( + login_options={LoginOptions.INVITATION_CODE: "inv_xyz"} + )) + params = parse_qs(urlparse(url).query) + self.assertEqual(params["invitation_code"], ["inv_xyz"]) + self.assertEqual(params["is_invitation"], ["true"]) + + def test_login_without_invitation_code(self): + """login() without invitation options produces no invitation params.""" + url = run_async(self.oauth.login()) + params = parse_qs(urlparse(url).query) + self.assertNotIn("invitation_code", params) + self.assertNotIn("is_invitation", params) + + def test_invitation_code_coexists_with_org_code(self): + """invitation_code and org_code can be used together.""" + result = run_async(self.oauth.generate_auth_url( + login_options={ + LoginOptions.INVITATION_CODE: "abc123", + LoginOptions.ORG_CODE: "org_456", + } + )) + params = parse_qs(urlparse(result["url"]).query) + self.assertEqual(params["invitation_code"], ["abc123"]) + self.assertEqual(params["is_invitation"], ["true"]) + self.assertEqual(params["org_code"], ["org_456"]) + + class TestOAuthMethodSignatures(unittest.TestCase): """Test OAuth method signatures to verify the fix for incorrect request parameter passing.""" diff --git a/testv2/testv2_framework/__init__.py b/testv2/testv2_framework/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testv2/testv2_framework/test_fastapi_framework.py b/testv2/testv2_framework/test_fastapi_framework.py new file mode 100644 index 00000000..36f07678 --- /dev/null +++ b/testv2/testv2_framework/test_fastapi_framework.py @@ -0,0 +1,63 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi import FastAPI +from starlette.testclient import TestClient +from starlette.middleware.sessions import SessionMiddleware + +from kinde_fastapi.framework.fastapi_framework import FastAPIFramework + + +class TestFastAPIFramework(unittest.TestCase): + """Tests for the FastAPI framework implementation.""" + + def setUp(self): + self.app = FastAPI() + self.app.add_middleware(SessionMiddleware, secret_key="test-secret") + + self.framework = FastAPIFramework(app=self.app) + + # Mock OAuth so no real auth happens + self.mock_oauth = MagicMock() + self.mock_oauth.login = AsyncMock(return_value="https://kinde.example.com/authorize") + self.framework._oauth = self.mock_oauth + + # Register routes + self.framework._register_kinde_routes() + + self.client = TestClient(self.app, follow_redirects=False) + + def test_login_with_invitation_code(self): + """invitation_code query param is forwarded to oauth.login().""" + resp = self.client.get("/login?invitation_code=inv_abc123") + + self.mock_oauth.login.assert_called_once() + login_options = self.mock_oauth.login.call_args[0][0] + self.assertEqual(login_options["invitation_code"], "inv_abc123") + + def test_login_without_invitation_code(self): + """No invitation_code means oauth.login() gets an empty dict.""" + resp = self.client.get("/login") + + self.mock_oauth.login.assert_called_once() + login_options = self.mock_oauth.login.call_args[0][0] + self.assertNotIn("invitation_code", login_options) + + def test_login_with_empty_invitation_code(self): + """An empty invitation_code query param is not forwarded.""" + resp = self.client.get("/login?invitation_code=") + + self.mock_oauth.login.assert_called_once() + login_options = self.mock_oauth.login.call_args[0][0] + self.assertNotIn("invitation_code", login_options) + + def test_login_redirects_to_oauth_url(self): + """The route returns a redirect to the URL from oauth.login().""" + resp = self.client.get("/login?invitation_code=inv_xyz") + + self.assertEqual(resp.status_code, 307) + self.assertEqual(resp.headers["location"], "https://kinde.example.com/authorize") + + +if __name__ == "__main__": + unittest.main() diff --git a/testv2/testv2_framework/test_flask_framework.py b/testv2/testv2_framework/test_flask_framework.py new file mode 100644 index 00000000..0b311244 --- /dev/null +++ b/testv2/testv2_framework/test_flask_framework.py @@ -0,0 +1,66 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, patch +from flask import Flask + +from kinde_flask.framework.flask_framework import FlaskFramework + + +class TestFlaskFramework(unittest.TestCase): + """Tests for the Flask framework implementation.""" + + def setUp(self): + self.app = Flask(__name__) + self.app.config["SECRET_KEY"] = "test-secret" + self.app.config["TESTING"] = True + + self.framework = FlaskFramework(app=self.app) + + # Mock OAuth so no real auth happens + self.mock_oauth = MagicMock() + self.mock_oauth.login = AsyncMock(return_value="https://kinde.example.com/authorize") + self.framework._oauth = self.mock_oauth + + # Register routes + self.framework._initialized = False + self.framework._register_kinde_routes() + + self.client = self.app.test_client() + + def test_login_with_invitation_code(self): + """invitation_code query param is forwarded to oauth.login().""" + with self.app.test_request_context(): + resp = self.client.get("/login?invitation_code=inv_abc123") + + self.mock_oauth.login.assert_called_once() + login_options = self.mock_oauth.login.call_args[0][0] + self.assertEqual(login_options["invitation_code"], "inv_abc123") + + def test_login_without_invitation_code(self): + """No invitation_code means oauth.login() gets an empty dict.""" + with self.app.test_request_context(): + resp = self.client.get("/login") + + self.mock_oauth.login.assert_called_once() + login_options = self.mock_oauth.login.call_args[0][0] + self.assertNotIn("invitation_code", login_options) + + def test_login_with_empty_invitation_code(self): + """An empty invitation_code query param is not forwarded.""" + with self.app.test_request_context(): + resp = self.client.get("/login?invitation_code=") + + self.mock_oauth.login.assert_called_once() + login_options = self.mock_oauth.login.call_args[0][0] + self.assertNotIn("invitation_code", login_options) + + def test_login_redirects_to_oauth_url(self): + """The route returns a redirect to the URL from oauth.login().""" + with self.app.test_request_context(): + resp = self.client.get("/login?invitation_code=inv_xyz") + + self.assertEqual(resp.status_code, 302) + self.assertEqual(resp.headers["Location"], "https://kinde.example.com/authorize") + + +if __name__ == "__main__": + unittest.main()