chainlit 2.7.0__py3-none-any.whl → 2.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (85) hide show
  1. {chainlit-2.7.0.dist-info → chainlit-2.7.1.dist-info}/METADATA +1 -1
  2. chainlit-2.7.1.dist-info/RECORD +4 -0
  3. chainlit/__init__.py +0 -207
  4. chainlit/__main__.py +0 -4
  5. chainlit/_utils.py +0 -8
  6. chainlit/action.py +0 -33
  7. chainlit/auth/__init__.py +0 -95
  8. chainlit/auth/cookie.py +0 -197
  9. chainlit/auth/jwt.py +0 -42
  10. chainlit/cache.py +0 -45
  11. chainlit/callbacks.py +0 -433
  12. chainlit/chat_context.py +0 -64
  13. chainlit/chat_settings.py +0 -34
  14. chainlit/cli/__init__.py +0 -235
  15. chainlit/config.py +0 -621
  16. chainlit/context.py +0 -112
  17. chainlit/data/__init__.py +0 -111
  18. chainlit/data/acl.py +0 -19
  19. chainlit/data/base.py +0 -107
  20. chainlit/data/chainlit_data_layer.py +0 -687
  21. chainlit/data/dynamodb.py +0 -616
  22. chainlit/data/literalai.py +0 -501
  23. chainlit/data/sql_alchemy.py +0 -741
  24. chainlit/data/storage_clients/__init__.py +0 -0
  25. chainlit/data/storage_clients/azure.py +0 -84
  26. chainlit/data/storage_clients/azure_blob.py +0 -94
  27. chainlit/data/storage_clients/base.py +0 -28
  28. chainlit/data/storage_clients/gcs.py +0 -101
  29. chainlit/data/storage_clients/s3.py +0 -88
  30. chainlit/data/utils.py +0 -29
  31. chainlit/discord/__init__.py +0 -6
  32. chainlit/discord/app.py +0 -364
  33. chainlit/element.py +0 -454
  34. chainlit/emitter.py +0 -450
  35. chainlit/hello.py +0 -12
  36. chainlit/input_widget.py +0 -182
  37. chainlit/langchain/__init__.py +0 -6
  38. chainlit/langchain/callbacks.py +0 -682
  39. chainlit/langflow/__init__.py +0 -25
  40. chainlit/llama_index/__init__.py +0 -6
  41. chainlit/llama_index/callbacks.py +0 -206
  42. chainlit/logger.py +0 -16
  43. chainlit/markdown.py +0 -57
  44. chainlit/mcp.py +0 -99
  45. chainlit/message.py +0 -619
  46. chainlit/mistralai/__init__.py +0 -50
  47. chainlit/oauth_providers.py +0 -835
  48. chainlit/openai/__init__.py +0 -53
  49. chainlit/py.typed +0 -0
  50. chainlit/secret.py +0 -9
  51. chainlit/semantic_kernel/__init__.py +0 -111
  52. chainlit/server.py +0 -1616
  53. chainlit/session.py +0 -304
  54. chainlit/sidebar.py +0 -55
  55. chainlit/slack/__init__.py +0 -6
  56. chainlit/slack/app.py +0 -427
  57. chainlit/socket.py +0 -381
  58. chainlit/step.py +0 -490
  59. chainlit/sync.py +0 -43
  60. chainlit/teams/__init__.py +0 -6
  61. chainlit/teams/app.py +0 -348
  62. chainlit/translations/bn.json +0 -214
  63. chainlit/translations/el-GR.json +0 -214
  64. chainlit/translations/en-US.json +0 -214
  65. chainlit/translations/fr-FR.json +0 -214
  66. chainlit/translations/gu.json +0 -214
  67. chainlit/translations/he-IL.json +0 -214
  68. chainlit/translations/hi.json +0 -214
  69. chainlit/translations/ja.json +0 -214
  70. chainlit/translations/kn.json +0 -214
  71. chainlit/translations/ml.json +0 -214
  72. chainlit/translations/mr.json +0 -214
  73. chainlit/translations/nl.json +0 -214
  74. chainlit/translations/ta.json +0 -214
  75. chainlit/translations/te.json +0 -214
  76. chainlit/translations/zh-CN.json +0 -214
  77. chainlit/translations.py +0 -60
  78. chainlit/types.py +0 -334
  79. chainlit/user.py +0 -43
  80. chainlit/user_session.py +0 -153
  81. chainlit/utils.py +0 -173
  82. chainlit/version.py +0 -8
  83. chainlit-2.7.0.dist-info/RECORD +0 -84
  84. {chainlit-2.7.0.dist-info → chainlit-2.7.1.dist-info}/WHEEL +0 -0
  85. {chainlit-2.7.0.dist-info → chainlit-2.7.1.dist-info}/entry_points.txt +0 -0
