chainlit 2.0rc0__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (66) hide show
  1. chainlit/__init__.py +5 -0
  2. chainlit/action.py +4 -2
  3. chainlit/{auth.py → auth/__init__.py} +20 -34
  4. chainlit/auth/cookie.py +124 -0
  5. chainlit/auth/jwt.py +37 -0
  6. chainlit/callbacks.py +28 -0
  7. chainlit/chat_context.py +2 -2
  8. chainlit/chat_settings.py +3 -1
  9. chainlit/cli/__init__.py +14 -1
  10. chainlit/config.py +18 -5
  11. chainlit/context.py +3 -2
  12. chainlit/copilot/dist/index.js +220 -220
  13. chainlit/data/__init__.py +29 -17
  14. chainlit/data/acl.py +3 -2
  15. chainlit/data/base.py +1 -1
  16. chainlit/data/dynamodb.py +5 -3
  17. chainlit/data/literalai.py +3 -5
  18. chainlit/data/sql_alchemy.py +6 -5
  19. chainlit/data/storage_clients/azure.py +1 -0
  20. chainlit/data/storage_clients/s3.py +1 -0
  21. chainlit/discord/app.py +2 -1
  22. chainlit/element.py +6 -5
  23. chainlit/emitter.py +19 -10
  24. chainlit/frontend/dist/assets/{DailyMotion-CleI-8Dh.js → DailyMotion-C-_sjrtO.js} +1 -1
  25. chainlit/frontend/dist/assets/{Facebook-C4PuTowX.js → Facebook-bB34P03l.js} +1 -1
  26. chainlit/frontend/dist/assets/{FilePlayer-D49YToZz.js → FilePlayer-BWgqGrXv.js} +1 -1
  27. chainlit/frontend/dist/assets/{Kaltura-BkZcQEIs.js → Kaltura-OY4P9Ofd.js} +1 -1
  28. chainlit/frontend/dist/assets/{Mixcloud-DzvBFYsm.js → Mixcloud-9CtT8w5Y.js} +1 -1
  29. chainlit/frontend/dist/assets/{Mux-UXPyWWYv.js → Mux-BH9A0qEi.js} +1 -1
  30. chainlit/frontend/dist/assets/{Preview-0YXzpiVm.js → Preview-Og00EJ05.js} +1 -1
  31. chainlit/frontend/dist/assets/{SoundCloud-CS54COex.js → SoundCloud-D7resGfn.js} +1 -1
  32. chainlit/frontend/dist/assets/{Streamable-DYYShO6Q.js → Streamable-6f_6bYz1.js} +1 -1
  33. chainlit/frontend/dist/assets/{Twitch-DG7403Hm.js → Twitch-BZJl3peM.js} +1 -1
  34. chainlit/frontend/dist/assets/{Vidyard-C5JbOHIQ.js → Vidyard-B7tv4b8_.js} +1 -1
  35. chainlit/frontend/dist/assets/{Vimeo-dFLZbhqH.js → Vimeo-F-eA4zQI.js} +1 -1
  36. chainlit/frontend/dist/assets/{Wistia-143Q9V9c.js → Wistia-Dhxhn3IB.js} +1 -1
  37. chainlit/frontend/dist/assets/{YouTube-Dct4gpfH.js → YouTube-aFdJGjI1.js} +1 -1
  38. chainlit/frontend/dist/assets/{index-2yAiK0R5.js → index-Ba33_hdJ.js} +122 -122
  39. chainlit/frontend/dist/assets/{react-plotly-CFHBSMgg.js → react-plotly-DoUJXMgz.js} +1 -1
  40. chainlit/frontend/dist/index.html +1 -1
  41. chainlit/haystack/callbacks.py +5 -4
  42. chainlit/input_widget.py +6 -4
  43. chainlit/langchain/callbacks.py +56 -47
  44. chainlit/langflow/__init__.py +1 -0
  45. chainlit/llama_index/callbacks.py +7 -7
  46. chainlit/message.py +6 -5
  47. chainlit/mistralai/__init__.py +3 -2
  48. chainlit/oauth_providers.py +70 -3
  49. chainlit/openai/__init__.py +3 -2
  50. chainlit/secret.py +1 -1
  51. chainlit/server.py +232 -156
  52. chainlit/session.py +7 -5
  53. chainlit/slack/app.py +3 -2
  54. chainlit/socket.py +88 -63
  55. chainlit/step.py +11 -10
  56. chainlit/sync.py +2 -1
  57. chainlit/teams/app.py +1 -0
  58. chainlit/translations/nl-NL.json +229 -0
  59. chainlit/types.py +3 -1
  60. chainlit/user.py +2 -1
  61. chainlit/utils.py +3 -2
  62. {chainlit-2.0rc0.dist-info → chainlit-2.0rc1.dist-info}/METADATA +1 -1
  63. chainlit-2.0rc1.dist-info/RECORD +102 -0
  64. chainlit-2.0rc0.dist-info/RECORD +0 -99
  65. {chainlit-2.0rc0.dist-info → chainlit-2.0rc1.dist-info}/WHEEL +0 -0
  66. {chainlit-2.0rc0.dist-info → chainlit-2.0rc1.dist-info}/entry_points.txt +0 -0
