chainlit 2.0.0__py3-none-any.whl → 2.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +57 -56
- chainlit/action.py +10 -12
- chainlit/{auth/__init__.py → auth.py} +34 -26
- chainlit/cache.py +6 -4
- chainlit/callbacks.py +7 -52
- chainlit/chat_context.py +2 -2
- chainlit/chat_settings.py +1 -3
- chainlit/cli/__init__.py +2 -15
- chainlit/config.py +70 -41
- chainlit/context.py +9 -8
- chainlit/copilot/dist/index.js +874 -8533
- chainlit/data/__init__.py +8 -96
- chainlit/data/acl.py +2 -3
- chainlit/data/base.py +15 -1
- chainlit/data/dynamodb.py +4 -7
- chainlit/data/literalai.py +6 -4
- chainlit/data/sql_alchemy.py +9 -10
- chainlit/data/{storage_clients/azure.py → storage_clients.py} +33 -2
- chainlit/discord/__init__.py +4 -4
- chainlit/discord/app.py +1 -2
- chainlit/element.py +9 -41
- chainlit/emitter.py +21 -17
- chainlit/frontend/dist/assets/DailyMotion-b4b7af47.js +1 -0
- chainlit/frontend/dist/assets/Facebook-572972a0.js +1 -0
- chainlit/frontend/dist/assets/FilePlayer-85c69ca8.js +1 -0
- chainlit/frontend/dist/assets/Kaltura-dfc24672.js +1 -0
- chainlit/frontend/dist/assets/Mixcloud-705011f4.js +1 -0
- chainlit/frontend/dist/assets/Mux-4201a9e6.js +1 -0
- chainlit/frontend/dist/assets/Preview-23ba40a6.js +1 -0
- chainlit/frontend/dist/assets/SoundCloud-1a582d51.js +1 -0
- chainlit/frontend/dist/assets/Streamable-5017c4ba.js +1 -0
- chainlit/frontend/dist/assets/Twitch-bb2de2fa.js +1 -0
- chainlit/frontend/dist/assets/Vidyard-54e269b1.js +1 -0
- chainlit/frontend/dist/assets/Vimeo-d92c37dd.js +1 -0
- chainlit/frontend/dist/assets/Wistia-25a1363b.js +1 -0
- chainlit/frontend/dist/assets/YouTube-616e8cb7.js +1 -0
- chainlit/frontend/dist/assets/index-aaf974a9.css +1 -0
- chainlit/frontend/dist/assets/index-f5df2072.js +1027 -0
- chainlit/frontend/dist/assets/{react-plotly-BpxUS-ab.js → react-plotly-f0315f86.js} +94 -94
- chainlit/frontend/dist/index.html +3 -2
- chainlit/haystack/callbacks.py +4 -5
- chainlit/input_widget.py +4 -6
- chainlit/langchain/callbacks.py +47 -56
- chainlit/langflow/__init__.py +0 -1
- chainlit/llama_index/callbacks.py +7 -7
- chainlit/message.py +10 -8
- chainlit/mistralai/__init__.py +2 -3
- chainlit/oauth_providers.py +12 -113
- chainlit/openai/__init__.py +7 -6
- chainlit/secret.py +1 -1
- chainlit/server.py +181 -491
- chainlit/session.py +5 -7
- chainlit/slack/__init__.py +3 -3
- chainlit/slack/app.py +2 -3
- chainlit/socket.py +103 -78
- chainlit/step.py +29 -21
- chainlit/sync.py +1 -2
- chainlit/teams/__init__.py +3 -3
- chainlit/teams/app.py +0 -1
- chainlit/types.py +4 -20
- chainlit/user.py +1 -2
- chainlit/utils.py +2 -3
- chainlit/version.py +2 -3
- {chainlit-2.0.0.dist-info → chainlit-2.0.dev0.dist-info}/METADATA +39 -27
- chainlit-2.0.dev0.dist-info/RECORD +96 -0
- chainlit/auth/cookie.py +0 -123
- chainlit/auth/jwt.py +0 -37
- chainlit/data/chainlit_data_layer.py +0 -584
- chainlit/data/storage_clients/__init__.py +0 -0
- chainlit/data/storage_clients/azure_blob.py +0 -80
- chainlit/data/storage_clients/base.py +0 -22
- chainlit/data/storage_clients/gcs.py +0 -78
- chainlit/data/storage_clients/s3.py +0 -49
- chainlit/frontend/dist/assets/DailyMotion-DgRzV5GZ.js +0 -1
- chainlit/frontend/dist/assets/Dataframe-DVgwSMU2.js +0 -22
- chainlit/frontend/dist/assets/Facebook-C0vx6HWv.js +0 -1
- chainlit/frontend/dist/assets/FilePlayer-CdhzeHPP.js +0 -1
- chainlit/frontend/dist/assets/Kaltura-5iVmeUct.js +0 -1
- chainlit/frontend/dist/assets/Mixcloud-C2zi77Ex.js +0 -1
- chainlit/frontend/dist/assets/Mux-Vkebogdf.js +0 -1
- chainlit/frontend/dist/assets/Preview-DwY_sEIl.js +0 -1
- chainlit/frontend/dist/assets/SoundCloud-CREBXAWo.js +0 -1
- chainlit/frontend/dist/assets/Streamable-B5Lu25uy.js +0 -1
- chainlit/frontend/dist/assets/Twitch-y9iKCcM1.js +0 -1
- chainlit/frontend/dist/assets/Vidyard-ClYvcuEu.js +0 -1
- chainlit/frontend/dist/assets/Vimeo-D6HvM2jt.js +0 -1
- chainlit/frontend/dist/assets/Wistia-Cu4zZ2Ci.js +0 -1
- chainlit/frontend/dist/assets/YouTube-D10tR6CJ.js +0 -1
- chainlit/frontend/dist/assets/index-CI4qFOt5.js +0 -8665
- chainlit/frontend/dist/assets/index-CrrqM0nZ.css +0 -1
- chainlit/translations/nl-NL.json +0 -229
- chainlit-2.0.0.dist-info/RECORD +0 -106
- /chainlit/copilot/dist/assets/{logo_dark-IkGJ_IwC.svg → logo_dark-2a3cf740.svg} +0 -0
- /chainlit/copilot/dist/assets/{logo_light-Bb_IPh6r.svg → logo_light-b078e7bc.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_dark-IkGJ_IwC.svg → logo_dark-2a3cf740.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_light-Bb_IPh6r.svg → logo_light-b078e7bc.svg} +0 -0
- {chainlit-2.0.0.dist-info → chainlit-2.0.dev0.dist-info}/WHEEL +0 -0
- {chainlit-2.0.0.dist-info → chainlit-2.0.dev0.dist-info}/entry_points.txt +0 -0
chainlit/server.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import fnmatch
|
|
3
2
|
import glob
|
|
4
3
|
import json
|
|
5
4
|
import mimetypes
|
|
@@ -10,36 +9,10 @@ import urllib.parse
|
|
|
10
9
|
import webbrowser
|
|
11
10
|
from contextlib import asynccontextmanager
|
|
12
11
|
from pathlib import Path
|
|
13
|
-
from typing import
|
|
12
|
+
from typing import Any, Optional, Union
|
|
14
13
|
|
|
15
14
|
import socketio
|
|
16
|
-
from
|
|
17
|
-
APIRouter,
|
|
18
|
-
Depends,
|
|
19
|
-
FastAPI,
|
|
20
|
-
Form,
|
|
21
|
-
HTTPException,
|
|
22
|
-
Query,
|
|
23
|
-
Request,
|
|
24
|
-
Response,
|
|
25
|
-
UploadFile,
|
|
26
|
-
status,
|
|
27
|
-
)
|
|
28
|
-
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
|
|
29
|
-
from fastapi.security import OAuth2PasswordRequestForm
|
|
30
|
-
from starlette.datastructures import URL
|
|
31
|
-
from starlette.middleware.cors import CORSMiddleware
|
|
32
|
-
from typing_extensions import Annotated
|
|
33
|
-
from watchfiles import awatch
|
|
34
|
-
|
|
35
|
-
from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
|
|
36
|
-
from chainlit.auth.cookie import (
|
|
37
|
-
clear_auth_cookie,
|
|
38
|
-
clear_oauth_state_cookie,
|
|
39
|
-
set_auth_cookie,
|
|
40
|
-
set_oauth_state_cookie,
|
|
41
|
-
validate_oauth_state_cookie,
|
|
42
|
-
)
|
|
15
|
+
from chainlit.auth import create_jwt, get_configuration, get_current_user
|
|
43
16
|
from chainlit.config import (
|
|
44
17
|
APP_ROOT,
|
|
45
18
|
BACKEND_ROOT,
|
|
@@ -48,7 +21,6 @@ from chainlit.config import (
|
|
|
48
21
|
PACKAGE_ROOT,
|
|
49
22
|
config,
|
|
50
23
|
load_module,
|
|
51
|
-
public_dir,
|
|
52
24
|
reload_config,
|
|
53
25
|
)
|
|
54
26
|
from chainlit.data import get_data_layer
|
|
@@ -58,16 +30,32 @@ from chainlit.markdown import get_markdown_str
|
|
|
58
30
|
from chainlit.oauth_providers import get_oauth_provider
|
|
59
31
|
from chainlit.secret import random_secret
|
|
60
32
|
from chainlit.types import (
|
|
61
|
-
CallActionRequest,
|
|
62
33
|
DeleteFeedbackRequest,
|
|
63
34
|
DeleteThreadRequest,
|
|
64
|
-
ElementRequest,
|
|
65
35
|
GetThreadsRequest,
|
|
66
36
|
Theme,
|
|
67
37
|
UpdateFeedbackRequest,
|
|
68
|
-
UpdateThreadRequest,
|
|
69
38
|
)
|
|
70
39
|
from chainlit.user import PersistedUser, User
|
|
40
|
+
from fastapi import (
|
|
41
|
+
APIRouter,
|
|
42
|
+
Depends,
|
|
43
|
+
FastAPI,
|
|
44
|
+
Form,
|
|
45
|
+
HTTPException,
|
|
46
|
+
Query,
|
|
47
|
+
Request,
|
|
48
|
+
Response,
|
|
49
|
+
UploadFile,
|
|
50
|
+
status,
|
|
51
|
+
)
|
|
52
|
+
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
|
|
53
|
+
from fastapi.security import OAuth2PasswordRequestForm
|
|
54
|
+
from fastapi.staticfiles import StaticFiles
|
|
55
|
+
from starlette.datastructures import URL
|
|
56
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
57
|
+
from typing_extensions import Annotated
|
|
58
|
+
from watchfiles import awatch
|
|
71
59
|
|
|
72
60
|
from ._utils import is_path_inside
|
|
73
61
|
|
|
@@ -216,59 +204,29 @@ app.add_middleware(
|
|
|
216
204
|
|
|
217
205
|
router = APIRouter(prefix=PREFIX)
|
|
218
206
|
|
|
207
|
+
app.mount(
|
|
208
|
+
f"{PREFIX}/public",
|
|
209
|
+
StaticFiles(directory="public", check_dir=False),
|
|
210
|
+
name="public",
|
|
211
|
+
)
|
|
219
212
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
if not is_path_inside(file_path, base_path):
|
|
230
|
-
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
231
|
-
|
|
232
|
-
if file_path.is_file():
|
|
233
|
-
return FileResponse(file_path)
|
|
234
|
-
else:
|
|
235
|
-
raise HTTPException(status_code=404, detail="File not found")
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
@router.get("/assets/{filename:path}")
|
|
239
|
-
async def serve_asset_file(
|
|
240
|
-
filename: str,
|
|
241
|
-
):
|
|
242
|
-
"""Serve a file from assets dir."""
|
|
243
|
-
|
|
244
|
-
base_path = Path(os.path.join(build_dir, "assets"))
|
|
245
|
-
file_path = (base_path / filename).resolve()
|
|
246
|
-
|
|
247
|
-
if not is_path_inside(file_path, base_path):
|
|
248
|
-
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
249
|
-
|
|
250
|
-
if file_path.is_file():
|
|
251
|
-
return FileResponse(file_path)
|
|
252
|
-
else:
|
|
253
|
-
raise HTTPException(status_code=404, detail="File not found")
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
@router.get("/copilot/{filename:path}")
|
|
257
|
-
async def serve_copilot_file(
|
|
258
|
-
filename: str,
|
|
259
|
-
):
|
|
260
|
-
"""Serve a file from assets dir."""
|
|
261
|
-
|
|
262
|
-
base_path = Path(copilot_build_dir)
|
|
263
|
-
file_path = (base_path / filename).resolve()
|
|
264
|
-
|
|
265
|
-
if not is_path_inside(file_path, base_path):
|
|
266
|
-
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
213
|
+
app.mount(
|
|
214
|
+
f"{PREFIX}/assets",
|
|
215
|
+
StaticFiles(
|
|
216
|
+
packages=[("chainlit", os.path.join(build_dir, "assets"))],
|
|
217
|
+
follow_symlink=config.project.follow_symlink,
|
|
218
|
+
),
|
|
219
|
+
name="assets",
|
|
220
|
+
)
|
|
267
221
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
222
|
+
app.mount(
|
|
223
|
+
f"{PREFIX}/copilot",
|
|
224
|
+
StaticFiles(
|
|
225
|
+
packages=[("chainlit", copilot_build_dir)],
|
|
226
|
+
follow_symlink=config.project.follow_symlink,
|
|
227
|
+
),
|
|
228
|
+
name="copilot",
|
|
229
|
+
)
|
|
272
230
|
|
|
273
231
|
|
|
274
232
|
# -------------------------------------------------------------------------------
|
|
@@ -289,7 +247,6 @@ if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_SIGNING_SECRET"):
|
|
|
289
247
|
|
|
290
248
|
if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
|
|
291
249
|
from botbuilder.schema import Activity
|
|
292
|
-
|
|
293
250
|
from chainlit.teams.app import adapter, bot
|
|
294
251
|
|
|
295
252
|
@router.post("/teams/events")
|
|
@@ -319,16 +276,6 @@ def get_html_template():
|
|
|
319
276
|
"""
|
|
320
277
|
Get HTML template for the index view.
|
|
321
278
|
"""
|
|
322
|
-
ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
323
|
-
|
|
324
|
-
custom_theme = None
|
|
325
|
-
custom_theme_file_path = Path(public_dir) / "theme.json"
|
|
326
|
-
if (
|
|
327
|
-
is_path_inside(custom_theme_file_path, Path(public_dir))
|
|
328
|
-
and custom_theme_file_path.is_file()
|
|
329
|
-
):
|
|
330
|
-
custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))
|
|
331
|
-
|
|
332
279
|
PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
|
|
333
280
|
JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
|
|
334
281
|
CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
|
|
@@ -351,10 +298,7 @@ def get_html_template():
|
|
|
351
298
|
<meta property="og:url" content="{url}">
|
|
352
299
|
<meta property="og:root_path" content="{ROOT_PATH}">"""
|
|
353
300
|
|
|
354
|
-
js = f"""<script>
|
|
355
|
-
{f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
|
|
356
|
-
{f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
|
|
357
|
-
</script>"""
|
|
301
|
+
js = f"""<script>{f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}</script>"""
|
|
358
302
|
|
|
359
303
|
css = None
|
|
360
304
|
if config.ui.custom_css:
|
|
@@ -366,15 +310,12 @@ def get_html_template():
|
|
|
366
310
|
js += f"""<script src="{config.ui.custom_js}" defer></script>"""
|
|
367
311
|
|
|
368
312
|
font = None
|
|
369
|
-
if
|
|
370
|
-
font = "
|
|
371
|
-
f"""<link rel="stylesheet" href="{font}">"""
|
|
372
|
-
for font in custom_theme.get("custom_fonts")
|
|
373
|
-
)
|
|
313
|
+
if config.ui.custom_font:
|
|
314
|
+
font = f"""<link rel="stylesheet" href="{config.ui.custom_font}">"""
|
|
374
315
|
|
|
375
316
|
index_html_file_path = os.path.join(build_dir, "index.html")
|
|
376
317
|
|
|
377
|
-
with open(index_html_file_path, encoding="utf-8") as f:
|
|
318
|
+
with open(index_html_file_path, "r", encoding="utf-8") as f:
|
|
378
319
|
content = f.read()
|
|
379
320
|
content = content.replace(PLACEHOLDER, tags)
|
|
380
321
|
if js:
|
|
@@ -419,132 +360,46 @@ async def auth(request: Request):
|
|
|
419
360
|
return get_configuration()
|
|
420
361
|
|
|
421
362
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
response_dict = _get_response_dict(access_token)
|
|
432
|
-
|
|
433
|
-
if redirect_to_callback:
|
|
434
|
-
root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
435
|
-
redirect_url = (
|
|
436
|
-
f"{root_path}/login/callback?{urllib.parse.urlencode(response_dict)}"
|
|
437
|
-
)
|
|
438
|
-
|
|
439
|
-
return RedirectResponse(
|
|
440
|
-
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
441
|
-
url=redirect_url,
|
|
442
|
-
status_code=302,
|
|
363
|
+
@router.post("/login")
|
|
364
|
+
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
365
|
+
"""
|
|
366
|
+
Login a user using the password auth callback.
|
|
367
|
+
"""
|
|
368
|
+
if not config.code.password_auth_callback:
|
|
369
|
+
raise HTTPException(
|
|
370
|
+
status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
|
|
443
371
|
)
|
|
444
372
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
def _get_oauth_redirect_error(error: str) -> Response:
|
|
449
|
-
"""Get the redirect response for an OAuth error."""
|
|
450
|
-
params = urllib.parse.urlencode(
|
|
451
|
-
{
|
|
452
|
-
"error": error,
|
|
453
|
-
}
|
|
454
|
-
)
|
|
455
|
-
response = RedirectResponse(
|
|
456
|
-
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
457
|
-
url=f"/login?{params}", # Shouldn't there be {root_path} here?
|
|
373
|
+
user = await config.code.password_auth_callback(
|
|
374
|
+
form_data.username, form_data.password
|
|
458
375
|
)
|
|
459
|
-
return response
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
async def _authenticate_user(
|
|
463
|
-
user: Optional[User], redirect_to_callback: bool = False
|
|
464
|
-
) -> Response:
|
|
465
|
-
"""Authenticate a user and return the response."""
|
|
466
376
|
|
|
467
377
|
if not user:
|
|
468
378
|
raise HTTPException(
|
|
469
379
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
470
380
|
detail="credentialssignin",
|
|
471
381
|
)
|
|
472
|
-
|
|
473
|
-
# If a data layer is defined, attempt to persist user.
|
|
382
|
+
access_token = create_jwt(user)
|
|
474
383
|
if data_layer := get_data_layer():
|
|
475
384
|
try:
|
|
476
385
|
await data_layer.create_user(user)
|
|
477
386
|
except Exception as e:
|
|
478
|
-
# Catch and log exceptions during user creation.
|
|
479
|
-
# TODO: Make this catch only specific errors and allow others to propagate.
|
|
480
387
|
logger.error(f"Error creating user: {e}")
|
|
481
388
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
set_auth_cookie(response, access_token)
|
|
487
|
-
|
|
488
|
-
return response
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
@router.post("/login")
|
|
492
|
-
async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()):
|
|
493
|
-
"""
|
|
494
|
-
Login a user using the password auth callback.
|
|
495
|
-
"""
|
|
496
|
-
if not config.code.password_auth_callback:
|
|
497
|
-
raise HTTPException(
|
|
498
|
-
status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
|
|
499
|
-
)
|
|
500
|
-
|
|
501
|
-
user = await config.code.password_auth_callback(
|
|
502
|
-
form_data.username, form_data.password
|
|
503
|
-
)
|
|
504
|
-
|
|
505
|
-
return await _authenticate_user(user)
|
|
389
|
+
return {
|
|
390
|
+
"access_token": access_token,
|
|
391
|
+
"token_type": "bearer",
|
|
392
|
+
}
|
|
506
393
|
|
|
507
394
|
|
|
508
395
|
@router.post("/logout")
|
|
509
396
|
async def logout(request: Request, response: Response):
|
|
510
397
|
"""Logout the user by calling the on_logout callback."""
|
|
511
|
-
clear_auth_cookie(response)
|
|
512
|
-
|
|
513
398
|
if config.code.on_logout:
|
|
514
399
|
return await config.code.on_logout(request, response)
|
|
515
|
-
|
|
516
400
|
return {"success": True}
|
|
517
401
|
|
|
518
402
|
|
|
519
|
-
@router.post("/auth/jwt")
|
|
520
|
-
async def jwt_auth(request: Request):
|
|
521
|
-
"""Login a user using a valid jwt."""
|
|
522
|
-
from jwt import InvalidTokenError
|
|
523
|
-
|
|
524
|
-
auth_header: Optional[str] = request.headers.get("Authorization")
|
|
525
|
-
if not auth_header:
|
|
526
|
-
raise HTTPException(status_code=401, detail="Authorization header missing")
|
|
527
|
-
|
|
528
|
-
# Check if it starts with "Bearer "
|
|
529
|
-
try:
|
|
530
|
-
scheme, token = auth_header.split()
|
|
531
|
-
if scheme.lower() != "bearer":
|
|
532
|
-
raise HTTPException(
|
|
533
|
-
status_code=401,
|
|
534
|
-
detail="Invalid authentication scheme. Please use Bearer",
|
|
535
|
-
)
|
|
536
|
-
except ValueError:
|
|
537
|
-
raise HTTPException(
|
|
538
|
-
status_code=401, detail="Invalid authorization header format"
|
|
539
|
-
)
|
|
540
|
-
|
|
541
|
-
try:
|
|
542
|
-
user = decode_jwt(token)
|
|
543
|
-
return await _authenticate_user(user)
|
|
544
|
-
except InvalidTokenError:
|
|
545
|
-
raise HTTPException(status_code=401, detail="Invalid token")
|
|
546
|
-
|
|
547
|
-
|
|
548
403
|
@router.post("/auth/header")
|
|
549
404
|
async def header_auth(request: Request):
|
|
550
405
|
"""Login a user using the header_auth_callback."""
|
|
@@ -556,7 +411,23 @@ async def header_auth(request: Request):
|
|
|
556
411
|
|
|
557
412
|
user = await config.code.header_auth_callback(request.headers)
|
|
558
413
|
|
|
559
|
-
|
|
414
|
+
if not user:
|
|
415
|
+
raise HTTPException(
|
|
416
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
417
|
+
detail="Unauthorized",
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
access_token = create_jwt(user)
|
|
421
|
+
if data_layer := get_data_layer():
|
|
422
|
+
try:
|
|
423
|
+
await data_layer.create_user(user)
|
|
424
|
+
except Exception as e:
|
|
425
|
+
logger.error(f"Error creating user: {e}")
|
|
426
|
+
|
|
427
|
+
return {
|
|
428
|
+
"access_token": access_token,
|
|
429
|
+
"token_type": "bearer",
|
|
430
|
+
}
|
|
560
431
|
|
|
561
432
|
|
|
562
433
|
@router.get("/auth/oauth/{provider_id}")
|
|
@@ -588,9 +459,16 @@ async def oauth_login(provider_id: str, request: Request):
|
|
|
588
459
|
response = RedirectResponse(
|
|
589
460
|
url=f"{provider.authorize_url}?{params}",
|
|
590
461
|
)
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
462
|
+
samesite: Any = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax")
|
|
463
|
+
secure = samesite.lower() == "none"
|
|
464
|
+
response.set_cookie(
|
|
465
|
+
"oauth_state",
|
|
466
|
+
random,
|
|
467
|
+
httponly=True,
|
|
468
|
+
samesite=samesite,
|
|
469
|
+
secure=secure,
|
|
470
|
+
max_age=3 * 60,
|
|
471
|
+
)
|
|
594
472
|
return response
|
|
595
473
|
|
|
596
474
|
|
|
@@ -618,7 +496,16 @@ async def oauth_callback(
|
|
|
618
496
|
)
|
|
619
497
|
|
|
620
498
|
if error:
|
|
621
|
-
|
|
499
|
+
params = urllib.parse.urlencode(
|
|
500
|
+
{
|
|
501
|
+
"error": error,
|
|
502
|
+
}
|
|
503
|
+
)
|
|
504
|
+
response = RedirectResponse(
|
|
505
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
506
|
+
url=f"/login?{params}",
|
|
507
|
+
)
|
|
508
|
+
return response
|
|
622
509
|
|
|
623
510
|
if not code or not state:
|
|
624
511
|
raise HTTPException(
|
|
@@ -626,11 +513,9 @@ async def oauth_callback(
|
|
|
626
513
|
detail="Missing code or state",
|
|
627
514
|
)
|
|
628
515
|
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
logger.exception("Unable to validate oauth state: %1", e)
|
|
633
|
-
|
|
516
|
+
# Check the state from the oauth provider against the browser cookie
|
|
517
|
+
oauth_state = request.cookies.get("oauth_state")
|
|
518
|
+
if oauth_state != state:
|
|
634
519
|
raise HTTPException(
|
|
635
520
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
636
521
|
detail="Unauthorized",
|
|
@@ -645,10 +530,34 @@ async def oauth_callback(
|
|
|
645
530
|
provider_id, token, raw_user_data, default_user
|
|
646
531
|
)
|
|
647
532
|
|
|
648
|
-
|
|
533
|
+
if not user:
|
|
534
|
+
raise HTTPException(
|
|
535
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
536
|
+
detail="Unauthorized",
|
|
537
|
+
)
|
|
649
538
|
|
|
650
|
-
|
|
539
|
+
access_token = create_jwt(user)
|
|
651
540
|
|
|
541
|
+
if data_layer := get_data_layer():
|
|
542
|
+
try:
|
|
543
|
+
await data_layer.create_user(user)
|
|
544
|
+
except Exception as e:
|
|
545
|
+
logger.error(f"Error creating user: {e}")
|
|
546
|
+
|
|
547
|
+
params = urllib.parse.urlencode(
|
|
548
|
+
{
|
|
549
|
+
"access_token": access_token,
|
|
550
|
+
"token_type": "bearer",
|
|
551
|
+
}
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
555
|
+
|
|
556
|
+
response = RedirectResponse(
|
|
557
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
558
|
+
url=f"{root_path}/login/callback?{params}",
|
|
559
|
+
)
|
|
560
|
+
response.delete_cookie("oauth_state")
|
|
652
561
|
return response
|
|
653
562
|
|
|
654
563
|
|
|
@@ -677,7 +586,16 @@ async def oauth_azure_hf_callback(
|
|
|
677
586
|
)
|
|
678
587
|
|
|
679
588
|
if error:
|
|
680
|
-
|
|
589
|
+
params = urllib.parse.urlencode(
|
|
590
|
+
{
|
|
591
|
+
"error": error,
|
|
592
|
+
}
|
|
593
|
+
)
|
|
594
|
+
response = RedirectResponse(
|
|
595
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
596
|
+
url=f"/login?{params}",
|
|
597
|
+
)
|
|
598
|
+
return response
|
|
681
599
|
|
|
682
600
|
if not code:
|
|
683
601
|
raise HTTPException(
|
|
@@ -694,24 +612,40 @@ async def oauth_azure_hf_callback(
|
|
|
694
612
|
provider_id, token, raw_user_data, default_user, id_token
|
|
695
613
|
)
|
|
696
614
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
615
|
+
if not user:
|
|
616
|
+
raise HTTPException(
|
|
617
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
618
|
+
detail="Unauthorized",
|
|
619
|
+
)
|
|
700
620
|
|
|
701
|
-
|
|
621
|
+
access_token = create_jwt(user)
|
|
702
622
|
|
|
623
|
+
if data_layer := get_data_layer():
|
|
624
|
+
try:
|
|
625
|
+
await data_layer.create_user(user)
|
|
626
|
+
except Exception as e:
|
|
627
|
+
logger.error(f"Error creating user: {e}")
|
|
703
628
|
|
|
704
|
-
|
|
705
|
-
|
|
629
|
+
params = urllib.parse.urlencode(
|
|
630
|
+
{
|
|
631
|
+
"access_token": access_token,
|
|
632
|
+
"token_type": "bearer",
|
|
633
|
+
}
|
|
634
|
+
)
|
|
706
635
|
|
|
636
|
+
root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
707
637
|
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
638
|
+
response = RedirectResponse(
|
|
639
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
640
|
+
url=f"{root_path}/login/callback?{params}",
|
|
641
|
+
status_code=302,
|
|
642
|
+
)
|
|
643
|
+
response.delete_cookie("oauth_state")
|
|
644
|
+
return response
|
|
711
645
|
|
|
712
646
|
|
|
713
647
|
_language_pattern = (
|
|
714
|
-
"^[a-zA-Z]{2,3}(-[a-zA-
|
|
648
|
+
"^[a-zA-Z]{2,3}(-[a-zA-Z]{2,3})?(-[a-zA-Z]{2,8})?(-x-[a-zA-Z0-9]{1,8})?$"
|
|
715
649
|
)
|
|
716
650
|
|
|
717
651
|
|
|
@@ -735,7 +669,7 @@ async def project_translations(
|
|
|
735
669
|
|
|
736
670
|
@router.get("/project/settings")
|
|
737
671
|
async def project_settings(
|
|
738
|
-
current_user:
|
|
672
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
739
673
|
language: str = Query(
|
|
740
674
|
default="en-US", description="Language code", pattern=_language_pattern
|
|
741
675
|
),
|
|
@@ -786,7 +720,7 @@ async def project_settings(
|
|
|
786
720
|
async def update_feedback(
|
|
787
721
|
request: Request,
|
|
788
722
|
update: UpdateFeedbackRequest,
|
|
789
|
-
current_user:
|
|
723
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
790
724
|
):
|
|
791
725
|
"""Update the human feedback for a particular message."""
|
|
792
726
|
data_layer = get_data_layer()
|
|
@@ -796,7 +730,7 @@ async def update_feedback(
|
|
|
796
730
|
try:
|
|
797
731
|
feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)
|
|
798
732
|
except Exception as e:
|
|
799
|
-
raise HTTPException(detail=str(e), status_code=500)
|
|
733
|
+
raise HTTPException(detail=str(e), status_code=500)
|
|
800
734
|
|
|
801
735
|
return JSONResponse(content={"success": True, "feedbackId": feedback_id})
|
|
802
736
|
|
|
@@ -805,7 +739,7 @@ async def update_feedback(
|
|
|
805
739
|
async def delete_feedback(
|
|
806
740
|
request: Request,
|
|
807
741
|
payload: DeleteFeedbackRequest,
|
|
808
|
-
current_user:
|
|
742
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
809
743
|
):
|
|
810
744
|
"""Delete a feedback."""
|
|
811
745
|
|
|
@@ -824,7 +758,7 @@ async def delete_feedback(
|
|
|
824
758
|
async def get_user_threads(
|
|
825
759
|
request: Request,
|
|
826
760
|
payload: GetThreadsRequest,
|
|
827
|
-
current_user:
|
|
761
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
828
762
|
):
|
|
829
763
|
"""Get the threads page by page."""
|
|
830
764
|
|
|
@@ -833,9 +767,6 @@ async def get_user_threads(
|
|
|
833
767
|
if not data_layer:
|
|
834
768
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
835
769
|
|
|
836
|
-
if not current_user:
|
|
837
|
-
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
838
|
-
|
|
839
770
|
if not isinstance(current_user, PersistedUser):
|
|
840
771
|
persisted_user = await data_layer.get_user(identifier=current_user.identifier)
|
|
841
772
|
if not persisted_user:
|
|
@@ -852,7 +783,7 @@ async def get_user_threads(
|
|
|
852
783
|
async def get_thread(
|
|
853
784
|
request: Request,
|
|
854
785
|
thread_id: str,
|
|
855
|
-
current_user:
|
|
786
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
856
787
|
):
|
|
857
788
|
"""Get a specific thread."""
|
|
858
789
|
data_layer = get_data_layer()
|
|
@@ -860,9 +791,6 @@ async def get_thread(
|
|
|
860
791
|
if not data_layer:
|
|
861
792
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
862
793
|
|
|
863
|
-
if not current_user:
|
|
864
|
-
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
865
|
-
|
|
866
794
|
await is_thread_author(current_user.identifier, thread_id)
|
|
867
795
|
|
|
868
796
|
res = await data_layer.get_thread(thread_id)
|
|
@@ -874,7 +802,7 @@ async def get_thread_element(
|
|
|
874
802
|
request: Request,
|
|
875
803
|
thread_id: str,
|
|
876
804
|
element_id: str,
|
|
877
|
-
current_user:
|
|
805
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
878
806
|
):
|
|
879
807
|
"""Get a specific thread element."""
|
|
880
808
|
data_layer = get_data_layer()
|
|
@@ -882,135 +810,17 @@ async def get_thread_element(
|
|
|
882
810
|
if not data_layer:
|
|
883
811
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
884
812
|
|
|
885
|
-
if not current_user:
|
|
886
|
-
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
887
|
-
|
|
888
813
|
await is_thread_author(current_user.identifier, thread_id)
|
|
889
814
|
|
|
890
815
|
res = await data_layer.get_element(thread_id, element_id)
|
|
891
816
|
return JSONResponse(content=res)
|
|
892
817
|
|
|
893
818
|
|
|
894
|
-
@router.put("/project/element")
|
|
895
|
-
async def update_thread_element(
|
|
896
|
-
payload: ElementRequest,
|
|
897
|
-
current_user: UserParam,
|
|
898
|
-
):
|
|
899
|
-
"""Update a specific thread element."""
|
|
900
|
-
|
|
901
|
-
from chainlit.context import init_ws_context
|
|
902
|
-
from chainlit.element import CustomElement, ElementDict
|
|
903
|
-
from chainlit.session import WebsocketSession
|
|
904
|
-
|
|
905
|
-
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
906
|
-
context = init_ws_context(session)
|
|
907
|
-
|
|
908
|
-
element_dict = cast(ElementDict, payload.element)
|
|
909
|
-
|
|
910
|
-
if element_dict["type"] != "custom":
|
|
911
|
-
return {"success": False}
|
|
912
|
-
|
|
913
|
-
element = CustomElement(
|
|
914
|
-
id=element_dict["id"],
|
|
915
|
-
object_key=element_dict["objectKey"],
|
|
916
|
-
chainlit_key=element_dict["chainlitKey"],
|
|
917
|
-
url=element_dict["url"],
|
|
918
|
-
for_id=element_dict.get("forId") or "",
|
|
919
|
-
thread_id=element_dict.get("threadId") or "",
|
|
920
|
-
name=element_dict["name"],
|
|
921
|
-
props=element_dict.get("props") or {},
|
|
922
|
-
display=element_dict["display"],
|
|
923
|
-
)
|
|
924
|
-
|
|
925
|
-
if current_user:
|
|
926
|
-
if (
|
|
927
|
-
not context.session.user
|
|
928
|
-
or context.session.user.identifier != current_user.identifier
|
|
929
|
-
):
|
|
930
|
-
raise HTTPException(
|
|
931
|
-
status_code=401,
|
|
932
|
-
detail="You are not authorized to update elements for this session",
|
|
933
|
-
)
|
|
934
|
-
|
|
935
|
-
await element.send(for_id=element.for_id or "")
|
|
936
|
-
return {"success": True}
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
@router.delete("/project/element")
|
|
940
|
-
async def delete_thread_element(
|
|
941
|
-
payload: ElementRequest,
|
|
942
|
-
current_user: UserParam,
|
|
943
|
-
):
|
|
944
|
-
"""Delete a specific thread element."""
|
|
945
|
-
|
|
946
|
-
from chainlit.context import init_ws_context
|
|
947
|
-
from chainlit.element import CustomElement, ElementDict
|
|
948
|
-
from chainlit.session import WebsocketSession
|
|
949
|
-
|
|
950
|
-
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
951
|
-
context = init_ws_context(session)
|
|
952
|
-
|
|
953
|
-
element_dict = cast(ElementDict, payload.element)
|
|
954
|
-
|
|
955
|
-
if element_dict["type"] != "custom":
|
|
956
|
-
return {"success": False}
|
|
957
|
-
|
|
958
|
-
element = CustomElement(
|
|
959
|
-
id=element_dict["id"],
|
|
960
|
-
object_key=element_dict["objectKey"],
|
|
961
|
-
chainlit_key=element_dict["chainlitKey"],
|
|
962
|
-
url=element_dict["url"],
|
|
963
|
-
for_id=element_dict.get("forId") or "",
|
|
964
|
-
thread_id=element_dict.get("threadId") or "",
|
|
965
|
-
name=element_dict["name"],
|
|
966
|
-
props=element_dict.get("props") or {},
|
|
967
|
-
display=element_dict["display"],
|
|
968
|
-
)
|
|
969
|
-
|
|
970
|
-
if current_user:
|
|
971
|
-
if (
|
|
972
|
-
not context.session.user
|
|
973
|
-
or context.session.user.identifier != current_user.identifier
|
|
974
|
-
):
|
|
975
|
-
raise HTTPException(
|
|
976
|
-
status_code=401,
|
|
977
|
-
detail="You are not authorized to remove elements for this session",
|
|
978
|
-
)
|
|
979
|
-
|
|
980
|
-
await element.remove()
|
|
981
|
-
|
|
982
|
-
return {"success": True}
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
@router.put("/project/thread")
|
|
986
|
-
async def rename_thread(
|
|
987
|
-
request: Request,
|
|
988
|
-
payload: UpdateThreadRequest,
|
|
989
|
-
current_user: UserParam,
|
|
990
|
-
):
|
|
991
|
-
"""Rename a thread."""
|
|
992
|
-
|
|
993
|
-
data_layer = get_data_layer()
|
|
994
|
-
|
|
995
|
-
if not data_layer:
|
|
996
|
-
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
997
|
-
|
|
998
|
-
if not current_user:
|
|
999
|
-
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
1000
|
-
|
|
1001
|
-
thread_id = payload.threadId
|
|
1002
|
-
|
|
1003
|
-
await is_thread_author(current_user.identifier, thread_id)
|
|
1004
|
-
|
|
1005
|
-
await data_layer.update_thread(thread_id, name=payload.name)
|
|
1006
|
-
return JSONResponse(content={"success": True})
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
819
|
@router.delete("/project/thread")
|
|
1010
820
|
async def delete_thread(
|
|
1011
821
|
request: Request,
|
|
1012
822
|
payload: DeleteThreadRequest,
|
|
1013
|
-
current_user:
|
|
823
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
1014
824
|
):
|
|
1015
825
|
"""Delete a thread."""
|
|
1016
826
|
|
|
@@ -1019,9 +829,6 @@ async def delete_thread(
|
|
|
1019
829
|
if not data_layer:
|
|
1020
830
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
1021
831
|
|
|
1022
|
-
if not current_user:
|
|
1023
|
-
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
1024
|
-
|
|
1025
832
|
thread_id = payload.threadId
|
|
1026
833
|
|
|
1027
834
|
await is_thread_author(current_user.identifier, thread_id)
|
|
@@ -1030,49 +837,13 @@ async def delete_thread(
|
|
|
1030
837
|
return JSONResponse(content={"success": True})
|
|
1031
838
|
|
|
1032
839
|
|
|
1033
|
-
@router.post("/project/action")
|
|
1034
|
-
async def call_action(
|
|
1035
|
-
payload: CallActionRequest,
|
|
1036
|
-
current_user: UserParam,
|
|
1037
|
-
):
|
|
1038
|
-
"""Run an action."""
|
|
1039
|
-
|
|
1040
|
-
from chainlit.action import Action
|
|
1041
|
-
from chainlit.context import init_ws_context
|
|
1042
|
-
from chainlit.session import WebsocketSession
|
|
1043
|
-
|
|
1044
|
-
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
1045
|
-
context = init_ws_context(session)
|
|
1046
|
-
|
|
1047
|
-
action = Action(**payload.action)
|
|
1048
|
-
|
|
1049
|
-
if current_user:
|
|
1050
|
-
if (
|
|
1051
|
-
not context.session.user
|
|
1052
|
-
or context.session.user.identifier != current_user.identifier
|
|
1053
|
-
):
|
|
1054
|
-
raise HTTPException(
|
|
1055
|
-
status_code=401,
|
|
1056
|
-
detail="You are not authorized to upload files for this session",
|
|
1057
|
-
)
|
|
1058
|
-
|
|
1059
|
-
callback = config.code.action_callbacks.get(action.name)
|
|
1060
|
-
if callback:
|
|
1061
|
-
await callback(action)
|
|
1062
|
-
else:
|
|
1063
|
-
raise HTTPException(
|
|
1064
|
-
status_code=404,
|
|
1065
|
-
detail=f"No callback found for action {action.name}",
|
|
1066
|
-
)
|
|
1067
|
-
|
|
1068
|
-
return JSONResponse(content={"success": True})
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
840
|
@router.post("/project/file")
|
|
1072
841
|
async def upload_file(
|
|
1073
|
-
current_user: UserParam,
|
|
1074
842
|
session_id: str,
|
|
1075
843
|
file: UploadFile,
|
|
844
|
+
current_user: Annotated[
|
|
845
|
+
Union[None, User, PersistedUser], Depends(get_current_user)
|
|
846
|
+
],
|
|
1076
847
|
):
|
|
1077
848
|
"""Upload a file to the session files directory."""
|
|
1078
849
|
|
|
@@ -1097,111 +868,30 @@ async def upload_file(
|
|
|
1097
868
|
|
|
1098
869
|
content = await file.read()
|
|
1099
870
|
|
|
1100
|
-
assert file.filename, "No filename for uploaded file"
|
|
1101
|
-
assert file.content_type, "No content type for uploaded file"
|
|
1102
|
-
|
|
1103
|
-
try:
|
|
1104
|
-
validate_file_upload(file)
|
|
1105
|
-
except ValueError as e:
|
|
1106
|
-
raise HTTPException(status_code=400, detail=str(e))
|
|
1107
|
-
|
|
1108
871
|
file_response = await session.persist_file(
|
|
1109
872
|
name=file.filename, content=content, mime=file.content_type
|
|
1110
873
|
)
|
|
1111
874
|
|
|
1112
|
-
return JSONResponse(
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
def validate_file_upload(file: UploadFile):
|
|
1116
|
-
"""Validate the file upload as configured in config.features.spontaneous_file_upload.
|
|
1117
|
-
Args:
|
|
1118
|
-
file (UploadFile): The file to validate.
|
|
1119
|
-
Raises:
|
|
1120
|
-
ValueError: If the file is not allowed.
|
|
1121
|
-
"""
|
|
1122
|
-
if config.features.spontaneous_file_upload is None:
|
|
1123
|
-
"""Default for a missing config is to allow the fileupload without any restrictions"""
|
|
1124
|
-
return
|
|
1125
|
-
if config.features.spontaneous_file_upload.enabled is False:
|
|
1126
|
-
raise ValueError("File upload is not enabled")
|
|
1127
|
-
|
|
1128
|
-
validate_file_mime_type(file)
|
|
1129
|
-
validate_file_size(file)
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
def validate_file_mime_type(file: UploadFile):
|
|
1133
|
-
"""Validate the file mime type as configured in config.features.spontaneous_file_upload.
|
|
1134
|
-
Args:
|
|
1135
|
-
file (UploadFile): The file to validate.
|
|
1136
|
-
Raises:
|
|
1137
|
-
ValueError: If the file type is not allowed.
|
|
1138
|
-
"""
|
|
1139
|
-
accept = config.features.spontaneous_file_upload.accept
|
|
1140
|
-
if accept is None:
|
|
1141
|
-
"Accept is not configured, allowing all file types"
|
|
1142
|
-
return
|
|
1143
|
-
|
|
1144
|
-
assert (
|
|
1145
|
-
isinstance(accept, List) or isinstance(accept, dict)
|
|
1146
|
-
), "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
|
|
1147
|
-
|
|
1148
|
-
if isinstance(accept, List):
|
|
1149
|
-
for pattern in accept:
|
|
1150
|
-
if fnmatch.fnmatch(file.content_type, pattern):
|
|
1151
|
-
return
|
|
1152
|
-
elif isinstance(accept, dict):
|
|
1153
|
-
for pattern, extensions in accept.items():
|
|
1154
|
-
if fnmatch.fnmatch(file.content_type, pattern):
|
|
1155
|
-
if len(extensions) == 0:
|
|
1156
|
-
return
|
|
1157
|
-
for extension in extensions:
|
|
1158
|
-
if file.filename is not None and file.filename.endswith(extension):
|
|
1159
|
-
return
|
|
1160
|
-
raise ValueError("File type not allowed")
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
def validate_file_size(file: UploadFile):
|
|
1164
|
-
"""Validate the file size as configured in config.features.spontaneous_file_upload.
|
|
1165
|
-
Args:
|
|
1166
|
-
file (UploadFile): The file to validate.
|
|
1167
|
-
Raises:
|
|
1168
|
-
ValueError: If the file size is too large.
|
|
1169
|
-
"""
|
|
1170
|
-
if config.features.spontaneous_file_upload.max_size_mb is None:
|
|
1171
|
-
return
|
|
1172
|
-
|
|
1173
|
-
if (
|
|
1174
|
-
file.size is not None
|
|
1175
|
-
and file.size
|
|
1176
|
-
> config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
|
|
1177
|
-
):
|
|
1178
|
-
raise ValueError("File size too large")
|
|
875
|
+
return JSONResponse(file_response)
|
|
1179
876
|
|
|
1180
877
|
|
|
1181
878
|
@router.get("/project/file/{file_id}")
|
|
1182
879
|
async def get_file(
|
|
1183
880
|
file_id: str,
|
|
1184
|
-
session_id: str,
|
|
1185
|
-
current_user: UserParam,
|
|
881
|
+
session_id: Optional[str] = None,
|
|
1186
882
|
):
|
|
1187
883
|
"""Get a file from the session files directory."""
|
|
884
|
+
|
|
1188
885
|
from chainlit.session import WebsocketSession
|
|
1189
886
|
|
|
1190
887
|
session = WebsocketSession.get_by_id(session_id) if session_id else None
|
|
1191
888
|
|
|
1192
889
|
if not session:
|
|
1193
890
|
raise HTTPException(
|
|
1194
|
-
status_code=
|
|
1195
|
-
detail="
|
|
891
|
+
status_code=404,
|
|
892
|
+
detail="Session not found",
|
|
1196
893
|
)
|
|
1197
894
|
|
|
1198
|
-
if current_user:
|
|
1199
|
-
if not session.user or session.user.identifier != current_user.identifier:
|
|
1200
|
-
raise HTTPException(
|
|
1201
|
-
status_code=401,
|
|
1202
|
-
detail="You are not authorized to download files from this session",
|
|
1203
|
-
)
|
|
1204
|
-
|
|
1205
895
|
if file_id in session.files:
|
|
1206
896
|
file = session.files[file_id]
|
|
1207
897
|
return FileResponse(file["path"], media_type=file["type"])
|
|
@@ -1212,7 +902,7 @@ async def get_file(
|
|
|
1212
902
|
@router.get("/files/{filename:path}")
|
|
1213
903
|
async def serve_file(
|
|
1214
904
|
filename: str,
|
|
1215
|
-
current_user:
|
|
905
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
1216
906
|
):
|
|
1217
907
|
"""Serve a file from the local filesystem."""
|
|
1218
908
|
|
|
@@ -1271,7 +961,7 @@ async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
|
|
|
1271
961
|
@router.get("/avatars/{avatar_id:str}")
|
|
1272
962
|
async def get_avatar(avatar_id: str):
|
|
1273
963
|
"""Get the avatar for the user based on the avatar_id."""
|
|
1274
|
-
if not re.match(r"^[a-zA-Z0-9_
|
|
964
|
+
if not re.match(r"^[a-zA-Z0-9_-]+$", avatar_id):
|
|
1275
965
|
raise HTTPException(status_code=400, detail="Invalid avatar_id")
|
|
1276
966
|
|
|
1277
967
|
if avatar_id == "default":
|