chainlit 1.1.0rc1__py3-none-any.whl → 1.1.200__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

chainlit/data/acl.py CHANGED
@@ -5,9 +5,12 @@ from fastapi import HTTPException
5
5
  async def is_thread_author(username: str, thread_id: str):
6
6
  data_layer = get_data_layer()
7
7
  if not data_layer:
8
- raise HTTPException(status_code=401, detail="Unauthorized")
8
+ raise HTTPException(status_code=400, detail="Data layer not initialized")
9
9
 
10
10
  thread_author = await data_layer.get_thread_author(thread_id)
11
+
12
+ if not thread_author:
13
+ raise HTTPException(status_code=404, detail="Thread not found")
11
14
 
12
15
  if thread_author != username:
13
16
  raise HTTPException(status_code=401, detail="Unauthorized")
@@ -39,9 +39,11 @@ class SQLAlchemyDataLayer(BaseDataLayer):
39
39
  ssl_require: bool = False,
40
40
  storage_provider: Optional[BaseStorageClient] = None,
41
41
  user_thread_limit: Optional[int] = 1000,
42
+ show_logger: Optional[bool] = False,
42
43
  ):
43
44
  self._conninfo = conninfo
44
45
  self.user_thread_limit = user_thread_limit
46
+ self.show_logger = show_logger
45
47
  ssl_args = {}
46
48
  if ssl_require:
47
49
  # Create an SSL context to require an SSL connection
@@ -55,7 +57,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
55
57
  self.async_session = sessionmaker(bind=self.engine, expire_on_commit=False, class_=AsyncSession) # type: ignore
56
58
  if storage_provider:
57
59
  self.storage_provider: Optional[BaseStorageClient] = storage_provider
58
- logger.info("SQLAlchemyDataLayer storage client initialized")
60
+ if self.show_logger: logger.info("SQLAlchemyDataLayer storage client initialized")
59
61
  else:
60
62
  self.storage_provider = None
61
63
  logger.warn(
@@ -102,7 +104,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
102
104
 
103
105
  ###### User ######
104
106
  async def get_user(self, identifier: str) -> Optional[PersistedUser]:
105
- logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
107
+ if self.show_logger: logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
106
108
  query = "SELECT * FROM users WHERE identifier = :identifier"
107
109
  parameters = {"identifier": identifier}
108
110
  result = await self.execute_sql(query=query, parameters=parameters)
@@ -112,20 +114,20 @@ class SQLAlchemyDataLayer(BaseDataLayer):
112
114
  return None
113
115
 
114
116
  async def create_user(self, user: User) -> Optional[PersistedUser]:
115
- logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
117
+ if self.show_logger: logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
116
118
  existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier)
117
119
  user_dict: Dict[str, Any] = {
118
120
  "identifier": str(user.identifier),
119
121
  "metadata": json.dumps(user.metadata) or {},
120
122
  }
121
123
  if not existing_user: # create the user
122
- logger.info("SQLAlchemy: create_user, creating the user")
124
+ if self.show_logger: logger.info("SQLAlchemy: create_user, creating the user")
123
125
  user_dict["id"] = str(uuid.uuid4())
124
126
  user_dict["createdAt"] = await self.get_current_timestamp()
125
127
  query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)"""
126
128
  await self.execute_sql(query=query, parameters=user_dict)
127
129
  else: # update the user
128
- logger.info("SQLAlchemy: update user metadata")
130
+ if self.show_logger: logger.info("SQLAlchemy: update user metadata")
129
131
  query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier"""
