ragbits-chat 1.4.0.dev202512160238__py3-none-any.whl → 1.4.0.dev202601130240__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ragbits/chat/api.py +364 -79
  2. ragbits/chat/auth/__init__.py +3 -1
  3. ragbits/chat/auth/backends.py +519 -81
  4. ragbits/chat/auth/base.py +8 -10
  5. ragbits/chat/auth/oauth2_providers.py +108 -0
  6. ragbits/chat/auth/provider_config.py +81 -0
  7. ragbits/chat/auth/session_store.py +178 -0
  8. ragbits/chat/auth/types.py +66 -29
  9. ragbits/chat/interface/types.py +15 -0
  10. ragbits/chat/persistence/sql.py +1 -0
  11. ragbits/chat/providers/model_provider.py +10 -8
  12. ragbits/chat/ui-build/assets/AuthGuard-Bq7UOJ7y.js +1 -0
  13. ragbits/chat/ui-build/assets/{ChatHistory--EyHdeYk.js → ChatHistory-B2hLBYMJ.js} +2 -2
  14. ragbits/chat/ui-build/assets/{ChatOptionsForm-BQJ-bYMu.js → ChatOptionsForm-bfNG8UIW.js} +1 -1
  15. ragbits/chat/ui-build/assets/CredentialsLogin-0g5-w2vR.js +1 -0
  16. ragbits/chat/ui-build/assets/{FeedbackForm-BYee4uF-.js → FeedbackForm-oSbly5oN.js} +1 -1
  17. ragbits/chat/ui-build/assets/Login-DSW_CNFu.js +1 -0
  18. ragbits/chat/ui-build/assets/LogoutButton-BQE8NNsg.js +1 -0
  19. ragbits/chat/ui-build/assets/OAuth2Login-EmJ39PUe.js +2 -0
  20. ragbits/chat/ui-build/assets/ShareButton-B7DyIVH0.js +1 -0
  21. ragbits/chat/ui-build/assets/{UsageButton-C0lzhbc6.js → UsageButton-BABA7a-w.js} +1 -1
  22. ragbits/chat/ui-build/assets/authStore-BfGlL8rp.js +1 -0
  23. ragbits/chat/ui-build/assets/{chunk-IGSAU2ZA-ZqFHUQCB.js → chunk-IGSAU2ZA-NXd1g0Qd.js} +1 -1
  24. ragbits/chat/ui-build/assets/{chunk-SSA7SXE4-Dj8WqXIN.js → chunk-SSA7SXE4-CJa0HuAU.js} +1 -1
  25. ragbits/chat/ui-build/assets/index-BZLU40Mk.js +83 -0
  26. ragbits/chat/ui-build/assets/index-Be0kkf3d.js +24 -0
  27. ragbits/chat/ui-build/assets/index-Bvn9K6h_.js +1 -0
  28. ragbits/chat/ui-build/assets/{index-Bpba6d6u.js → index-Ceq7Rkzy.js} +1 -1
  29. ragbits/chat/ui-build/assets/index-ClAYkAiv.css +1 -0
  30. ragbits/chat/ui-build/assets/useInitializeUserStore-DyHP7g8x.js +1 -0
  31. ragbits/chat/ui-build/assets/{useMenuTriggerState-DaMXoDzf.js → useMenuTriggerState-SaFmATkk.js} +1 -1
  32. ragbits/chat/ui-build/assets/{useSelectableItem-D8MlMsUd.js → useSelectableItem-DhuFnc0W.js} +1 -1
  33. ragbits/chat/ui-build/index.html +2 -2
  34. {ragbits_chat-1.4.0.dev202512160238.dist-info → ragbits_chat-1.4.0.dev202601130240.dist-info}/METADATA +2 -2
  35. ragbits_chat-1.4.0.dev202601130240.dist-info/RECORD +58 -0
  36. ragbits/chat/ui-build/assets/AuthGuard-Da8Duw5e.js +0 -1
  37. ragbits/chat/ui-build/assets/Login-DkcluI7q.js +0 -1
  38. ragbits/chat/ui-build/assets/LogoutButton-DSdCqsdC.js +0 -1
  39. ragbits/chat/ui-build/assets/ShareButton-YVgRY6bN.js +0 -1
  40. ragbits/chat/ui-build/assets/authStore-BFeUV-Bg.js +0 -1
  41. ragbits/chat/ui-build/assets/index-8hpVK8cj.js +0 -131
  42. ragbits/chat/ui-build/assets/index-BUbs7vFP.js +0 -1
  43. ragbits/chat/ui-build/assets/index-BZGp6GjF.js +0 -32
  44. ragbits/chat/ui-build/assets/index-CmsICuOz.css +0 -1
  45. ragbits_chat-1.4.0.dev202512160238.dist-info/RECORD +0 -52
  46. {ragbits_chat-1.4.0.dev202512160238.dist-info → ragbits_chat-1.4.0.dev202601130240.dist-info}/WHEEL +0 -0
