remdb 0.3.14__py3-none-any.whl → 0.3.133__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rem/agentic/README.md +76 -0
- rem/agentic/__init__.py +15 -0
- rem/agentic/agents/__init__.py +16 -2
- rem/agentic/agents/sse_simulator.py +502 -0
- rem/agentic/context.py +51 -27
- rem/agentic/llm_provider_models.py +301 -0
- rem/agentic/mcp/tool_wrapper.py +112 -17
- rem/agentic/otel/setup.py +93 -4
- rem/agentic/providers/phoenix.py +302 -109
- rem/agentic/providers/pydantic_ai.py +215 -26
- rem/agentic/schema.py +361 -21
- rem/agentic/tools/rem_tools.py +3 -3
- rem/api/README.md +215 -1
- rem/api/deps.py +255 -0
- rem/api/main.py +132 -40
- rem/api/mcp_router/resources.py +1 -1
- rem/api/mcp_router/server.py +26 -5
- rem/api/mcp_router/tools.py +465 -7
- rem/api/routers/admin.py +494 -0
- rem/api/routers/auth.py +70 -0
- rem/api/routers/chat/completions.py +402 -20
- rem/api/routers/chat/models.py +88 -10
- rem/api/routers/chat/otel_utils.py +33 -0
- rem/api/routers/chat/sse_events.py +542 -0
- rem/api/routers/chat/streaming.py +642 -45
- rem/api/routers/dev.py +81 -0
- rem/api/routers/feedback.py +268 -0
- rem/api/routers/messages.py +473 -0
- rem/api/routers/models.py +78 -0
- rem/api/routers/query.py +360 -0
- rem/api/routers/shared_sessions.py +406 -0
- rem/auth/middleware.py +126 -27
- rem/cli/commands/README.md +237 -64
- rem/cli/commands/cluster.py +1808 -0
- rem/cli/commands/configure.py +1 -3
- rem/cli/commands/db.py +386 -143
- rem/cli/commands/experiments.py +418 -27
- rem/cli/commands/process.py +14 -8
- rem/cli/commands/schema.py +97 -50
- rem/cli/main.py +27 -6
- rem/config.py +10 -3
- rem/models/core/core_model.py +7 -1
- rem/models/core/experiment.py +54 -0
- rem/models/core/rem_query.py +5 -2
- rem/models/entities/__init__.py +21 -0
- rem/models/entities/domain_resource.py +38 -0
- rem/models/entities/feedback.py +123 -0
- rem/models/entities/message.py +30 -1
- rem/models/entities/session.py +83 -0
- rem/models/entities/shared_session.py +180 -0
- rem/registry.py +10 -4
- rem/schemas/agents/rem.yaml +7 -3
- rem/services/content/service.py +92 -20
- rem/services/embeddings/api.py +4 -4
- rem/services/embeddings/worker.py +16 -16
- rem/services/phoenix/client.py +154 -14
- rem/services/postgres/README.md +159 -15
- rem/services/postgres/__init__.py +2 -1
- rem/services/postgres/diff_service.py +531 -0
- rem/services/postgres/pydantic_to_sqlalchemy.py +427 -129
- rem/services/postgres/repository.py +132 -0
- rem/services/postgres/schema_generator.py +205 -4
- rem/services/postgres/service.py +6 -6
- rem/services/rem/parser.py +44 -9
- rem/services/rem/service.py +36 -2
- rem/services/session/compression.py +24 -1
- rem/services/session/reload.py +1 -1
- rem/settings.py +324 -23
- rem/sql/background_indexes.sql +21 -16
- rem/sql/migrations/001_install.sql +387 -54
- rem/sql/migrations/002_install_models.sql +2320 -393
- rem/sql/migrations/003_optional_extensions.sql +326 -0
- rem/sql/migrations/004_cache_system.sql +548 -0
- rem/utils/__init__.py +18 -0
- rem/utils/date_utils.py +2 -2
- rem/utils/model_helpers.py +156 -1
- rem/utils/schema_loader.py +220 -22
- rem/utils/sql_paths.py +146 -0
- rem/utils/sql_types.py +3 -1
- rem/workers/__init__.py +3 -1
- rem/workers/db_listener.py +579 -0
- rem/workers/unlogged_maintainer.py +463 -0
- {remdb-0.3.14.dist-info → remdb-0.3.133.dist-info}/METADATA +335 -226
- {remdb-0.3.14.dist-info → remdb-0.3.133.dist-info}/RECORD +86 -66
- {remdb-0.3.14.dist-info → remdb-0.3.133.dist-info}/WHEEL +1 -1
- rem/sql/002_install_models.sql +0 -1068
- rem/sql/install_models.sql +0 -1051
- rem/sql/migrations/003_seed_default_user.sql +0 -48
- {remdb-0.3.14.dist-info → remdb-0.3.133.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Session sharing endpoints.
|
|
3
|
+
|
|
4
|
+
Enables session sharing between users for collaborative access to conversation history.
|
|
5
|
+
|
|
6
|
+
Endpoints:
|
|
7
|
+
POST /api/v1/sessions/{session_id}/share - Share a session with another user
|
|
8
|
+
DELETE /api/v1/sessions/{session_id}/share/{user_id} - Revoke a share (soft delete)
|
|
9
|
+
GET /api/v1/sessions/shared-with-me - Get users sharing sessions with you
|
|
10
|
+
GET /api/v1/sessions/shared-with-me/{user_id}/messages - Get messages from a user's shared sessions
|
|
11
|
+
|
|
12
|
+
See src/rem/models/entities/shared_session.py for full documentation.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from typing import Literal
|
|
16
|
+
|
|
17
|
+
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request
|
|
18
|
+
from loguru import logger
|
|
19
|
+
from pydantic import BaseModel, Field
|
|
20
|
+
|
|
21
|
+
from ..deps import get_current_user, require_auth
|
|
22
|
+
from ...models.entities import (
|
|
23
|
+
Message,
|
|
24
|
+
SharedSession,
|
|
25
|
+
SharedSessionCreate,
|
|
26
|
+
SharedWithMeResponse,
|
|
27
|
+
SharedWithMeSummary,
|
|
28
|
+
)
|
|
29
|
+
from ...services.postgres import get_postgres_service
|
|
30
|
+
from ...settings import settings
|
|
31
|
+
from ...utils.date_utils import utc_now
|
|
32
|
+
|
|
33
|
+
router = APIRouter(prefix="/api/v1")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
async def get_connected_postgres():
|
|
37
|
+
"""Get a connected PostgresService instance."""
|
|
38
|
+
pg = get_postgres_service()
|
|
39
|
+
if pg and not pg.pool:
|
|
40
|
+
await pg.connect()
|
|
41
|
+
return pg
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# =============================================================================
|
|
45
|
+
# Request/Response Models
|
|
46
|
+
# =============================================================================
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class PaginationMetadata(BaseModel):
|
|
50
|
+
"""Pagination metadata for paginated responses."""
|
|
51
|
+
|
|
52
|
+
total: int = Field(description="Total number of records matching filters")
|
|
53
|
+
page: int = Field(description="Current page number (1-indexed)")
|
|
54
|
+
page_size: int = Field(description="Number of records per page")
|
|
55
|
+
total_pages: int = Field(description="Total number of pages")
|
|
56
|
+
has_next: bool = Field(description="Whether there are more pages after this one")
|
|
57
|
+
has_previous: bool = Field(description="Whether there are pages before this one")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SharedMessagesResponse(BaseModel):
|
|
61
|
+
"""Response for shared messages query."""
|
|
62
|
+
|
|
63
|
+
object: Literal["list"] = "list"
|
|
64
|
+
data: list[Message] = Field(description="List of messages from shared sessions")
|
|
65
|
+
metadata: PaginationMetadata = Field(description="Pagination metadata")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ShareSessionResponse(BaseModel):
|
|
69
|
+
"""Response after sharing a session."""
|
|
70
|
+
|
|
71
|
+
success: bool = True
|
|
72
|
+
message: str
|
|
73
|
+
share: SharedSession
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# =============================================================================
|
|
77
|
+
# Share Session Endpoints
|
|
78
|
+
# =============================================================================
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@router.post(
|
|
82
|
+
"/sessions/{session_id}/share",
|
|
83
|
+
response_model=ShareSessionResponse,
|
|
84
|
+
status_code=201,
|
|
85
|
+
tags=["sessions"],
|
|
86
|
+
)
|
|
87
|
+
async def share_session(
|
|
88
|
+
request: Request,
|
|
89
|
+
session_id: str,
|
|
90
|
+
body: SharedSessionCreate,
|
|
91
|
+
user: dict = Depends(require_auth),
|
|
92
|
+
x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
|
|
93
|
+
) -> ShareSessionResponse:
|
|
94
|
+
"""
|
|
95
|
+
Share a session with another user.
|
|
96
|
+
|
|
97
|
+
Creates a SharedSession record that grants the recipient access to view
|
|
98
|
+
messages in this session.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
session_id: The session to share
|
|
102
|
+
body: Contains shared_with_user_id - the recipient
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
The created SharedSession record
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
400: Session already shared with this user
|
|
109
|
+
503: Database not enabled
|
|
110
|
+
"""
|
|
111
|
+
if not settings.postgres.enabled:
|
|
112
|
+
raise HTTPException(status_code=503, detail="Database not enabled")
|
|
113
|
+
|
|
114
|
+
current_user_id = user.get("id", "default")
|
|
115
|
+
pg = await get_connected_postgres()
|
|
116
|
+
|
|
117
|
+
# Check if share already exists (active)
|
|
118
|
+
existing = await pg.fetchrow(
|
|
119
|
+
"""
|
|
120
|
+
SELECT id FROM shared_sessions
|
|
121
|
+
WHERE tenant_id = $1
|
|
122
|
+
AND session_id = $2
|
|
123
|
+
AND owner_user_id = $3
|
|
124
|
+
AND shared_with_user_id = $4
|
|
125
|
+
AND deleted_at IS NULL
|
|
126
|
+
""",
|
|
127
|
+
x_tenant_id,
|
|
128
|
+
session_id,
|
|
129
|
+
current_user_id,
|
|
130
|
+
body.shared_with_user_id,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
if existing:
|
|
134
|
+
raise HTTPException(
|
|
135
|
+
status_code=400,
|
|
136
|
+
detail=f"Session '{session_id}' is already shared with user '{body.shared_with_user_id}'",
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Create the share
|
|
140
|
+
result = await pg.fetchrow(
|
|
141
|
+
"""
|
|
142
|
+
INSERT INTO shared_sessions (session_id, owner_user_id, shared_with_user_id, tenant_id)
|
|
143
|
+
VALUES ($1, $2, $3, $4)
|
|
144
|
+
RETURNING id, session_id, owner_user_id, shared_with_user_id, tenant_id, created_at, updated_at, deleted_at
|
|
145
|
+
""",
|
|
146
|
+
session_id,
|
|
147
|
+
current_user_id,
|
|
148
|
+
body.shared_with_user_id,
|
|
149
|
+
x_tenant_id,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
share = SharedSession(
|
|
153
|
+
id=result["id"],
|
|
154
|
+
session_id=result["session_id"],
|
|
155
|
+
owner_user_id=result["owner_user_id"],
|
|
156
|
+
shared_with_user_id=result["shared_with_user_id"],
|
|
157
|
+
tenant_id=result["tenant_id"],
|
|
158
|
+
created_at=result["created_at"],
|
|
159
|
+
updated_at=result["updated_at"],
|
|
160
|
+
deleted_at=result["deleted_at"],
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
logger.debug(
|
|
164
|
+
f"User {current_user_id} shared session '{session_id}' with {body.shared_with_user_id}"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
return ShareSessionResponse(
|
|
168
|
+
success=True,
|
|
169
|
+
message=f"Session shared with {body.shared_with_user_id}",
|
|
170
|
+
share=share,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@router.delete(
|
|
175
|
+
"/sessions/{session_id}/share/{shared_with_user_id}",
|
|
176
|
+
status_code=200,
|
|
177
|
+
tags=["sessions"],
|
|
178
|
+
)
|
|
179
|
+
async def remove_session_share(
|
|
180
|
+
request: Request,
|
|
181
|
+
session_id: str,
|
|
182
|
+
shared_with_user_id: str,
|
|
183
|
+
user: dict = Depends(require_auth),
|
|
184
|
+
x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
|
|
185
|
+
) -> dict:
|
|
186
|
+
"""
|
|
187
|
+
Remove a session share (soft delete).
|
|
188
|
+
|
|
189
|
+
Sets deleted_at on the SharedSession record. The share can be re-created
|
|
190
|
+
later if needed.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
session_id: The session to unshare
|
|
194
|
+
shared_with_user_id: The user to remove access from
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Success message
|
|
198
|
+
|
|
199
|
+
Raises:
|
|
200
|
+
404: Share not found
|
|
201
|
+
503: Database not enabled
|
|
202
|
+
"""
|
|
203
|
+
if not settings.postgres.enabled:
|
|
204
|
+
raise HTTPException(status_code=503, detail="Database not enabled")
|
|
205
|
+
|
|
206
|
+
current_user_id = user.get("id", "default")
|
|
207
|
+
pg = await get_connected_postgres()
|
|
208
|
+
|
|
209
|
+
# Soft delete the share
|
|
210
|
+
result = await pg.fetchrow(
|
|
211
|
+
"""
|
|
212
|
+
UPDATE shared_sessions
|
|
213
|
+
SET deleted_at = $1, updated_at = $1
|
|
214
|
+
WHERE tenant_id = $2
|
|
215
|
+
AND session_id = $3
|
|
216
|
+
AND owner_user_id = $4
|
|
217
|
+
AND shared_with_user_id = $5
|
|
218
|
+
AND deleted_at IS NULL
|
|
219
|
+
RETURNING id
|
|
220
|
+
""",
|
|
221
|
+
utc_now(),
|
|
222
|
+
x_tenant_id,
|
|
223
|
+
session_id,
|
|
224
|
+
current_user_id,
|
|
225
|
+
shared_with_user_id,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if not result:
|
|
229
|
+
raise HTTPException(
|
|
230
|
+
status_code=404,
|
|
231
|
+
detail=f"Share not found for session '{session_id}' with user '{shared_with_user_id}'",
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
logger.debug(
|
|
235
|
+
f"User {current_user_id} removed share for session '{session_id}' with {shared_with_user_id}"
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
return {
|
|
239
|
+
"success": True,
|
|
240
|
+
"message": f"Share removed for user {shared_with_user_id}",
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
# =============================================================================
|
|
245
|
+
# Shared With Me Endpoints
|
|
246
|
+
# =============================================================================
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@router.get(
|
|
250
|
+
"/sessions/shared-with-me",
|
|
251
|
+
response_model=SharedWithMeResponse,
|
|
252
|
+
tags=["sessions"],
|
|
253
|
+
)
|
|
254
|
+
async def get_shared_with_me(
|
|
255
|
+
request: Request,
|
|
256
|
+
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
|
257
|
+
page_size: int = Query(default=50, ge=1, le=100, description="Results per page"),
|
|
258
|
+
user: dict = Depends(require_auth),
|
|
259
|
+
x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
|
|
260
|
+
) -> SharedWithMeResponse:
|
|
261
|
+
"""
|
|
262
|
+
Get aggregate summary of users sharing sessions with you.
|
|
263
|
+
|
|
264
|
+
Returns a paginated list of users who have shared sessions with the
|
|
265
|
+
current user, including message counts and date ranges.
|
|
266
|
+
|
|
267
|
+
Each entry shows:
|
|
268
|
+
- user_id, name, email of the person sharing
|
|
269
|
+
- message_count: total messages across all their shared sessions
|
|
270
|
+
- session_count: number of sessions they've shared
|
|
271
|
+
- first_message_at, last_message_at: date range
|
|
272
|
+
|
|
273
|
+
Results are ordered by most recent message first.
|
|
274
|
+
"""
|
|
275
|
+
if not settings.postgres.enabled:
|
|
276
|
+
raise HTTPException(status_code=503, detail="Database not enabled")
|
|
277
|
+
|
|
278
|
+
current_user_id = user.get("id", "default")
|
|
279
|
+
pg = await get_connected_postgres()
|
|
280
|
+
offset = (page - 1) * page_size
|
|
281
|
+
|
|
282
|
+
# Get total count
|
|
283
|
+
count_result = await pg.fetchrow(
|
|
284
|
+
"SELECT fn_count_shared_with_me($1, $2) as total",
|
|
285
|
+
x_tenant_id,
|
|
286
|
+
current_user_id,
|
|
287
|
+
)
|
|
288
|
+
total = count_result["total"] if count_result else 0
|
|
289
|
+
|
|
290
|
+
# Get paginated results
|
|
291
|
+
rows = await pg.fetch(
|
|
292
|
+
"SELECT * FROM fn_get_shared_with_me($1, $2, $3, $4)",
|
|
293
|
+
x_tenant_id,
|
|
294
|
+
current_user_id,
|
|
295
|
+
page_size,
|
|
296
|
+
offset,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
data = [
|
|
300
|
+
SharedWithMeSummary(
|
|
301
|
+
user_id=row["user_id"],
|
|
302
|
+
name=row["name"],
|
|
303
|
+
email=row["email"],
|
|
304
|
+
message_count=row["message_count"],
|
|
305
|
+
session_count=row["session_count"],
|
|
306
|
+
first_message_at=row["first_message_at"],
|
|
307
|
+
last_message_at=row["last_message_at"],
|
|
308
|
+
)
|
|
309
|
+
for row in rows
|
|
310
|
+
]
|
|
311
|
+
|
|
312
|
+
total_pages = (total + page_size - 1) // page_size if total > 0 else 1
|
|
313
|
+
|
|
314
|
+
return SharedWithMeResponse(
|
|
315
|
+
data=data,
|
|
316
|
+
metadata={
|
|
317
|
+
"total": total,
|
|
318
|
+
"page": page,
|
|
319
|
+
"page_size": page_size,
|
|
320
|
+
"total_pages": total_pages,
|
|
321
|
+
"has_next": page < total_pages,
|
|
322
|
+
"has_previous": page > 1,
|
|
323
|
+
},
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
@router.get(
|
|
328
|
+
"/sessions/shared-with-me/{owner_user_id}/messages",
|
|
329
|
+
response_model=SharedMessagesResponse,
|
|
330
|
+
tags=["sessions"],
|
|
331
|
+
)
|
|
332
|
+
async def get_shared_messages(
|
|
333
|
+
request: Request,
|
|
334
|
+
owner_user_id: str,
|
|
335
|
+
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
|
336
|
+
page_size: int = Query(default=50, ge=1, le=100, description="Results per page"),
|
|
337
|
+
user: dict = Depends(require_auth),
|
|
338
|
+
x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
|
|
339
|
+
) -> SharedMessagesResponse:
|
|
340
|
+
"""
|
|
341
|
+
Get messages from sessions shared by a specific user.
|
|
342
|
+
|
|
343
|
+
Returns paginated messages from all sessions that owner_user_id has
|
|
344
|
+
shared with the current user. Messages are ordered by created_at DESC.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
owner_user_id: The user who shared the sessions
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
Paginated list of Message objects
|
|
351
|
+
"""
|
|
352
|
+
if not settings.postgres.enabled:
|
|
353
|
+
raise HTTPException(status_code=503, detail="Database not enabled")
|
|
354
|
+
|
|
355
|
+
current_user_id = user.get("id", "default")
|
|
356
|
+
pg = await get_connected_postgres()
|
|
357
|
+
offset = (page - 1) * page_size
|
|
358
|
+
|
|
359
|
+
# Get total count
|
|
360
|
+
count_result = await pg.fetchrow(
|
|
361
|
+
"SELECT fn_count_shared_messages($1, $2, $3) as total",
|
|
362
|
+
x_tenant_id,
|
|
363
|
+
current_user_id,
|
|
364
|
+
owner_user_id,
|
|
365
|
+
)
|
|
366
|
+
total = count_result["total"] if count_result else 0
|
|
367
|
+
|
|
368
|
+
# Get paginated messages
|
|
369
|
+
rows = await pg.fetch(
|
|
370
|
+
"SELECT * FROM fn_get_shared_messages($1, $2, $3, $4, $5)",
|
|
371
|
+
x_tenant_id,
|
|
372
|
+
current_user_id,
|
|
373
|
+
owner_user_id,
|
|
374
|
+
page_size,
|
|
375
|
+
offset,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# Convert to Message objects
|
|
379
|
+
data = [
|
|
380
|
+
Message(
|
|
381
|
+
id=row["id"],
|
|
382
|
+
content=row["content"],
|
|
383
|
+
message_type=row["message_type"],
|
|
384
|
+
session_id=row["session_id"],
|
|
385
|
+
model=row["model"],
|
|
386
|
+
token_count=row["token_count"],
|
|
387
|
+
created_at=row["created_at"],
|
|
388
|
+
metadata=row["metadata"] or {},
|
|
389
|
+
tenant_id=x_tenant_id,
|
|
390
|
+
)
|
|
391
|
+
for row in rows
|
|
392
|
+
]
|
|
393
|
+
|
|
394
|
+
total_pages = (total + page_size - 1) // page_size if total > 0 else 1
|
|
395
|
+
|
|
396
|
+
return SharedMessagesResponse(
|
|
397
|
+
data=data,
|
|
398
|
+
metadata=PaginationMetadata(
|
|
399
|
+
total=total,
|
|
400
|
+
page=page,
|
|
401
|
+
page_size=page_size,
|
|
402
|
+
total_pages=total_pages,
|
|
403
|
+
has_next=page < total_pages,
|
|
404
|
+
has_previous=page > 1,
|
|
405
|
+
),
|
|
406
|
+
)
|
rem/auth/middleware.py
CHANGED
|
@@ -2,14 +2,29 @@
|
|
|
2
2
|
OAuth Authentication Middleware for FastAPI.
|
|
3
3
|
|
|
4
4
|
Protects API endpoints by requiring valid session.
|
|
5
|
-
|
|
5
|
+
Supports anonymous access with rate limiting when allow_anonymous=True.
|
|
6
|
+
MCP endpoints are always protected unless explicitly disabled.
|
|
6
7
|
|
|
7
8
|
Design Pattern:
|
|
8
9
|
- Check session for user on protected paths
|
|
9
|
-
-
|
|
10
|
-
-
|
|
10
|
+
- Check Bearer token for dev token (non-production only)
|
|
11
|
+
- MCP paths always require authentication (protected service)
|
|
12
|
+
- If allow_anonymous=True: Allow unauthenticated requests (marked as ANONYMOUS tier)
|
|
13
|
+
- If allow_anonymous=False: Return 401 for API calls, redirect browsers to login
|
|
11
14
|
- Exclude auth endpoints and public paths
|
|
12
15
|
|
|
16
|
+
Access Modes (configured in settings.auth):
|
|
17
|
+
- enabled=true, allow_anonymous=true: Auth available, anonymous gets rate-limited access
|
|
18
|
+
- enabled=true, allow_anonymous=false: Auth required for all requests
|
|
19
|
+
- enabled=false: Middleware not loaded, all requests pass through
|
|
20
|
+
- mcp_requires_auth=true (default): MCP always requires login regardless of allow_anonymous
|
|
21
|
+
- mcp_requires_auth=false: MCP follows normal allow_anonymous rules (dev only)
|
|
22
|
+
|
|
23
|
+
Dev Token Support (non-production only):
|
|
24
|
+
- GET /api/auth/dev/token returns a Bearer token for test-user
|
|
25
|
+
- Include as: Authorization: Bearer dev_<signature>
|
|
26
|
+
- Only works when ENVIRONMENT != "production"
|
|
27
|
+
|
|
13
28
|
Usage:
|
|
14
29
|
from rem.auth.middleware import AuthMiddleware
|
|
15
30
|
|
|
@@ -17,6 +32,8 @@ Usage:
|
|
|
17
32
|
AuthMiddleware,
|
|
18
33
|
protected_paths=["/api/v1"],
|
|
19
34
|
excluded_paths=["/api/auth", "/health"],
|
|
35
|
+
allow_anonymous=settings.auth.allow_anonymous,
|
|
36
|
+
mcp_requires_auth=settings.auth.mcp_requires_auth,
|
|
20
37
|
)
|
|
21
38
|
"""
|
|
22
39
|
|
|
@@ -25,6 +42,8 @@ from starlette.requests import Request
|
|
|
25
42
|
from starlette.responses import JSONResponse, RedirectResponse
|
|
26
43
|
from loguru import logger
|
|
27
44
|
|
|
45
|
+
from ..settings import settings
|
|
46
|
+
|
|
28
47
|
|
|
29
48
|
class AuthMiddleware(BaseHTTPMiddleware):
|
|
30
49
|
"""
|
|
@@ -32,6 +51,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
32
51
|
|
|
33
52
|
Checks for valid user session on protected paths.
|
|
34
53
|
Compatible with OAuth flows from auth router.
|
|
54
|
+
Supports anonymous access with rate limiting.
|
|
55
|
+
MCP endpoints are always protected unless explicitly disabled.
|
|
35
56
|
"""
|
|
36
57
|
|
|
37
58
|
def __init__(
|
|
@@ -39,6 +60,9 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
39
60
|
app,
|
|
40
61
|
protected_paths: list[str] | None = None,
|
|
41
62
|
excluded_paths: list[str] | None = None,
|
|
63
|
+
allow_anonymous: bool = True,
|
|
64
|
+
mcp_requires_auth: bool = True,
|
|
65
|
+
mcp_path: str = "/api/v1/mcp",
|
|
42
66
|
):
|
|
43
67
|
"""
|
|
44
68
|
Initialize auth middleware.
|
|
@@ -47,10 +71,52 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
47
71
|
app: ASGI application
|
|
48
72
|
protected_paths: Paths that require authentication
|
|
49
73
|
excluded_paths: Paths to exclude from auth check
|
|
74
|
+
allow_anonymous: Allow unauthenticated requests (rate-limited)
|
|
75
|
+
mcp_requires_auth: Always require auth for MCP (protected service)
|
|
76
|
+
mcp_path: Path prefix for MCP endpoints
|
|
50
77
|
"""
|
|
51
78
|
super().__init__(app)
|
|
52
79
|
self.protected_paths = protected_paths or ["/api/v1"]
|
|
53
80
|
self.excluded_paths = excluded_paths or ["/api/auth", "/health", "/docs", "/openapi.json"]
|
|
81
|
+
self.allow_anonymous = allow_anonymous
|
|
82
|
+
self.mcp_requires_auth = mcp_requires_auth
|
|
83
|
+
self.mcp_path = mcp_path
|
|
84
|
+
|
|
85
|
+
def _check_dev_token(self, request: Request) -> dict | None:
|
|
86
|
+
"""
|
|
87
|
+
Check for valid dev token in Authorization header (non-production only).
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Test user dict if valid dev token, None otherwise
|
|
91
|
+
"""
|
|
92
|
+
if settings.environment == "production":
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
auth_header = request.headers.get("authorization", "")
|
|
96
|
+
if not auth_header.startswith("Bearer "):
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
token = auth_header[7:] # Strip "Bearer "
|
|
100
|
+
|
|
101
|
+
# Only check dev tokens (start with "dev_")
|
|
102
|
+
if not token.startswith("dev_"):
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
# Verify dev token
|
|
106
|
+
from ..api.routers.dev import verify_dev_token
|
|
107
|
+
if verify_dev_token(token):
|
|
108
|
+
logger.debug(f"Dev token authenticated as test-user")
|
|
109
|
+
return {
|
|
110
|
+
"id": "test-user",
|
|
111
|
+
"email": "test@rem.local",
|
|
112
|
+
"name": "Test User",
|
|
113
|
+
"provider": "dev",
|
|
114
|
+
"tenant_id": "default",
|
|
115
|
+
"tier": "pro", # Give test user pro tier for full access
|
|
116
|
+
"roles": ["admin"],
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
return None
|
|
54
120
|
|
|
55
121
|
async def dispatch(self, request: Request, call_next):
|
|
56
122
|
"""
|
|
@@ -61,7 +127,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
61
127
|
call_next: Next middleware in chain
|
|
62
128
|
|
|
63
129
|
Returns:
|
|
64
|
-
Response (401/redirect if unauthorized, normal response if authorized)
|
|
130
|
+
Response (401/redirect if unauthorized, normal response if authorized/anonymous)
|
|
65
131
|
"""
|
|
66
132
|
path = request.url.path
|
|
67
133
|
|
|
@@ -69,32 +135,65 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
69
135
|
is_protected = any(path.startswith(p) for p in self.protected_paths)
|
|
70
136
|
is_excluded = any(path.startswith(p) for p in self.excluded_paths)
|
|
71
137
|
|
|
138
|
+
# Check if this is an MCP path (paid service, always requires auth)
|
|
139
|
+
is_mcp_path = path.startswith(self.mcp_path)
|
|
140
|
+
|
|
72
141
|
# Skip auth check for excluded paths
|
|
73
142
|
if not is_protected or is_excluded:
|
|
74
143
|
return await call_next(request)
|
|
75
144
|
|
|
145
|
+
# Check for dev token (non-production only)
|
|
146
|
+
dev_user = self._check_dev_token(request)
|
|
147
|
+
if dev_user:
|
|
148
|
+
request.state.user = dev_user
|
|
149
|
+
request.state.is_anonymous = False
|
|
150
|
+
return await call_next(request)
|
|
151
|
+
|
|
76
152
|
# Check for valid session
|
|
77
153
|
user = request.session.get("user")
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
154
|
+
|
|
155
|
+
if user:
|
|
156
|
+
# Authenticated user - add to request state
|
|
157
|
+
request.state.user = user
|
|
158
|
+
request.state.is_anonymous = False
|
|
159
|
+
return await call_next(request)
|
|
160
|
+
|
|
161
|
+
# No user session - check if MCP path requires auth
|
|
162
|
+
if is_mcp_path and self.mcp_requires_auth:
|
|
163
|
+
# MCP is a protected service - always require authentication
|
|
164
|
+
logger.warning(f"Unauthorized MCP access attempt: {path}")
|
|
165
|
+
return JSONResponse(
|
|
166
|
+
status_code=401,
|
|
167
|
+
content={
|
|
168
|
+
"detail": "Authentication required for MCP. Please login to use this service.",
|
|
169
|
+
"code": "MCP_AUTH_REQUIRED",
|
|
170
|
+
},
|
|
171
|
+
headers={
|
|
172
|
+
"WWW-Authenticate": 'Bearer realm="REM MCP"',
|
|
173
|
+
},
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# No user session - handle anonymous access for non-MCP paths
|
|
177
|
+
if self.allow_anonymous:
|
|
178
|
+
# Allow anonymous access - rate limiting handled downstream
|
|
179
|
+
request.state.user = None
|
|
180
|
+
request.state.is_anonymous = True
|
|
181
|
+
logger.debug(f"Anonymous access: {path}")
|
|
182
|
+
return await call_next(request)
|
|
183
|
+
|
|
184
|
+
# Anonymous not allowed - require authentication
|
|
185
|
+
logger.warning(f"Unauthorized access attempt: {path}")
|
|
186
|
+
|
|
187
|
+
# Return 401 for API requests (JSON)
|
|
188
|
+
accept = request.headers.get("accept", "")
|
|
189
|
+
if "application/json" in accept or path.startswith("/api/"):
|
|
190
|
+
return JSONResponse(
|
|
191
|
+
status_code=401,
|
|
192
|
+
content={"detail": "Authentication required"},
|
|
193
|
+
headers={
|
|
194
|
+
"WWW-Authenticate": 'Bearer realm="REM API"',
|
|
195
|
+
},
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Redirect to login for browser requests
|
|
199
|
+
return RedirectResponse(url="/api/auth/google/login", status_code=302)
|