chainlit 1.3.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +58 -56
- chainlit/action.py +12 -10
- chainlit/{auth.py → auth/__init__.py} +24 -34
- chainlit/auth/cookie.py +123 -0
- chainlit/auth/jwt.py +37 -0
- chainlit/cache.py +4 -6
- chainlit/callbacks.py +65 -11
- chainlit/chat_context.py +2 -2
- chainlit/chat_settings.py +3 -1
- chainlit/cli/__init__.py +15 -2
- chainlit/config.py +46 -90
- chainlit/context.py +4 -3
- chainlit/copilot/dist/index.js +8608 -642
- chainlit/data/__init__.py +96 -8
- chainlit/data/acl.py +3 -2
- chainlit/data/base.py +1 -15
- chainlit/data/chainlit_data_layer.py +584 -0
- chainlit/data/dynamodb.py +7 -4
- chainlit/data/literalai.py +4 -6
- chainlit/data/sql_alchemy.py +9 -8
- chainlit/data/storage_clients/__init__.py +0 -0
- chainlit/data/{storage_clients.py → storage_clients/azure.py} +2 -33
- chainlit/data/storage_clients/azure_blob.py +80 -0
- chainlit/data/storage_clients/base.py +22 -0
- chainlit/data/storage_clients/gcs.py +78 -0
- chainlit/data/storage_clients/s3.py +49 -0
- chainlit/discord/__init__.py +4 -4
- chainlit/discord/app.py +2 -1
- chainlit/element.py +41 -9
- chainlit/emitter.py +37 -16
- chainlit/frontend/dist/assets/{DailyMotion-Bq4wFES6.js → DailyMotion-DgRzV5GZ.js} +1 -1
- chainlit/frontend/dist/assets/Dataframe-DVgwSMU2.js +22 -0
- chainlit/frontend/dist/assets/{Facebook-CHEgeJDe.js → Facebook-C0vx6HWv.js} +1 -1
- chainlit/frontend/dist/assets/{FilePlayer-BMFA6He5.js → FilePlayer-CdhzeHPP.js} +1 -1
- chainlit/frontend/dist/assets/{Kaltura-BS4Q0SKd.js → Kaltura-5iVmeUct.js} +1 -1
- chainlit/frontend/dist/assets/{Mixcloud-tLlgZy_i.js → Mixcloud-C2zi77Ex.js} +1 -1
- chainlit/frontend/dist/assets/{Mux-Bcz0qNhS.js → Mux-Vkebogdf.js} +1 -1
- chainlit/frontend/dist/assets/{Preview-RsJjlwJx.js → Preview-DwY_sEIl.js} +1 -1
- chainlit/frontend/dist/assets/{SoundCloud-B9UgR7Bk.js → SoundCloud-CREBXAWo.js} +1 -1
- chainlit/frontend/dist/assets/{Streamable-BOgIqbui.js → Streamable-B5Lu25uy.js} +1 -1
- chainlit/frontend/dist/assets/{Twitch-CBX_d6nV.js → Twitch-y9iKCcM1.js} +1 -1
- chainlit/frontend/dist/assets/{Vidyard-C5HPuozf.js → Vidyard-ClYvcuEu.js} +1 -1
- chainlit/frontend/dist/assets/{Vimeo-CHBmywi9.js → Vimeo-D6HvM2jt.js} +1 -1
- chainlit/frontend/dist/assets/Wistia-Cu4zZ2Ci.js +1 -0
- chainlit/frontend/dist/assets/{YouTube-CA7t0q0j.js → YouTube-D10tR6CJ.js} +1 -1
- chainlit/frontend/dist/assets/index-CI4qFOt5.js +8665 -0
- chainlit/frontend/dist/assets/index-CrrqM0nZ.css +1 -0
- chainlit/frontend/dist/assets/{react-plotly-Ba2Cl614.js → react-plotly-BpxUS-ab.js} +1 -1
- chainlit/frontend/dist/index.html +2 -2
- chainlit/haystack/callbacks.py +5 -4
- chainlit/input_widget.py +6 -4
- chainlit/langchain/callbacks.py +56 -47
- chainlit/langflow/__init__.py +1 -0
- chainlit/llama_index/callbacks.py +7 -7
- chainlit/message.py +8 -10
- chainlit/mistralai/__init__.py +3 -2
- chainlit/oauth_providers.py +70 -3
- chainlit/openai/__init__.py +3 -2
- chainlit/secret.py +1 -1
- chainlit/server.py +481 -182
- chainlit/session.py +7 -5
- chainlit/slack/__init__.py +3 -3
- chainlit/slack/app.py +3 -2
- chainlit/socket.py +89 -112
- chainlit/step.py +12 -12
- chainlit/sync.py +2 -1
- chainlit/teams/__init__.py +3 -3
- chainlit/teams/app.py +1 -0
- chainlit/translations/en-US.json +2 -1
- chainlit/translations/nl-NL.json +229 -0
- chainlit/types.py +24 -8
- chainlit/user.py +2 -1
- chainlit/utils.py +3 -2
- chainlit/version.py +3 -2
- {chainlit-1.3.2.dist-info → chainlit-2.0.0.dist-info}/METADATA +15 -35
- chainlit-2.0.0.dist-info/RECORD +106 -0
- chainlit/frontend/dist/assets/Wistia-1Gb23ljh.js +0 -1
- chainlit/frontend/dist/assets/index-CwmincdQ.css +0 -1
- chainlit/frontend/dist/assets/index-DnjoDoLU.js +0 -723
- chainlit-1.3.2.dist-info/RECORD +0 -96
- {chainlit-1.3.2.dist-info → chainlit-2.0.0.dist-info}/WHEEL +0 -0
- {chainlit-1.3.2.dist-info → chainlit-2.0.0.dist-info}/entry_points.txt +0 -0
chainlit/server.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import fnmatch
|
|
2
3
|
import glob
|
|
3
4
|
import json
|
|
4
5
|
import mimetypes
|
|
@@ -9,10 +10,36 @@ import urllib.parse
|
|
|
9
10
|
import webbrowser
|
|
10
11
|
from contextlib import asynccontextmanager
|
|
11
12
|
from pathlib import Path
|
|
12
|
-
from typing import
|
|
13
|
+
from typing import List, Optional, Union, cast
|
|
13
14
|
|
|
14
15
|
import socketio
|
|
15
|
-
from
|
|
16
|
+
from fastapi import (
|
|
17
|
+
APIRouter,
|
|
18
|
+
Depends,
|
|
19
|
+
FastAPI,
|
|
20
|
+
Form,
|
|
21
|
+
HTTPException,
|
|
22
|
+
Query,
|
|
23
|
+
Request,
|
|
24
|
+
Response,
|
|
25
|
+
UploadFile,
|
|
26
|
+
status,
|
|
27
|
+
)
|
|
28
|
+
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
|
|
29
|
+
from fastapi.security import OAuth2PasswordRequestForm
|
|
30
|
+
from starlette.datastructures import URL
|
|
31
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
32
|
+
from typing_extensions import Annotated
|
|
33
|
+
from watchfiles import awatch
|
|
34
|
+
|
|
35
|
+
from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
|
|
36
|
+
from chainlit.auth.cookie import (
|
|
37
|
+
clear_auth_cookie,
|
|
38
|
+
clear_oauth_state_cookie,
|
|
39
|
+
set_auth_cookie,
|
|
40
|
+
set_oauth_state_cookie,
|
|
41
|
+
validate_oauth_state_cookie,
|
|
42
|
+
)
|
|
16
43
|
from chainlit.config import (
|
|
17
44
|
APP_ROOT,
|
|
18
45
|
BACKEND_ROOT,
|
|
@@ -21,6 +48,7 @@ from chainlit.config import (
|
|
|
21
48
|
PACKAGE_ROOT,
|
|
22
49
|
config,
|
|
23
50
|
load_module,
|
|
51
|
+
public_dir,
|
|
24
52
|
reload_config,
|
|
25
53
|
)
|
|
26
54
|
from chainlit.data import get_data_layer
|
|
@@ -30,33 +58,16 @@ from chainlit.markdown import get_markdown_str
|
|
|
30
58
|
from chainlit.oauth_providers import get_oauth_provider
|
|
31
59
|
from chainlit.secret import random_secret
|
|
32
60
|
from chainlit.types import (
|
|
61
|
+
CallActionRequest,
|
|
33
62
|
DeleteFeedbackRequest,
|
|
34
63
|
DeleteThreadRequest,
|
|
64
|
+
ElementRequest,
|
|
35
65
|
GetThreadsRequest,
|
|
36
66
|
Theme,
|
|
37
67
|
UpdateFeedbackRequest,
|
|
68
|
+
UpdateThreadRequest,
|
|
38
69
|
)
|
|
39
70
|
from chainlit.user import PersistedUser, User
|
|
40
|
-
from fastapi import (
|
|
41
|
-
APIRouter,
|
|
42
|
-
Depends,
|
|
43
|
-
FastAPI,
|
|
44
|
-
File,
|
|
45
|
-
Form,
|
|
46
|
-
HTTPException,
|
|
47
|
-
Query,
|
|
48
|
-
Request,
|
|
49
|
-
Response,
|
|
50
|
-
UploadFile,
|
|
51
|
-
status,
|
|
52
|
-
)
|
|
53
|
-
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
|
|
54
|
-
from fastapi.security import OAuth2PasswordRequestForm
|
|
55
|
-
from fastapi.staticfiles import StaticFiles
|
|
56
|
-
from starlette.datastructures import URL
|
|
57
|
-
from starlette.middleware.cors import CORSMiddleware
|
|
58
|
-
from typing_extensions import Annotated
|
|
59
|
-
from watchfiles import awatch
|
|
60
71
|
|
|
61
72
|
from ._utils import is_path_inside
|
|
62
73
|
|
|
@@ -205,29 +216,59 @@ app.add_middleware(
|
|
|
205
216
|
|
|
206
217
|
router = APIRouter(prefix=PREFIX)
|
|
207
218
|
|
|
208
|
-
app.mount(
|
|
209
|
-
f"{PREFIX}/public",
|
|
210
|
-
StaticFiles(directory="public", check_dir=False),
|
|
211
|
-
name="public",
|
|
212
|
-
)
|
|
213
219
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
),
|
|
220
|
-
name="assets",
|
|
221
|
-
)
|
|
220
|
+
@router.get("/public/{filename:path}")
|
|
221
|
+
async def serve_public_file(
|
|
222
|
+
filename: str,
|
|
223
|
+
):
|
|
224
|
+
"""Serve a file from public dir."""
|
|
222
225
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
)
|
|
226
|
+
base_path = Path(public_dir)
|
|
227
|
+
file_path = (base_path / filename).resolve()
|
|
228
|
+
|
|
229
|
+
if not is_path_inside(file_path, base_path):
|
|
230
|
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
231
|
+
|
|
232
|
+
if file_path.is_file():
|
|
233
|
+
return FileResponse(file_path)
|
|
234
|
+
else:
|
|
235
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@router.get("/assets/{filename:path}")
|
|
239
|
+
async def serve_asset_file(
|
|
240
|
+
filename: str,
|
|
241
|
+
):
|
|
242
|
+
"""Serve a file from assets dir."""
|
|
243
|
+
|
|
244
|
+
base_path = Path(os.path.join(build_dir, "assets"))
|
|
245
|
+
file_path = (base_path / filename).resolve()
|
|
246
|
+
|
|
247
|
+
if not is_path_inside(file_path, base_path):
|
|
248
|
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
249
|
+
|
|
250
|
+
if file_path.is_file():
|
|
251
|
+
return FileResponse(file_path)
|
|
252
|
+
else:
|
|
253
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@router.get("/copilot/{filename:path}")
|
|
257
|
+
async def serve_copilot_file(
|
|
258
|
+
filename: str,
|
|
259
|
+
):
|
|
260
|
+
"""Serve a file from assets dir."""
|
|
261
|
+
|
|
262
|
+
base_path = Path(copilot_build_dir)
|
|
263
|
+
file_path = (base_path / filename).resolve()
|
|
264
|
+
|
|
265
|
+
if not is_path_inside(file_path, base_path):
|
|
266
|
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
267
|
+
|
|
268
|
+
if file_path.is_file():
|
|
269
|
+
return FileResponse(file_path)
|
|
270
|
+
else:
|
|
271
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
231
272
|
|
|
232
273
|
|
|
233
274
|
# -------------------------------------------------------------------------------
|
|
@@ -248,6 +289,7 @@ if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_SIGNING_SECRET"):
|
|
|
248
289
|
|
|
249
290
|
if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
|
|
250
291
|
from botbuilder.schema import Activity
|
|
292
|
+
|
|
251
293
|
from chainlit.teams.app import adapter, bot
|
|
252
294
|
|
|
253
295
|
@router.post("/teams/events")
|
|
@@ -277,6 +319,16 @@ def get_html_template():
|
|
|
277
319
|
"""
|
|
278
320
|
Get HTML template for the index view.
|
|
279
321
|
"""
|
|
322
|
+
ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
323
|
+
|
|
324
|
+
custom_theme = None
|
|
325
|
+
custom_theme_file_path = Path(public_dir) / "theme.json"
|
|
326
|
+
if (
|
|
327
|
+
is_path_inside(custom_theme_file_path, Path(public_dir))
|
|
328
|
+
and custom_theme_file_path.is_file()
|
|
329
|
+
):
|
|
330
|
+
custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))
|
|
331
|
+
|
|
280
332
|
PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
|
|
281
333
|
JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
|
|
282
334
|
CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
|
|
@@ -299,7 +351,10 @@ def get_html_template():
|
|
|
299
351
|
<meta property="og:url" content="{url}">
|
|
300
352
|
<meta property="og:root_path" content="{ROOT_PATH}">"""
|
|
301
353
|
|
|
302
|
-
js = f"""<script>
|
|
354
|
+
js = f"""<script>
|
|
355
|
+
{f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
|
|
356
|
+
{f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
|
|
357
|
+
</script>"""
|
|
303
358
|
|
|
304
359
|
css = None
|
|
305
360
|
if config.ui.custom_css:
|
|
@@ -311,12 +366,15 @@ def get_html_template():
|
|
|
311
366
|
js += f"""<script src="{config.ui.custom_js}" defer></script>"""
|
|
312
367
|
|
|
313
368
|
font = None
|
|
314
|
-
if
|
|
315
|
-
font =
|
|
369
|
+
if custom_theme and custom_theme.get("custom_fonts"):
|
|
370
|
+
font = "\n".join(
|
|
371
|
+
f"""<link rel="stylesheet" href="{font}">"""
|
|
372
|
+
for font in custom_theme.get("custom_fonts")
|
|
373
|
+
)
|
|
316
374
|
|
|
317
375
|
index_html_file_path = os.path.join(build_dir, "index.html")
|
|
318
376
|
|
|
319
|
-
with open(index_html_file_path,
|
|
377
|
+
with open(index_html_file_path, encoding="utf-8") as f:
|
|
320
378
|
content = f.read()
|
|
321
379
|
content = content.replace(PLACEHOLDER, tags)
|
|
322
380
|
if js:
|
|
@@ -361,46 +419,132 @@ async def auth(request: Request):
|
|
|
361
419
|
return get_configuration()
|
|
362
420
|
|
|
363
421
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
422
|
+
def _get_response_dict(access_token: str) -> dict:
|
|
423
|
+
"""Get the response dictionary for the auth response."""
|
|
424
|
+
|
|
425
|
+
return {"success": True}
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _get_auth_response(access_token: str, redirect_to_callback: bool) -> Response:
|
|
429
|
+
"""Get the redirect params for the OAuth callback."""
|
|
430
|
+
|
|
431
|
+
response_dict = _get_response_dict(access_token)
|
|
432
|
+
|
|
433
|
+
if redirect_to_callback:
|
|
434
|
+
root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
435
|
+
redirect_url = (
|
|
436
|
+
f"{root_path}/login/callback?{urllib.parse.urlencode(response_dict)}"
|
|
372
437
|
)
|
|
373
438
|
|
|
374
|
-
|
|
375
|
-
|
|
439
|
+
return RedirectResponse(
|
|
440
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
441
|
+
url=redirect_url,
|
|
442
|
+
status_code=302,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
return JSONResponse(response_dict)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def _get_oauth_redirect_error(error: str) -> Response:
|
|
449
|
+
"""Get the redirect response for an OAuth error."""
|
|
450
|
+
params = urllib.parse.urlencode(
|
|
451
|
+
{
|
|
452
|
+
"error": error,
|
|
453
|
+
}
|
|
376
454
|
)
|
|
455
|
+
response = RedirectResponse(
|
|
456
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
457
|
+
url=f"/login?{params}", # Shouldn't there be {root_path} here?
|
|
458
|
+
)
|
|
459
|
+
return response
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
async def _authenticate_user(
|
|
463
|
+
user: Optional[User], redirect_to_callback: bool = False
|
|
464
|
+
) -> Response:
|
|
465
|
+
"""Authenticate a user and return the response."""
|
|
377
466
|
|
|
378
467
|
if not user:
|
|
379
468
|
raise HTTPException(
|
|
380
469
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
381
470
|
detail="credentialssignin",
|
|
382
471
|
)
|
|
383
|
-
|
|
472
|
+
|
|
473
|
+
# If a data layer is defined, attempt to persist user.
|
|
384
474
|
if data_layer := get_data_layer():
|
|
385
475
|
try:
|
|
386
476
|
await data_layer.create_user(user)
|
|
387
477
|
except Exception as e:
|
|
478
|
+
# Catch and log exceptions during user creation.
|
|
479
|
+
# TODO: Make this catch only specific errors and allow others to propagate.
|
|
388
480
|
logger.error(f"Error creating user: {e}")
|
|
389
481
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
482
|
+
access_token = create_jwt(user)
|
|
483
|
+
|
|
484
|
+
response = _get_auth_response(access_token, redirect_to_callback)
|
|
485
|
+
|
|
486
|
+
set_auth_cookie(response, access_token)
|
|
487
|
+
|
|
488
|
+
return response
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
@router.post("/login")
|
|
492
|
+
async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()):
|
|
493
|
+
"""
|
|
494
|
+
Login a user using the password auth callback.
|
|
495
|
+
"""
|
|
496
|
+
if not config.code.password_auth_callback:
|
|
497
|
+
raise HTTPException(
|
|
498
|
+
status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
user = await config.code.password_auth_callback(
|
|
502
|
+
form_data.username, form_data.password
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
return await _authenticate_user(user)
|
|
394
506
|
|
|
395
507
|
|
|
396
508
|
@router.post("/logout")
|
|
397
509
|
async def logout(request: Request, response: Response):
|
|
398
510
|
"""Logout the user by calling the on_logout callback."""
|
|
511
|
+
clear_auth_cookie(response)
|
|
512
|
+
|
|
399
513
|
if config.code.on_logout:
|
|
400
514
|
return await config.code.on_logout(request, response)
|
|
515
|
+
|
|
401
516
|
return {"success": True}
|
|
402
517
|
|
|
403
518
|
|
|
519
|
+
@router.post("/auth/jwt")
|
|
520
|
+
async def jwt_auth(request: Request):
|
|
521
|
+
"""Login a user using a valid jwt."""
|
|
522
|
+
from jwt import InvalidTokenError
|
|
523
|
+
|
|
524
|
+
auth_header: Optional[str] = request.headers.get("Authorization")
|
|
525
|
+
if not auth_header:
|
|
526
|
+
raise HTTPException(status_code=401, detail="Authorization header missing")
|
|
527
|
+
|
|
528
|
+
# Check if it starts with "Bearer "
|
|
529
|
+
try:
|
|
530
|
+
scheme, token = auth_header.split()
|
|
531
|
+
if scheme.lower() != "bearer":
|
|
532
|
+
raise HTTPException(
|
|
533
|
+
status_code=401,
|
|
534
|
+
detail="Invalid authentication scheme. Please use Bearer",
|
|
535
|
+
)
|
|
536
|
+
except ValueError:
|
|
537
|
+
raise HTTPException(
|
|
538
|
+
status_code=401, detail="Invalid authorization header format"
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
try:
|
|
542
|
+
user = decode_jwt(token)
|
|
543
|
+
return await _authenticate_user(user)
|
|
544
|
+
except InvalidTokenError:
|
|
545
|
+
raise HTTPException(status_code=401, detail="Invalid token")
|
|
546
|
+
|
|
547
|
+
|
|
404
548
|
@router.post("/auth/header")
|
|
405
549
|
async def header_auth(request: Request):
|
|
406
550
|
"""Login a user using the header_auth_callback."""
|
|
@@ -412,23 +556,7 @@ async def header_auth(request: Request):
|
|
|
412
556
|
|
|
413
557
|
user = await config.code.header_auth_callback(request.headers)
|
|
414
558
|
|
|
415
|
-
|
|
416
|
-
raise HTTPException(
|
|
417
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
418
|
-
detail="Unauthorized",
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
access_token = create_jwt(user)
|
|
422
|
-
if data_layer := get_data_layer():
|
|
423
|
-
try:
|
|
424
|
-
await data_layer.create_user(user)
|
|
425
|
-
except Exception as e:
|
|
426
|
-
logger.error(f"Error creating user: {e}")
|
|
427
|
-
|
|
428
|
-
return {
|
|
429
|
-
"access_token": access_token,
|
|
430
|
-
"token_type": "bearer",
|
|
431
|
-
}
|
|
559
|
+
return await _authenticate_user(user)
|
|
432
560
|
|
|
433
561
|
|
|
434
562
|
@router.get("/auth/oauth/{provider_id}")
|
|
@@ -460,16 +588,9 @@ async def oauth_login(provider_id: str, request: Request):
|
|
|
460
588
|
response = RedirectResponse(
|
|
461
589
|
url=f"{provider.authorize_url}?{params}",
|
|
462
590
|
)
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
"oauth_state",
|
|
467
|
-
random,
|
|
468
|
-
httponly=True,
|
|
469
|
-
samesite=samesite,
|
|
470
|
-
secure=secure,
|
|
471
|
-
max_age=3 * 60,
|
|
472
|
-
)
|
|
591
|
+
|
|
592
|
+
set_oauth_state_cookie(response, random)
|
|
593
|
+
|
|
473
594
|
return response
|
|
474
595
|
|
|
475
596
|
|
|
@@ -497,16 +618,7 @@ async def oauth_callback(
|
|
|
497
618
|
)
|
|
498
619
|
|
|
499
620
|
if error:
|
|
500
|
-
|
|
501
|
-
{
|
|
502
|
-
"error": error,
|
|
503
|
-
}
|
|
504
|
-
)
|
|
505
|
-
response = RedirectResponse(
|
|
506
|
-
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
507
|
-
url=f"/login?{params}",
|
|
508
|
-
)
|
|
509
|
-
return response
|
|
621
|
+
return _get_oauth_redirect_error(error)
|
|
510
622
|
|
|
511
623
|
if not code or not state:
|
|
512
624
|
raise HTTPException(
|
|
@@ -514,9 +626,11 @@ async def oauth_callback(
|
|
|
514
626
|
detail="Missing code or state",
|
|
515
627
|
)
|
|
516
628
|
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
629
|
+
try:
|
|
630
|
+
validate_oauth_state_cookie(request, state)
|
|
631
|
+
except Exception as e:
|
|
632
|
+
logger.exception("Unable to validate oauth state: %1", e)
|
|
633
|
+
|
|
520
634
|
raise HTTPException(
|
|
521
635
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
522
636
|
detail="Unauthorized",
|
|
@@ -531,34 +645,10 @@ async def oauth_callback(
|
|
|
531
645
|
provider_id, token, raw_user_data, default_user
|
|
532
646
|
)
|
|
533
647
|
|
|
534
|
-
|
|
535
|
-
raise HTTPException(
|
|
536
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
537
|
-
detail="Unauthorized",
|
|
538
|
-
)
|
|
539
|
-
|
|
540
|
-
access_token = create_jwt(user)
|
|
541
|
-
|
|
542
|
-
if data_layer := get_data_layer():
|
|
543
|
-
try:
|
|
544
|
-
await data_layer.create_user(user)
|
|
545
|
-
except Exception as e:
|
|
546
|
-
logger.error(f"Error creating user: {e}")
|
|
547
|
-
|
|
548
|
-
params = urllib.parse.urlencode(
|
|
549
|
-
{
|
|
550
|
-
"access_token": access_token,
|
|
551
|
-
"token_type": "bearer",
|
|
552
|
-
}
|
|
553
|
-
)
|
|
648
|
+
response = await _authenticate_user(user, redirect_to_callback=True)
|
|
554
649
|
|
|
555
|
-
|
|
650
|
+
clear_oauth_state_cookie(response)
|
|
556
651
|
|
|
557
|
-
response = RedirectResponse(
|
|
558
|
-
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
559
|
-
url=f"{root_path}/login/callback?{params}",
|
|
560
|
-
)
|
|
561
|
-
response.delete_cookie("oauth_state")
|
|
562
652
|
return response
|
|
563
653
|
|
|
564
654
|
|
|
@@ -587,16 +677,7 @@ async def oauth_azure_hf_callback(
|
|
|
587
677
|
)
|
|
588
678
|
|
|
589
679
|
if error:
|
|
590
|
-
|
|
591
|
-
{
|
|
592
|
-
"error": error,
|
|
593
|
-
}
|
|
594
|
-
)
|
|
595
|
-
response = RedirectResponse(
|
|
596
|
-
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
597
|
-
url=f"/login?{params}",
|
|
598
|
-
)
|
|
599
|
-
return response
|
|
680
|
+
return _get_oauth_redirect_error(error)
|
|
600
681
|
|
|
601
682
|
if not code:
|
|
602
683
|
raise HTTPException(
|
|
@@ -613,36 +694,20 @@ async def oauth_azure_hf_callback(
|
|
|
613
694
|
provider_id, token, raw_user_data, default_user, id_token
|
|
614
695
|
)
|
|
615
696
|
|
|
616
|
-
|
|
617
|
-
raise HTTPException(
|
|
618
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
619
|
-
detail="Unauthorized",
|
|
620
|
-
)
|
|
697
|
+
response = await _authenticate_user(user, redirect_to_callback=True)
|
|
621
698
|
|
|
622
|
-
|
|
699
|
+
clear_oauth_state_cookie(response)
|
|
623
700
|
|
|
624
|
-
|
|
625
|
-
try:
|
|
626
|
-
await data_layer.create_user(user)
|
|
627
|
-
except Exception as e:
|
|
628
|
-
logger.error(f"Error creating user: {e}")
|
|
701
|
+
return response
|
|
629
702
|
|
|
630
|
-
params = urllib.parse.urlencode(
|
|
631
|
-
{
|
|
632
|
-
"access_token": access_token,
|
|
633
|
-
"token_type": "bearer",
|
|
634
|
-
}
|
|
635
|
-
)
|
|
636
703
|
|
|
637
|
-
|
|
704
|
+
GenericUser = Union[User, PersistedUser, None]
|
|
705
|
+
UserParam = Annotated[GenericUser, Depends(get_current_user)]
|
|
638
706
|
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
)
|
|
644
|
-
response.delete_cookie("oauth_state")
|
|
645
|
-
return response
|
|
707
|
+
|
|
708
|
+
@router.get("/user")
|
|
709
|
+
async def get_user(current_user: UserParam) -> GenericUser:
|
|
710
|
+
return current_user
|
|
646
711
|
|
|
647
712
|
|
|
648
713
|
_language_pattern = (
|
|
@@ -670,7 +735,7 @@ async def project_translations(
|
|
|
670
735
|
|
|
671
736
|
@router.get("/project/settings")
|
|
672
737
|
async def project_settings(
|
|
673
|
-
current_user:
|
|
738
|
+
current_user: UserParam,
|
|
674
739
|
language: str = Query(
|
|
675
740
|
default="en-US", description="Language code", pattern=_language_pattern
|
|
676
741
|
),
|
|
@@ -721,7 +786,7 @@ async def project_settings(
|
|
|
721
786
|
async def update_feedback(
|
|
722
787
|
request: Request,
|
|
723
788
|
update: UpdateFeedbackRequest,
|
|
724
|
-
current_user:
|
|
789
|
+
current_user: UserParam,
|
|
725
790
|
):
|
|
726
791
|
"""Update the human feedback for a particular message."""
|
|
727
792
|
data_layer = get_data_layer()
|
|
@@ -731,7 +796,7 @@ async def update_feedback(
|
|
|
731
796
|
try:
|
|
732
797
|
feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)
|
|
733
798
|
except Exception as e:
|
|
734
|
-
raise HTTPException(detail=str(e), status_code=500)
|
|
799
|
+
raise HTTPException(detail=str(e), status_code=500) from e
|
|
735
800
|
|
|
736
801
|
return JSONResponse(content={"success": True, "feedbackId": feedback_id})
|
|
737
802
|
|
|
@@ -740,7 +805,7 @@ async def update_feedback(
|
|
|
740
805
|
async def delete_feedback(
|
|
741
806
|
request: Request,
|
|
742
807
|
payload: DeleteFeedbackRequest,
|
|
743
|
-
current_user:
|
|
808
|
+
current_user: UserParam,
|
|
744
809
|
):
|
|
745
810
|
"""Delete a feedback."""
|
|
746
811
|
|
|
@@ -759,7 +824,7 @@ async def delete_feedback(
|
|
|
759
824
|
async def get_user_threads(
|
|
760
825
|
request: Request,
|
|
761
826
|
payload: GetThreadsRequest,
|
|
762
|
-
current_user:
|
|
827
|
+
current_user: UserParam,
|
|
763
828
|
):
|
|
764
829
|
"""Get the threads page by page."""
|
|
765
830
|
|
|
@@ -768,6 +833,9 @@ async def get_user_threads(
|
|
|
768
833
|
if not data_layer:
|
|
769
834
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
770
835
|
|
|
836
|
+
if not current_user:
|
|
837
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
838
|
+
|
|
771
839
|
if not isinstance(current_user, PersistedUser):
|
|
772
840
|
persisted_user = await data_layer.get_user(identifier=current_user.identifier)
|
|
773
841
|
if not persisted_user:
|
|
@@ -784,7 +852,7 @@ async def get_user_threads(
|
|
|
784
852
|
async def get_thread(
|
|
785
853
|
request: Request,
|
|
786
854
|
thread_id: str,
|
|
787
|
-
current_user:
|
|
855
|
+
current_user: UserParam,
|
|
788
856
|
):
|
|
789
857
|
"""Get a specific thread."""
|
|
790
858
|
data_layer = get_data_layer()
|
|
@@ -792,6 +860,9 @@ async def get_thread(
|
|
|
792
860
|
if not data_layer:
|
|
793
861
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
794
862
|
|
|
863
|
+
if not current_user:
|
|
864
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
865
|
+
|
|
795
866
|
await is_thread_author(current_user.identifier, thread_id)
|
|
796
867
|
|
|
797
868
|
res = await data_layer.get_thread(thread_id)
|
|
@@ -803,7 +874,7 @@ async def get_thread_element(
|
|
|
803
874
|
request: Request,
|
|
804
875
|
thread_id: str,
|
|
805
876
|
element_id: str,
|
|
806
|
-
current_user:
|
|
877
|
+
current_user: UserParam,
|
|
807
878
|
):
|
|
808
879
|
"""Get a specific thread element."""
|
|
809
880
|
data_layer = get_data_layer()
|
|
@@ -811,17 +882,135 @@ async def get_thread_element(
|
|
|
811
882
|
if not data_layer:
|
|
812
883
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
813
884
|
|
|
885
|
+
if not current_user:
|
|
886
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
887
|
+
|
|
814
888
|
await is_thread_author(current_user.identifier, thread_id)
|
|
815
889
|
|
|
816
890
|
res = await data_layer.get_element(thread_id, element_id)
|
|
817
891
|
return JSONResponse(content=res)
|
|
818
892
|
|
|
819
893
|
|
|
894
|
+
@router.put("/project/element")
|
|
895
|
+
async def update_thread_element(
|
|
896
|
+
payload: ElementRequest,
|
|
897
|
+
current_user: UserParam,
|
|
898
|
+
):
|
|
899
|
+
"""Update a specific thread element."""
|
|
900
|
+
|
|
901
|
+
from chainlit.context import init_ws_context
|
|
902
|
+
from chainlit.element import CustomElement, ElementDict
|
|
903
|
+
from chainlit.session import WebsocketSession
|
|
904
|
+
|
|
905
|
+
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
906
|
+
context = init_ws_context(session)
|
|
907
|
+
|
|
908
|
+
element_dict = cast(ElementDict, payload.element)
|
|
909
|
+
|
|
910
|
+
if element_dict["type"] != "custom":
|
|
911
|
+
return {"success": False}
|
|
912
|
+
|
|
913
|
+
element = CustomElement(
|
|
914
|
+
id=element_dict["id"],
|
|
915
|
+
object_key=element_dict["objectKey"],
|
|
916
|
+
chainlit_key=element_dict["chainlitKey"],
|
|
917
|
+
url=element_dict["url"],
|
|
918
|
+
for_id=element_dict.get("forId") or "",
|
|
919
|
+
thread_id=element_dict.get("threadId") or "",
|
|
920
|
+
name=element_dict["name"],
|
|
921
|
+
props=element_dict.get("props") or {},
|
|
922
|
+
display=element_dict["display"],
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
if current_user:
|
|
926
|
+
if (
|
|
927
|
+
not context.session.user
|
|
928
|
+
or context.session.user.identifier != current_user.identifier
|
|
929
|
+
):
|
|
930
|
+
raise HTTPException(
|
|
931
|
+
status_code=401,
|
|
932
|
+
detail="You are not authorized to update elements for this session",
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
await element.send(for_id=element.for_id or "")
|
|
936
|
+
return {"success": True}
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
@router.delete("/project/element")
|
|
940
|
+
async def delete_thread_element(
|
|
941
|
+
payload: ElementRequest,
|
|
942
|
+
current_user: UserParam,
|
|
943
|
+
):
|
|
944
|
+
"""Delete a specific thread element."""
|
|
945
|
+
|
|
946
|
+
from chainlit.context import init_ws_context
|
|
947
|
+
from chainlit.element import CustomElement, ElementDict
|
|
948
|
+
from chainlit.session import WebsocketSession
|
|
949
|
+
|
|
950
|
+
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
951
|
+
context = init_ws_context(session)
|
|
952
|
+
|
|
953
|
+
element_dict = cast(ElementDict, payload.element)
|
|
954
|
+
|
|
955
|
+
if element_dict["type"] != "custom":
|
|
956
|
+
return {"success": False}
|
|
957
|
+
|
|
958
|
+
element = CustomElement(
|
|
959
|
+
id=element_dict["id"],
|
|
960
|
+
object_key=element_dict["objectKey"],
|
|
961
|
+
chainlit_key=element_dict["chainlitKey"],
|
|
962
|
+
url=element_dict["url"],
|
|
963
|
+
for_id=element_dict.get("forId") or "",
|
|
964
|
+
thread_id=element_dict.get("threadId") or "",
|
|
965
|
+
name=element_dict["name"],
|
|
966
|
+
props=element_dict.get("props") or {},
|
|
967
|
+
display=element_dict["display"],
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
if current_user:
|
|
971
|
+
if (
|
|
972
|
+
not context.session.user
|
|
973
|
+
or context.session.user.identifier != current_user.identifier
|
|
974
|
+
):
|
|
975
|
+
raise HTTPException(
|
|
976
|
+
status_code=401,
|
|
977
|
+
detail="You are not authorized to remove elements for this session",
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
await element.remove()
|
|
981
|
+
|
|
982
|
+
return {"success": True}
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
@router.put("/project/thread")
|
|
986
|
+
async def rename_thread(
|
|
987
|
+
request: Request,
|
|
988
|
+
payload: UpdateThreadRequest,
|
|
989
|
+
current_user: UserParam,
|
|
990
|
+
):
|
|
991
|
+
"""Rename a thread."""
|
|
992
|
+
|
|
993
|
+
data_layer = get_data_layer()
|
|
994
|
+
|
|
995
|
+
if not data_layer:
|
|
996
|
+
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
997
|
+
|
|
998
|
+
if not current_user:
|
|
999
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
1000
|
+
|
|
1001
|
+
thread_id = payload.threadId
|
|
1002
|
+
|
|
1003
|
+
await is_thread_author(current_user.identifier, thread_id)
|
|
1004
|
+
|
|
1005
|
+
await data_layer.update_thread(thread_id, name=payload.name)
|
|
1006
|
+
return JSONResponse(content={"success": True})
|
|
1007
|
+
|
|
1008
|
+
|
|
820
1009
|
@router.delete("/project/thread")
|
|
821
1010
|
async def delete_thread(
|
|
822
1011
|
request: Request,
|
|
823
1012
|
payload: DeleteThreadRequest,
|
|
824
|
-
current_user:
|
|
1013
|
+
current_user: UserParam,
|
|
825
1014
|
):
|
|
826
1015
|
"""Delete a thread."""
|
|
827
1016
|
|
|
@@ -830,6 +1019,9 @@ async def delete_thread(
|
|
|
830
1019
|
if not data_layer:
|
|
831
1020
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
832
1021
|
|
|
1022
|
+
if not current_user:
|
|
1023
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
1024
|
+
|
|
833
1025
|
thread_id = payload.threadId
|
|
834
1026
|
|
|
835
1027
|
await is_thread_author(current_user.identifier, thread_id)
|
|
@@ -838,9 +1030,47 @@ async def delete_thread(
|
|
|
838
1030
|
return JSONResponse(content={"success": True})
|
|
839
1031
|
|
|
840
1032
|
|
|
1033
|
+
@router.post("/project/action")
|
|
1034
|
+
async def call_action(
|
|
1035
|
+
payload: CallActionRequest,
|
|
1036
|
+
current_user: UserParam,
|
|
1037
|
+
):
|
|
1038
|
+
"""Run an action."""
|
|
1039
|
+
|
|
1040
|
+
from chainlit.action import Action
|
|
1041
|
+
from chainlit.context import init_ws_context
|
|
1042
|
+
from chainlit.session import WebsocketSession
|
|
1043
|
+
|
|
1044
|
+
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
1045
|
+
context = init_ws_context(session)
|
|
1046
|
+
|
|
1047
|
+
action = Action(**payload.action)
|
|
1048
|
+
|
|
1049
|
+
if current_user:
|
|
1050
|
+
if (
|
|
1051
|
+
not context.session.user
|
|
1052
|
+
or context.session.user.identifier != current_user.identifier
|
|
1053
|
+
):
|
|
1054
|
+
raise HTTPException(
|
|
1055
|
+
status_code=401,
|
|
1056
|
+
detail="You are not authorized to upload files for this session",
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
callback = config.code.action_callbacks.get(action.name)
|
|
1060
|
+
if callback:
|
|
1061
|
+
await callback(action)
|
|
1062
|
+
else:
|
|
1063
|
+
raise HTTPException(
|
|
1064
|
+
status_code=404,
|
|
1065
|
+
detail=f"No callback found for action {action.name}",
|
|
1066
|
+
)
|
|
1067
|
+
|
|
1068
|
+
return JSONResponse(content={"success": True})
|
|
1069
|
+
|
|
1070
|
+
|
|
841
1071
|
@router.post("/project/file")
|
|
842
1072
|
async def upload_file(
|
|
843
|
-
current_user:
|
|
1073
|
+
current_user: UserParam,
|
|
844
1074
|
session_id: str,
|
|
845
1075
|
file: UploadFile,
|
|
846
1076
|
):
|
|
@@ -870,6 +1100,11 @@ async def upload_file(
|
|
|
870
1100
|
assert file.filename, "No filename for uploaded file"
|
|
871
1101
|
assert file.content_type, "No content type for uploaded file"
|
|
872
1102
|
|
|
1103
|
+
try:
|
|
1104
|
+
validate_file_upload(file)
|
|
1105
|
+
except ValueError as e:
|
|
1106
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
1107
|
+
|
|
873
1108
|
file_response = await session.persist_file(
|
|
874
1109
|
name=file.filename, content=content, mime=file.content_type
|
|
875
1110
|
)
|
|
@@ -877,14 +1112,79 @@ async def upload_file(
|
|
|
877
1112
|
return JSONResponse(content=file_response)
|
|
878
1113
|
|
|
879
1114
|
|
|
1115
|
+
def validate_file_upload(file: UploadFile):
|
|
1116
|
+
"""Validate the file upload as configured in config.features.spontaneous_file_upload.
|
|
1117
|
+
Args:
|
|
1118
|
+
file (UploadFile): The file to validate.
|
|
1119
|
+
Raises:
|
|
1120
|
+
ValueError: If the file is not allowed.
|
|
1121
|
+
"""
|
|
1122
|
+
if config.features.spontaneous_file_upload is None:
|
|
1123
|
+
"""Default for a missing config is to allow the fileupload without any restrictions"""
|
|
1124
|
+
return
|
|
1125
|
+
if config.features.spontaneous_file_upload.enabled is False:
|
|
1126
|
+
raise ValueError("File upload is not enabled")
|
|
1127
|
+
|
|
1128
|
+
validate_file_mime_type(file)
|
|
1129
|
+
validate_file_size(file)
|
|
1130
|
+
|
|
1131
|
+
|
|
1132
|
+
def validate_file_mime_type(file: UploadFile):
|
|
1133
|
+
"""Validate the file mime type as configured in config.features.spontaneous_file_upload.
|
|
1134
|
+
Args:
|
|
1135
|
+
file (UploadFile): The file to validate.
|
|
1136
|
+
Raises:
|
|
1137
|
+
ValueError: If the file type is not allowed.
|
|
1138
|
+
"""
|
|
1139
|
+
accept = config.features.spontaneous_file_upload.accept
|
|
1140
|
+
if accept is None:
|
|
1141
|
+
"Accept is not configured, allowing all file types"
|
|
1142
|
+
return
|
|
1143
|
+
|
|
1144
|
+
assert (
|
|
1145
|
+
isinstance(accept, List) or isinstance(accept, dict)
|
|
1146
|
+
), "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
|
|
1147
|
+
|
|
1148
|
+
if isinstance(accept, List):
|
|
1149
|
+
for pattern in accept:
|
|
1150
|
+
if fnmatch.fnmatch(file.content_type, pattern):
|
|
1151
|
+
return
|
|
1152
|
+
elif isinstance(accept, dict):
|
|
1153
|
+
for pattern, extensions in accept.items():
|
|
1154
|
+
if fnmatch.fnmatch(file.content_type, pattern):
|
|
1155
|
+
if len(extensions) == 0:
|
|
1156
|
+
return
|
|
1157
|
+
for extension in extensions:
|
|
1158
|
+
if file.filename is not None and file.filename.endswith(extension):
|
|
1159
|
+
return
|
|
1160
|
+
raise ValueError("File type not allowed")
|
|
1161
|
+
|
|
1162
|
+
|
|
1163
|
+
def validate_file_size(file: UploadFile):
|
|
1164
|
+
"""Validate the file size as configured in config.features.spontaneous_file_upload.
|
|
1165
|
+
Args:
|
|
1166
|
+
file (UploadFile): The file to validate.
|
|
1167
|
+
Raises:
|
|
1168
|
+
ValueError: If the file size is too large.
|
|
1169
|
+
"""
|
|
1170
|
+
if config.features.spontaneous_file_upload.max_size_mb is None:
|
|
1171
|
+
return
|
|
1172
|
+
|
|
1173
|
+
if (
|
|
1174
|
+
file.size is not None
|
|
1175
|
+
and file.size
|
|
1176
|
+
> config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
|
|
1177
|
+
):
|
|
1178
|
+
raise ValueError("File size too large")
|
|
1179
|
+
|
|
1180
|
+
|
|
880
1181
|
@router.get("/project/file/{file_id}")
|
|
881
1182
|
async def get_file(
|
|
882
1183
|
file_id: str,
|
|
883
1184
|
session_id: str,
|
|
884
|
-
|
|
1185
|
+
current_user: UserParam,
|
|
885
1186
|
):
|
|
886
1187
|
"""Get a file from the session files directory."""
|
|
887
|
-
|
|
888
1188
|
from chainlit.session import WebsocketSession
|
|
889
1189
|
|
|
890
1190
|
session = WebsocketSession.get_by_id(session_id) if session_id else None
|
|
@@ -895,13 +1195,12 @@ async def get_file(
|
|
|
895
1195
|
detail="Unauthorized",
|
|
896
1196
|
)
|
|
897
1197
|
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
# )
|
|
1198
|
+
if current_user:
|
|
1199
|
+
if not session.user or session.user.identifier != current_user.identifier:
|
|
1200
|
+
raise HTTPException(
|
|
1201
|
+
status_code=401,
|
|
1202
|
+
detail="You are not authorized to download files from this session",
|
|
1203
|
+
)
|
|
905
1204
|
|
|
906
1205
|
if file_id in session.files:
|
|
907
1206
|
file = session.files[file_id]
|
|
@@ -913,7 +1212,7 @@ async def get_file(
|
|
|
913
1212
|
@router.get("/files/{filename:path}")
|
|
914
1213
|
async def serve_file(
|
|
915
1214
|
filename: str,
|
|
916
|
-
current_user:
|
|
1215
|
+
current_user: UserParam,
|
|
917
1216
|
):
|
|
918
1217
|
"""Serve a file from the local filesystem."""
|
|
919
1218
|
|