ragbits/chat/api.py CHANGED
@@ -1,24 +1,27 @@
1
+ import asyncio
1
2
  import importlib
2
3
  import json
3
4
  import logging
5
+ import os
4
6
  import re
5
7
  import time
6
8
  from collections.abc import AsyncGenerator
7
- from contextlib import asynccontextmanager
9
+ from contextlib import asynccontextmanager, suppress
8
10
  from pathlib import Path
9
11
  from typing import Any, cast
10
12
 
11
13
  import uvicorn
12
- from fastapi import Depends, FastAPI, HTTPException, Request, status
14
+ from fastapi import FastAPI, HTTPException, Request, status
13
15
  from fastapi.exceptions import RequestValidationError
14
16
  from fastapi.middleware.cors import CORSMiddleware
15
- from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse, StreamingResponse
16
- from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
17
+ from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse, RedirectResponse, StreamingResponse
17
18
  from fastapi.staticfiles import StaticFiles
18
19
  from pydantic import BaseModel
19
20
 
20
21
  from ragbits.chat.auth import AuthenticationBackend, User
21
- from ragbits.chat.auth.types import LoginRequest, LoginResponse, LogoutRequest
22
+ from ragbits.chat.auth.backends import MultiAuthenticationBackend, OAuth2AuthenticationBackend
23
+ from ragbits.chat.auth.provider_config import get_provider_visual_config
24
+ from ragbits.chat.auth.types import LoginRequest, LoginResponse, OAuth2Credentials
22
25
  from ragbits.chat.interface import ChatInterface
23
26
  from ragbits.chat.interface.types import (
24
27
  AuthenticationConfig,
@@ -33,6 +36,7 @@ from ragbits.chat.interface.types import (
33
36
  FeedbackRequest,
34
37
  Image,
35
38
  ImageResponse,
39
+ OAuth2ProviderConfig,
36
40
  )
37
41
  from ragbits.core.audit.metrics import record_metric
38
42
  from ragbits.core.audit.metrics.base import MetricType
@@ -42,6 +46,12 @@ from .metrics import ChatCounterMetric, ChatHistogramMetric
42
46
 
43
47
  logger = logging.getLogger(__name__)
44
48
 
49
+ # Environment-aware cookie security: only require HTTPS in production
50
+ IS_PRODUCTION = os.getenv("ENVIRONMENT", "").lower() == "production"
51
+
52
+ # Session cookie name - used for storing session ID in HTTP-only cookie
53
+ SESSION_COOKIE_NAME = "ragbits_session"
54
+
45
55
  # Chunk size for large base64 images to prevent SSE message size issues
46
56
  # Keep chunks extremely small to avoid JSON string length limits in browsers and SSE parsing issues
47
57
  # Account for JSON overhead: metadata + base64 data should fit comfortably in browser limits
@@ -79,14 +89,34 @@ class RagbitsAPI:
79
89
  self.cors_origins = cors_origins or []
80
90
  self.debug_mode = debug_mode
81
91
  self.auth_backend = self._load_auth_backend(auth_backend)
82
- self.security = HTTPBearer(auto_error=False) if auth_backend else None
83
92
  self.theme_path = Path(theme_path) if theme_path else None
84
93
 
94
+ # get frontend base URL from environment variable or use default
95
+ frontend_base_url = os.getenv("FRONTEND_BASE_URL", "http://localhost:8000")
96
+ # remove trailing slash from frontend base URL
97
+ frontend_base_url = frontend_base_url.rstrip("/")
98
+ # set frontend base URL on the API
99
+ self.frontend_base_url = frontend_base_url
100
+
85
101
  @asynccontextmanager
86
102
  async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
87
103
  await self.chat_interface.setup()
104
+
105
+ # Start background cleanup tasks for session and OAuth state management
106
+ cleanup_tasks: list[asyncio.Task] = []
107
+
108
+ if self.auth_backend:
109
+ cleanup_tasks.append(asyncio.create_task(self._session_cleanup_loop()))
110
+ cleanup_tasks.append(asyncio.create_task(self._oauth_state_cleanup_loop()))
111
+
88
112
  yield
89
113
 
114
+ # Cancel all cleanup tasks on shutdown
115
+ for task in cleanup_tasks:
116
+ task.cancel()
117
+ with suppress(asyncio.CancelledError):
118
+ await task
119
+
90
120
  self.app = FastAPI(lifespan=lifespan)
91
121
 
92
122
  self.configure_app()
@@ -122,52 +152,127 @@ class RagbitsAPI:
122
152
  content={"detail": exc.errors(), "body": exc.body},
123
153
  )
