chainlit 1.0.506__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

chainlit/data/__init__.py CHANGED
@@ -156,6 +156,7 @@ class ChainlitDataLayer(BaseDataLayer):
156
156
  "chainlitKey": None,
157
157
  "display": metadata.get("display", "side"),
158
158
  "language": metadata.get("language"),
159
+ "autoPlay": metadata.get("autoPlay", None),
159
160
  "page": metadata.get("page"),
160
161
  "size": metadata.get("size"),
161
162
  "type": metadata.get("type", "file"),
@@ -219,7 +220,7 @@ class ChainlitDataLayer(BaseDataLayer):
219
220
  "disableFeedback": metadata.get("disableFeedback", False),
220
221
  "indent": metadata.get("indent"),
221
222
  "language": metadata.get("language"),
222
- "isError": metadata.get("isError", False),
223
+ "isError": bool(step.error),
223
224
  "waitForAnswer": metadata.get("waitForAnswer", False),
224
225
  }
225
226
 
@@ -348,7 +349,6 @@ class ChainlitDataLayer(BaseDataLayer):
348
349
  step_dict.get("metadata", {}),
349
350
  **{
350
351
  "disableFeedback": step_dict.get("disableFeedback"),
351
- "isError": step_dict.get("isError"),
352
352
  "waitForAnswer": step_dict.get("waitForAnswer"),
353
353
  "language": step_dict.get("language"),
354
354
  "showInput": step_dict.get("showInput"),
@@ -372,6 +372,8 @@ class ChainlitDataLayer(BaseDataLayer):
372
372
  step["input"] = {"content": step_dict.get("input")}
373
373
  if step_dict.get("output"):
374
374
  step["output"] = {"content": step_dict.get("output")}
375
+ if step_dict.get("isError"):
376
+ step["error"] = step_dict.get("output")
375
377
 
376
378
  await self.client.api.send_steps([step])
377
379
 
chainlit/data/acl.py CHANGED
@@ -5,9 +5,12 @@ from fastapi import HTTPException
5
5
  async def is_thread_author(username: str, thread_id: str):
6
6
  data_layer = get_data_layer()
7
7
  if not data_layer:
8
- raise HTTPException(status_code=401, detail="Unauthorized")
8
+ raise HTTPException(status_code=400, detail="Data layer not initialized")
9
9
 
10
10
  thread_author = await data_layer.get_thread_author(thread_id)
11
+
12
+ if not thread_author:
13
+ raise HTTPException(status_code=404, detail="Thread not found")
11
14
 
12
15
  if thread_author != username:
13
16
  raise HTTPException(status_code=401, detail="Unauthorized")
@@ -39,9 +39,11 @@ class SQLAlchemyDataLayer(BaseDataLayer):
39
39
  ssl_require: bool = False,
40
40
  storage_provider: Optional[BaseStorageClient] = None,
41
41
  user_thread_limit: Optional[int] = 1000,
42
+ show_logger: Optional[bool] = False,
42
43
  ):
43
44
  self._conninfo = conninfo
44
45
  self.user_thread_limit = user_thread_limit
46
+ self.show_logger = show_logger
45
47
  ssl_args = {}
46
48
  if ssl_require:
47
49
  # Create an SSL context to require an SSL connection
@@ -55,7 +57,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
55
57
  self.async_session = sessionmaker(bind=self.engine, expire_on_commit=False, class_=AsyncSession) # type: ignore
56
58
  if storage_provider:
57
59
  self.storage_provider: Optional[BaseStorageClient] = storage_provider
58
- logger.info("SQLAlchemyDataLayer storage client initialized")
60
+ if self.show_logger: logger.info("SQLAlchemyDataLayer storage client initialized")
59
61
  else:
60
62
  self.storage_provider = None
