chainlit 2.0rc1__py3-none-any.whl → 2.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +47 -57
- chainlit/action.py +8 -8
- chainlit/auth/__init__.py +1 -1
- chainlit/auth/cookie.py +7 -9
- chainlit/auth/jwt.py +5 -3
- chainlit/callbacks.py +1 -1
- chainlit/config.py +8 -59
- chainlit/copilot/dist/index.js +8319 -1019
- chainlit/data/__init__.py +71 -2
- chainlit/data/chainlit_data_layer.py +608 -0
- chainlit/data/literalai.py +1 -1
- chainlit/data/sql_alchemy.py +26 -2
- chainlit/data/storage_clients/azure_blob.py +89 -0
- chainlit/data/storage_clients/base.py +10 -0
- chainlit/data/storage_clients/gcs.py +88 -0
- chainlit/data/storage_clients/s3.py +42 -4
- chainlit/element.py +7 -4
- chainlit/emitter.py +9 -14
- chainlit/frontend/dist/assets/{DailyMotion-C-_sjrtO.js → DailyMotion-DFvM941y.js} +1 -1
- chainlit/frontend/dist/assets/Dataframe-CA6SlUSB.js +22 -0
- chainlit/frontend/dist/assets/{Facebook-bB34P03l.js → Facebook-BM4MwXR1.js} +1 -1
- chainlit/frontend/dist/assets/{FilePlayer-BWgqGrXv.js → FilePlayer-CfjB8iXr.js} +1 -1
- chainlit/frontend/dist/assets/{Kaltura-OY4P9Ofd.js → Kaltura-Bg-U6Xkz.js} +1 -1
- chainlit/frontend/dist/assets/{Mixcloud-9CtT8w5Y.js → Mixcloud-xJrfoMTv.js} +1 -1
- chainlit/frontend/dist/assets/{Mux-BH9A0qEi.js → Mux-CKnKDBmk.js} +1 -1
- chainlit/frontend/dist/assets/{Preview-Og00EJ05.js → Preview-DwHPdmIg.js} +1 -1
- chainlit/frontend/dist/assets/{SoundCloud-D7resGfn.js → SoundCloud-Crd5dwXV.js} +1 -1
- chainlit/frontend/dist/assets/{Streamable-6f_6bYz1.js → Streamable-Dq0c8lyx.js} +1 -1
- chainlit/frontend/dist/assets/{Twitch-BZJl3peM.js → Twitch-DIDvP936.js} +1 -1
- chainlit/frontend/dist/assets/{Vidyard-B7tv4b8_.js → Vidyard-B1dz9WN4.js} +1 -1
- chainlit/frontend/dist/assets/{Vimeo-F-eA4zQI.js → Vimeo-22Su6q2w.js} +1 -1
- chainlit/frontend/dist/assets/Wistia-C7adXRjN.js +1 -0
- chainlit/frontend/dist/assets/{YouTube-aFdJGjI1.js → YouTube-Dt4UMtQI.js} +1 -1
- chainlit/frontend/dist/assets/index-DbdLVHtZ.js +8665 -0
- chainlit/frontend/dist/assets/index-g8LTJwwr.css +1 -0
- chainlit/frontend/dist/assets/{react-plotly-DoUJXMgz.js → react-plotly-DvpXYYRJ.js} +1 -1
- chainlit/frontend/dist/index.html +2 -2
- chainlit/message.py +1 -3
- chainlit/server.py +297 -78
- chainlit/session.py +9 -0
- chainlit/socket.py +5 -53
- chainlit/step.py +0 -1
- chainlit/translations/en-US.json +1 -1
- chainlit/types.py +17 -3
- chainlit/user_session.py +1 -0
- {chainlit-2.0rc1.dist-info → chainlit-2.0.2.dist-info}/METADATA +4 -35
- {chainlit-2.0rc1.dist-info → chainlit-2.0.2.dist-info}/RECORD +49 -45
- chainlit/frontend/dist/assets/Wistia-Dhxhn3IB.js +0 -1
- chainlit/frontend/dist/assets/index-Ba33_hdJ.js +0 -1091
- chainlit/frontend/dist/assets/index-CwmincdQ.css +0 -1
- {chainlit-2.0rc1.dist-info → chainlit-2.0.2.dist-info}/WHEEL +0 -0
- {chainlit-2.0rc1.dist-info → chainlit-2.0.2.dist-info}/entry_points.txt +0 -0
|
@@ -21,8 +21,8 @@
|
|
|
21
21
|
<script>
|
|
22
22
|
const global = globalThis;
|
|
23
23
|
</script>
|
|
24
|
-
<script type="module" crossorigin src="/assets/index-
|
|
25
|
-
<link rel="stylesheet" crossorigin href="/assets/index-
|
|
24
|
+
<script type="module" crossorigin src="/assets/index-DbdLVHtZ.js"></script>
|
|
25
|
+
<link rel="stylesheet" crossorigin href="/assets/index-g8LTJwwr.css">
|
|
26
26
|
</head>
|
|
27
27
|
<body>
|
|
28
28
|
<div id="root"></div>
|
chainlit/message.py
CHANGED
|
@@ -43,7 +43,6 @@ class MessageBase(ABC):
|
|
|
43
43
|
metadata: Optional[Dict] = None
|
|
44
44
|
tags: Optional[List[str]] = None
|
|
45
45
|
wait_for_answer = False
|
|
46
|
-
indent: Optional[int] = None
|
|
47
46
|
|
|
48
47
|
def __post_init__(self) -> None:
|
|
49
48
|
trace_event(f"init {self.__class__.__name__}")
|
|
@@ -86,7 +85,6 @@ class MessageBase(ABC):
|
|
|
86
85
|
"streaming": self.streaming,
|
|
87
86
|
"isError": self.is_error,
|
|
88
87
|
"waitForAnswer": self.wait_for_answer,
|
|
89
|
-
"indent": self.indent,
|
|
90
88
|
"metadata": self.metadata or {},
|
|
91
89
|
"tags": self.tags,
|
|
92
90
|
}
|
|
@@ -542,7 +540,7 @@ class AskActionMessage(AskMessageBase):
|
|
|
542
540
|
if res is None:
|
|
543
541
|
self.content = "Timed out: no action was taken"
|
|
544
542
|
else:
|
|
545
|
-
self.content = f
|
|
543
|
+
self.content = f"**Selected:** {res['label']}"
|
|
546
544
|
|
|
547
545
|
self.wait_for_answer = False
|
|
548
546
|
|
chainlit/server.py
CHANGED
|
@@ -10,7 +10,7 @@ import urllib.parse
|
|
|
10
10
|
import webbrowser
|
|
11
11
|
from contextlib import asynccontextmanager
|
|
12
12
|
from pathlib import Path
|
|
13
|
-
from typing import List, Optional, Union
|
|
13
|
+
from typing import List, Optional, Union, cast
|
|
14
14
|
|
|
15
15
|
import socketio
|
|
16
16
|
from fastapi import (
|
|
@@ -27,13 +27,12 @@ from fastapi import (
|
|
|
27
27
|
)
|
|
28
28
|
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
|
|
29
29
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
30
|
-
from fastapi.staticfiles import StaticFiles
|
|
31
30
|
from starlette.datastructures import URL
|
|
32
31
|
from starlette.middleware.cors import CORSMiddleware
|
|
33
32
|
from typing_extensions import Annotated
|
|
34
33
|
from watchfiles import awatch
|
|
35
34
|
|
|
36
|
-
from chainlit.auth import create_jwt, get_configuration, get_current_user
|
|
35
|
+
from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
|
|
37
36
|
from chainlit.auth.cookie import (
|
|
38
37
|
clear_auth_cookie,
|
|
39
38
|
clear_oauth_state_cookie,
|
|
@@ -49,6 +48,7 @@ from chainlit.config import (
|
|
|
49
48
|
PACKAGE_ROOT,
|
|
50
49
|
config,
|
|
51
50
|
load_module,
|
|
51
|
+
public_dir,
|
|
52
52
|
reload_config,
|
|
53
53
|
)
|
|
54
54
|
from chainlit.data import get_data_layer
|
|
@@ -58,11 +58,14 @@ from chainlit.markdown import get_markdown_str
|
|
|
58
58
|
from chainlit.oauth_providers import get_oauth_provider
|
|
59
59
|
from chainlit.secret import random_secret
|
|
60
60
|
from chainlit.types import (
|
|
61
|
+
CallActionRequest,
|
|
61
62
|
DeleteFeedbackRequest,
|
|
62
63
|
DeleteThreadRequest,
|
|
64
|
+
ElementRequest,
|
|
63
65
|
GetThreadsRequest,
|
|
64
66
|
Theme,
|
|
65
67
|
UpdateFeedbackRequest,
|
|
68
|
+
UpdateThreadRequest,
|
|
66
69
|
)
|
|
67
70
|
from chainlit.user import PersistedUser, User
|
|
68
71
|
|
|
@@ -213,29 +216,59 @@ app.add_middleware(
|
|
|
213
216
|
|
|
214
217
|
router = APIRouter(prefix=PREFIX)
|
|
215
218
|
|
|
216
|
-
app.mount(
|
|
217
|
-
f"{PREFIX}/public",
|
|
218
|
-
StaticFiles(directory="public", check_dir=False),
|
|
219
|
-
name="public",
|
|
220
|
-
)
|
|
221
219
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
),
|
|
228
|
-
name="assets",
|
|
229
|
-
)
|
|
220
|
+
@router.get("/public/{filename:path}")
|
|
221
|
+
async def serve_public_file(
|
|
222
|
+
filename: str,
|
|
223
|
+
):
|
|
224
|
+
"""Serve a file from public dir."""
|
|
230
225
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
)
|
|
226
|
+
base_path = Path(public_dir)
|
|
227
|
+
file_path = (base_path / filename).resolve()
|
|
228
|
+
|
|
229
|
+
if not is_path_inside(file_path, base_path):
|
|
230
|
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
231
|
+
|
|
232
|
+
if file_path.is_file():
|
|
233
|
+
return FileResponse(file_path)
|
|
234
|
+
else:
|
|
235
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@router.get("/assets/{filename:path}")
|
|
239
|
+
async def serve_asset_file(
|
|
240
|
+
filename: str,
|
|
241
|
+
):
|
|
242
|
+
"""Serve a file from assets dir."""
|
|
243
|
+
|
|
244
|
+
base_path = Path(os.path.join(build_dir, "assets"))
|
|
245
|
+
file_path = (base_path / filename).resolve()
|
|
246
|
+
|
|
247
|
+
if not is_path_inside(file_path, base_path):
|
|
248
|
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
249
|
+
|
|
250
|
+
if file_path.is_file():
|
|
251
|
+
return FileResponse(file_path)
|
|
252
|
+
else:
|
|
253
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@router.get("/copilot/{filename:path}")
|
|
257
|
+
async def serve_copilot_file(
|
|
258
|
+
filename: str,
|
|
259
|
+
):
|
|
260
|
+
"""Serve a file from assets dir."""
|
|
261
|
+
|
|
262
|
+
base_path = Path(copilot_build_dir)
|
|
263
|
+
file_path = (base_path / filename).resolve()
|
|
264
|
+
|
|
265
|
+
if not is_path_inside(file_path, base_path):
|
|
266
|
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
267
|
+
|
|
268
|
+
if file_path.is_file():
|
|
269
|
+
return FileResponse(file_path)
|
|
270
|
+
else:
|
|
271
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
239
272
|
|
|
240
273
|
|
|
241
274
|
# -------------------------------------------------------------------------------
|
|
@@ -286,6 +319,16 @@ def get_html_template():
|
|
|
286
319
|
"""
|
|
287
320
|
Get HTML template for the index view.
|
|
288
321
|
"""
|
|
322
|
+
ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
323
|
+
|
|
324
|
+
custom_theme = None
|
|
325
|
+
custom_theme_file_path = Path(public_dir) / "theme.json"
|
|
326
|
+
if (
|
|
327
|
+
is_path_inside(custom_theme_file_path, Path(public_dir))
|
|
328
|
+
and custom_theme_file_path.is_file()
|
|
329
|
+
):
|
|
330
|
+
custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))
|
|
331
|
+
|
|
289
332
|
PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
|
|
290
333
|
JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
|
|
291
334
|
CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
|
|
@@ -309,8 +352,8 @@ def get_html_template():
|
|
|
309
352
|
<meta property="og:root_path" content="{ROOT_PATH}">"""
|
|
310
353
|
|
|
311
354
|
js = f"""<script>
|
|
312
|
-
{f"window.theme = {json.dumps(
|
|
313
|
-
{f"window.transports = {json.dumps(config.project.transports)};
|
|
355
|
+
{f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
|
|
356
|
+
{f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
|
|
314
357
|
</script>"""
|
|
315
358
|
|
|
316
359
|
css = None
|
|
@@ -323,8 +366,11 @@ def get_html_template():
|
|
|
323
366
|
js += f"""<script src="{config.ui.custom_js}" defer></script>"""
|
|
324
367
|
|
|
325
368
|
font = None
|
|
326
|
-
if
|
|
327
|
-
font =
|
|
369
|
+
if custom_theme and custom_theme.get("custom_fonts"):
|
|
370
|
+
font = "\n".join(
|
|
371
|
+
f"""<link rel="stylesheet" href="{font}">"""
|
|
372
|
+
for font in custom_theme.get("custom_fonts")
|
|
373
|
+
)
|
|
328
374
|
|
|
329
375
|
index_html_file_path = os.path.join(build_dir, "index.html")
|
|
330
376
|
|
|
@@ -376,13 +422,6 @@ async def auth(request: Request):
|
|
|
376
422
|
def _get_response_dict(access_token: str) -> dict:
|
|
377
423
|
"""Get the response dictionary for the auth response."""
|
|
378
424
|
|
|
379
|
-
if not config.project.cookie_auth:
|
|
380
|
-
# Legacy auth
|
|
381
|
-
return {
|
|
382
|
-
"access_token": access_token,
|
|
383
|
-
"token_type": "bearer",
|
|
384
|
-
}
|
|
385
|
-
|
|
386
425
|
return {"success": True}
|
|
387
426
|
|
|
388
427
|
|
|
@@ -444,8 +483,7 @@ async def _authenticate_user(
|
|
|
444
483
|
|
|
445
484
|
response = _get_auth_response(access_token, redirect_to_callback)
|
|
446
485
|
|
|
447
|
-
|
|
448
|
-
set_auth_cookie(response, access_token)
|
|
486
|
+
set_auth_cookie(response, access_token)
|
|
449
487
|
|
|
450
488
|
return response
|
|
451
489
|
|
|
@@ -470,8 +508,7 @@ async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depen
|
|
|
470
508
|
@router.post("/logout")
|
|
471
509
|
async def logout(request: Request, response: Response):
|
|
472
510
|
"""Logout the user by calling the on_logout callback."""
|
|
473
|
-
|
|
474
|
-
clear_auth_cookie(response)
|
|
511
|
+
clear_auth_cookie(response)
|
|
475
512
|
|
|
476
513
|
if config.code.on_logout:
|
|
477
514
|
return await config.code.on_logout(request, response)
|
|
@@ -479,6 +516,35 @@ async def logout(request: Request, response: Response):
|
|
|
479
516
|
return {"success": True}
|
|
480
517
|
|
|
481
518
|
|
|
519
|
+
@router.post("/auth/jwt")
|
|
520
|
+
async def jwt_auth(request: Request):
|
|
521
|
+
"""Login a user using a valid jwt."""
|
|
522
|
+
from jwt import InvalidTokenError
|
|
523
|
+
|
|
524
|
+
auth_header: Optional[str] = request.headers.get("Authorization")
|
|
525
|
+
if not auth_header:
|
|
526
|
+
raise HTTPException(status_code=401, detail="Authorization header missing")
|
|
527
|
+
|
|
528
|
+
# Check if it starts with "Bearer "
|
|
529
|
+
try:
|
|
530
|
+
scheme, token = auth_header.split()
|
|
531
|
+
if scheme.lower() != "bearer":
|
|
532
|
+
raise HTTPException(
|
|
533
|
+
status_code=401,
|
|
534
|
+
detail="Invalid authentication scheme. Please use Bearer",
|
|
535
|
+
)
|
|
536
|
+
except ValueError:
|
|
537
|
+
raise HTTPException(
|
|
538
|
+
status_code=401, detail="Invalid authorization header format"
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
try:
|
|
542
|
+
user = decode_jwt(token)
|
|
543
|
+
return await _authenticate_user(user)
|
|
544
|
+
except InvalidTokenError:
|
|
545
|
+
raise HTTPException(status_code=401, detail="Invalid token")
|
|
546
|
+
|
|
547
|
+
|
|
482
548
|
@router.post("/auth/header")
|
|
483
549
|
async def header_auth(request: Request):
|
|
484
550
|
"""Login a user using the header_auth_callback."""
|
|
@@ -635,7 +701,7 @@ async def oauth_azure_hf_callback(
|
|
|
635
701
|
return response
|
|
636
702
|
|
|
637
703
|
|
|
638
|
-
GenericUser = Union[User, PersistedUser]
|
|
704
|
+
GenericUser = Union[User, PersistedUser, None]
|
|
639
705
|
UserParam = Annotated[GenericUser, Depends(get_current_user)]
|
|
640
706
|
|
|
641
707
|
|
|
@@ -767,6 +833,9 @@ async def get_user_threads(
|
|
|
767
833
|
if not data_layer:
|
|
768
834
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
769
835
|
|
|
836
|
+
if not current_user:
|
|
837
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
838
|
+
|
|
770
839
|
if not isinstance(current_user, PersistedUser):
|
|
771
840
|
persisted_user = await data_layer.get_user(identifier=current_user.identifier)
|
|
772
841
|
if not persisted_user:
|
|
@@ -791,6 +860,9 @@ async def get_thread(
|
|
|
791
860
|
if not data_layer:
|
|
792
861
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
793
862
|
|
|
863
|
+
if not current_user:
|
|
864
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
865
|
+
|
|
794
866
|
await is_thread_author(current_user.identifier, thread_id)
|
|
795
867
|
|
|
796
868
|
res = await data_layer.get_thread(thread_id)
|
|
@@ -810,12 +882,130 @@ async def get_thread_element(
|
|
|
810
882
|
if not data_layer:
|
|
811
883
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
812
884
|
|
|
885
|
+
if not current_user:
|
|
886
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
887
|
+
|
|
813
888
|
await is_thread_author(current_user.identifier, thread_id)
|
|
814
889
|
|
|
815
890
|
res = await data_layer.get_element(thread_id, element_id)
|
|
816
891
|
return JSONResponse(content=res)
|
|
817
892
|
|
|
818
893
|
|
|
894
|
+
@router.put("/project/element")
|
|
895
|
+
async def update_thread_element(
|
|
896
|
+
payload: ElementRequest,
|
|
897
|
+
current_user: UserParam,
|
|
898
|
+
):
|
|
899
|
+
"""Update a specific thread element."""
|
|
900
|
+
|
|
901
|
+
from chainlit.context import init_ws_context
|
|
902
|
+
from chainlit.element import CustomElement, ElementDict
|
|
903
|
+
from chainlit.session import WebsocketSession
|
|
904
|
+
|
|
905
|
+
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
906
|
+
context = init_ws_context(session)
|
|
907
|
+
|
|
908
|
+
element_dict = cast(ElementDict, payload.element)
|
|
909
|
+
|
|
910
|
+
if element_dict["type"] != "custom":
|
|
911
|
+
return {"success": False}
|
|
912
|
+
|
|
913
|
+
element = CustomElement(
|
|
914
|
+
id=element_dict["id"],
|
|
915
|
+
object_key=element_dict["objectKey"],
|
|
916
|
+
chainlit_key=element_dict["chainlitKey"],
|
|
917
|
+
url=element_dict["url"],
|
|
918
|
+
for_id=element_dict.get("forId") or "",
|
|
919
|
+
thread_id=element_dict.get("threadId") or "",
|
|
920
|
+
name=element_dict["name"],
|
|
921
|
+
props=element_dict.get("props") or {},
|
|
922
|
+
display=element_dict["display"],
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
if current_user:
|
|
926
|
+
if (
|
|
927
|
+
not context.session.user
|
|
928
|
+
or context.session.user.identifier != current_user.identifier
|
|
929
|
+
):
|
|
930
|
+
raise HTTPException(
|
|
931
|
+
status_code=401,
|
|
932
|
+
detail="You are not authorized to update elements for this session",
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
await element.send(for_id=element.for_id or "")
|
|
936
|
+
return {"success": True}
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
@router.delete("/project/element")
|
|
940
|
+
async def delete_thread_element(
|
|
941
|
+
payload: ElementRequest,
|
|
942
|
+
current_user: UserParam,
|
|
943
|
+
):
|
|
944
|
+
"""Delete a specific thread element."""
|
|
945
|
+
|
|
946
|
+
from chainlit.context import init_ws_context
|
|
947
|
+
from chainlit.element import CustomElement, ElementDict
|
|
948
|
+
from chainlit.session import WebsocketSession
|
|
949
|
+
|
|
950
|
+
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
951
|
+
context = init_ws_context(session)
|
|
952
|
+
|
|
953
|
+
element_dict = cast(ElementDict, payload.element)
|
|
954
|
+
|
|
955
|
+
if element_dict["type"] != "custom":
|
|
956
|
+
return {"success": False}
|
|
957
|
+
|
|
958
|
+
element = CustomElement(
|
|
959
|
+
id=element_dict["id"],
|
|
960
|
+
object_key=element_dict["objectKey"],
|
|
961
|
+
chainlit_key=element_dict["chainlitKey"],
|
|
962
|
+
url=element_dict["url"],
|
|
963
|
+
for_id=element_dict.get("forId") or "",
|
|
964
|
+
thread_id=element_dict.get("threadId") or "",
|
|
965
|
+
name=element_dict["name"],
|
|
966
|
+
props=element_dict.get("props") or {},
|
|
967
|
+
display=element_dict["display"],
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
if current_user:
|
|
971
|
+
if (
|
|
972
|
+
not context.session.user
|
|
973
|
+
or context.session.user.identifier != current_user.identifier
|
|
974
|
+
):
|
|
975
|
+
raise HTTPException(
|
|
976
|
+
status_code=401,
|
|
977
|
+
detail="You are not authorized to remove elements for this session",
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
await element.remove()
|
|
981
|
+
|
|
982
|
+
return {"success": True}
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
@router.put("/project/thread")
|
|
986
|
+
async def rename_thread(
|
|
987
|
+
request: Request,
|
|
988
|
+
payload: UpdateThreadRequest,
|
|
989
|
+
current_user: UserParam,
|
|
990
|
+
):
|
|
991
|
+
"""Rename a thread."""
|
|
992
|
+
|
|
993
|
+
data_layer = get_data_layer()
|
|
994
|
+
|
|
995
|
+
if not data_layer:
|
|
996
|
+
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
997
|
+
|
|
998
|
+
if not current_user:
|
|
999
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
1000
|
+
|
|
1001
|
+
thread_id = payload.threadId
|
|
1002
|
+
|
|
1003
|
+
await is_thread_author(current_user.identifier, thread_id)
|
|
1004
|
+
|
|
1005
|
+
await data_layer.update_thread(thread_id, name=payload.name)
|
|
1006
|
+
return JSONResponse(content={"success": True})
|
|
1007
|
+
|
|
1008
|
+
|
|
819
1009
|
@router.delete("/project/thread")
|
|
820
1010
|
async def delete_thread(
|
|
821
1011
|
request: Request,
|
|
@@ -829,6 +1019,9 @@ async def delete_thread(
|
|
|
829
1019
|
if not data_layer:
|
|
830
1020
|
raise HTTPException(status_code=400, detail="Data persistence is not enabled")
|
|
831
1021
|
|
|
1022
|
+
if not current_user:
|
|
1023
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
1024
|
+
|
|
832
1025
|
thread_id = payload.threadId
|
|
833
1026
|
|
|
834
1027
|
await is_thread_author(current_user.identifier, thread_id)
|
|
@@ -837,6 +1030,48 @@ async def delete_thread(
|
|
|
837
1030
|
return JSONResponse(content={"success": True})
|
|
838
1031
|
|
|
839
1032
|
|
|
1033
|
+
@router.post("/project/action")
|
|
1034
|
+
async def call_action(
|
|
1035
|
+
payload: CallActionRequest,
|
|
1036
|
+
current_user: UserParam,
|
|
1037
|
+
):
|
|
1038
|
+
"""Run an action."""
|
|
1039
|
+
|
|
1040
|
+
from chainlit.action import Action
|
|
1041
|
+
from chainlit.context import init_ws_context
|
|
1042
|
+
from chainlit.session import WebsocketSession
|
|
1043
|
+
|
|
1044
|
+
session = WebsocketSession.get_by_id(payload.sessionId)
|
|
1045
|
+
context = init_ws_context(session)
|
|
1046
|
+
|
|
1047
|
+
action = Action(**payload.action)
|
|
1048
|
+
|
|
1049
|
+
if current_user:
|
|
1050
|
+
if (
|
|
1051
|
+
not context.session.user
|
|
1052
|
+
or context.session.user.identifier != current_user.identifier
|
|
1053
|
+
):
|
|
1054
|
+
raise HTTPException(
|
|
1055
|
+
status_code=401,
|
|
1056
|
+
detail="You are not authorized to upload files for this session",
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
callback = config.code.action_callbacks.get(action.name)
|
|
1060
|
+
if callback:
|
|
1061
|
+
if not context.session.has_first_interaction:
|
|
1062
|
+
context.session.has_first_interaction = True
|
|
1063
|
+
asyncio.create_task(context.emitter.init_thread(action.name))
|
|
1064
|
+
|
|
1065
|
+
await callback(action)
|
|
1066
|
+
else:
|
|
1067
|
+
raise HTTPException(
|
|
1068
|
+
status_code=404,
|
|
1069
|
+
detail=f"No callback found for action {action.name}",
|
|
1070
|
+
)
|
|
1071
|
+
|
|
1072
|
+
return JSONResponse(content={"success": True})
|
|
1073
|
+
|
|
1074
|
+
|
|
840
1075
|
@router.post("/project/file")
|
|
841
1076
|
async def upload_file(
|
|
842
1077
|
current_user: UserParam,
|
|
@@ -888,11 +1123,14 @@ def validate_file_upload(file: UploadFile):
|
|
|
888
1123
|
Raises:
|
|
889
1124
|
ValueError: If the file is not allowed.
|
|
890
1125
|
"""
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
if config.features.spontaneous_file_upload
|
|
895
|
-
|
|
1126
|
+
# TODO: This logic/endpoint is shared across spontaneous uploads and the AskFileMessage API.
|
|
1127
|
+
# Commenting this check until we find a better solution
|
|
1128
|
+
|
|
1129
|
+
# if config.features.spontaneous_file_upload is None:
|
|
1130
|
+
# """Default for a missing config is to allow the fileupload without any restrictions"""
|
|
1131
|
+
# return
|
|
1132
|
+
# if not config.features.spontaneous_file_upload.enabled:
|
|
1133
|
+
# raise ValueError("File upload is not enabled")
|
|
896
1134
|
|
|
897
1135
|
validate_file_mime_type(file)
|
|
898
1136
|
validate_file_size(file)
|
|
@@ -905,14 +1143,19 @@ def validate_file_mime_type(file: UploadFile):
|
|
|
905
1143
|
Raises:
|
|
906
1144
|
ValueError: If the file type is not allowed.
|
|
907
1145
|
"""
|
|
908
|
-
|
|
909
|
-
if
|
|
1146
|
+
|
|
1147
|
+
if (
|
|
1148
|
+
config.features.spontaneous_file_upload is None
|
|
1149
|
+
or config.features.spontaneous_file_upload.accept is None
|
|
1150
|
+
):
|
|
910
1151
|
"Accept is not configured, allowing all file types"
|
|
911
1152
|
return
|
|
912
1153
|
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
1154
|
+
accept = config.features.spontaneous_file_upload.accept
|
|
1155
|
+
|
|
1156
|
+
assert isinstance(accept, List) or isinstance(accept, dict), (
|
|
1157
|
+
"Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
|
|
1158
|
+
)
|
|
916
1159
|
|
|
917
1160
|
if isinstance(accept, List):
|
|
918
1161
|
for pattern in accept:
|
|
@@ -936,7 +1179,10 @@ def validate_file_size(file: UploadFile):
|
|
|
936
1179
|
Raises:
|
|
937
1180
|
ValueError: If the file size is too large.
|
|
938
1181
|
"""
|
|
939
|
-
if
|
|
1182
|
+
if (
|
|
1183
|
+
config.features.spontaneous_file_upload is None
|
|
1184
|
+
or config.features.spontaneous_file_upload.max_size_mb is None
|
|
1185
|
+
):
|
|
940
1186
|
return
|
|
941
1187
|
|
|
942
1188
|
if (
|
|
@@ -954,14 +1200,6 @@ async def get_file(
|
|
|
954
1200
|
current_user: UserParam,
|
|
955
1201
|
):
|
|
956
1202
|
"""Get a file from the session files directory."""
|
|
957
|
-
|
|
958
|
-
if not config.project.cookie_auth:
|
|
959
|
-
# We cannot make this work safely without cookie auth, so disable it.
|
|
960
|
-
raise HTTPException(
|
|
961
|
-
status_code=404,
|
|
962
|
-
detail="File downloads unavailable.",
|
|
963
|
-
)
|
|
964
|
-
|
|
965
1203
|
from chainlit.session import WebsocketSession
|
|
966
1204
|
|
|
967
1205
|
session = WebsocketSession.get_by_id(session_id) if session_id else None
|
|
@@ -986,25 +1224,6 @@ async def get_file(
|
|
|
986
1224
|
raise HTTPException(status_code=404, detail="File not found")
|
|
987
1225
|
|
|
988
1226
|
|
|
989
|
-
@router.get("/files/{filename:path}")
|
|
990
|
-
async def serve_file(
|
|
991
|
-
filename: str,
|
|
992
|
-
current_user: UserParam,
|
|
993
|
-
):
|
|
994
|
-
"""Serve a file from the local filesystem."""
|
|
995
|
-
|
|
996
|
-
base_path = Path(config.project.local_fs_path).resolve()
|
|
997
|
-
file_path = (base_path / filename).resolve()
|
|
998
|
-
|
|
999
|
-
if not is_path_inside(file_path, base_path):
|
|
1000
|
-
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
1001
|
-
|
|
1002
|
-
if file_path.is_file():
|
|
1003
|
-
return FileResponse(file_path)
|
|
1004
|
-
else:
|
|
1005
|
-
raise HTTPException(status_code=404, detail="File not found")
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
1227
|
@router.get("/favicon")
|
|
1009
1228
|
async def get_favicon():
|
|
1010
1229
|
"""Get the favicon for the UI."""
|
chainlit/session.py
CHANGED
|
@@ -64,6 +64,8 @@ class BaseSession:
|
|
|
64
64
|
chat_profile: Optional[str] = None,
|
|
65
65
|
# Origin of the request
|
|
66
66
|
http_referer: Optional[str] = None,
|
|
67
|
+
# Cookie
|
|
68
|
+
http_cookie: Optional[str] = None,
|
|
67
69
|
):
|
|
68
70
|
if thread_id:
|
|
69
71
|
self.thread_id_to_resume = thread_id
|
|
@@ -75,6 +77,7 @@ class BaseSession:
|
|
|
75
77
|
self.user_env = user_env or {}
|
|
76
78
|
self.chat_profile = chat_profile
|
|
77
79
|
self.http_referer = http_referer
|
|
80
|
+
self.http_cookie = http_cookie
|
|
78
81
|
|
|
79
82
|
self.files: Dict[str, FileDict] = {}
|
|
80
83
|
|
|
@@ -167,6 +170,8 @@ class HTTPSession(BaseSession):
|
|
|
167
170
|
user_env: Optional[Dict[str, str]] = None,
|
|
168
171
|
# Origin of the request
|
|
169
172
|
http_referer: Optional[str] = None,
|
|
173
|
+
# Cookie
|
|
174
|
+
http_cookie: Optional[str] = None,
|
|
170
175
|
):
|
|
171
176
|
super().__init__(
|
|
172
177
|
id=id,
|
|
@@ -176,6 +181,7 @@ class HTTPSession(BaseSession):
|
|
|
176
181
|
client_type=client_type,
|
|
177
182
|
user_env=user_env,
|
|
178
183
|
http_referer=http_referer,
|
|
184
|
+
http_cookie=http_cookie,
|
|
179
185
|
)
|
|
180
186
|
|
|
181
187
|
def delete(self):
|
|
@@ -226,6 +232,8 @@ class WebsocketSession(BaseSession):
|
|
|
226
232
|
languages: Optional[str] = None,
|
|
227
233
|
# Origin of the request
|
|
228
234
|
http_referer: Optional[str] = None,
|
|
235
|
+
# Cookie
|
|
236
|
+
http_cookie: Optional[str] = None,
|
|
229
237
|
):
|
|
230
238
|
super().__init__(
|
|
231
239
|
id=id,
|
|
@@ -236,6 +244,7 @@ class WebsocketSession(BaseSession):
|
|
|
236
244
|
client_type=client_type,
|
|
237
245
|
chat_profile=chat_profile,
|
|
238
246
|
http_referer=http_referer,
|
|
247
|
+
http_cookie=http_cookie,
|
|
239
248
|
)
|
|
240
249
|
|
|
241
250
|
self.socket_id = socket_id
|