124
154
 
125
- def setup_routes(self) -> None:
155
+ def setup_routes(self) -> None: # noqa: PLR0915
126
156
  """Defines API routes."""
127
157
  # Authentication routes
128
158
  if self.auth_backend:
129
- # Create security dependency variable to avoid B008 linting error
130
- security_dependency = Depends(self.security)
131
159
 
132
160
  @self.app.post("/api/auth/login", response_class=JSONResponse)
133
161
  async def login(request: LoginRequest) -> JSONResponse:
134
162
  return await self._handle_login(request)
135
163
 
136
164
  @self.app.post("/api/auth/logout", response_class=JSONResponse)
137
- async def logout(request: LogoutRequest) -> JSONResponse:
165
+ async def logout(request: Request) -> JSONResponse:
138
166
  return await self._handle_logout(request)
139
167
 
168
+ # OAuth2 routes (if OAuth2 backend is configured)
169
+ oauth2_backends = []
170
+ if isinstance(self.auth_backend, MultiAuthenticationBackend):
171
+ oauth2_backends = self.auth_backend.get_oauth2_backends()
172
+ elif isinstance(self.auth_backend, OAuth2AuthenticationBackend):
173
+ oauth2_backends = [self.auth_backend]
174
+
175
+ if oauth2_backends:
176
+ # Create a mapping of provider name to backend for quick lookup
177
+ oauth2_backend_map = {backend.provider.name: backend for backend in oauth2_backends}
178
+
179
+ @self.app.get("/api/auth/authorize/{provider}", response_class=JSONResponse)
180
+ async def oauth2_authorize(provider: str) -> JSONResponse:
181
+ """Generate OAuth2 authorization URL for specified provider."""
182
+ backend = oauth2_backend_map.get(provider)
183
+ if not backend:
184
+ raise HTTPException(status_code=404, detail=f"OAuth2 provider '{provider}' not found")
185
+
186
+ authorize_url, state = backend.generate_authorize_url()
187
+ return JSONResponse(content={"authorize_url": authorize_url, "state": state})
188
+
189
+ @self.app.get("/api/auth/callback/{provider}", response_class=RedirectResponse)
190
+ async def oauth2_callback(
191
+ provider: str, code: str | None = None, state: str | None = None
192
+ ) -> RedirectResponse:
193
+ """Handle OAuth2 callback from provider."""
194
+ backend = oauth2_backend_map.get(provider)
195
+ if not backend:
196
+ raise HTTPException(status_code=404, detail=f"OAuth2 provider '{provider}' not found")
197
+
198
+ return await self._handle_oauth2_callback(code, state, backend)
199
+
140
200
  @self.app.post("/api/chat", response_class=StreamingResponse)
141
201
  async def chat_message(
142
- request: ChatMessageRequest,
143
- credentials: HTTPAuthorizationCredentials | None = security_dependency,
202
+ request: Request,
203
+ chat_request: ChatMessageRequest,
144
204
  ) -> StreamingResponse:
145
- return await self._handle_chat_message(request, credentials)
205
+ return await self._handle_chat_message(chat_request, request)
146
206
 
147
207
  @self.app.post("/api/feedback", response_class=JSONResponse)
148
208
  async def feedback(
149
- request: FeedbackRequest,
150
- credentials: HTTPAuthorizationCredentials | None = security_dependency,
209
+ request: Request,
210
+ feedback_request: FeedbackRequest,
151
211
  ) -> JSONResponse:
152
- return await self._handle_feedback(request, credentials)
212
+ return await self._handle_feedback(feedback_request, request)
153
213
  else:
