remdb 0.3.181__py3-none-any.whl → 0.3.223__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of remdb might be problematic. Click here for more details.
- rem/agentic/README.md +262 -2
- rem/agentic/context.py +173 -0
- rem/agentic/context_builder.py +12 -2
- rem/agentic/mcp/tool_wrapper.py +2 -2
- rem/agentic/providers/pydantic_ai.py +1 -1
- rem/agentic/schema.py +2 -2
- rem/api/main.py +1 -1
- rem/api/mcp_router/server.py +4 -0
- rem/api/mcp_router/tools.py +542 -170
- rem/api/routers/admin.py +30 -4
- rem/api/routers/auth.py +106 -10
- rem/api/routers/chat/completions.py +66 -18
- rem/api/routers/chat/sse_events.py +7 -3
- rem/api/routers/chat/streaming.py +254 -22
- rem/api/routers/common.py +18 -0
- rem/api/routers/dev.py +7 -1
- rem/api/routers/feedback.py +9 -1
- rem/api/routers/messages.py +176 -38
- rem/api/routers/models.py +9 -1
- rem/api/routers/query.py +12 -1
- rem/api/routers/shared_sessions.py +16 -0
- rem/auth/jwt.py +19 -4
- rem/auth/middleware.py +42 -28
- rem/cli/README.md +62 -0
- rem/cli/commands/db.py +33 -19
- rem/cli/commands/process.py +171 -43
- rem/models/entities/ontology.py +18 -20
- rem/schemas/agents/rem.yaml +1 -1
- rem/services/content/service.py +18 -5
- rem/services/postgres/__init__.py +28 -3
- rem/services/postgres/diff_service.py +57 -5
- rem/services/postgres/programmable_diff_service.py +635 -0
- rem/services/postgres/pydantic_to_sqlalchemy.py +2 -2
- rem/services/postgres/register_type.py +11 -10
- rem/services/postgres/repository.py +14 -4
- rem/services/session/__init__.py +8 -1
- rem/services/session/compression.py +40 -2
- rem/services/session/pydantic_messages.py +276 -0
- rem/settings.py +28 -0
- rem/sql/migrations/001_install.sql +125 -7
- rem/sql/migrations/002_install_models.sql +136 -126
- rem/sql/migrations/004_cache_system.sql +7 -275
- rem/sql/migrations/migrate_session_id_to_uuid.sql +45 -0
- rem/utils/schema_loader.py +6 -6
- {remdb-0.3.181.dist-info → remdb-0.3.223.dist-info}/METADATA +1 -1
- {remdb-0.3.181.dist-info → remdb-0.3.223.dist-info}/RECORD +48 -44
- {remdb-0.3.181.dist-info → remdb-0.3.223.dist-info}/WHEEL +0 -0
- {remdb-0.3.181.dist-info → remdb-0.3.223.dist-info}/entry_points.txt +0 -0
rem/api/routers/messages.py
CHANGED
|
@@ -16,6 +16,7 @@ Endpoints:
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
from datetime import datetime
|
|
19
|
+
from enum import Enum
|
|
19
20
|
from typing import Literal
|
|
20
21
|
from uuid import UUID
|
|
21
22
|
|
|
@@ -23,6 +24,8 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request
|
|
|
23
24
|
from loguru import logger
|
|
24
25
|
from pydantic import BaseModel, Field
|
|
25
26
|
|
|
27
|
+
from .common import ErrorResponse
|
|
28
|
+
|
|
26
29
|
from ..deps import (
|
|
27
30
|
get_current_user,
|
|
28
31
|
get_user_filter,
|
|
@@ -38,6 +41,18 @@ from ...utils.date_utils import parse_iso, utc_now
|
|
|
38
41
|
router = APIRouter(prefix="/api/v1")
|
|
39
42
|
|
|
40
43
|
|
|
44
|
+
# =============================================================================
|
|
45
|
+
# Enums
|
|
46
|
+
# =============================================================================
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SortOrder(str, Enum):
|
|
50
|
+
"""Sort order for list queries."""
|
|
51
|
+
|
|
52
|
+
ASC = "asc"
|
|
53
|
+
DESC = "desc"
|
|
54
|
+
|
|
55
|
+
|
|
41
56
|
# =============================================================================
|
|
42
57
|
# Request/Response Models
|
|
43
58
|
# =============================================================================
|
|
@@ -93,6 +108,23 @@ class SessionListResponse(BaseModel):
|
|
|
93
108
|
has_more: bool
|
|
94
109
|
|
|
95
110
|
|
|
111
|
+
class SessionWithUser(BaseModel):
|
|
112
|
+
"""Session with user info for admin views."""
|
|
113
|
+
|
|
114
|
+
id: str
|
|
115
|
+
name: str
|
|
116
|
+
mode: str | None = None
|
|
117
|
+
description: str | None = None
|
|
118
|
+
user_id: str | None = None
|
|
119
|
+
user_name: str | None = None
|
|
120
|
+
user_email: str | None = None
|
|
121
|
+
message_count: int = 0
|
|
122
|
+
total_tokens: int | None = None
|
|
123
|
+
created_at: datetime | None = None
|
|
124
|
+
updated_at: datetime | None = None
|
|
125
|
+
metadata: dict | None = None
|
|
126
|
+
|
|
127
|
+
|
|
96
128
|
class PaginationMetadata(BaseModel):
|
|
97
129
|
"""Pagination metadata for paginated responses."""
|
|
98
130
|
|
|
@@ -108,7 +140,7 @@ class SessionsQueryResponse(BaseModel):
|
|
|
108
140
|
"""Response for paginated sessions query."""
|
|
109
141
|
|
|
110
142
|
object: Literal["list"] = "list"
|
|
111
|
-
data: list[
|
|
143
|
+
data: list[SessionWithUser] = Field(description="List of sessions for the current page")
|
|
112
144
|
metadata: PaginationMetadata = Field(description="Pagination metadata")
|
|
113
145
|
|
|
114
146
|
|
|
@@ -117,7 +149,14 @@ class SessionsQueryResponse(BaseModel):
|
|
|
117
149
|
# =============================================================================
|
|
118
150
|
|
|
119
151
|
|
|
120
|
-
@router.get(
|
|
152
|
+
@router.get(
|
|
153
|
+
"/messages",
|
|
154
|
+
response_model=MessageListResponse,
|
|
155
|
+
tags=["messages"],
|
|
156
|
+
responses={
|
|
157
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
158
|
+
},
|
|
159
|
+
)
|
|
121
160
|
async def list_messages(
|
|
122
161
|
request: Request,
|
|
123
162
|
mine: bool = Query(default=False, description="Only show my messages (uses JWT identity)"),
|
|
@@ -134,6 +173,7 @@ async def list_messages(
|
|
|
134
173
|
),
|
|
135
174
|
limit: int = Query(default=50, ge=1, le=100, description="Max results to return"),
|
|
136
175
|
offset: int = Query(default=0, ge=0, description="Offset for pagination"),
|
|
176
|
+
sort: SortOrder = Query(default=SortOrder.DESC, description="Sort order by created_at (asc or desc)"),
|
|
137
177
|
) -> MessageListResponse:
|
|
138
178
|
"""
|
|
139
179
|
List messages with optional filters.
|
|
@@ -149,8 +189,9 @@ async def list_messages(
|
|
|
149
189
|
- session_id: Filter by conversation session
|
|
150
190
|
- start_date/end_date: Filter by creation time range (ISO 8601 format)
|
|
151
191
|
- message_type: Filter by role (user, assistant, system, tool)
|
|
192
|
+
- sort: Sort order by created_at (asc or desc, default: desc)
|
|
152
193
|
|
|
153
|
-
Returns paginated results ordered by created_at
|
|
194
|
+
Returns paginated results ordered by created_at.
|
|
154
195
|
"""
|
|
155
196
|
if not settings.postgres.enabled:
|
|
156
197
|
raise HTTPException(status_code=503, detail="Database not enabled")
|
|
@@ -172,6 +213,7 @@ async def list_messages(
|
|
|
172
213
|
|
|
173
214
|
# Apply optional filters
|
|
174
215
|
if session_id:
|
|
216
|
+
# session_id is the session UUID - use directly
|
|
175
217
|
filters["session_id"] = session_id
|
|
176
218
|
if message_type:
|
|
177
219
|
filters["message_type"] = message_type
|
|
@@ -183,12 +225,15 @@ async def list_messages(
|
|
|
183
225
|
f"filters={filters}"
|
|
184
226
|
)
|
|
185
227
|
|
|
228
|
+
# Build order_by clause based on sort parameter
|
|
229
|
+
order_by = f"created_at {sort.value.upper()}"
|
|
230
|
+
|
|
186
231
|
# For date filtering, we need custom SQL (not supported by basic Repository)
|
|
187
232
|
# For now, fetch all matching base filters and filter in Python
|
|
188
233
|
# TODO: Extend Repository to support date range filters
|
|
189
234
|
messages = await repo.find(
|
|
190
235
|
filters,
|
|
191
|
-
order_by=
|
|
236
|
+
order_by=order_by,
|
|
192
237
|
limit=limit + 1, # Fetch one extra to determine has_more
|
|
193
238
|
offset=offset,
|
|
194
239
|
)
|
|
@@ -224,7 +269,16 @@ async def list_messages(
|
|
|
224
269
|
return MessageListResponse(data=messages, total=total, has_more=has_more)
|
|
225
270
|
|
|
226
271
|
|
|
227
|
-
@router.get(
|
|
272
|
+
@router.get(
|
|
273
|
+
"/messages/{message_id}",
|
|
274
|
+
response_model=Message,
|
|
275
|
+
tags=["messages"],
|
|
276
|
+
responses={
|
|
277
|
+
403: {"model": ErrorResponse, "description": "Access denied: not owner"},
|
|
278
|
+
404: {"model": ErrorResponse, "description": "Message not found"},
|
|
279
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
280
|
+
},
|
|
281
|
+
)
|
|
228
282
|
async def get_message(
|
|
229
283
|
request: Request,
|
|
230
284
|
message_id: str,
|
|
@@ -270,10 +324,19 @@ async def get_message(
|
|
|
270
324
|
# =============================================================================
|
|
271
325
|
|
|
272
326
|
|
|
273
|
-
@router.get(
|
|
327
|
+
@router.get(
|
|
328
|
+
"/sessions",
|
|
329
|
+
response_model=SessionsQueryResponse,
|
|
330
|
+
tags=["sessions"],
|
|
331
|
+
responses={
|
|
332
|
+
503: {"model": ErrorResponse, "description": "Database not enabled or connection failed"},
|
|
333
|
+
},
|
|
334
|
+
)
|
|
274
335
|
async def list_sessions(
|
|
275
336
|
request: Request,
|
|
276
337
|
user_id: str | None = Query(default=None, description="Filter by user ID (admin only for cross-user)"),
|
|
338
|
+
user_name: str | None = Query(default=None, description="Filter by user name (partial match, admin only)"),
|
|
339
|
+
user_email: str | None = Query(default=None, description="Filter by user email (partial match, admin only)"),
|
|
277
340
|
mode: SessionMode | None = Query(default=None, description="Filter by session mode"),
|
|
278
341
|
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
|
279
342
|
page_size: int = Query(default=50, ge=1, le=100, description="Number of results per page"),
|
|
@@ -283,51 +346,113 @@ async def list_sessions(
|
|
|
283
346
|
|
|
284
347
|
Access Control:
|
|
285
348
|
- Regular users: Only see their own sessions
|
|
286
|
-
- Admin users: Can filter by any user_id or see all sessions
|
|
349
|
+
- Admin users: Can filter by any user_id, user_name, user_email, or see all sessions
|
|
287
350
|
|
|
288
351
|
Filters:
|
|
289
352
|
- user_id: Filter by session owner (admin only for cross-user)
|
|
353
|
+
- user_name: Filter by user name partial match (admin only)
|
|
354
|
+
- user_email: Filter by user email partial match (admin only)
|
|
290
355
|
- mode: Filter by session mode (normal or evaluation)
|
|
291
356
|
|
|
292
357
|
Pagination:
|
|
293
358
|
- page: Page number (1-indexed, default: 1)
|
|
294
359
|
- page_size: Number of sessions per page (default: 50, max: 100)
|
|
295
360
|
|
|
296
|
-
Returns paginated results ordered by created_at descending
|
|
361
|
+
Returns paginated results with user info ordered by created_at descending.
|
|
297
362
|
"""
|
|
298
363
|
if not settings.postgres.enabled:
|
|
299
364
|
raise HTTPException(status_code=503, detail="Database not enabled")
|
|
300
365
|
|
|
301
|
-
|
|
366
|
+
current_user = get_current_user(request)
|
|
367
|
+
admin = is_admin(current_user)
|
|
302
368
|
|
|
303
|
-
#
|
|
304
|
-
|
|
305
|
-
if
|
|
306
|
-
|
|
369
|
+
# Get postgres service for raw SQL query
|
|
370
|
+
db = get_postgres_service()
|
|
371
|
+
if not db:
|
|
372
|
+
raise HTTPException(status_code=503, detail="Database connection failed")
|
|
373
|
+
if not db.pool:
|
|
374
|
+
await db.connect()
|
|
307
375
|
|
|
308
|
-
#
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
376
|
+
# Build effective filters based on user role
|
|
377
|
+
effective_user_id = user_id
|
|
378
|
+
effective_user_name = user_name if admin else None # Only admin can search by name
|
|
379
|
+
effective_user_email = user_email if admin else None # Only admin can search by email
|
|
380
|
+
|
|
381
|
+
if not admin:
|
|
382
|
+
# Non-admin users can only see their own sessions
|
|
383
|
+
effective_user_id = current_user.get("id") if current_user else None
|
|
384
|
+
if not effective_user_id:
|
|
385
|
+
# Anonymous user - return empty
|
|
386
|
+
return SessionsQueryResponse(
|
|
387
|
+
data=[],
|
|
388
|
+
metadata=PaginationMetadata(
|
|
389
|
+
total=0, page=page, page_size=page_size,
|
|
390
|
+
total_pages=0, has_next=False, has_previous=False,
|
|
391
|
+
),
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Call the SQL function for sessions with user info
|
|
395
|
+
async with db.pool.acquire() as conn:
|
|
396
|
+
rows = await conn.fetch(
|
|
397
|
+
"""
|
|
398
|
+
SELECT * FROM fn_list_sessions_with_user(
|
|
399
|
+
$1, $2, $3, $4, $5, $6
|
|
400
|
+
)
|
|
401
|
+
""",
|
|
402
|
+
effective_user_id,
|
|
403
|
+
effective_user_name,
|
|
404
|
+
effective_user_email,
|
|
405
|
+
mode.value if mode else None,
|
|
406
|
+
page,
|
|
407
|
+
page_size,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
# Extract total from first row
|
|
411
|
+
total = rows[0]["total_count"] if rows else 0
|
|
412
|
+
|
|
413
|
+
# Convert rows to SessionWithUser
|
|
414
|
+
data = [
|
|
415
|
+
SessionWithUser(
|
|
416
|
+
id=str(row["id"]),
|
|
417
|
+
name=row["name"],
|
|
418
|
+
mode=row["mode"],
|
|
419
|
+
description=row["description"],
|
|
420
|
+
user_id=row["user_id"],
|
|
421
|
+
user_name=row["user_name"],
|
|
422
|
+
user_email=row["user_email"],
|
|
423
|
+
message_count=row["message_count"] or 0,
|
|
424
|
+
total_tokens=row["total_tokens"],
|
|
425
|
+
created_at=row["created_at"],
|
|
426
|
+
updated_at=row["updated_at"],
|
|
427
|
+
metadata=row["metadata"],
|
|
428
|
+
)
|
|
429
|
+
for row in rows
|
|
430
|
+
]
|
|
431
|
+
|
|
432
|
+
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
|
|
316
433
|
|
|
317
434
|
return SessionsQueryResponse(
|
|
318
|
-
data=
|
|
435
|
+
data=data,
|
|
319
436
|
metadata=PaginationMetadata(
|
|
320
|
-
total=
|
|
321
|
-
page=
|
|
322
|
-
page_size=
|
|
323
|
-
total_pages=
|
|
324
|
-
has_next=
|
|
325
|
-
has_previous=
|
|
437
|
+
total=total,
|
|
438
|
+
page=page,
|
|
439
|
+
page_size=page_size,
|
|
440
|
+
total_pages=total_pages,
|
|
441
|
+
has_next=page < total_pages,
|
|
442
|
+
has_previous=page > 1,
|
|
326
443
|
),
|
|
327
444
|
)
|
|
328
445
|
|
|
329
446
|
|
|
330
|
-
@router.post(
|
|
447
|
+
@router.post(
|
|
448
|
+
"/sessions",
|
|
449
|
+
response_model=Session,
|
|
450
|
+
status_code=201,
|
|
451
|
+
tags=["sessions"],
|
|
452
|
+
responses={
|
|
453
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
454
|
+
},
|
|
455
|
+
)
|
|
331
456
|
async def create_session(
|
|
332
457
|
request_body: SessionCreateRequest,
|
|
333
458
|
user: dict = Depends(require_admin),
|
|
@@ -379,7 +504,16 @@ async def create_session(
|
|
|
379
504
|
return result # type: ignore
|
|
380
505
|
|
|
381
506
|
|
|
382
|
-
@router.get(
|
|
507
|
+
@router.get(
|
|
508
|
+
"/sessions/{session_id}",
|
|
509
|
+
response_model=Session,
|
|
510
|
+
tags=["sessions"],
|
|
511
|
+
responses={
|
|
512
|
+
403: {"model": ErrorResponse, "description": "Access denied: not owner"},
|
|
513
|
+
404: {"model": ErrorResponse, "description": "Session not found"},
|
|
514
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
515
|
+
},
|
|
516
|
+
)
|
|
383
517
|
async def get_session(
|
|
384
518
|
request: Request,
|
|
385
519
|
session_id: str,
|
|
@@ -392,7 +526,7 @@ async def get_session(
|
|
|
392
526
|
- Admin users: Can access any session
|
|
393
527
|
|
|
394
528
|
Args:
|
|
395
|
-
session_id: UUID
|
|
529
|
+
session_id: UUID of the session
|
|
396
530
|
|
|
397
531
|
Returns:
|
|
398
532
|
Session object if found
|
|
@@ -408,12 +542,7 @@ async def get_session(
|
|
|
408
542
|
session = await repo.get_by_id(session_id)
|
|
409
543
|
|
|
410
544
|
if not session:
|
|
411
|
-
|
|
412
|
-
sessions = await repo.find({"name": session_id}, limit=1)
|
|
413
|
-
if sessions:
|
|
414
|
-
session = sessions[0]
|
|
415
|
-
else:
|
|
416
|
-
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
|
545
|
+
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
|
417
546
|
|
|
418
547
|
# Check access: admin or owner
|
|
419
548
|
current_user = get_current_user(request)
|
|
@@ -425,7 +554,16 @@ async def get_session(
|
|
|
425
554
|
return session
|
|
426
555
|
|
|
427
556
|
|
|
428
|
-
@router.put(
|
|
557
|
+
@router.put(
|
|
558
|
+
"/sessions/{session_id}",
|
|
559
|
+
response_model=Session,
|
|
560
|
+
tags=["sessions"],
|
|
561
|
+
responses={
|
|
562
|
+
403: {"model": ErrorResponse, "description": "Access denied: not owner"},
|
|
563
|
+
404: {"model": ErrorResponse, "description": "Session not found"},
|
|
564
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
565
|
+
},
|
|
566
|
+
)
|
|
429
567
|
async def update_session(
|
|
430
568
|
request: Request,
|
|
431
569
|
session_id: str,
|
rem/api/routers/models.py
CHANGED
|
@@ -15,6 +15,8 @@ from typing import Literal
|
|
|
15
15
|
from fastapi import APIRouter, HTTPException
|
|
16
16
|
from pydantic import BaseModel, Field
|
|
17
17
|
|
|
18
|
+
from .common import ErrorResponse
|
|
19
|
+
|
|
18
20
|
from rem.agentic.llm_provider_models import (
|
|
19
21
|
ModelInfo,
|
|
20
22
|
AVAILABLE_MODELS,
|
|
@@ -57,7 +59,13 @@ async def list_models() -> ModelsResponse:
|
|
|
57
59
|
return ModelsResponse(data=AVAILABLE_MODELS)
|
|
58
60
|
|
|
59
61
|
|
|
60
|
-
@router.get(
|
|
62
|
+
@router.get(
|
|
63
|
+
"/models/{model_id:path}",
|
|
64
|
+
response_model=ModelInfo,
|
|
65
|
+
responses={
|
|
66
|
+
404: {"model": ErrorResponse, "description": "Model not found"},
|
|
67
|
+
},
|
|
68
|
+
)
|
|
61
69
|
async def get_model(model_id: str) -> ModelInfo:
|
|
62
70
|
"""
|
|
63
71
|
Get information about a specific model.
|
rem/api/routers/query.py
CHANGED
|
@@ -86,6 +86,8 @@ from fastapi import APIRouter, Header, HTTPException
|
|
|
86
86
|
from loguru import logger
|
|
87
87
|
from pydantic import BaseModel, Field
|
|
88
88
|
|
|
89
|
+
from .common import ErrorResponse
|
|
90
|
+
|
|
89
91
|
from ...services.postgres import get_postgres_service
|
|
90
92
|
from ...services.rem.service import RemService
|
|
91
93
|
from ...services.rem.parser import RemQueryParser
|
|
@@ -213,7 +215,16 @@ class QueryResponse(BaseModel):
|
|
|
213
215
|
)
|
|
214
216
|
|
|
215
217
|
|
|
216
|
-
@router.post(
|
|
218
|
+
@router.post(
|
|
219
|
+
"/query",
|
|
220
|
+
response_model=QueryResponse,
|
|
221
|
+
responses={
|
|
222
|
+
400: {"model": ErrorResponse, "description": "Invalid query or missing required fields"},
|
|
223
|
+
500: {"model": ErrorResponse, "description": "Query execution failed"},
|
|
224
|
+
501: {"model": ErrorResponse, "description": "Feature not yet implemented"},
|
|
225
|
+
503: {"model": ErrorResponse, "description": "Database not configured or unavailable"},
|
|
226
|
+
},
|
|
227
|
+
)
|
|
217
228
|
async def execute_query(
|
|
218
229
|
request: QueryRequest,
|
|
219
230
|
x_user_id: str | None = Header(default=None, description="User ID for query isolation (optional, uses default if not provided)"),
|
|
@@ -18,6 +18,8 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request
|
|
|
18
18
|
from loguru import logger
|
|
19
19
|
from pydantic import BaseModel, Field
|
|
20
20
|
|
|
21
|
+
from .common import ErrorResponse
|
|
22
|
+
|
|
21
23
|
from ..deps import get_current_user, require_auth
|
|
22
24
|
from ...models.entities import (
|
|
23
25
|
Message,
|
|
@@ -83,6 +85,10 @@ class ShareSessionResponse(BaseModel):
|
|
|
83
85
|
response_model=ShareSessionResponse,
|
|
84
86
|
status_code=201,
|
|
85
87
|
tags=["sessions"],
|
|
88
|
+
responses={
|
|
89
|
+
400: {"model": ErrorResponse, "description": "Session already shared with this user"},
|
|
90
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
91
|
+
},
|
|
86
92
|
)
|
|
87
93
|
async def share_session(
|
|
88
94
|
request: Request,
|
|
@@ -175,6 +181,10 @@ async def share_session(
|
|
|
175
181
|
"/sessions/{session_id}/share/{shared_with_user_id}",
|
|
176
182
|
status_code=200,
|
|
177
183
|
tags=["sessions"],
|
|
184
|
+
responses={
|
|
185
|
+
404: {"model": ErrorResponse, "description": "Share not found"},
|
|
186
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
187
|
+
},
|
|
178
188
|
)
|
|
179
189
|
async def remove_session_share(
|
|
180
190
|
request: Request,
|
|
@@ -250,6 +260,9 @@ async def remove_session_share(
|
|
|
250
260
|
"/sessions/shared-with-me",
|
|
251
261
|
response_model=SharedWithMeResponse,
|
|
252
262
|
tags=["sessions"],
|
|
263
|
+
responses={
|
|
264
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
265
|
+
},
|
|
253
266
|
)
|
|
254
267
|
async def get_shared_with_me(
|
|
255
268
|
request: Request,
|
|
@@ -328,6 +341,9 @@ async def get_shared_with_me(
|
|
|
328
341
|
"/sessions/shared-with-me/{owner_user_id}/messages",
|
|
329
342
|
response_model=SharedMessagesResponse,
|
|
330
343
|
tags=["sessions"],
|
|
344
|
+
responses={
|
|
345
|
+
503: {"model": ErrorResponse, "description": "Database not enabled"},
|
|
346
|
+
},
|
|
331
347
|
)
|
|
332
348
|
async def get_shared_messages(
|
|
333
349
|
request: Request,
|
rem/auth/jwt.py
CHANGED
|
@@ -260,12 +260,16 @@ class JWTService:
|
|
|
260
260
|
"tenant_id": payload.get("tenant_id", "default"),
|
|
261
261
|
}
|
|
262
262
|
|
|
263
|
-
def refresh_access_token(
|
|
263
|
+
def refresh_access_token(
|
|
264
|
+
self, refresh_token: str, user_override: dict | None = None
|
|
265
|
+
) -> dict | None:
|
|
264
266
|
"""
|
|
265
267
|
Create new access token using refresh token.
|
|
266
268
|
|
|
267
269
|
Args:
|
|
268
270
|
refresh_token: Valid refresh token
|
|
271
|
+
user_override: Optional dict with user fields to override defaults
|
|
272
|
+
(e.g., role, roles, tier, name from database lookup)
|
|
269
273
|
|
|
270
274
|
Returns:
|
|
271
275
|
New token dict or None if refresh token is invalid
|
|
@@ -285,8 +289,7 @@ class JWTService:
|
|
|
285
289
|
logger.debug("Refresh token expired")
|
|
286
290
|
return None
|
|
287
291
|
|
|
288
|
-
#
|
|
289
|
-
# In production, you'd look up the full user from database
|
|
292
|
+
# Build user dict with defaults
|
|
290
293
|
user = {
|
|
291
294
|
"id": payload.get("sub"),
|
|
292
295
|
"email": payload.get("email"),
|
|
@@ -294,16 +297,28 @@ class JWTService:
|
|
|
294
297
|
"provider": "email",
|
|
295
298
|
"tenant_id": "default",
|
|
296
299
|
"tier": "free",
|
|
300
|
+
"role": "user",
|
|
297
301
|
"roles": ["user"],
|
|
298
302
|
}
|
|
299
303
|
|
|
304
|
+
# Apply overrides from database lookup if provided
|
|
305
|
+
if user_override:
|
|
306
|
+
if user_override.get("role"):
|
|
307
|
+
user["role"] = user_override["role"]
|
|
308
|
+
if user_override.get("roles"):
|
|
309
|
+
user["roles"] = user_override["roles"]
|
|
310
|
+
if user_override.get("tier"):
|
|
311
|
+
user["tier"] = user_override["tier"]
|
|
312
|
+
if user_override.get("name"):
|
|
313
|
+
user["name"] = user_override["name"]
|
|
314
|
+
|
|
300
315
|
# Only return new access token, keep same refresh token
|
|
301
316
|
now = int(time.time())
|
|
302
317
|
access_payload = {
|
|
303
318
|
"sub": user["id"],
|
|
304
319
|
"email": user["email"],
|
|
305
320
|
"name": user["name"],
|
|
306
|
-
"role": user
|
|
321
|
+
"role": user["role"],
|
|
307
322
|
"tier": user["tier"],
|
|
308
323
|
"roles": user["roles"],
|
|
309
324
|
"provider": user["provider"],
|
rem/auth/middleware.py
CHANGED
|
@@ -14,15 +14,14 @@ Design Pattern:
|
|
|
14
14
|
- MCP paths always require authentication (protected service)
|
|
15
15
|
|
|
16
16
|
Authentication Flow:
|
|
17
|
-
1.
|
|
18
|
-
2.
|
|
19
|
-
3.
|
|
20
|
-
4.
|
|
21
|
-
5. If allow_anonymous=
|
|
22
|
-
6. If allow_anonymous=False: Return 401 / redirect to login
|
|
17
|
+
1. Check JWT/dev token/session for user identity first
|
|
18
|
+
2. If user is admin: bypass API key check (admin privilege)
|
|
19
|
+
3. If API key enabled and user is not admin: Validate X-API-Key header
|
|
20
|
+
4. If allow_anonymous=True: Allow as anonymous (rate-limited)
|
|
21
|
+
5. If allow_anonymous=False: Return 401 / redirect to login
|
|
23
22
|
|
|
24
23
|
IMPORTANT: API key validates ACCESS, JWT identifies USER.
|
|
25
|
-
|
|
24
|
+
Admin users bypass the API key requirement (trusted identity).
|
|
26
25
|
|
|
27
26
|
Access Modes (configured in settings.auth):
|
|
28
27
|
- enabled=true, allow_anonymous=true: Auth available, anonymous gets rate-limited access
|
|
@@ -195,6 +194,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
195
194
|
|
|
196
195
|
return None
|
|
197
196
|
|
|
197
|
+
def _is_admin(self, user: dict | None) -> bool:
|
|
198
|
+
"""Check if user has admin role."""
|
|
199
|
+
if not user:
|
|
200
|
+
return False
|
|
201
|
+
return "admin" in user.get("roles", [])
|
|
202
|
+
|
|
198
203
|
async def dispatch(self, request: Request, call_next):
|
|
199
204
|
"""
|
|
200
205
|
Check authentication for protected paths.
|
|
@@ -219,8 +224,35 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
219
224
|
if not is_protected or is_excluded:
|
|
220
225
|
return await call_next(request)
|
|
221
226
|
|
|
222
|
-
#
|
|
223
|
-
#
|
|
227
|
+
# Check for user identity FIRST (JWT, dev token, session)
|
|
228
|
+
# This allows admin users to bypass API key requirement
|
|
229
|
+
user = None
|
|
230
|
+
|
|
231
|
+
# Check for JWT token in Authorization header (primary user identity)
|
|
232
|
+
jwt_user = self._check_jwt_token(request)
|
|
233
|
+
if jwt_user:
|
|
234
|
+
user = jwt_user
|
|
235
|
+
|
|
236
|
+
# Check for dev token (non-production only)
|
|
237
|
+
if not user:
|
|
238
|
+
dev_user = self._check_dev_token(request)
|
|
239
|
+
if dev_user:
|
|
240
|
+
user = dev_user
|
|
241
|
+
|
|
242
|
+
# Check for valid session (backward compatibility)
|
|
243
|
+
if not user:
|
|
244
|
+
session_user = request.session.get("user")
|
|
245
|
+
if session_user:
|
|
246
|
+
user = session_user
|
|
247
|
+
|
|
248
|
+
# If user is admin, bypass API key check entirely
|
|
249
|
+
if self._is_admin(user):
|
|
250
|
+
logger.debug(f"Admin user {user.get('email')} bypassing API key check")
|
|
251
|
+
request.state.user = user
|
|
252
|
+
request.state.is_anonymous = False
|
|
253
|
+
return await call_next(request)
|
|
254
|
+
|
|
255
|
+
# API key validation for non-admin users (access control guardrail)
|
|
224
256
|
if settings.api.api_key_enabled:
|
|
225
257
|
api_key = request.headers.get("x-api-key")
|
|
226
258
|
if not api_key:
|
|
@@ -238,27 +270,9 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
238
270
|
headers={"WWW-Authenticate": 'ApiKey realm="REM API"'},
|
|
239
271
|
)
|
|
240
272
|
logger.debug("X-API-Key validated for access")
|
|
241
|
-
# API key valid - continue to check JWT for user identity
|
|
242
|
-
|
|
243
|
-
# Check for JWT token in Authorization header (primary user identity)
|
|
244
|
-
jwt_user = self._check_jwt_token(request)
|
|
245
|
-
if jwt_user:
|
|
246
|
-
request.state.user = jwt_user
|
|
247
|
-
request.state.is_anonymous = False
|
|
248
|
-
return await call_next(request)
|
|
249
|
-
|
|
250
|
-
# Check for dev token (non-production only)
|
|
251
|
-
dev_user = self._check_dev_token(request)
|
|
252
|
-
if dev_user:
|
|
253
|
-
request.state.user = dev_user
|
|
254
|
-
request.state.is_anonymous = False
|
|
255
|
-
return await call_next(request)
|
|
256
|
-
|
|
257
|
-
# Check for valid session (backward compatibility)
|
|
258
|
-
user = request.session.get("user")
|
|
259
273
|
|
|
274
|
+
# If we have a valid user (non-admin, but passed API key check), allow access
|
|
260
275
|
if user:
|
|
261
|
-
# Authenticated user - add to request state
|
|
262
276
|
request.state.user = user
|
|
263
277
|
request.state.is_anonymous = False
|
|
264
278
|
return await call_next(request)
|