remdb 0.3.141__py3-none-any.whl → 0.3.163__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of remdb might be problematic. Click here for more details.

Files changed (44) hide show
  1. rem/agentic/agents/__init__.py +16 -0
  2. rem/agentic/agents/agent_manager.py +310 -0
  3. rem/agentic/context.py +81 -3
  4. rem/agentic/context_builder.py +18 -3
  5. rem/api/deps.py +3 -5
  6. rem/api/main.py +22 -3
  7. rem/api/mcp_router/server.py +2 -0
  8. rem/api/mcp_router/tools.py +90 -0
  9. rem/api/middleware/tracking.py +5 -5
  10. rem/api/routers/auth.py +346 -5
  11. rem/api/routers/chat/completions.py +4 -2
  12. rem/api/routers/chat/streaming.py +77 -22
  13. rem/api/routers/messages.py +24 -15
  14. rem/auth/__init__.py +13 -3
  15. rem/auth/jwt.py +352 -0
  16. rem/auth/middleware.py +108 -6
  17. rem/auth/providers/__init__.py +4 -1
  18. rem/auth/providers/email.py +215 -0
  19. rem/cli/commands/experiments.py +32 -46
  20. rem/models/core/experiment.py +4 -14
  21. rem/models/entities/__init__.py +4 -0
  22. rem/models/entities/subscriber.py +175 -0
  23. rem/models/entities/user.py +1 -0
  24. rem/schemas/agents/core/agent-builder.yaml +134 -0
  25. rem/services/__init__.py +3 -1
  26. rem/services/content/service.py +4 -3
  27. rem/services/email/__init__.py +10 -0
  28. rem/services/email/service.py +511 -0
  29. rem/services/email/templates.py +360 -0
  30. rem/services/postgres/README.md +38 -0
  31. rem/services/postgres/diff_service.py +19 -3
  32. rem/services/postgres/pydantic_to_sqlalchemy.py +45 -13
  33. rem/services/postgres/repository.py +5 -4
  34. rem/services/session/compression.py +113 -50
  35. rem/services/session/reload.py +14 -7
  36. rem/services/user_service.py +29 -0
  37. rem/settings.py +199 -4
  38. rem/sql/migrations/005_schema_update.sql +145 -0
  39. rem/utils/README.md +45 -0
  40. rem/utils/files.py +157 -1
  41. {remdb-0.3.141.dist-info → remdb-0.3.163.dist-info}/METADATA +7 -5
  42. {remdb-0.3.141.dist-info → remdb-0.3.163.dist-info}/RECORD +44 -35
  43. {remdb-0.3.141.dist-info → remdb-0.3.163.dist-info}/WHEEL +0 -0
  44. {remdb-0.3.141.dist-info → remdb-0.3.163.dist-info}/entry_points.txt +0 -0
@@ -76,6 +76,9 @@ async def stream_openai_response(
76
76
  agent_schema: str | None = None,
77
77
  # Mutable container to capture trace context (deterministic, not AI-dependent)
78
78
  trace_context_out: dict | None = None,
79
+ # Mutable container to capture tool calls for persistence
80
+ # Format: list of {"tool_name": str, "tool_id": str, "arguments": dict, "result": any}
81
+ tool_calls_out: list | None = None,
79
82
  ) -> AsyncGenerator[str, None]:
