chainlit 1.1.404__py3-none-any.whl → 1.2.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +63 -305
- chainlit/_utils.py +8 -0
- chainlit/assistant.py +16 -0
- chainlit/assistant_settings.py +35 -0
- chainlit/callbacks.py +340 -0
- chainlit/cli/__init__.py +1 -1
- chainlit/config.py +58 -28
- chainlit/copilot/dist/index.js +512 -631
- chainlit/data/__init__.py +6 -521
- chainlit/data/base.py +121 -0
- chainlit/data/dynamodb.py +5 -8
- chainlit/data/literalai.py +395 -0
- chainlit/data/sql_alchemy.py +11 -9
- chainlit/data/storage_clients.py +69 -15
- chainlit/data/utils.py +29 -0
- chainlit/element.py +1 -1
- chainlit/emitter.py +7 -0
- chainlit/frontend/dist/assets/{DailyMotion-e665b444.js → DailyMotion-aa368b7e.js} +1 -1
- chainlit/frontend/dist/assets/{Facebook-5207db92.js → Facebook-0335db46.js} +1 -1
- chainlit/frontend/dist/assets/{FilePlayer-86937d6e.js → FilePlayer-8d04256c.js} +1 -1
- chainlit/frontend/dist/assets/{Kaltura-c96622c1.js → Kaltura-67c9dd31.js} +1 -1
- chainlit/frontend/dist/assets/{Mixcloud-57ae3e32.js → Mixcloud-6bbaccf5.js} +1 -1
- chainlit/frontend/dist/assets/{Mux-20373920.js → Mux-c2bcb757.js} +1 -1
- chainlit/frontend/dist/assets/{Preview-c68c0613.js → Preview-210f3955.js} +1 -1
- chainlit/frontend/dist/assets/{SoundCloud-8a9e3eae.js → SoundCloud-a0276b84.js} +1 -1
- chainlit/frontend/dist/assets/{Streamable-1ed099af.js → Streamable-a007323d.js} +1 -1
- chainlit/frontend/dist/assets/{Twitch-6820039f.js → Twitch-e6a88aa3.js} +1 -1
- chainlit/frontend/dist/assets/{Vidyard-d39ab91d.js → Vidyard-dfb88a35.js} +1 -1
- chainlit/frontend/dist/assets/{Vimeo-017cd9a7.js → Vimeo-3baa13d9.js} +1 -1
- chainlit/frontend/dist/assets/{Wistia-a509d9f2.js → Wistia-e52f7bef.js} +1 -1
- chainlit/frontend/dist/assets/{YouTube-42dfd82f.js → YouTube-1715f22b.js} +1 -1
- chainlit/frontend/dist/assets/index-bfdd8585.js +729 -0
- chainlit/frontend/dist/assets/react-plotly-55648373.js +3484 -0
- chainlit/frontend/dist/index.html +1 -1
- chainlit/input_widget.py +22 -0
- chainlit/langchain/callbacks.py +6 -1
- chainlit/llama_index/callbacks.py +20 -4
- chainlit/markdown.py +15 -9
- chainlit/message.py +0 -1
- chainlit/server.py +113 -37
- chainlit/session.py +27 -4
- chainlit/socket.py +50 -1
- chainlit/translations/bn.json +231 -0
- chainlit/translations/en-US.json +6 -0
- chainlit/translations/fr-FR.json +236 -0
- chainlit/translations/gu.json +231 -0
- chainlit/translations/he-IL.json +231 -0
- chainlit/translations/hi.json +231 -0
- chainlit/translations/kn.json +231 -0
- chainlit/translations/ml.json +231 -0
- chainlit/translations/mr.json +231 -0
- chainlit/translations/ta.json +231 -0
- chainlit/translations/te.json +231 -0
- chainlit/types.py +1 -1
- chainlit/user_session.py +4 -0
- chainlit/utils.py +1 -1
- {chainlit-1.1.404.dist-info → chainlit-1.2.0rc0.dist-info}/METADATA +10 -10
- chainlit-1.2.0rc0.dist-info/RECORD +99 -0
- chainlit/frontend/dist/assets/index-30df9b2b.js +0 -730
- chainlit/frontend/dist/assets/react-plotly-5bb34118.js +0 -3602
- chainlit-1.1.404.dist-info/RECORD +0 -82
- {chainlit-1.1.404.dist-info → chainlit-1.2.0rc0.dist-info}/WHEEL +0 -0
- {chainlit-1.1.404.dist-info → chainlit-1.2.0rc0.dist-info}/entry_points.txt +0 -0
|
@@ -21,7 +21,7 @@
|
|
|
21
21
|
<script>
|
|
22
22
|
const global = globalThis;
|
|
23
23
|
</script>
|
|
24
|
-
<script type="module" crossorigin src="/assets/index-
|
|
24
|
+
<script type="module" crossorigin src="/assets/index-bfdd8585.js"></script>
|
|
25
25
|
<link rel="stylesheet" href="/assets/index-aaf974a9.css">
|
|
26
26
|
</head>
|
|
27
27
|
<body>
|
chainlit/input_widget.py
CHANGED
|
@@ -161,6 +161,28 @@ class NumberInput(InputWidget):
|
|
|
161
161
|
"description": self.description,
|
|
162
162
|
}
|
|
163
163
|
|
|
164
|
+
@dataclass
|
|
165
|
+
class FileUploadInput(InputWidget):
|
|
166
|
+
"""Useful to create a file upload input."""
|
|
167
|
+
|
|
168
|
+
type: InputWidgetType = "fileupload"
|
|
169
|
+
initial: Optional[str] = None
|
|
170
|
+
placeholder: Optional[str] = None
|
|
171
|
+
accept: List[str] = Field(default_factory=lambda: [])
|
|
172
|
+
max_size_mb: Optional[int] = None
|
|
173
|
+
max_files: Optional[int] = None
|
|
174
|
+
|
|
175
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
176
|
+
return {
|
|
177
|
+
"type": self.type,
|
|
178
|
+
"id": self.id,
|
|
179
|
+
"label": self.label,
|
|
180
|
+
"initial": self.initial,
|
|
181
|
+
"placeholder": self.placeholder,
|
|
182
|
+
"tooltip": self.tooltip,
|
|
183
|
+
"description": self.description,
|
|
184
|
+
}
|
|
185
|
+
|
|
164
186
|
|
|
165
187
|
@dataclass
|
|
166
188
|
class Tags(InputWidget):
|
chainlit/langchain/callbacks.py
CHANGED
|
@@ -587,12 +587,17 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
|
|
|
587
587
|
outputs = run.outputs or {}
|
|
588
588
|
output_keys = list(outputs.keys())
|
|
589
589
|
output = outputs
|
|
590
|
+
|
|
590
591
|
if output_keys:
|
|
591
592
|
output = outputs.get(output_keys[0], outputs)
|
|
592
593
|
|
|
593
594
|
if current_step:
|
|
594
595
|
current_step.output = (
|
|
595
|
-
output[0]
|
|
596
|
+
output[0]
|
|
597
|
+
if isinstance(output, Sequence)
|
|
598
|
+
and not isinstance(output, str)
|
|
599
|
+
and len(output)
|
|
600
|
+
else output
|
|
596
601
|
)
|
|
597
602
|
current_step.end = utc_now()
|
|
598
603
|
self._run_sync(current_step.update())
|
|
@@ -8,6 +8,7 @@ from literalai.helper import utc_now
|
|
|
8
8
|
from llama_index.core.callbacks import TokenCountingHandler
|
|
9
9
|
from llama_index.core.callbacks.schema import CBEventType, EventPayload
|
|
10
10
|
from llama_index.core.llms import ChatMessage, ChatResponse, CompletionResponse
|
|
11
|
+
from llama_index.core.tools.types import ToolMetadata
|
|
11
12
|
|
|
12
13
|
DEFAULT_IGNORE = [
|
|
13
14
|
CBEventType.CHUNKING,
|
|
@@ -54,7 +55,16 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
|
|
|
54
55
|
) -> str:
|
|
55
56
|
"""Run when an event starts and return id of event."""
|
|
56
57
|
step_type: StepType = "undefined"
|
|
57
|
-
|
|
58
|
+
step_name: str = event_type.value
|
|
59
|
+
step_input: Optional[Dict[str, Any]] = payload
|
|
60
|
+
if event_type == CBEventType.FUNCTION_CALL:
|
|
61
|
+
step_type = "tool"
|
|
62
|
+
if payload:
|
|
63
|
+
metadata: Optional[ToolMetadata] = payload.get(EventPayload.TOOL)
|
|
64
|
+
if metadata:
|
|
65
|
+
step_name = getattr(metadata, "name", step_name)
|
|
66
|
+
step_input = payload.get(EventPayload.FUNCTION_CALL)
|
|
67
|
+
elif event_type == CBEventType.RETRIEVE:
|
|
58
68
|
step_type = "tool"
|
|
59
69
|
elif event_type == CBEventType.QUERY:
|
|
60
70
|
step_type = "tool"
|
|
@@ -64,7 +74,7 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
|
|
|
64
74
|
return event_id
|
|
65
75
|
|
|
66
76
|
step = Step(
|
|
67
|
-
name=
|
|
77
|
+
name=step_name,
|
|
68
78
|
type=step_type,
|
|
69
79
|
parent_id=self._get_parent_id(parent_id),
|
|
70
80
|
id=event_id,
|
|
@@ -72,7 +82,7 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
|
|
|
72
82
|
|
|
73
83
|
self.steps[event_id] = step
|
|
74
84
|
step.start = utc_now()
|
|
75
|
-
step.input =
|
|
85
|
+
step.input = step_input or {}
|
|
76
86
|
context_var.get().loop.create_task(step.send())
|
|
77
87
|
return event_id
|
|
78
88
|
|
|
@@ -91,7 +101,13 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
|
|
|
91
101
|
|
|
92
102
|
step.end = utc_now()
|
|
93
103
|
|
|
94
|
-
if event_type == CBEventType.
|
|
104
|
+
if event_type == CBEventType.FUNCTION_CALL:
|
|
105
|
+
response = payload.get(EventPayload.FUNCTION_OUTPUT)
|
|
106
|
+
if response:
|
|
107
|
+
step.output = f"{response}"
|
|
108
|
+
context_var.get().loop.create_task(step.update())
|
|
109
|
+
|
|
110
|
+
elif event_type == CBEventType.QUERY:
|
|
95
111
|
response = payload.get(EventPayload.RESPONSE)
|
|
96
112
|
source_nodes = getattr(response, "source_nodes", None)
|
|
97
113
|
if source_nodes:
|
chainlit/markdown.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
2
4
|
|
|
3
5
|
from chainlit.logger import logger
|
|
4
6
|
|
|
7
|
+
from ._utils import is_path_inside
|
|
8
|
+
|
|
5
9
|
# Default chainlit.md file created if none exists
|
|
6
10
|
DEFAULT_MARKDOWN_STR = """# Welcome to Chainlit! 🚀🤖
|
|
7
11
|
|
|
@@ -30,12 +34,16 @@ def init_markdown(root: str):
|
|
|
30
34
|
logger.info(f"Created default chainlit markdown file at {chainlit_md_file}")
|
|
31
35
|
|
|
32
36
|
|
|
33
|
-
def get_markdown_str(root: str, language: str):
|
|
37
|
+
def get_markdown_str(root: str, language: str) -> Optional[str]:
|
|
34
38
|
"""Get the chainlit.md file as a string."""
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
+
root_path = Path(root)
|
|
40
|
+
translated_chainlit_md_path = root_path / f"chainlit_{language}.md"
|
|
41
|
+
default_chainlit_md_path = root_path / "chainlit.md"
|
|
42
|
+
|
|
43
|
+
if (
|
|
44
|
+
is_path_inside(translated_chainlit_md_path, root_path)
|
|
45
|
+
and translated_chainlit_md_path.is_file()
|
|
46
|
+
):
|
|
39
47
|
chainlit_md_path = translated_chainlit_md_path
|
|
40
48
|
else:
|
|
41
49
|
chainlit_md_path = default_chainlit_md_path
|
|
@@ -43,9 +51,7 @@ def get_markdown_str(root: str, language: str):
|
|
|
43
51
|
f"Translated markdown file for {language} not found. Defaulting to chainlit.md."
|
|
44
52
|
)
|
|
45
53
|
|
|
46
|
-
if
|
|
47
|
-
|
|
48
|
-
chainlit_md = f.read()
|
|
49
|
-
return chainlit_md
|
|
54
|
+
if chainlit_md_path.is_file():
|
|
55
|
+
return chainlit_md_path.read_text(encoding="utf-8")
|
|
50
56
|
else:
|
|
51
57
|
return None
|
chainlit/message.py
CHANGED
chainlit/server.py
CHANGED
|
@@ -1,22 +1,15 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import glob
|
|
2
3
|
import json
|
|
3
4
|
import mimetypes
|
|
5
|
+
import os
|
|
4
6
|
import re
|
|
5
7
|
import shutil
|
|
6
8
|
import urllib.parse
|
|
7
|
-
from typing import Any, Optional, Union
|
|
8
|
-
|
|
9
|
-
from chainlit.oauth_providers import get_oauth_provider
|
|
10
|
-
from chainlit.secret import random_secret
|
|
11
|
-
|
|
12
|
-
mimetypes.add_type("application/javascript", ".js")
|
|
13
|
-
mimetypes.add_type("text/css", ".css")
|
|
14
|
-
|
|
15
|
-
import asyncio
|
|
16
|
-
import os
|
|
17
9
|
import webbrowser
|
|
18
10
|
from contextlib import asynccontextmanager
|
|
19
11
|
from pathlib import Path
|
|
12
|
+
from typing import Any, Optional, Union
|
|
20
13
|
|
|
21
14
|
import socketio
|
|
22
15
|
from chainlit.auth import create_jwt, get_configuration, get_current_user
|
|
@@ -34,6 +27,8 @@ from chainlit.data import get_data_layer
|
|
|
34
27
|
from chainlit.data.acl import is_thread_author
|
|
35
28
|
from chainlit.logger import logger
|
|
36
29
|
from chainlit.markdown import get_markdown_str
|
|
30
|
+
from chainlit.oauth_providers import get_oauth_provider
|
|
31
|
+
from chainlit.secret import random_secret
|
|
37
32
|
from chainlit.types import (
|
|
38
33
|
DeleteFeedbackRequest,
|
|
39
34
|
DeleteThreadRequest,
|
|
@@ -62,12 +57,20 @@ from starlette.middleware.cors import CORSMiddleware
|
|
|
62
57
|
from typing_extensions import Annotated
|
|
63
58
|
from watchfiles import awatch
|
|
64
59
|
|
|
60
|
+
from ._utils import is_path_inside
|
|
61
|
+
|
|
62
|
+
mimetypes.add_type("application/javascript", ".js")
|
|
63
|
+
mimetypes.add_type("text/css", ".css")
|
|
64
|
+
|
|
65
65
|
ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
|
|
66
66
|
IS_SUBMOUNT = os.environ.get("CHAINLIT_SUBMOUNT", "") == "true"
|
|
67
|
+
# If the app is a submount, no need to set the prefix
|
|
68
|
+
PREFIX = ROOT_PATH if ROOT_PATH and not IS_SUBMOUNT else ""
|
|
67
69
|
|
|
68
70
|
|
|
69
71
|
@asynccontextmanager
|
|
70
72
|
async def lifespan(app: FastAPI):
|
|
73
|
+
"""Context manager to handle app start and shutdown."""
|
|
71
74
|
host = config.run.host
|
|
72
75
|
port = config.run.port
|
|
73
76
|
|
|
@@ -150,7 +153,18 @@ async def lifespan(app: FastAPI):
|
|
|
150
153
|
os._exit(0)
|
|
151
154
|
|
|
152
155
|
|
|
153
|
-
def get_build_dir(local_target: str, packaged_target: str):
|
|
156
|
+
def get_build_dir(local_target: str, packaged_target: str) -> str:
|
|
157
|
+
"""
|
|
158
|
+
Get the build directory based on the UI build strategy.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
local_target (str): The local target directory.
|
|
162
|
+
packaged_target (str): The packaged target directory.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
str: The build directory
|
|
166
|
+
"""
|
|
167
|
+
|
|
154
168
|
local_build_dir = os.path.join(PACKAGE_ROOT, local_target, "dist")
|
|
155
169
|
packaged_build_dir = os.path.join(BACKEND_ROOT, packaged_target, "dist")
|
|
156
170
|
|
|
@@ -171,18 +185,14 @@ copilot_build_dir = get_build_dir(os.path.join("libs", "copilot"), "copilot")
|
|
|
171
185
|
|
|
172
186
|
app = FastAPI(lifespan=lifespan)
|
|
173
187
|
|
|
174
|
-
sio = socketio.AsyncServer(
|
|
175
|
-
cors_allowed_origins=[], async_mode="asgi"
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
sio_mount_location = f"{ROOT_PATH}/ws" if ROOT_PATH else "ws"
|
|
188
|
+
sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
|
|
179
189
|
|
|
180
190
|
asgi_app = socketio.ASGIApp(
|
|
181
191
|
socketio_server=sio,
|
|
182
|
-
socketio_path=
|
|
192
|
+
socketio_path="",
|
|
183
193
|
)
|
|
184
194
|
|
|
185
|
-
app.mount(f"
|
|
195
|
+
app.mount(f"{PREFIX}/ws/socket.io", asgi_app)
|
|
186
196
|
|
|
187
197
|
app.add_middleware(
|
|
188
198
|
CORSMiddleware,
|
|
@@ -192,16 +202,16 @@ app.add_middleware(
|
|
|
192
202
|
allow_headers=["*"],
|
|
193
203
|
)
|
|
194
204
|
|
|
195
|
-
router = APIRouter(prefix=
|
|
205
|
+
router = APIRouter(prefix=PREFIX)
|
|
196
206
|
|
|
197
207
|
app.mount(
|
|
198
|
-
f"{
|
|
208
|
+
f"{PREFIX}/public",
|
|
199
209
|
StaticFiles(directory="public", check_dir=False),
|
|
200
210
|
name="public",
|
|
201
211
|
)
|
|
202
212
|
|
|
203
213
|
app.mount(
|
|
204
|
-
f"{
|
|
214
|
+
f"{PREFIX}/assets",
|
|
205
215
|
StaticFiles(
|
|
206
216
|
packages=[("chainlit", os.path.join(build_dir, "assets"))],
|
|
207
217
|
follow_symlink=config.project.follow_symlink,
|
|
@@ -210,7 +220,7 @@ app.mount(
|
|
|
210
220
|
)
|
|
211
221
|
|
|
212
222
|
app.mount(
|
|
213
|
-
f"{
|
|
223
|
+
f"{PREFIX}/copilot",
|
|
214
224
|
StaticFiles(
|
|
215
225
|
packages=[("chainlit", copilot_build_dir)],
|
|
216
226
|
follow_symlink=config.project.follow_symlink,
|
|
@@ -218,7 +228,6 @@ app.mount(
|
|
|
218
228
|
name="copilot",
|
|
219
229
|
)
|
|
220
230
|
|
|
221
|
-
|
|
222
231
|
# -------------------------------------------------------------------------------
|
|
223
232
|
# SLACK HANDLER
|
|
224
233
|
# -------------------------------------------------------------------------------
|
|
@@ -253,12 +262,19 @@ if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
|
|
|
253
262
|
# -------------------------------------------------------------------------------
|
|
254
263
|
|
|
255
264
|
|
|
256
|
-
def replace_between_tags(
|
|
265
|
+
def replace_between_tags(
|
|
266
|
+
text: str, start_tag: str, end_tag: str, replacement: str
|
|
267
|
+
) -> str:
|
|
268
|
+
"""Replace text between two tags in a string."""
|
|
269
|
+
|
|
257
270
|
pattern = start_tag + ".*?" + end_tag
|
|
258
271
|
return re.sub(pattern, start_tag + replacement + end_tag, text, flags=re.DOTALL)
|
|
259
272
|
|
|
260
273
|
|
|
261
274
|
def get_html_template():
|
|
275
|
+
"""
|
|
276
|
+
Get HTML template for the index view.
|
|
277
|
+
"""
|
|
262
278
|
PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
|
|
263
279
|
JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
|
|
264
280
|
CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
|
|
@@ -345,6 +361,9 @@ async def auth(request: Request):
|
|
|
345
361
|
|
|
346
362
|
@router.post("/login")
|
|
347
363
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
364
|
+
"""
|
|
365
|
+
Login a user using the password auth callback.
|
|
366
|
+
"""
|
|
348
367
|
if not config.code.password_auth_callback:
|
|
349
368
|
raise HTTPException(
|
|
350
369
|
status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
|
|
@@ -374,6 +393,7 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
|
374
393
|
|
|
375
394
|
@router.post("/logout")
|
|
376
395
|
async def logout(request: Request, response: Response):
|
|
396
|
+
"""Logout the user by calling the on_logout callback."""
|
|
377
397
|
if config.code.on_logout:
|
|
378
398
|
return await config.code.on_logout(request, response)
|
|
379
399
|
return {"success": True}
|
|
@@ -381,6 +401,7 @@ async def logout(request: Request, response: Response):
|
|
|
381
401
|
|
|
382
402
|
@router.post("/auth/header")
|
|
383
403
|
async def header_auth(request: Request):
|
|
404
|
+
"""Login a user using the header_auth_callback."""
|
|
384
405
|
if not config.code.header_auth_callback:
|
|
385
406
|
raise HTTPException(
|
|
386
407
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -410,6 +431,7 @@ async def header_auth(request: Request):
|
|
|
410
431
|
|
|
411
432
|
@router.get("/auth/oauth/{provider_id}")
|
|
412
433
|
async def oauth_login(provider_id: str, request: Request):
|
|
434
|
+
"""Redirect the user to the oauth provider login page."""
|
|
413
435
|
if config.code.oauth_callback is None:
|
|
414
436
|
raise HTTPException(
|
|
415
437
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -436,7 +458,7 @@ async def oauth_login(provider_id: str, request: Request):
|
|
|
436
458
|
response = RedirectResponse(
|
|
437
459
|
url=f"{provider.authorize_url}?{params}",
|
|
438
460
|
)
|
|
439
|
-
samesite = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax")
|
|
461
|
+
samesite: Any = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax")
|
|
440
462
|
secure = samesite.lower() == "none"
|
|
441
463
|
response.set_cookie(
|
|
442
464
|
"oauth_state",
|
|
@@ -457,6 +479,8 @@ async def oauth_callback(
|
|
|
457
479
|
code: Optional[str] = None,
|
|
458
480
|
state: Optional[str] = None,
|
|
459
481
|
):
|
|
482
|
+
"""Handle the oauth callback and login the user."""
|
|
483
|
+
|
|
460
484
|
if config.code.oauth_callback is None:
|
|
461
485
|
raise HTTPException(
|
|
462
486
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
@@ -544,6 +568,8 @@ async def oauth_azure_hf_callback(
|
|
|
544
568
|
code: Annotated[Optional[str], Form()] = None,
|
|
545
569
|
id_token: Annotated[Optional[str], Form()] = None,
|
|
546
570
|
):
|
|
571
|
+
"""Handle the azure ad hybrid flow callback and login the user."""
|
|
572
|
+
|
|
547
573
|
provider_id = "azure-ad-hybrid"
|
|
548
574
|
if config.code.oauth_callback is None:
|
|
549
575
|
raise HTTPException(
|
|
@@ -617,9 +643,16 @@ async def oauth_azure_hf_callback(
|
|
|
617
643
|
return response
|
|
618
644
|
|
|
619
645
|
|
|
646
|
+
_language_pattern = (
|
|
647
|
+
"^[a-zA-Z]{2,3}(-[a-zA-Z]{2,3})?(-[a-zA-Z]{2,8})?(-x-[a-zA-Z0-9]{1,8})?$"
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
|
|
620
651
|
@router.get("/project/translations")
|
|
621
652
|
async def project_translations(
|
|
622
|
-
language: str = Query(
|
|
653
|
+
language: str = Query(
|
|
654
|
+
default="en-US", description="Language code", pattern=_language_pattern
|
|
655
|
+
),
|
|
623
656
|
):
|
|
624
657
|
"""Return project translations."""
|
|
625
658
|
|
|
@@ -636,11 +669,14 @@ async def project_translations(
|
|
|
636
669
|
@router.get("/project/settings")
|
|
637
670
|
async def project_settings(
|
|
638
671
|
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
639
|
-
language: str = Query(
|
|
672
|
+
language: str = Query(
|
|
673
|
+
default="en-US", description="Language code", pattern=_language_pattern
|
|
674
|
+
),
|
|
640
675
|
):
|
|
641
676
|
"""Return project settings. This is called by the UI before the establishing the websocket connection."""
|
|
642
677
|
|
|
643
678
|
# Load the markdown file based on the provided language
|
|
679
|
+
|
|
644
680
|
markdown = get_markdown_str(config.root, language)
|
|
645
681
|
|
|
646
682
|
profiles = []
|
|
@@ -808,6 +844,8 @@ async def upload_file(
|
|
|
808
844
|
Union[None, User, PersistedUser], Depends(get_current_user)
|
|
809
845
|
],
|
|
810
846
|
):
|
|
847
|
+
"""Upload a file to the session files directory."""
|
|
848
|
+
|
|
811
849
|
from chainlit.session import WebsocketSession
|
|
812
850
|
|
|
813
851
|
session = WebsocketSession.get_by_id(session_id)
|
|
@@ -841,6 +879,8 @@ async def get_file(
|
|
|
841
879
|
file_id: str,
|
|
842
880
|
session_id: Optional[str] = None,
|
|
843
881
|
):
|
|
882
|
+
"""Get a file from the session files directory."""
|
|
883
|
+
|
|
844
884
|
from chainlit.session import WebsocketSession
|
|
845
885
|
|
|
846
886
|
session = WebsocketSession.get_by_id(session_id) if session_id else None
|
|
@@ -863,11 +903,12 @@ async def serve_file(
|
|
|
863
903
|
filename: str,
|
|
864
904
|
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
|
|
865
905
|
):
|
|
906
|
+
"""Serve a file from the local filesystem."""
|
|
907
|
+
|
|
866
908
|
base_path = Path(config.project.local_fs_path).resolve()
|
|
867
909
|
file_path = (base_path / filename).resolve()
|
|
868
910
|
|
|
869
|
-
|
|
870
|
-
if base_path not in file_path.parents:
|
|
911
|
+
if not is_path_inside(file_path, base_path):
|
|
871
912
|
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
872
913
|
|
|
873
914
|
if file_path.is_file():
|
|
@@ -878,6 +919,7 @@ async def serve_file(
|
|
|
878
919
|
|
|
879
920
|
@router.get("/favicon")
|
|
880
921
|
async def get_favicon():
|
|
922
|
+
"""Get the favicon for the UI."""
|
|
881
923
|
custom_favicon_path = os.path.join(APP_ROOT, "public", "favicon.*")
|
|
882
924
|
files = glob.glob(custom_favicon_path)
|
|
883
925
|
|
|
@@ -893,6 +935,7 @@ async def get_favicon():
|
|
|
893
935
|
|
|
894
936
|
@router.get("/logo")
|
|
895
937
|
async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
|
|
938
|
+
"""Get the default logo for the UI."""
|
|
896
939
|
theme_value = theme.value if theme else Theme.light.value
|
|
897
940
|
logo_path = None
|
|
898
941
|
|
|
@@ -908,32 +951,65 @@ async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
|
|
|
908
951
|
|
|
909
952
|
if not logo_path:
|
|
910
953
|
raise HTTPException(status_code=404, detail="Missing default logo")
|
|
954
|
+
|
|
911
955
|
media_type, _ = mimetypes.guess_type(logo_path)
|
|
912
956
|
|
|
913
957
|
return FileResponse(logo_path, media_type=media_type)
|
|
914
958
|
|
|
915
959
|
|
|
916
|
-
@router.get("/avatars/{avatar_id}")
|
|
960
|
+
@router.get("/avatars/{avatar_id:str}")
|
|
917
961
|
async def get_avatar(avatar_id: str):
|
|
962
|
+
"""Get the avatar for the user based on the avatar_id."""
|
|
963
|
+
if not re.match(r"^[a-zA-Z0-9_-]+$", avatar_id):
|
|
964
|
+
raise HTTPException(status_code=400, detail="Invalid avatar_id")
|
|
965
|
+
|
|
918
966
|
if avatar_id == "default":
|
|
919
967
|
avatar_id = config.ui.name
|
|
920
968
|
|
|
921
969
|
avatar_id = avatar_id.strip().lower().replace(" ", "_")
|
|
922
970
|
|
|
923
|
-
|
|
971
|
+
base_path = Path(APP_ROOT) / "public" / "avatars"
|
|
972
|
+
avatar_pattern = f"{avatar_id}.*"
|
|
924
973
|
|
|
925
|
-
|
|
974
|
+
matching_files = base_path.glob(avatar_pattern)
|
|
975
|
+
|
|
976
|
+
if avatar_path := next(matching_files, None):
|
|
977
|
+
if not is_path_inside(avatar_path, base_path):
|
|
978
|
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
979
|
+
|
|
980
|
+
media_type, _ = mimetypes.guess_type(str(avatar_path))
|
|
926
981
|
|
|
927
|
-
if files:
|
|
928
|
-
avatar_path = files[0]
|
|
929
|
-
media_type, _ = mimetypes.guess_type(avatar_path)
|
|
930
982
|
return FileResponse(avatar_path, media_type=media_type)
|
|
931
|
-
|
|
932
|
-
|
|
983
|
+
|
|
984
|
+
return await get_favicon()
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
# post avatar/{avatar_id} (only for authenticated users)
|
|
988
|
+
@router.post("/avatars/{avatar_id}")
|
|
989
|
+
async def upload_avatar(
|
|
990
|
+
avatar_id: str,
|
|
991
|
+
file: UploadFile,
|
|
992
|
+
current_user: Annotated[
|
|
993
|
+
Union[None, User, PersistedUser], Depends(get_current_user)
|
|
994
|
+
],
|
|
995
|
+
):
|
|
996
|
+
try:
|
|
997
|
+
avatar_path = os.path.join(APP_ROOT, "public", "avatars", avatar_id)
|
|
998
|
+
|
|
999
|
+
# Ensure the avatars directory exists
|
|
1000
|
+
os.makedirs(os.path.dirname(avatar_path), exist_ok=True)
|
|
1001
|
+
|
|
1002
|
+
with open(avatar_path, "wb") as f:
|
|
1003
|
+
f.write(await file.read())
|
|
1004
|
+
except Exception as e:
|
|
1005
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1006
|
+
|
|
1007
|
+
return {"id": avatar_id}
|
|
933
1008
|
|
|
934
1009
|
|
|
935
1010
|
@router.head("/")
|
|
936
1011
|
def status_check():
|
|
1012
|
+
"""Check if the site is operational."""
|
|
937
1013
|
return {"message": "Site is operational"}
|
|
938
1014
|
|
|
939
1015
|
|
chainlit/session.py
CHANGED
|
@@ -16,6 +16,7 @@ from typing import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
import aiofiles
|
|
19
|
+
from chainlit.assistant import Assistant
|
|
19
20
|
from chainlit.logger import logger
|
|
20
21
|
|
|
21
22
|
if TYPE_CHECKING:
|
|
@@ -64,7 +65,7 @@ class BaseSession:
|
|
|
64
65
|
client_type: ClientType,
|
|
65
66
|
# Thread id
|
|
66
67
|
thread_id: Optional[str],
|
|
67
|
-
# Logged-in user
|
|
68
|
+
# Logged-in user information
|
|
68
69
|
user: Optional[Union["User", "PersistedUser"]],
|
|
69
70
|
# Logged-in user token
|
|
70
71
|
token: Optional[str],
|
|
@@ -72,8 +73,12 @@ class BaseSession:
|
|
|
72
73
|
user_env: Optional[Dict[str, str]],
|
|
73
74
|
# Chat profile selected before the session was created
|
|
74
75
|
chat_profile: Optional[str] = None,
|
|
76
|
+
# Selected assistant
|
|
77
|
+
selected_assistant: Optional[Assistant] = None,
|
|
75
78
|
# Origin of the request
|
|
76
79
|
http_referer: Optional[str] = None,
|
|
80
|
+
# assistant settings
|
|
81
|
+
assistant_settings: Optional[Dict[str, Any]] = None,
|
|
77
82
|
):
|
|
78
83
|
if thread_id:
|
|
79
84
|
self.thread_id_to_resume = thread_id
|
|
@@ -90,7 +95,9 @@ class BaseSession:
|
|
|
90
95
|
|
|
91
96
|
self.id = id
|
|
92
97
|
|
|
98
|
+
self.assistant_settings = assistant_settings
|
|
93
99
|
self.chat_settings: Dict[str, Any] = {}
|
|
100
|
+
self.selected_assistant = selected_assistant
|
|
94
101
|
|
|
95
102
|
@property
|
|
96
103
|
def files_dir(self):
|
|
@@ -153,6 +160,7 @@ class BaseSession:
|
|
|
153
160
|
user_session = user_sessions.get(self.id) or {} # type: Dict
|
|
154
161
|
user_session["chat_settings"] = self.chat_settings
|
|
155
162
|
user_session["chat_profile"] = self.chat_profile
|
|
163
|
+
user_session["selected_assistant"] = self.selected_assistant
|
|
156
164
|
user_session["http_referer"] = self.http_referer
|
|
157
165
|
user_session["client_type"] = self.client_type
|
|
158
166
|
metadata = clean_metadata(user_session)
|
|
@@ -169,13 +177,17 @@ class HTTPSession(BaseSession):
|
|
|
169
177
|
client_type: ClientType,
|
|
170
178
|
# Thread id
|
|
171
179
|
thread_id: Optional[str] = None,
|
|
172
|
-
# Logged-in user
|
|
180
|
+
# Logged-in user information
|
|
173
181
|
user: Optional[Union["User", "PersistedUser"]] = None,
|
|
174
182
|
# Logged-in user token
|
|
175
183
|
token: Optional[str] = None,
|
|
176
184
|
user_env: Optional[Dict[str, str]] = None,
|
|
177
185
|
# Origin of the request
|
|
178
186
|
http_referer: Optional[str] = None,
|
|
187
|
+
# assistant settings
|
|
188
|
+
assistant_settings: Optional[Dict[str, Any]] = None,
|
|
189
|
+
# selected assistant
|
|
190
|
+
selected_assistant: Optional[Assistant] = None,
|
|
179
191
|
):
|
|
180
192
|
super().__init__(
|
|
181
193
|
id=id,
|
|
@@ -185,6 +197,8 @@ class HTTPSession(BaseSession):
|
|
|
185
197
|
client_type=client_type,
|
|
186
198
|
user_env=user_env,
|
|
187
199
|
http_referer=http_referer,
|
|
200
|
+
assistant_settings=assistant_settings,
|
|
201
|
+
selected_assistant=selected_assistant,
|
|
188
202
|
)
|
|
189
203
|
|
|
190
204
|
def delete(self):
|
|
@@ -193,6 +207,9 @@ class HTTPSession(BaseSession):
|
|
|
193
207
|
shutil.rmtree(self.files_dir)
|
|
194
208
|
|
|
195
209
|
|
|
210
|
+
ThreadQueue = Deque[tuple[Callable, object, tuple, Dict]]
|
|
211
|
+
|
|
212
|
+
|
|
196
213
|
class WebsocketSession(BaseSession):
|
|
197
214
|
"""Internal web socket session object.
|
|
198
215
|
|
|
@@ -222,16 +239,20 @@ class WebsocketSession(BaseSession):
|
|
|
222
239
|
client_type: ClientType,
|
|
223
240
|
# Thread id
|
|
224
241
|
thread_id: Optional[str] = None,
|
|
225
|
-
# Logged-in user
|
|
242
|
+
# Logged-in user information
|
|
226
243
|
user: Optional[Union["User", "PersistedUser"]] = None,
|
|
227
244
|
# Logged-in user token
|
|
228
245
|
token: Optional[str] = None,
|
|
229
246
|
# Chat profile selected before the session was created
|
|
230
247
|
chat_profile: Optional[str] = None,
|
|
248
|
+
# Selected assistant
|
|
249
|
+
selected_assistant: Optional[Assistant] = None,
|
|
231
250
|
# Languages of the user's browser
|
|
232
251
|
languages: Optional[str] = None,
|
|
233
252
|
# Origin of the request
|
|
234
253
|
http_referer: Optional[str] = None,
|
|
254
|
+
# chat settings
|
|
255
|
+
assistant_settings: Optional[Dict[str, Any]] = None,
|
|
235
256
|
):
|
|
236
257
|
super().__init__(
|
|
237
258
|
id=id,
|
|
@@ -241,7 +262,9 @@ class WebsocketSession(BaseSession):
|
|
|
241
262
|
user_env=user_env,
|
|
242
263
|
client_type=client_type,
|
|
243
264
|
chat_profile=chat_profile,
|
|
265
|
+
selected_assistant=selected_assistant,
|
|
244
266
|
http_referer=http_referer,
|
|
267
|
+
assistant_settings=assistant_settings,
|
|
245
268
|
)
|
|
246
269
|
|
|
247
270
|
self.socket_id = socket_id
|
|
@@ -250,7 +273,7 @@ class WebsocketSession(BaseSession):
|
|
|
250
273
|
|
|
251
274
|
self.restored = False
|
|
252
275
|
|
|
253
|
-
self.thread_queues
|
|
276
|
+
self.thread_queues: Dict[str, ThreadQueue] = {}
|
|
254
277
|
|
|
255
278
|
ws_sessions_id[self.id] = self
|
|
256
279
|
ws_sessions_sid[socket_id] = self
|