remdb 0.3.181__py3-none-any.whl → 0.3.202__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of remdb might be problematic. Click here for more details.
- rem/agentic/context.py +101 -0
- rem/agentic/context_builder.py +12 -2
- rem/api/main.py +1 -1
- rem/api/mcp_router/server.py +4 -0
- rem/api/mcp_router/tools.py +395 -159
- rem/api/routers/auth.py +43 -1
- rem/api/routers/chat/completions.py +51 -9
- rem/api/routers/chat/sse_events.py +2 -2
- rem/api/routers/chat/streaming.py +146 -21
- rem/api/routers/messages.py +96 -23
- rem/auth/jwt.py +19 -4
- rem/auth/middleware.py +42 -28
- rem/cli/README.md +62 -0
- rem/cli/commands/db.py +33 -19
- rem/cli/commands/process.py +171 -43
- rem/models/entities/ontology.py +18 -20
- rem/services/content/service.py +18 -5
- rem/services/postgres/__init__.py +28 -3
- rem/services/postgres/diff_service.py +57 -5
- rem/services/postgres/programmable_diff_service.py +635 -0
- rem/services/postgres/pydantic_to_sqlalchemy.py +2 -2
- rem/services/postgres/register_type.py +11 -10
- rem/services/session/__init__.py +7 -1
- rem/services/session/compression.py +42 -2
- rem/services/session/pydantic_messages.py +210 -0
- rem/sql/migrations/001_install.sql +125 -7
- rem/sql/migrations/002_install_models.sql +136 -126
- rem/sql/migrations/004_cache_system.sql +7 -275
- rem/utils/schema_loader.py +6 -6
- {remdb-0.3.181.dist-info → remdb-0.3.202.dist-info}/METADATA +1 -1
- {remdb-0.3.181.dist-info → remdb-0.3.202.dist-info}/RECORD +33 -31
- {remdb-0.3.181.dist-info → remdb-0.3.202.dist-info}/WHEEL +0 -0
- {remdb-0.3.181.dist-info → remdb-0.3.202.dist-info}/entry_points.txt +0 -0
rem/api/routers/auth.py
CHANGED
|
@@ -541,6 +541,9 @@ async def refresh_token(body: TokenRefreshRequest):
|
|
|
541
541
|
"""
|
|
542
542
|
Refresh access token using refresh token.
|
|
543
543
|
|
|
544
|
+
Fetches the user's current role/tier from the database to ensure
|
|
545
|
+
the new access token reflects their actual permissions.
|
|
546
|
+
|
|
544
547
|
Args:
|
|
545
548
|
body: TokenRefreshRequest with refresh_token
|
|
546
549
|
|
|
@@ -548,7 +551,46 @@ async def refresh_token(body: TokenRefreshRequest):
|
|
|
548
551
|
New access token or 401 if refresh token is invalid
|
|
549
552
|
"""
|
|
550
553
|
jwt_service = get_jwt_service()
|
|
551
|
-
|
|
554
|
+
|
|
555
|
+
# First decode the refresh token to get user_id (without full verification yet)
|
|
556
|
+
payload = jwt_service.decode_without_verification(body.refresh_token)
|
|
557
|
+
if not payload:
|
|
558
|
+
raise HTTPException(
|
|
559
|
+
status_code=401,
|
|
560
|
+
detail="Invalid refresh token format"
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
user_id = payload.get("sub")
|
|
564
|
+
if not user_id:
|
|
565
|
+
raise HTTPException(
|
|
566
|
+
status_code=401,
|
|
567
|
+
detail="Invalid refresh token: missing user ID"
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
# Fetch user from database to get current role/tier
|
|
571
|
+
user_override = None
|
|
572
|
+
if settings.postgres.enabled:
|
|
573
|
+
db = PostgresService()
|
|
574
|
+
try:
|
|
575
|
+
await db.connect()
|
|
576
|
+
user_service = UserService(db)
|
|
577
|
+
user_entity = await user_service.get_user_by_id(user_id)
|
|
578
|
+
if user_entity:
|
|
579
|
+
user_override = {
|
|
580
|
+
"role": user_entity.role or "user",
|
|
581
|
+
"roles": [user_entity.role] if user_entity.role else ["user"],
|
|
582
|
+
"tier": user_entity.tier.value if user_entity.tier else "free",
|
|
583
|
+
"name": user_entity.name,
|
|
584
|
+
}
|
|
585
|
+
logger.debug(f"Refresh token: fetched user {user_id} with role={user_override['role']}, tier={user_override['tier']}")
|
|
586
|
+
except Exception as e:
|
|
587
|
+
logger.warning(f"Could not fetch user for token refresh: {e}")
|
|
588
|
+
# Continue without override - will use defaults
|
|
589
|
+
finally:
|
|
590
|
+
await db.disconnect()
|
|
591
|
+
|
|
592
|
+
# Now do the actual refresh with proper verification
|
|
593
|
+
result = jwt_service.refresh_access_token(body.refresh_token, user_override=user_override)
|
|
552
594
|
|
|
553
595
|
if not result:
|
|
554
596
|
raise HTTPException(
|
|
@@ -503,16 +503,42 @@ async def chat_completions(body: ChatCompletionRequest, request: Request):
|
|
|
503
503
|
logger.error(f"Failed to transcribe audio: {e}")
|
|
504
504
|
# Fall through with original content (will likely fail at agent)
|
|
505
505
|
|
|
506
|
-
# Use ContextBuilder to construct
|
|
507
|
-
#
|
|
508
|
-
# 2. Session history (if session_id provided)
|
|
509
|
-
# 3. New messages from request body (transcribed if audio)
|
|
506
|
+
# Use ContextBuilder to construct context and basic messages
|
|
507
|
+
# Note: We load session history separately for proper pydantic-ai message_history
|
|
510
508
|
context, messages = await ContextBuilder.build_from_headers(
|
|
511
509
|
headers=dict(request.headers),
|
|
512
510
|
new_messages=new_messages,
|
|
513
511
|
user_id=temp_context.user_id, # From JWT token (source of truth)
|
|
514
512
|
)
|
|
515
513
|
|
|
514
|
+
# Load raw session history for proper pydantic-ai message_history format
|
|
515
|
+
# This enables proper tool call/return pairing for LLM API compatibility
|
|
516
|
+
from ....services.session import SessionMessageStore, session_to_pydantic_messages
|
|
517
|
+
from ....agentic.schema import get_system_prompt
|
|
518
|
+
|
|
519
|
+
pydantic_message_history = None
|
|
520
|
+
if context.session_id and settings.postgres.enabled:
|
|
521
|
+
try:
|
|
522
|
+
store = SessionMessageStore(user_id=context.user_id or settings.test.effective_user_id)
|
|
523
|
+
raw_session_history = await store.load_session_messages(
|
|
524
|
+
session_id=context.session_id,
|
|
525
|
+
user_id=context.user_id,
|
|
526
|
+
compress_on_load=False, # Don't compress - we need full data for reconstruction
|
|
527
|
+
)
|
|
528
|
+
if raw_session_history:
|
|
529
|
+
# CRITICAL: Extract and pass the agent's system prompt
|
|
530
|
+
# pydantic-ai only auto-adds system prompts when message_history is empty
|
|
531
|
+
# When we pass message_history, we must include the system prompt ourselves
|
|
532
|
+
agent_system_prompt = get_system_prompt(agent_schema) if agent_schema else None
|
|
533
|
+
pydantic_message_history = session_to_pydantic_messages(
|
|
534
|
+
raw_session_history,
|
|
535
|
+
system_prompt=agent_system_prompt,
|
|
536
|
+
)
|
|
537
|
+
logger.debug(f"Converted {len(raw_session_history)} session messages to {len(pydantic_message_history)} pydantic-ai messages (with system prompt)")
|
|
538
|
+
except Exception as e:
|
|
539
|
+
logger.warning(f"Failed to load session history for message_history: {e}")
|
|
540
|
+
# Fall back to old behavior (concatenated prompt)
|
|
541
|
+
|
|
516
542
|
logger.info(f"Built context with {len(messages)} total messages (includes history + user context)")
|
|
517
543
|
|
|
518
544
|
# Ensure session exists with metadata and eval mode if applicable
|
|
@@ -533,9 +559,17 @@ async def chat_completions(body: ChatCompletionRequest, request: Request):
|
|
|
533
559
|
model_override=body.model, # type: ignore[arg-type]
|
|
534
560
|
)
|
|
535
561
|
|
|
536
|
-
#
|
|
537
|
-
#
|
|
538
|
-
|
|
562
|
+
# Build the prompt for the agent
|
|
563
|
+
# If we have proper message_history, use just the latest user message as prompt
|
|
564
|
+
# Otherwise, fall back to concatenating all messages (legacy behavior)
|
|
565
|
+
if pydantic_message_history:
|
|
566
|
+
# Use the latest user message as the prompt, with history passed separately
|
|
567
|
+
user_prompt = body.messages[-1].content if body.messages else ""
|
|
568
|
+
prompt = user_prompt
|
|
569
|
+
logger.debug(f"Using message_history with {len(pydantic_message_history)} messages")
|
|
570
|
+
else:
|
|
571
|
+
# Legacy: Combine all messages into single prompt for agent
|
|
572
|
+
prompt = "\n".join(msg.content for msg in messages)
|
|
539
573
|
|
|
540
574
|
# Generate OpenAI-compatible request ID
|
|
541
575
|
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
@@ -570,6 +604,8 @@ async def chat_completions(body: ChatCompletionRequest, request: Request):
|
|
|
570
604
|
agent_schema=schema_name,
|
|
571
605
|
session_id=context.session_id,
|
|
572
606
|
user_id=context.user_id,
|
|
607
|
+
agent_context=context, # Pass context for multi-agent support
|
|
608
|
+
message_history=pydantic_message_history, # Native pydantic-ai message history
|
|
573
609
|
),
|
|
574
610
|
media_type="text/event-stream",
|
|
575
611
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
|
@@ -592,10 +628,16 @@ async def chat_completions(body: ChatCompletionRequest, request: Request):
|
|
|
592
628
|
) as span:
|
|
593
629
|
# Capture trace context from the span we just created
|
|
594
630
|
trace_id, span_id = get_current_trace_context()
|
|
595
|
-
|
|
631
|
+
if pydantic_message_history:
|
|
632
|
+
result = await agent.run(prompt, message_history=pydantic_message_history)
|
|
633
|
+
else:
|
|
634
|
+
result = await agent.run(prompt)
|
|
596
635
|
else:
|
|
597
636
|
# No tracer available, run without tracing
|
|
598
|
-
|
|
637
|
+
if pydantic_message_history:
|
|
638
|
+
result = await agent.run(prompt, message_history=pydantic_message_history)
|
|
639
|
+
else:
|
|
640
|
+
result = await agent.run(prompt)
|
|
599
641
|
|
|
600
642
|
# Determine content format based on response_format request
|
|
601
643
|
if body.response_format and body.response_format.type == "json_object":
|
|
@@ -409,9 +409,9 @@ class ToolCallEvent(BaseModel):
|
|
|
409
409
|
default=None,
|
|
410
410
|
description="Tool arguments (for 'started' status)"
|
|
411
411
|
)
|
|
412
|
-
result: str | None = Field(
|
|
412
|
+
result: str | dict[str, Any] | None = Field(
|
|
413
413
|
default=None,
|
|
414
|
-
description="Tool result summary
|
|
414
|
+
description="Tool result - full dict for finalize_intake, summary string for others"
|
|
415
415
|
)
|
|
416
416
|
error: str | None = Field(
|
|
417
417
|
default=None,
|
|
@@ -15,6 +15,11 @@ Key Insight
|
|
|
15
15
|
- Use PartEndEvent to detect tool completion
|
|
16
16
|
- Use FunctionToolResultEvent to get tool results
|
|
17
17
|
|
|
18
|
+
Multi-Agent Context Propagation:
|
|
19
|
+
- AgentContext is set via agent_context_scope() before agent.iter()
|
|
20
|
+
- Child agents (via ask_agent tool) can access parent context via get_current_context()
|
|
21
|
+
- Context includes user_id, tenant_id, session_id, is_eval for proper scoping
|
|
22
|
+
|
|
18
23
|
SSE Format (OpenAI-compatible):
|
|
19
24
|
data: {"id": "chatcmpl-...", "choices": [{"delta": {"content": "..."}}]}\\n\\n
|
|
20
25
|
data: [DONE]\\n\\n
|
|
@@ -28,10 +33,12 @@ Extended SSE Format (Custom Events):
|
|
|
28
33
|
See sse_events.py for the full event type definitions.
|
|
29
34
|
"""
|
|
30
35
|
|
|
36
|
+
from __future__ import annotations
|
|
37
|
+
|
|
31
38
|
import json
|
|
32
39
|
import time
|
|
33
40
|
import uuid
|
|
34
|
-
from typing import AsyncGenerator
|
|
41
|
+
from typing import TYPE_CHECKING, AsyncGenerator
|
|
35
42
|
|
|
36
43
|
from loguru import logger
|
|
37
44
|
from pydantic_ai.agent import Agent
|
|
@@ -55,6 +62,7 @@ from .models import (
|
|
|
55
62
|
)
|
|
56
63
|
from .sse_events import (
|
|
57
64
|
DoneEvent,
|
|
65
|
+
ErrorEvent,
|
|
58
66
|
MetadataEvent,
|
|
59
67
|
ProgressEvent,
|
|
60
68
|
ReasoningEvent,
|
|
@@ -62,6 +70,9 @@ from .sse_events import (
|
|
|
62
70
|
format_sse_event,
|
|
63
71
|
)
|
|
64
72
|
|
|
73
|
+
if TYPE_CHECKING:
|
|
74
|
+
from ....agentic.context import AgentContext
|
|
75
|
+
|
|
65
76
|
|
|
66
77
|
async def stream_openai_response(
|
|
67
78
|
agent: Agent,
|
|
@@ -79,6 +90,11 @@ async def stream_openai_response(
|
|
|
79
90
|
# Mutable container to capture tool calls for persistence
|
|
80
91
|
# Format: list of {"tool_name": str, "tool_id": str, "arguments": dict, "result": any}
|
|
81
92
|
tool_calls_out: list | None = None,
|
|
93
|
+
# Agent context for multi-agent propagation
|
|
94
|
+
# When set, enables child agents to access parent context via get_current_context()
|
|
95
|
+
agent_context: "AgentContext | None" = None,
|
|
96
|
+
# Pydantic-ai native message history for proper tool call/return pairing
|
|
97
|
+
message_history: list | None = None,
|
|
82
98
|
) -> AsyncGenerator[str, None]:
|
|
83
99
|
"""
|
|
84
100
|
Stream Pydantic AI agent responses with rich SSE events.
|
|
@@ -153,6 +169,17 @@ async def stream_openai_response(
|
|
|
153
169
|
# Maps tool_id -> {"tool_name": str, "tool_id": str, "arguments": dict}
|
|
154
170
|
pending_tool_data: dict[str, dict] = {}
|
|
155
171
|
|
|
172
|
+
# Import context functions for multi-agent support
|
|
173
|
+
from ....agentic.context import set_current_context
|
|
174
|
+
|
|
175
|
+
# Set up context for multi-agent propagation
|
|
176
|
+
# This allows child agents (via ask_agent tool) to access parent context
|
|
177
|
+
previous_context = None
|
|
178
|
+
if agent_context is not None:
|
|
179
|
+
from ....agentic.context import get_current_context
|
|
180
|
+
previous_context = get_current_context()
|
|
181
|
+
set_current_context(agent_context)
|
|
182
|
+
|
|
156
183
|
try:
|
|
157
184
|
# Emit initial progress event
|
|
158
185
|
current_step = 1
|
|
@@ -164,7 +191,9 @@ async def stream_openai_response(
|
|
|
164
191
|
))
|
|
165
192
|
|
|
166
193
|
# Use agent.iter() to get complete execution with tool calls
|
|
167
|
-
|
|
194
|
+
# Pass message_history if available for proper tool call/return pairing
|
|
195
|
+
iter_kwargs = {"message_history": message_history} if message_history else {}
|
|
196
|
+
async with agent.iter(prompt, **iter_kwargs) as agent_run:
|
|
168
197
|
# Capture trace context IMMEDIATELY inside agent execution
|
|
169
198
|
# This is deterministic - it's the OTEL context from Pydantic AI instrumentation
|
|
170
199
|
# NOT dependent on any AI-generated content
|
|
@@ -285,6 +314,12 @@ async def stream_openai_response(
|
|
|
285
314
|
args_dict = event.part.args.args_dict
|
|
286
315
|
elif isinstance(event.part.args, dict):
|
|
287
316
|
args_dict = event.part.args
|
|
317
|
+
elif isinstance(event.part.args, str):
|
|
318
|
+
# Parse JSON string args (common with pydantic-ai)
|
|
319
|
+
try:
|
|
320
|
+
args_dict = json.loads(event.part.args)
|
|
321
|
+
except json.JSONDecodeError:
|
|
322
|
+
logger.warning(f"Failed to parse tool args as JSON: {event.part.args[:100]}")
|
|
288
323
|
|
|
289
324
|
# Log tool call with key parameters
|
|
290
325
|
if args_dict and tool_name == "search_rem":
|
|
@@ -330,8 +365,25 @@ async def stream_openai_response(
|
|
|
330
365
|
):
|
|
331
366
|
if event.index in active_tool_calls:
|
|
332
367
|
tool_name, tool_id = active_tool_calls[event.index]
|
|
333
|
-
|
|
334
|
-
#
|
|
368
|
+
|
|
369
|
+
# Extract full args from completed ToolCallPart
|
|
370
|
+
# (PartStartEvent only has empty/partial args during streaming)
|
|
371
|
+
args_dict = None
|
|
372
|
+
if event.part.args is not None:
|
|
373
|
+
if hasattr(event.part.args, 'args_dict'):
|
|
374
|
+
args_dict = event.part.args.args_dict
|
|
375
|
+
elif isinstance(event.part.args, dict):
|
|
376
|
+
args_dict = event.part.args
|
|
377
|
+
elif isinstance(event.part.args, str) and event.part.args:
|
|
378
|
+
try:
|
|
379
|
+
args_dict = json.loads(event.part.args)
|
|
380
|
+
except json.JSONDecodeError:
|
|
381
|
+
logger.warning(f"Failed to parse tool args: {event.part.args[:100]}")
|
|
382
|
+
|
|
383
|
+
# Update pending_tool_data with complete args
|
|
384
|
+
if tool_id in pending_tool_data:
|
|
385
|
+
pending_tool_data[tool_id]["arguments"] = args_dict
|
|
386
|
+
|
|
335
387
|
del active_tool_calls[event.index]
|
|
336
388
|
|
|
337
389
|
# ============================================
|
|
@@ -434,6 +486,12 @@ async def stream_openai_response(
|
|
|
434
486
|
hidden=False,
|
|
435
487
|
))
|
|
436
488
|
|
|
489
|
+
# Get complete args from pending_tool_data BEFORE deleting
|
|
490
|
+
# (captured at PartEndEvent with full args)
|
|
491
|
+
completed_args = None
|
|
492
|
+
if tool_id in pending_tool_data:
|
|
493
|
+
completed_args = pending_tool_data[tool_id].get("arguments")
|
|
494
|
+
|
|
437
495
|
# Capture tool call with result for persistence
|
|
438
496
|
# Special handling for register_metadata - always capture full data
|
|
439
497
|
if tool_calls_out is not None and tool_id in pending_tool_data:
|
|
@@ -445,8 +503,12 @@ async def stream_openai_response(
|
|
|
445
503
|
|
|
446
504
|
if not is_metadata_event:
|
|
447
505
|
# Normal tool completion - emit ToolCallEvent
|
|
448
|
-
|
|
449
|
-
|
|
506
|
+
# For finalize_intake, send full result dict for frontend
|
|
507
|
+
if tool_name == "finalize_intake" and isinstance(result_content, dict):
|
|
508
|
+
result_for_sse = result_content
|
|
509
|
+
else:
|
|
510
|
+
result_str = str(result_content)
|
|
511
|
+
result_for_sse = result_str[:200] + "..." if len(result_str) > 200 else result_str
|
|
450
512
|
|
|
451
513
|
# Log result count for search_rem
|
|
452
514
|
if tool_name == "search_rem" and isinstance(result_content, dict):
|
|
@@ -477,7 +539,8 @@ async def stream_openai_response(
|
|
|
477
539
|
tool_name=tool_name,
|
|
478
540
|
tool_id=tool_id,
|
|
479
541
|
status="completed",
|
|
480
|
-
|
|
542
|
+
arguments=completed_args,
|
|
543
|
+
result=result_for_sse
|
|
481
544
|
))
|
|
482
545
|
|
|
483
546
|
# Update progress after tool completion
|
|
@@ -587,25 +650,77 @@ async def stream_openai_response(
|
|
|
587
650
|
|
|
588
651
|
except Exception as e:
|
|
589
652
|
import traceback
|
|
653
|
+
import re
|
|
590
654
|
|
|
591
655
|
error_msg = str(e)
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
656
|
+
|
|
657
|
+
# Parse error details for better client handling
|
|
658
|
+
error_code = "stream_error"
|
|
659
|
+
error_details: dict = {}
|
|
660
|
+
recoverable = True
|
|
661
|
+
|
|
662
|
+
# Check for rate limit errors (OpenAI 429)
|
|
663
|
+
if "429" in error_msg or "rate_limit" in error_msg.lower() or "RateLimitError" in type(e).__name__:
|
|
664
|
+
error_code = "rate_limit_exceeded"
|
|
665
|
+
recoverable = True
|
|
666
|
+
|
|
667
|
+
# Extract retry-after time from error message
|
|
668
|
+
# Pattern: "Please try again in X.XXs" or "Please try again in Xs"
|
|
669
|
+
retry_match = re.search(r"try again in (\d+(?:\.\d+)?)\s*s", error_msg)
|
|
670
|
+
if retry_match:
|
|
671
|
+
retry_seconds = float(retry_match.group(1))
|
|
672
|
+
error_details["retry_after_seconds"] = retry_seconds
|
|
673
|
+
error_details["retry_after_ms"] = int(retry_seconds * 1000)
|
|
674
|
+
|
|
675
|
+
# Extract token usage info if available
|
|
676
|
+
used_match = re.search(r"Used (\d+)", error_msg)
|
|
677
|
+
limit_match = re.search(r"Limit (\d+)", error_msg)
|
|
678
|
+
requested_match = re.search(r"Requested (\d+)", error_msg)
|
|
679
|
+
if used_match:
|
|
680
|
+
error_details["tokens_used"] = int(used_match.group(1))
|
|
681
|
+
if limit_match:
|
|
682
|
+
error_details["tokens_limit"] = int(limit_match.group(1))
|
|
683
|
+
if requested_match:
|
|
684
|
+
error_details["tokens_requested"] = int(requested_match.group(1))
|
|
685
|
+
|
|
686
|
+
logger.error(f"🔴 Streaming error: status_code: 429, model_name: {model}, body: {error_msg[:200]}")
|
|
687
|
+
|
|
688
|
+
# Check for authentication errors
|
|
689
|
+
elif "401" in error_msg or "AuthenticationError" in type(e).__name__:
|
|
690
|
+
error_code = "authentication_error"
|
|
691
|
+
recoverable = False
|
|
692
|
+
logger.error(f"🔴 Streaming error: Authentication failed")
|
|
693
|
+
|
|
694
|
+
# Check for model not found / invalid model
|
|
695
|
+
elif "404" in error_msg or "model" in error_msg.lower() and "not found" in error_msg.lower():
|
|
696
|
+
error_code = "model_not_found"
|
|
697
|
+
recoverable = False
|
|
698
|
+
logger.error(f"🔴 Streaming error: Model not found")
|
|
699
|
+
|
|
700
|
+
# Generic error
|
|
701
|
+
else:
|
|
702
|
+
logger.error(f"🔴 Streaming error: {error_msg}")
|
|
703
|
+
|
|
704
|
+
logger.error(f"🔴 {traceback.format_exc()}")
|
|
705
|
+
|
|
706
|
+
# Emit proper ErrorEvent via SSE (with event: prefix for client parsing)
|
|
707
|
+
yield format_sse_event(ErrorEvent(
|
|
708
|
+
code=error_code,
|
|
709
|
+
message=error_msg,
|
|
710
|
+
details=error_details if error_details else None,
|
|
711
|
+
recoverable=recoverable,
|
|
712
|
+
))
|
|
604
713
|
|
|
605
714
|
# Emit done event with error reason
|
|
606
715
|
yield format_sse_event(DoneEvent(reason="error"))
|
|
607
716
|
yield "data: [DONE]\n\n"
|
|
608
717
|
|
|
718
|
+
finally:
|
|
719
|
+
# Restore previous context for multi-agent support
|
|
720
|
+
# This ensures nested agent calls don't pollute the parent's context
|
|
721
|
+
if agent_context is not None:
|
|
722
|
+
set_current_context(previous_context)
|
|
723
|
+
|
|
609
724
|
|
|
610
725
|
async def stream_simulator_response(
|
|
611
726
|
prompt: str,
|
|
@@ -716,6 +831,10 @@ async def stream_openai_response_with_save(
|
|
|
716
831
|
agent_schema: str | None = None,
|
|
717
832
|
session_id: str | None = None,
|
|
718
833
|
user_id: str | None = None,
|
|
834
|
+
# Agent context for multi-agent propagation
|
|
835
|
+
agent_context: "AgentContext | None" = None,
|
|
836
|
+
# Pydantic-ai native message history for proper tool call/return pairing
|
|
837
|
+
message_history: list | None = None,
|
|
719
838
|
) -> AsyncGenerator[str, None]:
|
|
720
839
|
"""
|
|
721
840
|
Wrapper around stream_openai_response that saves the assistant response after streaming.
|
|
@@ -731,6 +850,7 @@ async def stream_openai_response_with_save(
|
|
|
731
850
|
agent_schema: Agent schema name
|
|
732
851
|
session_id: Session ID for message storage
|
|
733
852
|
user_id: User ID for message storage
|
|
853
|
+
agent_context: Agent context for multi-agent propagation (enables child agents)
|
|
734
854
|
|
|
735
855
|
Yields:
|
|
736
856
|
SSE-formatted strings
|
|
@@ -763,6 +883,8 @@ async def stream_openai_response_with_save(
|
|
|
763
883
|
message_id=message_id,
|
|
764
884
|
trace_context_out=trace_context, # Pass container to capture trace IDs
|
|
765
885
|
tool_calls_out=tool_calls, # Capture tool calls for persistence
|
|
886
|
+
agent_context=agent_context, # Pass context for multi-agent support
|
|
887
|
+
message_history=message_history, # Native pydantic-ai message history
|
|
766
888
|
):
|
|
767
889
|
yield chunk
|
|
768
890
|
|
|
@@ -793,6 +915,8 @@ async def stream_openai_response_with_save(
|
|
|
793
915
|
|
|
794
916
|
# First, store tool call messages (message_type: "tool")
|
|
795
917
|
for tool_call in tool_calls:
|
|
918
|
+
if not tool_call:
|
|
919
|
+
continue
|
|
796
920
|
tool_message = {
|
|
797
921
|
"role": "tool",
|
|
798
922
|
"content": json.dumps(tool_call.get("result", {}), default=str),
|
|
@@ -838,8 +962,9 @@ async def stream_openai_response_with_save(
|
|
|
838
962
|
|
|
839
963
|
# Update session description with session_name (non-blocking, after all yields)
|
|
840
964
|
for tool_call in tool_calls:
|
|
841
|
-
if tool_call.get("tool_name") == "register_metadata" and tool_call.get("is_metadata"):
|
|
842
|
-
|
|
965
|
+
if tool_call and tool_call.get("tool_name") == "register_metadata" and tool_call.get("is_metadata"):
|
|
966
|
+
arguments = tool_call.get("arguments") or {}
|
|
967
|
+
session_name = arguments.get("session_name")
|
|
843
968
|
if session_name:
|
|
844
969
|
try:
|
|
845
970
|
from ....models.entities import Session
|
rem/api/routers/messages.py
CHANGED
|
@@ -93,6 +93,23 @@ class SessionListResponse(BaseModel):
|
|
|
93
93
|
has_more: bool
|
|
94
94
|
|
|
95
95
|
|
|
96
|
+
class SessionWithUser(BaseModel):
|
|
97
|
+
"""Session with user info for admin views."""
|
|
98
|
+
|
|
99
|
+
id: str
|
|
100
|
+
name: str
|
|
101
|
+
mode: str | None = None
|
|
102
|
+
description: str | None = None
|
|
103
|
+
user_id: str | None = None
|
|
104
|
+
user_name: str | None = None
|
|
105
|
+
user_email: str | None = None
|
|
106
|
+
message_count: int = 0
|
|
107
|
+
total_tokens: int | None = None
|
|
108
|
+
created_at: datetime | None = None
|
|
109
|
+
updated_at: datetime | None = None
|
|
110
|
+
metadata: dict | None = None
|
|
111
|
+
|
|
112
|
+
|
|
96
113
|
class PaginationMetadata(BaseModel):
|
|
97
114
|
"""Pagination metadata for paginated responses."""
|
|
98
115
|
|
|
@@ -108,7 +125,7 @@ class SessionsQueryResponse(BaseModel):
|
|
|
108
125
|
"""Response for paginated sessions query."""
|
|
109
126
|
|
|
110
127
|
object: Literal["list"] = "list"
|
|
111
|
-
data: list[
|
|
128
|
+
data: list[SessionWithUser] = Field(description="List of sessions for the current page")
|
|
112
129
|
metadata: PaginationMetadata = Field(description="Pagination metadata")
|
|
113
130
|
|
|
114
131
|
|
|
@@ -274,6 +291,8 @@ async def get_message(
|
|
|
274
291
|
async def list_sessions(
|
|
275
292
|
request: Request,
|
|
276
293
|
user_id: str | None = Query(default=None, description="Filter by user ID (admin only for cross-user)"),
|
|
294
|
+
user_name: str | None = Query(default=None, description="Filter by user name (partial match, admin only)"),
|
|
295
|
+
user_email: str | None = Query(default=None, description="Filter by user email (partial match, admin only)"),
|
|
277
296
|
mode: SessionMode | None = Query(default=None, description="Filter by session mode"),
|
|
278
297
|
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
|
279
298
|
page_size: int = Query(default=50, ge=1, le=100, description="Number of results per page"),
|
|
@@ -283,46 +302,100 @@ async def list_sessions(
|
|
|
283
302
|
|
|
284
303
|
Access Control:
|
|
285
304
|
- Regular users: Only see their own sessions
|
|
286
|
-
- Admin users: Can filter by any user_id or see all sessions
|
|
305
|
+
- Admin users: Can filter by any user_id, user_name, user_email, or see all sessions
|
|
287
306
|
|
|
288
307
|
Filters:
|
|
289
308
|
- user_id: Filter by session owner (admin only for cross-user)
|
|
309
|
+
- user_name: Filter by user name partial match (admin only)
|
|
310
|
+
- user_email: Filter by user email partial match (admin only)
|
|
290
311
|
- mode: Filter by session mode (normal or evaluation)
|
|
291
312
|
|
|
292
313
|
Pagination:
|
|
293
314
|
- page: Page number (1-indexed, default: 1)
|
|
294
315
|
- page_size: Number of sessions per page (default: 50, max: 100)
|
|
295
316
|
|
|
296
|
-
Returns paginated results ordered by created_at descending
|
|
317
|
+
Returns paginated results with user info ordered by created_at descending.
|
|
297
318
|
"""
|
|
298
319
|
if not settings.postgres.enabled:
|
|
299
320
|
raise HTTPException(status_code=503, detail="Database not enabled")
|
|
300
321
|
|
|
301
|
-
|
|
322
|
+
current_user = get_current_user(request)
|
|
323
|
+
admin = is_admin(current_user)
|
|
302
324
|
|
|
303
|
-
#
|
|
304
|
-
|
|
305
|
-
if
|
|
306
|
-
|
|
325
|
+
# Get postgres service for raw SQL query
|
|
326
|
+
db = get_postgres_service()
|
|
327
|
+
if not db:
|
|
328
|
+
raise HTTPException(status_code=503, detail="Database connection failed")
|
|
329
|
+
if not db.pool:
|
|
330
|
+
await db.connect()
|
|
307
331
|
|
|
308
|
-
#
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
332
|
+
# Build effective filters based on user role
|
|
333
|
+
effective_user_id = user_id
|
|
334
|
+
effective_user_name = user_name if admin else None # Only admin can search by name
|
|
335
|
+
effective_user_email = user_email if admin else None # Only admin can search by email
|
|
336
|
+
|
|
337
|
+
if not admin:
|
|
338
|
+
# Non-admin users can only see their own sessions
|
|
339
|
+
effective_user_id = current_user.get("id") if current_user else None
|
|
340
|
+
if not effective_user_id:
|
|
341
|
+
# Anonymous user - return empty
|
|
342
|
+
return SessionsQueryResponse(
|
|
343
|
+
data=[],
|
|
344
|
+
metadata=PaginationMetadata(
|
|
345
|
+
total=0, page=page, page_size=page_size,
|
|
346
|
+
total_pages=0, has_next=False, has_previous=False,
|
|
347
|
+
),
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# Call the SQL function for sessions with user info
|
|
351
|
+
async with db.pool.acquire() as conn:
|
|
352
|
+
rows = await conn.fetch(
|
|
353
|
+
"""
|
|
354
|
+
SELECT * FROM fn_list_sessions_with_user(
|
|
355
|
+
$1, $2, $3, $4, $5, $6
|
|
356
|
+
)
|
|
357
|
+
""",
|
|
358
|
+
effective_user_id,
|
|
359
|
+
effective_user_name,
|
|
360
|
+
effective_user_email,
|
|
361
|
+
mode.value if mode else None,
|
|
362
|
+
page,
|
|
363
|
+
page_size,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Extract total from first row
|
|
367
|
+
total = rows[0]["total_count"] if rows else 0
|
|
368
|
+
|
|
369
|
+
# Convert rows to SessionWithUser
|
|
370
|
+
data = [
|
|
371
|
+
SessionWithUser(
|
|
372
|
+
id=str(row["id"]),
|
|
373
|
+
name=row["name"],
|
|
374
|
+
mode=row["mode"],
|
|
375
|
+
description=row["description"],
|
|
376
|
+
user_id=row["user_id"],
|
|
377
|
+
user_name=row["user_name"],
|
|
378
|
+
user_email=row["user_email"],
|
|
379
|
+
message_count=row["message_count"] or 0,
|
|
380
|
+
total_tokens=row["total_tokens"],
|
|
381
|
+
created_at=row["created_at"],
|
|
382
|
+
updated_at=row["updated_at"],
|
|
383
|
+
metadata=row["metadata"],
|
|
384
|
+
)
|
|
385
|
+
for row in rows
|
|
386
|
+
]
|
|
387
|
+
|
|
388
|
+
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
|
|
316
389
|
|
|
317
390
|
return SessionsQueryResponse(
|
|
318
|
-
data=
|
|
391
|
+
data=data,
|
|
319
392
|
metadata=PaginationMetadata(
|
|
320
|
-
total=
|
|
321
|
-
page=
|
|
322
|
-
page_size=
|
|
323
|
-
total_pages=
|
|
324
|
-
has_next=
|
|
325
|
-
has_previous=
|
|
393
|
+
total=total,
|
|
394
|
+
page=page,
|
|
395
|
+
page_size=page_size,
|
|
396
|
+
total_pages=total_pages,
|
|
397
|
+
has_next=page < total_pages,
|
|
398
|
+
has_previous=page > 1,
|
|
326
399
|
),
|
|
327
400
|
)
|
|
328
401
|
|