80
83
  """
81
84
  Stream Pydantic AI agent responses with rich SSE events.
@@ -146,6 +149,9 @@ async def stream_openai_response(
146
149
  pending_tool_completions: list[tuple[str, str]] = []
147
150
  # Track if metadata was registered via register_metadata tool
148
151
  metadata_registered = False
152
+ # Track pending tool calls with full data for persistence
153
+ # Maps tool_id -> {"tool_name": str, "tool_id": str, "arguments": dict}
154
+ pending_tool_data: dict[str, dict] = {}
149
155
 
150
156
  try:
151
157
  # Emit initial progress event
@@ -299,6 +305,13 @@ async def stream_openai_response(
299
305
  arguments=args_dict
300
306
  ))
301
307
 
308
+ # Track tool call data for persistence (especially register_metadata)
309
+ pending_tool_data[tool_id] = {
310
+ "tool_name": tool_name,
311
+ "tool_id": tool_id,
312
+ "arguments": args_dict,
313
+ }
314
+
302
315
  # Update progress
303
316
  current_step = 2
304
317
  total_steps = 4 # Added tool execution step
@@ -421,6 +434,15 @@ async def stream_openai_response(
421
434
  hidden=False,
422
435
  ))
423
436
 
437
+ # Capture tool call with result for persistence
438
+ # Special handling for register_metadata - always capture full data
439
+ if tool_calls_out is not None and tool_id in pending_tool_data:
440
+ tool_data = pending_tool_data[tool_id]
441
+ tool_data["result"] = result_content
442
+ tool_data["is_metadata"] = is_metadata_event
443
+ tool_calls_out.append(tool_data)
444
+ del pending_tool_data[tool_id]
445
+
424
446
  if not is_metadata_event:
425
447
  # Normal tool completion - emit ToolCallEvent
426
448
  result_str = str(result_content)
@@ -728,6 +750,9 @@ async def stream_openai_response_with_save(
728
750
  # Accumulate content during streaming
729
751
  accumulated_content = []
730
752
 
753
+ # Capture tool calls for persistence (especially register_metadata)
754
+ tool_calls: list = []
755
+
731
756
  async for chunk in stream_openai_response(
732
757
  agent=agent,
733
758
  prompt=prompt,
@@ -737,6 +762,7 @@ async def stream_openai_response_with_save(
737
762
  session_id=session_id,
738
763
  message_id=message_id,
739
764
  trace_context_out=trace_context, # Pass container to capture trace IDs
765
+ tool_calls_out=tool_calls, # Capture tool calls for persistence
740
766
  ):
741
767
  yield chunk
742
768
 
@@ -755,28 +781,57 @@ async def stream_openai_response_with_save(
755
781
  except (json.JSONDecodeError, KeyError, IndexError):
756
782
  pass # Skip non-JSON or malformed chunks
757
783
 
758
- # After streaming completes, save the assistant response
759
- if settings.postgres.enabled and session_id and accumulated_content:
760
- full_content = "".join(accumulated_content)
784
+ # After streaming completes, save tool calls and assistant response
785
+ # Note: All messages stored UNCOMPRESSED. Compression happens on reload.
786
+ if settings.postgres.enabled and session_id:
761
787
  # Get captured trace context from container (deterministically captured inside agent execution)
762
788
  captured_trace_id = trace_context.get("trace_id")
763
789
  captured_span_id = trace_context.get("span_id")
764
- assistant_message = {
765
- "id": message_id, # Use pre-generated ID for consistency with metadata event
766
- "role": "assistant",
767
- "content": full_content,
768
- "timestamp": to_iso(utc_now()),
769
- "trace_id": captured_trace_id,
770
- "span_id": captured_span_id,
771
- }
772
- try:
773
- store = SessionMessageStore(user_id=user_id or settings.test.effective_user_id)
774
- await store.store_session_messages(
775
- session_id=session_id,
776
- messages=[assistant_message],
777
- user_id=user_id,
778
- compress=True, # Compress long assistant responses
779
- )
780
- logger.debug(f"Saved assistant response {message_id} to session {session_id} ({len(full_content)} chars)")
781
- except Exception as e:
782
- logger.error(f"Failed to save assistant response: {e}", exc_info=True)
790
+ timestamp = to_iso(utc_now())
791
+
792
+ messages_to_store = []
793
+
794
+ # First, store tool call messages (message_type: "tool")
795
+ for tool_call in tool_calls:
796
+ tool_message = {
797
+ "role": "tool",
798
+ "content": json.dumps(tool_call.get("result", {}), default=str),
799
+ "timestamp": timestamp,
800
+ "trace_id": captured_trace_id,
801
+ "span_id": captured_span_id,
802
+ # Store tool call details in a way that can be reconstructed
803
+ "tool_call_id": tool_call.get("tool_id"),
804
+ "tool_name": tool_call.get("tool_name"),
805
+ "tool_arguments": tool_call.get("arguments"),
806
+ }
807
+ messages_to_store.append(tool_message)
808
+
809
+ # Then store assistant text response (if any)
810
+ if accumulated_content:
811
+ full_content = "".join(accumulated_content)
812
+ assistant_message = {
813
+ "id": message_id, # Use pre-generated ID for consistency with metadata event
814
+ "role": "assistant",
815
+ "content": full_content,
816
+ "timestamp": timestamp,
817
+ "trace_id": captured_trace_id,
818
+ "span_id": captured_span_id,
819
+ }
820
+ messages_to_store.append(assistant_message)
821
+
822
+ if messages_to_store:
823
+ try:
824
+ store = SessionMessageStore(user_id=user_id or settings.test.effective_user_id)
825
+ await store.store_session_messages(
826
+ session_id=session_id,
827
+ messages=messages_to_store,
828
+ user_id=user_id,
829
+ compress=False, # Store uncompressed; compression happens on reload
830
+ )
831
+ logger.debug(
832
+ f"Saved {len(tool_calls)} tool calls and "
833
+ f"{'assistant response' if accumulated_content else 'no text'} "
834
+ f"to session {session_id}"
835
+ )
836
+ except Exception as e:
837
+ logger.error(f"Failed to save session messages: {e}", exc_info=True)
@@ -134,7 +134,6 @@ async def list_messages(
134
134
  ),
135
135
  limit: int = Query(default=50, ge=1, le=100, description="Max results to return"),
136
136
  offset: int = Query(default=0, ge=0, description="Offset for pagination"),
137
- x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
138
137
  ) -> MessageListResponse:
139
138
  """
