remdb 0.3.171__py3-none-any.whl → 0.3.230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. rem/agentic/README.md +36 -2
  2. rem/agentic/context.py +173 -0
  3. rem/agentic/context_builder.py +12 -2
  4. rem/agentic/mcp/tool_wrapper.py +39 -16
  5. rem/agentic/providers/pydantic_ai.py +78 -45
  6. rem/agentic/schema.py +6 -5
  7. rem/agentic/tools/rem_tools.py +11 -0
  8. rem/api/main.py +1 -1
  9. rem/api/mcp_router/resources.py +75 -14
  10. rem/api/mcp_router/server.py +31 -24
  11. rem/api/mcp_router/tools.py +621 -166
  12. rem/api/routers/admin.py +30 -4
  13. rem/api/routers/auth.py +114 -15
  14. rem/api/routers/chat/child_streaming.py +379 -0
  15. rem/api/routers/chat/completions.py +74 -37
  16. rem/api/routers/chat/sse_events.py +7 -3
  17. rem/api/routers/chat/streaming.py +352 -257
  18. rem/api/routers/chat/streaming_utils.py +327 -0
  19. rem/api/routers/common.py +18 -0
  20. rem/api/routers/dev.py +7 -1
  21. rem/api/routers/feedback.py +9 -1
  22. rem/api/routers/messages.py +176 -38
  23. rem/api/routers/models.py +9 -1
  24. rem/api/routers/query.py +12 -1
  25. rem/api/routers/shared_sessions.py +16 -0
  26. rem/auth/jwt.py +19 -4
  27. rem/auth/middleware.py +42 -28
  28. rem/cli/README.md +62 -0
  29. rem/cli/commands/ask.py +61 -81
  30. rem/cli/commands/db.py +148 -70
  31. rem/cli/commands/process.py +171 -43
  32. rem/models/entities/ontology.py +91 -101
  33. rem/schemas/agents/rem.yaml +1 -1
  34. rem/services/content/service.py +18 -5
  35. rem/services/email/service.py +11 -2
  36. rem/services/embeddings/worker.py +26 -12
  37. rem/services/postgres/__init__.py +28 -3
  38. rem/services/postgres/diff_service.py +57 -5
  39. rem/services/postgres/programmable_diff_service.py +635 -0
  40. rem/services/postgres/pydantic_to_sqlalchemy.py +2 -2
  41. rem/services/postgres/register_type.py +12 -11
  42. rem/services/postgres/repository.py +39 -29
  43. rem/services/postgres/schema_generator.py +5 -5
  44. rem/services/postgres/sql_builder.py +6 -5
  45. rem/services/session/__init__.py +8 -1
  46. rem/services/session/compression.py +40 -2
  47. rem/services/session/pydantic_messages.py +292 -0
  48. rem/settings.py +34 -0
  49. rem/sql/background_indexes.sql +5 -0
  50. rem/sql/migrations/001_install.sql +157 -10
  51. rem/sql/migrations/002_install_models.sql +160 -132
  52. rem/sql/migrations/004_cache_system.sql +7 -275
  53. rem/sql/migrations/migrate_session_id_to_uuid.sql +45 -0
  54. rem/utils/model_helpers.py +101 -0
  55. rem/utils/schema_loader.py +79 -51
  56. {remdb-0.3.171.dist-info → remdb-0.3.230.dist-info}/METADATA +2 -2
  57. {remdb-0.3.171.dist-info → remdb-0.3.230.dist-info}/RECORD +59 -53
  58. {remdb-0.3.171.dist-info → remdb-0.3.230.dist-info}/WHEEL +0 -0
  59. {remdb-0.3.171.dist-info → remdb-0.3.230.dist-info}/entry_points.txt +0 -0
rem/api/routers/admin.py CHANGED
@@ -31,6 +31,8 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Background
31
31
  from loguru import logger
32
32
  from pydantic import BaseModel
33
33
 
34
+ from .common import ErrorResponse
35
+
34
36
  from ..deps import require_admin
35
37
  from ...models.entities import Message, Session, SessionMode
36
38
  from ...services.postgres import Repository
