google-adk-extras 0.1.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. google_adk_extras/__init__.py +31 -1
  2. google_adk_extras/adk_builder.py +1030 -0
  3. google_adk_extras/artifacts/__init__.py +25 -12
  4. google_adk_extras/artifacts/base_custom_artifact_service.py +148 -11
  5. google_adk_extras/artifacts/local_folder_artifact_service.py +133 -13
  6. google_adk_extras/artifacts/s3_artifact_service.py +135 -19
  7. google_adk_extras/artifacts/sql_artifact_service.py +109 -10
  8. google_adk_extras/credentials/__init__.py +34 -0
  9. google_adk_extras/credentials/base_custom_credential_service.py +113 -0
  10. google_adk_extras/credentials/github_oauth2_credential_service.py +213 -0
  11. google_adk_extras/credentials/google_oauth2_credential_service.py +216 -0
  12. google_adk_extras/credentials/http_basic_auth_credential_service.py +388 -0
  13. google_adk_extras/credentials/jwt_credential_service.py +345 -0
  14. google_adk_extras/credentials/microsoft_oauth2_credential_service.py +250 -0
  15. google_adk_extras/credentials/x_oauth2_credential_service.py +240 -0
  16. google_adk_extras/custom_agent_loader.py +156 -0
  17. google_adk_extras/enhanced_adk_web_server.py +137 -0
  18. google_adk_extras/enhanced_fastapi.py +470 -0
  19. google_adk_extras/enhanced_runner.py +38 -0
  20. google_adk_extras/memory/__init__.py +30 -13
  21. google_adk_extras/memory/base_custom_memory_service.py +37 -5
  22. google_adk_extras/memory/sql_memory_service.py +105 -19
  23. google_adk_extras/memory/yaml_file_memory_service.py +115 -22
  24. google_adk_extras/sessions/__init__.py +29 -13
  25. google_adk_extras/sessions/base_custom_session_service.py +133 -11
  26. google_adk_extras/sessions/sql_session_service.py +127 -16
  27. google_adk_extras/sessions/yaml_file_session_service.py +122 -14
  28. google_adk_extras-0.2.3.dist-info/METADATA +302 -0
  29. google_adk_extras-0.2.3.dist-info/RECORD +37 -0
  30. google_adk_extras/py.typed +0 -0
  31. google_adk_extras-0.1.1.dist-info/METADATA +0 -175
  32. google_adk_extras-0.1.1.dist-info/RECORD +0 -25
  33. {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.3.dist-info}/WHEEL +0 -0
  34. {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.3.dist-info}/licenses/LICENSE +0 -0
  35. {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,470 @@
1
+ """Enhanced FastAPI app creation with credential service support.
2
+
3
+ This module provides an enhanced version of Google ADK's get_fast_api_app function
4
+ that properly supports custom credential services.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ from pathlib import Path
11
+ import shutil
12
+ from typing import Any, Mapping, Optional, List, Callable, Dict
13
+
14
+ import click
15
+ from fastapi import FastAPI
16
+ from fastapi import UploadFile
17
+ from fastapi.responses import FileResponse
18
+ from fastapi.responses import PlainTextResponse
19
+ from starlette.types import Lifespan
20
+ from watchdog.observers import Observer
21
+
22
+ from google.adk.artifacts.gcs_artifact_service import GcsArtifactService
23
+ from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
24
+ from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
25
+ from google.adk.auth.credential_service.base_credential_service import BaseCredentialService
26
+ from google.adk.evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
27
+ from google.adk.evaluation.local_eval_sets_manager import LocalEvalSetsManager
28
+ from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
29
+ from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
30
+ from google.adk.runners import Runner
31
+ from google.adk.sessions.in_memory_session_service import InMemorySessionService
32
+ from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService
33
+ from google.adk.sessions.database_session_service import DatabaseSessionService
34
+ from google.adk.utils.feature_decorator import working_in_progress
35
+ from google.adk.cli.adk_web_server import AdkWebServer
36
+ from .enhanced_adk_web_server import EnhancedAdkWebServer
37
+ from google.adk.cli.utils import envs
38
+ from google.adk.cli.utils import evals
39
+ from google.adk.cli.utils.agent_change_handler import AgentChangeEventHandler
40
+ from google.adk.cli.utils.agent_loader import AgentLoader
41
+ from google.adk.cli.utils.base_agent_loader import BaseAgentLoader
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ def get_enhanced_fast_api_app(
47
+ *,
48
+ agents_dir: Optional[str] = None,
49
+ agent_loader: Optional[BaseAgentLoader] = None,
50
+ session_service_uri: Optional[str] = None,
51
+ session_db_kwargs: Optional[Mapping[str, Any]] = None,
52
+ artifact_service_uri: Optional[str] = None,
53
+ memory_service_uri: Optional[str] = None,
54
+ credential_service: Optional[BaseCredentialService] = None, # Optional credential service
55
+ eval_storage_uri: Optional[str] = None,
56
+ allow_origins: Optional[List[str]] = None,
57
+ web: bool = True,
58
+ a2a: bool = False,
59
+ programmatic_a2a: bool = False,
60
+ programmatic_a2a_mount_base: str = "/a2a",
61
+ programmatic_a2a_card_factory: Optional[Callable[[str, Any], Dict[str, Any]]] = None,
62
+ host: str = "127.0.0.1",
63
+ port: int = 8000,
64
+ trace_to_cloud: bool = False,
65
+ reload_agents: bool = False,
66
+ lifespan: Optional[Lifespan[FastAPI]] = None,
67
+ ) -> FastAPI:
68
+ """Enhanced version of Google ADK's get_fast_api_app with EnhancedRunner integration.
69
+
70
+ This function extends Google ADK's get_fast_api_app with enhanced capabilities:
71
+ 1. Uses EnhancedAdkWebServer which creates EnhancedRunner instances
72
+ 2. Supports custom credential services instead of hardcoding InMemoryCredentialService
73
+ 3. Supports custom agent loading logic
74
+ 4. Provides advanced tool execution strategies (MCP, OpenAPI, Function tools)
75
+ 5. Enables circuit breakers, retry policies, and performance monitoring
76
+ 6. Supports YAML-driven configuration and error context
77
+
78
+ Args:
79
+ agents_dir: Directory containing agent definitions (optional if agent_loader provided).
80
+ agent_loader: Custom agent loader instance (optional if agents_dir provided).
81
+ session_service_uri: Session service URI.
82
+ session_db_kwargs: Additional database configuration for session service.
83
+ artifact_service_uri: Artifact service URI.
84
+ memory_service_uri: Memory service URI.
85
+ credential_service: Custom credential service instance.
86
+ eval_storage_uri: Evaluation storage URI.
87
+ allow_origins: CORS allowed origins.
88
+ web: Whether to serve web UI.
89
+ a2a: Whether to enable A2A protocol.
90
+ host: Server host.
91
+ port: Server port.
92
+ trace_to_cloud: Whether to enable cloud tracing.
93
+ reload_agents: Whether to enable hot reloading.
94
+ lifespan: FastAPI lifespan callable.
95
+ (Enhanced runner options removed for simplified scope.)
96
+
97
+ Returns:
98
+ FastAPI: Configured FastAPI application.
99
+
100
+ Raises:
101
+ ValueError: If neither agents_dir nor agent_loader is provided.
102
+ """
103
+ # Validate agent configuration
104
+ if not agent_loader and not agents_dir:
105
+ raise ValueError("Either agent_loader or agents_dir must be provided")
106
+
107
+ # Create or use provided agent loader
108
+ if agent_loader is not None:
109
+ final_agent_loader = agent_loader
110
+ # Try to extract agents_dir from AgentLoader for compatibility
111
+ if agents_dir is None and hasattr(agent_loader, 'agents_dir'):
112
+ agents_dir = agent_loader.agents_dir
113
+ elif agents_dir is None:
114
+ # For non-directory loaders, create a temp dir for eval managers
115
+ import tempfile
116
+ agents_dir = tempfile.gettempdir()
117
+ else:
118
+ final_agent_loader = AgentLoader(agents_dir)
119
+
120
+ logger.info("Using agent loader: %s", type(final_agent_loader).__name__)
121
+
122
+ # Set up eval managers (same as ADK)
123
+ if eval_storage_uri:
124
+ gcs_eval_managers = evals.create_gcs_eval_managers_from_uri(eval_storage_uri)
125
+ eval_sets_manager = gcs_eval_managers.eval_sets_manager
126
+ eval_set_results_manager = gcs_eval_managers.eval_set_results_manager
127
+ else:
128
+ eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
129
+ eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
130
+
131
+ def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name):
132
+ """Parse agent engine resource name (same as ADK)."""
133
+ if not agent_engine_id_or_resource_name:
134
+ raise click.ClickException(
135
+ "Agent engine resource name or resource id can not be empty."
136
+ )
137
+
138
+ if "/" in agent_engine_id_or_resource_name:
139
+ if len(agent_engine_id_or_resource_name.split("/")) != 6:
140
+ raise click.ClickException(
141
+ "Agent engine resource name is mal-formatted. It should be of"
142
+ " format: projects/{project_id}/locations/{location}/reasoningEngines/{resource_id}"
143
+ )
144
+ project = agent_engine_id_or_resource_name.split("/")[1]
145
+ location = agent_engine_id_or_resource_name.split("/")[3]
146
+ agent_engine_id = agent_engine_id_or_resource_name.split("/")[-1]
147
+ else:
148
+ envs.load_dotenv_for_agent("", agents_dir)
149
+ project = os.environ["GOOGLE_CLOUD_PROJECT"]
150
+ location = os.environ["GOOGLE_CLOUD_LOCATION"]
151
+ agent_engine_id = agent_engine_id_or_resource_name
152
+ return project, location, agent_engine_id
153
+
154
+ # Build the Memory service (same as ADK)
155
+ if memory_service_uri:
156
+ if memory_service_uri.startswith("rag://"):
157
+ from google.adk.memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
158
+ rag_corpus = memory_service_uri.split("://")[1]
159
+ if not rag_corpus:
160
+ raise click.ClickException("Rag corpus can not be empty.")
161
+ envs.load_dotenv_for_agent("", agents_dir)
162
+ memory_service = VertexAiRagMemoryService(
163
+ rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}'
164
+ )
165
+ elif memory_service_uri.startswith("agentengine://"):
166
+ agent_engine_id_or_resource_name = memory_service_uri.split("://")[1]
167
+ project, location, agent_engine_id = _parse_agent_engine_resource_name(
168
+ agent_engine_id_or_resource_name
169
+ )
170
+ memory_service = VertexAiMemoryBankService(
171
+ project=project,
172
+ location=location,
173
+ agent_engine_id=agent_engine_id,
174
+ )
175
+ else:
176
+ raise click.ClickException(
177
+ "Unsupported memory service URI: %s" % memory_service_uri
178
+ )
179
+ else:
180
+ memory_service = InMemoryMemoryService()
181
+
182
+ # Build the Session service (same as ADK)
183
+ if session_service_uri:
184
+ if session_service_uri.startswith("agentengine://"):
185
+ agent_engine_id_or_resource_name = session_service_uri.split("://")[1]
186
+ project, location, agent_engine_id = _parse_agent_engine_resource_name(
187
+ agent_engine_id_or_resource_name
188
+ )
189
+ session_service = VertexAiSessionService(
190
+ project=project,
191
+ location=location,
192
+ agent_engine_id=agent_engine_id,
193
+ )
194
+ else:
195
+ # Database session additional settings
196
+ if session_db_kwargs is None:
197
+ session_db_kwargs = {}
198
+ session_service = DatabaseSessionService(
199
+ db_url=session_service_uri, **session_db_kwargs
200
+ )
201
+ else:
202
+ session_service = InMemorySessionService()
203
+
204
+ # Build the Artifact service (same as ADK)
205
+ if artifact_service_uri:
206
+ if artifact_service_uri.startswith("gs://"):
207
+ gcs_bucket = artifact_service_uri.split("://")[1]
208
+ artifact_service = GcsArtifactService(bucket_name=gcs_bucket)
209
+ else:
210
+ raise click.ClickException(
211
+ "Unsupported artifact service URI: %s" % artifact_service_uri
212
+ )
213
+ else:
214
+ artifact_service = InMemoryArtifactService()
215
+
216
+ # Credential service is optional; EnhancedAdkWebServer will default if needed
217
+ credential_service_instance = credential_service
218
+ if credential_service_instance is None:
219
+ logger.info("No credential service provided; server will use its default")
220
+
221
+ # Use configured agent loader (enhanced from ADK)
222
+
223
+ # Create EnhancedAdkWebServer with our custom credential service and enhanced features
224
+ adk_web_server = EnhancedAdkWebServer(
225
+ # Standard ADK parameters
226
+ agent_loader=final_agent_loader,
227
+ session_service=session_service,
228
+ artifact_service=artifact_service,
229
+ memory_service=memory_service,
230
+ credential_service=credential_service_instance, # Use our custom service
231
+ eval_sets_manager=eval_sets_manager,
232
+ eval_set_results_manager=eval_set_results_manager,
233
+ agents_dir=agents_dir,
234
+ )
235
+
236
+ # Callbacks & other optional args for FastAPI instance (same as ADK)
237
+ extra_fast_api_args = {}
238
+
239
+ if trace_to_cloud:
240
+ logger.warning(
241
+ "trace_to_cloud requested but OpenTelemetry exporters are not bundled. "
242
+ "Tracing is disabled."
243
+ )
244
+
245
+ if reload_agents:
246
+ def setup_observer(observer: Observer, adk_web_server: AdkWebServer):
247
+ agent_change_handler = AgentChangeEventHandler(
248
+ agent_loader=final_agent_loader,
249
+ runners_to_clean=adk_web_server.runners_to_clean,
250
+ current_app_name_ref=adk_web_server.current_app_name_ref,
251
+ )
252
+ observer.schedule(agent_change_handler, agents_dir, recursive=True)
253
+ observer.start()
254
+
255
+ def tear_down_observer(observer: Observer, _: AdkWebServer):
256
+ observer.stop()
257
+ observer.join()
258
+
259
+ extra_fast_api_args.update(
260
+ setup_observer=setup_observer,
261
+ tear_down_observer=tear_down_observer,
262
+ )
263
+
264
+ if web:
265
+ try:
266
+ # Try to find ADK's web assets
267
+ from google.adk.cli.fast_api import BASE_DIR
268
+ ANGULAR_DIST_PATH = BASE_DIR / "browser"
269
+ except (ImportError, AttributeError):
270
+ # Fallback if ADK structure changes
271
+ BASE_DIR = Path(__file__).parent.resolve()
272
+ ANGULAR_DIST_PATH = BASE_DIR / "browser"
273
+
274
+ if ANGULAR_DIST_PATH.exists():
275
+ extra_fast_api_args.update(web_assets_dir=ANGULAR_DIST_PATH)
276
+ else:
277
+ logger.warning("Web UI assets not found, web interface will not be available")
278
+
279
+ # Create FastAPI app
280
+ app = adk_web_server.get_fast_api_app(
281
+ lifespan=lifespan,
282
+ allow_origins=allow_origins,
283
+ **extra_fast_api_args,
284
+ )
285
+
286
+ # Store the ADK web server in app state for testing access
287
+ app.state.adk_web_server = adk_web_server
288
+
289
+ # Add additional endpoints that ADK normally adds
290
+ @working_in_progress("builder_save is not ready for use.")
291
+ @app.post("/builder/save", response_model_exclude_none=True)
292
+ async def builder_build(files: List[UploadFile]) -> bool:
293
+ base_path = Path.cwd() / agents_dir
294
+ for file in files:
295
+ try:
296
+ if not file.filename:
297
+ logger.exception("Agent name is missing in the input files")
298
+ return False
299
+ agent_name, filename = file.filename.split("/")
300
+ agent_dir = os.path.join(base_path, agent_name)
301
+ os.makedirs(agent_dir, exist_ok=True)
302
+ file_path = os.path.join(agent_dir, filename)
303
+ with open(file_path, "wb") as buffer:
304
+ shutil.copyfileobj(file.file, buffer)
305
+ except Exception as e:
306
+ logger.exception("Error in builder_build: %s", e)
307
+ return False
308
+ return True
309
+
310
+ @working_in_progress("builder_get is not ready for use.")
311
+ @app.get(
312
+ "/builder/app/{app_name}",
313
+ response_model_exclude_none=True,
314
+ response_class=PlainTextResponse,
315
+ )
316
+ async def get_agent_builder(app_name: str, file_path: Optional[str] = None):
317
+ base_path = Path.cwd() / agents_dir
318
+ agent_dir = base_path / app_name
319
+ if not file_path:
320
+ file_name = "root_agent.yaml"
321
+ root_file_path = agent_dir / file_name
322
+ if not root_file_path.is_file():
323
+ return ""
324
+ else:
325
+ return FileResponse(
326
+ path=root_file_path,
327
+ media_type="application/x-yaml",
328
+ filename=f"{app_name}.yaml",
329
+ headers={"Cache-Control": "no-store"},
330
+ )
331
+ else:
332
+ agent_file_path = agent_dir / file_path
333
+ if not agent_file_path.is_file():
334
+ return ""
335
+ else:
336
+ return FileResponse(
337
+ path=agent_file_path,
338
+ media_type="application/x-yaml",
339
+ filename=file_path,
340
+ headers={"Cache-Control": "no-store"},
341
+ )
342
+
343
+ # A2A protocol support (same as ADK)
344
+ if a2a:
345
+ try:
346
+ from a2a.server.apps import A2AStarletteApplication
347
+ from a2a.server.request_handlers import DefaultRequestHandler
348
+ from a2a.server.tasks import InMemoryTaskStore
349
+ from a2a.types import AgentCard
350
+ from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
351
+ from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor
352
+
353
+ except ImportError as e:
354
+ import sys
355
+ if sys.version_info < (3, 12):
356
+ raise ImportError(
357
+ "A2A requires Python 3.12 or above. Please upgrade your Python version."
358
+ ) from e
359
+ else:
360
+ raise e
361
+
362
+ a2a_task_store = InMemoryTaskStore()
363
+
364
+ def create_a2a_runner_loader(captured_app_name: str):
365
+ async def _get_a2a_runner_async() -> Runner:
366
+ return await adk_web_server.get_runner_async(captured_app_name)
367
+ return _get_a2a_runner_async
368
+
369
+ # 1) Directory-based A2A (existing behavior)
370
+ if agents_dir:
371
+ base_path = Path.cwd() / agents_dir
372
+ if base_path.exists() and base_path.is_dir():
373
+ for p in base_path.iterdir():
374
+ try:
375
+ if (
376
+ p.is_file()
377
+ or p.name.startswith((".", "__pycache__"))
378
+ or not (p / "agent.json").is_file()
379
+ ):
380
+ continue
381
+ except PermissionError:
382
+ # Skip directories we cannot inspect
383
+ continue
384
+
385
+ app_name = p.name
386
+ logger.info("Setting up A2A agent (dir): %s", app_name)
387
+
388
+ try:
389
+ agent_executor = A2aAgentExecutor(
390
+ runner=create_a2a_runner_loader(app_name),
391
+ )
392
+ request_handler = DefaultRequestHandler(
393
+ agent_executor=agent_executor, task_store=a2a_task_store
394
+ )
395
+ with (p / "agent.json").open("r", encoding="utf-8") as f:
396
+ data = json.load(f)
397
+ agent_card = AgentCard(**data)
398
+ a2a_app = A2AStarletteApplication(
399
+ agent_card=agent_card,
400
+ http_handler=request_handler,
401
+ )
402
+ routes = a2a_app.routes(
403
+ rpc_url=f"/a2a/{app_name}",
404
+ agent_card_url=f"/a2a/{app_name}{AGENT_CARD_WELL_KNOWN_PATH}",
405
+ )
406
+ for new_route in routes:
407
+ app.router.routes.append(new_route)
408
+ logger.info("Configured A2A agent (dir): %s", app_name)
409
+ except Exception as e:
410
+ logger.error("Failed to setup A2A agent %s: %s", app_name, e)
411
+
412
+ # 2) Programmatic A2A for registered agents (no agents_dir)
413
+ if programmatic_a2a:
414
+ # Attempt to enumerate agents from the provided loader
415
+ agent_names = []
416
+ if hasattr(final_agent_loader, "list_agents"):
417
+ try:
418
+ agent_names = final_agent_loader.list_agents() # type: ignore[attr-defined]
419
+ except Exception:
420
+ agent_names = []
421
+
422
+ for app_name in agent_names:
423
+ try:
424
+ agent_instance = final_agent_loader.load_agent(app_name)
425
+ except Exception:
426
+ agent_instance = None
427
+
428
+ logger.info("Setting up A2A agent (programmatic): %s", app_name)
429
+ try:
430
+ # Construct AgentCard data
431
+ data: Dict[str, Any]
432
+ if programmatic_a2a_card_factory and agent_instance is not None:
433
+ try:
434
+ data = programmatic_a2a_card_factory(app_name, agent_instance)
435
+ except TypeError:
436
+ # Backward compatibility: factory taking only name
437
+ data = programmatic_a2a_card_factory(app_name) # type: ignore[misc]
438
+ else:
439
+ # Minimal default card
440
+ data = {
441
+ "name": app_name,
442
+ "description": f"A2A-exposed agent {app_name}",
443
+ "defaultInputModes": ["text/plain"],
444
+ "defaultOutputModes": ["application/json"],
445
+ "version": "1.0.0",
446
+ }
447
+
448
+ agent_executor = A2aAgentExecutor(
449
+ runner=create_a2a_runner_loader(app_name),
450
+ )
451
+ request_handler = DefaultRequestHandler(
452
+ agent_executor=agent_executor, task_store=a2a_task_store
453
+ )
454
+ agent_card = AgentCard(**data)
455
+ a2a_app = A2AStarletteApplication(
456
+ agent_card=agent_card,
457
+ http_handler=request_handler,
458
+ )
459
+ routes = a2a_app.routes(
460
+ rpc_url=f"{programmatic_a2a_mount_base}/{app_name}",
461
+ agent_card_url=f"{programmatic_a2a_mount_base}/{app_name}{AGENT_CARD_WELL_KNOWN_PATH}",
462
+ )
463
+ for new_route in routes:
464
+ app.router.routes.append(new_route)
465
+ logger.info("Configured A2A agent (programmatic): %s", app_name)
466
+ except Exception as e:
467
+ logger.error("Failed to setup programmatic A2A agent %s: %s", app_name, e)
468
+
469
+ logger.info("Enhanced FastAPI app created with credential service support")
470
+ return app
@@ -0,0 +1,38 @@
1
+ """Thin wrapper over google.adk.runners.Runner.
2
+
3
+ EnhancedRunner exists for compatibility with this package’s FastAPI server
4
+ integration. It does not add behavior beyond the base ADK Runner.
5
+ """
6
+
7
+ from typing import List, Optional
8
+
9
+ from google.adk.agents.base_agent import BaseAgent
10
+ from google.adk.artifacts.base_artifact_service import BaseArtifactService
11
+ from google.adk.auth.credential_service.base_credential_service import BaseCredentialService
12
+ from google.adk.plugins.base_plugin import BasePlugin
13
+ from google.adk.runners import Runner
14
+ from google.adk.sessions.base_session_service import BaseSessionService
15
+ from google.adk.memory.base_memory_service import BaseMemoryService
16
+
17
+
18
+ class EnhancedRunner(Runner):
19
+ def __init__(
20
+ self,
21
+ *,
22
+ app_name: str,
23
+ agent: BaseAgent,
24
+ plugins: Optional[List[BasePlugin]] = None,
25
+ artifact_service: Optional[BaseArtifactService] = None,
26
+ session_service: BaseSessionService,
27
+ memory_service: Optional[BaseMemoryService] = None,
28
+ credential_service: Optional[BaseCredentialService] = None,
29
+ ):
30
+ super().__init__(
31
+ app_name=app_name,
32
+ agent=agent,
33
+ plugins=plugins,
34
+ artifact_service=artifact_service,
35
+ session_service=session_service,
36
+ memory_service=memory_service,
37
+ credential_service=credential_service,
38
+ )
@@ -1,15 +1,32 @@
1
- """Custom ADK memory services package."""
1
+ """Custom ADK memory services package.
2
+
3
+ Optional backends are imported lazily based on installed dependencies.
4
+ """
2
5
 
3
6
  from .base_custom_memory_service import BaseCustomMemoryService
4
- from .sql_memory_service import SQLMemoryService
5
- from .mongo_memory_service import MongoMemoryService
6
- from .redis_memory_service import RedisMemoryService
7
- from .yaml_file_memory_service import YamlFileMemoryService
8
-
9
- __all__ = [
10
- "BaseCustomMemoryService",
11
- "SQLMemoryService",
12
- "MongoMemoryService",
13
- "RedisMemoryService",
14
- "YamlFileMemoryService",
15
- ]
7
+
8
+ # Optional dependencies
9
+ try:
10
+ from .sql_memory_service import SQLMemoryService # type: ignore
11
+ except Exception:
12
+ SQLMemoryService = None # type: ignore
13
+
14
+ try:
15
+ from .mongo_memory_service import MongoMemoryService # type: ignore
16
+ except Exception:
17
+ MongoMemoryService = None # type: ignore
18
+
19
+ try:
20
+ from .redis_memory_service import RedisMemoryService # type: ignore
21
+ except Exception:
22
+ RedisMemoryService = None # type: ignore
23
+
24
+ try:
25
+ from .yaml_file_memory_service import YamlFileMemoryService # type: ignore
26
+ except Exception:
27
+ YamlFileMemoryService = None # type: ignore
28
+
29
+ __all__ = ["BaseCustomMemoryService"]
30
+ for _name in ("SQLMemoryService", "MongoMemoryService", "RedisMemoryService", "YamlFileMemoryService"):
31
+ if globals().get(_name) is not None:
32
+ __all__.append(_name)
@@ -10,7 +10,11 @@ if TYPE_CHECKING:
10
10
 
11
11
 
12
12
  class BaseCustomMemoryService(BaseMemoryService):
13
- """Base class for custom memory services with common functionality."""
13
+ """Base class for custom memory services with common functionality.
14
+
15
+ This abstract base class provides a foundation for implementing custom
16
+ memory services with automatic initialization and cleanup handling.
17
+ """
14
18
 
15
19
  def __init__(self):
16
20
  """Initialize the base custom memory service."""
@@ -23,6 +27,9 @@ class BaseCustomMemoryService(BaseMemoryService):
23
27
 
24
28
  Args:
25
29
  session: The session to add to memory.
30
+
31
+ Raises:
32
+ RuntimeError: If adding the session to memory fails.
26
33
  """
27
34
 
28
35
  @abstractmethod
@@ -38,6 +45,9 @@ class BaseCustomMemoryService(BaseMemoryService):
38
45
 
39
46
  Returns:
40
47
  A SearchMemoryResponse containing the matching memories.
48
+
49
+ Raises:
50
+ RuntimeError: If searching memory fails.
41
51
  """
42
52
 
43
53
  async def add_session_to_memory(self, session: "Session") -> None:
@@ -45,6 +55,9 @@ class BaseCustomMemoryService(BaseMemoryService):
45
55
 
46
56
  Args:
47
57
  session: The session to add.
58
+
59
+ Raises:
60
+ RuntimeError: If adding the session to memory fails.
48
61
  """
49
62
  if not self._initialized:
50
63
  await self.initialize()
@@ -62,6 +75,9 @@ class BaseCustomMemoryService(BaseMemoryService):
62
75
 
63
76
  Returns:
64
77
  A SearchMemoryResponse containing the matching memories.
78
+
79
+ Raises:
80
+ RuntimeError: If searching memory fails.
65
81
  """
66
82
  if not self._initialized:
67
83
  await self.initialize()
@@ -70,21 +86,37 @@ class BaseCustomMemoryService(BaseMemoryService):
70
86
  )
71
87
 
72
88
  async def initialize(self) -> None:
73
- """Initialize the memory service."""
89
+ """Initialize the memory service.
90
+
91
+ Raises:
92
+ RuntimeError: If initialization fails.
93
+ """
74
94
  if not self._initialized:
75
95
  await self._initialize_impl()
76
96
  self._initialized = True
77
97
 
78
98
  async def cleanup(self) -> None:
79
- """Clean up the memory service."""
99
+ """Clean up the memory service.
100
+
101
+ Raises:
102
+ RuntimeError: If cleanup fails.
103
+ """
80
104
  if self._initialized:
81
105
  await self._cleanup_impl()
82
106
  self._initialized = False
83
107
 
84
108
  @abstractmethod
85
109
  async def _initialize_impl(self) -> None:
86
- """Implementation of initialization."""
110
+ """Implementation of initialization.
111
+
112
+ Raises:
113
+ RuntimeError: If initialization fails.
114
+ """
87
115
 
88
116
  @abstractmethod
89
117
  async def _cleanup_impl(self) -> None:
90
- """Implementation of cleanup."""
118
+ """Implementation of cleanup.
119
+
120
+ Raises:
121
+ RuntimeError: If cleanup fails.
122
+ """