google-adk-extras 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,4 +28,4 @@ __all__ = [
28
28
  "CustomAgentLoader",
29
29
  ]
30
30
 
31
- __version__ = "0.2.5"
31
+ __version__ = "0.2.6"
@@ -5,6 +5,7 @@ that properly supports custom credential services.
5
5
  """
6
6
 
7
7
  import json
8
+ import asyncio
8
9
  import logging
9
10
  import os
10
11
  from pathlib import Path
@@ -34,6 +35,7 @@ from google.adk.sessions.database_session_service import DatabaseSessionService
34
35
  from google.adk.utils.feature_decorator import working_in_progress
35
36
  from google.adk.cli.adk_web_server import AdkWebServer
36
37
  from .enhanced_adk_web_server import EnhancedAdkWebServer
38
+ from .streaming import StreamingController, StreamingConfig
37
39
  from google.adk.cli.utils import envs
38
40
  from google.adk.cli.utils import evals
39
41
  from google.adk.cli.utils.agent_change_handler import AgentChangeEventHandler
@@ -64,6 +66,9 @@ def get_enhanced_fast_api_app(
64
66
  trace_to_cloud: bool = False,
65
67
  reload_agents: bool = False,
66
68
  lifespan: Optional[Lifespan[FastAPI]] = None,
69
+ # Streaming layer (optional)
70
+ enable_streaming: bool = False,
71
+ streaming_config: Optional[StreamingConfig] = None,
67
72
  ) -> FastAPI:
68
73
  """Enhanced version of Google ADK's get_fast_api_app with EnhancedRunner integration.
69
74
 
