remdb 0.3.181__py3-none-any.whl → 0.3.223__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of remdb might be problematic. Click here for more details.

Files changed (48) hide show
  1. rem/agentic/README.md +262 -2
  2. rem/agentic/context.py +173 -0
  3. rem/agentic/context_builder.py +12 -2
  4. rem/agentic/mcp/tool_wrapper.py +2 -2
  5. rem/agentic/providers/pydantic_ai.py +1 -1
  6. rem/agentic/schema.py +2 -2
  7. rem/api/main.py +1 -1
  8. rem/api/mcp_router/server.py +4 -0
  9. rem/api/mcp_router/tools.py +542 -170
  10. rem/api/routers/admin.py +30 -4
  11. rem/api/routers/auth.py +106 -10
  12. rem/api/routers/chat/completions.py +66 -18
  13. rem/api/routers/chat/sse_events.py +7 -3
  14. rem/api/routers/chat/streaming.py +254 -22
  15. rem/api/routers/common.py +18 -0
  16. rem/api/routers/dev.py +7 -1
  17. rem/api/routers/feedback.py +9 -1
  18. rem/api/routers/messages.py +176 -38
  19. rem/api/routers/models.py +9 -1
  20. rem/api/routers/query.py +12 -1
  21. rem/api/routers/shared_sessions.py +16 -0
  22. rem/auth/jwt.py +19 -4
  23. rem/auth/middleware.py +42 -28
  24. rem/cli/README.md +62 -0
  25. rem/cli/commands/db.py +33 -19
  26. rem/cli/commands/process.py +171 -43
  27. rem/models/entities/ontology.py +18 -20
  28. rem/schemas/agents/rem.yaml +1 -1
  29. rem/services/content/service.py +18 -5
  30. rem/services/postgres/__init__.py +28 -3
  31. rem/services/postgres/diff_service.py +57 -5
  32. rem/services/postgres/programmable_diff_service.py +635 -0
  33. rem/services/postgres/pydantic_to_sqlalchemy.py +2 -2
  34. rem/services/postgres/register_type.py +11 -10
  35. rem/services/postgres/repository.py +14 -4
  36. rem/services/session/__init__.py +8 -1
  37. rem/services/session/compression.py +40 -2
  38. rem/services/session/pydantic_messages.py +276 -0
  39. rem/settings.py +28 -0
  40. rem/sql/migrations/001_install.sql +125 -7
  41. rem/sql/migrations/002_install_models.sql +136 -126
  42. rem/sql/migrations/004_cache_system.sql +7 -275
  43. rem/sql/migrations/migrate_session_id_to_uuid.sql +45 -0
  44. rem/utils/schema_loader.py +6 -6
  45. {remdb-0.3.181.dist-info → remdb-0.3.223.dist-info}/METADATA +1 -1
  46. {remdb-0.3.181.dist-info → remdb-0.3.223.dist-info}/RECORD +48 -44
  47. {remdb-0.3.181.dist-info → remdb-0.3.223.dist-info}/WHEEL +0 -0
  48. {remdb-0.3.181.dist-info → remdb-0.3.223.dist-info}/entry_points.txt +0 -0
