chainlit 1.0.505__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +34 -1
- chainlit/config.py +43 -21
- chainlit/context.py +19 -7
- chainlit/copilot/dist/index.js +650 -528
- chainlit/data/__init__.py +4 -2
- chainlit/data/acl.py +4 -1
- chainlit/data/sql_alchemy.py +39 -31
- chainlit/discord/__init__.py +6 -0
- chainlit/discord/app.py +304 -0
- chainlit/element.py +9 -3
- chainlit/emitter.py +11 -2
- chainlit/frontend/dist/assets/{index-d200e7ad.js → index-0a52365d.js} +189 -185
- chainlit/frontend/dist/assets/react-plotly-509d26a7.js +3602 -0
- chainlit/frontend/dist/index.html +1 -1
- chainlit/input_widget.py +2 -0
- chainlit/llama_index/callbacks.py +7 -6
- chainlit/message.py +3 -3
- chainlit/server.py +31 -7
- chainlit/session.py +83 -62
- chainlit/slack/__init__.py +6 -0
- chainlit/slack/app.py +368 -0
- chainlit/socket.py +91 -33
- chainlit/step.py +25 -1
- chainlit/types.py +21 -1
- chainlit/user_session.py +6 -2
- chainlit/utils.py +2 -1
- {chainlit-1.0.505.dist-info → chainlit-1.1.0.dist-info}/METADATA +4 -3
- {chainlit-1.0.505.dist-info → chainlit-1.1.0.dist-info}/RECORD +30 -26
- chainlit/frontend/dist/assets/react-plotly-10f4012e.js +0 -3484
- {chainlit-1.0.505.dist-info → chainlit-1.1.0.dist-info}/WHEEL +0 -0
- {chainlit-1.0.505.dist-info → chainlit-1.1.0.dist-info}/entry_points.txt +0 -0
chainlit/data/__init__.py
CHANGED
|
@@ -156,6 +156,7 @@ class ChainlitDataLayer(BaseDataLayer):
|
|
|
156
156
|
"chainlitKey": None,
|
|
157
157
|
"display": metadata.get("display", "side"),
|
|
158
158
|
"language": metadata.get("language"),
|
|
159
|
+
"autoPlay": metadata.get("autoPlay", None),
|
|
159
160
|
"page": metadata.get("page"),
|
|
160
161
|
"size": metadata.get("size"),
|
|
161
162
|
"type": metadata.get("type", "file"),
|
|
@@ -219,7 +220,7 @@ class ChainlitDataLayer(BaseDataLayer):
|
|
|
219
220
|
"disableFeedback": metadata.get("disableFeedback", False),
|
|
220
221
|
"indent": metadata.get("indent"),
|
|
221
222
|
"language": metadata.get("language"),
|
|
222
|
-
"isError":
|
|
223
|
+
"isError": bool(step.error),
|
|
223
224
|
"waitForAnswer": metadata.get("waitForAnswer", False),
|
|
224
225
|
}
|
|
225
226
|
|
|
@@ -348,7 +349,6 @@ class ChainlitDataLayer(BaseDataLayer):
|
|
|
348
349
|
step_dict.get("metadata", {}),
|
|
349
350
|
**{
|
|
350
351
|
"disableFeedback": step_dict.get("disableFeedback"),
|
|
351
|
-
"isError": step_dict.get("isError"),
|
|
352
352
|
"waitForAnswer": step_dict.get("waitForAnswer"),
|
|
353
353
|
"language": step_dict.get("language"),
|
|
354
354
|
"showInput": step_dict.get("showInput"),
|
|
@@ -372,6 +372,8 @@ class ChainlitDataLayer(BaseDataLayer):
|
|
|
372
372
|
step["input"] = {"content": step_dict.get("input")}
|
|
373
373
|
if step_dict.get("output"):
|
|
374
374
|
step["output"] = {"content": step_dict.get("output")}
|
|
375
|
+
if step_dict.get("isError"):
|
|
376
|
+
step["error"] = step_dict.get("output")
|
|
375
377
|
|
|
376
378
|
await self.client.api.send_steps([step])
|
|
377
379
|
|
chainlit/data/acl.py
CHANGED
|
@@ -5,9 +5,12 @@ from fastapi import HTTPException
|
|
|
5
5
|
async def is_thread_author(username: str, thread_id: str):
|
|
6
6
|
data_layer = get_data_layer()
|
|
7
7
|
if not data_layer:
|
|
8
|
-
raise HTTPException(status_code=
|
|
8
|
+
raise HTTPException(status_code=400, detail="Data layer not initialized")
|
|
9
9
|
|
|
10
10
|
thread_author = await data_layer.get_thread_author(thread_id)
|
|
11
|
+
|
|
12
|
+
if not thread_author:
|
|
13
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
11
14
|
|
|
12
15
|
if thread_author != username:
|
|
13
16
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
chainlit/data/sql_alchemy.py
CHANGED
|
@@ -39,9 +39,11 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
39
39
|
ssl_require: bool = False,
|
|
40
40
|
storage_provider: Optional[BaseStorageClient] = None,
|
|
41
41
|
user_thread_limit: Optional[int] = 1000,
|
|
42
|
+
show_logger: Optional[bool] = False,
|
|
42
43
|
):
|
|
43
44
|
self._conninfo = conninfo
|
|
44
45
|
self.user_thread_limit = user_thread_limit
|
|
46
|
+
self.show_logger = show_logger
|
|
45
47
|
ssl_args = {}
|
|
46
48
|
if ssl_require:
|
|
47
49
|
# Create an SSL context to require an SSL connection
|
|
@@ -55,7 +57,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
55
57
|
self.async_session = sessionmaker(bind=self.engine, expire_on_commit=False, class_=AsyncSession) # type: ignore
|
|
56
58
|
if storage_provider:
|
|
57
59
|
self.storage_provider: Optional[BaseStorageClient] = storage_provider
|
|
58
|
-
logger.info("SQLAlchemyDataLayer storage client initialized")
|
|
60
|
+
if self.show_logger: logger.info("SQLAlchemyDataLayer storage client initialized")
|
|
59
61
|
else:
|
|
60
62
|
self.storage_provider = None
|
|
61
63
|
logger.warn(
|
|
@@ -102,7 +104,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
102
104
|
|
|
103
105
|
###### User ######
|
|
104
106
|
async def get_user(self, identifier: str) -> Optional[PersistedUser]:
|
|
105
|
-
logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
|
|
107
|
+
if self.show_logger: logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
|
|
106
108
|
query = "SELECT * FROM users WHERE identifier = :identifier"
|
|
107
109
|
parameters = {"identifier": identifier}
|
|
108
110
|
result = await self.execute_sql(query=query, parameters=parameters)
|
|
@@ -112,20 +114,20 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
112
114
|
return None
|
|
113
115
|
|
|
114
116
|
async def create_user(self, user: User) -> Optional[PersistedUser]:
|
|
115
|
-
logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
|
|
117
|
+
if self.show_logger: logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
|
|
116
118
|
existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier)
|
|
117
119
|
user_dict: Dict[str, Any] = {
|
|
118
120
|
"identifier": str(user.identifier),
|
|
119
121
|
"metadata": json.dumps(user.metadata) or {},
|
|
120
122
|
}
|
|
121
123
|
if not existing_user: # create the user
|
|
122
|
-
logger.info("SQLAlchemy: create_user, creating the user")
|
|
124
|
+
if self.show_logger: logger.info("SQLAlchemy: create_user, creating the user")
|
|
123
125
|
user_dict["id"] = str(uuid.uuid4())
|
|
124
126
|
user_dict["createdAt"] = await self.get_current_timestamp()
|
|
125
127
|
query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)"""
|
|
126
128
|
await self.execute_sql(query=query, parameters=user_dict)
|
|
127
129
|
else: # update the user
|
|
128
|
-
logger.info("SQLAlchemy: update user metadata")
|
|
130
|
+
if self.show_logger: logger.info("SQLAlchemy: update user metadata")
|
|
129
131
|
query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier"""
|
|
130
132
|
await self.execute_sql(
|
|
131
133
|
query=query, parameters=user_dict
|
|
@@ -134,19 +136,18 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
134
136
|
|
|
135
137
|
###### Threads ######
|
|
136
138
|
async def get_thread_author(self, thread_id: str) -> str:
|
|
137
|
-
logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
|
|
139
|
+
if self.show_logger: logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
|
|
138
140
|
query = """SELECT "userIdentifier" FROM threads WHERE "id" = :id"""
|
|
139
141
|
parameters = {"id": thread_id}
|
|
140
142
|
result = await self.execute_sql(query=query, parameters=parameters)
|
|
141
143
|
if isinstance(result, list) and result[0]:
|
|
142
144
|
author_identifier = result[0].get("userIdentifier")
|
|
143
145
|
if author_identifier is not None:
|
|
144
|
-
print(f"Author found: {author_identifier}")
|
|
145
146
|
return author_identifier
|
|
146
147
|
raise ValueError(f"Author not found for thread_id {thread_id}")
|
|
147
148
|
|
|
148
149
|
async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
|
|
149
|
-
logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
|
|
150
|
+
if self.show_logger: logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
|
|
150
151
|
user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(
|
|
151
152
|
thread_id=thread_id
|
|
152
153
|
)
|
|
@@ -163,19 +164,21 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
163
164
|
metadata: Optional[Dict] = None,
|
|
164
165
|
tags: Optional[List[str]] = None,
|
|
165
166
|
):
|
|
166
|
-
logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
|
|
167
|
+
if self.show_logger: logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
|
|
167
168
|
if context.session.user is not None:
|
|
168
169
|
user_identifier = context.session.user.identifier
|
|
169
170
|
else:
|
|
170
171
|
raise ValueError("User not found in session context")
|
|
171
172
|
data = {
|
|
172
173
|
"id": thread_id,
|
|
173
|
-
"createdAt":
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
"name":
|
|
177
|
-
|
|
178
|
-
|
|
174
|
+
"createdAt": (
|
|
175
|
+
await self.get_current_timestamp() if metadata is None else None
|
|
176
|
+
),
|
|
177
|
+
"name": (
|
|
178
|
+
name
|
|
179
|
+
if name is not None
|
|
180
|
+
else (metadata.get("name") if metadata and "name" in metadata else None)
|
|
181
|
+
),
|
|
179
182
|
"userId": user_id,
|
|
180
183
|
"userIdentifier": user_identifier,
|
|
181
184
|
"tags": tags,
|
|
@@ -198,7 +201,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
198
201
|
await self.execute_sql(query=query, parameters=parameters)
|
|
199
202
|
|
|
200
203
|
async def delete_thread(self, thread_id: str):
|
|
201
|
-
logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
|
|
204
|
+
if self.show_logger: logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
|
|
202
205
|
# Delete feedbacks/elements/steps/thread
|
|
203
206
|
feedbacks_query = """DELETE FROM feedbacks WHERE "forId" IN (SELECT "id" FROM steps WHERE "threadId" = :id)"""
|
|
204
207
|
elements_query = """DELETE FROM elements WHERE "threadId" = :id"""
|
|
@@ -213,7 +216,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
213
216
|
async def list_threads(
|
|
214
217
|
self, pagination: Pagination, filters: ThreadFilter
|
|
215
218
|
) -> PaginatedResponse:
|
|
216
|
-
logger.info(
|
|
219
|
+
if self.show_logger: logger.info(
|
|
217
220
|
f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}"
|
|
218
221
|
)
|
|
219
222
|
if not filters.userId:
|
|
@@ -273,7 +276,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
273
276
|
###### Steps ######
|
|
274
277
|
@queue_until_user_message()
|
|
275
278
|
async def create_step(self, step_dict: "StepDict"):
|
|
276
|
-
logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
|
|
279
|
+
if self.show_logger: logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
|
|
277
280
|
if not getattr(context.session.user, "id", None):
|
|
278
281
|
raise ValueError("No authenticated user in context")
|
|
279
282
|
step_dict["showInput"] = (
|
|
@@ -303,12 +306,12 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
303
306
|
|
|
304
307
|
@queue_until_user_message()
|
|
305
308
|
async def update_step(self, step_dict: "StepDict"):
|
|
306
|
-
logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
|
|
309
|
+
if self.show_logger: logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
|
|
307
310
|
await self.create_step(step_dict)
|
|
308
311
|
|
|
309
312
|
@queue_until_user_message()
|
|
310
313
|
async def delete_step(self, step_id: str):
|
|
311
|
-
logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
|
|
314
|
+
if self.show_logger: logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
|
|
312
315
|
# Delete feedbacks/elements/steps
|
|
313
316
|
feedbacks_query = """DELETE FROM feedbacks WHERE "forId" = :id"""
|
|
314
317
|
elements_query = """DELETE FROM elements WHERE "forId" = :id"""
|
|
@@ -320,7 +323,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
320
323
|
|
|
321
324
|
###### Feedback ######
|
|
322
325
|
async def upsert_feedback(self, feedback: Feedback) -> str:
|
|
323
|
-
logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
|
|
326
|
+
if self.show_logger: logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
|
|
324
327
|
feedback.id = feedback.id or str(uuid.uuid4())
|
|
325
328
|
feedback_dict = asdict(feedback)
|
|
326
329
|
parameters = {
|
|
@@ -342,7 +345,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
342
345
|
return feedback.id
|
|
343
346
|
|
|
344
347
|
async def delete_feedback(self, feedback_id: str) -> bool:
|
|
345
|
-
logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
|
|
348
|
+
if self.show_logger: logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
|
|
346
349
|
query = """DELETE FROM feedbacks WHERE "id" = :feedback_id"""
|
|
347
350
|
parameters = {"feedback_id": feedback_id}
|
|
348
351
|
await self.execute_sql(query=query, parameters=parameters)
|
|
@@ -351,7 +354,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
351
354
|
###### Elements ######
|
|
352
355
|
@queue_until_user_message()
|
|
353
356
|
async def create_element(self, element: "Element"):
|
|
354
|
-
logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
|
|
357
|
+
if self.show_logger: logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
|
|
355
358
|
if not getattr(context.session.user, "id", None):
|
|
356
359
|
raise ValueError("No authenticated user in context")
|
|
357
360
|
if isinstance(element, Avatar): # Skip creating elements of type avatar
|
|
@@ -414,7 +417,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
414
417
|
|
|
415
418
|
@queue_until_user_message()
|
|
416
419
|
async def delete_element(self, element_id: str):
|
|
417
|
-
logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
|
|
420
|
+
if self.show_logger: logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
|
|
418
421
|
query = """DELETE FROM elements WHERE "id" = :id"""
|
|
419
422
|
parameters = {"id": element_id}
|
|
420
423
|
await self.execute_sql(query=query, parameters=parameters)
|
|
@@ -426,7 +429,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
426
429
|
self, user_id: Optional[str] = None, thread_id: Optional[str] = None
|
|
427
430
|
) -> Optional[List[ThreadDict]]:
|
|
428
431
|
"""Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided."""
|
|
429
|
-
logger.info(f"SQLAlchemy: get_all_user_threads")
|
|
432
|
+
if self.show_logger: logger.info(f"SQLAlchemy: get_all_user_threads")
|
|
430
433
|
user_threads_query = """
|
|
431
434
|
SELECT
|
|
432
435
|
"id" AS thread_id,
|
|
@@ -552,13 +555,17 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
552
555
|
streaming=step_feedback.get("step_streaming", False),
|
|
553
556
|
waitForAnswer=step_feedback.get("step_waitforanswer"),
|
|
554
557
|
isError=step_feedback.get("step_iserror"),
|
|
555
|
-
metadata=
|
|
556
|
-
|
|
557
|
-
|
|
558
|
+
metadata=(
|
|
559
|
+
step_feedback["step_metadata"]
|
|
560
|
+
if step_feedback.get("step_metadata") is not None
|
|
561
|
+
else {}
|
|
562
|
+
),
|
|
558
563
|
tags=step_feedback.get("step_tags"),
|
|
559
|
-
input=
|
|
560
|
-
|
|
561
|
-
|
|
564
|
+
input=(
|
|
565
|
+
step_feedback.get("step_input", "")
|
|
566
|
+
if step_feedback["step_showinput"] == "true"
|
|
567
|
+
else None
|
|
568
|
+
),
|
|
562
569
|
output=step_feedback.get("step_output", ""),
|
|
563
570
|
createdAt=step_feedback.get("step_createdat"),
|
|
564
571
|
start=step_feedback.get("step_start"),
|
|
@@ -587,6 +594,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
587
594
|
display=element["element_display"],
|
|
588
595
|
size=element.get("element_size"),
|
|
589
596
|
language=element.get("element_language"),
|
|
597
|
+
autoPlay=element.get("element_autoPlay"),
|
|
590
598
|
page=element.get("element_page"),
|
|
591
599
|
forId=element.get("element_forid"),
|
|
592
600
|
mime=element.get("element_mime"),
|
chainlit/discord/app.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import mimetypes
|
|
3
|
+
import re
|
|
4
|
+
import uuid
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
from typing import Dict, List, Optional, Union, TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from discord.abc import MessageableChannel
|
|
10
|
+
|
|
11
|
+
import discord
|
|
12
|
+
|
|
13
|
+
import filetype
|
|
14
|
+
import httpx
|
|
15
|
+
from chainlit.config import config
|
|
16
|
+
from chainlit.context import ChainlitContext, HTTPSession, context_var
|
|
17
|
+
from chainlit.data import get_data_layer
|
|
18
|
+
from chainlit.element import Element, ElementDict
|
|
19
|
+
from chainlit.emitter import BaseChainlitEmitter
|
|
20
|
+
from chainlit.logger import logger
|
|
21
|
+
from chainlit.message import Message, StepDict
|
|
22
|
+
from chainlit.types import Feedback
|
|
23
|
+
from chainlit.user import PersistedUser, User
|
|
24
|
+
from chainlit.user_session import user_session
|
|
25
|
+
from discord.ui import Button, View
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class FeedbackView(View):
|
|
29
|
+
def __init__(self, step_id: str):
|
|
30
|
+
super().__init__(timeout=None)
|
|
31
|
+
self.step_id = step_id
|
|
32
|
+
|
|
33
|
+
@discord.ui.button(label="👎")
|
|
34
|
+
async def thumbs_down(self, interaction: discord.Interaction, button: Button):
|
|
35
|
+
if data_layer := get_data_layer():
|
|
36
|
+
await data_layer.upsert_feedback(Feedback(forId=self.step_id, value=0))
|
|
37
|
+
if interaction.message:
|
|
38
|
+
await interaction.message.edit(view=None)
|
|
39
|
+
await interaction.message.add_reaction("👎")
|
|
40
|
+
|
|
41
|
+
@discord.ui.button(label="👍")
|
|
42
|
+
async def thumbs_up(self, interaction: discord.Interaction, button: Button):
|
|
43
|
+
if data_layer := get_data_layer():
|
|
44
|
+
await data_layer.upsert_feedback(Feedback(forId=self.step_id, value=1))
|
|
45
|
+
if interaction.message:
|
|
46
|
+
await interaction.message.edit(view=None)
|
|
47
|
+
await interaction.message.add_reaction("👍")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class DiscordEmitter(BaseChainlitEmitter):
|
|
51
|
+
def __init__(
|
|
52
|
+
self, session: HTTPSession, channel: "MessageableChannel", enabled=False
|
|
53
|
+
):
|
|
54
|
+
super().__init__(session)
|
|
55
|
+
self.channel = channel
|
|
56
|
+
self.enabled = enabled
|
|
57
|
+
|
|
58
|
+
async def send_element(self, element_dict: ElementDict):
|
|
59
|
+
if not self.enabled or element_dict.get("display") != "inline":
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
persisted_file = self.session.files.get(element_dict.get("chainlitKey") or "")
|
|
63
|
+
file: Optional[Union[BytesIO, str]] = None
|
|
64
|
+
mime: Optional[str] = None
|
|
65
|
+
|
|
66
|
+
if persisted_file:
|
|
67
|
+
file = str(persisted_file["path"])
|
|
68
|
+
mime = element_dict.get("mime")
|
|
69
|
+
elif file_url := element_dict.get("url"):
|
|
70
|
+
async with httpx.AsyncClient() as client:
|
|
71
|
+
response = await client.get(file_url)
|
|
72
|
+
if response.status_code == 200:
|
|
73
|
+
file = BytesIO(response.content)
|
|
74
|
+
mime = filetype.guess_mime(file)
|
|
75
|
+
|
|
76
|
+
if not file:
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
element_name: str = element_dict.get("name", "Untitled")
|
|
80
|
+
|
|
81
|
+
if mime:
|
|
82
|
+
file_extension = mimetypes.guess_extension(mime)
|
|
83
|
+
if file_extension:
|
|
84
|
+
element_name += file_extension
|
|
85
|
+
|
|
86
|
+
file_obj = discord.File(file, filename=element_name)
|
|
87
|
+
await self.channel.send(file=file_obj)
|
|
88
|
+
|
|
89
|
+
async def send_step(self, step_dict: StepDict):
|
|
90
|
+
if not self.enabled:
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
is_chain_of_thought = bool(step_dict.get("parentId"))
|
|
94
|
+
is_empty_output = not step_dict.get("output")
|
|
95
|
+
|
|
96
|
+
if is_chain_of_thought or is_empty_output:
|
|
97
|
+
return
|
|
98
|
+
else:
|
|
99
|
+
enable_feedback = not step_dict.get("disableFeedback") and get_data_layer()
|
|
100
|
+
message = await self.channel.send(step_dict["output"])
|
|
101
|
+
|
|
102
|
+
if enable_feedback:
|
|
103
|
+
view = FeedbackView(step_dict.get("id", ""))
|
|
104
|
+
await message.edit(view=view)
|
|
105
|
+
|
|
106
|
+
async def update_step(self, step_dict: StepDict):
|
|
107
|
+
if not self.enabled:
|
|
108
|
+
return
|
|
109
|
+
|
|
110
|
+
await self.send_step(step_dict)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
intents = discord.Intents.default()
|
|
114
|
+
intents.message_content = True
|
|
115
|
+
|
|
116
|
+
client = discord.Client(intents=intents)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def init_discord_context(
|
|
120
|
+
session: HTTPSession,
|
|
121
|
+
channel: "MessageableChannel",
|
|
122
|
+
message: discord.Message,
|
|
123
|
+
) -> ChainlitContext:
|
|
124
|
+
emitter = DiscordEmitter(session=session, channel=channel)
|
|
125
|
+
context = ChainlitContext(session=session, emitter=emitter)
|
|
126
|
+
context_var.set(context)
|
|
127
|
+
user_session.set("discord_message", message)
|
|
128
|
+
user_session.set("discord_channel", channel)
|
|
129
|
+
return context
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
users_by_discord_id: Dict[int, Union[User, PersistedUser]] = {}
|
|
133
|
+
|
|
134
|
+
USER_PREFIX = "discord_"
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
async def get_user(discord_user: Union[discord.User, discord.Member]):
|
|
138
|
+
metadata = {
|
|
139
|
+
"name": discord_user.name,
|
|
140
|
+
"id": discord_user.id,
|
|
141
|
+
}
|
|
142
|
+
user = User(identifier=USER_PREFIX + str(discord_user.name), metadata=metadata)
|
|
143
|
+
|
|
144
|
+
users_by_discord_id[discord_user.id] = user
|
|
145
|
+
|
|
146
|
+
if data_layer := get_data_layer():
|
|
147
|
+
persisted_user = await data_layer.create_user(user)
|
|
148
|
+
if persisted_user:
|
|
149
|
+
users_by_discord_id[discord_user.id] = persisted_user
|
|
150
|
+
|
|
151
|
+
return users_by_discord_id[discord_user.id]
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
async def download_discord_file(url: str):
|
|
155
|
+
async with httpx.AsyncClient() as client:
|
|
156
|
+
response = await client.get(url)
|
|
157
|
+
if response.status_code == 200:
|
|
158
|
+
return response.content
|
|
159
|
+
else:
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
async def download_discord_files(
|
|
164
|
+
session: HTTPSession, attachments: List[discord.Attachment]
|
|
165
|
+
):
|
|
166
|
+
download_coros = [
|
|
167
|
+
download_discord_file(attachment.url) for attachment in attachments
|
|
168
|
+
]
|
|
169
|
+
file_bytes_list = await asyncio.gather(*download_coros)
|
|
170
|
+
file_refs = []
|
|
171
|
+
for idx, file_bytes in enumerate(file_bytes_list):
|
|
172
|
+
if file_bytes:
|
|
173
|
+
name = attachments[idx].filename
|
|
174
|
+
mime_type = attachments[idx].content_type or "application/octet-stream"
|
|
175
|
+
file_ref = await session.persist_file(
|
|
176
|
+
name=name, mime=mime_type, content=file_bytes
|
|
177
|
+
)
|
|
178
|
+
file_refs.append(file_ref)
|
|
179
|
+
|
|
180
|
+
files_dicts = [
|
|
181
|
+
session.files[file["id"]] for file in file_refs if file["id"] in session.files
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
file_elements = [Element.from_dict(file_dict) for file_dict in files_dicts]
|
|
185
|
+
|
|
186
|
+
return file_elements
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def clean_content(message: discord.Message):
|
|
190
|
+
if not client.user:
|
|
191
|
+
return message.content
|
|
192
|
+
|
|
193
|
+
# Regex to find mentions of the bot
|
|
194
|
+
bot_mention = f"<@!?{client.user.id}>"
|
|
195
|
+
# Replace the bot's mention with nothing
|
|
196
|
+
return re.sub(bot_mention, "", message.content).strip()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
async def process_discord_message(
|
|
200
|
+
message: discord.Message,
|
|
201
|
+
thread_name: str,
|
|
202
|
+
channel: "MessageableChannel",
|
|
203
|
+
bind_thread_to_user=False,
|
|
204
|
+
):
|
|
205
|
+
user = await get_user(message.author)
|
|
206
|
+
|
|
207
|
+
thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(channel.id)))
|
|
208
|
+
|
|
209
|
+
text = clean_content(message)
|
|
210
|
+
discord_files = message.attachments
|
|
211
|
+
|
|
212
|
+
session_id = str(uuid.uuid4())
|
|
213
|
+
session = HTTPSession(
|
|
214
|
+
id=session_id,
|
|
215
|
+
thread_id=thread_id,
|
|
216
|
+
user=user,
|
|
217
|
+
client_type="discord",
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
ctx = init_discord_context(
|
|
221
|
+
session=session,
|
|
222
|
+
channel=channel,
|
|
223
|
+
message=message,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
file_elements = await download_discord_files(session, discord_files)
|
|
227
|
+
|
|
228
|
+
msg = Message(
|
|
229
|
+
content=text,
|
|
230
|
+
elements=file_elements,
|
|
231
|
+
type="user_message",
|
|
232
|
+
author=user.metadata.get("name"),
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
await msg.send()
|
|
236
|
+
|
|
237
|
+
ctx.emitter.enabled = True
|
|
238
|
+
|
|
239
|
+
if on_chat_start := config.code.on_chat_start:
|
|
240
|
+
await on_chat_start()
|
|
241
|
+
|
|
242
|
+
if on_message := config.code.on_message:
|
|
243
|
+
await on_message(msg)
|
|
244
|
+
|
|
245
|
+
if on_chat_end := config.code.on_chat_end:
|
|
246
|
+
await on_chat_end()
|
|
247
|
+
|
|
248
|
+
if data_layer := get_data_layer():
|
|
249
|
+
user_id = None
|
|
250
|
+
if isinstance(user, PersistedUser):
|
|
251
|
+
user_id = user.id if bind_thread_to_user else None
|
|
252
|
+
|
|
253
|
+
await data_layer.update_thread(
|
|
254
|
+
thread_id=thread_id,
|
|
255
|
+
name=thread_name,
|
|
256
|
+
metadata=ctx.session.to_persistable(),
|
|
257
|
+
user_id=user_id,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
ctx.session.delete()
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@client.event
|
|
264
|
+
async def on_ready():
|
|
265
|
+
logger.info(f"Logged in as {client.user}")
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@client.event
|
|
269
|
+
async def on_message(message: discord.Message):
|
|
270
|
+
if not client.user or message.author == client.user:
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
is_dm = isinstance(message.channel, discord.DMChannel)
|
|
274
|
+
if not client.user.mentioned_in(message) and not is_dm:
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
thread_name: str = ""
|
|
278
|
+
bind_thread_to_user = False
|
|
279
|
+
channel = message.channel
|
|
280
|
+
|
|
281
|
+
if isinstance(message.channel, discord.Thread):
|
|
282
|
+
thread_name = f"{message.channel.name}"
|
|
283
|
+
elif isinstance(message.channel, discord.ForumChannel):
|
|
284
|
+
thread_name = f"{message.channel.name}"
|
|
285
|
+
elif isinstance(message.channel, discord.DMChannel):
|
|
286
|
+
thread_name = f"{message.author} Discord DM"
|
|
287
|
+
bind_thread_to_user = True
|
|
288
|
+
elif isinstance(message.channel, discord.GroupChannel):
|
|
289
|
+
thread_name = f"{message.channel.name}"
|
|
290
|
+
elif isinstance(message.channel, discord.TextChannel):
|
|
291
|
+
channel = await message.channel.create_thread(
|
|
292
|
+
name=clean_content(message), message=message
|
|
293
|
+
)
|
|
294
|
+
thread_name = f"{channel.name}"
|
|
295
|
+
else:
|
|
296
|
+
logger.warning(f"Unsupported channel type: {message.channel.type}")
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
await process_discord_message(
|
|
300
|
+
message=message,
|
|
301
|
+
thread_name=thread_name,
|
|
302
|
+
channel=channel,
|
|
303
|
+
bind_thread_to_user=bind_thread_to_user,
|
|
304
|
+
)
|
chainlit/element.py
CHANGED
|
@@ -5,6 +5,7 @@ from io import BytesIO
|
|
|
5
5
|
from typing import Any, ClassVar, List, Literal, Optional, TypedDict, TypeVar, Union
|
|
6
6
|
|
|
7
7
|
import filetype
|
|
8
|
+
import mimetypes
|
|
8
9
|
from chainlit.context import context
|
|
9
10
|
from chainlit.data import get_data_layer
|
|
10
11
|
from chainlit.logger import logger
|
|
@@ -38,6 +39,7 @@ class ElementDict(TypedDict):
|
|
|
38
39
|
size: Optional[ElementSize]
|
|
39
40
|
language: Optional[str]
|
|
40
41
|
page: Optional[int]
|
|
42
|
+
autoPlay: Optional[bool]
|
|
41
43
|
forId: Optional[str]
|
|
42
44
|
mime: Optional[str]
|
|
43
45
|
|
|
@@ -61,7 +63,7 @@ class Element:
|
|
|
61
63
|
# The byte content of the element.
|
|
62
64
|
content: Optional[Union[bytes, str]] = None
|
|
63
65
|
# Controls how the image element should be displayed in the UI. Choices are “side” (default), “inline”, or “page”.
|
|
64
|
-
display: ElementDisplay = Field(default="
|
|
66
|
+
display: ElementDisplay = Field(default="inline")
|
|
65
67
|
# Controls element size
|
|
66
68
|
size: Optional[ElementSize] = None
|
|
67
69
|
# The ID of the message this element is associated with.
|
|
@@ -93,6 +95,7 @@ class Element:
|
|
|
93
95
|
"objectKey": getattr(self, "object_key", None),
|
|
94
96
|
"size": getattr(self, "size", None),
|
|
95
97
|
"page": getattr(self, "page", None),
|
|
98
|
+
"autoPlay": getattr(self, "auto_play", None),
|
|
96
99
|
"language": getattr(self, "language", None),
|
|
97
100
|
"forId": getattr(self, "for_id", None),
|
|
98
101
|
"mime": getattr(self, "mime", None),
|
|
@@ -163,14 +166,16 @@ class Element:
|
|
|
163
166
|
if self.type in mime_types
|
|
164
167
|
else filetype.guess_mime(self.path or self.content)
|
|
165
168
|
)
|
|
166
|
-
|
|
169
|
+
if not self.mime and self.url:
|
|
170
|
+
self.mime = mimetypes.guess_type(self.url)[0]
|
|
171
|
+
|
|
167
172
|
await self._create()
|
|
168
173
|
|
|
169
174
|
if not self.url and not self.chainlit_key:
|
|
170
175
|
raise ValueError("Must provide url or chainlit key to send element")
|
|
171
176
|
|
|
172
177
|
trace_event(f"send {self.__class__.__name__}")
|
|
173
|
-
await context.emitter.
|
|
178
|
+
await context.emitter.send_element(self.to_dict())
|
|
174
179
|
|
|
175
180
|
|
|
176
181
|
ElementBased = TypeVar("ElementBased", bound=Element)
|
|
@@ -306,6 +311,7 @@ class TaskList(Element):
|
|
|
306
311
|
@dataclass
|
|
307
312
|
class Audio(Element):
|
|
308
313
|
type: ClassVar[ElementType] = "audio"
|
|
314
|
+
auto_play: bool = False
|
|
309
315
|
|
|
310
316
|
|
|
311
317
|
@dataclass
|
chainlit/emitter.py
CHANGED
|
@@ -4,10 +4,10 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast
|
|
|
4
4
|
|
|
5
5
|
from chainlit.config import config
|
|
6
6
|
from chainlit.data import get_data_layer
|
|
7
|
-
from chainlit.element import Element, File
|
|
7
|
+
from chainlit.element import Element, ElementDict, File
|
|
8
8
|
from chainlit.logger import logger
|
|
9
9
|
from chainlit.message import Message
|
|
10
|
-
from chainlit.session import BaseSession, WebsocketSession
|
|
10
|
+
from chainlit.session import BaseSession, HTTPSession, WebsocketSession
|
|
11
11
|
from chainlit.step import StepDict
|
|
12
12
|
from chainlit.types import (
|
|
13
13
|
AskActionResponse,
|
|
@@ -29,6 +29,7 @@ class BaseChainlitEmitter:
|
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
session: BaseSession
|
|
32
|
+
enabled: bool = True
|
|
32
33
|
|
|
33
34
|
def __init__(self, session: BaseSession) -> None:
|
|
34
35
|
"""Initialize with the user session."""
|
|
@@ -46,6 +47,10 @@ class BaseChainlitEmitter:
|
|
|
46
47
|
"""Stub method to resume a thread."""
|
|
47
48
|
pass
|
|
48
49
|
|
|
50
|
+
async def send_element(self, element_dict: ElementDict):
|
|
51
|
+
"""Stub method to send an element to the UI."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
49
54
|
async def send_step(self, step_dict: StepDict):
|
|
50
55
|
"""Stub method to send a message to the UI."""
|
|
51
56
|
pass
|
|
@@ -151,6 +156,10 @@ class ChainlitEmitter(BaseChainlitEmitter):
|
|
|
151
156
|
"""Send a thread to the UI to resume it"""
|
|
152
157
|
return self.emit("resume_thread", thread_dict)
|
|
153
158
|
|
|
159
|
+
async def send_element(self, element_dict: ElementDict):
|
|
160
|
+
"""Stub method to send an element to the UI."""
|
|
161
|
+
await self.emit("element", element_dict)
|
|
162
|
+
|
|
154
163
|
def send_step(self, step_dict: StepDict):
|
|
155
164
|
"""Send a message to the UI."""
|
|
156
165
|
return self.emit("new_message", step_dict)
|