remdb 0.2.6__py3-none-any.whl → 0.3.103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of remdb might be problematic. Click here for more details.

Files changed (82) hide show
  1. rem/__init__.py +129 -2
  2. rem/agentic/README.md +76 -0
  3. rem/agentic/__init__.py +15 -0
  4. rem/agentic/agents/__init__.py +16 -2
  5. rem/agentic/agents/sse_simulator.py +500 -0
  6. rem/agentic/context.py +7 -5
  7. rem/agentic/llm_provider_models.py +301 -0
  8. rem/agentic/providers/phoenix.py +32 -43
  9. rem/agentic/providers/pydantic_ai.py +84 -10
  10. rem/api/README.md +238 -1
  11. rem/api/deps.py +255 -0
  12. rem/api/main.py +70 -22
  13. rem/api/mcp_router/server.py +8 -1
  14. rem/api/mcp_router/tools.py +80 -0
  15. rem/api/middleware/tracking.py +172 -0
  16. rem/api/routers/admin.py +277 -0
  17. rem/api/routers/auth.py +124 -0
  18. rem/api/routers/chat/completions.py +123 -14
  19. rem/api/routers/chat/models.py +7 -3
  20. rem/api/routers/chat/sse_events.py +526 -0
  21. rem/api/routers/chat/streaming.py +468 -45
  22. rem/api/routers/dev.py +81 -0
  23. rem/api/routers/feedback.py +455 -0
  24. rem/api/routers/messages.py +473 -0
  25. rem/api/routers/models.py +78 -0
  26. rem/api/routers/shared_sessions.py +406 -0
  27. rem/auth/middleware.py +126 -27
  28. rem/cli/commands/ask.py +15 -11
  29. rem/cli/commands/configure.py +169 -94
  30. rem/cli/commands/db.py +53 -7
  31. rem/cli/commands/experiments.py +278 -96
  32. rem/cli/commands/process.py +8 -7
  33. rem/cli/commands/scaffold.py +47 -0
  34. rem/cli/commands/schema.py +9 -9
  35. rem/cli/main.py +10 -0
  36. rem/config.py +2 -2
  37. rem/models/core/core_model.py +7 -1
  38. rem/models/entities/__init__.py +21 -0
  39. rem/models/entities/domain_resource.py +38 -0
  40. rem/models/entities/feedback.py +123 -0
  41. rem/models/entities/message.py +30 -1
  42. rem/models/entities/session.py +83 -0
  43. rem/models/entities/shared_session.py +206 -0
  44. rem/models/entities/user.py +10 -3
  45. rem/registry.py +367 -0
  46. rem/schemas/agents/rem.yaml +7 -3
  47. rem/services/content/providers.py +94 -140
  48. rem/services/content/service.py +85 -16
  49. rem/services/dreaming/affinity_service.py +2 -16
  50. rem/services/dreaming/moment_service.py +2 -15
  51. rem/services/embeddings/api.py +20 -13
  52. rem/services/phoenix/EXPERIMENT_DESIGN.md +3 -3
  53. rem/services/phoenix/client.py +252 -19
  54. rem/services/postgres/README.md +29 -10
  55. rem/services/postgres/repository.py +132 -0
  56. rem/services/postgres/schema_generator.py +86 -5
  57. rem/services/rate_limit.py +113 -0
  58. rem/services/rem/README.md +14 -0
  59. rem/services/session/compression.py +17 -1
  60. rem/services/user_service.py +98 -0
  61. rem/settings.py +115 -17
  62. rem/sql/background_indexes.sql +10 -0
  63. rem/sql/migrations/001_install.sql +152 -2
  64. rem/sql/migrations/002_install_models.sql +580 -231
  65. rem/sql/migrations/003_seed_default_user.sql +48 -0
  66. rem/utils/constants.py +97 -0
  67. rem/utils/date_utils.py +228 -0
  68. rem/utils/embeddings.py +17 -4
  69. rem/utils/files.py +167 -0
  70. rem/utils/mime_types.py +158 -0
  71. rem/utils/model_helpers.py +156 -1
  72. rem/utils/schema_loader.py +273 -14
  73. rem/utils/sql_types.py +3 -1
  74. rem/utils/vision.py +9 -14
  75. rem/workers/README.md +14 -14
  76. rem/workers/db_maintainer.py +74 -0
  77. {remdb-0.2.6.dist-info → remdb-0.3.103.dist-info}/METADATA +486 -132
  78. {remdb-0.2.6.dist-info → remdb-0.3.103.dist-info}/RECORD +80 -57
  79. {remdb-0.2.6.dist-info → remdb-0.3.103.dist-info}/WHEEL +1 -1
  80. rem/sql/002_install_models.sql +0 -1068
  81. rem/sql/install_models.sql +0 -1038
  82. {remdb-0.2.6.dist-info → remdb-0.3.103.dist-info}/entry_points.txt +0 -0
