remdb 0.3.202__py3-none-any.whl → 0.3.245__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of remdb might be problematic. Click here for more details.

Files changed (44) hide show
  1. rem/agentic/README.md +36 -2
  2. rem/agentic/context.py +86 -3
  3. rem/agentic/context_builder.py +39 -33
  4. rem/agentic/mcp/tool_wrapper.py +2 -2
  5. rem/agentic/providers/pydantic_ai.py +68 -51
  6. rem/agentic/schema.py +2 -2
  7. rem/api/mcp_router/resources.py +223 -0
  8. rem/api/mcp_router/tools.py +170 -18
  9. rem/api/routers/admin.py +30 -4
  10. rem/api/routers/auth.py +175 -18
  11. rem/api/routers/chat/child_streaming.py +394 -0
  12. rem/api/routers/chat/completions.py +24 -29
  13. rem/api/routers/chat/sse_events.py +5 -1
  14. rem/api/routers/chat/streaming.py +242 -272
  15. rem/api/routers/chat/streaming_utils.py +327 -0
  16. rem/api/routers/common.py +18 -0
  17. rem/api/routers/dev.py +7 -1
  18. rem/api/routers/feedback.py +9 -1
  19. rem/api/routers/messages.py +80 -15
  20. rem/api/routers/models.py +9 -1
  21. rem/api/routers/query.py +17 -15
  22. rem/api/routers/shared_sessions.py +16 -0
  23. rem/cli/commands/ask.py +205 -114
  24. rem/cli/commands/process.py +12 -4
  25. rem/cli/commands/query.py +109 -0
  26. rem/cli/commands/session.py +117 -0
  27. rem/cli/main.py +2 -0
  28. rem/models/entities/session.py +1 -0
  29. rem/schemas/agents/rem.yaml +1 -1
  30. rem/services/postgres/repository.py +7 -7
  31. rem/services/rem/service.py +47 -0
  32. rem/services/session/__init__.py +2 -1
  33. rem/services/session/compression.py +14 -12
  34. rem/services/session/pydantic_messages.py +111 -11
  35. rem/services/session/reload.py +2 -1
  36. rem/settings.py +71 -0
  37. rem/sql/migrations/001_install.sql +4 -4
  38. rem/sql/migrations/004_cache_system.sql +3 -1
  39. rem/sql/migrations/migrate_session_id_to_uuid.sql +45 -0
  40. rem/utils/schema_loader.py +139 -111
  41. {remdb-0.3.202.dist-info → remdb-0.3.245.dist-info}/METADATA +2 -2
  42. {remdb-0.3.202.dist-info → remdb-0.3.245.dist-info}/RECORD +44 -39
  43. {remdb-0.3.202.dist-info → remdb-0.3.245.dist-info}/WHEEL +0 -0
  44. {remdb-0.3.202.dist-info → remdb-0.3.245.dist-info}/entry_points.txt +0 -0
rem/api/routers/auth.py CHANGED
@@ -3,11 +3,12 @@ Authentication Router.
3
3
 
4
4
  Supports multiple authentication methods:
5
5
  1. Email (passwordless): POST /api/auth/email/send-code, POST /api/auth/email/verify
6
- 2. OAuth (Google, Microsoft): GET /api/auth/{provider}/login, GET /api/auth/{provider}/callback
6
+ 2. Pre-approved codes: POST /api/auth/email/verify (with pre-approved code, no send-code needed)
7
+ 3. OAuth (Google, Microsoft): GET /api/auth/{provider}/login, GET /api/auth/{provider}/callback
7
8
 
8
9
  Endpoints:
9
10
  - POST /api/auth/email/send-code - Send login code to email
10
- - POST /api/auth/email/verify - Verify code and create session
11
+ - POST /api/auth/email/verify - Verify code and create session (supports pre-approved codes)
11
12
  - GET /api/auth/{provider}/login - Initiate OAuth flow
12
13
  - GET /api/auth/{provider}/callback - OAuth callback
13
14
  - POST /api/auth/logout - Clear session
@@ -15,9 +16,39 @@ Endpoints:
15
16
 
16
17
  Supported providers:
17
18
  - email: Passwordless email login
19
+ - preapproved: Pre-approved codes (bypass email, set via AUTH__PREAPPROVED_CODES)
18
20
  - google: Google OAuth 2.0 / OIDC
