chainlit 1.1.404__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (52) hide show
  1. chainlit/__init__.py +53 -305
  2. chainlit/_utils.py +8 -0
  3. chainlit/callbacks.py +308 -0
  4. chainlit/config.py +55 -29
  5. chainlit/copilot/dist/index.js +510 -629
  6. chainlit/data/__init__.py +6 -521
  7. chainlit/data/base.py +121 -0
  8. chainlit/data/dynamodb.py +2 -5
  9. chainlit/data/literalai.py +395 -0
  10. chainlit/data/sql_alchemy.py +10 -9
  11. chainlit/data/storage_clients.py +69 -15
  12. chainlit/data/utils.py +29 -0
  13. chainlit/frontend/dist/assets/{DailyMotion-e665b444.js → DailyMotion-05f4fe48.js} +1 -1
  14. chainlit/frontend/dist/assets/{Facebook-5207db92.js → Facebook-f25411d1.js} +1 -1
  15. chainlit/frontend/dist/assets/{FilePlayer-86937d6e.js → FilePlayer-40ff3414.js} +1 -1
  16. chainlit/frontend/dist/assets/{Kaltura-c96622c1.js → Kaltura-6cbf3897.js} +1 -1
  17. chainlit/frontend/dist/assets/{Mixcloud-57ae3e32.js → Mixcloud-34e7c912.js} +1 -1
  18. chainlit/frontend/dist/assets/{Mux-20373920.js → Mux-8aaff6ac.js} +1 -1
  19. chainlit/frontend/dist/assets/{Preview-c68c0613.js → Preview-2d3bf558.js} +1 -1
  20. chainlit/frontend/dist/assets/{SoundCloud-8a9e3eae.js → SoundCloud-b835f90f.js} +1 -1
  21. chainlit/frontend/dist/assets/{Streamable-1ed099af.js → Streamable-1293e4f3.js} +1 -1
  22. chainlit/frontend/dist/assets/{Twitch-6820039f.js → Twitch-c69660cd.js} +1 -1
  23. chainlit/frontend/dist/assets/{Vidyard-d39ab91d.js → Vidyard-43bda599.js} +1 -1
  24. chainlit/frontend/dist/assets/{Vimeo-017cd9a7.js → Vimeo-54150039.js} +1 -1
  25. chainlit/frontend/dist/assets/{Wistia-a509d9f2.js → Wistia-aa3c721b.js} +1 -1
  26. chainlit/frontend/dist/assets/{YouTube-42dfd82f.js → YouTube-dd0f3cc2.js} +1 -1
  27. chainlit/frontend/dist/assets/index-cf48bedd.js +729 -0
  28. chainlit/frontend/dist/assets/react-plotly-f52a41eb.js +3484 -0
  29. chainlit/frontend/dist/index.html +1 -1
  30. chainlit/langchain/callbacks.py +6 -1
  31. chainlit/llama_index/callbacks.py +20 -4
  32. chainlit/markdown.py +15 -9
  33. chainlit/message.py +0 -1
  34. chainlit/server.py +90 -36
  35. chainlit/session.py +4 -1
  36. chainlit/translations/bn.json +231 -0
  37. chainlit/translations/gu.json +231 -0
  38. chainlit/translations/he-IL.json +231 -0
  39. chainlit/translations/hi.json +231 -0
  40. chainlit/translations/kn.json +231 -0
  41. chainlit/translations/ml.json +231 -0
  42. chainlit/translations/mr.json +231 -0
  43. chainlit/translations/ta.json +231 -0
  44. chainlit/translations/te.json +231 -0
  45. chainlit/utils.py +1 -1
  46. {chainlit-1.1.404.dist-info → chainlit-1.2.0.dist-info}/METADATA +3 -3
  47. chainlit-1.2.0.dist-info/RECORD +96 -0
  48. chainlit/frontend/dist/assets/index-30df9b2b.js +0 -730
  49. chainlit/frontend/dist/assets/react-plotly-5bb34118.js +0 -3602
  50. chainlit-1.1.404.dist-info/RECORD +0 -82
  51. {chainlit-1.1.404.dist-info → chainlit-1.2.0.dist-info}/WHEEL +0 -0
  52. {chainlit-1.1.404.dist-info → chainlit-1.2.0.dist-info}/entry_points.txt +0 -0
