remdb 0.3.180__py3-none-any.whl → 0.3.230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rem/agentic/README.md +36 -2
- rem/agentic/context.py +173 -0
- rem/agentic/context_builder.py +12 -2
- rem/agentic/mcp/tool_wrapper.py +2 -2
- rem/agentic/providers/pydantic_ai.py +1 -1
- rem/agentic/schema.py +2 -2
- rem/api/main.py +1 -1
- rem/api/mcp_router/server.py +4 -0
- rem/api/mcp_router/tools.py +542 -166
- rem/api/routers/admin.py +30 -4
- rem/api/routers/auth.py +106 -10
- rem/api/routers/chat/child_streaming.py +379 -0
- rem/api/routers/chat/completions.py +74 -37
- rem/api/routers/chat/sse_events.py +7 -3
- rem/api/routers/chat/streaming.py +352 -257
- rem/api/routers/chat/streaming_utils.py +327 -0
- rem/api/routers/common.py +18 -0
- rem/api/routers/dev.py +7 -1
- rem/api/routers/feedback.py +9 -1
- rem/api/routers/messages.py +176 -38
- rem/api/routers/models.py +9 -1
- rem/api/routers/query.py +12 -1
- rem/api/routers/shared_sessions.py +16 -0
- rem/auth/jwt.py +19 -4
- rem/auth/middleware.py +42 -28
- rem/cli/README.md +62 -0
- rem/cli/commands/ask.py +61 -81
- rem/cli/commands/db.py +55 -31
- rem/cli/commands/process.py +171 -43
- rem/models/entities/ontology.py +18 -20
- rem/schemas/agents/rem.yaml +1 -1
- rem/services/content/service.py +18 -5
- rem/services/embeddings/worker.py +26 -12
- rem/services/postgres/__init__.py +28 -3
- rem/services/postgres/diff_service.py +57 -5
- rem/services/postgres/programmable_diff_service.py +635 -0
- rem/services/postgres/pydantic_to_sqlalchemy.py +2 -2
- rem/services/postgres/register_type.py +11 -10
- rem/services/postgres/repository.py +39 -29
- rem/services/postgres/schema_generator.py +5 -5
- rem/services/postgres/sql_builder.py +6 -5
- rem/services/session/__init__.py +8 -1
- rem/services/session/compression.py +40 -2
- rem/services/session/pydantic_messages.py +292 -0
- rem/settings.py +28 -0
- rem/sql/migrations/001_install.sql +125 -7
- rem/sql/migrations/002_install_models.sql +159 -149
- rem/sql/migrations/004_cache_system.sql +7 -275
- rem/sql/migrations/migrate_session_id_to_uuid.sql +45 -0
- rem/utils/schema_loader.py +79 -51
- {remdb-0.3.180.dist-info → remdb-0.3.230.dist-info}/METADATA +2 -2
- {remdb-0.3.180.dist-info → remdb-0.3.230.dist-info}/RECORD +54 -48
- {remdb-0.3.180.dist-info → remdb-0.3.230.dist-info}/WHEEL +0 -0
- {remdb-0.3.180.dist-info → remdb-0.3.230.dist-info}/entry_points.txt +0 -0
rem/api/routers/admin.py
CHANGED
|
@@ -31,6 +31,8 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Background
|
|
|
31
31
|
from loguru import logger
|
|
32
32
|
from pydantic import BaseModel
|
|
33
33
|
|
|
34
|
+
from .common import ErrorResponse
|
|
35
|
+
|
|
34
36
|
from ..deps import require_admin
|
|
35
37
|
from ...models.entities import Message, Session, SessionMode
|
|
36
38
|
from ...services.postgres import Repository
|
|
@@ -103,7 +105,13 @@ class SystemStats(BaseModel):
|
|
|
103
105
|
# =============================================================================
|
|
104
106
|
|
|
105
107
|
|
|
106
|
-
@router.get(
|
|
108
|
+
@router.get(
|
|
109
|
+
"/users",
|
|
110
|
+
response_model=UserListResponse,
|
|
111
|
+
responses={
|
|
112
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
113
|
+
},
|
|
114
|
+
)
|
|
107
115
|
async def list_all_users(
|
|
108
116
|
user: dict = Depends(require_admin),
|
|
109
117
|
limit: int = Query(default=50, ge=1, le=100),
|
|
@@ -155,7 +163,13 @@ async def list_all_users(
|
|
|
155
163
|
return UserListResponse(data=summaries, total=total, has_more=has_more)
|
|
156
164
|
|
|
157
165
|
|
|
158
|
-
@router.get(
|
|
166
|
+
@router.get(
|
|
167
|
+
"/sessions",
|
|
168
|
+
response_model=SessionListResponse,
|
|
169
|
+
responses={
|
|
170
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
171
|
+
},
|
|
172
|
+
)
|
|
159
173
|
async def list_all_sessions(
|
|
160
174
|
user: dict = Depends(require_admin),
|
|
161
175
|
user_id: str | None = Query(default=None, description="Filter by user ID"),
|
|
@@ -202,7 +216,13 @@ async def list_all_sessions(
|
|
|
202
216
|
return SessionListResponse(data=sessions, total=total, has_more=has_more)
|
|
203
217
|
|
|
204
218
|
|
|
205
|
-
@router.get(
|
|
219
|
+
@router.get(
|
|
220
|
+
"/messages",
|
|
221
|
+
response_model=MessageListResponse,
|
|
222
|
+
responses={
|
|
223
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
224
|
+
},
|
|
225
|
+
)
|
|
206
226
|
async def list_all_messages(
|
|
207
227
|
user: dict = Depends(require_admin),
|
|
208
228
|
user_id: str | None = Query(default=None, description="Filter by user ID"),
|
|
@@ -252,7 +272,13 @@ async def list_all_messages(
|
|
|
252
272
|
return MessageListResponse(data=messages, total=total, has_more=has_more)
|
|
253
273
|
|
|
254
274
|
|
|
255
|
-
@router.get(
|
|
275
|
+
@router.get(
|
|
276
|
+
"/stats",
|
|
277
|
+
response_model=SystemStats,
|
|
278
|
+
responses={
|
|
279
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
280
|
+
},
|
|
281
|
+
)
|
|
256
282
|
async def get_system_stats(
|
|
257
283
|
user: dict = Depends(require_admin),
|
|
258
284
|
) -> SystemStats:
|
rem/api/routers/auth.py
CHANGED
|
@@ -101,6 +101,8 @@ from authlib.integrations.starlette_client import OAuth
|
|
|
101
101
|
from pydantic import BaseModel, EmailStr
|
|
102
102
|
from loguru import logger
|
|
103
103
|
|
|
104
|
+
from .common import ErrorResponse
|
|
105
|
+
|
|
104
106
|
from ...settings import settings
|
|
105
107
|
from ...services.postgres.service import PostgresService
|
|
106
108
|
from ...services.user_service import UserService
|
|
@@ -159,7 +161,14 @@ class EmailVerifyRequest(BaseModel):
|
|
|
159
161
|
code: str
|
|
160
162
|
|
|
161
163
|
|
|
162
|
-
@router.post(
|
|
164
|
+
@router.post(
|
|
165
|
+
"/email/send-code",
|
|
166
|
+
responses={
|
|
167
|
+
400: {"model": ErrorResponse, "description": "Invalid request or email rejected"},
|
|
168
|
+
500: {"model": ErrorResponse, "description": "Failed to send login code"},
|
|
169
|
+
501: {"model": ErrorResponse, "description": "Email auth or database not configured"},
|
|
170
|
+
},
|
|
171
|
+
)
|
|
163
172
|
async def send_email_code(request: Request, body: EmailSendCodeRequest):
|
|
164
173
|
"""
|
|
165
174
|
Send a login code to an email address.
|
|
@@ -221,7 +230,14 @@ async def send_email_code(request: Request, body: EmailSendCodeRequest):
|
|
|
221
230
|
await db.disconnect()
|
|
222
231
|
|
|
223
232
|
|
|
224
|
-
@router.post(
|
|
233
|
+
@router.post(
|
|
234
|
+
"/email/verify",
|
|
235
|
+
responses={
|
|
236
|
+
400: {"model": ErrorResponse, "description": "Invalid or expired code"},
|
|
237
|
+
500: {"model": ErrorResponse, "description": "Failed to verify login code"},
|
|
238
|
+
501: {"model": ErrorResponse, "description": "Email auth or database not configured"},
|
|
239
|
+
},
|
|
240
|
+
)
|
|
225
241
|
async def verify_email_code(request: Request, body: EmailVerifyRequest):
|
|
226
242
|
"""
|
|
227
243
|
Verify login code and create session with JWT tokens.
|
|
@@ -319,7 +335,13 @@ async def verify_email_code(request: Request, body: EmailVerifyRequest):
|
|
|
319
335
|
# =============================================================================
|
|
320
336
|
|
|
321
337
|
|
|
322
|
-
@router.get(
|
|
338
|
+
@router.get(
|
|
339
|
+
"/{provider}/login",
|
|
340
|
+
responses={
|
|
341
|
+
400: {"model": ErrorResponse, "description": "Unknown OAuth provider"},
|
|
342
|
+
501: {"model": ErrorResponse, "description": "Authentication is disabled"},
|
|
343
|
+
},
|
|
344
|
+
)
|
|
323
345
|
async def login(provider: str, request: Request):
|
|
324
346
|
"""
|
|
325
347
|
Initiate OAuth flow with provider.
|
|
@@ -361,7 +383,13 @@ async def login(provider: str, request: Request):
|
|
|
361
383
|
return await client.authorize_redirect(request, redirect_uri)
|
|
362
384
|
|
|
363
385
|
|
|
364
|
-
@router.get(
|
|
386
|
+
@router.get(
|
|
387
|
+
"/{provider}/callback",
|
|
388
|
+
responses={
|
|
389
|
+
400: {"model": ErrorResponse, "description": "Authentication failed or unknown provider"},
|
|
390
|
+
501: {"model": ErrorResponse, "description": "Authentication is disabled"},
|
|
391
|
+
},
|
|
392
|
+
)
|
|
365
393
|
async def callback(provider: str, request: Request):
|
|
366
394
|
"""
|
|
367
395
|
OAuth callback endpoint.
|
|
@@ -498,7 +526,12 @@ async def logout(request: Request):
|
|
|
498
526
|
return {"message": "Logged out successfully"}
|
|
499
527
|
|
|
500
528
|
|
|
501
|
-
@router.get(
|
|
529
|
+
@router.get(
|
|
530
|
+
"/me",
|
|
531
|
+
responses={
|
|
532
|
+
401: {"model": ErrorResponse, "description": "Not authenticated"},
|
|
533
|
+
},
|
|
534
|
+
)
|
|
502
535
|
async def me(request: Request):
|
|
503
536
|
"""
|
|
504
537
|
Get current user information from session or JWT.
|
|
@@ -536,11 +569,19 @@ class TokenRefreshRequest(BaseModel):
|
|
|
536
569
|
refresh_token: str
|
|
537
570
|
|
|
538
571
|
|
|
539
|
-
@router.post(
|
|
572
|
+
@router.post(
|
|
573
|
+
"/token/refresh",
|
|
574
|
+
responses={
|
|
575
|
+
401: {"model": ErrorResponse, "description": "Invalid or expired refresh token"},
|
|
576
|
+
},
|
|
577
|
+
)
|
|
540
578
|
async def refresh_token(body: TokenRefreshRequest):
|
|
541
579
|
"""
|
|
542
580
|
Refresh access token using refresh token.
|
|
543
581
|
|
|
582
|
+
Fetches the user's current role/tier from the database to ensure
|
|
583
|
+
the new access token reflects their actual permissions.
|
|
584
|
+
|
|
544
585
|
Args:
|
|
545
586
|
body: TokenRefreshRequest with refresh_token
|
|
546
587
|
|
|
@@ -548,7 +589,46 @@ async def refresh_token(body: TokenRefreshRequest):
|
|
|
548
589
|
New access token or 401 if refresh token is invalid
|
|
549
590
|
"""
|
|
550
591
|
jwt_service = get_jwt_service()
|
|
551
|
-
|
|
592
|
+
|
|
593
|
+
# First decode the refresh token to get user_id (without full verification yet)
|
|
594
|
+
payload = jwt_service.decode_without_verification(body.refresh_token)
|
|
595
|
+
if not payload:
|
|
596
|
+
raise HTTPException(
|
|
597
|
+
status_code=401,
|
|
598
|
+
detail="Invalid refresh token format"
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
user_id = payload.get("sub")
|
|
602
|
+
if not user_id:
|
|
603
|
+
raise HTTPException(
|
|
604
|
+
status_code=401,
|
|
605
|
+
detail="Invalid refresh token: missing user ID"
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
# Fetch user from database to get current role/tier
|
|
609
|
+
user_override = None
|
|
610
|
+
if settings.postgres.enabled:
|
|
611
|
+
db = PostgresService()
|
|
612
|
+
try:
|
|
613
|
+
await db.connect()
|
|
614
|
+
user_service = UserService(db)
|
|
615
|
+
user_entity = await user_service.get_user_by_id(user_id)
|
|
616
|
+
if user_entity:
|
|
617
|
+
user_override = {
|
|
618
|
+
"role": user_entity.role or "user",
|
|
619
|
+
"roles": [user_entity.role] if user_entity.role else ["user"],
|
|
620
|
+
"tier": user_entity.tier.value if user_entity.tier else "free",
|
|
621
|
+
"name": user_entity.name,
|
|
622
|
+
}
|
|
623
|
+
logger.debug(f"Refresh token: fetched user {user_id} with role={user_override['role']}, tier={user_override['tier']}")
|
|
624
|
+
except Exception as e:
|
|
625
|
+
logger.warning(f"Could not fetch user for token refresh: {e}")
|
|
626
|
+
# Continue without override - will use defaults
|
|
627
|
+
finally:
|
|
628
|
+
await db.disconnect()
|
|
629
|
+
|
|
630
|
+
# Now do the actual refresh with proper verification
|
|
631
|
+
result = jwt_service.refresh_access_token(body.refresh_token, user_override=user_override)
|
|
552
632
|
|
|
553
633
|
if not result:
|
|
554
634
|
raise HTTPException(
|
|
@@ -559,7 +639,12 @@ async def refresh_token(body: TokenRefreshRequest):
|
|
|
559
639
|
return result
|
|
560
640
|
|
|
561
641
|
|
|
562
|
-
@router.post(
|
|
642
|
+
@router.post(
|
|
643
|
+
"/token/verify",
|
|
644
|
+
responses={
|
|
645
|
+
401: {"model": ErrorResponse, "description": "Missing, invalid, or expired token"},
|
|
646
|
+
},
|
|
647
|
+
)
|
|
563
648
|
async def verify_token(request: Request):
|
|
564
649
|
"""
|
|
565
650
|
Verify an access token is valid.
|
|
@@ -623,7 +708,12 @@ def verify_dev_token(token: str) -> bool:
|
|
|
623
708
|
return token == expected
|
|
624
709
|
|
|
625
710
|
|
|
626
|
-
@router.get(
|
|
711
|
+
@router.get(
|
|
712
|
+
"/dev/token",
|
|
713
|
+
responses={
|
|
714
|
+
401: {"model": ErrorResponse, "description": "Dev tokens not available in production"},
|
|
715
|
+
},
|
|
716
|
+
)
|
|
627
717
|
async def get_dev_token(request: Request):
|
|
628
718
|
"""
|
|
629
719
|
Get a development token for testing (non-production only).
|
|
@@ -659,7 +749,13 @@ async def get_dev_token(request: Request):
|
|
|
659
749
|
}
|
|
660
750
|
|
|
661
751
|
|
|
662
|
-
@router.get(
|
|
752
|
+
@router.get(
|
|
753
|
+
"/dev/mock-code/{email}",
|
|
754
|
+
responses={
|
|
755
|
+
401: {"model": ErrorResponse, "description": "Mock codes not available in production"},
|
|
756
|
+
404: {"model": ErrorResponse, "description": "No code found for email"},
|
|
757
|
+
},
|
|
758
|
+
)
|
|
663
759
|
async def get_mock_code(email: str, request: Request):
|
|
664
760
|
"""
|
|
665
761
|
Get the mock login code for testing (non-production only).
|
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Child Agent Event Handling.
|
|
3
|
+
|
|
4
|
+
Handles events from child agents during multi-agent orchestration.
|
|
5
|
+
|
|
6
|
+
Event Flow:
|
|
7
|
+
```
|
|
8
|
+
Parent Agent (Siggy)
|
|
9
|
+
│
|
|
10
|
+
▼
|
|
11
|
+
ask_agent tool
|
|
12
|
+
│
|
|
13
|
+
├──────────────────────────────────┐
|
|
14
|
+
▼ │
|
|
15
|
+
Child Agent (intake_diverge) │
|
|
16
|
+
│ │
|
|
17
|
+
├── child_tool_start ──────────────┼──► Event Sink (Queue)
|
|
18
|
+
├── child_content ─────────────────┤
|
|
19
|
+
└── child_tool_result ─────────────┘
|
|
20
|
+
│
|
|
21
|
+
▼
|
|
22
|
+
drain_child_events()
|
|
23
|
+
│
|
|
24
|
+
├── SSE to client
|
|
25
|
+
└── DB persistence
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
IMPORTANT: When child_content is streamed, parent text output should be SKIPPED
|
|
29
|
+
to prevent content duplication.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from __future__ import annotations
|
|
33
|
+
|
|
34
|
+
import asyncio
|
|
35
|
+
import json
|
|
36
|
+
import uuid
|
|
37
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator
|
|
38
|
+
|
|
39
|
+
from loguru import logger
|
|
40
|
+
|
|
41
|
+
from .streaming_utils import StreamingState, build_content_chunk
|
|
42
|
+
from .sse_events import MetadataEvent, ToolCallEvent, format_sse_event
|
|
43
|
+
from ....services.session import SessionMessageStore
|
|
44
|
+
from ....settings import settings
|
|
45
|
+
from ....utils.date_utils import to_iso, utc_now
|
|
46
|
+
|
|
47
|
+
if TYPE_CHECKING:
|
|
48
|
+
from ....agentic.context import AgentContext
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
async def handle_child_tool_start(
|
|
52
|
+
state: StreamingState,
|
|
53
|
+
child_agent: str,
|
|
54
|
+
tool_name: str,
|
|
55
|
+
arguments: dict | None,
|
|
56
|
+
session_id: str | None,
|
|
57
|
+
user_id: str | None,
|
|
58
|
+
) -> AsyncGenerator[str, None]:
|
|
59
|
+
"""
|
|
60
|
+
Handle child_tool_start event.
|
|
61
|
+
|
|
62
|
+
Actions:
|
|
63
|
+
1. Log the tool call
|
|
64
|
+
2. Emit SSE event
|
|
65
|
+
3. Save to database
|
|
66
|
+
"""
|
|
67
|
+
full_tool_name = f"{child_agent}:{tool_name}"
|
|
68
|
+
tool_id = f"call_{uuid.uuid4().hex[:8]}"
|
|
69
|
+
|
|
70
|
+
# Normalize arguments
|
|
71
|
+
if not isinstance(arguments, dict):
|
|
72
|
+
arguments = None
|
|
73
|
+
|
|
74
|
+
# 1. LOG
|
|
75
|
+
logger.info(f"🔧 {full_tool_name}")
|
|
76
|
+
|
|
77
|
+
# 2. EMIT SSE
|
|
78
|
+
yield format_sse_event(ToolCallEvent(
|
|
79
|
+
tool_name=full_tool_name,
|
|
80
|
+
tool_id=tool_id,
|
|
81
|
+
status="started",
|
|
82
|
+
arguments=arguments,
|
|
83
|
+
))
|
|
84
|
+
|
|
85
|
+
# 3. SAVE TO DB
|
|
86
|
+
if session_id and settings.postgres.enabled:
|
|
87
|
+
try:
|
|
88
|
+
store = SessionMessageStore(
|
|
89
|
+
user_id=user_id or settings.test.effective_user_id
|
|
90
|
+
)
|
|
91
|
+
tool_msg = {
|
|
92
|
+
"role": "tool",
|
|
93
|
+
"tool_name": full_tool_name,
|
|
94
|
+
"content": json.dumps(arguments) if arguments else "",
|
|
95
|
+
"timestamp": to_iso(utc_now()),
|
|
96
|
+
}
|
|
97
|
+
await store.store_session_messages(
|
|
98
|
+
session_id=session_id,
|
|
99
|
+
messages=[tool_msg],
|
|
100
|
+
user_id=user_id,
|
|
101
|
+
compress=False,
|
|
102
|
+
)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.warning(f"Failed to save child tool call: {e}")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def handle_child_content(
|
|
108
|
+
state: StreamingState,
|
|
109
|
+
child_agent: str,
|
|
110
|
+
content: str,
|
|
111
|
+
) -> str | None:
|
|
112
|
+
"""
|
|
113
|
+
Handle child_content event.
|
|
114
|
+
|
|
115
|
+
CRITICAL: Sets state.child_content_streamed = True
|
|
116
|
+
This flag is used to skip parent text output and prevent duplication.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
SSE chunk or None if content is empty
|
|
120
|
+
"""
|
|
121
|
+
if not content:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
# Track that child content was streamed
|
|
125
|
+
# Parent text output should be SKIPPED when this is True
|
|
126
|
+
state.child_content_streamed = True
|
|
127
|
+
state.responding_agent = child_agent
|
|
128
|
+
|
|
129
|
+
return build_content_chunk(state, content)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
async def handle_child_tool_result(
|
|
133
|
+
state: StreamingState,
|
|
134
|
+
child_agent: str,
|
|
135
|
+
result: Any,
|
|
136
|
+
message_id: str | None,
|
|
137
|
+
session_id: str | None,
|
|
138
|
+
agent_schema: str | None,
|
|
139
|
+
) -> AsyncGenerator[str, None]:
|
|
140
|
+
"""
|
|
141
|
+
Handle child_tool_result event.
|
|
142
|
+
|
|
143
|
+
Actions:
|
|
144
|
+
1. Log metadata if present
|
|
145
|
+
2. Emit metadata event if present
|
|
146
|
+
3. Emit tool completion event
|
|
147
|
+
"""
|
|
148
|
+
# Check for metadata registration
|
|
149
|
+
if isinstance(result, dict) and result.get("_metadata_event"):
|
|
150
|
+
risk = result.get("risk_level", "")
|
|
151
|
+
conf = result.get("confidence", "")
|
|
152
|
+
logger.info(f"📊 {child_agent} metadata: risk={risk}, confidence={conf}")
|
|
153
|
+
|
|
154
|
+
# Update responding agent from child
|
|
155
|
+
if result.get("agent_schema"):
|
|
156
|
+
state.responding_agent = result.get("agent_schema")
|
|
157
|
+
|
|
158
|
+
# Build extra dict with risk fields
|
|
159
|
+
extra_data = {}
|
|
160
|
+
if risk:
|
|
161
|
+
extra_data["risk_level"] = risk
|
|
162
|
+
|
|
163
|
+
yield format_sse_event(MetadataEvent(
|
|
164
|
+
message_id=message_id,
|
|
165
|
+
session_id=session_id,
|
|
166
|
+
agent_schema=agent_schema,
|
|
167
|
+
responding_agent=state.responding_agent,
|
|
168
|
+
confidence=result.get("confidence"),
|
|
169
|
+
extra=extra_data if extra_data else None,
|
|
170
|
+
))
|
|
171
|
+
|
|
172
|
+
# Emit tool completion
|
|
173
|
+
yield format_sse_event(ToolCallEvent(
|
|
174
|
+
tool_name=f"{child_agent}:tool",
|
|
175
|
+
tool_id=f"call_{uuid.uuid4().hex[:8]}",
|
|
176
|
+
status="completed",
|
|
177
|
+
result=str(result)[:200] if result else None,
|
|
178
|
+
))
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
async def drain_child_events(
|
|
182
|
+
event_sink: asyncio.Queue,
|
|
183
|
+
state: StreamingState,
|
|
184
|
+
session_id: str | None = None,
|
|
185
|
+
user_id: str | None = None,
|
|
186
|
+
message_id: str | None = None,
|
|
187
|
+
agent_schema: str | None = None,
|
|
188
|
+
) -> AsyncGenerator[str, None]:
|
|
189
|
+
"""
|
|
190
|
+
Drain all pending child events from the event sink.
|
|
191
|
+
|
|
192
|
+
This is called during tool execution to process events
|
|
193
|
+
pushed by child agents via ask_agent.
|
|
194
|
+
|
|
195
|
+
IMPORTANT: When child_content events are processed, this sets
|
|
196
|
+
state.child_content_streamed = True. Callers should check this
|
|
197
|
+
flag and skip parent text output to prevent duplication.
|
|
198
|
+
"""
|
|
199
|
+
while not event_sink.empty():
|
|
200
|
+
try:
|
|
201
|
+
child_event = event_sink.get_nowait()
|
|
202
|
+
async for chunk in process_child_event(
|
|
203
|
+
child_event, state, session_id, user_id, message_id, agent_schema
|
|
204
|
+
):
|
|
205
|
+
yield chunk
|
|
206
|
+
except Exception as e:
|
|
207
|
+
logger.warning(f"Error processing child event: {e}")
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
async def process_child_event(
|
|
211
|
+
child_event: dict,
|
|
212
|
+
state: StreamingState,
|
|
213
|
+
session_id: str | None = None,
|
|
214
|
+
user_id: str | None = None,
|
|
215
|
+
message_id: str | None = None,
|
|
216
|
+
agent_schema: str | None = None,
|
|
217
|
+
) -> AsyncGenerator[str, None]:
|
|
218
|
+
"""Process a single child event and yield SSE chunks."""
|
|
219
|
+
event_type = child_event.get("type", "")
|
|
220
|
+
child_agent = child_event.get("agent_name", "child")
|
|
221
|
+
|
|
222
|
+
if event_type == "child_tool_start":
|
|
223
|
+
async for chunk in handle_child_tool_start(
|
|
224
|
+
state=state,
|
|
225
|
+
child_agent=child_agent,
|
|
226
|
+
tool_name=child_event.get("tool_name", "tool"),
|
|
227
|
+
arguments=child_event.get("arguments"),
|
|
228
|
+
session_id=session_id,
|
|
229
|
+
user_id=user_id,
|
|
230
|
+
):
|
|
231
|
+
yield chunk
|
|
232
|
+
|
|
233
|
+
elif event_type == "child_content":
|
|
234
|
+
chunk = handle_child_content(
|
|
235
|
+
state=state,
|
|
236
|
+
child_agent=child_agent,
|
|
237
|
+
content=child_event.get("content", ""),
|
|
238
|
+
)
|
|
239
|
+
if chunk:
|
|
240
|
+
yield chunk
|
|
241
|
+
|
|
242
|
+
elif event_type == "child_tool_result":
|
|
243
|
+
async for chunk in handle_child_tool_result(
|
|
244
|
+
state=state,
|
|
245
|
+
child_agent=child_agent,
|
|
246
|
+
result=child_event.get("result"),
|
|
247
|
+
message_id=message_id,
|
|
248
|
+
session_id=session_id,
|
|
249
|
+
agent_schema=agent_schema,
|
|
250
|
+
):
|
|
251
|
+
yield chunk
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
async def stream_with_child_events(
|
|
255
|
+
tools_stream,
|
|
256
|
+
child_event_sink: asyncio.Queue,
|
|
257
|
+
state: StreamingState,
|
|
258
|
+
session_id: str | None = None,
|
|
259
|
+
user_id: str | None = None,
|
|
260
|
+
message_id: str | None = None,
|
|
261
|
+
agent_schema: str | None = None,
|
|
262
|
+
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
263
|
+
"""
|
|
264
|
+
Multiplex tool events with child events using asyncio.wait().
|
|
265
|
+
|
|
266
|
+
This is the key fix for child agent streaming - instead of draining
|
|
267
|
+
the queue synchronously during tool event iteration, we concurrently
|
|
268
|
+
listen to both sources and yield events as they arrive.
|
|
269
|
+
|
|
270
|
+
Yields:
|
|
271
|
+
Tuples of (event_type, event_data) where event_type is either
|
|
272
|
+
"tool" or "child", allowing the caller to handle each appropriately.
|
|
273
|
+
"""
|
|
274
|
+
tool_iter = tools_stream.__aiter__()
|
|
275
|
+
|
|
276
|
+
# Create initial tasks
|
|
277
|
+
pending_tool: asyncio.Task | None = None
|
|
278
|
+
pending_child: asyncio.Task | None = None
|
|
279
|
+
|
|
280
|
+
try:
|
|
281
|
+
pending_tool = asyncio.create_task(tool_iter.__anext__())
|
|
282
|
+
except StopAsyncIteration:
|
|
283
|
+
# No tool events, just drain any remaining child events
|
|
284
|
+
while not child_event_sink.empty():
|
|
285
|
+
try:
|
|
286
|
+
child_event = child_event_sink.get_nowait()
|
|
287
|
+
yield ("child", child_event)
|
|
288
|
+
except asyncio.QueueEmpty:
|
|
289
|
+
break
|
|
290
|
+
return
|
|
291
|
+
|
|
292
|
+
# Start listening for child events with a short timeout
|
|
293
|
+
pending_child = asyncio.create_task(
|
|
294
|
+
_get_child_event_with_timeout(child_event_sink, timeout=0.05)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
while True:
|
|
299
|
+
# Wait for either source to produce an event
|
|
300
|
+
tasks = {t for t in [pending_tool, pending_child] if t is not None}
|
|
301
|
+
if not tasks:
|
|
302
|
+
break
|
|
303
|
+
|
|
304
|
+
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
305
|
+
|
|
306
|
+
for task in done:
|
|
307
|
+
try:
|
|
308
|
+
result = task.result()
|
|
309
|
+
except asyncio.TimeoutError:
|
|
310
|
+
# Child queue timeout - restart listener
|
|
311
|
+
if task is pending_child:
|
|
312
|
+
pending_child = asyncio.create_task(
|
|
313
|
+
_get_child_event_with_timeout(child_event_sink, timeout=0.05)
|
|
314
|
+
)
|
|
315
|
+
continue
|
|
316
|
+
except StopAsyncIteration:
|
|
317
|
+
# Tool stream exhausted
|
|
318
|
+
if task is pending_tool:
|
|
319
|
+
pending_tool = None
|
|
320
|
+
# Final drain of any remaining child events
|
|
321
|
+
if pending_child:
|
|
322
|
+
pending_child.cancel()
|
|
323
|
+
try:
|
|
324
|
+
await pending_child
|
|
325
|
+
except asyncio.CancelledError:
|
|
326
|
+
pass
|
|
327
|
+
while not child_event_sink.empty():
|
|
328
|
+
try:
|
|
329
|
+
child_event = child_event_sink.get_nowait()
|
|
330
|
+
yield ("child", child_event)
|
|
331
|
+
except asyncio.QueueEmpty:
|
|
332
|
+
break
|
|
333
|
+
return
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
if task is pending_child and result is not None:
|
|
337
|
+
# Got a child event
|
|
338
|
+
yield ("child", result)
|
|
339
|
+
# Restart child listener
|
|
340
|
+
pending_child = asyncio.create_task(
|
|
341
|
+
_get_child_event_with_timeout(child_event_sink, timeout=0.05)
|
|
342
|
+
)
|
|
343
|
+
elif task is pending_tool:
|
|
344
|
+
# Got a tool event
|
|
345
|
+
yield ("tool", result)
|
|
346
|
+
# Get next tool event
|
|
347
|
+
try:
|
|
348
|
+
pending_tool = asyncio.create_task(tool_iter.__anext__())
|
|
349
|
+
except StopAsyncIteration:
|
|
350
|
+
pending_tool = None
|
|
351
|
+
elif task is pending_child and result is None:
|
|
352
|
+
# Timeout with no event - restart listener
|
|
353
|
+
pending_child = asyncio.create_task(
|
|
354
|
+
_get_child_event_with_timeout(child_event_sink, timeout=0.05)
|
|
355
|
+
)
|
|
356
|
+
finally:
|
|
357
|
+
# Cleanup any pending tasks
|
|
358
|
+
for task in [pending_tool, pending_child]:
|
|
359
|
+
if task and not task.done():
|
|
360
|
+
task.cancel()
|
|
361
|
+
try:
|
|
362
|
+
await task
|
|
363
|
+
except asyncio.CancelledError:
|
|
364
|
+
pass
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
async def _get_child_event_with_timeout(
|
|
368
|
+
queue: asyncio.Queue, timeout: float = 0.05
|
|
369
|
+
) -> dict | None:
|
|
370
|
+
"""
|
|
371
|
+
Get an event from the queue with a timeout.
|
|
372
|
+
|
|
373
|
+
Returns None on timeout (no event available).
|
|
374
|
+
This allows the multiplexer to check for tool events regularly.
|
|
375
|
+
"""
|
|
376
|
+
try:
|
|
377
|
+
return await asyncio.wait_for(queue.get(), timeout=timeout)
|
|
378
|
+
except asyncio.TimeoutError:
|
|
379
|
+
return None
|