19
21
  - microsoft: Microsoft Entra ID OIDC
20
22
 
23
+ =============================================================================
24
+ Pre-Approved Code Authentication
25
+ =============================================================================
26
+
27
+ Pre-approved codes allow login without email verification. Useful for:
28
+ - Demo accounts
29
+ - Testing
30
+ - Beta access codes
31
+ - Admin provisioning
32
+
33
+ Configuration:
34
+ AUTH__PREAPPROVED_CODES=A12345,A67890,B11111,B22222
35
+
36
+ Code prefixes:
37
+ A = Admin role (e.g., A12345, AADMIN1)
38
+ B = Normal user role (e.g., B11111, BUSER1)
39
+
40
+ Flow:
41
+ 1. User enters email + pre-approved code (no send-code step needed)
42
+ 2. POST /api/auth/email/verify with email and code
43
+ 3. System validates code against AUTH__PREAPPROVED_CODES
44
+ 4. Creates user if not exists, sets role based on prefix
45
+ 5. Returns JWT tokens (same as email auth)
46
+
47
+ Example:
48
+ curl -X POST http://localhost:8000/api/auth/email/verify \
49
+ -H "Content-Type: application/json" \
50
+ -d '{"email": "admin@example.com", "code": "A12345"}'
51
+
21
52
  =============================================================================
22
53
  Email Authentication Access Control
23
54
  =============================================================================
@@ -101,6 +132,8 @@ from authlib.integrations.starlette_client import OAuth
101
132
  from pydantic import BaseModel, EmailStr
102
133
  from loguru import logger
103
134
 
135
+ from .common import ErrorResponse
136
+
104
137
  from ...settings import settings
105
138
  from ...services.postgres.service import PostgresService
106
139
  from ...services.user_service import UserService
@@ -159,7 +192,14 @@ class EmailVerifyRequest(BaseModel):
159
192
  code: str
160
193
 
161
194
 
162
- @router.post("/email/send-code")
195
+ @router.post(
196
+ "/email/send-code",
197
+ responses={
198
+ 400: {"model": ErrorResponse, "description": "Invalid request or email rejected"},
199
+ 500: {"model": ErrorResponse, "description": "Failed to send login code"},
200
+ 501: {"model": ErrorResponse, "description": "Email auth or database not configured"},
201
+ },
202
+ )
163
203
  async def send_email_code(request: Request, body: EmailSendCodeRequest):
164
204
  """
165
205
  Send a login code to an email address.
@@ -221,11 +261,24 @@ async def send_email_code(request: Request, body: EmailSendCodeRequest):
221
261
  await db.disconnect()
222
262
 
223
263
 
224
- @router.post("/email/verify")
264
+ @router.post(
265
+ "/email/verify",
266
+ responses={
267
+ 400: {"model": ErrorResponse, "description": "Invalid or expired code"},
268
+ 500: {"model": ErrorResponse, "description": "Failed to verify login code"},
269
+ 501: {"model": ErrorResponse, "description": "Email auth or database not configured"},
270
+ },
271
+ )
225
272
  async def verify_email_code(request: Request, body: EmailVerifyRequest):
226
273
  """
227
274
  Verify login code and create session with JWT tokens.
228
275
 
276
+ Supports two authentication methods:
277
+ 1. Pre-approved codes: Codes from AUTH__PREAPPROVED_CODES bypass email verification.
278
+ - A prefix = admin role, B prefix = normal user role
279
+ - Creates user if not exists, logs in directly
280
+ 2. Email verification: Standard 6-digit code sent via email
281
+
229
282
  Args:
230
283
  request: FastAPI request
231
284
  body: EmailVerifyRequest with email and code
@@ -233,12 +286,6 @@ async def verify_email_code(request: Request, body: EmailVerifyRequest):
233
286
  Returns:
234
287
  Success status with user info and JWT tokens