154
214
 
155
215
  @self.app.post("/api/chat", response_class=StreamingResponse)
156
216
  async def chat_message(
157
- request: ChatMessageRequest,
217
+ request: Request,
218
+ chat_request: ChatMessageRequest,
158
219
  ) -> StreamingResponse:
159
- return await self._handle_chat_message(request, None)
220
+ return await self._handle_chat_message(chat_request, request)
160
221
 
161
222
  @self.app.post("/api/feedback", response_class=JSONResponse)
162
223
  async def feedback(
163
- request: FeedbackRequest,
224
+ request: Request,
225
+ feedback_request: FeedbackRequest,
164
226
  ) -> JSONResponse:
165
- return await self._handle_feedback(request, None)
227
+ return await self._handle_feedback(feedback_request, request)
166
228
 
167
229
  @self.app.get("/api/config", response_class=JSONResponse)
168
230
  async def config() -> JSONResponse:
169
231
  feedback_config = self.chat_interface.feedback_config
170
232
 
233
+ # Determine available auth types and OAuth2 providers based on backend
234
+ auth_types = []
235
+ oauth2_providers = []
236
+
237
+ if self.auth_backend:
238
+ if isinstance(self.auth_backend, MultiAuthenticationBackend):
239
+ # Multi backend: check what types are available
240
+ if self.auth_backend.get_credentials_backends():
241
+ auth_types.append(AuthType.CREDENTIALS)
242
+
243
+ oauth2_backends = self.auth_backend.get_oauth2_backends()
244
+ if oauth2_backends:
245
+ auth_types.append(AuthType.OAUTH2)
246
+ for backend in oauth2_backends:
247
+ visual_config = get_provider_visual_config(backend.provider.name)
248
+ oauth2_providers.append(
249
+ OAuth2ProviderConfig(
250
+ name=backend.provider.name,
251
+ display_name=backend.provider.display_name,
252
+ color=visual_config.color,
253
+ button_color=visual_config.button_color,
254
+ text_color=visual_config.text_color,
255
+ icon_svg=visual_config.icon_svg,
256
+ )
257
+ )
258
+ elif isinstance(self.auth_backend, OAuth2AuthenticationBackend):
259
+ # Single OAuth2 backend
260
+ auth_types = [AuthType.OAUTH2]
261
+ visual_config = get_provider_visual_config(self.auth_backend.provider.name)
262
+ oauth2_providers = [
263
+ OAuth2ProviderConfig(
264
+ name=self.auth_backend.provider.name,
265
+ display_name=self.auth_backend.provider.display_name,
266
+ color=visual_config.color,
267
+ button_color=visual_config.button_color,
268
+ text_color=visual_config.text_color,
269
+ icon_svg=visual_config.icon_svg,
270
+ )
271
+ ]
272
+ else:
273
+ # Single credentials backend
274
+ auth_types = [AuthType.CREDENTIALS]
275
+
171
276
  config_response = ConfigResponse(
172
277
  feedback=FeedbackConfig(
173
278
  like=FeedbackItem(
@@ -186,12 +291,22 @@ class RagbitsAPI:
186
291
  show_usage=self.chat_interface.show_usage,
187
292
  authentication=AuthenticationConfig(
188
293
  enabled=self.auth_backend is not None,
189
- auth_types=[AuthType.CREDENTIALS],
294
+ auth_types=auth_types,
295
+ oauth2_providers=oauth2_providers,
190
296
  ),
191
297
  )
192
298
 
193
299
  return JSONResponse(content=config_response.model_dump())
194
300
 
301
+ # User info endpoint - returns current authenticated user
302
+ @self.app.get("/api/user", response_class=JSONResponse)
303
+ async def get_user(request: Request) -> JSONResponse:
304
+ """Get current authenticated user from session cookie."""
305
+ user = await self.get_current_user_from_cookie(request)
306
+ if user:
307
+ return JSONResponse(content=user.model_dump())
308
+ return JSONResponse(content=None, status_code=401)
309
+
195
310
  # Theme CSS endpoint - always available, returns 404 if no theme configured
196
311
  @self.app.get("/api/theme", response_class=PlainTextResponse)
197
312
  async def theme() -> PlainTextResponse:
@@ -217,41 +332,18 @@ class RagbitsAPI:
217
332
  with open(str(index_file)) as file:
218
333
  return HTMLResponse(content=file.read())
219
334
 
220
- async def _validate_authentication(self, credentials: HTTPAuthorizationCredentials | None) -> User | None:
221
- """Validate authentication credentials and return user if valid."""
222
- if not self.auth_backend:
223
- return None
224
-
225
- if not credentials:
226
- raise HTTPException(
227
- status_code=status.HTTP_401_UNAUTHORIZED,
228
- detail="Authentication required",
229
- headers={"WWW-Authenticate": "Bearer"},
230
- )
231
-
232
- # The jwt_token should be the session_id
233
- auth_result = await self.auth_backend.validate_token(credentials.credentials)
234
- if not auth_result.success:
235
- raise HTTPException(
236
- status_code=status.HTTP_401_UNAUTHORIZED,
237
- detail=auth_result.error_message or "Invalid session",
238
- headers={"WWW-Authenticate": "Bearer"},
239
- )
240
-
241
- return auth_result.user
242
-
243
335
  @staticmethod
244
336
  def _prepare_chat_context(
245
337
  request: ChatMessageRequest,
246
338
  authenticated_user: User | None,
247
- credentials: HTTPAuthorizationCredentials | None,
339
+ session_id: str | None,
248
340
  ) -> ChatContext:
249
341
  """Prepare and validate chat context from request."""
250
342
  chat_context = ChatContext(**request.context)
251
343
 
252
344
  # Add session_id to context if authenticated
253
- if authenticated_user and credentials:
254
- chat_context.session_id = credentials.credentials
345
+ if authenticated_user and session_id:
346
+ chat_context.session_id = session_id
255
347
  chat_context.user = authenticated_user
256
348
 
257
349
  # Verify state signature if provided
@@ -275,9 +367,7 @@ class RagbitsAPI:
275
367
 
276
368
  return chat_context
277
369
 
278
- async def _handle_chat_message(
279
- self, request: ChatMessageRequest, credentials: HTTPAuthorizationCredentials | None = None
280
- ) -> StreamingResponse: # noqa: PLR0915
370
+ async def _handle_chat_message(self, chat_request: ChatMessageRequest, request: Request) -> StreamingResponse: # noqa: PLR0915
281
371
  """Handle chat message requests with metrics tracking."""
282
372
  start_time = time.time()
283
373
 
@@ -287,8 +377,8 @@ class RagbitsAPI:
287
377
  )
288
378
 
289
379
  try:
290
- # Validate authentication if required
291
- authenticated_user = await self._validate_authentication(credentials)
380
+ # Validate authentication if required using cookies
381
+ authenticated_user = await self.require_authenticated_user(request)
292
382
 
293
383
  if not self.chat_interface:
294
384
  record_metric(
@@ -302,12 +392,13 @@ class RagbitsAPI:
302
392
  raise HTTPException(status_code=500, detail="Chat implementation is not initialized")
303
393
 
304
394
  # Prepare chat context
305
- chat_context = RagbitsAPI._prepare_chat_context(request, authenticated_user, credentials)
395
+ session_id = request.cookies.get(SESSION_COOKIE_NAME)
396
+ chat_context = RagbitsAPI._prepare_chat_context(chat_request, authenticated_user, session_id)
306
397
 
307
398
  # Get the response generator from the chat interface
308
399
  response_generator = self.chat_interface.chat(
309
- message=request.message,
310
- history=[msg.model_dump() for msg in request.history],
400
+ message=chat_request.message,
401
+ history=[msg.model_dump() for msg in chat_request.history],
311
402
  context=chat_context,
312
403
  )
313
404
 
@@ -318,8 +409,8 @@ class RagbitsAPI:
318
409
  state_update_text = ""
319
410
 
320
411
  with trace(
321
- message=request.message,
322
- history=[msg.model_dump() for msg in request.history],
412
+ message=chat_request.message,
413
+ history=[msg.model_dump() for msg in chat_request.history],
323
414
  context=chat_context,
324
415
  ) as outputs:
325
416
  async for chunk in RagbitsAPI._chat_response_to_sse(response_generator):
@@ -406,9 +497,7 @@ class RagbitsAPI:
406
497
  )
407
498
  raise HTTPException(status_code=500, detail="Internal server error") from None
408
499
 
409
- async def _handle_feedback(
410
- self, request: FeedbackRequest, credentials: HTTPAuthorizationCredentials | None = None
411
- ) -> JSONResponse:
500
+ async def _handle_feedback(self, feedback_request: FeedbackRequest, request: Request) -> JSONResponse:
412
501
  """Handle feedback requests with metrics tracking."""
413
502
  start_time = time.time()
414
503
 
@@ -422,8 +511,8 @@ class RagbitsAPI:
422
511
  )
