chainlit 1.1.404__py3-none-any.whl → 1.2.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (63) hide show
  1. chainlit/__init__.py +63 -305
  2. chainlit/_utils.py +8 -0
  3. chainlit/assistant.py +16 -0
  4. chainlit/assistant_settings.py +35 -0
  5. chainlit/callbacks.py +340 -0
  6. chainlit/cli/__init__.py +1 -1
  7. chainlit/config.py +58 -28
  8. chainlit/copilot/dist/index.js +512 -631
  9. chainlit/data/__init__.py +6 -521
  10. chainlit/data/base.py +121 -0
  11. chainlit/data/dynamodb.py +5 -8
  12. chainlit/data/literalai.py +395 -0
  13. chainlit/data/sql_alchemy.py +11 -9
  14. chainlit/data/storage_clients.py +69 -15
  15. chainlit/data/utils.py +29 -0
  16. chainlit/element.py +1 -1
  17. chainlit/emitter.py +7 -0
  18. chainlit/frontend/dist/assets/{DailyMotion-e665b444.js → DailyMotion-aa368b7e.js} +1 -1
  19. chainlit/frontend/dist/assets/{Facebook-5207db92.js → Facebook-0335db46.js} +1 -1
  20. chainlit/frontend/dist/assets/{FilePlayer-86937d6e.js → FilePlayer-8d04256c.js} +1 -1
  21. chainlit/frontend/dist/assets/{Kaltura-c96622c1.js → Kaltura-67c9dd31.js} +1 -1
  22. chainlit/frontend/dist/assets/{Mixcloud-57ae3e32.js → Mixcloud-6bbaccf5.js} +1 -1
  23. chainlit/frontend/dist/assets/{Mux-20373920.js → Mux-c2bcb757.js} +1 -1
  24. chainlit/frontend/dist/assets/{Preview-c68c0613.js → Preview-210f3955.js} +1 -1
  25. chainlit/frontend/dist/assets/{SoundCloud-8a9e3eae.js → SoundCloud-a0276b84.js} +1 -1
  26. chainlit/frontend/dist/assets/{Streamable-1ed099af.js → Streamable-a007323d.js} +1 -1
  27. chainlit/frontend/dist/assets/{Twitch-6820039f.js → Twitch-e6a88aa3.js} +1 -1
  28. chainlit/frontend/dist/assets/{Vidyard-d39ab91d.js → Vidyard-dfb88a35.js} +1 -1
  29. chainlit/frontend/dist/assets/{Vimeo-017cd9a7.js → Vimeo-3baa13d9.js} +1 -1
  30. chainlit/frontend/dist/assets/{Wistia-a509d9f2.js → Wistia-e52f7bef.js} +1 -1
  31. chainlit/frontend/dist/assets/{YouTube-42dfd82f.js → YouTube-1715f22b.js} +1 -1
  32. chainlit/frontend/dist/assets/index-bfdd8585.js +729 -0
  33. chainlit/frontend/dist/assets/react-plotly-55648373.js +3484 -0
  34. chainlit/frontend/dist/index.html +1 -1
  35. chainlit/input_widget.py +22 -0
  36. chainlit/langchain/callbacks.py +6 -1
  37. chainlit/llama_index/callbacks.py +20 -4
  38. chainlit/markdown.py +15 -9
  39. chainlit/message.py +0 -1
  40. chainlit/server.py +113 -37
  41. chainlit/session.py +27 -4
  42. chainlit/socket.py +50 -1
  43. chainlit/translations/bn.json +231 -0
  44. chainlit/translations/en-US.json +6 -0
  45. chainlit/translations/fr-FR.json +236 -0
  46. chainlit/translations/gu.json +231 -0
  47. chainlit/translations/he-IL.json +231 -0
  48. chainlit/translations/hi.json +231 -0
  49. chainlit/translations/kn.json +231 -0
  50. chainlit/translations/ml.json +231 -0
  51. chainlit/translations/mr.json +231 -0
  52. chainlit/translations/ta.json +231 -0
  53. chainlit/translations/te.json +231 -0
  54. chainlit/types.py +1 -1
  55. chainlit/user_session.py +4 -0
  56. chainlit/utils.py +1 -1
  57. {chainlit-1.1.404.dist-info → chainlit-1.2.0rc0.dist-info}/METADATA +10 -10
  58. chainlit-1.2.0rc0.dist-info/RECORD +99 -0
  59. chainlit/frontend/dist/assets/index-30df9b2b.js +0 -730
  60. chainlit/frontend/dist/assets/react-plotly-5bb34118.js +0 -3602
  61. chainlit-1.1.404.dist-info/RECORD +0 -82
  62. {chainlit-1.1.404.dist-info → chainlit-1.2.0rc0.dist-info}/WHEEL +0 -0
  63. {chainlit-1.1.404.dist-info → chainlit-1.2.0rc0.dist-info}/entry_points.txt +0 -0
