google-adk-extras 0.1.1__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. google_adk_extras/__init__.py +31 -1
  2. google_adk_extras/adk_builder.py +1030 -0
  3. google_adk_extras/artifacts/__init__.py +25 -12
  4. google_adk_extras/artifacts/base_custom_artifact_service.py +148 -11
  5. google_adk_extras/artifacts/local_folder_artifact_service.py +133 -13
  6. google_adk_extras/artifacts/s3_artifact_service.py +135 -19
  7. google_adk_extras/artifacts/sql_artifact_service.py +109 -10
  8. google_adk_extras/credentials/__init__.py +34 -0
  9. google_adk_extras/credentials/base_custom_credential_service.py +113 -0
  10. google_adk_extras/credentials/github_oauth2_credential_service.py +213 -0
  11. google_adk_extras/credentials/google_oauth2_credential_service.py +216 -0
  12. google_adk_extras/credentials/http_basic_auth_credential_service.py +388 -0
  13. google_adk_extras/credentials/jwt_credential_service.py +345 -0
  14. google_adk_extras/credentials/microsoft_oauth2_credential_service.py +250 -0
  15. google_adk_extras/credentials/x_oauth2_credential_service.py +240 -0
  16. google_adk_extras/custom_agent_loader.py +170 -0
  17. google_adk_extras/enhanced_adk_web_server.py +137 -0
  18. google_adk_extras/enhanced_fastapi.py +507 -0
  19. google_adk_extras/enhanced_runner.py +38 -0
  20. google_adk_extras/memory/__init__.py +30 -13
  21. google_adk_extras/memory/base_custom_memory_service.py +37 -5
  22. google_adk_extras/memory/sql_memory_service.py +105 -19
  23. google_adk_extras/memory/yaml_file_memory_service.py +115 -22
  24. google_adk_extras/sessions/__init__.py +29 -13
  25. google_adk_extras/sessions/base_custom_session_service.py +133 -11
  26. google_adk_extras/sessions/sql_session_service.py +127 -16
  27. google_adk_extras/sessions/yaml_file_session_service.py +122 -14
  28. google_adk_extras-0.2.5.dist-info/METADATA +302 -0
  29. google_adk_extras-0.2.5.dist-info/RECORD +37 -0
  30. google_adk_extras/py.typed +0 -0
  31. google_adk_extras-0.1.1.dist-info/METADATA +0 -175
  32. google_adk_extras-0.1.1.dist-info/RECORD +0 -25
  33. {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.5.dist-info}/WHEEL +0 -0
  34. {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.5.dist-info}/licenses/LICENSE +0 -0
  35. {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.5.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@
3
3
  import json
4
4
  import base64
5
5
  from typing import Optional, List
6
- from datetime import datetime
6
+ from datetime import datetime, timezone
7
7
 
8
8
  try:
9
9
  import boto3
@@ -19,7 +19,12 @@ from .base_custom_artifact_service import BaseCustomArtifactService
19
19
 
20
20
 
21
21
  class S3ArtifactService(BaseCustomArtifactService):
22
- """S3-compatible artifact service implementation."""
22
+ """S3-compatible artifact service implementation.
23
+
24
+ This service stores artifacts in AWS S3 or S3-compatible storage services.
25
+ It supports versioning and works with any S3-compatible service including
26
+ AWS S3, MinIO, Google Cloud Storage, etc.
27
+ """
23
28
 
24
29
  def __init__(
25
30
  self,
@@ -33,12 +38,12 @@ class S3ArtifactService(BaseCustomArtifactService):
33
38
  """Initialize the S3 artifact service.
34
39
 
35
40
  Args:
36
- bucket_name: S3 bucket name
37
- endpoint_url: S3 endpoint URL (for non-AWS S3 services)
38
- region_name: AWS region name
39
- aws_access_key_id: AWS access key ID
40
- aws_secret_access_key: AWS secret access key
41
- prefix: Prefix for artifact storage paths
41
+ bucket_name: S3 bucket name.
42
+ endpoint_url: S3 endpoint URL (for non-AWS S3 services like MinIO).
43
+ region_name: AWS region name.
44
+ aws_access_key_id: AWS access key ID.
45
+ aws_secret_access_key: AWS secret access key.
46
+ prefix: Prefix for artifact storage paths. Defaults to "adk-artifacts".
42
47
  """
43
48
  super().__init__()
44
49
  self.bucket_name = bucket_name
@@ -50,7 +55,14 @@ class S3ArtifactService(BaseCustomArtifactService):
50
55
  self.s3_client = None
51
56
 
52
57
  async def _initialize_impl(self) -> None:
53
- """Initialize the S3 client."""
58
+ """Initialize the S3 client.
59
+
60
+ Creates the S3 client and verifies the bucket exists (or creates it).
61
+
62
+ Raises:
63
+ RuntimeError: If S3 initialization fails.
64
+ NoCredentialsError: If AWS credentials are not found.
65
+ """
54
66
  try:
55
67
  # Create S3 client
56
68
  self.s3_client = boto3.client(
@@ -88,22 +100,61 @@ class S3ArtifactService(BaseCustomArtifactService):
88
100
  self.s3_client = None
89
101
 
90
102
  def _get_artifact_key(self, app_name: str, user_id: str, session_id: str, filename: str) -> str:
91
- """Generate S3 key for artifact metadata."""
103
+ """Generate S3 key for artifact metadata.
104
+
105
+ Args:
106
+ app_name: The name of the application.
107
+ user_id: The ID of the user.
108
+ session_id: The ID of the session.
109
+ filename: The name of the file.
110
+
111
+ Returns:
112
+ S3 key for the metadata file.
113
+ """
92
114
  return f"{self.prefix}/{app_name}/{user_id}/{session_id}/{filename}.json"
93
115
 
94
116
  def _get_artifact_data_key(self, app_name: str, user_id: str, session_id: str, filename: str, version: int) -> str:
95
- """Generate S3 key for artifact data."""
117
+ """Generate S3 key for artifact data.
118
+
119
+ Args:
120
+ app_name: The name of the application.
121
+ user_id: The ID of the user.
122
+ session_id: The ID of the session.
123
+ filename: The name of the file.
124
+ version: The version number.
125
+
126
+ Returns:
127
+ S3 key for the data file.
128
+ """
96
129
  return f"{self.prefix}/{app_name}/{user_id}/{session_id}/{filename}.v{version}.data"
97
130
 
98
131
  def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
99
- """Extract blob data and mime type from a Part."""
132
+ """Extract blob data and mime type from a Part.
133
+
134
+ Args:
135
+ part: The Part object containing the blob data.
136
+
137
+ Returns:
138
+ A tuple of (data, mime_type).
139
+
140
+ Raises:
141
+ ValueError: If the part type is not supported.
142
+ """
100
143
  if part.inline_data:
101
144
  return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
102
145
  else:
103
146
  raise ValueError("Only inline_data parts are supported")
104
147
 
105
148
  def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
106
- """Create a Part from blob data and mime type."""
149
+ """Create a Part from blob data and mime type.
150
+
151
+ Args:
152
+ data: The binary data.
153
+ mime_type: The MIME type of the data.
154
+
155
+ Returns:
156
+ A Part object containing the blob data.
157
+ """
107
158
  blob = types.Blob(data=data, mime_type=mime_type)
108
159
  return types.Part(inline_data=blob)
109
160
 
@@ -116,7 +167,22 @@ class S3ArtifactService(BaseCustomArtifactService):
116
167
  filename: str,
117
168
  artifact: types.Part,
118
169
  ) -> int:
119
- """Implementation of artifact saving."""
170
+ """Implementation of artifact saving.
171
+
172
+ Args:
173
+ app_name: The name of the application.
174
+ user_id: The ID of the user.
175
+ session_id: The ID of the session.
176
+ filename: The name of the file to save.
177
+ artifact: The artifact to save.
178
+
179
+ Returns:
180
+ The version number of the saved artifact.
181
+
182
+ Raises:
183
+ RuntimeError: If saving the artifact fails.
184
+ ValueError: If the artifact type is not supported.
185
+ """
120
186
  try:
121
187
  # Extract blob data
122
188
  data, mime_type = self._serialize_blob(artifact)
@@ -155,7 +221,7 @@ class S3ArtifactService(BaseCustomArtifactService):
155
221
  version_info = {
156
222
  "version": version,
157
223
  "mime_type": mime_type,
158
- "created_at": datetime.utcnow().isoformat(),
224
+ "created_at": datetime.now(timezone.utc).isoformat(),
159
225
  "data_key": data_key
160
226
  }
161
227
  metadata["versions"].append(version_info)
@@ -180,7 +246,22 @@ class S3ArtifactService(BaseCustomArtifactService):
180
246
  filename: str,
181
247
  version: Optional[int] = None,
182
248
  ) -> Optional[types.Part]:
183
- """Implementation of artifact loading."""
249
+ """Implementation of artifact loading.
250
+
251
+ Args:
252
+ app_name: The name of the application.
253
+ user_id: The ID of the user.
254
+ session_id: The ID of the session.
255
+ filename: The name of the file to load.
256
+ version: Optional version number to load. If not provided,
257
+ the latest version will be loaded.
258
+
259
+ Returns:
260
+ The loaded artifact if found, None otherwise.
261
+
262
+ Raises:
263
+ RuntimeError: If loading the artifact fails.
264
+ """
184
265
  try:
185
266
  # Load metadata
186
267
  metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
@@ -237,7 +318,19 @@ class S3ArtifactService(BaseCustomArtifactService):
237
318
  user_id: str,
238
319
  session_id: str,
239
320
  ) -> List[str]:
