chainlit 1.0.401__py3-none-any.whl → 2.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +98 -279
- chainlit/_utils.py +8 -0
- chainlit/action.py +12 -10
- chainlit/{auth.py → auth/__init__.py} +28 -36
- chainlit/auth/cookie.py +123 -0
- chainlit/auth/jwt.py +39 -0
- chainlit/cache.py +4 -6
- chainlit/callbacks.py +362 -0
- chainlit/chat_context.py +64 -0
- chainlit/chat_settings.py +3 -1
- chainlit/cli/__init__.py +77 -8
- chainlit/config.py +191 -102
- chainlit/context.py +42 -13
- chainlit/copilot/dist/index.js +8750 -903
- chainlit/data/__init__.py +101 -416
- chainlit/data/acl.py +6 -2
- chainlit/data/base.py +107 -0
- chainlit/data/chainlit_data_layer.py +614 -0
- chainlit/data/dynamodb.py +590 -0
- chainlit/data/literalai.py +500 -0
- chainlit/data/sql_alchemy.py +721 -0
- chainlit/data/storage_clients/__init__.py +0 -0
- chainlit/data/storage_clients/azure.py +81 -0
- chainlit/data/storage_clients/azure_blob.py +89 -0
- chainlit/data/storage_clients/base.py +26 -0
- chainlit/data/storage_clients/gcs.py +88 -0
- chainlit/data/storage_clients/s3.py +75 -0
- chainlit/data/utils.py +29 -0
- chainlit/discord/__init__.py +6 -0
- chainlit/discord/app.py +354 -0
- chainlit/element.py +91 -33
- chainlit/emitter.py +81 -29
- chainlit/frontend/dist/assets/DailyMotion-Ce9dQoqZ.js +1 -0
- chainlit/frontend/dist/assets/Dataframe-C1XonMcV.js +22 -0
- chainlit/frontend/dist/assets/Facebook-DVVt6lrr.js +1 -0
- chainlit/frontend/dist/assets/FilePlayer-c7stW4vz.js +1 -0
- chainlit/frontend/dist/assets/Kaltura-BmMmgorA.js +1 -0
- chainlit/frontend/dist/assets/Mixcloud-Cw8hDmiO.js +1 -0
- chainlit/frontend/dist/assets/Mux-DiRZfeUf.js +1 -0
- chainlit/frontend/dist/assets/Preview-6Jt2mRHx.js +1 -0
- chainlit/frontend/dist/assets/SoundCloud-DKwcT58_.js +1 -0
- chainlit/frontend/dist/assets/Streamable-BVdxrEeX.js +1 -0
- chainlit/frontend/dist/assets/Twitch-DFqZR7Gu.js +1 -0
- chainlit/frontend/dist/assets/Vidyard-0BQAAtVk.js +1 -0
- chainlit/frontend/dist/assets/Vimeo-CRFSH0Vu.js +1 -0
- chainlit/frontend/dist/assets/Wistia-CKrmdQaG.js +1 -0
- chainlit/frontend/dist/assets/YouTube-CQpL-rvU.js +1 -0
- chainlit/frontend/dist/assets/index-DQmLRKyv.css +1 -0
- chainlit/frontend/dist/assets/index-QdmxtIMQ.js +8665 -0
- chainlit/frontend/dist/assets/react-plotly-B9hvVpUG.js +3484 -0
- chainlit/frontend/dist/index.html +2 -4
- chainlit/haystack/callbacks.py +4 -7
- chainlit/input_widget.py +8 -4
- chainlit/langchain/callbacks.py +103 -68
- chainlit/langflow/__init__.py +1 -0
- chainlit/llama_index/callbacks.py +65 -40
- chainlit/markdown.py +22 -6
- chainlit/message.py +54 -56
- chainlit/mistralai/__init__.py +50 -0
- chainlit/oauth_providers.py +266 -8
- chainlit/openai/__init__.py +10 -18
- chainlit/secret.py +1 -1
- chainlit/server.py +789 -228
- chainlit/session.py +108 -90
- chainlit/slack/__init__.py +6 -0
- chainlit/slack/app.py +397 -0
- chainlit/socket.py +199 -116
- chainlit/step.py +141 -89
- chainlit/sync.py +2 -1
- chainlit/teams/__init__.py +6 -0
- chainlit/teams/app.py +338 -0
- chainlit/translations/bn.json +244 -0
- chainlit/translations/en-US.json +122 -8
- chainlit/translations/gu.json +244 -0
- chainlit/translations/he-IL.json +244 -0
- chainlit/translations/hi.json +244 -0
- chainlit/translations/ja.json +242 -0
- chainlit/translations/kn.json +244 -0
- chainlit/translations/ml.json +244 -0
- chainlit/translations/mr.json +244 -0
- chainlit/translations/nl-NL.json +242 -0
- chainlit/translations/ta.json +244 -0
- chainlit/translations/te.json +244 -0
- chainlit/translations/zh-CN.json +243 -0
- chainlit/translations.py +60 -0
- chainlit/types.py +133 -28
- chainlit/user.py +14 -3
- chainlit/user_session.py +6 -3
- chainlit/utils.py +52 -5
- chainlit/version.py +3 -2
- {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/METADATA +48 -50
- chainlit-2.0.4.dist-info/RECORD +107 -0
- chainlit/cli/utils.py +0 -24
- chainlit/frontend/dist/assets/index-9711593e.js +0 -723
- chainlit/frontend/dist/assets/index-d088547c.css +0 -1
- chainlit/frontend/dist/assets/react-plotly-d8762cc2.js +0 -3602
- chainlit/playground/__init__.py +0 -2
- chainlit/playground/config.py +0 -40
- chainlit/playground/provider.py +0 -108
- chainlit/playground/providers/__init__.py +0 -13
- chainlit/playground/providers/anthropic.py +0 -118
- chainlit/playground/providers/huggingface.py +0 -75
- chainlit/playground/providers/langchain.py +0 -89
- chainlit/playground/providers/openai.py +0 -408
- chainlit/playground/providers/vertexai.py +0 -171
- chainlit/translations/pt-BR.json +0 -155
- chainlit-1.0.401.dist-info/RECORD +0 -66
- /chainlit/copilot/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
- /chainlit/copilot/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
- {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/WHEEL +0 -0
- {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/entry_points.txt +0 -0
chainlit/discord/app.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import mimetypes
|
|
3
|
+
import re
|
|
4
|
+
import uuid
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from io import BytesIO
|
|
7
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from discord.abc import MessageableChannel
|
|
11
|
+
|
|
12
|
+
import discord
|
|
13
|
+
import filetype
|
|
14
|
+
import httpx
|
|
15
|
+
from discord.ui import Button, View
|
|
16
|
+
|
|
17
|
+
from chainlit.config import config
|
|
18
|
+
from chainlit.context import ChainlitContext, HTTPSession, context, context_var
|
|
19
|
+
from chainlit.data import get_data_layer
|
|
20
|
+
from chainlit.element import Element, ElementDict
|
|
21
|
+
from chainlit.emitter import BaseChainlitEmitter
|
|
22
|
+
from chainlit.logger import logger
|
|
23
|
+
from chainlit.message import Message, StepDict
|
|
24
|
+
from chainlit.telemetry import trace
|
|
25
|
+
from chainlit.types import Feedback
|
|
26
|
+
from chainlit.user import PersistedUser, User
|
|
27
|
+
from chainlit.user_session import user_session
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class FeedbackView(View):
|
|
31
|
+
def __init__(self, step_id: str):
|
|
32
|
+
super().__init__(timeout=None)
|
|
33
|
+
self.step_id = step_id
|
|
34
|
+
|
|
35
|
+
@discord.ui.button(label="👎")
|
|
36
|
+
async def thumbs_down(self, interaction: discord.Interaction, button: Button):
|
|
37
|
+
if data_layer := get_data_layer():
|
|
38
|
+
try:
|
|
39
|
+
feedback = Feedback(forId=self.step_id, value=0)
|
|
40
|
+
await data_layer.upsert_feedback(feedback)
|
|
41
|
+
except Exception as e:
|
|
42
|
+
logger.error(f"Error upserting feedback: {e}")
|
|
43
|
+
if interaction.message:
|
|
44
|
+
await interaction.message.edit(view=None)
|
|
45
|
+
await interaction.message.add_reaction("👎")
|
|
46
|
+
|
|
47
|
+
@discord.ui.button(label="👍")
|
|
48
|
+
async def thumbs_up(self, interaction: discord.Interaction, button: Button):
|
|
49
|
+
if data_layer := get_data_layer():
|
|
50
|
+
try:
|
|
51
|
+
feedback = Feedback(forId=self.step_id, value=1)
|
|
52
|
+
await data_layer.upsert_feedback(feedback)
|
|
53
|
+
except Exception as e:
|
|
54
|
+
logger.error(f"Error upserting feedback: {e}")
|
|
55
|
+
if interaction.message:
|
|
56
|
+
await interaction.message.edit(view=None)
|
|
57
|
+
await interaction.message.add_reaction("👍")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DiscordEmitter(BaseChainlitEmitter):
|
|
61
|
+
def __init__(self, session: HTTPSession, channel: "MessageableChannel"):
|
|
62
|
+
super().__init__(session)
|
|
63
|
+
self.channel = channel
|
|
64
|
+
|
|
65
|
+
async def send_element(self, element_dict: ElementDict):
|
|
66
|
+
if element_dict.get("display") != "inline":
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
persisted_file = self.session.files.get(element_dict.get("chainlitKey") or "")
|
|
70
|
+
file: Optional[Union[BytesIO, str]] = None
|
|
71
|
+
mime: Optional[str] = None
|
|
72
|
+
|
|
73
|
+
if persisted_file:
|
|
74
|
+
file = str(persisted_file["path"])
|
|
75
|
+
mime = element_dict.get("mime")
|
|
76
|
+
elif file_url := element_dict.get("url"):
|
|
77
|
+
async with httpx.AsyncClient() as client:
|
|
78
|
+
response = await client.get(file_url)
|
|
79
|
+
if response.status_code == 200:
|
|
80
|
+
file = BytesIO(response.content)
|
|
81
|
+
mime = filetype.guess_mime(file)
|
|
82
|
+
|
|
83
|
+
if not file:
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
element_name: str = element_dict.get("name", "Untitled")
|
|
87
|
+
|
|
88
|
+
if mime:
|
|
89
|
+
file_extension = mimetypes.guess_extension(mime)
|
|
90
|
+
if file_extension:
|
|
91
|
+
element_name += file_extension
|
|
92
|
+
|
|
93
|
+
file_obj = discord.File(file, filename=element_name)
|
|
94
|
+
await self.channel.send(file=file_obj)
|
|
95
|
+
|
|
96
|
+
async def send_step(self, step_dict: StepDict):
|
|
97
|
+
if not step_dict["type"] == "assistant_message":
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
step_type = step_dict.get("type")
|
|
101
|
+
is_message = step_type in [
|
|
102
|
+
"user_message",
|
|
103
|
+
"assistant_message",
|
|
104
|
+
]
|
|
105
|
+
is_empty_output = not step_dict.get("output")
|
|
106
|
+
|
|
107
|
+
if is_empty_output or not is_message:
|
|
108
|
+
return
|
|
109
|
+
else:
|
|
110
|
+
enable_feedback = get_data_layer()
|
|
111
|
+
message = await self.channel.send(step_dict["output"])
|
|
112
|
+
|
|
113
|
+
if enable_feedback:
|
|
114
|
+
current_run = context.current_run
|
|
115
|
+
scorable_id = current_run.id if current_run else step_dict.get("id")
|
|
116
|
+
if not scorable_id:
|
|
117
|
+
return
|
|
118
|
+
view = FeedbackView(scorable_id)
|
|
119
|
+
await message.edit(view=view)
|
|
120
|
+
|
|
121
|
+
async def update_step(self, step_dict: StepDict):
|
|
122
|
+
if not step_dict["type"] == "assistant_message":
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
await self.send_step(step_dict)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
intents = discord.Intents.default()
|
|
129
|
+
intents.message_content = True
|
|
130
|
+
|
|
131
|
+
client = discord.Client(intents=intents)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@trace
|
|
135
|
+
def init_discord_context(
|
|
136
|
+
session: HTTPSession,
|
|
137
|
+
channel: "MessageableChannel",
|
|
138
|
+
message: discord.Message,
|
|
139
|
+
) -> ChainlitContext:
|
|
140
|
+
emitter = DiscordEmitter(session=session, channel=channel)
|
|
141
|
+
context = ChainlitContext(session=session, emitter=emitter)
|
|
142
|
+
context_var.set(context)
|
|
143
|
+
user_session.set("discord_message", message)
|
|
144
|
+
user_session.set("discord_channel", channel)
|
|
145
|
+
return context
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
users_by_discord_id: Dict[int, Union[User, PersistedUser]] = {}
|
|
149
|
+
|
|
150
|
+
USER_PREFIX = "discord_"
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
async def get_user(discord_user: Union[discord.User, discord.Member]):
|
|
154
|
+
if discord_user.id in users_by_discord_id:
|
|
155
|
+
return users_by_discord_id[discord_user.id]
|
|
156
|
+
|
|
157
|
+
metadata = {
|
|
158
|
+
"name": discord_user.name,
|
|
159
|
+
"id": discord_user.id,
|
|
160
|
+
}
|
|
161
|
+
user = User(identifier=USER_PREFIX + str(discord_user.name), metadata=metadata)
|
|
162
|
+
|
|
163
|
+
users_by_discord_id[discord_user.id] = user
|
|
164
|
+
|
|
165
|
+
if data_layer := get_data_layer():
|
|
166
|
+
try:
|
|
167
|
+
persisted_user = await data_layer.create_user(user)
|
|
168
|
+
if persisted_user:
|
|
169
|
+
users_by_discord_id[discord_user.id] = persisted_user
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.error(f"Error creating user: {e}")
|
|
172
|
+
|
|
173
|
+
return users_by_discord_id[discord_user.id]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
async def download_discord_file(url: str):
|
|
177
|
+
async with httpx.AsyncClient() as client:
|
|
178
|
+
response = await client.get(url)
|
|
179
|
+
if response.status_code == 200:
|
|
180
|
+
return response.content
|
|
181
|
+
else:
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
async def download_discord_files(
|
|
186
|
+
session: HTTPSession, attachments: List[discord.Attachment]
|
|
187
|
+
):
|
|
188
|
+
download_coros = [
|
|
189
|
+
download_discord_file(attachment.url) for attachment in attachments
|
|
190
|
+
]
|
|
191
|
+
file_bytes_list = await asyncio.gather(*download_coros)
|
|
192
|
+
file_refs = []
|
|
193
|
+
for idx, file_bytes in enumerate(file_bytes_list):
|
|
194
|
+
if file_bytes:
|
|
195
|
+
name = attachments[idx].filename
|
|
196
|
+
mime_type = attachments[idx].content_type or "application/octet-stream"
|
|
197
|
+
file_ref = await session.persist_file(
|
|
198
|
+
name=name, mime=mime_type, content=file_bytes
|
|
199
|
+
)
|
|
200
|
+
file_refs.append(file_ref)
|
|
201
|
+
|
|
202
|
+
files_dicts = [
|
|
203
|
+
session.files[file["id"]] for file in file_refs if file["id"] in session.files
|
|
204
|
+
]
|
|
205
|
+
|
|
206
|
+
file_elements = [Element.from_dict(file_dict) for file_dict in files_dicts]
|
|
207
|
+
|
|
208
|
+
return file_elements
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def clean_content(message: discord.Message):
|
|
212
|
+
if not client.user:
|
|
213
|
+
return message.content
|
|
214
|
+
|
|
215
|
+
# Regex to find mentions of the bot
|
|
216
|
+
bot_mention = f"<@!?{client.user.id}>"
|
|
217
|
+
# Replace the bot's mention with nothing
|
|
218
|
+
return re.sub(bot_mention, "", message.content).strip()
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
async def process_discord_message(
|
|
222
|
+
message: discord.Message,
|
|
223
|
+
thread_id: str,
|
|
224
|
+
thread_name: str,
|
|
225
|
+
channel: "MessageableChannel",
|
|
226
|
+
bind_thread_to_user=False,
|
|
227
|
+
):
|
|
228
|
+
user = await get_user(message.author)
|
|
229
|
+
|
|
230
|
+
text = clean_content(message)
|
|
231
|
+
discord_files = message.attachments
|
|
232
|
+
|
|
233
|
+
session_id = str(uuid.uuid4())
|
|
234
|
+
session = HTTPSession(
|
|
235
|
+
id=session_id,
|
|
236
|
+
thread_id=thread_id,
|
|
237
|
+
user=user,
|
|
238
|
+
client_type="discord",
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
ctx = init_discord_context(
|
|
242
|
+
session=session,
|
|
243
|
+
channel=channel,
|
|
244
|
+
message=message,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
file_elements = await download_discord_files(session, discord_files)
|
|
248
|
+
|
|
249
|
+
if on_chat_start := config.code.on_chat_start:
|
|
250
|
+
await on_chat_start()
|
|
251
|
+
|
|
252
|
+
msg = Message(
|
|
253
|
+
content=text,
|
|
254
|
+
elements=file_elements,
|
|
255
|
+
type="user_message",
|
|
256
|
+
author=user.metadata.get("name"),
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
await msg.send()
|
|
260
|
+
|
|
261
|
+
if on_message := config.code.on_message:
|
|
262
|
+
async with channel.typing():
|
|
263
|
+
await on_message(msg)
|
|
264
|
+
|
|
265
|
+
if on_chat_end := config.code.on_chat_end:
|
|
266
|
+
await on_chat_end()
|
|
267
|
+
|
|
268
|
+
if data_layer := get_data_layer():
|
|
269
|
+
user_id = None
|
|
270
|
+
if isinstance(user, PersistedUser):
|
|
271
|
+
user_id = user.id if bind_thread_to_user else None
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
await data_layer.update_thread(
|
|
275
|
+
thread_id=thread_id,
|
|
276
|
+
name=thread_name,
|
|
277
|
+
metadata=ctx.session.to_persistable(),
|
|
278
|
+
user_id=user_id,
|
|
279
|
+
)
|
|
280
|
+
except Exception as e:
|
|
281
|
+
logger.error(f"Error updating thread: {e}")
|
|
282
|
+
|
|
283
|
+
ctx.session.delete()
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@client.event
|
|
287
|
+
async def on_ready():
|
|
288
|
+
logger.info(f"Logged in as {client.user}")
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@client.event
|
|
292
|
+
async def on_message(message: discord.Message):
|
|
293
|
+
if not client.user or message.author == client.user:
|
|
294
|
+
return
|
|
295
|
+
|
|
296
|
+
is_dm = isinstance(message.channel, discord.DMChannel)
|
|
297
|
+
if not client.user.mentioned_in(message) and not is_dm:
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
thread_name: str = ""
|
|
301
|
+
thread_id: str = ""
|
|
302
|
+
bind_thread_to_user = False
|
|
303
|
+
channel = message.channel
|
|
304
|
+
|
|
305
|
+
if isinstance(message.channel, discord.Thread):
|
|
306
|
+
thread_name = f"{message.channel.name}"
|
|
307
|
+
thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(channel.id)))
|
|
308
|
+
elif isinstance(message.channel, discord.ForumChannel):
|
|
309
|
+
thread_name = f"{message.channel.name}"
|
|
310
|
+
thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(channel.id)))
|
|
311
|
+
elif isinstance(message.channel, discord.DMChannel):
|
|
312
|
+
thread_id = str(
|
|
313
|
+
uuid.uuid5(
|
|
314
|
+
uuid.NAMESPACE_DNS,
|
|
315
|
+
str(channel.id) + datetime.today().strftime("%Y-%m-%d"),
|
|
316
|
+
)
|
|
317
|
+
)
|
|
318
|
+
thread_name = (
|
|
319
|
+
f"{message.author} Discord DM {datetime.today().strftime('%Y-%m-%d')}"
|
|
320
|
+
)
|
|
321
|
+
bind_thread_to_user = True
|
|
322
|
+
elif isinstance(message.channel, discord.GroupChannel):
|
|
323
|
+
thread_id = str(
|
|
324
|
+
uuid.uuid5(
|
|
325
|
+
uuid.NAMESPACE_DNS,
|
|
326
|
+
str(channel.id) + datetime.today().strftime("%Y-%m-%d"),
|
|
327
|
+
)
|
|
328
|
+
)
|
|
329
|
+
thread_name = f"{message.channel.name}"
|
|
330
|
+
elif isinstance(message.channel, discord.TextChannel):
|
|
331
|
+
# Discord limits thread names to 100 characters and does not create
|
|
332
|
+
# threads from empty messages.
|
|
333
|
+
thread_id = str(
|
|
334
|
+
uuid.uuid5(
|
|
335
|
+
uuid.NAMESPACE_DNS,
|
|
336
|
+
str(channel.id) + datetime.today().strftime("%Y-%m-%d"),
|
|
337
|
+
)
|
|
338
|
+
)
|
|
339
|
+
discord_thread_name = clean_content(message)[:100] or "Untitled"
|
|
340
|
+
channel = await message.channel.create_thread(
|
|
341
|
+
name=discord_thread_name, message=message
|
|
342
|
+
)
|
|
343
|
+
thread_name = f"{channel.name}"
|
|
344
|
+
else:
|
|
345
|
+
logger.warning(f"Unsupported channel type: {message.channel.type}")
|
|
346
|
+
return
|
|
347
|
+
|
|
348
|
+
await process_discord_message(
|
|
349
|
+
message=message,
|
|
350
|
+
thread_id=thread_id,
|
|
351
|
+
thread_name=thread_name,
|
|
352
|
+
channel=channel,
|
|
353
|
+
bind_thread_to_user=bind_thread_to_user,
|
|
354
|
+
)
|
chainlit/element.py
CHANGED
|
@@ -1,17 +1,30 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import mimetypes
|
|
2
3
|
import uuid
|
|
3
4
|
from enum import Enum
|
|
4
5
|
from io import BytesIO
|
|
5
|
-
from typing import
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
ClassVar,
|
|
9
|
+
Dict,
|
|
10
|
+
List,
|
|
11
|
+
Literal,
|
|
12
|
+
Optional,
|
|
13
|
+
TypedDict,
|
|
14
|
+
TypeVar,
|
|
15
|
+
Union,
|
|
16
|
+
)
|
|
6
17
|
|
|
7
18
|
import filetype
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
from pydantic.dataclasses import dataclass
|
|
21
|
+
from syncer import asyncio
|
|
22
|
+
|
|
8
23
|
from chainlit.context import context
|
|
9
24
|
from chainlit.data import get_data_layer
|
|
10
25
|
from chainlit.logger import logger
|
|
11
26
|
from chainlit.telemetry import trace_event
|
|
12
27
|
from chainlit.types import FileDict
|
|
13
|
-
from pydantic.dataclasses import Field, dataclass
|
|
14
|
-
from syncer import asyncio
|
|
15
28
|
|
|
16
29
|
mime_types = {
|
|
17
30
|
"text": "text/plain",
|
|
@@ -20,7 +33,16 @@ mime_types = {
|
|
|
20
33
|
}
|
|
21
34
|
|
|
22
35
|
ElementType = Literal[
|
|
23
|
-
"image",
|
|
36
|
+
"image",
|
|
37
|
+
"text",
|
|
38
|
+
"pdf",
|
|
39
|
+
"tasklist",
|
|
40
|
+
"audio",
|
|
41
|
+
"video",
|
|
42
|
+
"file",
|
|
43
|
+
"plotly",
|
|
44
|
+
"dataframe",
|
|
45
|
+
"custom",
|
|
24
46
|
]
|
|
25
47
|
ElementDisplay = Literal["inline", "side", "page"]
|
|
26
48
|
ElementSize = Literal["small", "medium", "large"]
|
|
@@ -38,12 +60,17 @@ class ElementDict(TypedDict):
|
|
|
38
60
|
size: Optional[ElementSize]
|
|
39
61
|
language: Optional[str]
|
|
40
62
|
page: Optional[int]
|
|
63
|
+
props: Optional[Dict]
|
|
64
|
+
autoPlay: Optional[bool]
|
|
65
|
+
playerConfig: Optional[dict]
|
|
41
66
|
forId: Optional[str]
|
|
42
67
|
mime: Optional[str]
|
|
43
68
|
|
|
44
69
|
|
|
45
70
|
@dataclass
|
|
46
71
|
class Element:
|
|
72
|
+
# Thread id
|
|
73
|
+
thread_id: str = Field(default_factory=lambda: context.session.thread_id)
|
|
47
74
|
# The type of the element. This will be used to determine how to display the element in the UI.
|
|
48
75
|
type: ClassVar[ElementType]
|
|
49
76
|
# Name of the element, this will be used to reference the element in the UI.
|
|
@@ -52,7 +79,7 @@ class Element:
|
|
|
52
79
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
53
80
|
# The key of the element hosted on Chainlit.
|
|
54
81
|
chainlit_key: Optional[str] = None
|
|
55
|
-
# The URL of the element if already hosted
|
|
82
|
+
# The URL of the element if already hosted somewhere else.
|
|
56
83
|
url: Optional[str] = None
|
|
57
84
|
# The S3 object key.
|
|
58
85
|
object_key: Optional[str] = None
|
|
@@ -61,7 +88,7 @@ class Element:
|
|
|
61
88
|
# The byte content of the element.
|
|
62
89
|
content: Optional[Union[bytes, str]] = None
|
|
63
90
|
# Controls how the image element should be displayed in the UI. Choices are “side” (default), “inline”, or “page”.
|
|
64
|
-
display: ElementDisplay = Field(default="
|
|
91
|
+
display: ElementDisplay = Field(default="inline")
|
|
65
92
|
# Controls element size
|
|
66
93
|
size: Optional[ElementSize] = None
|
|
67
94
|
# The ID of the message this element is associated with.
|
|
@@ -75,7 +102,6 @@ class Element:
|
|
|
75
102
|
trace_event(f"init {self.__class__.__name__}")
|
|
76
103
|
self.persisted = False
|
|
77
104
|
self.updatable = False
|
|
78
|
-
self.thread_id = context.session.thread_id
|
|
79
105
|
|
|
80
106
|
if not self.url and not self.path and not self.content:
|
|
81
107
|
raise ValueError("Must provide url, path or content to instantiate element")
|
|
@@ -92,7 +118,10 @@ class Element:
|
|
|
92
118
|
"display": self.display,
|
|
93
119
|
"objectKey": getattr(self, "object_key", None),
|
|
94
120
|
"size": getattr(self, "size", None),
|
|
121
|
+
"props": getattr(self, "props", None),
|
|
95
122
|
"page": getattr(self, "page", None),
|
|
123
|
+
"autoPlay": getattr(self, "auto_play", None),
|
|
124
|
+
"playerConfig": getattr(self, "player_config", None),
|
|
96
125
|
"language": getattr(self, "language", None),
|
|
97
126
|
"forId": getattr(self, "for_id", None),
|
|
98
127
|
"mime": getattr(self, "mime", None),
|
|
@@ -129,7 +158,7 @@ class Element:
|
|
|
129
158
|
try:
|
|
130
159
|
asyncio.create_task(data_layer.create_element(self))
|
|
131
160
|
except Exception as e:
|
|
132
|
-
logger.error(f"Failed to create element: {
|
|
161
|
+
logger.error(f"Failed to create element: {e!s}")
|
|
133
162
|
if not self.url and (not self.chainlit_key or self.updatable):
|
|
134
163
|
file_dict = await context.session.persist_file(
|
|
135
164
|
name=self.name,
|
|
@@ -147,7 +176,7 @@ class Element:
|
|
|
147
176
|
trace_event(f"remove {self.__class__.__name__}")
|
|
148
177
|
data_layer = get_data_layer()
|
|
149
178
|
if data_layer and self.persisted:
|
|
150
|
-
await data_layer.delete_element(self.id)
|
|
179
|
+
await data_layer.delete_element(self.id, self.thread_id)
|
|
151
180
|
await context.emitter.emit("remove_element", {"id": self.id})
|
|
152
181
|
|
|
153
182
|
async def send(self, for_id: str):
|
|
@@ -157,12 +186,14 @@ class Element:
|
|
|
157
186
|
self.for_id = for_id
|
|
158
187
|
|
|
159
188
|
if not self.mime:
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
189
|
+
if self.type in mime_types:
|
|
190
|
+
self.mime = mime_types[self.type]
|
|
191
|
+
elif self.path or isinstance(self.content, (bytes, bytearray)):
|
|
192
|
+
file_type = filetype.guess(self.path or self.content)
|
|
193
|
+
if file_type:
|
|
194
|
+
self.mime = file_type.mime
|
|
195
|
+
elif self.url:
|
|
196
|
+
self.mime = mimetypes.guess_type(self.url)[0]
|
|
166
197
|
|
|
167
198
|
await self._create()
|
|
168
199
|
|
|
@@ -170,7 +201,7 @@ class Element:
|
|
|
170
201
|
raise ValueError("Must provide url or chainlit key to send element")
|
|
171
202
|
|
|
172
203
|
trace_event(f"send {self.__class__.__name__}")
|
|
173
|
-
await context.emitter.
|
|
204
|
+
await context.emitter.send_element(self.to_dict())
|
|
174
205
|
|
|
175
206
|
|
|
176
207
|
ElementBased = TypeVar("ElementBased", bound=Element)
|
|
@@ -183,14 +214,6 @@ class Image(Element):
|
|
|
183
214
|
size: ElementSize = "medium"
|
|
184
215
|
|
|
185
216
|
|
|
186
|
-
@dataclass
|
|
187
|
-
class Avatar(Element):
|
|
188
|
-
type: ClassVar[ElementType] = "avatar"
|
|
189
|
-
|
|
190
|
-
async def send(self):
|
|
191
|
-
await super().send(for_id="")
|
|
192
|
-
|
|
193
|
-
|
|
194
217
|
@dataclass
|
|
195
218
|
class Text(Element):
|
|
196
219
|
"""Useful to send a text (not a message) to the UI."""
|
|
@@ -226,14 +249,10 @@ class Pyplot(Element):
|
|
|
226
249
|
if not isinstance(self.figure, Figure):
|
|
227
250
|
raise TypeError("figure must be a matplotlib.figure.Figure")
|
|
228
251
|
|
|
229
|
-
options = {
|
|
230
|
-
"dpi": 200,
|
|
231
|
-
"bbox_inches": "tight",
|
|
232
|
-
"backend": "Agg",
|
|
233
|
-
"format": "png",
|
|
234
|
-
}
|
|
235
252
|
image = BytesIO()
|
|
236
|
-
self.figure.savefig(
|
|
253
|
+
self.figure.savefig(
|
|
254
|
+
image, dpi=200, bbox_inches="tight", backend="Agg", format="png"
|
|
255
|
+
)
|
|
237
256
|
self.content = image.getvalue()
|
|
238
257
|
|
|
239
258
|
super().__post_init__()
|
|
@@ -306,6 +325,7 @@ class TaskList(Element):
|
|
|
306
325
|
@dataclass
|
|
307
326
|
class Audio(Element):
|
|
308
327
|
type: ClassVar[ElementType] = "audio"
|
|
328
|
+
auto_play: bool = False
|
|
309
329
|
|
|
310
330
|
|
|
311
331
|
@dataclass
|
|
@@ -313,6 +333,9 @@ class Video(Element):
|
|
|
313
333
|
type: ClassVar[ElementType] = "video"
|
|
314
334
|
|
|
315
335
|
size: ElementSize = "medium"
|
|
336
|
+
# Override settings for each type of player in ReactPlayer
|
|
337
|
+
# https://github.com/cookpete/react-player?tab=readme-ov-file#config-prop
|
|
338
|
+
player_config: Optional[dict] = None
|
|
316
339
|
|
|
317
340
|
|
|
318
341
|
@dataclass
|
|
@@ -333,8 +356,7 @@ class Plotly(Element):
|
|
|
333
356
|
content: str = ""
|
|
334
357
|
|
|
335
358
|
def __post_init__(self) -> None:
|
|
336
|
-
from plotly import graph_objects as go
|
|
337
|
-
from plotly import io as pio
|
|
359
|
+
from plotly import graph_objects as go, io as pio
|
|
338
360
|
|
|
339
361
|
if not isinstance(self.figure, go.Figure):
|
|
340
362
|
raise TypeError("figure must be a plotly.graph_objects.Figure")
|
|
@@ -346,3 +368,39 @@ class Plotly(Element):
|
|
|
346
368
|
self.mime = "application/json"
|
|
347
369
|
|
|
348
370
|
super().__post_init__()
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
@dataclass
|
|
374
|
+
class Dataframe(Element):
|
|
375
|
+
"""Useful to send a pandas DataFrame to the UI."""
|
|
376
|
+
|
|
377
|
+
type: ClassVar[ElementType] = "dataframe"
|
|
378
|
+
size: ElementSize = "large"
|
|
379
|
+
data: Any = None # The type is Any because it is checked in __post_init__.
|
|
380
|
+
|
|
381
|
+
def __post_init__(self) -> None:
|
|
382
|
+
"""Ensures the data is a pandas DataFrame and converts it to JSON."""
|
|
383
|
+
from pandas import DataFrame
|
|
384
|
+
|
|
385
|
+
if not isinstance(self.data, DataFrame):
|
|
386
|
+
raise TypeError("data must be a pandas.DataFrame")
|
|
387
|
+
|
|
388
|
+
self.content = self.data.to_json(orient="split", date_format="iso")
|
|
389
|
+
super().__post_init__()
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
@dataclass
|
|
393
|
+
class CustomElement(Element):
|
|
394
|
+
"""Useful to send a custom element to the UI."""
|
|
395
|
+
|
|
396
|
+
type: ClassVar[ElementType] = "custom"
|
|
397
|
+
mime: str = "application/json"
|
|
398
|
+
props: Dict = Field(default_factory=dict)
|
|
399
|
+
|
|
400
|
+
def __post_init__(self) -> None:
|
|
401
|
+
self.content = json.dumps(self.props)
|
|
402
|
+
super().__post_init__()
|
|
403
|
+
self.updatable = True
|
|
404
|
+
|
|
405
|
+
async def update(self):
|
|
406
|
+
await super().send(self.for_id)
|