423
512
 
424
513
  try:
425
- # Validate authentication if required
426
- await self._validate_authentication(credentials)
514
+ # Validate authentication if required using cookies
515
+ await self.require_authenticated_user(request)
427
516
 
428
517
  if not self.chat_interface:
429
518
  record_metric(
@@ -437,9 +526,9 @@ class RagbitsAPI:
437
526
  raise HTTPException(status_code=500, detail="Chat implementation is not initialized")
438
527
 
439
528
  await self.chat_interface.save_feedback(
440
- message_id=request.message_id,
441
- feedback=request.feedback,
442
- payload=request.payload,
529
+ message_id=feedback_request.message_id,
530
+ feedback=feedback_request.feedback,
531
+ payload=feedback_request.payload,
443
532
  )
444
533
 
445
534
  # Track successful request duration
@@ -496,26 +585,109 @@ class RagbitsAPI:
496
585
  )
497
586
  raise HTTPException(status_code=500, detail="Internal server error") from None
498
587
 
588
+ async def get_current_user_from_cookie(self, request: Request) -> User | None:
589
+ """
590
+ Get current user from session cookie.
591
+
592
+ Args:
593
+ request: FastAPI request object
594
+
595
+ Returns:
596
+ User object if authenticated, None otherwise
597
+ """
598
+ if not self.auth_backend:
599
+ return None
600
+
601
+ session_id = request.cookies.get(SESSION_COOKIE_NAME)
602
+ if not session_id:
603
+ return None
604
+
605
+ result = await self.auth_backend.validate_session(session_id)
606
+ if result.success:
607
+ return result.user
608
+
609
+ return None
610
+
611
+ async def require_authenticated_user(self, request: Request) -> User | None:
612
+ """
613
+ Get current user from session cookie and raise HTTPException if authentication
614
+ is required but user is not authenticated.
615
+
616
+ This is a reusable dependency for handlers that require authentication.
617
+
618
+ Args:
619
+ request: FastAPI request object
620
+
621
+ Returns:
622
+ User object if authenticated (or if no auth is required), None if no auth backend
623
+
624
+ Raises:
625
+ HTTPException: 401 Unauthorized if authentication is required but user is not authenticated
626
+ """
627
+ authenticated_user = await self.get_current_user_from_cookie(request)
628
+ if self.auth_backend and not authenticated_user:
629
+ raise HTTPException(
630
+ status_code=status.HTTP_401_UNAUTHORIZED,
631
+ detail="Authentication required",
632
+ )
633
+ return authenticated_user
634
+
635
+ def get_session_expiry_seconds(self) -> int:
636
+ """
637
+ Get session expiry time in seconds from the auth backend configuration.
638
+
639
+ Returns:
640
+ Session expiry time in seconds (default: 24 hours if not configured)
641
+ """
642
+ # Default to 24 hours
643
+ default_expiry_hours = 24
644
+
645
+ if not self.auth_backend:
646
+ return default_expiry_hours * 3600
647
+
648
+ # Check if the backend has session_expiry_hours attribute
649
+ if hasattr(self.auth_backend, "session_expiry_hours"):
650
+ return self.auth_backend.session_expiry_hours * 3600
651
+
652
+ # For MultiAuthenticationBackend, try to get from first backend that has it
653
+ if isinstance(self.auth_backend, MultiAuthenticationBackend):
654
+ for backend in self.auth_backend.backends:
655
+ if hasattr(backend, "session_expiry_hours"):
656
+ return backend.session_expiry_hours * 3600
657
+
658
+ return default_expiry_hours * 3600
659
+
499
660
  async def _handle_login(self, request: LoginRequest) -> JSONResponse:
500
- """Handle user login requests."""
661
+ """Handle user login requests with credentials."""
501
662
  if not self.auth_backend:
502
663
  raise HTTPException(status_code=500, detail="Authentication not configured")
503
664
 
504
665
  try:
505
- from .auth.types import UserCredentials
666
+ # LoginRequest is UserCredentials
667
+ auth_result = await self.auth_backend.authenticate_with_credentials(request)
506
668
 
507
- credentials = UserCredentials(username=request.username, password=request.password)
508
- auth_result = await self.auth_backend.authenticate_with_credentials(credentials)
509
-
510
- if auth_result.success and auth_result.jwt_token:
511
- return JSONResponse(
669
+ if auth_result.success and auth_result.session_id:
670
+ response = JSONResponse(
512
671
  content=LoginResponse(
513
672
  success=True,
514
673
  user=auth_result.user if auth_result.user else None,
515
674
  error_message=None,
516
- jwt_token=auth_result.jwt_token,
517
675
  ).model_dump()
518
676
  )
677
+
678
+ # Set secure HTTP-only cookie using backend's session expiry configuration
679
+ session_expiry_seconds = self.get_session_expiry_seconds()
680
+ response.set_cookie(
681
+ key=SESSION_COOKIE_NAME,
682
+ value=auth_result.session_id,
683
+ httponly=True,
684
+ secure=IS_PRODUCTION, # Only require HTTPS in production
685
+ samesite="lax",
686
+ max_age=session_expiry_seconds,
687
+ path="/",
688
+ )
689
+
690
+ return response
519
691
  else:
520
692
  return JSONResponse(
521
693
  status_code=status.HTTP_401_UNAUTHORIZED,
@@ -523,7 +695,6 @@ class RagbitsAPI:
523
695
  success=False,
524
696
  user=None,
525
697
  error_message=auth_result.error_message or "Invalid credentials",
526
- jwt_token=None,
527
698
  ).model_dump(),
528
699
  )