140
139
  List messages with optional filters.
@@ -158,15 +157,18 @@ async def list_messages(
158
157
 
159
158
  repo = Repository(Message, table_name="messages")
160
159
 
160
+ # Get current user for logging
161
+ current_user = get_current_user(request)
162
+ jwt_user_id = current_user.get("id") if current_user else None
163
+
161
164
  # If mine=true, force filter to current user's ID from JWT
162
165
  effective_user_id = user_id
163
166
  if mine:
164
- current_user = get_current_user(request)
165
167
  if current_user:
166
168
  effective_user_id = current_user.get("id")
167
169
 
168
170
  # Build user-scoped filters (admin can see all, regular users see only their own)
169
- filters = await get_user_filter(request, x_user_id=effective_user_id, x_tenant_id=x_tenant_id)
171
+ filters = await get_user_filter(request, x_user_id=effective_user_id)
170
172
 
171
173
  # Apply optional filters
172
174
  if session_id:
@@ -174,6 +176,13 @@ async def list_messages(
174
176
  if message_type:
175
177
  filters["message_type"] = message_type
176
178
 
179
+ # Log the query parameters for debugging
180
+ logger.debug(
181
+ f"[messages] Query: session_id={session_id} | "
182
+ f"jwt_user_id={jwt_user_id} | "
183
+ f"filters={filters}"
184
+ )
185
+
177
186
  # For date filtering, we need custom SQL (not supported by basic Repository)
178
187
  # For now, fetch all matching base filters and filter in Python
179
188
  # TODO: Extend Repository to support date range filters
@@ -206,6 +215,12 @@ async def list_messages(
206
215
  # Get total count for pagination info
207
216
  total = await repo.count(filters)
208
217
 
218
+ # Log result count
219
+ logger.debug(
220
+ f"[messages] Result: returned={len(messages)} | total={total} | "
221
+ f"session_id={session_id}"
222
+ )
223
+
209
224
  return MessageListResponse(data=messages, total=total, has_more=has_more)
210
225
 
211
226
 
@@ -213,7 +228,6 @@ async def list_messages(
213
228
  async def get_message(
214
229
  request: Request,
215
230
  message_id: str,
216
- x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
217
231
  ) -> Message:
218
232
  """
219
233
  Get a specific message by ID.
@@ -236,7 +250,7 @@ async def get_message(
236
250
  raise HTTPException(status_code=503, detail="Database not enabled")
237
251
 
238
252
  repo = Repository(Message, table_name="messages")
239
- message = await repo.get_by_id(message_id, x_tenant_id)
253
+ message = await repo.get_by_id(message_id)
240
254
 
241
255
  if not message:
242
256
  raise HTTPException(status_code=404, detail=f"Message '{message_id}' not found")
@@ -263,7 +277,6 @@ async def list_sessions(
263
277
  mode: SessionMode | None = Query(default=None, description="Filter by session mode"),
264
278
  page: int = Query(default=1, ge=1, description="Page number (1-indexed)"),
265
279
  page_size: int = Query(default=50, ge=1, le=100, description="Number of results per page"),
266
- x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
267
280
  ) -> SessionsQueryResponse:
268
281
  """
269
282
  List sessions with optional filters and page-based pagination.
@@ -288,7 +301,7 @@ async def list_sessions(
288
301
  repo = Repository(Session, table_name="sessions")
289
302
 
290
303
  # Build user-scoped filters (admin can see all, regular users see only their own)
291
- filters = await get_user_filter(request, x_user_id=user_id, x_tenant_id=x_tenant_id)
304
+ filters = await get_user_filter(request, x_user_id=user_id)
292
305
  if mode:
293
306
  filters["mode"] = mode.value
294
307
 
@@ -319,7 +332,6 @@ async def create_session(
319
332
  request_body: SessionCreateRequest,
320
333
  user: dict = Depends(require_admin),
321
334
  x_user_id: str = Header(alias="X-User-Id", default="default"),
322
- x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
323
335
  ) -> Session:
324
336
  """
325
337
  Create a new session.
@@ -334,7 +346,6 @@ async def create_session(
334
346
 
335
347
  Headers:
336
348
  - X-User-Id: User identifier (owner of the session)
337
- - X-Tenant-Id: Tenant identifier
338
349
 
339
350
  Returns:
340
351
  Created session object
@@ -354,7 +365,7 @@ async def create_session(
354
365
  prompt=request_body.prompt,
355
366
  agent_schema_uri=request_body.agent_schema_uri,
356
367
  user_id=effective_user_id,
357
- tenant_id=x_tenant_id,
368
+ tenant_id="default", # tenant_id not used for filtering, set to default
358
369
  )
359
370
 
360
371
  repo = Repository(Session, table_name="sessions")
@@ -372,7 +383,6 @@ async def create_session(
372
383
  async def get_session(
373
384
  request: Request,
374
385
  session_id: str,
375
- x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
376
386
  ) -> Session:
377
387
  """
378
388
  Get a specific session by ID.
@@ -395,11 +405,11 @@ async def get_session(
395
405
  raise HTTPException(status_code=503, detail="Database not enabled")
396
406
 
397
407
  repo = Repository(Session, table_name="sessions")
398
- session = await repo.get_by_id(session_id, x_tenant_id)
408
+ session = await repo.get_by_id(session_id)
399
409
 
400
410
  if not session:
401
411
  # Try finding by name
402
- sessions = await repo.find({"name": session_id, "tenant_id": x_tenant_id}, limit=1)
412
+ sessions = await repo.find({"name": session_id}, limit=1)
403
413
  if sessions:
404
414
  session = sessions[0]
405
415
  else:
@@ -420,7 +430,6 @@ async def update_session(
420
430
  request: Request,
421
431
  session_id: str,
422
432
  request_body: SessionUpdateRequest,
423
- x_tenant_id: str = Header(alias="X-Tenant-Id", default="default"),
424
433
  ) -> Session:
425
434
  """
426
435
  Update an existing session.
@@ -450,7 +459,7 @@ async def update_session(
450
459
  raise HTTPException(status_code=503, detail="Database not enabled")
451
460
 
452
461
  repo = Repository(Session, table_name="sessions")
453
- session = await repo.get_by_id(session_id, x_tenant_id)
462
+ session = await repo.get_by_id(session_id)
454
463
 
455
464
  if not session:
456
465
  raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
rem/auth/__init__.py CHANGED
@@ -1,26 +1,36 @@
1
1
  """
2
2
  REM Authentication Module.
3
3
 
4
- OAuth 2.1 compliant authentication with support for:
4
+ Authentication with support for:
5
+ - Email passwordless login (verification codes)
5
6
  - Google OAuth
6
7
  - Microsoft Entra ID (Azure AD) OIDC
7
8
  - Custom OIDC providers
8
9
 
9
10
  Design Pattern:
10
11
  - Provider-agnostic base classes
11
- - PKCE (Proof Key for Code Exchange) for all flows
12
+ - PKCE (Proof Key for Code Exchange) for OAuth flows
12
13
  - State parameter for CSRF protection
13
14
  - Nonce for ID token replay protection
14
15
  - Token validation with JWKS
15
- - Clean separation: providers/ for OAuth logic, middleware.py for FastAPI integration
16
+ - Clean separation: providers/ for auth logic, middleware.py for FastAPI integration
17
+
18
+ Email Auth Flow:
19
+ 1. POST /api/auth/email/send-code with {email}
20
+ 2. User receives code via email
21
+ 3. POST /api/auth/email/verify with {email, code}
22
+ 4. Session created, user authenticated
16
23
  """
17
24
 
18
25
  from .providers.base import OAuthProvider
26
+ from .providers.email import EmailAuthProvider, EmailAuthResult
19
27
  from .providers.google import GoogleOAuthProvider
20
28
  from .providers.microsoft import MicrosoftOAuthProvider
21
29
 
22
30
  __all__ = [
23
31
  "OAuthProvider",
32
+ "EmailAuthProvider",
33
+ "EmailAuthResult",
24
34
  "GoogleOAuthProvider",
25
35
  "MicrosoftOAuthProvider",
26
36
  ]
rem/auth/jwt.py ADDED
@@ -0,0 +1,352 @@
1
+ """
2
+ JWT Token Service for REM Authentication.
3
+
4
+ Provides JWT token generation and validation for stateless authentication.
5
+ Uses HS256 algorithm with the session secret for signing.
6
+
7
+ Token Types:
8
+ - Access Token: Short-lived (default 1 hour), used for API authentication
9
+ - Refresh Token: Long-lived (default 7 days), used to obtain new access tokens
10
+
11
+ Token Claims:
12
+ - sub: User ID (UUID string)
13
+ - email: User email
14
+ - name: User display name
15
+ - role: User role (user, admin)
16
+ - tier: User subscription tier
17
+ - roles: List of roles for authorization
18
+ - provider: Auth provider (email, google, microsoft)
19
+ - tenant_id: Tenant identifier for multi-tenancy
20
+ - exp: Expiration timestamp
21
+ - iat: Issued at timestamp
22
+ - type: Token type (access, refresh)
23
+
24
+ Usage:
25
+ from rem.auth.jwt import JWTService
26
+
27
+ jwt_service = JWTService()
28
+
29
+ # Generate tokens after successful authentication
30
+ tokens = jwt_service.create_tokens(user_dict)
31
+ # Returns: {"access_token": "...", "refresh_token": "...", "token_type": "bearer", "expires_in": 3600}
32
+
33
+ # Validate token from Authorization header
34
+ user = jwt_service.verify_token(token)
35
+ # Returns user dict or None if invalid
36
+
37
+ # Refresh access token
38
+ new_tokens = jwt_service.refresh_access_token(refresh_token)
39
+ """
40
+
41
+ import time
42
+ import hmac
43
+ import hashlib
44
+ import base64
45
+ import json
46
+ from datetime import datetime, timezone
47
+ from typing import Optional
48
+
49
+ from loguru import logger
50
+
51
+
52
+ class JWTService:
53
+ """
54
+ JWT token service for authentication.
55
+
56
+ Uses HMAC-SHA256 for signing - simple and secure for single-service deployment.
57
+ For multi-service deployments, consider switching to RS256 with public/private keys.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ secret: str | None = None,
63
+ access_token_expiry_seconds: int = 3600, # 1 hour
64
+ refresh_token_expiry_seconds: int = 604800, # 7 days
65
+ issuer: str = "rem",
66
+ ):
67
+ """
68
+ Initialize JWT service.
69
+
70
+ Args:
71
+ secret: Secret key for signing (uses settings.auth.session_secret if not provided)
72
+ access_token_expiry_seconds: Access token lifetime in seconds
73
+ refresh_token_expiry_seconds: Refresh token lifetime in seconds
74
+ issuer: Token issuer identifier
75
+ """
76
+ if secret:
77
+ self._secret = secret
78
+ else:
79
+ from ..settings import settings
80
+ self._secret = settings.auth.session_secret
81
+
82
+ self._access_expiry = access_token_expiry_seconds
83
+ self._refresh_expiry = refresh_token_expiry_seconds
84
+ self._issuer = issuer
85
+
86
+ def _base64url_encode(self, data: bytes) -> str:
87
+ """Base64url encode without padding."""
88
+ return base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8")
89
+
90
+ def _base64url_decode(self, data: str) -> bytes:
91
+ """Base64url decode with padding restoration."""
92
+ padding = 4 - len(data) % 4
93
+ if padding != 4:
94
+ data += "=" * padding
95
+ return base64.urlsafe_b64decode(data)
96
+
97
+ def _sign(self, message: str) -> str:
98
+ """Create HMAC-SHA256 signature."""
99
+ signature = hmac.new(
100
+ self._secret.encode("utf-8"),
101
+ message.encode("utf-8"),
102
+ hashlib.sha256
103
+ ).digest()
104
+ return self._base64url_encode(signature)
105
+
106
+ def _create_token(self, payload: dict) -> str:
107
+ """
108
+ Create a JWT token.
109
+
110
+ Args:
111
+ payload: Token claims
112
+
113
+ Returns:
114
+ Encoded JWT string
115
+ """
116
+ header = {"alg": "HS256", "typ": "JWT"}
117
+
118
+ header_encoded = self._base64url_encode(json.dumps(header, separators=(",", ":")).encode())
119
+ payload_encoded = self._base64url_encode(json.dumps(payload, separators=(",", ":")).encode())
120
+
121
+ message = f"{header_encoded}.{payload_encoded}"
122
+ signature = self._sign(message)
123
+
124
+ return f"{message}.{signature}"
125
+
126
+ def _verify_signature(self, token: str) -> dict | None:
127
+ """
128
+ Verify token signature and decode payload.
129
+
130
+ Args:
131
+ token: JWT token string
132
+
133
+ Returns:
134
+ Decoded payload dict or None if invalid
135
+ """
136
+ try:
137
+ parts = token.split(".")
138
+ if len(parts) != 3:
139
+ return None
140
+
141
+ header_encoded, payload_encoded, signature = parts
142
+
143
+ # Verify signature
144
+ message = f"{header_encoded}.{payload_encoded}"
145
+ expected_signature = self._sign(message)
146
+
147
+ if not hmac.compare_digest(signature, expected_signature):
148
+ logger.debug("JWT signature verification failed")
149
+ return None
150
+
151
+ # Decode payload
152
+ payload = json.loads(self._base64url_decode(payload_encoded))
153
+ return payload
154
+
155
+ except Exception as e:
156
+ logger.debug(f"JWT decode error: {e}")
157
+ return None
158
+
159
+ def create_tokens(
160
+ self,
161
+ user: dict,
162
+ access_expiry: int | None = None,
163
+ refresh_expiry: int | None = None,
164
+ ) -> dict:
165
+ """
166
+ Create access and refresh tokens for a user.
167
+
168
+ Args:
169
+ user: User dict with id, email, name, role, tier, roles, provider, tenant_id
170
+ access_expiry: Override access token expiry (seconds)
171
+ refresh_expiry: Override refresh token expiry (seconds)
172
+
173
+ Returns:
174
+ Dict with access_token, refresh_token, token_type, expires_in
175
+ """
176
+ now = int(time.time())
177
+ access_exp = access_expiry or self._access_expiry
178
+ refresh_exp = refresh_expiry or self._refresh_expiry
179
+
180
+ # Common claims
181
+ base_claims = {
182
+ "sub": user.get("id", ""),
183
+ "email": user.get("email", ""),
184
+ "name": user.get("name", ""),
185
+ "role": user.get("role"),
186
+ "tier": user.get("tier", "free"),
187
+ "roles": user.get("roles", ["user"]),
188
+ "provider": user.get("provider", "email"),
189
+ "tenant_id": user.get("tenant_id", "default"),
190
+ "iss": self._issuer,
191
+ "iat": now,
192
+ }
193
+
194
+ # Access token
195
+ access_payload = {
196
+ **base_claims,
197
+ "type": "access",
198
+ "exp": now + access_exp,
199
+ }
200
+ access_token = self._create_token(access_payload)
201
+
202
+ # Refresh token (minimal claims for security)
203
+ refresh_payload = {
204
+ "sub": user.get("id", ""),
205
+ "email": user.get("email", ""),
206
+ "type": "refresh",
207
+ "iss": self._issuer,
208
+ "iat": now,
209
+ "exp": now + refresh_exp,
210
+ }
211
+ refresh_token = self._create_token(refresh_payload)
212
+
213
+ return {
214
+ "access_token": access_token,
215
+ "refresh_token": refresh_token,
216
+ "token_type": "bearer",
217
+ "expires_in": access_exp,
218
+ }
219
+
220
+ def verify_token(self, token: str, token_type: str = "access") -> dict | None:
221
+ """
222
+ Verify a token and return user claims.
223
+
224
+ Args:
225
+ token: JWT token string
226
+ token_type: Expected token type ("access" or "refresh")
227
+
228
+ Returns:
229
+ User dict with claims or None if invalid/expired
230
+ """
231
+ payload = self._verify_signature(token)
232
+ if not payload:
233
+ return None
234
+
235
+ # Check token type
236
+ if payload.get("type") != token_type:
237
+ logger.debug(f"Token type mismatch: expected {token_type}, got {payload.get('type')}")
238
+ return None
239
+
240
+ # Check expiration
241
+ exp = payload.get("exp", 0)
242
+ if exp < time.time():
243
+ logger.debug("Token expired")
244
+ return None
245
+
246
+ # Check issuer
247
+ if payload.get("iss") != self._issuer:
248
+ logger.debug(f"Token issuer mismatch: expected {self._issuer}, got {payload.get('iss')}")
249
+ return None
250
+
251
+ # Return user dict (compatible with session user format)
252
+ return {
253
+ "id": payload.get("sub"),
254
+ "email": payload.get("email"),
255
+ "name": payload.get("name"),
256
+ "role": payload.get("role"),
257
+ "tier": payload.get("tier", "free"),
258
+ "roles": payload.get("roles", ["user"]),
259
+ "provider": payload.get("provider", "email"),
260
+ "tenant_id": payload.get("tenant_id", "default"),
261
+ }
262
+
263
+ def refresh_access_token(self, refresh_token: str) -> dict | None:
264
+ """
265
+ Create new access token using refresh token.
266
+
267
+ Args:
268
+ refresh_token: Valid refresh token
269
+
270
+ Returns:
271
+ New token dict or None if refresh token is invalid
272
+ """
273
+ # Verify refresh token
274
+ payload = self._verify_signature(refresh_token)
275
+ if not payload:
276
+ return None
277
+
278
+ if payload.get("type") != "refresh":
279
+ logger.debug("Not a refresh token")
280
+ return None
281
+
282
+ # Check expiration
283
+ exp = payload.get("exp", 0)
284
+ if exp < time.time():
285
+ logger.debug("Refresh token expired")
286
+ return None
287
+
288
+ # Create new access token with minimal info from refresh token
289
+ # In production, you'd look up the full user from database
290
+ user = {
291
+ "id": payload.get("sub"),
292
+ "email": payload.get("email"),
293
+ "name": payload.get("email", "").split("@")[0],
294
+ "provider": "email",
295
+ "tenant_id": "default",
296
+ "tier": "free",
297
+ "roles": ["user"],
298
+ }
299
+
300
+ # Only return new access token, keep same refresh token
301
+ now = int(time.time())
302
+ access_payload = {
303
+ "sub": user["id"],
304
+ "email": user["email"],
305
+ "name": user["name"],
306
+ "role": user.get("role"),
307
+ "tier": user["tier"],
308
+ "roles": user["roles"],
309
+ "provider": user["provider"],
310
+ "tenant_id": user["tenant_id"],
311
+ "type": "access",
312
+ "iss": self._issuer,
313
+ "iat": now,
314
+ "exp": now + self._access_expiry,
315
+ }
316
+
317
+ return {
318
+ "access_token": self._create_token(access_payload),
319
+ "token_type": "bearer",
320
+ "expires_in": self._access_expiry,
321
+ }
322
+
323
+ def decode_without_verification(self, token: str) -> dict | None:
324
+ """
325
+ Decode token without verification (for debugging only).
326
+
327
+ Args:
328
+ token: JWT token string
329
+
330
+ Returns:
331
+ Decoded payload or None
332
+ """
333
+ try:
334
+ parts = token.split(".")
335
+ if len(parts) != 3:
336
+ return None
337
+ payload = json.loads(self._base64url_decode(parts[1]))
338
+ return payload
339
+ except Exception:
340
+ return None
341
+
342
+
343
+ # Singleton instance for convenience
344
+ _jwt_service: Optional[JWTService] = None
345
+
346
+
347
+ def get_jwt_service() -> JWTService:
348
+ """Get or create the JWT service singleton."""
349
+ global _jwt_service
350
+ if _jwt_service is None:
351
+ _jwt_service = JWTService()
352
+ return _jwt_service