@@ -21,7 +21,7 @@
21
21
  <script>
22
22
  const global = globalThis;
23
23
  </script>
24
- <script type="module" crossorigin src="/assets/index-30df9b2b.js"></script>
24
+ <script type="module" crossorigin src="/assets/index-cf48bedd.js"></script>
25
25
  <link rel="stylesheet" href="/assets/index-aaf974a9.css">
26
26
  </head>
27
27
  <body>
@@ -587,12 +587,17 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
587
587
  outputs = run.outputs or {}
588
588
  output_keys = list(outputs.keys())
589
589
  output = outputs
590
+
590
591
  if output_keys:
591
592
  output = outputs.get(output_keys[0], outputs)
592
593
 
593
594
  if current_step:
594
595
  current_step.output = (
595
- output[0] if isinstance(output, Sequence) and len(output) else output
596
+ output[0]
597
+ if isinstance(output, Sequence)
598
+ and not isinstance(output, str)
599
+ and len(output)
600
+ else output
596
601
  )
597
602
  current_step.end = utc_now()
598
603
  self._run_sync(current_step.update())
@@ -8,6 +8,7 @@ from literalai.helper import utc_now
8
8
  from llama_index.core.callbacks import TokenCountingHandler
9
9
  from llama_index.core.callbacks.schema import CBEventType, EventPayload
10
10
  from llama_index.core.llms import ChatMessage, ChatResponse, CompletionResponse
11
+ from llama_index.core.tools.types import ToolMetadata
11
12
 
12
13
  DEFAULT_IGNORE = [
13
14
  CBEventType.CHUNKING,
@@ -54,7 +55,16 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
54
55
  ) -> str:
55
56
  """Run when an event starts and return id of event."""
56
57
  step_type: StepType = "undefined"
57
- if event_type == CBEventType.RETRIEVE:
58
+ step_name: str = event_type.value
59
+ step_input: Optional[Dict[str, Any]] = payload
60
+ if event_type == CBEventType.FUNCTION_CALL:
61
+ step_type = "tool"
62
+ if payload:
63
+ metadata: Optional[ToolMetadata] = payload.get(EventPayload.TOOL)
64
+ if metadata:
65
+ step_name = getattr(metadata, "name", step_name)
66
+ step_input = payload.get(EventPayload.FUNCTION_CALL)
67
+ elif event_type == CBEventType.RETRIEVE:
58
68
  step_type = "tool"
59
69
  elif event_type == CBEventType.QUERY:
60
70
  step_type = "tool"
@@ -64,7 +74,7 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
64
74
  return event_id
65
75
 