130
132
  await self.execute_sql(
131
133
  query=query, parameters=user_dict
@@ -134,19 +136,18 @@ class SQLAlchemyDataLayer(BaseDataLayer):
134
136
 
135
137
  ###### Threads ######
136
138
  async def get_thread_author(self, thread_id: str) -> str:
137
- logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
139
+ if self.show_logger: logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
138
140
  query = """SELECT "userIdentifier" FROM threads WHERE "id" = :id"""
139
141
  parameters = {"id": thread_id}
140
142
  result = await self.execute_sql(query=query, parameters=parameters)
141
143
  if isinstance(result, list) and result[0]:
142
144
  author_identifier = result[0].get("userIdentifier")
143
145
  if author_identifier is not None:
144
- print(f"Author found: {author_identifier}")
145
146
  return author_identifier
146
147
  raise ValueError(f"Author not found for thread_id {thread_id}")
147
148
 
148
149
  async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
149
- logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
150
+ if self.show_logger: logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
150
151
  user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(
151
152
  thread_id=thread_id
152
153
  )
@@ -163,7 +164,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
163
164
  metadata: Optional[Dict] = None,
164
165
  tags: Optional[List[str]] = None,
165
166
  ):
166
- logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
167
+ if self.show_logger: logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
167
168
  if context.session.user is not None:
168
169
  user_identifier = context.session.user.identifier
169
170
  else:
@@ -200,7 +201,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
200
201
  await self.execute_sql(query=query, parameters=parameters)
201
202
 
202
203
  async def delete_thread(self, thread_id: str):
203
- logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
204
+ if self.show_logger: logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
204
205
  # Delete feedbacks/elements/steps/thread
205
206
  feedbacks_query = """DELETE FROM feedbacks WHERE "forId" IN (SELECT "id" FROM steps WHERE "threadId" = :id)"""
206
207
  elements_query = """DELETE FROM elements WHERE "threadId" = :id"""
@@ -215,7 +216,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
215
216
  async def list_threads(
216
217
  self, pagination: Pagination, filters: ThreadFilter
217
218
  ) -> PaginatedResponse:
218
- logger.info(
219
+ if self.show_logger: logger.info(
219
220
  f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}"
220
221
  )
221
222
  if not filters.userId:
@@ -275,7 +276,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
275
276
  ###### Steps ######
276
277
  @queue_until_user_message()
277
278
  async def create_step(self, step_dict: "StepDict"):
278
- logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
279
+ if self.show_logger: logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
279
280
  if not getattr(context.session.user, "id", None):
280
281
  raise ValueError("No authenticated user in context")