@@ -16,6 +16,7 @@ Endpoints:
16
16
  """
17
17
 
18
18
  from datetime import datetime
19
+ from enum import Enum
19
20
  from typing import Literal
20
21
  from uuid import UUID
21
22
 
@@ -23,6 +24,8 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request
23
24
  from loguru import logger
24
25
  from pydantic import BaseModel, Field
25
26
 
27
+ from .common import ErrorResponse
28
+
26
29
  from ..deps import (
27
30
  get_current_user,
28
31
  get_user_filter,
@@ -38,6 +41,18 @@ from ...utils.date_utils import parse_iso, utc_now
38
41
  router = APIRouter(prefix="/api/v1")
39
42
 
40
43
 
44
+ # =============================================================================
45
+ # Enums
46
+ # =============================================================================
47
+
48
+
49
+ class SortOrder(str, Enum):
50
+ """Sort order for list queries."""
51
+
52
+ ASC = "asc"
53
+ DESC = "desc"
54
+
55
+
41
56
  # =============================================================================
42
57
  # Request/Response Models
43
58
  # =============================================================================
@@ -93,6 +108,23 @@ class SessionListResponse(BaseModel):
93
108
  has_more: bool
94
109
 
95
110
 
111
+ class SessionWithUser(BaseModel):
112
+ """Session with user info for admin views."""
113
+
114
+ id: str
115
+ name: str
116
+ mode: str | None = None
117
+ description: str | None = None
118
+ user_id: str | None = None
119
+ user_name: str | None = None
120
+ user_email: str | None = None
121
+ message_count: int = 0
122
+ total_tokens: int | None = None
123
+ created_at: datetime | None = None
124
+ updated_at: datetime | None = None
125
+ metadata: dict | None = None
126
+
127
+
96
128
  class PaginationMetadata(BaseModel):
97
129
  """Pagination metadata for paginated responses."""
98
130
 
@@ -108,7 +140,7 @@ class SessionsQueryResponse(BaseModel):
108
140
  """Response for paginated sessions query."""
109
141
 
110
142
  object: Literal["list"] = "list"
111
- data: list[Session] = Field(description="List of sessions for the current page")
143
+ data: list[SessionWithUser] = Field(description="List of sessions for the current page")
112
144
  metadata: PaginationMetadata = Field(description="Pagination metadata")
113
145
 
114
146
 
@@ -117,7 +149,14 @@ class SessionsQueryResponse(BaseModel):
117
149
  # =============================================================================
118
150
 
119
151
 
120
- @router.get("/messages", response_model=MessageListResponse, tags=["messages"])
152
+ @router.get(
153
+ "/messages",
154
+ response_model=MessageListResponse,
155
+ tags=["messages"],
156
+ responses={
157
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
158
+ },
159
+ )
121
160
  async def list_messages(
122
161
  request: Request,
123
162
  mine: bool = Query(default=False, description="Only show my messages (uses JWT identity)"),
@@ -134,6 +173,7 @@ async def list_messages(
134
173
  ),
135
174
  limit: int = Query(default=50, ge=1, le=100, description="Max results to return"),
136
175
  offset: int = Query(default=0, ge=0, description="Offset for pagination"),
176
+ sort: SortOrder = Query(default=SortOrder.DESC, description="Sort order by created_at (asc or desc)"),
137
177
  ) -> MessageListResponse:
138
178
  """
139
179
  List messages with optional filters.
