remdb 0.3.202__py3-none-any.whl → 0.3.245__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of remdb might be problematic. Click here for more details.

Files changed (44) hide show
  1. rem/agentic/README.md +36 -2
  2. rem/agentic/context.py +86 -3
  3. rem/agentic/context_builder.py +39 -33
  4. rem/agentic/mcp/tool_wrapper.py +2 -2
  5. rem/agentic/providers/pydantic_ai.py +68 -51
  6. rem/agentic/schema.py +2 -2
  7. rem/api/mcp_router/resources.py +223 -0
  8. rem/api/mcp_router/tools.py +170 -18
  9. rem/api/routers/admin.py +30 -4
  10. rem/api/routers/auth.py +175 -18
  11. rem/api/routers/chat/child_streaming.py +394 -0
  12. rem/api/routers/chat/completions.py +24 -29
  13. rem/api/routers/chat/sse_events.py +5 -1
  14. rem/api/routers/chat/streaming.py +242 -272
  15. rem/api/routers/chat/streaming_utils.py +327 -0
  16. rem/api/routers/common.py +18 -0
  17. rem/api/routers/dev.py +7 -1
  18. rem/api/routers/feedback.py +9 -1
  19. rem/api/routers/messages.py +80 -15
  20. rem/api/routers/models.py +9 -1
  21. rem/api/routers/query.py +17 -15
  22. rem/api/routers/shared_sessions.py +16 -0
  23. rem/cli/commands/ask.py +205 -114
  24. rem/cli/commands/process.py +12 -4
  25. rem/cli/commands/query.py +109 -0
  26. rem/cli/commands/session.py +117 -0
  27. rem/cli/main.py +2 -0
  28. rem/models/entities/session.py +1 -0
  29. rem/schemas/agents/rem.yaml +1 -1
  30. rem/services/postgres/repository.py +7 -7
  31. rem/services/rem/service.py +47 -0
  32. rem/services/session/__init__.py +2 -1
  33. rem/services/session/compression.py +14 -12
  34. rem/services/session/pydantic_messages.py +111 -11
  35. rem/services/session/reload.py +2 -1
  36. rem/settings.py +71 -0
  37. rem/sql/migrations/001_install.sql +4 -4
  38. rem/sql/migrations/004_cache_system.sql +3 -1
  39. rem/sql/migrations/migrate_session_id_to_uuid.sql +45 -0
  40. rem/utils/schema_loader.py +139 -111
  41. {remdb-0.3.202.dist-info → remdb-0.3.245.dist-info}/METADATA +2 -2
  42. {remdb-0.3.202.dist-info → remdb-0.3.245.dist-info}/RECORD +44 -39
  43. {remdb-0.3.202.dist-info → remdb-0.3.245.dist-info}/WHEEL +0 -0
  44. {remdb-0.3.202.dist-info → remdb-0.3.245.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,327 @@
1
+ """
2
+ Streaming Utilities.
3
+
4
+ Pure functions and data structures for SSE streaming.
5
+ No I/O, no database calls - just data transformation.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import time
12
+ import uuid
13
+ from dataclasses import dataclass, field
14
+ from typing import Any
15
+
16
+ from loguru import logger
17
+
18
+ from .models import (
19
+ ChatCompletionMessageDelta,
20
+ ChatCompletionStreamChoice,
21
+ ChatCompletionStreamResponse,
22
+ )
23
+ from .sse_events import (
24
+ MetadataEvent,
25
+ ProgressEvent,
26
+ ReasoningEvent,
27
+ ToolCallEvent,
28
+ format_sse_event,
29
+ )
30
+
31
+
32
+ # =============================================================================
33
+ # STREAMING STATE
34
+ # =============================================================================
35
+
36
+ @dataclass
37
+ class StreamingState:
38
+ """
39
+ Tracks state during SSE streaming.
40
+
41
+ This is a pure data container - no methods that do I/O.
42
+ """
43
+ request_id: str
44
+ created_at: int
45
+ model: str
46
+ start_time: float = field(default_factory=time.time)
47
+
48
+ # Content tracking
49
+ is_first_chunk: bool = True
50
+ token_count: int = 0
51
+
52
+ # Child agent tracking - KEY FOR DUPLICATION FIX
53
+ child_content_streamed: bool = False
54
+ responding_agent: str | None = None
55
+
56
+ # Tool tracking
57
+ active_tool_calls: dict = field(default_factory=dict) # index -> (name, id)
58
+ pending_tool_completions: list = field(default_factory=list) # FIFO queue
59
+ pending_tool_data: dict = field(default_factory=dict) # tool_id -> data
60
+
61
+ # Reasoning tracking
62
+ reasoning_step: int = 0
63
+
64
+ # Progress tracking
65
+ current_step: int = 0
66
+ total_steps: int = 3
67
+
68
+ # Metadata tracking
69
+ metadata_registered: bool = False
70
+
71
+ # Trace context (captured from OTEL)
72
+ trace_id: str | None = None
73
+ span_id: str | None = None
74
+
75
+ @classmethod
76
+ def create(cls, model: str, request_id: str | None = None) -> "StreamingState":
77
+ """Create a new streaming state."""
78
+ return cls(
79
+ request_id=request_id or f"chatcmpl-{uuid.uuid4().hex[:24]}",
80
+ created_at=int(time.time()),
81
+ model=model,
82
+ )
83
+
84
+ def latency_ms(self) -> int:
85
+ """Calculate latency since start."""
86
+ return int((time.time() - self.start_time) * 1000)
87
+
88
+
89
+ # =============================================================================
90
+ # SSE CHUNK BUILDERS
91
+ # =============================================================================
92
+
93
+ def build_content_chunk(state: StreamingState, content: str) -> str:
94
+ """
95
+ Build an SSE content chunk in OpenAI format.
96
+
97
+ Updates state.is_first_chunk and state.token_count.
98
+ """
99
+ state.token_count += len(content.split())
100
+
101
+ chunk = ChatCompletionStreamResponse(
102
+ id=state.request_id,
103
+ created=state.created_at,
104
+ model=state.model,
105
+ choices=[
106
+ ChatCompletionStreamChoice(
107
+ index=0,
108
+ delta=ChatCompletionMessageDelta(
109
+ role="assistant" if state.is_first_chunk else None,
110
+ content=content,
111
+ ),
112
+ finish_reason=None,
113
+ )
114
+ ],
115
+ )
116
+ state.is_first_chunk = False
117
+ return f"data: {chunk.model_dump_json()}\n\n"
118
+
119
+
120
+ def build_final_chunk(state: StreamingState) -> str:
121
+ """Build the final SSE chunk with finish_reason=stop."""
122
+ chunk = ChatCompletionStreamResponse(
123
+ id=state.request_id,
124
+ created=state.created_at,
125
+ model=state.model,
126
+ choices=[
127
+ ChatCompletionStreamChoice(
128
+ index=0,
129
+ delta=ChatCompletionMessageDelta(),
130
+ finish_reason="stop",
131
+ )
132
+ ],
133
+ )
134
+ return f"data: {chunk.model_dump_json()}\n\n"
135
+
136
+
137
+ def build_reasoning_event(state: StreamingState, content: str) -> str:
138
+ """Build a reasoning SSE event."""
139
+ return format_sse_event(ReasoningEvent(
140
+ content=content,
141
+ step=state.reasoning_step,
142
+ ))
143
+
144
+
145
+ def build_progress_event(
146
+ step: int,
147
+ total_steps: int,
148
+ label: str,
149
+ status: str = "in_progress",
150
+ ) -> str:
151
+ """Build a progress SSE event."""
152
+ return format_sse_event(ProgressEvent(
153
+ step=step,
154
+ total_steps=total_steps,
155
+ label=label,
156
+ status=status,
157
+ ))
158
+
159
+
160
+ def build_tool_start_event(
161
+ tool_name: str,
162
+ tool_id: str,
163
+ arguments: dict | None = None,
164
+ ) -> str:
165
+ """Build a tool call started SSE event."""
166
+ return format_sse_event(ToolCallEvent(
167
+ tool_name=tool_name,
168
+ tool_id=tool_id,
169
+ status="started",
170
+ arguments=arguments,
171
+ ))
172
+
173
+
174
+ def build_tool_complete_event(
175
+ tool_name: str,
176
+ tool_id: str,
177
+ arguments: dict | None = None,
178
+ result: Any = None,
179
+ ) -> str:
180
+ """Build a tool call completed SSE event."""
181
+ result_str = None
182
+ if result is not None:
183
+ result_str = str(result)
184
+ if len(result_str) > 200:
185
+ result_str = result_str[:200] + "..."
186
+
187
+ return format_sse_event(ToolCallEvent(
188
+ tool_name=tool_name,
189
+ tool_id=tool_id,
190
+ status="completed",
191
+ arguments=arguments,
192
+ result=result_str,
193
+ ))
194
+
195
+
196
+ def build_metadata_event(
197
+ message_id: str | None = None,
198
+ in_reply_to: str | None = None,
199
+ session_id: str | None = None,
200
+ agent_schema: str | None = None,
201
+ responding_agent: str | None = None,
202
+ confidence: float | None = None,
203
+ sources: list | None = None,
204
+ model_version: str | None = None,
205
+ latency_ms: int | None = None,
206
+ token_count: int | None = None,
207
+ trace_id: str | None = None,
208
+ span_id: str | None = None,
209
+ extra: dict | None = None,
210
+ ) -> str:
211
+ """Build a metadata SSE event."""
212
+ return format_sse_event(MetadataEvent(
213
+ message_id=message_id,
214
+ in_reply_to=in_reply_to,
215
+ session_id=session_id,
216
+ agent_schema=agent_schema,
217
+ responding_agent=responding_agent,
218
+ confidence=confidence,
219
+ sources=sources,
220
+ model_version=model_version,
221
+ latency_ms=latency_ms,
222
+ token_count=token_count,
223
+ trace_id=trace_id,
224
+ span_id=span_id,
225
+ extra=extra,
226
+ ))
227
+
228
+
229
+ # =============================================================================
230
+ # TOOL ARGUMENT EXTRACTION
231
+ # =============================================================================
232
+
233
+ def extract_tool_args(part) -> dict | None:
234
+ """
235
+ Extract arguments from a ToolCallPart.
236
+
237
+ Handles various formats:
238
+ - ArgsDict object with args_dict attribute
239
+ - Plain dict
240
+ - JSON string
241
+ """
242
+ if part.args is None:
243
+ return None
244
+
245
+ if hasattr(part.args, 'args_dict'):
246
+ return part.args.args_dict
247
+
248
+ if isinstance(part.args, dict):
249
+ return part.args
250
+
251
+ if isinstance(part.args, str) and part.args:
252
+ try:
253
+ return json.loads(part.args)
254
+ except json.JSONDecodeError:
255
+ logger.warning(f"Failed to parse tool args: {part.args[:100]}")
256
+
257
+ return None
258
+
259
+
260
+ def log_tool_call(tool_name: str, args_dict: dict | None) -> None:
261
+ """Log a tool call with key parameters."""
262
+ if args_dict and tool_name == "search_rem":
263
+ query_type = args_dict.get("query_type", "?")
264
+ limit = args_dict.get("limit", 20)
265
+ table = args_dict.get("table", "")
266
+ query_text = args_dict.get("query_text", args_dict.get("entity_key", ""))
267
+ if query_text and len(str(query_text)) > 50:
268
+ query_text = str(query_text)[:50] + "..."
269
+ logger.info(f"🔧 {tool_name} {query_type.upper()} '{query_text}' table={table} limit={limit}")
270
+ else:
271
+ logger.info(f"🔧 {tool_name}")
272
+
273
+
274
+ def log_tool_result(tool_name: str, result_content: Any) -> None:
275
+ """Log a tool result with key metrics."""
276
+ if tool_name == "search_rem" and isinstance(result_content, dict):
277
+ results = result_content.get("results", {})
278
+ if isinstance(results, dict):
279
+ count = results.get("count", len(results.get("results", [])))
280
+ query_type = results.get("query_type", "?")
281
+ query_text = results.get("query_text", results.get("key", ""))
282
+ table = results.get("table_name", "")
283
+ elif isinstance(results, list):
284
+ count = len(results)
285
+ query_type = "?"
286
+ query_text = ""
287
+ table = ""
288
+ else:
289
+ count = "?"
290
+ query_type = "?"
291
+ query_text = ""
292
+ table = ""
293
+
294
+ if query_text and len(str(query_text)) > 40:
295
+ query_text = str(query_text)[:40] + "..."
296
+ logger.info(f" ↳ {tool_name} {query_type} '{query_text}' table={table} → {count} results")
297
+
298
+
299
+ # =============================================================================
300
+ # METADATA EXTRACTION
301
+ # =============================================================================
302
+
303
+ def extract_metadata_from_result(result_content: Any) -> dict | None:
304
+ """
305
+ Extract metadata from a register_metadata tool result.
306
+
307
+ Returns dict with extracted fields or None if not a metadata event.
308
+ """
309
+ if not isinstance(result_content, dict):
310
+ return None
311
+
312
+ if not result_content.get("_metadata_event"):
313
+ return None
314
+
315
+ return {
316
+ "confidence": result_content.get("confidence"),
317
+ "sources": result_content.get("sources"),
318
+ "references": result_content.get("references"),
319
+ "flags": result_content.get("flags"),
320
+ "session_name": result_content.get("session_name"),
321
+ "risk_level": result_content.get("risk_level"),
322
+ "risk_score": result_content.get("risk_score"),
323
+ "risk_reasoning": result_content.get("risk_reasoning"),
324
+ "recommended_action": result_content.get("recommended_action"),
325
+ "agent_schema": result_content.get("agent_schema"),
326
+ "extra": result_content.get("extra"),
327
+ }
@@ -0,0 +1,18 @@
1
+ """
2
+ Common models shared across API routers.
3
+ """
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class ErrorResponse(BaseModel):
9
+ """Standard error response format for HTTPException errors.
10
+
11
+ This is different from FastAPI's HTTPValidationError which is used
12
+ for Pydantic validation failures (422 errors with loc/msg/type array).
13
+
14
+ HTTPException errors return this simpler format:
15
+ {"detail": "Error message here"}
16
+ """
17
+
18
+ detail: str = Field(description="Error message describing what went wrong")
rem/api/routers/dev.py CHANGED
@@ -11,6 +11,7 @@ Endpoints:
11
11
  from fastapi import APIRouter, HTTPException, Request
12
12
  from loguru import logger
13
13
 
14
+ from .common import ErrorResponse
14
15
  from ...settings import settings
15
16
 
16
17
  router = APIRouter(prefix="/api/dev", tags=["dev"])
@@ -45,7 +46,12 @@ def verify_dev_token(token: str) -> bool:
45
46
  return token == expected
46
47
 
47
48
 
48
- @router.get("/token")
49
+ @router.get(
50
+ "/token",
51
+ responses={
52
+ 401: {"model": ErrorResponse, "description": "Dev tokens not available in production"},
53
+ },
54
+ )
49
55
  async def get_dev_token(request: Request):
50
56
  """
51
57
  Get a development token for testing (non-production only).
@@ -63,6 +63,8 @@ from fastapi import APIRouter, Header, HTTPException, Request, Response
63
63
  from loguru import logger
64
64
  from pydantic import BaseModel, Field
65
65
 
66
+ from .common import ErrorResponse
67
+
66
68
  from ..deps import get_user_id_from_request
67
69
  from ...models.entities import Feedback
68
70
  from ...services.postgres import Repository
@@ -121,7 +123,13 @@ class FeedbackResponse(BaseModel):
121
123
  # =============================================================================
122
124
 
123
125
 
124
- @router.post("/messages/feedback", response_model=FeedbackResponse)
126
+ @router.post(
127
+ "/messages/feedback",
128
+ response_model=FeedbackResponse,
129
+ responses={
130
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
131
+ },
132
+ )
125
133
  async def submit_feedback(
126
134
  request: Request,
127
135
  response: Response,
@@ -16,6 +16,7 @@ Endpoints:
16
16
  """
17
17
 
18
18
  from datetime import datetime
19
+ from enum import Enum
19
20
  from typing import Literal
20
21
  from uuid import UUID
21
22
 
@@ -23,6 +24,8 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request
23
24
  from loguru import logger
24
25
  from pydantic import BaseModel, Field
25
26
 
27
+ from .common import ErrorResponse
28
+
26
29
  from ..deps import (
27
30
  get_current_user,
28
31
  get_user_filter,
@@ -38,6 +41,18 @@ from ...utils.date_utils import parse_iso, utc_now
38
41
  router = APIRouter(prefix="/api/v1")
39
42
 
40
43
 
44
+ # =============================================================================
45
+ # Enums
46
+ # =============================================================================
47
+
48
+
49
+ class SortOrder(str, Enum):
50
+ """Sort order for list queries."""
51
+
52
+ ASC = "asc"
53
+ DESC = "desc"
54
+
55
+
41
56
  # =============================================================================
42
57
  # Request/Response Models
43
58
  # =============================================================================
@@ -134,7 +149,14 @@ class SessionsQueryResponse(BaseModel):
134
149
  # =============================================================================
135
150
 
136
151
 
137
- @router.get("/messages", response_model=MessageListResponse, tags=["messages"])
152
+ @router.get(
153
+ "/messages",
154
+ response_model=MessageListResponse,
155
+ tags=["messages"],
156
+ responses={
157
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
158
+ },
159
+ )
138
160
  async def list_messages(
139
161
  request: Request,
140
162
  mine: bool = Query(default=False, description="Only show my messages (uses JWT identity)"),
@@ -151,6 +173,7 @@ async def list_messages(
151
173
  ),
152
174
  limit: int = Query(default=50, ge=1, le=100, description="Max results to return"),
153
175
  offset: int = Query(default=0, ge=0, description="Offset for pagination"),
176
+ sort: SortOrder = Query(default=SortOrder.DESC, description="Sort order by created_at (asc or desc)"),
154
177
  ) -> MessageListResponse:
155
178
  """
156
179
  List messages with optional filters.
@@ -166,8 +189,9 @@ async def list_messages(
166
189
  - session_id: Filter by conversation session
167
190
  - start_date/end_date: Filter by creation time range (ISO 8601 format)
168
191
  - message_type: Filter by role (user, assistant, system, tool)
192
+ - sort: Sort order by created_at (asc or desc, default: desc)
169
193
 
170
- Returns paginated results ordered by created_at descending.
194
+ Returns paginated results ordered by created_at.
171
195
  """
172
196
  if not settings.postgres.enabled:
173
197
  raise HTTPException(status_code=503, detail="Database not enabled")
@@ -189,6 +213,7 @@ async def list_messages(
189
213
 
190
214
  # Apply optional filters
191
215
  if session_id:
216
+ # session_id is the session UUID - use directly
192
217
  filters["session_id"] = session_id
193
218
  if message_type:
194
219
  filters["message_type"] = message_type
@@ -200,12 +225,15 @@ async def list_messages(
200
225
  f"filters={filters}"
201
226
  )
202
227
 
228
+ # Build order_by clause based on sort parameter
229
+ order_by = f"created_at {sort.value.upper()}"
230
+
203
231
  # For date filtering, we need custom SQL (not supported by basic Repository)
204
232
  # For now, fetch all matching base filters and filter in Python
205
233
  # TODO: Extend Repository to support date range filters
206
234
  messages = await repo.find(
207
235
  filters,
208
- order_by="created_at DESC",
236
+ order_by=order_by,
209
237
  limit=limit + 1, # Fetch one extra to determine has_more
210
238
  offset=offset,
211
239
  )
@@ -241,7 +269,16 @@ async def list_messages(
241
269
  return MessageListResponse(data=messages, total=total, has_more=has_more)
242
270
 
243
271
 
244
- @router.get("/messages/{message_id}", response_model=Message, tags=["messages"])
272
+ @router.get(
273
+ "/messages/{message_id}",
274
+ response_model=Message,
275
+ tags=["messages"],
276
+ responses={
277
+ 403: {"model": ErrorResponse, "description": "Access denied: not owner"},
278
+ 404: {"model": ErrorResponse, "description": "Message not found"},
279
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
280
+ },
281
+ )
245
282
  async def get_message(
246
283
  request: Request,
247
284
  message_id: str,
@@ -287,7 +324,14 @@ async def get_message(
287
324
  # =============================================================================
288
325
 
289
326
 
290
- @router.get("/sessions", response_model=SessionsQueryResponse, tags=["sessions"])
327
+ @router.get(
328
+ "/sessions",
329
+ response_model=SessionsQueryResponse,
330
+ tags=["sessions"],
331
+ responses={
332
+ 503: {"model": ErrorResponse, "description": "Database not enabled or connection failed"},
333
+ },
334
+ )
291
335
  async def list_sessions(
292
336
  request: Request,
293
337
  user_id: str | None = Query(default=None, description="Filter by user ID (admin only for cross-user)"),
@@ -400,7 +444,15 @@ async def list_sessions(
400
444
  )
401
445
 
402
446
 
403
- @router.post("/sessions", response_model=Session, status_code=201, tags=["sessions"])
447
+ @router.post(
448
+ "/sessions",
449
+ response_model=Session,
450
+ status_code=201,
451
+ tags=["sessions"],
452
+ responses={
453
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
454
+ },
455
+ )
404
456
  async def create_session(
405
457
  request_body: SessionCreateRequest,
406
458
  user: dict = Depends(require_admin),
@@ -452,7 +504,16 @@ async def create_session(
452
504
  return result # type: ignore
453
505
 
454
506
 
455
- @router.get("/sessions/{session_id}", response_model=Session, tags=["sessions"])
507
+ @router.get(
508
+ "/sessions/{session_id}",
509
+ response_model=Session,
510
+ tags=["sessions"],
511
+ responses={
512
+ 403: {"model": ErrorResponse, "description": "Access denied: not owner"},
513
+ 404: {"model": ErrorResponse, "description": "Session not found"},
514
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
515
+ },
516
+ )
456
517
  async def get_session(
457
518
  request: Request,
458
519
  session_id: str,
@@ -465,7 +526,7 @@ async def get_session(
465
526
  - Admin users: Can access any session
466
527
 
467
528
  Args:
468
- session_id: UUID or name of the session
529
+ session_id: UUID of the session
469
530
 
470
531
  Returns:
471
532
  Session object if found
@@ -481,12 +542,7 @@ async def get_session(
481
542
  session = await repo.get_by_id(session_id)
482
543
 
483
544
  if not session:
484
- # Try finding by name
485
- sessions = await repo.find({"name": session_id}, limit=1)
486
- if sessions:
487
- session = sessions[0]
488
- else:
489
- raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
545
+ raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
490
546
 
491
547
  # Check access: admin or owner
492
548
  current_user = get_current_user(request)
@@ -498,7 +554,16 @@ async def get_session(
498
554
  return session
499
555
 
500
556
 
501
- @router.put("/sessions/{session_id}", response_model=Session, tags=["sessions"])
557
+ @router.put(
558
+ "/sessions/{session_id}",
559
+ response_model=Session,
560
+ tags=["sessions"],
561
+ responses={
562
+ 403: {"model": ErrorResponse, "description": "Access denied: not owner"},
563
+ 404: {"model": ErrorResponse, "description": "Session not found"},
564
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
565
+ },
566
+ )
502
567
  async def update_session(
503
568
  request: Request,
504
569
  session_id: str,
rem/api/routers/models.py CHANGED
@@ -15,6 +15,8 @@ from typing import Literal
15
15
  from fastapi import APIRouter, HTTPException
16
16
  from pydantic import BaseModel, Field
17
17
 
18
+ from .common import ErrorResponse
19
+
18
20
  from rem.agentic.llm_provider_models import (
19
21
  ModelInfo,
20
22
  AVAILABLE_MODELS,
@@ -57,7 +59,13 @@ async def list_models() -> ModelsResponse:
57
59
  return ModelsResponse(data=AVAILABLE_MODELS)
58
60
 
59
61
 
60
- @router.get("/models/{model_id:path}", response_model=ModelInfo)
62
+ @router.get(
63
+ "/models/{model_id:path}",
64
+ response_model=ModelInfo,
65
+ responses={
66
+ 404: {"model": ErrorResponse, "description": "Model not found"},
67
+ },
68
+ )
61
69
  async def get_model(model_id: str) -> ModelInfo:
62
70
  """
63
71
  Get information about a specific model.
rem/api/routers/query.py CHANGED
@@ -86,10 +86,10 @@ from fastapi import APIRouter, Header, HTTPException
86
86
  from loguru import logger
87
87
  from pydantic import BaseModel, Field
88
88
 
89
+ from .common import ErrorResponse
90
+
89
91
  from ...services.postgres import get_postgres_service
90
92
  from ...services.rem.service import RemService
91
- from ...services.rem.parser import RemQueryParser
92
- from ...models.core import RemQuery
93
93
  from ...settings import settings
94
94
 
95
95
  router = APIRouter(prefix="/api/v1", tags=["query"])
@@ -213,7 +213,16 @@ class QueryResponse(BaseModel):
213
213
  )
214
214
 
215
215
 
216
- @router.post("/query", response_model=QueryResponse)
216
+ @router.post(
217
+ "/query",
218
+ response_model=QueryResponse,
219
+ responses={
220
+ 400: {"model": ErrorResponse, "description": "Invalid query or missing required fields"},
221
+ 500: {"model": ErrorResponse, "description": "Query execution failed"},
222
+ 501: {"model": ErrorResponse, "description": "Feature not yet implemented"},
223
+ 503: {"model": ErrorResponse, "description": "Database not configured or unavailable"},
224
+ },
225
+ )
217
226
  async def execute_query(
218
227
  request: QueryRequest,
219
228
  x_user_id: str | None = Header(default=None, description="User ID for query isolation (optional, uses default if not provided)"),
@@ -320,7 +329,7 @@ async def execute_query(
320
329
  return response
321
330
 
322
331
  else:
323
- # REM dialect mode - parse and execute directly
332
+ # REM dialect mode - use unified execute_query_string
324
333
  if not request.query:
325
334
  raise HTTPException(
326
335
  status_code=400,
@@ -329,17 +338,10 @@ async def execute_query(
329
338
 
330
339
  logger.info(f"REM dialect query: {request.query[:100]}...")
331
340
 
332
- parser = RemQueryParser()
333
- query_type, parameters = parser.parse(request.query)
334
-
335
- # Create and execute RemQuery
336
- rem_query = RemQuery.model_validate({
337
- "query_type": query_type,
338
- "parameters": parameters,
339
- "user_id": effective_user_id,
340
- })
341
-
342
- result = await rem_service.execute_query(rem_query)
341
+ # Use the unified execute_query_string method
342
+ result = await rem_service.execute_query_string(
343
+ request.query, user_id=effective_user_id
344
+ )
343
345
 
344
346
  return QueryResponse(
345
347
  query_type=result["query_type"],