@@ -103,7 +105,13 @@ class SystemStats(BaseModel):
103
105
  # =============================================================================
104
106
 
105
107
 
106
- @router.get("/users", response_model=UserListResponse)
108
+ @router.get(
109
+ "/users",
110
+ response_model=UserListResponse,
111
+ responses={
112
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
113
+ },
114
+ )
107
115
  async def list_all_users(
108
116
  user: dict = Depends(require_admin),
109
117
  limit: int = Query(default=50, ge=1, le=100),
@@ -155,7 +163,13 @@ async def list_all_users(
155
163
  return UserListResponse(data=summaries, total=total, has_more=has_more)
156
164
 
157
165
 
158
- @router.get("/sessions", response_model=SessionListResponse)
166
+ @router.get(
167
+ "/sessions",
168
+ response_model=SessionListResponse,
169
+ responses={
170
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
171
+ },
172
+ )
159
173
  async def list_all_sessions(
160
174
  user: dict = Depends(require_admin),
161
175
  user_id: str | None = Query(default=None, description="Filter by user ID"),
@@ -202,7 +216,13 @@ async def list_all_sessions(
202
216
  return SessionListResponse(data=sessions, total=total, has_more=has_more)
203
217
 
204
218
 
205
- @router.get("/messages", response_model=MessageListResponse)
219
+ @router.get(
220
+ "/messages",
221
+ response_model=MessageListResponse,
222
+ responses={
223
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
224
+ },
225
+ )
206
226
  async def list_all_messages(
207
227
  user: dict = Depends(require_admin),
208
228
  user_id: str | None = Query(default=None, description="Filter by user ID"),
@@ -252,7 +272,13 @@ async def list_all_messages(
252
272
  return MessageListResponse(data=messages, total=total, has_more=has_more)
253
273
 
254
274
 
255
- @router.get("/stats", response_model=SystemStats)
275
+ @router.get(
276
+ "/stats",
277
+ response_model=SystemStats,
278
+ responses={
279
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
280
+ },
281
+ )
256
282
  async def get_system_stats(
257
283
  user: dict = Depends(require_admin),
258
284
  ) -> SystemStats:
rem/api/routers/auth.py CHANGED
@@ -30,14 +30,17 @@ Access Control Flow (send-code):
30
30
  │ ├── Yes → Check user.tier
31
31
  │ │ ├── tier == BLOCKED → Reject "Account is blocked"
32
32
  │ │ └── tier != BLOCKED → Allow (send code, existing users grandfathered)
33
- │ └── No (new user) → Check EMAIL__TRUSTED_EMAIL_DOMAINS
34
- │ ├── Setting configureddomain in trusted list?
35
- │ ├── Yes Create user & send code
36
- │ └── NoReject "Email domain not allowed for signup"
37
- └── Not configured (empty) → Create user & send code (no restrictions)
33
+ │ └── No (new user) → Check subscriber list first
34
+ │ ├── Email in subscribers table? Allow (create user & send code)
35
+ └── Not a subscriber Check EMAIL__TRUSTED_EMAIL_DOMAINS
36
+ ├── Setting configured → domain in trusted list?
37
+ │ ├── Yes → Create user & send code
38
+ │ │ └── No → Reject "Email domain not allowed for signup"
39
+ │ └── Not configured (empty) → Create user & send code (no restrictions)
38
40
 
39
41
  Key Behaviors:
40
42
  - Existing users: Always allowed to login (unless tier=BLOCKED)
43
+ - Subscribers: Always allowed to login (regardless of email domain)
41
44
  - New users: Must have email from trusted domain (if EMAIL__TRUSTED_EMAIL_DOMAINS is set)
42
45
  - No restrictions: Leave EMAIL__TRUSTED_EMAIL_DOMAINS empty to allow all domains
43
46
 
@@ -98,6 +101,8 @@ from authlib.integrations.starlette_client import OAuth
98
101
  from pydantic import BaseModel, EmailStr
99
102
  from loguru import logger
100
103
 
104
+ from .common import ErrorResponse
105
+
101
106
  from ...settings import settings
102
107
  from ...services.postgres.service import PostgresService
103
108
  from ...services.user_service import UserService
@@ -156,7 +161,14 @@ class EmailVerifyRequest(BaseModel):
156
161
  code: str
157
162
 
158
163
 
159
- @router.post("/email/send-code")
164
+ @router.post(
165
+ "/email/send-code",
166
+ responses={
167
+ 400: {"model": ErrorResponse, "description": "Invalid request or email rejected"},
168
+ 500: {"model": ErrorResponse, "description": "Failed to send login code"},
169
+ 501: {"model": ErrorResponse, "description": "Email auth or database not configured"},
170
+ },
171
+ )
160
172
  async def send_email_code(request: Request, body: EmailSendCodeRequest):
161
173
  """
162
174
  Send a login code to an email address.
@@ -218,7 +230,14 @@ async def send_email_code(request: Request, body: EmailSendCodeRequest):
218
230
  await db.disconnect()
219
231
 
220
232
 
221
- @router.post("/email/verify")
233
+ @router.post(
234
+ "/email/verify",
235
+ responses={
236
+ 400: {"model": ErrorResponse, "description": "Invalid or expired code"},
237
+ 500: {"model": ErrorResponse, "description": "Failed to verify login code"},
238
+ 501: {"model": ErrorResponse, "description": "Email auth or database not configured"},
239
+ },
240
+ )
222
241
  async def verify_email_code(request: Request, body: EmailVerifyRequest):
223
242
  """
224
243
  Verify login code and create session with JWT tokens.
@@ -316,7 +335,13 @@ async def verify_email_code(request: Request, body: EmailVerifyRequest):
316
335
  # =============================================================================
317
336
 
318
337
 
319
- @router.get("/{provider}/login")
338
+ @router.get(
339
+ "/{provider}/login",
340
+ responses={
341
+ 400: {"model": ErrorResponse, "description": "Unknown OAuth provider"},
342
+ 501: {"model": ErrorResponse, "description": "Authentication is disabled"},
343
+ },
344
+ )
320
345
  async def login(provider: str, request: Request):
321
346
  """
322
347
  Initiate OAuth flow with provider.
@@ -358,7 +383,13 @@ async def login(provider: str, request: Request):
358
383
  return await client.authorize_redirect(request, redirect_uri)
359
384
 
360
385
 
361
- @router.get("/{provider}/callback")
386
+ @router.get(
387
+ "/{provider}/callback",
388
+ responses={
389
+ 400: {"model": ErrorResponse, "description": "Authentication failed or unknown provider"},
390
+ 501: {"model": ErrorResponse, "description": "Authentication is disabled"},
391
+ },
392
+ )
362
393
  async def callback(provider: str, request: Request):
363
394
  """
364
395
  OAuth callback endpoint.
@@ -495,7 +526,12 @@ async def logout(request: Request):
495
526
  return {"message": "Logged out successfully"}
496
527
 
497
528
 
498
- @router.get("/me")
529
+ @router.get(
530
+ "/me",
531
+ responses={
532
+ 401: {"model": ErrorResponse, "description": "Not authenticated"},
533
+ },
534
+ )
499
535
  async def me(request: Request):
500
536
  """
501
537
  Get current user information from session or JWT.
@@ -533,11 +569,19 @@ class TokenRefreshRequest(BaseModel):
533
569
  refresh_token: str
534
570
 
535
571
 
536
- @router.post("/token/refresh")
572
+ @router.post(
573
+ "/token/refresh",
574
+ responses={
575
+ 401: {"model": ErrorResponse, "description": "Invalid or expired refresh token"},
576
+ },
577
+ )
537
578
  async def refresh_token(body: TokenRefreshRequest):
538
579
  """
539
580
  Refresh access token using refresh token.
540
581
 
582
+ Fetches the user's current role/tier from the database to ensure
583
+ the new access token reflects their actual permissions.
584
+
541
585
  Args:
542
586
  body: TokenRefreshRequest with refresh_token
543
587
 
@@ -545,7 +589,46 @@ async def refresh_token(body: TokenRefreshRequest):
545
589
  New access token or 401 if refresh token is invalid
546
590
  """
547
591
  jwt_service = get_jwt_service()
548
- result = jwt_service.refresh_access_token(body.refresh_token)
592
+
593
+ # First decode the refresh token to get user_id (without full verification yet)
594
+ payload = jwt_service.decode_without_verification(body.refresh_token)
595
+ if not payload:
596
+ raise HTTPException(
597
+ status_code=401,
598
+ detail="Invalid refresh token format"
599
+ )
600
+
601
+ user_id = payload.get("sub")
602
+ if not user_id:
603
+ raise HTTPException(
604
+ status_code=401,
605
+ detail="Invalid refresh token: missing user ID"
606
+ )
607
+
608
+ # Fetch user from database to get current role/tier
609
+ user_override = None
610
+ if settings.postgres.enabled:
611
+ db = PostgresService()
612
+ try:
613
+ await db.connect()
614
+ user_service = UserService(db)
615
+ user_entity = await user_service.get_user_by_id(user_id)
616
+ if user_entity:
617
+ user_override = {
618
+ "role": user_entity.role or "user",
619
+ "roles": [user_entity.role] if user_entity.role else ["user"],
620
+ "tier": user_entity.tier.value if user_entity.tier else "free",
621
+ "name": user_entity.name,
622
+ }
623
+ logger.debug(f"Refresh token: fetched user {user_id} with role={user_override['role']}, tier={user_override['tier']}")
624
+ except Exception as e:
625
+ logger.warning(f"Could not fetch user for token refresh: {e}")
626
+ # Continue without override - will use defaults
627
+ finally:
628
+ await db.disconnect()
629
+
630
+ # Now do the actual refresh with proper verification
631
+ result = jwt_service.refresh_access_token(body.refresh_token, user_override=user_override)
549
632
 
550
633
  if not result:
551
634
  raise HTTPException(
@@ -556,7 +639,12 @@ async def refresh_token(body: TokenRefreshRequest):
556
639
  return result
557
640
 
558
641
 
559
- @router.post("/token/verify")
642
+ @router.post(
643
+ "/token/verify",
644
+ responses={
645
+ 401: {"model": ErrorResponse, "description": "Missing, invalid, or expired token"},
646
+ },
647
+ )
560
648
  async def verify_token(request: Request):
561
649
  """
562
650
  Verify an access token is valid.
@@ -620,7 +708,12 @@ def verify_dev_token(token: str) -> bool:
620
708
  return token == expected
621
709
 
622
710
 
623
- @router.get("/dev/token")
711
+ @router.get(
712
+ "/dev/token",
713
+ responses={
714
+ 401: {"model": ErrorResponse, "description": "Dev tokens not available in production"},
715
+ },
716
+ )
624
717
  async def get_dev_token(request: Request):
625
718
  """
626
719
  Get a development token for testing (non-production only).
@@ -656,7 +749,13 @@ async def get_dev_token(request: Request):
656
749
  }
657
750
 
658
751
 
659
- @router.get("/dev/mock-code/{email}")
752
+ @router.get(
753
+ "/dev/mock-code/{email}",
754
+ responses={
755
+ 401: {"model": ErrorResponse, "description": "Mock codes not available in production"},
756
+ 404: {"model": ErrorResponse, "description": "No code found for email"},
757
+ },
758
+ )
660
759
  async def get_mock_code(email: str, request: Request):
661
760
  """
662
761
  Get the mock login code for testing (non-production only).
@@ -0,0 +1,379 @@
1
+ """
2
+ Child Agent Event Handling.
3
+
4
+ Handles events from child agents during multi-agent orchestration.
5
+
6
+ Event Flow:
7
+ ```
8
+ Parent Agent (Siggy)
9
+
10
+
11
+ ask_agent tool
12
+
13
+ ├──────────────────────────────────┐
14
+ ▼ │
15
+ Child Agent (intake_diverge) │
16
+ │ │
17
+ ├── child_tool_start ──────────────┼──► Event Sink (Queue)
18
+ ├── child_content ─────────────────┤
19
+ └── child_tool_result ─────────────┘
20
+
21
+
22
+ drain_child_events()
23
+
24
+ ├── SSE to client
25
+ └── DB persistence
26
+ ```
27
+
28
+ IMPORTANT: When child_content is streamed, parent text output should be SKIPPED
29
+ to prevent content duplication.
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import asyncio
35
+ import json
36
+ import uuid
37
+ from typing import TYPE_CHECKING, Any, AsyncGenerator
38
+
39
+ from loguru import logger
40
+
41
+ from .streaming_utils import StreamingState, build_content_chunk
42
+ from .sse_events import MetadataEvent, ToolCallEvent, format_sse_event
43
+ from ....services.session import SessionMessageStore
44
+ from ....settings import settings
45
+ from ....utils.date_utils import to_iso, utc_now
46
+
47
+ if TYPE_CHECKING:
48
+ from ....agentic.context import AgentContext
49
+
50
+
51
+ async def handle_child_tool_start(
52
+ state: StreamingState,
53
+ child_agent: str,
54
+ tool_name: str,
55
+ arguments: dict | None,
56
+ session_id: str | None,
57
+ user_id: str | None,
58
+ ) -> AsyncGenerator[str, None]:
59
+ """
60
+ Handle child_tool_start event.
61
+
62
+ Actions:
63
+ 1. Log the tool call
64
+ 2. Emit SSE event
65
+ 3. Save to database
66
+ """
67
+ full_tool_name = f"{child_agent}:{tool_name}"
68
+ tool_id = f"call_{uuid.uuid4().hex[:8]}"
69
+
70
+ # Normalize arguments
71
+ if not isinstance(arguments, dict):
72
+ arguments = None
73
+
74
+ # 1. LOG
75
+ logger.info(f"🔧 {full_tool_name}")
76
+
77
+ # 2. EMIT SSE
78
+ yield format_sse_event(ToolCallEvent(
79
+ tool_name=full_tool_name,
80
+ tool_id=tool_id,
81
+ status="started",
82
+ arguments=arguments,
83
+ ))
84
+
85
+ # 3. SAVE TO DB
86
+ if session_id and settings.postgres.enabled:
87
+ try:
88
+ store = SessionMessageStore(
89
+ user_id=user_id or settings.test.effective_user_id
90
+ )
91
+ tool_msg = {
92
+ "role": "tool",
93
+ "tool_name": full_tool_name,
94
+ "content": json.dumps(arguments) if arguments else "",
95
+ "timestamp": to_iso(utc_now()),
96
+ }
97
+ await store.store_session_messages(
98
+ session_id=session_id,
99
+ messages=[tool_msg],
100
+ user_id=user_id,
101
+ compress=False,
102
+ )
103
+ except Exception as e:
104
+ logger.warning(f"Failed to save child tool call: {e}")
105
+
106
+
107
+ def handle_child_content(
108
+ state: StreamingState,
109
+ child_agent: str,
110
+ content: str,
111
+ ) -> str | None:
112
+ """
113
+ Handle child_content event.
114
+
115
+ CRITICAL: Sets state.child_content_streamed = True
116
+ This flag is used to skip parent text output and prevent duplication.
117
+
118
+ Returns:
119
+ SSE chunk or None if content is empty
120
+ """
121
+ if not content:
122
+ return None
123
+
124
+ # Track that child content was streamed
125
+ # Parent text output should be SKIPPED when this is True
126
+ state.child_content_streamed = True
127
+ state.responding_agent = child_agent
128
+
129
+ return build_content_chunk(state, content)
130
+
131
+
132
+ async def handle_child_tool_result(
133
+ state: StreamingState,
134
+ child_agent: str,
135
+ result: Any,
136
+ message_id: str | None,
137
+ session_id: str | None,
138
+ agent_schema: str | None,
139
+ ) -> AsyncGenerator[str, None]:
140
+ """
141
+ Handle child_tool_result event.
142
+
143
+ Actions:
144
+ 1. Log metadata if present
145
+ 2. Emit metadata event if present
146
+ 3. Emit tool completion event
147
+ """
148
+ # Check for metadata registration
149
+ if isinstance(result, dict) and result.get("_metadata_event"):
150
+ risk = result.get("risk_level", "")
151
+ conf = result.get("confidence", "")
152
+ logger.info(f"📊 {child_agent} metadata: risk={risk}, confidence={conf}")
153
+
154
+ # Update responding agent from child
155
+ if result.get("agent_schema"):
156
+ state.responding_agent = result.get("agent_schema")
157
+
158
+ # Build extra dict with risk fields
159
+ extra_data = {}
160
+ if risk:
161
+ extra_data["risk_level"] = risk
162
+
163
+ yield format_sse_event(MetadataEvent(
164
+ message_id=message_id,
165
+ session_id=session_id,
166
+ agent_schema=agent_schema,
167
+ responding_agent=state.responding_agent,
168
+ confidence=result.get("confidence"),
169
+ extra=extra_data if extra_data else None,
170
+ ))
171
+
172
+ # Emit tool completion
173
+ yield format_sse_event(ToolCallEvent(
174
+ tool_name=f"{child_agent}:tool",
175
+ tool_id=f"call_{uuid.uuid4().hex[:8]}",
176
+ status="completed",
177
+ result=str(result)[:200] if result else None,
178
+ ))
179
+
180
+
181
+ async def drain_child_events(
182
+ event_sink: asyncio.Queue,
183
+ state: StreamingState,
184
+ session_id: str | None = None,
185
+ user_id: str | None = None,
186
+ message_id: str | None = None,
187
+ agent_schema: str | None = None,
188
+ ) -> AsyncGenerator[str, None]:
189
+ """
190
+ Drain all pending child events from the event sink.
191
+
192
+ This is called during tool execution to process events
193
+ pushed by child agents via ask_agent.
194
+
195
+ IMPORTANT: When child_content events are processed, this sets
196
+ state.child_content_streamed = True. Callers should check this
197
+ flag and skip parent text output to prevent duplication.
198
+ """
199
+ while not event_sink.empty():
200
+ try:
201
+ child_event = event_sink.get_nowait()
202
+ async for chunk in process_child_event(
203
+ child_event, state, session_id, user_id, message_id, agent_schema
204
+ ):
205
+ yield chunk
206
+ except Exception as e:
207
+ logger.warning(f"Error processing child event: {e}")
208
+
209
+
210
+ async def process_child_event(
211
+ child_event: dict,
212
+ state: StreamingState,
213
+ session_id: str | None = None,
214
+ user_id: str | None = None,
215
+ message_id: str | None = None,
216
+ agent_schema: str | None = None,
217
+ ) -> AsyncGenerator[str, None]:
218
+ """Process a single child event and yield SSE chunks."""
219
+ event_type = child_event.get("type", "")
220
+ child_agent = child_event.get("agent_name", "child")
221
+
222
+ if event_type == "child_tool_start":
223
+ async for chunk in handle_child_tool_start(
224
+ state=state,
225
+ child_agent=child_agent,
226
+ tool_name=child_event.get("tool_name", "tool"),
227
+ arguments=child_event.get("arguments"),
228
+ session_id=session_id,
229
+ user_id=user_id,
230
+ ):
231
+ yield chunk
232
+
233
+ elif event_type == "child_content":
234
+ chunk = handle_child_content(
235
+ state=state,
236
+ child_agent=child_agent,
237
+ content=child_event.get("content", ""),
238
+ )
239
+ if chunk:
240
+ yield chunk
241
+
242
+ elif event_type == "child_tool_result":
243
+ async for chunk in handle_child_tool_result(
244
+ state=state,
245
+ child_agent=child_agent,
246
+ result=child_event.get("result"),
247
+ message_id=message_id,
248
+ session_id=session_id,
249
+ agent_schema=agent_schema,
250
+ ):
251
+ yield chunk
252
+
253
+
254
+ async def stream_with_child_events(
255
+ tools_stream,
256
+ child_event_sink: asyncio.Queue,
257
+ state: StreamingState,
258
+ session_id: str | None = None,
259
+ user_id: str | None = None,
260
+ message_id: str | None = None,
261
+ agent_schema: str | None = None,
262
+ ) -> AsyncGenerator[tuple[str, Any], None]:
263
+ """
264
+ Multiplex tool events with child events using asyncio.wait().
265
+
266
+ This is the key fix for child agent streaming - instead of draining
267
+ the queue synchronously during tool event iteration, we concurrently
268
+ listen to both sources and yield events as they arrive.
269
+
270
+ Yields:
271
+ Tuples of (event_type, event_data) where event_type is either
272
+ "tool" or "child", allowing the caller to handle each appropriately.
273
+ """
274
+ tool_iter = tools_stream.__aiter__()
275
+
276
+ # Create initial tasks
277
+ pending_tool: asyncio.Task | None = None
278
+ pending_child: asyncio.Task | None = None
279
+
280
+ try:
281
+ pending_tool = asyncio.create_task(tool_iter.__anext__())
282
+ except StopAsyncIteration:
283
+ # No tool events, just drain any remaining child events
284
+ while not child_event_sink.empty():
285
+ try:
286
+ child_event = child_event_sink.get_nowait()
287
+ yield ("child", child_event)
288
+ except asyncio.QueueEmpty:
289
+ break
290
+ return
291
+
292
+ # Start listening for child events with a short timeout
293
+ pending_child = asyncio.create_task(
294
+ _get_child_event_with_timeout(child_event_sink, timeout=0.05)
295
+ )
296
+
297
+ try:
298
+ while True:
299
+ # Wait for either source to produce an event
300
+ tasks = {t for t in [pending_tool, pending_child] if t is not None}
301
+ if not tasks:
302
+ break
303
+
304
+ done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
305
+
306
+ for task in done:
307
+ try:
308
+ result = task.result()
309
+ except asyncio.TimeoutError:
310
+ # Child queue timeout - restart listener
311
+ if task is pending_child:
312
+ pending_child = asyncio.create_task(
313
+ _get_child_event_with_timeout(child_event_sink, timeout=0.05)
314
+ )
315
+ continue
316
+ except StopAsyncIteration:
317
+ # Tool stream exhausted
318
+ if task is pending_tool:
319
+ pending_tool = None
320
+ # Final drain of any remaining child events
321
+ if pending_child:
322
+ pending_child.cancel()
323
+ try:
324
+ await pending_child
325
+ except asyncio.CancelledError:
326
+ pass
327
+ while not child_event_sink.empty():
328
+ try:
329
+ child_event = child_event_sink.get_nowait()
330
+ yield ("child", child_event)
331
+ except asyncio.QueueEmpty:
332
+ break
333
+ return
334
+ continue
335
+
336
+ if task is pending_child and result is not None:
337
+ # Got a child event
338
+ yield ("child", result)
339
+ # Restart child listener
340
+ pending_child = asyncio.create_task(
341
+ _get_child_event_with_timeout(child_event_sink, timeout=0.05)
342
+ )
343
+ elif task is pending_tool:
344
+ # Got a tool event
345
+ yield ("tool", result)
346
+ # Get next tool event
347
+ try:
348
+ pending_tool = asyncio.create_task(tool_iter.__anext__())
349
+ except StopAsyncIteration:
350
+ pending_tool = None
351
+ elif task is pending_child and result is None:
352
+ # Timeout with no event - restart listener
353
+ pending_child = asyncio.create_task(
354
+ _get_child_event_with_timeout(child_event_sink, timeout=0.05)
355
+ )
356
+ finally:
357
+ # Cleanup any pending tasks
358
+ for task in [pending_tool, pending_child]:
359
+ if task and not task.done():
360
+ task.cancel()
361
+ try:
362
+ await task
363
+ except asyncio.CancelledError:
364
+ pass
365
+
366
+
367
+ async def _get_child_event_with_timeout(
368
+ queue: asyncio.Queue, timeout: float = 0.05
369
+ ) -> dict | None:
370
+ """
371
+ Get an event from the queue with a timeout.
372
+
373
+ Returns None on timeout (no event available).
374
+ This allows the multiplexer to check for tool events regularly.
375
+ """
376
+ try:
377
+ return await asyncio.wait_for(queue.get(), timeout=timeout)
378
+ except asyncio.TimeoutError:
379
+ return None