66
76
  step = Step(
67
- name=event_type.value,
77
+ name=step_name,
68
78
  type=step_type,
69
79
  parent_id=self._get_parent_id(parent_id),
70
80
  id=event_id,
@@ -72,7 +82,7 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
72
82
 
73
83
  self.steps[event_id] = step
74
84
  step.start = utc_now()
75
- step.input = payload or {}
85
+ step.input = step_input or {}
76
86
  context_var.get().loop.create_task(step.send())
77
87
  return event_id
78
88
 
@@ -91,7 +101,13 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
91
101
 
92
102
  step.end = utc_now()
93
103
 
94
- if event_type == CBEventType.QUERY:
104
+ if event_type == CBEventType.FUNCTION_CALL:
105
+ response = payload.get(EventPayload.FUNCTION_OUTPUT)
106
+ if response:
107
+ step.output = f"{response}"
108
+ context_var.get().loop.create_task(step.update())
109
+
110
+ elif event_type == CBEventType.QUERY:
95
111
  response = payload.get(EventPayload.RESPONSE)
96
112
  source_nodes = getattr(response, "source_nodes", None)
97
113
  if source_nodes:
chainlit/markdown.py CHANGED
@@ -1,7 +1,11 @@
1
1
  import os
2
+ from pathlib import Path
3
+ from typing import Optional
2
4
 
3
5
  from chainlit.logger import logger
4
6
 
7
+ from ._utils import is_path_inside
8
+
5
9
  # Default chainlit.md file created if none exists
6
10
  DEFAULT_MARKDOWN_STR = """# Welcome to Chainlit! 🚀🤖
7
11
 
@@ -30,12 +34,16 @@ def init_markdown(root: str):
30
34
  logger.info(f"Created default chainlit markdown file at {chainlit_md_file}")
31
35
 
32
36
 
33
- def get_markdown_str(root: str, language: str):
37
+ def get_markdown_str(root: str, language: str) -> Optional[str]:
34
38
  """Get the chainlit.md file as a string."""
35
- translated_chainlit_md_path = os.path.join(root, f"chainlit_{language}.md")
36
- default_chainlit_md_path = os.path.join(root, "chainlit.md")
37
-
38
- if os.path.exists(translated_chainlit_md_path):
39
+ root_path = Path(root)
40
+ translated_chainlit_md_path = root_path / f"chainlit_{language}.md"
41
+ default_chainlit_md_path = root_path / "chainlit.md"
42
+
43
+ if (
44
+ is_path_inside(translated_chainlit_md_path, root_path)
45
+ and translated_chainlit_md_path.is_file()
46
+ ):
39
47
  chainlit_md_path = translated_chainlit_md_path
40
48
  else:
41
49
  chainlit_md_path = default_chainlit_md_path
@@ -43,9 +51,7 @@ def get_markdown_str(root: str, language: str):
43
51
  f"Translated markdown file for {language} not found. Defaulting to chainlit.md."
44
52
  )
45
53
 
46
- if os.path.exists(chainlit_md_path):
47
- with open(chainlit_md_path, "r", encoding="utf-8") as f:
48
- chainlit_md = f.read()
49
- return chainlit_md
54
+ if chainlit_md_path.is_file():
55
+ return chainlit_md_path.read_text(encoding="utf-8")
50
56
  else:
51
57
  return None
chainlit/message.py CHANGED
@@ -82,7 +82,6 @@ class MessageBase(ABC):
82
82
  "output": self.content,
83
83
  "name": self.author,
84
84
  "type": self.type,
85
- "createdAt": self.created_at,
86
85
  "language": self.language,
87
86
  "streaming": self.streaming,
88
87
  "isError": self.is_error,
chainlit/server.py CHANGED
@@ -1,22 +1,15 @@
1
+ import asyncio
1
2
  import glob
2
3
  import json
3
4
  import mimetypes
5
+ import os
4
6
  import re
5
7
  import shutil
6
8
  import urllib.parse
7
- from typing import Any, Optional, Union
8
-
9
- from chainlit.oauth_providers import get_oauth_provider
10
- from chainlit.secret import random_secret
11
-
12
- mimetypes.add_type("application/javascript", ".js")
13
- mimetypes.add_type("text/css", ".css")
14
-
15
- import asyncio
16
- import os
17
9
  import webbrowser
18
10
  from contextlib import asynccontextmanager
19
11
  from pathlib import Path
12
+ from typing import Any, Optional, Union
20
13
 
21
14
  import socketio
22
15
  from chainlit.auth import create_jwt, get_configuration, get_current_user
@@ -34,6 +27,8 @@ from chainlit.data import get_data_layer
34
27
  from chainlit.data.acl import is_thread_author
35
28
  from chainlit.logger import logger
36
29
  from chainlit.markdown import get_markdown_str
30
+ from chainlit.oauth_providers import get_oauth_provider
31
+ from chainlit.secret import random_secret
37
32
  from chainlit.types import (
38
33
  DeleteFeedbackRequest,
39
34
  DeleteThreadRequest,
@@ -62,12 +57,20 @@ from starlette.middleware.cors import CORSMiddleware
62
57
  from typing_extensions import Annotated
63
58
  from watchfiles import awatch
64
59
 
60
+ from ._utils import is_path_inside
61
+
62
+ mimetypes.add_type("application/javascript", ".js")
63
+ mimetypes.add_type("text/css", ".css")
64
+
65
65
  ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
66
66
  IS_SUBMOUNT = os.environ.get("CHAINLIT_SUBMOUNT", "") == "true"
67
+ # If the app is a submount, no need to set the prefix
68
+ PREFIX = ROOT_PATH if ROOT_PATH and not IS_SUBMOUNT else ""
67
69
 
68
70
 
69
71
  @asynccontextmanager
70
72
  async def lifespan(app: FastAPI):
73
+ """Context manager to handle app start and shutdown."""
71
74
  host = config.run.host
72
75
  port = config.run.port
73
76
 
@@ -150,7 +153,18 @@ async def lifespan(app: FastAPI):
150
153
  os._exit(0)
151
154
 
152
155
 
153
- def get_build_dir(local_target: str, packaged_target: str):
156
+ def get_build_dir(local_target: str, packaged_target: str) -> str:
157
+ """
158
+ Get the build directory based on the UI build strategy.
159
+
160
+ Args:
161
+ local_target (str): The local target directory.
162
+ packaged_target (str): The packaged target directory.
163
+
164
+ Returns:
165
+ str: The build directory
166
+ """
167
+
154
168
  local_build_dir = os.path.join(PACKAGE_ROOT, local_target, "dist")
155
169
  packaged_build_dir = os.path.join(BACKEND_ROOT, packaged_target, "dist")
156
170
 
@@ -171,18 +185,14 @@ copilot_build_dir = get_build_dir(os.path.join("libs", "copilot"), "copilot")
171
185
 
172
186
  app = FastAPI(lifespan=lifespan)
173
187
 
174
- sio = socketio.AsyncServer(
175
- cors_allowed_origins=[], async_mode="asgi"
176
- )
177
-
178
- sio_mount_location = f"{ROOT_PATH}/ws" if ROOT_PATH else "ws"
188
+ sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
179
189
 
180
190
  asgi_app = socketio.ASGIApp(
181
191
  socketio_server=sio,
182
- socketio_path=f"{sio_mount_location}/socket.io",
192
+ socketio_path="",
183
193
  )
184
194
 
185
- app.mount(f"/{sio_mount_location}", asgi_app)
195
+ app.mount(f"{PREFIX}/ws/socket.io", asgi_app)
186
196
 
187
197
  app.add_middleware(
188
198
  CORSMiddleware,
@@ -192,16 +202,16 @@ app.add_middleware(
192
202
  allow_headers=["*"],
193
203
  )
194
204
 
195
- router = APIRouter(prefix=ROOT_PATH)
205
+ router = APIRouter(prefix=PREFIX)
196
206
 
197
207
  app.mount(
198
- f"{ROOT_PATH}/public",
208
+ f"{PREFIX}/public",
199
209
  StaticFiles(directory="public", check_dir=False),
200
210
  name="public",
201
211
  )
202
212
 
203
213
  app.mount(
204
- f"{ROOT_PATH}/assets",
214
+ f"{PREFIX}/assets",
205
215
  StaticFiles(
206
216
  packages=[("chainlit", os.path.join(build_dir, "assets"))],
207
217
  follow_symlink=config.project.follow_symlink,
@@ -210,7 +220,7 @@ app.mount(
210
220
  )
211
221
 
212
222
  app.mount(
213
- f"{ROOT_PATH}/copilot",
223
+ f"{PREFIX}/copilot",
214
224
  StaticFiles(
215
225
  packages=[("chainlit", copilot_build_dir)],
216
226
  follow_symlink=config.project.follow_symlink,
@@ -253,12 +263,19 @@ if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
253
263
  # -------------------------------------------------------------------------------
254
264
 
255
265
 
256
- def replace_between_tags(text: str, start_tag: str, end_tag: str, replacement: str):
266
+ def replace_between_tags(
267
+ text: str, start_tag: str, end_tag: str, replacement: str
268
+ ) -> str:
269
+ """Replace text between two tags in a string."""
270
+
257
271
  pattern = start_tag + ".*?" + end_tag
258
272
  return re.sub(pattern, start_tag + replacement + end_tag, text, flags=re.DOTALL)
259
273
 
260
274
 
261
275
  def get_html_template():
276
+ """
277
+ Get HTML template for the index view.
278
+ """
262
279
  PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
263
280
  JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
264
281
  CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
@@ -345,6 +362,9 @@ async def auth(request: Request):
345
362
 
346
363
  @router.post("/login")
347
364
  async def login(form_data: OAuth2PasswordRequestForm = Depends()):
365
+ """
366
+ Login a user using the password auth callback.
367
+ """
348
368
  if not config.code.password_auth_callback:
349
369
  raise HTTPException(
350
370
  status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
@@ -374,6 +394,7 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
374
394
 
375
395
  @router.post("/logout")
376
396
  async def logout(request: Request, response: Response):
397
+ """Logout the user by calling the on_logout callback."""
377
398
  if config.code.on_logout:
378
399
  return await config.code.on_logout(request, response)
379
400
  return {"success": True}
@@ -381,6 +402,7 @@ async def logout(request: Request, response: Response):
381
402
 
382
403
  @router.post("/auth/header")
383
404
  async def header_auth(request: Request):
405
+ """Login a user using the header_auth_callback."""
384
406
  if not config.code.header_auth_callback:
385
407
  raise HTTPException(
386
408
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -410,6 +432,7 @@ async def header_auth(request: Request):
410
432
 
411
433
  @router.get("/auth/oauth/{provider_id}")
412
434
  async def oauth_login(provider_id: str, request: Request):
435
+ """Redirect the user to the oauth provider login page."""
413
436
  if config.code.oauth_callback is None:
414
437
  raise HTTPException(
415
438
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -436,7 +459,7 @@ async def oauth_login(provider_id: str, request: Request):
436
459
  response = RedirectResponse(
437
460
  url=f"{provider.authorize_url}?{params}",
438
461
  )
439
- samesite = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax") # type: Any
462
+ samesite: Any = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax")
440
463
  secure = samesite.lower() == "none"
441
464
  response.set_cookie(
442
465
  "oauth_state",
@@ -457,6 +480,8 @@ async def oauth_callback(
457
480
  code: Optional[str] = None,
458
481
  state: Optional[str] = None,
459
482
  ):
483
+ """Handle the oauth callback and login the user."""
484
+
460
485
  if config.code.oauth_callback is None:
461
486
  raise HTTPException(
462
487
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -544,6 +569,8 @@ async def oauth_azure_hf_callback(
544
569
  code: Annotated[Optional[str], Form()] = None,
545
570
  id_token: Annotated[Optional[str], Form()] = None,
546
571
  ):
572
+ """Handle the azure ad hybrid flow callback and login the user."""
573
+
547
574
  provider_id = "azure-ad-hybrid"
548
575
  if config.code.oauth_callback is None:
549
576
  raise HTTPException(
@@ -617,9 +644,16 @@ async def oauth_azure_hf_callback(
617
644
  return response
618
645
 
619
646
 
647
+ _language_pattern = (
648
+ "^[a-zA-Z]{2,3}(-[a-zA-Z]{2,3})?(-[a-zA-Z]{2,8})?(-x-[a-zA-Z0-9]{1,8})?$"
649
+ )
650
+
651
+
620
652
  @router.get("/project/translations")
621
653
  async def project_translations(
622
- language: str = Query(default="en-US", description="Language code"),
654
+ language: str = Query(
655
+ default="en-US", description="Language code", pattern=_language_pattern
656
+ ),
623
657
  ):
624
658
  """Return project translations."""
625
659
 
@@ -636,11 +670,14 @@ async def project_translations(
636
670
  @router.get("/project/settings")
637
671
  async def project_settings(
638
672
  current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
639
- language: str = Query(default="en-US", description="Language code"),
673
+ language: str = Query(
674
+ default="en-US", description="Language code", pattern=_language_pattern
675
+ ),
640
676
  ):
641
677
  """Return project settings. This is called by the UI before the establishing the websocket connection."""
642
678
 
643
679
  # Load the markdown file based on the provided language
680
+
644
681
  markdown = get_markdown_str(config.root, language)
645
682
 
646
683
  profiles = []
@@ -808,6 +845,8 @@ async def upload_file(
808
845
  Union[None, User, PersistedUser], Depends(get_current_user)
809
846
  ],
810
847
  ):
848
+ """Upload a file to the session files directory."""
849
+
811
850
  from chainlit.session import WebsocketSession
812
851
 
813
852
  session = WebsocketSession.get_by_id(session_id)
@@ -841,6 +880,8 @@ async def get_file(
841
880
  file_id: str,
842
881
  session_id: Optional[str] = None,
843
882
  ):
883
+ """Get a file from the session files directory."""
884
+
844
885
  from chainlit.session import WebsocketSession
845
886
 
846
887
  session = WebsocketSession.get_by_id(session_id) if session_id else None
@@ -863,11 +904,12 @@ async def serve_file(
863
904
  filename: str,
864
905
  current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
865
906
  ):
907
+ """Serve a file from the local filesystem."""
908
+
866
909
  base_path = Path(config.project.local_fs_path).resolve()
867
910
  file_path = (base_path / filename).resolve()
868
911
 
869
- # Check if the base path is a parent of the file path
870
- if base_path not in file_path.parents:
912
+ if not is_path_inside(file_path, base_path):
871
913
  raise HTTPException(status_code=400, detail="Invalid filename")
872
914
 
873
915
  if file_path.is_file():
@@ -878,6 +920,7 @@ async def serve_file(
878
920
 
879
921
  @router.get("/favicon")
880
922
  async def get_favicon():
923
+ """Get the favicon for the UI."""
881
924
  custom_favicon_path = os.path.join(APP_ROOT, "public", "favicon.*")
882
925
  files = glob.glob(custom_favicon_path)
883
926
 
@@ -893,6 +936,7 @@ async def get_favicon():
893
936
 
894
937
  @router.get("/logo")
895
938
  async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
939
+ """Get the default logo for the UI."""
896
940
  theme_value = theme.value if theme else Theme.light.value
897
941
  logo_path = None
898
942
 
@@ -908,32 +952,42 @@ async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
908
952
 
909
953
  if not logo_path:
910
954
  raise HTTPException(status_code=404, detail="Missing default logo")
955
+
911
956
  media_type, _ = mimetypes.guess_type(logo_path)
912
957
 
913
958
  return FileResponse(logo_path, media_type=media_type)
914
959
 
915
960
 
916
- @router.get("/avatars/{avatar_id}")
961
+ @router.get("/avatars/{avatar_id:str}")
917
962
  async def get_avatar(avatar_id: str):
963
+ """Get the avatar for the user based on the avatar_id."""
964
+ if not re.match(r"^[a-zA-Z0-9_-]+$", avatar_id):
965
+ raise HTTPException(status_code=400, detail="Invalid avatar_id")
966
+
918
967
  if avatar_id == "default":
919
968
  avatar_id = config.ui.name
920
969
 
921
970
  avatar_id = avatar_id.strip().lower().replace(" ", "_")
922
971
 
923
- avatar_path = os.path.join(APP_ROOT, "public", "avatars", f"{avatar_id}.*")
972
+ base_path = Path(APP_ROOT) / "public" / "avatars"
973
+ avatar_pattern = f"{avatar_id}.*"
924
974
 
925
- files = glob.glob(avatar_path)
975
+ matching_files = base_path.glob(avatar_pattern)
976
+
977
+ if avatar_path := next(matching_files, None):
978
+ if not is_path_inside(avatar_path, base_path):
979
+ raise HTTPException(status_code=400, detail="Invalid filename")
980
+
981
+ media_type, _ = mimetypes.guess_type(str(avatar_path))
926
982
 
927
- if files:
928
- avatar_path = files[0]
929
- media_type, _ = mimetypes.guess_type(avatar_path)
930
983
  return FileResponse(avatar_path, media_type=media_type)
931
- else:
932
- return await get_favicon()
984
+
985
+ return await get_favicon()
933
986
 
934
987
 
935
988
  @router.head("/")
936
989
  def status_check():
990
+ """Check if the site is operational."""
937
991
  return {"message": "Site is operational"}
938
992
 
939
993
 
chainlit/session.py CHANGED
@@ -193,6 +193,9 @@ class HTTPSession(BaseSession):
193
193
  shutil.rmtree(self.files_dir)
194
194
 
195
195
 
196
+ ThreadQueue = Deque[tuple[Callable, object, tuple, Dict]]
197
+
198
+
196
199
  class WebsocketSession(BaseSession):
197
200
  """Internal web socket session object.
198
201
 
@@ -250,7 +253,7 @@ class WebsocketSession(BaseSession):
250
253
 
251
254
  self.restored = False
252
255
 
253
- self.thread_queues = {} # type: Dict[str, Deque[Callable]]
256
+ self.thread_queues: Dict[str, ThreadQueue] = {}
254
257
 
255
258
  ws_sessions_id[self.id] = self
256
259
  ws_sessions_sid[socket_id] = self