240
- """Implementation of artifact key listing."""
321
+ """Implementation of artifact key listing.
322
+
323
+ Args:
324
+ app_name: The name of the application.
325
+ user_id: The ID of the user.
326
+ session_id: The ID of the session.
327
+
328
+ Returns:
329
+ A list of artifact keys (filenames).
330
+
331
+ Raises:
332
+ RuntimeError: If listing artifact keys fails.
333
+ """
241
334
  try:
242
335
  # List objects with the prefix
243
336
  prefix = f"{self.prefix}/{app_name}/{user_id}/{session_id}/"
@@ -269,7 +362,17 @@ class S3ArtifactService(BaseCustomArtifactService):
269
362
  session_id: str,
270
363
  filename: str,
271
364
  ) -> None:
272
- """Implementation of artifact deletion."""
365
+ """Implementation of artifact deletion.
366
+
367
+ Args:
368
+ app_name: The name of the application.
369
+ user_id: The ID of the user.
370
+ session_id: The ID of the session.
371
+ filename: The name of the file to delete.
372
+
373
+ Raises:
374
+ RuntimeError: If deleting the artifact fails.
375
+ """
273
376
  try:
274
377
  # Load metadata to find all version files
275
378
  metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
@@ -304,7 +407,20 @@ class S3ArtifactService(BaseCustomArtifactService):
304
407
  session_id: str,
