remdb 0.3.7__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rem/__init__.py +129 -2
- rem/agentic/context.py +7 -5
- rem/agentic/providers/phoenix.py +32 -43
- rem/api/README.md +23 -0
- rem/api/main.py +27 -2
- rem/api/middleware/tracking.py +172 -0
- rem/api/routers/auth.py +54 -0
- rem/api/routers/chat/completions.py +1 -1
- rem/cli/commands/ask.py +13 -10
- rem/cli/commands/configure.py +4 -3
- rem/cli/commands/db.py +17 -3
- rem/cli/commands/experiments.py +76 -72
- rem/cli/commands/process.py +8 -7
- rem/cli/commands/scaffold.py +47 -0
- rem/cli/main.py +2 -0
- rem/models/entities/user.py +10 -3
- rem/registry.py +367 -0
- rem/services/content/providers.py +92 -133
- rem/services/dreaming/affinity_service.py +2 -16
- rem/services/dreaming/moment_service.py +2 -15
- rem/services/embeddings/api.py +20 -13
- rem/services/phoenix/EXPERIMENT_DESIGN.md +3 -3
- rem/services/phoenix/client.py +148 -14
- rem/services/postgres/schema_generator.py +86 -5
- rem/services/rate_limit.py +113 -0
- rem/services/rem/README.md +14 -0
- rem/services/user_service.py +98 -0
- rem/settings.py +79 -10
- rem/sql/install_models.sql +13 -0
- rem/sql/migrations/003_seed_default_user.sql +48 -0
- rem/utils/constants.py +97 -0
- rem/utils/date_utils.py +228 -0
- rem/utils/embeddings.py +17 -4
- rem/utils/files.py +167 -0
- rem/utils/mime_types.py +158 -0
- rem/utils/schema_loader.py +63 -14
- rem/utils/vision.py +9 -14
- rem/workers/README.md +14 -14
- rem/workers/db_maintainer.py +74 -0
- {remdb-0.3.7.dist-info → remdb-0.3.14.dist-info}/METADATA +169 -121
- {remdb-0.3.7.dist-info → remdb-0.3.14.dist-info}/RECORD +43 -32
- {remdb-0.3.7.dist-info → remdb-0.3.14.dist-info}/WHEEL +0 -0
- {remdb-0.3.7.dist-info → remdb-0.3.14.dist-info}/entry_points.txt +0 -0
rem/__init__.py
CHANGED
|
@@ -1,2 +1,129 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""
|
|
2
|
+
REM - Resources, Entities, Moments.
|
|
3
|
+
|
|
4
|
+
A bio-inspired memory system for agentic AI, built on FastAPI.
|
|
5
|
+
|
|
6
|
+
Usage (API mode):
|
|
7
|
+
from rem import create_app
|
|
8
|
+
|
|
9
|
+
# Create REM app (FastAPI with MCP server pre-configured)
|
|
10
|
+
app = create_app()
|
|
11
|
+
|
|
12
|
+
# Extend like any FastAPI app
|
|
13
|
+
@app.get("/my-endpoint")
|
|
14
|
+
async def my_endpoint():
|
|
15
|
+
return {"custom": True}
|
|
16
|
+
|
|
17
|
+
# Add routers
|
|
18
|
+
app.include_router(my_router)
|
|
19
|
+
|
|
20
|
+
# Access MCP server directly (FastMCP instance)
|
|
21
|
+
@app.mcp_server.tool()
|
|
22
|
+
async def my_custom_tool(query: str) -> dict:
|
|
23
|
+
'''Custom MCP tool for my application.'''
|
|
24
|
+
return {"result": "..."}
|
|
25
|
+
|
|
26
|
+
@app.mcp_server.resource("custom://config")
|
|
27
|
+
async def get_config() -> str:
|
|
28
|
+
'''Custom resource.'''
|
|
29
|
+
return '{"setting": "value"}'
|
|
30
|
+
|
|
31
|
+
Usage (model registration - works with or without API):
|
|
32
|
+
import rem
|
|
33
|
+
from rem.models.core import CoreModel
|
|
34
|
+
|
|
35
|
+
@rem.register_model
|
|
36
|
+
class CustomEntity(CoreModel):
|
|
37
|
+
name: str
|
|
38
|
+
custom_field: str
|
|
39
|
+
|
|
40
|
+
# Or register multiple:
|
|
41
|
+
rem.register_models(ModelA, ModelB)
|
|
42
|
+
|
|
43
|
+
# Then schema generation includes your models:
|
|
44
|
+
# rem db schema generate
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
from .registry import (
|
|
48
|
+
# Model registration
|
|
49
|
+
register_model,
|
|
50
|
+
register_models,
|
|
51
|
+
get_model_registry,
|
|
52
|
+
clear_model_registry,
|
|
53
|
+
# Schema path registration
|
|
54
|
+
register_schema_path,
|
|
55
|
+
register_schema_paths,
|
|
56
|
+
get_schema_paths,
|
|
57
|
+
get_schema_path_registry,
|
|
58
|
+
clear_schema_path_registry,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def create_app():
|
|
63
|
+
"""
|
|
64
|
+
Create and return a FastAPI application with REM features pre-configured.
|
|
65
|
+
|
|
66
|
+
The returned app has:
|
|
67
|
+
- MCP server mounted at /api/v1/mcp
|
|
68
|
+
- Chat completions endpoint at /api/v1/chat/completions
|
|
69
|
+
- Health check at /health
|
|
70
|
+
- OpenAPI docs at /docs
|
|
71
|
+
|
|
72
|
+
The app exposes `app.mcp_server` (FastMCP instance) for adding custom
|
|
73
|
+
tools, resources, and prompts.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
FastAPI application with .mcp_server attribute
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
from rem import create_app
|
|
80
|
+
|
|
81
|
+
app = create_app()
|
|
82
|
+
|
|
83
|
+
# Add custom endpoint
|
|
84
|
+
@app.get("/custom")
|
|
85
|
+
async def custom():
|
|
86
|
+
return {"custom": True}
|
|
87
|
+
|
|
88
|
+
# Add custom MCP tool
|
|
89
|
+
@app.mcp_server.tool()
|
|
90
|
+
async def my_tool(query: str) -> dict:
|
|
91
|
+
return {"result": query}
|
|
92
|
+
"""
|
|
93
|
+
from .api.main import create_app as _create_app
|
|
94
|
+
return _create_app()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# Lazy app instance - created on first access
|
|
98
|
+
_app = None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_app():
|
|
102
|
+
"""
|
|
103
|
+
Get or create the default REM app instance.
|
|
104
|
+
|
|
105
|
+
For most cases, use create_app() to get a fresh instance.
|
|
106
|
+
This is provided for convenience in simple scripts.
|
|
107
|
+
"""
|
|
108
|
+
global _app
|
|
109
|
+
if _app is None:
|
|
110
|
+
_app = create_app()
|
|
111
|
+
return _app
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
__all__ = [
|
|
115
|
+
# App creation
|
|
116
|
+
"create_app",
|
|
117
|
+
"get_app",
|
|
118
|
+
# Model registration
|
|
119
|
+
"register_model",
|
|
120
|
+
"register_models",
|
|
121
|
+
"get_model_registry",
|
|
122
|
+
"clear_model_registry",
|
|
123
|
+
# Schema path registration
|
|
124
|
+
"register_schema_path",
|
|
125
|
+
"register_schema_paths",
|
|
126
|
+
"get_schema_paths",
|
|
127
|
+
"get_schema_path_registry",
|
|
128
|
+
"clear_schema_path_registry",
|
|
129
|
+
]
|
rem/agentic/context.py
CHANGED
|
@@ -72,7 +72,7 @@ class AgentContext(BaseModel):
|
|
|
72
72
|
def get_user_id_or_default(
|
|
73
73
|
user_id: str | None,
|
|
74
74
|
source: str = "context",
|
|
75
|
-
default: str =
|
|
75
|
+
default: str | None = None,
|
|
76
76
|
) -> str:
|
|
77
77
|
"""
|
|
78
78
|
Get user_id or fallback to default with logging.
|
|
@@ -83,10 +83,10 @@ class AgentContext(BaseModel):
|
|
|
83
83
|
Args:
|
|
84
84
|
user_id: User identifier (may be None)
|
|
85
85
|
source: Source of the call (for logging clarity)
|
|
86
|
-
default: Default value to use (default:
|
|
86
|
+
default: Default value to use (default: settings.test.effective_user_id)
|
|
87
87
|
|
|
88
88
|
Returns:
|
|
89
|
-
user_id if provided, otherwise default
|
|
89
|
+
user_id if provided, otherwise default from settings
|
|
90
90
|
|
|
91
91
|
Example:
|
|
92
92
|
# In MCP tool
|
|
@@ -105,8 +105,10 @@ class AgentContext(BaseModel):
|
|
|
105
105
|
)
|
|
106
106
|
"""
|
|
107
107
|
if user_id is None:
|
|
108
|
-
|
|
109
|
-
|
|
108
|
+
from rem.settings import settings
|
|
109
|
+
effective_default = default or settings.test.effective_user_id
|
|
110
|
+
logger.debug(f"No user_id provided from {source}, using '{effective_default}'")
|
|
111
|
+
return effective_default
|
|
110
112
|
return user_id
|
|
111
113
|
|
|
112
114
|
@classmethod
|
rem/agentic/providers/phoenix.py
CHANGED
|
@@ -128,15 +128,16 @@ def sanitize_tool_name(tool_name: str) -> str:
|
|
|
128
128
|
|
|
129
129
|
|
|
130
130
|
def load_evaluator_schema(evaluator_name: str) -> dict[str, Any]:
|
|
131
|
-
"""Load evaluator schema
|
|
131
|
+
"""Load evaluator schema using centralized schema loader.
|
|
132
132
|
|
|
133
|
-
|
|
134
|
-
|
|
133
|
+
Uses the same unified search logic as agent schemas:
|
|
134
|
+
- "hello-world/default" → schemas/evaluators/hello-world/default.yaml
|
|
135
|
+
- "lookup-correctness" → schemas/evaluators/rem/lookup-correctness.yaml
|
|
136
|
+
- "rem-lookup-correctness" → schemas/evaluators/rem/lookup-correctness.yaml
|
|
135
137
|
|
|
136
138
|
Args:
|
|
137
|
-
evaluator_name: Evaluator name
|
|
138
|
-
e.g., "
|
|
139
|
-
"rem-lookup-correctness.yaml"
|
|
139
|
+
evaluator_name: Evaluator name or path
|
|
140
|
+
e.g., "hello-world/default", "lookup-correctness"
|
|
140
141
|
|
|
141
142
|
Returns:
|
|
142
143
|
Evaluator schema dictionary with keys:
|
|
@@ -150,43 +151,13 @@ def load_evaluator_schema(evaluator_name: str) -> dict[str, Any]:
|
|
|
150
151
|
FileNotFoundError: If evaluator schema not found
|
|
151
152
|
|
|
152
153
|
Example:
|
|
153
|
-
>>> schema = load_evaluator_schema("
|
|
154
|
+
>>> schema = load_evaluator_schema("hello-world/default")
|
|
154
155
|
>>> print(schema["description"])
|
|
155
156
|
"""
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
#
|
|
159
|
-
|
|
160
|
-
rem_module_dir = Path(rem.__file__).parent # rem/src/rem
|
|
161
|
-
rem_package_root = rem_module_dir.parent.parent # rem/src/rem -> rem/src -> rem
|
|
162
|
-
schema_dir = rem_package_root / "schemas" / "evaluators"
|
|
163
|
-
|
|
164
|
-
# Try .yaml first (preferred format)
|
|
165
|
-
yaml_path = schema_dir / f"{evaluator_name}.yaml"
|
|
166
|
-
if yaml_path.exists():
|
|
167
|
-
logger.debug(f"Loading evaluator schema from {yaml_path}")
|
|
168
|
-
with open(yaml_path) as f:
|
|
169
|
-
return yaml.safe_load(f)
|
|
170
|
-
|
|
171
|
-
# Try .yml
|
|
172
|
-
yml_path = schema_dir / f"{evaluator_name}.yml"
|
|
173
|
-
if yml_path.exists():
|
|
174
|
-
logger.debug(f"Loading evaluator schema from {yml_path}")
|
|
175
|
-
with open(yml_path) as f:
|
|
176
|
-
return yaml.safe_load(f)
|
|
177
|
-
|
|
178
|
-
# Try .json
|
|
179
|
-
json_path = schema_dir / f"{evaluator_name}.json"
|
|
180
|
-
if json_path.exists():
|
|
181
|
-
logger.debug(f"Loading evaluator schema from {json_path}")
|
|
182
|
-
with open(json_path) as f:
|
|
183
|
-
return json.load(f)
|
|
184
|
-
|
|
185
|
-
raise FileNotFoundError(
|
|
186
|
-
f"Evaluator schema not found: {evaluator_name}\n"
|
|
187
|
-
f"Searched in: {schema_dir}\n"
|
|
188
|
-
f"Supported formats: .yaml, .yml, .json"
|
|
189
|
-
)
|
|
157
|
+
from ...utils.schema_loader import load_agent_schema
|
|
158
|
+
|
|
159
|
+
# Use centralized schema loader (searches evaluator paths too)
|
|
160
|
+
return load_agent_schema(evaluator_name)
|
|
190
161
|
|
|
191
162
|
|
|
192
163
|
# =============================================================================
|
|
@@ -338,6 +309,22 @@ def create_evaluator_from_schema(
|
|
|
338
309
|
# Already a dict
|
|
339
310
|
schema = evaluator_schema_path
|
|
340
311
|
|
|
312
|
+
# Extract model from schema's provider_configs if not explicitly provided
|
|
313
|
+
if model_name is None:
|
|
314
|
+
json_schema_extra = schema.get("json_schema_extra", {})
|
|
315
|
+
provider_configs = json_schema_extra.get("provider_configs", [])
|
|
316
|
+
if provider_configs:
|
|
317
|
+
# Use first provider config
|
|
318
|
+
first_provider = provider_configs[0]
|
|
319
|
+
provider_name = first_provider.get("provider_name", "openai")
|
|
320
|
+
schema_model_name = first_provider.get("model_name", "gpt-4o-mini")
|
|
321
|
+
# Format as "provider:model" if not OpenAI (OpenAI is default)
|
|
322
|
+
if provider_name == "openai":
|
|
323
|
+
model_name = schema_model_name
|
|
324
|
+
else:
|
|
325
|
+
model_name = f"{provider_name}:{schema_model_name}"
|
|
326
|
+
logger.debug(f"Using model from schema provider_configs: {model_name}")
|
|
327
|
+
|
|
341
328
|
# Create evaluator config
|
|
342
329
|
evaluator_config = create_phoenix_evaluator(
|
|
343
330
|
evaluator_schema=schema,
|
|
@@ -361,7 +348,8 @@ def create_evaluator_from_schema(
|
|
|
361
348
|
Returns:
|
|
362
349
|
Evaluation result with score, label, explanation
|
|
363
350
|
"""
|
|
364
|
-
|
|
351
|
+
input_preview = str(example.get('input', ''))[:100]
|
|
352
|
+
logger.debug(f"Evaluating example: {input_preview}...")
|
|
365
353
|
|
|
366
354
|
# Phoenix llm_classify() expects a flat dict with string values
|
|
367
355
|
# Build evaluation input by flattening nested dicts
|
|
@@ -393,6 +381,7 @@ def create_evaluator_from_schema(
|
|
|
393
381
|
|
|
394
382
|
try:
|
|
395
383
|
# Create single-row DataFrame for llm_classify
|
|
384
|
+
# Note: Phoenix's llm_classify requires pandas DataFrame (imported above)
|
|
396
385
|
df = pd.DataFrame([eval_input])
|
|
397
386
|
|
|
398
387
|
# Call Phoenix llm_classify
|
|
@@ -404,7 +393,7 @@ def create_evaluator_from_schema(
|
|
|
404
393
|
provide_explanation=True,
|
|
405
394
|
)
|
|
406
395
|
|
|
407
|
-
# Extract result
|
|
396
|
+
# Extract result (results_df is pandas DataFrame from Phoenix)
|
|
408
397
|
if not results_df.empty:
|
|
409
398
|
row = results_df.iloc[0]
|
|
410
399
|
label = row.get("label", "error")
|
rem/api/README.md
CHANGED
|
@@ -392,6 +392,29 @@ Middleware runs in reverse order of addition:
|
|
|
392
392
|
|
|
393
393
|
## Error Responses
|
|
394
394
|
|
|
395
|
+
### 429 - Rate Limit Exceeded
|
|
396
|
+
|
|
397
|
+
When a user exceeds their rate limit (based on their tier), the API returns a 429 status code with a structured error body. The frontend should intercept this error to prompt the user to sign in or upgrade.
|
|
398
|
+
|
|
399
|
+
```json
|
|
400
|
+
{
|
|
401
|
+
"error": {
|
|
402
|
+
"code": "rate_limit_exceeded",
|
|
403
|
+
"message": "You have exceeded your rate limit. Please sign in or upgrade to continue.",
|
|
404
|
+
"details": {
|
|
405
|
+
"limit": 50,
|
|
406
|
+
"tier": "anonymous",
|
|
407
|
+
"retry_after": 60
|
|
408
|
+
}
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
```
|
|
412
|
+
|
|
413
|
+
**Handling Strategy:**
|
|
414
|
+
1. **Intercept 429s:** API client should listen for `status === 429`.
|
|
415
|
+
2. **Check Code:** If `error.code === 'rate_limit_exceeded'` AND `error.details.tier === 'anonymous'`, trigger "Login / Sign Up" flow.
|
|
416
|
+
3. **Authenticated Users:** If `tier !== 'anonymous'`, prompt to upgrade plan.
|
|
417
|
+
|
|
395
418
|
### 500 - Agent Schema Not Found
|
|
396
419
|
|
|
397
420
|
```json
|
rem/api/main.py
CHANGED
|
@@ -163,7 +163,22 @@ async def lifespan(app: FastAPI):
|
|
|
163
163
|
|
|
164
164
|
def create_app() -> FastAPI:
|
|
165
165
|
"""
|
|
166
|
-
Create and configure the FastAPI application.
|
|
166
|
+
Create and configure the FastAPI application with MCP server.
|
|
167
|
+
|
|
168
|
+
The returned app exposes `app.mcp_server` (FastMCP instance) for adding
|
|
169
|
+
custom tools, resources, and prompts:
|
|
170
|
+
|
|
171
|
+
app = create_app()
|
|
172
|
+
|
|
173
|
+
@app.mcp_server.tool()
|
|
174
|
+
async def my_tool(query: str) -> dict:
|
|
175
|
+
'''Custom MCP tool.'''
|
|
176
|
+
return {"result": query}
|
|
177
|
+
|
|
178
|
+
@app.mcp_server.resource("custom://data")
|
|
179
|
+
async def my_resource() -> str:
|
|
180
|
+
'''Custom resource.'''
|
|
181
|
+
return '{"data": "value"}'
|
|
167
182
|
|
|
168
183
|
Design Pattern:
|
|
169
184
|
1. Create MCP server
|
|
@@ -174,9 +189,10 @@ def create_app() -> FastAPI:
|
|
|
174
189
|
6. Define health endpoints
|
|
175
190
|
7. Register API routers
|
|
176
191
|
8. Mount MCP app
|
|
192
|
+
9. Expose mcp_server on app for extension
|
|
177
193
|
|
|
178
194
|
Returns:
|
|
179
|
-
Configured FastAPI application
|
|
195
|
+
Configured FastAPI application with .mcp_server attribute
|
|
180
196
|
"""
|
|
181
197
|
# Create MCP server and get HTTP app
|
|
182
198
|
# path="/" creates routes at root, then mount at /api/v1/mcp
|
|
@@ -228,6 +244,11 @@ def create_app() -> FastAPI:
|
|
|
228
244
|
|
|
229
245
|
# Add SSE buffering middleware (for MCP SSE transport)
|
|
230
246
|
app.add_middleware(SSEBufferingMiddleware)
|
|
247
|
+
|
|
248
|
+
# Add Anonymous Tracking & Rate Limiting (Runs AFTER Auth if Auth is enabled)
|
|
249
|
+
# Must be added BEFORE AuthMiddleware in code to be INNER in the stack
|
|
250
|
+
from .middleware.tracking import AnonymousTrackingMiddleware
|
|
251
|
+
app.add_middleware(AnonymousTrackingMiddleware)
|
|
231
252
|
|
|
232
253
|
# Add authentication middleware (if enabled)
|
|
233
254
|
if settings.auth.enabled:
|
|
@@ -305,6 +326,10 @@ def create_app() -> FastAPI:
|
|
|
305
326
|
# Mount MCP app at /api/v1/mcp
|
|
306
327
|
app.mount("/api/v1/mcp", mcp_app)
|
|
307
328
|
|
|
329
|
+
# Expose MCP server on app for extension
|
|
330
|
+
# Users can add tools/resources/prompts via app.mcp_server
|
|
331
|
+
app.mcp_server = mcp_server # type: ignore[attr-defined]
|
|
332
|
+
|
|
308
333
|
return app
|
|
309
334
|
|
|
310
335
|
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Anonymous User Tracking & Rate Limiting Middleware.
|
|
3
|
+
|
|
4
|
+
Handles:
|
|
5
|
+
1. Anonymous Identity: Generates/Validates 'rem_anon_id' cookie.
|
|
6
|
+
2. Context Injection: Sets request.state.anon_id.
|
|
7
|
+
3. Rate Limiting: Enforces tenant-aware tiered limits via RateLimitService.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import hmac
|
|
11
|
+
import hashlib
|
|
12
|
+
import uuid
|
|
13
|
+
import secrets
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
from fastapi import Request, Response
|
|
17
|
+
from fastapi.responses import JSONResponse
|
|
18
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
19
|
+
from starlette.types import ASGIApp
|
|
20
|
+
|
|
21
|
+
from ...services.postgres.service import PostgresService
|
|
22
|
+
from ...services.rate_limit import RateLimitService
|
|
23
|
+
from ...models.entities.user import UserTier
|
|
24
|
+
from ...settings import settings
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AnonymousTrackingMiddleware(BaseHTTPMiddleware):
|
|
28
|
+
"""
|
|
29
|
+
Middleware for anonymous user tracking and rate limiting.
|
|
30
|
+
|
|
31
|
+
Design Pattern:
|
|
32
|
+
- Uses a secure, signed cookie for anonymous ID.
|
|
33
|
+
- Enforces rate limits before request processing.
|
|
34
|
+
- Injects anon_id into request state.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, app: ASGIApp):
|
|
38
|
+
super().__init__(app)
|
|
39
|
+
# Secret for signing cookies (should be in settings, fallback for safety)
|
|
40
|
+
self.secret_key = settings.auth.session_secret or "fallback-secret-change-me"
|
|
41
|
+
self.cookie_name = "rem_anon_id"
|
|
42
|
+
|
|
43
|
+
# Dedicated DB service for this middleware (one pool per app instance)
|
|
44
|
+
self.db = PostgresService()
|
|
45
|
+
self.rate_limiter = RateLimitService(self.db)
|
|
46
|
+
|
|
47
|
+
# Excluded paths (health checks, static assets, auth callbacks)
|
|
48
|
+
self.excluded_paths = {
|
|
49
|
+
"/health",
|
|
50
|
+
"/docs",
|
|
51
|
+
"/openapi.json",
|
|
52
|
+
"/favicon.ico",
|
|
53
|
+
"/api/auth", # Don't rate limit auth flow heavily
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
async def dispatch(self, request: Request, call_next):
|
|
57
|
+
# 0. Skip excluded paths
|
|
58
|
+
if any(request.url.path.startswith(p) for p in self.excluded_paths):
|
|
59
|
+
return await call_next(request)
|
|
60
|
+
|
|
61
|
+
# 1. Lazy DB Connection
|
|
62
|
+
if not self.db.pool:
|
|
63
|
+
# Note: simple lazy init. In high concurrency startup, might trigger multiple connects
|
|
64
|
+
# followed by disconnects, but asyncpg pool handles this gracefully usually.
|
|
65
|
+
# Ideally hook into lifespan, but middleware is separate.
|
|
66
|
+
if settings.postgres.enabled:
|
|
67
|
+
await self.db.connect()
|
|
68
|
+
|
|
69
|
+
# 2. Identification (Cookie Strategy)
|
|
70
|
+
anon_id = request.cookies.get(self.cookie_name)
|
|
71
|
+
is_new_anon = False
|
|
72
|
+
|
|
73
|
+
if not anon_id or not self._validate_signature(anon_id):
|
|
74
|
+
anon_id = self._generate_signed_id()
|
|
75
|
+
is_new_anon = True
|
|
76
|
+
|
|
77
|
+
# Strip signature for internal use
|
|
78
|
+
raw_anon_id = anon_id.split(".")[0]
|
|
79
|
+
request.state.anon_id = raw_anon_id
|
|
80
|
+
|
|
81
|
+
# 3. Determine User Tier & ID for Rate Limiting
|
|
82
|
+
# Check if user is authenticated (set by AuthMiddleware usually, but that runs AFTER?)
|
|
83
|
+
# Actually middleware runs in reverse order of addition.
|
|
84
|
+
# If AuthMiddleware adds user to request.session, we might need to access session directly.
|
|
85
|
+
# request.user is standard.
|
|
86
|
+
|
|
87
|
+
user = getattr(request.state, "user", None)
|
|
88
|
+
if user:
|
|
89
|
+
# Authenticated User
|
|
90
|
+
identifier = user.get("id") # Assuming user dict or object
|
|
91
|
+
# Determine tier from user object
|
|
92
|
+
tier_str = user.get("tier", UserTier.FREE.value)
|
|
93
|
+
try:
|
|
94
|
+
tier = UserTier(tier_str)
|
|
95
|
+
except ValueError:
|
|
96
|
+
tier = UserTier.FREE
|
|
97
|
+
tenant_id = user.get("tenant_id", "default")
|
|
98
|
+
else:
|
|
99
|
+
# Anonymous User
|
|
100
|
+
identifier = raw_anon_id
|
|
101
|
+
tier = UserTier.ANONYMOUS
|
|
102
|
+
# Tenant ID from header or default
|
|
103
|
+
tenant_id = request.headers.get("X-Tenant-Id", "default")
|
|
104
|
+
|
|
105
|
+
# 4. Rate Limiting
|
|
106
|
+
if settings.postgres.enabled:
|
|
107
|
+
is_allowed, current, limit = await self.rate_limiter.check_rate_limit(
|
|
108
|
+
tenant_id=tenant_id,
|
|
109
|
+
identifier=identifier,
|
|
110
|
+
tier=tier
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if not is_allowed:
|
|
114
|
+
return JSONResponse(
|
|
115
|
+
status_code=429,
|
|
116
|
+
content={
|
|
117
|
+
"error": {
|
|
118
|
+
"code": "rate_limit_exceeded",
|
|
119
|
+
"message": "You have exceeded your rate limit. Please sign in or upgrade to continue.",
|
|
120
|
+
"details": {
|
|
121
|
+
"limit": limit,
|
|
122
|
+
"tier": tier.value,
|
|
123
|
+
"retry_after": 60
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
},
|
|
127
|
+
headers={"Retry-After": "60"}
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# 5. Process Request
|
|
131
|
+
response = await call_next(request)
|
|
132
|
+
|
|
133
|
+
# 6. Set Cookie if new
|
|
134
|
+
if is_new_anon:
|
|
135
|
+
response.set_cookie(
|
|
136
|
+
key=self.cookie_name,
|
|
137
|
+
value=anon_id,
|
|
138
|
+
max_age=31536000, # 1 year
|
|
139
|
+
httponly=True,
|
|
140
|
+
samesite="lax",
|
|
141
|
+
secure=settings.environment == "production"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Add Rate Limit headers
|
|
145
|
+
if settings.postgres.enabled and 'limit' in locals():
|
|
146
|
+
response.headers["X-RateLimit-Limit"] = str(limit)
|
|
147
|
+
response.headers["X-RateLimit-Remaining"] = str(max(0, limit - current))
|
|
148
|
+
|
|
149
|
+
return response
|
|
150
|
+
|
|
151
|
+
def _generate_signed_id(self) -> str:
|
|
152
|
+
"""Generate a UUID4 signed with HMAC."""
|
|
153
|
+
val = str(uuid.uuid4())
|
|
154
|
+
sig = hmac.new(
|
|
155
|
+
self.secret_key.encode(),
|
|
156
|
+
val.encode(),
|
|
157
|
+
hashlib.sha256
|
|
158
|
+
).hexdigest()[:12] # Short signature
|
|
159
|
+
return f"{val}.{sig}"
|
|
160
|
+
|
|
161
|
+
def _validate_signature(self, signed_val: str) -> bool:
|
|
162
|
+
"""Validate the HMAC signature."""
|
|
163
|
+
try:
|
|
164
|
+
val, sig = signed_val.split(".")
|
|
165
|
+
expected_sig = hmac.new(
|
|
166
|
+
self.secret_key.encode(),
|
|
167
|
+
val.encode(),
|
|
168
|
+
hashlib.sha256
|
|
169
|
+
).hexdigest()[:12]
|
|
170
|
+
return secrets.compare_digest(sig, expected_sig)
|
|
171
|
+
except ValueError:
|
|
172
|
+
return False
|
rem/api/routers/auth.py
CHANGED
|
@@ -49,6 +49,8 @@ from authlib.integrations.starlette_client import OAuth
|
|
|
49
49
|
from loguru import logger
|
|
50
50
|
|
|
51
51
|
from ...settings import settings
|
|
52
|
+
from ...services.postgres.service import PostgresService
|
|
53
|
+
from ...services.user_service import UserService
|
|
52
54
|
|
|
53
55
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
|
54
56
|
|
|
@@ -168,6 +170,53 @@ async def callback(provider: str, request: Request):
|
|
|
168
170
|
if not user_info:
|
|
169
171
|
# Fetch from userinfo endpoint if not in ID token
|
|
170
172
|
user_info = await client.userinfo(token=token)
|
|
173
|
+
|
|
174
|
+
# --- REM Integration Start ---
|
|
175
|
+
if settings.postgres.enabled:
|
|
176
|
+
# Connect to DB
|
|
177
|
+
db = PostgresService()
|
|
178
|
+
try:
|
|
179
|
+
await db.connect()
|
|
180
|
+
user_service = UserService(db)
|
|
181
|
+
|
|
182
|
+
# Get/Create User
|
|
183
|
+
user_entity = await user_service.get_or_create_user(
|
|
184
|
+
email=user_info.get("email"),
|
|
185
|
+
name=user_info.get("name", "New User"),
|
|
186
|
+
avatar_url=user_info.get("picture"),
|
|
187
|
+
tenant_id="default", # Single tenant for now
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Link Anonymous Session
|
|
191
|
+
# TrackingMiddleware sets request.state.anon_id
|
|
192
|
+
anon_id = getattr(request.state, "anon_id", None)
|
|
193
|
+
# Fallback to cookie if middleware didn't run or state missing
|
|
194
|
+
if not anon_id:
|
|
195
|
+
# Attempt to parse cookie manually if needed, but middleware
|
|
196
|
+
# usually handles the signature logic.
|
|
197
|
+
# Just check raw cookie for simple case (not recommended if signed)
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
if anon_id:
|
|
201
|
+
await user_service.link_anonymous_session(user_entity, anon_id)
|
|
202
|
+
|
|
203
|
+
# Enrich session user with DB info
|
|
204
|
+
db_info = {
|
|
205
|
+
"id": str(user_entity.id),
|
|
206
|
+
"tenant_id": user_entity.tenant_id,
|
|
207
|
+
"tier": user_entity.tier.value if user_entity.tier else "free",
|
|
208
|
+
"roles": [user_entity.role] if user_entity.role else [],
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
except Exception as db_e:
|
|
212
|
+
logger.error(f"Database error during auth callback: {db_e}")
|
|
213
|
+
# Continue login even if DB fails, but warn
|
|
214
|
+
db_info = {"id": "db_error", "tier": "free"}
|
|
215
|
+
finally:
|
|
216
|
+
await db.disconnect()
|
|
217
|
+
else:
|
|
218
|
+
db_info = {"id": "no_db", "tier": "free"}
|
|
219
|
+
# --- REM Integration End ---
|
|
171
220
|
|
|
172
221
|
# Store user info in session
|
|
173
222
|
request.session["user"] = {
|
|
@@ -176,6 +225,11 @@ async def callback(provider: str, request: Request):
|
|
|
176
225
|
"email": user_info.get("email"),
|
|
177
226
|
"name": user_info.get("name"),
|
|
178
227
|
"picture": user_info.get("picture"),
|
|
228
|
+
# Add DB info
|
|
229
|
+
"id": db_info.get("id"),
|
|
230
|
+
"tenant_id": db_info.get("tenant_id", "default"),
|
|
231
|
+
"tier": db_info.get("tier"),
|
|
232
|
+
"roles": db_info.get("roles", []),
|
|
179
233
|
}
|
|
180
234
|
|
|
181
235
|
# Store tokens in session for API access
|
|
@@ -251,7 +251,7 @@ async def chat_completions(body: ChatCompletionRequest, request: Request):
|
|
|
251
251
|
}
|
|
252
252
|
|
|
253
253
|
# Store messages with compression
|
|
254
|
-
store = SessionMessageStore(user_id=context.user_id or
|
|
254
|
+
store = SessionMessageStore(user_id=context.user_id or settings.test.effective_user_id)
|
|
255
255
|
|
|
256
256
|
await store.store_session_messages(
|
|
257
257
|
session_id=context.session_id,
|