google-adk-extras 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,323 @@
1
+ """S3-compatible artifact service implementation."""
2
+
3
+ import json
4
+ import base64
5
+ from typing import Optional, List
6
+ from datetime import datetime
7
+
8
+ try:
9
+ import boto3
10
+ from botocore.exceptions import ClientError, NoCredentialsError
11
+ except ImportError:
12
+ raise ImportError(
13
+ "Boto3 is required for S3ArtifactService. "
14
+ "Install it with: pip install boto3"
15
+ )
16
+
17
+ from google.genai import types
18
+ from .base_custom_artifact_service import BaseCustomArtifactService
19
+
20
+
21
+ class S3ArtifactService(BaseCustomArtifactService):
22
+ """S3-compatible artifact service implementation."""
23
+
24
+ def __init__(
25
+ self,
26
+ bucket_name: str,
27
+ endpoint_url: Optional[str] = None,
28
+ region_name: Optional[str] = None,
29
+ aws_access_key_id: Optional[str] = None,
30
+ aws_secret_access_key: Optional[str] = None,
31
+ prefix: str = "adk-artifacts",
32
+ ):
33
+ """Initialize the S3 artifact service.
34
+
35
+ Args:
36
+ bucket_name: S3 bucket name
37
+ endpoint_url: S3 endpoint URL (for non-AWS S3 services)
38
+ region_name: AWS region name
39
+ aws_access_key_id: AWS access key ID
40
+ aws_secret_access_key: AWS secret access key
41
+ prefix: Prefix for artifact storage paths
42
+ """
43
+ super().__init__()
44
+ self.bucket_name = bucket_name
45
+ self.endpoint_url = endpoint_url
46
+ self.region_name = region_name
47
+ self.aws_access_key_id = aws_access_key_id
48
+ self.aws_secret_access_key = aws_secret_access_key
49
+ self.prefix = prefix
50
+ self.s3_client = None
51
+
52
+ async def _initialize_impl(self) -> None:
53
+ """Initialize the S3 client."""
54
+ try:
55
+ # Create S3 client
56
+ self.s3_client = boto3.client(
57
+ 's3',
58
+ endpoint_url=self.endpoint_url,
59
+ region_name=self.region_name,
60
+ aws_access_key_id=self.aws_access_key_id,
61
+ aws_secret_access_key=self.aws_secret_access_key,
62
+ )
63
+
64
+ # Verify bucket exists or create it
65
+ try:
66
+ self.s3_client.head_bucket(Bucket=self.bucket_name)
67
+ except ClientError as e:
68
+ error_code = int(e.response['Error']['Code'])
69
+ if error_code == 404:
70
+ # Bucket doesn't exist, create it
71
+ if self.region_name:
72
+ self.s3_client.create_bucket(
73
+ Bucket=self.bucket_name,
74
+ CreateBucketConfiguration={'LocationConstraint': self.region_name}
75
+ )
76
+ else:
77
+ self.s3_client.create_bucket(Bucket=self.bucket_name)
78
+ else:
79
+ raise
80
+ except NoCredentialsError:
81
+ raise RuntimeError("AWS credentials not found. Please provide credentials.")
82
+ except ClientError as e:
83
+ raise RuntimeError(f"Failed to initialize S3 artifact service: {e}")
84
+
85
+ async def _cleanup_impl(self) -> None:
86
+ """Clean up S3 client."""
87
+ if self.s3_client:
88
+ self.s3_client = None
89
+
90
+ def _get_artifact_key(self, app_name: str, user_id: str, session_id: str, filename: str) -> str:
91
+ """Generate S3 key for artifact metadata."""
92
+ return f"{self.prefix}/{app_name}/{user_id}/{session_id}/{filename}.json"
93
+
94
+ def _get_artifact_data_key(self, app_name: str, user_id: str, session_id: str, filename: str, version: int) -> str:
95
+ """Generate S3 key for artifact data."""
96
+ return f"{self.prefix}/{app_name}/{user_id}/{session_id}/{filename}.v{version}.data"
97
+
98
+ def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
99
+ """Extract blob data and mime type from a Part."""
100
+ if part.inline_data:
101
+ return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
102
+ else:
103
+ raise ValueError("Only inline_data parts are supported")
104
+
105
+ def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
106
+ """Create a Part from blob data and mime type."""
107
+ blob = types.Blob(data=data, mime_type=mime_type)
108
+ return types.Part(inline_data=blob)
109
+
110
+ async def _save_artifact_impl(
111
+ self,
112
+ *,
113
+ app_name: str,
114
+ user_id: str,
115
+ session_id: str,
116
+ filename: str,
117
+ artifact: types.Part,
118
+ ) -> int:
119
+ """Implementation of artifact saving."""
120
+ try:
121
+ # Extract blob data
122
+ data, mime_type = self._serialize_blob(artifact)
123
+
124
+ # Get the next version number
125
+ metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
126
+
127
+ try:
128
+ # Try to load existing metadata
129
+ response = self.s3_client.get_object(Bucket=self.bucket_name, Key=metadata_key)
130
+ metadata = json.loads(response['Body'].read().decode('utf-8'))
131
+ version = len(metadata.get("versions", []))
132
+ except ClientError as e:
133
+ if e.response['Error']['Code'] == 'NoSuchKey':
134
+ # Metadata doesn't exist, create new
135
+ metadata = {
136
+ "app_name": app_name,
137
+ "user_id": user_id,
138
+ "session_id": session_id,
139
+ "filename": filename,
140
+ "versions": []
141
+ }
142
+ version = 0
143
+ else:
144
+ raise
145
+
146
+ # Save data to S3
147
+ data_key = self._get_artifact_data_key(app_name, user_id, session_id, filename, version)
148
+ self.s3_client.put_object(
149
+ Bucket=self.bucket_name,
150
+ Key=data_key,
151
+ Body=data
152
+ )
153
+
154
+ # Update metadata
155
+ version_info = {
156
+ "version": version,
157
+ "mime_type": mime_type,
158
+ "created_at": datetime.utcnow().isoformat(),
159
+ "data_key": data_key
160
+ }
161
+ metadata["versions"].append(version_info)
162
+
163
+ # Save metadata to S3
164
+ self.s3_client.put_object(
165
+ Bucket=self.bucket_name,
166
+ Key=metadata_key,
167
+ Body=json.dumps(metadata, indent=2).encode('utf-8')
168
+ )
169
+
170
+ return version
171
+ except ClientError as e:
172
+ raise RuntimeError(f"Failed to save artifact: {e}")
173
+
174
+ async def _load_artifact_impl(
175
+ self,
176
+ *,
177
+ app_name: str,
178
+ user_id: str,
179
+ session_id: str,
180
+ filename: str,
181
+ version: Optional[int] = None,
182
+ ) -> Optional[types.Part]:
183
+ """Implementation of artifact loading."""
184
+ try:
185
+ # Load metadata
186
+ metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
187
+
188
+ try:
189
+ response = self.s3_client.get_object(Bucket=self.bucket_name, Key=metadata_key)
190
+ metadata = json.loads(response['Body'].read().decode('utf-8'))
191
+ except ClientError as e:
192
+ if e.response['Error']['Code'] == 'NoSuchKey':
193
+ return None
194
+ else:
195
+ raise
196
+
197
+ # Determine version to load
198
+ versions = metadata.get("versions", [])
199
+ if not versions:
200
+ return None
201
+
202
+ if version is not None:
203
+ # Find specific version
204
+ version_info = None
205
+ for v in versions:
206
+ if v["version"] == version:
207
+ version_info = v
208
+ break
209
+ if not version_info:
210
+ return None
211
+ else:
212
+ # Load latest version
213
+ version_info = versions[-1]
214
+
215
+ # Load data
216
+ try:
217
+ data_response = self.s3_client.get_object(
218
+ Bucket=self.bucket_name,
219
+ Key=version_info["data_key"]
220
+ )
221
+ data = data_response['Body'].read()
222
+ except ClientError as e:
223
+ if e.response['Error']['Code'] == 'NoSuchKey':
224
+ return None
225
+ else:
226
+ raise
227
+
228
+ # Create Part from blob data
229
+ return self._deserialize_blob(data, version_info["mime_type"])
230
+ except ClientError as e:
231
+ raise RuntimeError(f"Failed to load artifact: {e}")
232
+
233
+ async def _list_artifact_keys_impl(
234
+ self,
235
+ *,
236
+ app_name: str,
237
+ user_id: str,
238
+ session_id: str,
239
+ ) -> List[str]:
240
+ """Implementation of artifact key listing."""
241
+ try:
242
+ # List objects with the prefix
243
+ prefix = f"{self.prefix}/{app_name}/{user_id}/{session_id}/"
244
+ response = self.s3_client.list_objects_v2(
245
+ Bucket=self.bucket_name,
246
+ Prefix=prefix,
247
+ Delimiter='/'
248
+ )
249
+
250
+ artifact_keys = []
251
+ if 'Contents' in response:
252
+ for obj in response['Contents']:
253
+ key = obj['Key']
254
+ # Check if it's a metadata file
255
+ if key.endswith('.json') and key.startswith(prefix):
256
+ # Extract filename from metadata key
257
+ filename = key[len(prefix):-5] # Remove prefix and .json extension
258
+ artifact_keys.append(filename)
259
+
260
+ return artifact_keys
261
+ except ClientError as e:
262
+ raise RuntimeError(f"Failed to list artifact keys: {e}")
263
+
264
+ async def _delete_artifact_impl(
265
+ self,
266
+ *,
267
+ app_name: str,
268
+ user_id: str,
269
+ session_id: str,
270
+ filename: str,
271
+ ) -> None:
272
+ """Implementation of artifact deletion."""
273
+ try:
274
+ # Load metadata to find all version files
275
+ metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
276
+
277
+ try:
278
+ response = self.s3_client.get_object(Bucket=self.bucket_name, Key=metadata_key)
279
+ metadata = json.loads(response['Body'].read().decode('utf-8'))
280
+
281
+ # Delete all version data files
282
+ for version_info in metadata.get("versions", []):
283
+ self.s3_client.delete_object(
284
+ Bucket=self.bucket_name,
285
+ Key=version_info["data_key"]
286
+ )
287
+
288
+ # Delete metadata file
289
+ self.s3_client.delete_object(
290
+ Bucket=self.bucket_name,
291
+ Key=metadata_key
292
+ )
293
+ except ClientError as e:
294
+ if e.response['Error']['Code'] != 'NoSuchKey':
295
+ raise
296
+ except ClientError as e:
297
+ raise RuntimeError(f"Failed to delete artifact: {e}")
298
+
299
+ async def _list_versions_impl(
300
+ self,
301
+ *,
302
+ app_name: str,
303
+ user_id: str,
304
+ session_id: str,
305
+ filename: str,
306
+ ) -> List[int]:
307
+ """Implementation of version listing."""
308
+ try:
309
+ metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
310
+
311
+ try:
312
+ response = self.s3_client.get_object(Bucket=self.bucket_name, Key=metadata_key)
313
+ metadata = json.loads(response['Body'].read().decode('utf-8'))
314
+ except ClientError as e:
315
+ if e.response['Error']['Code'] == 'NoSuchKey':
316
+ return []
317
+ else:
318
+ raise
319
+
320
+ versions = metadata.get("versions", [])
321
+ return [v["version"] for v in versions]
322
+ except ClientError as e:
323
+ raise RuntimeError(f"Failed to list versions: {e}")
@@ -0,0 +1,255 @@
1
+ """SQL-based artifact service implementation using SQLAlchemy."""
2
+
3
+ import json
4
+ from typing import Optional, List
5
+ from datetime import datetime, timezone
6
+
7
+ try:
8
+ from sqlalchemy import create_engine, Column, String, Text, Integer, DateTime, LargeBinary
9
+ from sqlalchemy.orm import declarative_base, sessionmaker
10
+ from sqlalchemy.exc import SQLAlchemyError
11
+ except ImportError:
12
+ raise ImportError(
13
+ "SQLAlchemy is required for SQLArtifactService. "
14
+ "Install it with: pip install sqlalchemy"
15
+ )
16
+
17
+ from google.genai import types
18
+ from .base_custom_artifact_service import BaseCustomArtifactService
19
+
20
+
21
+ # Use the modern declarative_base import
22
+ Base = declarative_base()
23
+
24
+
25
+ class SQLArtifactModel(Base):
26
+ """SQLAlchemy model for storing artifacts."""
27
+ __tablename__ = 'adk_artifacts'
28
+
29
+ # Composite primary key
30
+ app_name = Column(String, primary_key=True)
31
+ user_id = Column(String, primary_key=True)
32
+ session_id = Column(String, primary_key=True)
33
+ filename = Column(String, primary_key=True)
34
+ version = Column(Integer, primary_key=True)
35
+
36
+ # Artifact data
37
+ mime_type = Column(String, nullable=False)
38
+ data = Column(LargeBinary, nullable=False) # Blob data
39
+ created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
40
+
41
+ # Metadata
42
+ metadata_json = Column(Text, nullable=True) # Additional metadata as JSON
43
+
44
+
45
+ class SQLArtifactService(BaseCustomArtifactService):
46
+ """SQL-based artifact service implementation."""
47
+
48
+ def __init__(self, database_url: str):
49
+ """Initialize the SQL artifact service.
50
+
51
+ Args:
52
+ database_url: Database connection string (e.g., 'sqlite:///artifacts.db')
53
+ """
54
+ super().__init__()
55
+ self.database_url = database_url
56
+ self.engine: Optional[object] = None
57
+ self.session_local: Optional[object] = None
58
+
59
+ async def _initialize_impl(self) -> None:
60
+ """Initialize the database connection and create tables."""
61
+ try:
62
+ self.engine = create_engine(self.database_url)
63
+ Base.metadata.create_all(self.engine)
64
+ self.session_local = sessionmaker(
65
+ autocommit=False,
66
+ autoflush=False,
67
+ bind=self.engine
68
+ )
69
+ except SQLAlchemyError as e:
70
+ raise RuntimeError(f"Failed to initialize SQL artifact service: {e}")
71
+
72
+ async def _cleanup_impl(self) -> None:
73
+ """Clean up database connections."""
74
+ if self.engine:
75
+ self.engine.dispose()
76
+ self.engine = None
77
+ self.session_local = None
78
+
79
+ def _get_db_session(self):
80
+ """Get a database session."""
81
+ if not self.session_local:
82
+ raise RuntimeError("Service not initialized")
83
+ return self.session_local()
84
+
85
+ def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
86
+ """Extract blob data and mime type from a Part."""
87
+ if part.inline_data:
88
+ return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
89
+ else:
90
+ # If it's not inline data, we need to handle other types
91
+ # For now, we'll raise an error - in a full implementation,
92
+ # we'd need to handle other Part types
93
+ raise ValueError("Only inline_data parts are supported")
94
+
95
+ def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
96
+ """Create a Part from blob data and mime type."""
97
+ blob = types.Blob(data=data, mime_type=mime_type)
98
+ return types.Part(inline_data=blob)
99
+
100
+ async def _save_artifact_impl(
101
+ self,
102
+ *,
103
+ app_name: str,
104
+ user_id: str,
105
+ session_id: str,
106
+ filename: str,
107
+ artifact: types.Part,
108
+ ) -> int:
109
+ """Implementation of artifact saving."""
110
+ db_session = self._get_db_session()
111
+ try:
112
+ # Extract blob data
113
+ data, mime_type = self._serialize_blob(artifact)
114
+
115
+ # Get the next version number
116
+ latest_version_result = db_session.query(SQLArtifactModel).filter(
117
+ SQLArtifactModel.app_name == app_name,
118
+ SQLArtifactModel.user_id == user_id,
119
+ SQLArtifactModel.session_id == session_id,
120
+ SQLArtifactModel.filename == filename
121
+ ).order_by(SQLArtifactModel.version.desc()).first()
122
+
123
+ version = (latest_version_result.version + 1) if latest_version_result else 0
124
+
125
+ # Create artifact model
126
+ db_artifact = SQLArtifactModel(
127
+ app_name=app_name,
128
+ user_id=user_id,
129
+ session_id=session_id,
130
+ filename=filename,
131
+ version=version,
132
+ mime_type=mime_type,
133
+ data=data
134
+ )
135
+
136
+ # Save to database
137
+ db_session.add(db_artifact)
138
+ db_session.commit()
139
+
140
+ return version
141
+ except SQLAlchemyError as e:
142
+ db_session.rollback()
143
+ raise RuntimeError(f"Failed to save artifact: {e}")
144
+ finally:
145
+ db_session.close()
146
+
147
+ async def _load_artifact_impl(
148
+ self,
149
+ *,
150
+ app_name: str,
151
+ user_id: str,
152
+ session_id: str,
153
+ filename: str,
154
+ version: Optional[int] = None,
155
+ ) -> Optional[types.Part]:
156
+ """Implementation of artifact loading."""
157
+ db_session = self._get_db_session()
158
+ try:
159
+ query = db_session.query(SQLArtifactModel).filter(
160
+ SQLArtifactModel.app_name == app_name,
161
+ SQLArtifactModel.user_id == user_id,
162
+ SQLArtifactModel.session_id == session_id,
163
+ SQLArtifactModel.filename == filename
164
+ )
165
+
166
+ if version is not None:
167
+ query = query.filter(SQLArtifactModel.version == version)
168
+ else:
169
+ # Get the latest version
170
+ query = query.order_by(SQLArtifactModel.version.desc())
171
+
172
+ db_artifact = query.first()
173
+
174
+ if not db_artifact:
175
+ return None
176
+
177
+ # Create Part from blob data
178
+ return self._deserialize_blob(db_artifact.data, db_artifact.mime_type)
179
+ except SQLAlchemyError as e:
180
+ raise RuntimeError(f"Failed to load artifact: {e}")
181
+ finally:
182
+ db_session.close()
183
+
184
+ async def _list_artifact_keys_impl(
185
+ self,
186
+ *,
187
+ app_name: str,
188
+ user_id: str,
189
+ session_id: str,
190
+ ) -> List[str]:
191
+ """Implementation of artifact key listing."""
192
+ db_session = self._get_db_session()
193
+ try:
194
+ # Get distinct filenames
195
+ filenames = db_session.query(SQLArtifactModel.filename).filter(
196
+ SQLArtifactModel.app_name == app_name,
197
+ SQLArtifactModel.user_id == user_id,
198
+ SQLArtifactModel.session_id == session_id
199
+ ).distinct().all()
200
+
201
+ return [filename[0] for filename in filenames]
202
+ except SQLAlchemyError as e:
203
+ raise RuntimeError(f"Failed to list artifact keys: {e}")
204
+ finally:
205
+ db_session.close()
206
+
207
+ async def _delete_artifact_impl(
208
+ self,
209
+ *,
210
+ app_name: str,
211
+ user_id: str,
212
+ session_id: str,
213
+ filename: str,
214
+ ) -> None:
215
+ """Implementation of artifact deletion."""
216
+ db_session = self._get_db_session()
217
+ try:
218
+ # Delete all versions of the artifact
219
+ db_session.query(SQLArtifactModel).filter(
220
+ SQLArtifactModel.app_name == app_name,
221
+ SQLArtifactModel.user_id == user_id,
222
+ SQLArtifactModel.session_id == session_id,
223
+ SQLArtifactModel.filename == filename
224
+ ).delete()
225
+
226
+ db_session.commit()
227
+ except SQLAlchemyError as e:
228
+ db_session.rollback()
229
+ raise RuntimeError(f"Failed to delete artifact: {e}")
230
+ finally:
231
+ db_session.close()
232
+
233
+ async def _list_versions_impl(
234
+ self,
235
+ *,
236
+ app_name: str,
237
+ user_id: str,
238
+ session_id: str,
239
+ filename: str,
240
+ ) -> List[int]:
241
+ """Implementation of version listing."""
242
+ db_session = self._get_db_session()
243
+ try:
244
+ versions = db_session.query(SQLArtifactModel.version).filter(
245
+ SQLArtifactModel.app_name == app_name,
246
+ SQLArtifactModel.user_id == user_id,
247
+ SQLArtifactModel.session_id == session_id,
248
+ SQLArtifactModel.filename == filename
249
+ ).order_by(SQLArtifactModel.version.asc()).all()
250
+
251
+ return [version[0] for version in versions]
252
+ except SQLAlchemyError as e:
253
+ raise RuntimeError(f"Failed to list versions: {e}")
254
+ finally:
255
+ db_session.close()
@@ -0,0 +1,15 @@
1
+ """Custom ADK memory services package."""
2
+
3
+ from .base_custom_memory_service import BaseCustomMemoryService
4
+ from .sql_memory_service import SQLMemoryService
5
+ from .mongo_memory_service import MongoMemoryService
6
+ from .redis_memory_service import RedisMemoryService
7
+ from .yaml_file_memory_service import YamlFileMemoryService
8
+
9
+ __all__ = [
10
+ "BaseCustomMemoryService",
11
+ "SQLMemoryService",
12
+ "MongoMemoryService",
13
+ "RedisMemoryService",
14
+ "YamlFileMemoryService",
15
+ ]
@@ -0,0 +1,90 @@
1
+ """Base class for custom memory services."""
2
+
3
+ from abc import abstractmethod
4
+ from typing import TYPE_CHECKING
5
+
6
+ from google.adk.memory.base_memory_service import BaseMemoryService
7
+
8
+ if TYPE_CHECKING:
9
+ from google.adk.sessions.session import Session
10
+
11
+
12
+ class BaseCustomMemoryService(BaseMemoryService):
13
+ """Base class for custom memory services with common functionality."""
14
+
15
+ def __init__(self):
16
+ """Initialize the base custom memory service."""
17
+ super().__init__()
18
+ self._initialized = False
19
+
20
+ @abstractmethod
21
+ async def _add_session_to_memory_impl(self, session: "Session") -> None:
22
+ """Implementation of adding a session to memory.
23
+
24
+ Args:
25
+ session: The session to add to memory.
26
+ """
27
+
28
+ @abstractmethod
29
+ async def _search_memory_impl(
30
+ self, *, app_name: str, user_id: str, query: str
31
+ ) -> "SearchMemoryResponse":
32
+ """Implementation of searching memory.
33
+
34
+ Args:
35
+ app_name: The name of the application.
36
+ user_id: The id of the user.
37
+ query: The query to search for.
38
+
39
+ Returns:
40
+ A SearchMemoryResponse containing the matching memories.
41
+ """
42
+
43
+ async def add_session_to_memory(self, session: "Session") -> None:
44
+ """Add a session to the memory service.
45
+
46
+ Args:
47
+ session: The session to add.
48
+ """
49
+ if not self._initialized:
50
+ await self.initialize()
51
+ await self._add_session_to_memory_impl(session)
52
+
53
+ async def search_memory(
54
+ self, *, app_name: str, user_id: str, query: str
55
+ ) -> "SearchMemoryResponse":
56
+ """Search for sessions that match the query.
57
+
58
+ Args:
59
+ app_name: The name of the application.
60
+ user_id: The id of the user.
61
+ query: The query to search for.
62
+
63
+ Returns:
64
+ A SearchMemoryResponse containing the matching memories.
65
+ """
66
+ if not self._initialized:
67
+ await self.initialize()
68
+ return await self._search_memory_impl(
69
+ app_name=app_name, user_id=user_id, query=query
70
+ )
71
+
72
+ async def initialize(self) -> None:
73
+ """Initialize the memory service."""
74
+ if not self._initialized:
75
+ await self._initialize_impl()
76
+ self._initialized = True
77
+
78
+ async def cleanup(self) -> None:
79
+ """Clean up the memory service."""
80
+ if self._initialized:
81
+ await self._cleanup_impl()
82
+ self._initialized = False
83
+
84
+ @abstractmethod
85
+ async def _initialize_impl(self) -> None:
86
+ """Implementation of initialization."""
87
+
88
+ @abstractmethod
89
+ async def _cleanup_impl(self) -> None:
90
+ """Implementation of cleanup."""