@@ -504,4 +509,96 @@ def get_enhanced_fast_api_app(
504
509
  logger.error("Failed to setup programmatic A2A agent %s: %s", app_name, e)
505
510
 
506
511
  logger.info("Enhanced FastAPI app created with credential service support")
512
+
513
+ # Optional streaming mounts (SSE + WebSocket)
514
+ if enable_streaming:
515
+ cfg = streaming_config or StreamingConfig(enable_streaming=True)
516
+ controller = StreamingController(
517
+ config=cfg,
518
+ get_runner_async=adk_web_server.get_runner_async,
519
+ session_service=session_service,
520
+ )
521
+ app.state.streaming_controller = controller
522
+ @app.on_event("startup")
523
+ async def _start_streaming(): # pragma: no cover - lifecycle glue
524
+ controller.start()
525
+ @app.on_event("shutdown")
526
+ async def _stop_streaming(): # pragma: no cover - lifecycle glue
527
+ await controller.stop()
528
+
529
+ from fastapi import APIRouter, WebSocket, Query
530
+ from fastapi.responses import StreamingResponse
531
+ from google.adk.cli.adk_web_server import RunAgentRequest
532
+
533
+ router = APIRouter()
534
+ base = cfg.streaming_path_base.rstrip("/")
535
+
536
+ @router.get(f"{base}/events/{{channel_id}}")
537
+ async def stream_events(channel_id: str, appName: str = Query(...), userId: str = Query(...), sessionId: Optional[str] = Query(None)):
538
+ ch = await app.state.streaming_controller.open_or_bind_channel(
539
+ channel_id=channel_id, app_name=appName, user_id=userId, session_id=sessionId
540
+ )
541
+ q = app.state.streaming_controller.subscribe(channel_id, kind="sse")
542
+
543
+ async def gen():
544
+ try:
545
+ # Announce channel binding with session id
546
+ yield "event: channel-bound\n"
547
+ yield f"data: {{\"appName\":\"{appName}\",\"userId\":\"{userId}\",\"sessionId\":\"{ch.session_id}\"}}\n\n"
548
+ while True:
549
+ payload = await q.get()
550
+ yield f"data: {payload}\n\n"
551
+ except asyncio.CancelledError:
552
+ pass
553
+ finally:
554
+ app.state.streaming_controller.unsubscribe(channel_id, q)
555
+
556
+ return StreamingResponse(gen(), media_type="text/event-stream")
557
+
558
+ @router.post(f"{base}/send/{{channel_id}}")
559
+ async def send_message(channel_id: str, req: RunAgentRequest):
560
+ # Validation: channel binding must match
561
+ await app.state.streaming_controller.enqueue(channel_id, req)
562
+ return PlainTextResponse("", status_code=204)
563
+
564
+ @router.websocket(f"{base}/ws/{{channel_id}}")
565
+ async def ws_endpoint(websocket: WebSocket, channel_id: str, appName: str, userId: str, sessionId: Optional[str] = None):
566
+ await websocket.accept()
567
+ try:
568
+ await app.state.streaming_controller.open_or_bind_channel(
569
+ channel_id=channel_id, app_name=appName, user_id=userId, session_id=sessionId
570
+ )
571
+ q = app.state.streaming_controller.subscribe(channel_id, kind="ws")
572
+ # Send channel binding info including session id
573
+ await websocket.send_text(json.dumps({"event": "channel-bound", "appName": appName, "userId": userId, "sessionId": app.state.streaming_controller._channels[channel_id].session_id}))
574
+
575
+ async def downlink():
576
+ try:
577
+ while True:
578
+ payload = await q.get()
579
+ await websocket.send_text(payload)
580
+ except asyncio.CancelledError:
581
+ pass
582
+
583
+ async def uplink():
584
+ try:
585
+ while True:
586
+ text = await websocket.receive_text()
587
+ # Strict type parity by default
588
+ req = RunAgentRequest.model_validate_json(text)
589
+ await app.state.streaming_controller.enqueue(channel_id, req)
590
+ except Exception:
591
+ return
592
+
593
+ down = asyncio.create_task(downlink())
594
+ up = asyncio.create_task(uplink())
595
+ await asyncio.wait({down, up}, return_when=asyncio.FIRST_COMPLETED)
596
+ finally:
597
+ try:
598
+ app.state.streaming_controller.unsubscribe(channel_id, q)
599
+ except Exception:
600
+ pass
601
+
602
+ app.include_router(router)
603
+
507
604
  return app
@@ -0,0 +1,12 @@
1
+ """Streaming support (SSE/WebSocket) for google-adk-extras.
2
+
3
+ This package provides an optional, persistent bi-directional streaming layer
4
+ with strict ADK type parity by default. It complements ADK's built-in
5
+ `/run`, `/run_sse`, and `/run_live` endpoints by offering per-channel
6
+ subscription and send semantics for chat-style UIs.
7
+ """
8
+
9
+ from .streaming_controller import StreamingConfig, StreamingController
10
+
11
+ __all__ = ["StreamingConfig", "StreamingController"]
12
+
@@ -0,0 +1,262 @@
1
+ import asyncio
2
+ import time
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, Optional, Set, Callable, Awaitable
5
+
6
+ from fastapi import WebSocket, HTTPException
7
+ from pydantic import BaseModel
8
+
9
+ from google.adk.events.event import Event
10
+ from google.adk.runners import Runner
11
+
12
+
13
+ class StreamingConfig(BaseModel):
14
+ enable_streaming: bool = False
15
+ streaming_path_base: str = "/stream"
16
+ strict_types: bool = True
17
+ create_session_on_open: bool = True
18
+ ttl_seconds: int = 900
19
+ max_queue_size: int = 128
20
+ max_channels_per_user: int = 20
21
+ heartbeat_interval: Optional[float] = 20.0
22
+ reuse_session_policy: str = "per_channel" # "per_channel" or "external"
23
+
24
+
25
+ @dataclass
26
+ class _Subscriber:
27
+ queue: "asyncio.Queue[str]"
28
+ kind: str # "sse" | "ws"
29
+
30
+
31
+ @dataclass
32
+ class _Channel:
33
+ channel_id: str
34
+ app_name: str
35
+ user_id: str
36
+ session_id: str
37
+ in_q: "asyncio.Queue[Any]" = field(default_factory=asyncio.Queue)
38
+ subscribers: list[_Subscriber] = field(default_factory=list)
39
+ worker_task: Optional[asyncio.Task] = None
40
+ created_at: float = field(default_factory=lambda: time.time())
41
+ last_activity: float = field(default_factory=lambda: time.time())
42
+ lock: asyncio.Lock = field(default_factory=asyncio.Lock)
43
+
44
+
45
+ class StreamingController:
46
+ """Manages streaming channels and workers.
47
+
48
+ This controller binds a channel to (app_name, user_id, session_id) and
49
+ runs a background worker per channel to execute streamed runs and push
50
+ ADK Event JSON to all subscribers.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ *,
56
+ config: StreamingConfig,
57
+ get_runner_async: Callable[[str], Awaitable[Runner]],
58
+ session_service,
59
+ ) -> None:
60
+ self._config = config
61
+ self._get_runner_async = get_runner_async
62
+ self._session_service = session_service
63
+ self._channels: Dict[str, _Channel] = {}
64
+ self._gc_task: Optional[asyncio.Task] = None
65
+
66
+ def start(self) -> None:
67
+ if self._gc_task is None:
68
+ self._gc_task = asyncio.create_task(self._gc_loop())
69
+
70
+ async def stop(self) -> None:
71
+ if self._gc_task:
72
+ self._gc_task.cancel()
73
+ with asyncio.CancelledError:
74
+ pass
75
+ self._gc_task = None
76
+ # Cancel workers
77
+ for ch in list(self._channels.values()):
78
+ if ch.worker_task and not ch.worker_task.done():
79
+ ch.worker_task.cancel()
80
+ self._channels.clear()
81
+
82
+ def _ensure_user_limit(self, user_id: str) -> None:
83
+ if self._config.max_channels_per_user <= 0:
84
+ return
85
+ count = sum(1 for c in self._channels.values() if c.user_id == user_id)
86
+ if count >= self._config.max_channels_per_user:
87
+ raise HTTPException(status_code=429, detail="Too many channels for this user")
88
+
89
+ async def open_or_bind_channel(
90
+ self,
91
+ *,
92
+ channel_id: str,
93
+ app_name: str,
94
+ user_id: str,
95
+ session_id: Optional[str],
96
+ ) -> _Channel:
97
+ # Existing channel validation/match
98
+ if channel_id in self._channels:
99
+ ch = self._channels[channel_id]
100
+ if ch.app_name != app_name or ch.user_id != user_id:
101
+ raise HTTPException(status_code=409, detail="Channel binding conflict")
102
+ if session_id and session_id != ch.session_id:
103
+ raise HTTPException(status_code=409, detail="Channel already bound to different session")
104
+ ch.last_activity = time.time()
105
+ return ch
106
+
107
+ # New channel
108
+ self._ensure_user_limit(user_id)
109
+ if not session_id:
110
+ if not self._config.create_session_on_open:
111
+ raise HTTPException(status_code=400, detail="sessionId required for this channel")
112
+ # Create a fresh ADK session
113
+ create = getattr(self._session_service, "create_session", None)
114
+ if create is None:
115
+ # Older ADK interfaces may expose sync variant
116
+ create = getattr(self._session_service, "create_session_sync", None)
117
+ if create is None:
118
+ raise HTTPException(status_code=500, detail="Session service does not support create_session")
119
+ if asyncio.iscoroutinefunction(create):
120
+ session = await create(app_name=app_name, user_id=user_id)
121
+ else:
122
+ # Call sync and wrap
123
+ session = create(app_name=app_name, user_id=user_id)
124
+ session_id = session.id
125
+ else:
126
+ # Validate existing session
127
+ session = await self._session_service.get_session(app_name=app_name, user_id=user_id, session_id=session_id)
128
+ if not session:
129
+ raise HTTPException(status_code=404, detail="Session not found")
130
+
131
+ ch = _Channel(
132
+ channel_id=channel_id,
133
+ app_name=app_name,
134
+ user_id=user_id,
135
+ session_id=session_id,
136
+ in_q=asyncio.Queue(),
137
+ )
138
+ self._channels[channel_id] = ch
139
+ ch.worker_task = asyncio.create_task(self._worker(ch))
140
+ return ch
141
+
142
+ def subscribe(self, channel_id: str, kind: str) -> asyncio.Queue[str]:
143
+ if channel_id not in self._channels:
144
+ raise HTTPException(status_code=404, detail="Channel not found")
145
+ q: asyncio.Queue[str] = asyncio.Queue(maxsize=self._config.max_queue_size)
146
+ self._channels[channel_id].subscribers.append(_Subscriber(queue=q, kind=kind))
147
+ self._channels[channel_id].last_activity = time.time()
148
+ return q
149
+
150
+ def unsubscribe(self, channel_id: str, q: asyncio.Queue[str]) -> None:
151
+ ch = self._channels.get(channel_id)
152
+ if not ch:
153
+ return
154
+ ch.subscribers = [s for s in ch.subscribers if s.queue is not q]
155
+ ch.last_activity = time.time()
156
+
157
+ async def enqueue(self, channel_id: str, req: Any) -> None:
158
+ ch = self._channels.get(channel_id)
159
+ if not ch:
160
+ raise HTTPException(status_code=404, detail="Channel not found")
161
+ # Validate binding
162
+ if getattr(req, "app_name", None) != ch.app_name or getattr(req, "user_id", None) != ch.user_id or getattr(req, "session_id", None) != ch.session_id:
163
+ raise HTTPException(status_code=409, detail="Request does not match channel binding")
164
+ await ch.in_q.put(req)
165
+ ch.last_activity = time.time()
166
+
167
+ async def _worker(self, ch: _Channel) -> None:
168
+ try:
169
+ while True:
170
+ req = await ch.in_q.get()
171
+ ch.last_activity = time.time()
172
+ try:
173
+ runner = await self._get_runner_async(ch.app_name)
174
+ # Stream events for this request
175
+ async with _aclosing(
176
+ runner.run_async(
177
+ user_id=ch.user_id,
178
+ session_id=ch.session_id,
179
+ new_message=req.new_message,
180
+ state_delta=getattr(req, "state_delta", None),
181
+ run_config=_maybe_run_config_streaming(True),
182
+ )
183
+ ) as agen:
184
+ async for event in agen:
185
+ await self._broadcast_event(ch, event)
186
+ except Exception as e: # pragma: no cover - safety
187
+ await self._broadcast_error(ch, str(e))
188
+ except asyncio.CancelledError: # worker shutdown
189
+ return
190
+
191
+ async def _broadcast_event(self, ch: _Channel, event: Event) -> None:
192
+ payload = event.model_dump_json(exclude_none=True, by_alias=True)
193
+ for sub in list(ch.subscribers):
194
+ try:
195
+ sub.queue.put_nowait(payload)
196
+ except asyncio.QueueFull:
197
+ # Drop subscriber on backpressure
198
+ ch.subscribers = [s for s in ch.subscribers if s is not sub]
199
+ ch.last_activity = time.time()
200
+
201
+ async def _broadcast_heartbeat(self, ch: _Channel) -> None:
202
+ if self._config.heartbeat_interval is None:
203
+ return
204
+ payload = '{"event":"heartbeat"}'
205
+ for sub in list(ch.subscribers):
206
+ try:
207
+ sub.queue.put_nowait(payload)
208
+ except asyncio.QueueFull:
209
+ ch.subscribers = [s for s in ch.subscribers if s is not sub]
210
+
211
+ async def _broadcast_error(self, ch: _Channel, message: str) -> None:
212
+ payload = '{"error": %s}' % _json_escape(message)
213
+ for sub in list(ch.subscribers):
214
+ try:
215
+ sub.queue.put_nowait(payload)
216
+ except asyncio.QueueFull:
217
+ ch.subscribers = [s for s in ch.subscribers if s is not sub]
218
+
219
+ async def _gc_loop(self) -> None:
220
+ try:
221
+ while True:
222
+ await asyncio.sleep(min(10, max(1, int(self._config.ttl_seconds / 3))))
223
+ now = time.time()
224
+ for cid, ch in list(self._channels.items()):
225
+ idle = now - ch.last_activity
226
+ if idle >= self._config.ttl_seconds and not ch.subscribers and ch.in_q.empty():
227
+ if ch.worker_task and not ch.worker_task.done():
228
+ ch.worker_task.cancel()
229
+ self._channels.pop(cid, None)
230
+ except asyncio.CancelledError:
231
+ return
232
+
233
+
234
+ # Utilities (avoid importing optional internals at module import time)
235
+ def _maybe_run_config_streaming(enabled: bool):
236
+ # Support multiple ADK versions by resolving RunConfig/StreamingMode from
237
+ # either google.adk.runners or google.adk.agents.run_config
238
+ try:
239
+ from google.adk.runners import RunConfig # type: ignore
240
+ except Exception: # pragma: no cover - version fallback
241
+ from google.adk.agents.run_config import RunConfig # type: ignore
242
+ try:
243
+ from google.adk.agents.run_config import StreamingMode # type: ignore
244
+ except Exception: # pragma: no cover - defensive
245
+ StreamingMode = type("StreamingMode", (), {"SSE": "sse", "NONE": None}) # minimal stub
246
+ return RunConfig(streaming_mode=StreamingMode.SSE if enabled else StreamingMode.NONE)
247
+
248
+
249
+ class _aclosing:
250
+ def __init__(self, agen):
251
+ self._agen = agen
252
+ async def __aenter__(self):
253
+ return self._agen
254
+ async def __aexit__(self, exc_type, exc, tb):
255
+ try:
256
+ await self._agen.aclose()
257
+ except Exception:
258
+ pass
259
+
260
+
261
+ def _json_escape(s: str) -> str:
262
+ return '"' + s.replace('\\', '\\\\').replace('"', '\\"') + '"'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: google-adk-extras
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: Production-ready services, credentials, and FastAPI wiring for Google ADK
5
5
  Home-page: https://github.com/DeadMeme5441/google-adk-extras
6
6
  Author: DeadMeme5441
@@ -1,8 +1,8 @@
1
- google_adk_extras/__init__.py,sha256=D0P8SjXVhABGaVYLQmkP9LlXRMLlsRfyhWZ4XsiWaOs,830
1
+ google_adk_extras/__init__.py,sha256=cIaqZH7E__SmojqMjCqOA4FV6S6-k1pH9_mGCfkCfU4,830
2
2
  google_adk_extras/adk_builder.py,sha256=anTpd-UYtPugRSqdYNmWa_uesVjHGEKO2mntkcU-J6g,41179
3
3
  google_adk_extras/custom_agent_loader.py,sha256=6rQyBTAxvuwFPs3QGt8hWYZKsQVpbI4WIy6OS73tsac,6001
4
4
  google_adk_extras/enhanced_adk_web_server.py,sha256=rML_m4Um9QCcnnlHvCOyLGjLHDcyladEunX-x4JX43Q,5414
5
- google_adk_extras/enhanced_fastapi.py,sha256=qMAawDdduVp3DMZtFJq1f1l6tD8o5_lyHf3Vf4k5JyM,23790
5
+ google_adk_extras/enhanced_fastapi.py,sha256=TzIr159B0tnEqj2dXKmV_tJvkA66Fq2I-kL5cz9VGSk,28304
6
6
  google_adk_extras/enhanced_runner.py,sha256=b7O1a9-4S49LduILOEDs6IxjCI4w_E39sc-Hs4y3Rys,1410
7
7
  google_adk_extras/artifacts/__init__.py,sha256=_IsKDgf6wanWR0HXvSpK9SiLa3n5URKLtazkKyH1P-o,931
8
8
  google_adk_extras/artifacts/base_custom_artifact_service.py,sha256=O9rkc250B3yDRYbyDI0EvTrCKvnih5_DQas5OF-hRMY,9721
@@ -30,8 +30,10 @@ google_adk_extras/sessions/mongo_session_service.py,sha256=r3jZ3PmDpbZ0veNzXzuAj
30
30
  google_adk_extras/sessions/redis_session_service.py,sha256=yyfXZozeFWJ2S_kz7zXqz--f_ymE6HMMpT3MhcpFXIE,10434
31
31
  google_adk_extras/sessions/sql_session_service.py,sha256=TaOeEVWnwQ_8nvDZBW7e3qhzR_ecuGsjvZ_kh6Guq8g,14558
32
32
  google_adk_extras/sessions/yaml_file_session_service.py,sha256=SpTh8YHIALcoxzmturhcZ4ReHKQrJI1CxoYiJQ-baRc,11819
33
- google_adk_extras-0.2.5.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
34
- google_adk_extras-0.2.5.dist-info/METADATA,sha256=QmkrJkhlsJQb09By4L-NlRFxfU9P_WfjJJIoo1YwCq8,10707
35
- google_adk_extras-0.2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
- google_adk_extras-0.2.5.dist-info/top_level.txt,sha256=DDWgVkz8G8ihPzznxAWyKa2jgJW3F6Fwy__qMddoKTs,18
37
- google_adk_extras-0.2.5.dist-info/RECORD,,
33
+ google_adk_extras/streaming/__init__.py,sha256=rcjmlCJHTlvUiCrx6qNGw5ObCnEtfENkGTvzfEiGL0M,461
34
+ google_adk_extras/streaming/streaming_controller.py,sha256=Qxg4yqakd3mBPJ9sISe4gO_pvBpYttvSq5SIK4W5ZPQ,10505
35
+ google_adk_extras-0.2.6.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
36
+ google_adk_extras-0.2.6.dist-info/METADATA,sha256=NXbByeiIsdHl2UcZ5F-i22thmm-rFcTr1poTFdVY5cI,10707
37
+ google_adk_extras-0.2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
38
+ google_adk_extras-0.2.6.dist-info/top_level.txt,sha256=DDWgVkz8G8ihPzznxAWyKa2jgJW3F6Fwy__qMddoKTs,18
39
+ google_adk_extras-0.2.6.dist-info/RECORD,,