chainlit 2.0rc1__py3-none-any.whl → 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (52) hide show
  1. chainlit/__init__.py +47 -57
  2. chainlit/action.py +8 -8
  3. chainlit/auth/__init__.py +1 -1
  4. chainlit/auth/cookie.py +7 -9
  5. chainlit/auth/jwt.py +5 -3
  6. chainlit/callbacks.py +1 -1
  7. chainlit/config.py +8 -59
  8. chainlit/copilot/dist/index.js +8319 -1019
  9. chainlit/data/__init__.py +71 -2
  10. chainlit/data/chainlit_data_layer.py +608 -0
  11. chainlit/data/literalai.py +1 -1
  12. chainlit/data/sql_alchemy.py +26 -2
  13. chainlit/data/storage_clients/azure_blob.py +89 -0
  14. chainlit/data/storage_clients/base.py +10 -0
  15. chainlit/data/storage_clients/gcs.py +88 -0
  16. chainlit/data/storage_clients/s3.py +42 -4
  17. chainlit/element.py +7 -4
  18. chainlit/emitter.py +9 -14
  19. chainlit/frontend/dist/assets/{DailyMotion-C-_sjrtO.js → DailyMotion-DFvM941y.js} +1 -1
  20. chainlit/frontend/dist/assets/Dataframe-CA6SlUSB.js +22 -0
  21. chainlit/frontend/dist/assets/{Facebook-bB34P03l.js → Facebook-BM4MwXR1.js} +1 -1
  22. chainlit/frontend/dist/assets/{FilePlayer-BWgqGrXv.js → FilePlayer-CfjB8iXr.js} +1 -1
  23. chainlit/frontend/dist/assets/{Kaltura-OY4P9Ofd.js → Kaltura-Bg-U6Xkz.js} +1 -1
  24. chainlit/frontend/dist/assets/{Mixcloud-9CtT8w5Y.js → Mixcloud-xJrfoMTv.js} +1 -1
  25. chainlit/frontend/dist/assets/{Mux-BH9A0qEi.js → Mux-CKnKDBmk.js} +1 -1
  26. chainlit/frontend/dist/assets/{Preview-Og00EJ05.js → Preview-DwHPdmIg.js} +1 -1
  27. chainlit/frontend/dist/assets/{SoundCloud-D7resGfn.js → SoundCloud-Crd5dwXV.js} +1 -1
  28. chainlit/frontend/dist/assets/{Streamable-6f_6bYz1.js → Streamable-Dq0c8lyx.js} +1 -1
  29. chainlit/frontend/dist/assets/{Twitch-BZJl3peM.js → Twitch-DIDvP936.js} +1 -1
  30. chainlit/frontend/dist/assets/{Vidyard-B7tv4b8_.js → Vidyard-B1dz9WN4.js} +1 -1
  31. chainlit/frontend/dist/assets/{Vimeo-F-eA4zQI.js → Vimeo-22Su6q2w.js} +1 -1
  32. chainlit/frontend/dist/assets/Wistia-C7adXRjN.js +1 -0
  33. chainlit/frontend/dist/assets/{YouTube-aFdJGjI1.js → YouTube-Dt4UMtQI.js} +1 -1
  34. chainlit/frontend/dist/assets/index-DbdLVHtZ.js +8665 -0
  35. chainlit/frontend/dist/assets/index-g8LTJwwr.css +1 -0
  36. chainlit/frontend/dist/assets/{react-plotly-DoUJXMgz.js → react-plotly-DvpXYYRJ.js} +1 -1
  37. chainlit/frontend/dist/index.html +2 -2
  38. chainlit/message.py +1 -3
  39. chainlit/server.py +297 -78
  40. chainlit/session.py +9 -0
  41. chainlit/socket.py +5 -53
  42. chainlit/step.py +0 -1
  43. chainlit/translations/en-US.json +1 -1
  44. chainlit/types.py +17 -3
  45. chainlit/user_session.py +1 -0
  46. {chainlit-2.0rc1.dist-info → chainlit-2.0.2.dist-info}/METADATA +4 -35
  47. {chainlit-2.0rc1.dist-info → chainlit-2.0.2.dist-info}/RECORD +49 -45
  48. chainlit/frontend/dist/assets/Wistia-Dhxhn3IB.js +0 -1
  49. chainlit/frontend/dist/assets/index-Ba33_hdJ.js +0 -1091
  50. chainlit/frontend/dist/assets/index-CwmincdQ.css +0 -1
  51. {chainlit-2.0rc1.dist-info → chainlit-2.0.2.dist-info}/WHEEL +0 -0
  52. {chainlit-2.0rc1.dist-info → chainlit-2.0.2.dist-info}/entry_points.txt +0 -0