chainlit/server.py DELETED
@@ -1,1616 +0,0 @@
1
- import asyncio
2
- import fnmatch
3
- import glob
4
- import json
5
- import mimetypes
6
- import os
7
- import re
8
- import shutil
9
- import urllib.parse
10
- import webbrowser
11
- from contextlib import AsyncExitStack, asynccontextmanager
12
- from pathlib import Path
13
- from typing import List, Optional, Union, cast
14
-
15
- import socketio
16
- from fastapi import (
17
- APIRouter,
18
- Depends,
19
- FastAPI,
20
- Form,
21
- HTTPException,
22
- Query,
23
- Request,
24
- Response,
25
- UploadFile,
26
- status,
27
- )
28
- from fastapi.middleware.gzip import GZipMiddleware
29
- from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
30
- from fastapi.security import OAuth2PasswordRequestForm
31
- from starlette.datastructures import URL
32
- from starlette.middleware.cors import CORSMiddleware
33
- from starlette.types import Receive, Scope, Send
34
- from typing_extensions import Annotated
35
- from watchfiles import awatch
36
-
37
- from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
38
- from chainlit.auth.cookie import (
39
- clear_auth_cookie,
40
- clear_oauth_state_cookie,
41
- set_auth_cookie,
42
- set_oauth_state_cookie,
43
- validate_oauth_state_cookie,
44
- )
45
- from chainlit.config import (
46
- APP_ROOT,
47
- BACKEND_ROOT,
48
- DEFAULT_HOST,
49
- FILES_DIRECTORY,
50
- PACKAGE_ROOT,
51
- ChainlitConfig,
52
- config,
53
- load_module,
54
- public_dir,
55
- reload_config,
56
- )
57
- from chainlit.data import get_data_layer
58
- from chainlit.data.acl import is_thread_author
59
- from chainlit.logger import logger
60
- from chainlit.markdown import get_markdown_str
61
- from chainlit.oauth_providers import get_oauth_provider
62
- from chainlit.secret import random_secret
63
- from chainlit.types import (
64
- AskFileSpec,
65
- CallActionRequest,
66
- ConnectMCPRequest,
67
- DeleteFeedbackRequest,
68
- DeleteThreadRequest,
69
- DisconnectMCPRequest,
70
- ElementRequest,
71
- GetThreadsRequest,
72
- Theme,
73
- UpdateFeedbackRequest,
74
- UpdateThreadRequest,
75
- )
76
- from chainlit.user import PersistedUser, User
77
-
78
- from ._utils import is_path_inside
79
-
80
- mimetypes.add_type("application/javascript", ".js")
81
- mimetypes.add_type("text/css", ".css")
82
-
83
-
84
- @asynccontextmanager
85
- async def lifespan(app: FastAPI):
86
- """Context manager to handle app start and shutdown."""
87
- if config.code.on_app_startup:
88
- await config.code.on_app_startup()
89
-
90
- host = config.run.host
91
- port = config.run.port
92
- root_path = os.getenv("CHAINLIT_ROOT_PATH", "")
93
-
94
- if host == DEFAULT_HOST:
95
- url = f"http://localhost:{port}{root_path}"
96
- else:
97
- url = f"http://{host}:{port}{root_path}"
98
-
99
- logger.info(f"Your app is available at {url}")
100
-
101
- if not config.run.headless:
102
- # Add a delay before opening the browser
103
- await asyncio.sleep(1)
104
- webbrowser.open(url)
105
-
106
- watch_task = None
107
- stop_event = asyncio.Event()
108
-
109
- if config.run.watch:
110
-
111
- async def watch_files_for_changes():
112
- extensions = [".py"]
113
- files = ["chainlit.md", "config.toml"]
114
- async for changes in awatch(config.root, stop_event=stop_event):
115
- for change_type, file_path in changes:
116
- file_name = os.path.basename(file_path)
117
- file_ext = os.path.splitext(file_name)[1]
118
-
119
- if file_ext.lower() in extensions or file_name.lower() in files:
120
- logger.info(
121
- f"File {change_type.name}: {file_name}. Reloading app..."
122
- )
123
-
124
- try:
125
- reload_config()
126
- except Exception as e:
127
- logger.error(f"Error reloading config: {e}")
128
- break
129
-
130
- # Reload the module if the module name is specified in the config
131
- if config.run.module_name:
132
- try:
133
- load_module(config.run.module_name, force_refresh=True)
134
- except Exception as e:
135
- logger.error(f"Error reloading module: {e}")
136
-
137
- await asyncio.sleep(1)
138
- await sio.emit("reload", {})
139
-
140
- break
141
-
142
- watch_task = asyncio.create_task(watch_files_for_changes())
143
-
144
- discord_task = None
145
-
146
- if discord_bot_token := os.environ.get("DISCORD_BOT_TOKEN"):
147
- from chainlit.discord.app import client
148
-
149
- discord_task = asyncio.create_task(client.start(discord_bot_token))
150
-
151
- slack_task = None
152
-
153
- # Slack Socket Handler if env variable SLACK_WEBSOCKET_TOKEN is set
154
- if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_WEBSOCKET_TOKEN"):
155
- from chainlit.slack.app import start_socket_mode
156
-
157
- slack_task = asyncio.create_task(start_socket_mode())
158
-
159
- try:
160
- yield
161
- finally:
162
- try:
163
- if config.code.on_app_shutdown:
164
- await config.code.on_app_shutdown()
165
-
166
- if watch_task:
167
- stop_event.set()
168
- watch_task.cancel()
169
- await watch_task
170
-
171
- if discord_task:
172
- discord_task.cancel()
173
- await discord_task
174
-
175
- if slack_task:
176
- slack_task.cancel()
177
- await slack_task
178
- except asyncio.exceptions.CancelledError:
179
- pass
180
-
181
- if FILES_DIRECTORY.is_dir():
182
- shutil.rmtree(FILES_DIRECTORY)
183
-
184
- # Force exit the process to avoid potential AnyIO threads still running
185
- os._exit(0)
186
-
187
-
188
- def get_build_dir(local_target: str, packaged_target: str) -> str:
189
- """
190
- Get the build directory based on the UI build strategy.
191
-
192
- Args:
193
- local_target (str): The local target directory.
194
- packaged_target (str): The packaged target directory.
195
-
196
- Returns:
197
- str: The build directory
198
- """
199
-
200
- local_build_dir = os.path.join(PACKAGE_ROOT, local_target, "dist")
201
- packaged_build_dir = os.path.join(BACKEND_ROOT, packaged_target, "dist")
202
-
203
- if config.ui.custom_build and os.path.exists(
204
- os.path.join(APP_ROOT, config.ui.custom_build)
205
- ):
206
- return os.path.join(APP_ROOT, config.ui.custom_build)
207
- elif os.path.exists(local_build_dir):
208
- return local_build_dir
209
- elif os.path.exists(packaged_build_dir):
210
- return packaged_build_dir
211
- else:
212
- raise FileNotFoundError(f"{local_target} built UI dir not found")
213
-
214
-
215
- build_dir = get_build_dir("frontend", "frontend")
216
- copilot_build_dir = get_build_dir(os.path.join("libs", "copilot"), "copilot")
217
-
218
- app = FastAPI(lifespan=lifespan)
219
-
220
- sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
221
-
222
- asgi_app = socketio.ASGIApp(socketio_server=sio, socketio_path="")
223
-
224
- # config.run.root_path is only set when started with --root-path. Not on submounts.
225
- SOCKET_IO_PATH = f"{config.run.root_path}/ws/socket.io"
226
- app.mount(SOCKET_IO_PATH, asgi_app)
227
-
228
- app.add_middleware(
229
- CORSMiddleware,
230
- allow_origins=config.project.allow_origins,
231
- allow_credentials=True,
232
- allow_methods=["*"],
233
- allow_headers=["*"],
234
- )
235
-
236
-
237
- class SafariWebSocketsCompatibleGZipMiddleware(GZipMiddleware):
238
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
239
- if scope["type"] != "http":
240
- return await self.app(scope, receive, send)
241
-
242
- # Prevent gzip compression for HTTP requests to socket.io path due to a bug in Safari
243
- if URL(scope=scope).path.startswith(SOCKET_IO_PATH):
244
- await self.app(scope, receive, send)
245
- else:
246
- await super().__call__(scope, receive, send)
247
-
248
-
249
- app.add_middleware(SafariWebSocketsCompatibleGZipMiddleware)
250
-
251
- # config.run.root_path is only set when started with --root-path. Not on submounts.
252
- router = APIRouter(prefix=config.run.root_path)
253
-
254
-
255
- @router.get("/public/{filename:path}")
256
- async def serve_public_file(
257
- filename: str,
258
- ):
259
- """Serve a file from public dir."""
260
-
261
- base_path = Path(public_dir)
262
- file_path = (base_path / filename).resolve()
263
-
264
- if not is_path_inside(file_path, base_path):
265
- raise HTTPException(status_code=400, detail="Invalid filename")
266
-
267
- if file_path.is_file():
268
- return FileResponse(file_path)
269
- else:
270
- raise HTTPException(status_code=404, detail="File not found")
271
-
272
-
273
- @router.get("/assets/{filename:path}")
274
- async def serve_asset_file(
275
- filename: str,
276
- ):
277
- """Serve a file from assets dir."""
278
-
279
- base_path = Path(os.path.join(build_dir, "assets"))
280
- file_path = (base_path / filename).resolve()
281
-
282
- if not is_path_inside(file_path, base_path):
283
- raise HTTPException(status_code=400, detail="Invalid filename")
284
-
285
- if file_path.is_file():
286
- return FileResponse(file_path)
287
- else:
288
- raise HTTPException(status_code=404, detail="File not found")
289
-
290
-
291
- @router.get("/copilot/{filename:path}")
292
- async def serve_copilot_file(
293
- filename: str,
294
- ):
295
- """Serve a file from assets dir."""
296
-
297
- base_path = Path(copilot_build_dir)
298
- file_path = (base_path / filename).resolve()
299
-
300
- if not is_path_inside(file_path, base_path):
301
- raise HTTPException(status_code=400, detail="Invalid filename")
302
-
303
- if file_path.is_file():
304
- return FileResponse(file_path)
305
- else:
306
- raise HTTPException(status_code=404, detail="File not found")
307
-
308
-
309
- # -------------------------------------------------------------------------------
310
- # SLACK HTTP HANDLER
311
- # -------------------------------------------------------------------------------
312
-
313
- if (
314
- os.environ.get("SLACK_BOT_TOKEN")
315
- and os.environ.get("SLACK_SIGNING_SECRET")
316
- and not os.environ.get("SLACK_WEBSOCKET_TOKEN")
317
- ):
318
- from chainlit.slack.app import slack_app_handler
319
-
320
- @router.post("/slack/events")
321
- async def slack_endpoint(req: Request):
322
- return await slack_app_handler.handle(req)
323
-
324
-
325
- # -------------------------------------------------------------------------------
326
- # TEAMS HANDLER
327
- # -------------------------------------------------------------------------------
328
-
329
- if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
330
- from botbuilder.schema import Activity
331
-
332
- from chainlit.teams.app import adapter, bot
333
-
334
- @router.post("/teams/events")
335
- async def teams_endpoint(req: Request):
336
- body = await req.json()
337
- activity = Activity().deserialize(body)
338
- auth_header = req.headers.get("Authorization", "")
339
- response = await adapter.process_activity(activity, auth_header, bot.on_turn)
340
- return response
341
-
342
-
343
- # -------------------------------------------------------------------------------
344
- # HTTP HANDLERS
345
- # -------------------------------------------------------------------------------
346
-
347
-
348
- def replace_between_tags(
349
- text: str, start_tag: str, end_tag: str, replacement: str
350
- ) -> str:
351
- """Replace text between two tags in a string."""
352
-
353
- pattern = start_tag + ".*?" + end_tag
354
- return re.sub(pattern, start_tag + replacement + end_tag, text, flags=re.DOTALL)
355
-
356
-
357
- def get_html_template(root_path):
358
- """
359
- Get HTML template for the index view.
360
- """
361
- root_path = root_path.rstrip("/") # Avoid duplicated / when joining with root path.
362
-
363
- custom_theme = None
364
- custom_theme_file_path = Path(public_dir) / "theme.json"
365
- if (
366
- is_path_inside(custom_theme_file_path, Path(public_dir))
367
- and custom_theme_file_path.is_file()
368
- ):
369
- custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))
370
-
371
- PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
372
- JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
373
- CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
374
-
375
- default_url = "https://github.com/Chainlit/chainlit"
376
- default_meta_image_url = (
377
- "https://chainlit-cloud.s3.eu-west-3.amazonaws.com/logo/chainlit_banner.png"
378
- )
379
- meta_image_url = config.ui.custom_meta_image_url or default_meta_image_url
380
- favicon_path = "/favicon"
381
-
382
- tags = f"""<title>{config.ui.name}</title>
383
- <link rel="icon" href="{favicon_path}" />
384
- <meta name="description" content="{config.ui.description}">
385
- <meta property="og:type" content="website">
386
- <meta property="og:title" content="{config.ui.name}">
387
- <meta property="og:description" content="{config.ui.description}">
388
- <meta property="og:image" content="{meta_image_url}">
389
- <meta property="og:url" content="{default_url}">
390
- <meta property="og:root_path" content="{root_path}">"""
391
-
392
- js = f"""<script>
393
- {f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
394
- {f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
395
- </script>"""
396
-
397
- css = None
398
- if config.ui.custom_css:
399
- css = f"""<link rel="stylesheet" type="text/css" href="{config.ui.custom_css}" {config.ui.custom_css_attributes}>"""
400
-
401
- if config.ui.custom_js:
402
- js += f"""<script src="{config.ui.custom_js}" {config.ui.custom_js_attributes}></script>"""
403
-
404
- font = None
405
- if custom_theme and custom_theme.get("custom_fonts"):
406
- font = "\n".join(
407
- f"""<link rel="stylesheet" href="{font}">"""
408
- for font in custom_theme.get("custom_fonts")
409
- )
410
-
411
- index_html_file_path = os.path.join(build_dir, "index.html")
412
-
413
- with open(index_html_file_path, encoding="utf-8") as f:
414
- content = f.read()
415
- content = content.replace(PLACEHOLDER, tags)
416
- if js:
417
- content = content.replace(JS_PLACEHOLDER, js)
418
- if css:
419
- content = content.replace(CSS_PLACEHOLDER, css)
420
- if font:
421
- content = replace_between_tags(
422
- content, "<!-- FONT START -->", "<!-- FONT END -->", font
423
- )
424
- content = content.replace('href="/', f'href="{root_path}/')
425
- content = content.replace('src="/', f'src="{root_path}/')
426
- return content
427
-
428
-
429
- def get_user_facing_url(url: URL):
430
- """
431
- Return the user facing URL for a given URL.
432
- Handles deployment with proxies (like cloud run).
433
- """
434
- chainlit_url = os.environ.get("CHAINLIT_URL")
435
-
436
- # No config, we keep the URL as is
437
- if not chainlit_url:
438
- url = url.replace(query="", fragment="")
439
- return url.__str__()
440
-
441
- config_url = URL(chainlit_url).replace(
442
- query="",
443
- fragment="",
444
- )
445
- # Remove trailing slash from config URL
446
- if config_url.path.endswith("/"):
447
- config_url = config_url.replace(path=config_url.path[:-1])
448
-
449
- return config_url.__str__() + url.path
450
-
451
-
452
- @router.get("/auth/config")
453
- async def auth(request: Request):
454
- return get_configuration()
455
-
456
-
457
- def _get_response_dict(access_token: str) -> dict:
458
- """Get the response dictionary for the auth response."""
459
-
460
- return {"success": True}
461
-
462
-
463
- def _get_auth_response(access_token: str, redirect_to_callback: bool) -> Response:
464
- """Get the redirect params for the OAuth callback."""
465
-
466
- response_dict = _get_response_dict(access_token)
467
-
468
- if redirect_to_callback:
469
- root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
470
- root_path = "" if root_path == "/" else root_path
471
- redirect_url = (
472
- f"{root_path}/login/callback?{urllib.parse.urlencode(response_dict)}"
473
- )
474
-
475
- return RedirectResponse(
476
- # FIXME: redirect to the right frontend base url to improve the dev environment
477
- url=redirect_url,
478
- status_code=302,
479
- )
480
-
481
- return JSONResponse(response_dict)
482
-
483
-
484
- def _get_oauth_redirect_error(request: Request, error: str) -> Response:
485
- """Get the redirect response for an OAuth error."""
486
- params = urllib.parse.urlencode(
487
- {
488
- "error": error,
489
- }
490
- )
491
- response = RedirectResponse(url=str(request.url_for("login")) + "?" + params)
492
- return response
493
-
494
-
495
- async def _authenticate_user(
496
- request: Request, user: Optional[User], redirect_to_callback: bool = False
497
- ) -> Response:
498
- """Authenticate a user and return the response."""
499
-
500
- if not user:
501
- raise HTTPException(
502
- status_code=status.HTTP_401_UNAUTHORIZED,
503
- detail="credentialssignin",
504
- )
505
-
506
- # If a data layer is defined, attempt to persist user.
507
- if data_layer := get_data_layer():
508
- try:
509
- await data_layer.create_user(user)
510
- except Exception as e:
511
- # Catch and log exceptions during user creation.
512
- # TODO: Make this catch only specific errors and allow others to propagate.
513
- logger.error(f"Error creating user: {e}")
514
-
515
- access_token = create_jwt(user)
516
-
517
- response = _get_auth_response(access_token, redirect_to_callback)
518
-
519
- set_auth_cookie(request, response, access_token)
520
-
521
- return response
522
-
523
-
524
- @router.post("/login")
525
- async def login(
526
- request: Request,
527
- response: Response,
528
- form_data: OAuth2PasswordRequestForm = Depends(),
529
- ):
530
- """
531
- Login a user using the password auth callback.
532
- """
533
- if not config.code.password_auth_callback:
534
- raise HTTPException(
535
- status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
536
- )
537
-
538
- user = await config.code.password_auth_callback(
539
- form_data.username, form_data.password
540
- )
541
-
542
- return await _authenticate_user(request, user)
543
-
544
-
545
- @router.post("/logout")
546
- async def logout(request: Request, response: Response):
547
- """Logout the user by calling the on_logout callback."""
548
- clear_auth_cookie(request, response)
549
-
550
- if config.code.on_logout:
551
- return await config.code.on_logout(request, response)
552
-
553
- return {"success": True}
554
-
555
-
556
- @router.post("/auth/jwt")
557
- async def jwt_auth(request: Request):
558
- """Login a user using a valid jwt."""
559
- from jwt import InvalidTokenError
560
-
561
- auth_header: Optional[str] = request.headers.get("Authorization")
562
- if not auth_header:
563
- raise HTTPException(status_code=401, detail="Authorization header missing")
564
-
565
- # Check if it starts with "Bearer "
566
- try:
567
- scheme, token = auth_header.split()
568
- if scheme.lower() != "bearer":
569
- raise HTTPException(
570
- status_code=401,
571
- detail="Invalid authentication scheme. Please use Bearer",
572
- )
573
- except ValueError:
574
- raise HTTPException(
575
- status_code=401, detail="Invalid authorization header format"
576
- )
577
-
578
- try:
579
- user = decode_jwt(token)
580
- return await _authenticate_user(request, user)
581
- except InvalidTokenError:
582
- raise HTTPException(status_code=401, detail="Invalid token")
583
-
584
-
585
- @router.post("/auth/header")
586
- async def header_auth(request: Request):
587
- """Login a user using the header_auth_callback."""
588
- if not config.code.header_auth_callback:
589
- raise HTTPException(
590
- status_code=status.HTTP_400_BAD_REQUEST,
591
- detail="No header_auth_callback defined",
592
- )
593
-
594
- user = await config.code.header_auth_callback(request.headers)
595
-
596
- return await _authenticate_user(request, user)
597
-
598
-
599
- @router.get("/auth/oauth/{provider_id}")
600
- async def oauth_login(provider_id: str, request: Request):
601
- """Redirect the user to the oauth provider login page."""
602
- if config.code.oauth_callback is None:
603
- raise HTTPException(
604
- status_code=status.HTTP_400_BAD_REQUEST,
605
- detail="No oauth_callback defined",
606
- )
607
-
608
- provider = get_oauth_provider(provider_id)
609
- if not provider:
610
- raise HTTPException(
611
- status_code=status.HTTP_404_NOT_FOUND,
612
- detail=f"Provider {provider_id} not found",
613
- )
614
-
615
- random = random_secret(32)
616
-
617
- params = urllib.parse.urlencode(
618
- {
619
- "client_id": provider.client_id,
620
- "redirect_uri": f"{get_user_facing_url(request.url)}/callback",
621
- "state": random,
622
- **provider.authorize_params,
623
- }
624
- )
625
- response = RedirectResponse(
626
- url=f"{provider.authorize_url}?{params}",
627
- )
628
-
629
- set_oauth_state_cookie(response, random)
630
-
631
- return response
632
-
633
-
634
- @router.get("/auth/oauth/{provider_id}/callback")
635
- async def oauth_callback(
636
- provider_id: str,
637
- request: Request,
638
- error: Optional[str] = None,
639
- code: Optional[str] = None,
640
- state: Optional[str] = None,
641
- ):
642
- """Handle the oauth callback and login the user."""
643
-
644
- if config.code.oauth_callback is None:
645
- raise HTTPException(
646
- status_code=status.HTTP_400_BAD_REQUEST,
647
- detail="No oauth_callback defined",
648
- )
649
-
650
- provider = get_oauth_provider(provider_id)
651
- if not provider:
652
- raise HTTPException(
653
- status_code=status.HTTP_404_NOT_FOUND,
654
- detail=f"Provider {provider_id} not found",
655
- )
656
-
657
- if error:
658
- return _get_oauth_redirect_error(request, error)
659
-
660
- if not code or not state:
661
- raise HTTPException(
662
- status_code=status.HTTP_400_BAD_REQUEST,
663
- detail="Missing code or state",
664
- )
665
-
666
- try:
667
- validate_oauth_state_cookie(request, state)
668
- except Exception as e:
669
- logger.exception("Unable to validate oauth state: %1", e)
670
-
671
- raise HTTPException(
672
- status_code=status.HTTP_401_UNAUTHORIZED,
673
- detail="Unauthorized",
674
- )
675
-
676
- url = get_user_facing_url(request.url)
677
- token = await provider.get_token(code, url)
678
-
679
- (raw_user_data, default_user) = await provider.get_user_info(token)
680
-
681
- user = await config.code.oauth_callback(
682
- provider_id, token, raw_user_data, default_user
683
- )
684
-
685
- response = await _authenticate_user(request, user, redirect_to_callback=True)
686
-
687
- clear_oauth_state_cookie(response)
688
-
689
- return response
690
-
691
-
692
- # specific route for azure ad hybrid flow
693
- @router.post("/auth/oauth/azure-ad-hybrid/callback")
694
- async def oauth_azure_hf_callback(
695
- request: Request,
696
- error: Optional[str] = None,
697
- code: Annotated[Optional[str], Form()] = None,
698
- id_token: Annotated[Optional[str], Form()] = None,
699
- ):
700
- """Handle the azure ad hybrid flow callback and login the user."""
701
-
702
- provider_id = "azure-ad-hybrid"
703
- if config.code.oauth_callback is None:
704
- raise HTTPException(
705
- status_code=status.HTTP_400_BAD_REQUEST,
706
- detail="No oauth_callback defined",
707
- )
708
-
709
- provider = get_oauth_provider(provider_id)
710
- if not provider:
711
- raise HTTPException(
712
- status_code=status.HTTP_404_NOT_FOUND,
713
- detail=f"Provider {provider_id} not found",
714
- )
715
-
716
- if error:
717
- return _get_oauth_redirect_error(request, error)
718
-
719
- if not code:
720
- raise HTTPException(
721
- status_code=status.HTTP_400_BAD_REQUEST,
722
- detail="Missing code",
723
- )
724
-
725
- url = get_user_facing_url(request.url)
726
- token = await provider.get_token(code, url)
727
-
728
- (raw_user_data, default_user) = await provider.get_user_info(token)
729
-
730
- user = await config.code.oauth_callback(
731
- provider_id, token, raw_user_data, default_user, id_token
732
- )
733
-
734
- response = await _authenticate_user(request, user, redirect_to_callback=True)
735
-
736
- clear_oauth_state_cookie(response)
737
-
738
- return response
739
-
740
-
741
- GenericUser = Union[User, PersistedUser, None]
742
- UserParam = Annotated[GenericUser, Depends(get_current_user)]
743
-
744
-
745
- @router.get("/user")
746
- async def get_user(current_user: UserParam) -> GenericUser:
747
- return current_user
748
-
749
-
750
- _language_pattern = (
751
- "^[a-zA-Z]{2,3}(-[a-zA-Z0-9]{2,4})?(-[a-zA-Z0-9]{2,8})?(-x-[a-zA-Z0-9]{1,8})?$"
752
- )
753
-
754
-
755
- @router.post("/set-session-cookie")
756
- async def set_session_cookie(request: Request, response: Response):
757
- body = await request.json()
758
- session_id = body.get("session_id")
759
-
760
- is_local = request.client and request.client.host in ["127.0.0.1", "localhost"]
761
-
762
- response.set_cookie(
763
- key="X-Chainlit-Session-id",
764
- value=session_id,
765
- path="/",
766
- httponly=True,
767
- secure=not is_local,
768
- samesite="lax" if is_local else "none",
769
- )
770
-
771
- return {"message": "Session cookie set"}
772
-
773
-
774
- @router.get("/project/translations")
775
- async def project_translations(
776
- language: str = Query(
777
- default="en-US", description="Language code", pattern=_language_pattern
778
- ),
779
- ):
780
- """Return project translations."""
781
-
782
- # Load translation based on the provided language
783
- translation = config.load_translation(language)
784
-
785
- return JSONResponse(
786
- content={
787
- "translation": translation,
788
- }
789
- )
790
-
791
-
792
- @router.get("/project/settings")
793
- async def project_settings(
794
- current_user: UserParam,
795
- language: str = Query(
796
- default="en-US", description="Language code", pattern=_language_pattern
797
- ),
798
- chat_profile: Optional[str] = Query(
799
- default=None, description="Current chat profile name"
800
- ),
801
- ):
802
- """Return project settings. This is called by the UI before the establishing the websocket connection."""
803
-
804
- # Load the markdown file based on the provided language
805
- markdown = get_markdown_str(config.root, language)
806
-
807
- profiles = []
808
- if config.code.set_chat_profiles:
809
- chat_profiles = await config.code.set_chat_profiles(current_user)
810
- if chat_profiles:
811
- # Custom serialization to handle ChainlitConfigOverrides
812
- for p in chat_profiles:
813
- profile_dict = p.to_dict()
814
- # Remove config_overrides from the serialized profile since it's used server-side only
815
- if "config_overrides" in profile_dict:
816
- del profile_dict["config_overrides"]
817
- profiles.append(profile_dict)
818
-
819
- starters = []
820
- if config.code.set_starters:
821
- starters = await config.code.set_starters(current_user)
822
- if starters:
823
- starters = [s.to_dict() for s in starters]
824
-
825
- debug_url = None
826
- data_layer = get_data_layer()
827
-
828
- if data_layer and config.run.debug:
829
- debug_url = await data_layer.build_debug_url()
830
-
831
- config_with_overrides = config
832
-
833
- # Apply profile-specific configuration overrides
834
- if chat_profile and config.code.set_chat_profiles:
835
- # Find the current chat profile and apply overrides
836
- chat_profiles = await config.code.set_chat_profiles(current_user)
837
- if chat_profiles:
838
- current_profile = next(
839
- (p for p in chat_profiles if p.name == chat_profile), None
840
- )
841
- if current_profile and current_profile.config_overrides:
842
- config_with_overrides = ChainlitConfig.model_validate(
843
- config.model_copy(
844
- update=current_profile.config_overrides.model_dump(
845
- exclude_none=True
846
- ),
847
- deep=True,
848
- )
849
- )
850
-
851
- return JSONResponse(
852
- content={
853
- "ui": config_with_overrides.ui.model_dump(),
854
- "features": config_with_overrides.features.model_dump(),
855
- "userEnv": config_with_overrides.project.user_env,
856
- "dataPersistence": get_data_layer() is not None,
857
- "threadResumable": bool(config.code.on_chat_resume),
858
- "markdown": markdown,
859
- "chatProfiles": profiles,
860
- "starters": starters,
861
- "debugUrl": debug_url,
862
- }
863
- )
864
-
865
-
866
- @router.put("/feedback")
867
- async def update_feedback(
868
- request: Request,
869
- update: UpdateFeedbackRequest,
870
- current_user: UserParam,
871
- ):
872
- """Update the human feedback for a particular message."""
873
- data_layer = get_data_layer()
874
- if not data_layer:
875
- raise HTTPException(status_code=500, detail="Data persistence is not enabled")
876
-
877
- try:
878
- feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)
879
-
880
- if config.code.on_feedback:
881
- try:
882
- await config.code.on_feedback(update.feedback)
883
- except Exception as callback_error:
884
- logger.error(
885
- f"Error in user-provided on_feedback callback: {callback_error}"
886
- )
887
- # Optionally, you could continue without raising an exception to avoid disrupting the endpoint.
888
- except Exception as e:
889
- raise HTTPException(detail=str(e), status_code=500) from e
890
-
891
- return JSONResponse(content={"success": True, "feedbackId": feedback_id})
892
-
893
-
894
- @router.delete("/feedback")
895
- async def delete_feedback(
896
- request: Request,
897
- payload: DeleteFeedbackRequest,
898
- current_user: UserParam,
899
- ):
900
- """Delete a feedback."""
901
-
902
- data_layer = get_data_layer()
903
-
904
- if not data_layer:
905
- raise HTTPException(status_code=400, detail="Data persistence is not enabled")
906
-
907
- feedback_id = payload.feedbackId
908
-
909
- await data_layer.delete_feedback(feedback_id)
910
- return JSONResponse(content={"success": True})
911
-
912
-
913
- @router.post("/project/threads")
914
- async def get_user_threads(
915
- request: Request,
916
- payload: GetThreadsRequest,
917
- current_user: UserParam,
918
- ):
919
- """Get the threads page by page."""
920
-
921
- data_layer = get_data_layer()
922
-
923
- if not data_layer:
924
- raise HTTPException(status_code=400, detail="Data persistence is not enabled")
925
-
926
- if not current_user:
927
- raise HTTPException(status_code=401, detail="Unauthorized")
928
-
929
- if not isinstance(current_user, PersistedUser):
930
- persisted_user = await data_layer.get_user(identifier=current_user.identifier)
931
- if not persisted_user:
932
- raise HTTPException(status_code=404, detail="User not found")
933
- payload.filter.userId = persisted_user.id
934
- else:
935
- payload.filter.userId = current_user.id
936
-
937
- res = await data_layer.list_threads(payload.pagination, payload.filter)
938
- return JSONResponse(content=res.to_dict())
939
-
940
-
941
- @router.get("/project/thread/{thread_id}")
942
- async def get_thread(
943
- request: Request,
944
- thread_id: str,
945
- current_user: UserParam,
946
- ):
947
- """Get a specific thread."""
948
- data_layer = get_data_layer()
949
-
950
- if not data_layer:
951
- raise HTTPException(status_code=400, detail="Data persistence is not enabled")
952
-
953
- if not current_user:
954
- raise HTTPException(status_code=401, detail="Unauthorized")
955
-
956
- await is_thread_author(current_user.identifier, thread_id)
957
-
958
- res = await data_layer.get_thread(thread_id)
959
- return JSONResponse(content=res)
960
-
961
-
962
- @router.get("/project/thread/{thread_id}/element/{element_id}")
963
- async def get_thread_element(
964
- request: Request,
965
- thread_id: str,
966
- element_id: str,
967
- current_user: UserParam,
968
- ):
969
- """Get a specific thread element."""
970
- data_layer = get_data_layer()
971
-
972
- if not data_layer:
973
- raise HTTPException(status_code=400, detail="Data persistence is not enabled")
974
-
975
- if not current_user:
976
- raise HTTPException(status_code=401, detail="Unauthorized")
977
-
978
- await is_thread_author(current_user.identifier, thread_id)
979
-
980
- res = await data_layer.get_element(thread_id, element_id)
981
- return JSONResponse(content=res)
982
-
983
-
984
- @router.put("/project/element")
985
- async def update_thread_element(
986
- payload: ElementRequest,
987
- current_user: UserParam,
988
- ):
989
- """Update a specific thread element."""
990
-
991
- from chainlit.context import init_ws_context
992
- from chainlit.element import Element, ElementDict
993
- from chainlit.session import WebsocketSession
994
-
995
- session = WebsocketSession.get_by_id(payload.sessionId)
996
- context = init_ws_context(session)
997
-
998
- element_dict = cast(ElementDict, payload.element)
999
-
1000
- if element_dict["type"] != "custom":
1001
- return {"success": False}
1002
-
1003
- element = Element.from_dict(element_dict)
1004
-
1005
- if current_user:
1006
- if (
1007
- not context.session.user
1008
- or context.session.user.identifier != current_user.identifier
1009
- ):
1010
- raise HTTPException(
1011
- status_code=401,
1012
- detail="You are not authorized to update elements for this session",
1013
- )
1014
-
1015
- await element.update()
1016
- return {"success": True}
1017
-
1018
-
1019
- @router.delete("/project/element")
1020
- async def delete_thread_element(
1021
- payload: ElementRequest,
1022
- current_user: UserParam,
1023
- ):
1024
- """Delete a specific thread element."""
1025
-
1026
- from chainlit.context import init_ws_context
1027
- from chainlit.element import CustomElement, ElementDict
1028
- from chainlit.session import WebsocketSession
1029
-
1030
- session = WebsocketSession.get_by_id(payload.sessionId)
1031
- context = init_ws_context(session)
1032
-
1033
- element_dict = cast(ElementDict, payload.element)
1034
-
1035
- if element_dict["type"] != "custom":
1036
- return {"success": False}
1037
-
1038
- element = CustomElement(
1039
- id=element_dict["id"],
1040
- object_key=element_dict["objectKey"],
1041
- chainlit_key=element_dict["chainlitKey"],
1042
- url=element_dict["url"],
1043
- for_id=element_dict.get("forId") or "",
1044
- thread_id=element_dict.get("threadId") or "",
1045
- name=element_dict["name"],
1046
- props=element_dict.get("props") or {},
1047
- display=element_dict["display"],
1048
- )
1049
-
1050
- if current_user:
1051
- if (
1052
- not context.session.user
1053
- or context.session.user.identifier != current_user.identifier
1054
- ):
1055
- raise HTTPException(
1056
- status_code=401,
1057
- detail="You are not authorized to remove elements for this session",
1058
- )
1059
-
1060
- await element.remove()
1061
-
1062
- return {"success": True}
1063
-
1064
-
1065
- @router.put("/project/thread")
1066
- async def rename_thread(
1067
- request: Request,
1068
- payload: UpdateThreadRequest,
1069
- current_user: UserParam,
1070
- ):
1071
- """Rename a thread."""
1072
-
1073
- data_layer = get_data_layer()
1074
-
1075
- if not data_layer:
1076
- raise HTTPException(status_code=400, detail="Data persistence is not enabled")
1077
-
1078
- if not current_user:
1079
- raise HTTPException(status_code=401, detail="Unauthorized")
1080
-
1081
- thread_id = payload.threadId
1082
-
1083
- await is_thread_author(current_user.identifier, thread_id)
1084
-
1085
- await data_layer.update_thread(thread_id, name=payload.name)
1086
- return JSONResponse(content={"success": True})
1087
-
1088
-
1089
- @router.delete("/project/thread")
1090
- async def delete_thread(
1091
- request: Request,
1092
- payload: DeleteThreadRequest,
1093
- current_user: UserParam,
1094
- ):
1095
- """Delete a thread."""
1096
-
1097
- data_layer = get_data_layer()
1098
-
1099
- if not data_layer:
1100
- raise HTTPException(status_code=400, detail="Data persistence is not enabled")
1101
-
1102
- if not current_user:
1103
- raise HTTPException(status_code=401, detail="Unauthorized")
1104
-
1105
- thread_id = payload.threadId
1106
-
1107
- await is_thread_author(current_user.identifier, thread_id)
1108
-
1109
- await data_layer.delete_thread(thread_id)
1110
- return JSONResponse(content={"success": True})
1111
-
1112
-
1113
- @router.post("/project/action")
1114
- async def call_action(
1115
- payload: CallActionRequest,
1116
- current_user: UserParam,
1117
- ):
1118
- """Run an action."""
1119
-
1120
- from chainlit.action import Action
1121
- from chainlit.context import init_ws_context
1122
- from chainlit.session import WebsocketSession
1123
-
1124
- session = WebsocketSession.get_by_id(payload.sessionId)
1125
- context = init_ws_context(session)
1126
-
1127
- action = Action(**payload.action)
1128
-
1129
- if current_user:
1130
- if (
1131
- not context.session.user
1132
- or context.session.user.identifier != current_user.identifier
1133
- ):
1134
- raise HTTPException(
1135
- status_code=401,
1136
- detail="You are not authorized to upload files for this session",
1137
- )
1138
-
1139
- callback = config.code.action_callbacks.get(action.name)
1140
- if callback:
1141
- if not context.session.has_first_interaction:
1142
- context.session.has_first_interaction = True
1143
- asyncio.create_task(context.emitter.init_thread(action.name))
1144
-
1145
- response = await callback(action)
1146
- else:
1147
- raise HTTPException(
1148
- status_code=404,
1149
- detail=f"No callback found for action {action.name}",
1150
- )
1151
-
1152
- return JSONResponse(content={"success": True, "response": response})
1153
-
1154
-
1155
- @router.post("/mcp")
1156
- async def connect_mcp(
1157
- payload: ConnectMCPRequest,
1158
- current_user: UserParam,
1159
- ):
1160
- from mcp import ClientSession
1161
- from mcp.client.sse import sse_client
1162
- from mcp.client.stdio import (
1163
- StdioServerParameters,
1164
- get_default_environment,
1165
- stdio_client,
1166
- )
1167
- from mcp.client.streamable_http import streamablehttp_client
1168
-
1169
- from chainlit.context import init_ws_context
1170
- from chainlit.mcp import (
1171
- HttpMcpConnection,
1172
- McpConnection,
1173
- SseMcpConnection,
1174
- StdioMcpConnection,
1175
- validate_mcp_command,
1176
- )
1177
- from chainlit.session import WebsocketSession
1178
-
1179
- session = WebsocketSession.get_by_id(payload.sessionId)
1180
- context = init_ws_context(session)
1181
-
1182
- if current_user:
1183
- if (
1184
- not context.session.user
1185
- or context.session.user.identifier != current_user.identifier
1186
- ):
1187
- raise HTTPException(
1188
- status_code=401,
1189
- )
1190
-
1191
- mcp_enabled = config.features.mcp.enabled
1192
- if mcp_enabled:
1193
- if payload.name in session.mcp_sessions:
1194
- old_client_session, old_exit_stack = session.mcp_sessions[payload.name]
1195
- if on_mcp_disconnect := config.code.on_mcp_disconnect:
1196
- await on_mcp_disconnect(payload.name, old_client_session)
1197
- try:
1198
- await old_exit_stack.aclose()
1199
- except Exception:
1200
- pass
1201
-
1202
- try:
1203
- exit_stack = AsyncExitStack()
1204
- mcp_connection: McpConnection
1205
-
1206
- if payload.clientType == "sse":
1207
- if not config.features.mcp.sse.enabled:
1208
- raise HTTPException(
1209
- status_code=400,
1210
- detail="SSE MCP is not enabled",
1211
- )
1212
-
1213
- mcp_connection = SseMcpConnection(
1214
- url=payload.url,
1215
- name=payload.name,
1216
- headers=getattr(payload, "headers", None),
1217
- )
1218
-
1219
- transport = await exit_stack.enter_async_context(
1220
- sse_client(
1221
- url=mcp_connection.url,
1222
- headers=mcp_connection.headers,
1223
- )
1224
- )
1225
- elif payload.clientType == "stdio":
1226
- if not config.features.mcp.stdio.enabled:
1227
- raise HTTPException(
1228
- status_code=400,
1229
- detail="Stdio MCP is not enabled",
1230
- )
1231
-
1232
- env_from_cmd, command, args = validate_mcp_command(payload.fullCommand)
1233
- mcp_connection = StdioMcpConnection(
1234
- command=command, args=args, name=payload.name
1235
- )
1236
-
1237
- env = get_default_environment()
1238
- env.update(env_from_cmd)
1239
- # Create the server parameters
1240
- server_params = StdioServerParameters(
1241
- command=command, args=args, env=env
1242
- )
1243
-
1244
- transport = await exit_stack.enter_async_context(
1245
- stdio_client(server_params)
1246
- )
1247
-
1248
- elif payload.clientType == "streamable-http":
1249
- if not config.features.mcp.streamable_http.enabled:
1250
- raise HTTPException(
1251
- status_code=400,
1252
- detail="HTTP MCP is not enabled",
1253
- )
1254
- mcp_connection = HttpMcpConnection(
1255
- url=payload.url,
1256
- name=payload.name,
1257
- headers=getattr(payload, "headers", None),
1258
- )
1259
- transport = await exit_stack.enter_async_context(
1260
- streamablehttp_client(
1261
- url=mcp_connection.url,
1262
- headers=mcp_connection.headers,
1263
- )
1264
- )
1265
-
1266
- # The transport can return (read, write) for stdio, sse
1267
- # Or (read, write, get_session_id) for streamable-http
1268
- # We are only interested in the read and write streams here.
1269
- read, write = transport[:2]
1270
-
1271
- mcp_session: ClientSession = await exit_stack.enter_async_context(
1272
- ClientSession(
1273
- read_stream=read, write_stream=write, sampling_callback=None
1274
- )
1275
- )
1276
-
1277
- # Initialize the session
1278
- await mcp_session.initialize()
1279
-
1280
- # Store the session
1281
- session.mcp_sessions[mcp_connection.name] = (mcp_session, exit_stack)
1282
-
1283
- # Call the callback
1284
- await config.code.on_mcp_connect(mcp_connection, mcp_session)
1285
-
1286
- except Exception as e:
1287
- raise HTTPException(
1288
- status_code=400,
1289
- detail=f"Could not connect to the MCP: {e!s}",
1290
- )
1291
- else:
1292
- raise HTTPException(
1293
- status_code=400,
1294
- detail="This app does not support MCP.",
1295
- )
1296
-
1297
- tool_list = await mcp_session.list_tools()
1298
-
1299
- return JSONResponse(
1300
- content={
1301
- "success": True,
1302
- "mcp": {
1303
- "name": payload.name,
1304
- "tools": [{"name": t.name} for t in tool_list.tools],
1305
- "clientType": payload.clientType,
1306
- "command": payload.fullCommand
1307
- if payload.clientType == "stdio"
1308
- else None,
1309
- "url": getattr(payload, "url", None)
1310
- if payload.clientType in ["sse", "streamable-http"]
1311
- else None,
1312
- # Include optional headers for SSE and streamable-http connections
1313
- "headers": getattr(payload, "headers", None)
1314
- if payload.clientType in ["sse", "streamable-http"]
1315
- else None,
1316
- },
1317
- }
1318
- )
1319
-
1320
-
1321
- @router.delete("/mcp")
1322
- async def disconnect_mcp(
1323
- payload: DisconnectMCPRequest,
1324
- current_user: UserParam,
1325
- ):
1326
- from chainlit.context import init_ws_context
1327
- from chainlit.session import WebsocketSession
1328
-
1329
- session = WebsocketSession.get_by_id(payload.sessionId)
1330
- context = init_ws_context(session)
1331
-
1332
- if current_user:
1333
- if (
1334
- not context.session.user
1335
- or context.session.user.identifier != current_user.identifier
1336
- ):
1337
- raise HTTPException(
1338
- status_code=401,
1339
- )
1340
-
1341
- callback = config.code.on_mcp_disconnect
1342
- if payload.name in session.mcp_sessions:
1343
- try:
1344
- client_session, exit_stack = session.mcp_sessions[payload.name]
1345
- if callback:
1346
- await callback(payload.name, client_session)
1347
-
1348
- try:
1349
- await exit_stack.aclose()
1350
- except Exception:
1351
- pass
1352
- del session.mcp_sessions[payload.name]
1353
-
1354
- except Exception as e:
1355
- raise HTTPException(
1356
- status_code=400,
1357
- detail=f"Could not disconnect to the MCP: {e!s}",
1358
- )
1359
-
1360
- return JSONResponse(content={"success": True})
1361
-
1362
-
1363
- @router.post("/project/file")
1364
- async def upload_file(
1365
- current_user: UserParam,
1366
- session_id: str,
1367
- file: UploadFile,
1368
- ask_parent_id: Optional[str] = None,
1369
- ):
1370
- """Upload a file to the session files directory."""
1371
-
1372
- from chainlit.session import WebsocketSession
1373
-
1374
- session = WebsocketSession.get_by_id(session_id)
1375
-
1376
- if not session:
1377
- raise HTTPException(
1378
- status_code=404,
1379
- detail="Session not found",
1380
- )
1381
-
1382
- if current_user:
1383
- if not session.user or session.user.identifier != current_user.identifier:
1384
- raise HTTPException(
1385
- status_code=401,
1386
- detail="You are not authorized to upload files for this session",
1387
- )
1388
-
1389
- session.files_dir.mkdir(exist_ok=True)
1390
-
1391
- try:
1392
- content = await file.read()
1393
-
1394
- assert file.filename, "No filename for uploaded file"
1395
- assert file.content_type, "No content type for uploaded file"
1396
-
1397
- spec: AskFileSpec = session.files_spec.get(ask_parent_id, None)
1398
- if not spec and ask_parent_id:
1399
- raise HTTPException(
1400
- status_code=404,
1401
- detail="Parent message not found",
1402
- )
1403
-
1404
- try:
1405
- validate_file_upload(file, spec=spec)
1406
- except ValueError as e:
1407
- raise HTTPException(status_code=400, detail=str(e))
1408
-
1409
- file_response = await session.persist_file(
1410
- name=file.filename, content=content, mime=file.content_type
1411
- )
1412
-
1413
- return JSONResponse(content=file_response)
1414
- finally:
1415
- await file.close()
1416
-
1417
-
1418
- def validate_file_upload(file: UploadFile, spec: Optional[AskFileSpec] = None):
1419
- """Validate the file upload as configured in config.features.spontaneous_file_upload or by AskFileSpec
1420
- for a specific message.
1421
-
1422
- Args:
1423
- file (UploadFile): The file to validate.
1424
- spec (AskFileSpec): The file spec to validate against if any.
1425
- Raises:
1426
- ValueError: If the file is not allowed.
1427
- """
1428
- if not spec and config.features.spontaneous_file_upload is None:
1429
- """Default for a missing config is to allow the fileupload without any restrictions"""
1430
- return
1431
-
1432
- if not spec and not config.features.spontaneous_file_upload.enabled:
1433
- raise ValueError("File upload is not enabled")
1434
-
1435
- validate_file_mime_type(file, spec)
1436
- validate_file_size(file, spec)
1437
-
1438
-
1439
- def validate_file_mime_type(file: UploadFile, spec: Optional[AskFileSpec]):
1440
- """Validate the file mime type as configured in config.features.spontaneous_file_upload.
1441
- Args:
1442
- file (UploadFile): The file to validate.
1443
- Raises:
1444
- ValueError: If the file type is not allowed.
1445
- """
1446
-
1447
- if not spec and (
1448
- config.features.spontaneous_file_upload is None
1449
- or config.features.spontaneous_file_upload.accept is None
1450
- ):
1451
- "Accept is not configured, allowing all file types"
1452
- return
1453
-
1454
- accept = config.features.spontaneous_file_upload.accept if not spec else spec.accept
1455
-
1456
- assert isinstance(accept, List) or isinstance(accept, dict), (
1457
- "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
1458
- )
1459
-
1460
- if isinstance(accept, List):
1461
- for pattern in accept:
1462
- if fnmatch.fnmatch(str(file.content_type), pattern):
1463
- return
1464
- elif isinstance(accept, dict):
1465
- for pattern, extensions in accept.items():
1466
- if fnmatch.fnmatch(str(file.content_type), pattern):
1467
- if len(extensions) == 0:
1468
- return
1469
- for extension in extensions:
1470
- if file.filename is not None and file.filename.lower().endswith(
1471
- extension.lower()
1472
- ):
1473
- return
1474
- raise ValueError("File type not allowed")
1475
-
1476
-
1477
- def validate_file_size(file: UploadFile, spec: Optional[AskFileSpec]):
1478
- """Validate the file size as configured in config.features.spontaneous_file_upload.
1479
- Args:
1480
- file (UploadFile): The file to validate.
1481
- Raises:
1482
- ValueError: If the file size is too large.
1483
- """
1484
- if not spec and (
1485
- config.features.spontaneous_file_upload is None
1486
- or config.features.spontaneous_file_upload.max_size_mb is None
1487
- ):
1488
- return
1489
-
1490
- max_size_mb = (
1491
- config.features.spontaneous_file_upload.max_size_mb
1492
- if not spec
1493
- else spec.max_size_mb
1494
- )
1495
- if file.size is not None and file.size > max_size_mb * 1024 * 1024:
1496
- raise ValueError("File size too large")
1497
-
1498
-
1499
- @router.get("/project/file/{file_id}")
1500
- async def get_file(
1501
- file_id: str,
1502
- session_id: str,
1503
- current_user: UserParam,
1504
- ):
1505
- """Get a file from the session files directory."""
1506
- from chainlit.session import WebsocketSession
1507
-
1508
- session = WebsocketSession.get_by_id(session_id) if session_id else None
1509
-
1510
- if not session:
1511
- raise HTTPException(
1512
- status_code=401,
1513
- detail="Unauthorized",
1514
- )
1515
-
1516
- if current_user:
1517
- if not session.user or session.user.identifier != current_user.identifier:
1518
- raise HTTPException(
1519
- status_code=401,
1520
- detail="You are not authorized to download files from this session",
1521
- )
1522
-
1523
- if file_id in session.files:
1524
- file = session.files[file_id]
1525
- return FileResponse(file["path"], media_type=file["type"])
1526
- else:
1527
- raise HTTPException(status_code=404, detail="File not found")
1528
-
1529
-
1530
- @router.get("/favicon")
1531
- async def get_favicon():
1532
- """Get the favicon for the UI."""
1533
- custom_favicon_path = os.path.join(APP_ROOT, "public", "favicon.*")
1534
- files = glob.glob(custom_favicon_path)
1535
-
1536
- if files:
1537
- favicon_path = files[0]
1538
- else:
1539
- favicon_path = os.path.join(build_dir, "favicon.svg")
1540
-
1541
- media_type, _ = mimetypes.guess_type(favicon_path)
1542
-
1543
- return FileResponse(favicon_path, media_type=media_type)
1544
-
1545
-
1546
- @router.get("/logo")
1547
- async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
1548
- """Get the default logo for the UI."""
1549
- theme_value = theme.value if theme else Theme.light.value
1550
- logo_path = None
1551
-
1552
- for path in [
1553
- os.path.join(APP_ROOT, "public", f"logo_{theme_value}.*"),
1554
- os.path.join(build_dir, "assets", f"logo_{theme_value}*.*"),
1555
- ]:
1556
- files = glob.glob(path)
1557
-
1558
- if files:
1559
- logo_path = files[0]
1560
- break
1561
-
1562
- if not logo_path:
1563
- raise HTTPException(status_code=404, detail="Missing default logo")
1564
-
1565
- media_type, _ = mimetypes.guess_type(logo_path)
1566
-
1567
- return FileResponse(logo_path, media_type=media_type)
1568
-
1569
-
1570
- @router.get("/avatars/{avatar_id:str}")
1571
- async def get_avatar(avatar_id: str):
1572
- """Get the avatar for the user based on the avatar_id."""
1573
- if not re.match(r"^[a-zA-Z0-9_ .-]+$", avatar_id):
1574
- raise HTTPException(status_code=400, detail="Invalid avatar_id")
1575
-
1576
- if avatar_id == "default":
1577
- avatar_id = config.ui.name
1578
-
1579
- avatar_id = avatar_id.strip().lower().replace(" ", "_").replace(".", "_")
1580
-
1581
- base_path = Path(APP_ROOT) / "public" / "avatars"
1582
- avatar_pattern = f"{avatar_id}.*"
1583
-
1584
- matching_files = base_path.glob(avatar_pattern)
1585
-
1586
- if avatar_path := next(matching_files, None):
1587
- if not is_path_inside(avatar_path, base_path):
1588
- raise HTTPException(status_code=400, detail="Invalid filename")
1589
- media_type, _ = mimetypes.guess_type(str(avatar_path))
1590
-
1591
- return FileResponse(avatar_path, media_type=media_type)
1592
-
1593
- return await get_favicon()
1594
-
1595
-
1596
- @router.head("/")
1597
- def status_check():
1598
- """Check if the site is operational."""
1599
- return {"message": "Site is operational"}
1600
-
1601
-
1602
- @router.get("/{full_path:path}")
1603
- async def serve(request: Request):
1604
- """Serve the UI files."""
1605
- root_path = os.getenv("CHAINLIT_PARENT_ROOT_PATH", "") + os.getenv(
1606
- "CHAINLIT_ROOT_PATH", ""
1607
- )
1608
- html_template = get_html_template(root_path)
1609
- response = HTMLResponse(content=html_template, status_code=200)
1610
-
1611
- return response
1612
-
1613
-
1614
- app.include_router(router)
1615
-
1616
- import chainlit.socket # noqa