diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index bc91d48f79..f707d6a0bc 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -18,6 +18,7 @@ from typing import Any from typing import Dict from typing import List +from typing import Literal from typing import Optional from pydantic import alias_generators @@ -80,6 +81,14 @@ class OAuth2Auth(BaseModelWithConfig): expires_at: Optional[int] = None expires_in: Optional[int] = None audience: Optional[str] = None + token_endpoint_auth_method: Optional[ + Literal[ + "client_secret_basic", + "client_secret_post", + "client_secret_jwt", + "private_key_jwt", + ] + ] = "client_secret_basic" class ServiceAccountCredential(BaseModelWithConfig): diff --git a/src/google/adk/auth/oauth2_credential_util.py b/src/google/adk/auth/oauth2_credential_util.py index cc315bd29e..843f1152b6 100644 --- a/src/google/adk/auth/oauth2_credential_util.py +++ b/src/google/adk/auth/oauth2_credential_util.py @@ -82,6 +82,7 @@ def create_oauth2_session( scope=" ".join(scopes), redirect_uri=auth_credential.oauth2.redirect_uri, state=auth_credential.oauth2.state, + token_endpoint_auth_method=auth_credential.oauth2.token_endpoint_auth_method, ), token_endpoint, ) diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py index f1fd607ff5..952ede61da 100644 --- a/tests/unittests/auth/test_oauth2_credential_util.py +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -25,6 +25,36 @@ from google.adk.auth.auth_schemes import OpenIdConnectWithConfig from google.adk.auth.oauth2_credential_util import create_oauth2_session from google.adk.auth.oauth2_credential_util import update_credential_with_tokens +import pytest + + +@pytest.fixture +def openid_connect_scheme(): + """Fixture providing a standard OpenIdConnectWithConfig scheme.""" + return OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url="https://example.com/.well-known/openid_configuration", + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + + +def create_oauth2_auth_credential(token_endpoint_auth_method=None): + """Helper function to create OAuth2Auth credential with optional token_endpoint_auth_method.""" + oauth2_auth = OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + state="test_state", + ) + if token_endpoint_auth_method is not None: + oauth2_auth.token_endpoint_auth_method = token_endpoint_auth_method + + return AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=oauth2_auth, + ) class TestOAuth2CredentialUtil: @@ -122,6 +152,68 @@ def test_create_oauth2_session_missing_credentials(self): assert client is None assert token_endpoint is None + def test_create_oauth2_session_with_token_endpoint_auth_method( + self, openid_connect_scheme + ): + """Test create_oauth2_session with token_endpoint_auth_method specified.""" + credential = create_oauth2_auth_credential( + token_endpoint_auth_method="client_secret_post" + ) + + client, token_endpoint = create_oauth2_session( + openid_connect_scheme, credential + ) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.client_id == "test_client_id" + assert client.client_secret == "test_client_secret" + assert client.token_endpoint_auth_method == "client_secret_post" + + def test_create_oauth2_session_with_default_token_endpoint_auth_method( + self, openid_connect_scheme + ): + """Test create_oauth2_session with default token_endpoint_auth_method.""" + credential = create_oauth2_auth_credential() + + client, token_endpoint = create_oauth2_session( + openid_connect_scheme, credential + ) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.client_id == "test_client_id" + assert client.client_secret == "test_client_secret" + assert client.token_endpoint_auth_method == "client_secret_basic" + + def test_create_oauth2_session_oauth2_scheme_with_token_endpoint_auth_method( + self, + ): + """Test create_oauth2_session with OAuth2 scheme and token_endpoint_auth_method.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + scheme = OAuth2(type_="oauth2", flows=flows) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + token_endpoint_auth_method="client_secret_jwt", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.token_endpoint_auth_method == "client_secret_jwt" + def test_update_credential_with_tokens(self): """Test update_credential_with_tokens function.""" credential = AuthCredential(