305
408
  filename: str,
306
409
  ) -> List[int]:
307
- """Implementation of version listing."""
410
+ """Implementation of version listing.
411
+
412
+ Args:
413
+ app_name: The name of the application.
414
+ user_id: The ID of the user.
415
+ session_id: The ID of the session.
416
+ filename: The name of the file to list versions for.
417
+
418
+ Returns:
419
+ A list of version numbers.
420
+
421
+ Raises:
422
+ RuntimeError: If listing versions fails.
423
+ """
308
424
  try:
309
425
  metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
310
426
 
@@ -43,7 +43,12 @@ class SQLArtifactModel(Base):
43
43
 
44
44
 
45
45
  class SQLArtifactService(BaseCustomArtifactService):
46
- """SQL-based artifact service implementation."""
46
+ """SQL-based artifact service implementation.
47
+
48
+ This service stores artifacts in a SQL database using SQLAlchemy.
49
+ It supports various SQL databases including SQLite, PostgreSQL, and MySQL.
50
+ Artifacts are stored with full versioning support.
51
+ """
47
52
 
48
53
  def __init__(self, database_url: str):
49
54
  """Initialize the SQL artifact service.
@@ -57,7 +62,11 @@ class SQLArtifactService(BaseCustomArtifactService):
57
62
  self.session_local: Optional[object] = None
58
63
 
59
64
  async def _initialize_impl(self) -> None:
60
- """Initialize the database connection and create tables."""
65
+ """Initialize the database connection and create tables.
66
+
67
+ Raises:
68
+ RuntimeError: If database initialization fails.
69
+ """
61
70
  try:
62
71
  self.engine = create_engine(self.database_url)
63
72
  Base.metadata.create_all(self.engine)
@@ -77,13 +86,30 @@ class SQLArtifactService(BaseCustomArtifactService):
77
86
  self.session_local = None
78
87
 
79
88
  def _get_db_session(self):
80
- """Get a database session."""
89
+ """Get a database session.
90
+
91
+ Returns:
92
+ A database session object.
93
+
94
+ Raises:
95
+ RuntimeError: If the service is not initialized.
96
+ """
81
97
  if not self.session_local:
82
98
  raise RuntimeError("Service not initialized")
83
99
  return self.session_local()
84
100
 
85
101
  def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
86
- """Extract blob data and mime type from a Part."""
102
+ """Extract blob data and mime type from a Part.
103
+
104
+ Args:
105
+ part: The Part object containing the blob data.
106
+
107
+ Returns:
108
+ A tuple of (data, mime_type).
109
+
110
+ Raises:
111
+ ValueError: If the part type is not supported.
112
+ """
87
113
  if part.inline_data:
88
114
  return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
89
115
  else:
@@ -93,7 +119,15 @@ class SQLArtifactService(BaseCustomArtifactService):
93
119
  raise ValueError("Only inline_data parts are supported")
94
120
 
95
121
  def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
96
- """Create a Part from blob data and mime type."""
122
+ """Create a Part from blob data and mime type.
123
+
124
+ Args:
125
+ data: The binary data.
126
+ mime_type: The MIME type of the data.
127
+
128
+ Returns:
129
+ A Part object containing the blob data.
130
+ """
97
131
  blob = types.Blob(data=data, mime_type=mime_type)
98
132
  return types.Part(inline_data=blob)
99
133
 
@@ -106,7 +140,22 @@ class SQLArtifactService(BaseCustomArtifactService):
106
140
  filename: str,
107
141
  artifact: types.Part,
108
142
  ) -> int:
109
- """Implementation of artifact saving."""
143
+ """Implementation of artifact saving.
144
+
145
+ Args:
146
+ app_name: The name of the application.
147
+ user_id: The ID of the user.
148
+ session_id: The ID of the session.
149
+ filename: The name of the file to save.
150
+ artifact: The artifact to save.
151
+
152
+ Returns:
153
+ The version number of the saved artifact.
154
+
155
+ Raises:
156
+ RuntimeError: If saving the artifact fails.
157
+ ValueError: If the artifact type is not supported.
158
+ """
110
159
  db_session = self._get_db_session()
111
160
  try:
112
161
  # Extract blob data
@@ -153,7 +202,22 @@ class SQLArtifactService(BaseCustomArtifactService):
153
202
  filename: str,
154
203
  version: Optional[int] = None,
155
204
  ) -> Optional[types.Part]:
156
- """Implementation of artifact loading."""
205
+ """Implementation of artifact loading.
206
+
207
+ Args:
208
+ app_name: The name of the application.
209
+ user_id: The ID of the user.
210
+ session_id: The ID of the session.
211
+ filename: The name of the file to load.
212
+ version: Optional version number to load. If not provided,
213
+ the latest version will be loaded.
214
+
215
+ Returns:
216
+ The loaded artifact if found, None otherwise.
217
+
218
+ Raises:
219
+ RuntimeError: If loading the artifact fails.
220
+ """
157
221
  db_session = self._get_db_session()
158
222
  try:
159
223
  query = db_session.query(SQLArtifactModel).filter(
@@ -188,7 +252,19 @@ class SQLArtifactService(BaseCustomArtifactService):
188
252
  user_id: str,
189
253
  session_id: str,
190
254
  ) -> List[str]:
191
- """Implementation of artifact key listing."""
255
+ """Implementation of artifact key listing.
256
+
257
+ Args:
258
+ app_name: The name of the application.
259
+ user_id: The ID of the user.
260
+ session_id: The ID of the session.
261
+
262
+ Returns:
263
+ A list of artifact keys (filenames).
264
+
265
+ Raises:
266
+ RuntimeError: If listing artifact keys fails.
267
+ """
192
268
  db_session = self._get_db_session()
193
269
  try:
194
270
  # Get distinct filenames
@@ -212,7 +288,17 @@ class SQLArtifactService(BaseCustomArtifactService):
212
288
  session_id: str,
213
289
  filename: str,
214
290
  ) -> None:
215
- """Implementation of artifact deletion."""
291
+ """Implementation of artifact deletion.
292
+
293
+ Args:
294
+ app_name: The name of the application.
295
+ user_id: The ID of the user.
296
+ session_id: The ID of the session.
297
+ filename: The name of the file to delete.
298
+
299
+ Raises:
300
+ RuntimeError: If deleting the artifact fails.
301
+ """
216
302
  db_session = self._get_db_session()
217
303
  try:
218
304
  # Delete all versions of the artifact
@@ -238,7 +324,20 @@ class SQLArtifactService(BaseCustomArtifactService):
238
324
  session_id: str,
239
325
  filename: str,
240
326
  ) -> List[int]:
241
- """Implementation of version listing."""
327
+ """Implementation of version listing.
328
+
329
+ Args:
330
+ app_name: The name of the application.
331
+ user_id: The ID of the user.
332
+ session_id: The ID of the session.
333
+ filename: The name of the file to list versions for.
334
+
335
+ Returns:
336
+ A list of version numbers.
337
+
338
+ Raises:
339
+ RuntimeError: If listing versions fails.
340
+ """
242
341
  db_session = self._get_db_session()
243
342
  try:
244
343
  versions = db_session.query(SQLArtifactModel.version).filter(
@@ -0,0 +1,34 @@
1
+ """Custom credential service implementations for Google ADK.
2
+
3
+ Optional services are imported lazily to avoid import-time failures when
4
+ their third-party dependencies are not installed.
5
+ """
6
+
7
+ from .base_custom_credential_service import BaseCustomCredentialService
8
+ from .google_oauth2_credential_service import GoogleOAuth2CredentialService
9
+ from .github_oauth2_credential_service import GitHubOAuth2CredentialService
10
+ from .microsoft_oauth2_credential_service import MicrosoftOAuth2CredentialService
11
+ from .x_oauth2_credential_service import XOAuth2CredentialService
12
+ from .http_basic_auth_credential_service import (
13
+ HTTPBasicAuthCredentialService,
14
+ HTTPBasicAuthWithCredentialsService,
15
+ )
16
+
17
+ # Optional: JWT (requires PyJWT)
18
+ try:
19
+ from .jwt_credential_service import JWTCredentialService # type: ignore
20
+ except Exception: # ImportError or transitive import errors
21
+ JWTCredentialService = None # type: ignore
22
+
23
+ __all__ = [
24
+ "BaseCustomCredentialService",
25
+ "GoogleOAuth2CredentialService",
26
+ "GitHubOAuth2CredentialService",
27
+ "MicrosoftOAuth2CredentialService",
28
+ "XOAuth2CredentialService",
29
+ "HTTPBasicAuthCredentialService",
30
+ "HTTPBasicAuthWithCredentialsService",
31
+ ]
32
+
33
+ if JWTCredentialService is not None:
34
+ __all__.append("JWTCredentialService")
@@ -0,0 +1,113 @@
1
+ """Base class for custom credential services."""
2
+
3
+ import abc
4
+ from typing import Optional
5
+
6
+ from google.adk.auth.credential_service.base_credential_service import BaseCredentialService, CallbackContext
7
+ from google.adk.auth import AuthConfig, AuthCredential
8
+
9
+
10
+ class BaseCustomCredentialService(BaseCredentialService, abc.ABC):
11
+ """Base class for custom credential services with common functionality.
12
+
13
+ This abstract base class provides a foundation for implementing custom
14
+ credential services with automatic initialization and cleanup handling.
15
+ """
16
+
17
+ def __init__(self):
18
+ """Initialize the base custom credential service."""
19
+ super().__init__()
20
+ self._initialized = False
21
+
22
+ async def initialize(self) -> None:
23
+ """Initialize the credential service.
24
+
25
+ This method should be called before using the service to ensure
26
+ any required setup (connections, validations, etc.) is complete.
27
+
28
+ Raises:
29
+ RuntimeError: If initialization fails.
30
+ """
31
+ if not self._initialized:
32
+ await self._initialize_impl()
33
+ self._initialized = True
34
+
35
+ @abc.abstractmethod
36
+ async def _initialize_impl(self) -> None:
37
+ """Implementation of service initialization.
38
+
39
+ This method should handle any setup required for the service to function,
40
+ such as validating credentials, establishing connections, etc.
41
+
42
+ Raises:
43
+ RuntimeError: If initialization fails.
44
+ """
45
+ pass
46
+
47
+ async def cleanup(self) -> None:
48
+ """Clean up resources used by the credential service.
49
+
50
+ This method should be called when the service is no longer needed
51
+ to ensure proper cleanup of any resources.
52
+ """
53
+ if self._initialized:
54
+ await self._cleanup_impl()
55
+ self._initialized = False
56
+
57
+ async def _cleanup_impl(self) -> None:
58
+ """Implementation of service cleanup.
59
+
60
+ This method should handle cleanup of any resources used by the service.
61
+ The default implementation does nothing, but subclasses can override
62
+ to perform specific cleanup operations.
63
+ """
64
+ pass
65
+
66
+ def _check_initialized(self) -> None:
67
+ """Check if the service has been initialized.
68
+
69
+ Raises:
70
+ RuntimeError: If the service has not been initialized.
71
+ """
72
+ if not self._initialized:
73
+ raise RuntimeError(
74
+ f"{self.__class__.__name__} must be initialized before use. "
75
+ "Call await service.initialize() first."
76
+ )
77
+
78
+ @abc.abstractmethod
79
+ async def load_credential(
80
+ self,
81
+ auth_config: AuthConfig,
82
+ callback_context: CallbackContext,
83
+ ) -> Optional[AuthCredential]:
84
+ """Load the credential by auth config and current callback context.
85
+
86
+ Args:
87
+ auth_config: The auth config which contains the auth scheme and auth
88
+ credential information. auth_config.get_credential_key will be used to
89
+ build the key to load the credential.
90
+ callback_context: The context of the current invocation when the tool is
91
+ trying to load the credential.
92
+
93
+ Returns:
94
+ Optional[AuthCredential]: the credential saved in the store, or None if not found.
95
+ """
96
+ pass
97
+
98
+ @abc.abstractmethod
99
+ async def save_credential(
100
+ self,
101
+ auth_config: AuthConfig,
102
+ callback_context: CallbackContext,
103
+ ) -> None:
104
+ """Save the exchanged_auth_credential in auth config.
105
+
106
+ Args:
107
+ auth_config: The auth config which contains the auth scheme and auth
108
+ credential information. auth_config.get_credential_key will be used to
109
+ build the key to save the credential.
110
+ callback_context: The context of the current invocation when the tool is
111
+ trying to save the credential.
112
+ """
113
+ pass