281
282
  step_dict["showInput"] = (
@@ -305,12 +306,12 @@ class SQLAlchemyDataLayer(BaseDataLayer):
305
306
 
306
307
  @queue_until_user_message()
307
308
  async def update_step(self, step_dict: "StepDict"):
308
- logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
309
+ if self.show_logger: logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
309
310
  await self.create_step(step_dict)
310
311
 
311
312
  @queue_until_user_message()
312
313
  async def delete_step(self, step_id: str):
313
- logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
314
+ if self.show_logger: logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
314
315
  # Delete feedbacks/elements/steps
315
316
  feedbacks_query = """DELETE FROM feedbacks WHERE "forId" = :id"""
316
317
  elements_query = """DELETE FROM elements WHERE "forId" = :id"""
@@ -322,7 +323,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
322
323
 
323
324
  ###### Feedback ######
324
325
  async def upsert_feedback(self, feedback: Feedback) -> str:
325
- logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
326
+ if self.show_logger: logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
326
327
  feedback.id = feedback.id or str(uuid.uuid4())
327
328
  feedback_dict = asdict(feedback)
328
329
  parameters = {
@@ -344,7 +345,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
344
345
  return feedback.id
345
346
 
346
347
  async def delete_feedback(self, feedback_id: str) -> bool:
347
- logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
348
+ if self.show_logger: logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
348
349
  query = """DELETE FROM feedbacks WHERE "id" = :feedback_id"""
349
350
  parameters = {"feedback_id": feedback_id}
350
351
  await self.execute_sql(query=query, parameters=parameters)
@@ -353,7 +354,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
353
354
  ###### Elements ######
354
355
  @queue_until_user_message()
355
356
  async def create_element(self, element: "Element"):
356
- logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
357
+ if self.show_logger: logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
357
358
  if not getattr(context.session.user, "id", None):
358
359
  raise ValueError("No authenticated user in context")
359
360
  if isinstance(element, Avatar): # Skip creating elements of type avatar
@@ -416,7 +417,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
416
417
 
417
418
  @queue_until_user_message()
418
419
  async def delete_element(self, element_id: str):
419
- logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
420
+ if self.show_logger: logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
420
421
  query = """DELETE FROM elements WHERE "id" = :id"""
421
422
  parameters = {"id": element_id}
422
423
  await self.execute_sql(query=query, parameters=parameters)
@@ -428,7 +429,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
428
429
  self, user_id: Optional[str] = None, thread_id: Optional[str] = None
429
430
  ) -> Optional[List[ThreadDict]]:
430
431
  """Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided."""
431
- logger.info(f"SQLAlchemy: get_all_user_threads")
432
+ if self.show_logger: logger.info(f"SQLAlchemy: get_all_user_threads")
432
433
  user_threads_query = """
433
434
  SELECT
434
435
  "id" AS thread_id,
@@ -562,8 +563,8 @@ class SQLAlchemyDataLayer(BaseDataLayer):
562
563
  tags=step_feedback.get("step_tags"),
563
564
  input=(
564
565
  step_feedback.get("step_input", "")
565
- if step_feedback["step_showinput"]
566
- else ""
566
+ if step_feedback["step_showinput"] == "true"
567
+ else None
567
568
  ),
568
569
  output=step_feedback.get("step_output", ""),
569
570
  createdAt=step_feedback.get("step_createdat"),
@@ -0,0 +1,6 @@
1
+ try:
2
+ import discord
3
+ except ModuleNotFoundError:
4
+ raise ValueError(
5
+ "The discord package is required to integrate Chainlit with a Slack app. Run `pip install discord --upgrade`"
6
+ )
@@ -0,0 +1,331 @@
1
+ import asyncio
2
+ import mimetypes
3
+ import re
4
+ import uuid
5
+ from io import BytesIO
6
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
7
+
8
+ if TYPE_CHECKING:
9
+ from discord.abc import MessageableChannel
10
+
11
+ import discord
12
+ import filetype
13
+ import httpx
14
+ from chainlit.config import config
15
+ from chainlit.context import ChainlitContext, HTTPSession, context_var
16
+ from chainlit.data import get_data_layer
17
+ from chainlit.element import Element, ElementDict
18
+ from chainlit.emitter import BaseChainlitEmitter
19
+ from chainlit.logger import logger
20
+ from chainlit.message import Message, StepDict
21
+ from chainlit.telemetry import trace
22
+ from chainlit.types import Feedback
23
+ from chainlit.user import PersistedUser, User
24
+ from chainlit.user_session import user_session
25
+ from discord.ui import Button, View
26
+
27
+
28
+ class FeedbackView(View):
29
+ def __init__(self, step_id: str):
30
+ super().__init__(timeout=None)
31
+ self.step_id = step_id
32
+
33
+ @discord.ui.button(label="👎")
34
+ async def thumbs_down(self, interaction: discord.Interaction, button: Button):
35
+ if data_layer := get_data_layer():
36
+ try:
37
+ thread_id = context_var.get().session.thread_id
38
+ feedback = Feedback(forId=self.step_id, threadId=thread_id, value=0)
39
+ await data_layer.upsert_feedback(feedback)
40
+ except Exception as e:
41
+ logger.error(f"Error upserting feedback: {e}")
42
+ if interaction.message:
43
+ await interaction.message.edit(view=None)
44
+ await interaction.message.add_reaction("👎")
45
+
46
+ @discord.ui.button(label="👍")
47
+ async def thumbs_up(self, interaction: discord.Interaction, button: Button):
48
+ if data_layer := get_data_layer():
49
+ try:
50
+ thread_id = context_var.get().session.thread_id
51
+ feedback = Feedback(forId=self.step_id, threadId=thread_id, value=1)
52
+ await data_layer.upsert_feedback(feedback)
53
+ except Exception as e:
54
+ logger.error(f"Error upserting feedback: {e}")
55
+ if interaction.message:
56
+ await interaction.message.edit(view=None)
57
+ await interaction.message.add_reaction("👍")
58
+
59
+
60
+ class DiscordEmitter(BaseChainlitEmitter):
61
+ def __init__(
62
+ self, session: HTTPSession, channel: "MessageableChannel", enabled=False
63
+ ):
64
+ super().__init__(session)
65
+ self.channel = channel
66
+ self.enabled = enabled
67
+
68
+ async def send_element(self, element_dict: ElementDict):
69
+ if not self.enabled or element_dict.get("display") != "inline":
70
+ return
71
+
72
+ persisted_file = self.session.files.get(element_dict.get("chainlitKey") or "")
73
+ file: Optional[Union[BytesIO, str]] = None
74
+ mime: Optional[str] = None
75
+
76
+ if persisted_file:
77
+ file = str(persisted_file["path"])
78
+ mime = element_dict.get("mime")
79
+ elif file_url := element_dict.get("url"):
80
+ async with httpx.AsyncClient() as client:
81
+ response = await client.get(file_url)
82
+ if response.status_code == 200:
83
+ file = BytesIO(response.content)
84
+ mime = filetype.guess_mime(file)
85
+
86
+ if not file:
87
+ return
88
+
89
+ element_name: str = element_dict.get("name", "Untitled")
90
+
91
+ if mime:
92
+ file_extension = mimetypes.guess_extension(mime)
93
+ if file_extension:
94
+ element_name += file_extension
95
+
96
+ file_obj = discord.File(file, filename=element_name)
97
+ await self.channel.send(file=file_obj)
98
+
99
+ async def send_step(self, step_dict: StepDict):
100
+ if not self.enabled:
101
+ return
102
+
103
+ step_type = step_dict.get("type")
104
+ is_message = step_type in [
105
+ "user_message",
106
+ "assistant_message",
107
+ "system_message",
108
+ ]
109
+ is_chain_of_thought = bool(step_dict.get("parentId"))
110
+ is_empty_output = not step_dict.get("output")
111
+
112
+ if is_chain_of_thought or is_empty_output or not is_message:
113
+ return
114
+ else:
115
+ enable_feedback = not step_dict.get("disableFeedback") and get_data_layer()
116
+ message = await self.channel.send(step_dict["output"])
117
+
118
+ if enable_feedback:
119
+ view = FeedbackView(step_dict.get("id", ""))
120
+ await message.edit(view=view)
121
+
122
+ async def update_step(self, step_dict: StepDict):
123
+ if not self.enabled:
124
+ return
125
+
126
+ await self.send_step(step_dict)
127
+
128
+
129
+ intents = discord.Intents.default()
130
+ intents.message_content = True
131
+
132
+ client = discord.Client(intents=intents)
133
+
134
+
135
+ @trace
136
+ def init_discord_context(
137
+ session: HTTPSession,
138
+ channel: "MessageableChannel",
139
+ message: discord.Message,
140
+ ) -> ChainlitContext:
141
+ emitter = DiscordEmitter(session=session, channel=channel)
142
+ context = ChainlitContext(session=session, emitter=emitter)
143
+ context_var.set(context)
144
+ user_session.set("discord_message", message)
145
+ user_session.set("discord_channel", channel)
146
+ return context
147
+
148
+
149
+ users_by_discord_id: Dict[int, Union[User, PersistedUser]] = {}
150
+
151
+ USER_PREFIX = "discord_"
152
+
153
+
154
+ async def get_user(discord_user: Union[discord.User, discord.Member]):
155
+ if discord_user.id in users_by_discord_id:
156
+ return users_by_discord_id[discord_user.id]
157
+
158
+ metadata = {
159
+ "name": discord_user.name,
160
+ "id": discord_user.id,
161
+ }
162
+ user = User(identifier=USER_PREFIX + str(discord_user.name), metadata=metadata)
163
+
164
+ users_by_discord_id[discord_user.id] = user
165
+
166
+ if data_layer := get_data_layer():
167
+ try:
168
+ persisted_user = await data_layer.create_user(user)
169
+ if persisted_user:
170
+ users_by_discord_id[discord_user.id] = persisted_user
171
+ except Exception as e:
172
+ logger.error(f"Error creating user: {e}")
173
+
174
+ return users_by_discord_id[discord_user.id]
175
+
176
+
177
+ async def download_discord_file(url: str):
178
+ async with httpx.AsyncClient() as client:
179
+ response = await client.get(url)
180
+ if response.status_code == 200:
181
+ return response.content
182
+ else:
183
+ return None
184
+
185
+
186
+ async def download_discord_files(
187
+ session: HTTPSession, attachments: List[discord.Attachment]
188
+ ):
189
+ download_coros = [
190
+ download_discord_file(attachment.url) for attachment in attachments
191
+ ]
192
+ file_bytes_list = await asyncio.gather(*download_coros)
193
+ file_refs = []
194
+ for idx, file_bytes in enumerate(file_bytes_list):
195
+ if file_bytes:
196
+ name = attachments[idx].filename
197
+ mime_type = attachments[idx].content_type or "application/octet-stream"
198
+ file_ref = await session.persist_file(
199
+ name=name, mime=mime_type, content=file_bytes
200
+ )
201
+ file_refs.append(file_ref)
202
+
203
+ files_dicts = [
204
+ session.files[file["id"]] for file in file_refs if file["id"] in session.files
205
+ ]
206
+
207
+ file_elements = [Element.from_dict(file_dict) for file_dict in files_dicts]
208
+
209
+ return file_elements
210
+
211
+
212
+ def clean_content(message: discord.Message):
213
+ if not client.user:
214
+ return message.content
215
+
216
+ # Regex to find mentions of the bot
217
+ bot_mention = f"<@!?{client.user.id}>"
218
+ # Replace the bot's mention with nothing
219
+ return re.sub(bot_mention, "", message.content).strip()
220
+
221
+
222
+ async def process_discord_message(
223
+ message: discord.Message,
224
+ thread_name: str,
225
+ channel: "MessageableChannel",
226
+ bind_thread_to_user=False,
227
+ ):
228
+ user = await get_user(message.author)
229
+
230
+ thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(channel.id)))
231
+
232
+ text = clean_content(message)
233
+ discord_files = message.attachments
234
+
235
+ session_id = str(uuid.uuid4())
236
+ session = HTTPSession(
237
+ id=session_id,
238
+ thread_id=thread_id,
239
+ user=user,
240
+ client_type="discord",
241
+ )
242
+
243
+ ctx = init_discord_context(
244
+ session=session,
245
+ channel=channel,
246
+ message=message,
247
+ )
248
+
249
+ file_elements = await download_discord_files(session, discord_files)
250
+
251
+ msg = Message(
252
+ content=text,
253
+ elements=file_elements,
254
+ type="user_message",
255
+ author=user.metadata.get("name"),
256
+ )
257
+
258
+ await msg.send()
259
+
260
+ ctx.emitter.enabled = True
261
+
262
+ if on_chat_start := config.code.on_chat_start:
263
+ await on_chat_start()
264
+
265
+ if on_message := config.code.on_message:
266
+ async with channel.typing():
267
+ await on_message(msg)
268
+
269
+ if on_chat_end := config.code.on_chat_end:
270
+ await on_chat_end()
271
+
272
+ if data_layer := get_data_layer():
273
+ user_id = None
274
+ if isinstance(user, PersistedUser):
275
+ user_id = user.id if bind_thread_to_user else None
276
+
277
+ try:
278
+ await data_layer.update_thread(
279
+ thread_id=thread_id,
280
+ name=thread_name,
281
+ metadata=ctx.session.to_persistable(),
282
+ user_id=user_id,
283
+ )
284
+ except Exception as e:
285
+ logger.error(f"Error updating thread: {e}")
286
+
287
+ ctx.session.delete()
288
+
289
+
290
+ @client.event
291
+ async def on_ready():
292
+ logger.info(f"Logged in as {client.user}")
293
+
294
+
295
+ @client.event
296
+ async def on_message(message: discord.Message):
297
+ if not client.user or message.author == client.user:
298
+ return
299
+
300
+ is_dm = isinstance(message.channel, discord.DMChannel)
301
+ if not client.user.mentioned_in(message) and not is_dm:
302
+ return
303
+
304
+ thread_name: str = ""
305
+ bind_thread_to_user = False
306
+ channel = message.channel
307
+
308
+ if isinstance(message.channel, discord.Thread):
309
+ thread_name = f"{message.channel.name}"
310
+ elif isinstance(message.channel, discord.ForumChannel):
311
+ thread_name = f"{message.channel.name}"
312
+ elif isinstance(message.channel, discord.DMChannel):
313
+ thread_name = f"{message.author} Discord DM"
314
+ bind_thread_to_user = True
315
+ elif isinstance(message.channel, discord.GroupChannel):
316
+ thread_name = f"{message.channel.name}"
317
+ elif isinstance(message.channel, discord.TextChannel):
318
+ channel = await message.channel.create_thread(
319
+ name=clean_content(message), message=message
320
+ )
321
+ thread_name = f"{channel.name}"
322
+ else:
323
+ logger.warning(f"Unsupported channel type: {message.channel.type}")
324
+ return
325
+
326
+ await process_discord_message(
327
+ message=message,
328
+ thread_name=thread_name,
329
+ channel=channel,
330
+ bind_thread_to_user=bind_thread_to_user,
331
+ )
chainlit/element.py CHANGED
@@ -5,6 +5,7 @@ from io import BytesIO
5
5
  from typing import Any, ClassVar, List, Literal, Optional, TypedDict, TypeVar, Union
6
6
 
7
7
  import filetype
8
+ import mimetypes
8
9
  from chainlit.context import context
9
10
  from chainlit.data import get_data_layer
10
11
  from chainlit.logger import logger
@@ -165,14 +166,16 @@ class Element:
165
166
  if self.type in mime_types
166
167
  else filetype.guess_mime(self.path or self.content)
167
168
  )
168
-
169
+ if not self.mime and self.url:
170
+ self.mime = mimetypes.guess_type(self.url)[0]
171
+
169
172
  await self._create()
170
173
 
171
174
  if not self.url and not self.chainlit_key:
172
175
  raise ValueError("Must provide url or chainlit key to send element")
173
176
 
174
177
  trace_event(f"send {self.__class__.__name__}")
175
- await context.emitter.emit("element", self.to_dict())
178
+ await context.emitter.send_element(self.to_dict())
176
179
 
177
180
 
178
181
  ElementBased = TypeVar("ElementBased", bound=Element)
chainlit/emitter.py CHANGED
@@ -4,10 +4,10 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast
4
4
 
5
5
  from chainlit.config import config
6
6
  from chainlit.data import get_data_layer
7
- from chainlit.element import Element, File
7
+ from chainlit.element import Element, ElementDict, File
8
8
  from chainlit.logger import logger
9
9
  from chainlit.message import Message
10
- from chainlit.session import BaseSession, WebsocketSession
10
+ from chainlit.session import BaseSession, HTTPSession, WebsocketSession
11
11
  from chainlit.step import StepDict
12
12
  from chainlit.types import (
13
13
  AskActionResponse,
@@ -29,6 +29,7 @@ class BaseChainlitEmitter:
29
29
  """
30
30
 
31
31
  session: BaseSession
32
+ enabled: bool = True
32
33
 
33
34
  def __init__(self, session: BaseSession) -> None:
34
35
  """Initialize with the user session."""
@@ -46,6 +47,10 @@ class BaseChainlitEmitter:
46
47
  """Stub method to resume a thread."""
47
48
  pass
48
49
 
50
+ async def send_element(self, element_dict: ElementDict):
51
+ """Stub method to send an element to the UI."""
52
+ pass
53
+
49
54
  async def send_step(self, step_dict: StepDict):
50
55
  """Stub method to send a message to the UI."""
51
56
  pass
@@ -151,6 +156,10 @@ class ChainlitEmitter(BaseChainlitEmitter):
151
156
  """Send a thread to the UI to resume it"""
152
157
  return self.emit("resume_thread", thread_dict)
153
158
 
159
+ async def send_element(self, element_dict: ElementDict):
160
+ """Stub method to send an element to the UI."""
161
+ await self.emit("element", element_dict)
162
+
154
163
  def send_step(self, step_dict: StepDict):
155
164
  """Send a message to the UI."""
156
165
  return self.emit("new_message", step_dict)