@@ -149,8 +189,9 @@ async def list_messages(
149
189
  - session_id: Filter by conversation session
150
190
  - start_date/end_date: Filter by creation time range (ISO 8601 format)
151
191
  - message_type: Filter by role (user, assistant, system, tool)
192
+ - sort: Sort order by created_at (asc or desc, default: desc)
152
193
 
153
- Returns paginated results ordered by created_at descending.
194
+ Returns paginated results ordered by created_at.
154
195
  """
155
196
  if not settings.postgres.enabled:
156
197
  raise HTTPException(status_code=503, detail="Database not enabled")
@@ -172,6 +213,7 @@ async def list_messages(
172
213
 
173
214
  # Apply optional filters
174
215
  if session_id:
216
+ # session_id is the session UUID - use directly
175
217
  filters["session_id"] = session_id
176
218
  if message_type:
177
219
  filters["message_type"] = message_type
@@ -183,12 +225,15 @@ async def list_messages(
183
225
  f"filters={filters}"
184
226
  )
185
227
 
228
+ # Build order_by clause based on sort parameter
229
+ order_by = f"created_at {sort.value.upper()}"
230
+
186
231
  # For date filtering, we need custom SQL (not supported by basic Repository)
187
232
  # For now, fetch all matching base filters and filter in Python
188
233
  # TODO: Extend Repository to support date range filters
189
234
  messages = await repo.find(
190
235
  filters,
191
- order_by="created_at DESC",
236
+ order_by=order_by,
192
237
  limit=limit + 1, # Fetch one extra to determine has_more
193
238
  offset=offset,
194
239
  )
@@ -224,7 +269,16 @@ async def list_messages(
224
269
  return MessageListResponse(data=messages, total=total, has_more=has_more)
225
270
 
226
271
 
227
- @router.get("/messages/{message_id}", response_model=Message, tags=["messages"])
272
+ @router.get(
273
+ "/messages/{message_id}",
274
+ response_model=Message,
275
+ tags=["messages"],
276
+ responses={
277
+ 403: {"model": ErrorResponse, "description": "Access denied: not owner"},
278
+ 404: {"model": ErrorResponse, "description": "Message not found"},
279
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
280
+ },
281
+ )
228
282
  async def get_message(
229
283
  request: Request,
230
284
  message_id: str,
@@ -270,10 +324,19 @@ async def get_message(
270
324
  # =============================================================================
271
325
 
272
326
 
273
- @router.get("/sessions", response_model=SessionsQueryResponse, tags=["sessions"])
327
+ @router.get(
328
+ "/sessions",
329
+ response_model=SessionsQueryResponse,
330
+ tags=["sessions"],
331
+ responses={
332
+ 503: {"model": ErrorResponse, "description": "Database not enabled or connection failed"},
333
+ },
334
+ )
274
335
  async def list_sessions(
275
336
  request: Request,
276
337
  user_id: str | None = Query(default=None, description="Filter by user ID (admin only for cross-user)"),
338
+ user_name: str | None = Query(default=None, description="Filter by user name (partial match, admin only)"),
339
+ user_email: str | None = Query(default=None, description="Filter by user email (partial match, admin only)"),
277
340
  mode: SessionMode | None = Query(default=None, description="Filter by session mode"),
278
341
  page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
279
342
  page_size: int = Query(default=50, ge=1, le=100, description="Number of results per page"),
@@ -283,51 +346,113 @@ async def list_sessions(
283
346
 
284
347
  Access Control:
285
348
  - Regular users: Only see their own sessions
286
- - Admin users: Can filter by any user_id or see all sessions
349
+ - Admin users: Can filter by any user_id, user_name, user_email, or see all sessions
287
350
 
288
351
  Filters:
289
352
  - user_id: Filter by session owner (admin only for cross-user)
353
+ - user_name: Filter by user name partial match (admin only)
354
+ - user_email: Filter by user email partial match (admin only)
290
355
  - mode: Filter by session mode (normal or evaluation)
291
356
 
292
357
  Pagination:
293
358
  - page: Page number (1-indexed, default: 1)
294
359
  - page_size: Number of sessions per page (default: 50, max: 100)
295
360
 
296
- Returns paginated results ordered by created_at descending with pagination metadata.
361
+ Returns paginated results with user info ordered by created_at descending.
297
362
  """
298
363
  if not settings.postgres.enabled:
299
364
  raise HTTPException(status_code=503, detail="Database not enabled")
300
365
 
301
- repo = Repository(Session, table_name="sessions")
366
+ current_user = get_current_user(request)
367
+ admin = is_admin(current_user)
302
368
 
303
- # Build user-scoped filters (admin can see all, regular users see only their own)
304
- filters = await get_user_filter(request, x_user_id=user_id)
305
- if mode:
306
- filters["mode"] = mode.value
369
+ # Get postgres service for raw SQL query
370
+ db = get_postgres_service()
371
+ if not db:
372
+ raise HTTPException(status_code=503, detail="Database connection failed")
373
+ if not db.pool:
374
+ await db.connect()
307
375
 
308
- # Use CTE-based pagination with ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY created_at DESC)
309
- result = await repo.find_paginated(
310
- filters,
311
- page=page,
312
- page_size=page_size,
313
- order_by="created_at DESC",
314
- partition_by="user_id",
315
- )
376
+ # Build effective filters based on user role
377
+ effective_user_id = user_id
378
+ effective_user_name = user_name if admin else None # Only admin can search by name
379
+ effective_user_email = user_email if admin else None # Only admin can search by email
380
+
381
+ if not admin:
382
+ # Non-admin users can only see their own sessions
383
+ effective_user_id = current_user.get("id") if current_user else None
384
+ if not effective_user_id:
385
+ # Anonymous user - return empty
386
+ return SessionsQueryResponse(
387
+ data=[],
388
+ metadata=PaginationMetadata(
389
+ total=0, page=page, page_size=page_size,
390
+ total_pages=0, has_next=False, has_previous=False,
391
+ ),
392
+ )
393
+
394
+ # Call the SQL function for sessions with user info
395
+ async with db.pool.acquire() as conn:
396
+ rows = await conn.fetch(
397
+ """
398
+ SELECT * FROM fn_list_sessions_with_user(
399
+ $1, $2, $3, $4, $5, $6
400
+ )
401
+ """,
402
+ effective_user_id,
403
+ effective_user_name,
404
+ effective_user_email,
405
+ mode.value if mode else None,
406
+ page,
407
+ page_size,
408
+ )
409
+
410
+ # Extract total from first row
411
+ total = rows[0]["total_count"] if rows else 0
412
+
413
+ # Convert rows to SessionWithUser
414
+ data = [
415
+ SessionWithUser(
416
+ id=str(row["id"]),
417
+ name=row["name"],
418
+ mode=row["mode"],
419
+ description=row["description"],
420
+ user_id=row["user_id"],
421
+ user_name=row["user_name"],
422
+ user_email=row["user_email"],
423
+ message_count=row["message_count"] or 0,
424
+ total_tokens=row["total_tokens"],
425
+ created_at=row["created_at"],
426
+ updated_at=row["updated_at"],
427
+ metadata=row["metadata"],
428
+ )
429
+ for row in rows
430
+ ]
431
+
432
+ total_pages = (total + page_size - 1) // page_size if total > 0 else 0
316
433
 
317
434
  return SessionsQueryResponse(
318
- data=result["data"],
435
+ data=data,
319
436
  metadata=PaginationMetadata(
320
- total=result["total"],
321
- page=result["page"],
322
- page_size=result["page_size"],
323
- total_pages=result["total_pages"],
324
- has_next=result["has_next"],
325
- has_previous=result["has_previous"],
437
+ total=total,
438
+ page=page,
439
+ page_size=page_size,
440
+ total_pages=total_pages,
441
+ has_next=page < total_pages,
442
+ has_previous=page > 1,
326
443
  ),
327
444
  )
328
445
 
329
446
 
330
- @router.post("/sessions", response_model=Session, status_code=201, tags=["sessions"])
447
+ @router.post(
448
+ "/sessions",
449
+ response_model=Session,
450
+ status_code=201,
451
+ tags=["sessions"],
452
+ responses={
453
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
454
+ },
455
+ )
331
456
  async def create_session(
332
457
  request_body: SessionCreateRequest,
333
458
  user: dict = Depends(require_admin),
@@ -379,7 +504,16 @@ async def create_session(
379
504
  return result # type: ignore
380
505
 
381
506
 
382
- @router.get("/sessions/{session_id}", response_model=Session, tags=["sessions"])
507
+ @router.get(
508
+ "/sessions/{session_id}",
509
+ response_model=Session,
510
+ tags=["sessions"],
511
+ responses={
512
+ 403: {"model": ErrorResponse, "description": "Access denied: not owner"},
513
+ 404: {"model": ErrorResponse, "description": "Session not found"},
514
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
515
+ },
516
+ )
383
517
  async def get_session(
384
518
  request: Request,
385
519
  session_id: str,
@@ -392,7 +526,7 @@ async def get_session(
392
526
  - Admin users: Can access any session
393
527
 
394
528
  Args:
395
- session_id: UUID or name of the session
529
+ session_id: UUID of the session
396
530
 
397
531
  Returns:
398
532
  Session object if found
@@ -408,12 +542,7 @@ async def get_session(
408
542
  session = await repo.get_by_id(session_id)
409
543
 
410
544
  if not session:
411
- # Try finding by name
412
- sessions = await repo.find({"name": session_id}, limit=1)
413
- if sessions:
414
- session = sessions[0]
415
- else:
416
- raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
545
+ raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
417
546
 
418
547
  # Check access: admin or owner
419
548
  current_user = get_current_user(request)
@@ -425,7 +554,16 @@ async def get_session(
425
554
  return session
426
555
 
427
556
 
428
- @router.put("/sessions/{session_id}", response_model=Session, tags=["sessions"])
557
+ @router.put(
558
+ "/sessions/{session_id}",
559
+ response_model=Session,
560
+ tags=["sessions"],
561
+ responses={
562
+ 403: {"model": ErrorResponse, "description": "Access denied: not owner"},
563
+ 404: {"model": ErrorResponse, "description": "Session not found"},
564
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
565
+ },
566
+ )
429
567
  async def update_session(
430
568
  request: Request,
431
569
  session_id: str,
rem/api/routers/models.py CHANGED
@@ -15,6 +15,8 @@ from typing import Literal
15
15
  from fastapi import APIRouter, HTTPException
16
16
  from pydantic import BaseModel, Field
17
17
 
18
+ from .common import ErrorResponse
19
+
18
20
  from rem.agentic.llm_provider_models import (
19
21
  ModelInfo,
20
22
  AVAILABLE_MODELS,
@@ -57,7 +59,13 @@ async def list_models() -> ModelsResponse:
57
59
  return ModelsResponse(data=AVAILABLE_MODELS)
58
60
 
59
61
 
60
- @router.get("/models/{model_id:path}", response_model=ModelInfo)
62
+ @router.get(
63
+ "/models/{model_id:path}",
64
+ response_model=ModelInfo,
65
+ responses={
66
+ 404: {"model": ErrorResponse, "description": "Model not found"},
67
+ },
68
+ )
61
69
  async def get_model(model_id: str) -> ModelInfo:
62
70
  """
63
71
  Get information about a specific model.
rem/api/routers/query.py CHANGED
@@ -86,6 +86,8 @@ from fastapi import APIRouter, Header, HTTPException
86
86
  from loguru import logger
87
87
  from pydantic import BaseModel, Field
88
88
 
89
+ from .common import ErrorResponse
90
+
89
91
  from ...services.postgres import get_postgres_service
90
92
  from ...services.rem.service import RemService
91
93
  from ...services.rem.parser import RemQueryParser
@@ -213,7 +215,16 @@ class QueryResponse(BaseModel):
213
215
  )
214
216
 
215
217
 
216
- @router.post("/query", response_model=QueryResponse)
218
+ @router.post(
219
+ "/query",
220
+ response_model=QueryResponse,
221
+ responses={
222
+ 400: {"model": ErrorResponse, "description": "Invalid query or missing required fields"},
223
+ 500: {"model": ErrorResponse, "description": "Query execution failed"},
224
+ 501: {"model": ErrorResponse, "description": "Feature not yet implemented"},
225
+ 503: {"model": ErrorResponse, "description": "Database not configured or unavailable"},
226
+ },
227
+ )
217
228
  async def execute_query(
218
229
  request: QueryRequest,
219
230
  x_user_id: str | None = Header(default=None, description="User ID for query isolation (optional, uses default if not provided)"),
@@ -18,6 +18,8 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request
18
18
  from loguru import logger
19
19
  from pydantic import BaseModel, Field
20
20
 
21
+ from .common import ErrorResponse
22
+
21
23
  from ..deps import get_current_user, require_auth
22
24
  from ...models.entities import (
23
25
  Message,
@@ -83,6 +85,10 @@ class ShareSessionResponse(BaseModel):
83
85
  response_model=ShareSessionResponse,
84
86
  status_code=201,
85
87
  tags=["sessions"],
88
+ responses={
89
+ 400: {"model": ErrorResponse, "description": "Session already shared with this user"},
90
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
91
+ },
86
92
  )
87
93
  async def share_session(
88
94
  request: Request,
@@ -175,6 +181,10 @@ async def share_session(
175
181
  "/sessions/{session_id}/share/{shared_with_user_id}",
176
182
  status_code=200,
177
183
  tags=["sessions"],
184
+ responses={
185
+ 404: {"model": ErrorResponse, "description": "Share not found"},
186
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
187
+ },
178
188
  )
179
189
  async def remove_session_share(
180
190
  request: Request,
@@ -250,6 +260,9 @@ async def remove_session_share(
250
260
  "/sessions/shared-with-me",
251
261
  response_model=SharedWithMeResponse,
252
262
  tags=["sessions"],
263
+ responses={
264
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
265
+ },
253
266
  )
254
267
  async def get_shared_with_me(
255
268
  request: Request,
@@ -328,6 +341,9 @@ async def get_shared_with_me(
328
341
  "/sessions/shared-with-me/{owner_user_id}/messages",
329
342
  response_model=SharedMessagesResponse,
330
343
  tags=["sessions"],
344
+ responses={
345
+ 503: {"model": ErrorResponse, "description": "Database not enabled"},
346
+ },
331
347
  )
332
348
  async def get_shared_messages(
333
349
  request: Request,
rem/auth/jwt.py CHANGED
@@ -260,12 +260,16 @@ class JWTService:
260
260
  "tenant_id": payload.get("tenant_id", "default"),
261
261
  }
262
262
 
263
- def refresh_access_token(self, refresh_token: str) -> dict | None:
263
+ def refresh_access_token(
264
+ self, refresh_token: str, user_override: dict | None = None
265
+ ) -> dict | None:
264
266
  """
265
267
  Create new access token using refresh token.
266
268
 
267
269
  Args:
268
270
  refresh_token: Valid refresh token
271
+ user_override: Optional dict with user fields to override defaults
272
+ (e.g., role, roles, tier, name from database lookup)
269
273
 
270
274
  Returns:
271
275
  New token dict or None if refresh token is invalid
@@ -285,8 +289,7 @@ class JWTService:
285
289
  logger.debug("Refresh token expired")
286
290
  return None
287
291
 
288
- # Create new access token with minimal info from refresh token
289
- # In production, you'd look up the full user from database
292
+ # Build user dict with defaults
290
293
  user = {
291
294
  "id": payload.get("sub"),
292
295
  "email": payload.get("email"),
@@ -294,16 +297,28 @@ class JWTService:
294
297
  "provider": "email",
295
298
  "tenant_id": "default",
296
299
  "tier": "free",
300
+ "role": "user",
297
301
  "roles": ["user"],
298
302
  }
299
303
 
304
+ # Apply overrides from database lookup if provided
305
+ if user_override:
306
+ if user_override.get("role"):
307
+ user["role"] = user_override["role"]
308
+ if user_override.get("roles"):
309
+ user["roles"] = user_override["roles"]
310
+ if user_override.get("tier"):
311
+ user["tier"] = user_override["tier"]
312
+ if user_override.get("name"):
313
+ user["name"] = user_override["name"]
314
+
300
315
  # Only return new access token, keep same refresh token
301
316
  now = int(time.time())
302
317
  access_payload = {
303
318
  "sub": user["id"],
304
319
  "email": user["email"],
305
320
  "name": user["name"],
306
- "role": user.get("role"),
321
+ "role": user["role"],
307
322
  "tier": user["tier"],
308
323
  "roles": user["roles"],
309
324
  "provider": user["provider"],
rem/auth/middleware.py CHANGED
@@ -14,15 +14,14 @@ Design Pattern:
14
14
  - MCP paths always require authentication (protected service)
15
15
 
16
16
  Authentication Flow:
17
- 1. If API key enabled: Validate X-API-Key header (access gate)
18
- 2. Check JWT token for user identity (primary)
19
- 3. Check dev token for testing (non-production only)
20
- 4. Check session for user (backward compatibility)
21
- 5. If allow_anonymous=True: Allow as anonymous (rate-limited)
22
- 6. If allow_anonymous=False: Return 401 / redirect to login
17
+ 1. Check JWT/dev token/session for user identity first
18
+ 2. If user is admin: bypass API key check (admin privilege)
19
+ 3. If API key enabled and user is not admin: Validate X-API-Key header
20
+ 4. If allow_anonymous=True: Allow as anonymous (rate-limited)
21
+ 5. If allow_anonymous=False: Return 401 / redirect to login
23
22
 
24
23
  IMPORTANT: API key validates ACCESS, JWT identifies USER.
25
- Both can be required: API key for access + JWT for user identity.
24
+ Admin users bypass the API key requirement (trusted identity).
26
25
 
27
26
  Access Modes (configured in settings.auth):
28
27
  - enabled=true, allow_anonymous=true: Auth available, anonymous gets rate-limited access
@@ -195,6 +194,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
195
194
 
196
195
  return None
197
196
 
197
+ def _is_admin(self, user: dict | None) -> bool:
198
+ """Check if user has admin role."""
199
+ if not user:
200
+ return False
201
+ return "admin" in user.get("roles", [])
202
+
198
203
  async def dispatch(self, request: Request, call_next):
199
204
  """
200
205
  Check authentication for protected paths.
@@ -219,8 +224,35 @@ class AuthMiddleware(BaseHTTPMiddleware):
219
224
  if not is_protected or is_excluded:
220
225
  return await call_next(request)
221
226
 
222
- # API key validation (access control, not user identity)
223
- # API key is a guardrail for access - JWT identifies the actual user
227
+ # Check for user identity FIRST (JWT, dev token, session)
228
+ # This allows admin users to bypass API key requirement
229
+ user = None
230
+
231
+ # Check for JWT token in Authorization header (primary user identity)
232
+ jwt_user = self._check_jwt_token(request)
233
+ if jwt_user:
234
+ user = jwt_user
235
+
236
+ # Check for dev token (non-production only)
237
+ if not user:
238
+ dev_user = self._check_dev_token(request)
239
+ if dev_user:
240
+ user = dev_user
241
+
242
+ # Check for valid session (backward compatibility)
243
+ if not user:
244
+ session_user = request.session.get("user")
245
+ if session_user:
246
+ user = session_user
247
+
248
+ # If user is admin, bypass API key check entirely
249
+ if self._is_admin(user):
250
+ logger.debug(f"Admin user {user.get('email')} bypassing API key check")
251
+ request.state.user = user
252
+ request.state.is_anonymous = False
253
+ return await call_next(request)
254
+
255
+ # API key validation for non-admin users (access control guardrail)
224
256
  if settings.api.api_key_enabled:
225
257
  api_key = request.headers.get("x-api-key")
226
258
  if not api_key:
@@ -238,27 +270,9 @@ class AuthMiddleware(BaseHTTPMiddleware):
238
270
  headers={"WWW-Authenticate": 'ApiKey realm="REM API"'},
239
271
  )
240
272
  logger.debug("X-API-Key validated for access")
241
- # API key valid - continue to check JWT for user identity
242
-
243
- # Check for JWT token in Authorization header (primary user identity)
244
- jwt_user = self._check_jwt_token(request)
245
- if jwt_user:
246
- request.state.user = jwt_user
247
- request.state.is_anonymous = False
248
- return await call_next(request)
249
-
250
- # Check for dev token (non-production only)
251
- dev_user = self._check_dev_token(request)
252
- if dev_user:
253
- request.state.user = dev_user
254
- request.state.is_anonymous = False
255
- return await call_next(request)
256
-
257
- # Check for valid session (backward compatibility)
258
- user = request.session.get("user")
259
273
 
274
+ # If we have a valid user (non-admin, but passed API key check), allow access
260
275
  if user:
261
- # Authenticated user - add to request state
262
276
  request.state.user = user
263
277
  request.state.is_anonymous = False
264
278
  return await call_next(request)