google-adk-extras 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ """Custom ADK Services - Extended implementations of Google ADK services."""
@@ -0,0 +1,15 @@
1
+ """Custom artifact service implementations for Google ADK."""
2
+
3
+ from .base_custom_artifact_service import BaseCustomArtifactService
4
+ from .sql_artifact_service import SQLArtifactService
5
+ from .mongo_artifact_service import MongoArtifactService
6
+ from .local_folder_artifact_service import LocalFolderArtifactService
7
+ from .s3_artifact_service import S3ArtifactService
8
+
9
+ __all__ = [
10
+ "BaseCustomArtifactService",
11
+ "SQLArtifactService",
12
+ "MongoArtifactService",
13
+ "LocalFolderArtifactService",
14
+ "S3ArtifactService",
15
+ ]
@@ -0,0 +1,207 @@
1
+ """Base class for custom artifact services."""
2
+
3
+ import abc
4
+ from typing import Optional, List
5
+
6
+ from google.adk.artifacts.base_artifact_service import BaseArtifactService
7
+ from google.genai import types
8
+
9
+
10
+ class BaseCustomArtifactService(BaseArtifactService, abc.ABC):
11
+ """Base class for custom artifact services with common functionality."""
12
+
13
+ def __init__(self):
14
+ """Initialize the base custom artifact service."""
15
+ super().__init__()
16
+ self._initialized = False
17
+
18
+ async def initialize(self) -> None:
19
+ """Initialize the artifact service.
20
+
21
+ This method should be called before using the service to ensure
22
+ any required setup (database connections, etc.) is complete.
23
+ """
24
+ if not self._initialized:
25
+ await self._initialize_impl()
26
+ self._initialized = True
27
+
28
+ @abc.abstractmethod
29
+ async def _initialize_impl(self) -> None:
30
+ """Implementation of service initialization.
31
+
32
+ This method should handle any setup required for the service to function,
33
+ such as database connections, creating tables, directories, etc.
34
+ """
35
+ pass
36
+
37
+ async def cleanup(self) -> None:
38
+ """Clean up resources used by the artifact service.
39
+
40
+ This method should be called when the service is no longer needed
41
+ to ensure proper cleanup of resources.
42
+ """
43
+ if self._initialized:
44
+ await self._cleanup_impl()
45
+ self._initialized = False
46
+
47
+ @abc.abstractmethod
48
+ async def _cleanup_impl(self) -> None:
49
+ """Implementation of service cleanup.
50
+
51
+ This method should handle any cleanup required for the service,
52
+ such as closing database connections.
53
+ """
54
+ pass
55
+
56
+ async def save_artifact(
57
+ self,
58
+ *,
59
+ app_name: str,
60
+ user_id: str,
61
+ session_id: str,
62
+ filename: str,
63
+ artifact: types.Part,
64
+ ) -> int:
65
+ """Save an artifact."""
66
+ if not self._initialized:
67
+ await self.initialize()
68
+ return await self._save_artifact_impl(
69
+ app_name=app_name,
70
+ user_id=user_id,
71
+ session_id=session_id,
72
+ filename=filename,
73
+ artifact=artifact,
74
+ )
75
+
76
+ async def load_artifact(
77
+ self,
78
+ *,
79
+ app_name: str,
80
+ user_id: str,
81
+ session_id: str,
82
+ filename: str,
83
+ version: Optional[int] = None,
84
+ ) -> Optional[types.Part]:
85
+ """Load an artifact."""
86
+ if not self._initialized:
87
+ await self.initialize()
88
+ return await self._load_artifact_impl(
89
+ app_name=app_name,
90
+ user_id=user_id,
91
+ session_id=session_id,
92
+ filename=filename,
93
+ version=version,
94
+ )
95
+
96
+ async def list_artifact_keys(
97
+ self,
98
+ *,
99
+ app_name: str,
100
+ user_id: str,
101
+ session_id: str,
102
+ ) -> List[str]:
103
+ """List artifact keys."""
104
+ if not self._initialized:
105
+ await self.initialize()
106
+ return await self._list_artifact_keys_impl(
107
+ app_name=app_name,
108
+ user_id=user_id,
109
+ session_id=session_id,
110
+ )
111
+
112
+ async def delete_artifact(
113
+ self,
114
+ *,
115
+ app_name: str,
116
+ user_id: str,
117
+ session_id: str,
118
+ filename: str,
119
+ ) -> None:
120
+ """Delete an artifact."""
121
+ if not self._initialized:
122
+ await self.initialize()
123
+ await self._delete_artifact_impl(
124
+ app_name=app_name,
125
+ user_id=user_id,
126
+ session_id=session_id,
127
+ filename=filename,
128
+ )
129
+
130
+ async def list_versions(
131
+ self,
132
+ *,
133
+ app_name: str,
134
+ user_id: str,
135
+ session_id: str,
136
+ filename: str,
137
+ ) -> List[int]:
138
+ """List versions of an artifact."""
139
+ if not self._initialized:
140
+ await self.initialize()
141
+ return await self._list_versions_impl(
142
+ app_name=app_name,
143
+ user_id=user_id,
144
+ session_id=session_id,
145
+ filename=filename,
146
+ )
147
+
148
+ @abc.abstractmethod
149
+ async def _save_artifact_impl(
150
+ self,
151
+ *,
152
+ app_name: str,
153
+ user_id: str,
154
+ session_id: str,
155
+ filename: str,
156
+ artifact: types.Part,
157
+ ) -> int:
158
+ """Implementation of artifact saving."""
159
+ pass
160
+
161
+ @abc.abstractmethod
162
+ async def _load_artifact_impl(
163
+ self,
164
+ *,
165
+ app_name: str,
166
+ user_id: str,
167
+ session_id: str,
168
+ filename: str,
169
+ version: Optional[int] = None,
170
+ ) -> Optional[types.Part]:
171
+ """Implementation of artifact loading."""
172
+ pass
173
+
174
+ @abc.abstractmethod
175
+ async def _list_artifact_keys_impl(
176
+ self,
177
+ *,
178
+ app_name: str,
179
+ user_id: str,
180
+ session_id: str,
181
+ ) -> List[str]:
182
+ """Implementation of artifact key listing."""
183
+ pass
184
+
185
+ @abc.abstractmethod
186
+ async def _delete_artifact_impl(
187
+ self,
188
+ *,
189
+ app_name: str,
190
+ user_id: str,
191
+ session_id: str,
192
+ filename: str,
193
+ ) -> None:
194
+ """Implementation of artifact deletion."""
195
+ pass
196
+
197
+ @abc.abstractmethod
198
+ async def _list_versions_impl(
199
+ self,
200
+ *,
201
+ app_name: str,
202
+ user_id: str,
203
+ session_id: str,
204
+ filename: str,
205
+ ) -> List[int]:
206
+ """Implementation of version listing."""
207
+ pass
@@ -0,0 +1,242 @@
1
+ """Local folder-based artifact service implementation."""
2
+
3
+ import os
4
+ import json
5
+ import base64
6
+ from typing import Optional, List
7
+ from pathlib import Path
8
+ from datetime import datetime, timezone
9
+
10
+ from google.genai import types
11
+ from .base_custom_artifact_service import BaseCustomArtifactService
12
+
13
+
14
+ class LocalFolderArtifactService(BaseCustomArtifactService):
15
+ """Local folder-based artifact service implementation."""
16
+
17
+ def __init__(self, base_directory: str = "./artifacts"):
18
+ """Initialize the local folder artifact service.
19
+
20
+ Args:
21
+ base_directory: Base directory for storing artifacts
22
+ """
23
+ super().__init__()
24
+ self.base_directory = Path(base_directory)
25
+ # Create base directory if it doesn't exist
26
+ self.base_directory.mkdir(parents=True, exist_ok=True)
27
+
28
+ async def _initialize_impl(self) -> None:
29
+ """Initialize the file system artifact service."""
30
+ # Ensure base directory exists
31
+ self.base_directory.mkdir(parents=True, exist_ok=True)
32
+
33
+ async def _cleanup_impl(self) -> None:
34
+ """Clean up resources (no cleanup needed for file-based service)."""
35
+ pass
36
+
37
+ def _get_artifact_directory(self, app_name: str, user_id: str, session_id: str) -> Path:
38
+ """Generate directory path for artifacts."""
39
+ directory = self.base_directory / app_name / user_id / session_id
40
+ directory.mkdir(parents=True, exist_ok=True)
41
+ return directory
42
+
43
+ def _get_artifact_file_path(self, app_name: str, user_id: str, session_id: str, filename: str) -> Path:
44
+ """Generate file path for artifact metadata."""
45
+ directory = self._get_artifact_directory(app_name, user_id, session_id)
46
+ return directory / f"{filename}.json"
47
+
48
+ def _get_artifact_data_path(self, app_name: str, user_id: str, session_id: str, filename: str, version: int) -> Path:
49
+ """Generate file path for artifact data."""
50
+ directory = self._get_artifact_directory(app_name, user_id, session_id)
51
+ return directory / f"{filename}.v{version}.data"
52
+
53
+ def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
54
+ """Extract blob data and mime type from a Part."""
55
+ if part.inline_data:
56
+ return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
57
+ else:
58
+ raise ValueError("Only inline_data parts are supported")
59
+
60
+ def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
61
+ """Create a Part from blob data and mime type."""
62
+ blob = types.Blob(data=data, mime_type=mime_type)
63
+ return types.Part(inline_data=blob)
64
+
65
+ async def _save_artifact_impl(
66
+ self,
67
+ *,
68
+ app_name: str,
69
+ user_id: str,
70
+ session_id: str,
71
+ filename: str,
72
+ artifact: types.Part,
73
+ ) -> int:
74
+ """Implementation of artifact saving."""
75
+ try:
76
+ # Extract blob data
77
+ data, mime_type = self._serialize_blob(artifact)
78
+
79
+ # Get the next version number
80
+ metadata_file = self._get_artifact_file_path(app_name, user_id, session_id, filename)
81
+
82
+ if metadata_file.exists():
83
+ with open(metadata_file, 'r') as f:
84
+ metadata = json.load(f)
85
+ version = len(metadata.get("versions", []))
86
+ else:
87
+ metadata = {
88
+ "app_name": app_name,
89
+ "user_id": user_id,
90
+ "session_id": session_id,
91
+ "filename": filename,
92
+ "versions": []
93
+ }
94
+ version = 0
95
+
96
+ # Save data to file
97
+ data_file = self._get_artifact_data_path(app_name, user_id, session_id, filename, version)
98
+ with open(data_file, 'wb') as f:
99
+ f.write(data)
100
+
101
+ # Update metadata
102
+ version_info = {
103
+ "version": version,
104
+ "mime_type": mime_type,
105
+ "created_at": datetime.now(timezone.utc).isoformat(),
106
+ "data_file": data_file.name
107
+ }
108
+ metadata["versions"].append(version_info)
109
+
110
+ # Save metadata
111
+ with open(metadata_file, 'w') as f:
112
+ json.dump(metadata, f, indent=2)
113
+
114
+ return version
115
+ except Exception as e:
116
+ raise RuntimeError(f"Failed to save artifact: {e}")
117
+
118
+ async def _load_artifact_impl(
119
+ self,
120
+ *,
121
+ app_name: str,
122
+ user_id: str,
123
+ session_id: str,
124
+ filename: str,
125
+ version: Optional[int] = None,
126
+ ) -> Optional[types.Part]:
127
+ """Implementation of artifact loading."""
128
+ try:
129
+ # Load metadata
130
+ metadata_file = self._get_artifact_file_path(app_name, user_id, session_id, filename)
131
+
132
+ if not metadata_file.exists():
133
+ return None
134
+
135
+ with open(metadata_file, 'r') as f:
136
+ metadata = json.load(f)
137
+
138
+ # Determine version to load
139
+ versions = metadata.get("versions", [])
140
+ if not versions:
141
+ return None
142
+
143
+ if version is not None:
144
+ # Find specific version
145
+ version_info = None
146
+ for v in versions:
147
+ if v["version"] == version:
148
+ version_info = v
149
+ break
150
+ if not version_info:
151
+ return None
152
+ else:
153
+ # Load latest version
154
+ version_info = versions[-1]
155
+
156
+ # Load data
157
+ data_file = self._get_artifact_directory(app_name, user_id, session_id) / version_info["data_file"]
158
+ if not data_file.exists():
159
+ return None
160
+
161
+ with open(data_file, 'rb') as f:
162
+ data = f.read()
163
+
164
+ # Create Part from blob data
165
+ return self._deserialize_blob(data, version_info["mime_type"])
166
+ except Exception as e:
167
+ raise RuntimeError(f"Failed to load artifact: {e}")
168
+
169
+ async def _list_artifact_keys_impl(
170
+ self,
171
+ *,
172
+ app_name: str,
173
+ user_id: str,
174
+ session_id: str,
175
+ ) -> List[str]:
176
+ """Implementation of artifact key listing."""
177
+ try:
178
+ directory = self._get_artifact_directory(app_name, user_id, session_id)
179
+
180
+ # Find all metadata files
181
+ artifact_keys = []
182
+ if directory.exists():
183
+ for file_path in directory.glob("*.json"):
184
+ # Extract filename from metadata file name (remove .json extension)
185
+ filename = file_path.name[:-5] # Remove .json
186
+ artifact_keys.append(filename)
187
+
188
+ return artifact_keys
189
+ except Exception as e:
190
+ raise RuntimeError(f"Failed to list artifact keys: {e}")
191
+
192
+ async def _delete_artifact_impl(
193
+ self,
194
+ *,
195
+ app_name: str,
196
+ user_id: str,
197
+ session_id: str,
198
+ filename: str,
199
+ ) -> None:
200
+ """Implementation of artifact deletion."""
201
+ try:
202
+ # Load metadata to find all version files
203
+ metadata_file = self._get_artifact_file_path(app_name, user_id, session_id, filename)
204
+
205
+ if metadata_file.exists():
206
+ with open(metadata_file, 'r') as f:
207
+ metadata = json.load(f)
208
+
209
+ # Delete all version data files
210
+ directory = self._get_artifact_directory(app_name, user_id, session_id)
211
+ for version_info in metadata.get("versions", []):
212
+ data_file = directory / version_info["data_file"]
213
+ if data_file.exists():
214
+ data_file.unlink()
215
+
216
+ # Delete metadata file
217
+ metadata_file.unlink()
218
+ except Exception as e:
219
+ raise RuntimeError(f"Failed to delete artifact: {e}")
220
+
221
+ async def _list_versions_impl(
222
+ self,
223
+ *,
224
+ app_name: str,
225
+ user_id: str,
226
+ session_id: str,
227
+ filename: str,
228
+ ) -> List[int]:
229
+ """Implementation of version listing."""
230
+ try:
231
+ metadata_file = self._get_artifact_file_path(app_name, user_id, session_id, filename)
232
+
233
+ if not metadata_file.exists():
234
+ return []
235
+
236
+ with open(metadata_file, 'r') as f:
237
+ metadata = json.load(f)
238
+
239
+ versions = metadata.get("versions", [])
240
+ return [v["version"] for v in versions]
241
+ except Exception as e:
242
+ raise RuntimeError(f"Failed to list versions: {e}")
@@ -0,0 +1,221 @@
1
+ """MongoDB-based artifact service implementation."""
2
+
3
+ import json
4
+ from typing import Optional, List
5
+ from datetime import datetime
6
+
7
+ try:
8
+ from pymongo import MongoClient
9
+ from pymongo.errors import PyMongoError
10
+ except ImportError:
11
+ raise ImportError(
12
+ "PyMongo is required for MongoArtifactService. "
13
+ "Install it with: pip install pymongo"
14
+ )
15
+
16
+ from google.genai import types
17
+ from .base_custom_artifact_service import BaseCustomArtifactService
18
+
19
+
20
+ class MongoArtifactService(BaseCustomArtifactService):
21
+ """MongoDB-based artifact service implementation."""
22
+
23
+ def __init__(self, connection_string: str, database_name: str = "adk_artifacts"):
24
+ """Initialize the MongoDB artifact service.
25
+
26
+ Args:
27
+ connection_string: MongoDB connection string
28
+ database_name: Name of the database to use
29
+ """
30
+ super().__init__()
31
+ self.connection_string = connection_string
32
+ self.database_name = database_name
33
+ self.client: Optional[MongoClient] = None
34
+ self.db = None
35
+ self.collection = None
36
+
37
+ async def _initialize_impl(self) -> None:
38
+ """Initialize the MongoDB connection."""
39
+ try:
40
+ self.client = MongoClient(self.connection_string)
41
+ self.db = self.client[self.database_name]
42
+ self.collection = self.db.artifacts
43
+
44
+ # Create indexes for better performance
45
+ self.collection.create_index([
46
+ ("app_name", 1),
47
+ ("user_id", 1),
48
+ ("session_id", 1),
49
+ ("filename", 1),
50
+ ("version", 1)
51
+ ])
52
+ except PyMongoError as e:
53
+ raise RuntimeError(f"Failed to initialize MongoDB artifact service: {e}")
54
+
55
+ async def _cleanup_impl(self) -> None:
56
+ """Clean up MongoDB connections."""
57
+ if self.client:
58
+ self.client.close()
59
+ self.client = None
60
+ self.db = None
61
+ self.collection = None
62
+
63
+ def _serialize_blob(self, part: types.Part) -> tuple[bytes, str]:
64
+ """Extract blob data and mime type from a Part."""
65
+ if part.inline_data:
66
+ return part.inline_data.data, part.inline_data.mime_type or "application/octet-stream"
67
+ else:
68
+ raise ValueError("Only inline_data parts are supported")
69
+
70
+ def _deserialize_blob(self, data: bytes, mime_type: str) -> types.Part:
71
+ """Create a Part from blob data and mime type."""
72
+ blob = types.Blob(data=data, mime_type=mime_type)
73
+ return types.Part(inline_data=blob)
74
+
75
+ async def _save_artifact_impl(
76
+ self,
77
+ *,
78
+ app_name: str,
79
+ user_id: str,
80
+ session_id: str,
81
+ filename: str,
82
+ artifact: types.Part,
83
+ ) -> int:
84
+ """Implementation of artifact saving."""
85
+ try:
86
+ # Extract blob data
87
+ data, mime_type = self._serialize_blob(artifact)
88
+
89
+ # Get the next version number
90
+ latest_version_doc = self.collection.find_one(
91
+ {
92
+ "app_name": app_name,
93
+ "user_id": user_id,
94
+ "session_id": session_id,
95
+ "filename": filename
96
+ },
97
+ sort=[("version", -1)]
98
+ )
99
+
100
+ version = (latest_version_doc["version"] + 1) if latest_version_doc else 0
101
+
102
+ # Create document
103
+ document = {
104
+ "app_name": app_name,
105
+ "user_id": user_id,
106
+ "session_id": session_id,
107
+ "filename": filename,
108
+ "version": version,
109
+ "mime_type": mime_type,
110
+ "data": data,
111
+ "created_at": datetime.utcnow()
112
+ }
113
+
114
+ # Insert into MongoDB
115
+ self.collection.insert_one(document)
116
+
117
+ return version
118
+ except PyMongoError as e:
119
+ raise RuntimeError(f"Failed to save artifact: {e}")
120
+
121
+ async def _load_artifact_impl(
122
+ self,
123
+ *,
124
+ app_name: str,
125
+ user_id: str,
126
+ session_id: str,
127
+ filename: str,
128
+ version: Optional[int] = None,
129
+ ) -> Optional[types.Part]:
130
+ """Implementation of artifact loading."""
131
+ try:
132
+ query = {
133
+ "app_name": app_name,
134
+ "user_id": user_id,
135
+ "session_id": session_id,
136
+ "filename": filename
137
+ }
138
+
139
+ if version is not None:
140
+ query["version"] = version
141
+ sort = None
142
+ else:
143
+ # Sort by version descending to get the latest
144
+ sort = [("version", -1)]
145
+
146
+ document = self.collection.find_one(query, sort=sort)
147
+
148
+ if not document:
149
+ return None
150
+
151
+ # Create Part from blob data
152
+ return self._deserialize_blob(document["data"], document["mime_type"])
153
+ except PyMongoError as e:
154
+ raise RuntimeError(f"Failed to load artifact: {e}")
155
+
156
+ async def _list_artifact_keys_impl(
157
+ self,
158
+ *,
159
+ app_name: str,
160
+ user_id: str,
161
+ session_id: str,
162
+ ) -> List[str]:
163
+ """Implementation of artifact key listing."""
164
+ try:
165
+ # Get distinct filenames
166
+ cursor = self.collection.distinct(
167
+ "filename",
168
+ {
169
+ "app_name": app_name,
170
+ "user_id": user_id,
171
+ "session_id": session_id
172
+ }
173
+ )
174
+
175
+ return list(cursor)
176
+ except PyMongoError as e:
177
+ raise RuntimeError(f"Failed to list artifact keys: {e}")
178
+
179
+ async def _delete_artifact_impl(
180
+ self,
181
+ *,
182
+ app_name: str,
183
+ user_id: str,
184
+ session_id: str,
185
+ filename: str,
186
+ ) -> None:
187
+ """Implementation of artifact deletion."""
188
+ try:
189
+ # Delete all versions of the artifact
190
+ self.collection.delete_many({
191
+ "app_name": app_name,
192
+ "user_id": user_id,
193
+ "session_id": session_id,
194
+ "filename": filename
195
+ })
196
+ except PyMongoError as e:
197
+ raise RuntimeError(f"Failed to delete artifact: {e}")
198
+
199
+ async def _list_versions_impl(
200
+ self,
201
+ *,
202
+ app_name: str,
203
+ user_id: str,
204
+ session_id: str,
205
+ filename: str,
206
+ ) -> List[int]:
207
+ """Implementation of version listing."""
208
+ try:
209
+ cursor = self.collection.find(
210
+ {
211
+ "app_name": app_name,
212
+ "user_id": user_id,
213
+ "session_id": session_id,
214
+ "filename": filename
215
+ },
216
+ {"version": 1, "_id": 0}
217
+ ).sort("version", 1)
218
+
219
+ return [doc["version"] for doc in cursor]
220
+ except PyMongoError as e:
221
+ raise RuntimeError(f"Failed to list versions: {e}")