235
288
  """
236
- if not settings.email.is_configured:
237
- raise HTTPException(
238
- status_code=501,
239
- detail="Email authentication is not configured"
240
- )
241
-
242
289
  if not settings.postgres.enabled:
243
290
  raise HTTPException(
244
291
  status_code=501,
@@ -248,6 +295,79 @@ async def verify_email_code(request: Request, body: EmailVerifyRequest):
248
295
  db = PostgresService()
249
296
  try:
250
297
  await db.connect()
298
+ user_service = UserService(db)
299
+
300
+ # Check for pre-approved code first
301
+ preapproved = settings.auth.check_preapproved_code(body.code)
302
+ if preapproved:
303
+ logger.info(f"Pre-approved code login attempt for {body.email} (role: {preapproved['role']})")
304
+
305
+ # Get or create user with pre-approved role
306
+ user_id = email_to_user_id(body.email)
307
+ user_entity = await user_service.get_user_by_id(user_id)
308
+
309
+ if not user_entity:
310
+ # Create new user with role from pre-approved code
311
+ user_entity = await user_service.get_or_create_user(
312
+ email=body.email,
313
+ name=body.email.split("@")[0],
314
+ tenant_id="default",
315
+ )
316
+ # Update role based on pre-approved code prefix
317
+ user_entity.role = preapproved["role"]
318
+ from ...services.postgres.repository import Repository
319
+ from ...models.entities.user import User
320
+ user_repo = Repository(User, "users", db=db)
321
+ await user_repo.upsert(user_entity)
322
+ logger.info(f"Created user {body.email} with role={preapproved['role']} via pre-approved code")
323
+ else:
324
+ # Update existing user's role if admin code used
325
+ if preapproved["role"] == "admin" and user_entity.role != "admin":
326
+ user_entity.role = "admin"
327
+ from ...services.postgres.repository import Repository
328
+ from ...models.entities.user import User
329
+ user_repo = Repository(User, "users", db=db)
330
+ await user_repo.upsert(user_entity)
331
+ logger.info(f"Upgraded user {body.email} to admin via pre-approved code")
332
+
333
+ # Build user dict for session/JWT
334
+ user_dict = {
335
+ "id": str(user_entity.id),
336
+ "email": body.email,
337
+ "email_verified": True,
338
+ "name": user_entity.name or body.email.split("@")[0],
339
+ "provider": "preapproved",
340
+ "tenant_id": user_entity.tenant_id or "default",
341
+ "tier": user_entity.tier.value if user_entity.tier else "free",
342
+ "role": user_entity.role or preapproved["role"],
343
+ "roles": [user_entity.role or preapproved["role"]],
344
+ }
345
+
346
+ # Generate JWT tokens
347
+ jwt_service = get_jwt_service()
348
+ tokens = jwt_service.create_tokens(user_dict)
349
+
350
+ # Store user in session
351
+ request.session["user"] = user_dict
352
+
353
+ logger.info(f"User authenticated via pre-approved code: {body.email} (role: {user_dict['role']})")
354
+
355
+ return {
356
+ "success": True,
357
+ "message": "Successfully authenticated with pre-approved code!",
358
+ "user": user_dict,
359
+ "access_token": tokens["access_token"],
360
+ "refresh_token": tokens["refresh_token"],
361
+ "token_type": tokens["token_type"],
362
+ "expires_in": tokens["expires_in"],
363
+ }
364
+
365
+ # Standard email verification flow
366
+ if not settings.email.is_configured:
367
+ raise HTTPException(
368
+ status_code=501,
369
+ detail="Email authentication is not configured"
370
+ )
251
371
 
252
372
  # Initialize email auth provider
253
373
  email_auth = EmailAuthProvider()
@@ -272,7 +392,6 @@ async def verify_email_code(request: Request, body: EmailVerifyRequest):
272
392
  )
273
393
 
274
394
  # Fetch actual user data from database to get role/tier
275
- user_service = UserService(db)
276
395
  try:
277
396
  user_entity = await user_service.get_user_by_id(result.user_id)
278
397
  if user_entity:
@@ -319,7 +438,13 @@ async def verify_email_code(request: Request, body: EmailVerifyRequest):
319
438
  # =============================================================================
320
439
 
321
440
 
322
- @router.get("/{provider}/login")
441
+ @router.get(
442
+ "/{provider}/login",
443
+ responses={
444
+ 400: {"model": ErrorResponse, "description": "Unknown OAuth provider"},
445
+ 501: {"model": ErrorResponse, "description": "Authentication is disabled"},
446
+ },
447
+ )
323
448
  async def login(provider: str, request: Request):
324
449
  """
325
450
  Initiate OAuth flow with provider.
