chainlit 1.0.401__py3-none-any.whl → 2.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (113) hide show
  1. chainlit/__init__.py +98 -279
  2. chainlit/_utils.py +8 -0
  3. chainlit/action.py +12 -10
  4. chainlit/{auth.py → auth/__init__.py} +28 -36
  5. chainlit/auth/cookie.py +123 -0
  6. chainlit/auth/jwt.py +39 -0
  7. chainlit/cache.py +4 -6
  8. chainlit/callbacks.py +362 -0
  9. chainlit/chat_context.py +64 -0
  10. chainlit/chat_settings.py +3 -1
  11. chainlit/cli/__init__.py +77 -8
  12. chainlit/config.py +191 -102
  13. chainlit/context.py +42 -13
  14. chainlit/copilot/dist/index.js +8750 -903
  15. chainlit/data/__init__.py +101 -416
  16. chainlit/data/acl.py +6 -2
  17. chainlit/data/base.py +107 -0
  18. chainlit/data/chainlit_data_layer.py +614 -0
  19. chainlit/data/dynamodb.py +590 -0
  20. chainlit/data/literalai.py +500 -0
  21. chainlit/data/sql_alchemy.py +721 -0
  22. chainlit/data/storage_clients/__init__.py +0 -0
  23. chainlit/data/storage_clients/azure.py +81 -0
  24. chainlit/data/storage_clients/azure_blob.py +89 -0
  25. chainlit/data/storage_clients/base.py +26 -0
  26. chainlit/data/storage_clients/gcs.py +88 -0
  27. chainlit/data/storage_clients/s3.py +75 -0
  28. chainlit/data/utils.py +29 -0
  29. chainlit/discord/__init__.py +6 -0
  30. chainlit/discord/app.py +354 -0
  31. chainlit/element.py +91 -33
  32. chainlit/emitter.py +81 -29
  33. chainlit/frontend/dist/assets/DailyMotion-Ce9dQoqZ.js +1 -0
  34. chainlit/frontend/dist/assets/Dataframe-C1XonMcV.js +22 -0
  35. chainlit/frontend/dist/assets/Facebook-DVVt6lrr.js +1 -0
  36. chainlit/frontend/dist/assets/FilePlayer-c7stW4vz.js +1 -0
  37. chainlit/frontend/dist/assets/Kaltura-BmMmgorA.js +1 -0
  38. chainlit/frontend/dist/assets/Mixcloud-Cw8hDmiO.js +1 -0
  39. chainlit/frontend/dist/assets/Mux-DiRZfeUf.js +1 -0
  40. chainlit/frontend/dist/assets/Preview-6Jt2mRHx.js +1 -0
  41. chainlit/frontend/dist/assets/SoundCloud-DKwcT58_.js +1 -0
  42. chainlit/frontend/dist/assets/Streamable-BVdxrEeX.js +1 -0
  43. chainlit/frontend/dist/assets/Twitch-DFqZR7Gu.js +1 -0
  44. chainlit/frontend/dist/assets/Vidyard-0BQAAtVk.js +1 -0
  45. chainlit/frontend/dist/assets/Vimeo-CRFSH0Vu.js +1 -0
  46. chainlit/frontend/dist/assets/Wistia-CKrmdQaG.js +1 -0
  47. chainlit/frontend/dist/assets/YouTube-CQpL-rvU.js +1 -0
  48. chainlit/frontend/dist/assets/index-DQmLRKyv.css +1 -0
  49. chainlit/frontend/dist/assets/index-QdmxtIMQ.js +8665 -0
  50. chainlit/frontend/dist/assets/react-plotly-B9hvVpUG.js +3484 -0
  51. chainlit/frontend/dist/index.html +2 -4
  52. chainlit/haystack/callbacks.py +4 -7
  53. chainlit/input_widget.py +8 -4
  54. chainlit/langchain/callbacks.py +103 -68
  55. chainlit/langflow/__init__.py +1 -0
  56. chainlit/llama_index/callbacks.py +65 -40
  57. chainlit/markdown.py +22 -6
  58. chainlit/message.py +54 -56
  59. chainlit/mistralai/__init__.py +50 -0
  60. chainlit/oauth_providers.py +266 -8
  61. chainlit/openai/__init__.py +10 -18
  62. chainlit/secret.py +1 -1
  63. chainlit/server.py +789 -228
  64. chainlit/session.py +108 -90
  65. chainlit/slack/__init__.py +6 -0
  66. chainlit/slack/app.py +397 -0
  67. chainlit/socket.py +199 -116
  68. chainlit/step.py +141 -89
  69. chainlit/sync.py +2 -1
  70. chainlit/teams/__init__.py +6 -0
  71. chainlit/teams/app.py +338 -0
  72. chainlit/translations/bn.json +244 -0
  73. chainlit/translations/en-US.json +122 -8
  74. chainlit/translations/gu.json +244 -0
  75. chainlit/translations/he-IL.json +244 -0
  76. chainlit/translations/hi.json +244 -0
  77. chainlit/translations/ja.json +242 -0
  78. chainlit/translations/kn.json +244 -0
  79. chainlit/translations/ml.json +244 -0
  80. chainlit/translations/mr.json +244 -0
  81. chainlit/translations/nl-NL.json +242 -0
  82. chainlit/translations/ta.json +244 -0
  83. chainlit/translations/te.json +244 -0
  84. chainlit/translations/zh-CN.json +243 -0
  85. chainlit/translations.py +60 -0
  86. chainlit/types.py +133 -28
  87. chainlit/user.py +14 -3
  88. chainlit/user_session.py +6 -3
  89. chainlit/utils.py +52 -5
  90. chainlit/version.py +3 -2
  91. {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/METADATA +48 -50
  92. chainlit-2.0.4.dist-info/RECORD +107 -0
  93. chainlit/cli/utils.py +0 -24
  94. chainlit/frontend/dist/assets/index-9711593e.js +0 -723
  95. chainlit/frontend/dist/assets/index-d088547c.css +0 -1
  96. chainlit/frontend/dist/assets/react-plotly-d8762cc2.js +0 -3602
  97. chainlit/playground/__init__.py +0 -2
  98. chainlit/playground/config.py +0 -40
  99. chainlit/playground/provider.py +0 -108
  100. chainlit/playground/providers/__init__.py +0 -13
  101. chainlit/playground/providers/anthropic.py +0 -118
  102. chainlit/playground/providers/huggingface.py +0 -75
  103. chainlit/playground/providers/langchain.py +0 -89
  104. chainlit/playground/providers/openai.py +0 -408
  105. chainlit/playground/providers/vertexai.py +0 -171
  106. chainlit/translations/pt-BR.json +0 -155
  107. chainlit-1.0.401.dist-info/RECORD +0 -66
  108. /chainlit/copilot/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  109. /chainlit/copilot/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  110. /chainlit/frontend/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  111. /chainlit/frontend/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  112. {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/WHEEL +0 -0
  113. {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/entry_points.txt +0 -0
chainlit/server.py CHANGED
@@ -1,24 +1,45 @@
1
+ import asyncio
2
+ import fnmatch
1
3
  import glob
2
4
  import json
3
5
  import mimetypes
6
+ import os
4
7
  import re
5
8
  import shutil
6
9
  import urllib.parse
7
- from typing import Any, Optional, Union
8
-
9
- from chainlit.oauth_providers import get_oauth_provider
10
- from chainlit.secret import random_secret
11
-
12
- mimetypes.add_type("application/javascript", ".js")
13
- mimetypes.add_type("text/css", ".css")
14
-
15
- import asyncio
16
- import os
17
10
  import webbrowser
18
11
  from contextlib import asynccontextmanager
19
12
  from pathlib import Path
13
+ from typing import List, Optional, Union, cast
14
+
15
+ import socketio
16
+ from fastapi import (
17
+ APIRouter,
18
+ Depends,
19
+ FastAPI,
20
+ Form,
21
+ HTTPException,
22
+ Query,
23
+ Request,
24
+ Response,
25
+ UploadFile,
26
+ status,
27
+ )
28
+ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
29
+ from fastapi.security import OAuth2PasswordRequestForm
30
+ from starlette.datastructures import URL
31
+ from starlette.middleware.cors import CORSMiddleware
32
+ from typing_extensions import Annotated
33
+ from watchfiles import awatch
20
34
 
21
- from chainlit.auth import create_jwt, get_configuration, get_current_user
35
+ from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
36
+ from chainlit.auth.cookie import (
37
+ clear_auth_cookie,
38
+ clear_oauth_state_cookie,
39
+ set_auth_cookie,
40
+ set_oauth_state_cookie,
41
+ validate_oauth_state_cookie,
42
+ )
22
43
  from chainlit.config import (
23
44
  APP_ROOT,
24
45
  BACKEND_ROOT,
@@ -27,51 +48,48 @@ from chainlit.config import (
27
48
  PACKAGE_ROOT,
28
49
  config,
29
50
  load_module,
51
+ public_dir,
30
52
  reload_config,
31
53
  )
32
54
  from chainlit.data import get_data_layer
33
55
  from chainlit.data.acl import is_thread_author
34
56
  from chainlit.logger import logger
35
57
  from chainlit.markdown import get_markdown_str
36
- from chainlit.playground.config import get_llm_providers
37
- from chainlit.telemetry import trace_event
58
+ from chainlit.oauth_providers import get_oauth_provider
59
+ from chainlit.secret import random_secret
38
60
  from chainlit.types import (
61
+ CallActionRequest,
62
+ DeleteFeedbackRequest,
39
63
  DeleteThreadRequest,
40
- GenerationRequest,
64
+ ElementRequest,
41
65
  GetThreadsRequest,
42
66
  Theme,
43
67
  UpdateFeedbackRequest,
68
+ UpdateThreadRequest,
44
69
  )
45
70
  from chainlit.user import PersistedUser, User
46
- from fastapi import (
47
- Depends,
48
- FastAPI,
49
- HTTPException,
50
- Query,
51
- Request,
52
- Response,
53
- UploadFile,
54
- status,
55
- )
56
- from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
57
- from fastapi.security import OAuth2PasswordRequestForm
58
- from fastapi.staticfiles import StaticFiles
59
- from fastapi_socketio import SocketManager
60
- from starlette.datastructures import URL
61
- from starlette.middleware.cors import CORSMiddleware
62
- from typing_extensions import Annotated
63
- from watchfiles import awatch
71
+
72
+ from ._utils import is_path_inside
73
+
74
+ mimetypes.add_type("application/javascript", ".js")
75
+ mimetypes.add_type("text/css", ".css")
76
+
77
+ ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
78
+ IS_SUBMOUNT = os.environ.get("CHAINLIT_SUBMOUNT", "") == "true"
79
+ # If the app is a submount, no need to set the prefix
80
+ PREFIX = ROOT_PATH if ROOT_PATH and not IS_SUBMOUNT else ""
64
81
 
65
82
 
66
83
  @asynccontextmanager
67
84
  async def lifespan(app: FastAPI):
85
+ """Context manager to handle app start and shutdown."""
68
86
  host = config.run.host
69
87
  port = config.run.port
70
88
 
71
89
  if host == DEFAULT_HOST:
72
- url = f"http://localhost:{port}"
90
+ url = f"http://localhost:{port}{ROOT_PATH}"
73
91
  else:
74
- url = f"http://{host}:{port}"
92
+ url = f"http://{host}:{port}{ROOT_PATH}"
75
93
 
76
94
  logger.info(f"Your app is available at {url}")
77
95
 
@@ -112,22 +130,33 @@ async def lifespan(app: FastAPI):
112
130
  logger.error(f"Error reloading module: {e}")
113
131
 
114
132
  await asyncio.sleep(1)
115
- await socket.emit("reload", {})
133
+ await sio.emit("reload", {})
116
134
 
117
135
  break
118
136
 
119
137
  watch_task = asyncio.create_task(watch_files_for_changes())
120
138
 
139
+ discord_task = None
140
+
141
+ if discord_bot_token := os.environ.get("DISCORD_BOT_TOKEN"):
142
+ from chainlit.discord.app import client
143
+
144
+ discord_task = asyncio.create_task(client.start(discord_bot_token))
145
+
121
146
  try:
122
147
  yield
123
148
  finally:
124
- if watch_task:
125
- try:
149
+ try:
150
+ if watch_task:
126
151
  stop_event.set()
127
152
  watch_task.cancel()
128
153
  await watch_task
129
- except asyncio.exceptions.CancelledError:
130
- pass
154
+
155
+ if discord_task:
156
+ discord_task.cancel()
157
+ await discord_task
158
+ except asyncio.exceptions.CancelledError:
159
+ pass
131
160
 
132
161
  if FILES_DIRECTORY.is_dir():
133
162
  shutil.rmtree(FILES_DIRECTORY)
@@ -136,10 +165,26 @@ async def lifespan(app: FastAPI):
136
165
  os._exit(0)
137
166
 
138
167
 
139
- def get_build_dir(local_target: str, packaged_target: str):
168
+ def get_build_dir(local_target: str, packaged_target: str) -> str:
169
+ """
170
+ Get the build directory based on the UI build strategy.
171
+
172
+ Args:
173
+ local_target (str): The local target directory.
174
+ packaged_target (str): The packaged target directory.
175
+
176
+ Returns:
177
+ str: The build directory
178
+ """
179
+
140
180
  local_build_dir = os.path.join(PACKAGE_ROOT, local_target, "dist")
141
181
  packaged_build_dir = os.path.join(BACKEND_ROOT, packaged_target, "dist")
142
- if os.path.exists(local_build_dir):
182
+
183
+ if config.ui.custom_build and os.path.exists(
184
+ os.path.join(APP_ROOT, config.ui.custom_build)
185
+ ):
186
+ return os.path.join(APP_ROOT, config.ui.custom_build)
187
+ elif os.path.exists(local_build_dir):
143
188
  return local_build_dir
144
189
  elif os.path.exists(packaged_build_dir):
145
190
  return packaged_build_dir
@@ -150,28 +195,16 @@ def get_build_dir(local_target: str, packaged_target: str):
150
195
  build_dir = get_build_dir("frontend", "frontend")
151
196
  copilot_build_dir = get_build_dir(os.path.join("libs", "copilot"), "copilot")
152
197
 
153
-
154
198
  app = FastAPI(lifespan=lifespan)
155
199
 
156
- app.mount("/public", StaticFiles(directory="public", check_dir=False), name="public")
157
- app.mount(
158
- "/assets",
159
- StaticFiles(
160
- packages=[("chainlit", os.path.join(build_dir, "assets"))],
161
- follow_symlink=config.project.follow_symlink,
162
- ),
163
- name="assets",
164
- )
200
+ sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
165
201
 
166
- app.mount(
167
- "/copilot",
168
- StaticFiles(
169
- packages=[("chainlit", copilot_build_dir)],
170
- follow_symlink=config.project.follow_symlink,
171
- ),
172
- name="copilot",
202
+ asgi_app = socketio.ASGIApp(
203
+ socketio_server=sio,
204
+ socketio_path="",
173
205
  )
174
206
 
207
+ app.mount(f"{PREFIX}/ws/socket.io", asgi_app)
175
208
 
176
209
  app.add_middleware(
177
210
  CORSMiddleware,
@@ -181,11 +214,91 @@ app.add_middleware(
181
214
  allow_headers=["*"],
182
215
  )
183
216
 
184
- socket = SocketManager(
185
- app,
186
- cors_allowed_origins=[],
187
- async_mode="asgi",
188
- )
217
+ router = APIRouter(prefix=PREFIX)
218
+
219
+
220
+ @router.get("/public/{filename:path}")
221
+ async def serve_public_file(
222
+ filename: str,
223
+ ):
224
+ """Serve a file from public dir."""
225
+
226
+ base_path = Path(public_dir)
227
+ file_path = (base_path / filename).resolve()
228
+
229
+ if not is_path_inside(file_path, base_path):
230
+ raise HTTPException(status_code=400, detail="Invalid filename")
231
+
232
+ if file_path.is_file():
233
+ return FileResponse(file_path)
234
+ else:
235
+ raise HTTPException(status_code=404, detail="File not found")
236
+
237
+
238
+ @router.get("/assets/{filename:path}")
239
+ async def serve_asset_file(
240
+ filename: str,
241
+ ):
242
+ """Serve a file from assets dir."""
243
+
244
+ base_path = Path(os.path.join(build_dir, "assets"))
245
+ file_path = (base_path / filename).resolve()
246
+
247
+ if not is_path_inside(file_path, base_path):
248
+ raise HTTPException(status_code=400, detail="Invalid filename")
249
+
250
+ if file_path.is_file():
251
+ return FileResponse(file_path)
252
+ else:
253
+ raise HTTPException(status_code=404, detail="File not found")
254
+
255
+
256
+ @router.get("/copilot/{filename:path}")
257
+ async def serve_copilot_file(
258
+ filename: str,
259
+ ):
260
+ """Serve a file from assets dir."""
261
+
262
+ base_path = Path(copilot_build_dir)
263
+ file_path = (base_path / filename).resolve()
264
+
265
+ if not is_path_inside(file_path, base_path):
266
+ raise HTTPException(status_code=400, detail="Invalid filename")
267
+
268
+ if file_path.is_file():
269
+ return FileResponse(file_path)
270
+ else:
271
+ raise HTTPException(status_code=404, detail="File not found")
272
+
273
+
274
+ # -------------------------------------------------------------------------------
275
+ # SLACK HANDLER
276
+ # -------------------------------------------------------------------------------
277
+
278
+ if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_SIGNING_SECRET"):
279
+ from chainlit.slack.app import slack_app_handler
280
+
281
+ @router.post("/slack/events")
282
+ async def slack_endpoint(req: Request):
283
+ return await slack_app_handler.handle(req)
284
+
285
+
286
+ # -------------------------------------------------------------------------------
287
+ # TEAMS HANDLER
288
+ # -------------------------------------------------------------------------------
289
+
290
+ if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
291
+ from botbuilder.schema import Activity
292
+
293
+ from chainlit.teams.app import adapter, bot
294
+
295
+ @router.post("/teams/events")
296
+ async def teams_endpoint(req: Request):
297
+ body = await req.json()
298
+ activity = Activity().deserialize(body)
299
+ auth_header = req.headers.get("Authorization", "")
300
+ response = await adapter.process_activity(activity, auth_header, bot.on_turn)
301
+ return response
189
302
 
190
303
 
191
304
  # -------------------------------------------------------------------------------
@@ -193,28 +306,55 @@ socket = SocketManager(
193
306
  # -------------------------------------------------------------------------------
194
307
 
195
308
 
196
- def replace_between_tags(text: str, start_tag: str, end_tag: str, replacement: str):
309
+ def replace_between_tags(
310
+ text: str, start_tag: str, end_tag: str, replacement: str
311
+ ) -> str:
312
+ """Replace text between two tags in a string."""
313
+
197
314
  pattern = start_tag + ".*?" + end_tag
198
315
  return re.sub(pattern, start_tag + replacement + end_tag, text, flags=re.DOTALL)
199
316
 
200
317
 
201
318
  def get_html_template():
319
+ """
320
+ Get HTML template for the index view.
321
+ """
322
+ ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
323
+
324
+ custom_theme = None
325
+ custom_theme_file_path = Path(public_dir) / "theme.json"
326
+ if (
327
+ is_path_inside(custom_theme_file_path, Path(public_dir))
328
+ and custom_theme_file_path.is_file()
329
+ ):
330
+ custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))
331
+
202
332
  PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
203
333
  JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
204
334
  CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
205
335
 
206
336
  default_url = "https://github.com/Chainlit/chainlit"
337
+ default_meta_image_url = (
338
+ "https://chainlit-cloud.s3.eu-west-3.amazonaws.com/logo/chainlit_banner.png"
339
+ )
207
340
  url = config.ui.github or default_url
341
+ meta_image_url = config.ui.custom_meta_image_url or default_meta_image_url
342
+ favicon_path = "/favicon"
208
343
 
209
344
  tags = f"""<title>{config.ui.name}</title>
345
+ <link rel="icon" href="{favicon_path}" />
210
346
  <meta name="description" content="{config.ui.description}">
211
347
  <meta property="og:type" content="website">
212
348
  <meta property="og:title" content="{config.ui.name}">
213
349
  <meta property="og:description" content="{config.ui.description}">
214
- <meta property="og:image" content="https://chainlit-cloud.s3.eu-west-3.amazonaws.com/logo/chainlit_banner.png">
215
- <meta property="og:url" content="{url}">"""
350
+ <meta property="og:image" content="{meta_image_url}">
351
+ <meta property="og:url" content="{url}">
352
+ <meta property="og:root_path" content="{ROOT_PATH}">"""
216
353
 
217
- js = f"""<script>{f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}</script>"""
354
+ js = f"""<script>
355
+ {f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
356
+ {f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
357
+ </script>"""
218
358
 
219
359
  css = None
220
360
  if config.ui.custom_css:
@@ -226,12 +366,15 @@ def get_html_template():
226
366
  js += f"""<script src="{config.ui.custom_js}" defer></script>"""
227
367
 
228
368
  font = None
229
- if config.ui.custom_font:
230
- font = f"""<link rel="stylesheet" href="{config.ui.custom_font}">"""
369
+ if custom_theme and custom_theme.get("custom_fonts"):
370
+ font = "\n".join(
371
+ f"""<link rel="stylesheet" href="{font}">"""
372
+ for font in custom_theme.get("custom_fonts")
373
+ )
231
374
 
232
375
  index_html_file_path = os.path.join(build_dir, "index.html")
233
376
 
234
- with open(index_html_file_path, "r", encoding="utf-8") as f:
377
+ with open(index_html_file_path, encoding="utf-8") as f:
235
378
  content = f.read()
236
379
  content = content.replace(PLACEHOLDER, tags)
237
380
  if js:
@@ -242,6 +385,9 @@ def get_html_template():
242
385
  content = replace_between_tags(
243
386
  content, "<!-- FONT START -->", "<!-- FONT END -->", font
244
387
  )
388
+ if ROOT_PATH:
389
+ content = content.replace('href="/', f'href="{ROOT_PATH}/')
390
+ content = content.replace('src="/', f'src="{ROOT_PATH}/')
245
391
  return content
246
392
 
247
393
 
@@ -250,7 +396,6 @@ def get_user_facing_url(url: URL):
250
396
  Return the user facing URL for a given URL.
251
397
  Handles deployment with proxies (like cloud run).
252
398
  """
253
-
254
399
  chainlit_url = os.environ.get("CHAINLIT_URL")
255
400
 
256
401
  # No config, we keep the URL as is
@@ -269,49 +414,140 @@ def get_user_facing_url(url: URL):
269
414
  return config_url.__str__() + url.path
270
415
 
271
416
 
272
- @app.get("/auth/config")
417
+ @router.get("/auth/config")
273
418
  async def auth(request: Request):
274
419
  return get_configuration()
275
420
 
276
421
 
277
- @app.post("/login")
278
- async def login(form_data: OAuth2PasswordRequestForm = Depends()):
279
- if not config.code.password_auth_callback:
280
- raise HTTPException(
281
- status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
422
+ def _get_response_dict(access_token: str) -> dict:
423
+ """Get the response dictionary for the auth response."""
424
+
425
+ return {"success": True}
426
+
427
+
428
+ def _get_auth_response(access_token: str, redirect_to_callback: bool) -> Response:
429
+ """Get the redirect params for the OAuth callback."""
430
+
431
+ response_dict = _get_response_dict(access_token)
432
+
433
+ if redirect_to_callback:
434
+ root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
435
+ redirect_url = (
436
+ f"{root_path}/login/callback?{urllib.parse.urlencode(response_dict)}"
282
437
  )
283
438
 
284
- user = await config.code.password_auth_callback(
285
- form_data.username, form_data.password
439
+ return RedirectResponse(
440
+ # FIXME: redirect to the right frontend base url to improve the dev environment
441
+ url=redirect_url,
442
+ status_code=302,
443
+ )
444
+
445
+ return JSONResponse(response_dict)
446
+
447
+
448
+ def _get_oauth_redirect_error(error: str) -> Response:
449
+ """Get the redirect response for an OAuth error."""
450
+ params = urllib.parse.urlencode(
451
+ {
452
+ "error": error,
453
+ }
286
454
  )
455
+ response = RedirectResponse(
456
+ # FIXME: redirect to the right frontend base url to improve the dev environment
457
+ url=f"/login?{params}", # Shouldn't there be {root_path} here?
458
+ )
459
+ return response
460
+
461
+
462
+ async def _authenticate_user(
463
+ user: Optional[User], redirect_to_callback: bool = False
464
+ ) -> Response:
465
+ """Authenticate a user and return the response."""
287
466
 
288
467
  if not user:
289
468
  raise HTTPException(
290
469
  status_code=status.HTTP_401_UNAUTHORIZED,
291
470
  detail="credentialssignin",
292
471
  )
293
- access_token = create_jwt(user)
472
+
473
+ # If a data layer is defined, attempt to persist user.
294
474
  if data_layer := get_data_layer():
295
475
  try:
296
476
  await data_layer.create_user(user)
297
477
  except Exception as e:
478
+ # Catch and log exceptions during user creation.
479
+ # TODO: Make this catch only specific errors and allow others to propagate.
298
480
  logger.error(f"Error creating user: {e}")
299
481
 
300
- return {
301
- "access_token": access_token,
302
- "token_type": "bearer",
303
- }
482
+ access_token = create_jwt(user)
483
+
484
+ response = _get_auth_response(access_token, redirect_to_callback)
485
+
486
+ set_auth_cookie(response, access_token)
487
+
488
+ return response
489
+
304
490
 
491
+ @router.post("/login")
492
+ async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()):
493
+ """
494
+ Login a user using the password auth callback.
495
+ """
496
+ if not config.code.password_auth_callback:
497
+ raise HTTPException(
498
+ status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
499
+ )
500
+
501
+ user = await config.code.password_auth_callback(
502
+ form_data.username, form_data.password
503
+ )
504
+
505
+ return await _authenticate_user(user)
305
506
 
306
- @app.post("/logout")
507
+
508
+ @router.post("/logout")
307
509
  async def logout(request: Request, response: Response):
510
+ """Logout the user by calling the on_logout callback."""
511
+ clear_auth_cookie(response)
512
+
308
513
  if config.code.on_logout:
309
514
  return await config.code.on_logout(request, response)
515
+
310
516
  return {"success": True}
311
517
 
312
518
 
313
- @app.post("/auth/header")
519
+ @router.post("/auth/jwt")
520
+ async def jwt_auth(request: Request):
521
+ """Login a user using a valid jwt."""
522
+ from jwt import InvalidTokenError
523
+
524
+ auth_header: Optional[str] = request.headers.get("Authorization")
525
+ if not auth_header:
526
+ raise HTTPException(status_code=401, detail="Authorization header missing")
527
+
528
+ # Check if it starts with "Bearer "
529
+ try:
530
+ scheme, token = auth_header.split()
531
+ if scheme.lower() != "bearer":
532
+ raise HTTPException(
533
+ status_code=401,
534
+ detail="Invalid authentication scheme. Please use Bearer",
535
+ )
536
+ except ValueError:
537
+ raise HTTPException(
538
+ status_code=401, detail="Invalid authorization header format"
539
+ )
540
+
541
+ try:
542
+ user = decode_jwt(token)
543
+ return await _authenticate_user(user)
544
+ except InvalidTokenError:
545
+ raise HTTPException(status_code=401, detail="Invalid token")
546
+
547
+
548
+ @router.post("/auth/header")
314
549
  async def header_auth(request: Request):
550
+ """Login a user using the header_auth_callback."""
315
551
  if not config.code.header_auth_callback:
316
552
  raise HTTPException(
317
553
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -320,27 +556,12 @@ async def header_auth(request: Request):
320
556
 
321
557
  user = await config.code.header_auth_callback(request.headers)
322
558
 
323
- if not user:
324
- raise HTTPException(
325
- status_code=status.HTTP_401_UNAUTHORIZED,
326
- detail="Unauthorized",
327
- )
328
-
329
- access_token = create_jwt(user)
330
- if data_layer := get_data_layer():
331
- try:
332
- await data_layer.create_user(user)
333
- except Exception as e:
334
- logger.error(f"Error creating user: {e}")
335
-
336
- return {
337
- "access_token": access_token,
338
- "token_type": "bearer",
339
- }
559
+ return await _authenticate_user(user)
340
560
 
341
561
 
342
- @app.get("/auth/oauth/{provider_id}")
562
+ @router.get("/auth/oauth/{provider_id}")
343
563
  async def oauth_login(provider_id: str, request: Request):
564
+ """Redirect the user to the oauth provider login page."""
344
565
  if config.code.oauth_callback is None:
345
566
  raise HTTPException(
346
567
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -367,20 +588,13 @@ async def oauth_login(provider_id: str, request: Request):
367
588
  response = RedirectResponse(
368
589
  url=f"{provider.authorize_url}?{params}",
369
590
  )
370
- samesite = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax") # type: Any
371
- secure = samesite.lower() == "none"
372
- response.set_cookie(
373
- "oauth_state",
374
- random,
375
- httponly=True,
376
- samesite=samesite,
377
- secure=secure,
378
- max_age=3 * 60,
379
- )
591
+
592
+ set_oauth_state_cookie(response, random)
593
+
380
594
  return response
381
595
 
382
596
 
383
- @app.get("/auth/oauth/{provider_id}/callback")
597
+ @router.get("/auth/oauth/{provider_id}/callback")
384
598
  async def oauth_callback(
385
599
  provider_id: str,
386
600
  request: Request,
@@ -388,6 +602,8 @@ async def oauth_callback(
388
602
  code: Optional[str] = None,
389
603
  state: Optional[str] = None,
390
604
  ):
605
+ """Handle the oauth callback and login the user."""
606
+
391
607
  if config.code.oauth_callback is None:
392
608
  raise HTTPException(
393
609
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -402,16 +618,7 @@ async def oauth_callback(
402
618
  )
403
619
 
404
620
  if error:
405
- params = urllib.parse.urlencode(
406
- {
407
- "error": error,
408
- }
409
- )
410
- response = RedirectResponse(
411
- # FIXME: redirect to the right frontend base url to improve the dev environment
412
- url=f"/login?{params}",
413
- )
414
- return response
621
+ return _get_oauth_redirect_error(error)
415
622
 
416
623
  if not code or not state:
417
624
  raise HTTPException(
@@ -419,9 +626,11 @@ async def oauth_callback(
419
626
  detail="Missing code or state",
420
627
  )
421
628
 
422
- # Check the state from the oauth provider against the browser cookie
423
- oauth_state = request.cookies.get("oauth_state")
424
- if oauth_state != state:
629
+ try:
630
+ validate_oauth_state_cookie(request, state)
631
+ except Exception as e:
632
+ logger.exception("Unable to validate oauth state: %1", e)
633
+
425
634
  raise HTTPException(
426
635
  status_code=status.HTTP_401_UNAUTHORIZED,
427
636
  detail="Unauthorized",
@@ -436,83 +645,128 @@ async def oauth_callback(
436
645
  provider_id, token, raw_user_data, default_user
437
646
  )
438
647
 
439
- if not user:
440
- raise HTTPException(
441
- status_code=status.HTTP_401_UNAUTHORIZED,
442
- detail="Unauthorized",
443
- )
444
-
445
- access_token = create_jwt(user)
648
+ response = await _authenticate_user(user, redirect_to_callback=True)
446
649
 
447
- if data_layer := get_data_layer():
448
- try:
449
- await data_layer.create_user(user)
450
- except Exception as e:
451
- logger.error(f"Error creating user: {e}")
650
+ clear_oauth_state_cookie(response)
452
651
 
453
- params = urllib.parse.urlencode(
454
- {
455
- "access_token": access_token,
456
- "token_type": "bearer",
457
- }
458
- )
459
- response = RedirectResponse(
460
- # FIXME: redirect to the right frontend base url to improve the dev environment
461
- url=f"/login/callback?{params}",
462
- )
463
- response.delete_cookie("oauth_state")
464
652
  return response
465
653
 
466
654
 
467
- @app.post("/generation")
468
- async def generation(
469
- request: GenerationRequest,
470
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
655
+ # specific route for azure ad hybrid flow
656
+ @router.post("/auth/oauth/azure-ad-hybrid/callback")
657
+ async def oauth_azure_hf_callback(
658
+ request: Request,
659
+ error: Optional[str] = None,
660
+ code: Annotated[Optional[str], Form()] = None,
661
+ id_token: Annotated[Optional[str], Form()] = None,
471
662
  ):
472
- """Handle a completion request from the prompt playground."""
663
+ """Handle the azure ad hybrid flow callback and login the user."""
473
664
 
474
- providers = get_llm_providers()
665
+ provider_id = "azure-ad-hybrid"
666
+ if config.code.oauth_callback is None:
667
+ raise HTTPException(
668
+ status_code=status.HTTP_400_BAD_REQUEST,
669
+ detail="No oauth_callback defined",
670
+ )
475
671
 
476
- try:
477
- provider = [p for p in providers if p.id == request.generation.provider][0]
478
- except IndexError:
672
+ provider = get_oauth_provider(provider_id)
673
+ if not provider:
479
674
  raise HTTPException(
480
- status_code=404,
481
- detail=f"LLM provider '{request.generation.provider}' not found",
675
+ status_code=status.HTTP_404_NOT_FOUND,
676
+ detail=f"Provider {provider_id} not found",
677
+ )
678
+
679
+ if error:
680
+ return _get_oauth_redirect_error(error)
681
+
682
+ if not code:
683
+ raise HTTPException(
684
+ status_code=status.HTTP_400_BAD_REQUEST,
685
+ detail="Missing code",
482
686
  )
483
687
 
484
- trace_event("pp_create_completion")
485
- response = await provider.create_completion(request)
688
+ url = get_user_facing_url(request.url)
689
+ token = await provider.get_token(code, url)
690
+
691
+ (raw_user_data, default_user) = await provider.get_user_info(token)
692
+
693
+ user = await config.code.oauth_callback(
694
+ provider_id, token, raw_user_data, default_user, id_token
695
+ )
696
+
697
+ response = await _authenticate_user(user, redirect_to_callback=True)
698
+
699
+ clear_oauth_state_cookie(response)
486
700
 
487
701
  return response
488
702
 
489
703
 
490
- @app.get("/project/llm-providers")
491
- async def get_providers(
492
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)]
704
+ GenericUser = Union[User, PersistedUser, None]
705
+ UserParam = Annotated[GenericUser, Depends(get_current_user)]
706
+
707
+
708
+ @router.get("/user")
709
+ async def get_user(current_user: UserParam) -> GenericUser:
710
+ return current_user
711
+
712
+
713
+ _language_pattern = (
714
+ "^[a-zA-Z]{2,3}(-[a-zA-Z0-9]{2,3})?(-[a-zA-Z0-9]{2,8})?(-x-[a-zA-Z0-9]{1,8})?$"
715
+ )
716
+
717
+
718
+ @router.get("/project/translations")
719
+ async def project_translations(
720
+ language: str = Query(
721
+ default="en-US", description="Language code", pattern=_language_pattern
722
+ ),
493
723
  ):
494
- """List the providers."""
495
- trace_event("pp_get_llm_providers")
496
- providers = get_llm_providers()
497
- providers = [p.to_dict() for p in providers]
498
- return JSONResponse(content={"providers": providers})
724
+ """Return project translations."""
725
+
726
+ # Load translation based on the provided language
727
+ translation = config.load_translation(language)
728
+
729
+ return JSONResponse(
730
+ content={
731
+ "translation": translation,
732
+ }
733
+ )
499
734
 
500
735
 
501
- @app.get("/project/settings")
736
+ @router.get("/project/settings")
502
737
  async def project_settings(
503
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
504
- language: str = Query(default="en-US", description="Language code"),
738
+ current_user: UserParam,
739
+ language: str = Query(
740
+ default="en-US", description="Language code", pattern=_language_pattern
741
+ ),
505
742
  ):
506
743
  """Return project settings. This is called by the UI before the establishing the websocket connection."""
507
744
 
508
- # Load translation based on the provided language
509
- translation = config.load_translation(language)
745
+ # Load the markdown file based on the provided language
746
+
747
+ markdown = get_markdown_str(config.root, language)
510
748
 
511
749
  profiles = []
512
750
  if config.code.set_chat_profiles:
513
751
  chat_profiles = await config.code.set_chat_profiles(current_user)
514
752
  if chat_profiles:
515
753
  profiles = [p.to_dict() for p in chat_profiles]
754
+
755
+ starters = []
756
+ if config.code.set_starters:
757
+ starters = await config.code.set_starters(current_user)
758
+ if starters:
759
+ starters = [s.to_dict() for s in starters]
760
+
761
+ if config.code.on_audio_chunk:
762
+ config.features.audio.enabled = True
763
+
764
+ debug_url = None
765
+ data_layer = get_data_layer()
766
+
767
+ if data_layer and config.run.debug:
768
+ debug_url = await data_layer.build_debug_url()
769
+
516
770
  return JSONResponse(
517
771
  content={
518
772
  "ui": config.ui.to_dict(),
@@ -520,18 +774,19 @@ async def project_settings(
520
774
  "userEnv": config.project.user_env,
521
775
  "dataPersistence": get_data_layer() is not None,
522
776
  "threadResumable": bool(config.code.on_chat_resume),
523
- "markdown": get_markdown_str(config.root),
777
+ "markdown": markdown,
524
778
  "chatProfiles": profiles,
525
- "translation": translation,
779
+ "starters": starters,
780
+ "debugUrl": debug_url,
526
781
  }
527
782
  )
528
783
 
529
784
 
530
- @app.put("/feedback")
785
+ @router.put("/feedback")
531
786
  async def update_feedback(
532
787
  request: Request,
533
788
  update: UpdateFeedbackRequest,
534
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
789
+ current_user: UserParam,
535
790
  ):
536
791
  """Update the human feedback for a particular message."""
537
792
  data_layer = get_data_layer()
@@ -541,36 +796,63 @@ async def update_feedback(
541
796
  try:
542
797
  feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)
543
798
  except Exception as e:
544
- raise HTTPException(detail=str(e), status_code=500)
799
+ raise HTTPException(detail=str(e), status_code=500) from e
545
800
 
546
801
  return JSONResponse(content={"success": True, "feedbackId": feedback_id})
547
802
 
548
803
 
549
- @app.post("/project/threads")
804
+ @router.delete("/feedback")
805
+ async def delete_feedback(
806
+ request: Request,
807
+ payload: DeleteFeedbackRequest,
808
+ current_user: UserParam,
809
+ ):
810
+ """Delete a feedback."""
811
+
812
+ data_layer = get_data_layer()
813
+
814
+ if not data_layer:
815
+ raise HTTPException(status_code=400, detail="Data persistence is not enabled")
816
+
817
+ feedback_id = payload.feedbackId
818
+
819
+ await data_layer.delete_feedback(feedback_id)
820
+ return JSONResponse(content={"success": True})
821
+
822
+
823
+ @router.post("/project/threads")
550
824
  async def get_user_threads(
551
825
  request: Request,
552
826
  payload: GetThreadsRequest,
553
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
827
+ current_user: UserParam,
554
828
  ):
555
829
  """Get the threads page by page."""
556
- # Only show the current user threads
557
830
 
558
831
  data_layer = get_data_layer()
559
832
 
560
833
  if not data_layer:
561
834
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
562
835
 
563
- payload.filter.userIdentifier = current_user.identifier
836
+ if not current_user:
837
+ raise HTTPException(status_code=401, detail="Unauthorized")
838
+
839
+ if not isinstance(current_user, PersistedUser):
840
+ persisted_user = await data_layer.get_user(identifier=current_user.identifier)
841
+ if not persisted_user:
842
+ raise HTTPException(status_code=404, detail="User not found")
843
+ payload.filter.userId = persisted_user.id
844
+ else:
845
+ payload.filter.userId = current_user.id
564
846
 
565
847
  res = await data_layer.list_threads(payload.pagination, payload.filter)
566
848
  return JSONResponse(content=res.to_dict())
567
849
 
568
850
 
569
- @app.get("/project/thread/{thread_id}")
851
+ @router.get("/project/thread/{thread_id}")
570
852
  async def get_thread(
571
853
  request: Request,
572
854
  thread_id: str,
573
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
855
+ current_user: UserParam,
574
856
  ):
575
857
  """Get a specific thread."""
576
858
  data_layer = get_data_layer()
@@ -578,18 +860,21 @@ async def get_thread(
578
860
  if not data_layer:
579
861
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
580
862
 
863
+ if not current_user:
864
+ raise HTTPException(status_code=401, detail="Unauthorized")
865
+
581
866
  await is_thread_author(current_user.identifier, thread_id)
582
867
 
583
868
  res = await data_layer.get_thread(thread_id)
584
869
  return JSONResponse(content=res)
585
870
 
586
871
 
587
- @app.get("/project/thread/{thread_id}/element/{element_id}")
872
+ @router.get("/project/thread/{thread_id}/element/{element_id}")
588
873
  async def get_thread_element(
589
874
  request: Request,
590
875
  thread_id: str,
591
876
  element_id: str,
592
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
877
+ current_user: UserParam,
593
878
  ):
594
879
  """Get a specific thread element."""
595
880
  data_layer = get_data_layer()
@@ -597,17 +882,135 @@ async def get_thread_element(
597
882
  if not data_layer:
598
883
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
599
884
 
885
+ if not current_user:
886
+ raise HTTPException(status_code=401, detail="Unauthorized")
887
+
600
888
  await is_thread_author(current_user.identifier, thread_id)
601
889
 
602
890
  res = await data_layer.get_element(thread_id, element_id)
603
891
  return JSONResponse(content=res)
604
892
 
605
893
 
606
- @app.delete("/project/thread")
894
+ @router.put("/project/element")
895
+ async def update_thread_element(
896
+ payload: ElementRequest,
897
+ current_user: UserParam,
898
+ ):
899
+ """Update a specific thread element."""
900
+
901
+ from chainlit.context import init_ws_context
902
+ from chainlit.element import CustomElement, ElementDict
903
+ from chainlit.session import WebsocketSession
904
+
905
+ session = WebsocketSession.get_by_id(payload.sessionId)
906
+ context = init_ws_context(session)
907
+
908
+ element_dict = cast(ElementDict, payload.element)
909
+
910
+ if element_dict["type"] != "custom":
911
+ return {"success": False}
912
+
913
+ element = CustomElement(
914
+ id=element_dict["id"],
915
+ object_key=element_dict["objectKey"],
916
+ chainlit_key=element_dict["chainlitKey"],
917
+ url=element_dict["url"],
918
+ for_id=element_dict.get("forId") or "",
919
+ thread_id=element_dict.get("threadId") or "",
920
+ name=element_dict["name"],
921
+ props=element_dict.get("props") or {},
922
+ display=element_dict["display"],
923
+ )
924
+
925
+ if current_user:
926
+ if (
927
+ not context.session.user
928
+ or context.session.user.identifier != current_user.identifier
929
+ ):
930
+ raise HTTPException(
931
+ status_code=401,
932
+ detail="You are not authorized to update elements for this session",
933
+ )
934
+
935
+ await element.update()
936
+ return {"success": True}
937
+
938
+
939
+ @router.delete("/project/element")
940
+ async def delete_thread_element(
941
+ payload: ElementRequest,
942
+ current_user: UserParam,
943
+ ):
944
+ """Delete a specific thread element."""
945
+
946
+ from chainlit.context import init_ws_context
947
+ from chainlit.element import CustomElement, ElementDict
948
+ from chainlit.session import WebsocketSession
949
+
950
+ session = WebsocketSession.get_by_id(payload.sessionId)
951
+ context = init_ws_context(session)
952
+
953
+ element_dict = cast(ElementDict, payload.element)
954
+
955
+ if element_dict["type"] != "custom":
956
+ return {"success": False}
957
+
958
+ element = CustomElement(
959
+ id=element_dict["id"],
960
+ object_key=element_dict["objectKey"],
961
+ chainlit_key=element_dict["chainlitKey"],
962
+ url=element_dict["url"],
963
+ for_id=element_dict.get("forId") or "",
964
+ thread_id=element_dict.get("threadId") or "",
965
+ name=element_dict["name"],
966
+ props=element_dict.get("props") or {},
967
+ display=element_dict["display"],
968
+ )
969
+
970
+ if current_user:
971
+ if (
972
+ not context.session.user
973
+ or context.session.user.identifier != current_user.identifier
974
+ ):
975
+ raise HTTPException(
976
+ status_code=401,
977
+ detail="You are not authorized to remove elements for this session",
978
+ )
979
+
980
+ await element.remove()
981
+
982
+ return {"success": True}
983
+
984
+
985
+ @router.put("/project/thread")
986
+ async def rename_thread(
987
+ request: Request,
988
+ payload: UpdateThreadRequest,
989
+ current_user: UserParam,
990
+ ):
991
+ """Rename a thread."""
992
+
993
+ data_layer = get_data_layer()
994
+
995
+ if not data_layer:
996
+ raise HTTPException(status_code=400, detail="Data persistence is not enabled")
997
+
998
+ if not current_user:
999
+ raise HTTPException(status_code=401, detail="Unauthorized")
1000
+
1001
+ thread_id = payload.threadId
1002
+
1003
+ await is_thread_author(current_user.identifier, thread_id)
1004
+
1005
+ await data_layer.update_thread(thread_id, name=payload.name)
1006
+ return JSONResponse(content={"success": True})
1007
+
1008
+
1009
+ @router.delete("/project/thread")
607
1010
  async def delete_thread(
608
1011
  request: Request,
609
1012
  payload: DeleteThreadRequest,
610
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
1013
+ current_user: UserParam,
611
1014
  ):
612
1015
  """Delete a thread."""
613
1016
 
@@ -616,6 +1019,9 @@ async def delete_thread(
616
1019
  if not data_layer:
617
1020
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
618
1021
 
1022
+ if not current_user:
1023
+ raise HTTPException(status_code=401, detail="Unauthorized")
1024
+
619
1025
  thread_id = payload.threadId
620
1026
 
621
1027
  await is_thread_author(current_user.identifier, thread_id)
@@ -624,14 +1030,56 @@ async def delete_thread(
624
1030
  return JSONResponse(content={"success": True})
625
1031
 
626
1032
 
627
- @app.post("/project/file")
1033
+ @router.post("/project/action")
1034
+ async def call_action(
1035
+ payload: CallActionRequest,
1036
+ current_user: UserParam,
1037
+ ):
1038
+ """Run an action."""
1039
+
1040
+ from chainlit.action import Action
1041
+ from chainlit.context import init_ws_context
1042
+ from chainlit.session import WebsocketSession
1043
+
1044
+ session = WebsocketSession.get_by_id(payload.sessionId)
1045
+ context = init_ws_context(session)
1046
+
1047
+ action = Action(**payload.action)
1048
+
1049
+ if current_user:
1050
+ if (
1051
+ not context.session.user
1052
+ or context.session.user.identifier != current_user.identifier
1053
+ ):
1054
+ raise HTTPException(
1055
+ status_code=401,
1056
+ detail="You are not authorized to upload files for this session",
1057
+ )
1058
+
1059
+ callback = config.code.action_callbacks.get(action.name)
1060
+ if callback:
1061
+ if not context.session.has_first_interaction:
1062
+ context.session.has_first_interaction = True
1063
+ asyncio.create_task(context.emitter.init_thread(action.name))
1064
+
1065
+ response = await callback(action)
1066
+ else:
1067
+ raise HTTPException(
1068
+ status_code=404,
1069
+ detail=f"No callback found for action {action.name}",
1070
+ )
1071
+
1072
+ return JSONResponse(content={"success": True, "response": response})
1073
+
1074
+
1075
+ @router.post("/project/file")
628
1076
  async def upload_file(
1077
+ current_user: UserParam,
629
1078
  session_id: str,
630
1079
  file: UploadFile,
631
- current_user: Annotated[
632
- Union[None, User, PersistedUser], Depends(get_current_user)
633
- ],
634
1080
  ):
1081
+ """Upload a file to the session files directory."""
1082
+
635
1083
  from chainlit.session import WebsocketSession
636
1084
 
637
1085
  session = WebsocketSession.get_by_id(session_id)
@@ -653,28 +1101,122 @@ async def upload_file(
653
1101
 
654
1102
  content = await file.read()
655
1103
 
1104
+ assert file.filename, "No filename for uploaded file"
1105
+ assert file.content_type, "No content type for uploaded file"
1106
+
1107
+ try:
1108
+ validate_file_upload(file)
1109
+ except ValueError as e:
1110
+ raise HTTPException(status_code=400, detail=str(e))
1111
+
656
1112
  file_response = await session.persist_file(
657
1113
  name=file.filename, content=content, mime=file.content_type
658
1114
  )
659
1115
 
660
- return JSONResponse(file_response)
1116
+ return JSONResponse(content=file_response)
1117
+
1118
+
1119
+ def validate_file_upload(file: UploadFile):
1120
+ """Validate the file upload as configured in config.features.spontaneous_file_upload.
1121
+ Args:
1122
+ file (UploadFile): The file to validate.
1123
+ Raises:
1124
+ ValueError: If the file is not allowed.
1125
+ """
1126
+ # TODO: This logic/endpoint is shared across spontaneous uploads and the AskFileMessage API.
1127
+ # Commenting this check until we find a better solution
1128
+
1129
+ # if config.features.spontaneous_file_upload is None:
1130
+ # """Default for a missing config is to allow the fileupload without any restrictions"""
1131
+ # return
1132
+ # if not config.features.spontaneous_file_upload.enabled:
1133
+ # raise ValueError("File upload is not enabled")
1134
+
1135
+ validate_file_mime_type(file)
1136
+ validate_file_size(file)
1137
+
661
1138
 
1139
+ def validate_file_mime_type(file: UploadFile):
1140
+ """Validate the file mime type as configured in config.features.spontaneous_file_upload.
1141
+ Args:
1142
+ file (UploadFile): The file to validate.
1143
+ Raises:
1144
+ ValueError: If the file type is not allowed.
1145
+ """
1146
+
1147
+ if (
1148
+ config.features.spontaneous_file_upload is None
1149
+ or config.features.spontaneous_file_upload.accept is None
1150
+ ):
1151
+ "Accept is not configured, allowing all file types"
1152
+ return
1153
+
1154
+ accept = config.features.spontaneous_file_upload.accept
1155
+
1156
+ assert isinstance(accept, List) or isinstance(accept, dict), (
1157
+ "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
1158
+ )
1159
+
1160
+ if isinstance(accept, List):
1161
+ for pattern in accept:
1162
+ if fnmatch.fnmatch(file.content_type, pattern):
1163
+ return
1164
+ elif isinstance(accept, dict):
1165
+ for pattern, extensions in accept.items():
1166
+ if fnmatch.fnmatch(file.content_type, pattern):
1167
+ if len(extensions) == 0:
1168
+ return
1169
+ for extension in extensions:
1170
+ if file.filename is not None and file.filename.endswith(extension):
1171
+ return
1172
+ raise ValueError("File type not allowed")
1173
+
1174
+
1175
+ def validate_file_size(file: UploadFile):
1176
+ """Validate the file size as configured in config.features.spontaneous_file_upload.
1177
+ Args:
1178
+ file (UploadFile): The file to validate.
1179
+ Raises:
1180
+ ValueError: If the file size is too large.
1181
+ """
1182
+ if (
1183
+ config.features.spontaneous_file_upload is None
1184
+ or config.features.spontaneous_file_upload.max_size_mb is None
1185
+ ):
1186
+ return
662
1187
 
663
- @app.get("/project/file/{file_id}")
1188
+ if (
1189
+ file.size is not None
1190
+ and file.size
1191
+ > config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
1192
+ ):
1193
+ raise ValueError("File size too large")
1194
+
1195
+
1196
+ @router.get("/project/file/{file_id}")
664
1197
  async def get_file(
665
1198
  file_id: str,
666
- session_id: Optional[str] = None,
1199
+ session_id: str,
1200
+ current_user: UserParam,
667
1201
  ):
1202
+ """Get a file from the session files directory."""
668
1203
  from chainlit.session import WebsocketSession
669
1204
 
670
1205
  session = WebsocketSession.get_by_id(session_id) if session_id else None
671
1206
 
672
1207
  if not session:
673
1208
  raise HTTPException(
674
- status_code=404,
675
- detail="Session not found",
1209
+ status_code=401,
1210
+ detail="Unauthorized",
676
1211
  )
677
1212
 
1213
+ if current_user:
1214
+ if not session.user or session.user.identifier != current_user.identifier:
1215
+ raise HTTPException(
1216
+ status_code=401,
1217
+ detail="You are not authorized to download files from this session",
1218
+ )
1219
+
678
1220
  if file_id in session.files:
679
1221
  file = session.files[file_id]
680
1222
  return FileResponse(file["path"], media_type=file["type"])
@@ -682,26 +1224,9 @@ async def get_file(
682
1224
  raise HTTPException(status_code=404, detail="File not found")
683
1225
 
684
1226
 
685
- @app.get("/files/{filename:path}")
686
- async def serve_file(
687
- filename: str,
688
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
689
- ):
690
- base_path = Path(config.project.local_fs_path).resolve()
691
- file_path = (base_path / filename).resolve()
692
-
693
- # Check if the base path is a parent of the file path
694
- if base_path not in file_path.parents:
695
- raise HTTPException(status_code=400, detail="Invalid filename")
696
-
697
- if file_path.is_file():
698
- return FileResponse(file_path)
699
- else:
700
- raise HTTPException(status_code=404, detail="File not found")
701
-
702
-
703
- @app.get("/favicon")
1227
+ @router.get("/favicon")
704
1228
  async def get_favicon():
1229
+ """Get the favicon for the UI."""
705
1230
  custom_favicon_path = os.path.join(APP_ROOT, "public", "favicon.*")
706
1231
  files = glob.glob(custom_favicon_path)
707
1232
 
@@ -715,8 +1240,9 @@ async def get_favicon():
715
1240
  return FileResponse(favicon_path, media_type=media_type)
716
1241
 
717
1242
 
718
- @app.get("/logo")
1243
+ @router.get("/logo")
719
1244
  async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
1245
+ """Get the default logo for the UI."""
720
1246
  theme_value = theme.value if theme else Theme.light.value
721
1247
  logo_path = None
722
1248
 
@@ -732,19 +1258,54 @@ async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
732
1258
 
733
1259
  if not logo_path:
734
1260
  raise HTTPException(status_code=404, detail="Missing default logo")
1261
+
735
1262
  media_type, _ = mimetypes.guess_type(logo_path)
736
1263
 
737
1264
  return FileResponse(logo_path, media_type=media_type)
738
1265
 
739
1266
 
740
- def register_wildcard_route_handler():
741
- @app.get("/{path:path}")
742
- async def serve(request: Request, path: str):
743
- html_template = get_html_template()
744
- """Serve the UI files."""
745
- response = HTMLResponse(content=html_template, status_code=200)
1267
+ @router.get("/avatars/{avatar_id:str}")
1268
+ async def get_avatar(avatar_id: str):
1269
+ """Get the avatar for the user based on the avatar_id."""
1270
+ if not re.match(r"^[a-zA-Z0-9_ -]+$", avatar_id):
1271
+ raise HTTPException(status_code=400, detail="Invalid avatar_id")
1272
+
1273
+ if avatar_id == "default":
1274
+ avatar_id = config.ui.name
1275
+
1276
+ avatar_id = avatar_id.strip().lower().replace(" ", "_")
1277
+
1278
+ base_path = Path(APP_ROOT) / "public" / "avatars"
1279
+ avatar_pattern = f"{avatar_id}.*"
1280
+
1281
+ matching_files = base_path.glob(avatar_pattern)
1282
+
1283
+ if avatar_path := next(matching_files, None):
1284
+ if not is_path_inside(avatar_path, base_path):
1285
+ raise HTTPException(status_code=400, detail="Invalid filename")
1286
+
1287
+ media_type, _ = mimetypes.guess_type(str(avatar_path))
1288
+
1289
+ return FileResponse(avatar_path, media_type=media_type)
1290
+
1291
+ return await get_favicon()
1292
+
1293
+
1294
+ @router.head("/")
1295
+ def status_check():
1296
+ """Check if the site is operational."""
1297
+ return {"message": "Site is operational"}
1298
+
1299
+
1300
+ @router.get("/{full_path:path}")
1301
+ async def serve():
1302
+ html_template = get_html_template()
1303
+ """Serve the UI files."""
1304
+ response = HTMLResponse(content=html_template, status_code=200)
1305
+
1306
+ return response
746
1307
 
747
- return response
748
1308
 
1309
+ app.include_router(router)
749
1310
 
750
1311
  import chainlit.socket # noqa