google-adk-extras 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_adk_extras/__init__.py +1 -0
- google_adk_extras/artifacts/__init__.py +15 -0
- google_adk_extras/artifacts/base_custom_artifact_service.py +207 -0
- google_adk_extras/artifacts/local_folder_artifact_service.py +242 -0
- google_adk_extras/artifacts/mongo_artifact_service.py +221 -0
- google_adk_extras/artifacts/s3_artifact_service.py +323 -0
- google_adk_extras/artifacts/sql_artifact_service.py +255 -0
- google_adk_extras/memory/__init__.py +15 -0
- google_adk_extras/memory/base_custom_memory_service.py +90 -0
- google_adk_extras/memory/mongo_memory_service.py +174 -0
- google_adk_extras/memory/redis_memory_service.py +188 -0
- google_adk_extras/memory/sql_memory_service.py +213 -0
- google_adk_extras/memory/yaml_file_memory_service.py +176 -0
- google_adk_extras/py.typed +0 -0
- google_adk_extras/sessions/__init__.py +13 -0
- google_adk_extras/sessions/base_custom_session_service.py +183 -0
- google_adk_extras/sessions/mongo_session_service.py +243 -0
- google_adk_extras/sessions/redis_session_service.py +271 -0
- google_adk_extras/sessions/sql_session_service.py +308 -0
- google_adk_extras/sessions/yaml_file_session_service.py +245 -0
- google_adk_extras-0.1.1.dist-info/METADATA +175 -0
- google_adk_extras-0.1.1.dist-info/RECORD +25 -0
- google_adk_extras-0.1.1.dist-info/WHEEL +5 -0
- google_adk_extras-0.1.1.dist-info/licenses/LICENSE +202 -0
- google_adk_extras-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,323 @@
|
|
1
|
+
"""S3-compatible artifact service implementation."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import base64
|
5
|
+
from typing import Optional, List
|
6
|
+
from datetime import datetime
|
7
|
+
|
8
|
+
try:
|
9
|
+
import boto3
|
10
|
+
from botocore.exceptions import ClientError, NoCredentialsError
|
11
|
+
except ImportError:
|
12
|
+
raise ImportError(
|
13
|
+
"Boto3 is required for S3ArtifactService. "
|
14
|
+
"Install it with: pip install boto3"
|
15
|
+
)
|
16
|
+
|
17
|
+
from google.genai import types
|
18
|
+
from .base_custom_artifact_service import BaseCustomArtifactService
|
19
|
+
|
20
|
+
|
21
|
+
class S3ArtifactService(BaseCustomArtifactService):
|
22
|
+
"""S3-compatible artifact service implementation."""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
bucket_name: str,
|
27
|
+
endpoint_url: Optional[str] = None,
|
28
|
+
region_name: Optional[str] = None,
|
29
|
+
aws_access_key_id: Optional[str] = None,
|
30
|
+
aws_secret_access_key: Optional[str] = None,
|
31
|
+
prefix: str = "adk-artifacts",
|
32
|
+
):
|
33
|
+
"""Initialize the S3 artifact service.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
bucket_name: S3 bucket name
|
37
|
+
endpoint_url: S3 endpoint URL (for non-AWS S3 services)
|
38
|
+
region_name: AWS region name
|
39
|
+
aws_access_key_id: AWS access key ID
|
40
|
+
aws_secret_access_key: AWS secret access key
|
41
|
+
prefix: Prefix for artifact storage paths
|
42
|
+
"""
|
43
|
+
super().__init__()
|
44
|
+
self.bucket_name = bucket_name
|
45
|
+
self.endpoint_url = endpoint_url
|
46
|
+
self.region_name = region_name
|
47
|
+
self.aws_access_key_id = aws_access_key_id
|
48
|
+
self.aws_secret_access_key = aws_secret_access_key
|
49
|
+
self.prefix = prefix
|
50
|
+
self.s3_client = None
|
51
|
+
|
52
|
+
async def _initialize_impl(self) -> None:
|
53
|
+
"""Initialize the S3 client."""
|
54
|
+
try:
|
55
|
+
# Create S3 client
|
56
|
+
self.s3_client = boto3.client(
|
57
|
+
's3',
|
58
|
+
endpoint_url=self.endpoint_url,
|
59
|
+
region_name=self.region_name,
|
60
|
+
aws_access_key_id=self.aws_access_key_id,
|
61
|
+
aws_secret_access_key=self.aws_secret_access_key,
|
62
|
+
)
|
63
|
+
|
64
|
+
# Verify bucket exists or create it
|
65
|
+
try:
|
66
|
+
self.s3_client.head_bucket(Bucket=self.bucket_name)
|
67
|
+
except ClientError as e:
|
68
|
+
error_code = int(e.response['Error']['Code'])
|
69
|
+
if error_code == 404:
|
70
|
+
# Bucket doesn't exist, create it
|
71
|
+
if self.region_name:
|
72
|
+
self.s3_client.create_bucket(
|
73
|
+
Bucket=self.bucket_name,
|
74
|
+
CreateBucketConfiguration={'LocationConstraint': self.region_name}
|
75
|
+
)
|
76
|
+
else:
|
77
|
+
self.s3_client.create_bucket(Bucket=self.bucket_name)
|
78
|
+
else:
|
79
|
+
raise
|
80
|
+
except NoCredentialsError:
|
81
|
+
raise RuntimeError("AWS credentials not found. Please provide credentials.")
|
82
|
+
except ClientError as e:
|
83
|
+
raise RuntimeError(f"Failed to initialize S3 artifact service: {e}")
|
84
|
+
|
85
|
+
async def _cleanup_impl(self) -> None:
|
86
|
+
"""Clean up S3 client."""
|
87
|
+
if self.s3_client:
|
88
|
+
self.s3_client = None
|
89
|
+
|
90
|
+
def _get_artifact_key(self, app_name: str, user_id: str, session_id: str, filename: str) -> str:
|
91
|
+
"""Generate S3 key for artifact metadata."""
|
92
|
+
return f"{self.prefix}/{app_name}/{user_id}/{session_id}/{filename}.json"
|
93
|
+
|
94
|
+
def _get_artifact_data_key(self, app_name: str, user_id: str, session_id: str, filename: str, version: int) -> str:
|
95
|
+
"""Generate S3 key for artifact data."""
|
96
|
+
return f"{self.prefix}/{app_name}/{user_id}/{session_id}/{filename}.v{version}.data"
|
97
|
+
|
98
|
+
def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
|
99
|
+
"""Extract blob data and mime type from a Part."""
|
100
|
+
if part.inline_data:
|
101
|
+
return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
|
102
|
+
else:
|
103
|
+
raise ValueError("Only inline_data parts are supported")
|
104
|
+
|
105
|
+
def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
|
106
|
+
"""Create a Part from blob data and mime type."""
|
107
|
+
blob = types.Blob(data=data, mime_type=mime_type)
|
108
|
+
return types.Part(inline_data=blob)
|
109
|
+
|
110
|
+
async def _save_artifact_impl(
|
111
|
+
self,
|
112
|
+
*,
|
113
|
+
app_name: str,
|
114
|
+
user_id: str,
|
115
|
+
session_id: str,
|
116
|
+
filename: str,
|
117
|
+
artifact: types.Part,
|
118
|
+
) -> int:
|
119
|
+
"""Implementation of artifact saving."""
|
120
|
+
try:
|
121
|
+
# Extract blob data
|
122
|
+
data, mime_type = self._serialize_blob(artifact)
|
123
|
+
|
124
|
+
# Get the next version number
|
125
|
+
metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
|
126
|
+
|
127
|
+
try:
|
128
|
+
# Try to load existing metadata
|
129
|
+
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=metadata_key)
|
130
|
+
metadata = json.loads(response['Body'].read().decode('utf-8'))
|
131
|
+
version = len(metadata.get("versions", []))
|
132
|
+
except ClientError as e:
|
133
|
+
if e.response['Error']['Code'] == 'NoSuchKey':
|
134
|
+
# Metadata doesn't exist, create new
|
135
|
+
metadata = {
|
136
|
+
"app_name": app_name,
|
137
|
+
"user_id": user_id,
|
138
|
+
"session_id": session_id,
|
139
|
+
"filename": filename,
|
140
|
+
"versions": []
|
141
|
+
}
|
142
|
+
version = 0
|
143
|
+
else:
|
144
|
+
raise
|
145
|
+
|
146
|
+
# Save data to S3
|
147
|
+
data_key = self._get_artifact_data_key(app_name, user_id, session_id, filename, version)
|
148
|
+
self.s3_client.put_object(
|
149
|
+
Bucket=self.bucket_name,
|
150
|
+
Key=data_key,
|
151
|
+
Body=data
|
152
|
+
)
|
153
|
+
|
154
|
+
# Update metadata
|
155
|
+
version_info = {
|
156
|
+
"version": version,
|
157
|
+
"mime_type": mime_type,
|
158
|
+
"created_at": datetime.utcnow().isoformat(),
|
159
|
+
"data_key": data_key
|
160
|
+
}
|
161
|
+
metadata["versions"].append(version_info)
|
162
|
+
|
163
|
+
# Save metadata to S3
|
164
|
+
self.s3_client.put_object(
|
165
|
+
Bucket=self.bucket_name,
|
166
|
+
Key=metadata_key,
|
167
|
+
Body=json.dumps(metadata, indent=2).encode('utf-8')
|
168
|
+
)
|
169
|
+
|
170
|
+
return version
|
171
|
+
except ClientError as e:
|
172
|
+
raise RuntimeError(f"Failed to save artifact: {e}")
|
173
|
+
|
174
|
+
async def _load_artifact_impl(
|
175
|
+
self,
|
176
|
+
*,
|
177
|
+
app_name: str,
|
178
|
+
user_id: str,
|
179
|
+
session_id: str,
|
180
|
+
filename: str,
|
181
|
+
version: Optional[int] = None,
|
182
|
+
) -> Optional[types.Part]:
|
183
|
+
"""Implementation of artifact loading."""
|
184
|
+
try:
|
185
|
+
# Load metadata
|
186
|
+
metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
|
187
|
+
|
188
|
+
try:
|
189
|
+
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=metadata_key)
|
190
|
+
metadata = json.loads(response['Body'].read().decode('utf-8'))
|
191
|
+
except ClientError as e:
|
192
|
+
if e.response['Error']['Code'] == 'NoSuchKey':
|
193
|
+
return None
|
194
|
+
else:
|
195
|
+
raise
|
196
|
+
|
197
|
+
# Determine version to load
|
198
|
+
versions = metadata.get("versions", [])
|
199
|
+
if not versions:
|
200
|
+
return None
|
201
|
+
|
202
|
+
if version is not None:
|
203
|
+
# Find specific version
|
204
|
+
version_info = None
|
205
|
+
for v in versions:
|
206
|
+
if v["version"] == version:
|
207
|
+
version_info = v
|
208
|
+
break
|
209
|
+
if not version_info:
|
210
|
+
return None
|
211
|
+
else:
|
212
|
+
# Load latest version
|
213
|
+
version_info = versions[-1]
|
214
|
+
|
215
|
+
# Load data
|
216
|
+
try:
|
217
|
+
data_response = self.s3_client.get_object(
|
218
|
+
Bucket=self.bucket_name,
|
219
|
+
Key=version_info["data_key"]
|
220
|
+
)
|
221
|
+
data = data_response['Body'].read()
|
222
|
+
except ClientError as e:
|
223
|
+
if e.response['Error']['Code'] == 'NoSuchKey':
|
224
|
+
return None
|
225
|
+
else:
|
226
|
+
raise
|
227
|
+
|
228
|
+
# Create Part from blob data
|
229
|
+
return self._deserialize_blob(data, version_info["mime_type"])
|
230
|
+
except ClientError as e:
|
231
|
+
raise RuntimeError(f"Failed to load artifact: {e}")
|
232
|
+
|
233
|
+
async def _list_artifact_keys_impl(
|
234
|
+
self,
|
235
|
+
*,
|
236
|
+
app_name: str,
|
237
|
+
user_id: str,
|
238
|
+
session_id: str,
|
239
|
+
) -> List[str]:
|
240
|
+
"""Implementation of artifact key listing."""
|
241
|
+
try:
|
242
|
+
# List objects with the prefix
|
243
|
+
prefix = f"{self.prefix}/{app_name}/{user_id}/{session_id}/"
|
244
|
+
response = self.s3_client.list_objects_v2(
|
245
|
+
Bucket=self.bucket_name,
|
246
|
+
Prefix=prefix,
|
247
|
+
Delimiter='/'
|
248
|
+
)
|
249
|
+
|
250
|
+
artifact_keys = []
|
251
|
+
if 'Contents' in response:
|
252
|
+
for obj in response['Contents']:
|
253
|
+
key = obj['Key']
|
254
|
+
# Check if it's a metadata file
|
255
|
+
if key.endswith('.json') and key.startswith(prefix):
|
256
|
+
# Extract filename from metadata key
|
257
|
+
filename = key[len(prefix):-5] # Remove prefix and .json extension
|
258
|
+
artifact_keys.append(filename)
|
259
|
+
|
260
|
+
return artifact_keys
|
261
|
+
except ClientError as e:
|
262
|
+
raise RuntimeError(f"Failed to list artifact keys: {e}")
|
263
|
+
|
264
|
+
async def _delete_artifact_impl(
|
265
|
+
self,
|
266
|
+
*,
|
267
|
+
app_name: str,
|
268
|
+
user_id: str,
|
269
|
+
session_id: str,
|
270
|
+
filename: str,
|
271
|
+
) -> None:
|
272
|
+
"""Implementation of artifact deletion."""
|
273
|
+
try:
|
274
|
+
# Load metadata to find all version files
|
275
|
+
metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
|
276
|
+
|
277
|
+
try:
|
278
|
+
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=metadata_key)
|
279
|
+
metadata = json.loads(response['Body'].read().decode('utf-8'))
|
280
|
+
|
281
|
+
# Delete all version data files
|
282
|
+
for version_info in metadata.get("versions", []):
|
283
|
+
self.s3_client.delete_object(
|
284
|
+
Bucket=self.bucket_name,
|
285
|
+
Key=version_info["data_key"]
|
286
|
+
)
|
287
|
+
|
288
|
+
# Delete metadata file
|
289
|
+
self.s3_client.delete_object(
|
290
|
+
Bucket=self.bucket_name,
|
291
|
+
Key=metadata_key
|
292
|
+
)
|
293
|
+
except ClientError as e:
|
294
|
+
if e.response['Error']['Code'] != 'NoSuchKey':
|
295
|
+
raise
|
296
|
+
except ClientError as e:
|
297
|
+
raise RuntimeError(f"Failed to delete artifact: {e}")
|
298
|
+
|
299
|
+
async def _list_versions_impl(
|
300
|
+
self,
|
301
|
+
*,
|
302
|
+
app_name: str,
|
303
|
+
user_id: str,
|
304
|
+
session_id: str,
|
305
|
+
filename: str,
|
306
|
+
) -> List[int]:
|
307
|
+
"""Implementation of version listing."""
|
308
|
+
try:
|
309
|
+
metadata_key = self._get_artifact_key(app_name, user_id, session_id, filename)
|
310
|
+
|
311
|
+
try:
|
312
|
+
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=metadata_key)
|
313
|
+
metadata = json.loads(response['Body'].read().decode('utf-8'))
|
314
|
+
except ClientError as e:
|
315
|
+
if e.response['Error']['Code'] == 'NoSuchKey':
|
316
|
+
return []
|
317
|
+
else:
|
318
|
+
raise
|
319
|
+
|
320
|
+
versions = metadata.get("versions", [])
|
321
|
+
return [v["version"] for v in versions]
|
322
|
+
except ClientError as e:
|
323
|
+
raise RuntimeError(f"Failed to list versions: {e}")
|
@@ -0,0 +1,255 @@
|
|
1
|
+
"""SQL-based artifact service implementation using SQLAlchemy."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
from typing import Optional, List
|
5
|
+
from datetime import datetime, timezone
|
6
|
+
|
7
|
+
try:
|
8
|
+
from sqlalchemy import create_engine, Column, String, Text, Integer, DateTime, LargeBinary
|
9
|
+
from sqlalchemy.orm import declarative_base, sessionmaker
|
10
|
+
from sqlalchemy.exc import SQLAlchemyError
|
11
|
+
except ImportError:
|
12
|
+
raise ImportError(
|
13
|
+
"SQLAlchemy is required for SQLArtifactService. "
|
14
|
+
"Install it with: pip install sqlalchemy"
|
15
|
+
)
|
16
|
+
|
17
|
+
from google.genai import types
|
18
|
+
from .base_custom_artifact_service import BaseCustomArtifactService
|
19
|
+
|
20
|
+
|
21
|
+
# Use the modern declarative_base import
|
22
|
+
Base = declarative_base()
|
23
|
+
|
24
|
+
|
25
|
+
class SQLArtifactModel(Base):
|
26
|
+
"""SQLAlchemy model for storing artifacts."""
|
27
|
+
__tablename__ = 'adk_artifacts'
|
28
|
+
|
29
|
+
# Composite primary key
|
30
|
+
app_name = Column(String, primary_key=True)
|
31
|
+
user_id = Column(String, primary_key=True)
|
32
|
+
session_id = Column(String, primary_key=True)
|
33
|
+
filename = Column(String, primary_key=True)
|
34
|
+
version = Column(Integer, primary_key=True)
|
35
|
+
|
36
|
+
# Artifact data
|
37
|
+
mime_type = Column(String, nullable=False)
|
38
|
+
data = Column(LargeBinary, nullable=False) # Blob data
|
39
|
+
created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc))
|
40
|
+
|
41
|
+
# Metadata
|
42
|
+
metadata_json = Column(Text, nullable=True) # Additional metadata as JSON
|
43
|
+
|
44
|
+
|
45
|
+
class SQLArtifactService(BaseCustomArtifactService):
|
46
|
+
"""SQL-based artifact service implementation."""
|
47
|
+
|
48
|
+
def __init__(self, database_url: str):
|
49
|
+
"""Initialize the SQL artifact service.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
database_url: Database connection string (e.g., 'sqlite:///artifacts.db')
|
53
|
+
"""
|
54
|
+
super().__init__()
|
55
|
+
self.database_url = database_url
|
56
|
+
self.engine: Optional[object] = None
|
57
|
+
self.session_local: Optional[object] = None
|
58
|
+
|
59
|
+
async def _initialize_impl(self) -> None:
|
60
|
+
"""Initialize the database connection and create tables."""
|
61
|
+
try:
|
62
|
+
self.engine = create_engine(self.database_url)
|
63
|
+
Base.metadata.create_all(self.engine)
|
64
|
+
self.session_local = sessionmaker(
|
65
|
+
autocommit=False,
|
66
|
+
autoflush=False,
|
67
|
+
bind=self.engine
|
68
|
+
)
|
69
|
+
except SQLAlchemyError as e:
|
70
|
+
raise RuntimeError(f"Failed to initialize SQL artifact service: {e}")
|
71
|
+
|
72
|
+
async def _cleanup_impl(self) -> None:
|
73
|
+
"""Clean up database connections."""
|
74
|
+
if self.engine:
|
75
|
+
self.engine.dispose()
|
76
|
+
self.engine = None
|
77
|
+
self.session_local = None
|
78
|
+
|
79
|
+
def _get_db_session(self):
|
80
|
+
"""Get a database session."""
|
81
|
+
if not self.session_local:
|
82
|
+
raise RuntimeError("Service not initialized")
|
83
|
+
return self.session_local()
|
84
|
+
|
85
|
+
def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
|
86
|
+
"""Extract blob data and mime type from a Part."""
|
87
|
+
if part.inline_data:
|
88
|
+
return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
|
89
|
+
else:
|
90
|
+
# If it's not inline data, we need to handle other types
|
91
|
+
# For now, we'll raise an error - in a full implementation,
|
92
|
+
# we'd need to handle other Part types
|
93
|
+
raise ValueError("Only inline_data parts are supported")
|
94
|
+
|
95
|
+
def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
|
96
|
+
"""Create a Part from blob data and mime type."""
|
97
|
+
blob = types.Blob(data=data, mime_type=mime_type)
|
98
|
+
return types.Part(inline_data=blob)
|
99
|
+
|
100
|
+
async def _save_artifact_impl(
|
101
|
+
self,
|
102
|
+
*,
|
103
|
+
app_name: str,
|
104
|
+
user_id: str,
|
105
|
+
session_id: str,
|
106
|
+
filename: str,
|
107
|
+
artifact: types.Part,
|
108
|
+
) -> int:
|
109
|
+
"""Implementation of artifact saving."""
|
110
|
+
db_session = self._get_db_session()
|
111
|
+
try:
|
112
|
+
# Extract blob data
|
113
|
+
data, mime_type = self._serialize_blob(artifact)
|
114
|
+
|
115
|
+
# Get the next version number
|
116
|
+
latest_version_result = db_session.query(SQLArtifactModel).filter(
|
117
|
+
SQLArtifactModel.app_name == app_name,
|
118
|
+
SQLArtifactModel.user_id == user_id,
|
119
|
+
SQLArtifactModel.session_id == session_id,
|
120
|
+
SQLArtifactModel.filename == filename
|
121
|
+
).order_by(SQLArtifactModel.version.desc()).first()
|
122
|
+
|
123
|
+
version = (latest_version_result.version + 1) if latest_version_result else 0
|
124
|
+
|
125
|
+
# Create artifact model
|
126
|
+
db_artifact = SQLArtifactModel(
|
127
|
+
app_name=app_name,
|
128
|
+
user_id=user_id,
|
129
|
+
session_id=session_id,
|
130
|
+
filename=filename,
|
131
|
+
version=version,
|
132
|
+
mime_type=mime_type,
|
133
|
+
data=data
|
134
|
+
)
|
135
|
+
|
136
|
+
# Save to database
|
137
|
+
db_session.add(db_artifact)
|
138
|
+
db_session.commit()
|
139
|
+
|
140
|
+
return version
|
141
|
+
except SQLAlchemyError as e:
|
142
|
+
db_session.rollback()
|
143
|
+
raise RuntimeError(f"Failed to save artifact: {e}")
|
144
|
+
finally:
|
145
|
+
db_session.close()
|
146
|
+
|
147
|
+
async def _load_artifact_impl(
|
148
|
+
self,
|
149
|
+
*,
|
150
|
+
app_name: str,
|
151
|
+
user_id: str,
|
152
|
+
session_id: str,
|
153
|
+
filename: str,
|
154
|
+
version: Optional[int] = None,
|
155
|
+
) -> Optional[types.Part]:
|
156
|
+
"""Implementation of artifact loading."""
|
157
|
+
db_session = self._get_db_session()
|
158
|
+
try:
|
159
|
+
query = db_session.query(SQLArtifactModel).filter(
|
160
|
+
SQLArtifactModel.app_name == app_name,
|
161
|
+
SQLArtifactModel.user_id == user_id,
|
162
|
+
SQLArtifactModel.session_id == session_id,
|
163
|
+
SQLArtifactModel.filename == filename
|
164
|
+
)
|
165
|
+
|
166
|
+
if version is not None:
|
167
|
+
query = query.filter(SQLArtifactModel.version == version)
|
168
|
+
else:
|
169
|
+
# Get the latest version
|
170
|
+
query = query.order_by(SQLArtifactModel.version.desc())
|
171
|
+
|
172
|
+
db_artifact = query.first()
|
173
|
+
|
174
|
+
if not db_artifact:
|
175
|
+
return None
|
176
|
+
|
177
|
+
# Create Part from blob data
|
178
|
+
return self._deserialize_blob(db_artifact.data, db_artifact.mime_type)
|
179
|
+
except SQLAlchemyError as e:
|
180
|
+
raise RuntimeError(f"Failed to load artifact: {e}")
|
181
|
+
finally:
|
182
|
+
db_session.close()
|
183
|
+
|
184
|
+
async def _list_artifact_keys_impl(
|
185
|
+
self,
|
186
|
+
*,
|
187
|
+
app_name: str,
|
188
|
+
user_id: str,
|
189
|
+
session_id: str,
|
190
|
+
) -> List[str]:
|
191
|
+
"""Implementation of artifact key listing."""
|
192
|
+
db_session = self._get_db_session()
|
193
|
+
try:
|
194
|
+
# Get distinct filenames
|
195
|
+
filenames = db_session.query(SQLArtifactModel.filename).filter(
|
196
|
+
SQLArtifactModel.app_name == app_name,
|
197
|
+
SQLArtifactModel.user_id == user_id,
|
198
|
+
SQLArtifactModel.session_id == session_id
|
199
|
+
).distinct().all()
|
200
|
+
|
201
|
+
return [filename[0] for filename in filenames]
|
202
|
+
except SQLAlchemyError as e:
|
203
|
+
raise RuntimeError(f"Failed to list artifact keys: {e}")
|
204
|
+
finally:
|
205
|
+
db_session.close()
|
206
|
+
|
207
|
+
async def _delete_artifact_impl(
|
208
|
+
self,
|
209
|
+
*,
|
210
|
+
app_name: str,
|
211
|
+
user_id: str,
|
212
|
+
session_id: str,
|
213
|
+
filename: str,
|
214
|
+
) -> None:
|
215
|
+
"""Implementation of artifact deletion."""
|
216
|
+
db_session = self._get_db_session()
|
217
|
+
try:
|
218
|
+
# Delete all versions of the artifact
|
219
|
+
db_session.query(SQLArtifactModel).filter(
|
220
|
+
SQLArtifactModel.app_name == app_name,
|
221
|
+
SQLArtifactModel.user_id == user_id,
|
222
|
+
SQLArtifactModel.session_id == session_id,
|
223
|
+
SQLArtifactModel.filename == filename
|
224
|
+
).delete()
|
225
|
+
|
226
|
+
db_session.commit()
|
227
|
+
except SQLAlchemyError as e:
|
228
|
+
db_session.rollback()
|
229
|
+
raise RuntimeError(f"Failed to delete artifact: {e}")
|
230
|
+
finally:
|
231
|
+
db_session.close()
|
232
|
+
|
233
|
+
async def _list_versions_impl(
|
234
|
+
self,
|
235
|
+
*,
|
236
|
+
app_name: str,
|
237
|
+
user_id: str,
|
238
|
+
session_id: str,
|
239
|
+
filename: str,
|
240
|
+
) -> List[int]:
|
241
|
+
"""Implementation of version listing."""
|
242
|
+
db_session = self._get_db_session()
|
243
|
+
try:
|
244
|
+
versions = db_session.query(SQLArtifactModel.version).filter(
|
245
|
+
SQLArtifactModel.app_name == app_name,
|
246
|
+
SQLArtifactModel.user_id == user_id,
|
247
|
+
SQLArtifactModel.session_id == session_id,
|
248
|
+
SQLArtifactModel.filename == filename
|
249
|
+
).order_by(SQLArtifactModel.version.asc()).all()
|
250
|
+
|
251
|
+
return [version[0] for version in versions]
|
252
|
+
except SQLAlchemyError as e:
|
253
|
+
raise RuntimeError(f"Failed to list versions: {e}")
|
254
|
+
finally:
|
255
|
+
db_session.close()
|
@@ -0,0 +1,15 @@
|
|
1
|
+
"""Custom ADK memory services package."""
|
2
|
+
|
3
|
+
from .base_custom_memory_service import BaseCustomMemoryService
|
4
|
+
from .sql_memory_service import SQLMemoryService
|
5
|
+
from .mongo_memory_service import MongoMemoryService
|
6
|
+
from .redis_memory_service import RedisMemoryService
|
7
|
+
from .yaml_file_memory_service import YamlFileMemoryService
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"BaseCustomMemoryService",
|
11
|
+
"SQLMemoryService",
|
12
|
+
"MongoMemoryService",
|
13
|
+
"RedisMemoryService",
|
14
|
+
"YamlFileMemoryService",
|
15
|
+
]
|
@@ -0,0 +1,90 @@
|
|
1
|
+
"""Base class for custom memory services."""
|
2
|
+
|
3
|
+
from abc import abstractmethod
|
4
|
+
from typing import TYPE_CHECKING
|
5
|
+
|
6
|
+
from google.adk.memory.base_memory_service import BaseMemoryService
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from google.adk.sessions.session import Session
|
10
|
+
|
11
|
+
|
12
|
+
class BaseCustomMemoryService(BaseMemoryService):
|
13
|
+
"""Base class for custom memory services with common functionality."""
|
14
|
+
|
15
|
+
def __init__(self):
|
16
|
+
"""Initialize the base custom memory service."""
|
17
|
+
super().__init__()
|
18
|
+
self._initialized = False
|
19
|
+
|
20
|
+
@abstractmethod
|
21
|
+
async def _add_session_to_memory_impl(self, session: "Session") -> None:
|
22
|
+
"""Implementation of adding a session to memory.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
session: The session to add to memory.
|
26
|
+
"""
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
async def _search_memory_impl(
|
30
|
+
self, *, app_name: str, user_id: str, query: str
|
31
|
+
) -> "SearchMemoryResponse":
|
32
|
+
"""Implementation of searching memory.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
app_name: The name of the application.
|
36
|
+
user_id: The id of the user.
|
37
|
+
query: The query to search for.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
A SearchMemoryResponse containing the matching memories.
|
41
|
+
"""
|
42
|
+
|
43
|
+
async def add_session_to_memory(self, session: "Session") -> None:
|
44
|
+
"""Add a session to the memory service.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
session: The session to add.
|
48
|
+
"""
|
49
|
+
if not self._initialized:
|
50
|
+
await self.initialize()
|
51
|
+
await self._add_session_to_memory_impl(session)
|
52
|
+
|
53
|
+
async def search_memory(
|
54
|
+
self, *, app_name: str, user_id: str, query: str
|
55
|
+
) -> "SearchMemoryResponse":
|
56
|
+
"""Search for sessions that match the query.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
app_name: The name of the application.
|
60
|
+
user_id: The id of the user.
|
61
|
+
query: The query to search for.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
A SearchMemoryResponse containing the matching memories.
|
65
|
+
"""
|
66
|
+
if not self._initialized:
|
67
|
+
await self.initialize()
|
68
|
+
return await self._search_memory_impl(
|
69
|
+
app_name=app_name, user_id=user_id, query=query
|
70
|
+
)
|
71
|
+
|
72
|
+
async def initialize(self) -> None:
|
73
|
+
"""Initialize the memory service."""
|
74
|
+
if not self._initialized:
|
75
|
+
await self._initialize_impl()
|
76
|
+
self._initialized = True
|
77
|
+
|
78
|
+
async def cleanup(self) -> None:
|
79
|
+
"""Clean up the memory service."""
|
80
|
+
if self._initialized:
|
81
|
+
await self._cleanup_impl()
|
82
|
+
self._initialized = False
|
83
|
+
|
84
|
+
@abstractmethod
|
85
|
+
async def _initialize_impl(self) -> None:
|
86
|
+
"""Implementation of initialization."""
|
87
|
+
|
88
|
+
@abstractmethod
|
89
|
+
async def _cleanup_impl(self) -> None:
|
90
|
+
"""Implementation of cleanup."""
|