chainlit 1.0.506__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +34 -1
- chainlit/config.py +43 -21
- chainlit/context.py +19 -7
- chainlit/copilot/dist/index.js +650 -528
- chainlit/data/__init__.py +4 -2
- chainlit/data/acl.py +4 -1
- chainlit/data/sql_alchemy.py +39 -31
- chainlit/discord/__init__.py +6 -0
- chainlit/discord/app.py +304 -0
- chainlit/element.py +9 -3
- chainlit/emitter.py +11 -2
- chainlit/frontend/dist/assets/{index-d4233b49.js → index-0a52365d.js} +189 -185
- chainlit/frontend/dist/assets/react-plotly-509d26a7.js +3602 -0
- chainlit/frontend/dist/index.html +1 -1
- chainlit/llama_index/callbacks.py +7 -6
- chainlit/message.py +3 -3
- chainlit/server.py +31 -4
- chainlit/session.py +83 -62
- chainlit/slack/__init__.py +6 -0
- chainlit/slack/app.py +368 -0
- chainlit/socket.py +91 -33
- chainlit/step.py +25 -1
- chainlit/types.py +21 -1
- chainlit/user_session.py +6 -2
- chainlit/utils.py +2 -1
- {chainlit-1.0.506.dist-info → chainlit-1.1.0.dist-info}/METADATA +4 -3
- {chainlit-1.0.506.dist-info → chainlit-1.1.0.dist-info}/RECORD +29 -25
- chainlit/frontend/dist/assets/react-plotly-2b7fa4f9.js +0 -3484
- {chainlit-1.0.506.dist-info → chainlit-1.1.0.dist-info}/WHEEL +0 -0
- {chainlit-1.0.506.dist-info → chainlit-1.1.0.dist-info}/entry_points.txt +0 -0
chainlit/slack/app.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import uuid
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
from chainlit.config import config
|
|
10
|
+
from chainlit.context import ChainlitContext, HTTPSession, context_var
|
|
11
|
+
from chainlit.data import get_data_layer
|
|
12
|
+
from chainlit.element import Element, ElementDict
|
|
13
|
+
from chainlit.emitter import BaseChainlitEmitter
|
|
14
|
+
from chainlit.message import Message, StepDict
|
|
15
|
+
from chainlit.types import Feedback
|
|
16
|
+
from chainlit.user import PersistedUser, User
|
|
17
|
+
from chainlit.user_session import user_session
|
|
18
|
+
from slack_bolt.adapter.fastapi.async_handler import AsyncSlackRequestHandler
|
|
19
|
+
from slack_bolt.async_app import AsyncApp
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SlackEmitter(BaseChainlitEmitter):
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
session: HTTPSession,
|
|
26
|
+
app: AsyncApp,
|
|
27
|
+
channel_id: str,
|
|
28
|
+
say,
|
|
29
|
+
enabled=False,
|
|
30
|
+
thread_ts: Optional[str] = None,
|
|
31
|
+
):
|
|
32
|
+
super().__init__(session)
|
|
33
|
+
self.app = app
|
|
34
|
+
self.channel_id = channel_id
|
|
35
|
+
self.say = say
|
|
36
|
+
self.enabled = enabled
|
|
37
|
+
self.thread_ts = thread_ts
|
|
38
|
+
|
|
39
|
+
async def send_element(self, element_dict: ElementDict):
|
|
40
|
+
if not self.enabled or element_dict.get("display") != "inline":
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
persisted_file = self.session.files.get(element_dict.get("chainlitKey") or "")
|
|
44
|
+
file: Optional[Union[bytes, str]] = None
|
|
45
|
+
|
|
46
|
+
if persisted_file:
|
|
47
|
+
file = str(persisted_file["path"])
|
|
48
|
+
elif file_url := element_dict.get("url"):
|
|
49
|
+
async with httpx.AsyncClient() as client:
|
|
50
|
+
response = await client.get(file_url)
|
|
51
|
+
if response.status_code == 200:
|
|
52
|
+
file = response.content
|
|
53
|
+
|
|
54
|
+
if not file:
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
await self.app.client.files_upload_v2(
|
|
58
|
+
channel=self.channel_id,
|
|
59
|
+
thread_ts=self.thread_ts,
|
|
60
|
+
file=file,
|
|
61
|
+
title=element_dict.get("name"),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
async def send_step(self, step_dict: StepDict):
|
|
65
|
+
if not self.enabled:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
is_chain_of_thought = bool(step_dict.get("parentId"))
|
|
69
|
+
is_empty_output = not step_dict.get("output")
|
|
70
|
+
|
|
71
|
+
if is_chain_of_thought or is_empty_output:
|
|
72
|
+
return
|
|
73
|
+
|
|
74
|
+
enable_feedback = not step_dict.get("disableFeedback") and get_data_layer()
|
|
75
|
+
blocks: List[Dict] = [
|
|
76
|
+
{
|
|
77
|
+
"type": "section",
|
|
78
|
+
"text": {"type": "mrkdwn", "text": step_dict["output"]},
|
|
79
|
+
}
|
|
80
|
+
]
|
|
81
|
+
if enable_feedback:
|
|
82
|
+
blocks.append(
|
|
83
|
+
{
|
|
84
|
+
"type": "actions",
|
|
85
|
+
"elements": [
|
|
86
|
+
{
|
|
87
|
+
"action_id": "thumbdown",
|
|
88
|
+
"type": "button",
|
|
89
|
+
"text": {
|
|
90
|
+
"type": "plain_text",
|
|
91
|
+
"emoji": True,
|
|
92
|
+
"text": ":thumbsdown:",
|
|
93
|
+
},
|
|
94
|
+
"value": step_dict.get("id"),
|
|
95
|
+
},
|
|
96
|
+
{
|
|
97
|
+
"action_id": "thumbup",
|
|
98
|
+
"type": "button",
|
|
99
|
+
"text": {
|
|
100
|
+
"type": "plain_text",
|
|
101
|
+
"emoji": True,
|
|
102
|
+
"text": ":thumbsup:",
|
|
103
|
+
},
|
|
104
|
+
"value": step_dict.get("id"),
|
|
105
|
+
},
|
|
106
|
+
],
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
await self.say(
|
|
110
|
+
text=step_dict["output"], blocks=blocks, thread_ts=self.thread_ts
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
async def update_step(self, step_dict: StepDict):
|
|
114
|
+
if not self.enabled:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
await self.send_step(step_dict)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
slack_app = AsyncApp(
|
|
121
|
+
token=os.environ.get("SLACK_BOT_TOKEN"),
|
|
122
|
+
signing_secret=os.environ.get("SLACK_SIGNING_SECRET"),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def init_slack_context(
|
|
127
|
+
session: HTTPSession,
|
|
128
|
+
slack_channel_id: str,
|
|
129
|
+
event,
|
|
130
|
+
say,
|
|
131
|
+
thread_ts: Optional[str] = None,
|
|
132
|
+
) -> ChainlitContext:
|
|
133
|
+
emitter = SlackEmitter(
|
|
134
|
+
session=session,
|
|
135
|
+
app=slack_app,
|
|
136
|
+
channel_id=slack_channel_id,
|
|
137
|
+
say=say,
|
|
138
|
+
thread_ts=thread_ts,
|
|
139
|
+
)
|
|
140
|
+
context = ChainlitContext(session=session, emitter=emitter)
|
|
141
|
+
context_var.set(context)
|
|
142
|
+
user_session.set("slack_event", event)
|
|
143
|
+
user_session.set(
|
|
144
|
+
"fetch_slack_message_history",
|
|
145
|
+
partial(
|
|
146
|
+
fetch_message_history, channel_id=slack_channel_id, thread_ts=thread_ts
|
|
147
|
+
),
|
|
148
|
+
)
|
|
149
|
+
return context
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
slack_app_handler = AsyncSlackRequestHandler(slack_app)
|
|
153
|
+
|
|
154
|
+
users_by_slack_id: Dict[str, Union[User, PersistedUser]] = {}
|
|
155
|
+
|
|
156
|
+
USER_PREFIX = "slack_"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def clean_content(message: str):
|
|
160
|
+
cleaned_text = re.sub(r"<@[\w]+>", "", message).strip()
|
|
161
|
+
return cleaned_text
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
async def get_user(slack_user_id: str):
|
|
165
|
+
slack_user = await slack_app.client.users_info(user=slack_user_id)
|
|
166
|
+
slack_user_profile = slack_user["user"]["profile"]
|
|
167
|
+
|
|
168
|
+
user_email = slack_user_profile.get("email")
|
|
169
|
+
user = User(identifier=USER_PREFIX + user_email, metadata=slack_user_profile)
|
|
170
|
+
|
|
171
|
+
users_by_slack_id[slack_user_id] = user
|
|
172
|
+
|
|
173
|
+
if data_layer := get_data_layer():
|
|
174
|
+
persisted_user = await data_layer.create_user(user)
|
|
175
|
+
if persisted_user:
|
|
176
|
+
users_by_slack_id[slack_user_id] = persisted_user
|
|
177
|
+
|
|
178
|
+
return users_by_slack_id[slack_user_id]
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
async def fetch_message_history(
|
|
182
|
+
channel_id: str, thread_ts: Optional[str] = None, limit=30
|
|
183
|
+
):
|
|
184
|
+
if not thread_ts:
|
|
185
|
+
result = await slack_app.client.conversations_history(
|
|
186
|
+
channel=channel_id, limit=limit
|
|
187
|
+
)
|
|
188
|
+
else:
|
|
189
|
+
result = await slack_app.client.conversations_replies(
|
|
190
|
+
channel=channel_id, ts=thread_ts, limit=limit
|
|
191
|
+
)
|
|
192
|
+
if result["ok"]:
|
|
193
|
+
messages = result["messages"]
|
|
194
|
+
return messages
|
|
195
|
+
else:
|
|
196
|
+
raise Exception(f"Failed to fetch messages: {result['error']}")
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
async def download_slack_file(url, token):
|
|
200
|
+
headers = {"Authorization": f"Bearer {token}"}
|
|
201
|
+
async with httpx.AsyncClient() as client:
|
|
202
|
+
response = await client.get(url, headers=headers)
|
|
203
|
+
if response.status_code == 200:
|
|
204
|
+
return response.content
|
|
205
|
+
else:
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
async def download_slack_files(session: HTTPSession, files, token):
|
|
210
|
+
download_coros = [
|
|
211
|
+
download_slack_file(file.get("url_private"), token) for file in files
|
|
212
|
+
]
|
|
213
|
+
file_bytes_list = await asyncio.gather(*download_coros)
|
|
214
|
+
file_refs = []
|
|
215
|
+
for idx, file_bytes in enumerate(file_bytes_list):
|
|
216
|
+
if file_bytes:
|
|
217
|
+
name = files[idx].get("name")
|
|
218
|
+
mime_type = files[idx].get("mimetype")
|
|
219
|
+
file_ref = await session.persist_file(
|
|
220
|
+
name=name, mime=mime_type, content=file_bytes
|
|
221
|
+
)
|
|
222
|
+
file_refs.append(file_ref)
|
|
223
|
+
|
|
224
|
+
files_dicts = [
|
|
225
|
+
session.files[file["id"]] for file in file_refs if file["id"] in session.files
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
file_elements = [Element.from_dict(file_dict) for file_dict in files_dicts]
|
|
229
|
+
|
|
230
|
+
return file_elements
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
async def process_slack_message(
|
|
234
|
+
event,
|
|
235
|
+
say,
|
|
236
|
+
thread_name: Optional[str] = None,
|
|
237
|
+
bind_thread_to_user=False,
|
|
238
|
+
thread_ts: Optional[str] = None,
|
|
239
|
+
):
|
|
240
|
+
user = await get_user(event["user"])
|
|
241
|
+
|
|
242
|
+
channel_id = event["channel"]
|
|
243
|
+
thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, thread_ts or channel_id))
|
|
244
|
+
|
|
245
|
+
text = event.get("text")
|
|
246
|
+
slack_files = event.get("files", [])
|
|
247
|
+
|
|
248
|
+
session_id = str(uuid.uuid4())
|
|
249
|
+
session = HTTPSession(
|
|
250
|
+
id=session_id,
|
|
251
|
+
thread_id=thread_id,
|
|
252
|
+
user=user,
|
|
253
|
+
client_type="slack",
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
ctx = init_slack_context(
|
|
257
|
+
session=session,
|
|
258
|
+
slack_channel_id=channel_id,
|
|
259
|
+
event=event,
|
|
260
|
+
say=say,
|
|
261
|
+
thread_ts=thread_ts,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
file_elements = await download_slack_files(
|
|
265
|
+
session, slack_files, slack_app.client.token
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
msg = Message(
|
|
269
|
+
content=clean_content(text),
|
|
270
|
+
elements=file_elements,
|
|
271
|
+
type="user_message",
|
|
272
|
+
author=user.metadata.get("real_name"),
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
await msg.send()
|
|
276
|
+
|
|
277
|
+
ctx.emitter.enabled = True
|
|
278
|
+
|
|
279
|
+
if on_chat_start := config.code.on_chat_start:
|
|
280
|
+
await on_chat_start()
|
|
281
|
+
|
|
282
|
+
if on_message := config.code.on_message:
|
|
283
|
+
await on_message(msg)
|
|
284
|
+
|
|
285
|
+
if on_chat_end := config.code.on_chat_end:
|
|
286
|
+
await on_chat_end()
|
|
287
|
+
|
|
288
|
+
if data_layer := get_data_layer():
|
|
289
|
+
user_id = None
|
|
290
|
+
if isinstance(user, PersistedUser):
|
|
291
|
+
user_id = user.id if bind_thread_to_user else None
|
|
292
|
+
|
|
293
|
+
await data_layer.update_thread(
|
|
294
|
+
thread_id=thread_id,
|
|
295
|
+
name=thread_name or msg.content,
|
|
296
|
+
metadata=ctx.session.to_persistable(),
|
|
297
|
+
user_id=user_id,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
ctx.session.delete()
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@slack_app.event("app_home_opened")
|
|
304
|
+
async def handle_app_home_opened(event, say):
|
|
305
|
+
pass
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@slack_app.event("app_mention")
|
|
309
|
+
async def handle_app_mentions(event, say):
|
|
310
|
+
thread_ts = event.get("thread_ts", event["ts"])
|
|
311
|
+
await process_slack_message(event, say, thread_ts=thread_ts)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@slack_app.event("message")
|
|
315
|
+
async def handle_message(message, say):
|
|
316
|
+
user = await get_user(message["user"])
|
|
317
|
+
thread_name = f"{user.identifier} Slack DM"
|
|
318
|
+
await process_slack_message(message, say, thread_name)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
@slack_app.block_action("thumbdown")
|
|
322
|
+
async def thumb_down(ack, context, body):
|
|
323
|
+
await ack()
|
|
324
|
+
step_id = body["actions"][0]["value"]
|
|
325
|
+
|
|
326
|
+
if data_layer := get_data_layer():
|
|
327
|
+
await data_layer.upsert_feedback(Feedback(forId=step_id, value=0))
|
|
328
|
+
|
|
329
|
+
text = body["message"]["text"]
|
|
330
|
+
blocks = body["message"]["blocks"]
|
|
331
|
+
updated_blocks = [block for block in blocks if block["type"] != "actions"]
|
|
332
|
+
updated_blocks.append(
|
|
333
|
+
{
|
|
334
|
+
"type": "section",
|
|
335
|
+
"text": {"type": "mrkdwn", "text": ":thumbsdown: Feedback received."},
|
|
336
|
+
}
|
|
337
|
+
)
|
|
338
|
+
await context.client.chat_update(
|
|
339
|
+
channel=body["channel"]["id"],
|
|
340
|
+
ts=body["container"]["message_ts"],
|
|
341
|
+
text=text,
|
|
342
|
+
blocks=updated_blocks,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@slack_app.block_action("thumbup")
|
|
347
|
+
async def thumb_up(ack, context, body):
|
|
348
|
+
await ack()
|
|
349
|
+
step_id = body["actions"][0]["value"]
|
|
350
|
+
|
|
351
|
+
if data_layer := get_data_layer():
|
|
352
|
+
await data_layer.upsert_feedback(Feedback(forId=step_id, value=1))
|
|
353
|
+
|
|
354
|
+
text = body["message"]["text"]
|
|
355
|
+
blocks = body["message"]["blocks"]
|
|
356
|
+
updated_blocks = [block for block in blocks if block["type"] != "actions"]
|
|
357
|
+
updated_blocks.append(
|
|
358
|
+
{
|
|
359
|
+
"type": "section",
|
|
360
|
+
"text": {"type": "mrkdwn", "text": ":thumbsup: Feedback received."},
|
|
361
|
+
}
|
|
362
|
+
)
|
|
363
|
+
await context.client.chat_update(
|
|
364
|
+
channel=body["channel"]["id"],
|
|
365
|
+
ts=body["container"]["message_ts"],
|
|
366
|
+
text=text,
|
|
367
|
+
blocks=updated_blocks,
|
|
368
|
+
)
|
chainlit/socket.py
CHANGED
|
@@ -9,12 +9,18 @@ from chainlit.auth import get_current_user, require_login
|
|
|
9
9
|
from chainlit.config import config
|
|
10
10
|
from chainlit.context import init_ws_context
|
|
11
11
|
from chainlit.data import get_data_layer
|
|
12
|
+
from chainlit.element import Element
|
|
12
13
|
from chainlit.logger import logger
|
|
13
14
|
from chainlit.message import ErrorMessage, Message
|
|
14
15
|
from chainlit.server import socket
|
|
15
16
|
from chainlit.session import WebsocketSession
|
|
16
17
|
from chainlit.telemetry import trace_event
|
|
17
|
-
from chainlit.types import
|
|
18
|
+
from chainlit.types import (
|
|
19
|
+
AudioChunk,
|
|
20
|
+
AudioChunkPayload,
|
|
21
|
+
AudioEndPayload,
|
|
22
|
+
UIMessagePayload,
|
|
23
|
+
)
|
|
18
24
|
from chainlit.user_session import user_sessions
|
|
19
25
|
|
|
20
26
|
|
|
@@ -93,9 +99,13 @@ def build_anon_user_identifier(environ):
|
|
|
93
99
|
|
|
94
100
|
@socket.on("connect")
|
|
95
101
|
async def connect(sid, environ, auth):
|
|
96
|
-
if
|
|
102
|
+
if (
|
|
103
|
+
not config.code.on_chat_start
|
|
104
|
+
and not config.code.on_message
|
|
105
|
+
and not config.code.on_audio_chunk
|
|
106
|
+
):
|
|
97
107
|
logger.warning(
|
|
98
|
-
"You need to configure at least
|
|
108
|
+
"You need to configure at least one of on_chat_start, on_message or on_audio_chunk callback"
|
|
99
109
|
)
|
|
100
110
|
return False
|
|
101
111
|
user = None
|
|
@@ -113,18 +123,10 @@ async def connect(sid, environ, auth):
|
|
|
113
123
|
|
|
114
124
|
# Session scoped function to emit to the client
|
|
115
125
|
def emit_fn(event, data):
|
|
116
|
-
if session := WebsocketSession.get(sid):
|
|
117
|
-
if session.should_stop:
|
|
118
|
-
session.should_stop = False
|
|
119
|
-
raise InterruptedError("Task stopped by user")
|
|
120
126
|
return socket.emit(event, data, to=sid)
|
|
121
127
|
|
|
122
128
|
# Session scoped function to emit to the client and wait for a response
|
|
123
129
|
def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
|
|
124
|
-
if session := WebsocketSession.get(sid):
|
|
125
|
-
if session.should_stop:
|
|
126
|
-
session.should_stop = False
|
|
127
|
-
raise InterruptedError("Task stopped by user")
|
|
128
130
|
return socket.call(event, data, timeout=timeout, to=sid)
|
|
129
131
|
|
|
130
132
|
session_id = environ.get("HTTP_X_CHAINLIT_SESSION_ID")
|
|
@@ -135,6 +137,7 @@ async def connect(sid, environ, auth):
|
|
|
135
137
|
user_env = load_user_env(user_env_string)
|
|
136
138
|
|
|
137
139
|
client_type = environ.get("HTTP_X_CHAINLIT_CLIENT_TYPE")
|
|
140
|
+
http_referer = environ.get("HTTP_REFERER")
|
|
138
141
|
|
|
139
142
|
ws_session = WebsocketSession(
|
|
140
143
|
id=session_id,
|
|
@@ -148,6 +151,7 @@ async def connect(sid, environ, auth):
|
|
|
148
151
|
chat_profile=environ.get("HTTP_X_CHAINLIT_CHAT_PROFILE"),
|
|
149
152
|
thread_id=environ.get("HTTP_X_CHAINLIT_THREAD_ID"),
|
|
150
153
|
languages=environ.get("HTTP_ACCEPT_LANGUAGE"),
|
|
154
|
+
http_referer=http_referer,
|
|
151
155
|
)
|
|
152
156
|
|
|
153
157
|
trace_event("connection_successful")
|
|
@@ -173,46 +177,53 @@ async def connection_successful(sid):
|
|
|
173
177
|
"first_interaction",
|
|
174
178
|
{"interaction": "resume", "thread_id": thread.get("id")},
|
|
175
179
|
)
|
|
176
|
-
await context.emitter.resume_thread(thread)
|
|
177
180
|
await config.code.on_chat_resume(thread)
|
|
181
|
+
await context.emitter.resume_thread(thread)
|
|
178
182
|
return
|
|
179
183
|
|
|
180
184
|
if config.code.on_chat_start:
|
|
181
|
-
|
|
185
|
+
task = asyncio.create_task(config.code.on_chat_start())
|
|
186
|
+
context.session.current_task = task
|
|
182
187
|
|
|
183
188
|
|
|
184
189
|
@socket.on("clear_session")
|
|
185
190
|
async def clean_session(sid):
|
|
186
|
-
|
|
191
|
+
session = WebsocketSession.get(sid)
|
|
192
|
+
if session:
|
|
193
|
+
session.to_clear = True
|
|
187
194
|
|
|
188
195
|
|
|
189
196
|
@socket.on("disconnect")
|
|
190
|
-
async def disconnect(sid
|
|
197
|
+
async def disconnect(sid):
|
|
191
198
|
session = WebsocketSession.get(sid)
|
|
192
|
-
if session:
|
|
193
|
-
init_ws_context(session)
|
|
194
199
|
|
|
195
|
-
if
|
|
200
|
+
if not session:
|
|
201
|
+
return
|
|
202
|
+
|
|
203
|
+
init_ws_context(session)
|
|
204
|
+
|
|
205
|
+
if config.code.on_chat_end:
|
|
196
206
|
await config.code.on_chat_end()
|
|
197
207
|
|
|
198
|
-
if session
|
|
208
|
+
if session.thread_id and session.has_first_interaction:
|
|
199
209
|
await persist_user_session(session.thread_id, session.to_persistable())
|
|
200
210
|
|
|
201
|
-
def clear():
|
|
202
|
-
if session := WebsocketSession.get(
|
|
211
|
+
def clear(_sid):
|
|
212
|
+
if session := WebsocketSession.get(_sid):
|
|
203
213
|
# Clean up the user session
|
|
204
214
|
if session.id in user_sessions:
|
|
205
215
|
user_sessions.pop(session.id)
|
|
206
216
|
# Clean up the session
|
|
207
217
|
session.delete()
|
|
208
218
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
clear()
|
|
212
|
-
|
|
213
|
-
if force_clear:
|
|
214
|
-
clear()
|
|
219
|
+
if session.to_clear:
|
|
220
|
+
clear(sid)
|
|
215
221
|
else:
|
|
222
|
+
|
|
223
|
+
async def clear_on_timeout(_sid):
|
|
224
|
+
await asyncio.sleep(config.project.session_timeout)
|
|
225
|
+
clear(_sid)
|
|
226
|
+
|
|
216
227
|
asyncio.ensure_future(clear_on_timeout(sid))
|
|
217
228
|
|
|
218
229
|
|
|
@@ -223,10 +234,11 @@ async def stop(sid):
|
|
|
223
234
|
|
|
224
235
|
init_ws_context(session)
|
|
225
236
|
await Message(
|
|
226
|
-
author="System", content="Task stopped
|
|
237
|
+
author="System", content="Task manually stopped.", disable_feedback=True
|
|
227
238
|
).send()
|
|
228
239
|
|
|
229
|
-
session.
|
|
240
|
+
if session.current_task:
|
|
241
|
+
session.current_task.cancel()
|
|
230
242
|
|
|
231
243
|
if config.code.on_stop:
|
|
232
244
|
await config.code.on_stop()
|
|
@@ -243,7 +255,7 @@ async def process_message(session: WebsocketSession, payload: UIMessagePayload):
|
|
|
243
255
|
# Sleep 1ms to make sure any children step starts after the message step start
|
|
244
256
|
time.sleep(0.001)
|
|
245
257
|
await config.code.on_message(message)
|
|
246
|
-
except
|
|
258
|
+
except asyncio.CancelledError:
|
|
247
259
|
pass
|
|
248
260
|
except Exception as e:
|
|
249
261
|
logger.exception(e)
|
|
@@ -258,9 +270,55 @@ async def process_message(session: WebsocketSession, payload: UIMessagePayload):
|
|
|
258
270
|
async def message(sid, payload: UIMessagePayload):
|
|
259
271
|
"""Handle a message sent by the User."""
|
|
260
272
|
session = WebsocketSession.require(sid)
|
|
261
|
-
session.should_stop = False
|
|
262
273
|
|
|
263
|
-
|
|
274
|
+
task = asyncio.create_task(process_message(session, payload))
|
|
275
|
+
session.current_task = task
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@socket.on("audio_chunk")
|
|
279
|
+
async def audio_chunk(sid, payload: AudioChunkPayload):
|
|
280
|
+
"""Handle an audio chunk sent by the user."""
|
|
281
|
+
session = WebsocketSession.require(sid)
|
|
282
|
+
|
|
283
|
+
init_ws_context(session)
|
|
284
|
+
|
|
285
|
+
if config.code.on_audio_chunk:
|
|
286
|
+
asyncio.create_task(config.code.on_audio_chunk(AudioChunk(**payload)))
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
@socket.on("audio_end")
|
|
290
|
+
async def audio_end(sid, payload: AudioEndPayload):
|
|
291
|
+
"""Handle the end of the audio stream."""
|
|
292
|
+
session = WebsocketSession.require(sid)
|
|
293
|
+
try:
|
|
294
|
+
context = init_ws_context(session)
|
|
295
|
+
await context.emitter.task_start()
|
|
296
|
+
|
|
297
|
+
if not session.has_first_interaction:
|
|
298
|
+
session.has_first_interaction = True
|
|
299
|
+
asyncio.create_task(context.emitter.init_thread("audio"))
|
|
300
|
+
|
|
301
|
+
file_elements = []
|
|
302
|
+
if config.code.on_audio_end:
|
|
303
|
+
file_refs = payload.get("fileReferences")
|
|
304
|
+
if file_refs:
|
|
305
|
+
files = [
|
|
306
|
+
session.files[file["id"]]
|
|
307
|
+
for file in file_refs
|
|
308
|
+
if file["id"] in session.files
|
|
309
|
+
]
|
|
310
|
+
file_elements = [Element.from_dict(file) for file in files]
|
|
311
|
+
|
|
312
|
+
await config.code.on_audio_end(file_elements)
|
|
313
|
+
except asyncio.CancelledError:
|
|
314
|
+
pass
|
|
315
|
+
except Exception as e:
|
|
316
|
+
logger.exception(e)
|
|
317
|
+
await ErrorMessage(
|
|
318
|
+
author="Error", content=str(e) or e.__class__.__name__
|
|
319
|
+
).send()
|
|
320
|
+
finally:
|
|
321
|
+
await context.emitter.task_end()
|
|
264
322
|
|
|
265
323
|
|
|
266
324
|
async def process_action(action: Action):
|
|
@@ -288,7 +346,7 @@ async def call_action(sid, action):
|
|
|
288
346
|
id=action.id, status=True, response=res if isinstance(res, str) else None
|
|
289
347
|
)
|
|
290
348
|
|
|
291
|
-
except
|
|
349
|
+
except asyncio.CancelledError:
|
|
292
350
|
await context.emitter.send_action_response(
|
|
293
351
|
id=action.id, status=False, response="Action interrupted by the user"
|
|
294
352
|
)
|
chainlit/step.py
CHANGED
|
@@ -193,10 +193,34 @@ class Step:
|
|
|
193
193
|
self.persisted = False
|
|
194
194
|
self.fail_on_persist_error = False
|
|
195
195
|
|
|
196
|
+
def _clean_content(self, content):
|
|
197
|
+
"""
|
|
198
|
+
Recursively checks and converts bytes objects in content.
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def handle_bytes(item):
|
|
202
|
+
if isinstance(item, bytes):
|
|
203
|
+
return "STRIPPED_BINARY_DATA"
|
|
204
|
+
elif isinstance(item, dict):
|
|
205
|
+
return {k: handle_bytes(v) for k, v in item.items()}
|
|
206
|
+
elif isinstance(item, list):
|
|
207
|
+
return [handle_bytes(i) for i in item]
|
|
208
|
+
elif isinstance(item, tuple):
|
|
209
|
+
return tuple(handle_bytes(i) for i in item)
|
|
210
|
+
return item
|
|
211
|
+
|
|
212
|
+
return handle_bytes(content)
|
|
213
|
+
|
|
196
214
|
def _process_content(self, content, set_language=False):
|
|
197
215
|
if content is None:
|
|
198
216
|
return ""
|
|
199
|
-
|
|
217
|
+
content = self._clean_content(content)
|
|
218
|
+
|
|
219
|
+
if (
|
|
220
|
+
isinstance(content, dict)
|
|
221
|
+
or isinstance(content, list)
|
|
222
|
+
or isinstance(content, tuple)
|
|
223
|
+
):
|
|
200
224
|
try:
|
|
201
225
|
processed_content = json.dumps(content, indent=4, ensure_ascii=False)
|
|
202
226
|
if set_language:
|
chainlit/types.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
+
from pathlib import Path
|
|
2
3
|
from typing import (
|
|
3
4
|
TYPE_CHECKING,
|
|
4
5
|
Any,
|
|
@@ -144,7 +145,7 @@ class FileReference(TypedDict):
|
|
|
144
145
|
class FileDict(TypedDict):
|
|
145
146
|
id: str
|
|
146
147
|
name: str
|
|
147
|
-
path:
|
|
148
|
+
path: Path
|
|
148
149
|
size: int
|
|
149
150
|
type: str
|
|
150
151
|
|
|
@@ -154,6 +155,25 @@ class UIMessagePayload(TypedDict):
|
|
|
154
155
|
fileReferences: Optional[List[FileReference]]
|
|
155
156
|
|
|
156
157
|
|
|
158
|
+
class AudioChunkPayload(TypedDict):
|
|
159
|
+
isStart: bool
|
|
160
|
+
mimeType: str
|
|
161
|
+
elapsedTime: float
|
|
162
|
+
data: bytes
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@dataclass
|
|
166
|
+
class AudioChunk:
|
|
167
|
+
isStart: bool
|
|
168
|
+
mimeType: str
|
|
169
|
+
elapsedTime: float
|
|
170
|
+
data: bytes
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class AudioEndPayload(TypedDict):
|
|
174
|
+
fileReferences: Optional[List[FileReference]]
|
|
175
|
+
|
|
176
|
+
|
|
157
177
|
@dataclass
|
|
158
178
|
class AskFileResponse:
|
|
159
179
|
id: str
|
chainlit/user_session.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import Dict
|
|
2
2
|
|
|
3
|
-
from chainlit.context import context
|
|
3
|
+
from chainlit.context import WebsocketSession, context
|
|
4
4
|
|
|
5
5
|
user_sessions: Dict[str, Dict] = {}
|
|
6
6
|
|
|
@@ -27,7 +27,11 @@ class UserSession:
|
|
|
27
27
|
user_session["chat_settings"] = context.session.chat_settings
|
|
28
28
|
user_session["user"] = context.session.user
|
|
29
29
|
user_session["chat_profile"] = context.session.chat_profile
|
|
30
|
-
user_session["
|
|
30
|
+
user_session["http_referer"] = context.session.http_referer
|
|
31
|
+
user_session["client_type"] = context.session.client_type
|
|
32
|
+
|
|
33
|
+
if isinstance(context.session, WebsocketSession):
|
|
34
|
+
user_session["languages"] = context.session.languages
|
|
31
35
|
|
|
32
36
|
if context.session.root_message:
|
|
33
37
|
user_session["root_message"] = context.session.root_message
|