@@ -361,7 +486,13 @@ async def login(provider: str, request: Request):
361
486
  return await client.authorize_redirect(request, redirect_uri)
362
487
 
363
488
 
364
- @router.get("/{provider}/callback")
489
+ @router.get(
490
+ "/{provider}/callback",
491
+ responses={
492
+ 400: {"model": ErrorResponse, "description": "Authentication failed or unknown provider"},
493
+ 501: {"model": ErrorResponse, "description": "Authentication is disabled"},
494
+ },
495
+ )
365
496
  async def callback(provider: str, request: Request):
366
497
  """
367
498
  OAuth callback endpoint.
@@ -498,7 +629,12 @@ async def logout(request: Request):
498
629
  return {"message": "Logged out successfully"}
499
630
 
500
631
 
501
- @router.get("/me")
632
+ @router.get(
633
+ "/me",
634
+ responses={
635
+ 401: {"model": ErrorResponse, "description": "Not authenticated"},
636
+ },
637
+ )
502
638
  async def me(request: Request):
503
639
  """
504
640
  Get current user information from session or JWT.
@@ -536,7 +672,12 @@ class TokenRefreshRequest(BaseModel):
536
672
  refresh_token: str
537
673
 
538
674
 
539
- @router.post("/token/refresh")
675
+ @router.post(
676
+ "/token/refresh",
677
+ responses={
678
+ 401: {"model": ErrorResponse, "description": "Invalid or expired refresh token"},
679
+ },
680
+ )
540
681
  async def refresh_token(body: TokenRefreshRequest):
541
682
  """
542
683
  Refresh access token using refresh token.
@@ -601,7 +742,12 @@ async def refresh_token(body: TokenRefreshRequest):
601
742
  return result
602
743
 
603
744
 
604
- @router.post("/token/verify")
745
+ @router.post(
746
+ "/token/verify",
747
+ responses={
748
+ 401: {"model": ErrorResponse, "description": "Missing, invalid, or expired token"},
749
+ },
750
+ )
605
751
  async def verify_token(request: Request):
606
752
  """
607
753
  Verify an access token is valid.
@@ -665,7 +811,12 @@ def verify_dev_token(token: str) -> bool:
665
811
  return token == expected
666
812
 
667
813
 
668
- @router.get("/dev/token")
814
+ @router.get(
815
+ "/dev/token",
816
+ responses={
817
+ 401: {"model": ErrorResponse, "description": "Dev tokens not available in production"},
818
+ },
819
+ )
669
820
  async def get_dev_token(request: Request):
670
821
  """
671
822
  Get a development token for testing (non-production only).
@@ -701,7 +852,13 @@ async def get_dev_token(request: Request):
701
852
  }
702
853
 
703
854
 
704
- @router.get("/dev/mock-code/{email}")
855
+ @router.get(
856
+ "/dev/mock-code/{email}",
857
+ responses={
858
+ 401: {"model": ErrorResponse, "description": "Mock codes not available in production"},
859
+ 404: {"model": ErrorResponse, "description": "No code found for email"},
860
+ },
861
+ )
705
862
  async def get_mock_code(email: str, request: Request):
706
863
  """
707
864
  Get the mock login code for testing (non-production only).