@@ -422,6 +422,7 @@ async def ingest_into_rem(
422
422
  tags: list[str] | None = None,
423
423
  is_local_server: bool = False,
424
424
  user_id: str | None = None,
425
+ resource_type: str | None = None,
425
426
  ) -> dict[str, Any]:
426
427
  """
427
428
  Ingest file into REM, creating searchable resources and embeddings.
@@ -448,6 +449,11 @@ async def ingest_into_rem(
448
449
  tags: Optional tags for file
449
450
  is_local_server: True if running as local/stdio MCP server
450
451
  user_id: Optional user identifier (defaults to authenticated user or "default")
452
+ resource_type: Optional resource type for storing chunks (case-insensitive).
453
+ Supports flexible naming:
454
+ - "resource", "resources", "Resource" → Resource (default)
455
+ - "domain-resource", "domain_resource", "DomainResource",
456
+ "domain-resources" → DomainResource (curated internal knowledge)
451
457
 
452
458
  Returns:
453
459
  Dict with:
@@ -478,6 +484,13 @@ async def ingest_into_rem(
478
484
  file_uri="https://example.com/whitepaper.pdf",
479
485
  tags=["research", "whitepaper"]
480
486
  )
487
+
488
+ # Ingest as curated domain knowledge
489
+ ingest_into_rem(
490
+ file_uri="s3://bucket/internal/procedures.pdf",
491
+ resource_type="domain-resource",
492
+ category="procedures"
493
+ )
481
494
  """
482
495
  from ...services.content import ContentService
483
496
 
@@ -493,6 +506,7 @@ async def ingest_into_rem(
493
506
  category=category,
494
507
  tags=tags,
495
508
  is_local_server=is_local_server,
509
+ resource_type=resource_type,
496
510
  )
497
511
 
498
512
  logger.info(
@@ -582,3 +596,69 @@ async def read_resource(uri: str) -> dict[str, Any]:
582
596
  "uri": uri,
583
597
  "data": {"content": result},
584
598
  }
