chainlit 1.0.401__py3-none-any.whl → 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +98 -279
- chainlit/_utils.py +8 -0
- chainlit/action.py +12 -10
- chainlit/{auth.py → auth/__init__.py} +28 -36
- chainlit/auth/cookie.py +122 -0
- chainlit/auth/jwt.py +39 -0
- chainlit/cache.py +4 -6
- chainlit/callbacks.py +362 -0
- chainlit/chat_context.py +64 -0
- chainlit/chat_settings.py +3 -1
- chainlit/cli/__init__.py +77 -8
- chainlit/config.py +181 -101
- chainlit/context.py +42 -13
- chainlit/copilot/dist/index.js +8750 -903
- chainlit/data/__init__.py +101 -416
- chainlit/data/acl.py +6 -2
- chainlit/data/base.py +107 -0
- chainlit/data/chainlit_data_layer.py +608 -0
- chainlit/data/dynamodb.py +590 -0
- chainlit/data/literalai.py +500 -0
- chainlit/data/sql_alchemy.py +721 -0
- chainlit/data/storage_clients/__init__.py +0 -0
- chainlit/data/storage_clients/azure.py +81 -0
- chainlit/data/storage_clients/azure_blob.py +89 -0
- chainlit/data/storage_clients/base.py +26 -0
- chainlit/data/storage_clients/gcs.py +88 -0
- chainlit/data/storage_clients/s3.py +75 -0
- chainlit/data/utils.py +29 -0
- chainlit/discord/__init__.py +6 -0
- chainlit/discord/app.py +354 -0
- chainlit/element.py +91 -33
- chainlit/emitter.py +80 -29
- chainlit/frontend/dist/assets/DailyMotion-C_XC7xJI.js +1 -0
- chainlit/frontend/dist/assets/Dataframe-Cs4l4hA1.js +22 -0
- chainlit/frontend/dist/assets/Facebook-CUeCH7hk.js +1 -0
- chainlit/frontend/dist/assets/FilePlayer-CB-fYkx8.js +1 -0
- chainlit/frontend/dist/assets/Kaltura-YX6qaq72.js +1 -0
- chainlit/frontend/dist/assets/Mixcloud-DGV0ldjP.js +1 -0
- chainlit/frontend/dist/assets/Mux-CmRss5oc.js +1 -0
- chainlit/frontend/dist/assets/Preview-DBVJn7-H.js +1 -0
- chainlit/frontend/dist/assets/SoundCloud-qLUb18oY.js +1 -0
- chainlit/frontend/dist/assets/Streamable-BvYP7bFp.js +1 -0
- chainlit/frontend/dist/assets/Twitch-CTHt-sGZ.js +1 -0
- chainlit/frontend/dist/assets/Vidyard-B-0mCJbm.js +1 -0
- chainlit/frontend/dist/assets/Vimeo-Dnp7ri8q.js +1 -0
- chainlit/frontend/dist/assets/Wistia-DW0x_UBn.js +1 -0
- chainlit/frontend/dist/assets/YouTube--98FipvA.js +1 -0
- chainlit/frontend/dist/assets/index-D71nZ46o.js +8665 -0
- chainlit/frontend/dist/assets/index-g8LTJwwr.css +1 -0
- chainlit/frontend/dist/assets/react-plotly-Cn_BQTQw.js +3484 -0
- chainlit/frontend/dist/index.html +2 -4
- chainlit/haystack/callbacks.py +4 -7
- chainlit/input_widget.py +8 -4
- chainlit/langchain/callbacks.py +103 -68
- chainlit/langflow/__init__.py +1 -0
- chainlit/llama_index/callbacks.py +65 -40
- chainlit/markdown.py +22 -6
- chainlit/message.py +54 -56
- chainlit/mistralai/__init__.py +50 -0
- chainlit/oauth_providers.py +266 -8
- chainlit/openai/__init__.py +10 -18
- chainlit/secret.py +1 -1
- chainlit/server.py +789 -228
- chainlit/session.py +108 -90
- chainlit/slack/__init__.py +6 -0
- chainlit/slack/app.py +397 -0
- chainlit/socket.py +199 -116
- chainlit/step.py +141 -89
- chainlit/sync.py +2 -1
- chainlit/teams/__init__.py +6 -0
- chainlit/teams/app.py +338 -0
- chainlit/translations/bn.json +235 -0
- chainlit/translations/en-US.json +83 -4
- chainlit/translations/gu.json +235 -0
- chainlit/translations/he-IL.json +235 -0
- chainlit/translations/hi.json +235 -0
- chainlit/translations/kn.json +235 -0
- chainlit/translations/ml.json +235 -0
- chainlit/translations/mr.json +235 -0
- chainlit/translations/nl-NL.json +233 -0
- chainlit/translations/ta.json +235 -0
- chainlit/translations/te.json +235 -0
- chainlit/translations/zh-CN.json +233 -0
- chainlit/translations.py +60 -0
- chainlit/types.py +133 -28
- chainlit/user.py +14 -3
- chainlit/user_session.py +6 -3
- chainlit/utils.py +52 -5
- chainlit/version.py +3 -2
- {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/METADATA +48 -50
- chainlit-2.0.3.dist-info/RECORD +106 -0
- chainlit/cli/utils.py +0 -24
- chainlit/frontend/dist/assets/index-9711593e.js +0 -723
- chainlit/frontend/dist/assets/index-d088547c.css +0 -1
- chainlit/frontend/dist/assets/react-plotly-d8762cc2.js +0 -3602
- chainlit/playground/__init__.py +0 -2
- chainlit/playground/config.py +0 -40
- chainlit/playground/provider.py +0 -108
- chainlit/playground/providers/__init__.py +0 -13
- chainlit/playground/providers/anthropic.py +0 -118
- chainlit/playground/providers/huggingface.py +0 -75
- chainlit/playground/providers/langchain.py +0 -89
- chainlit/playground/providers/openai.py +0 -408
- chainlit/playground/providers/vertexai.py +0 -171
- chainlit/translations/pt-BR.json +0 -155
- chainlit-1.0.401.dist-info/RECORD +0 -66
- /chainlit/copilot/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
- /chainlit/copilot/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
- {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/WHEEL +0 -0
- {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/entry_points.txt +0 -0
chainlit/server.py
CHANGED
|
@@ -1,24 +1,45 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import fnmatch
|
|
1
3
|
import glob
|
|
2
4
|
import json
|
|
3
5
|
import mimetypes
|
|
6
|
+
import os
|
|
4
7
|
import re
|
|
5
8
|
import shutil
|
|
6
9
|
import urllib.parse
|
|
7
|
-
from typing import Any, Optional, Union
|
|
8
|
-
|
|
9
|
-
from chainlit.oauth_providers import get_oauth_provider
|
|
10
|
-
from chainlit.secret import random_secret
|
|
11
|
-
|
|
12
|
-
mimetypes.add_type("application/javascript", ".js")
|
|
13
|
-
mimetypes.add_type("text/css", ".css")
|
|
14
|
-
|
|
15
|
-
import asyncio
|
|
16
|
-
import os
|
|
17
10
|
import webbrowser
|
|
18
11
|
from contextlib import asynccontextmanager
|
|
19
12
|
from pathlib import Path
|
|
13
|
+
from typing import List, Optional, Union, cast
|
|
14
|
+
|
|
15
|
+
import socketio
|
|
16
|
+
from fastapi import (
|
|
17
|
+
APIRouter,
|
|
18
|
+
Depends,
|
|
19
|
+
FastAPI,
|
|
20
|
+
Form,
|
|
21
|
+
HTTPException,
|
|
22
|
+
Query,
|
|
23
|
+
Request,
|
|
24
|
+
Response,
|
|
25
|
+
UploadFile,
|
|
26
|
+
status,
|
|
27
|
+
)
|
|
28
|
+
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
|
|
29
|
+
from fastapi.security import OAuth2PasswordRequestForm
|
|
30
|
+
from starlette.datastructures import URL
|
|
31
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
32
|
+
from typing_extensions import Annotated
|
|
33
|
+
from watchfiles import awatch
|
|
20
34
|
|
|
21
|
-
from chainlit.auth import create_jwt, get_configuration, get_current_user
|
|
35
|
+
from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
|
|
36
|
+
from chainlit.auth.cookie import (
|
|
37
|
+
clear_auth_cookie,
|
|
38
|
+
clear_oauth_state_cookie,
|
|
39
|
+
set_auth_cookie,
|
|
40
|
+
set_oauth_state_cookie,
|
|
41
|
+
validate_oauth_state_cookie,
|
|
42
|
+
)
|
|
22
43
|
from chainlit.config import (
|
|
23
44
|
APP_ROOT,
|
|
24
45
|
BACKEND_ROOT,
|
|
@@ -27,51 +48,48 @@ from chainlit.config import (
|
|
|
27
48
|
PACKAGE_ROOT,
|
|
28
49
|
config,
|
|
29
50
|
load_module,
|
|
51
|
+
public_dir,
|
|
30
52
|
reload_config,
|
|
31
53
|
)
|
|
32
54
|
from chainlit.data import get_data_layer
|
|
33
55
|
from chainlit.data.acl import is_thread_author
|
|
34
56
|
from chainlit.logger import logger
|
|
35
57
|
from chainlit.markdown import get_markdown_str
|
|
36
|
-
from chainlit.
|
|
37
|
-
from chainlit.
|
|
58
|
+
from chainlit.oauth_providers import get_oauth_provider
|
|
59
|
+
from chainlit.secret import random_secret
|
|
38
60
|
from chainlit.types import (
|
|
61
|
+
CallActionRequest,
|
|
62
|
+
DeleteFeedbackRequest,
|
|
39
63
|
DeleteThreadRequest,
|
|
40
|
-
|
|
64
|
+
ElementRequest,
|
|
41
65
|
GetThreadsRequest,
|
|
42
66
|
Theme,
|
|
43
67
|
UpdateFeedbackRequest,
|
|
68
|
+
UpdateThreadRequest,
|
|
44
69
|
)
|
|
45
70
|
from chainlit.user import PersistedUser, User
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
|
|
57
|
-
from fastapi.security import OAuth2PasswordRequestForm
|
|
58
|
-
from fastapi.staticfiles import StaticFiles
|
|
59
|
-
from fastapi_socketio import SocketManager
|
|
60
|
-
from starlette.datastructures import URL
|
|
61
|
-
from starlette.middleware.cors import CORSMiddleware
|
|
62
|
-
from typing_extensions import Annotated
|
|
63
|
-
from watchfiles import awatch
|
|
71
|
+
|
|
72
|
+
from ._utils import is_path_inside
|
|
73
|
+
|
|
74
|
+
mimetypes.add_type("application/javascript", ".js")
|
|
75
|
+
mimetypes.add_type("text/css", ".css")
|
|
76
|
+
|
|
77
|
+
ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
78
|
+
IS_SUBMOUNT = os.environ.get("CHAINLIT_SUBMOUNT", "") == "true"
|
|
79
|
+
# If the app is a submount, no need to set the prefix
|
|
80
|
+
PREFIX = ROOT_PATH if ROOT_PATH and not IS_SUBMOUNT else ""
|
|
64
81
|
|
|
65
82
|
|
|
66
83
|
@asynccontextmanager
|
|
67
84
|
async def lifespan(app: FastAPI):
|
|
85
|
+
"""Context manager to handle app start and shutdown."""
|
|
68
86
|
host = config.run.host
|
|
69
87
|
port = config.run.port
|
|
70
88
|
|
|
71
89
|
if host == DEFAULT_HOST:
|
|
72
|
-
url = f"http://localhost:{port}"
|
|
90
|
+
url = f"http://localhost:{port}{ROOT_PATH}"
|
|
73
91
|
else:
|
|
74
|
-
url = f"http://{host}:{port}"
|
|
92
|
+
url = f"http://{host}:{port}{ROOT_PATH}"
|
|
75
93
|
|
|
76
94
|
logger.info(f"Your app is available at {url}")
|
|
77
95
|
|
|
@@ -112,22 +130,33 @@ async def lifespan(app: FastAPI):
|
|
|
112
130
|
logger.error(f"Error reloading module: {e}")
|
|
113
131
|
|
|
114
132
|
await asyncio.sleep(1)
|
|
115
|
-
await
|
|
133
|
+
await sio.emit("reload", {})
|
|
116
134
|
|
|
117
135
|
break
|
|
118
136
|
|
|
119
137
|
watch_task = asyncio.create_task(watch_files_for_changes())
|
|
120
138
|
|
|
139
|
+
discord_task = None
|
|
140
|
+
|
|
141
|
+
if discord_bot_token := os.environ.get("DISCORD_BOT_TOKEN"):
|
|
142
|
+
from chainlit.discord.app import client
|
|
143
|
+
|
|
144
|
+
discord_task = asyncio.create_task(client.start(discord_bot_token))
|
|
145
|
+
|
|
121
146
|
try:
|
|
122
147
|
yield
|
|
123
148
|
finally:
|
|
124
|
-
|
|
125
|
-
|
|
149
|
+
try:
|
|
150
|
+
if watch_task:
|
|
126
151
|
stop_event.set()
|
|
127
152
|
watch_task.cancel()
|
|
128
153
|
await watch_task
|
|
129
|
-
|
|
130
|
-
|
|
154
|
+
|
|
155
|
+
if discord_task:
|
|
156
|
+
discord_task.cancel()
|
|
157
|
+
await discord_task
|
|
158
|
+
except asyncio.exceptions.CancelledError:
|
|
159
|
+
pass
|
|
131
160
|
|
|
132
161
|
if FILES_DIRECTORY.is_dir():
|
|
133
162
|
shutil.rmtree(FILES_DIRECTORY)
|
|
@@ -136,10 +165,26 @@ async def lifespan(app: FastAPI):
|
|
|
136
165
|
os._exit(0)
|
|
137
166
|
|
|
138
167
|
|
|
139
|
-
def get_build_dir(local_target: str, packaged_target: str):
|
|
168
|
+
def get_build_dir(local_target: str, packaged_target: str) -> str:
|
|
169
|
+
"""
|
|
170
|
+
Get the build directory based on the UI build strategy.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
local_target (str): The local target directory.
|
|
174
|
+
packaged_target (str): The packaged target directory.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
str: The build directory
|
|
178
|
+
"""
|
|
179
|
+
|
|
140
180
|
local_build_dir = os.path.join(PACKAGE_ROOT, local_target, "dist")
|
|
141
181
|
packaged_build_dir = os.path.join(BACKEND_ROOT, packaged_target, "dist")
|
|
142
|
-
|
|
182
|
+
|
|
183
|
+
if config.ui.custom_build and os.path.exists(
|
|
184
|
+
os.path.join(APP_ROOT, config.ui.custom_build)
|
|
185
|
+
):
|
|
186
|
+
return os.path.join(APP_ROOT, config.ui.custom_build)
|
|
187
|
+
elif os.path.exists(local_build_dir):
|
|
143
188
|
return local_build_dir
|
|
144
189
|
elif os.path.exists(packaged_build_dir):
|
|
145
190
|
return packaged_build_dir
|
|
@@ -150,28 +195,16 @@ def get_build_dir(local_target: str, packaged_target: str):
|
|
|
150
195
|
build_dir = get_build_dir("frontend", "frontend")
|
|
151
196
|
copilot_build_dir = get_build_dir(os.path.join("libs", "copilot"), "copilot")
|
|
152
197
|
|
|
153
|
-
|
|
154
198
|
app = FastAPI(lifespan=lifespan)
|
|
155
199
|
|
|
156
|
-
|
|
157
|
-
app.mount(
|
|
158
|
-
"/assets",
|
|
159
|
-
StaticFiles(
|
|
160
|
-
packages=[("chainlit", os.path.join(build_dir, "assets"))],
|
|
161
|
-
follow_symlink=config.project.follow_symlink,
|
|
162
|
-
),
|
|
163
|
-
name="assets",
|
|
164
|
-
)
|
|
200
|
+
sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
|
|
165
201
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
packages=[("chainlit", copilot_build_dir)],
|
|
170
|
-
follow_symlink=config.project.follow_symlink,
|
|
171
|
-
),
|
|
172
|
-
name="copilot",
|
|
202
|
+
asgi_app = socketio.ASGIApp(
|
|
203
|
+
socketio_server=sio,
|
|
204
|
+
socketio_path="",
|
|
173
205
|
)
|
|
174
206
|
|
|
207
|
+
app.mount(f"{PREFIX}/ws/socket.io", asgi_app)
|
|
175
208
|
|
|
176
209
|
app.add_middleware(
|
|
177
210
|
CORSMiddleware,
|
|
@@ -181,11 +214,91 @@ app.add_middleware(
|
|
|
181
214
|
allow_headers=["*"],
|
|
182
215
|
)
|
|
183
216
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
217
|
+
router = APIRouter(prefix=PREFIX)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@router.get("/public/{filename:path}")
|
|
221
|
+
async def serve_public_file(
|
|
222
|
+
filename: str,
|
|
223
|
+
):
|
|
224
|
+
"""Serve a file from public dir."""
|
|
225
|
+
|
|
226
|
+
base_path = Path(public_dir)
|
|
227
|
+
file_path = (base_path / filename).resolve()
|
|
228
|
+
|
|
229
|
+
if not is_path_inside(file_path, base_path):
|
|
230
|
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
231
|
+
|
|
232
|
+
if file_path.is_file():
|
|
233
|
+
return FileResponse(file_path)
|
|
234
|
+
else:
|
|
235
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@router.get("/assets/{filename:path}")
|
|
239
|
+
async def serve_asset_file(
|
|
240
|
+
filename: str,
|
|
241
|
+
):
|
|
242
|
+
"""Serve a file from assets dir."""
|
|
243
|
+
|
|
244
|
+
base_path = Path(os.path.join(build_dir, "assets"))
|
|
245
|
+
file_path = (base_path / filename).resolve()
|
|
246
|
+
|
|
247
|
+
if not is_path_inside(file_path, base_path):
|
|
248
|
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
249
|
+
|
|
250
|
+
if file_path.is_file():
|
|
251
|
+
return FileResponse(file_path)
|
|
252
|
+
else:
|
|
253
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@router.get("/copilot/{filename:path}")
|
|
257
|
+
async def serve_copilot_file(
|
|
258
|
+
filename: str,
|
|
259
|
+
):
|
|
260
|
+
"""Serve a file from assets dir."""
|
|
261
|
+
|
|
262
|
+
base_path = Path(copilot_build_dir)
|
|
263
|
+
file_path = (base_path / filename).resolve()
|
|
264
|
+
|
|
265
|
+
if not is_path_inside(file_path, base_path):
|
|
266
|
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
267
|
+
|
|
268
|
+
if file_path.is_file():
|
|
269
|
+
return FileResponse(file_path)
|
|
270
|
+
else:
|
|
271
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# -------------------------------------------------------------------------------
|
|
275
|
+
# SLACK HANDLER
|
|
276
|
+
# -------------------------------------------------------------------------------
|
|
277
|
+
|
|
278
|
+
if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_SIGNING_SECRET"):
|
|
279
|
+
from chainlit.slack.app import slack_app_handler
|
|
280
|
+
|
|
281
|
+
@router.post("/slack/events")
|
|
282
|
+
async def slack_endpoint(req: Request):
|
|
283
|
+
return await slack_app_handler.handle(req)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# -------------------------------------------------------------------------------
|
|
287
|
+
# TEAMS HANDLER
|
|
288
|
+
# -------------------------------------------------------------------------------
|
|
289
|
+
|
|
290
|
+
if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
|
|
291
|
+
from botbuilder.schema import Activity
|
|
292
|
+
|
|
293
|
+
from chainlit.teams.app import adapter, bot
|
|
294
|
+
|
|
295
|
+
@router.post("/teams/events")
|
|
296
|
+
async def teams_endpoint(req: Request):
|
|
297
|
+
body = await req.json()
|
|
298
|
+
activity = Activity().deserialize(body)
|
|
299
|
+
auth_header = req.headers.get("Authorization", "")
|
|
300
|
+
response = await adapter.process_activity(activity, auth_header, bot.on_turn)
|
|
301
|
+
return response
|
|
189
302
|
|
|
190
303
|
|
|
191
304
|
# -------------------------------------------------------------------------------
|
|
@@ -193,28 +306,55 @@ socket = SocketManager(
|
|
|
193
306
|
# -------------------------------------------------------------------------------
|
|
194
307
|
|
|
195
308
|
|
|
196
|
-
def replace_between_tags(
|
|
309
|
+
def replace_between_tags(
|
|
310
|
+
text: str, start_tag: str, end_tag: str, replacement: str
|
|
311
|
+
) -> str:
|
|
312
|
+
"""Replace text between two tags in a string."""
|
|
313
|
+
|
|
197
314
|
pattern = start_tag + ".*?" + end_tag
|
|
198
315
|
return re.sub(pattern, start_tag + replacement + end_tag, text, flags=re.DOTALL)
|
|
199
316
|
|
|
200
317
|
|
|
201
318
|
def get_html_template():
|
|
319
|
+
"""
|
|
320
|
+
Get HTML template for the index view.
|
|
321
|
+
"""
|
|
322
|
+
ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
323
|
+
|
|
324
|
+
custom_theme = None
|
|
325
|
+
custom_theme_file_path = Path(public_dir) / "theme.json"
|
|
326
|
+
if (
|
|
327
|
+
is_path_inside(custom_theme_file_path, Path(public_dir))
|
|
328
|
+
and custom_theme_file_path.is_file()
|
|
329
|
+
):
|
|
330
|
+
custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))
|
|
331
|
+
|
|
202
332
|
PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
|
|
203
333
|
JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
|
|
204
334
|
CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
|
|
205
335
|
|
|
206
336
|
default_url = "https://github.com/Chainlit/chainlit"
|
|
337
|
+
default_meta_image_url = (
|
|
338
|
+
"https://chainlit-cloud.s3.eu-west-3.amazonaws.com/logo/chainlit_banner.png"
|
|
339
|
+
)
|
|
207
340
|
url = config.ui.github or default_url
|
|
341
|
+
meta_image_url = config.ui.custom_meta_image_url or default_meta_image_url
|
|
342
|
+
favicon_path = "/favicon"
|
|
208
343
|
|
|
209
344
|
tags = f"""<title>{config.ui.name}</title>
|
|
345
|
+
<link rel="icon" href="{favicon_path}" />
|
|
210
346
|
<meta name="description" content="{config.ui.description}">
|
|
211
347
|
<meta property="og:type" content="website">
|
|
212
348
|
<meta property="og:title" content="{config.ui.name}">
|
|
213
349
|
<meta property="og:description" content="{config.ui.description}">
|
|
214
|
-
<meta property="og:image" content="
|
|
215
|
-
<meta property="og:url" content="{url}">
|
|
350
|
+
<meta property="og:image" content="{meta_image_url}">
|
|
351
|
+
<meta property="og:url" content="{url}">
|
|
352
|
+
<meta property="og:root_path" content="{ROOT_PATH}">"""
|
|
216
353
|
|
|
217
|
-
js = f"""<script>
|
|
354
|
+
js = f"""<script>
|
|
355
|
+
{f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
|
|
356
|
+
{f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
|
|
357
|
+
</script>"""
|
|
218
358
|
|
|
219
359
|
css = None
|
|
220
360
|
if config.ui.custom_css:
|
|
@@ -226,12 +366,15 @@ def get_html_template():
|
|
|
226
366
|
js += f"""<script src="{config.ui.custom_js}" defer></script>"""
|
|
227
367
|
|
|
228
368
|
font = None
|
|
229
|
-
if
|
|
230
|
-
font =
|
|
369
|
+
if custom_theme and custom_theme.get("custom_fonts"):
|
|
370
|
+
font = "\n".join(
|
|
371
|
+
f"""<link rel="stylesheet" href="{font}">"""
|
|
372
|
+
for font in custom_theme.get("custom_fonts")
|
|
373
|
+
)
|
|
231
374
|
|
|
232
375
|
index_html_file_path = os.path.join(build_dir, "index.html")
|
|
233
376
|
|
|
234
|
-
with open(index_html_file_path,
|
|
377
|
+
with open(index_html_file_path, encoding="utf-8") as f:
|
|
235
378
|
content = f.read()
|
|
236
379
|
content = content.replace(PLACEHOLDER, tags)
|
|
237
380
|
if js:
|
|
@@ -242,6 +385,9 @@ def get_html_template():
|
|
|
242
385
|
content = replace_between_tags(
|
|
243
386
|
content, "<!-- FONT START -->", "<!-- FONT END -->", font
|
|
244
387
|
)
|
|
388
|
+
if ROOT_PATH:
|
|
389
|
+
content = content.replace('href="/', f'href="{ROOT_PATH}/')
|
|
390
|
+
content = content.replace('src="/', f'src="{ROOT_PATH}/')
|
|
245
391
|
return content
|
|
246
392
|
|
|
247
393
|
|
|
@@ -250,7 +396,6 @@ def get_user_facing_url(url: URL):
|
|
|
250
396
|
Return the user facing URL for a given URL.
|
|
251
397
|
Handles deployment with proxies (like cloud run).
|
|
252
398
|
"""
|
|
253
|
-
|
|
254
399
|
chainlit_url = os.environ.get("CHAINLIT_URL")
|
|
255
400
|
|
|
256
401
|
# No config, we keep the URL as is
|
|
@@ -269,49 +414,140 @@ def get_user_facing_url(url: URL):
|
|
|
269
414
|
return config_url.__str__() + url.path
|
|
270
415
|
|
|
271
416
|
|
|
272
|
-
@
|
|
417
|
+
@router.get("/auth/config")
|
|
273
418
|
async def auth(request: Request):
|
|
274
419
|
return get_configuration()
|
|
275
420
|
|
|
276
421
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
422
|
+
def _get_response_dict(access_token: str) -> dict:
|
|
423
|
+
"""Get the response dictionary for the auth response."""
|
|
424
|
+
|
|
425
|
+
return {"success": True}
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _get_auth_response(access_token: str, redirect_to_callback: bool) -> Response:
|
|
429
|
+
"""Get the redirect params for the OAuth callback."""
|
|
430
|
+
|
|
431
|
+
response_dict = _get_response_dict(access_token)
|
|
432
|
+
|
|
433
|
+
if redirect_to_callback:
|
|
434
|
+
root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
435
|
+
redirect_url = (
|
|
436
|
+
f"{root_path}/login/callback?{urllib.parse.urlencode(response_dict)}"
|
|
282
437
|
)
|
|
283
438
|
|
|
284
|
-
|
|
285
|
-
|
|
439
|
+
return RedirectResponse(
|
|
440
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
441
|
+
url=redirect_url,
|
|
442
|
+
status_code=302,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
return JSONResponse(response_dict)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def _get_oauth_redirect_error(error: str) -> Response:
|
|
449
|
+
"""Get the redirect response for an OAuth error."""
|
|
450
|
+
params = urllib.parse.urlencode(
|
|
451
|
+
{
|
|
452
|
+
"error": error,
|
|
453
|
+
}
|
|
286
454
|
)
|
|
455
|
+
response = RedirectResponse(
|
|
456
|
+
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
457
|
+
url=f"/login?{params}", # Shouldn't there be {root_path} here?
|
|
458
|
+
)
|
|
459
|
+
return response
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
async def _authenticate_user(
|
|
463
|
+
user: Optional[User], redirect_to_callback: bool = False
|
|
464
|
+
) -> Response:
|
|
465
|
+
"""Authenticate a user and return the response."""
|
|
287
466
|
|
|
288
467
|
if not user:
|
|
289
468
|
raise HTTPException(
|
|
290
469
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
291
470
|
detail="credentialssignin",
|
|
292
471
|
)
|
|
293
|
-
|
|
472
|
+
|
|
473
|
+
# If a data layer is defined, attempt to persist user.
|
|
294
474
|
if data_layer := get_data_layer():
|
|
295
475
|
try:
|
|
296
476
|
await data_layer.create_user(user)
|
|
297
477
|
except Exception as e:
|
|
478
|
+
# Catch and log exceptions during user creation.
|
|
479
|
+
# TODO: Make this catch only specific errors and allow others to propagate.
|
|
298
480
|
logger.error(f"Error creating user: {e}")
|
|
299
481
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
482
|
+
access_token = create_jwt(user)
|
|
483
|
+
|
|
484
|
+
response = _get_auth_response(access_token, redirect_to_callback)
|
|
485
|
+
|
|
486
|
+
set_auth_cookie(response, access_token)
|
|
487
|
+
|
|
488
|
+
return response
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
@router.post("/login")
|
|
492
|
+
async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()):
|
|
493
|
+
"""
|
|
494
|
+
Login a user using the password auth callback.
|
|
495
|
+
"""
|
|
496
|
+
if not config.code.password_auth_callback:
|
|
497
|
+
raise HTTPException(
|
|
498
|
+
status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
user = await config.code.password_auth_callback(
|
|
502
|
+
form_data.username, form_data.password
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
return await _authenticate_user(user)
|
|
304
506
|
|
|
305
507
|
|
|
306
|
-
@
|
|
508
|
+
@router.post("/logout")
|
|
307
509
|
async def logout(request: Request, response: Response):
|
|
510
|
+
"""Logout the user by calling the on_logout callback."""
|
|
511
|
+
clear_auth_cookie(response)
|
|
512
|
+
|
|
308
513
|
if config.code.on_logout:
|
|
309
514
|
return await config.code.on_logout(request, response)
|
|
515
|
+
|
|
310
516
|
return {"success": True}
|
|
311
517
|
|
|
312
518
|
|
|
313
|
-
@
|
|
519
|
+
@router.post("/auth/jwt")
|
|
520
|
+
async def jwt_auth(request: Request):
|
|
521
|
+
"""Login a user using a valid jwt."""
|
|
522
|
+
from jwt import InvalidTokenError
|
|
523
|
+
|
|
524
|
+
auth_header: Optional[str] = request.headers.get("Authorization")
|
|
525
|
+
if not auth_header:
|
|
526
|
+
raise HTTPException(status_code=401, detail="Authorization header missing")
|
|
527
|
+
|
|
528
|
+
# Check if it starts with "Bearer "
|
|
529
|
+
try:
|
|
530
|
+
scheme, token = auth_header.split()
|
|
531
|
+
if scheme.lower() != "bearer":
|
|
532
|
+
raise HTTPException(
|
|
533
|
+
status_code=401,
|
|
534
|
+
detail="Invalid authentication scheme. Please use Bearer",
|
|
535
|
+
)
|
|
536
|
+
except ValueError:
|
|
537
|
+
raise HTTPException(
|
|
538
|
+
status_code=401, detail="Invalid authorization header format"
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
try:
|
|
542
|
+
user = decode_jwt(token)
|
|
543
|
+
return await _authenticate_user(user)
|
|
544
|
+
except InvalidTokenError:
|
|
545
|
+
raise HTTPException(status_code=401, detail="Invalid token")
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
@router.post("/auth/header")
|
|
314
549
|
async def header_auth(request: Request):
|
|
550
|
+
"""Login a user using the header_auth_callback."""
|
|
315
551
|
if not config.code.header_auth_callback:
|
|
316
552
|
raise HTTPException(
|
|
317
553
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -320,27 +556,12 @@ async def header_auth(request: Request):
|
|
|
320
556
|
|
|
321
557
|
user = await config.code.header_auth_callback(request.headers)
|
|
322
558
|
|
|
323
|
-
|
|
324
|
-
raise HTTPException(
|
|
325
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
326
|
-
detail="Unauthorized",
|
|
327
|
-
)
|
|
328
|
-
|
|
329
|
-
access_token = create_jwt(user)
|
|
330
|
-
if data_layer := get_data_layer():
|
|
331
|
-
try:
|
|
332
|
-
await data_layer.create_user(user)
|
|
333
|
-
except Exception as e:
|
|
334
|
-
logger.error(f"Error creating user: {e}")
|
|
335
|
-
|
|
336
|
-
return {
|
|
337
|
-
"access_token": access_token,
|
|
338
|
-
"token_type": "bearer",
|
|
339
|
-
}
|
|
559
|
+
return await _authenticate_user(user)
|
|
340
560
|
|
|
341
561
|
|
|
342
|
-
@
|
|
562
|
+
@router.get("/auth/oauth/{provider_id}")
|
|
343
563
|
async def oauth_login(provider_id: str, request: Request):
|
|
564
|
+
"""Redirect the user to the oauth provider login page."""
|
|
344
565
|
if config.code.oauth_callback is None:
|
|
345
566
|
raise HTTPException(
|
|
346
567
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -367,20 +588,13 @@ async def oauth_login(provider_id: str, request: Request):
|
|
|
367
588
|
response = RedirectResponse(
|
|
368
589
|
url=f"{provider.authorize_url}?{params}",
|
|
369
590
|
)
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
"oauth_state",
|
|
374
|
-
random,
|
|
375
|
-
httponly=True,
|
|
376
|
-
samesite=samesite,
|
|
377
|
-
secure=secure,
|
|
378
|
-
max_age=3 * 60,
|
|
379
|
-
)
|
|
591
|
+
|
|
592
|
+
set_oauth_state_cookie(response, random)
|
|
593
|
+
|
|
380
594
|
return response
|
|
381
595
|
|
|
382
596
|
|
|
383
|
-
@
|
|
597
|
+
@router.get("/auth/oauth/{provider_id}/callback")
|
|
384
598
|
async def oauth_callback(
|
|
385
599
|
provider_id: str,
|
|
386
600
|
request: Request,
|
|
@@ -388,6 +602,8 @@ async def oauth_callback(
|
|
|
388
602
|
code: Optional[str] = None,
|
|
389
603
|
state: Optional[str] = None,
|
|
390
604
|
):
|
|
605
|
+
"""Handle the oauth callback and login the user."""
|
|
606
|
+
|
|
391
607
|
if config.code.oauth_callback is None:
|
|
392
608
|
raise HTTPException(
|
|
393
609
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -402,16 +618,7 @@ async def oauth_callback(
|
|
|
402
618
|
)
|
|
403
619
|
|
|
404
620
|
if error:
|
|
405
|
-
|
|
406
|
-
{
|
|
407
|
-
"error": error,
|
|
408
|
-
}
|
|
409
|
-
)
|
|
410
|
-
response = RedirectResponse(
|
|
411
|
-
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
412
|
-
url=f"/login?{params}",
|
|
413
|
-
)
|
|
414
|
-
return response
|
|
621
|
+
return _get_oauth_redirect_error(error)
|
|
415
622
|
|
|
416
623
|
if not code or not state:
|
|
417
624
|
raise HTTPException(
|
|
@@ -419,9 +626,11 @@ async def oauth_callback(
|
|
|
419
626
|
detail="Missing code or state",
|
|
420
627
|
)
|
|
421
628
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
629
|
+
try:
|
|
630
|
+
validate_oauth_state_cookie(request, state)
|
|
631
|
+
except Exception as e:
|
|
632
|
+
logger.exception("Unable to validate oauth state: %1", e)
|
|
633
|
+
|
|
425
634
|
raise HTTPException(
|
|
426
635
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
427
636
|
detail="Unauthorized",
|
|
@@ -436,83 +645,128 @@ async def oauth_callback(
|
|
|
436
645
|
provider_id, token, raw_user_data, default_user
|
|
437
646
|
)
|
|
438
647
|
|
|
439
|
-
|
|
440
|
-
raise HTTPException(
|
|
441
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
442
|
-
detail="Unauthorized",
|
|
443
|
-
)
|
|
648
|
+
response = await _authenticate_user(user, redirect_to_callback=True)
|
|
444
649
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
if data_layer := get_data_layer():
|
|
448
|
-
try:
|
|
449
|
-
await data_layer.create_user(user)
|
|
450
|
-
except Exception as e:
|
|
451
|
-
logger.error(f"Error creating user: {e}")
|
|
650
|
+
clear_oauth_state_cookie(response)
|
|
452
651
|
|
|
453
|
-
params = urllib.parse.urlencode(
|
|
454
|
-
{
|
|
455
|
-
"access_token": access_token,
|
|
456
|
-
"token_type": "bearer",
|
|
457
|
-
}
|
|
458
|
-
)
|
|
459
|
-
response = RedirectResponse(
|
|
460
|
-
# FIXME: redirect to the right frontend base url to improve the dev environment
|
|
461
|
-
url=f"/login/callback?{params}",
|
|
462
|
-
)
|
|
463
|
-
response.delete_cookie("oauth_state")
|
|
464
652
|
return response
|
|
465
653
|
|
|
466
654
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
655
|
+
# specific route for azure ad hybrid flow
|
|
656
|
+
@router.post("/auth/oauth/azure-ad-hybrid/callback")
|
|
657
|
+
async def oauth_azure_hf_callback(
|
|
658
|
+
request: Request,
|
|
659
|
+
error: Optional[str] = None,
|
|
660
|
+
code: Annotated[Optional[str], Form()] = None,
|
|
661
|
+
id_token: Annotated[Optional[str], Form()] = None,
|
|
471
662
|
):
|
|
472
|
-
"""Handle
|
|
663
|
+
"""Handle the azure ad hybrid flow callback and login the user."""
|
|
473
664
|
|
|
474
|
-
|
|
665
|
+
provider_id = "azure-ad-hybrid"
|
|
666
|
+
if config.code.oauth_callback is None:
|
|
667
|
+
raise HTTPException(
|
|
668
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
669
|
+
detail="No oauth_callback defined",
|
|
670
|
+
)
|
|
475
671
|
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
except IndexError:
|
|
672
|
+
provider = get_oauth_provider(provider_id)
|
|
673
|
+
if not provider:
|
|
479
674
|
raise HTTPException(
|
|
480
|
-
status_code=
|
|
481
|
-
detail=f"
|
|
675
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
676
|
+
detail=f"Provider {provider_id} not found",
|
|
482
677
|
)
|
|
483
678
|
|
|
484
|
-
|
|
485
|
-
|
|
679
|
+
if error:
|
|
680
|
+
return _get_oauth_redirect_error(error)
|
|
681
|
+
|
|
682
|
+
if not code:
|
|
683
|
+
raise HTTPException(
|
|
684
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
685
|
+
detail="Missing code",
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
url = get_user_facing_url(request.url)
|
|
689
|
+
token = await provider.get_token(code, url)
|
|
690
|
+
|
|
691
|
+
(raw_user_data, default_user) = await provider.get_user_info(token)
|
|
692
|
+
|
|
693
|
+
user = await config.code.oauth_callback(
|
|
694
|
+
provider_id, token, raw_user_data, default_user, id_token
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
response = await _authenticate_user(user, redirect_to_callback=True)
|
|
698
|
+
|
|
699
|
+
clear_oauth_state_cookie(response)
|
|
486
700
|
|
|
487
701
|
return response
|
|
488
702
|
|
|
489
703
|
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
704
|
+
GenericUser = Union[User, PersistedUser, None]
|
|
705
|
+
UserParam = Annotated[GenericUser, Depends(get_current_user)]
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
@router.get("/user")
|
|
709
|
+
async def get_user(current_user: UserParam) -> GenericUser:
|
|
710
|
+
return current_user
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
_language_pattern = (
|
|
714
|
+
"^[a-zA-Z]{2,3}(-[a-zA-Z0-9]{2,3})?(-[a-zA-Z0-9]{2,8})?(-x-[a-zA-Z0-9]{1,8})?$"
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
@router.get("/project/translations")
|
|
719
|
+
async def project_translations(
|
|
720
|
+
language: str = Query(
|
|
721
|
+
default="en-US", description="Language code", pattern=_language_pattern
|
|
722
|
+
),
|
|
493
723
|
):
|
|
494
|
-
"""
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
724
|
+
"""Return project translations."""
|
|
725
|
+
|
|
726
|
+
# Load translation based on the provided language
|
|
727
|
+
translation = config.load_translation(language)
|
|
728
|
+
|
|
729
|
+
return JSONResponse(
|
|
730
|
+
content={
|
|
731
|
+
"translation": translation,
|
|
732
|
+
}
|
|
733
|
+
)
|
|
499
734
|
|
|
500
735
|
|
|
501
|
-
@
|
|
736
|
+
@router.get("/project/settings")
|
|
502
737
|
async def project_settings(
|
|
503
|
-
current_user:
|
|
504
|
-
language: str = Query(
|
|
738
|
+
current_user: UserParam,
|
|
739
|
+
language: str = Query(
|
|
740
|
+
default="en-US", description="Language code", pattern=_language_pattern
|
|
741
|
+
),
|
|
505
742
|
):
|
|
506
743
|
"""Return project settings. This is called by the UI before the establishing the websocket connection."""
|
|
507
744
|
|
|
508
|
-
# Load
|
|
509
|
-
|
|
745
|
+
# Load the markdown file based on the provided language
|
|
746
|
+
|
|
747
|
+
markdown = get_markdown_str(config.root, language)
|
|
510
748
|
|
|
511
749
|
profiles = []
|
|
512
750
|
if config.code.set_chat_profiles:
|
|
513
751
|
chat_profiles = await config.code.set_chat_profiles(current_user)
|
|
514
752
|
if chat_profiles:
|
|
515
753
|
profiles = [p.to_dict() for p in chat_profiles]
|
|
754
|
+
|
|
755
|
+
starters = []
|
|
756
|
+
if config.code.set_starters:
|
|
757
|
+
starters = await config.code.set_starters(current_user)
|
|
758
|
+
if starters:
|
|
759
|
+
starters = [s.to_dict() for s in starters]
|
|
760
|
+
|
|
761
|
+
if config.code.on_audio_chunk:
|
|
762
|
+
config.features.audio.enabled = True
|
|
763
|
+
|
|
764
|
+
debug_url = None
|
|
765
|
+
data_layer = get_data_layer()
|
|
766
|
+
|
|
767
|
+
if data_layer and config.run.debug:
|
|
768
|
+
debug_url = await data_layer.build_debug_url()
|
|
769
|
+
|
|
516
770
|
return JSONResponse(
|
|
517
771
|
content={
|
|
518
772
|
"ui": config.ui.to_dict(),
|
|
@@ -520,18 +774,19 @@ async def project_settings(
|
|
|
520
774
|
"userEnv": config.project.user_env,
|
|
521
775
|
"dataPersistence": get_data_layer() is not None,
|
|
522
776
|
"threadResumable": bool(config.code.on_chat_resume),
|
|
523
|
-
"markdown":
|
|
777
|
+
"markdown": markdown,
|
|
524
778
|
"chatProfiles": profiles,
|
|
525
|
-
"
|
|
779
|
+
"starters": starters,
|
|
780
|
+
"debugUrl": debug_url,
|
|
526
781
|
}
|
|
527
782
|
)
|
|
528
783
|
|
|
529
784
|
|
|
530
|
-
@
|
|
785
|
+
@router.put("/feedback")
|
|
531
786
|
async def update_feedback(
|
|
532
787
|
request: Request,
|
|
533
788
|
update: UpdateFeedbackRequest,
|
|
534
|
-
current_user:
|
|
789
|
+
current_user: UserParam,
|
|
535
790
|
):
|
|
536
791
|
"""Update the human feedback for a particular message."""
|
|
537
792
|
data_layer = get_data_layer()
|
|
@@ -541,36 +796,63 @@ async def update_feedback(
|
|
|
541
796
|
try:
|
|
542
797
|
feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)
|
|
543
798
|
except Exception as e:
|
|
544
|
-
raise HTTPException(detail=str(e), status_code=500)
|
|
799
|
+
raise HTTPException(detail=str(e), status_code=500) from e
|
|
545
800
|
|
|
546
801
|
return JSONResponse(content={"success": True, "feedbackId": feedback_id})
|
|
547
802
|
|
|
548
803
|
|
|
549
|
-
@
|
|
804
|
+
@router.delete("/feedback")
|
|
805
|
+
async def delete_feedback(
|
|
806
|
+
request: Request,
|
|
807
|
+
payload: DeleteFeedbackRequest,
|
|
808
|
+
current_user: UserParam,
|
|
809
|
+
):
|
|
810
|
+
"""Delete a feedback."""
|
|
811
|
+
|
|
812
|
+
data_layer = get_data_layer()
|
|
813
|
+
|
|
814
|
+
if not data_layer:
|
|
815
|
+
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
816
|
+
|
|
817
|
+
feedback_id = payload.feedbackId
|
|
818
|
+
|
|
819
|
+
await data_layer.delete_feedback(feedback_id)
|
|
820
|
+
return JSONResponse(content={"success": True})
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
@router.post("/project/threads")
|
|
550
824
|
async def get_user_threads(
|
|
551
825
|
request: Request,
|
|
552
826
|
payload: GetThreadsRequest,
|
|
553
|
-
current_user:
|
|
827
|
+
current_user: UserParam,
|
|
554
828
|
):
|
|
555
829
|
"""Get the threads page by page."""
|
|
556
|
-
# Only show the current user threads
|
|
557
830
|
|
|
558
831
|
data_layer = get_data_layer()
|
|
559
832
|
|
|
560
833
|
if not data_layer:
|
|
561
834
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
562
835
|
|
|
563
|
-
|
|
836
|
+
if not current_user:
|
|
837
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
838
|
+
|
|
839
|
+
if not isinstance(current_user, PersistedUser):
|
|
840
|
+
persisted_user = await data_layer.get_user(identifier=current_user.identifier)
|
|
841
|
+
if not persisted_user:
|
|
842
|
+
raise HTTPException(status_code=404, detail="User not found")
|
|
843
|
+
payload.filter.userId = persisted_user.id
|
|
844
|
+
else:
|
|
845
|
+
payload.filter.userId = current_user.id
|
|
564
846
|
|
|
565
847
|
res = await data_layer.list_threads(payload.pagination, payload.filter)
|
|
566
848
|
return JSONResponse(content=res.to_dict())
|
|
567
849
|
|
|
568
850
|
|
|
569
|
-
@
|
|
851
|
+
@router.get("/project/thread/{thread_id}")
|
|
570
852
|
async def get_thread(
|
|
571
853
|
request: Request,
|
|
572
854
|
thread_id: str,
|
|
573
|
-
current_user:
|
|
855
|
+
current_user: UserParam,
|
|
574
856
|
):
|
|
575
857
|
"""Get a specific thread."""
|
|
576
858
|
data_layer = get_data_layer()
|
|
@@ -578,18 +860,21 @@ async def get_thread(
|
|
|
578
860
|
if not data_layer:
|
|
579
861
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
580
862
|
|
|
863
|
+
if not current_user:
|
|
864
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
865
|
+
|
|
581
866
|
await is_thread_author(current_user.identifier, thread_id)
|
|
582
867
|
|
|
583
868
|
res = await data_layer.get_thread(thread_id)
|
|
584
869
|
return JSONResponse(content=res)
|
|
585
870
|
|
|
586
871
|
|
|
587
|
-
@
|
|
872
|
+
@router.get("/project/thread/{thread_id}/element/{element_id}")
|
|
588
873
|
async def get_thread_element(
|
|
589
874
|
request: Request,
|
|
590
875
|
thread_id: str,
|
|
591
876
|
element_id: str,
|
|
592
|
-
current_user:
|
|
877
|
+
current_user: UserParam,
|
|
593
878
|
):
|
|
594
879
|
"""Get a specific thread element."""
|
|
595
880
|
data_layer = get_data_layer()
|
|
@@ -597,17 +882,135 @@ async def get_thread_element(
|
|
|
597
882
|
if not data_layer:
|
|
598
883
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
599
884
|
|
|
885
|
+
if not current_user:
|
|
886
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
887
|
+
|
|
600
888
|
await is_thread_author(current_user.identifier, thread_id)
|
|
601
889
|
|
|
602
890
|
res = await data_layer.get_element(thread_id, element_id)
|
|
603
891
|
return JSONResponse(content=res)
|
|
604
892
|
|
|
605
893
|
|
|
606
|
-
@
|
|
894
|
+
@router.put("/project/element")
|
|
895
|
+
async def update_thread_element(
|
|
896
|
+
payload: ElementRequest,
|
|
897
|
+
current_user: UserParam,
|
|
898
|
+
):
|
|
899
|
+
"""Update a specific thread element."""
|
|
900
|
+
|
|
901
|
+
from chainlit.context import init_ws_context
|
|
902
|
+
from chainlit.element import CustomElement, ElementDict
|
|
903
|
+
from chainlit.session import WebsocketSession
|
|
904
|
+
|
|
905
|
+
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
906
|
+
context = init_ws_context(session)
|
|
907
|
+
|
|
908
|
+
element_dict = cast(ElementDict, payload.element)
|
|
909
|
+
|
|
910
|
+
if element_dict["type"] != "custom":
|
|
911
|
+
return {"success": False}
|
|
912
|
+
|
|
913
|
+
element = CustomElement(
|
|
914
|
+
id=element_dict["id"],
|
|
915
|
+
object_key=element_dict["objectKey"],
|
|
916
|
+
chainlit_key=element_dict["chainlitKey"],
|
|
917
|
+
url=element_dict["url"],
|
|
918
|
+
for_id=element_dict.get("forId") or "",
|
|
919
|
+
thread_id=element_dict.get("threadId") or "",
|
|
920
|
+
name=element_dict["name"],
|
|
921
|
+
props=element_dict.get("props") or {},
|
|
922
|
+
display=element_dict["display"],
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
if current_user:
|
|
926
|
+
if (
|
|
927
|
+
not context.session.user
|
|
928
|
+
or context.session.user.identifier != current_user.identifier
|
|
929
|
+
):
|
|
930
|
+
raise HTTPException(
|
|
931
|
+
status_code=401,
|
|
932
|
+
detail="You are not authorized to update elements for this session",
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
await element.update()
|
|
936
|
+
return {"success": True}
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
@router.delete("/project/element")
|
|
940
|
+
async def delete_thread_element(
|
|
941
|
+
payload: ElementRequest,
|
|
942
|
+
current_user: UserParam,
|
|
943
|
+
):
|
|
944
|
+
"""Delete a specific thread element."""
|
|
945
|
+
|
|
946
|
+
from chainlit.context import init_ws_context
|
|
947
|
+
from chainlit.element import CustomElement, ElementDict
|
|
948
|
+
from chainlit.session import WebsocketSession
|
|
949
|
+
|
|
950
|
+
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
951
|
+
context = init_ws_context(session)
|
|
952
|
+
|
|
953
|
+
element_dict = cast(ElementDict, payload.element)
|
|
954
|
+
|
|
955
|
+
if element_dict["type"] != "custom":
|
|
956
|
+
return {"success": False}
|
|
957
|
+
|
|
958
|
+
element = CustomElement(
|
|
959
|
+
id=element_dict["id"],
|
|
960
|
+
object_key=element_dict["objectKey"],
|
|
961
|
+
chainlit_key=element_dict["chainlitKey"],
|
|
962
|
+
url=element_dict["url"],
|
|
963
|
+
for_id=element_dict.get("forId") or "",
|
|
964
|
+
thread_id=element_dict.get("threadId") or "",
|
|
965
|
+
name=element_dict["name"],
|
|
966
|
+
props=element_dict.get("props") or {},
|
|
967
|
+
display=element_dict["display"],
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
if current_user:
|
|
971
|
+
if (
|
|
972
|
+
not context.session.user
|
|
973
|
+
or context.session.user.identifier != current_user.identifier
|
|
974
|
+
):
|
|
975
|
+
raise HTTPException(
|
|
976
|
+
status_code=401,
|
|
977
|
+
detail="You are not authorized to remove elements for this session",
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
await element.remove()
|
|
981
|
+
|
|
982
|
+
return {"success": True}
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
@router.put("/project/thread")
|
|
986
|
+
async def rename_thread(
|
|
987
|
+
request: Request,
|
|
988
|
+
payload: UpdateThreadRequest,
|
|
989
|
+
current_user: UserParam,
|
|
990
|
+
):
|
|
991
|
+
"""Rename a thread."""
|
|
992
|
+
|
|
993
|
+
data_layer = get_data_layer()
|
|
994
|
+
|
|
995
|
+
if not data_layer:
|
|
996
|
+
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
997
|
+
|
|
998
|
+
if not current_user:
|
|
999
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
1000
|
+
|
|
1001
|
+
thread_id = payload.threadId
|
|
1002
|
+
|
|
1003
|
+
await is_thread_author(current_user.identifier, thread_id)
|
|
1004
|
+
|
|
1005
|
+
await data_layer.update_thread(thread_id, name=payload.name)
|
|
1006
|
+
return JSONResponse(content={"success": True})
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
@router.delete("/project/thread")
|
|
607
1010
|
async def delete_thread(
|
|
608
1011
|
request: Request,
|
|
609
1012
|
payload: DeleteThreadRequest,
|
|
610
|
-
current_user:
|
|
1013
|
+
current_user: UserParam,
|
|
611
1014
|
):
|
|
612
1015
|
"""Delete a thread."""
|
|
613
1016
|
|
|
@@ -616,6 +1019,9 @@ async def delete_thread(
|
|
|
616
1019
|
if not data_layer:
|
|
617
1020
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
618
1021
|
|
|
1022
|
+
if not current_user:
|
|
1023
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
1024
|
+
|
|
619
1025
|
thread_id = payload.threadId
|
|
620
1026
|
|
|
621
1027
|
await is_thread_author(current_user.identifier, thread_id)
|
|
@@ -624,14 +1030,56 @@ async def delete_thread(
|
|
|
624
1030
|
return JSONResponse(content={"success": True})
|
|
625
1031
|
|
|
626
1032
|
|
|
627
|
-
@
|
|
1033
|
+
@router.post("/project/action")
|
|
1034
|
+
async def call_action(
|
|
1035
|
+
payload: CallActionRequest,
|
|
1036
|
+
current_user: UserParam,
|
|
1037
|
+
):
|
|
1038
|
+
"""Run an action."""
|
|
1039
|
+
|
|
1040
|
+
from chainlit.action import Action
|
|
1041
|
+
from chainlit.context import init_ws_context
|
|
1042
|
+
from chainlit.session import WebsocketSession
|
|
1043
|
+
|
|
1044
|
+
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
1045
|
+
context = init_ws_context(session)
|
|
1046
|
+
|
|
1047
|
+
action = Action(**payload.action)
|
|
1048
|
+
|
|
1049
|
+
if current_user:
|
|
1050
|
+
if (
|
|
1051
|
+
not context.session.user
|
|
1052
|
+
or context.session.user.identifier != current_user.identifier
|
|
1053
|
+
):
|
|
1054
|
+
raise HTTPException(
|
|
1055
|
+
status_code=401,
|
|
1056
|
+
detail="You are not authorized to upload files for this session",
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
callback = config.code.action_callbacks.get(action.name)
|
|
1060
|
+
if callback:
|
|
1061
|
+
if not context.session.has_first_interaction:
|
|
1062
|
+
context.session.has_first_interaction = True
|
|
1063
|
+
asyncio.create_task(context.emitter.init_thread(action.name))
|
|
1064
|
+
|
|
1065
|
+
await callback(action)
|
|
1066
|
+
else:
|
|
1067
|
+
raise HTTPException(
|
|
1068
|
+
status_code=404,
|
|
1069
|
+
detail=f"No callback found for action {action.name}",
|
|
1070
|
+
)
|
|
1071
|
+
|
|
1072
|
+
return JSONResponse(content={"success": True})
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
@router.post("/project/file")
|
|
628
1076
|
async def upload_file(
|
|
1077
|
+
current_user: UserParam,
|
|
629
1078
|
session_id: str,
|
|
630
1079
|
file: UploadFile,
|
|
631
|
-
current_user: Annotated[
|
|
632
|
-
Union[None, User, PersistedUser], Depends(get_current_user)
|
|
633
|
-
],
|
|
634
1080
|
):
|
|
1081
|
+
"""Upload a file to the session files directory."""
|
|
1082
|
+
|
|
635
1083
|
from chainlit.session import WebsocketSession
|
|
636
1084
|
|
|
637
1085
|
session = WebsocketSession.get_by_id(session_id)
|
|
@@ -653,28 +1101,122 @@ async def upload_file(
|
|
|
653
1101
|
|
|
654
1102
|
content = await file.read()
|
|
655
1103
|
|
|
1104
|
+
assert file.filename, "No filename for uploaded file"
|
|
1105
|
+
assert file.content_type, "No content type for uploaded file"
|
|
1106
|
+
|
|
1107
|
+
try:
|
|
1108
|
+
validate_file_upload(file)
|
|
1109
|
+
except ValueError as e:
|
|
1110
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
1111
|
+
|
|
656
1112
|
file_response = await session.persist_file(
|
|
657
1113
|
name=file.filename, content=content, mime=file.content_type
|
|
658
1114
|
)
|
|
659
1115
|
|
|
660
|
-
return JSONResponse(file_response)
|
|
1116
|
+
return JSONResponse(content=file_response)
|
|
1117
|
+
|
|
1118
|
+
|
|
1119
|
+
def validate_file_upload(file: UploadFile):
|
|
1120
|
+
"""Validate the file upload as configured in config.features.spontaneous_file_upload.
|
|
1121
|
+
Args:
|
|
1122
|
+
file (UploadFile): The file to validate.
|
|
1123
|
+
Raises:
|
|
1124
|
+
ValueError: If the file is not allowed.
|
|
1125
|
+
"""
|
|
1126
|
+
# TODO: This logic/endpoint is shared across spontaneous uploads and the AskFileMessage API.
|
|
1127
|
+
# Commenting this check until we find a better solution
|
|
1128
|
+
|
|
1129
|
+
# if config.features.spontaneous_file_upload is None:
|
|
1130
|
+
# """Default for a missing config is to allow the fileupload without any restrictions"""
|
|
1131
|
+
# return
|
|
1132
|
+
# if not config.features.spontaneous_file_upload.enabled:
|
|
1133
|
+
# raise ValueError("File upload is not enabled")
|
|
1134
|
+
|
|
1135
|
+
validate_file_mime_type(file)
|
|
1136
|
+
validate_file_size(file)
|
|
1137
|
+
|
|
661
1138
|
|
|
1139
|
+
def validate_file_mime_type(file: UploadFile):
|
|
1140
|
+
"""Validate the file mime type as configured in config.features.spontaneous_file_upload.
|
|
1141
|
+
Args:
|
|
1142
|
+
file (UploadFile): The file to validate.
|
|
1143
|
+
Raises:
|
|
1144
|
+
ValueError: If the file type is not allowed.
|
|
1145
|
+
"""
|
|
1146
|
+
|
|
1147
|
+
if (
|
|
1148
|
+
config.features.spontaneous_file_upload is None
|
|
1149
|
+
or config.features.spontaneous_file_upload.accept is None
|
|
1150
|
+
):
|
|
1151
|
+
"Accept is not configured, allowing all file types"
|
|
1152
|
+
return
|
|
1153
|
+
|
|
1154
|
+
accept = config.features.spontaneous_file_upload.accept
|
|
1155
|
+
|
|
1156
|
+
assert isinstance(accept, List) or isinstance(accept, dict), (
|
|
1157
|
+
"Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
if isinstance(accept, List):
|
|
1161
|
+
for pattern in accept:
|
|
1162
|
+
if fnmatch.fnmatch(file.content_type, pattern):
|
|
1163
|
+
return
|
|
1164
|
+
elif isinstance(accept, dict):
|
|
1165
|
+
for pattern, extensions in accept.items():
|
|
1166
|
+
if fnmatch.fnmatch(file.content_type, pattern):
|
|
1167
|
+
if len(extensions) == 0:
|
|
1168
|
+
return
|
|
1169
|
+
for extension in extensions:
|
|
1170
|
+
if file.filename is not None and file.filename.endswith(extension):
|
|
1171
|
+
return
|
|
1172
|
+
raise ValueError("File type not allowed")
|
|
1173
|
+
|
|
1174
|
+
|
|
1175
|
+
def validate_file_size(file: UploadFile):
|
|
1176
|
+
"""Validate the file size as configured in config.features.spontaneous_file_upload.
|
|
1177
|
+
Args:
|
|
1178
|
+
file (UploadFile): The file to validate.
|
|
1179
|
+
Raises:
|
|
1180
|
+
ValueError: If the file size is too large.
|
|
1181
|
+
"""
|
|
1182
|
+
if (
|
|
1183
|
+
config.features.spontaneous_file_upload is None
|
|
1184
|
+
or config.features.spontaneous_file_upload.max_size_mb is None
|
|
1185
|
+
):
|
|
1186
|
+
return
|
|
662
1187
|
|
|
663
|
-
|
|
1188
|
+
if (
|
|
1189
|
+
file.size is not None
|
|
1190
|
+
and file.size
|
|
1191
|
+
> config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
|
|
1192
|
+
):
|
|
1193
|
+
raise ValueError("File size too large")
|
|
1194
|
+
|
|
1195
|
+
|
|
1196
|
+
@router.get("/project/file/{file_id}")
|
|
664
1197
|
async def get_file(
|
|
665
1198
|
file_id: str,
|
|
666
|
-
session_id:
|
|
1199
|
+
session_id: str,
|
|
1200
|
+
current_user: UserParam,
|
|
667
1201
|
):
|
|
1202
|
+
"""Get a file from the session files directory."""
|
|
668
1203
|
from chainlit.session import WebsocketSession
|
|
669
1204
|
|
|
670
1205
|
session = WebsocketSession.get_by_id(session_id) if session_id else None
|
|
671
1206
|
|
|
672
1207
|
if not session:
|
|
673
1208
|
raise HTTPException(
|
|
674
|
-
status_code=
|
|
675
|
-
detail="
|
|
1209
|
+
status_code=401,
|
|
1210
|
+
detail="Unauthorized",
|
|
676
1211
|
)
|
|
677
1212
|
|
|
1213
|
+
if current_user:
|
|
1214
|
+
if not session.user or session.user.identifier != current_user.identifier:
|
|
1215
|
+
raise HTTPException(
|
|
1216
|
+
status_code=401,
|
|
1217
|
+
detail="You are not authorized to download files from this session",
|
|
1218
|
+
)
|
|
1219
|
+
|
|
678
1220
|
if file_id in session.files:
|
|
679
1221
|
file = session.files[file_id]
|
|
680
1222
|
return FileResponse(file["path"], media_type=file["type"])
|
|
@@ -682,26 +1224,9 @@ async def get_file(
|
|
|
682
1224
|
raise HTTPException(status_code=404, detail="File not found")
|
|
683
1225
|
|
|
684
1226
|
|
|
685
|
-
@
|
|
686
|
-
async def serve_file(
|
|
687
|
-
filename: str,
|
|
688
|
-
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
689
|
-
):
|
|
690
|
-
base_path = Path(config.project.local_fs_path).resolve()
|
|
691
|
-
file_path = (base_path / filename).resolve()
|
|
692
|
-
|
|
693
|
-
# Check if the base path is a parent of the file path
|
|
694
|
-
if base_path not in file_path.parents:
|
|
695
|
-
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
696
|
-
|
|
697
|
-
if file_path.is_file():
|
|
698
|
-
return FileResponse(file_path)
|
|
699
|
-
else:
|
|
700
|
-
raise HTTPException(status_code=404, detail="File not found")
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
@app.get("/favicon")
|
|
1227
|
+
@router.get("/favicon")
|
|
704
1228
|
async def get_favicon():
|
|
1229
|
+
"""Get the favicon for the UI."""
|
|
705
1230
|
custom_favicon_path = os.path.join(APP_ROOT, "public", "favicon.*")
|
|
706
1231
|
files = glob.glob(custom_favicon_path)
|
|
707
1232
|
|
|
@@ -715,8 +1240,9 @@ async def get_favicon():
|
|
|
715
1240
|
return FileResponse(favicon_path, media_type=media_type)
|
|
716
1241
|
|
|
717
1242
|
|
|
718
|
-
@
|
|
1243
|
+
@router.get("/logo")
|
|
719
1244
|
async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
|
|
1245
|
+
"""Get the default logo for the UI."""
|
|
720
1246
|
theme_value = theme.value if theme else Theme.light.value
|
|
721
1247
|
logo_path = None
|
|
722
1248
|
|
|
@@ -732,19 +1258,54 @@ async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
|
|
|
732
1258
|
|
|
733
1259
|
if not logo_path:
|
|
734
1260
|
raise HTTPException(status_code=404, detail="Missing default logo")
|
|
1261
|
+
|
|
735
1262
|
media_type, _ = mimetypes.guess_type(logo_path)
|
|
736
1263
|
|
|
737
1264
|
return FileResponse(logo_path, media_type=media_type)
|
|
738
1265
|
|
|
739
1266
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
1267
|
+
@router.get("/avatars/{avatar_id:str}")
|
|
1268
|
+
async def get_avatar(avatar_id: str):
|
|
1269
|
+
"""Get the avatar for the user based on the avatar_id."""
|
|
1270
|
+
if not re.match(r"^[a-zA-Z0-9_ -]+$", avatar_id):
|
|
1271
|
+
raise HTTPException(status_code=400, detail="Invalid avatar_id")
|
|
1272
|
+
|
|
1273
|
+
if avatar_id == "default":
|
|
1274
|
+
avatar_id = config.ui.name
|
|
1275
|
+
|
|
1276
|
+
avatar_id = avatar_id.strip().lower().replace(" ", "_")
|
|
1277
|
+
|
|
1278
|
+
base_path = Path(APP_ROOT) / "public" / "avatars"
|
|
1279
|
+
avatar_pattern = f"{avatar_id}.*"
|
|
1280
|
+
|
|
1281
|
+
matching_files = base_path.glob(avatar_pattern)
|
|
1282
|
+
|
|
1283
|
+
if avatar_path := next(matching_files, None):
|
|
1284
|
+
if not is_path_inside(avatar_path, base_path):
|
|
1285
|
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
1286
|
+
|
|
1287
|
+
media_type, _ = mimetypes.guess_type(str(avatar_path))
|
|
1288
|
+
|
|
1289
|
+
return FileResponse(avatar_path, media_type=media_type)
|
|
1290
|
+
|
|
1291
|
+
return await get_favicon()
|
|
1292
|
+
|
|
1293
|
+
|
|
1294
|
+
@router.head("/")
|
|
1295
|
+
def status_check():
|
|
1296
|
+
"""Check if the site is operational."""
|
|
1297
|
+
return {"message": "Site is operational"}
|
|
1298
|
+
|
|
1299
|
+
|
|
1300
|
+
@router.get("/{full_path:path}")
|
|
1301
|
+
async def serve():
|
|
1302
|
+
html_template = get_html_template()
|
|
1303
|
+
"""Serve the UI files."""
|
|
1304
|
+
response = HTMLResponse(content=html_template, status_code=200)
|
|
1305
|
+
|
|
1306
|
+
return response
|
|
746
1307
|
|
|
747
|
-
return response
|
|
748
1308
|
|
|
1309
|
+
app.include_router(router)
|
|
749
1310
|
|
|
750
1311
|
import chainlit.socket # noqa
|