diff --git a/kinde-core/src/main/java/com/kinde/constants/KindeConstants.java b/kinde-core/src/main/java/com/kinde/constants/KindeConstants.java index ace55dca..7864ba59 100644 --- a/kinde-core/src/main/java/com/kinde/constants/KindeConstants.java +++ b/kinde-core/src/main/java/com/kinde/constants/KindeConstants.java @@ -7,5 +7,6 @@ public class KindeConstants { public final static String ORG_CODE = "org_code"; public final static String LANG = "lang"; public final static String ORG_NAME = "org_name"; + public final static String CONNECTION_ID = "connection_id"; public final static String SCOPE = "openid,email,profile"; } diff --git a/kinde-core/src/main/java/com/kinde/session/KindeRequestParameters.java b/kinde-core/src/main/java/com/kinde/session/KindeRequestParameters.java index 3e75252f..b138edd9 100644 --- a/kinde-core/src/main/java/com/kinde/session/KindeRequestParameters.java +++ b/kinde-core/src/main/java/com/kinde/session/KindeRequestParameters.java @@ -5,4 +5,5 @@ public class KindeRequestParameters { public final static String HAS_SUCCESS_PAGE = "has_success_page"; public final static String LANG = "lang"; public final static String ORG_CODE = "org_code"; + public final static String CONNECTION_ID = "connection_id"; } diff --git a/kinde-core/src/main/java/com/kinde/token/BaseToken.java b/kinde-core/src/main/java/com/kinde/token/BaseToken.java index 613f196e..77cb90ee 100644 --- a/kinde-core/src/main/java/com/kinde/token/BaseToken.java +++ b/kinde-core/src/main/java/com/kinde/token/BaseToken.java @@ -1,6 +1,5 @@ package com.kinde.token; -import com.google.inject.Inject; import com.kinde.accounts.KindeAccountsClient; import com.kinde.accounts.dto.PermissionDto; import com.kinde.accounts.dto.RoleDto; @@ -91,6 +90,33 @@ public Object getClaim(String key) { return this.signedJWT.getJWTClaimsSet().getClaim(key); } + @Override + @SneakyThrows + public String getConnectionId() { + if (this.signedJWT == null) { + return null; + } + + // First, try direct connection_id claim + Object connectionId = getClaim("connection_id"); + if (connectionId instanceof String) { + return (String) connectionId; + } + + // Then, try nested ext_provider.connection_id structure + Object extProvider = getClaim("ext_provider"); + if (extProvider instanceof Map) { + @SuppressWarnings("unchecked") + Map extProviderMap = (Map) extProvider; + Object nestedConnectionId = extProviderMap.get("connection_id"); + if (nestedConnectionId instanceof String) { + return (String) nestedConnectionId; + } + } + + return null; + } + @SuppressWarnings("unchecked") @Override public List getPermissions() { diff --git a/kinde-core/src/main/java/com/kinde/token/KindeToken.java b/kinde-core/src/main/java/com/kinde/token/KindeToken.java index 469f6a81..3627cec3 100644 --- a/kinde-core/src/main/java/com/kinde/token/KindeToken.java +++ b/kinde-core/src/main/java/com/kinde/token/KindeToken.java @@ -17,6 +17,17 @@ public interface KindeToken { Object getClaim(String key); + /** + * Gets the connection ID from the token. + * This method checks for connection_id in the token claims, including nested structures + * like ext_provider.connection_id for external identity providers. + * + * @return The connection ID string, or null if not found + */ + default String getConnectionId() { + return null; + } + List getPermissions(); /** diff --git a/kinde-core/src/test/java/com/kinde/session/ConnectionIdTest.java b/kinde-core/src/test/java/com/kinde/session/ConnectionIdTest.java new file mode 100644 index 00000000..35ddd12c --- /dev/null +++ b/kinde-core/src/test/java/com/kinde/session/ConnectionIdTest.java @@ -0,0 +1,167 @@ +package com.kinde.session; + +import com.kinde.KindeClient; +import com.kinde.KindeClientBuilder; +import com.kinde.KindeClientSession; +import com.kinde.authorization.AuthorizationType; +import com.kinde.authorization.AuthorizationUrl; +import com.kinde.client.KindeCoreGuiceTestModule; +import com.kinde.guice.KindeEnvironmentSingleton; +import com.kinde.guice.KindeGuiceSingleton; +import com.kinde.token.KindeTokenGuiceTestModule; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +public class ConnectionIdTest { + + @BeforeEach + public void setUp() { + KindeGuiceSingleton.fin(); + KindeEnvironmentSingleton.fin(); + KindeEnvironmentSingleton.init(KindeEnvironmentSingleton.State.TEST); + + KindeGuiceSingleton.init( + new KindeCoreGuiceTestModule(), + new KindeTokenGuiceTestModule()); + } + + @Test + @DisplayName("authorizationUrlWithParameters should include connection_id when provided") + public void testAuthorizationUrlWithConnectionId() { + KindeClient kindeClient = KindeClientBuilder.builder() + .domain("http://localhost:8089") + .clientId("test") + .clientSecret("test") + .redirectUri("http://localhost:8080/") + .build(); + + KindeClientSession kindeClientSession = kindeClient.initClientSession("test", null); + + Map parameters = new HashMap<>(); + parameters.put(KindeRequestParameters.CONNECTION_ID, "conn_123456789"); + + AuthorizationUrl authorizationUrl = kindeClientSession.authorizationUrlWithParameters(parameters); + + assertNotNull(authorizationUrl); + assertNotNull(authorizationUrl.getUrl()); + String urlString = authorizationUrl.getUrl().toString(); + assertTrue(urlString.contains("connection_id=conn_123456789"), + "URL should contain connection_id parameter. URL: " + urlString); + } + + @Test + @DisplayName("login should support connection_id via authorizationUrlWithParameters") + public void testLoginWithConnectionId() { + KindeClient kindeClient = KindeClientBuilder.builder() + .domain("http://localhost:8089") + .clientId("test") + .clientSecret("test") + .redirectUri("http://localhost:8080/") + .build(); + + KindeClientSession kindeClientSession = kindeClient.initClientSession("test", null); + + Map parameters = new HashMap<>(); + parameters.put("supports_reauth", "true"); + parameters.put(KindeRequestParameters.CONNECTION_ID, "conn_social_google"); + + AuthorizationUrl authorizationUrl = kindeClientSession.authorizationUrlWithParameters(parameters); + + assertNotNull(authorizationUrl); + assertNotNull(authorizationUrl.getUrl()); + String urlString = authorizationUrl.getUrl().toString(); + assertTrue(urlString.contains("connection_id=conn_social_google"), + "URL should contain connection_id parameter. URL: " + urlString); + assertTrue(urlString.contains("supports_reauth=true"), + "URL should contain supports_reauth parameter. URL: " + urlString); + } + + @Test + @DisplayName("register should support connection_id via authorizationUrlWithParameters") + public void testRegisterWithConnectionId() { + KindeClient kindeClient = KindeClientBuilder.builder() + .domain("http://localhost:8089") + .clientId("test") + .clientSecret("test") + .redirectUri("http://localhost:8080/") + .build(); + + KindeClientSession kindeClientSession = kindeClient.initClientSession("test", null); + + Map parameters = new HashMap<>(); + parameters.put("prompt", "create"); + parameters.put(KindeRequestParameters.CONNECTION_ID, "conn_enterprise_saml"); + + AuthorizationUrl authorizationUrl = kindeClientSession.authorizationUrlWithParameters(parameters); + + assertNotNull(authorizationUrl); + assertNotNull(authorizationUrl.getUrl()); + String urlString = authorizationUrl.getUrl().toString(); + assertTrue(urlString.contains("connection_id=conn_enterprise_saml"), + "URL should contain connection_id parameter. URL: " + urlString); + assertTrue(urlString.contains("prompt=create"), + "URL should contain prompt parameter. URL: " + urlString); + } + + @Test + @DisplayName("connection_id should work with CODE grant type") + public void testConnectionIdWithCodeGrant() { + KindeClient kindeClient = KindeClientBuilder.builder() + .domain("http://localhost:8089") + .clientId("test") + .clientSecret("test") + .redirectUri("http://localhost:8080/") + .grantType(AuthorizationType.CODE) + .build(); + + KindeClientSession kindeClientSession = kindeClient.initClientSession("test", null); + + Map parameters = new HashMap<>(); + parameters.put(KindeRequestParameters.CONNECTION_ID, "conn_123456789"); + + AuthorizationUrl authorizationUrl = kindeClientSession.authorizationUrlWithParameters(parameters); + + assertNotNull(authorizationUrl); + assertNotNull(authorizationUrl.getUrl()); + assertNotNull(authorizationUrl.getCodeVerifier(), "Code verifier should be present for CODE grant type"); + String urlString = authorizationUrl.getUrl().toString(); + assertTrue(urlString.contains("connection_id=conn_123456789"), + "URL should contain connection_id parameter. URL: " + urlString); + } + + @Test + @DisplayName("connection_id should work with other parameters like org_code and lang") + public void testConnectionIdWithOtherParameters() { + KindeClient kindeClient = KindeClientBuilder.builder() + .domain("http://localhost:8089") + .clientId("test") + .clientSecret("test") + .redirectUri("http://localhost:8080/") + .orgCode("ORG123") + .lang("en") + .build(); + + KindeClientSession kindeClientSession = kindeClient.initClientSession("test", null); + + Map parameters = new HashMap<>(); + parameters.put(KindeRequestParameters.CONNECTION_ID, "conn_123456789"); + + AuthorizationUrl authorizationUrl = kindeClientSession.authorizationUrlWithParameters(parameters); + + assertNotNull(authorizationUrl); + assertNotNull(authorizationUrl.getUrl()); + String urlString = authorizationUrl.getUrl().toString(); + assertTrue(urlString.contains("connection_id=conn_123456789"), + "URL should contain connection_id parameter. URL: " + urlString); + assertTrue(urlString.contains("org_code=ORG123"), + "URL should contain org_code parameter. URL: " + urlString); + assertTrue(urlString.contains("lang=en"), + "URL should contain lang parameter. URL: " + urlString); + } +} diff --git a/kinde-core/src/test/java/com/kinde/token/ConnectionIdTokenTest.java b/kinde-core/src/test/java/com/kinde/token/ConnectionIdTokenTest.java new file mode 100644 index 00000000..a41c0964 --- /dev/null +++ b/kinde-core/src/test/java/com/kinde/token/ConnectionIdTokenTest.java @@ -0,0 +1,124 @@ +package com.kinde.token; + +import com.kinde.token.jwt.JwtGenerator; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +public class ConnectionIdTokenTest { + + @Test + @DisplayName("getConnectionId should return connection_id from token when present as direct claim") + public void testGetConnectionIdDirectClaim() throws Exception { + String connectionId = "conn_123456789"; + String tokenString = JwtGenerator.generateIDTokenWithConnectionId(connectionId); + + KindeToken kindeToken = IDToken.init(tokenString, true); + + assertNotNull(kindeToken); + assertTrue(kindeToken.valid()); + assertEquals(connectionId, kindeToken.getConnectionId(), + "getConnectionId() should return the connection_id from the token"); + } + + @Test + @DisplayName("getConnectionId should return connection_id from ext_provider nested structure") + public void testGetConnectionIdFromExtProvider() throws Exception { + String connectionId = "conn_enterprise_saml_789"; + String tokenString = JwtGenerator.generateIDTokenWithExtProviderConnectionId(connectionId); + + KindeToken kindeToken = IDToken.init(tokenString, true); + + assertNotNull(kindeToken); + assertTrue(kindeToken.valid()); + assertEquals(connectionId, kindeToken.getConnectionId(), + "getConnectionId() should return the connection_id from ext_provider.connection_id"); + } + + @Test + @DisplayName("getConnectionId should return null when connection_id is not present") + public void testGetConnectionIdWhenNotPresent() throws Exception { + String tokenString = JwtGenerator.generateIDToken(); + + KindeToken kindeToken = IDToken.init(tokenString, true); + + assertNotNull(kindeToken); + assertTrue(kindeToken.valid()); + assertNull(kindeToken.getConnectionId(), + "getConnectionId() should return null when connection_id is not in the token"); + } + + @Test + @DisplayName("getConnectionId should prefer direct connection_id over nested ext_provider.connection_id") + public void testGetConnectionIdPreferDirectOverNested() throws Exception { + // Create a token with both direct and nested connection_id to test preference + String directConnectionId = "conn_direct_123"; + String nestedConnectionId = "conn_nested_456"; + + String tokenString = JwtGenerator.generateIDTokenWithBothConnectionIds(directConnectionId, nestedConnectionId); + + KindeToken kindeToken = IDToken.init(tokenString, true); + + assertNotNull(kindeToken); + assertTrue(kindeToken.valid()); + assertEquals(directConnectionId, kindeToken.getConnectionId(), + "getConnectionId() should prefer direct connection_id over ext_provider.connection_id"); + } + + @Test + @DisplayName("getConnectionId should work with AccessToken") + public void testGetConnectionIdWithAccessToken() throws Exception { + String connectionId = "conn_access_token_123"; + String tokenString = JwtGenerator.generateIDTokenWithConnectionId(connectionId); + + // AccessToken uses the same BaseToken implementation + KindeToken kindeToken = AccessToken.init(tokenString, true); + + assertNotNull(kindeToken); + assertTrue(kindeToken.valid()); + assertEquals(connectionId, kindeToken.getConnectionId(), + "getConnectionId() should work with AccessToken"); + } + + @Test + @DisplayName("getConnectionId should return null for invalid token") + public void testGetConnectionIdWithInvalidToken() throws Exception { + String tokenString = "invalid.token.string"; + + KindeToken kindeToken = IDToken.init(tokenString, false); + + assertNotNull(kindeToken); + assertFalse(kindeToken.valid()); + assertNull(kindeToken.getConnectionId(), + "getConnectionId() should return null for invalid tokens"); + } + + @Test + @DisplayName("getConnectionId should handle null ext_provider gracefully") + public void testGetConnectionIdWithNullExtProvider() throws Exception { + // Token without ext_provider should work fine + String tokenString = JwtGenerator.generateIDToken(); + + KindeToken kindeToken = IDToken.init(tokenString, true); + + assertNotNull(kindeToken); + assertTrue(kindeToken.valid()); + assertNull(kindeToken.getConnectionId(), + "getConnectionId() should handle missing ext_provider gracefully"); + } + + @Test + @DisplayName("getConnectionId should handle ext_provider without connection_id gracefully") + public void testGetConnectionIdWithExtProviderButNoConnectionId() throws Exception { + // This test verifies that if ext_provider exists but doesn't have connection_id, it returns null + String tokenString = JwtGenerator.generateIDTokenWithExtProviderButNoConnectionId(); + + KindeToken kindeToken = IDToken.init(tokenString, true); + + assertNotNull(kindeToken); + assertTrue(kindeToken.valid()); + assertNull(kindeToken.getConnectionId(), + "getConnectionId() should return null when ext_provider exists but has no connection_id"); + } +} diff --git a/kinde-core/src/test/java/com/kinde/token/jwt/JwtGenerator.java b/kinde-core/src/test/java/com/kinde/token/jwt/JwtGenerator.java index 596a159c..d7c28b24 100644 --- a/kinde-core/src/test/java/com/kinde/token/jwt/JwtGenerator.java +++ b/kinde-core/src/test/java/com/kinde/token/jwt/JwtGenerator.java @@ -230,4 +230,160 @@ public static String refreshToken() { signedJWT.sign(signer); return signedJWT.serialize(); } + + @SneakyThrows + public static String generateIDTokenWithConnectionId(String connectionId) { + RSAKey rsaJWK = new RSAKeyGenerator(2048) + .keyID("123") + .generate(); + + JWSSigner signer = new RSASSASigner(rsaJWK); + Date now = new Date(); + + Map featureFlags = new HashMap<>(); + featureFlags.put("test_str","test_str"); + featureFlags.put("test_integer",Integer.valueOf(1)); + featureFlags.put("test_boolean",Boolean.valueOf(false)); + + JWTClaimsSet jwtClaims = new JWTClaimsSet.Builder() + .issuer("https://openid.net") + .subject("test") + .audience(Arrays.asList("https://kinde.com")) + .expirationTime(new Date(now.getTime() + 1000*60*10)) + .notBeforeTime(now) + .issueTime(now) + .claim("permissions",Arrays.asList("test1","test1")) + .claim("org_codes",Arrays.asList("test1","test1")) + .claim("feature_flags",featureFlags) + .claim("connection_id", connectionId) + .jwtID(UUID.randomUUID().toString()) + .build(); + + SignedJWT signedJWT = new SignedJWT( + new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJWK.getKeyID()).build(), + jwtClaims); + + signedJWT.sign(signer); + return signedJWT.serialize(); + } + + @SneakyThrows + public static String generateIDTokenWithExtProviderConnectionId(String connectionId) { + RSAKey rsaJWK = new RSAKeyGenerator(2048) + .keyID("123") + .generate(); + + JWSSigner signer = new RSASSASigner(rsaJWK); + Date now = new Date(); + + Map featureFlags = new HashMap<>(); + featureFlags.put("test_str","test_str"); + featureFlags.put("test_integer",Integer.valueOf(1)); + featureFlags.put("test_boolean",Boolean.valueOf(false)); + + Map extProvider = new HashMap<>(); + extProvider.put("connection_id", connectionId); + + JWTClaimsSet jwtClaims = new JWTClaimsSet.Builder() + .issuer("https://openid.net") + .subject("test") + .audience(Arrays.asList("https://kinde.com")) + .expirationTime(new Date(now.getTime() + 1000*60*10)) + .notBeforeTime(now) + .issueTime(now) + .claim("permissions",Arrays.asList("test1","test1")) + .claim("org_codes",Arrays.asList("test1","test1")) + .claim("feature_flags",featureFlags) + .claim("ext_provider", extProvider) + .jwtID(UUID.randomUUID().toString()) + .build(); + + SignedJWT signedJWT = new SignedJWT( + new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJWK.getKeyID()).build(), + jwtClaims); + + signedJWT.sign(signer); + return signedJWT.serialize(); + } + + @SneakyThrows + public static String generateIDTokenWithBothConnectionIds(String directConnectionId, String nestedConnectionId) { + RSAKey rsaJWK = new RSAKeyGenerator(2048) + .keyID("123") + .generate(); + + JWSSigner signer = new RSASSASigner(rsaJWK); + Date now = new Date(); + + Map featureFlags = new HashMap<>(); + featureFlags.put("test_str","test_str"); + featureFlags.put("test_integer",Integer.valueOf(1)); + featureFlags.put("test_boolean",Boolean.valueOf(false)); + + Map extProvider = new HashMap<>(); + extProvider.put("connection_id", nestedConnectionId); + + JWTClaimsSet jwtClaims = new JWTClaimsSet.Builder() + .issuer("https://openid.net") + .subject("test") + .audience(Arrays.asList("https://kinde.com")) + .expirationTime(new Date(now.getTime() + 1000*60*10)) + .notBeforeTime(now) + .issueTime(now) + .claim("permissions",Arrays.asList("test1","test1")) + .claim("org_codes",Arrays.asList("test1","test1")) + .claim("feature_flags",featureFlags) + .claim("connection_id", directConnectionId) + .claim("ext_provider", extProvider) + .jwtID(UUID.randomUUID().toString()) + .build(); + + SignedJWT signedJWT = new SignedJWT( + new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJWK.getKeyID()).build(), + jwtClaims); + + signedJWT.sign(signer); + return signedJWT.serialize(); + } + + @SneakyThrows + public static String generateIDTokenWithExtProviderButNoConnectionId() { + RSAKey rsaJWK = new RSAKeyGenerator(2048) + .keyID("123") + .generate(); + + JWSSigner signer = new RSASSASigner(rsaJWK); + Date now = new Date(); + + Map featureFlags = new HashMap<>(); + featureFlags.put("test_str","test_str"); + featureFlags.put("test_integer",Integer.valueOf(1)); + featureFlags.put("test_boolean",Boolean.valueOf(false)); + + // Create ext_provider with other fields but no connection_id + Map extProvider = new HashMap<>(); + extProvider.put("provider", "google"); + extProvider.put("provider_id", "12345"); + + JWTClaimsSet jwtClaims = new JWTClaimsSet.Builder() + .issuer("https://openid.net") + .subject("test") + .audience(Arrays.asList("https://kinde.com")) + .expirationTime(new Date(now.getTime() + 1000*60*10)) + .notBeforeTime(now) + .issueTime(now) + .claim("permissions",Arrays.asList("test1","test1")) + .claim("org_codes",Arrays.asList("test1","test1")) + .claim("feature_flags",featureFlags) + .claim("ext_provider", extProvider) + .jwtID(UUID.randomUUID().toString()) + .build(); + + SignedJWT signedJWT = new SignedJWT( + new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJWK.getKeyID()).build(), + jwtClaims); + + signedJWT.sign(signer); + return signedJWT.serialize(); + } } diff --git a/kinde-j2ee/src/main/java/com/kinde/filter/KindeAuthenticationFilter.java b/kinde-j2ee/src/main/java/com/kinde/filter/KindeAuthenticationFilter.java index 6872b57e..b267b1e5 100644 --- a/kinde-j2ee/src/main/java/com/kinde/filter/KindeAuthenticationFilter.java +++ b/kinde-j2ee/src/main/java/com/kinde/filter/KindeAuthenticationFilter.java @@ -16,9 +16,12 @@ import java.nio.charset.StandardCharsets; import java.security.Principal; import java.util.Base64; +import java.util.HashMap; +import java.util.Map; import static com.kinde.constants.KindeConstants.*; import static com.kinde.constants.KindeJ2eeConstants.*; +import com.kinde.session.KindeRequestParameters; public abstract class KindeAuthenticationFilter implements Filter { protected void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain, KindeAuthenticationAction kindeAuthenticationAction) throws IOException, ServletException { @@ -56,15 +59,50 @@ protected void doFilter(ServletRequest servletRequest, ServletResponse servletRe if (userPrincipal == null || authorizationUrl == null) { // Redirect to the OAuth provider's authorization page KindeClientSession kindeClientSession = createKindeClientSession(req); + + // Build parameters map for connection_id support + Map parameters = new HashMap<>(); + String connectionId = req.getParameter(com.kinde.constants.KindeConstants.CONNECTION_ID); + if (connectionId != null && !connectionId.isEmpty()) { + parameters.put(KindeRequestParameters.CONNECTION_ID, connectionId); + } + if (kindeAuthenticationAction == KindeAuthenticationAction.LOGIN) { - authorizationUrl = kindeClientSession.login(); + if (parameters.isEmpty()) { + authorizationUrl = kindeClientSession.login(); + } else { + Map loginParams = new HashMap<>(parameters); + loginParams.put("supports_reauth", "true"); + authorizationUrl = kindeClientSession.authorizationUrlWithParameters(loginParams); + } } else if (kindeAuthenticationAction == KindeAuthenticationAction.REGISTER) { - authorizationUrl = kindeClientSession.register(); + if (parameters.isEmpty()) { + authorizationUrl = kindeClientSession.register(); + } else { + Map registerParams = new HashMap<>(parameters); + registerParams.put("prompt", "create"); + registerParams.put("supports_reauth", "true"); + authorizationUrl = kindeClientSession.authorizationUrlWithParameters(registerParams); + } } else if (kindeAuthenticationAction == KindeAuthenticationAction.CREATE_ORG) { if (req.getParameter(ORG_NAME) == null) { - throw new ServletException("Must proved org_name query parameter to create an organisation."); + throw new ServletException("Must provide org_name query parameter to create an organisation."); } - authorizationUrl = kindeClientSession.createOrg(req.getParameter(ORG_NAME)); + if (parameters.isEmpty()) { + authorizationUrl = kindeClientSession.createOrg(req.getParameter(ORG_NAME)); + } else { + Map createOrgParams = new HashMap<>(parameters); + createOrgParams.put("prompt", "create"); + createOrgParams.put("is_create_org", Boolean.TRUE.toString()); + createOrgParams.put("org_name", req.getParameter(ORG_NAME)); + authorizationUrl = kindeClientSession.authorizationUrlWithParameters(createOrgParams); + } + } else { + throw new ServletException("Unknown authentication action: " + kindeAuthenticationAction); + } + + if (authorizationUrl == null) { + throw new ServletException("Failed to generate authorization URL"); } req.getSession().setAttribute(AUTHORIZATION_URL,authorizationUrl); resp.sendRedirect(authorizationUrl.getUrl().toString()); @@ -85,9 +123,7 @@ protected void doFilter(ServletRequest servletRequest, ServletResponse servletRe throw new ServletException("OAuth token exchange failed", e); } } else { - if (userPrincipal == null) { - throw new ServletException("Authentication failure as the user principal has not been set correctly"); - } + // userPrincipal is not null here (otherwise we'd be in the first if block) HttpServletRequest wrappedRequest = new KindeHttpRequestWrapper(req, userPrincipal); filterChain.doFilter(wrappedRequest,servletResponse); } diff --git a/kinde-j2ee/src/main/java/com/kinde/servlet/KindeAuthenticationServlet.java b/kinde-j2ee/src/main/java/com/kinde/servlet/KindeAuthenticationServlet.java index 0ca81bcd..123cdf74 100644 --- a/kinde-j2ee/src/main/java/com/kinde/servlet/KindeAuthenticationServlet.java +++ b/kinde-j2ee/src/main/java/com/kinde/servlet/KindeAuthenticationServlet.java @@ -17,9 +17,12 @@ import java.nio.charset.StandardCharsets; import java.security.Principal; import java.util.Base64; +import java.util.HashMap; +import java.util.Map; import static com.kinde.constants.KindeConstants.*; import static com.kinde.constants.KindeJ2eeConstants.*; +import com.kinde.session.KindeRequestParameters; @Slf4j public class KindeAuthenticationServlet extends HttpServlet { @@ -53,16 +56,51 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp, KindeAuth } // Redirect to the OAuth provider's authorization page KindeClientSession kindeClientSession = createKindeClientSession(req); + + // Build parameters map for connection_id support + Map parameters = new HashMap<>(); + String connectionId = req.getParameter(com.kinde.constants.KindeConstants.CONNECTION_ID); + if (connectionId != null && !connectionId.isEmpty()) { + parameters.put(KindeRequestParameters.CONNECTION_ID, connectionId); + } + AuthorizationUrl authorizationUrl = null; if (kindeAuthenticationAction == KindeAuthenticationAction.LOGIN) { - authorizationUrl = kindeClientSession.login(); + if (parameters.isEmpty()) { + authorizationUrl = kindeClientSession.login(); + } else { + Map loginParams = new HashMap<>(parameters); + loginParams.put("supports_reauth", "true"); + authorizationUrl = kindeClientSession.authorizationUrlWithParameters(loginParams); + } } else if (kindeAuthenticationAction == KindeAuthenticationAction.REGISTER) { - authorizationUrl = kindeClientSession.register(); + if (parameters.isEmpty()) { + authorizationUrl = kindeClientSession.register(); + } else { + Map registerParams = new HashMap<>(parameters); + registerParams.put("prompt", "create"); + registerParams.put("supports_reauth", "true"); + authorizationUrl = kindeClientSession.authorizationUrlWithParameters(registerParams); + } } else if (kindeAuthenticationAction == KindeAuthenticationAction.CREATE_ORG) { if (req.getParameter(ORG_NAME) == null) { throw new ServletException("Must provide org_name query parameter to create an organisation."); } - authorizationUrl = kindeClientSession.createOrg(req.getParameter(ORG_NAME)); + if (parameters.isEmpty()) { + authorizationUrl = kindeClientSession.createOrg(req.getParameter(ORG_NAME)); + } else { + Map createOrgParams = new HashMap<>(parameters); + createOrgParams.put("prompt", "create"); + createOrgParams.put("is_create_org", Boolean.TRUE.toString()); + createOrgParams.put("org_name", req.getParameter(ORG_NAME)); + authorizationUrl = kindeClientSession.authorizationUrlWithParameters(createOrgParams); + } + } else { + throw new ServletException("Unknown authentication action: " + kindeAuthenticationAction); + } + + if (authorizationUrl == null) { + throw new ServletException("Failed to generate authorization URL"); } req.getSession().setAttribute(AUTHORIZATION_URL,authorizationUrl); req.getSession().setAttribute(POST_LOGIN_URL,postLoginUrl); diff --git a/kinde-j2ee/src/test/java/com/kinde/filter/ConnectionIdFilterTest.java b/kinde-j2ee/src/test/java/com/kinde/filter/ConnectionIdFilterTest.java new file mode 100644 index 00000000..5cda376a --- /dev/null +++ b/kinde-j2ee/src/test/java/com/kinde/filter/ConnectionIdFilterTest.java @@ -0,0 +1,151 @@ +package com.kinde.filter; + +import com.kinde.KindeClient; +import com.kinde.KindeClientBuilder; +import com.kinde.KindeClientSession; +import com.kinde.authorization.AuthorizationType; +import com.kinde.authorization.AuthorizationUrl; +import com.kinde.constants.KindeAuthenticationAction; +import com.kinde.servlet.KindeSingleton; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.servlet.http.HttpSession; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import java.io.IOException; +import java.net.URL; +import java.security.Principal; +import java.util.Map; + +import static com.kinde.constants.KindeConstants.*; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.*; + +public class ConnectionIdFilterTest { + + @Mock + private HttpServletRequest request; + + @Mock + private HttpServletResponse response; + + @Mock + private FilterChain filterChain; + + @Mock + private HttpSession session; + + @Mock + private KindeSingleton kindeSingleton; + + @Mock + private KindeClientBuilder kindeClientBuilder; + + @Mock + private KindeClient kindeClient; + + @Mock + private KindeClientSession kindeClientSession; + + @Mock + private AuthorizationUrl authorizationUrl; + + private KindeLoginFilter filter; + private MockedStatic kindeSingletonStatic; + + @BeforeEach + public void setUp() { + MockitoAnnotations.openMocks(this); + filter = new KindeLoginFilter(); + + // Static mocking for KindeSingleton + kindeSingletonStatic = Mockito.mockStatic(KindeSingleton.class); + kindeSingletonStatic.when(KindeSingleton::getInstance).thenReturn(kindeSingleton); + + when(request.getSession()).thenReturn(session); + when(request.getRequestURL()).thenReturn(new StringBuffer("http://localhost:8080/test")); + when(session.getAttribute(anyString())).thenReturn(null); + when(request.getParameter("code")).thenReturn(null); + when(request.getParameter("error")).thenReturn(null); + } + + @AfterEach + public void tearDown() { + if (kindeSingletonStatic != null) { + kindeSingletonStatic.close(); + } + } + + @Test + @DisplayName("Filter should include connection_id in authorization URL when provided as request parameter") + public void testFilterWithConnectionId() throws ServletException, IOException { + // Setup + String connectionId = "conn_123456789"; + when(request.getParameter(CONNECTION_ID)).thenReturn(connectionId); + when(request.getParameter("org_code")).thenReturn(null); + when(request.getParameter("lang")).thenReturn(null); + + // Mock KindeSingleton chain + when(kindeSingleton.getKindeClientBuilder()).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.redirectUri(anyString())).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.grantType(any(AuthorizationType.class))).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.orgCode(any())).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.lang(any())).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.scopes(anyString())).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.build()).thenReturn(kindeClient); + when(kindeClient.clientSession()).thenReturn(kindeClientSession); + when(kindeClientSession.authorizationUrlWithParameters(any())).thenReturn(authorizationUrl); + when(authorizationUrl.getUrl()).thenReturn(new URL("http://example.com/auth?connection_id=" + connectionId)); + + // Execute + filter.doFilter(request, response, filterChain); + + // Verify + verify(kindeClientSession).authorizationUrlWithParameters(argThat((Map params) -> + params.containsKey(CONNECTION_ID) && + params.get(CONNECTION_ID).equals(connectionId) && + params.containsKey("supports_reauth") + )); + verify(response).sendRedirect(anyString()); + } + + @Test + @DisplayName("Filter should work without connection_id when not provided") + public void testFilterWithoutConnectionId() throws ServletException, IOException { + // Setup + when(request.getParameter(CONNECTION_ID)).thenReturn(null); + when(request.getParameter("org_code")).thenReturn(null); + when(request.getParameter("lang")).thenReturn(null); + + // Mock KindeSingleton chain + when(kindeSingleton.getKindeClientBuilder()).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.redirectUri(anyString())).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.grantType(any(AuthorizationType.class))).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.orgCode(any())).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.lang(any())).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.scopes(anyString())).thenReturn(kindeClientBuilder); + when(kindeClientBuilder.build()).thenReturn(kindeClient); + when(kindeClient.clientSession()).thenReturn(kindeClientSession); + when(kindeClientSession.login()).thenReturn(authorizationUrl); + when(authorizationUrl.getUrl()).thenReturn(new URL("http://example.com/auth")); + + // Execute + filter.doFilter(request, response, filterChain); + + // Verify - should use login() method when no connection_id + verify(kindeClientSession).login(); + verify(response).sendRedirect(anyString()); + } +}