ragbits-chat 1.4.0.dev202512160238__py3-none-any.whl → 1.4.0.dev202601130240__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits/chat/api.py +364 -79
- ragbits/chat/auth/__init__.py +3 -1
- ragbits/chat/auth/backends.py +519 -81
- ragbits/chat/auth/base.py +8 -10
- ragbits/chat/auth/oauth2_providers.py +108 -0
- ragbits/chat/auth/provider_config.py +81 -0
- ragbits/chat/auth/session_store.py +178 -0
- ragbits/chat/auth/types.py +66 -29
- ragbits/chat/interface/types.py +15 -0
- ragbits/chat/persistence/sql.py +1 -0
- ragbits/chat/providers/model_provider.py +10 -8
- ragbits/chat/ui-build/assets/AuthGuard-Bq7UOJ7y.js +1 -0
- ragbits/chat/ui-build/assets/{ChatHistory--EyHdeYk.js → ChatHistory-B2hLBYMJ.js} +2 -2
- ragbits/chat/ui-build/assets/{ChatOptionsForm-BQJ-bYMu.js → ChatOptionsForm-bfNG8UIW.js} +1 -1
- ragbits/chat/ui-build/assets/CredentialsLogin-0g5-w2vR.js +1 -0
- ragbits/chat/ui-build/assets/{FeedbackForm-BYee4uF-.js → FeedbackForm-oSbly5oN.js} +1 -1
- ragbits/chat/ui-build/assets/Login-DSW_CNFu.js +1 -0
- ragbits/chat/ui-build/assets/LogoutButton-BQE8NNsg.js +1 -0
- ragbits/chat/ui-build/assets/OAuth2Login-EmJ39PUe.js +2 -0
- ragbits/chat/ui-build/assets/ShareButton-B7DyIVH0.js +1 -0
- ragbits/chat/ui-build/assets/{UsageButton-C0lzhbc6.js → UsageButton-BABA7a-w.js} +1 -1
- ragbits/chat/ui-build/assets/authStore-BfGlL8rp.js +1 -0
- ragbits/chat/ui-build/assets/{chunk-IGSAU2ZA-ZqFHUQCB.js → chunk-IGSAU2ZA-NXd1g0Qd.js} +1 -1
- ragbits/chat/ui-build/assets/{chunk-SSA7SXE4-Dj8WqXIN.js → chunk-SSA7SXE4-CJa0HuAU.js} +1 -1
- ragbits/chat/ui-build/assets/index-BZLU40Mk.js +83 -0
- ragbits/chat/ui-build/assets/index-Be0kkf3d.js +24 -0
- ragbits/chat/ui-build/assets/index-Bvn9K6h_.js +1 -0
- ragbits/chat/ui-build/assets/{index-Bpba6d6u.js → index-Ceq7Rkzy.js} +1 -1
- ragbits/chat/ui-build/assets/index-ClAYkAiv.css +1 -0
- ragbits/chat/ui-build/assets/useInitializeUserStore-DyHP7g8x.js +1 -0
- ragbits/chat/ui-build/assets/{useMenuTriggerState-DaMXoDzf.js → useMenuTriggerState-SaFmATkk.js} +1 -1
- ragbits/chat/ui-build/assets/{useSelectableItem-D8MlMsUd.js → useSelectableItem-DhuFnc0W.js} +1 -1
- ragbits/chat/ui-build/index.html +2 -2
- {ragbits_chat-1.4.0.dev202512160238.dist-info → ragbits_chat-1.4.0.dev202601130240.dist-info}/METADATA +2 -2
- ragbits_chat-1.4.0.dev202601130240.dist-info/RECORD +58 -0
- ragbits/chat/ui-build/assets/AuthGuard-Da8Duw5e.js +0 -1
- ragbits/chat/ui-build/assets/Login-DkcluI7q.js +0 -1
- ragbits/chat/ui-build/assets/LogoutButton-DSdCqsdC.js +0 -1
- ragbits/chat/ui-build/assets/ShareButton-YVgRY6bN.js +0 -1
- ragbits/chat/ui-build/assets/authStore-BFeUV-Bg.js +0 -1
- ragbits/chat/ui-build/assets/index-8hpVK8cj.js +0 -131
- ragbits/chat/ui-build/assets/index-BUbs7vFP.js +0 -1
- ragbits/chat/ui-build/assets/index-BZGp6GjF.js +0 -32
- ragbits/chat/ui-build/assets/index-CmsICuOz.css +0 -1
- ragbits_chat-1.4.0.dev202512160238.dist-info/RECORD +0 -52
- {ragbits_chat-1.4.0.dev202512160238.dist-info → ragbits_chat-1.4.0.dev202601130240.dist-info}/WHEEL +0 -0
ragbits/chat/api.py
CHANGED
|
@@ -1,24 +1,27 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import importlib
|
|
2
3
|
import json
|
|
3
4
|
import logging
|
|
5
|
+
import os
|
|
4
6
|
import re
|
|
5
7
|
import time
|
|
6
8
|
from collections.abc import AsyncGenerator
|
|
7
|
-
from contextlib import asynccontextmanager
|
|
9
|
+
from contextlib import asynccontextmanager, suppress
|
|
8
10
|
from pathlib import Path
|
|
9
11
|
from typing import Any, cast
|
|
10
12
|
|
|
11
13
|
import uvicorn
|
|
12
|
-
from fastapi import
|
|
14
|
+
from fastapi import FastAPI, HTTPException, Request, status
|
|
13
15
|
from fastapi.exceptions import RequestValidationError
|
|
14
16
|
from fastapi.middleware.cors import CORSMiddleware
|
|
15
|
-
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse, StreamingResponse
|
|
16
|
-
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
17
|
+
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse, RedirectResponse, StreamingResponse
|
|
17
18
|
from fastapi.staticfiles import StaticFiles
|
|
18
19
|
from pydantic import BaseModel
|
|
19
20
|
|
|
20
21
|
from ragbits.chat.auth import AuthenticationBackend, User
|
|
21
|
-
from ragbits.chat.auth.
|
|
22
|
+
from ragbits.chat.auth.backends import MultiAuthenticationBackend, OAuth2AuthenticationBackend
|
|
23
|
+
from ragbits.chat.auth.provider_config import get_provider_visual_config
|
|
24
|
+
from ragbits.chat.auth.types import LoginRequest, LoginResponse, OAuth2Credentials
|
|
22
25
|
from ragbits.chat.interface import ChatInterface
|
|
23
26
|
from ragbits.chat.interface.types import (
|
|
24
27
|
AuthenticationConfig,
|
|
@@ -33,6 +36,7 @@ from ragbits.chat.interface.types import (
|
|
|
33
36
|
FeedbackRequest,
|
|
34
37
|
Image,
|
|
35
38
|
ImageResponse,
|
|
39
|
+
OAuth2ProviderConfig,
|
|
36
40
|
)
|
|
37
41
|
from ragbits.core.audit.metrics import record_metric
|
|
38
42
|
from ragbits.core.audit.metrics.base import MetricType
|
|
@@ -42,6 +46,12 @@ from .metrics import ChatCounterMetric, ChatHistogramMetric
|
|
|
42
46
|
|
|
43
47
|
logger = logging.getLogger(__name__)
|
|
44
48
|
|
|
49
|
+
# Environment-aware cookie security: only require HTTPS in production
|
|
50
|
+
IS_PRODUCTION = os.getenv("ENVIRONMENT", "").lower() == "production"
|
|
51
|
+
|
|
52
|
+
# Session cookie name - used for storing session ID in HTTP-only cookie
|
|
53
|
+
SESSION_COOKIE_NAME = "ragbits_session"
|
|
54
|
+
|
|
45
55
|
# Chunk size for large base64 images to prevent SSE message size issues
|
|
46
56
|
# Keep chunks extremely small to avoid JSON string length limits in browsers and SSE parsing issues
|
|
47
57
|
# Account for JSON overhead: metadata + base64 data should fit comfortably in browser limits
|
|
@@ -79,14 +89,34 @@ class RagbitsAPI:
|
|
|
79
89
|
self.cors_origins = cors_origins or []
|
|
80
90
|
self.debug_mode = debug_mode
|
|
81
91
|
self.auth_backend = self._load_auth_backend(auth_backend)
|
|
82
|
-
self.security = HTTPBearer(auto_error=False) if auth_backend else None
|
|
83
92
|
self.theme_path = Path(theme_path) if theme_path else None
|
|
84
93
|
|
|
94
|
+
# get frontend base URL from environment variable or use default
|
|
95
|
+
frontend_base_url = os.getenv("FRONTEND_BASE_URL", "http://localhost:8000")
|
|
96
|
+
# remove trailing slash from frontend base URL
|
|
97
|
+
frontend_base_url = frontend_base_url.rstrip("/")
|
|
98
|
+
# set frontend base URL on the API
|
|
99
|
+
self.frontend_base_url = frontend_base_url
|
|
100
|
+
|
|
85
101
|
@asynccontextmanager
|
|
86
102
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
87
103
|
await self.chat_interface.setup()
|
|
104
|
+
|
|
105
|
+
# Start background cleanup tasks for session and OAuth state management
|
|
106
|
+
cleanup_tasks: list[asyncio.Task] = []
|
|
107
|
+
|
|
108
|
+
if self.auth_backend:
|
|
109
|
+
cleanup_tasks.append(asyncio.create_task(self._session_cleanup_loop()))
|
|
110
|
+
cleanup_tasks.append(asyncio.create_task(self._oauth_state_cleanup_loop()))
|
|
111
|
+
|
|
88
112
|
yield
|
|
89
113
|
|
|
114
|
+
# Cancel all cleanup tasks on shutdown
|
|
115
|
+
for task in cleanup_tasks:
|
|
116
|
+
task.cancel()
|
|
117
|
+
with suppress(asyncio.CancelledError):
|
|
118
|
+
await task
|
|
119
|
+
|
|
90
120
|
self.app = FastAPI(lifespan=lifespan)
|
|
91
121
|
|
|
92
122
|
self.configure_app()
|
|
@@ -122,52 +152,127 @@ class RagbitsAPI:
|
|
|
122
152
|
content={"detail": exc.errors(), "body": exc.body},
|
|
123
153
|
)
|
|
124
154
|
|
|
125
|
-
def setup_routes(self) -> None:
|
|
155
|
+
def setup_routes(self) -> None: # noqa: PLR0915
|
|
126
156
|
"""Defines API routes."""
|
|
127
157
|
# Authentication routes
|
|
128
158
|
if self.auth_backend:
|
|
129
|
-
# Create security dependency variable to avoid B008 linting error
|
|
130
|
-
security_dependency = Depends(self.security)
|
|
131
159
|
|
|
132
160
|
@self.app.post("/api/auth/login", response_class=JSONResponse)
|
|
133
161
|
async def login(request: LoginRequest) -> JSONResponse:
|
|
134
162
|
return await self._handle_login(request)
|
|
135
163
|
|
|
136
164
|
@self.app.post("/api/auth/logout", response_class=JSONResponse)
|
|
137
|
-
async def logout(request:
|
|
165
|
+
async def logout(request: Request) -> JSONResponse:
|
|
138
166
|
return await self._handle_logout(request)
|
|
139
167
|
|
|
168
|
+
# OAuth2 routes (if OAuth2 backend is configured)
|
|
169
|
+
oauth2_backends = []
|
|
170
|
+
if isinstance(self.auth_backend, MultiAuthenticationBackend):
|
|
171
|
+
oauth2_backends = self.auth_backend.get_oauth2_backends()
|
|
172
|
+
elif isinstance(self.auth_backend, OAuth2AuthenticationBackend):
|
|
173
|
+
oauth2_backends = [self.auth_backend]
|
|
174
|
+
|
|
175
|
+
if oauth2_backends:
|
|
176
|
+
# Create a mapping of provider name to backend for quick lookup
|
|
177
|
+
oauth2_backend_map = {backend.provider.name: backend for backend in oauth2_backends}
|
|
178
|
+
|
|
179
|
+
@self.app.get("/api/auth/authorize/{provider}", response_class=JSONResponse)
|
|
180
|
+
async def oauth2_authorize(provider: str) -> JSONResponse:
|
|
181
|
+
"""Generate OAuth2 authorization URL for specified provider."""
|
|
182
|
+
backend = oauth2_backend_map.get(provider)
|
|
183
|
+
if not backend:
|
|
184
|
+
raise HTTPException(status_code=404, detail=f"OAuth2 provider '{provider}' not found")
|
|
185
|
+
|
|
186
|
+
authorize_url, state = backend.generate_authorize_url()
|
|
187
|
+
return JSONResponse(content={"authorize_url": authorize_url, "state": state})
|
|
188
|
+
|
|
189
|
+
@self.app.get("/api/auth/callback/{provider}", response_class=RedirectResponse)
|
|
190
|
+
async def oauth2_callback(
|
|
191
|
+
provider: str, code: str | None = None, state: str | None = None
|
|
192
|
+
) -> RedirectResponse:
|
|
193
|
+
"""Handle OAuth2 callback from provider."""
|
|
194
|
+
backend = oauth2_backend_map.get(provider)
|
|
195
|
+
if not backend:
|
|
196
|
+
raise HTTPException(status_code=404, detail=f"OAuth2 provider '{provider}' not found")
|
|
197
|
+
|
|
198
|
+
return await self._handle_oauth2_callback(code, state, backend)
|
|
199
|
+
|
|
140
200
|
@self.app.post("/api/chat", response_class=StreamingResponse)
|
|
141
201
|
async def chat_message(
|
|
142
|
-
request:
|
|
143
|
-
|
|
202
|
+
request: Request,
|
|
203
|
+
chat_request: ChatMessageRequest,
|
|
144
204
|
) -> StreamingResponse:
|
|
145
|
-
return await self._handle_chat_message(
|
|
205
|
+
return await self._handle_chat_message(chat_request, request)
|
|
146
206
|
|
|
147
207
|
@self.app.post("/api/feedback", response_class=JSONResponse)
|
|
148
208
|
async def feedback(
|
|
149
|
-
request:
|
|
150
|
-
|
|
209
|
+
request: Request,
|
|
210
|
+
feedback_request: FeedbackRequest,
|
|
151
211
|
) -> JSONResponse:
|
|
152
|
-
return await self._handle_feedback(
|
|
212
|
+
return await self._handle_feedback(feedback_request, request)
|
|
153
213
|
else:
|
|
154
214
|
|
|
155
215
|
@self.app.post("/api/chat", response_class=StreamingResponse)
|
|
156
216
|
async def chat_message(
|
|
157
|
-
request:
|
|
217
|
+
request: Request,
|
|
218
|
+
chat_request: ChatMessageRequest,
|
|
158
219
|
) -> StreamingResponse:
|
|
159
|
-
return await self._handle_chat_message(
|
|
220
|
+
return await self._handle_chat_message(chat_request, request)
|
|
160
221
|
|
|
161
222
|
@self.app.post("/api/feedback", response_class=JSONResponse)
|
|
162
223
|
async def feedback(
|
|
163
|
-
request:
|
|
224
|
+
request: Request,
|
|
225
|
+
feedback_request: FeedbackRequest,
|
|
164
226
|
) -> JSONResponse:
|
|
165
|
-
return await self._handle_feedback(
|
|
227
|
+
return await self._handle_feedback(feedback_request, request)
|
|
166
228
|
|
|
167
229
|
@self.app.get("/api/config", response_class=JSONResponse)
|
|
168
230
|
async def config() -> JSONResponse:
|
|
169
231
|
feedback_config = self.chat_interface.feedback_config
|
|
170
232
|
|
|
233
|
+
# Determine available auth types and OAuth2 providers based on backend
|
|
234
|
+
auth_types = []
|
|
235
|
+
oauth2_providers = []
|
|
236
|
+
|
|
237
|
+
if self.auth_backend:
|
|
238
|
+
if isinstance(self.auth_backend, MultiAuthenticationBackend):
|
|
239
|
+
# Multi backend: check what types are available
|
|
240
|
+
if self.auth_backend.get_credentials_backends():
|
|
241
|
+
auth_types.append(AuthType.CREDENTIALS)
|
|
242
|
+
|
|
243
|
+
oauth2_backends = self.auth_backend.get_oauth2_backends()
|
|
244
|
+
if oauth2_backends:
|
|
245
|
+
auth_types.append(AuthType.OAUTH2)
|
|
246
|
+
for backend in oauth2_backends:
|
|
247
|
+
visual_config = get_provider_visual_config(backend.provider.name)
|
|
248
|
+
oauth2_providers.append(
|
|
249
|
+
OAuth2ProviderConfig(
|
|
250
|
+
name=backend.provider.name,
|
|
251
|
+
display_name=backend.provider.display_name,
|
|
252
|
+
color=visual_config.color,
|
|
253
|
+
button_color=visual_config.button_color,
|
|
254
|
+
text_color=visual_config.text_color,
|
|
255
|
+
icon_svg=visual_config.icon_svg,
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
elif isinstance(self.auth_backend, OAuth2AuthenticationBackend):
|
|
259
|
+
# Single OAuth2 backend
|
|
260
|
+
auth_types = [AuthType.OAUTH2]
|
|
261
|
+
visual_config = get_provider_visual_config(self.auth_backend.provider.name)
|
|
262
|
+
oauth2_providers = [
|
|
263
|
+
OAuth2ProviderConfig(
|
|
264
|
+
name=self.auth_backend.provider.name,
|
|
265
|
+
display_name=self.auth_backend.provider.display_name,
|
|
266
|
+
color=visual_config.color,
|
|
267
|
+
button_color=visual_config.button_color,
|
|
268
|
+
text_color=visual_config.text_color,
|
|
269
|
+
icon_svg=visual_config.icon_svg,
|
|
270
|
+
)
|
|
271
|
+
]
|
|
272
|
+
else:
|
|
273
|
+
# Single credentials backend
|
|
274
|
+
auth_types = [AuthType.CREDENTIALS]
|
|
275
|
+
|
|
171
276
|
config_response = ConfigResponse(
|
|
172
277
|
feedback=FeedbackConfig(
|
|
173
278
|
like=FeedbackItem(
|
|
@@ -186,12 +291,22 @@ class RagbitsAPI:
|
|
|
186
291
|
show_usage=self.chat_interface.show_usage,
|
|
187
292
|
authentication=AuthenticationConfig(
|
|
188
293
|
enabled=self.auth_backend is not None,
|
|
189
|
-
auth_types=
|
|
294
|
+
auth_types=auth_types,
|
|
295
|
+
oauth2_providers=oauth2_providers,
|
|
190
296
|
),
|
|
191
297
|
)
|
|
192
298
|
|
|
193
299
|
return JSONResponse(content=config_response.model_dump())
|
|
194
300
|
|
|
301
|
+
# User info endpoint - returns current authenticated user
|
|
302
|
+
@self.app.get("/api/user", response_class=JSONResponse)
|
|
303
|
+
async def get_user(request: Request) -> JSONResponse:
|
|
304
|
+
"""Get current authenticated user from session cookie."""
|
|
305
|
+
user = await self.get_current_user_from_cookie(request)
|
|
306
|
+
if user:
|
|
307
|
+
return JSONResponse(content=user.model_dump())
|
|
308
|
+
return JSONResponse(content=None, status_code=401)
|
|
309
|
+
|
|
195
310
|
# Theme CSS endpoint - always available, returns 404 if no theme configured
|
|
196
311
|
@self.app.get("/api/theme", response_class=PlainTextResponse)
|
|
197
312
|
async def theme() -> PlainTextResponse:
|
|
@@ -217,41 +332,18 @@ class RagbitsAPI:
|
|
|
217
332
|
with open(str(index_file)) as file:
|
|
218
333
|
return HTMLResponse(content=file.read())
|
|
219
334
|
|
|
220
|
-
async def _validate_authentication(self, credentials: HTTPAuthorizationCredentials | None) -> User | None:
|
|
221
|
-
"""Validate authentication credentials and return user if valid."""
|
|
222
|
-
if not self.auth_backend:
|
|
223
|
-
return None
|
|
224
|
-
|
|
225
|
-
if not credentials:
|
|
226
|
-
raise HTTPException(
|
|
227
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
228
|
-
detail="Authentication required",
|
|
229
|
-
headers={"WWW-Authenticate": "Bearer"},
|
|
230
|
-
)
|
|
231
|
-
|
|
232
|
-
# The jwt_token should be the session_id
|
|
233
|
-
auth_result = await self.auth_backend.validate_token(credentials.credentials)
|
|
234
|
-
if not auth_result.success:
|
|
235
|
-
raise HTTPException(
|
|
236
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
237
|
-
detail=auth_result.error_message or "Invalid session",
|
|
238
|
-
headers={"WWW-Authenticate": "Bearer"},
|
|
239
|
-
)
|
|
240
|
-
|
|
241
|
-
return auth_result.user
|
|
242
|
-
|
|
243
335
|
@staticmethod
|
|
244
336
|
def _prepare_chat_context(
|
|
245
337
|
request: ChatMessageRequest,
|
|
246
338
|
authenticated_user: User | None,
|
|
247
|
-
|
|
339
|
+
session_id: str | None,
|
|
248
340
|
) -> ChatContext:
|
|
249
341
|
"""Prepare and validate chat context from request."""
|
|
250
342
|
chat_context = ChatContext(**request.context)
|
|
251
343
|
|
|
252
344
|
# Add session_id to context if authenticated
|
|
253
|
-
if authenticated_user and
|
|
254
|
-
chat_context.session_id =
|
|
345
|
+
if authenticated_user and session_id:
|
|
346
|
+
chat_context.session_id = session_id
|
|
255
347
|
chat_context.user = authenticated_user
|
|
256
348
|
|
|
257
349
|
# Verify state signature if provided
|
|
@@ -275,9 +367,7 @@ class RagbitsAPI:
|
|
|
275
367
|
|
|
276
368
|
return chat_context
|
|
277
369
|
|
|
278
|
-
async def _handle_chat_message(
|
|
279
|
-
self, request: ChatMessageRequest, credentials: HTTPAuthorizationCredentials | None = None
|
|
280
|
-
) -> StreamingResponse: # noqa: PLR0915
|
|
370
|
+
async def _handle_chat_message(self, chat_request: ChatMessageRequest, request: Request) -> StreamingResponse: # noqa: PLR0915
|
|
281
371
|
"""Handle chat message requests with metrics tracking."""
|
|
282
372
|
start_time = time.time()
|
|
283
373
|
|
|
@@ -287,8 +377,8 @@ class RagbitsAPI:
|
|
|
287
377
|
)
|
|
288
378
|
|
|
289
379
|
try:
|
|
290
|
-
# Validate authentication if required
|
|
291
|
-
authenticated_user = await self.
|
|
380
|
+
# Validate authentication if required using cookies
|
|
381
|
+
authenticated_user = await self.require_authenticated_user(request)
|
|
292
382
|
|
|
293
383
|
if not self.chat_interface:
|
|
294
384
|
record_metric(
|
|
@@ -302,12 +392,13 @@ class RagbitsAPI:
|
|
|
302
392
|
raise HTTPException(status_code=500, detail="Chat implementation is not initialized")
|
|
303
393
|
|
|
304
394
|
# Prepare chat context
|
|
305
|
-
|
|
395
|
+
session_id = request.cookies.get(SESSION_COOKIE_NAME)
|
|
396
|
+
chat_context = RagbitsAPI._prepare_chat_context(chat_request, authenticated_user, session_id)
|
|
306
397
|
|
|
307
398
|
# Get the response generator from the chat interface
|
|
308
399
|
response_generator = self.chat_interface.chat(
|
|
309
|
-
message=
|
|
310
|
-
history=[msg.model_dump() for msg in
|
|
400
|
+
message=chat_request.message,
|
|
401
|
+
history=[msg.model_dump() for msg in chat_request.history],
|
|
311
402
|
context=chat_context,
|
|
312
403
|
)
|
|
313
404
|
|
|
@@ -318,8 +409,8 @@ class RagbitsAPI:
|
|
|
318
409
|
state_update_text = ""
|
|
319
410
|
|
|
320
411
|
with trace(
|
|
321
|
-
message=
|
|
322
|
-
history=[msg.model_dump() for msg in
|
|
412
|
+
message=chat_request.message,
|
|
413
|
+
history=[msg.model_dump() for msg in chat_request.history],
|
|
323
414
|
context=chat_context,
|
|
324
415
|
) as outputs:
|
|
325
416
|
async for chunk in RagbitsAPI._chat_response_to_sse(response_generator):
|
|
@@ -406,9 +497,7 @@ class RagbitsAPI:
|
|
|
406
497
|
)
|
|
407
498
|
raise HTTPException(status_code=500, detail="Internal server error") from None
|
|
408
499
|
|
|
409
|
-
async def _handle_feedback(
|
|
410
|
-
self, request: FeedbackRequest, credentials: HTTPAuthorizationCredentials | None = None
|
|
411
|
-
) -> JSONResponse:
|
|
500
|
+
async def _handle_feedback(self, feedback_request: FeedbackRequest, request: Request) -> JSONResponse:
|
|
412
501
|
"""Handle feedback requests with metrics tracking."""
|
|
413
502
|
start_time = time.time()
|
|
414
503
|
|
|
@@ -422,8 +511,8 @@ class RagbitsAPI:
|
|
|
422
511
|
)
|
|
423
512
|
|
|
424
513
|
try:
|
|
425
|
-
# Validate authentication if required
|
|
426
|
-
await self.
|
|
514
|
+
# Validate authentication if required using cookies
|
|
515
|
+
await self.require_authenticated_user(request)
|
|
427
516
|
|
|
428
517
|
if not self.chat_interface:
|
|
429
518
|
record_metric(
|
|
@@ -437,9 +526,9 @@ class RagbitsAPI:
|
|
|
437
526
|
raise HTTPException(status_code=500, detail="Chat implementation is not initialized")
|
|
438
527
|
|
|
439
528
|
await self.chat_interface.save_feedback(
|
|
440
|
-
message_id=
|
|
441
|
-
feedback=
|
|
442
|
-
payload=
|
|
529
|
+
message_id=feedback_request.message_id,
|
|
530
|
+
feedback=feedback_request.feedback,
|
|
531
|
+
payload=feedback_request.payload,
|
|
443
532
|
)
|
|
444
533
|
|
|
445
534
|
# Track successful request duration
|
|
@@ -496,26 +585,109 @@ class RagbitsAPI:
|
|
|
496
585
|
)
|
|
497
586
|
raise HTTPException(status_code=500, detail="Internal server error") from None
|
|
498
587
|
|
|
588
|
+
async def get_current_user_from_cookie(self, request: Request) -> User | None:
|
|
589
|
+
"""
|
|
590
|
+
Get current user from session cookie.
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
request: FastAPI request object
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
User object if authenticated, None otherwise
|
|
597
|
+
"""
|
|
598
|
+
if not self.auth_backend:
|
|
599
|
+
return None
|
|
600
|
+
|
|
601
|
+
session_id = request.cookies.get(SESSION_COOKIE_NAME)
|
|
602
|
+
if not session_id:
|
|
603
|
+
return None
|
|
604
|
+
|
|
605
|
+
result = await self.auth_backend.validate_session(session_id)
|
|
606
|
+
if result.success:
|
|
607
|
+
return result.user
|
|
608
|
+
|
|
609
|
+
return None
|
|
610
|
+
|
|
611
|
+
async def require_authenticated_user(self, request: Request) -> User | None:
|
|
612
|
+
"""
|
|
613
|
+
Get current user from session cookie and raise HTTPException if authentication
|
|
614
|
+
is required but user is not authenticated.
|
|
615
|
+
|
|
616
|
+
This is a reusable dependency for handlers that require authentication.
|
|
617
|
+
|
|
618
|
+
Args:
|
|
619
|
+
request: FastAPI request object
|
|
620
|
+
|
|
621
|
+
Returns:
|
|
622
|
+
User object if authenticated (or if no auth is required), None if no auth backend
|
|
623
|
+
|
|
624
|
+
Raises:
|
|
625
|
+
HTTPException: 401 Unauthorized if authentication is required but user is not authenticated
|
|
626
|
+
"""
|
|
627
|
+
authenticated_user = await self.get_current_user_from_cookie(request)
|
|
628
|
+
if self.auth_backend and not authenticated_user:
|
|
629
|
+
raise HTTPException(
|
|
630
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
631
|
+
detail="Authentication required",
|
|
632
|
+
)
|
|
633
|
+
return authenticated_user
|
|
634
|
+
|
|
635
|
+
def get_session_expiry_seconds(self) -> int:
|
|
636
|
+
"""
|
|
637
|
+
Get session expiry time in seconds from the auth backend configuration.
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
Session expiry time in seconds (default: 24 hours if not configured)
|
|
641
|
+
"""
|
|
642
|
+
# Default to 24 hours
|
|
643
|
+
default_expiry_hours = 24
|
|
644
|
+
|
|
645
|
+
if not self.auth_backend:
|
|
646
|
+
return default_expiry_hours * 3600
|
|
647
|
+
|
|
648
|
+
# Check if the backend has session_expiry_hours attribute
|
|
649
|
+
if hasattr(self.auth_backend, "session_expiry_hours"):
|
|
650
|
+
return self.auth_backend.session_expiry_hours * 3600
|
|
651
|
+
|
|
652
|
+
# For MultiAuthenticationBackend, try to get from first backend that has it
|
|
653
|
+
if isinstance(self.auth_backend, MultiAuthenticationBackend):
|
|
654
|
+
for backend in self.auth_backend.backends:
|
|
655
|
+
if hasattr(backend, "session_expiry_hours"):
|
|
656
|
+
return backend.session_expiry_hours * 3600
|
|
657
|
+
|
|
658
|
+
return default_expiry_hours * 3600
|
|
659
|
+
|
|
499
660
|
async def _handle_login(self, request: LoginRequest) -> JSONResponse:
|
|
500
|
-
"""Handle user login requests."""
|
|
661
|
+
"""Handle user login requests with credentials."""
|
|
501
662
|
if not self.auth_backend:
|
|
502
663
|
raise HTTPException(status_code=500, detail="Authentication not configured")
|
|
503
664
|
|
|
504
665
|
try:
|
|
505
|
-
|
|
666
|
+
# LoginRequest is UserCredentials
|
|
667
|
+
auth_result = await self.auth_backend.authenticate_with_credentials(request)
|
|
506
668
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
if auth_result.success and auth_result.jwt_token:
|
|
511
|
-
return JSONResponse(
|
|
669
|
+
if auth_result.success and auth_result.session_id:
|
|
670
|
+
response = JSONResponse(
|
|
512
671
|
content=LoginResponse(
|
|
513
672
|
success=True,
|
|
514
673
|
user=auth_result.user if auth_result.user else None,
|
|
515
674
|
error_message=None,
|
|
516
|
-
jwt_token=auth_result.jwt_token,
|
|
517
675
|
).model_dump()
|
|
518
676
|
)
|
|
677
|
+
|
|
678
|
+
# Set secure HTTP-only cookie using backend's session expiry configuration
|
|
679
|
+
session_expiry_seconds = self.get_session_expiry_seconds()
|
|
680
|
+
response.set_cookie(
|
|
681
|
+
key=SESSION_COOKIE_NAME,
|
|
682
|
+
value=auth_result.session_id,
|
|
683
|
+
httponly=True,
|
|
684
|
+
secure=IS_PRODUCTION, # Only require HTTPS in production
|
|
685
|
+
samesite="lax",
|
|
686
|
+
max_age=session_expiry_seconds,
|
|
687
|
+
path="/",
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
return response
|
|
519
691
|
else:
|
|
520
692
|
return JSONResponse(
|
|
521
693
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
@@ -523,7 +695,6 @@ class RagbitsAPI:
|
|
|
523
695
|
success=False,
|
|
524
696
|
user=None,
|
|
525
697
|
error_message=auth_result.error_message or "Invalid credentials",
|
|
526
|
-
jwt_token=None,
|
|
527
698
|
).model_dump(),
|
|
528
699
|
)
|
|
529
700
|
except Exception as e:
|
|
@@ -534,25 +705,100 @@ class RagbitsAPI:
|
|
|
534
705
|
success=False,
|
|
535
706
|
user=None,
|
|
536
707
|
error_message="Internal server error",
|
|
537
|
-
jwt_token=None,
|
|
538
708
|
).model_dump(),
|
|
539
709
|
)
|
|
540
710
|
|
|
541
|
-
async def _handle_logout(self, request:
|
|
711
|
+
async def _handle_logout(self, request: Request) -> JSONResponse:
|
|
542
712
|
"""Handle user logout requests."""
|
|
543
713
|
if not self.auth_backend:
|
|
544
714
|
raise HTTPException(status_code=500, detail="Authentication not configured")
|
|
545
715
|
|
|
546
716
|
try:
|
|
547
|
-
|
|
548
|
-
|
|
717
|
+
# Get session ID from cookie
|
|
718
|
+
session_id = request.cookies.get(SESSION_COOKIE_NAME)
|
|
719
|
+
|
|
720
|
+
if not session_id:
|
|
721
|
+
# No session cookie, just return success
|
|
722
|
+
response = JSONResponse(content={"success": True})
|
|
723
|
+
response.delete_cookie(key=SESSION_COOKIE_NAME, path="/")
|
|
724
|
+
return response
|
|
725
|
+
|
|
726
|
+
# Delete the session from store
|
|
727
|
+
success = await self.auth_backend.revoke_session(session_id)
|
|
728
|
+
|
|
729
|
+
response = JSONResponse(content={"success": success})
|
|
730
|
+
# Clear the session cookie
|
|
731
|
+
response.delete_cookie(key=SESSION_COOKIE_NAME, path="/")
|
|
732
|
+
return response
|
|
733
|
+
|
|
549
734
|
except Exception as e:
|
|
550
735
|
logger.error(f"Logout error: {e}")
|
|
551
736
|
return JSONResponse(
|
|
552
737
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
553
|
-
content={"success": False, "
|
|
738
|
+
content={"success": False, "error": "Internal server error"},
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
async def _handle_oauth2_callback( # noqa: PLR6301
|
|
742
|
+
self, code: str | None, state: str | None, backend: OAuth2AuthenticationBackend
|
|
743
|
+
) -> RedirectResponse:
|
|
744
|
+
"""
|
|
745
|
+
Handle OAuth2 callback from provider.
|
|
746
|
+
|
|
747
|
+
This endpoint receives the authorization code from the OAuth2 provider,
|
|
748
|
+
exchanges it for an access token, authenticates the user, creates a session,
|
|
749
|
+
and redirects to the frontend with a secure HTTP-only cookie.
|
|
750
|
+
|
|
751
|
+
Args:
|
|
752
|
+
code: Authorization code from OAuth2 provider
|
|
753
|
+
state: CSRF protection state parameter
|
|
754
|
+
backend: The specific OAuth2 backend for this provider
|
|
755
|
+
"""
|
|
756
|
+
# Verify required parameters
|
|
757
|
+
if not code:
|
|
758
|
+
# Redirect to login with error
|
|
759
|
+
return RedirectResponse(url=f"{self.frontend_base_url}/login?error=missing_code", status_code=302)
|
|
760
|
+
|
|
761
|
+
# Verify state parameter for CSRF protection
|
|
762
|
+
if not state or not backend.verify_state(state):
|
|
763
|
+
return RedirectResponse(url=f"{self.frontend_base_url}/login?error=invalid_state", status_code=302)
|
|
764
|
+
|
|
765
|
+
try:
|
|
766
|
+
# Exchange code for access token
|
|
767
|
+
access_token = await backend.exchange_code_for_token(code)
|
|
768
|
+
if not access_token:
|
|
769
|
+
return RedirectResponse(
|
|
770
|
+
url=f"{self.frontend_base_url}/login?error=token_exchange_failed", status_code=302
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
# Authenticate with the access token
|
|
774
|
+
oauth_credentials = OAuth2Credentials(access_token=access_token, token_type="Bearer") # noqa: S106
|
|
775
|
+
auth_result = await backend.authenticate_with_oauth2(oauth_credentials)
|
|
776
|
+
|
|
777
|
+
if not auth_result.success or not auth_result.session_id:
|
|
778
|
+
error_msg = auth_result.error_message or "Authentication failed"
|
|
779
|
+
return RedirectResponse(url=f"{self.frontend_base_url}/login?error={error_msg}", status_code=302)
|
|
780
|
+
|
|
781
|
+
# Success! Create redirect response with session cookie
|
|
782
|
+
response = RedirectResponse(url=f"{self.frontend_base_url}/", status_code=302)
|
|
783
|
+
|
|
784
|
+
# Set secure HTTP-only cookie
|
|
785
|
+
session_expiry_seconds = backend.session_expiry_hours * 3600
|
|
786
|
+
response.set_cookie(
|
|
787
|
+
key=SESSION_COOKIE_NAME,
|
|
788
|
+
value=auth_result.session_id,
|
|
789
|
+
httponly=True,
|
|
790
|
+
secure=IS_PRODUCTION, # Only require HTTPS in production
|
|
791
|
+
samesite="lax",
|
|
792
|
+
max_age=session_expiry_seconds,
|
|
793
|
+
path="/",
|
|
554
794
|
)
|
|
555
795
|
|
|
796
|
+
return response
|
|
797
|
+
|
|
798
|
+
except Exception as e:
|
|
799
|
+
logger.error(f"OAuth2 callback error: {e}")
|
|
800
|
+
return RedirectResponse(url=f"{self.frontend_base_url}/login?error=internal_error", status_code=302)
|
|
801
|
+
|
|
556
802
|
@staticmethod
|
|
557
803
|
async def _chat_response_to_sse(
|
|
558
804
|
responses: AsyncGenerator[ChatResponseUnion],
|
|
@@ -776,3 +1022,42 @@ class RagbitsAPI:
|
|
|
776
1022
|
css_lines.append(f" --heroui-{color_name}: {color_value};")
|
|
777
1023
|
|
|
778
1024
|
return css_lines
|
|
1025
|
+
|
|
1026
|
+
async def _session_cleanup_loop(self) -> None:
|
|
1027
|
+
if (
|
|
1028
|
+
not self.auth_backend
|
|
1029
|
+
or not hasattr(self.auth_backend, "session_store")
|
|
1030
|
+
or not hasattr(self.auth_backend.session_store, "cleanup_expired_sessions")
|
|
1031
|
+
):
|
|
1032
|
+
return
|
|
1033
|
+
|
|
1034
|
+
while True:
|
|
1035
|
+
await asyncio.sleep(3600) # Run every hour
|
|
1036
|
+
try:
|
|
1037
|
+
removed = self.auth_backend.session_store.cleanup_expired_sessions() # type: ignore
|
|
1038
|
+
if removed > 0:
|
|
1039
|
+
logger.info(f"Cleaned up {removed} expired sessions")
|
|
1040
|
+
except Exception as e:
|
|
1041
|
+
logger.exception(f"Error during session cleanup: {e}")
|
|
1042
|
+
|
|
1043
|
+
async def _oauth_state_cleanup_loop(self) -> None:
|
|
1044
|
+
oauth2_backends = []
|
|
1045
|
+
if isinstance(self.auth_backend, MultiAuthenticationBackend):
|
|
1046
|
+
oauth2_backends = self.auth_backend.get_oauth2_backends()
|
|
1047
|
+
elif isinstance(self.auth_backend, OAuth2AuthenticationBackend):
|
|
1048
|
+
oauth2_backends = [self.auth_backend]
|
|
1049
|
+
|
|
1050
|
+
if not oauth2_backends:
|
|
1051
|
+
return
|
|
1052
|
+
|
|
1053
|
+
while True:
|
|
1054
|
+
await asyncio.sleep(600) # Run every 10 minutes
|
|
1055
|
+
try:
|
|
1056
|
+
total_removed = 0
|
|
1057
|
+
for backend in oauth2_backends:
|
|
1058
|
+
removed = backend.cleanup_expired_states()
|
|
1059
|
+
total_removed += removed
|
|
1060
|
+
if total_removed > 0:
|
|
1061
|
+
logger.info(f"Cleaned up {total_removed} expired OAuth2 state tokens")
|
|
1062
|
+
except Exception as e:
|
|
1063
|
+
logger.exception(f"Error during OAuth state cleanup: {e}")
|
ragbits/chat/auth/__init__.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from .backends import ListAuthenticationBackend
|
|
2
2
|
from .base import AuthenticationBackend, AuthenticationResponse
|
|
3
|
-
from .types import User, UserCredentials
|
|
3
|
+
from .types import Session, SessionStore, User, UserCredentials
|
|
4
4
|
|
|
5
5
|
__all__ = [
|
|
6
6
|
"AuthenticationBackend",
|
|
7
7
|
"AuthenticationResponse",
|
|
8
8
|
"ListAuthenticationBackend",
|
|
9
|
+
"Session",
|
|
10
|
+
"SessionStore",
|
|
9
11
|
"User",
|
|
10
12
|
"UserCredentials",
|
|
11
13
|
]
|