google-adk-extras 0.2.3__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_adk_extras/__init__.py +1 -1
- google_adk_extras/custom_agent_loader.py +15 -1
- google_adk_extras/enhanced_fastapi.py +138 -4
- google_adk_extras/streaming/__init__.py +12 -0
- google_adk_extras/streaming/streaming_controller.py +262 -0
- {google_adk_extras-0.2.3.dist-info → google_adk_extras-0.2.6.dist-info}/METADATA +1 -1
- {google_adk_extras-0.2.3.dist-info → google_adk_extras-0.2.6.dist-info}/RECORD +10 -8
- {google_adk_extras-0.2.3.dist-info → google_adk_extras-0.2.6.dist-info}/WHEEL +0 -0
- {google_adk_extras-0.2.3.dist-info → google_adk_extras-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {google_adk_extras-0.2.3.dist-info → google_adk_extras-0.2.6.dist-info}/top_level.txt +0 -0
google_adk_extras/__init__.py
CHANGED
@@ -146,6 +146,20 @@ class CustomAgentLoader(BaseAgentLoader):
|
|
146
146
|
sorted_agents = sorted(agent_names)
|
147
147
|
logger.debug("Total registered agents: %d", len(sorted_agents))
|
148
148
|
return sorted_agents
|
149
|
+
|
150
|
+
# Compatibility with ADK's AgentLoader API used by AgentChangeEventHandler
|
151
|
+
def remove_agent_from_cache(self, name: str) -> None:
|
152
|
+
"""No-op cache invalidation for compatibility with ADK hot reload.
|
153
|
+
|
154
|
+
ADK's file-watcher calls `agent_loader.remove_agent_from_cache(current_app)`
|
155
|
+
when files change. Our loader does not cache filesystem-loaded agents,
|
156
|
+
but we provide this method to satisfy the expected interface.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
name: Agent name to invalidate (ignored here).
|
160
|
+
"""
|
161
|
+
# Nothing to do; present for interface compatibility.
|
162
|
+
logger.debug("CustomAgentLoader.remove_agent_from_cache(%s) - no-op", name)
|
149
163
|
|
150
164
|
|
151
165
|
def __repr__(self) -> str:
|
@@ -153,4 +167,4 @@ class CustomAgentLoader(BaseAgentLoader):
|
|
153
167
|
with self._lock:
|
154
168
|
registered_count = len(self._registered_agents)
|
155
169
|
|
156
|
-
return f"CustomAgentLoader(registered={registered_count})"
|
170
|
+
return f"CustomAgentLoader(registered={registered_count})"
|
@@ -5,6 +5,7 @@ that properly supports custom credential services.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import json
|
8
|
+
import asyncio
|
8
9
|
import logging
|
9
10
|
import os
|
10
11
|
from pathlib import Path
|
@@ -34,6 +35,7 @@ from google.adk.sessions.database_session_service import DatabaseSessionService
|
|
34
35
|
from google.adk.utils.feature_decorator import working_in_progress
|
35
36
|
from google.adk.cli.adk_web_server import AdkWebServer
|
36
37
|
from .enhanced_adk_web_server import EnhancedAdkWebServer
|
38
|
+
from .streaming import StreamingController, StreamingConfig
|
37
39
|
from google.adk.cli.utils import envs
|
38
40
|
from google.adk.cli.utils import evals
|
39
41
|
from google.adk.cli.utils.agent_change_handler import AgentChangeEventHandler
|
@@ -64,6 +66,9 @@ def get_enhanced_fast_api_app(
|
|
64
66
|
trace_to_cloud: bool = False,
|
65
67
|
reload_agents: bool = False,
|
66
68
|
lifespan: Optional[Lifespan[FastAPI]] = None,
|
69
|
+
# Streaming layer (optional)
|
70
|
+
enable_streaming: bool = False,
|
71
|
+
streaming_config: Optional[StreamingConfig] = None,
|
67
72
|
) -> FastAPI:
|
68
73
|
"""Enhanced version of Google ADK's get_fast_api_app with EnhancedRunner integration.
|
69
74
|
|
@@ -151,7 +156,7 @@ def get_enhanced_fast_api_app(
|
|
151
156
|
agent_engine_id = agent_engine_id_or_resource_name
|
152
157
|
return project, location, agent_engine_id
|
153
158
|
|
154
|
-
# Build the Memory service (
|
159
|
+
# Build the Memory service (enhanced to recognize extras URIs)
|
155
160
|
if memory_service_uri:
|
156
161
|
if memory_service_uri.startswith("rag://"):
|
157
162
|
from google.adk.memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
|
@@ -172,6 +177,19 @@ def get_enhanced_fast_api_app(
|
|
172
177
|
location=location,
|
173
178
|
agent_engine_id=agent_engine_id,
|
174
179
|
)
|
180
|
+
elif memory_service_uri.startswith("yaml://"):
|
181
|
+
from .memory.yaml_file_memory_service import YamlFileMemoryService
|
182
|
+
base_directory = memory_service_uri.split("://")[1]
|
183
|
+
memory_service = YamlFileMemoryService(base_directory=base_directory)
|
184
|
+
elif memory_service_uri.startswith("redis://"):
|
185
|
+
from .memory.redis_memory_service import RedisMemoryService
|
186
|
+
memory_service = RedisMemoryService(connection_string=memory_service_uri) # type: ignore[arg-type]
|
187
|
+
elif memory_service_uri.startswith(("sqlite://", "postgresql://", "mysql://")):
|
188
|
+
from .memory.sql_memory_service import SQLMemoryService
|
189
|
+
memory_service = SQLMemoryService(database_url=memory_service_uri)
|
190
|
+
elif memory_service_uri.startswith("mongodb://"):
|
191
|
+
from .memory.mongo_memory_service import MongoMemoryService
|
192
|
+
memory_service = MongoMemoryService(connection_string=memory_service_uri)
|
175
193
|
else:
|
176
194
|
raise click.ClickException(
|
177
195
|
"Unsupported memory service URI: %s" % memory_service_uri
|
@@ -179,7 +197,7 @@ def get_enhanced_fast_api_app(
|
|
179
197
|
else:
|
180
198
|
memory_service = InMemoryMemoryService()
|
181
199
|
|
182
|
-
# Build the Session service (
|
200
|
+
# Build the Session service (enhanced to recognize extras URIs)
|
183
201
|
if session_service_uri:
|
184
202
|
if session_service_uri.startswith("agentengine://"):
|
185
203
|
agent_engine_id_or_resource_name = session_service_uri.split("://")[1]
|
@@ -191,8 +209,18 @@ def get_enhanced_fast_api_app(
|
|
191
209
|
location=location,
|
192
210
|
agent_engine_id=agent_engine_id,
|
193
211
|
)
|
212
|
+
elif session_service_uri.startswith("yaml://"):
|
213
|
+
from .sessions.yaml_file_session_service import YamlFileSessionService
|
214
|
+
base_directory = session_service_uri.split("://")[1]
|
215
|
+
session_service = YamlFileSessionService(base_directory=base_directory)
|
216
|
+
elif session_service_uri.startswith("redis://"):
|
217
|
+
from .sessions.redis_session_service import RedisSessionService
|
218
|
+
session_service = RedisSessionService(connection_string=session_service_uri) # type: ignore[arg-type]
|
219
|
+
elif session_service_uri.startswith("mongodb://"):
|
220
|
+
from .sessions.mongo_session_service import MongoSessionService
|
221
|
+
session_service = MongoSessionService(connection_string=session_service_uri)
|
194
222
|
else:
|
195
|
-
#
|
223
|
+
# Treat remaining schemes as database URLs (sqlite/postgres/mysql)
|
196
224
|
if session_db_kwargs is None:
|
197
225
|
session_db_kwargs = {}
|
198
226
|
session_service = DatabaseSessionService(
|
@@ -201,11 +229,25 @@ def get_enhanced_fast_api_app(
|
|
201
229
|
else:
|
202
230
|
session_service = InMemorySessionService()
|
203
231
|
|
204
|
-
# Build the Artifact service (
|
232
|
+
# Build the Artifact service (enhanced to recognize extras URIs)
|
205
233
|
if artifact_service_uri:
|
206
234
|
if artifact_service_uri.startswith("gs://"):
|
207
235
|
gcs_bucket = artifact_service_uri.split("://")[1]
|
208
236
|
artifact_service = GcsArtifactService(bucket_name=gcs_bucket)
|
237
|
+
elif artifact_service_uri.startswith("local://"):
|
238
|
+
from .artifacts.local_folder_artifact_service import LocalFolderArtifactService
|
239
|
+
base_directory = artifact_service_uri.split("://")[1]
|
240
|
+
artifact_service = LocalFolderArtifactService(base_directory=base_directory)
|
241
|
+
elif artifact_service_uri.startswith("s3://"):
|
242
|
+
from .artifacts.s3_artifact_service import S3ArtifactService
|
243
|
+
bucket_name = artifact_service_uri.split("://")[1]
|
244
|
+
artifact_service = S3ArtifactService(bucket_name=bucket_name)
|
245
|
+
elif artifact_service_uri.startswith(("sqlite://", "postgresql://", "mysql://")):
|
246
|
+
from .artifacts.sql_artifact_service import SQLArtifactService
|
247
|
+
artifact_service = SQLArtifactService(database_url=artifact_service_uri)
|
248
|
+
elif artifact_service_uri.startswith("mongodb://"):
|
249
|
+
from .artifacts.mongo_artifact_service import MongoArtifactService
|
250
|
+
artifact_service = MongoArtifactService(connection_string=artifact_service_uri)
|
209
251
|
else:
|
210
252
|
raise click.ClickException(
|
211
253
|
"Unsupported artifact service URI: %s" % artifact_service_uri
|
@@ -467,4 +509,96 @@ def get_enhanced_fast_api_app(
|
|
467
509
|
logger.error("Failed to setup programmatic A2A agent %s: %s", app_name, e)
|
468
510
|
|
469
511
|
logger.info("Enhanced FastAPI app created with credential service support")
|
512
|
+
|
513
|
+
# Optional streaming mounts (SSE + WebSocket)
|
514
|
+
if enable_streaming:
|
515
|
+
cfg = streaming_config or StreamingConfig(enable_streaming=True)
|
516
|
+
controller = StreamingController(
|
517
|
+
config=cfg,
|
518
|
+
get_runner_async=adk_web_server.get_runner_async,
|
519
|
+
session_service=session_service,
|
520
|
+
)
|
521
|
+
app.state.streaming_controller = controller
|
522
|
+
@app.on_event("startup")
|
523
|
+
async def _start_streaming(): # pragma: no cover - lifecycle glue
|
524
|
+
controller.start()
|
525
|
+
@app.on_event("shutdown")
|
526
|
+
async def _stop_streaming(): # pragma: no cover - lifecycle glue
|
527
|
+
await controller.stop()
|
528
|
+
|
529
|
+
from fastapi import APIRouter, WebSocket, Query
|
530
|
+
from fastapi.responses import StreamingResponse
|
531
|
+
from google.adk.cli.adk_web_server import RunAgentRequest
|
532
|
+
|
533
|
+
router = APIRouter()
|
534
|
+
base = cfg.streaming_path_base.rstrip("/")
|
535
|
+
|
536
|
+
@router.get(f"{base}/events/{{channel_id}}")
|
537
|
+
async def stream_events(channel_id: str, appName: str = Query(...), userId: str = Query(...), sessionId: Optional[str] = Query(None)):
|
538
|
+
ch = await app.state.streaming_controller.open_or_bind_channel(
|
539
|
+
channel_id=channel_id, app_name=appName, user_id=userId, session_id=sessionId
|
540
|
+
)
|
541
|
+
q = app.state.streaming_controller.subscribe(channel_id, kind="sse")
|
542
|
+
|
543
|
+
async def gen():
|
544
|
+
try:
|
545
|
+
# Announce channel binding with session id
|
546
|
+
yield "event: channel-bound\n"
|
547
|
+
yield f"data: {{\"appName\":\"{appName}\",\"userId\":\"{userId}\",\"sessionId\":\"{ch.session_id}\"}}\n\n"
|
548
|
+
while True:
|
549
|
+
payload = await q.get()
|
550
|
+
yield f"data: {payload}\n\n"
|
551
|
+
except asyncio.CancelledError:
|
552
|
+
pass
|
553
|
+
finally:
|
554
|
+
app.state.streaming_controller.unsubscribe(channel_id, q)
|
555
|
+
|
556
|
+
return StreamingResponse(gen(), media_type="text/event-stream")
|
557
|
+
|
558
|
+
@router.post(f"{base}/send/{{channel_id}}")
|
559
|
+
async def send_message(channel_id: str, req: RunAgentRequest):
|
560
|
+
# Validation: channel binding must match
|
561
|
+
await app.state.streaming_controller.enqueue(channel_id, req)
|
562
|
+
return PlainTextResponse("", status_code=204)
|
563
|
+
|
564
|
+
@router.websocket(f"{base}/ws/{{channel_id}}")
|
565
|
+
async def ws_endpoint(websocket: WebSocket, channel_id: str, appName: str, userId: str, sessionId: Optional[str] = None):
|
566
|
+
await websocket.accept()
|
567
|
+
try:
|
568
|
+
await app.state.streaming_controller.open_or_bind_channel(
|
569
|
+
channel_id=channel_id, app_name=appName, user_id=userId, session_id=sessionId
|
570
|
+
)
|
571
|
+
q = app.state.streaming_controller.subscribe(channel_id, kind="ws")
|
572
|
+
# Send channel binding info including session id
|
573
|
+
await websocket.send_text(json.dumps({"event": "channel-bound", "appName": appName, "userId": userId, "sessionId": app.state.streaming_controller._channels[channel_id].session_id}))
|
574
|
+
|
575
|
+
async def downlink():
|
576
|
+
try:
|
577
|
+
while True:
|
578
|
+
payload = await q.get()
|
579
|
+
await websocket.send_text(payload)
|
580
|
+
except asyncio.CancelledError:
|
581
|
+
pass
|
582
|
+
|
583
|
+
async def uplink():
|
584
|
+
try:
|
585
|
+
while True:
|
586
|
+
text = await websocket.receive_text()
|
587
|
+
# Strict type parity by default
|
588
|
+
req = RunAgentRequest.model_validate_json(text)
|
589
|
+
await app.state.streaming_controller.enqueue(channel_id, req)
|
590
|
+
except Exception:
|
591
|
+
return
|
592
|
+
|
593
|
+
down = asyncio.create_task(downlink())
|
594
|
+
up = asyncio.create_task(uplink())
|
595
|
+
await asyncio.wait({down, up}, return_when=asyncio.FIRST_COMPLETED)
|
596
|
+
finally:
|
597
|
+
try:
|
598
|
+
app.state.streaming_controller.unsubscribe(channel_id, q)
|
599
|
+
except Exception:
|
600
|
+
pass
|
601
|
+
|
602
|
+
app.include_router(router)
|
603
|
+
|
470
604
|
return app
|
@@ -0,0 +1,12 @@
|
|
1
|
+
"""Streaming support (SSE/WebSocket) for google-adk-extras.
|
2
|
+
|
3
|
+
This package provides an optional, persistent bi-directional streaming layer
|
4
|
+
with strict ADK type parity by default. It complements ADK's built-in
|
5
|
+
`/run`, `/run_sse`, and `/run_live` endpoints by offering per-channel
|
6
|
+
subscription and send semantics for chat-style UIs.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from .streaming_controller import StreamingConfig, StreamingController
|
10
|
+
|
11
|
+
__all__ = ["StreamingConfig", "StreamingController"]
|
12
|
+
|
@@ -0,0 +1,262 @@
|
|
1
|
+
import asyncio
|
2
|
+
import time
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
from typing import Any, Dict, Optional, Set, Callable, Awaitable
|
5
|
+
|
6
|
+
from fastapi import WebSocket, HTTPException
|
7
|
+
from pydantic import BaseModel
|
8
|
+
|
9
|
+
from google.adk.events.event import Event
|
10
|
+
from google.adk.runners import Runner
|
11
|
+
|
12
|
+
|
13
|
+
class StreamingConfig(BaseModel):
|
14
|
+
enable_streaming: bool = False
|
15
|
+
streaming_path_base: str = "/stream"
|
16
|
+
strict_types: bool = True
|
17
|
+
create_session_on_open: bool = True
|
18
|
+
ttl_seconds: int = 900
|
19
|
+
max_queue_size: int = 128
|
20
|
+
max_channels_per_user: int = 20
|
21
|
+
heartbeat_interval: Optional[float] = 20.0
|
22
|
+
reuse_session_policy: str = "per_channel" # "per_channel" or "external"
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class _Subscriber:
|
27
|
+
queue: "asyncio.Queue[str]"
|
28
|
+
kind: str # "sse" | "ws"
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class _Channel:
|
33
|
+
channel_id: str
|
34
|
+
app_name: str
|
35
|
+
user_id: str
|
36
|
+
session_id: str
|
37
|
+
in_q: "asyncio.Queue[Any]" = field(default_factory=asyncio.Queue)
|
38
|
+
subscribers: list[_Subscriber] = field(default_factory=list)
|
39
|
+
worker_task: Optional[asyncio.Task] = None
|
40
|
+
created_at: float = field(default_factory=lambda: time.time())
|
41
|
+
last_activity: float = field(default_factory=lambda: time.time())
|
42
|
+
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
43
|
+
|
44
|
+
|
45
|
+
class StreamingController:
|
46
|
+
"""Manages streaming channels and workers.
|
47
|
+
|
48
|
+
This controller binds a channel to (app_name, user_id, session_id) and
|
49
|
+
runs a background worker per channel to execute streamed runs and push
|
50
|
+
ADK Event JSON to all subscribers.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
*,
|
56
|
+
config: StreamingConfig,
|
57
|
+
get_runner_async: Callable[[str], Awaitable[Runner]],
|
58
|
+
session_service,
|
59
|
+
) -> None:
|
60
|
+
self._config = config
|
61
|
+
self._get_runner_async = get_runner_async
|
62
|
+
self._session_service = session_service
|
63
|
+
self._channels: Dict[str, _Channel] = {}
|
64
|
+
self._gc_task: Optional[asyncio.Task] = None
|
65
|
+
|
66
|
+
def start(self) -> None:
|
67
|
+
if self._gc_task is None:
|
68
|
+
self._gc_task = asyncio.create_task(self._gc_loop())
|
69
|
+
|
70
|
+
async def stop(self) -> None:
|
71
|
+
if self._gc_task:
|
72
|
+
self._gc_task.cancel()
|
73
|
+
with asyncio.CancelledError:
|
74
|
+
pass
|
75
|
+
self._gc_task = None
|
76
|
+
# Cancel workers
|
77
|
+
for ch in list(self._channels.values()):
|
78
|
+
if ch.worker_task and not ch.worker_task.done():
|
79
|
+
ch.worker_task.cancel()
|
80
|
+
self._channels.clear()
|
81
|
+
|
82
|
+
def _ensure_user_limit(self, user_id: str) -> None:
|
83
|
+
if self._config.max_channels_per_user <= 0:
|
84
|
+
return
|
85
|
+
count = sum(1 for c in self._channels.values() if c.user_id == user_id)
|
86
|
+
if count >= self._config.max_channels_per_user:
|
87
|
+
raise HTTPException(status_code=429, detail="Too many channels for this user")
|
88
|
+
|
89
|
+
async def open_or_bind_channel(
|
90
|
+
self,
|
91
|
+
*,
|
92
|
+
channel_id: str,
|
93
|
+
app_name: str,
|
94
|
+
user_id: str,
|
95
|
+
session_id: Optional[str],
|
96
|
+
) -> _Channel:
|
97
|
+
# Existing channel validation/match
|
98
|
+
if channel_id in self._channels:
|
99
|
+
ch = self._channels[channel_id]
|
100
|
+
if ch.app_name != app_name or ch.user_id != user_id:
|
101
|
+
raise HTTPException(status_code=409, detail="Channel binding conflict")
|
102
|
+
if session_id and session_id != ch.session_id:
|
103
|
+
raise HTTPException(status_code=409, detail="Channel already bound to different session")
|
104
|
+
ch.last_activity = time.time()
|
105
|
+
return ch
|
106
|
+
|
107
|
+
# New channel
|
108
|
+
self._ensure_user_limit(user_id)
|
109
|
+
if not session_id:
|
110
|
+
if not self._config.create_session_on_open:
|
111
|
+
raise HTTPException(status_code=400, detail="sessionId required for this channel")
|
112
|
+
# Create a fresh ADK session
|
113
|
+
create = getattr(self._session_service, "create_session", None)
|
114
|
+
if create is None:
|
115
|
+
# Older ADK interfaces may expose sync variant
|
116
|
+
create = getattr(self._session_service, "create_session_sync", None)
|
117
|
+
if create is None:
|
118
|
+
raise HTTPException(status_code=500, detail="Session service does not support create_session")
|
119
|
+
if asyncio.iscoroutinefunction(create):
|
120
|
+
session = await create(app_name=app_name, user_id=user_id)
|
121
|
+
else:
|
122
|
+
# Call sync and wrap
|
123
|
+
session = create(app_name=app_name, user_id=user_id)
|
124
|
+
session_id = session.id
|
125
|
+
else:
|
126
|
+
# Validate existing session
|
127
|
+
session = await self._session_service.get_session(app_name=app_name, user_id=user_id, session_id=session_id)
|
128
|
+
if not session:
|
129
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
130
|
+
|
131
|
+
ch = _Channel(
|
132
|
+
channel_id=channel_id,
|
133
|
+
app_name=app_name,
|
134
|
+
user_id=user_id,
|
135
|
+
session_id=session_id,
|
136
|
+
in_q=asyncio.Queue(),
|
137
|
+
)
|
138
|
+
self._channels[channel_id] = ch
|
139
|
+
ch.worker_task = asyncio.create_task(self._worker(ch))
|
140
|
+
return ch
|
141
|
+
|
142
|
+
def subscribe(self, channel_id: str, kind: str) -> asyncio.Queue[str]:
|
143
|
+
if channel_id not in self._channels:
|
144
|
+
raise HTTPException(status_code=404, detail="Channel not found")
|
145
|
+
q: asyncio.Queue[str] = asyncio.Queue(maxsize=self._config.max_queue_size)
|
146
|
+
self._channels[channel_id].subscribers.append(_Subscriber(queue=q, kind=kind))
|
147
|
+
self._channels[channel_id].last_activity = time.time()
|
148
|
+
return q
|
149
|
+
|
150
|
+
def unsubscribe(self, channel_id: str, q: asyncio.Queue[str]) -> None:
|
151
|
+
ch = self._channels.get(channel_id)
|
152
|
+
if not ch:
|
153
|
+
return
|
154
|
+
ch.subscribers = [s for s in ch.subscribers if s.queue is not q]
|
155
|
+
ch.last_activity = time.time()
|
156
|
+
|
157
|
+
async def enqueue(self, channel_id: str, req: Any) -> None:
|
158
|
+
ch = self._channels.get(channel_id)
|
159
|
+
if not ch:
|
160
|
+
raise HTTPException(status_code=404, detail="Channel not found")
|
161
|
+
# Validate binding
|
162
|
+
if getattr(req, "app_name", None) != ch.app_name or getattr(req, "user_id", None) != ch.user_id or getattr(req, "session_id", None) != ch.session_id:
|
163
|
+
raise HTTPException(status_code=409, detail="Request does not match channel binding")
|
164
|
+
await ch.in_q.put(req)
|
165
|
+
ch.last_activity = time.time()
|
166
|
+
|
167
|
+
async def _worker(self, ch: _Channel) -> None:
|
168
|
+
try:
|
169
|
+
while True:
|
170
|
+
req = await ch.in_q.get()
|
171
|
+
ch.last_activity = time.time()
|
172
|
+
try:
|
173
|
+
runner = await self._get_runner_async(ch.app_name)
|
174
|
+
# Stream events for this request
|
175
|
+
async with _aclosing(
|
176
|
+
runner.run_async(
|
177
|
+
user_id=ch.user_id,
|
178
|
+
session_id=ch.session_id,
|
179
|
+
new_message=req.new_message,
|
180
|
+
state_delta=getattr(req, "state_delta", None),
|
181
|
+
run_config=_maybe_run_config_streaming(True),
|
182
|
+
)
|
183
|
+
) as agen:
|
184
|
+
async for event in agen:
|
185
|
+
await self._broadcast_event(ch, event)
|
186
|
+
except Exception as e: # pragma: no cover - safety
|
187
|
+
await self._broadcast_error(ch, str(e))
|
188
|
+
except asyncio.CancelledError: # worker shutdown
|
189
|
+
return
|
190
|
+
|
191
|
+
async def _broadcast_event(self, ch: _Channel, event: Event) -> None:
|
192
|
+
payload = event.model_dump_json(exclude_none=True, by_alias=True)
|
193
|
+
for sub in list(ch.subscribers):
|
194
|
+
try:
|
195
|
+
sub.queue.put_nowait(payload)
|
196
|
+
except asyncio.QueueFull:
|
197
|
+
# Drop subscriber on backpressure
|
198
|
+
ch.subscribers = [s for s in ch.subscribers if s is not sub]
|
199
|
+
ch.last_activity = time.time()
|
200
|
+
|
201
|
+
async def _broadcast_heartbeat(self, ch: _Channel) -> None:
|
202
|
+
if self._config.heartbeat_interval is None:
|
203
|
+
return
|
204
|
+
payload = '{"event":"heartbeat"}'
|
205
|
+
for sub in list(ch.subscribers):
|
206
|
+
try:
|
207
|
+
sub.queue.put_nowait(payload)
|
208
|
+
except asyncio.QueueFull:
|
209
|
+
ch.subscribers = [s for s in ch.subscribers if s is not sub]
|
210
|
+
|
211
|
+
async def _broadcast_error(self, ch: _Channel, message: str) -> None:
|
212
|
+
payload = '{"error": %s}' % _json_escape(message)
|
213
|
+
for sub in list(ch.subscribers):
|
214
|
+
try:
|
215
|
+
sub.queue.put_nowait(payload)
|
216
|
+
except asyncio.QueueFull:
|
217
|
+
ch.subscribers = [s for s in ch.subscribers if s is not sub]
|
218
|
+
|
219
|
+
async def _gc_loop(self) -> None:
|
220
|
+
try:
|
221
|
+
while True:
|
222
|
+
await asyncio.sleep(min(10, max(1, int(self._config.ttl_seconds / 3))))
|
223
|
+
now = time.time()
|
224
|
+
for cid, ch in list(self._channels.items()):
|
225
|
+
idle = now - ch.last_activity
|
226
|
+
if idle >= self._config.ttl_seconds and not ch.subscribers and ch.in_q.empty():
|
227
|
+
if ch.worker_task and not ch.worker_task.done():
|
228
|
+
ch.worker_task.cancel()
|
229
|
+
self._channels.pop(cid, None)
|
230
|
+
except asyncio.CancelledError:
|
231
|
+
return
|
232
|
+
|
233
|
+
|
234
|
+
# Utilities (avoid importing optional internals at module import time)
|
235
|
+
def _maybe_run_config_streaming(enabled: bool):
|
236
|
+
# Support multiple ADK versions by resolving RunConfig/StreamingMode from
|
237
|
+
# either google.adk.runners or google.adk.agents.run_config
|
238
|
+
try:
|
239
|
+
from google.adk.runners import RunConfig # type: ignore
|
240
|
+
except Exception: # pragma: no cover - version fallback
|
241
|
+
from google.adk.agents.run_config import RunConfig # type: ignore
|
242
|
+
try:
|
243
|
+
from google.adk.agents.run_config import StreamingMode # type: ignore
|
244
|
+
except Exception: # pragma: no cover - defensive
|
245
|
+
StreamingMode = type("StreamingMode", (), {"SSE": "sse", "NONE": None}) # minimal stub
|
246
|
+
return RunConfig(streaming_mode=StreamingMode.SSE if enabled else StreamingMode.NONE)
|
247
|
+
|
248
|
+
|
249
|
+
class _aclosing:
|
250
|
+
def __init__(self, agen):
|
251
|
+
self._agen = agen
|
252
|
+
async def __aenter__(self):
|
253
|
+
return self._agen
|
254
|
+
async def __aexit__(self, exc_type, exc, tb):
|
255
|
+
try:
|
256
|
+
await self._agen.aclose()
|
257
|
+
except Exception:
|
258
|
+
pass
|
259
|
+
|
260
|
+
|
261
|
+
def _json_escape(s: str) -> str:
|
262
|
+
return '"' + s.replace('\\', '\\\\').replace('"', '\\"') + '"'
|
@@ -1,8 +1,8 @@
|
|
1
|
-
google_adk_extras/__init__.py,sha256=
|
1
|
+
google_adk_extras/__init__.py,sha256=cIaqZH7E__SmojqMjCqOA4FV6S6-k1pH9_mGCfkCfU4,830
|
2
2
|
google_adk_extras/adk_builder.py,sha256=anTpd-UYtPugRSqdYNmWa_uesVjHGEKO2mntkcU-J6g,41179
|
3
|
-
google_adk_extras/custom_agent_loader.py,sha256=
|
3
|
+
google_adk_extras/custom_agent_loader.py,sha256=6rQyBTAxvuwFPs3QGt8hWYZKsQVpbI4WIy6OS73tsac,6001
|
4
4
|
google_adk_extras/enhanced_adk_web_server.py,sha256=rML_m4Um9QCcnnlHvCOyLGjLHDcyladEunX-x4JX43Q,5414
|
5
|
-
google_adk_extras/enhanced_fastapi.py,sha256=
|
5
|
+
google_adk_extras/enhanced_fastapi.py,sha256=TzIr159B0tnEqj2dXKmV_tJvkA66Fq2I-kL5cz9VGSk,28304
|
6
6
|
google_adk_extras/enhanced_runner.py,sha256=b7O1a9-4S49LduILOEDs6IxjCI4w_E39sc-Hs4y3Rys,1410
|
7
7
|
google_adk_extras/artifacts/__init__.py,sha256=_IsKDgf6wanWR0HXvSpK9SiLa3n5URKLtazkKyH1P-o,931
|
8
8
|
google_adk_extras/artifacts/base_custom_artifact_service.py,sha256=O9rkc250B3yDRYbyDI0EvTrCKvnih5_DQas5OF-hRMY,9721
|
@@ -30,8 +30,10 @@ google_adk_extras/sessions/mongo_session_service.py,sha256=r3jZ3PmDpbZ0veNzXzuAj
|
|
30
30
|
google_adk_extras/sessions/redis_session_service.py,sha256=yyfXZozeFWJ2S_kz7zXqz--f_ymE6HMMpT3MhcpFXIE,10434
|
31
31
|
google_adk_extras/sessions/sql_session_service.py,sha256=TaOeEVWnwQ_8nvDZBW7e3qhzR_ecuGsjvZ_kh6Guq8g,14558
|
32
32
|
google_adk_extras/sessions/yaml_file_session_service.py,sha256=SpTh8YHIALcoxzmturhcZ4ReHKQrJI1CxoYiJQ-baRc,11819
|
33
|
-
google_adk_extras
|
34
|
-
google_adk_extras
|
35
|
-
google_adk_extras-0.2.
|
36
|
-
google_adk_extras-0.2.
|
37
|
-
google_adk_extras-0.2.
|
33
|
+
google_adk_extras/streaming/__init__.py,sha256=rcjmlCJHTlvUiCrx6qNGw5ObCnEtfENkGTvzfEiGL0M,461
|
34
|
+
google_adk_extras/streaming/streaming_controller.py,sha256=Qxg4yqakd3mBPJ9sISe4gO_pvBpYttvSq5SIK4W5ZPQ,10505
|
35
|
+
google_adk_extras-0.2.6.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
36
|
+
google_adk_extras-0.2.6.dist-info/METADATA,sha256=NXbByeiIsdHl2UcZ5F-i22thmm-rFcTr1poTFdVY5cI,10707
|
37
|
+
google_adk_extras-0.2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
38
|
+
google_adk_extras-0.2.6.dist-info/top_level.txt,sha256=DDWgVkz8G8ihPzznxAWyKa2jgJW3F6Fwy__qMddoKTs,18
|
39
|
+
google_adk_extras-0.2.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|