61
63
  logger.warn(
@@ -102,7 +104,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
102
104
 
103
105
  ###### User ######
104
106
  async def get_user(self, identifier: str) -> Optional[PersistedUser]:
105
- logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
107
+ if self.show_logger: logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
106
108
  query = "SELECT * FROM users WHERE identifier = :identifier"
107
109
  parameters = {"identifier": identifier}
108
110
  result = await self.execute_sql(query=query, parameters=parameters)
@@ -112,20 +114,20 @@ class SQLAlchemyDataLayer(BaseDataLayer):
112
114
  return None
113
115
 
114
116
  async def create_user(self, user: User) -> Optional[PersistedUser]:
115
- logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
117
+ if self.show_logger: logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
116
118
  existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier)
117
119
  user_dict: Dict[str, Any] = {
118
120
  "identifier": str(user.identifier),
119
121
  "metadata": json.dumps(user.metadata) or {},
120
122
  }
121
123
  if not existing_user: # create the user
122
- logger.info("SQLAlchemy: create_user, creating the user")
124
+ if self.show_logger: logger.info("SQLAlchemy: create_user, creating the user")
123
125
  user_dict["id"] = str(uuid.uuid4())
124
126
  user_dict["createdAt"] = await self.get_current_timestamp()
125
127
  query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)"""
126
128
  await self.execute_sql(query=query, parameters=user_dict)
127
129
  else: # update the user
128
- logger.info("SQLAlchemy: update user metadata")
130
+ if self.show_logger: logger.info("SQLAlchemy: update user metadata")
129
131
  query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier"""
130
132
  await self.execute_sql(
131
133
  query=query, parameters=user_dict
@@ -134,19 +136,18 @@ class SQLAlchemyDataLayer(BaseDataLayer):
134
136
 
135
137
  ###### Threads ######
136
138
  async def get_thread_author(self, thread_id: str) -> str:
137
- logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
139
+ if self.show_logger: logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
138
140
  query = """SELECT "userIdentifier" FROM threads WHERE "id" = :id"""
139
141
  parameters = {"id": thread_id}
140
142
  result = await self.execute_sql(query=query, parameters=parameters)
141
143
  if isinstance(result, list) and result[0]:
142
144
  author_identifier = result[0].get("userIdentifier")
143
145
  if author_identifier is not None:
144
- print(f"Author found: {author_identifier}")
145
146
  return author_identifier
146
147
  raise ValueError(f"Author not found for thread_id {thread_id}")
147
148
 
148
149
  async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
149
- logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
150
+ if self.show_logger: logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
150
151
  user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(
151
152
  thread_id=thread_id
152
153
  )
@@ -163,19 +164,21 @@ class SQLAlchemyDataLayer(BaseDataLayer):
163
164
  metadata: Optional[Dict] = None,
164
165
  tags: Optional[List[str]] = None,
165
166
  ):
166
- logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
167
+ if self.show_logger: logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
167
168
  if context.session.user is not None:
168
169
  user_identifier = context.session.user.identifier
169
170
  else:
170
171
  raise ValueError("User not found in session context")