529
700
  except Exception as e:
@@ -534,25 +705,100 @@ class RagbitsAPI:
534
705
  success=False,
535
706
  user=None,
536
707
  error_message="Internal server error",
537
- jwt_token=None,
538
708
  ).model_dump(),
539
709
  )
540
710
 
541
- async def _handle_logout(self, request: LogoutRequest) -> JSONResponse:
711
+ async def _handle_logout(self, request: Request) -> JSONResponse:
542
712
  """Handle user logout requests."""
543
713
  if not self.auth_backend:
544
714
  raise HTTPException(status_code=500, detail="Authentication not configured")
545
715
 
546
716
  try:
547
- success = await self.auth_backend.revoke_token(request.token)
548
- return JSONResponse(content={"success": success})
717
+ # Get session ID from cookie
718
+ session_id = request.cookies.get(SESSION_COOKIE_NAME)
719
+
720
+ if not session_id:
721
+ # No session cookie, just return success
722
+ response = JSONResponse(content={"success": True})
723
+ response.delete_cookie(key=SESSION_COOKIE_NAME, path="/")
724
+ return response
725
+
726
+ # Delete the session from store
727
+ success = await self.auth_backend.revoke_session(session_id)
728
+
729
+ response = JSONResponse(content={"success": success})
730
+ # Clear the session cookie
731
+ response.delete_cookie(key=SESSION_COOKIE_NAME, path="/")
732
+ return response
733
+
549
734
  except Exception as e:
550
735
  logger.error(f"Logout error: {e}")
