google-adk-extras 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_adk_extras/__init__.py +1 -0
- google_adk_extras/artifacts/__init__.py +15 -0
- google_adk_extras/artifacts/base_custom_artifact_service.py +207 -0
- google_adk_extras/artifacts/local_folder_artifact_service.py +242 -0
- google_adk_extras/artifacts/mongo_artifact_service.py +221 -0
- google_adk_extras/artifacts/s3_artifact_service.py +323 -0
- google_adk_extras/artifacts/sql_artifact_service.py +255 -0
- google_adk_extras/memory/__init__.py +15 -0
- google_adk_extras/memory/base_custom_memory_service.py +90 -0
- google_adk_extras/memory/mongo_memory_service.py +174 -0
- google_adk_extras/memory/redis_memory_service.py +188 -0
- google_adk_extras/memory/sql_memory_service.py +213 -0
- google_adk_extras/memory/yaml_file_memory_service.py +176 -0
- google_adk_extras/py.typed +0 -0
- google_adk_extras/sessions/__init__.py +13 -0
- google_adk_extras/sessions/base_custom_session_service.py +183 -0
- google_adk_extras/sessions/mongo_session_service.py +243 -0
- google_adk_extras/sessions/redis_session_service.py +271 -0
- google_adk_extras/sessions/sql_session_service.py +308 -0
- google_adk_extras/sessions/yaml_file_session_service.py +245 -0
- google_adk_extras-0.1.1.dist-info/METADATA +175 -0
- google_adk_extras-0.1.1.dist-info/RECORD +25 -0
- google_adk_extras-0.1.1.dist-info/WHEEL +5 -0
- google_adk_extras-0.1.1.dist-info/licenses/LICENSE +202 -0
- google_adk_extras-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1 @@
|
|
1
|
+
"""Custom ADK Services - Extended implementations of Google ADK services."""
|
@@ -0,0 +1,15 @@
|
|
1
|
+
"""Custom artifact service implementations for Google ADK."""
|
2
|
+
|
3
|
+
from .base_custom_artifact_service import BaseCustomArtifactService
|
4
|
+
from .sql_artifact_service import SQLArtifactService
|
5
|
+
from .mongo_artifact_service import MongoArtifactService
|
6
|
+
from .local_folder_artifact_service import LocalFolderArtifactService
|
7
|
+
from .s3_artifact_service import S3ArtifactService
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"BaseCustomArtifactService",
|
11
|
+
"SQLArtifactService",
|
12
|
+
"MongoArtifactService",
|
13
|
+
"LocalFolderArtifactService",
|
14
|
+
"S3ArtifactService",
|
15
|
+
]
|
@@ -0,0 +1,207 @@
|
|
1
|
+
"""Base class for custom artifact services."""
|
2
|
+
|
3
|
+
import abc
|
4
|
+
from typing import Optional, List
|
5
|
+
|
6
|
+
from google.adk.artifacts.base_artifact_service import BaseArtifactService
|
7
|
+
from google.genai import types
|
8
|
+
|
9
|
+
|
10
|
+
class BaseCustomArtifactService(BaseArtifactService, abc.ABC):
|
11
|
+
"""Base class for custom artifact services with common functionality."""
|
12
|
+
|
13
|
+
def __init__(self):
|
14
|
+
"""Initialize the base custom artifact service."""
|
15
|
+
super().__init__()
|
16
|
+
self._initialized = False
|
17
|
+
|
18
|
+
async def initialize(self) -> None:
|
19
|
+
"""Initialize the artifact service.
|
20
|
+
|
21
|
+
This method should be called before using the service to ensure
|
22
|
+
any required setup (database connections, etc.) is complete.
|
23
|
+
"""
|
24
|
+
if not self._initialized:
|
25
|
+
await self._initialize_impl()
|
26
|
+
self._initialized = True
|
27
|
+
|
28
|
+
@abc.abstractmethod
|
29
|
+
async def _initialize_impl(self) -> None:
|
30
|
+
"""Implementation of service initialization.
|
31
|
+
|
32
|
+
This method should handle any setup required for the service to function,
|
33
|
+
such as database connections, creating tables, directories, etc.
|
34
|
+
"""
|
35
|
+
pass
|
36
|
+
|
37
|
+
async def cleanup(self) -> None:
|
38
|
+
"""Clean up resources used by the artifact service.
|
39
|
+
|
40
|
+
This method should be called when the service is no longer needed
|
41
|
+
to ensure proper cleanup of resources.
|
42
|
+
"""
|
43
|
+
if self._initialized:
|
44
|
+
await self._cleanup_impl()
|
45
|
+
self._initialized = False
|
46
|
+
|
47
|
+
@abc.abstractmethod
|
48
|
+
async def _cleanup_impl(self) -> None:
|
49
|
+
"""Implementation of service cleanup.
|
50
|
+
|
51
|
+
This method should handle any cleanup required for the service,
|
52
|
+
such as closing database connections.
|
53
|
+
"""
|
54
|
+
pass
|
55
|
+
|
56
|
+
async def save_artifact(
|
57
|
+
self,
|
58
|
+
*,
|
59
|
+
app_name: str,
|
60
|
+
user_id: str,
|
61
|
+
session_id: str,
|
62
|
+
filename: str,
|
63
|
+
artifact: types.Part,
|
64
|
+
) -> int:
|
65
|
+
"""Save an artifact."""
|
66
|
+
if not self._initialized:
|
67
|
+
await self.initialize()
|
68
|
+
return await self._save_artifact_impl(
|
69
|
+
app_name=app_name,
|
70
|
+
user_id=user_id,
|
71
|
+
session_id=session_id,
|
72
|
+
filename=filename,
|
73
|
+
artifact=artifact,
|
74
|
+
)
|
75
|
+
|
76
|
+
async def load_artifact(
|
77
|
+
self,
|
78
|
+
*,
|
79
|
+
app_name: str,
|
80
|
+
user_id: str,
|
81
|
+
session_id: str,
|
82
|
+
filename: str,
|
83
|
+
version: Optional[int] = None,
|
84
|
+
) -> Optional[types.Part]:
|
85
|
+
"""Load an artifact."""
|
86
|
+
if not self._initialized:
|
87
|
+
await self.initialize()
|
88
|
+
return await self._load_artifact_impl(
|
89
|
+
app_name=app_name,
|
90
|
+
user_id=user_id,
|
91
|
+
session_id=session_id,
|
92
|
+
filename=filename,
|
93
|
+
version=version,
|
94
|
+
)
|
95
|
+
|
96
|
+
async def list_artifact_keys(
|
97
|
+
self,
|
98
|
+
*,
|
99
|
+
app_name: str,
|
100
|
+
user_id: str,
|
101
|
+
session_id: str,
|
102
|
+
) -> List[str]:
|
103
|
+
"""List artifact keys."""
|
104
|
+
if not self._initialized:
|
105
|
+
await self.initialize()
|
106
|
+
return await self._list_artifact_keys_impl(
|
107
|
+
app_name=app_name,
|
108
|
+
user_id=user_id,
|
109
|
+
session_id=session_id,
|
110
|
+
)
|
111
|
+
|
112
|
+
async def delete_artifact(
|
113
|
+
self,
|
114
|
+
*,
|
115
|
+
app_name: str,
|
116
|
+
user_id: str,
|
117
|
+
session_id: str,
|
118
|
+
filename: str,
|
119
|
+
) -> None:
|
120
|
+
"""Delete an artifact."""
|
121
|
+
if not self._initialized:
|
122
|
+
await self.initialize()
|
123
|
+
await self._delete_artifact_impl(
|
124
|
+
app_name=app_name,
|
125
|
+
user_id=user_id,
|
126
|
+
session_id=session_id,
|
127
|
+
filename=filename,
|
128
|
+
)
|
129
|
+
|
130
|
+
async def list_versions(
|
131
|
+
self,
|
132
|
+
*,
|
133
|
+
app_name: str,
|
134
|
+
user_id: str,
|
135
|
+
session_id: str,
|
136
|
+
filename: str,
|
137
|
+
) -> List[int]:
|
138
|
+
"""List versions of an artifact."""
|
139
|
+
if not self._initialized:
|
140
|
+
await self.initialize()
|
141
|
+
return await self._list_versions_impl(
|
142
|
+
app_name=app_name,
|
143
|
+
user_id=user_id,
|
144
|
+
session_id=session_id,
|
145
|
+
filename=filename,
|
146
|
+
)
|
147
|
+
|
148
|
+
@abc.abstractmethod
|
149
|
+
async def _save_artifact_impl(
|
150
|
+
self,
|
151
|
+
*,
|
152
|
+
app_name: str,
|
153
|
+
user_id: str,
|
154
|
+
session_id: str,
|
155
|
+
filename: str,
|
156
|
+
artifact: types.Part,
|
157
|
+
) -> int:
|
158
|
+
"""Implementation of artifact saving."""
|
159
|
+
pass
|
160
|
+
|
161
|
+
@abc.abstractmethod
|
162
|
+
async def _load_artifact_impl(
|
163
|
+
self,
|
164
|
+
*,
|
165
|
+
app_name: str,
|
166
|
+
user_id: str,
|
167
|
+
session_id: str,
|
168
|
+
filename: str,
|
169
|
+
version: Optional[int] = None,
|
170
|
+
) -> Optional[types.Part]:
|
171
|
+
"""Implementation of artifact loading."""
|
172
|
+
pass
|
173
|
+
|
174
|
+
@abc.abstractmethod
|
175
|
+
async def _list_artifact_keys_impl(
|
176
|
+
self,
|
177
|
+
*,
|
178
|
+
app_name: str,
|
179
|
+
user_id: str,
|
180
|
+
session_id: str,
|
181
|
+
) -> List[str]:
|
182
|
+
"""Implementation of artifact key listing."""
|
183
|
+
pass
|
184
|
+
|
185
|
+
@abc.abstractmethod
|
186
|
+
async def _delete_artifact_impl(
|
187
|
+
self,
|
188
|
+
*,
|
189
|
+
app_name: str,
|
190
|
+
user_id: str,
|
191
|
+
session_id: str,
|
192
|
+
filename: str,
|
193
|
+
) -> None:
|
194
|
+
"""Implementation of artifact deletion."""
|
195
|
+
pass
|
196
|
+
|
197
|
+
@abc.abstractmethod
|
198
|
+
async def _list_versions_impl(
|
199
|
+
self,
|
200
|
+
*,
|
201
|
+
app_name: str,
|
202
|
+
user_id: str,
|
203
|
+
session_id: str,
|
204
|
+
filename: str,
|
205
|
+
) -> List[int]:
|
206
|
+
"""Implementation of version listing."""
|
207
|
+
pass
|
@@ -0,0 +1,242 @@
|
|
1
|
+
"""Local folder-based artifact service implementation."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
import json
|
5
|
+
import base64
|
6
|
+
from typing import Optional, List
|
7
|
+
from pathlib import Path
|
8
|
+
from datetime import datetime, timezone
|
9
|
+
|
10
|
+
from google.genai import types
|
11
|
+
from .base_custom_artifact_service import BaseCustomArtifactService
|
12
|
+
|
13
|
+
|
14
|
+
class LocalFolderArtifactService(BaseCustomArtifactService):
|
15
|
+
"""Local folder-based artifact service implementation."""
|
16
|
+
|
17
|
+
def __init__(self, base_directory: str = "./artifacts"):
|
18
|
+
"""Initialize the local folder artifact service.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
base_directory: Base directory for storing artifacts
|
22
|
+
"""
|
23
|
+
super().__init__()
|
24
|
+
self.base_directory = Path(base_directory)
|
25
|
+
# Create base directory if it doesn't exist
|
26
|
+
self.base_directory.mkdir(parents=True, exist_ok=True)
|
27
|
+
|
28
|
+
async def _initialize_impl(self) -> None:
|
29
|
+
"""Initialize the file system artifact service."""
|
30
|
+
# Ensure base directory exists
|
31
|
+
self.base_directory.mkdir(parents=True, exist_ok=True)
|
32
|
+
|
33
|
+
async def _cleanup_impl(self) -> None:
|
34
|
+
"""Clean up resources (no cleanup needed for file-based service)."""
|
35
|
+
pass
|
36
|
+
|
37
|
+
def _get_artifact_directory(self, app_name: str, user_id: str, session_id: str) -> Path:
|
38
|
+
"""Generate directory path for artifacts."""
|
39
|
+
directory = self.base_directory / app_name / user_id / session_id
|
40
|
+
directory.mkdir(parents=True, exist_ok=True)
|
41
|
+
return directory
|
42
|
+
|
43
|
+
def _get_artifact_file_path(self, app_name: str, user_id: str, session_id: str, filename: str) -> Path:
|
44
|
+
"""Generate file path for artifact metadata."""
|
45
|
+
directory = self._get_artifact_directory(app_name, user_id, session_id)
|
46
|
+
return directory / f"{filename}.json"
|
47
|
+
|
48
|
+
def _get_artifact_data_path(self, app_name: str, user_id: str, session_id: str, filename: str, version: int) -> Path:
|
49
|
+
"""Generate file path for artifact data."""
|
50
|
+
directory = self._get_artifact_directory(app_name, user_id, session_id)
|
51
|
+
return directory / f"{filename}.v{version}.data"
|
52
|
+
|
53
|
+
def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
|
54
|
+
"""Extract blob data and mime type from a Part."""
|
55
|
+
if part.inline_data:
|
56
|
+
return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
|
57
|
+
else:
|
58
|
+
raise ValueError("Only inline_data parts are supported")
|
59
|
+
|
60
|
+
def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
|
61
|
+
"""Create a Part from blob data and mime type."""
|
62
|
+
blob = types.Blob(data=data, mime_type=mime_type)
|
63
|
+
return types.Part(inline_data=blob)
|
64
|
+
|
65
|
+
async def _save_artifact_impl(
|
66
|
+
self,
|
67
|
+
*,
|
68
|
+
app_name: str,
|
69
|
+
user_id: str,
|
70
|
+
session_id: str,
|
71
|
+
filename: str,
|
72
|
+
artifact: types.Part,
|
73
|
+
) -> int:
|
74
|
+
"""Implementation of artifact saving."""
|
75
|
+
try:
|
76
|
+
# Extract blob data
|
77
|
+
data, mime_type = self._serialize_blob(artifact)
|
78
|
+
|
79
|
+
# Get the next version number
|
80
|
+
metadata_file = self._get_artifact_file_path(app_name, user_id, session_id, filename)
|
81
|
+
|
82
|
+
if metadata_file.exists():
|
83
|
+
with open(metadata_file, 'r') as f:
|
84
|
+
metadata = json.load(f)
|
85
|
+
version = len(metadata.get("versions", []))
|
86
|
+
else:
|
87
|
+
metadata = {
|
88
|
+
"app_name": app_name,
|
89
|
+
"user_id": user_id,
|
90
|
+
"session_id": session_id,
|
91
|
+
"filename": filename,
|
92
|
+
"versions": []
|
93
|
+
}
|
94
|
+
version = 0
|
95
|
+
|
96
|
+
# Save data to file
|
97
|
+
data_file = self._get_artifact_data_path(app_name, user_id, session_id, filename, version)
|
98
|
+
with open(data_file, 'wb') as f:
|
99
|
+
f.write(data)
|
100
|
+
|
101
|
+
# Update metadata
|
102
|
+
version_info = {
|
103
|
+
"version": version,
|
104
|
+
"mime_type": mime_type,
|
105
|
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
106
|
+
"data_file": data_file.name
|
107
|
+
}
|
108
|
+
metadata["versions"].append(version_info)
|
109
|
+
|
110
|
+
# Save metadata
|
111
|
+
with open(metadata_file, 'w') as f:
|
112
|
+
json.dump(metadata, f, indent=2)
|
113
|
+
|
114
|
+
return version
|
115
|
+
except Exception as e:
|
116
|
+
raise RuntimeError(f"Failed to save artifact: {e}")
|
117
|
+
|
118
|
+
async def _load_artifact_impl(
|
119
|
+
self,
|
120
|
+
*,
|
121
|
+
app_name: str,
|
122
|
+
user_id: str,
|
123
|
+
session_id: str,
|
124
|
+
filename: str,
|
125
|
+
version: Optional[int] = None,
|
126
|
+
) -> Optional[types.Part]:
|
127
|
+
"""Implementation of artifact loading."""
|
128
|
+
try:
|
129
|
+
# Load metadata
|
130
|
+
metadata_file = self._get_artifact_file_path(app_name, user_id, session_id, filename)
|
131
|
+
|
132
|
+
if not metadata_file.exists():
|
133
|
+
return None
|
134
|
+
|
135
|
+
with open(metadata_file, 'r') as f:
|
136
|
+
metadata = json.load(f)
|
137
|
+
|
138
|
+
# Determine version to load
|
139
|
+
versions = metadata.get("versions", [])
|
140
|
+
if not versions:
|
141
|
+
return None
|
142
|
+
|
143
|
+
if version is not None:
|
144
|
+
# Find specific version
|
145
|
+
version_info = None
|
146
|
+
for v in versions:
|
147
|
+
if v["version"] == version:
|
148
|
+
version_info = v
|
149
|
+
break
|
150
|
+
if not version_info:
|
151
|
+
return None
|
152
|
+
else:
|
153
|
+
# Load latest version
|
154
|
+
version_info = versions[-1]
|
155
|
+
|
156
|
+
# Load data
|
157
|
+
data_file = self._get_artifact_directory(app_name, user_id, session_id) / version_info["data_file"]
|
158
|
+
if not data_file.exists():
|
159
|
+
return None
|
160
|
+
|
161
|
+
with open(data_file, 'rb') as f:
|
162
|
+
data = f.read()
|
163
|
+
|
164
|
+
# Create Part from blob data
|
165
|
+
return self._deserialize_blob(data, version_info["mime_type"])
|
166
|
+
except Exception as e:
|
167
|
+
raise RuntimeError(f"Failed to load artifact: {e}")
|
168
|
+
|
169
|
+
async def _list_artifact_keys_impl(
|
170
|
+
self,
|
171
|
+
*,
|
172
|
+
app_name: str,
|
173
|
+
user_id: str,
|
174
|
+
session_id: str,
|
175
|
+
) -> List[str]:
|
176
|
+
"""Implementation of artifact key listing."""
|
177
|
+
try:
|
178
|
+
directory = self._get_artifact_directory(app_name, user_id, session_id)
|
179
|
+
|
180
|
+
# Find all metadata files
|
181
|
+
artifact_keys = []
|
182
|
+
if directory.exists():
|
183
|
+
for file_path in directory.glob("*.json"):
|
184
|
+
# Extract filename from metadata file name (remove .json extension)
|
185
|
+
filename = file_path.name[:-5] # Remove .json
|
186
|
+
artifact_keys.append(filename)
|
187
|
+
|
188
|
+
return artifact_keys
|
189
|
+
except Exception as e:
|
190
|
+
raise RuntimeError(f"Failed to list artifact keys: {e}")
|
191
|
+
|
192
|
+
async def _delete_artifact_impl(
|
193
|
+
self,
|
194
|
+
*,
|
195
|
+
app_name: str,
|
196
|
+
user_id: str,
|
197
|
+
session_id: str,
|
198
|
+
filename: str,
|
199
|
+
) -> None:
|
200
|
+
"""Implementation of artifact deletion."""
|
201
|
+
try:
|
202
|
+
# Load metadata to find all version files
|
203
|
+
metadata_file = self._get_artifact_file_path(app_name, user_id, session_id, filename)
|
204
|
+
|
205
|
+
if metadata_file.exists():
|
206
|
+
with open(metadata_file, 'r') as f:
|
207
|
+
metadata = json.load(f)
|
208
|
+
|
209
|
+
# Delete all version data files
|
210
|
+
directory = self._get_artifact_directory(app_name, user_id, session_id)
|
211
|
+
for version_info in metadata.get("versions", []):
|
212
|
+
data_file = directory / version_info["data_file"]
|
213
|
+
if data_file.exists():
|
214
|
+
data_file.unlink()
|
215
|
+
|
216
|
+
# Delete metadata file
|
217
|
+
metadata_file.unlink()
|
218
|
+
except Exception as e:
|
219
|
+
raise RuntimeError(f"Failed to delete artifact: {e}")
|
220
|
+
|
221
|
+
async def _list_versions_impl(
|
222
|
+
self,
|
223
|
+
*,
|
224
|
+
app_name: str,
|
225
|
+
user_id: str,
|
226
|
+
session_id: str,
|
227
|
+
filename: str,
|
228
|
+
) -> List[int]:
|
229
|
+
"""Implementation of version listing."""
|
230
|
+
try:
|
231
|
+
metadata_file = self._get_artifact_file_path(app_name, user_id, session_id, filename)
|
232
|
+
|
233
|
+
if not metadata_file.exists():
|
234
|
+
return []
|
235
|
+
|
236
|
+
with open(metadata_file, 'r') as f:
|
237
|
+
metadata = json.load(f)
|
238
|
+
|
239
|
+
versions = metadata.get("versions", [])
|
240
|
+
return [v["version"] for v in versions]
|
241
|
+
except Exception as e:
|
242
|
+
raise RuntimeError(f"Failed to list versions: {e}")
|
@@ -0,0 +1,221 @@
|
|
1
|
+
"""MongoDB-based artifact service implementation."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
from typing import Optional, List
|
5
|
+
from datetime import datetime
|
6
|
+
|
7
|
+
try:
|
8
|
+
from pymongo import MongoClient
|
9
|
+
from pymongo.errors import PyMongoError
|
10
|
+
except ImportError:
|
11
|
+
raise ImportError(
|
12
|
+
"PyMongo is required for MongoArtifactService. "
|
13
|
+
"Install it with: pip install pymongo"
|
14
|
+
)
|
15
|
+
|
16
|
+
from google.genai import types
|
17
|
+
from .base_custom_artifact_service import BaseCustomArtifactService
|
18
|
+
|
19
|
+
|
20
|
+
class MongoArtifactService(BaseCustomArtifactService):
|
21
|
+
"""MongoDB-based artifact service implementation."""
|
22
|
+
|
23
|
+
def __init__(self, connection_string: str, database_name: str = "adk_artifacts"):
|
24
|
+
"""Initialize the MongoDB artifact service.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
connection_string: MongoDB connection string
|
28
|
+
database_name: Name of the database to use
|
29
|
+
"""
|
30
|
+
super().__init__()
|
31
|
+
self.connection_string = connection_string
|
32
|
+
self.database_name = database_name
|
33
|
+
self.client: Optional[MongoClient] = None
|
34
|
+
self.db = None
|
35
|
+
self.collection = None
|
36
|
+
|
37
|
+
async def _initialize_impl(self) -> None:
|
38
|
+
"""Initialize the MongoDB connection."""
|
39
|
+
try:
|
40
|
+
self.client = MongoClient(self.connection_string)
|
41
|
+
self.db = self.client[self.database_name]
|
42
|
+
self.collection = self.db.artifacts
|
43
|
+
|
44
|
+
# Create indexes for better performance
|
45
|
+
self.collection.create_index([
|
46
|
+
("app_name", 1),
|
47
|
+
("user_id", 1),
|
48
|
+
("session_id", 1),
|
49
|
+
("filename", 1),
|
50
|
+
("version", 1)
|
51
|
+
])
|
52
|
+
except PyMongoError as e:
|
53
|
+
raise RuntimeError(f"Failed to initialize MongoDB artifact service: {e}")
|
54
|
+
|
55
|
+
async def _cleanup_impl(self) -> None:
|
56
|
+
"""Clean up MongoDB connections."""
|
57
|
+
if self.client:
|
58
|
+
self.client.close()
|
59
|
+
self.client = None
|
60
|
+
self.db = None
|
61
|
+
self.collection = None
|
62
|
+
|
63
|
+
def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
|
64
|
+
"""Extract blob data and mime type from a Part."""
|
65
|
+
if part.inline_data:
|
66
|
+
return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
|
67
|
+
else:
|
68
|
+
raise ValueError("Only inline_data parts are supported")
|
69
|
+
|
70
|
+
def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
|
71
|
+
"""Create a Part from blob data and mime type."""
|
72
|
+
blob = types.Blob(data=data, mime_type=mime_type)
|
73
|
+
return types.Part(inline_data=blob)
|
74
|
+
|
75
|
+
async def _save_artifact_impl(
|
76
|
+
self,
|
77
|
+
*,
|
78
|
+
app_name: str,
|
79
|
+
user_id: str,
|
80
|
+
session_id: str,
|
81
|
+
filename: str,
|
82
|
+
artifact: types.Part,
|
83
|
+
) -> int:
|
84
|
+
"""Implementation of artifact saving."""
|
85
|
+
try:
|
86
|
+
# Extract blob data
|
87
|
+
data, mime_type = self._serialize_blob(artifact)
|
88
|
+
|
89
|
+
# Get the next version number
|
90
|
+
latest_version_doc = self.collection.find_one(
|
91
|
+
{
|
92
|
+
"app_name": app_name,
|
93
|
+
"user_id": user_id,
|
94
|
+
"session_id": session_id,
|
95
|
+
"filename": filename
|
96
|
+
},
|
97
|
+
sort=[("version", -1)]
|
98
|
+
)
|
99
|
+
|
100
|
+
version = (latest_version_doc["version"] + 1) if latest_version_doc else 0
|
101
|
+
|
102
|
+
# Create document
|
103
|
+
document = {
|
104
|
+
"app_name": app_name,
|
105
|
+
"user_id": user_id,
|
106
|
+
"session_id": session_id,
|
107
|
+
"filename": filename,
|
108
|
+
"version": version,
|
109
|
+
"mime_type": mime_type,
|
110
|
+
"data": data,
|
111
|
+
"created_at": datetime.utcnow()
|
112
|
+
}
|
113
|
+
|
114
|
+
# Insert into MongoDB
|
115
|
+
self.collection.insert_one(document)
|
116
|
+
|
117
|
+
return version
|
118
|
+
except PyMongoError as e:
|
119
|
+
raise RuntimeError(f"Failed to save artifact: {e}")
|
120
|
+
|
121
|
+
async def _load_artifact_impl(
|
122
|
+
self,
|
123
|
+
*,
|
124
|
+
app_name: str,
|
125
|
+
user_id: str,
|
126
|
+
session_id: str,
|
127
|
+
filename: str,
|
128
|
+
version: Optional[int] = None,
|
129
|
+
) -> Optional[types.Part]:
|
130
|
+
"""Implementation of artifact loading."""
|
131
|
+
try:
|
132
|
+
query = {
|
133
|
+
"app_name": app_name,
|
134
|
+
"user_id": user_id,
|
135
|
+
"session_id": session_id,
|
136
|
+
"filename": filename
|
137
|
+
}
|
138
|
+
|
139
|
+
if version is not None:
|
140
|
+
query["version"] = version
|
141
|
+
sort = None
|
142
|
+
else:
|
143
|
+
# Sort by version descending to get the latest
|
144
|
+
sort = [("version", -1)]
|
145
|
+
|
146
|
+
document = self.collection.find_one(query, sort=sort)
|
147
|
+
|
148
|
+
if not document:
|
149
|
+
return None
|
150
|
+
|
151
|
+
# Create Part from blob data
|
152
|
+
return self._deserialize_blob(document["data"], document["mime_type"])
|
153
|
+
except PyMongoError as e:
|
154
|
+
raise RuntimeError(f"Failed to load artifact: {e}")
|
155
|
+
|
156
|
+
async def _list_artifact_keys_impl(
|
157
|
+
self,
|
158
|
+
*,
|
159
|
+
app_name: str,
|
160
|
+
user_id: str,
|
161
|
+
session_id: str,
|
162
|
+
) -> List[str]:
|
163
|
+
"""Implementation of artifact key listing."""
|
164
|
+
try:
|
165
|
+
# Get distinct filenames
|
166
|
+
cursor = self.collection.distinct(
|
167
|
+
"filename",
|
168
|
+
{
|
169
|
+
"app_name": app_name,
|
170
|
+
"user_id": user_id,
|
171
|
+
"session_id": session_id
|
172
|
+
}
|
173
|
+
)
|
174
|
+
|
175
|
+
return list(cursor)
|
176
|
+
except PyMongoError as e:
|
177
|
+
raise RuntimeError(f"Failed to list artifact keys: {e}")
|
178
|
+
|
179
|
+
async def _delete_artifact_impl(
|
180
|
+
self,
|
181
|
+
*,
|
182
|
+
app_name: str,
|
183
|
+
user_id: str,
|
184
|
+
session_id: str,
|
185
|
+
filename: str,
|
186
|
+
) -> None:
|
187
|
+
"""Implementation of artifact deletion."""
|
188
|
+
try:
|
189
|
+
# Delete all versions of the artifact
|
190
|
+
self.collection.delete_many({
|
191
|
+
"app_name": app_name,
|
192
|
+
"user_id": user_id,
|
193
|
+
"session_id": session_id,
|
194
|
+
"filename": filename
|
195
|
+
})
|
196
|
+
except PyMongoError as e:
|
197
|
+
raise RuntimeError(f"Failed to delete artifact: {e}")
|
198
|
+
|
199
|
+
async def _list_versions_impl(
|
200
|
+
self,
|
201
|
+
*,
|
202
|
+
app_name: str,
|
203
|
+
user_id: str,
|
204
|
+
session_id: str,
|
205
|
+
filename: str,
|
206
|
+
) -> List[int]:
|
207
|
+
"""Implementation of version listing."""
|
208
|
+
try:
|
209
|
+
cursor = self.collection.find(
|
210
|
+
{
|
211
|
+
"app_name": app_name,
|
212
|
+
"user_id": user_id,
|
213
|
+
"session_id": session_id,
|
214
|
+
"filename": filename
|
215
|
+
},
|
216
|
+
{"version": 1, "_id": 0}
|
217
|
+
).sort("version", 1)
|
218
|
+
|
219
|
+
return [doc["version"] for doc in cursor]
|
220
|
+
except PyMongoError as e:
|
221
|
+
raise RuntimeError(f"Failed to list versions: {e}")
|