remdb 0.3.226__py3-none-any.whl → 0.3.245__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of remdb might be problematic. Click here for more details.
- rem/agentic/README.md +22 -248
- rem/agentic/context.py +13 -2
- rem/agentic/context_builder.py +39 -33
- rem/agentic/providers/pydantic_ai.py +67 -50
- rem/api/mcp_router/resources.py +223 -0
- rem/api/mcp_router/tools.py +25 -9
- rem/api/routers/auth.py +112 -9
- rem/api/routers/chat/child_streaming.py +394 -0
- rem/api/routers/chat/streaming.py +166 -357
- rem/api/routers/chat/streaming_utils.py +327 -0
- rem/api/routers/query.py +5 -14
- rem/cli/commands/ask.py +144 -33
- rem/cli/commands/process.py +9 -1
- rem/cli/commands/query.py +109 -0
- rem/cli/commands/session.py +117 -0
- rem/cli/main.py +2 -0
- rem/models/entities/session.py +1 -0
- rem/services/postgres/repository.py +7 -17
- rem/services/rem/service.py +47 -0
- rem/services/session/compression.py +7 -3
- rem/services/session/pydantic_messages.py +45 -11
- rem/services/session/reload.py +2 -1
- rem/settings.py +43 -0
- rem/sql/migrations/004_cache_system.sql +3 -1
- rem/utils/schema_loader.py +99 -99
- {remdb-0.3.226.dist-info → remdb-0.3.245.dist-info}/METADATA +2 -2
- {remdb-0.3.226.dist-info → remdb-0.3.245.dist-info}/RECORD +29 -26
- {remdb-0.3.226.dist-info → remdb-0.3.245.dist-info}/WHEEL +0 -0
- {remdb-0.3.226.dist-info → remdb-0.3.245.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Streaming Utilities.
|
|
3
|
+
|
|
4
|
+
Pure functions and data structures for SSE streaming.
|
|
5
|
+
No I/O, no database calls - just data transformation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import time
|
|
12
|
+
import uuid
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from loguru import logger
|
|
17
|
+
|
|
18
|
+
from .models import (
|
|
19
|
+
ChatCompletionMessageDelta,
|
|
20
|
+
ChatCompletionStreamChoice,
|
|
21
|
+
ChatCompletionStreamResponse,
|
|
22
|
+
)
|
|
23
|
+
from .sse_events import (
|
|
24
|
+
MetadataEvent,
|
|
25
|
+
ProgressEvent,
|
|
26
|
+
ReasoningEvent,
|
|
27
|
+
ToolCallEvent,
|
|
28
|
+
format_sse_event,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# =============================================================================
|
|
33
|
+
# STREAMING STATE
|
|
34
|
+
# =============================================================================
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class StreamingState:
|
|
38
|
+
"""
|
|
39
|
+
Tracks state during SSE streaming.
|
|
40
|
+
|
|
41
|
+
This is a pure data container - no methods that do I/O.
|
|
42
|
+
"""
|
|
43
|
+
request_id: str
|
|
44
|
+
created_at: int
|
|
45
|
+
model: str
|
|
46
|
+
start_time: float = field(default_factory=time.time)
|
|
47
|
+
|
|
48
|
+
# Content tracking
|
|
49
|
+
is_first_chunk: bool = True
|
|
50
|
+
token_count: int = 0
|
|
51
|
+
|
|
52
|
+
# Child agent tracking - KEY FOR DUPLICATION FIX
|
|
53
|
+
child_content_streamed: bool = False
|
|
54
|
+
responding_agent: str | None = None
|
|
55
|
+
|
|
56
|
+
# Tool tracking
|
|
57
|
+
active_tool_calls: dict = field(default_factory=dict) # index -> (name, id)
|
|
58
|
+
pending_tool_completions: list = field(default_factory=list) # FIFO queue
|
|
59
|
+
pending_tool_data: dict = field(default_factory=dict) # tool_id -> data
|
|
60
|
+
|
|
61
|
+
# Reasoning tracking
|
|
62
|
+
reasoning_step: int = 0
|
|
63
|
+
|
|
64
|
+
# Progress tracking
|
|
65
|
+
current_step: int = 0
|
|
66
|
+
total_steps: int = 3
|
|
67
|
+
|
|
68
|
+
# Metadata tracking
|
|
69
|
+
metadata_registered: bool = False
|
|
70
|
+
|
|
71
|
+
# Trace context (captured from OTEL)
|
|
72
|
+
trace_id: str | None = None
|
|
73
|
+
span_id: str | None = None
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def create(cls, model: str, request_id: str | None = None) -> "StreamingState":
|
|
77
|
+
"""Create a new streaming state."""
|
|
78
|
+
return cls(
|
|
79
|
+
request_id=request_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
|
|
80
|
+
created_at=int(time.time()),
|
|
81
|
+
model=model,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def latency_ms(self) -> int:
|
|
85
|
+
"""Calculate latency since start."""
|
|
86
|
+
return int((time.time() - self.start_time) * 1000)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# =============================================================================
|
|
90
|
+
# SSE CHUNK BUILDERS
|
|
91
|
+
# =============================================================================
|
|
92
|
+
|
|
93
|
+
def build_content_chunk(state: StreamingState, content: str) -> str:
|
|
94
|
+
"""
|
|
95
|
+
Build an SSE content chunk in OpenAI format.
|
|
96
|
+
|
|
97
|
+
Updates state.is_first_chunk and state.token_count.
|
|
98
|
+
"""
|
|
99
|
+
state.token_count += len(content.split())
|
|
100
|
+
|
|
101
|
+
chunk = ChatCompletionStreamResponse(
|
|
102
|
+
id=state.request_id,
|
|
103
|
+
created=state.created_at,
|
|
104
|
+
model=state.model,
|
|
105
|
+
choices=[
|
|
106
|
+
ChatCompletionStreamChoice(
|
|
107
|
+
index=0,
|
|
108
|
+
delta=ChatCompletionMessageDelta(
|
|
109
|
+
role="assistant" if state.is_first_chunk else None,
|
|
110
|
+
content=content,
|
|
111
|
+
),
|
|
112
|
+
finish_reason=None,
|
|
113
|
+
)
|
|
114
|
+
],
|
|
115
|
+
)
|
|
116
|
+
state.is_first_chunk = False
|
|
117
|
+
return f"data: {chunk.model_dump_json()}\n\n"
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def build_final_chunk(state: StreamingState) -> str:
|
|
121
|
+
"""Build the final SSE chunk with finish_reason=stop."""
|
|
122
|
+
chunk = ChatCompletionStreamResponse(
|
|
123
|
+
id=state.request_id,
|
|
124
|
+
created=state.created_at,
|
|
125
|
+
model=state.model,
|
|
126
|
+
choices=[
|
|
127
|
+
ChatCompletionStreamChoice(
|
|
128
|
+
index=0,
|
|
129
|
+
delta=ChatCompletionMessageDelta(),
|
|
130
|
+
finish_reason="stop",
|
|
131
|
+
)
|
|
132
|
+
],
|
|
133
|
+
)
|
|
134
|
+
return f"data: {chunk.model_dump_json()}\n\n"
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def build_reasoning_event(state: StreamingState, content: str) -> str:
|
|
138
|
+
"""Build a reasoning SSE event."""
|
|
139
|
+
return format_sse_event(ReasoningEvent(
|
|
140
|
+
content=content,
|
|
141
|
+
step=state.reasoning_step,
|
|
142
|
+
))
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def build_progress_event(
|
|
146
|
+
step: int,
|
|
147
|
+
total_steps: int,
|
|
148
|
+
label: str,
|
|
149
|
+
status: str = "in_progress",
|
|
150
|
+
) -> str:
|
|
151
|
+
"""Build a progress SSE event."""
|
|
152
|
+
return format_sse_event(ProgressEvent(
|
|
153
|
+
step=step,
|
|
154
|
+
total_steps=total_steps,
|
|
155
|
+
label=label,
|
|
156
|
+
status=status,
|
|
157
|
+
))
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def build_tool_start_event(
|
|
161
|
+
tool_name: str,
|
|
162
|
+
tool_id: str,
|
|
163
|
+
arguments: dict | None = None,
|
|
164
|
+
) -> str:
|
|
165
|
+
"""Build a tool call started SSE event."""
|
|
166
|
+
return format_sse_event(ToolCallEvent(
|
|
167
|
+
tool_name=tool_name,
|
|
168
|
+
tool_id=tool_id,
|
|
169
|
+
status="started",
|
|
170
|
+
arguments=arguments,
|
|
171
|
+
))
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def build_tool_complete_event(
|
|
175
|
+
tool_name: str,
|
|
176
|
+
tool_id: str,
|
|
177
|
+
arguments: dict | None = None,
|
|
178
|
+
result: Any = None,
|
|
179
|
+
) -> str:
|
|
180
|
+
"""Build a tool call completed SSE event."""
|
|
181
|
+
result_str = None
|
|
182
|
+
if result is not None:
|
|
183
|
+
result_str = str(result)
|
|
184
|
+
if len(result_str) > 200:
|
|
185
|
+
result_str = result_str[:200] + "..."
|
|
186
|
+
|
|
187
|
+
return format_sse_event(ToolCallEvent(
|
|
188
|
+
tool_name=tool_name,
|
|
189
|
+
tool_id=tool_id,
|
|
190
|
+
status="completed",
|
|
191
|
+
arguments=arguments,
|
|
192
|
+
result=result_str,
|
|
193
|
+
))
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def build_metadata_event(
|
|
197
|
+
message_id: str | None = None,
|
|
198
|
+
in_reply_to: str | None = None,
|
|
199
|
+
session_id: str | None = None,
|
|
200
|
+
agent_schema: str | None = None,
|
|
201
|
+
responding_agent: str | None = None,
|
|
202
|
+
confidence: float | None = None,
|
|
203
|
+
sources: list | None = None,
|
|
204
|
+
model_version: str | None = None,
|
|
205
|
+
latency_ms: int | None = None,
|
|
206
|
+
token_count: int | None = None,
|
|
207
|
+
trace_id: str | None = None,
|
|
208
|
+
span_id: str | None = None,
|
|
209
|
+
extra: dict | None = None,
|
|
210
|
+
) -> str:
|
|
211
|
+
"""Build a metadata SSE event."""
|
|
212
|
+
return format_sse_event(MetadataEvent(
|
|
213
|
+
message_id=message_id,
|
|
214
|
+
in_reply_to=in_reply_to,
|
|
215
|
+
session_id=session_id,
|
|
216
|
+
agent_schema=agent_schema,
|
|
217
|
+
responding_agent=responding_agent,
|
|
218
|
+
confidence=confidence,
|
|
219
|
+
sources=sources,
|
|
220
|
+
model_version=model_version,
|
|
221
|
+
latency_ms=latency_ms,
|
|
222
|
+
token_count=token_count,
|
|
223
|
+
trace_id=trace_id,
|
|
224
|
+
span_id=span_id,
|
|
225
|
+
extra=extra,
|
|
226
|
+
))
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
# =============================================================================
|
|
230
|
+
# TOOL ARGUMENT EXTRACTION
|
|
231
|
+
# =============================================================================
|
|
232
|
+
|
|
233
|
+
def extract_tool_args(part) -> dict | None:
|
|
234
|
+
"""
|
|
235
|
+
Extract arguments from a ToolCallPart.
|
|
236
|
+
|
|
237
|
+
Handles various formats:
|
|
238
|
+
- ArgsDict object with args_dict attribute
|
|
239
|
+
- Plain dict
|
|
240
|
+
- JSON string
|
|
241
|
+
"""
|
|
242
|
+
if part.args is None:
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
if hasattr(part.args, 'args_dict'):
|
|
246
|
+
return part.args.args_dict
|
|
247
|
+
|
|
248
|
+
if isinstance(part.args, dict):
|
|
249
|
+
return part.args
|
|
250
|
+
|
|
251
|
+
if isinstance(part.args, str) and part.args:
|
|
252
|
+
try:
|
|
253
|
+
return json.loads(part.args)
|
|
254
|
+
except json.JSONDecodeError:
|
|
255
|
+
logger.warning(f"Failed to parse tool args: {part.args[:100]}")
|
|
256
|
+
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def log_tool_call(tool_name: str, args_dict: dict | None) -> None:
|
|
261
|
+
"""Log a tool call with key parameters."""
|
|
262
|
+
if args_dict and tool_name == "search_rem":
|
|
263
|
+
query_type = args_dict.get("query_type", "?")
|
|
264
|
+
limit = args_dict.get("limit", 20)
|
|
265
|
+
table = args_dict.get("table", "")
|
|
266
|
+
query_text = args_dict.get("query_text", args_dict.get("entity_key", ""))
|
|
267
|
+
if query_text and len(str(query_text)) > 50:
|
|
268
|
+
query_text = str(query_text)[:50] + "..."
|
|
269
|
+
logger.info(f"🔧 {tool_name} {query_type.upper()} '{query_text}' table={table} limit={limit}")
|
|
270
|
+
else:
|
|
271
|
+
logger.info(f"🔧 {tool_name}")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def log_tool_result(tool_name: str, result_content: Any) -> None:
|
|
275
|
+
"""Log a tool result with key metrics."""
|
|
276
|
+
if tool_name == "search_rem" and isinstance(result_content, dict):
|
|
277
|
+
results = result_content.get("results", {})
|
|
278
|
+
if isinstance(results, dict):
|
|
279
|
+
count = results.get("count", len(results.get("results", [])))
|
|
280
|
+
query_type = results.get("query_type", "?")
|
|
281
|
+
query_text = results.get("query_text", results.get("key", ""))
|
|
282
|
+
table = results.get("table_name", "")
|
|
283
|
+
elif isinstance(results, list):
|
|
284
|
+
count = len(results)
|
|
285
|
+
query_type = "?"
|
|
286
|
+
query_text = ""
|
|
287
|
+
table = ""
|
|
288
|
+
else:
|
|
289
|
+
count = "?"
|
|
290
|
+
query_type = "?"
|
|
291
|
+
query_text = ""
|
|
292
|
+
table = ""
|
|
293
|
+
|
|
294
|
+
if query_text and len(str(query_text)) > 40:
|
|
295
|
+
query_text = str(query_text)[:40] + "..."
|
|
296
|
+
logger.info(f" ↳ {tool_name} {query_type} '{query_text}' table={table} → {count} results")
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
# =============================================================================
|
|
300
|
+
# METADATA EXTRACTION
|
|
301
|
+
# =============================================================================
|
|
302
|
+
|
|
303
|
+
def extract_metadata_from_result(result_content: Any) -> dict | None:
|
|
304
|
+
"""
|
|
305
|
+
Extract metadata from a register_metadata tool result.
|
|
306
|
+
|
|
307
|
+
Returns dict with extracted fields or None if not a metadata event.
|
|
308
|
+
"""
|
|
309
|
+
if not isinstance(result_content, dict):
|
|
310
|
+
return None
|
|
311
|
+
|
|
312
|
+
if not result_content.get("_metadata_event"):
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
return {
|
|
316
|
+
"confidence": result_content.get("confidence"),
|
|
317
|
+
"sources": result_content.get("sources"),
|
|
318
|
+
"references": result_content.get("references"),
|
|
319
|
+
"flags": result_content.get("flags"),
|
|
320
|
+
"session_name": result_content.get("session_name"),
|
|
321
|
+
"risk_level": result_content.get("risk_level"),
|
|
322
|
+
"risk_score": result_content.get("risk_score"),
|
|
323
|
+
"risk_reasoning": result_content.get("risk_reasoning"),
|
|
324
|
+
"recommended_action": result_content.get("recommended_action"),
|
|
325
|
+
"agent_schema": result_content.get("agent_schema"),
|
|
326
|
+
"extra": result_content.get("extra"),
|
|
327
|
+
}
|
rem/api/routers/query.py
CHANGED
|
@@ -90,8 +90,6 @@ from .common import ErrorResponse
|
|
|
90
90
|
|
|
91
91
|
from ...services.postgres import get_postgres_service
|
|
92
92
|
from ...services.rem.service import RemService
|
|
93
|
-
from ...services.rem.parser import RemQueryParser
|
|
94
|
-
from ...models.core import RemQuery
|
|
95
93
|
from ...settings import settings
|
|
96
94
|
|
|
97
95
|
router = APIRouter(prefix="/api/v1", tags=["query"])
|
|
@@ -331,7 +329,7 @@ async def execute_query(
|
|
|
331
329
|
return response
|
|
332
330
|
|
|
333
331
|
else:
|
|
334
|
-
# REM dialect mode -
|
|
332
|
+
# REM dialect mode - use unified execute_query_string
|
|
335
333
|
if not request.query:
|
|
336
334
|
raise HTTPException(
|
|
337
335
|
status_code=400,
|
|
@@ -340,17 +338,10 @@ async def execute_query(
|
|
|
340
338
|
|
|
341
339
|
logger.info(f"REM dialect query: {request.query[:100]}...")
|
|
342
340
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
rem_query = RemQuery.model_validate({
|
|
348
|
-
"query_type": query_type,
|
|
349
|
-
"parameters": parameters,
|
|
350
|
-
"user_id": effective_user_id,
|
|
351
|
-
})
|
|
352
|
-
|
|
353
|
-
result = await rem_service.execute_query(rem_query)
|
|
341
|
+
# Use the unified execute_query_string method
|
|
342
|
+
result = await rem_service.execute_query_string(
|
|
343
|
+
request.query, user_id=effective_user_id
|
|
344
|
+
)
|
|
354
345
|
|
|
355
346
|
return QueryResponse(
|
|
356
347
|
query_type=result["query_type"],
|
rem/cli/commands/ask.py
CHANGED
|
@@ -164,9 +164,13 @@ async def run_agent_non_streaming(
|
|
|
164
164
|
context: AgentContext | None = None,
|
|
165
165
|
plan: bool = False,
|
|
166
166
|
max_iterations: int | None = None,
|
|
167
|
+
user_message: str | None = None,
|
|
167
168
|
) -> dict[str, Any] | None:
|
|
168
169
|
"""
|
|
169
|
-
Run agent in non-streaming mode using agent.
|
|
170
|
+
Run agent in non-streaming mode using agent.iter() to capture tool calls.
|
|
171
|
+
|
|
172
|
+
This mirrors the streaming code path to ensure tool messages are properly
|
|
173
|
+
persisted to the database for state tracking across turns.
|
|
170
174
|
|
|
171
175
|
Args:
|
|
172
176
|
agent: Pydantic AI agent
|
|
@@ -176,77 +180,183 @@ async def run_agent_non_streaming(
|
|
|
176
180
|
context: Optional AgentContext for session persistence
|
|
177
181
|
plan: If True, output only the generated query (for query-agent)
|
|
178
182
|
max_iterations: Maximum iterations/requests (from agent schema or settings)
|
|
183
|
+
user_message: The user's original message (for database storage)
|
|
179
184
|
|
|
180
185
|
Returns:
|
|
181
186
|
Output data if successful, None otherwise
|
|
182
187
|
"""
|
|
183
188
|
from pydantic_ai import UsageLimits
|
|
189
|
+
from pydantic_ai.agent import Agent
|
|
190
|
+
from pydantic_ai.messages import (
|
|
191
|
+
FunctionToolResultEvent,
|
|
192
|
+
PartStartEvent,
|
|
193
|
+
PartEndEvent,
|
|
194
|
+
TextPart,
|
|
195
|
+
ToolCallPart,
|
|
196
|
+
)
|
|
184
197
|
from rem.utils.date_utils import to_iso_with_z, utc_now
|
|
185
198
|
|
|
186
199
|
logger.info("Running agent in non-streaming mode...")
|
|
187
200
|
|
|
188
201
|
try:
|
|
189
|
-
#
|
|
190
|
-
|
|
191
|
-
|
|
202
|
+
# Track tool calls for persistence (same as streaming code path)
|
|
203
|
+
tool_calls: list = []
|
|
204
|
+
pending_tool_data: dict = {}
|
|
205
|
+
pending_tool_completions: list = []
|
|
206
|
+
accumulated_content: list = []
|
|
207
|
+
|
|
208
|
+
# Get the underlying pydantic-ai agent
|
|
209
|
+
pydantic_agent = agent.agent if hasattr(agent, 'agent') else agent
|
|
210
|
+
|
|
211
|
+
# Use agent.iter() to capture tool calls (same as streaming)
|
|
212
|
+
async with pydantic_agent.iter(prompt) as agent_run:
|
|
213
|
+
async for node in agent_run:
|
|
214
|
+
# Handle model request nodes (text + tool call starts)
|
|
215
|
+
if Agent.is_model_request_node(node):
|
|
216
|
+
async with node.stream(agent_run.ctx) as request_stream:
|
|
217
|
+
async for event in request_stream:
|
|
218
|
+
# Capture text content
|
|
219
|
+
if isinstance(event, PartStartEvent) and isinstance(event.part, TextPart):
|
|
220
|
+
if event.part.content:
|
|
221
|
+
accumulated_content.append(event.part.content)
|
|
222
|
+
|
|
223
|
+
# Capture tool call starts
|
|
224
|
+
elif isinstance(event, PartStartEvent) and isinstance(event.part, ToolCallPart):
|
|
225
|
+
tool_name = event.part.tool_name
|
|
226
|
+
if tool_name == "final_result":
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
import uuid
|
|
230
|
+
tool_id = f"call_{uuid.uuid4().hex[:8]}"
|
|
231
|
+
pending_tool_completions.append((tool_name, tool_id))
|
|
232
|
+
|
|
233
|
+
# Extract arguments
|
|
234
|
+
args_dict = {}
|
|
235
|
+
if hasattr(event.part, 'args'):
|
|
236
|
+
args = event.part.args
|
|
237
|
+
if isinstance(args, str):
|
|
238
|
+
try:
|
|
239
|
+
args_dict = json.loads(args)
|
|
240
|
+
except json.JSONDecodeError:
|
|
241
|
+
args_dict = {"raw": args}
|
|
242
|
+
elif isinstance(args, dict):
|
|
243
|
+
args_dict = args
|
|
244
|
+
|
|
245
|
+
pending_tool_data[tool_id] = {
|
|
246
|
+
"tool_name": tool_name,
|
|
247
|
+
"tool_id": tool_id,
|
|
248
|
+
"arguments": args_dict,
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
# Print tool call for CLI visibility
|
|
252
|
+
print(f"\n[Calling: {tool_name}]", flush=True)
|
|
253
|
+
|
|
254
|
+
# Capture tool call end (update arguments if changed)
|
|
255
|
+
elif isinstance(event, PartEndEvent) and isinstance(event.part, ToolCallPart):
|
|
256
|
+
pass # Arguments already captured at start
|
|
257
|
+
|
|
258
|
+
# Handle tool execution nodes (results)
|
|
259
|
+
elif Agent.is_call_tools_node(node):
|
|
260
|
+
async with node.stream(agent_run.ctx) as tools_stream:
|
|
261
|
+
async for event in tools_stream:
|
|
262
|
+
if isinstance(event, FunctionToolResultEvent):
|
|
263
|
+
# Get tool info from pending queue
|
|
264
|
+
if pending_tool_completions:
|
|
265
|
+
tool_name, tool_id = pending_tool_completions.pop(0)
|
|
266
|
+
else:
|
|
267
|
+
import uuid
|
|
268
|
+
tool_name = "tool"
|
|
269
|
+
tool_id = f"call_{uuid.uuid4().hex[:8]}"
|
|
270
|
+
|
|
271
|
+
result_content = event.result.content if hasattr(event.result, 'content') else event.result
|
|
272
|
+
|
|
273
|
+
# Capture tool call for persistence
|
|
274
|
+
if tool_id in pending_tool_data:
|
|
275
|
+
tool_data = pending_tool_data[tool_id]
|
|
276
|
+
tool_data["result"] = result_content
|
|
277
|
+
tool_calls.append(tool_data)
|
|
278
|
+
del pending_tool_data[tool_id]
|
|
279
|
+
|
|
280
|
+
# Get final result
|
|
281
|
+
result = agent_run.result
|
|
192
282
|
|
|
193
283
|
# Extract output data
|
|
194
284
|
output_data = None
|
|
195
285
|
assistant_content = None
|
|
196
|
-
if hasattr(result, "output"):
|
|
286
|
+
if result is not None and hasattr(result, "output"):
|
|
197
287
|
output = result.output
|
|
198
288
|
from rem.agentic.serialization import serialize_agent_result
|
|
199
289
|
output_data = serialize_agent_result(output)
|
|
200
290
|
|
|
201
291
|
if plan and isinstance(output_data, dict) and "query" in output_data:
|
|
202
|
-
# Plan mode: Output only the query
|
|
203
|
-
# Use sql formatting if possible or just raw string
|
|
204
292
|
assistant_content = output_data["query"]
|
|
205
293
|
print(assistant_content)
|
|
206
294
|
else:
|
|
207
|
-
#
|
|
208
|
-
|
|
295
|
+
# For string output, use it directly
|
|
296
|
+
if isinstance(output_data, str):
|
|
297
|
+
assistant_content = output_data
|
|
298
|
+
else:
|
|
299
|
+
assistant_content = json.dumps(output_data, indent=2)
|
|
209
300
|
print(assistant_content)
|
|
210
301
|
else:
|
|
211
|
-
|
|
212
|
-
assistant_content
|
|
213
|
-
|
|
302
|
+
assistant_content = str(result) if result else ""
|
|
303
|
+
if assistant_content:
|
|
304
|
+
print(assistant_content)
|
|
214
305
|
|
|
215
306
|
# Save to file if requested
|
|
216
307
|
if output_file and output_data:
|
|
217
308
|
await _save_output_file(output_file, output_data)
|
|
218
309
|
|
|
219
|
-
# Save session messages (
|
|
310
|
+
# Save session messages including tool calls (same as streaming code path)
|
|
220
311
|
if context and context.session_id and settings.postgres.enabled:
|
|
221
312
|
from ...services.session.compression import SessionMessageStore
|
|
222
313
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
# We need to extract the last user message
|
|
226
|
-
user_message_content = prompt.split("\n\n")[-1] if "\n\n" in prompt else prompt
|
|
314
|
+
timestamp = to_iso_with_z(utc_now())
|
|
315
|
+
messages_to_store = []
|
|
227
316
|
|
|
228
|
-
|
|
317
|
+
# Save user message first
|
|
318
|
+
user_message_content = user_message or (prompt.split("\n\n")[-1] if "\n\n" in prompt else prompt)
|
|
319
|
+
messages_to_store.append({
|
|
229
320
|
"role": "user",
|
|
230
321
|
"content": user_message_content,
|
|
231
|
-
"timestamp":
|
|
232
|
-
}
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
322
|
+
"timestamp": timestamp,
|
|
323
|
+
})
|
|
324
|
+
|
|
325
|
+
# Save tool call messages (message_type: "tool") - CRITICAL for state tracking
|
|
326
|
+
for tool_call in tool_calls:
|
|
327
|
+
if not tool_call:
|
|
328
|
+
continue
|
|
329
|
+
tool_message = {
|
|
330
|
+
"role": "tool",
|
|
331
|
+
"content": json.dumps(tool_call.get("result", {}), default=str),
|
|
332
|
+
"timestamp": timestamp,
|
|
333
|
+
"tool_call_id": tool_call.get("tool_id"),
|
|
334
|
+
"tool_name": tool_call.get("tool_name"),
|
|
335
|
+
"tool_arguments": tool_call.get("arguments"),
|
|
336
|
+
}
|
|
337
|
+
messages_to_store.append(tool_message)
|
|
338
|
+
|
|
339
|
+
# Save assistant message
|
|
340
|
+
if assistant_content:
|
|
341
|
+
messages_to_store.append({
|
|
342
|
+
"role": "assistant",
|
|
343
|
+
"content": assistant_content,
|
|
344
|
+
"timestamp": timestamp,
|
|
345
|
+
})
|
|
346
|
+
|
|
347
|
+
# Store all messages
|
|
241
348
|
store = SessionMessageStore(user_id=context.user_id or settings.test.effective_user_id)
|
|
242
349
|
await store.store_session_messages(
|
|
243
350
|
session_id=context.session_id,
|
|
244
|
-
messages=
|
|
351
|
+
messages=messages_to_store,
|
|
245
352
|
user_id=context.user_id,
|
|
246
|
-
compress=
|
|
353
|
+
compress=False, # Store uncompressed; compression happens on reload
|
|
247
354
|
)
|
|
248
355
|
|
|
249
|
-
logger.debug(
|
|
356
|
+
logger.debug(
|
|
357
|
+
f"Saved {len(tool_calls)} tool calls + user/assistant messages "
|
|
358
|
+
f"to session {context.session_id}"
|
|
359
|
+
)
|
|
250
360
|
|
|
251
361
|
return output_data
|
|
252
362
|
|
|
@@ -332,8 +442,8 @@ async def _save_output_file(file_path: Path, data: dict[str, Any]) -> None:
|
|
|
332
442
|
)
|
|
333
443
|
@click.option(
|
|
334
444
|
"--stream/--no-stream",
|
|
335
|
-
default=
|
|
336
|
-
help="Enable streaming mode (default:
|
|
445
|
+
default=True,
|
|
446
|
+
help="Enable streaming mode (default: enabled)",
|
|
337
447
|
)
|
|
338
448
|
@click.option(
|
|
339
449
|
"--user-id",
|
|
@@ -538,6 +648,7 @@ async def _ask_async(
|
|
|
538
648
|
output_file=output_file,
|
|
539
649
|
context=context,
|
|
540
650
|
plan=plan,
|
|
651
|
+
user_message=query,
|
|
541
652
|
)
|
|
542
653
|
|
|
543
654
|
# Log session ID for reuse
|
rem/cli/commands/process.py
CHANGED
|
@@ -193,7 +193,15 @@ def process_ingest(
|
|
|
193
193
|
try:
|
|
194
194
|
# Read file content
|
|
195
195
|
content = file_path.read_text(encoding="utf-8")
|
|
196
|
-
|
|
196
|
+
|
|
197
|
+
# Generate entity key from filename
|
|
198
|
+
# Special case: README files use parent directory as section name
|
|
199
|
+
if file_path.stem.lower() == "readme":
|
|
200
|
+
# Use parent directory name, e.g., "drugs" for drugs/README.md
|
|
201
|
+
# For nested paths like disorders/anxiety/README.md -> "anxiety"
|
|
202
|
+
entity_key = file_path.parent.name
|
|
203
|
+
else:
|
|
204
|
+
entity_key = file_path.stem # filename without extension
|
|
197
205
|
|
|
198
206
|
# Build entity based on table
|
|
199
207
|
entity_data = {
|