551
736
  return JSONResponse(
552
737
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
553
- content={"success": False, "error_message": "Internal server error"},
738
+ content={"success": False, "error": "Internal server error"},
739
+ )
740
+
741
+ async def _handle_oauth2_callback( # noqa: PLR6301
742
+ self, code: str | None, state: str | None, backend: OAuth2AuthenticationBackend
743
+ ) -> RedirectResponse:
744
+ """
745
+ Handle OAuth2 callback from provider.
746
+
747
+ This endpoint receives the authorization code from the OAuth2 provider,
748
+ exchanges it for an access token, authenticates the user, creates a session,
749
+ and redirects to the frontend with a secure HTTP-only cookie.
750
+
751
+ Args:
752
+ code: Authorization code from OAuth2 provider
753
+ state: CSRF protection state parameter
754
+ backend: The specific OAuth2 backend for this provider
755
+ """
756
+ # Verify required parameters
757
+ if not code:
758
+ # Redirect to login with error
759
+ return RedirectResponse(url=f"{self.frontend_base_url}/login?error=missing_code", status_code=302)
760
+
761
+ # Verify state parameter for CSRF protection
762
+ if not state or not backend.verify_state(state):
763
+ return RedirectResponse(url=f"{self.frontend_base_url}/login?error=invalid_state", status_code=302)
764
+
765
+ try:
766
+ # Exchange code for access token
767
+ access_token = await backend.exchange_code_for_token(code)
768
+ if not access_token:
769
+ return RedirectResponse(
770
+ url=f"{self.frontend_base_url}/login?error=token_exchange_failed", status_code=302
771
+ )
772
+
773
+ # Authenticate with the access token
774
+ oauth_credentials = OAuth2Credentials(access_token=access_token, token_type="Bearer") # noqa: S106
775
+ auth_result = await backend.authenticate_with_oauth2(oauth_credentials)
776
+
777
+ if not auth_result.success or not auth_result.session_id:
778
+ error_msg = auth_result.error_message or "Authentication failed"
779
+ return RedirectResponse(url=f"{self.frontend_base_url}/login?error={error_msg}", status_code=302)
780
+
781
+ # Success! Create redirect response with session cookie
782
+ response = RedirectResponse(url=f"{self.frontend_base_url}/", status_code=302)
783
+
784
+ # Set secure HTTP-only cookie
785
+ session_expiry_seconds = backend.session_expiry_hours * 3600
786
+ response.set_cookie(
787
+ key=SESSION_COOKIE_NAME,
788
+ value=auth_result.session_id,
789
+ httponly=True,
790
+ secure=IS_PRODUCTION, # Only require HTTPS in production
791
+ samesite="lax",
792
+ max_age=session_expiry_seconds,
793
+ path="/",
554
794
  )
555
795
 
796
+ return response
797
+
798
+ except Exception as e:
799
+ logger.error(f"OAuth2 callback error: {e}")
800
+ return RedirectResponse(url=f"{self.frontend_base_url}/login?error=internal_error", status_code=302)
801
+
556
802
  @staticmethod
557
803
  async def _chat_response_to_sse(
558
804
  responses: AsyncGenerator[ChatResponseUnion],
@@ -776,3 +1022,42 @@ class RagbitsAPI:
776
1022
  css_lines.append(f" --heroui-{color_name}: {color_value};")
777
1023
 
778
1024
  return css_lines
1025
+
1026
+ async def _session_cleanup_loop(self) -> None:
1027
+ if (
1028
+ not self.auth_backend
1029
+ or not hasattr(self.auth_backend, "session_store")
1030
+ or not hasattr(self.auth_backend.session_store, "cleanup_expired_sessions")
1031
+ ):
1032
+ return
1033
+
1034
+ while True:
1035
+ await asyncio.sleep(3600) # Run every hour
1036
+ try:
1037
+ removed = self.auth_backend.session_store.cleanup_expired_sessions() # type: ignore
1038
+ if removed > 0:
1039
+ logger.info(f"Cleaned up {removed} expired sessions")
1040
+ except Exception as e:
1041
+ logger.exception(f"Error during session cleanup: {e}")
1042
+
1043
+ async def _oauth_state_cleanup_loop(self) -> None:
1044
+ oauth2_backends = []
1045
+ if isinstance(self.auth_backend, MultiAuthenticationBackend):
1046
+ oauth2_backends = self.auth_backend.get_oauth2_backends()
1047
+ elif isinstance(self.auth_backend, OAuth2AuthenticationBackend):
1048
+ oauth2_backends = [self.auth_backend]
1049
+
1050
+ if not oauth2_backends:
1051
+ return
1052
+
1053
+ while True:
1054
+ await asyncio.sleep(600) # Run every 10 minutes
1055
+ try:
1056
+ total_removed = 0
1057
+ for backend in oauth2_backends:
1058
+ removed = backend.cleanup_expired_states()
1059
+ total_removed += removed
1060
+ if total_removed > 0:
1061
+ logger.info(f"Cleaned up {total_removed} expired OAuth2 state tokens")
1062
+ except Exception as e:
1063
+ logger.exception(f"Error during OAuth state cleanup: {e}")
@@ -1,11 +1,13 @@
1
1
  from .backends import ListAuthenticationBackend
2
2
  from .base import AuthenticationBackend, AuthenticationResponse
3
- from .types import User, UserCredentials
3
+ from .types import Session, SessionStore, User, UserCredentials
4
4
 
5
5
  __all__ = [
6
6
  "AuthenticationBackend",
7
7
  "AuthenticationResponse",
8
8
  "ListAuthenticationBackend",
9
+ "Session",
10
+ "SessionStore",
9
11
  "User",
10
12
  "UserCredentials",
11
13
  ]