599
+
600
+
601
+ async def register_metadata(
602
+ confidence: float | None = None,
603
+ references: list[str] | None = None,
604
+ sources: list[str] | None = None,
605
+ flags: list[str] | None = None,
606
+ ) -> dict[str, Any]:
607
+ """
608
+ Register response metadata to be emitted as an SSE MetadataEvent.
609
+
610
+ Call this tool BEFORE generating your final response to provide structured
611
+ metadata that will be sent to the client alongside your natural language output.
612
+ This allows you to stream conversational responses while still providing
613
+ machine-readable confidence scores, references, and other metadata.
614
+
615
+ **Design Pattern**: Agents can call this once before their final response to
616
+ register metadata that the streaming layer will emit as a MetadataEvent.
617
+ This decouples structured metadata from the response format.
618
+
619
+ Args:
620
+ confidence: Confidence score (0.0-1.0) for the response quality.
621
+ - 0.9-1.0: High confidence, answer is well-supported
622
+ - 0.7-0.9: Medium confidence, some uncertainty
623
+ - 0.5-0.7: Low confidence, significant gaps
624
+ - <0.5: Very uncertain, may need clarification
625
+ references: List of reference identifiers (file paths, document IDs,
626
+ entity labels) that support the response.
627
+ sources: List of source descriptions (e.g., "REM database",
628
+ "search results", "user context").
629
+ flags: Optional flags for the response (e.g., "needs_review",
630
+ "uncertain", "incomplete").
631
+
632
+ Returns:
633
+ Dict with:
634
+ - status: "success"
635
+ - _metadata_event: True (marker for streaming layer)
636
+ - confidence, references, sources, flags: The registered values
637
+
638
+ Examples:
639
+ # High confidence answer with references
640
+ register_metadata(
641
+ confidence=0.95,
642
+ references=["sarah-chen", "q3-report-2024"],
643
+ sources=["REM database lookup"]
644
+ )
645
+
646
+ # Lower confidence with flags
647
+ register_metadata(
648
+ confidence=0.65,
649
+ flags=["needs_review", "incomplete_data"]
650
+ )
651
+ """
652
+ logger.info(
653
+ f"📊 Registering metadata: confidence={confidence}, "
654
+ f"refs={len(references or [])}, sources={len(sources or [])}"
655
+ )
656
+
657
+ return {
658
+ "status": "success",
659
+ "_metadata_event": True, # Marker for streaming layer
660
+ "confidence": confidence,
661
+ "references": references,
662
+ "sources": sources,
663
+ "flags": flags,
664
+ }
@@ -0,0 +1,172 @@
1
+ """
2
+ Anonymous User Tracking & Rate Limiting Middleware.
3
+
4
+ Handles:
5
+ 1. Anonymous Identity: Generates/Validates 'rem_anon_id' cookie.
6
+ 2. Context Injection: Sets request.state.anon_id.
7
+ 3. Rate Limiting: Enforces tenant-aware tiered limits via RateLimitService.
8
+ """
9
+
10
+ import hmac
11
+ import hashlib
12
+ import uuid
13
+ import secrets
14
+ from typing import Optional
15
+
16
+ from fastapi import Request, Response
17
+ from fastapi.responses import JSONResponse
18
+ from starlette.middleware.base import BaseHTTPMiddleware
19
+ from starlette.types import ASGIApp
20
+
21
+ from ...services.postgres.service import PostgresService
22
+ from ...services.rate_limit import RateLimitService
23
+ from ...models.entities.user import UserTier
24
+ from ...settings import settings
25
+
26
+
27
+ class AnonymousTrackingMiddleware(BaseHTTPMiddleware):
28
+ """
29
+ Middleware for anonymous user tracking and rate limiting.
30
+
31
+ Design Pattern:
32
+ - Uses a secure, signed cookie for anonymous ID.
33
+ - Enforces rate limits before request processing.
34
+ - Injects anon_id into request state.
35
+ """
36
+
37
+ def __init__(self, app: ASGIApp):
38
+ super().__init__(app)
39
+ # Secret for signing cookies (should be in settings, fallback for safety)
40
+ self.secret_key = settings.auth.session_secret or "fallback-secret-change-me"
41
+ self.cookie_name = "rem_anon_id"
42
+
43
+ # Dedicated DB service for this middleware (one pool per app instance)
44
+ self.db = PostgresService()
45
+ self.rate_limiter = RateLimitService(self.db)
46
+
47
+ # Excluded paths (health checks, static assets, auth callbacks)
48
+ self.excluded_paths = {
49
+ "/health",
50
+ "/docs",
51
+ "/openapi.json",
52
+ "/favicon.ico",
53
+ "/api/auth", # Don't rate limit auth flow heavily
54
+ }
55
+
56
+ async def dispatch(self, request: Request, call_next):
57
+ # 0. Skip excluded paths
58
+ if any(request.url.path.startswith(p) for p in self.excluded_paths):
59
+ return await call_next(request)
60
+
61
+ # 1. Lazy DB Connection
62
+ if not self.db.pool:
63
+ # Note: simple lazy init. In high concurrency startup, might trigger multiple connects
64
+ # followed by disconnects, but asyncpg pool handles this gracefully usually.
65
+ # Ideally hook into lifespan, but middleware is separate.
66
+ if settings.postgres.enabled:
67
+ await self.db.connect()
68
+
69
+ # 2. Identification (Cookie Strategy)
70
+ anon_id = request.cookies.get(self.cookie_name)
71
+ is_new_anon = False
72
+
73
+ if not anon_id or not self._validate_signature(anon_id):
74
+ anon_id = self._generate_signed_id()
75
+ is_new_anon = True
76
+
77
+ # Strip signature for internal use
78
+ raw_anon_id = anon_id.split(".")[0]
79
+ request.state.anon_id = raw_anon_id
80
+
81
+ # 3. Determine User Tier & ID for Rate Limiting
82
+ # Check if user is authenticated (set by AuthMiddleware usually, but that runs AFTER?)
83
+ # Actually middleware runs in reverse order of addition.
84
+ # If AuthMiddleware adds user to request.session, we might need to access session directly.
85
+ # request.user is standard.
86
+
87
+ user = getattr(request.state, "user", None)
88
+ if user:
89
+ # Authenticated User
90
+ identifier = user.get("id") # Assuming user dict or object
91
+ # Determine tier from user object
92
+ tier_str = user.get("tier", UserTier.FREE.value)
93
+ try:
94
+ tier = UserTier(tier_str)
95
+ except ValueError:
96
+ tier = UserTier.FREE
97
+ tenant_id = user.get("tenant_id", "default")
98
+ else:
99
+ # Anonymous User
100
+ identifier = raw_anon_id
101
+ tier = UserTier.ANONYMOUS
102
+ # Tenant ID from header or default
103
+ tenant_id = request.headers.get("X-Tenant-Id", "default")
104
+
105
+ # 4. Rate Limiting
106
+ if settings.postgres.enabled:
107
+ is_allowed, current, limit = await self.rate_limiter.check_rate_limit(
108
+ tenant_id=tenant_id,
109
+ identifier=identifier,
110
+ tier=tier
111
+ )
112
+
113
+ if not is_allowed:
114
+ return JSONResponse(
115
+ status_code=429,
116
+ content={
117
+ "error": {
118
+ "code": "rate_limit_exceeded",
119
+ "message": "You have exceeded your rate limit. Please sign in or upgrade to continue.",
120
+ "details": {
121
+ "limit": limit,
122
+ "tier": tier.value,
123
+ "retry_after": 60
124
+ }
125
+ }
126
+ },
127
+ headers={"Retry-After": "60"}
128
+ )
129
+
130
+ # 5. Process Request
131
+ response = await call_next(request)
132
+
133
+ # 6. Set Cookie if new
134
+ if is_new_anon:
135
+ response.set_cookie(
136
+ key=self.cookie_name,
137
+ value=anon_id,
138
+ max_age=31536000, # 1 year
139
+ httponly=True,
140
+ samesite="lax",
141
+ secure=settings.environment == "production"
142
+ )
143
+
144
+ # Add Rate Limit headers
145
+ if settings.postgres.enabled and 'limit' in locals():
146
+ response.headers["X-RateLimit-Limit"] = str(limit)
147
+ response.headers["X-RateLimit-Remaining"] = str(max(0, limit - current))
148
+
149
+ return response
150
+
151
+ def _generate_signed_id(self) -> str:
152
+ """Generate a UUID4 signed with HMAC."""
153
+ val = str(uuid.uuid4())
154
+ sig = hmac.new(
155
+ self.secret_key.encode(),
156
+ val.encode(),
157
+ hashlib.sha256
158
+ ).hexdigest()[:12] # Short signature
159
+ return f"{val}.{sig}"
160
+
161
+ def _validate_signature(self, signed_val: str) -> bool:
162
+ """Validate the HMAC signature."""
163
+ try:
164
+ val, sig = signed_val.split(".")
165
+ expected_sig = hmac.new(
166
+ self.secret_key.encode(),
167
+ val.encode(),
168
+ hashlib.sha256
169
+ ).hexdigest()[:12]
170
+ return secrets.compare_digest(sig, expected_sig)
171
+ except ValueError:
172
+ return False
@@ -0,0 +1,277 @@
1
+ """
2
+ Admin API Router.
3
+
4
+ Protected endpoints requiring admin role for system management tasks.
5
+
6
+ Endpoints:
7
+ GET /api/admin/users - List all users (admin only)
8
+ GET /api/admin/sessions - List all sessions across users (admin only)
9
+ GET /api/admin/messages - List all messages across users (admin only)
10
+ GET /api/admin/stats - System statistics (admin only)
11
+
12
+ All endpoints require:
13
+ 1. Authentication (valid session)
14
+ 2. Admin role in user's roles list
15
+
16
+ Design Pattern:
17
+ - Uses require_admin dependency for role enforcement
18
+ - Cross-tenant queries (no user_id filtering)
19
+ - Audit logging for admin actions
20
+ """
21
+
22
+ from typing import Literal
23
+
24
+ from fastapi import APIRouter, Depends, HTTPException, Query
25
+ from loguru import logger
26
+ from pydantic import BaseModel
27
+
28
+ from ..deps import require_admin
29
+ from ...models.entities import Message, Session, SessionMode
30
+ from ...services.postgres import Repository
31
+ from ...settings import settings
32
+
33
+ router = APIRouter(prefix="/api/admin", tags=["admin"])
34
+
35
+
36
+ # =============================================================================
37
+ # Response Models
38
+ # =============================================================================
39
+
40
+
41
+ class UserSummary(BaseModel):
42
+ """User summary for admin listing."""
43
+
44
+ id: str
45
+ email: str | None
46
+ name: str | None
47
+ tier: str
48
+ role: str | None
49
+ created_at: str | None
50
+
51
+
52
+ class UserListResponse(BaseModel):
53
+ """Response for user list endpoint."""
54
+
55
+ object: Literal["list"] = "list"
56
+ data: list[UserSummary]
57
+ total: int
58
+ has_more: bool
59
+
60
+
61
+ class SessionListResponse(BaseModel):
62
+ """Response for session list endpoint."""
63
+
64
+ object: Literal["list"] = "list"
65
+ data: list[Session]
66
+ total: int
67
+ has_more: bool
68
+
69
+
70
+ class MessageListResponse(BaseModel):
71
+ """Response for message list endpoint."""
72
+
73
+ object: Literal["list"] = "list"
74
+ data: list[Message]
75
+ total: int
76
+ has_more: bool
77
+
78
+
79
+ class SystemStats(BaseModel):
80
+ """System statistics for admin dashboard."""
81
+
82
+ total_users: int
83
+ total_sessions: int
84
+ total_messages: int
85
+ active_sessions_24h: int
86
+ messages_24h: int
87
+
88
+
89
+ # =============================================================================
90
+ # Admin Endpoints
91
+ # =============================================================================
92
+
93
+
94
+ @router.get("/users", response_model=UserListResponse)
95
+ async def list_all_users(
96
+ user: dict = Depends(require_admin),
97
+ limit: int = Query(default=50, ge=1, le=100),
98
+ offset: int = Query(default=0, ge=0),
99
+ ) -> UserListResponse:
100
+ """
101
+ List all users in the system.
102
+
103
+ Admin-only endpoint for user management.
104
+ Returns users across all tenants.
105
+ """
106
+ if not settings.postgres.enabled:
107
+ raise HTTPException(status_code=503, detail="Database not enabled")
108
+
109
+ logger.info(f"Admin {user.get('email')} listing all users")
110
+
111
+ # Import User model dynamically to avoid circular imports
112
+ from ...models.entities import User
113
+
114
+ repo = Repository(User, table_name="users")
115
+
116
+ # No tenant filter - admin sees all
117
+ users = await repo.find(
118
+ filters={},
119
+ order_by="created_at DESC",
120
+ limit=limit + 1,
121
+ offset=offset,
122
+ )
123
+
124
+ has_more = len(users) > limit
125
+ if has_more:
126
+ users = users[:limit]
127
+
128
+ total = await repo.count({})
129
+
130
+ # Convert to summary format
131
+ summaries = [
132
+ UserSummary(
133
+ id=str(u.id),
134
+ email=u.email,
135
+ name=u.name,
136
+ tier=u.tier.value if u.tier else "free",
137
+ role=u.role,
138
+ created_at=u.created_at.isoformat() if u.created_at else None,
139
+ )
140
+ for u in users
141
+ ]
142
+
143
+ return UserListResponse(data=summaries, total=total, has_more=has_more)
144
+
145
+
146
+ @router.get("/sessions", response_model=SessionListResponse)
147
+ async def list_all_sessions(
148
+ user: dict = Depends(require_admin),
149
+ user_id: str | None = Query(default=None, description="Filter by user ID"),
150
+ mode: SessionMode | None = Query(default=None, description="Filter by mode"),
151
+ limit: int = Query(default=50, ge=1, le=100),
152
+ offset: int = Query(default=0, ge=0),
153
+ ) -> SessionListResponse:
154
+ """
155
+ List all sessions across all users.
156
+
157
+ Admin-only endpoint for session monitoring.
158
+ Can optionally filter by user_id or mode.
159
+ """
160
+ if not settings.postgres.enabled:
161
+ raise HTTPException(status_code=503, detail="Database not enabled")
162
+
163
+ logger.info(
164
+ f"Admin {user.get('email')} listing sessions "
165
+ f"(user_id={user_id}, mode={mode})"
166
+ )
167
+
168
+ repo = Repository(Session, table_name="sessions")
169
+
170
+ # Build optional filters
171
+ filters: dict = {}
172
+ if user_id:
173
+ filters["user_id"] = user_id
174
+ if mode:
175
+ filters["mode"] = mode.value
176
+
177
+ sessions = await repo.find(
178
+ filters=filters,
179
+ order_by="created_at DESC",
180
+ limit=limit + 1,
181
+ offset=offset,
182
+ )
183
+
184
+ has_more = len(sessions) > limit
185
+ if has_more:
186
+ sessions = sessions[:limit]
187
+
188
+ total = await repo.count(filters)
189
+
190
+ return SessionListResponse(data=sessions, total=total, has_more=has_more)
191
+
192
+
193
+ @router.get("/messages", response_model=MessageListResponse)
194
+ async def list_all_messages(
195
+ user: dict = Depends(require_admin),
196
+ user_id: str | None = Query(default=None, description="Filter by user ID"),
197
+ session_id: str | None = Query(default=None, description="Filter by session ID"),
198
+ message_type: str | None = Query(default=None, description="Filter by type"),
199
+ limit: int = Query(default=50, ge=1, le=100),
200
+ offset: int = Query(default=0, ge=0),
201
+ ) -> MessageListResponse:
202
+ """
203
+ List all messages across all users.
204
+
205
+ Admin-only endpoint for message auditing.
206
+ Can filter by user_id, session_id, or message_type.
207
+ """
208
+ if not settings.postgres.enabled:
209
+ raise HTTPException(status_code=503, detail="Database not enabled")
210
+
211
+ logger.info(
212
+ f"Admin {user.get('email')} listing messages "
213
+ f"(user_id={user_id}, session_id={session_id})"
214
+ )
215
+
216
+ repo = Repository(Message, table_name="messages")
217
+
218
+ # Build optional filters
219
+ filters: dict = {}
220
+ if user_id:
221
+ filters["user_id"] = user_id
222
+ if session_id:
223
+ filters["session_id"] = session_id
224
+ if message_type:
225
+ filters["message_type"] = message_type
226
+
227
+ messages = await repo.find(
228
+ filters=filters,
229
+ order_by="created_at DESC",
230
+ limit=limit + 1,
231
+ offset=offset,
232
+ )
233
+
234
+ has_more = len(messages) > limit
235
+ if has_more:
236
+ messages = messages[:limit]
237
+
238
+ total = await repo.count(filters)
239
+
240
+ return MessageListResponse(data=messages, total=total, has_more=has_more)
241
+
242
+
243
+ @router.get("/stats", response_model=SystemStats)
244
+ async def get_system_stats(
245
+ user: dict = Depends(require_admin),
246
+ ) -> SystemStats:
247
+ """
248
+ Get system-wide statistics.
249
+
250
+ Admin-only endpoint for monitoring dashboard.
251
+ """
252
+ if not settings.postgres.enabled:
253
+ raise HTTPException(status_code=503, detail="Database not enabled")
254
+
255
+ logger.info(f"Admin {user.get('email')} fetching system stats")
256
+
257
+ from ...models.entities import User
258
+ from ...utils.date_utils import days_ago
259
+
260
+ user_repo = Repository(User, table_name="users")
261
+ session_repo = Repository(Session, table_name="sessions")
262
+ message_repo = Repository(Message, table_name="messages")
263
+
264
+ # Get totals
265
+ total_users = await user_repo.count({})
266
+ total_sessions = await session_repo.count({})
267
+ total_messages = await message_repo.count({})
268
+
269
+ # For 24h stats, we'd need date filtering in Repository
270
+ # For now, return totals (TODO: add date range support)
271
+ return SystemStats(
272
+ total_users=total_users,
273
+ total_sessions=total_sessions,
274
+ total_messages=total_messages,
275
+ active_sessions_24h=0, # TODO: implement
276
+ messages_24h=0, # TODO: implement
277
+ )
rem/api/routers/auth.py CHANGED
@@ -49,6 +49,8 @@ from authlib.integrations.starlette_client import OAuth
49
49
  from loguru import logger
