google-adk 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/a2a/__init__.py +13 -0
- google/adk/a2a/converters/__init__.py +13 -0
- google/adk/a2a/converters/part_converter.py +166 -0
- google/adk/agents/invocation_context.py +2 -0
- google/adk/agents/llm_agent.py +1 -6
- google/adk/agents/run_config.py +11 -0
- google/adk/auth/auth_credential.py +5 -0
- google/adk/auth/auth_handler.py +22 -96
- google/adk/auth/auth_preprocessor.py +3 -3
- google/adk/auth/auth_tool.py +46 -0
- google/adk/auth/credential_manager.py +265 -0
- google/adk/auth/credential_service/__init__.py +13 -0
- google/adk/auth/credential_service/base_credential_service.py +75 -0
- google/adk/auth/credential_service/in_memory_credential_service.py +64 -0
- google/adk/auth/exchanger/__init__.py +23 -0
- google/adk/auth/exchanger/base_credential_exchanger.py +57 -0
- google/adk/auth/exchanger/credential_exchanger_registry.py +58 -0
- google/adk/auth/exchanger/oauth2_credential_exchanger.py +104 -0
- google/adk/auth/exchanger/service_account_credential_exchanger.py +104 -0
- google/adk/auth/oauth2_credential_util.py +107 -0
- google/adk/auth/refresher/__init__.py +21 -0
- google/adk/auth/refresher/base_credential_refresher.py +74 -0
- google/adk/auth/refresher/credential_refresher_registry.py +59 -0
- google/adk/auth/refresher/oauth2_credential_refresher.py +154 -0
- google/adk/cli/agent_graph.py +34 -32
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/main-JAAWEV7F.js +92 -0
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
- google/adk/cli/cli.py +10 -0
- google/adk/cli/cli_deploy.py +80 -21
- google/adk/cli/cli_tools_click.py +132 -61
- google/adk/cli/fast_api.py +46 -41
- google/adk/cli/utils/agent_loader.py +15 -2
- google/adk/code_executors/container_code_executor.py +10 -6
- google/adk/code_executors/vertex_ai_code_executor.py +8 -2
- google/adk/evaluation/_eval_set_results_manager_utils.py +44 -0
- google/adk/evaluation/_eval_sets_manager_utils.py +108 -0
- google/adk/evaluation/eval_metrics.py +0 -5
- google/adk/evaluation/eval_result.py +12 -7
- google/adk/evaluation/eval_set_results_manager.py +6 -1
- google/adk/evaluation/gcs_eval_set_results_manager.py +121 -0
- google/adk/evaluation/gcs_eval_sets_manager.py +196 -0
- google/adk/evaluation/local_eval_set_results_manager.py +6 -18
- google/adk/evaluation/local_eval_sets_manager.py +27 -78
- google/adk/flows/llm_flows/basic.py +9 -0
- google/adk/models/anthropic_llm.py +1 -1
- google/adk/models/gemini_llm_connection.py +2 -0
- google/adk/models/google_llm.py +57 -16
- google/adk/models/lite_llm.py +2 -1
- google/adk/platform/__init__.py +13 -0
- google/adk/platform/internal/__init__.py +15 -0
- google/adk/platform/internal/thread.py +30 -0
- google/adk/platform/thread.py +31 -0
- google/adk/runners.py +8 -2
- google/adk/sessions/in_memory_session_service.py +12 -1
- google/adk/sessions/vertex_ai_session_service.py +71 -50
- google/adk/tools/__init__.py +2 -0
- google/adk/tools/_automatic_function_calling_util.py +1 -0
- google/adk/tools/_forwarding_artifact_service.py +96 -0
- google/adk/tools/_function_parameter_parse_util.py +1 -0
- google/adk/tools/agent_tool.py +5 -39
- google/adk/tools/application_integration_tool/integration_connector_tool.py +2 -2
- google/adk/tools/authenticated_function_tool.py +107 -0
- google/adk/tools/base_authenticated_tool.py +107 -0
- google/adk/tools/bigquery/bigquery_credentials.py +6 -4
- google/adk/tools/bigquery/bigquery_tool.py +22 -9
- google/adk/tools/bigquery/bigquery_toolset.py +9 -3
- google/adk/tools/bigquery/client.py +7 -3
- google/adk/tools/bigquery/config.py +46 -0
- google/adk/tools/bigquery/metadata_tool.py +114 -91
- google/adk/tools/bigquery/query_tool.py +141 -23
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +7 -4
- google/adk/tools/google_search_tool.py +0 -1
- google/adk/tools/mcp_tool/__init__.py +6 -0
- google/adk/tools/mcp_tool/mcp_session_manager.py +271 -149
- google/adk/tools/mcp_tool/mcp_tool.py +79 -22
- google/adk/tools/mcp_tool/mcp_toolset.py +32 -29
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +3 -3
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +56 -33
- google/adk/tools/retrieval/files_retrieval.py +7 -1
- google/adk/tools/url_context_tool.py +61 -0
- google/adk/tools/vertex_ai_search_tool.py +13 -2
- google/adk/utils/feature_decorator.py +175 -0
- google/adk/version.py +1 -1
- {google_adk-1.2.1.dist-info → google_adk-1.4.0.dist-info}/METADATA +10 -2
- {google_adk-1.2.1.dist-info → google_adk-1.4.0.dist-info}/RECORD +89 -58
- google/adk/cli/browser/main-CS5OLUMF.js +0 -91
- google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -17
- {google_adk-1.2.1.dist-info → google_adk-1.4.0.dist-info}/WHEEL +0 -0
- {google_adk-1.2.1.dist-info → google_adk-1.4.0.dist-info}/entry_points.txt +0 -0
- {google_adk-1.2.1.dist-info → google_adk-1.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,104 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Credential fetcher for Google Service Account."""
|
16
|
+
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import Optional
|
20
|
+
|
21
|
+
import google.auth
|
22
|
+
from google.auth.transport.requests import Request
|
23
|
+
from google.oauth2 import service_account
|
24
|
+
from typing_extensions import override
|
25
|
+
|
26
|
+
from ...utils.feature_decorator import experimental
|
27
|
+
from ..auth_credential import AuthCredential
|
28
|
+
from ..auth_credential import AuthCredentialTypes
|
29
|
+
from ..auth_schemes import AuthScheme
|
30
|
+
from .base_credential_exchanger import BaseCredentialExchanger
|
31
|
+
|
32
|
+
|
33
|
+
@experimental
|
34
|
+
class ServiceAccountCredentialExchanger(BaseCredentialExchanger):
|
35
|
+
"""Exchanges Google Service Account credentials for an access token.
|
36
|
+
|
37
|
+
Uses the default service credential if `use_default_credential = True`.
|
38
|
+
Otherwise, uses the service account credential provided in the auth
|
39
|
+
credential.
|
40
|
+
"""
|
41
|
+
|
42
|
+
@override
|
43
|
+
async def exchange(
|
44
|
+
self,
|
45
|
+
auth_credential: AuthCredential,
|
46
|
+
auth_scheme: Optional[AuthScheme] = None,
|
47
|
+
) -> AuthCredential:
|
48
|
+
"""Exchanges the service account auth credential for an access token.
|
49
|
+
|
50
|
+
If the AuthCredential contains a service account credential, it will be used
|
51
|
+
to exchange for an access token. Otherwise, if use_default_credential is True,
|
52
|
+
the default application credential will be used for exchanging an access token.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
auth_scheme: The authentication scheme.
|
56
|
+
auth_credential: The credential to exchange.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
An AuthCredential in OAUTH2 format, containing the exchanged credential JSON.
|
60
|
+
|
61
|
+
Raises:
|
62
|
+
ValueError: If service account credentials are missing or invalid.
|
63
|
+
Exception: If credential exchange or refresh fails.
|
64
|
+
"""
|
65
|
+
if auth_credential is None:
|
66
|
+
raise ValueError("Credential cannot be None.")
|
67
|
+
|
68
|
+
if auth_credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT:
|
69
|
+
raise ValueError("Credential is not a service account credential.")
|
70
|
+
|
71
|
+
if auth_credential.service_account is None:
|
72
|
+
raise ValueError(
|
73
|
+
"Service account credentials are missing. Please provide them."
|
74
|
+
)
|
75
|
+
|
76
|
+
if (
|
77
|
+
auth_credential.service_account.service_account_credential is None
|
78
|
+
and not auth_credential.service_account.use_default_credential
|
79
|
+
):
|
80
|
+
raise ValueError(
|
81
|
+
"Service account credentials are invalid. Please set the"
|
82
|
+
" service_account_credential field or set `use_default_credential ="
|
83
|
+
" True` to use application default credential in a hosted service"
|
84
|
+
" like Google Cloud Run."
|
85
|
+
)
|
86
|
+
|
87
|
+
try:
|
88
|
+
if auth_credential.service_account.use_default_credential:
|
89
|
+
credentials, _ = google.auth.default()
|
90
|
+
else:
|
91
|
+
config = auth_credential.service_account
|
92
|
+
credentials = service_account.Credentials.from_service_account_info(
|
93
|
+
config.service_account_credential.model_dump(), scopes=config.scopes
|
94
|
+
)
|
95
|
+
|
96
|
+
# Refresh credentials to ensure we have a valid access token
|
97
|
+
credentials.refresh(Request())
|
98
|
+
|
99
|
+
return AuthCredential(
|
100
|
+
auth_type=AuthCredentialTypes.OAUTH2,
|
101
|
+
google_oauth2_json=credentials.to_json(),
|
102
|
+
)
|
103
|
+
except Exception as e:
|
104
|
+
raise ValueError(f"Failed to exchange service account token: {e}") from e
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
import logging
|
18
|
+
from typing import Optional
|
19
|
+
from typing import Tuple
|
20
|
+
|
21
|
+
from fastapi.openapi.models import OAuth2
|
22
|
+
|
23
|
+
from ..utils.feature_decorator import experimental
|
24
|
+
from .auth_credential import AuthCredential
|
25
|
+
from .auth_schemes import AuthScheme
|
26
|
+
from .auth_schemes import OpenIdConnectWithConfig
|
27
|
+
|
28
|
+
try:
|
29
|
+
from authlib.integrations.requests_client import OAuth2Session
|
30
|
+
from authlib.oauth2.rfc6749 import OAuth2Token
|
31
|
+
|
32
|
+
AUTHLIB_AVIALABLE = True
|
33
|
+
except ImportError:
|
34
|
+
AUTHLIB_AVIALABLE = False
|
35
|
+
|
36
|
+
|
37
|
+
logger = logging.getLogger("google_adk." + __name__)
|
38
|
+
|
39
|
+
|
40
|
+
@experimental
|
41
|
+
def create_oauth2_session(
|
42
|
+
auth_scheme: AuthScheme,
|
43
|
+
auth_credential: AuthCredential,
|
44
|
+
) -> Tuple[Optional[OAuth2Session], Optional[str]]:
|
45
|
+
"""Create an OAuth2 session for token operations.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
auth_scheme: The authentication scheme configuration.
|
49
|
+
auth_credential: The authentication credential.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
Tuple of (OAuth2Session, token_endpoint) or (None, None) if cannot create session.
|
53
|
+
"""
|
54
|
+
if isinstance(auth_scheme, OpenIdConnectWithConfig):
|
55
|
+
if not hasattr(auth_scheme, "token_endpoint"):
|
56
|
+
return None, None
|
57
|
+
token_endpoint = auth_scheme.token_endpoint
|
58
|
+
scopes = auth_scheme.scopes
|
59
|
+
elif isinstance(auth_scheme, OAuth2):
|
60
|
+
if (
|
61
|
+
not auth_scheme.flows.authorizationCode
|
62
|
+
or not auth_scheme.flows.authorizationCode.tokenUrl
|
63
|
+
):
|
64
|
+
return None, None
|
65
|
+
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
|
66
|
+
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
|
67
|
+
else:
|
68
|
+
return None, None
|
69
|
+
|
70
|
+
if (
|
71
|
+
not auth_credential
|
72
|
+
or not auth_credential.oauth2
|
73
|
+
or not auth_credential.oauth2.client_id
|
74
|
+
or not auth_credential.oauth2.client_secret
|
75
|
+
):
|
76
|
+
return None, None
|
77
|
+
|
78
|
+
return (
|
79
|
+
OAuth2Session(
|
80
|
+
auth_credential.oauth2.client_id,
|
81
|
+
auth_credential.oauth2.client_secret,
|
82
|
+
scope=" ".join(scopes),
|
83
|
+
redirect_uri=auth_credential.oauth2.redirect_uri,
|
84
|
+
state=auth_credential.oauth2.state,
|
85
|
+
),
|
86
|
+
token_endpoint,
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
@experimental
|
91
|
+
def update_credential_with_tokens(
|
92
|
+
auth_credential: AuthCredential, tokens: OAuth2Token
|
93
|
+
) -> None:
|
94
|
+
"""Update the credential with new tokens.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
auth_credential: The authentication credential to update.
|
98
|
+
tokens: The OAuth2Token object containing new token information.
|
99
|
+
"""
|
100
|
+
auth_credential.oauth2.access_token = tokens.get("access_token")
|
101
|
+
auth_credential.oauth2.refresh_token = tokens.get("refresh_token")
|
102
|
+
auth_credential.oauth2.expires_at = (
|
103
|
+
int(tokens.get("expires_at")) if tokens.get("expires_at") else None
|
104
|
+
)
|
105
|
+
auth_credential.oauth2.expires_in = (
|
106
|
+
int(tokens.get("expires_in")) if tokens.get("expires_in") else None
|
107
|
+
)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Credential refresher module."""
|
16
|
+
|
17
|
+
from .base_credential_refresher import BaseCredentialRefresher
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"BaseCredentialRefresher",
|
21
|
+
]
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Base credential refresher interface."""
|
16
|
+
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import abc
|
20
|
+
from typing import Optional
|
21
|
+
|
22
|
+
from google.adk.auth.auth_credential import AuthCredential
|
23
|
+
from google.adk.auth.auth_schemes import AuthScheme
|
24
|
+
from google.adk.utils.feature_decorator import experimental
|
25
|
+
|
26
|
+
|
27
|
+
class CredentialRefresherError(Exception):
|
28
|
+
"""Base exception for credential refresh errors."""
|
29
|
+
|
30
|
+
|
31
|
+
@experimental
|
32
|
+
class BaseCredentialRefresher(abc.ABC):
|
33
|
+
"""Base interface for credential refreshers.
|
34
|
+
|
35
|
+
Credential refreshers are responsible for checking if a credential is expired
|
36
|
+
or needs to be refreshed, and for refreshing it if necessary.
|
37
|
+
"""
|
38
|
+
|
39
|
+
@abc.abstractmethod
|
40
|
+
async def is_refresh_needed(
|
41
|
+
self,
|
42
|
+
auth_credential: AuthCredential,
|
43
|
+
auth_scheme: Optional[AuthScheme] = None,
|
44
|
+
) -> bool:
|
45
|
+
"""Checks if a credential needs to be refreshed.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
auth_credential: The credential to check.
|
49
|
+
auth_scheme: The authentication scheme (optional, some refreshers don't need it).
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
True if the credential needs to be refreshed, False otherwise.
|
53
|
+
"""
|
54
|
+
pass
|
55
|
+
|
56
|
+
@abc.abstractmethod
|
57
|
+
async def refresh(
|
58
|
+
self,
|
59
|
+
auth_credential: AuthCredential,
|
60
|
+
auth_scheme: Optional[AuthScheme] = None,
|
61
|
+
) -> AuthCredential:
|
62
|
+
"""Refreshes a credential if needed.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
auth_credential: The credential to refresh.
|
66
|
+
auth_scheme: The authentication scheme (optional, some refreshers don't need it).
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
The refreshed credential.
|
70
|
+
|
71
|
+
Raises:
|
72
|
+
CredentialRefresherError: If credential refresh fails.
|
73
|
+
"""
|
74
|
+
pass
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Credential refresher registry."""
|
16
|
+
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import Dict
|
20
|
+
from typing import Optional
|
21
|
+
|
22
|
+
from google.adk.auth.auth_credential import AuthCredentialTypes
|
23
|
+
from google.adk.utils.feature_decorator import experimental
|
24
|
+
|
25
|
+
from .base_credential_refresher import BaseCredentialRefresher
|
26
|
+
|
27
|
+
|
28
|
+
@experimental
|
29
|
+
class CredentialRefresherRegistry:
|
30
|
+
"""Registry for credential refresher instances."""
|
31
|
+
|
32
|
+
def __init__(self):
|
33
|
+
self._refreshers: Dict[AuthCredentialTypes, BaseCredentialRefresher] = {}
|
34
|
+
|
35
|
+
def register(
|
36
|
+
self,
|
37
|
+
credential_type: AuthCredentialTypes,
|
38
|
+
refresher_instance: BaseCredentialRefresher,
|
39
|
+
) -> None:
|
40
|
+
"""Register a refresher instance for a credential type.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
credential_type: The credential type to register for.
|
44
|
+
refresher_instance: The refresher instance to register.
|
45
|
+
"""
|
46
|
+
self._refreshers[credential_type] = refresher_instance
|
47
|
+
|
48
|
+
def get_refresher(
|
49
|
+
self, credential_type: AuthCredentialTypes
|
50
|
+
) -> Optional[BaseCredentialRefresher]:
|
51
|
+
"""Get the refresher instance for a credential type.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
credential_type: The credential type to get refresher for.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
The refresher instance if registered, None otherwise.
|
58
|
+
"""
|
59
|
+
return self._refreshers.get(credential_type)
|
@@ -0,0 +1,154 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""OAuth2 credential refresher implementation."""
|
16
|
+
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import json
|
20
|
+
import logging
|
21
|
+
from typing import Optional
|
22
|
+
|
23
|
+
from google.adk.auth.auth_credential import AuthCredential
|
24
|
+
from google.adk.auth.auth_schemes import AuthScheme
|
25
|
+
from google.adk.auth.oauth2_credential_util import create_oauth2_session
|
26
|
+
from google.adk.auth.oauth2_credential_util import update_credential_with_tokens
|
27
|
+
from google.adk.utils.feature_decorator import experimental
|
28
|
+
from google.auth.transport.requests import Request
|
29
|
+
from google.oauth2.credentials import Credentials
|
30
|
+
from typing_extensions import override
|
31
|
+
|
32
|
+
from .base_credential_refresher import BaseCredentialRefresher
|
33
|
+
|
34
|
+
try:
|
35
|
+
from authlib.oauth2.rfc6749 import OAuth2Token
|
36
|
+
|
37
|
+
AUTHLIB_AVIALABLE = True
|
38
|
+
except ImportError:
|
39
|
+
AUTHLIB_AVIALABLE = False
|
40
|
+
|
41
|
+
logger = logging.getLogger("google_adk." + __name__)
|
42
|
+
|
43
|
+
|
44
|
+
@experimental
|
45
|
+
class OAuth2CredentialRefresher(BaseCredentialRefresher):
|
46
|
+
"""Refreshes OAuth2 credentials including Google OAuth2 JSON credentials."""
|
47
|
+
|
48
|
+
@override
|
49
|
+
async def is_refresh_needed(
|
50
|
+
self,
|
51
|
+
auth_credential: AuthCredential,
|
52
|
+
auth_scheme: Optional[AuthScheme] = None,
|
53
|
+
) -> bool:
|
54
|
+
"""Check if the OAuth2 credential needs to be refreshed.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
auth_credential: The OAuth2 credential to check.
|
58
|
+
auth_scheme: The OAuth2 authentication scheme (optional for Google OAuth2 JSON).
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
True if the credential needs to be refreshed, False otherwise.
|
62
|
+
"""
|
63
|
+
# Handle Google OAuth2 credentials (from service account exchange)
|
64
|
+
if auth_credential.google_oauth2_json:
|
65
|
+
try:
|
66
|
+
google_credential = Credentials.from_authorized_user_info(
|
67
|
+
json.loads(auth_credential.google_oauth2_json)
|
68
|
+
)
|
69
|
+
return google_credential.expired and bool(
|
70
|
+
google_credential.refresh_token
|
71
|
+
)
|
72
|
+
except Exception as e:
|
73
|
+
logger.warning("Failed to parse Google OAuth2 JSON credential: %s", e)
|
74
|
+
return False
|
75
|
+
|
76
|
+
# Handle regular OAuth2 credentials
|
77
|
+
elif auth_credential.oauth2 and auth_scheme:
|
78
|
+
if not AUTHLIB_AVIALABLE:
|
79
|
+
return False
|
80
|
+
|
81
|
+
if not auth_credential.oauth2:
|
82
|
+
return False
|
83
|
+
|
84
|
+
return OAuth2Token({
|
85
|
+
"expires_at": auth_credential.oauth2.expires_at,
|
86
|
+
"expires_in": auth_credential.oauth2.expires_in,
|
87
|
+
}).is_expired()
|
88
|
+
|
89
|
+
return False
|
90
|
+
|
91
|
+
@override
|
92
|
+
async def refresh(
|
93
|
+
self,
|
94
|
+
auth_credential: AuthCredential,
|
95
|
+
auth_scheme: Optional[AuthScheme] = None,
|
96
|
+
) -> AuthCredential:
|
97
|
+
"""Refresh the OAuth2 credential.
|
98
|
+
If refresh failed, return the original credential.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
auth_credential: The OAuth2 credential to refresh.
|
102
|
+
auth_scheme: The OAuth2 authentication scheme (optional for Google OAuth2 JSON).
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
The refreshed credential.
|
106
|
+
|
107
|
+
"""
|
108
|
+
# Handle Google OAuth2 credentials (from service account exchange)
|
109
|
+
if auth_credential.google_oauth2_json:
|
110
|
+
try:
|
111
|
+
google_credential = Credentials.from_authorized_user_info(
|
112
|
+
json.loads(auth_credential.google_oauth2_json)
|
113
|
+
)
|
114
|
+
if google_credential.expired and google_credential.refresh_token:
|
115
|
+
google_credential.refresh(Request())
|
116
|
+
auth_credential.google_oauth2_json = google_credential.to_json()
|
117
|
+
logger.info("Successfully refreshed Google OAuth2 JSON credential")
|
118
|
+
except Exception as e:
|
119
|
+
# TODO reconsider whether we should raise error when refresh failed.
|
120
|
+
logger.error("Failed to refresh Google OAuth2 JSON credential: %s", e)
|
121
|
+
|
122
|
+
# Handle regular OAuth2 credentials
|
123
|
+
elif auth_credential.oauth2 and auth_scheme:
|
124
|
+
if not AUTHLIB_AVIALABLE:
|
125
|
+
return auth_credential
|
126
|
+
|
127
|
+
if not auth_credential.oauth2:
|
128
|
+
return auth_credential
|
129
|
+
|
130
|
+
if OAuth2Token({
|
131
|
+
"expires_at": auth_credential.oauth2.expires_at,
|
132
|
+
"expires_in": auth_credential.oauth2.expires_in,
|
133
|
+
}).is_expired():
|
134
|
+
client, token_endpoint = create_oauth2_session(
|
135
|
+
auth_scheme, auth_credential
|
136
|
+
)
|
137
|
+
if not client:
|
138
|
+
logger.warning("Could not create OAuth2 session for token refresh")
|
139
|
+
return auth_credential
|
140
|
+
|
141
|
+
try:
|
142
|
+
tokens = client.refresh_token(
|
143
|
+
url=token_endpoint,
|
144
|
+
refresh_token=auth_credential.oauth2.refresh_token,
|
145
|
+
)
|
146
|
+
update_credential_with_tokens(auth_credential, tokens)
|
147
|
+
logger.debug("Successfully refreshed OAuth2 tokens")
|
148
|
+
except Exception as e:
|
149
|
+
# TODO reconsider whether we should raise error when refresh failed.
|
150
|
+
logger.error("Failed to refresh OAuth2 tokens: %s", e)
|
151
|
+
# Return original credential on failure
|
152
|
+
return auth_credential
|
153
|
+
|
154
|
+
return auth_credential
|
google/adk/cli/agent_graph.py
CHANGED
@@ -64,11 +64,11 @@ async def build_graph(
|
|
64
64
|
if isinstance(tool_or_agent, BaseAgent):
|
65
65
|
# Added Workflow Agent checks for different agent types
|
66
66
|
if isinstance(tool_or_agent, SequentialAgent):
|
67
|
-
return tool_or_agent.name +
|
67
|
+
return tool_or_agent.name + ' (Sequential Agent)'
|
68
68
|
elif isinstance(tool_or_agent, LoopAgent):
|
69
|
-
return tool_or_agent.name +
|
69
|
+
return tool_or_agent.name + ' (Loop Agent)'
|
70
70
|
elif isinstance(tool_or_agent, ParallelAgent):
|
71
|
-
return tool_or_agent.name +
|
71
|
+
return tool_or_agent.name + ' (Parallel Agent)'
|
72
72
|
else:
|
73
73
|
return tool_or_agent.name
|
74
74
|
elif isinstance(tool_or_agent, BaseTool):
|
@@ -144,49 +144,53 @@ async def build_graph(
|
|
144
144
|
)
|
145
145
|
return False
|
146
146
|
|
147
|
-
def build_cluster(child: graphviz.Digraph, agent: BaseAgent, name: str):
|
147
|
+
async def build_cluster(child: graphviz.Digraph, agent: BaseAgent, name: str):
|
148
148
|
if isinstance(agent, LoopAgent):
|
149
149
|
# Draw the edge from the parent agent to the first sub-agent
|
150
|
-
|
150
|
+
if parent_agent:
|
151
|
+
draw_edge(parent_agent.name, agent.sub_agents[0].name)
|
151
152
|
length = len(agent.sub_agents)
|
152
|
-
|
153
|
+
curr_length = 0
|
153
154
|
# Draw the edges between the sub-agents
|
154
155
|
for sub_agent_int_sequential in agent.sub_agents:
|
155
|
-
build_graph(child, sub_agent_int_sequential, highlight_pairs)
|
156
|
+
await build_graph(child, sub_agent_int_sequential, highlight_pairs)
|
156
157
|
# Draw the edge between the current sub-agent and the next one
|
157
158
|
# If it's the last sub-agent, draw an edge to the first one to indicating a loop
|
158
159
|
draw_edge(
|
159
|
-
agent.sub_agents[
|
160
|
+
agent.sub_agents[curr_length].name,
|
160
161
|
agent.sub_agents[
|
161
|
-
0 if
|
162
|
+
0 if curr_length == length - 1 else curr_length + 1
|
162
163
|
].name,
|
163
164
|
)
|
164
|
-
|
165
|
+
curr_length += 1
|
165
166
|
elif isinstance(agent, SequentialAgent):
|
166
167
|
# Draw the edge from the parent agent to the first sub-agent
|
167
|
-
|
168
|
+
if parent_agent:
|
169
|
+
draw_edge(parent_agent.name, agent.sub_agents[0].name)
|
168
170
|
length = len(agent.sub_agents)
|
169
|
-
|
171
|
+
curr_length = 0
|
170
172
|
|
171
173
|
# Draw the edges between the sub-agents
|
172
174
|
for sub_agent_int_sequential in agent.sub_agents:
|
173
|
-
build_graph(child, sub_agent_int_sequential, highlight_pairs)
|
175
|
+
await build_graph(child, sub_agent_int_sequential, highlight_pairs)
|
174
176
|
# Draw the edge between the current sub-agent and the next one
|
175
177
|
# If it's the last sub-agent, don't draw an edge to avoid a loop
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
178
|
+
if curr_length != length - 1:
|
179
|
+
draw_edge(
|
180
|
+
agent.sub_agents[curr_length].name,
|
181
|
+
agent.sub_agents[curr_length + 1].name,
|
182
|
+
)
|
183
|
+
curr_length += 1
|
181
184
|
|
182
185
|
elif isinstance(agent, ParallelAgent):
|
183
186
|
# Draw the edge from the parent agent to every sub-agent
|
184
187
|
for sub_agent in agent.sub_agents:
|
185
|
-
build_graph(child, sub_agent, highlight_pairs)
|
186
|
-
|
188
|
+
await build_graph(child, sub_agent, highlight_pairs)
|
189
|
+
if parent_agent:
|
190
|
+
draw_edge(parent_agent.name, sub_agent.name)
|
187
191
|
else:
|
188
192
|
for sub_agent in agent.sub_agents:
|
189
|
-
build_graph(child, sub_agent, highlight_pairs)
|
193
|
+
await build_graph(child, sub_agent, highlight_pairs)
|
190
194
|
draw_edge(agent.name, sub_agent.name)
|
191
195
|
|
192
196
|
child.attr(
|
@@ -196,21 +200,20 @@ async def build_graph(
|
|
196
200
|
fontcolor=light_gray,
|
197
201
|
)
|
198
202
|
|
199
|
-
def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
|
203
|
+
async def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
|
200
204
|
name = get_node_name(tool_or_agent)
|
201
205
|
shape = get_node_shape(tool_or_agent)
|
202
206
|
caption = get_node_caption(tool_or_agent)
|
203
|
-
|
204
|
-
child = None
|
207
|
+
as_cluster = should_build_agent_cluster(tool_or_agent)
|
205
208
|
if highlight_pairs:
|
206
209
|
for highlight_tuple in highlight_pairs:
|
207
210
|
if name in highlight_tuple:
|
208
211
|
# if in highlight, draw highlight node
|
209
|
-
if
|
212
|
+
if as_cluster:
|
210
213
|
cluster = graphviz.Digraph(
|
211
214
|
name='cluster_' + name
|
212
215
|
) # adding "cluster_" to the name makes the graph render as a cluster subgraph
|
213
|
-
build_cluster(cluster, agent, name)
|
216
|
+
await build_cluster(cluster, agent, name)
|
214
217
|
graph.subgraph(cluster)
|
215
218
|
else:
|
216
219
|
graph.node(
|
@@ -224,12 +227,12 @@ async def build_graph(
|
|
224
227
|
)
|
225
228
|
return
|
226
229
|
# if not in highlight, draw non-highlight node
|
227
|
-
if
|
230
|
+
if as_cluster:
|
228
231
|
|
229
232
|
cluster = graphviz.Digraph(
|
230
233
|
name='cluster_' + name
|
231
234
|
) # adding "cluster_" to the name makes the graph render as a cluster subgraph
|
232
|
-
build_cluster(cluster, agent, name)
|
235
|
+
await build_cluster(cluster, agent, name)
|
233
236
|
graph.subgraph(cluster)
|
234
237
|
|
235
238
|
else:
|
@@ -264,10 +267,9 @@ async def build_graph(
|
|
264
267
|
else:
|
265
268
|
graph.edge(from_name, to_name, arrowhead='none', color=light_gray)
|
266
269
|
|
267
|
-
draw_node(agent)
|
270
|
+
await draw_node(agent)
|
268
271
|
for sub_agent in agent.sub_agents:
|
269
|
-
|
270
|
-
build_graph(graph, sub_agent, highlight_pairs, agent)
|
272
|
+
await build_graph(graph, sub_agent, highlight_pairs, agent)
|
271
273
|
if not should_build_agent_cluster(
|
272
274
|
sub_agent
|
273
275
|
) and not should_build_agent_cluster(
|
@@ -276,7 +278,7 @@ async def build_graph(
|
|
276
278
|
draw_edge(agent.name, sub_agent.name)
|
277
279
|
if isinstance(agent, LlmAgent):
|
278
280
|
for tool in await agent.canonical_tools():
|
279
|
-
draw_node(tool)
|
281
|
+
await draw_node(tool)
|
280
282
|
draw_edge(agent.name, get_node_name(tool))
|
281
283
|
|
282
284
|
|