@@ -21,7 +21,7 @@
21
21
  <script>
22
22
  const global = globalThis;
23
23
  </script>
24
- <script type="module" crossorigin src="/assets/index-30df9b2b.js"></script>
24
+ <script type="module" crossorigin src="/assets/index-bfdd8585.js"></script>
25
25
  <link rel="stylesheet" href="/assets/index-aaf974a9.css">
26
26
  </head>
27
27
  <body>
chainlit/input_widget.py CHANGED
@@ -161,6 +161,28 @@ class NumberInput(InputWidget):
161
161
  "description": self.description,
162
162
  }
163
163
 
164
+ @dataclass
165
+ class FileUploadInput(InputWidget):
166
+ """Useful to create a file upload input."""
167
+
168
+ type: InputWidgetType = "fileupload"
169
+ initial: Optional[str] = None
170
+ placeholder: Optional[str] = None
171
+ accept: List[str] = Field(default_factory=lambda: [])
172
+ max_size_mb: Optional[int] = None
173
+ max_files: Optional[int] = None
174
+
175
+ def to_dict(self) -> Dict[str, Any]:
176
+ return {
177
+ "type": self.type,
178
+ "id": self.id,
179
+ "label": self.label,
180
+ "initial": self.initial,
181
+ "placeholder": self.placeholder,
182
+ "tooltip": self.tooltip,
183
+ "description": self.description,
184
+ }
185
+
164
186
 
165
187
  @dataclass
166
188
  class Tags(InputWidget):
@@ -587,12 +587,17 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
587
587
  outputs = run.outputs or {}
588
588
  output_keys = list(outputs.keys())
589
589
  output = outputs
590
+
590
591
  if output_keys:
591
592
  output = outputs.get(output_keys[0], outputs)
592
593
 
593
594
  if current_step:
594
595
  current_step.output = (
595
- output[0] if isinstance(output, Sequence) and len(output) else output
596
+ output[0]
597
+ if isinstance(output, Sequence)
598
+ and not isinstance(output, str)
599
+ and len(output)
600
+ else output
596
601
  )
597
602
  current_step.end = utc_now()
598
603
  self._run_sync(current_step.update())
@@ -8,6 +8,7 @@ from literalai.helper import utc_now
8
8
  from llama_index.core.callbacks import TokenCountingHandler
9
9
  from llama_index.core.callbacks.schema import CBEventType, EventPayload
10
10
  from llama_index.core.llms import ChatMessage, ChatResponse, CompletionResponse
11
+ from llama_index.core.tools.types import ToolMetadata
11
12
 
12
13
  DEFAULT_IGNORE = [
13
14
  CBEventType.CHUNKING,
@@ -54,7 +55,16 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
54
55
  ) -> str:
55
56
  """Run when an event starts and return id of event."""
56
57
  step_type: StepType = "undefined"
57
- if event_type == CBEventType.RETRIEVE:
58
+ step_name: str = event_type.value
59
+ step_input: Optional[Dict[str, Any]] = payload
60
+ if event_type == CBEventType.FUNCTION_CALL:
61
+ step_type = "tool"
62
+ if payload:
63
+ metadata: Optional[ToolMetadata] = payload.get(EventPayload.TOOL)
64
+ if metadata:
65
+ step_name = getattr(metadata, "name", step_name)
66
+ step_input = payload.get(EventPayload.FUNCTION_CALL)
67
+ elif event_type == CBEventType.RETRIEVE:
58
68
  step_type = "tool"