@@ -0,0 +1,394 @@
1
+ """
2
+ Child Agent Event Handling.
3
+
4
+ Handles events from child agents during multi-agent orchestration.
5
+
6
+ Event Flow:
7
+ ```
8
+ Parent Agent (Siggy)
9
+
10
+
11
+ ask_agent tool
12
+
13
+ ├──────────────────────────────────┐
14
+ ▼ │
15
+ Child Agent (intake_diverge) │
16
+ │ │
17
+ ├── child_tool_start ──────────────┼──► Event Sink (Queue)
18
+ ├── child_content ─────────────────┤
19
+ └── child_tool_result ─────────────┘
20
+
21
+
22
+ drain_child_events()
23
+
24
+ ├── SSE to client
25
+ └── DB persistence
26
+ ```
27
+
28
+ IMPORTANT: When child_content is streamed, parent text output should be SKIPPED
29
+ to prevent content duplication.
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import asyncio
35
+ import json
36
+ import uuid
37
+ from typing import TYPE_CHECKING, Any, AsyncGenerator
38
+
39
+ from loguru import logger
40
+
41
+ from .streaming_utils import StreamingState, build_content_chunk
42
+ from .sse_events import MetadataEvent, ToolCallEvent, format_sse_event
43
+ from ....services.session import SessionMessageStore
44
+ from ....settings import settings
45
+ from ....utils.date_utils import to_iso, utc_now
46
+
47
+ if TYPE_CHECKING:
48
+ from ....agentic.context import AgentContext
49
+
50
+
51
+ async def handle_child_tool_start(
52
+ state: StreamingState,
53
+ child_agent: str,
54
+ tool_name: str,
55
+ arguments: dict | str | None,
56
+ session_id: str | None,
57
+ user_id: str | None,
58
+ ) -> AsyncGenerator[str, None]:
59
+ """
60
+ Handle child_tool_start event.
61
+
62
+ Actions:
63
+ 1. Log the tool call
64
+ 2. Emit SSE event
65
+ 3. Save to database (with tool_arguments in metadata for consistency with parent)
66
+ """
67
+ full_tool_name = f"{child_agent}:{tool_name}"
68
+ tool_id = f"call_{uuid.uuid4().hex[:8]}"
69
+
70
+ # Normalize arguments - may come as JSON string from ToolCallPart.args
71
+ if isinstance(arguments, str):
72
+ try:
73
+ arguments = json.loads(arguments)
74
+ except json.JSONDecodeError:
75
+ arguments = None
76
+ elif not isinstance(arguments, dict):
77
+ arguments = None
78
+
79
+ # 1. LOG
80
+ logger.info(f"🔧 {full_tool_name}")
81
+
82
+ # 2. EMIT SSE
83
+ yield format_sse_event(ToolCallEvent(
84
+ tool_name=full_tool_name,
85
+ tool_id=tool_id,
86
+ status="started",
87
+ arguments=arguments,
88
+ ))
89
+
90
+ # 3. SAVE TO DB - content contains args as JSON (pydantic_messages.py parses it)
91
+ if session_id and settings.postgres.enabled:
92
+ try:
93
+ store = SessionMessageStore(
94
+ user_id=user_id or settings.test.effective_user_id
95
+ )
96
+ tool_msg = {
97
+ "role": "tool",
98
+ # Content is the tool call args as JSON - this is what the agent sees on reload
99
+ # and what pydantic_messages.py parses for ToolCallPart.args
100
+ "content": json.dumps(arguments) if arguments else "",
101
+ "timestamp": to_iso(utc_now()),
102
+ "tool_call_id": tool_id,
103
+ "tool_name": full_tool_name,
104
+ }
105
+ await store.store_session_messages(
106
+ session_id=session_id,
107
+ messages=[tool_msg],
108
+ user_id=user_id,
109
+ compress=False,
110
+ )
111
+ except Exception as e:
112
+ logger.warning(f"Failed to save child tool call: {e}")
113
+
114
+
115
+ def handle_child_content(
116
+ state: StreamingState,
117
+ child_agent: str,
118
+ content: str,
119
+ ) -> str | None:
120
+ """
121
+ Handle child_content event.
122
+
123
+ CRITICAL: Sets state.child_content_streamed = True
124
+ This flag is used to skip parent text output and prevent duplication.
125
+
126
+ Returns:
127
+ SSE chunk or None if content is empty
128
+ """
129
+ if not content:
130
+ return None
131
+
132
+ # Track that child content was streamed
133
+ # Parent text output should be SKIPPED when this is True
134
+ state.child_content_streamed = True
135
+ state.responding_agent = child_agent
136
+
137
+ return build_content_chunk(state, content)
138
+
139
+
140
+ async def handle_child_tool_result(
141
+ state: StreamingState,
142
+ child_agent: str,
143
+ result: Any,
144
+ message_id: str | None,
145
+ session_id: str | None,
146
+ agent_schema: str | None,
147
+ ) -> AsyncGenerator[str, None]:
148
+ """
149
+ Handle child_tool_result event.
150
+
151
+ Actions:
152
+ 1. Log metadata if present
153
+ 2. Emit metadata event if present
154
+ 3. Emit tool completion event
155
+ """
156
+ # Check for metadata registration
157
+ if isinstance(result, dict) and result.get("_metadata_event"):
158
+ risk = result.get("risk_level", "")
159
+ conf = result.get("confidence", "")
160
+ logger.info(f"📊 {child_agent} metadata: risk={risk}, confidence={conf}")
161
+
162
+ # Update responding agent from child
163
+ if result.get("agent_schema"):
164
+ state.responding_agent = result.get("agent_schema")
165
+
166
+ # Build extra dict with risk fields
167
+ extra_data = {}
168
+ if risk:
169
+ extra_data["risk_level"] = risk
170
+
171
+ yield format_sse_event(MetadataEvent(
172
+ message_id=message_id,
173
+ session_id=session_id,
174
+ agent_schema=agent_schema,
175
+ responding_agent=state.responding_agent,
176
+ confidence=result.get("confidence"),
177
+ extra=extra_data if extra_data else None,
178
+ ))
179
+
180
+ # Emit tool completion
181
+ # Preserve full result dict if it contains an artifact (e.g. finalize_intake)
182
+ # This is needed for frontend to extract artifact URLs for download
183
+ if isinstance(result, dict) and result.get("artifact"):
184
+ result_for_sse = result # Full dict with artifact
185
+ else:
186
+ result_for_sse = str(result)[:200] if result else None
187
+
188
+ yield format_sse_event(ToolCallEvent(
189
+ tool_name=f"{child_agent}:tool",
190
+ tool_id=f"call_{uuid.uuid4().hex[:8]}",
191
+ status="completed",
192
+ result=result_for_sse,
193
+ ))
194
+
195
+
196
+ async def drain_child_events(
197
+ event_sink: asyncio.Queue,
198
+ state: StreamingState,
199
+ session_id: str | None = None,
200
+ user_id: str | None = None,
201
+ message_id: str | None = None,
202
+ agent_schema: str | None = None,
203
+ ) -> AsyncGenerator[str, None]:
204
+ """
205
+ Drain all pending child events from the event sink.
206
+
207
+ This is called during tool execution to process events
208
+ pushed by child agents via ask_agent.
209
+
210
+ IMPORTANT: When child_content events are processed, this sets
211
+ state.child_content_streamed = True. Callers should check this
212
+ flag and skip parent text output to prevent duplication.
213
+ """
214
+ while not event_sink.empty():
215
+ try:
216
+ child_event = event_sink.get_nowait()
217
+ async for chunk in process_child_event(
218
+ child_event, state, session_id, user_id, message_id, agent_schema
219
+ ):
220
+ yield chunk
221
+ except Exception as e:
222
+ logger.warning(f"Error processing child event: {e}")
223
+
224
+
225
+ async def process_child_event(
226
+ child_event: dict,
227
+ state: StreamingState,
228
+ session_id: str | None = None,
229
+ user_id: str | None = None,
230
+ message_id: str | None = None,
231
+ agent_schema: str | None = None,
232
+ ) -> AsyncGenerator[str, None]:
233
+ """Process a single child event and yield SSE chunks."""
234
+ event_type = child_event.get("type", "")
235
+ child_agent = child_event.get("agent_name", "child")
236
+
237
+ if event_type == "child_tool_start":
238
+ async for chunk in handle_child_tool_start(
239
+ state=state,
240
+ child_agent=child_agent,
241
+ tool_name=child_event.get("tool_name", "tool"),
242
+ arguments=child_event.get("arguments"),
243
+ session_id=session_id,
244
+ user_id=user_id,
245
+ ):
246
+ yield chunk
247
+
248
+ elif event_type == "child_content":
249
+ chunk = handle_child_content(
250
+ state=state,
251
+ child_agent=child_agent,
252
+ content=child_event.get("content", ""),
253
+ )
254
+ if chunk:
255
+ yield chunk
256
+
257
+ elif event_type == "child_tool_result":
258
+ async for chunk in handle_child_tool_result(
259
+ state=state,
260
+ child_agent=child_agent,
261
+ result=child_event.get("result"),
262
+ message_id=message_id,
263
+ session_id=session_id,
264
+ agent_schema=agent_schema,
265
+ ):
266
+ yield chunk
267
+
268
+
269
+ async def stream_with_child_events(
270
+ tools_stream,
271
+ child_event_sink: asyncio.Queue,
272
+ state: StreamingState,
273
+ session_id: str | None = None,
274
+ user_id: str | None = None,
275
+ message_id: str | None = None,
276
+ agent_schema: str | None = None,
277
+ ) -> AsyncGenerator[tuple[str, Any], None]:
278
+ """
279
+ Multiplex tool events with child events using asyncio.wait().
280
+
281
+ This is the key fix for child agent streaming - instead of draining
282
+ the queue synchronously during tool event iteration, we concurrently
283
+ listen to both sources and yield events as they arrive.
284
+
285
+ Yields:
286
+ Tuples of (event_type, event_data) where event_type is either
287
+ "tool" or "child", allowing the caller to handle each appropriately.
288
+ """
289
+ tool_iter = tools_stream.__aiter__()
290
+
291
+ # Create initial tasks
292
+ pending_tool: asyncio.Task | None = None
293
+ pending_child: asyncio.Task | None = None
294
+
295
+ try:
296
+ pending_tool = asyncio.create_task(tool_iter.__anext__())
297
+ except StopAsyncIteration:
298
+ # No tool events, just drain any remaining child events
299
+ while not child_event_sink.empty():
300
+ try:
301
+ child_event = child_event_sink.get_nowait()
302
+ yield ("child", child_event)
303
+ except asyncio.QueueEmpty:
304
+ break
305
+ return
306
+
307
+ # Start listening for child events with a short timeout
308
+ pending_child = asyncio.create_task(
309
+ _get_child_event_with_timeout(child_event_sink, timeout=0.05)
310
+ )
311
+
312
+ try:
313
+ while True:
314
+ # Wait for either source to produce an event
315
+ tasks = {t for t in [pending_tool, pending_child] if t is not None}
316
+ if not tasks:
317
+ break
318
+
319
+ done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
320
+
321
+ for task in done:
322
+ try:
323
+ result = task.result()
324
+ except asyncio.TimeoutError:
325
+ # Child queue timeout - restart listener
326
+ if task is pending_child:
327
+ pending_child = asyncio.create_task(
328
+ _get_child_event_with_timeout(child_event_sink, timeout=0.05)
329
+ )
330
+ continue
331
+ except StopAsyncIteration:
332
+ # Tool stream exhausted
333
+ if task is pending_tool:
334
+ pending_tool = None
335
+ # Final drain of any remaining child events
336
+ if pending_child:
337
+ pending_child.cancel()
338
+ try:
339
+ await pending_child
340
+ except asyncio.CancelledError:
341
+ pass
342
+ while not child_event_sink.empty():
343
+ try:
344
+ child_event = child_event_sink.get_nowait()
345
+ yield ("child", child_event)
346
+ except asyncio.QueueEmpty:
347
+ break
348
+ return
349
+ continue
350
+
351
+ if task is pending_child and result is not None:
352
+ # Got a child event
353
+ yield ("child", result)
354
+ # Restart child listener
355
+ pending_child = asyncio.create_task(
356
+ _get_child_event_with_timeout(child_event_sink, timeout=0.05)
357
+ )
358
+ elif task is pending_tool:
359
+ # Got a tool event
360
+ yield ("tool", result)
361
+ # Get next tool event
362
+ try:
363
+ pending_tool = asyncio.create_task(tool_iter.__anext__())
364
+ except StopAsyncIteration:
365
+ pending_tool = None
366
+ elif task is pending_child and result is None:
367
+ # Timeout with no event - restart listener
368
+ pending_child = asyncio.create_task(
369
+ _get_child_event_with_timeout(child_event_sink, timeout=0.05)
370
+ )
371
+ finally:
372
+ # Cleanup any pending tasks
373
+ for task in [pending_tool, pending_child]:
374
+ if task and not task.done():
375
+ task.cancel()
376
+ try:
377
+ await task
378
+ except asyncio.CancelledError:
379
+ pass
380
+
381
+
382
+ async def _get_child_event_with_timeout(
383
+ queue: asyncio.Queue, timeout: float = 0.05
384
+ ) -> dict | None:
385
+ """
386
+ Get an event from the queue with a timeout.
387
+
388
+ Returns None on timeout (no event available).
389
+ This allows the multiplexer to check for tool events regularly.
390
+ """
391
+ try:
392
+ return await asyncio.wait_for(queue.get(), timeout=timeout)
393
+ except asyncio.TimeoutError:
394
+ return None