google-adk-extras 0.1.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_adk_extras/__init__.py +31 -1
- google_adk_extras/adk_builder.py +1030 -0
- google_adk_extras/artifacts/__init__.py +25 -12
- google_adk_extras/artifacts/base_custom_artifact_service.py +148 -11
- google_adk_extras/artifacts/local_folder_artifact_service.py +133 -13
- google_adk_extras/artifacts/s3_artifact_service.py +135 -19
- google_adk_extras/artifacts/sql_artifact_service.py +109 -10
- google_adk_extras/credentials/__init__.py +34 -0
- google_adk_extras/credentials/base_custom_credential_service.py +113 -0
- google_adk_extras/credentials/github_oauth2_credential_service.py +213 -0
- google_adk_extras/credentials/google_oauth2_credential_service.py +216 -0
- google_adk_extras/credentials/http_basic_auth_credential_service.py +388 -0
- google_adk_extras/credentials/jwt_credential_service.py +345 -0
- google_adk_extras/credentials/microsoft_oauth2_credential_service.py +250 -0
- google_adk_extras/credentials/x_oauth2_credential_service.py +240 -0
- google_adk_extras/custom_agent_loader.py +156 -0
- google_adk_extras/enhanced_adk_web_server.py +137 -0
- google_adk_extras/enhanced_fastapi.py +470 -0
- google_adk_extras/enhanced_runner.py +38 -0
- google_adk_extras/memory/__init__.py +30 -13
- google_adk_extras/memory/base_custom_memory_service.py +37 -5
- google_adk_extras/memory/sql_memory_service.py +105 -19
- google_adk_extras/memory/yaml_file_memory_service.py +115 -22
- google_adk_extras/sessions/__init__.py +29 -13
- google_adk_extras/sessions/base_custom_session_service.py +133 -11
- google_adk_extras/sessions/sql_session_service.py +127 -16
- google_adk_extras/sessions/yaml_file_session_service.py +122 -14
- google_adk_extras-0.2.3.dist-info/METADATA +302 -0
- google_adk_extras-0.2.3.dist-info/RECORD +37 -0
- google_adk_extras/py.typed +0 -0
- google_adk_extras-0.1.1.dist-info/METADATA +0 -175
- google_adk_extras-0.1.1.dist-info/RECORD +0 -25
- {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.3.dist-info}/WHEEL +0 -0
- {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.3.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@
|
|
3
3
|
import json
|
4
4
|
import base64
|
5
5
|
from typing import Optional, List
|
6
|
-
from datetime import datetime
|
6
|
+
from datetime import datetime, timezone
|
7
7
|
|
8
8
|
try:
|
9
9
|
import boto3
|
@@ -19,7 +19,12 @@ from .base_custom_artifact_service import BaseCustomArtifactService
|
|
19
19
|
|
20
20
|
|
21
21
|
class S3ArtifactService(BaseCustomArtifactService):
|
22
|
-
"""S3-compatible artifact service implementation.
|
22
|
+
"""S3-compatible artifact service implementation.
|
23
|
+
|
24
|
+
This service stores artifacts in AWS S3 or S3-compatible storage services.
|
25
|
+
It supports versioning and works with any S3-compatible service including
|
26
|
+
AWS S3, MinIO, Google Cloud Storage, etc.
|
27
|
+
"""
|
23
28
|
|
24
29
|
def __init__(
|
25
30
|
self,
|
@@ -33,12 +38,12 @@ class S3ArtifactService(BaseCustomArtifactService):
|
|
33
38
|
"""Initialize the S3 artifact service.
|
34
39
|
|
35
40
|
Args:
|
36
|
-
bucket_name: S3 bucket name
|
37
|
-
endpoint_url: S3 endpoint URL (for non-AWS S3 services)
|
38
|
-
region_name: AWS region name
|
39
|
-
aws_access_key_id: AWS access key ID
|
40
|
-
aws_secret_access_key: AWS secret access key
|
41
|
-
prefix: Prefix for artifact storage paths
|
41
|
+
bucket_name: S3 bucket name.
|
42
|
+
endpoint_url: S3 endpoint URL (for non-AWS S3 services like MinIO).
|
43
|
+
region_name: AWS region name.
|
44
|
+
aws_access_key_id: AWS access key ID.
|
45
|
+
aws_secret_access_key: AWS secret access key.
|
46
|
+
prefix: Prefix for artifact storage paths. Defaults to "adk-artifacts".
|
42
47
|
"""
|
43
48
|
super().__init__()
|
44
49
|
self.bucket_name = bucket_name
|
@@ -50,7 +55,14 @@ class S3ArtifactService(BaseCustomArtifactService):
|
|
50
55
|
self.s3_client = None
|
51
56
|
|
52
57
|
async def _initialize_impl(self) -> None:
|
53
|
-
"""Initialize the S3 client.
|
58
|
+
"""Initialize the S3 client.
|
59
|
+
|
60
|
+
Creates the S3 client and verifies the bucket exists (or creates it).
|
61
|
+
|
62
|
+
Raises:
|
63
|
+
RuntimeError: If S3 initialization fails.
|
64
|
+
NoCredentialsError: If AWS credentials are not found.
|
65
|
+
"""
|
54
66
|
try:
|
55
67
|
# Create S3 client
|
56
68
|
self.s3_client = boto3.client(
|
@@ -88,22 +100,61 @@ class S3ArtifactService(BaseCustomArtifactService):
|
|
88
100
|
self.s3_client = None
|
89
101
|
|
90
102
|
def _get_artifact_key(self, app_name: str, user_id: str, session_id: str, filename: str) -> str:
|
91
|
-
"""Generate S3 key for artifact metadata.
|
103
|
+
"""Generate S3 key for artifact metadata.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
app_name: The name of the application.
|
107
|
+
user_id: The ID of the user.
|
108
|
+
session_id: The ID of the session.
|
109
|
+
filename: The name of the file.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
S3 key for the metadata file.
|
113
|
+
"""
|
92
114
|
return f"{self.prefix}/{app_name}/{user_id}/{session_id}/{filename}.json"
|
93
115
|
|
94
116
|
def _get_artifact_data_key(self, app_name: str, user_id: str, session_id: str, filename: str, version: int) -> str:
|
95
|
-
"""Generate S3 key for artifact data.
|
117
|
+
"""Generate S3 key for artifact data.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
app_name: The name of the application.
|
121
|
+
user_id: The ID of the user.
|
122
|
+
session_id: The ID of the session.
|
123
|
+
filename: The name of the file.
|
124
|
+
version: The version number.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
S3 key for the data file.
|
128
|
+
"""
|
96
129
|
return f"{self.prefix}/{app_name}/{user_id}/{session_id}/{filename}.v{version}.data"
|
97
130
|
|
98
131
|
def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
|
99
|
-
"""Extract blob data and mime type from a Part.
|
132
|
+
"""Extract blob data and mime type from a Part.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
part: The Part object containing the blob data.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
A tuple of (data, mime_type).
|
139
|
+
|
140
|
+
Raises:
|
141
|
+
ValueError: If the part type is not supported.
|
142
|
+
"""
|
100
143
|
if part.inline_data:
|
101
144
|
return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
|
102
145
|
else:
|
103
146
|
raise ValueError("Only inline_data parts are supported")
|
104
147
|
|
105
148
|
def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
|
106
|
-
"""Create a Part from blob data and mime type.
|
149
|
+
"""Create a Part from blob data and mime type.
|
150
|
+
|
151
|
+
Args:
|
152
|
+
data: The binary data.
|
153
|
+
mime_type: The MIME type of the data.
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
A Part object containing the blob data.
|
157
|
+
"""
|
107
158
|
blob = types.Blob(data=data, mime_type=mime_type)
|
108
159
|
return types.Part(inline_data=blob)
|
109
160
|
|
@@ -116,7 +167,22 @@ class S3ArtifactService(BaseCustomArtifactService):
|
|
116
167
|
filename: str,
|
117
168
|
artifact: types.Part,
|
118
169
|
) -> int:
|
119
|
-
"""Implementation of artifact saving.
|
170
|
+
"""Implementation of artifact saving.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
app_name: The name of the application.
|
174
|
+
user_id: The ID of the user.
|
175
|
+
session_id: The ID of the session.
|
176
|
+
filename: The name of the file to save.
|
177
|
+
artifact: The artifact to save.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
The version number of the saved artifact.
|
181
|
+
|
182
|
+
Raises:
|
183
|
+
RuntimeError: If saving the artifact fails.
|
184
|
+
ValueError: If the artifact type is not supported.
|
185
|
+
"""
|
120
186
|
try:
|
121
187
|
# Extract blob data
|
122
188
|
data, mime_type = self._serialize_blob(artifact)
|
@@ -155,7 +221,7 @@ class S3ArtifactService(BaseCustomArtifactService):
|
|
155
221
|
version_info = {
|
156
222
|
"version": version,
|
157
223
|
"mime_type": mime_type,
|
158
|
-
"created_at": datetime.
|
224
|
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
159
225
|
"data_key": data_key
|
160
226
|
}
|
161
227
|
metadata["versions"].append(version_info)
|
@@ -180,7 +246,22 @@ class S3ArtifactService(BaseCustomArtifactService):
|
|
180
246
|
filename: str,
|
181
247
|
version: Optional[int] = None,
|
182
248
|
) -> Optional[types.Part]:
|
183
|
-
"""Implementation of artifact loading.
|
249
|
+
"""Implementation of artifact loading.
|
250
|
+
|
251
|
+
Args:
|
252
|
+
app_name: The name of the application.
|
253
|
+
user_id: The ID of the user.
|
254
|
+
session_id: The ID of the session.
|
255
|
+
filename: The name of the file to load.
|
256
|
+
version: Optional version number to load. If not provided,
|
257
|
+
the latest version will be loaded.
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
The loaded artifact if found, None otherwise.
|
261
|
+
|
262
|
+
Raises:
|
263
|
+
RuntimeError: If loading the artifact fails.
|
264
|
+
"""
|
184
265
|
try:
|
185
266
|
# Load metadata
|
186
267
|
metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
|
@@ -237,7 +318,19 @@ class S3ArtifactService(BaseCustomArtifactService):
|
|
237
318
|
user_id: str,
|
238
319
|
session_id: str,
|
239
320
|
) -> List[str]:
|
240
|
-
"""Implementation of artifact key listing.
|
321
|
+
"""Implementation of artifact key listing.
|
322
|
+
|
323
|
+
Args:
|
324
|
+
app_name: The name of the application.
|
325
|
+
user_id: The ID of the user.
|
326
|
+
session_id: The ID of the session.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
A list of artifact keys (filenames).
|
330
|
+
|
331
|
+
Raises:
|
332
|
+
RuntimeError: If listing artifact keys fails.
|
333
|
+
"""
|
241
334
|
try:
|
242
335
|
# List objects with the prefix
|
243
336
|
prefix = f"{self.prefix}/{app_name}/{user_id}/{session_id}/"
|
@@ -269,7 +362,17 @@ class S3ArtifactService(BaseCustomArtifactService):
|
|
269
362
|
session_id: str,
|
270
363
|
filename: str,
|
271
364
|
) -> None:
|
272
|
-
"""Implementation of artifact deletion.
|
365
|
+
"""Implementation of artifact deletion.
|
366
|
+
|
367
|
+
Args:
|
368
|
+
app_name: The name of the application.
|
369
|
+
user_id: The ID of the user.
|
370
|
+
session_id: The ID of the session.
|
371
|
+
filename: The name of the file to delete.
|
372
|
+
|
373
|
+
Raises:
|
374
|
+
RuntimeError: If deleting the artifact fails.
|
375
|
+
"""
|
273
376
|
try:
|
274
377
|
# Load metadata to find all version files
|
275
378
|
metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
|
@@ -304,7 +407,20 @@ class S3ArtifactService(BaseCustomArtifactService):
|
|
304
407
|
session_id: str,
|
305
408
|
filename: str,
|
306
409
|
) -> List[int]:
|
307
|
-
"""Implementation of version listing.
|
410
|
+
"""Implementation of version listing.
|
411
|
+
|
412
|
+
Args:
|
413
|
+
app_name: The name of the application.
|
414
|
+
user_id: The ID of the user.
|
415
|
+
session_id: The ID of the session.
|
416
|
+
filename: The name of the file to list versions for.
|
417
|
+
|
418
|
+
Returns:
|
419
|
+
A list of version numbers.
|
420
|
+
|
421
|
+
Raises:
|
422
|
+
RuntimeError: If listing versions fails.
|
423
|
+
"""
|
308
424
|
try:
|
309
425
|
metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
|
310
426
|
|
@@ -43,7 +43,12 @@ class SQLArtifactModel(Base):
|
|
43
43
|
|
44
44
|
|
45
45
|
class SQLArtifactService(BaseCustomArtifactService):
|
46
|
-
"""SQL-based artifact service implementation.
|
46
|
+
"""SQL-based artifact service implementation.
|
47
|
+
|
48
|
+
This service stores artifacts in a SQL database using SQLAlchemy.
|
49
|
+
It supports various SQL databases including SQLite, PostgreSQL, and MySQL.
|
50
|
+
Artifacts are stored with full versioning support.
|
51
|
+
"""
|
47
52
|
|
48
53
|
def __init__(self, database_url: str):
|
49
54
|
"""Initialize the SQL artifact service.
|
@@ -57,7 +62,11 @@ class SQLArtifactService(BaseCustomArtifactService):
|
|
57
62
|
self.session_local: Optional[object] = None
|
58
63
|
|
59
64
|
async def _initialize_impl(self) -> None:
|
60
|
-
"""Initialize the database connection and create tables.
|
65
|
+
"""Initialize the database connection and create tables.
|
66
|
+
|
67
|
+
Raises:
|
68
|
+
RuntimeError: If database initialization fails.
|
69
|
+
"""
|
61
70
|
try:
|
62
71
|
self.engine = create_engine(self.database_url)
|
63
72
|
Base.metadata.create_all(self.engine)
|
@@ -77,13 +86,30 @@ class SQLArtifactService(BaseCustomArtifactService):
|
|
77
86
|
self.session_local = None
|
78
87
|
|
79
88
|
def _get_db_session(self):
|
80
|
-
"""Get a database session.
|
89
|
+
"""Get a database session.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
A database session object.
|
93
|
+
|
94
|
+
Raises:
|
95
|
+
RuntimeError: If the service is not initialized.
|
96
|
+
"""
|
81
97
|
if not self.session_local:
|
82
98
|
raise RuntimeError("Service not initialized")
|
83
99
|
return self.session_local()
|
84
100
|
|
85
101
|
def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
|
86
|
-
"""Extract blob data and mime type from a Part.
|
102
|
+
"""Extract blob data and mime type from a Part.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
part: The Part object containing the blob data.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
A tuple of (data, mime_type).
|
109
|
+
|
110
|
+
Raises:
|
111
|
+
ValueError: If the part type is not supported.
|
112
|
+
"""
|
87
113
|
if part.inline_data:
|
88
114
|
return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
|
89
115
|
else:
|
@@ -93,7 +119,15 @@ class SQLArtifactService(BaseCustomArtifactService):
|
|
93
119
|
raise ValueError("Only inline_data parts are supported")
|
94
120
|
|
95
121
|
def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
|
96
|
-
"""Create a Part from blob data and mime type.
|
122
|
+
"""Create a Part from blob data and mime type.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
data: The binary data.
|
126
|
+
mime_type: The MIME type of the data.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
A Part object containing the blob data.
|
130
|
+
"""
|
97
131
|
blob = types.Blob(data=data, mime_type=mime_type)
|
98
132
|
return types.Part(inline_data=blob)
|
99
133
|
|
@@ -106,7 +140,22 @@ class SQLArtifactService(BaseCustomArtifactService):
|
|
106
140
|
filename: str,
|
107
141
|
artifact: types.Part,
|
108
142
|
) -> int:
|
109
|
-
"""Implementation of artifact saving.
|
143
|
+
"""Implementation of artifact saving.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
app_name: The name of the application.
|
147
|
+
user_id: The ID of the user.
|
148
|
+
session_id: The ID of the session.
|
149
|
+
filename: The name of the file to save.
|
150
|
+
artifact: The artifact to save.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
The version number of the saved artifact.
|
154
|
+
|
155
|
+
Raises:
|
156
|
+
RuntimeError: If saving the artifact fails.
|
157
|
+
ValueError: If the artifact type is not supported.
|
158
|
+
"""
|
110
159
|
db_session = self._get_db_session()
|
111
160
|
try:
|
112
161
|
# Extract blob data
|
@@ -153,7 +202,22 @@ class SQLArtifactService(BaseCustomArtifactService):
|
|
153
202
|
filename: str,
|
154
203
|
version: Optional[int] = None,
|
155
204
|
) -> Optional[types.Part]:
|
156
|
-
"""Implementation of artifact loading.
|
205
|
+
"""Implementation of artifact loading.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
app_name: The name of the application.
|
209
|
+
user_id: The ID of the user.
|
210
|
+
session_id: The ID of the session.
|
211
|
+
filename: The name of the file to load.
|
212
|
+
version: Optional version number to load. If not provided,
|
213
|
+
the latest version will be loaded.
|
214
|
+
|
215
|
+
Returns:
|
216
|
+
The loaded artifact if found, None otherwise.
|
217
|
+
|
218
|
+
Raises:
|
219
|
+
RuntimeError: If loading the artifact fails.
|
220
|
+
"""
|
157
221
|
db_session = self._get_db_session()
|
158
222
|
try:
|
159
223
|
query = db_session.query(SQLArtifactModel).filter(
|
@@ -188,7 +252,19 @@ class SQLArtifactService(BaseCustomArtifactService):
|
|
188
252
|
user_id: str,
|
189
253
|
session_id: str,
|
190
254
|
) -> List[str]:
|
191
|
-
"""Implementation of artifact key listing.
|
255
|
+
"""Implementation of artifact key listing.
|
256
|
+
|
257
|
+
Args:
|
258
|
+
app_name: The name of the application.
|
259
|
+
user_id: The ID of the user.
|
260
|
+
session_id: The ID of the session.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
A list of artifact keys (filenames).
|
264
|
+
|
265
|
+
Raises:
|
266
|
+
RuntimeError: If listing artifact keys fails.
|
267
|
+
"""
|
192
268
|
db_session = self._get_db_session()
|
193
269
|
try:
|
194
270
|
# Get distinct filenames
|
@@ -212,7 +288,17 @@ class SQLArtifactService(BaseCustomArtifactService):
|
|
212
288
|
session_id: str,
|
213
289
|
filename: str,
|
214
290
|
) -> None:
|
215
|
-
"""Implementation of artifact deletion.
|
291
|
+
"""Implementation of artifact deletion.
|
292
|
+
|
293
|
+
Args:
|
294
|
+
app_name: The name of the application.
|
295
|
+
user_id: The ID of the user.
|
296
|
+
session_id: The ID of the session.
|
297
|
+
filename: The name of the file to delete.
|
298
|
+
|
299
|
+
Raises:
|
300
|
+
RuntimeError: If deleting the artifact fails.
|
301
|
+
"""
|
216
302
|
db_session = self._get_db_session()
|
217
303
|
try:
|
218
304
|
# Delete all versions of the artifact
|
@@ -238,7 +324,20 @@ class SQLArtifactService(BaseCustomArtifactService):
|
|
238
324
|
session_id: str,
|
239
325
|
filename: str,
|
240
326
|
) -> List[int]:
|
241
|
-
"""Implementation of version listing.
|
327
|
+
"""Implementation of version listing.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
app_name: The name of the application.
|
331
|
+
user_id: The ID of the user.
|
332
|
+
session_id: The ID of the session.
|
333
|
+
filename: The name of the file to list versions for.
|
334
|
+
|
335
|
+
Returns:
|
336
|
+
A list of version numbers.
|
337
|
+
|
338
|
+
Raises:
|
339
|
+
RuntimeError: If listing versions fails.
|
340
|
+
"""
|
242
341
|
db_session = self._get_db_session()
|
243
342
|
try:
|
244
343
|
versions = db_session.query(SQLArtifactModel.version).filter(
|
@@ -0,0 +1,34 @@
|
|
1
|
+
"""Custom credential service implementations for Google ADK.
|
2
|
+
|
3
|
+
Optional services are imported lazily to avoid import-time failures when
|
4
|
+
their third-party dependencies are not installed.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from .base_custom_credential_service import BaseCustomCredentialService
|
8
|
+
from .google_oauth2_credential_service import GoogleOAuth2CredentialService
|
9
|
+
from .github_oauth2_credential_service import GitHubOAuth2CredentialService
|
10
|
+
from .microsoft_oauth2_credential_service import MicrosoftOAuth2CredentialService
|
11
|
+
from .x_oauth2_credential_service import XOAuth2CredentialService
|
12
|
+
from .http_basic_auth_credential_service import (
|
13
|
+
HTTPBasicAuthCredentialService,
|
14
|
+
HTTPBasicAuthWithCredentialsService,
|
15
|
+
)
|
16
|
+
|
17
|
+
# Optional: JWT (requires PyJWT)
|
18
|
+
try:
|
19
|
+
from .jwt_credential_service import JWTCredentialService # type: ignore
|
20
|
+
except Exception: # ImportError or transitive import errors
|
21
|
+
JWTCredentialService = None # type: ignore
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
"BaseCustomCredentialService",
|
25
|
+
"GoogleOAuth2CredentialService",
|
26
|
+
"GitHubOAuth2CredentialService",
|
27
|
+
"MicrosoftOAuth2CredentialService",
|
28
|
+
"XOAuth2CredentialService",
|
29
|
+
"HTTPBasicAuthCredentialService",
|
30
|
+
"HTTPBasicAuthWithCredentialsService",
|
31
|
+
]
|
32
|
+
|
33
|
+
if JWTCredentialService is not None:
|
34
|
+
__all__.append("JWTCredentialService")
|
@@ -0,0 +1,113 @@
|
|
1
|
+
"""Base class for custom credential services."""
|
2
|
+
|
3
|
+
import abc
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
from google.adk.auth.credential_service.base_credential_service import BaseCredentialService, CallbackContext
|
7
|
+
from google.adk.auth import AuthConfig, AuthCredential
|
8
|
+
|
9
|
+
|
10
|
+
class BaseCustomCredentialService(BaseCredentialService, abc.ABC):
|
11
|
+
"""Base class for custom credential services with common functionality.
|
12
|
+
|
13
|
+
This abstract base class provides a foundation for implementing custom
|
14
|
+
credential services with automatic initialization and cleanup handling.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self):
|
18
|
+
"""Initialize the base custom credential service."""
|
19
|
+
super().__init__()
|
20
|
+
self._initialized = False
|
21
|
+
|
22
|
+
async def initialize(self) -> None:
|
23
|
+
"""Initialize the credential service.
|
24
|
+
|
25
|
+
This method should be called before using the service to ensure
|
26
|
+
any required setup (connections, validations, etc.) is complete.
|
27
|
+
|
28
|
+
Raises:
|
29
|
+
RuntimeError: If initialization fails.
|
30
|
+
"""
|
31
|
+
if not self._initialized:
|
32
|
+
await self._initialize_impl()
|
33
|
+
self._initialized = True
|
34
|
+
|
35
|
+
@abc.abstractmethod
|
36
|
+
async def _initialize_impl(self) -> None:
|
37
|
+
"""Implementation of service initialization.
|
38
|
+
|
39
|
+
This method should handle any setup required for the service to function,
|
40
|
+
such as validating credentials, establishing connections, etc.
|
41
|
+
|
42
|
+
Raises:
|
43
|
+
RuntimeError: If initialization fails.
|
44
|
+
"""
|
45
|
+
pass
|
46
|
+
|
47
|
+
async def cleanup(self) -> None:
|
48
|
+
"""Clean up resources used by the credential service.
|
49
|
+
|
50
|
+
This method should be called when the service is no longer needed
|
51
|
+
to ensure proper cleanup of any resources.
|
52
|
+
"""
|
53
|
+
if self._initialized:
|
54
|
+
await self._cleanup_impl()
|
55
|
+
self._initialized = False
|
56
|
+
|
57
|
+
async def _cleanup_impl(self) -> None:
|
58
|
+
"""Implementation of service cleanup.
|
59
|
+
|
60
|
+
This method should handle cleanup of any resources used by the service.
|
61
|
+
The default implementation does nothing, but subclasses can override
|
62
|
+
to perform specific cleanup operations.
|
63
|
+
"""
|
64
|
+
pass
|
65
|
+
|
66
|
+
def _check_initialized(self) -> None:
|
67
|
+
"""Check if the service has been initialized.
|
68
|
+
|
69
|
+
Raises:
|
70
|
+
RuntimeError: If the service has not been initialized.
|
71
|
+
"""
|
72
|
+
if not self._initialized:
|
73
|
+
raise RuntimeError(
|
74
|
+
f"{self.__class__.__name__} must be initialized before use. "
|
75
|
+
"Call await service.initialize() first."
|
76
|
+
)
|
77
|
+
|
78
|
+
@abc.abstractmethod
|
79
|
+
async def load_credential(
|
80
|
+
self,
|
81
|
+
auth_config: AuthConfig,
|
82
|
+
callback_context: CallbackContext,
|
83
|
+
) -> Optional[AuthCredential]:
|
84
|
+
"""Load the credential by auth config and current callback context.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
auth_config: The auth config which contains the auth scheme and auth
|
88
|
+
credential information. auth_config.get_credential_key will be used to
|
89
|
+
build the key to load the credential.
|
90
|
+
callback_context: The context of the current invocation when the tool is
|
91
|
+
trying to load the credential.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
Optional[AuthCredential]: the credential saved in the store, or None if not found.
|
95
|
+
"""
|
96
|
+
pass
|
97
|
+
|
98
|
+
@abc.abstractmethod
|
99
|
+
async def save_credential(
|
100
|
+
self,
|
101
|
+
auth_config: AuthConfig,
|
102
|
+
callback_context: CallbackContext,
|
103
|
+
) -> None:
|
104
|
+
"""Save the exchanged_auth_credential in auth config.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
auth_config: The auth config which contains the auth scheme and auth
|
108
|
+
credential information. auth_config.get_credential_key will be used to
|
109
|
+
build the key to save the credential.
|
110
|
+
callback_context: The context of the current invocation when the tool is
|
111
|
+
trying to save the credential.
|
112
|
+
"""
|
113
|
+
pass
|