50
50
 
51
51
  from ...settings import settings
52
+ from ...services.postgres.service import PostgresService
53
+ from ...services.user_service import UserService
52
54
 
53
55
  router = APIRouter(prefix="/api/auth", tags=["auth"])
54
56
 
@@ -168,6 +170,53 @@ async def callback(provider: str, request: Request):
168
170
  if not user_info:
169
171
  # Fetch from userinfo endpoint if not in ID token
170
172
  user_info = await client.userinfo(token=token)
173
+
174
+ # --- REM Integration Start ---
175
+ if settings.postgres.enabled:
176
+ # Connect to DB
177
+ db = PostgresService()
178
+ try:
179
+ await db.connect()
180
+ user_service = UserService(db)
181
+
182
+ # Get/Create User
183
+ user_entity = await user_service.get_or_create_user(
184
+ email=user_info.get("email"),
185
+ name=user_info.get("name", "New User"),
186
+ avatar_url=user_info.get("picture"),
187
+ tenant_id="default", # Single tenant for now
188
+ )
189
+
190
+ # Link Anonymous Session
191
+ # TrackingMiddleware sets request.state.anon_id
192
+ anon_id = getattr(request.state, "anon_id", None)
193
+ # Fallback to cookie if middleware didn't run or state missing
194
+ if not anon_id:
195
+ # Attempt to parse cookie manually if needed, but middleware
196
+ # usually handles the signature logic.
197
+ # Just check raw cookie for simple case (not recommended if signed)
198
+ pass
199
+
200
+ if anon_id:
201
+ await user_service.link_anonymous_session(user_entity, anon_id)
202
+
203
+ # Enrich session user with DB info
204
+ db_info = {
205
+ "id": str(user_entity.id),
206
+ "tenant_id": user_entity.tenant_id,
207
+ "tier": user_entity.tier.value if user_entity.tier else "free",
208
+ "roles": [user_entity.role] if user_entity.role else [],
209
+ }
210
+
211
+ except Exception as db_e:
212
+ logger.error(f"Database error during auth callback: {db_e}")
213
+ # Continue login even if DB fails, but warn
214
+ db_info = {"id": "db_error", "tier": "free"}
215
+ finally:
216
+ await db.disconnect()
217
+ else:
218
+ db_info = {"id": "no_db", "tier": "free"}
219
+ # --- REM Integration End ---
171
220
 