chainlit/server.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import fnmatch
2
3
  import glob
3
4
  import json
4
5
  import mimetypes
@@ -9,10 +10,37 @@ import urllib.parse
9
10
  import webbrowser
10
11
  from contextlib import asynccontextmanager
11
12
  from pathlib import Path
12
- from typing import Any, Optional, Union
13
+ from typing import List, Optional, Union
13
14
 
14
15
  import socketio
16
+ from fastapi import (
17
+ APIRouter,
18
+ Depends,
19
+ FastAPI,
20
+ Form,
21
+ HTTPException,
22
+ Query,
23
+ Request,
24
+ Response,
25
+ UploadFile,
26
+ status,
27
+ )
28
+ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
29
+ from fastapi.security import OAuth2PasswordRequestForm
30
+ from fastapi.staticfiles import StaticFiles
31
+ from starlette.datastructures import URL
32
+ from starlette.middleware.cors import CORSMiddleware
33
+ from typing_extensions import Annotated
34
+ from watchfiles import awatch
35
+
15
36
  from chainlit.auth import create_jwt, get_configuration, get_current_user
37
+ from chainlit.auth.cookie import (
38
+ clear_auth_cookie,
39
+ clear_oauth_state_cookie,
40
+ set_auth_cookie,
41
+ set_oauth_state_cookie,
42
+ validate_oauth_state_cookie,
43
+ )
16
44
  from chainlit.config import (
17
45
  APP_ROOT,
18
46
  BACKEND_ROOT,
@@ -37,26 +65,6 @@ from chainlit.types import (
37
65
  UpdateFeedbackRequest,
38
66
  )
39
67
  from chainlit.user import PersistedUser, User
40
- from fastapi import (
41
- APIRouter,
42
- Depends,
43
- FastAPI,
44
- File,
45
- Form,
46
- HTTPException,
47
- Query,
48
- Request,
49
- Response,
50
- UploadFile,
51
- status,
52
- )
53
- from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
54
- from fastapi.security import OAuth2PasswordRequestForm
55
- from fastapi.staticfiles import StaticFiles
56
- from starlette.datastructures import URL
57
- from starlette.middleware.cors import CORSMiddleware
58
- from typing_extensions import Annotated
59
- from watchfiles import awatch
60
68
 
61
69
  from ._utils import is_path_inside
62
70
 
@@ -248,6 +256,7 @@ if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_SIGNING_SECRET"):
248
256
 
249
257
  if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
250
258
  from botbuilder.schema import Activity
259
+
251
260
  from chainlit.teams.app import adapter, bot
252
261
 
253
262
  @router.post("/teams/events")
@@ -299,7 +308,10 @@ def get_html_template():
299
308
  <meta property="og:url" content="{url}">
300
309
  <meta property="og:root_path" content="{ROOT_PATH}">"""
301
310
 
302
- js = f"""<script>{f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}</script>"""
311
+ js = f"""<script>
312
+ {f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}
313
+ {f"window.transports = {json.dumps(config.project.transports)}; " if config.project.transports else "undefined"}
314
+ </script>"""
303
315
 
304
316
  css = None
305
317
  if config.ui.custom_css:
@@ -316,7 +328,7 @@ def get_html_template():
316
328
 
317
329
  index_html_file_path = os.path.join(build_dir, "index.html")
318
330
 
319
- with open(index_html_file_path, "r", encoding="utf-8") as f:
331
+ with open(index_html_file_path, encoding="utf-8") as f:
320
332
  content = f.read()
321
333
  content = content.replace(PLACEHOLDER, tags)
322
334
  if js:
@@ -361,43 +373,109 @@ async def auth(request: Request):
361
373
  return get_configuration()
362
374
 
363
375
 
364
- @router.post("/login")
365
- async def login(form_data: OAuth2PasswordRequestForm = Depends()):
366
- """
367
- Login a user using the password auth callback.
368
- """
369
- if not config.code.password_auth_callback:
370
- raise HTTPException(
371
- status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
376
+ def _get_response_dict(access_token: str) -> dict:
377
+ """Get the response dictionary for the auth response."""
378
+
379
+ if not config.project.cookie_auth:
380
+ # Legacy auth
381
+ return {
382
+ "access_token": access_token,
383
+ "token_type": "bearer",
384
+ }
385
+
386
+ return {"success": True}
387
+
388
+
389
+ def _get_auth_response(access_token: str, redirect_to_callback: bool) -> Response:
390
+ """Get the redirect params for the OAuth callback."""
391
+
392
+ response_dict = _get_response_dict(access_token)
393
+
394
+ if redirect_to_callback:
395
+ root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
396
+ redirect_url = (
397
+ f"{root_path}/login/callback?{urllib.parse.urlencode(response_dict)}"
372
398
  )
373
399
 
374
- user = await config.code.password_auth_callback(
375
- form_data.username, form_data.password
400
+ return RedirectResponse(
401
+ # FIXME: redirect to the right frontend base url to improve the dev environment
402
+ url=redirect_url,
403
+ status_code=302,
404
+ )
405
+
406
+ return JSONResponse(response_dict)
407
+
408
+
409
+ def _get_oauth_redirect_error(error: str) -> Response:
410
+ """Get the redirect response for an OAuth error."""
411
+ params = urllib.parse.urlencode(
412
+ {
413
+ "error": error,
414
+ }
376
415
  )
416
+ response = RedirectResponse(
417
+ # FIXME: redirect to the right frontend base url to improve the dev environment
418
+ url=f"/login?{params}", # Shouldn't there be {root_path} here?
419
+ )
420
+ return response
421
+
422
+
423
+ async def _authenticate_user(
424
+ user: Optional[User], redirect_to_callback: bool = False
425
+ ) -> Response:
426
+ """Authenticate a user and return the response."""
377
427
 
378
428
  if not user:
379
429
  raise HTTPException(
380
430
  status_code=status.HTTP_401_UNAUTHORIZED,
381
431
  detail="credentialssignin",
382
432
  )
383
- access_token = create_jwt(user)
433
+
434
+ # If a data layer is defined, attempt to persist user.
384
435
  if data_layer := get_data_layer():
385
436
  try:
386
437
  await data_layer.create_user(user)
387
438
  except Exception as e:
439
+ # Catch and log exceptions during user creation.
440
+ # TODO: Make this catch only specific errors and allow others to propagate.
388
441
  logger.error(f"Error creating user: {e}")
389
442
 
390
- return {
391
- "access_token": access_token,
392
- "token_type": "bearer",
393
- }
443
+ access_token = create_jwt(user)
444
+
445
+ response = _get_auth_response(access_token, redirect_to_callback)
446
+
447
+ if config.project.cookie_auth:
448
+ set_auth_cookie(response, access_token)
449
+
450
+ return response
451
+
452
+
453
+ @router.post("/login")
454
+ async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()):
455
+ """
456
+ Login a user using the password auth callback.
457
+ """
458
+ if not config.code.password_auth_callback:
459
+ raise HTTPException(
460
+ status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
461
+ )
462
+
463
+ user = await config.code.password_auth_callback(
464
+ form_data.username, form_data.password
465
+ )
466
+
467
+ return await _authenticate_user(user)
394
468
 
395
469
 
396
470
  @router.post("/logout")
397
471
  async def logout(request: Request, response: Response):
398
472
  """Logout the user by calling the on_logout callback."""
473
+ if config.project.cookie_auth:
474
+ clear_auth_cookie(response)
475
+
399
476
  if config.code.on_logout:
400
477
  return await config.code.on_logout(request, response)
478
+
401
479
  return {"success": True}
402
480
 
403
481
 
@@ -412,23 +490,7 @@ async def header_auth(request: Request):
412
490
 
413
491
  user = await config.code.header_auth_callback(request.headers)
414
492
 
415
- if not user:
416
- raise HTTPException(
417
- status_code=status.HTTP_401_UNAUTHORIZED,
418
- detail="Unauthorized",
419
- )
420
-
421
- access_token = create_jwt(user)
422
- if data_layer := get_data_layer():
423
- try:
424
- await data_layer.create_user(user)
425
- except Exception as e:
426
- logger.error(f"Error creating user: {e}")
427
-
428
- return {
429
- "access_token": access_token,
430
- "token_type": "bearer",
431
- }
493
+ return await _authenticate_user(user)
432
494
 
433
495
 
434
496
  @router.get("/auth/oauth/{provider_id}")
@@ -460,16 +522,9 @@ async def oauth_login(provider_id: str, request: Request):
460
522
  response = RedirectResponse(
461
523
  url=f"{provider.authorize_url}?{params}",
462
524
  )
463
- samesite: Any = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax")
464
- secure = samesite.lower() == "none"
465
- response.set_cookie(
466
- "oauth_state",
467
- random,
468
- httponly=True,
469
- samesite=samesite,
470
- secure=secure,
471
- max_age=3 * 60,
472
- )
525
+
526
+ set_oauth_state_cookie(response, random)
527
+
473
528
  return response
474
529
 
475
530
 
@@ -497,16 +552,7 @@ async def oauth_callback(
497
552
  )
498
553
 
499
554
  if error:
500
- params = urllib.parse.urlencode(
501
- {
502
- "error": error,
503
- }
504
- )
505
- response = RedirectResponse(
506
- # FIXME: redirect to the right frontend base url to improve the dev environment
507
- url=f"/login?{params}",
508
- )
509
- return response
555
+ return _get_oauth_redirect_error(error)
510
556
 
511
557
  if not code or not state:
512
558
  raise HTTPException(
@@ -514,9 +560,11 @@ async def oauth_callback(
514
560
  detail="Missing code or state",
515
561
  )
516
562
 
517
- # Check the state from the oauth provider against the browser cookie
518
- oauth_state = request.cookies.get("oauth_state")
519
- if oauth_state != state:
563
+ try:
564
+ validate_oauth_state_cookie(request, state)
565
+ except Exception as e:
566
+ logger.exception("Unable to validate oauth state: %1", e)
567
+
520
568
  raise HTTPException(
521
569
  status_code=status.HTTP_401_UNAUTHORIZED,
522
570
  detail="Unauthorized",
@@ -531,34 +579,10 @@ async def oauth_callback(
531
579
  provider_id, token, raw_user_data, default_user
532
580
  )
533
581
 
534
- if not user:
535
- raise HTTPException(
536
- status_code=status.HTTP_401_UNAUTHORIZED,
537
- detail="Unauthorized",
538
- )
539
-
540
- access_token = create_jwt(user)
541
-
542
- if data_layer := get_data_layer():
543
- try:
544
- await data_layer.create_user(user)
545
- except Exception as e:
546
- logger.error(f"Error creating user: {e}")
547
-
548
- params = urllib.parse.urlencode(
549
- {
550
- "access_token": access_token,
551
- "token_type": "bearer",
552
- }
553
- )
582
+ response = await _authenticate_user(user, redirect_to_callback=True)
554
583
 
555
- root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
584
+ clear_oauth_state_cookie(response)
556
585
 
557
- response = RedirectResponse(
558
- # FIXME: redirect to the right frontend base url to improve the dev environment
559
- url=f"{root_path}/login/callback?{params}",
560
- )
561
- response.delete_cookie("oauth_state")
562
586
  return response
563
587
 
564
588
 
@@ -587,16 +611,7 @@ async def oauth_azure_hf_callback(
587
611
  )
588
612
 
589
613
  if error:
590
- params = urllib.parse.urlencode(
591
- {
592
- "error": error,
593
- }
594
- )
595
- response = RedirectResponse(
596
- # FIXME: redirect to the right frontend base url to improve the dev environment
597
- url=f"/login?{params}",
598
- )
599
- return response
614
+ return _get_oauth_redirect_error(error)
600
615
 
601
616
  if not code:
602
617
  raise HTTPException(
@@ -613,36 +628,20 @@ async def oauth_azure_hf_callback(
613
628
  provider_id, token, raw_user_data, default_user, id_token
614
629
  )
615
630
 
616
- if not user:
617
- raise HTTPException(
618
- status_code=status.HTTP_401_UNAUTHORIZED,
619
- detail="Unauthorized",
620
- )
631
+ response = await _authenticate_user(user, redirect_to_callback=True)
621
632
 
622
- access_token = create_jwt(user)
633
+ clear_oauth_state_cookie(response)
623
634
 
624
- if data_layer := get_data_layer():
625
- try:
626
- await data_layer.create_user(user)
627
- except Exception as e:
628
- logger.error(f"Error creating user: {e}")
635
+ return response
629
636
 
630
- params = urllib.parse.urlencode(
631
- {
632
- "access_token": access_token,
633
- "token_type": "bearer",
634
- }
635
- )
636
637
 
637
- root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
638
+ GenericUser = Union[User, PersistedUser]
639
+ UserParam = Annotated[GenericUser, Depends(get_current_user)]
638
640
 
639
- response = RedirectResponse(
640
- # FIXME: redirect to the right frontend base url to improve the dev environment
641
- url=f"{root_path}/login/callback?{params}",
642
- status_code=302,
643
- )
644
- response.delete_cookie("oauth_state")
645
- return response
641
+
642
+ @router.get("/user")
643
+ async def get_user(current_user: UserParam) -> GenericUser:
644
+ return current_user
646
645
 
647
646
 
648
647
  _language_pattern = (
@@ -670,7 +669,7 @@ async def project_translations(
670
669
 
671
670
  @router.get("/project/settings")
672
671
  async def project_settings(
673
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
672
+ current_user: UserParam,
674
673
  language: str = Query(
675
674
  default="en-US", description="Language code", pattern=_language_pattern
676
675
  ),
@@ -721,7 +720,7 @@ async def project_settings(
721
720
  async def update_feedback(
722
721
  request: Request,
723
722
  update: UpdateFeedbackRequest,
724
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
723
+ current_user: UserParam,
725
724
  ):
726
725
  """Update the human feedback for a particular message."""
727
726
  data_layer = get_data_layer()
@@ -740,7 +739,7 @@ async def update_feedback(
740
739
  async def delete_feedback(
741
740
  request: Request,
742
741
  payload: DeleteFeedbackRequest,
743
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
742
+ current_user: UserParam,
744
743
  ):
745
744
  """Delete a feedback."""
746
745
 
@@ -759,7 +758,7 @@ async def delete_feedback(
759
758
  async def get_user_threads(
760
759
  request: Request,
761
760
  payload: GetThreadsRequest,
762
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
761
+ current_user: UserParam,
763
762
  ):
764
763
  """Get the threads page by page."""
765
764
 
@@ -784,7 +783,7 @@ async def get_user_threads(
784
783
  async def get_thread(
785
784
  request: Request,
786
785
  thread_id: str,
787
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
786
+ current_user: UserParam,
788
787
  ):
789
788
  """Get a specific thread."""
790
789
  data_layer = get_data_layer()
@@ -803,7 +802,7 @@ async def get_thread_element(
803
802
  request: Request,
804
803
  thread_id: str,
805
804
  element_id: str,
806
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
805
+ current_user: UserParam,
807
806
  ):
808
807
  """Get a specific thread element."""
809
808
  data_layer = get_data_layer()
@@ -821,7 +820,7 @@ async def get_thread_element(
821
820
  async def delete_thread(
822
821
  request: Request,
823
822
  payload: DeleteThreadRequest,
824
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
823
+ current_user: UserParam,
825
824
  ):
826
825
  """Delete a thread."""
827
826
 
@@ -840,7 +839,7 @@ async def delete_thread(
840
839
 
841
840
  @router.post("/project/file")
842
841
  async def upload_file(
843
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
842
+ current_user: UserParam,
844
843
  session_id: str,
845
844
  file: UploadFile,
846
845
  ):
@@ -870,6 +869,11 @@ async def upload_file(
870
869
  assert file.filename, "No filename for uploaded file"
871
870
  assert file.content_type, "No content type for uploaded file"
872
871
 
872
+ try:
873
+ validate_file_upload(file)
874
+ except ValueError as e:
875
+ raise HTTPException(status_code=400, detail=str(e))
876
+
873
877
  file_response = await session.persist_file(
874
878
  name=file.filename, content=content, mime=file.content_type
875
879
  )
@@ -877,14 +881,87 @@ async def upload_file(
877
881
  return JSONResponse(content=file_response)
878
882
 
879
883
 
884
+ def validate_file_upload(file: UploadFile):
885
+ """Validate the file upload as configured in config.features.spontaneous_file_upload.
886
+ Args:
887
+ file (UploadFile): The file to validate.
888
+ Raises:
889
+ ValueError: If the file is not allowed.
890
+ """
891
+ if config.features.spontaneous_file_upload is None:
892
+ """Default for a missing config is to allow the fileupload without any restrictions"""
893
+ return
894
+ if config.features.spontaneous_file_upload.enabled is False:
895
+ raise ValueError("File upload is not enabled")
896
+
897
+ validate_file_mime_type(file)
898
+ validate_file_size(file)
899
+
900
+
901
+ def validate_file_mime_type(file: UploadFile):
902
+ """Validate the file mime type as configured in config.features.spontaneous_file_upload.
903
+ Args:
904
+ file (UploadFile): The file to validate.
905
+ Raises:
906
+ ValueError: If the file type is not allowed.
907
+ """
908
+ accept = config.features.spontaneous_file_upload.accept
909
+ if accept is None:
910
+ "Accept is not configured, allowing all file types"
911
+ return
912
+
913
+ assert (
914
+ isinstance(accept, List) or isinstance(accept, dict)
915
+ ), "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
916
+
917
+ if isinstance(accept, List):
918
+ for pattern in accept:
919
+ if fnmatch.fnmatch(file.content_type, pattern):
920
+ return
921
+ elif isinstance(accept, dict):
922
+ for pattern, extensions in accept.items():
923
+ if fnmatch.fnmatch(file.content_type, pattern):
924
+ if len(extensions) == 0:
925
+ return
926
+ for extension in extensions:
927
+ if file.filename is not None and file.filename.endswith(extension):
928
+ return
929
+ raise ValueError("File type not allowed")
930
+
931
+
932
+ def validate_file_size(file: UploadFile):
933
+ """Validate the file size as configured in config.features.spontaneous_file_upload.
934
+ Args:
935
+ file (UploadFile): The file to validate.
936
+ Raises:
937
+ ValueError: If the file size is too large.
938
+ """
939
+ if config.features.spontaneous_file_upload.max_size_mb is None:
940
+ return
941
+
942
+ if (
943
+ file.size is not None
944
+ and file.size
945
+ > config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
946
+ ):
947
+ raise ValueError("File size too large")
948
+
949
+
880
950
  @router.get("/project/file/{file_id}")
881
951
  async def get_file(
882
952
  file_id: str,
883
953
  session_id: str,
884
- # current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], #TODO: Causes 401 error. See https://github.com/Chainlit/chainlit/issues/1472
954
+ current_user: UserParam,
885
955
  ):
886
956
  """Get a file from the session files directory."""
887
957
 
958
+ if not config.project.cookie_auth:
959
+ # We cannot make this work safely without cookie auth, so disable it.
960
+ raise HTTPException(
961
+ status_code=404,
962
+ detail="File downloads unavailable.",
963
+ )
964
+
888
965
  from chainlit.session import WebsocketSession
889
966
 
890
967
  session = WebsocketSession.get_by_id(session_id) if session_id else None
@@ -895,13 +972,12 @@ async def get_file(
895
972
  detail="Unauthorized",
896
973
  )
897
974
 
898
- #TODO: Causes 401 error. See https://github.com/Chainlit/chainlit/issues/1472
899
- # if current_user:
900
- # if not session.user or session.user.identifier != current_user.identifier:
901
- # raise HTTPException(
902
- # status_code=401,
903
- # detail="You are not authorized to download files from this session",
904
- # )
975
+ if current_user:
976
+ if not session.user or session.user.identifier != current_user.identifier:
977
+ raise HTTPException(
978
+ status_code=401,
979
+ detail="You are not authorized to download files from this session",
980
+ )
905
981
 
906
982
  if file_id in session.files:
907
983
  file = session.files[file_id]
@@ -913,7 +989,7 @@ async def get_file(
913
989
  @router.get("/files/{filename:path}")
914
990
  async def serve_file(
915
991
  filename: str,
916
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
992
+ current_user: UserParam,
917
993
  ):
918
994
  """Serve a file from the local filesystem."""
919
995
 
chainlit/session.py CHANGED
@@ -6,6 +6,7 @@ import uuid
6
6
  from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Literal, Optional, Union
7
7
 
8
8
  import aiofiles
9
+
9
10
  from chainlit.logger import logger
10
11
  from chainlit.types import FileReference
11
12
 
@@ -17,9 +18,9 @@ ClientType = Literal["webapp", "copilot", "teams", "slack", "discord"]
17
18
 
18
19
 
19
20
  class JSONEncoderIgnoreNonSerializable(json.JSONEncoder):
20
- def default(self, obj):
21
+ def default(self, o):
21
22
  try:
22
- return super(JSONEncoderIgnoreNonSerializable, self).default(obj)
23
+ return super().default(o)
23
24
  except TypeError:
24
25
  return None
25
26
 
@@ -112,9 +113,10 @@ class BaseSession:
112
113
 
113
114
  if path:
114
115
  # Copy the file from the given path
115
- async with aiofiles.open(path, "rb") as src, aiofiles.open(
116
- file_path, "wb"
117
- ) as dst:
116
+ async with (
117
+ aiofiles.open(path, "rb") as src,
118
+ aiofiles.open(file_path, "wb") as dst,
119
+ ):
118
120
  await dst.write(await src.read())
119
121
  elif content:
120
122
  # Write the provided content to the file
chainlit/slack/app.py CHANGED
@@ -7,6 +7,9 @@ from functools import partial
7
7
  from typing import Dict, List, Optional, Union
8
8
 
9
9
  import httpx
10
+ from slack_bolt.adapter.fastapi.async_handler import AsyncSlackRequestHandler
11
+ from slack_bolt.async_app import AsyncApp
12
+
10
13
  from chainlit.config import config
11
14
  from chainlit.context import ChainlitContext, HTTPSession, context, context_var
12
15
  from chainlit.data import get_data_layer
@@ -18,8 +21,6 @@ from chainlit.telemetry import trace
18
21
  from chainlit.types import Feedback
19
22
  from chainlit.user import PersistedUser, User
20
23
  from chainlit.user_session import user_session
21
- from slack_bolt.adapter.fastapi.async_handler import AsyncSlackRequestHandler
22
- from slack_bolt.async_app import AsyncApp
23
24
 
24
25
 
25
26
  class SlackEmitter(BaseChainlitEmitter):