@@ -21,8 +21,8 @@
21
21
  <script>
22
22
  const global = globalThis;
23
23
  </script>
24
- <script type="module" crossorigin src="/assets/index-Ba33_hdJ.js"></script>
25
- <link rel="stylesheet" crossorigin href="/assets/index-CwmincdQ.css">
24
+ <script type="module" crossorigin src="/assets/index-DbdLVHtZ.js"></script>
25
+ <link rel="stylesheet" crossorigin href="/assets/index-g8LTJwwr.css">
26
26
  </head>
27
27
  <body>
28
28
  <div id="root"></div>
chainlit/message.py CHANGED
@@ -43,7 +43,6 @@ class MessageBase(ABC):
43
43
  metadata: Optional[Dict] = None
44
44
  tags: Optional[List[str]] = None
45
45
  wait_for_answer = False
46
- indent: Optional[int] = None
47
46
 
48
47
  def __post_init__(self) -> None:
49
48
  trace_event(f"init {self.__class__.__name__}")
@@ -86,7 +85,6 @@ class MessageBase(ABC):
86
85
  "streaming": self.streaming,
87
86
  "isError": self.is_error,
88
87
  "waitForAnswer": self.wait_for_answer,
89
- "indent": self.indent,
90
88
  "metadata": self.metadata or {},
91
89
  "tags": self.tags,
92
90
  }
@@ -542,7 +540,7 @@ class AskActionMessage(AskMessageBase):
542
540
  if res is None:
543
541
  self.content = "Timed out: no action was taken"
544
542
  else:
545
- self.content = f'**Selected:** {res["label"]}'
543
+ self.content = f"**Selected:** {res['label']}"
546
544
 
547
545
  self.wait_for_answer = False
548
546
 
chainlit/server.py CHANGED
@@ -10,7 +10,7 @@ import urllib.parse
10
10
  import webbrowser
11
11
  from contextlib import asynccontextmanager
12
12
  from pathlib import Path
13
- from typing import List, Optional, Union
13
+ from typing import List, Optional, Union, cast
14
14
 
15
15
  import socketio