171
172
  data = {
172
173
  "id": thread_id,
173
- "createdAt": await self.get_current_timestamp()
174
- if metadata is None
175
- else None,
176
- "name": name
177
- if name is not None
178
- else (metadata.get("name") if metadata and "name" in metadata else None),
174
+ "createdAt": (
175
+ await self.get_current_timestamp() if metadata is None else None
176
+ ),
177
+ "name": (
178
+ name
179
+ if name is not None
180
+ else (metadata.get("name") if metadata and "name" in metadata else None)
181
+ ),
179
182
  "userId": user_id,
180
183
  "userIdentifier": user_identifier,
181
184
  "tags": tags,
@@ -198,7 +201,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
198
201
  await self.execute_sql(query=query, parameters=parameters)
199
202
 
200
203
  async def delete_thread(self, thread_id: str):
201
- logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
204
+ if self.show_logger: logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
202
205
  # Delete feedbacks/elements/steps/thread
203
206
  feedbacks_query = """DELETE FROM feedbacks WHERE "forId" IN (SELECT "id" FROM steps WHERE "threadId" = :id)"""
204
207
  elements_query = """DELETE FROM elements WHERE "threadId" = :id"""
@@ -213,7 +216,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
213
216
  async def list_threads(
214
217
  self, pagination: Pagination, filters: ThreadFilter
215
218
  ) -> PaginatedResponse:
216
- logger.info(
219
+ if self.show_logger: logger.info(
217
220
  f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}"
218
221
  )
219
222
  if not filters.userId:
@@ -273,7 +276,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
273
276
  ###### Steps ######
274
277
  @queue_until_user_message()
275
278
  async def create_step(self, step_dict: "StepDict"):
276
- logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
279
+ if self.show_logger: logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
277
280
  if not getattr(context.session.user, "id", None):
278
281
  raise ValueError("No authenticated user in context")
279
282
  step_dict["showInput"] = (
@@ -303,12 +306,12 @@ class SQLAlchemyDataLayer(BaseDataLayer):
303
306
 
304
307
  @queue_until_user_message()
305
308
  async def update_step(self, step_dict: "StepDict"):
306
- logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
309
+ if self.show_logger: logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
307
310
  await self.create_step(step_dict)
308
311
 
309
312
  @queue_until_user_message()
310
313
  async def delete_step(self, step_id: str):
311
- logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
314
+ if self.show_logger: logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
312
315
  # Delete feedbacks/elements/steps
313
316
  feedbacks_query = """DELETE FROM feedbacks WHERE "forId" = :id"""
314
317
  elements_query = """DELETE FROM elements WHERE "forId" = :id"""
@@ -320,7 +323,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
320
323
 
321
324
  ###### Feedback ######
322
325
  async def upsert_feedback(self, feedback: Feedback) -> str:
323
- logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
326
+ if self.show_logger: logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
324
327
  feedback.id = feedback.id or str(uuid.uuid4())
325
328
  feedback_dict = asdict(feedback)
326
329
  parameters = {
@@ -342,7 +345,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
342
345
  return feedback.id
343
346
 
344
347
  async def delete_feedback(self, feedback_id: str) -> bool:
345
- logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
348
+ if self.show_logger: logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
346
349
  query = """DELETE FROM feedbacks WHERE "id" = :feedback_id"""
347
350
  parameters = {"feedback_id": feedback_id}
348
351
  await self.execute_sql(query=query, parameters=parameters)
@@ -351,7 +354,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
351
354
  ###### Elements ######
352
355
  @queue_until_user_message()
353
356
  async def create_element(self, element: "Element"):
354
- logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
357
+ if self.show_logger: logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
355
358
  if not getattr(context.session.user, "id", None):
356
359
  raise ValueError("No authenticated user in context")
357
360
  if isinstance(element, Avatar): # Skip creating elements of type avatar
@@ -414,7 +417,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
414
417
 
415
418
  @queue_until_user_message()
416
419
  async def delete_element(self, element_id: str):
417
- logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
420
+ if self.show_logger: logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
418
421
  query = """DELETE FROM elements WHERE "id" = :id"""
419
422
  parameters = {"id": element_id}
420
423
  await self.execute_sql(query=query, parameters=parameters)
@@ -426,7 +429,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
426
429
  self, user_id: Optional[str] = None, thread_id: Optional[str] = None
427
430
  ) -> Optional[List[ThreadDict]]:
428
431
  """Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided."""
429
- logger.info(f"SQLAlchemy: get_all_user_threads")
432
+ if self.show_logger: logger.info(f"SQLAlchemy: get_all_user_threads")
430
433
  user_threads_query = """
431
434
  SELECT
432
435
  "id" AS thread_id,
@@ -552,13 +555,17 @@ class SQLAlchemyDataLayer(BaseDataLayer):
552
555
  streaming=step_feedback.get("step_streaming", False),
553
556
  waitForAnswer=step_feedback.get("step_waitforanswer"),
554
557
  isError=step_feedback.get("step_iserror"),
555
- metadata=step_feedback["step_metadata"]
556
- if step_feedback.get("step_metadata") is not None
557
- else {},
558
+ metadata=(
559
+ step_feedback["step_metadata"]
560
+ if step_feedback.get("step_metadata") is not None
561
+ else {}
562
+ ),
558
563
  tags=step_feedback.get("step_tags"),
559
- input=step_feedback.get("step_input", "")
560
- if step_feedback["step_showinput"]
561
- else "",
564
+ input=(
565
+ step_feedback.get("step_input", "")
566
+ if step_feedback["step_showinput"] == "true"
567
+ else None
568
+ ),
562
569
  output=step_feedback.get("step_output", ""),
563
570
  createdAt=step_feedback.get("step_createdat"),
564
571
  start=step_feedback.get("step_start"),
@@ -587,6 +594,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
587
594
  display=element["element_display"],
588
595
  size=element.get("element_size"),
589
596
  language=element.get("element_language"),
597
+ autoPlay=element.get("element_autoPlay"),
590
598
  page=element.get("element_page"),
591
599
  forId=element.get("element_forid"),
592
600
  mime=element.get("element_mime"),
@@ -0,0 +1,6 @@
1
+ try:
2
+ import discord
3
+ except ModuleNotFoundError:
4
+ raise ValueError(
5
+ "The discord package is required to integrate Chainlit with a Slack app. Run `pip install discord --upgrade`"
6
+ )
@@ -0,0 +1,304 @@
1
+ import asyncio
2
+ import mimetypes
3
+ import re
4
+ import uuid
5
+ from io import BytesIO
6
+ from typing import Dict, List, Optional, Union, TYPE_CHECKING
7
+
8
+ if TYPE_CHECKING:
9
+ from discord.abc import MessageableChannel
10
+
11
+ import discord
12
+
13
+ import filetype
14
+ import httpx
15
+ from chainlit.config import config
16
+ from chainlit.context import ChainlitContext, HTTPSession, context_var
17
+ from chainlit.data import get_data_layer
18
+ from chainlit.element import Element, ElementDict
19
+ from chainlit.emitter import BaseChainlitEmitter
20
+ from chainlit.logger import logger
21
+ from chainlit.message import Message, StepDict
22
+ from chainlit.types import Feedback
23
+ from chainlit.user import PersistedUser, User
24
+ from chainlit.user_session import user_session
25
+ from discord.ui import Button, View
26
+
27
+
28
+ class FeedbackView(View):
29
+ def __init__(self, step_id: str):
30
+ super().__init__(timeout=None)
31
+ self.step_id = step_id
32
+
33
+ @discord.ui.button(label="👎")
34
+ async def thumbs_down(self, interaction: discord.Interaction, button: Button):
35
+ if data_layer := get_data_layer():
36
+ await data_layer.upsert_feedback(Feedback(forId=self.step_id, value=0))
37
+ if interaction.message:
38
+ await interaction.message.edit(view=None)
39
+ await interaction.message.add_reaction("👎")
40
+
41
+ @discord.ui.button(label="👍")
42
+ async def thumbs_up(self, interaction: discord.Interaction, button: Button):
43
+ if data_layer := get_data_layer():
44
+ await data_layer.upsert_feedback(Feedback(forId=self.step_id, value=1))
45
+ if interaction.message:
46
+ await interaction.message.edit(view=None)
47
+ await interaction.message.add_reaction("👍")
48
+
49
+
50
+ class DiscordEmitter(BaseChainlitEmitter):
51
+ def __init__(
52
+ self, session: HTTPSession, channel: "MessageableChannel", enabled=False
53
+ ):
54
+ super().__init__(session)
55
+ self.channel = channel
56
+ self.enabled = enabled
57
+
58
+ async def send_element(self, element_dict: ElementDict):
59
+ if not self.enabled or element_dict.get("display") != "inline":
60
+ return
61
+
62
+ persisted_file = self.session.files.get(element_dict.get("chainlitKey") or "")
63
+ file: Optional[Union[BytesIO, str]] = None
64
+ mime: Optional[str] = None
65
+
66
+ if persisted_file:
67
+ file = str(persisted_file["path"])
68
+ mime = element_dict.get("mime")
69
+ elif file_url := element_dict.get("url"):
70
+ async with httpx.AsyncClient() as client:
71
+ response = await client.get(file_url)
72
+ if response.status_code == 200:
73
+ file = BytesIO(response.content)
74
+ mime = filetype.guess_mime(file)
75
+
76
+ if not file:
77
+ return
78
+
79
+ element_name: str = element_dict.get("name", "Untitled")
80
+
81
+ if mime:
82
+ file_extension = mimetypes.guess_extension(mime)
83
+ if file_extension:
84
+ element_name += file_extension
85
+
86
+ file_obj = discord.File(file, filename=element_name)
87
+ await self.channel.send(file=file_obj)
88
+
89
+ async def send_step(self, step_dict: StepDict):
90
+ if not self.enabled:
91
+ return
92
+
93
+ is_chain_of_thought = bool(step_dict.get("parentId"))
94
+ is_empty_output = not step_dict.get("output")
95
+
96
+ if is_chain_of_thought or is_empty_output:
97
+ return
98
+ else:
99
+ enable_feedback = not step_dict.get("disableFeedback") and get_data_layer()
100
+ message = await self.channel.send(step_dict["output"])
101
+
102
+ if enable_feedback:
103
+ view = FeedbackView(step_dict.get("id", ""))
104
+ await message.edit(view=view)
105
+
106
+ async def update_step(self, step_dict: StepDict):
107
+ if not self.enabled:
108
+ return
109
+
110
+ await self.send_step(step_dict)
111
+
112
+
113
+ intents = discord.Intents.default()
114
+ intents.message_content = True
115
+
116
+ client = discord.Client(intents=intents)
117
+
118
+
119
+ def init_discord_context(
120
+ session: HTTPSession,
121
+ channel: "MessageableChannel",
122
+ message: discord.Message,
123
+ ) -> ChainlitContext:
124
+ emitter = DiscordEmitter(session=session, channel=channel)
125
+ context = ChainlitContext(session=session, emitter=emitter)
126
+ context_var.set(context)
127
+ user_session.set("discord_message", message)
128
+ user_session.set("discord_channel", channel)
129
+ return context
130
+
131
+
132
+ users_by_discord_id: Dict[int, Union[User, PersistedUser]] = {}
133
+
134
+ USER_PREFIX = "discord_"
135
+
136
+
137
+ async def get_user(discord_user: Union[discord.User, discord.Member]):
138
+ metadata = {
139
+ "name": discord_user.name,
140
+ "id": discord_user.id,
141
+ }
142
+ user = User(identifier=USER_PREFIX + str(discord_user.name), metadata=metadata)
143
+
144
+ users_by_discord_id[discord_user.id] = user
145
+
146
+ if data_layer := get_data_layer():
147
+ persisted_user = await data_layer.create_user(user)
148
+ if persisted_user:
149
+ users_by_discord_id[discord_user.id] = persisted_user
150
+
151
+ return users_by_discord_id[discord_user.id]
152
+
153
+
154
+ async def download_discord_file(url: str):
155
+ async with httpx.AsyncClient() as client:
156
+ response = await client.get(url)
157
+ if response.status_code == 200:
158
+ return response.content
159
+ else:
160
+ return None
161
+
162
+
163
+ async def download_discord_files(
164
+ session: HTTPSession, attachments: List[discord.Attachment]
165
+ ):
166
+ download_coros = [
167
+ download_discord_file(attachment.url) for attachment in attachments
168
+ ]
169
+ file_bytes_list = await asyncio.gather(*download_coros)
170
+ file_refs = []
171
+ for idx, file_bytes in enumerate(file_bytes_list):
172
+ if file_bytes:
173
+ name = attachments[idx].filename
174
+ mime_type = attachments[idx].content_type or "application/octet-stream"
175
+ file_ref = await session.persist_file(
176
+ name=name, mime=mime_type, content=file_bytes
177
+ )
178
+ file_refs.append(file_ref)
179
+
180
+ files_dicts = [
181
+ session.files[file["id"]] for file in file_refs if file["id"] in session.files
182
+ ]
183
+
184
+ file_elements = [Element.from_dict(file_dict) for file_dict in files_dicts]
185
+
186
+ return file_elements
187
+
188
+
189
+ def clean_content(message: discord.Message):
190
+ if not client.user:
191
+ return message.content
192
+
193
+ # Regex to find mentions of the bot
194
+ bot_mention = f"<@!?{client.user.id}>"
195
+ # Replace the bot's mention with nothing
196
+ return re.sub(bot_mention, "", message.content).strip()
197
+
198
+
199
+ async def process_discord_message(
200
+ message: discord.Message,
201
+ thread_name: str,
202
+ channel: "MessageableChannel",
203
+ bind_thread_to_user=False,
204
+ ):
205
+ user = await get_user(message.author)
206
+
207
+ thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(channel.id)))
208
+
209
+ text = clean_content(message)
210
+ discord_files = message.attachments
211
+
212
+ session_id = str(uuid.uuid4())
213
+ session = HTTPSession(
214
+ id=session_id,
215
+ thread_id=thread_id,
216
+ user=user,
217
+ client_type="discord",
218
+ )
219
+
220
+ ctx = init_discord_context(
221
+ session=session,
222
+ channel=channel,
223
+ message=message,
224
+ )
225
+
226
+ file_elements = await download_discord_files(session, discord_files)
227
+
228
+ msg = Message(
229
+ content=text,
230
+ elements=file_elements,
231
+ type="user_message",
232
+ author=user.metadata.get("name"),
233
+ )
234
+
235
+ await msg.send()
236
+
237
+ ctx.emitter.enabled = True
238
+
239
+ if on_chat_start := config.code.on_chat_start:
240
+ await on_chat_start()
241
+
242
+ if on_message := config.code.on_message:
243
+ await on_message(msg)
244
+
245
+ if on_chat_end := config.code.on_chat_end:
246
+ await on_chat_end()
247
+
248
+ if data_layer := get_data_layer():
249
+ user_id = None
250
+ if isinstance(user, PersistedUser):
251
+ user_id = user.id if bind_thread_to_user else None
252
+
253
+ await data_layer.update_thread(
254
+ thread_id=thread_id,
255
+ name=thread_name,
256
+ metadata=ctx.session.to_persistable(),
257
+ user_id=user_id,
258
+ )
259
+
260
+ ctx.session.delete()
261
+
262
+
263
+ @client.event
264
+ async def on_ready():
265
+ logger.info(f"Logged in as {client.user}")
266
+
267
+
268
+ @client.event
269
+ async def on_message(message: discord.Message):
270
+ if not client.user or message.author == client.user:
271
+ return
272
+
273
+ is_dm = isinstance(message.channel, discord.DMChannel)
274
+ if not client.user.mentioned_in(message) and not is_dm:
275
+ return
276
+
277
+ thread_name: str = ""
278
+ bind_thread_to_user = False
279
+ channel = message.channel
280
+
281
+ if isinstance(message.channel, discord.Thread):
282
+ thread_name = f"{message.channel.name}"
283
+ elif isinstance(message.channel, discord.ForumChannel):
284
+ thread_name = f"{message.channel.name}"
285
+ elif isinstance(message.channel, discord.DMChannel):
286
+ thread_name = f"{message.author} Discord DM"
287
+ bind_thread_to_user = True
288
+ elif isinstance(message.channel, discord.GroupChannel):
289
+ thread_name = f"{message.channel.name}"
290
+ elif isinstance(message.channel, discord.TextChannel):
291
+ channel = await message.channel.create_thread(
292
+ name=clean_content(message), message=message
293
+ )
294
+ thread_name = f"{channel.name}"
295
+ else:
296
+ logger.warning(f"Unsupported channel type: {message.channel.type}")
297
+ return
298
+
299
+ await process_discord_message(
300
+ message=message,
301
+ thread_name=thread_name,
302
+ channel=channel,
303
+ bind_thread_to_user=bind_thread_to_user,
304
+ )
chainlit/element.py CHANGED
@@ -5,6 +5,7 @@ from io import BytesIO
5
5
  from typing import Any, ClassVar, List, Literal, Optional, TypedDict, TypeVar, Union
6
6
 
7
7
  import filetype
8
+ import mimetypes
8
9
  from chainlit.context import context
9
10
  from chainlit.data import get_data_layer
10
11
  from chainlit.logger import logger
@@ -38,6 +39,7 @@ class ElementDict(TypedDict):
38
39
  size: Optional[ElementSize]
39
40
  language: Optional[str]
40
41
  page: Optional[int]
42
+ autoPlay: Optional[bool]
41
43
  forId: Optional[str]
42
44
  mime: Optional[str]
43
45
 
@@ -61,7 +63,7 @@ class Element:
61
63
  # The byte content of the element.
62
64
  content: Optional[Union[bytes, str]] = None
63
65
  # Controls how the image element should be displayed in the UI. Choices are “side” (default), “inline”, or “page”.
64
- display: ElementDisplay = Field(default="side")
66
+ display: ElementDisplay = Field(default="inline")
65
67
  # Controls element size
66
68
  size: Optional[ElementSize] = None
67
69
  # The ID of the message this element is associated with.
@@ -93,6 +95,7 @@ class Element:
93
95
  "objectKey": getattr(self, "object_key", None),
94
96
  "size": getattr(self, "size", None),
95
97
  "page": getattr(self, "page", None),
98
+ "autoPlay": getattr(self, "auto_play", None),
96
99
  "language": getattr(self, "language", None),
97
100
  "forId": getattr(self, "for_id", None),
98
101
  "mime": getattr(self, "mime", None),
@@ -163,14 +166,16 @@ class Element:
163
166
  if self.type in mime_types
164
167
  else filetype.guess_mime(self.path or self.content)
165
168
  )
166
-
169
+ if not self.mime and self.url:
170
+ self.mime = mimetypes.guess_type(self.url)[0]
171
+
167
172
  await self._create()
168
173
 
169
174
  if not self.url and not self.chainlit_key:
170
175
  raise ValueError("Must provide url or chainlit key to send element")
171
176
 
172
177
  trace_event(f"send {self.__class__.__name__}")
173
- await context.emitter.emit("element", self.to_dict())
178
+ await context.emitter.send_element(self.to_dict())
174
179
 
175
180
 
176
181
  ElementBased = TypeVar("ElementBased", bound=Element)
@@ -306,6 +311,7 @@ class TaskList(Element):
306
311
  @dataclass
307
312
  class Audio(Element):
308
313
  type: ClassVar[ElementType] = "audio"
314
+ auto_play: bool = False
309
315
 
310
316
 
311
317
  @dataclass
chainlit/emitter.py CHANGED
@@ -4,10 +4,10 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast
4
4
 
5
5
  from chainlit.config import config
6
6
  from chainlit.data import get_data_layer
7
- from chainlit.element import Element, File
7
+ from chainlit.element import Element, ElementDict, File
8
8
  from chainlit.logger import logger
9
9
  from chainlit.message import Message
10
- from chainlit.session import BaseSession, WebsocketSession
10
+ from chainlit.session import BaseSession, HTTPSession, WebsocketSession
11
11
  from chainlit.step import StepDict
12
12
  from chainlit.types import (
13
13
  AskActionResponse,
@@ -29,6 +29,7 @@ class BaseChainlitEmitter:
29
29
  """
30
30
 
31
31
  session: BaseSession
32
+ enabled: bool = True
32
33
 
33
34
  def __init__(self, session: BaseSession) -> None:
34
35
  """Initialize with the user session."""
@@ -46,6 +47,10 @@ class BaseChainlitEmitter:
46
47
  """Stub method to resume a thread."""
47
48
  pass
48
49
 
50
+ async def send_element(self, element_dict: ElementDict):
51
+ """Stub method to send an element to the UI."""
52
+ pass
53
+
49
54
  async def send_step(self, step_dict: StepDict):
50
55
  """Stub method to send a message to the UI."""
51
56
  pass
@@ -151,6 +156,10 @@ class ChainlitEmitter(BaseChainlitEmitter):
151
156
  """Send a thread to the UI to resume it"""
152
157
  return self.emit("resume_thread", thread_dict)
153
158
 
159
+ async def send_element(self, element_dict: ElementDict):
160
+ """Stub method to send an element to the UI."""
161
+ await self.emit("element", element_dict)
162
+
154
163
  def send_step(self, step_dict: StepDict):
155
164
  """Send a message to the UI."""
156
165
  return self.emit("new_message", step_dict)