Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/google/adk/auth/auth_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class OAuth2Auth(BaseModelWithConfig):
expires_at: Optional[int] = None
expires_in: Optional[int] = None
audience: Optional[str] = None
token_endpoint_auth_method: Optional[str] = "client_secret_basic"


class ServiceAccountCredential(BaseModelWithConfig):
Expand Down
1 change: 1 addition & 0 deletions src/google/adk/auth/oauth2_credential_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def create_oauth2_session(
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
state=auth_credential.oauth2.state,
token_endpoint_auth_method=auth_credential.oauth2.token_endpoint_auth_method,
),
token_endpoint,
)
Expand Down
87 changes: 87 additions & 0 deletions tests/unittests/auth/test_oauth2_credential_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,93 @@ def test_create_oauth2_session_missing_credentials(self):
assert client is None
assert token_endpoint is None

def test_create_oauth2_session_with_token_endpoint_auth_method(self):
"""Test create_oauth2_session with token_endpoint_auth_method specified."""
scheme = OpenIdConnectWithConfig(
type_="openIdConnect",
openId_connect_url=(
"https://example.com/.well-known/openid_configuration"
),
authorization_endpoint="https://example.com/auth",
token_endpoint="https://example.com/token",
scopes=["openid", "profile"],
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=OAuth2Auth(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uri="https://example.com/callback",
state="test_state",
token_endpoint_auth_method="client_secret_post",
),
)

client, token_endpoint = create_oauth2_session(scheme, credential)

assert client is not None
assert token_endpoint == "https://example.com/token"
assert client.client_id == "test_client_id"
assert client.client_secret == "test_client_secret"
assert client.token_endpoint_auth_method == "client_secret_post"

def test_create_oauth2_session_with_default_token_endpoint_auth_method(self):
"""Test create_oauth2_session with default token_endpoint_auth_method (None)."""
scheme = OpenIdConnectWithConfig(
type_="openIdConnect",
openId_connect_url=(
"https://example.com/.well-known/openid_configuration"
),
authorization_endpoint="https://example.com/auth",
token_endpoint="https://example.com/token",
scopes=["openid", "profile"],
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=OAuth2Auth(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uri="https://example.com/callback",
state="test_state",
),
)

client, token_endpoint = create_oauth2_session(scheme, credential)

assert client is not None
assert token_endpoint == "https://example.com/token"
assert client.client_id == "test_client_id"
assert client.client_secret == "test_client_secret"
assert client.token_endpoint_auth_method == "client_secret_basic"

def test_create_oauth2_session_oauth2_scheme_with_token_endpoint_auth_method(
self,
):
"""Test create_oauth2_session with OAuth2 scheme and token_endpoint_auth_method."""
flows = OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl="https://example.com/auth",
tokenUrl="https://example.com/token",
scopes={"read": "Read access", "write": "Write access"},
)
)
scheme = OAuth2(type_="oauth2", flows=flows)
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="test_client_id",
client_secret="test_client_secret",
redirect_uri="https://example.com/callback",
token_endpoint_auth_method="client_secret_jwt",
),
)

client, token_endpoint = create_oauth2_session(scheme, credential)

assert client is not None
assert token_endpoint == "https://example.com/token"
assert client.token_endpoint_auth_method == "client_secret_jwt"

def test_update_credential_with_tokens(self):
"""Test update_credential_with_tokens function."""
credential = AuthCredential(
Expand Down