chainlit 2.0.0__py3-none-any.whl → 2.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +57 -55
- chainlit/action.py +10 -12
- chainlit/{auth/__init__.py → auth.py} +34 -20
- chainlit/cache.py +1 -2
- chainlit/callbacks.py +7 -52
- chainlit/chat_context.py +2 -2
- chainlit/chat_settings.py +1 -3
- chainlit/cli/__init__.py +1 -14
- chainlit/config.py +69 -35
- chainlit/context.py +2 -3
- chainlit/copilot/dist/index.js +935 -8533
- chainlit/data/__init__.py +8 -96
- chainlit/data/acl.py +2 -3
- chainlit/data/base.py +1 -1
- chainlit/data/dynamodb.py +3 -5
- chainlit/data/literalai.py +6 -4
- chainlit/data/sql_alchemy.py +7 -8
- chainlit/data/storage_clients/azure.py +0 -1
- chainlit/data/storage_clients/base.py +0 -6
- chainlit/data/storage_clients/s3.py +3 -16
- chainlit/discord/app.py +1 -2
- chainlit/element.py +9 -13
- chainlit/emitter.py +21 -17
- chainlit/frontend/dist/assets/{DailyMotion-DgRzV5GZ.js → DailyMotion-D1ipkdPJ.js} +1 -1
- chainlit/frontend/dist/assets/{Facebook-C0vx6HWv.js → Facebook-d4TLeTik.js} +1 -1
- chainlit/frontend/dist/assets/{FilePlayer-CdhzeHPP.js → FilePlayer-BcU7tttX.js} +1 -1
- chainlit/frontend/dist/assets/{Kaltura-5iVmeUct.js → Kaltura-DdaRjZrh.js} +1 -1
- chainlit/frontend/dist/assets/{Mixcloud-C2zi77Ex.js → Mixcloud-BaJoMsaU.js} +1 -1
- chainlit/frontend/dist/assets/{Mux-Vkebogdf.js → Mux-DxPCM5d3.js} +1 -1
- chainlit/frontend/dist/assets/{Preview-DwY_sEIl.js → Preview-tUK_Z9pZ.js} +1 -1
- chainlit/frontend/dist/assets/{SoundCloud-CREBXAWo.js → SoundCloud-K8-lFZC6.js} +1 -1
- chainlit/frontend/dist/assets/{Streamable-B5Lu25uy.js → Streamable-hB-AQ54w.js} +1 -1
- chainlit/frontend/dist/assets/{Twitch-y9iKCcM1.js → Twitch-pmuNY0J5.js} +1 -1
- chainlit/frontend/dist/assets/{Vidyard-ClYvcuEu.js → Vidyard-BSUm6trV.js} +1 -1
- chainlit/frontend/dist/assets/{Vimeo-D6HvM2jt.js → Vimeo-JIPn71zS.js} +1 -1
- chainlit/frontend/dist/assets/Wistia-D75KkqOG.js +1 -0
- chainlit/frontend/dist/assets/{YouTube-D10tR6CJ.js → YouTube-CPlwqNm_.js} +1 -1
- chainlit/frontend/dist/assets/index-CuSbXjG5.js +1091 -0
- chainlit/frontend/dist/assets/index-CwmincdQ.css +1 -0
- chainlit/frontend/dist/assets/{react-plotly-BpxUS-ab.js → react-plotly-DALmanjC.js} +1 -1
- chainlit/frontend/dist/index.html +2 -2
- chainlit/haystack/callbacks.py +4 -5
- chainlit/input_widget.py +4 -6
- chainlit/langchain/callbacks.py +47 -56
- chainlit/langflow/__init__.py +0 -1
- chainlit/llama_index/callbacks.py +7 -7
- chainlit/message.py +7 -6
- chainlit/mistralai/__init__.py +2 -3
- chainlit/oauth_providers.py +3 -70
- chainlit/openai/__init__.py +2 -3
- chainlit/secret.py +1 -1
- chainlit/server.py +174 -474
- chainlit/session.py +5 -7
- chainlit/slack/app.py +2 -3
- chainlit/socket.py +103 -78
- chainlit/step.py +11 -11
- chainlit/sync.py +1 -2
- chainlit/teams/app.py +0 -1
- chainlit/types.py +4 -20
- chainlit/user.py +1 -2
- chainlit/utils.py +2 -3
- {chainlit-2.0.0.dist-info → chainlit-2.0.dev1.dist-info}/METADATA +38 -8
- chainlit-2.0.dev1.dist-info/RECORD +99 -0
- chainlit/auth/cookie.py +0 -123
- chainlit/auth/jwt.py +0 -37
- chainlit/data/chainlit_data_layer.py +0 -584
- chainlit/data/storage_clients/azure_blob.py +0 -80
- chainlit/data/storage_clients/gcs.py +0 -78
- chainlit/frontend/dist/assets/Dataframe-DVgwSMU2.js +0 -22
- chainlit/frontend/dist/assets/Wistia-Cu4zZ2Ci.js +0 -1
- chainlit/frontend/dist/assets/index-CI4qFOt5.js +0 -8665
- chainlit/frontend/dist/assets/index-CrrqM0nZ.css +0 -1
- chainlit/translations/nl-NL.json +0 -229
- chainlit-2.0.0.dist-info/RECORD +0 -106
- {chainlit-2.0.0.dist-info → chainlit-2.0.dev1.dist-info}/WHEEL +0 -0
- {chainlit-2.0.0.dist-info → chainlit-2.0.dev1.dist-info}/entry_points.txt +0 -0
chainlit/server.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import fnmatch
|
|
3
2
|
import glob
|
|
4
3
|
import json
|
|
5
4
|
import mimetypes
|
|
@@ -10,36 +9,10 @@ import urllib.parse
|
|
|
10
9
|
import webbrowser
|
|
11
10
|
from contextlib import asynccontextmanager
|
|
12
11
|
from pathlib import Path
|
|
13
|
-
from typing import
|
|
12
|
+
from typing import Any, Optional, Union
|
|
14
13
|
|
|
15
14
|
import socketio
|
|
16
|
-
from
|
|
17
|
-
APIRouter,
|
|
18
|
-
Depends,
|
|
19
|
-
FastAPI,
|
|
20
|
-
Form,
|
|
21
|
-
HTTPException,
|
|
22
|
-
Query,
|
|
23
|
-
Request,
|
|
24
|
-
Response,
|
|
25
|
-
UploadFile,
|
|
26
|
-
status,
|
|
27
|
-
)
|
|
28
|
-
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
|
|
29
|
-
from fastapi.security import OAuth2PasswordRequestForm
|
|
30
|
-
from starlette.datastructures import URL
|
|
31
|
-
from starlette.middleware.cors import CORSMiddleware
|
|
32
|
-
from typing_extensions import Annotated
|
|
33
|
-
from watchfiles import awatch
|
|
34
|
-
|
|
35
|
-
from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
|
|
36
|
-
from chainlit.auth.cookie import (
|
|
37
|
-
clear_auth_cookie,
|
|
38
|
-
clear_oauth_state_cookie,
|
|
39
|
-
set_auth_cookie,
|
|
40
|
-
set_oauth_state_cookie,
|
|
41
|
-
validate_oauth_state_cookie,
|
|
42
|
-
)
|
|
15
|
+
from chainlit.auth import create_jwt, get_configuration, get_current_user
|
|
43
16
|
from chainlit.config import (
|
|
44
17
|
APP_ROOT,
|
|
45
18
|
BACKEND_ROOT,
|
|
@@ -48,7 +21,6 @@ from chainlit.config import (
|
|
|
48
21
|
PACKAGE_ROOT,
|
|
49
22
|
config,
|
|
50
23
|
load_module,
|
|
51
|
-
public_dir,
|
|
52
24
|
reload_config,
|
|
53
25
|
)
|
|
54
26
|
from chainlit.data import get_data_layer
|
|
@@ -58,16 +30,33 @@ from chainlit.markdown import get_markdown_str
|
|
|
58
30
|
from chainlit.oauth_providers import get_oauth_provider
|
|
59
31
|
from chainlit.secret import random_secret
|
|
60
32
|
from chainlit.types import (
|
|
61
|
-
CallActionRequest,
|
|
62
33
|
DeleteFeedbackRequest,
|
|
63
34
|
DeleteThreadRequest,
|
|
64
|
-
ElementRequest,
|
|
65
35
|
GetThreadsRequest,
|
|
66
36
|
Theme,
|
|
67
37
|
UpdateFeedbackRequest,
|
|
68
|
-
UpdateThreadRequest,
|
|
69
38
|
)
|
|
70
39
|
from chainlit.user import PersistedUser, User
|
|
40
|
+
from fastapi import (
|
|
41
|
+
APIRouter,
|
|
42
|
+
Depends,
|
|
43
|
+
FastAPI,
|
|
44
|
+
File,
|
|
45
|
+
Form,
|
|
46
|
+
HTTPException,
|
|
47
|
+
Query,
|
|
48
|
+
Request,
|
|
49
|
+
Response,
|
|
50
|
+
UploadFile,
|
|
51
|
+
status,
|
|
52
|
+
)
|
|
53
|
+
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
|
|
54
|
+
from fastapi.security import OAuth2PasswordRequestForm
|
|
55
|
+
from fastapi.staticfiles import StaticFiles
|
|
56
|
+
from starlette.datastructures import URL
|
|
57
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
58
|
+
from typing_extensions import Annotated
|
|
59
|
+
from watchfiles import awatch
|
|
71
60
|
|
|
72
61
|
from ._utils import is_path_inside
|
|
73
62
|
|
|
@@ -216,59 +205,29 @@ app.add_middleware(
|
|
|
216
205
|
|
|
217
206
|
router = APIRouter(prefix=PREFIX)
|
|
218
207
|
|
|
208
|
+
app.mount(
|
|
209
|
+
f"{PREFIX}/public",
|
|
210
|
+
StaticFiles(directory="public", check_dir=False),
|
|
211
|
+
name="public",
|
|
212
|
+
)
|
|
219
213
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
if not is_path_inside(file_path, base_path):
|
|
230
|
-
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
231
|
-
|
|
232
|
-
if file_path.is_file():
|
|
233
|
-
return FileResponse(file_path)
|
|
234
|
-
else:
|
|
235
|
-
raise HTTPException(status_code=404, detail="File not found")
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
@router.get("/assets/{filename:path}")
|
|
239
|
-
async def serve_asset_file(
|
|
240
|
-
filename: str,
|
|
241
|
-
):
|
|
242
|
-
"""Serve a file from assets dir."""
|
|
243
|
-
|
|
244
|
-
base_path = Path(os.path.join(build_dir, "assets"))
|
|
245
|
-
file_path = (base_path / filename).resolve()
|
|
246
|
-
|
|
247
|
-
if not is_path_inside(file_path, base_path):
|
|
248
|
-
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
249
|
-
|
|
250
|
-
if file_path.is_file():
|
|
251
|
-
return FileResponse(file_path)
|
|
252
|
-
else:
|
|
253
|
-
raise HTTPException(status_code=404, detail="File not found")
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
@router.get("/copilot/{filename:path}")
|
|
257
|
-
async def serve_copilot_file(
|
|
258
|
-
filename: str,
|
|
259
|
-
):
|
|
260
|
-
"""Serve a file from assets dir."""
|
|
261
|
-
|
|
262
|
-
base_path = Path(copilot_build_dir)
|
|
263
|
-
file_path = (base_path / filename).resolve()
|
|
264
|
-
|
|
265
|
-
if not is_path_inside(file_path, base_path):
|
|
266
|
-
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
214
|
+
app.mount(
|
|
215
|
+
f"{PREFIX}/assets",
|
|
216
|
+
StaticFiles(
|
|
217
|
+
packages=[("chainlit", os.path.join(build_dir, "assets"))],
|
|
218
|
+
follow_symlink=config.project.follow_symlink,
|
|
219
|
+
),
|
|
220
|
+
name="assets",
|
|
221
|
+
)
|
|
267
222
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
223
|
+
app.mount(
|
|
224
|
+
f"{PREFIX}/copilot",
|
|
225
|
+
StaticFiles(
|
|
226
|
+
packages=[("chainlit", copilot_build_dir)],
|
|
227
|
+
follow_symlink=config.project.follow_symlink,
|
|
228
|
+
),
|
|
229
|
+
name="copilot",
|
|
230
|
+
)
|
|
272
231
|
|
|
273
232
|
|
|
274
233
|
# -------------------------------------------------------------------------------
|
|
@@ -289,7 +248,6 @@ if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_SIGNING_SECRET"):
|
|
|
289
248
|
|
|
290
249
|
if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
|
|
291
250
|
from botbuilder.schema import Activity
|
|
292
|
-
|
|
293
251
|
from chainlit.teams.app import adapter, bot
|
|
294
252
|
|
|
295
253
|
@router.post("/teams/events")
|
|
@@ -319,16 +277,6 @@ def get_html_template():
|
|
|
319
277
|
"""
|
|
320
278
|
Get HTML template for the index view.
|
|
321
279
|
"""
|
|
322
|
-
ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
323
|
-
|
|
324
|
-
custom_theme = None
|
|
325
|
-
custom_theme_file_path = Path(public_dir) / "theme.json"
|
|
326
|
-
if (
|
|
327
|
-
is_path_inside(custom_theme_file_path, Path(public_dir))
|
|
328
|
-
and custom_theme_file_path.is_file()
|
|
329
|
-
):
|
|
330
|
-
custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))
|
|
331
|
-
|
|
332
280
|
PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
|
|
333
281
|
JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
|
|
334
282
|
CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
|
|
@@ -351,10 +299,7 @@ def get_html_template():
|
|
|
351
299
|
<meta property="og:url" content="{url}">
|
|
352
300
|
<meta property="og:root_path" content="{ROOT_PATH}">"""
|
|
353
301
|
|
|
354
|
-
js = f"""<script>
|
|
355
|
-
{f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
|
|
356
|
-
{f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
|
|
357
|
-
</script>"""
|
|
302
|
+
js = f"""<script>{f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}</script>"""
|
|
358
303
|
|
|
359
304
|
css = None
|
|
360
305
|
if config.ui.custom_css:
|
|
@@ -366,15 +311,12 @@ def get_html_template():
|
|
|
366
311
|
js += f"""<script src="{config.ui.custom_js}" defer></script>"""
|
|
367
312
|
|
|
368
313
|
font = None
|
|
369
|
-
if
|
|
370
|
-
font = "
|
|
371
|
-
f"""<link rel="stylesheet" href="{font}">"""
|
|
372
|
-
for font in custom_theme.get("custom_fonts")
|
|
373
|
-
)
|
|
314
|
+
if config.ui.custom_font:
|
|
315
|
+
font = f"""<link rel="stylesheet" href="{config.ui.custom_font}">"""
|
|
374
316
|
|
|
375
317
|
index_html_file_path = os.path.join(build_dir, "index.html")
|
|
376
318
|
|
|
377
|
-
with open(index_html_file_path, encoding="utf-8") as f:
|
|
319
|
+
with open(index_html_file_path, "r", encoding="utf-8") as f:
|
|
378
320
|
content = f.read()
|
|
379
321
|
content = content.replace(PLACEHOLDER, tags)
|
|
380
322
|
if js:
|
|
@@ -419,132 +361,46 @@ async def auth(request: Request):
|
|
|
419
361
|
return get_configuration()
|
|
420
362
|
|
|
421
363
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
response_dict = _get_response_dict(access_token)
|
|
432
|
-
|
|
433
|
-
if redirect_to_callback:
|
|
434
|
-
root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
435
|
-
redirect_url = (
|
|
436
|
-
f"{root_path}/login/callback?{urllib.parse.urlencode(response_dict)}"
|
|
437
|
-
)
|
|
438
|
-
|
|
439
|
-
return RedirectResponse(
|
|
440
|
-
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
441
|
-
url=redirect_url,
|
|
442
|
-
status_code=302,
|
|
364
|
+
@router.post("/login")
|
|
365
|
+
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
366
|
+
"""
|
|
367
|
+
Login a user using the password auth callback.
|
|
368
|
+
"""
|
|
369
|
+
if not config.code.password_auth_callback:
|
|
370
|
+
raise HTTPException(
|
|
371
|
+
status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
|
|
443
372
|
)
|
|
444
373
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
def _get_oauth_redirect_error(error: str) -> Response:
|
|
449
|
-
"""Get the redirect response for an OAuth error."""
|
|
450
|
-
params = urllib.parse.urlencode(
|
|
451
|
-
{
|
|
452
|
-
"error": error,
|
|
453
|
-
}
|
|
454
|
-
)
|
|
455
|
-
response = RedirectResponse(
|
|
456
|
-
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
457
|
-
url=f"/login?{params}", # Shouldn't there be {root_path} here?
|
|
374
|
+
user = await config.code.password_auth_callback(
|
|
375
|
+
form_data.username, form_data.password
|
|
458
376
|
)
|
|
459
|
-
return response
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
async def _authenticate_user(
|
|
463
|
-
user: Optional[User], redirect_to_callback: bool = False
|
|
464
|
-
) -> Response:
|
|
465
|
-
"""Authenticate a user and return the response."""
|
|
466
377
|
|
|
467
378
|
if not user:
|
|
468
379
|
raise HTTPException(
|
|
469
380
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
470
381
|
detail="credentialssignin",
|
|
471
382
|
)
|
|
472
|
-
|
|
473
|
-
# If a data layer is defined, attempt to persist user.
|
|
383
|
+
access_token = create_jwt(user)
|
|
474
384
|
if data_layer := get_data_layer():
|
|
475
385
|
try:
|
|
476
386
|
await data_layer.create_user(user)
|
|
477
387
|
except Exception as e:
|
|
478
|
-
# Catch and log exceptions during user creation.
|
|
479
|
-
# TODO: Make this catch only specific errors and allow others to propagate.
|
|
480
388
|
logger.error(f"Error creating user: {e}")
|
|
481
389
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
set_auth_cookie(response, access_token)
|
|
487
|
-
|
|
488
|
-
return response
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
@router.post("/login")
|
|
492
|
-
async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()):
|
|
493
|
-
"""
|
|
494
|
-
Login a user using the password auth callback.
|
|
495
|
-
"""
|
|
496
|
-
if not config.code.password_auth_callback:
|
|
497
|
-
raise HTTPException(
|
|
498
|
-
status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
|
|
499
|
-
)
|
|
500
|
-
|
|
501
|
-
user = await config.code.password_auth_callback(
|
|
502
|
-
form_data.username, form_data.password
|
|
503
|
-
)
|
|
504
|
-
|
|
505
|
-
return await _authenticate_user(user)
|
|
390
|
+
return {
|
|
391
|
+
"access_token": access_token,
|
|
392
|
+
"token_type": "bearer",
|
|
393
|
+
}
|
|
506
394
|
|
|
507
395
|
|
|
508
396
|
@router.post("/logout")
|
|
509
397
|
async def logout(request: Request, response: Response):
|
|
510
398
|
"""Logout the user by calling the on_logout callback."""
|
|
511
|
-
clear_auth_cookie(response)
|
|
512
|
-
|
|
513
399
|
if config.code.on_logout:
|
|
514
400
|
return await config.code.on_logout(request, response)
|
|
515
|
-
|
|
516
401
|
return {"success": True}
|
|
517
402
|
|
|
518
403
|
|
|
519
|
-
@router.post("/auth/jwt")
|
|
520
|
-
async def jwt_auth(request: Request):
|
|
521
|
-
"""Login a user using a valid jwt."""
|
|
522
|
-
from jwt import InvalidTokenError
|
|
523
|
-
|
|
524
|
-
auth_header: Optional[str] = request.headers.get("Authorization")
|
|
525
|
-
if not auth_header:
|
|
526
|
-
raise HTTPException(status_code=401, detail="Authorization header missing")
|
|
527
|
-
|
|
528
|
-
# Check if it starts with "Bearer "
|
|
529
|
-
try:
|
|
530
|
-
scheme, token = auth_header.split()
|
|
531
|
-
if scheme.lower() != "bearer":
|
|
532
|
-
raise HTTPException(
|
|
533
|
-
status_code=401,
|
|
534
|
-
detail="Invalid authentication scheme. Please use Bearer",
|
|
535
|
-
)
|
|
536
|
-
except ValueError:
|
|
537
|
-
raise HTTPException(
|
|
538
|
-
status_code=401, detail="Invalid authorization header format"
|
|
539
|
-
)
|
|
540
|
-
|
|
541
|
-
try:
|
|
542
|
-
user = decode_jwt(token)
|
|
543
|
-
return await _authenticate_user(user)
|
|
544
|
-
except InvalidTokenError:
|
|
545
|
-
raise HTTPException(status_code=401, detail="Invalid token")
|
|
546
|
-
|
|
547
|
-
|
|
548
404
|
@router.post("/auth/header")
|
|
549
405
|
async def header_auth(request: Request):
|
|
550
406
|
"""Login a user using the header_auth_callback."""
|
|
@@ -556,7 +412,23 @@ async def header_auth(request: Request):
|
|
|
556
412
|
|
|
557
413
|
user = await config.code.header_auth_callback(request.headers)
|
|
558
414
|
|
|
559
|
-
|
|
415
|
+
if not user:
|
|
416
|
+
raise HTTPException(
|
|
417
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
418
|
+
detail="Unauthorized",
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
access_token = create_jwt(user)
|
|
422
|
+
if data_layer := get_data_layer():
|
|
423
|
+
try:
|
|
424
|
+
await data_layer.create_user(user)
|
|
425
|
+
except Exception as e:
|
|
426
|
+
logger.error(f"Error creating user: {e}")
|
|
427
|
+
|
|
428
|
+
return {
|
|
429
|
+
"access_token": access_token,
|
|
430
|
+
"token_type": "bearer",
|
|
431
|
+
}
|
|
560
432
|
|
|
561
433
|
|
|
562
434
|
@router.get("/auth/oauth/{provider_id}")
|
|
@@ -588,9 +460,16 @@ async def oauth_login(provider_id: str, request: Request):
|
|
|
588
460
|
response = RedirectResponse(
|
|
589
461
|
url=f"{provider.authorize_url}?{params}",
|
|
590
462
|
)
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
463
|
+
samesite: Any = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax")
|
|
464
|
+
secure = samesite.lower() == "none"
|
|
465
|
+
response.set_cookie(
|
|
466
|
+
"oauth_state",
|
|
467
|
+
random,
|
|
468
|
+
httponly=True,
|
|
469
|
+
samesite=samesite,
|
|
470
|
+
secure=secure,
|
|
471
|
+
max_age=3 * 60,
|
|
472
|
+
)
|
|
594
473
|
return response
|
|
595
474
|
|
|
596
475
|
|
|
@@ -618,7 +497,16 @@ async def oauth_callback(
|
|
|
618
497
|
)
|
|
619
498
|
|
|
620
499
|
if error:
|
|
621
|
-
|
|
500
|
+
params = urllib.parse.urlencode(
|
|
501
|
+
{
|
|
502
|
+
"error": error,
|
|
503
|
+
}
|
|
504
|
+
)
|
|
505
|
+
response = RedirectResponse(
|
|
506
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
507
|
+
url=f"/login?{params}",
|
|
508
|
+
)
|
|
509
|
+
return response
|
|
622
510
|
|
|
623
511
|
if not code or not state:
|
|
624
512
|
raise HTTPException(
|
|
@@ -626,11 +514,9 @@ async def oauth_callback(
|
|
|
626
514
|
detail="Missing code or state",
|
|
627
515
|
)
|
|
628
516
|
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
logger.exception("Unable to validate oauth state: %1", e)
|
|
633
|
-
|
|
517
|
+
# Check the state from the oauth provider against the browser cookie
|
|
518
|
+
oauth_state = request.cookies.get("oauth_state")
|
|
519
|
+
if oauth_state != state:
|
|
634
520
|
raise HTTPException(
|
|
635
521
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
636
522
|
detail="Unauthorized",
|
|
@@ -645,10 +531,34 @@ async def oauth_callback(
|
|
|
645
531
|
provider_id, token, raw_user_data, default_user
|
|
646
532
|
)
|
|
647
533
|
|
|
648
|
-
|
|
534
|
+
if not user:
|
|
535
|
+
raise HTTPException(
|
|
536
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
537
|
+
detail="Unauthorized",
|
|
538
|
+
)
|
|
649
539
|
|
|
650
|
-
|
|
540
|
+
access_token = create_jwt(user)
|
|
651
541
|
|
|
542
|
+
if data_layer := get_data_layer():
|
|
543
|
+
try:
|
|
544
|
+
await data_layer.create_user(user)
|
|
545
|
+
except Exception as e:
|
|
546
|
+
logger.error(f"Error creating user: {e}")
|
|
547
|
+
|
|
548
|
+
params = urllib.parse.urlencode(
|
|
549
|
+
{
|
|
550
|
+
"access_token": access_token,
|
|
551
|
+
"token_type": "bearer",
|
|
552
|
+
}
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
556
|
+
|
|
557
|
+
response = RedirectResponse(
|
|
558
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
559
|
+
url=f"{root_path}/login/callback?{params}",
|
|
560
|
+
)
|
|
561
|
+
response.delete_cookie("oauth_state")
|
|
652
562
|
return response
|
|
653
563
|
|
|
654
564
|
|
|
@@ -677,7 +587,16 @@ async def oauth_azure_hf_callback(
|
|
|
677
587
|
)
|
|
678
588
|
|
|
679
589
|
if error:
|
|
680
|
-
|
|
590
|
+
params = urllib.parse.urlencode(
|
|
591
|
+
{
|
|
592
|
+
"error": error,
|
|
593
|
+
}
|
|
594
|
+
)
|
|
595
|
+
response = RedirectResponse(
|
|
596
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
597
|
+
url=f"/login?{params}",
|
|
598
|
+
)
|
|
599
|
+
return response
|
|
681
600
|
|
|
682
601
|
if not code:
|
|
683
602
|
raise HTTPException(
|
|
@@ -694,20 +613,36 @@ async def oauth_azure_hf_callback(
|
|
|
694
613
|
provider_id, token, raw_user_data, default_user, id_token
|
|
695
614
|
)
|
|
696
615
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
616
|
+
if not user:
|
|
617
|
+
raise HTTPException(
|
|
618
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
619
|
+
detail="Unauthorized",
|
|
620
|
+
)
|
|
700
621
|
|
|
701
|
-
|
|
622
|
+
access_token = create_jwt(user)
|
|
702
623
|
|
|
624
|
+
if data_layer := get_data_layer():
|
|
625
|
+
try:
|
|
626
|
+
await data_layer.create_user(user)
|
|
627
|
+
except Exception as e:
|
|
628
|
+
logger.error(f"Error creating user: {e}")
|
|
703
629
|
|
|
704
|
-
|
|
705
|
-
|
|
630
|
+
params = urllib.parse.urlencode(
|
|
631
|
+
{
|
|
632
|
+
"access_token": access_token,
|
|
633
|
+
"token_type": "bearer",
|
|
634
|
+
}
|
|
635
|
+
)
|
|
706
636
|
|
|
637
|
+
root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
707
638
|
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
639
|
+
response = RedirectResponse(
|
|
640
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
641
|
+
url=f"{root_path}/login/callback?{params}",
|
|
642
|
+
status_code=302,
|
|
643
|
+
)
|
|
644
|
+
response.delete_cookie("oauth_state")
|
|
645
|
+
return response
|
|
711
646
|
|
|
712
647
|
|
|
713
648
|
_language_pattern = (
|
|
@@ -735,7 +670,7 @@ async def project_translations(
|
|
|
735
670
|
|
|
736
671
|
@router.get("/project/settings")
|
|
737
672
|
async def project_settings(
|
|
738
|
-
current_user:
|
|
673
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
739
674
|
language: str = Query(
|
|
740
675
|
default="en-US", description="Language code", pattern=_language_pattern
|
|
741
676
|
),
|
|
@@ -786,7 +721,7 @@ async def project_settings(
|
|
|
786
721
|
async def update_feedback(
|
|
787
722
|
request: Request,
|
|
788
723
|
update: UpdateFeedbackRequest,
|
|
789
|
-
current_user:
|
|
724
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
790
725
|
):
|
|
791
726
|
"""Update the human feedback for a particular message."""
|
|
792
727
|
data_layer = get_data_layer()
|
|
@@ -805,7 +740,7 @@ async def update_feedback(
|
|
|
805
740
|
async def delete_feedback(
|
|
806
741
|
request: Request,
|
|
807
742
|
payload: DeleteFeedbackRequest,
|
|
808
|
-
current_user:
|
|
743
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
809
744
|
):
|
|
810
745
|
"""Delete a feedback."""
|
|
811
746
|
|
|
@@ -824,7 +759,7 @@ async def delete_feedback(
|
|
|
824
759
|
async def get_user_threads(
|
|
825
760
|
request: Request,
|
|
826
761
|
payload: GetThreadsRequest,
|
|
827
|
-
current_user:
|
|
762
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
828
763
|
):
|
|
829
764
|
"""Get the threads page by page."""
|
|
830
765
|
|
|
@@ -833,9 +768,6 @@ async def get_user_threads(
|
|
|
833
768
|
if not data_layer:
|
|
834
769
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
835
770
|
|
|
836
|
-
if not current_user:
|
|
837
|
-
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
838
|
-
|
|
839
771
|
if not isinstance(current_user, PersistedUser):
|
|
840
772
|
persisted_user = await data_layer.get_user(identifier=current_user.identifier)
|
|
841
773
|
if not persisted_user:
|
|
@@ -852,7 +784,7 @@ async def get_user_threads(
|
|
|
852
784
|
async def get_thread(
|
|
853
785
|
request: Request,
|
|
854
786
|
thread_id: str,
|
|
855
|
-
current_user:
|
|
787
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
856
788
|
):
|
|
857
789
|
"""Get a specific thread."""
|
|
858
790
|
data_layer = get_data_layer()
|
|
@@ -860,9 +792,6 @@ async def get_thread(
|
|
|
860
792
|
if not data_layer:
|
|
861
793
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
862
794
|
|
|
863
|
-
if not current_user:
|
|
864
|
-
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
865
|
-
|
|
866
795
|
await is_thread_author(current_user.identifier, thread_id)
|
|
867
796
|
|
|
868
797
|
res = await data_layer.get_thread(thread_id)
|
|
@@ -874,7 +803,7 @@ async def get_thread_element(
|
|
|
874
803
|
request: Request,
|
|
875
804
|
thread_id: str,
|
|
876
805
|
element_id: str,
|
|
877
|
-
current_user:
|
|
806
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
878
807
|
):
|
|
879
808
|
"""Get a specific thread element."""
|
|
880
809
|
data_layer = get_data_layer()
|
|
@@ -882,135 +811,17 @@ async def get_thread_element(
|
|
|
882
811
|
if not data_layer:
|
|
883
812
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
884
813
|
|
|
885
|
-
if not current_user:
|
|
886
|
-
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
887
|
-
|
|
888
814
|
await is_thread_author(current_user.identifier, thread_id)
|
|
889
815
|
|
|
890
816
|
res = await data_layer.get_element(thread_id, element_id)
|
|
891
817
|
return JSONResponse(content=res)
|
|
892
818
|
|
|
893
819
|
|
|
894
|
-
@router.put("/project/element")
|
|
895
|
-
async def update_thread_element(
|
|
896
|
-
payload: ElementRequest,
|
|
897
|
-
current_user: UserParam,
|
|
898
|
-
):
|
|
899
|
-
"""Update a specific thread element."""
|
|
900
|
-
|
|
901
|
-
from chainlit.context import init_ws_context
|
|
902
|
-
from chainlit.element import CustomElement, ElementDict
|
|
903
|
-
from chainlit.session import WebsocketSession
|
|
904
|
-
|
|
905
|
-
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
906
|
-
context = init_ws_context(session)
|
|
907
|
-
|
|
908
|
-
element_dict = cast(ElementDict, payload.element)
|
|
909
|
-
|
|
910
|
-
if element_dict["type"] != "custom":
|
|
911
|
-
return {"success": False}
|
|
912
|
-
|
|
913
|
-
element = CustomElement(
|
|
914
|
-
id=element_dict["id"],
|
|
915
|
-
object_key=element_dict["objectKey"],
|
|
916
|
-
chainlit_key=element_dict["chainlitKey"],
|
|
917
|
-
url=element_dict["url"],
|
|
918
|
-
for_id=element_dict.get("forId") or "",
|
|
919
|
-
thread_id=element_dict.get("threadId") or "",
|
|
920
|
-
name=element_dict["name"],
|
|
921
|
-
props=element_dict.get("props") or {},
|
|
922
|
-
display=element_dict["display"],
|
|
923
|
-
)
|
|
924
|
-
|
|
925
|
-
if current_user:
|
|
926
|
-
if (
|
|
927
|
-
not context.session.user
|
|
928
|
-
or context.session.user.identifier != current_user.identifier
|
|
929
|
-
):
|
|
930
|
-
raise HTTPException(
|
|
931
|
-
status_code=401,
|
|
932
|
-
detail="You are not authorized to update elements for this session",
|
|
933
|
-
)
|
|
934
|
-
|
|
935
|
-
await element.send(for_id=element.for_id or "")
|
|
936
|
-
return {"success": True}
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
@router.delete("/project/element")
|
|
940
|
-
async def delete_thread_element(
|
|
941
|
-
payload: ElementRequest,
|
|
942
|
-
current_user: UserParam,
|
|
943
|
-
):
|
|
944
|
-
"""Delete a specific thread element."""
|
|
945
|
-
|
|
946
|
-
from chainlit.context import init_ws_context
|
|
947
|
-
from chainlit.element import CustomElement, ElementDict
|
|
948
|
-
from chainlit.session import WebsocketSession
|
|
949
|
-
|
|
950
|
-
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
951
|
-
context = init_ws_context(session)
|
|
952
|
-
|
|
953
|
-
element_dict = cast(ElementDict, payload.element)
|
|
954
|
-
|
|
955
|
-
if element_dict["type"] != "custom":
|
|
956
|
-
return {"success": False}
|
|
957
|
-
|
|
958
|
-
element = CustomElement(
|
|
959
|
-
id=element_dict["id"],
|
|
960
|
-
object_key=element_dict["objectKey"],
|
|
961
|
-
chainlit_key=element_dict["chainlitKey"],
|
|
962
|
-
url=element_dict["url"],
|
|
963
|
-
for_id=element_dict.get("forId") or "",
|
|
964
|
-
thread_id=element_dict.get("threadId") or "",
|
|
965
|
-
name=element_dict["name"],
|
|
966
|
-
props=element_dict.get("props") or {},
|
|
967
|
-
display=element_dict["display"],
|
|
968
|
-
)
|
|
969
|
-
|
|
970
|
-
if current_user:
|
|
971
|
-
if (
|
|
972
|
-
not context.session.user
|
|
973
|
-
or context.session.user.identifier != current_user.identifier
|
|
974
|
-
):
|
|
975
|
-
raise HTTPException(
|
|
976
|
-
status_code=401,
|
|
977
|
-
detail="You are not authorized to remove elements for this session",
|
|
978
|
-
)
|
|
979
|
-
|
|
980
|
-
await element.remove()
|
|
981
|
-
|
|
982
|
-
return {"success": True}
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
@router.put("/project/thread")
|
|
986
|
-
async def rename_thread(
|
|
987
|
-
request: Request,
|
|
988
|
-
payload: UpdateThreadRequest,
|
|
989
|
-
current_user: UserParam,
|
|
990
|
-
):
|
|
991
|
-
"""Rename a thread."""
|
|
992
|
-
|
|
993
|
-
data_layer = get_data_layer()
|
|
994
|
-
|
|
995
|
-
if not data_layer:
|
|
996
|
-
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
997
|
-
|
|
998
|
-
if not current_user:
|
|
999
|
-
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
1000
|
-
|
|
1001
|
-
thread_id = payload.threadId
|
|
1002
|
-
|
|
1003
|
-
await is_thread_author(current_user.identifier, thread_id)
|
|
1004
|
-
|
|
1005
|
-
await data_layer.update_thread(thread_id, name=payload.name)
|
|
1006
|
-
return JSONResponse(content={"success": True})
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
820
|
@router.delete("/project/thread")
|
|
1010
821
|
async def delete_thread(
|
|
1011
822
|
request: Request,
|
|
1012
823
|
payload: DeleteThreadRequest,
|
|
1013
|
-
current_user:
|
|
824
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
1014
825
|
):
|
|
1015
826
|
"""Delete a thread."""
|
|
1016
827
|
|
|
@@ -1019,9 +830,6 @@ async def delete_thread(
|
|
|
1019
830
|
if not data_layer:
|
|
1020
831
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
1021
832
|
|
|
1022
|
-
if not current_user:
|
|
1023
|
-
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
1024
|
-
|
|
1025
833
|
thread_id = payload.threadId
|
|
1026
834
|
|
|
1027
835
|
await is_thread_author(current_user.identifier, thread_id)
|
|
@@ -1030,47 +838,9 @@ async def delete_thread(
|
|
|
1030
838
|
return JSONResponse(content={"success": True})
|
|
1031
839
|
|
|
1032
840
|
|
|
1033
|
-
@router.post("/project/action")
|
|
1034
|
-
async def call_action(
|
|
1035
|
-
payload: CallActionRequest,
|
|
1036
|
-
current_user: UserParam,
|
|
1037
|
-
):
|
|
1038
|
-
"""Run an action."""
|
|
1039
|
-
|
|
1040
|
-
from chainlit.action import Action
|
|
1041
|
-
from chainlit.context import init_ws_context
|
|
1042
|
-
from chainlit.session import WebsocketSession
|
|
1043
|
-
|
|
1044
|
-
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
1045
|
-
context = init_ws_context(session)
|
|
1046
|
-
|
|
1047
|
-
action = Action(**payload.action)
|
|
1048
|
-
|
|
1049
|
-
if current_user:
|
|
1050
|
-
if (
|
|
1051
|
-
not context.session.user
|
|
1052
|
-
or context.session.user.identifier != current_user.identifier
|
|
1053
|
-
):
|
|
1054
|
-
raise HTTPException(
|
|
1055
|
-
status_code=401,
|
|
1056
|
-
detail="You are not authorized to upload files for this session",
|
|
1057
|
-
)
|
|
1058
|
-
|
|
1059
|
-
callback = config.code.action_callbacks.get(action.name)
|
|
1060
|
-
if callback:
|
|
1061
|
-
await callback(action)
|
|
1062
|
-
else:
|
|
1063
|
-
raise HTTPException(
|
|
1064
|
-
status_code=404,
|
|
1065
|
-
detail=f"No callback found for action {action.name}",
|
|
1066
|
-
)
|
|
1067
|
-
|
|
1068
|
-
return JSONResponse(content={"success": True})
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
841
|
@router.post("/project/file")
|
|
1072
842
|
async def upload_file(
|
|
1073
|
-
current_user:
|
|
843
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
1074
844
|
session_id: str,
|
|
1075
845
|
file: UploadFile,
|
|
1076
846
|
):
|
|
@@ -1100,11 +870,6 @@ async def upload_file(
|
|
|
1100
870
|
assert file.filename, "No filename for uploaded file"
|
|
1101
871
|
assert file.content_type, "No content type for uploaded file"
|
|
1102
872
|
|
|
1103
|
-
try:
|
|
1104
|
-
validate_file_upload(file)
|
|
1105
|
-
except ValueError as e:
|
|
1106
|
-
raise HTTPException(status_code=400, detail=str(e))
|
|
1107
|
-
|
|
1108
873
|
file_response = await session.persist_file(
|
|
1109
874
|
name=file.filename, content=content, mime=file.content_type
|
|
1110
875
|
)
|
|
@@ -1112,79 +877,14 @@ async def upload_file(
|
|
|
1112
877
|
return JSONResponse(content=file_response)
|
|
1113
878
|
|
|
1114
879
|
|
|
1115
|
-
def validate_file_upload(file: UploadFile):
|
|
1116
|
-
"""Validate the file upload as configured in config.features.spontaneous_file_upload.
|
|
1117
|
-
Args:
|
|
1118
|
-
file (UploadFile): The file to validate.
|
|
1119
|
-
Raises:
|
|
1120
|
-
ValueError: If the file is not allowed.
|
|
1121
|
-
"""
|
|
1122
|
-
if config.features.spontaneous_file_upload is None:
|
|
1123
|
-
"""Default for a missing config is to allow the fileupload without any restrictions"""
|
|
1124
|
-
return
|
|
1125
|
-
if config.features.spontaneous_file_upload.enabled is False:
|
|
1126
|
-
raise ValueError("File upload is not enabled")
|
|
1127
|
-
|
|
1128
|
-
validate_file_mime_type(file)
|
|
1129
|
-
validate_file_size(file)
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
def validate_file_mime_type(file: UploadFile):
|
|
1133
|
-
"""Validate the file mime type as configured in config.features.spontaneous_file_upload.
|
|
1134
|
-
Args:
|
|
1135
|
-
file (UploadFile): The file to validate.
|
|
1136
|
-
Raises:
|
|
1137
|
-
ValueError: If the file type is not allowed.
|
|
1138
|
-
"""
|
|
1139
|
-
accept = config.features.spontaneous_file_upload.accept
|
|
1140
|
-
if accept is None:
|
|
1141
|
-
"Accept is not configured, allowing all file types"
|
|
1142
|
-
return
|
|
1143
|
-
|
|
1144
|
-
assert (
|
|
1145
|
-
isinstance(accept, List) or isinstance(accept, dict)
|
|
1146
|
-
), "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
|
|
1147
|
-
|
|
1148
|
-
if isinstance(accept, List):
|
|
1149
|
-
for pattern in accept:
|
|
1150
|
-
if fnmatch.fnmatch(file.content_type, pattern):
|
|
1151
|
-
return
|
|
1152
|
-
elif isinstance(accept, dict):
|
|
1153
|
-
for pattern, extensions in accept.items():
|
|
1154
|
-
if fnmatch.fnmatch(file.content_type, pattern):
|
|
1155
|
-
if len(extensions) == 0:
|
|
1156
|
-
return
|
|
1157
|
-
for extension in extensions:
|
|
1158
|
-
if file.filename is not None and file.filename.endswith(extension):
|
|
1159
|
-
return
|
|
1160
|
-
raise ValueError("File type not allowed")
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
def validate_file_size(file: UploadFile):
|
|
1164
|
-
"""Validate the file size as configured in config.features.spontaneous_file_upload.
|
|
1165
|
-
Args:
|
|
1166
|
-
file (UploadFile): The file to validate.
|
|
1167
|
-
Raises:
|
|
1168
|
-
ValueError: If the file size is too large.
|
|
1169
|
-
"""
|
|
1170
|
-
if config.features.spontaneous_file_upload.max_size_mb is None:
|
|
1171
|
-
return
|
|
1172
|
-
|
|
1173
|
-
if (
|
|
1174
|
-
file.size is not None
|
|
1175
|
-
and file.size
|
|
1176
|
-
> config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
|
|
1177
|
-
):
|
|
1178
|
-
raise ValueError("File size too large")
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
880
|
@router.get("/project/file/{file_id}")
|
|
1182
881
|
async def get_file(
|
|
1183
882
|
file_id: str,
|
|
1184
883
|
session_id: str,
|
|
1185
|
-
current_user:
|
|
884
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
1186
885
|
):
|
|
1187
886
|
"""Get a file from the session files directory."""
|
|
887
|
+
|
|
1188
888
|
from chainlit.session import WebsocketSession
|
|
1189
889
|
|
|
1190
890
|
session = WebsocketSession.get_by_id(session_id) if session_id else None
|
|
@@ -1212,7 +912,7 @@ async def get_file(
|
|
|
1212
912
|
@router.get("/files/{filename:path}")
|
|
1213
913
|
async def serve_file(
|
|
1214
914
|
filename: str,
|
|
1215
|
-
current_user:
|
|
915
|
+
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
1216
916
|
):
|
|
1217
917
|
"""Serve a file from the local filesystem."""
|
|
1218
918
|
|