remdb 0.3.146__py3-none-any.whl → 0.3.163__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of remdb might be problematic. Click here for more details.
- rem/agentic/agents/__init__.py +16 -0
- rem/agentic/agents/agent_manager.py +310 -0
- rem/agentic/context.py +81 -3
- rem/agentic/context_builder.py +18 -3
- rem/api/deps.py +3 -5
- rem/api/main.py +22 -3
- rem/api/mcp_router/server.py +2 -0
- rem/api/mcp_router/tools.py +90 -0
- rem/api/middleware/tracking.py +5 -5
- rem/api/routers/auth.py +346 -5
- rem/api/routers/chat/completions.py +4 -2
- rem/api/routers/chat/streaming.py +77 -22
- rem/api/routers/messages.py +24 -15
- rem/auth/__init__.py +13 -3
- rem/auth/jwt.py +352 -0
- rem/auth/middleware.py +42 -5
- rem/auth/providers/__init__.py +4 -1
- rem/auth/providers/email.py +215 -0
- rem/models/entities/__init__.py +4 -0
- rem/models/entities/subscriber.py +175 -0
- rem/models/entities/user.py +1 -0
- rem/schemas/agents/core/agent-builder.yaml +134 -0
- rem/services/__init__.py +3 -1
- rem/services/content/service.py +4 -3
- rem/services/email/__init__.py +10 -0
- rem/services/email/service.py +511 -0
- rem/services/email/templates.py +360 -0
- rem/services/postgres/README.md +38 -0
- rem/services/postgres/diff_service.py +19 -3
- rem/services/postgres/pydantic_to_sqlalchemy.py +37 -2
- rem/services/postgres/repository.py +5 -4
- rem/services/session/compression.py +113 -50
- rem/services/session/reload.py +14 -7
- rem/services/user_service.py +29 -0
- rem/settings.py +175 -0
- rem/sql/migrations/005_schema_update.sql +145 -0
- {remdb-0.3.146.dist-info → remdb-0.3.163.dist-info}/METADATA +1 -1
- {remdb-0.3.146.dist-info → remdb-0.3.163.dist-info}/RECORD +40 -31
- {remdb-0.3.146.dist-info → remdb-0.3.163.dist-info}/WHEEL +0 -0
- {remdb-0.3.146.dist-info → remdb-0.3.163.dist-info}/entry_points.txt +0 -0
|
@@ -76,6 +76,9 @@ async def stream_openai_response(
|
|
|
76
76
|
agent_schema: str | None = None,
|
|
77
77
|
# Mutable container to capture trace context (deterministic, not AI-dependent)
|
|
78
78
|
trace_context_out: dict | None = None,
|
|
79
|
+
# Mutable container to capture tool calls for persistence
|
|
80
|
+
# Format: list of {"tool_name": str, "tool_id": str, "arguments": dict, "result": any}
|
|
81
|
+
tool_calls_out: list | None = None,
|
|
79
82
|
) -> AsyncGenerator[str, None]:
|
|
80
83
|
"""
|
|
81
84
|
Stream Pydantic AI agent responses with rich SSE events.
|
|
@@ -146,6 +149,9 @@ async def stream_openai_response(
|
|
|
146
149
|
pending_tool_completions: list[tuple[str, str]] = []
|
|
147
150
|
# Track if metadata was registered via register_metadata tool
|
|
148
151
|
metadata_registered = False
|
|
152
|
+
# Track pending tool calls with full data for persistence
|
|
153
|
+
# Maps tool_id -> {"tool_name": str, "tool_id": str, "arguments": dict}
|
|
154
|
+
pending_tool_data: dict[str, dict] = {}
|
|
149
155
|
|
|
150
156
|
try:
|
|
151
157
|
# Emit initial progress event
|
|
@@ -299,6 +305,13 @@ async def stream_openai_response(
|
|
|
299
305
|
arguments=args_dict
|
|
300
306
|
))
|
|
301
307
|
|
|
308
|
+
# Track tool call data for persistence (especially register_metadata)
|
|
309
|
+
pending_tool_data[tool_id] = {
|
|
310
|
+
"tool_name": tool_name,
|
|
311
|
+
"tool_id": tool_id,
|
|
312
|
+
"arguments": args_dict,
|
|
313
|
+
}
|
|
314
|
+
|
|
302
315
|
# Update progress
|
|
303
316
|
current_step = 2
|
|
304
317
|
total_steps = 4 # Added tool execution step
|
|
@@ -421,6 +434,15 @@ async def stream_openai_response(
|
|
|
421
434
|
hidden=False,
|
|
422
435
|
))
|
|
423
436
|
|
|
437
|
+
# Capture tool call with result for persistence
|
|
438
|
+
# Special handling for register_metadata - always capture full data
|
|
439
|
+
if tool_calls_out is not None and tool_id in pending_tool_data:
|
|
440
|
+
tool_data = pending_tool_data[tool_id]
|
|
441
|
+
tool_data["result"] = result_content
|
|
442
|
+
tool_data["is_metadata"] = is_metadata_event
|
|
443
|
+
tool_calls_out.append(tool_data)
|
|
444
|
+
del pending_tool_data[tool_id]
|
|
445
|
+
|
|
424
446
|
if not is_metadata_event:
|
|
425
447
|
# Normal tool completion - emit ToolCallEvent
|
|
426
448
|
result_str = str(result_content)
|
|
@@ -728,6 +750,9 @@ async def stream_openai_response_with_save(
|
|
|
728
750
|
# Accumulate content during streaming
|
|
729
751
|
accumulated_content = []
|
|
730
752
|
|
|
753
|
+
# Capture tool calls for persistence (especially register_metadata)
|
|
754
|
+
tool_calls: list = []
|
|
755
|
+
|
|
731
756
|
async for chunk in stream_openai_response(
|
|
732
757
|
agent=agent,
|
|
733
758
|
prompt=prompt,
|
|
@@ -737,6 +762,7 @@ async def stream_openai_response_with_save(
|
|
|
737
762
|
session_id=session_id,
|
|
738
763
|
message_id=message_id,
|
|
739
764
|
trace_context_out=trace_context, # Pass container to capture trace IDs
|
|
765
|
+
tool_calls_out=tool_calls, # Capture tool calls for persistence
|
|
740
766
|
):
|
|
741
767
|
yield chunk
|
|
742
768
|
|
|
@@ -755,28 +781,57 @@ async def stream_openai_response_with_save(
|
|
|
755
781
|
except (json.JSONDecodeError, KeyError, IndexError):
|
|
756
782
|
pass # Skip non-JSON or malformed chunks
|
|
757
783
|
|
|
758
|
-
# After streaming completes, save
|
|
759
|
-
|
|
760
|
-
|
|
784
|
+
# After streaming completes, save tool calls and assistant response
|
|
785
|
+
# Note: All messages stored UNCOMPRESSED. Compression happens on reload.
|
|
786
|
+
if settings.postgres.enabled and session_id:
|
|
761
787
|
# Get captured trace context from container (deterministically captured inside agent execution)
|
|
762
788
|
captured_trace_id = trace_context.get("trace_id")
|
|
763
789
|
captured_span_id = trace_context.get("span_id")
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
790
|
+
timestamp = to_iso(utc_now())
|
|
791
|
+
|
|
792
|
+
messages_to_store = []
|
|
793
|
+
|
|
794
|
+
# First, store tool call messages (message_type: "tool")
|
|
795
|
+
for tool_call in tool_calls:
|
|
796
|
+
tool_message = {
|
|
797
|
+
"role": "tool",
|
|
798
|
+
"content": json.dumps(tool_call.get("result", {}), default=str),
|
|
799
|
+
"timestamp": timestamp,
|
|
800
|
+
"trace_id": captured_trace_id,
|
|
801
|
+
"span_id": captured_span_id,
|
|
802
|
+
# Store tool call details in a way that can be reconstructed
|
|
803
|
+
"tool_call_id": tool_call.get("tool_id"),
|
|
804
|
+
"tool_name": tool_call.get("tool_name"),
|
|
805
|
+
"tool_arguments": tool_call.get("arguments"),
|
|
806
|
+
}
|
|
807
|
+
messages_to_store.append(tool_message)
|
|
808
|
+
|
|
809
|
+
# Then store assistant text response (if any)
|
|
810
|
+
if accumulated_content:
|
|
811
|
+
full_content = "".join(accumulated_content)
|
|
812
|
+
assistant_message = {
|
|
813
|
+
"id": message_id, # Use pre-generated ID for consistency with metadata event
|
|
814
|
+
"role": "assistant",
|
|
815
|
+
"content": full_content,
|
|
816
|
+
"timestamp": timestamp,
|
|
817
|
+
"trace_id": captured_trace_id,
|
|
818
|
+
"span_id": captured_span_id,
|
|
819
|
+
}
|
|
820
|
+
messages_to_store.append(assistant_message)
|
|
821
|
+
|
|
822
|
+
if messages_to_store:
|
|
823
|
+
try:
|
|
824
|
+
store = SessionMessageStore(user_id=user_id or settings.test.effective_user_id)
|
|
825
|
+
await store.store_session_messages(
|
|
826
|
+
session_id=session_id,
|
|
827
|
+
messages=messages_to_store,
|
|
828
|
+
user_id=user_id,
|
|
829
|
+
compress=False, # Store uncompressed; compression happens on reload
|
|
830
|
+
)
|
|
831
|
+
logger.debug(
|
|
832
|
+
f"Saved {len(tool_calls)} tool calls and "
|
|
833
|
+
f"{'assistant response' if accumulated_content else 'no text'} "
|
|
834
|
+
f"to session {session_id}"
|
|
835
|
+
)
|
|
836
|
+
except Exception as e:
|
|
837
|
+
logger.error(f"Failed to save session messages: {e}", exc_info=True)
|
rem/api/routers/messages.py
CHANGED
|
@@ -134,7 +134,6 @@ async def list_messages(
|
|
|
134
134
|
),
|
|
135
135
|
limit: int = Query(default=50, ge=1, le=100, description="Max results to return"),
|
|
136
136
|
offset: int = Query(default=0, ge=0, description="Offset for pagination"),
|
|
137
|
-
x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
|
|
138
137
|
) -> MessageListResponse:
|
|
139
138
|
"""
|
|
140
139
|
List messages with optional filters.
|
|
@@ -158,15 +157,18 @@ async def list_messages(
|
|
|
158
157
|
|
|
159
158
|
repo = Repository(Message, table_name="messages")
|
|
160
159
|
|
|
160
|
+
# Get current user for logging
|
|
161
|
+
current_user = get_current_user(request)
|
|
162
|
+
jwt_user_id = current_user.get("id") if current_user else None
|
|
163
|
+
|
|
161
164
|
# If mine=true, force filter to current user's ID from JWT
|
|
162
165
|
effective_user_id = user_id
|
|
163
166
|
if mine:
|
|
164
|
-
current_user = get_current_user(request)
|
|
165
167
|
if current_user:
|
|
166
168
|
effective_user_id = current_user.get("id")
|
|
167
169
|
|
|
168
170
|
# Build user-scoped filters (admin can see all, regular users see only their own)
|
|
169
|
-
filters = await get_user_filter(request, x_user_id=effective_user_id
|
|
171
|
+
filters = await get_user_filter(request, x_user_id=effective_user_id)
|
|
170
172
|
|
|
171
173
|
# Apply optional filters
|
|
172
174
|
if session_id:
|
|
@@ -174,6 +176,13 @@ async def list_messages(
|
|
|
174
176
|
if message_type:
|
|
175
177
|
filters["message_type"] = message_type
|
|
176
178
|
|
|
179
|
+
# Log the query parameters for debugging
|
|
180
|
+
logger.debug(
|
|
181
|
+
f"[messages] Query: session_id={session_id} | "
|
|
182
|
+
f"jwt_user_id={jwt_user_id} | "
|
|
183
|
+
f"filters={filters}"
|
|
184
|
+
)
|
|
185
|
+
|
|
177
186
|
# For date filtering, we need custom SQL (not supported by basic Repository)
|
|
178
187
|
# For now, fetch all matching base filters and filter in Python
|
|
179
188
|
# TODO: Extend Repository to support date range filters
|
|
@@ -206,6 +215,12 @@ async def list_messages(
|
|
|
206
215
|
# Get total count for pagination info
|
|
207
216
|
total = await repo.count(filters)
|
|
208
217
|
|
|
218
|
+
# Log result count
|
|
219
|
+
logger.debug(
|
|
220
|
+
f"[messages] Result: returned={len(messages)} | total={total} | "
|
|
221
|
+
f"session_id={session_id}"
|
|
222
|
+
)
|
|
223
|
+
|
|
209
224
|
return MessageListResponse(data=messages, total=total, has_more=has_more)
|
|
210
225
|
|
|
211
226
|
|
|
@@ -213,7 +228,6 @@ async def list_messages(
|
|
|
213
228
|
async def get_message(
|
|
214
229
|
request: Request,
|
|
215
230
|
message_id: str,
|
|
216
|
-
x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
|
|
217
231
|
) -> Message:
|
|
218
232
|
"""
|
|
219
233
|
Get a specific message by ID.
|
|
@@ -236,7 +250,7 @@ async def get_message(
|
|
|
236
250
|
raise HTTPException(status_code=503, detail="Database not enabled")
|
|
237
251
|
|
|
238
252
|
repo = Repository(Message, table_name="messages")
|
|
239
|
-
message = await repo.get_by_id(message_id
|
|
253
|
+
message = await repo.get_by_id(message_id)
|
|
240
254
|
|
|
241
255
|
if not message:
|
|
242
256
|
raise HTTPException(status_code=404, detail=f"Message '{message_id}' not found")
|
|
@@ -263,7 +277,6 @@ async def list_sessions(
|
|
|
263
277
|
mode: SessionMode | None = Query(default=None, description="Filter by session mode"),
|
|
264
278
|
page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
|
|
265
279
|
page_size: int = Query(default=50, ge=1, le=100, description="Number of results per page"),
|
|
266
|
-
x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
|
|
267
280
|
) -> SessionsQueryResponse:
|
|
268
281
|
"""
|
|
269
282
|
List sessions with optional filters and page-based pagination.
|
|
@@ -288,7 +301,7 @@ async def list_sessions(
|
|
|
288
301
|
repo = Repository(Session, table_name="sessions")
|
|
289
302
|
|
|
290
303
|
# Build user-scoped filters (admin can see all, regular users see only their own)
|
|
291
|
-
filters = await get_user_filter(request, x_user_id=user_id
|
|
304
|
+
filters = await get_user_filter(request, x_user_id=user_id)
|
|
292
305
|
if mode:
|
|
293
306
|
filters["mode"] = mode.value
|
|
294
307
|
|
|
@@ -319,7 +332,6 @@ async def create_session(
|
|
|
319
332
|
request_body: SessionCreateRequest,
|
|
320
333
|
user: dict = Depends(require_admin),
|
|
321
334
|
x_user_id: str = Header(alias="X-User-Id", default="default"),
|
|
322
|
-
x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
|
|
323
335
|
) -> Session:
|
|
324
336
|
"""
|
|
325
337
|
Create a new session.
|
|
@@ -334,7 +346,6 @@ async def create_session(
|
|
|
334
346
|
|
|
335
347
|
Headers:
|
|
336
348
|
- X-User-Id: User identifier (owner of the session)
|
|
337
|
-
- X-Tenant-Id: Tenant identifier
|
|
338
349
|
|
|
339
350
|
Returns:
|
|
340
351
|
Created session object
|
|
@@ -354,7 +365,7 @@ async def create_session(
|
|
|
354
365
|
prompt=request_body.prompt,
|
|
355
366
|
agent_schema_uri=request_body.agent_schema_uri,
|
|
356
367
|
user_id=effective_user_id,
|
|
357
|
-
tenant_id=
|
|
368
|
+
tenant_id="default", # tenant_id not used for filtering, set to default
|
|
358
369
|
)
|
|
359
370
|
|
|
360
371
|
repo = Repository(Session, table_name="sessions")
|
|
@@ -372,7 +383,6 @@ async def create_session(
|
|
|
372
383
|
async def get_session(
|
|
373
384
|
request: Request,
|
|
374
385
|
session_id: str,
|
|
375
|
-
x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
|
|
376
386
|
) -> Session:
|
|
377
387
|
"""
|
|
378
388
|
Get a specific session by ID.
|
|
@@ -395,11 +405,11 @@ async def get_session(
|
|
|
395
405
|
raise HTTPException(status_code=503, detail="Database not enabled")
|
|
396
406
|
|
|
397
407
|
repo = Repository(Session, table_name="sessions")
|
|
398
|
-
session = await repo.get_by_id(session_id
|
|
408
|
+
session = await repo.get_by_id(session_id)
|
|
399
409
|
|
|
400
410
|
if not session:
|
|
401
411
|
# Try finding by name
|
|
402
|
-
sessions = await repo.find({"name": session_id
|
|
412
|
+
sessions = await repo.find({"name": session_id}, limit=1)
|
|
403
413
|
if sessions:
|
|
404
414
|
session = sessions[0]
|
|
405
415
|
else:
|
|
@@ -420,7 +430,6 @@ async def update_session(
|
|
|
420
430
|
request: Request,
|
|
421
431
|
session_id: str,
|
|
422
432
|
request_body: SessionUpdateRequest,
|
|
423
|
-
x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
|
|
424
433
|
) -> Session:
|
|
425
434
|
"""
|
|
426
435
|
Update an existing session.
|
|
@@ -450,7 +459,7 @@ async def update_session(
|
|
|
450
459
|
raise HTTPException(status_code=503, detail="Database not enabled")
|
|
451
460
|
|
|
452
461
|
repo = Repository(Session, table_name="sessions")
|
|
453
|
-
session = await repo.get_by_id(session_id
|
|
462
|
+
session = await repo.get_by_id(session_id)
|
|
454
463
|
|
|
455
464
|
if not session:
|
|
456
465
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
rem/auth/__init__.py
CHANGED
|
@@ -1,26 +1,36 @@
|
|
|
1
1
|
"""
|
|
2
2
|
REM Authentication Module.
|
|
3
3
|
|
|
4
|
-
|
|
4
|
+
Authentication with support for:
|
|
5
|
+
- Email passwordless login (verification codes)
|
|
5
6
|
- Google OAuth
|
|
6
7
|
- Microsoft Entra ID (Azure AD) OIDC
|
|
7
8
|
- Custom OIDC providers
|
|
8
9
|
|
|
9
10
|
Design Pattern:
|
|
10
11
|
- Provider-agnostic base classes
|
|
11
|
-
- PKCE (Proof Key for Code Exchange) for
|
|
12
|
+
- PKCE (Proof Key for Code Exchange) for OAuth flows
|
|
12
13
|
- State parameter for CSRF protection
|
|
13
14
|
- Nonce for ID token replay protection
|
|
14
15
|
- Token validation with JWKS
|
|
15
|
-
- Clean separation: providers/ for
|
|
16
|
+
- Clean separation: providers/ for auth logic, middleware.py for FastAPI integration
|
|
17
|
+
|
|
18
|
+
Email Auth Flow:
|
|
19
|
+
1. POST /api/auth/email/send-code with {email}
|
|
20
|
+
2. User receives code via email
|
|
21
|
+
3. POST /api/auth/email/verify with {email, code}
|
|
22
|
+
4. Session created, user authenticated
|
|
16
23
|
"""
|
|
17
24
|
|
|
18
25
|
from .providers.base import OAuthProvider
|
|
26
|
+
from .providers.email import EmailAuthProvider, EmailAuthResult
|
|
19
27
|
from .providers.google import GoogleOAuthProvider
|
|
20
28
|
from .providers.microsoft import MicrosoftOAuthProvider
|
|
21
29
|
|
|
22
30
|
__all__ = [
|
|
23
31
|
"OAuthProvider",
|
|
32
|
+
"EmailAuthProvider",
|
|
33
|
+
"EmailAuthResult",
|
|
24
34
|
"GoogleOAuthProvider",
|
|
25
35
|
"MicrosoftOAuthProvider",
|
|
26
36
|
]
|
rem/auth/jwt.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JWT Token Service for REM Authentication.
|
|
3
|
+
|
|
4
|
+
Provides JWT token generation and validation for stateless authentication.
|
|
5
|
+
Uses HS256 algorithm with the session secret for signing.
|
|
6
|
+
|
|
7
|
+
Token Types:
|
|
8
|
+
- Access Token: Short-lived (default 1 hour), used for API authentication
|
|
9
|
+
- Refresh Token: Long-lived (default 7 days), used to obtain new access tokens
|
|
10
|
+
|
|
11
|
+
Token Claims:
|
|
12
|
+
- sub: User ID (UUID string)
|
|
13
|
+
- email: User email
|
|
14
|
+
- name: User display name
|
|
15
|
+
- role: User role (user, admin)
|
|
16
|
+
- tier: User subscription tier
|
|
17
|
+
- roles: List of roles for authorization
|
|
18
|
+
- provider: Auth provider (email, google, microsoft)
|
|
19
|
+
- tenant_id: Tenant identifier for multi-tenancy
|
|
20
|
+
- exp: Expiration timestamp
|
|
21
|
+
- iat: Issued at timestamp
|
|
22
|
+
- type: Token type (access, refresh)
|
|
23
|
+
|
|
24
|
+
Usage:
|
|
25
|
+
from rem.auth.jwt import JWTService
|
|
26
|
+
|
|
27
|
+
jwt_service = JWTService()
|
|
28
|
+
|
|
29
|
+
# Generate tokens after successful authentication
|
|
30
|
+
tokens = jwt_service.create_tokens(user_dict)
|
|
31
|
+
# Returns: {"access_token": "...", "refresh_token": "...", "token_type": "bearer", "expires_in": 3600}
|
|
32
|
+
|
|
33
|
+
# Validate token from Authorization header
|
|
34
|
+
user = jwt_service.verify_token(token)
|
|
35
|
+
# Returns user dict or None if invalid
|
|
36
|
+
|
|
37
|
+
# Refresh access token
|
|
38
|
+
new_tokens = jwt_service.refresh_access_token(refresh_token)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
import time
|
|
42
|
+
import hmac
|
|
43
|
+
import hashlib
|
|
44
|
+
import base64
|
|
45
|
+
import json
|
|
46
|
+
from datetime import datetime, timezone
|
|
47
|
+
from typing import Optional
|
|
48
|
+
|
|
49
|
+
from loguru import logger
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class JWTService:
|
|
53
|
+
"""
|
|
54
|
+
JWT token service for authentication.
|
|
55
|
+
|
|
56
|
+
Uses HMAC-SHA256 for signing - simple and secure for single-service deployment.
|
|
57
|
+
For multi-service deployments, consider switching to RS256 with public/private keys.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
secret: str | None = None,
|
|
63
|
+
access_token_expiry_seconds: int = 3600, # 1 hour
|
|
64
|
+
refresh_token_expiry_seconds: int = 604800, # 7 days
|
|
65
|
+
issuer: str = "rem",
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Initialize JWT service.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
secret: Secret key for signing (uses settings.auth.session_secret if not provided)
|
|
72
|
+
access_token_expiry_seconds: Access token lifetime in seconds
|
|
73
|
+
refresh_token_expiry_seconds: Refresh token lifetime in seconds
|
|
74
|
+
issuer: Token issuer identifier
|
|
75
|
+
"""
|
|
76
|
+
if secret:
|
|
77
|
+
self._secret = secret
|
|
78
|
+
else:
|
|
79
|
+
from ..settings import settings
|
|
80
|
+
self._secret = settings.auth.session_secret
|
|
81
|
+
|
|
82
|
+
self._access_expiry = access_token_expiry_seconds
|
|
83
|
+
self._refresh_expiry = refresh_token_expiry_seconds
|
|
84
|
+
self._issuer = issuer
|
|
85
|
+
|
|
86
|
+
def _base64url_encode(self, data: bytes) -> str:
|
|
87
|
+
"""Base64url encode without padding."""
|
|
88
|
+
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8")
|
|
89
|
+
|
|
90
|
+
def _base64url_decode(self, data: str) -> bytes:
|
|
91
|
+
"""Base64url decode with padding restoration."""
|
|
92
|
+
padding = 4 - len(data) % 4
|
|
93
|
+
if padding != 4:
|
|
94
|
+
data += "=" * padding
|
|
95
|
+
return base64.urlsafe_b64decode(data)
|
|
96
|
+
|
|
97
|
+
def _sign(self, message: str) -> str:
|
|
98
|
+
"""Create HMAC-SHA256 signature."""
|
|
99
|
+
signature = hmac.new(
|
|
100
|
+
self._secret.encode("utf-8"),
|
|
101
|
+
message.encode("utf-8"),
|
|
102
|
+
hashlib.sha256
|
|
103
|
+
).digest()
|
|
104
|
+
return self._base64url_encode(signature)
|
|
105
|
+
|
|
106
|
+
def _create_token(self, payload: dict) -> str:
|
|
107
|
+
"""
|
|
108
|
+
Create a JWT token.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
payload: Token claims
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Encoded JWT string
|
|
115
|
+
"""
|
|
116
|
+
header = {"alg": "HS256", "typ": "JWT"}
|
|
117
|
+
|
|
118
|
+
header_encoded = self._base64url_encode(json.dumps(header, separators=(",", ":")).encode())
|
|
119
|
+
payload_encoded = self._base64url_encode(json.dumps(payload, separators=(",", ":")).encode())
|
|
120
|
+
|
|
121
|
+
message = f"{header_encoded}.{payload_encoded}"
|
|
122
|
+
signature = self._sign(message)
|
|
123
|
+
|
|
124
|
+
return f"{message}.{signature}"
|
|
125
|
+
|
|
126
|
+
def _verify_signature(self, token: str) -> dict | None:
|
|
127
|
+
"""
|
|
128
|
+
Verify token signature and decode payload.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
token: JWT token string
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Decoded payload dict or None if invalid
|
|
135
|
+
"""
|
|
136
|
+
try:
|
|
137
|
+
parts = token.split(".")
|
|
138
|
+
if len(parts) != 3:
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
header_encoded, payload_encoded, signature = parts
|
|
142
|
+
|
|
143
|
+
# Verify signature
|
|
144
|
+
message = f"{header_encoded}.{payload_encoded}"
|
|
145
|
+
expected_signature = self._sign(message)
|
|
146
|
+
|
|
147
|
+
if not hmac.compare_digest(signature, expected_signature):
|
|
148
|
+
logger.debug("JWT signature verification failed")
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
# Decode payload
|
|
152
|
+
payload = json.loads(self._base64url_decode(payload_encoded))
|
|
153
|
+
return payload
|
|
154
|
+
|
|
155
|
+
except Exception as e:
|
|
156
|
+
logger.debug(f"JWT decode error: {e}")
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
def create_tokens(
|
|
160
|
+
self,
|
|
161
|
+
user: dict,
|
|
162
|
+
access_expiry: int | None = None,
|
|
163
|
+
refresh_expiry: int | None = None,
|
|
164
|
+
) -> dict:
|
|
165
|
+
"""
|
|
166
|
+
Create access and refresh tokens for a user.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
user: User dict with id, email, name, role, tier, roles, provider, tenant_id
|
|
170
|
+
access_expiry: Override access token expiry (seconds)
|
|
171
|
+
refresh_expiry: Override refresh token expiry (seconds)
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
Dict with access_token, refresh_token, token_type, expires_in
|
|
175
|
+
"""
|
|
176
|
+
now = int(time.time())
|
|
177
|
+
access_exp = access_expiry or self._access_expiry
|
|
178
|
+
refresh_exp = refresh_expiry or self._refresh_expiry
|
|
179
|
+
|
|
180
|
+
# Common claims
|
|
181
|
+
base_claims = {
|
|
182
|
+
"sub": user.get("id", ""),
|
|
183
|
+
"email": user.get("email", ""),
|
|
184
|
+
"name": user.get("name", ""),
|
|
185
|
+
"role": user.get("role"),
|
|
186
|
+
"tier": user.get("tier", "free"),
|
|
187
|
+
"roles": user.get("roles", ["user"]),
|
|
188
|
+
"provider": user.get("provider", "email"),
|
|
189
|
+
"tenant_id": user.get("tenant_id", "default"),
|
|
190
|
+
"iss": self._issuer,
|
|
191
|
+
"iat": now,
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
# Access token
|
|
195
|
+
access_payload = {
|
|
196
|
+
**base_claims,
|
|
197
|
+
"type": "access",
|
|
198
|
+
"exp": now + access_exp,
|
|
199
|
+
}
|
|
200
|
+
access_token = self._create_token(access_payload)
|
|
201
|
+
|
|
202
|
+
# Refresh token (minimal claims for security)
|
|
203
|
+
refresh_payload = {
|
|
204
|
+
"sub": user.get("id", ""),
|
|
205
|
+
"email": user.get("email", ""),
|
|
206
|
+
"type": "refresh",
|
|
207
|
+
"iss": self._issuer,
|
|
208
|
+
"iat": now,
|
|
209
|
+
"exp": now + refresh_exp,
|
|
210
|
+
}
|
|
211
|
+
refresh_token = self._create_token(refresh_payload)
|
|
212
|
+
|
|
213
|
+
return {
|
|
214
|
+
"access_token": access_token,
|
|
215
|
+
"refresh_token": refresh_token,
|
|
216
|
+
"token_type": "bearer",
|
|
217
|
+
"expires_in": access_exp,
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
def verify_token(self, token: str, token_type: str = "access") -> dict | None:
|
|
221
|
+
"""
|
|
222
|
+
Verify a token and return user claims.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
token: JWT token string
|
|
226
|
+
token_type: Expected token type ("access" or "refresh")
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
User dict with claims or None if invalid/expired
|
|
230
|
+
"""
|
|
231
|
+
payload = self._verify_signature(token)
|
|
232
|
+
if not payload:
|
|
233
|
+
return None
|
|
234
|
+
|
|
235
|
+
# Check token type
|
|
236
|
+
if payload.get("type") != token_type:
|
|
237
|
+
logger.debug(f"Token type mismatch: expected {token_type}, got {payload.get('type')}")
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
# Check expiration
|
|
241
|
+
exp = payload.get("exp", 0)
|
|
242
|
+
if exp < time.time():
|
|
243
|
+
logger.debug("Token expired")
|
|
244
|
+
return None
|
|
245
|
+
|
|
246
|
+
# Check issuer
|
|
247
|
+
if payload.get("iss") != self._issuer:
|
|
248
|
+
logger.debug(f"Token issuer mismatch: expected {self._issuer}, got {payload.get('iss')}")
|
|
249
|
+
return None
|
|
250
|
+
|
|
251
|
+
# Return user dict (compatible with session user format)
|
|
252
|
+
return {
|
|
253
|
+
"id": payload.get("sub"),
|
|
254
|
+
"email": payload.get("email"),
|
|
255
|
+
"name": payload.get("name"),
|
|
256
|
+
"role": payload.get("role"),
|
|
257
|
+
"tier": payload.get("tier", "free"),
|
|
258
|
+
"roles": payload.get("roles", ["user"]),
|
|
259
|
+
"provider": payload.get("provider", "email"),
|
|
260
|
+
"tenant_id": payload.get("tenant_id", "default"),
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
def refresh_access_token(self, refresh_token: str) -> dict | None:
|
|
264
|
+
"""
|
|
265
|
+
Create new access token using refresh token.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
refresh_token: Valid refresh token
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
New token dict or None if refresh token is invalid
|
|
272
|
+
"""
|
|
273
|
+
# Verify refresh token
|
|
274
|
+
payload = self._verify_signature(refresh_token)
|
|
275
|
+
if not payload:
|
|
276
|
+
return None
|
|
277
|
+
|
|
278
|
+
if payload.get("type") != "refresh":
|
|
279
|
+
logger.debug("Not a refresh token")
|
|
280
|
+
return None
|
|
281
|
+
|
|
282
|
+
# Check expiration
|
|
283
|
+
exp = payload.get("exp", 0)
|
|
284
|
+
if exp < time.time():
|
|
285
|
+
logger.debug("Refresh token expired")
|
|
286
|
+
return None
|
|
287
|
+
|
|
288
|
+
# Create new access token with minimal info from refresh token
|
|
289
|
+
# In production, you'd look up the full user from database
|
|
290
|
+
user = {
|
|
291
|
+
"id": payload.get("sub"),
|
|
292
|
+
"email": payload.get("email"),
|
|
293
|
+
"name": payload.get("email", "").split("@")[0],
|
|
294
|
+
"provider": "email",
|
|
295
|
+
"tenant_id": "default",
|
|
296
|
+
"tier": "free",
|
|
297
|
+
"roles": ["user"],
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
# Only return new access token, keep same refresh token
|
|
301
|
+
now = int(time.time())
|
|
302
|
+
access_payload = {
|
|
303
|
+
"sub": user["id"],
|
|
304
|
+
"email": user["email"],
|
|
305
|
+
"name": user["name"],
|
|
306
|
+
"role": user.get("role"),
|
|
307
|
+
"tier": user["tier"],
|
|
308
|
+
"roles": user["roles"],
|
|
309
|
+
"provider": user["provider"],
|
|
310
|
+
"tenant_id": user["tenant_id"],
|
|
311
|
+
"type": "access",
|
|
312
|
+
"iss": self._issuer,
|
|
313
|
+
"iat": now,
|
|
314
|
+
"exp": now + self._access_expiry,
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
return {
|
|
318
|
+
"access_token": self._create_token(access_payload),
|
|
319
|
+
"token_type": "bearer",
|
|
320
|
+
"expires_in": self._access_expiry,
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
def decode_without_verification(self, token: str) -> dict | None:
|
|
324
|
+
"""
|
|
325
|
+
Decode token without verification (for debugging only).
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
token: JWT token string
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
Decoded payload or None
|
|
332
|
+
"""
|
|
333
|
+
try:
|
|
334
|
+
parts = token.split(".")
|
|
335
|
+
if len(parts) != 3:
|
|
336
|
+
return None
|
|
337
|
+
payload = json.loads(self._base64url_decode(parts[1]))
|
|
338
|
+
return payload
|
|
339
|
+
except Exception:
|
|
340
|
+
return None
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
# Singleton instance for convenience
|
|
344
|
+
_jwt_service: Optional[JWTService] = None
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def get_jwt_service() -> JWTService:
|
|
348
|
+
"""Get or create the JWT service singleton."""
|
|
349
|
+
global _jwt_service
|
|
350
|
+
if _jwt_service is None:
|
|
351
|
+
_jwt_service = JWTService()
|
|
352
|
+
return _jwt_service
|