59
69
  elif event_type == CBEventType.QUERY:
60
70
  step_type = "tool"
@@ -64,7 +74,7 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
64
74
  return event_id
65
75
 
66
76
  step = Step(
67
- name=event_type.value,
77
+ name=step_name,
68
78
  type=step_type,
69
79
  parent_id=self._get_parent_id(parent_id),
70
80
  id=event_id,
@@ -72,7 +82,7 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
72
82
 
73
83
  self.steps[event_id] = step
74
84
  step.start = utc_now()
75
- step.input = payload or {}
85
+ step.input = step_input or {}
76
86
  context_var.get().loop.create_task(step.send())
77
87
  return event_id
78
88
 
@@ -91,7 +101,13 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
91
101
 
92
102
  step.end = utc_now()
93
103
 
94
- if event_type == CBEventType.QUERY:
104
+ if event_type == CBEventType.FUNCTION_CALL:
105
+ response = payload.get(EventPayload.FUNCTION_OUTPUT)
106
+ if response:
107
+ step.output = f"{response}"
108
+ context_var.get().loop.create_task(step.update())
109
+
110
+ elif event_type == CBEventType.QUERY:
95
111
  response = payload.get(EventPayload.RESPONSE)
96
112
  source_nodes = getattr(response, "source_nodes", None)
97
113
  if source_nodes:
chainlit/markdown.py CHANGED
@@ -1,7 +1,11 @@
1
1
  import os
2
+ from pathlib import Path
3
+ from typing import Optional
2
4
 
3
5
  from chainlit.logger import logger
4
6
 
7
+ from ._utils import is_path_inside
8
+
5
9
  # Default chainlit.md file created if none exists
6
10
  DEFAULT_MARKDOWN_STR = """# Welcome to Chainlit! 🚀🤖
7
11
 
@@ -30,12 +34,16 @@ def init_markdown(root: str):
30
34
  logger.info(f"Created default chainlit markdown file at {chainlit_md_file}")
31
35
 
32
36
 
33
- def get_markdown_str(root: str, language: str):
37
+ def get_markdown_str(root: str, language: str) -> Optional[str]:
34
38
  """Get the chainlit.md file as a string."""
35
- translated_chainlit_md_path = os.path.join(root, f"chainlit_{language}.md")
36
- default_chainlit_md_path = os.path.join(root, "chainlit.md")
37
-
38
- if os.path.exists(translated_chainlit_md_path):
39
+ root_path = Path(root)
40
+ translated_chainlit_md_path = root_path / f"chainlit_{language}.md"
41
+ default_chainlit_md_path = root_path / "chainlit.md"
42
+
43
+ if (
44
+ is_path_inside(translated_chainlit_md_path, root_path)
45
+ and translated_chainlit_md_path.is_file()
46
+ ):
39
47
  chainlit_md_path = translated_chainlit_md_path
40
48
  else:
41
49
  chainlit_md_path = default_chainlit_md_path
@@ -43,9 +51,7 @@ def get_markdown_str(root: str, language: str):
43
51
  f"Translated markdown file for {language} not found. Defaulting to chainlit.md."
44
52
  )
45
53
 
46
- if os.path.exists(chainlit_md_path):
47
- with open(chainlit_md_path, "r", encoding="utf-8") as f:
48
- chainlit_md = f.read()
49
- return chainlit_md
54
+ if chainlit_md_path.is_file():
55
+ return chainlit_md_path.read_text(encoding="utf-8")
50
56
  else:
51
57
  return None
chainlit/message.py CHANGED
@@ -82,7 +82,6 @@ class MessageBase(ABC):
82
82
  "output": self.content,
83
83
  "name": self.author,
84
84
  "type": self.type,
85
- "createdAt": self.created_at,
86
85
  "language": self.language,
87
86
  "streaming": self.streaming,
88
87
  "isError": self.is_error,
chainlit/server.py CHANGED
@@ -1,22 +1,15 @@
1
+ import asyncio
1
2
  import glob
2
3
  import json
3
4
  import mimetypes
5
+ import os
4
6
  import re
5
7
  import shutil
6
8
  import urllib.parse
7
- from typing import Any, Optional, Union
8
-
9
- from chainlit.oauth_providers import get_oauth_provider
10
- from chainlit.secret import random_secret
11
-
12
- mimetypes.add_type("application/javascript", ".js")
13
- mimetypes.add_type("text/css", ".css")
14
-
15
- import asyncio
16
- import os
17
9
  import webbrowser
18
10
  from contextlib import asynccontextmanager
19
11
  from pathlib import Path
12
+ from typing import Any, Optional, Union
20
13
 
21
14
  import socketio
22
15
  from chainlit.auth import create_jwt, get_configuration, get_current_user
@@ -34,6 +27,8 @@ from chainlit.data import get_data_layer
34
27
  from chainlit.data.acl import is_thread_author
35
28
  from chainlit.logger import logger
36
29
  from chainlit.markdown import get_markdown_str
30
+ from chainlit.oauth_providers import get_oauth_provider
31
+ from chainlit.secret import random_secret
37
32
  from chainlit.types import (
38
33
  DeleteFeedbackRequest,
39
34
  DeleteThreadRequest,
@@ -62,12 +57,20 @@ from starlette.middleware.cors import CORSMiddleware
62
57
  from typing_extensions import Annotated
63
58
  from watchfiles import awatch
64
59
 
60
+ from ._utils import is_path_inside
61
+
62
+ mimetypes.add_type("application/javascript", ".js")
63
+ mimetypes.add_type("text/css", ".css")
64
+
65
65
  ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
66
66
  IS_SUBMOUNT = os.environ.get("CHAINLIT_SUBMOUNT", "") == "true"
67
+ # If the app is a submount, no need to set the prefix
68
+ PREFIX = ROOT_PATH if ROOT_PATH and not IS_SUBMOUNT else ""
67
69
 
68
70
 
69
71
  @asynccontextmanager
70
72
  async def lifespan(app: FastAPI):
73
+ """Context manager to handle app start and shutdown."""
71
74
  host = config.run.host
72
75
  port = config.run.port
73
76
 
@@ -150,7 +153,18 @@ async def lifespan(app: FastAPI):
150
153
  os._exit(0)
151
154
 
152
155
 
153
- def get_build_dir(local_target: str, packaged_target: str):
156
+ def get_build_dir(local_target: str, packaged_target: str) -> str:
157
+ """
158
+ Get the build directory based on the UI build strategy.
159
+
160
+ Args:
161
+ local_target (str): The local target directory.
162
+ packaged_target (str): The packaged target directory.
163
+
164
+ Returns:
165
+ str: The build directory
166
+ """
167
+
154
168
  local_build_dir = os.path.join(PACKAGE_ROOT, local_target, "dist")
155
169
  packaged_build_dir = os.path.join(BACKEND_ROOT, packaged_target, "dist")
156
170
 
@@ -171,18 +185,14 @@ copilot_build_dir = get_build_dir(os.path.join("libs", "copilot"), "copilot")
171
185
 
172
186
  app = FastAPI(lifespan=lifespan)
173
187
 
174
- sio = socketio.AsyncServer(
175
- cors_allowed_origins=[], async_mode="asgi"
176
- )
177
-
178
- sio_mount_location = f"{ROOT_PATH}/ws" if ROOT_PATH else "ws"
188
+ sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
179
189
 
180
190
  asgi_app = socketio.ASGIApp(
181
191
  socketio_server=sio,
182
- socketio_path=f"{sio_mount_location}/socket.io",
192
+ socketio_path="",
183
193
  )
184
194
 
185
- app.mount(f"/{sio_mount_location}", asgi_app)
195
+ app.mount(f"{PREFIX}/ws/socket.io", asgi_app)
186
196
 
187
197
  app.add_middleware(
188
198
  CORSMiddleware,
@@ -192,16 +202,16 @@ app.add_middleware(
192
202
  allow_headers=["*"],
193
203
  )
194
204
 
195
- router = APIRouter(prefix=ROOT_PATH)
205
+ router = APIRouter(prefix=PREFIX)
196
206
 
197
207
  app.mount(
198
- f"{ROOT_PATH}/public",
208
+ f"{PREFIX}/public",
199
209
  StaticFiles(directory="public", check_dir=False),
200
210
  name="public",
201
211
  )
202
212
 
203
213
  app.mount(
204
- f"{ROOT_PATH}/assets",
214
+ f"{PREFIX}/assets",
205
215
  StaticFiles(
206
216
  packages=[("chainlit", os.path.join(build_dir, "assets"))],
207
217
  follow_symlink=config.project.follow_symlink,
@@ -210,7 +220,7 @@ app.mount(
210
220
  )
211
221
 
212
222
  app.mount(
213
- f"{ROOT_PATH}/copilot",
223
+ f"{PREFIX}/copilot",
214
224
  StaticFiles(
215
225
  packages=[("chainlit", copilot_build_dir)],
216
226
  follow_symlink=config.project.follow_symlink,
@@ -218,7 +228,6 @@ app.mount(
218
228
  name="copilot",
219
229
  )
220
230
 
221
-
222
231
  # -------------------------------------------------------------------------------
223
232
  # SLACK HANDLER
224
233
  # -------------------------------------------------------------------------------
@@ -253,12 +262,19 @@ if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
253
262
  # -------------------------------------------------------------------------------
254
263
 
255
264
 
256
- def replace_between_tags(text: str, start_tag: str, end_tag: str, replacement: str):
265
+ def replace_between_tags(
266
+ text: str, start_tag: str, end_tag: str, replacement: str
267
+ ) -> str:
268
+ """Replace text between two tags in a string."""
269
+
257
270
  pattern = start_tag + ".*?" + end_tag
258
271
  return re.sub(pattern, start_tag + replacement + end_tag, text, flags=re.DOTALL)
259
272
 
260
273
 
261
274
  def get_html_template():
275
+ """
276
+ Get HTML template for the index view.
277
+ """
262
278
  PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
263
279
  JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
264
280
  CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
@@ -345,6 +361,9 @@ async def auth(request: Request):
345
361
 
346
362
  @router.post("/login")
347
363
  async def login(form_data: OAuth2PasswordRequestForm = Depends()):
364
+ """
365
+ Login a user using the password auth callback.
366
+ """
348
367
  if not config.code.password_auth_callback:
349
368
  raise HTTPException(
350
369
  status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
@@ -374,6 +393,7 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
374
393
 
375
394
  @router.post("/logout")
376
395
  async def logout(request: Request, response: Response):
396
+ """Logout the user by calling the on_logout callback."""
377
397
  if config.code.on_logout:
378
398
  return await config.code.on_logout(request, response)
379
399
  return {"success": True}
@@ -381,6 +401,7 @@ async def logout(request: Request, response: Response):
381
401
 
382
402
  @router.post("/auth/header")
383
403
  async def header_auth(request: Request):
404
+ """Login a user using the header_auth_callback."""
384
405
  if not config.code.header_auth_callback:
385
406
  raise HTTPException(
386
407
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -410,6 +431,7 @@ async def header_auth(request: Request):
410
431
 
411
432
  @router.get("/auth/oauth/{provider_id}")
412
433
  async def oauth_login(provider_id: str, request: Request):
434
+ """Redirect the user to the oauth provider login page."""
413
435
  if config.code.oauth_callback is None:
414
436
  raise HTTPException(
415
437
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -436,7 +458,7 @@ async def oauth_login(provider_id: str, request: Request):
436
458
  response = RedirectResponse(
437
459
  url=f"{provider.authorize_url}?{params}",
438
460
  )
439
- samesite = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax") # type: Any
461
+ samesite: Any = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax")
440
462
  secure = samesite.lower() == "none"
441
463
  response.set_cookie(
442
464
  "oauth_state",
@@ -457,6 +479,8 @@ async def oauth_callback(
457
479
  code: Optional[str] = None,
458
480
  state: Optional[str] = None,
459
481
  ):
482
+ """Handle the oauth callback and login the user."""
483
+
460
484
  if config.code.oauth_callback is None:
461
485
  raise HTTPException(
462
486
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -544,6 +568,8 @@ async def oauth_azure_hf_callback(
544
568
  code: Annotated[Optional[str], Form()] = None,
545
569
  id_token: Annotated[Optional[str], Form()] = None,
546
570
  ):
571
+ """Handle the azure ad hybrid flow callback and login the user."""
572
+
547
573
  provider_id = "azure-ad-hybrid"
548
574
  if config.code.oauth_callback is None:
549
575
  raise HTTPException(
@@ -617,9 +643,16 @@ async def oauth_azure_hf_callback(
617
643
  return response
618
644
 
619
645
 
646
+ _language_pattern = (
647
+ "^[a-zA-Z]{2,3}(-[a-zA-Z]{2,3})?(-[a-zA-Z]{2,8})?(-x-[a-zA-Z0-9]{1,8})?$"
648
+ )
649
+
650
+
620
651
  @router.get("/project/translations")
621
652
  async def project_translations(
622
- language: str = Query(default="en-US", description="Language code"),
653
+ language: str = Query(
654
+ default="en-US", description="Language code", pattern=_language_pattern
655
+ ),
623
656
  ):
624
657
  """Return project translations."""
625
658
 
@@ -636,11 +669,14 @@ async def project_translations(
636
669
  @router.get("/project/settings")
637
670
  async def project_settings(
638
671
  current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
639
- language: str = Query(default="en-US", description="Language code"),
672
+ language: str = Query(
673
+ default="en-US", description="Language code", pattern=_language_pattern
674
+ ),
640
675
  ):
641
676
  """Return project settings. This is called by the UI before the establishing the websocket connection."""
642
677
 
643
678
  # Load the markdown file based on the provided language
679
+
644
680
  markdown = get_markdown_str(config.root, language)
645
681
 
646
682
  profiles = []
@@ -808,6 +844,8 @@ async def upload_file(
808
844
  Union[None, User, PersistedUser], Depends(get_current_user)
809
845
  ],
810
846
  ):
847
+ """Upload a file to the session files directory."""
848
+
811
849
  from chainlit.session import WebsocketSession
812
850
 
813
851
  session = WebsocketSession.get_by_id(session_id)
@@ -841,6 +879,8 @@ async def get_file(
841
879
  file_id: str,
842
880
  session_id: Optional[str] = None,
843
881
  ):
882
+ """Get a file from the session files directory."""
883
+
844
884
  from chainlit.session import WebsocketSession
845
885
 
846
886
  session = WebsocketSession.get_by_id(session_id) if session_id else None
@@ -863,11 +903,12 @@ async def serve_file(
863
903
  filename: str,
864
904
  current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
865
905
  ):
906
+ """Serve a file from the local filesystem."""
907
+
866
908
  base_path = Path(config.project.local_fs_path).resolve()
867
909
  file_path = (base_path / filename).resolve()
868
910
 
869
- # Check if the base path is a parent of the file path
870
- if base_path not in file_path.parents:
911
+ if not is_path_inside(file_path, base_path):
871
912
  raise HTTPException(status_code=400, detail="Invalid filename")
872
913
 
873
914
  if file_path.is_file():
@@ -878,6 +919,7 @@ async def serve_file(
878
919
 
879
920
  @router.get("/favicon")
880
921
  async def get_favicon():
922
+ """Get the favicon for the UI."""
881
923
  custom_favicon_path = os.path.join(APP_ROOT, "public", "favicon.*")
882
924
  files = glob.glob(custom_favicon_path)
883
925
 
@@ -893,6 +935,7 @@ async def get_favicon():
893
935
 
894
936
  @router.get("/logo")
895
937
  async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
938
+ """Get the default logo for the UI."""
896
939
  theme_value = theme.value if theme else Theme.light.value
897
940
  logo_path = None
898
941
 
@@ -908,32 +951,65 @@ async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
908
951
 
909
952
  if not logo_path:
910
953
  raise HTTPException(status_code=404, detail="Missing default logo")
954
+
911
955
  media_type, _ = mimetypes.guess_type(logo_path)
912
956
 
913
957
  return FileResponse(logo_path, media_type=media_type)
914
958
 
915
959
 
916
- @router.get("/avatars/{avatar_id}")
960
+ @router.get("/avatars/{avatar_id:str}")
917
961
  async def get_avatar(avatar_id: str):
962
+ """Get the avatar for the user based on the avatar_id."""
963
+ if not re.match(r"^[a-zA-Z0-9_-]+$", avatar_id):
964
+ raise HTTPException(status_code=400, detail="Invalid avatar_id")
965
+
918
966
  if avatar_id == "default":
919
967
  avatar_id = config.ui.name
920
968
 
921
969
  avatar_id = avatar_id.strip().lower().replace(" ", "_")
922
970
 
923
- avatar_path = os.path.join(APP_ROOT, "public", "avatars", f"{avatar_id}.*")
971
+ base_path = Path(APP_ROOT) / "public" / "avatars"
972
+ avatar_pattern = f"{avatar_id}.*"
924
973
 
925
- files = glob.glob(avatar_path)
974
+ matching_files = base_path.glob(avatar_pattern)
975
+
976
+ if avatar_path := next(matching_files, None):
977
+ if not is_path_inside(avatar_path, base_path):
978
+ raise HTTPException(status_code=400, detail="Invalid filename")
979
+
980
+ media_type, _ = mimetypes.guess_type(str(avatar_path))
926
981
 
927
- if files:
928
- avatar_path = files[0]
929
- media_type, _ = mimetypes.guess_type(avatar_path)
930
982
  return FileResponse(avatar_path, media_type=media_type)
931
- else:
932
- return await get_favicon()
983
+
984
+ return await get_favicon()
985
+
986
+
987
+ # post avatar/{avatar_id} (only for authenticated users)
988
+ @router.post("/avatars/{avatar_id}")
989
+ async def upload_avatar(
990
+ avatar_id: str,
991
+ file: UploadFile,
992
+ current_user: Annotated[
993
+ Union[None, User, PersistedUser], Depends(get_current_user)
994
+ ],
995
+ ):
996
+ try:
997
+ avatar_path = os.path.join(APP_ROOT, "public", "avatars", avatar_id)
998
+
999
+ # Ensure the avatars directory exists
1000
+ os.makedirs(os.path.dirname(avatar_path), exist_ok=True)
1001
+
1002
+ with open(avatar_path, "wb") as f:
1003
+ f.write(await file.read())
1004
+ except Exception as e:
1005
+ raise HTTPException(status_code=500, detail=str(e))
1006
+
1007
+ return {"id": avatar_id}
933
1008
 
934
1009
 
935
1010
  @router.head("/")
936
1011
  def status_check():
1012
+ """Check if the site is operational."""
937
1013
  return {"message": "Site is operational"}
938
1014
 
939
1015
 
chainlit/session.py CHANGED
@@ -16,6 +16,7 @@ from typing import (
16
16
  )
17
17
 
18
18
  import aiofiles
19
+ from chainlit.assistant import Assistant
19
20
  from chainlit.logger import logger
20
21
 
21
22
  if TYPE_CHECKING:
@@ -64,7 +65,7 @@ class BaseSession:
64
65
  client_type: ClientType,
65
66
  # Thread id
66
67
  thread_id: Optional[str],
67
- # Logged-in user informations
68
+ # Logged-in user information
68
69
  user: Optional[Union["User", "PersistedUser"]],
69
70
  # Logged-in user token
70
71
  token: Optional[str],
@@ -72,8 +73,12 @@ class BaseSession:
72
73
  user_env: Optional[Dict[str, str]],
73
74
  # Chat profile selected before the session was created
74
75
  chat_profile: Optional[str] = None,
76
+ # Selected assistant
77
+ selected_assistant: Optional[Assistant] = None,
75
78
  # Origin of the request
76
79
  http_referer: Optional[str] = None,
80
+ # assistant settings
81
+ assistant_settings: Optional[Dict[str, Any]] = None,
77
82
  ):
78
83
  if thread_id:
79
84
  self.thread_id_to_resume = thread_id
@@ -90,7 +95,9 @@ class BaseSession:
90
95
 
91
96
  self.id = id
92
97
 
98
+ self.assistant_settings = assistant_settings
93
99
  self.chat_settings: Dict[str, Any] = {}
100
+ self.selected_assistant = selected_assistant
94
101
 
95
102
  @property
96
103
  def files_dir(self):
@@ -153,6 +160,7 @@ class BaseSession:
153
160
  user_session = user_sessions.get(self.id) or {} # type: Dict
154
161
  user_session["chat_settings"] = self.chat_settings
155
162
  user_session["chat_profile"] = self.chat_profile
163
+ user_session["selected_assistant"] = self.selected_assistant
156
164
  user_session["http_referer"] = self.http_referer
157
165
  user_session["client_type"] = self.client_type
158
166
  metadata = clean_metadata(user_session)
@@ -169,13 +177,17 @@ class HTTPSession(BaseSession):
169
177
  client_type: ClientType,
170
178
  # Thread id
171
179
  thread_id: Optional[str] = None,
172
- # Logged-in user informations
180
+ # Logged-in user information
173
181
  user: Optional[Union["User", "PersistedUser"]] = None,
174
182
  # Logged-in user token
175
183
  token: Optional[str] = None,
176
184
  user_env: Optional[Dict[str, str]] = None,
177
185
  # Origin of the request
178
186
  http_referer: Optional[str] = None,
187
+ # assistant settings
188
+ assistant_settings: Optional[Dict[str, Any]] = None,
189
+ # selected assistant
190
+ selected_assistant: Optional[Assistant] = None,
179
191
  ):
180
192
  super().__init__(
181
193
  id=id,
@@ -185,6 +197,8 @@ class HTTPSession(BaseSession):
185
197
  client_type=client_type,
186
198
  user_env=user_env,
187
199
  http_referer=http_referer,
200
+ assistant_settings=assistant_settings,
201
+ selected_assistant=selected_assistant,
188
202
  )
189
203
 
190
204
  def delete(self):
@@ -193,6 +207,9 @@ class HTTPSession(BaseSession):
193
207
  shutil.rmtree(self.files_dir)
194
208
 
195
209
 
210
+ ThreadQueue = Deque[tuple[Callable, object, tuple, Dict]]
211
+
212
+
196
213
  class WebsocketSession(BaseSession):
197
214
  """Internal web socket session object.
198
215
 
@@ -222,16 +239,20 @@ class WebsocketSession(BaseSession):
222
239
  client_type: ClientType,
223
240
  # Thread id
224
241
  thread_id: Optional[str] = None,
225
- # Logged-in user informations
242
+ # Logged-in user information
226
243
  user: Optional[Union["User", "PersistedUser"]] = None,
227
244
  # Logged-in user token
228
245
  token: Optional[str] = None,
229
246
  # Chat profile selected before the session was created
230
247
  chat_profile: Optional[str] = None,
248
+ # Selected assistant
249
+ selected_assistant: Optional[Assistant] = None,
231
250
  # Languages of the user's browser
232
251
  languages: Optional[str] = None,
233
252
  # Origin of the request
234
253
  http_referer: Optional[str] = None,
254
+ # chat settings
255
+ assistant_settings: Optional[Dict[str, Any]] = None,
235
256
  ):
236
257
  super().__init__(
237
258
  id=id,
@@ -241,7 +262,9 @@ class WebsocketSession(BaseSession):
241
262
  user_env=user_env,
242
263
  client_type=client_type,
243
264
  chat_profile=chat_profile,
265
+ selected_assistant=selected_assistant,
244
266
  http_referer=http_referer,
267
+ assistant_settings=assistant_settings,
245
268
  )
246
269
 
247
270
  self.socket_id = socket_id
@@ -250,7 +273,7 @@ class WebsocketSession(BaseSession):
250
273
 
251
274
  self.restored = False
252
275
 
253
- self.thread_queues = {} # type: Dict[str, Deque[Callable]]
276
+ self.thread_queues: Dict[str, ThreadQueue] = {}
254
277
 
255
278
  ws_sessions_id[self.id] = self
256
279
  ws_sessions_sid[socket_id] = self