16
16
  from fastapi import (
@@ -27,13 +27,12 @@ from fastapi import (
27
27
  )
28
28
  from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
29
29
  from fastapi.security import OAuth2PasswordRequestForm
30
- from fastapi.staticfiles import StaticFiles
31
30
  from starlette.datastructures import URL
32
31
  from starlette.middleware.cors import CORSMiddleware
33
32
  from typing_extensions import Annotated
34
33
  from watchfiles import awatch
35
34
 
36
- from chainlit.auth import create_jwt, get_configuration, get_current_user
35
+ from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
37
36
  from chainlit.auth.cookie import (
38
37
  clear_auth_cookie,
39
38
  clear_oauth_state_cookie,
@@ -49,6 +48,7 @@ from chainlit.config import (
49
48
  PACKAGE_ROOT,
50
49
  config,
51
50
  load_module,
51
+ public_dir,
52
52
  reload_config,
53
53
  )
54
54
  from chainlit.data import get_data_layer
@@ -58,11 +58,14 @@ from chainlit.markdown import get_markdown_str
58
58
  from chainlit.oauth_providers import get_oauth_provider
59
59
  from chainlit.secret import random_secret
60
60
  from chainlit.types import (
61
+ CallActionRequest,
61
62
  DeleteFeedbackRequest,
62
63
  DeleteThreadRequest,
64
+ ElementRequest,
63
65
  GetThreadsRequest,
64
66
  Theme,
65
67
  UpdateFeedbackRequest,
68
+ UpdateThreadRequest,
66
69
  )
67
70
  from chainlit.user import PersistedUser, User
68
71
 
@@ -213,29 +216,59 @@ app.add_middleware(
213
216
 
214
217
  router = APIRouter(prefix=PREFIX)
215
218
 
216
- app.mount(
217
- f"{PREFIX}/public",
218
- StaticFiles(directory="public", check_dir=False),
219
- name="public",
220
- )
221
219
 
222
- app.mount(
223
- f"{PREFIX}/assets",
224
- StaticFiles(
225
- packages=[("chainlit", os.path.join(build_dir, "assets"))],
226
- follow_symlink=config.project.follow_symlink,
227
- ),
228
- name="assets",
229
- )
220
+ @router.get("/public/{filename:path}")
221
+ async def serve_public_file(
222
+ filename: str,
223
+ ):
224
+ """Serve a file from public dir."""
230
225
 
231
- app.mount(
232
- f"{PREFIX}/copilot",
233
- StaticFiles(
234
- packages=[("chainlit", copilot_build_dir)],
235
- follow_symlink=config.project.follow_symlink,
236
- ),
237
- name="copilot",
238
- )
226
+ base_path = Path(public_dir)
227
+ file_path = (base_path / filename).resolve()
228
+
229
+ if not is_path_inside(file_path, base_path):
230
+ raise HTTPException(status_code=400, detail="Invalid filename")
231
+
232
+ if file_path.is_file():
233
+ return FileResponse(file_path)
234
+ else:
235
+ raise HTTPException(status_code=404, detail="File not found")
236
+
237
+
238
+ @router.get("/assets/{filename:path}")
239
+ async def serve_asset_file(
240
+ filename: str,
241
+ ):
242
+ """Serve a file from assets dir."""
243
+
244
+ base_path = Path(os.path.join(build_dir, "assets"))
245
+ file_path = (base_path / filename).resolve()
246
+
247
+ if not is_path_inside(file_path, base_path):
248
+ raise HTTPException(status_code=400, detail="Invalid filename")
249
+
250
+ if file_path.is_file():
251
+ return FileResponse(file_path)
252
+ else:
253
+ raise HTTPException(status_code=404, detail="File not found")
254
+
255
+
256
+ @router.get("/copilot/{filename:path}")
257
+ async def serve_copilot_file(
258
+ filename: str,
259
+ ):
260
+ """Serve a file from assets dir."""
261
+
262
+ base_path = Path(copilot_build_dir)
263
+ file_path = (base_path / filename).resolve()
264
+
265
+ if not is_path_inside(file_path, base_path):
266
+ raise HTTPException(status_code=400, detail="Invalid filename")
267
+
268
+ if file_path.is_file():
269
+ return FileResponse(file_path)
270
+ else:
271
+ raise HTTPException(status_code=404, detail="File not found")
239
272
 
240
273
 
241
274
  # -------------------------------------------------------------------------------
@@ -286,6 +319,16 @@ def get_html_template():
286
319
  """
287
320
  Get HTML template for the index view.
288
321
  """
322
+ ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
323
+
324
+ custom_theme = None
325
+ custom_theme_file_path = Path(public_dir) / "theme.json"
326
+ if (
327
+ is_path_inside(custom_theme_file_path, Path(public_dir))
328
+ and custom_theme_file_path.is_file()
329
+ ):
330
+ custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))
331
+
289
332
  PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
290
333
  JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
291
334
  CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
@@ -309,8 +352,8 @@ def get_html_template():
309
352
  <meta property="og:root_path" content="{ROOT_PATH}">"""
310
353
 
311
354
  js = f"""<script>
312
- {f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}
313
- {f"window.transports = {json.dumps(config.project.transports)}; " if config.project.transports else "undefined"}
355
+ {f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
356
+ {f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
314
357
  </script>"""
315
358
 
316
359
  css = None
@@ -323,8 +366,11 @@ def get_html_template():
323
366
  js += f"""<script src="{config.ui.custom_js}" defer></script>"""
324
367
 
325
368
  font = None
326
- if config.ui.custom_font:
327
- font = f"""<link rel="stylesheet" href="{config.ui.custom_font}">"""
369
+ if custom_theme and custom_theme.get("custom_fonts"):
370
+ font = "\n".join(
371
+ f"""<link rel="stylesheet" href="{font}">"""
372
+ for font in custom_theme.get("custom_fonts")
373
+ )
328
374
 
329
375
  index_html_file_path = os.path.join(build_dir, "index.html")
330
376
 
@@ -376,13 +422,6 @@ async def auth(request: Request):
376
422
  def _get_response_dict(access_token: str) -> dict:
377
423
  """Get the response dictionary for the auth response."""
378
424
 
379
- if not config.project.cookie_auth:
380
- # Legacy auth
381
- return {
382
- "access_token": access_token,
383
- "token_type": "bearer",
384
- }
385
-
386
425
  return {"success": True}
387
426
 
388
427
 
@@ -444,8 +483,7 @@ async def _authenticate_user(
444
483
 
445
484
  response = _get_auth_response(access_token, redirect_to_callback)
446
485
 
447
- if config.project.cookie_auth:
448
- set_auth_cookie(response, access_token)
486
+ set_auth_cookie(response, access_token)
449
487
 
450
488
  return response
451
489
 
@@ -470,8 +508,7 @@ async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depen
470
508
  @router.post("/logout")
471
509
  async def logout(request: Request, response: Response):
472
510
  """Logout the user by calling the on_logout callback."""
473
- if config.project.cookie_auth:
474
- clear_auth_cookie(response)
511
+ clear_auth_cookie(response)
475
512
 
476
513
  if config.code.on_logout:
477
514
  return await config.code.on_logout(request, response)
@@ -479,6 +516,35 @@ async def logout(request: Request, response: Response):
479
516
  return {"success": True}
480
517
 
481
518
 
519
+ @router.post("/auth/jwt")
520
+ async def jwt_auth(request: Request):
521
+ """Login a user using a valid jwt."""
522
+ from jwt import InvalidTokenError
523
+
524
+ auth_header: Optional[str] = request.headers.get("Authorization")
525
+ if not auth_header:
526
+ raise HTTPException(status_code=401, detail="Authorization header missing")
527
+
528
+ # Check if it starts with "Bearer "
529
+ try:
530
+ scheme, token = auth_header.split()
531
+ if scheme.lower() != "bearer":
532
+ raise HTTPException(
533
+ status_code=401,
534
+ detail="Invalid authentication scheme. Please use Bearer",
535
+ )
536
+ except ValueError:
537
+ raise HTTPException(
538
+ status_code=401, detail="Invalid authorization header format"
539
+ )
540
+
541
+ try:
542
+ user = decode_jwt(token)
543
+ return await _authenticate_user(user)
544
+ except InvalidTokenError:
545
+ raise HTTPException(status_code=401, detail="Invalid token")
546
+
547
+
482
548
  @router.post("/auth/header")
483
549
  async def header_auth(request: Request):
484
550
  """Login a user using the header_auth_callback."""
@@ -635,7 +701,7 @@ async def oauth_azure_hf_callback(
635
701
  return response
636
702
 
637
703
 
638
- GenericUser = Union[User, PersistedUser]
704
+ GenericUser = Union[User, PersistedUser, None]
639
705
  UserParam = Annotated[GenericUser, Depends(get_current_user)]
640
706
 
641
707
 
@@ -767,6 +833,9 @@ async def get_user_threads(
767
833
  if not data_layer:
768
834
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
769
835
 
836
+ if not current_user:
837
+ raise HTTPException(status_code=401, detail="Unauthorized")
838
+
770
839
  if not isinstance(current_user, PersistedUser):
771
840
  persisted_user = await data_layer.get_user(identifier=current_user.identifier)
772
841
  if not persisted_user:
@@ -791,6 +860,9 @@ async def get_thread(
791
860
  if not data_layer:
792
861
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
793
862
 
863
+ if not current_user:
864
+ raise HTTPException(status_code=401, detail="Unauthorized")
865
+
794
866
  await is_thread_author(current_user.identifier, thread_id)
795
867
 
796
868
  res = await data_layer.get_thread(thread_id)
@@ -810,12 +882,130 @@ async def get_thread_element(
810
882
  if not data_layer:
811
883
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
812
884
 
885
+ if not current_user:
886
+ raise HTTPException(status_code=401, detail="Unauthorized")
887
+
813
888
  await is_thread_author(current_user.identifier, thread_id)
814
889
 
815
890
  res = await data_layer.get_element(thread_id, element_id)
816
891
  return JSONResponse(content=res)
817
892
 
818
893
 
894
+ @router.put("/project/element")
895
+ async def update_thread_element(
896
+ payload: ElementRequest,
897
+ current_user: UserParam,
898
+ ):
899
+ """Update a specific thread element."""
900
+
901
+ from chainlit.context import init_ws_context
902
+ from chainlit.element import CustomElement, ElementDict
903
+ from chainlit.session import WebsocketSession
904
+
905
+ session = WebsocketSession.get_by_id(payload.sessionId)
906
+ context = init_ws_context(session)
907
+
908
+ element_dict = cast(ElementDict, payload.element)
909
+
910
+ if element_dict["type"] != "custom":
911
+ return {"success": False}
912
+
913
+ element = CustomElement(
914
+ id=element_dict["id"],
915
+ object_key=element_dict["objectKey"],
916
+ chainlit_key=element_dict["chainlitKey"],
917
+ url=element_dict["url"],
918
+ for_id=element_dict.get("forId") or "",
919
+ thread_id=element_dict.get("threadId") or "",
920
+ name=element_dict["name"],
921
+ props=element_dict.get("props") or {},
922
+ display=element_dict["display"],
923
+ )
924
+
925
+ if current_user:
926
+ if (
927
+ not context.session.user
928
+ or context.session.user.identifier != current_user.identifier
929
+ ):
930
+ raise HTTPException(
931
+ status_code=401,
932
+ detail="You are not authorized to update elements for this session",
933
+ )
934
+
935
+ await element.send(for_id=element.for_id or "")
936
+ return {"success": True}
937
+
938
+
939
+ @router.delete("/project/element")
940
+ async def delete_thread_element(
941
+ payload: ElementRequest,
942
+ current_user: UserParam,
943
+ ):
944
+ """Delete a specific thread element."""
945
+
946
+ from chainlit.context import init_ws_context
947
+ from chainlit.element import CustomElement, ElementDict
948
+ from chainlit.session import WebsocketSession
949
+
950
+ session = WebsocketSession.get_by_id(payload.sessionId)
951
+ context = init_ws_context(session)
952
+
953
+ element_dict = cast(ElementDict, payload.element)
954
+
955
+ if element_dict["type"] != "custom":
956
+ return {"success": False}
957
+
958
+ element = CustomElement(
959
+ id=element_dict["id"],
960
+ object_key=element_dict["objectKey"],
961
+ chainlit_key=element_dict["chainlitKey"],
962
+ url=element_dict["url"],
963
+ for_id=element_dict.get("forId") or "",
964
+ thread_id=element_dict.get("threadId") or "",
965
+ name=element_dict["name"],
966
+ props=element_dict.get("props") or {},
967
+ display=element_dict["display"],
968
+ )
969
+
970
+ if current_user:
971
+ if (
972
+ not context.session.user
973
+ or context.session.user.identifier != current_user.identifier
974
+ ):
975
+ raise HTTPException(
976
+ status_code=401,
977
+ detail="You are not authorized to remove elements for this session",
978
+ )
979
+
980
+ await element.remove()
981
+
982
+ return {"success": True}
983
+
984
+
985
+ @router.put("/project/thread")
986
+ async def rename_thread(
987
+ request: Request,
988
+ payload: UpdateThreadRequest,
989
+ current_user: UserParam,
990
+ ):
991
+ """Rename a thread."""
992
+
993
+ data_layer = get_data_layer()
994
+
995
+ if not data_layer:
996
+ raise HTTPException(status_code=400, detail="Data persistence is not enabled")
997
+
998
+ if not current_user:
999
+ raise HTTPException(status_code=401, detail="Unauthorized")
1000
+
1001
+ thread_id = payload.threadId
1002
+
1003
+ await is_thread_author(current_user.identifier, thread_id)
1004
+
1005
+ await data_layer.update_thread(thread_id, name=payload.name)
1006
+ return JSONResponse(content={"success": True})
1007
+
1008
+
819
1009
  @router.delete("/project/thread")
820
1010
  async def delete_thread(
821
1011
  request: Request,
@@ -829,6 +1019,9 @@ async def delete_thread(
829
1019
  if not data_layer:
830
1020
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
831
1021
 
1022
+ if not current_user:
1023
+ raise HTTPException(status_code=401, detail="Unauthorized")
1024
+
832
1025
  thread_id = payload.threadId
833
1026
 
834
1027
  await is_thread_author(current_user.identifier, thread_id)
@@ -837,6 +1030,48 @@ async def delete_thread(
837
1030
  return JSONResponse(content={"success": True})
838
1031
 
839
1032
 
1033
+ @router.post("/project/action")
1034
+ async def call_action(
1035
+ payload: CallActionRequest,
1036
+ current_user: UserParam,
1037
+ ):
1038
+ """Run an action."""
1039
+
1040
+ from chainlit.action import Action
1041
+ from chainlit.context import init_ws_context
1042
+ from chainlit.session import WebsocketSession
1043
+
1044
+ session = WebsocketSession.get_by_id(payload.sessionId)
1045
+ context = init_ws_context(session)
1046
+
1047
+ action = Action(**payload.action)
1048
+
1049
+ if current_user:
1050
+ if (
1051
+ not context.session.user
1052
+ or context.session.user.identifier != current_user.identifier
1053
+ ):
1054
+ raise HTTPException(
1055
+ status_code=401,
1056
+ detail="You are not authorized to upload files for this session",
1057
+ )
1058
+
1059
+ callback = config.code.action_callbacks.get(action.name)
1060
+ if callback:
1061
+ if not context.session.has_first_interaction:
1062
+ context.session.has_first_interaction = True
1063
+ asyncio.create_task(context.emitter.init_thread(action.name))
1064
+
1065
+ await callback(action)
1066
+ else:
1067
+ raise HTTPException(
1068
+ status_code=404,
1069
+ detail=f"No callback found for action {action.name}",
1070
+ )
1071
+
1072
+ return JSONResponse(content={"success": True})
1073
+
1074
+
840
1075
  @router.post("/project/file")
841
1076
  async def upload_file(
842
1077
  current_user: UserParam,
@@ -888,11 +1123,14 @@ def validate_file_upload(file: UploadFile):
888
1123
  Raises:
889
1124
  ValueError: If the file is not allowed.
890
1125
  """
891
- if config.features.spontaneous_file_upload is None:
892
- """Default for a missing config is to allow the fileupload without any restrictions"""
893
- return
894
- if config.features.spontaneous_file_upload.enabled is False:
895
- raise ValueError("File upload is not enabled")
1126
+ # TODO: This logic/endpoint is shared across spontaneous uploads and the AskFileMessage API.
1127
+ # Commenting this check until we find a better solution
1128
+
1129
+ # if config.features.spontaneous_file_upload is None:
1130
+ # """Default for a missing config is to allow the fileupload without any restrictions"""
1131
+ # return
1132
+ # if not config.features.spontaneous_file_upload.enabled:
1133
+ # raise ValueError("File upload is not enabled")
896
1134
 
897
1135
  validate_file_mime_type(file)
898
1136
  validate_file_size(file)
@@ -905,14 +1143,19 @@ def validate_file_mime_type(file: UploadFile):
905
1143
  Raises:
906
1144
  ValueError: If the file type is not allowed.
907
1145
  """
908
- accept = config.features.spontaneous_file_upload.accept
909
- if accept is None:
1146
+
1147
+ if (
1148
+ config.features.spontaneous_file_upload is None
1149
+ or config.features.spontaneous_file_upload.accept is None
1150
+ ):
910
1151
  "Accept is not configured, allowing all file types"
911
1152
  return
912
1153
 
913
- assert (
914
- isinstance(accept, List) or isinstance(accept, dict)
915
- ), "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
1154
+ accept = config.features.spontaneous_file_upload.accept
1155
+
1156
+ assert isinstance(accept, List) or isinstance(accept, dict), (
1157
+ "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
1158
+ )
916
1159
 
917
1160
  if isinstance(accept, List):
918
1161
  for pattern in accept:
@@ -936,7 +1179,10 @@ def validate_file_size(file: UploadFile):
936
1179
  Raises:
937
1180
  ValueError: If the file size is too large.
938
1181
  """
939
- if config.features.spontaneous_file_upload.max_size_mb is None:
1182
+ if (
1183
+ config.features.spontaneous_file_upload is None
1184
+ or config.features.spontaneous_file_upload.max_size_mb is None
1185
+ ):
940
1186
  return
941
1187
 
942
1188
  if (
@@ -954,14 +1200,6 @@ async def get_file(
954
1200
  current_user: UserParam,
955
1201
  ):
956
1202
  """Get a file from the session files directory."""
957
-
958
- if not config.project.cookie_auth:
959
- # We cannot make this work safely without cookie auth, so disable it.
960
- raise HTTPException(
961
- status_code=404,
962
- detail="File downloads unavailable.",
963
- )
964
-
965
1203
  from chainlit.session import WebsocketSession
966
1204
 
967
1205
  session = WebsocketSession.get_by_id(session_id) if session_id else None
@@ -986,25 +1224,6 @@ async def get_file(
986
1224
  raise HTTPException(status_code=404, detail="File not found")
987
1225
 
988
1226
 
989
- @router.get("/files/{filename:path}")
990
- async def serve_file(
991
- filename: str,
992
- current_user: UserParam,
993
- ):
994
- """Serve a file from the local filesystem."""
995
-
996
- base_path = Path(config.project.local_fs_path).resolve()
997
- file_path = (base_path / filename).resolve()
998
-
999
- if not is_path_inside(file_path, base_path):
1000
- raise HTTPException(status_code=400, detail="Invalid filename")
1001
-
1002
- if file_path.is_file():
1003
- return FileResponse(file_path)
1004
- else:
1005
- raise HTTPException(status_code=404, detail="File not found")
1006
-
1007
-
1008
1227
  @router.get("/favicon")
1009
1228
  async def get_favicon():
1010
1229
  """Get the favicon for the UI."""
chainlit/session.py CHANGED
@@ -64,6 +64,8 @@ class BaseSession:
64
64
  chat_profile: Optional[str] = None,
65
65
  # Origin of the request
66
66
  http_referer: Optional[str] = None,
67
+ # Cookie
68
+ http_cookie: Optional[str] = None,
67
69
  ):
68
70
  if thread_id:
69
71
  self.thread_id_to_resume = thread_id
@@ -75,6 +77,7 @@ class BaseSession:
75
77
  self.user_env = user_env or {}
76
78
  self.chat_profile = chat_profile
77
79
  self.http_referer = http_referer
80
+ self.http_cookie = http_cookie
78
81
 
79
82
  self.files: Dict[str, FileDict] = {}
80
83
 
@@ -167,6 +170,8 @@ class HTTPSession(BaseSession):
167
170
  user_env: Optional[Dict[str, str]] = None,
168
171
  # Origin of the request
169
172
  http_referer: Optional[str] = None,
173
+ # Cookie
174
+ http_cookie: Optional[str] = None,
170
175
  ):
171
176
  super().__init__(
172
177
  id=id,
@@ -176,6 +181,7 @@ class HTTPSession(BaseSession):
176
181
  client_type=client_type,
177
182
  user_env=user_env,
178
183
  http_referer=http_referer,
184
+ http_cookie=http_cookie,
179
185
  )
180
186
 
181
187
  def delete(self):
@@ -226,6 +232,8 @@ class WebsocketSession(BaseSession):
226
232
  languages: Optional[str] = None,
227
233
  # Origin of the request
228
234
  http_referer: Optional[str] = None,
235
+ # Cookie
236
+ http_cookie: Optional[str] = None,
229
237
  ):
230
238
  super().__init__(
231
239
  id=id,
@@ -236,6 +244,7 @@ class WebsocketSession(BaseSession):
236
244
  client_type=client_type,
237
245
  chat_profile=chat_profile,
238
246
  http_referer=http_referer,
247
+ http_cookie=http_cookie,
239
248
  )
240
249
 
241
250
  self.socket_id = socket_id