172
221
  # Store user info in session
173
222
  request.session["user"] = {
@@ -176,6 +225,11 @@ async def callback(provider: str, request: Request):
176
225
  "email": user_info.get("email"),
177
226
  "name": user_info.get("name"),
178
227
  "picture": user_info.get("picture"),
228
+ # Add DB info
229
+ "id": db_info.get("id"),
230
+ "tenant_id": db_info.get("tenant_id", "default"),
231
+ "tier": db_info.get("tier"),
232
+ "roles": db_info.get("roles", []),
179
233
  }
180
234
 
181
235
  # Store tokens in session for API access
@@ -227,3 +281,73 @@ async def me(request: Request):
227
281
  raise HTTPException(status_code=401, detail="Not authenticated")
228
282
 
229
283
  return user
284
+
285
+
286
+ # =============================================================================
287
+ # Development Token Endpoints (non-production only)
288
+ # =============================================================================
289
+
290
+
291
+ def generate_dev_token() -> str:
292
+ """
293
+ Generate a dev token for testing.
294
+
295
+ Token format: dev_<hmac_signature>
296
+ The signature is based on the session secret to ensure only valid tokens work.
297
+ """
298
+ import hashlib
299
+ import hmac
300
+
301
+ # Use session secret as key
302
+ secret = settings.auth.session_secret or "dev-secret"
303
+ message = "test-user:dev-token"
304
+
305
+ signature = hmac.new(
306
+ secret.encode(),
307
+ message.encode(),
308
+ hashlib.sha256
309
+ ).hexdigest()[:32]
310
+
311
+ return f"dev_{signature}"
312
+
313
+
314
+ def verify_dev_token(token: str) -> bool:
315
+ """Verify a dev token is valid."""
316
+ expected = generate_dev_token()
317
+ return token == expected
318
+
319
+
320
+ @router.get("/dev/token")
321
+ async def get_dev_token(request: Request):
322
+ """
323
+ Get a development token for testing (non-production only).
324
+
325
+ This token can be used as a Bearer token to authenticate as the
326
+ test user (test-user / test@rem.local) without going through OAuth.
327
+
328
+ Usage:
329
+ curl -H "Authorization: Bearer <token>" http://localhost:8000/api/v1/...
330
+
331
+ Returns:
332
+ 401 if in production environment
333
+ Token and usage instructions otherwise
334
+ """
335
+ if settings.environment == "production":
336
+ raise HTTPException(
337
+ status_code=401,
338
+ detail="Dev tokens are not available in production"
339
+ )
340
+
341
+ token = generate_dev_token()
342
+
343
+ return {
344
+ "token": token,
345
+ "type": "Bearer",
346
+ "user": {
347
+ "id": "test-user",
348
+ "email": "test@rem.local",
349
+ "name": "Test User",
350
+ },
351
+ "usage": f'curl -H "Authorization: Bearer {token}" http://localhost:8000/api/v1/...',
352
+ "warning": "This token is for development/testing only and will not work in production.",
353
+ }