chainlit 1.1.0rc1__py3-none-any.whl → 1.1.101__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/config.py +4 -2
- chainlit/context.py +19 -7
- chainlit/copilot/dist/index.js +639 -521
- chainlit/data/acl.py +4 -1
- chainlit/data/sql_alchemy.py +22 -21
- chainlit/discord/__init__.py +6 -0
- chainlit/discord/app.py +322 -0
- chainlit/element.py +5 -2
- chainlit/emitter.py +11 -2
- chainlit/frontend/dist/assets/{index-032fca02.js → index-37c9a5a9.js} +120 -120
- chainlit/frontend/dist/assets/react-plotly-c55d0c95.js +3602 -0
- chainlit/frontend/dist/index.html +1 -1
- chainlit/server.py +27 -4
- chainlit/session.py +72 -61
- chainlit/slack/__init__.py +6 -0
- chainlit/slack/app.py +379 -0
- chainlit/socket.py +21 -15
- chainlit/step.py +25 -1
- chainlit/types.py +2 -1
- chainlit/user_session.py +5 -2
- {chainlit-1.1.0rc1.dist-info → chainlit-1.1.101.dist-info}/METADATA +4 -3
- {chainlit-1.1.0rc1.dist-info → chainlit-1.1.101.dist-info}/RECORD +24 -20
- chainlit/frontend/dist/assets/react-plotly-8c993614.js +0 -3484
- {chainlit-1.1.0rc1.dist-info → chainlit-1.1.101.dist-info}/WHEEL +0 -0
- {chainlit-1.1.0rc1.dist-info → chainlit-1.1.101.dist-info}/entry_points.txt +0 -0
chainlit/data/acl.py
CHANGED
|
@@ -5,9 +5,12 @@ from fastapi import HTTPException
|
|
|
5
5
|
async def is_thread_author(username: str, thread_id: str):
|
|
6
6
|
data_layer = get_data_layer()
|
|
7
7
|
if not data_layer:
|
|
8
|
-
raise HTTPException(status_code=
|
|
8
|
+
raise HTTPException(status_code=400, detail="Data layer not initialized")
|
|
9
9
|
|
|
10
10
|
thread_author = await data_layer.get_thread_author(thread_id)
|
|
11
|
+
|
|
12
|
+
if not thread_author:
|
|
13
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
11
14
|
|
|
12
15
|
if thread_author != username:
|
|
13
16
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
chainlit/data/sql_alchemy.py
CHANGED
|
@@ -39,9 +39,11 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
39
39
|
ssl_require: bool = False,
|
|
40
40
|
storage_provider: Optional[BaseStorageClient] = None,
|
|
41
41
|
user_thread_limit: Optional[int] = 1000,
|
|
42
|
+
show_logger: Optional[bool] = False,
|
|
42
43
|
):
|
|
43
44
|
self._conninfo = conninfo
|
|
44
45
|
self.user_thread_limit = user_thread_limit
|
|
46
|
+
self.show_logger = show_logger
|
|
45
47
|
ssl_args = {}
|
|
46
48
|
if ssl_require:
|
|
47
49
|
# Create an SSL context to require an SSL connection
|
|
@@ -55,7 +57,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
55
57
|
self.async_session = sessionmaker(bind=self.engine, expire_on_commit=False, class_=AsyncSession) # type: ignore
|
|
56
58
|
if storage_provider:
|
|
57
59
|
self.storage_provider: Optional[BaseStorageClient] = storage_provider
|
|
58
|
-
logger.info("SQLAlchemyDataLayer storage client initialized")
|
|
60
|
+
if self.show_logger: logger.info("SQLAlchemyDataLayer storage client initialized")
|
|
59
61
|
else:
|
|
60
62
|
self.storage_provider = None
|
|
61
63
|
logger.warn(
|
|
@@ -102,7 +104,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
102
104
|
|
|
103
105
|
###### User ######
|
|
104
106
|
async def get_user(self, identifier: str) -> Optional[PersistedUser]:
|
|
105
|
-
logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
|
|
107
|
+
if self.show_logger: logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
|
|
106
108
|
query = "SELECT * FROM users WHERE identifier = :identifier"
|
|
107
109
|
parameters = {"identifier": identifier}
|
|
108
110
|
result = await self.execute_sql(query=query, parameters=parameters)
|
|
@@ -112,20 +114,20 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
112
114
|
return None
|
|
113
115
|
|
|
114
116
|
async def create_user(self, user: User) -> Optional[PersistedUser]:
|
|
115
|
-
logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
|
|
117
|
+
if self.show_logger: logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
|
|
116
118
|
existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier)
|
|
117
119
|
user_dict: Dict[str, Any] = {
|
|
118
120
|
"identifier": str(user.identifier),
|
|
119
121
|
"metadata": json.dumps(user.metadata) or {},
|
|
120
122
|
}
|
|
121
123
|
if not existing_user: # create the user
|
|
122
|
-
logger.info("SQLAlchemy: create_user, creating the user")
|
|
124
|
+
if self.show_logger: logger.info("SQLAlchemy: create_user, creating the user")
|
|
123
125
|
user_dict["id"] = str(uuid.uuid4())
|
|
124
126
|
user_dict["createdAt"] = await self.get_current_timestamp()
|
|
125
127
|
query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)"""
|
|
126
128
|
await self.execute_sql(query=query, parameters=user_dict)
|
|
127
129
|
else: # update the user
|
|
128
|
-
logger.info("SQLAlchemy: update user metadata")
|
|
130
|
+
if self.show_logger: logger.info("SQLAlchemy: update user metadata")
|
|
129
131
|
query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier"""
|
|
130
132
|
await self.execute_sql(
|
|
131
133
|
query=query, parameters=user_dict
|
|
@@ -134,19 +136,18 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
134
136
|
|
|
135
137
|
###### Threads ######
|
|
136
138
|
async def get_thread_author(self, thread_id: str) -> str:
|
|
137
|
-
logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
|
|
139
|
+
if self.show_logger: logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
|
|
138
140
|
query = """SELECT "userIdentifier" FROM threads WHERE "id" = :id"""
|
|
139
141
|
parameters = {"id": thread_id}
|
|
140
142
|
result = await self.execute_sql(query=query, parameters=parameters)
|
|
141
143
|
if isinstance(result, list) and result[0]:
|
|
142
144
|
author_identifier = result[0].get("userIdentifier")
|
|
143
145
|
if author_identifier is not None:
|
|
144
|
-
print(f"Author found: {author_identifier}")
|
|
145
146
|
return author_identifier
|
|
146
147
|
raise ValueError(f"Author not found for thread_id {thread_id}")
|
|
147
148
|
|
|
148
149
|
async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
|
|
149
|
-
logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
|
|
150
|
+
if self.show_logger: logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
|
|
150
151
|
user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(
|
|
151
152
|
thread_id=thread_id
|
|
152
153
|
)
|
|
@@ -163,7 +164,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
163
164
|
metadata: Optional[Dict] = None,
|
|
164
165
|
tags: Optional[List[str]] = None,
|
|
165
166
|
):
|
|
166
|
-
logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
|
|
167
|
+
if self.show_logger: logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
|
|
167
168
|
if context.session.user is not None:
|
|
168
169
|
user_identifier = context.session.user.identifier
|
|
169
170
|
else:
|
|
@@ -200,7 +201,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
200
201
|
await self.execute_sql(query=query, parameters=parameters)
|
|
201
202
|
|
|
202
203
|
async def delete_thread(self, thread_id: str):
|
|
203
|
-
logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
|
|
204
|
+
if self.show_logger: logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
|
|
204
205
|
# Delete feedbacks/elements/steps/thread
|
|
205
206
|
feedbacks_query = """DELETE FROM feedbacks WHERE "forId" IN (SELECT "id" FROM steps WHERE "threadId" = :id)"""
|
|
206
207
|
elements_query = """DELETE FROM elements WHERE "threadId" = :id"""
|
|
@@ -215,7 +216,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
215
216
|
async def list_threads(
|
|
216
217
|
self, pagination: Pagination, filters: ThreadFilter
|
|
217
218
|
) -> PaginatedResponse:
|
|
218
|
-
logger.info(
|
|
219
|
+
if self.show_logger: logger.info(
|
|
219
220
|
f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}"
|
|
220
221
|
)
|
|
221
222
|
if not filters.userId:
|
|
@@ -275,7 +276,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
275
276
|
###### Steps ######
|
|
276
277
|
@queue_until_user_message()
|
|
277
278
|
async def create_step(self, step_dict: "StepDict"):
|
|
278
|
-
logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
|
|
279
|
+
if self.show_logger: logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
|
|
279
280
|
if not getattr(context.session.user, "id", None):
|
|
280
281
|
raise ValueError("No authenticated user in context")
|
|
281
282
|
step_dict["showInput"] = (
|
|
@@ -305,12 +306,12 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
305
306
|
|
|
306
307
|
@queue_until_user_message()
|
|
307
308
|
async def update_step(self, step_dict: "StepDict"):
|
|
308
|
-
logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
|
|
309
|
+
if self.show_logger: logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
|
|
309
310
|
await self.create_step(step_dict)
|
|
310
311
|
|
|
311
312
|
@queue_until_user_message()
|
|
312
313
|
async def delete_step(self, step_id: str):
|
|
313
|
-
logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
|
|
314
|
+
if self.show_logger: logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
|
|
314
315
|
# Delete feedbacks/elements/steps
|
|
315
316
|
feedbacks_query = """DELETE FROM feedbacks WHERE "forId" = :id"""
|
|
316
317
|
elements_query = """DELETE FROM elements WHERE "forId" = :id"""
|
|
@@ -322,7 +323,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
322
323
|
|
|
323
324
|
###### Feedback ######
|
|
324
325
|
async def upsert_feedback(self, feedback: Feedback) -> str:
|
|
325
|
-
logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
|
|
326
|
+
if self.show_logger: logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
|
|
326
327
|
feedback.id = feedback.id or str(uuid.uuid4())
|
|
327
328
|
feedback_dict = asdict(feedback)
|
|
328
329
|
parameters = {
|
|
@@ -344,7 +345,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
344
345
|
return feedback.id
|
|
345
346
|
|
|
346
347
|
async def delete_feedback(self, feedback_id: str) -> bool:
|
|
347
|
-
logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
|
|
348
|
+
if self.show_logger: logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
|
|
348
349
|
query = """DELETE FROM feedbacks WHERE "id" = :feedback_id"""
|
|
349
350
|
parameters = {"feedback_id": feedback_id}
|
|
350
351
|
await self.execute_sql(query=query, parameters=parameters)
|
|
@@ -353,7 +354,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
353
354
|
###### Elements ######
|
|
354
355
|
@queue_until_user_message()
|
|
355
356
|
async def create_element(self, element: "Element"):
|
|
356
|
-
logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
|
|
357
|
+
if self.show_logger: logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
|
|
357
358
|
if not getattr(context.session.user, "id", None):
|
|
358
359
|
raise ValueError("No authenticated user in context")
|
|
359
360
|
if isinstance(element, Avatar): # Skip creating elements of type avatar
|
|
@@ -416,7 +417,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
416
417
|
|
|
417
418
|
@queue_until_user_message()
|
|
418
419
|
async def delete_element(self, element_id: str):
|
|
419
|
-
logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
|
|
420
|
+
if self.show_logger: logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
|
|
420
421
|
query = """DELETE FROM elements WHERE "id" = :id"""
|
|
421
422
|
parameters = {"id": element_id}
|
|
422
423
|
await self.execute_sql(query=query, parameters=parameters)
|
|
@@ -428,7 +429,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
428
429
|
self, user_id: Optional[str] = None, thread_id: Optional[str] = None
|
|
429
430
|
) -> Optional[List[ThreadDict]]:
|
|
430
431
|
"""Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided."""
|
|
431
|
-
logger.info(f"SQLAlchemy: get_all_user_threads")
|
|
432
|
+
if self.show_logger: logger.info(f"SQLAlchemy: get_all_user_threads")
|
|
432
433
|
user_threads_query = """
|
|
433
434
|
SELECT
|
|
434
435
|
"id" AS thread_id,
|
|
@@ -562,8 +563,8 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
562
563
|
tags=step_feedback.get("step_tags"),
|
|
563
564
|
input=(
|
|
564
565
|
step_feedback.get("step_input", "")
|
|
565
|
-
if step_feedback["step_showinput"]
|
|
566
|
-
else
|
|
566
|
+
if step_feedback["step_showinput"] == "true"
|
|
567
|
+
else None
|
|
567
568
|
),
|
|
568
569
|
output=step_feedback.get("step_output", ""),
|
|
569
570
|
createdAt=step_feedback.get("step_createdat"),
|
chainlit/discord/app.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import mimetypes
|
|
3
|
+
import re
|
|
4
|
+
import uuid
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
from typing import Dict, List, Optional, Union, TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from discord.abc import MessageableChannel
|
|
10
|
+
|
|
11
|
+
import discord
|
|
12
|
+
|
|
13
|
+
import filetype
|
|
14
|
+
import httpx
|
|
15
|
+
from chainlit.config import config
|
|
16
|
+
from chainlit.context import ChainlitContext, HTTPSession, context_var
|
|
17
|
+
from chainlit.data import get_data_layer
|
|
18
|
+
from chainlit.element import Element, ElementDict
|
|
19
|
+
from chainlit.emitter import BaseChainlitEmitter
|
|
20
|
+
from chainlit.logger import logger
|
|
21
|
+
from chainlit.message import Message, StepDict
|
|
22
|
+
from chainlit.types import Feedback
|
|
23
|
+
from chainlit.user import PersistedUser, User
|
|
24
|
+
from chainlit.user_session import user_session
|
|
25
|
+
from chainlit.telemetry import trace
|
|
26
|
+
|
|
27
|
+
from discord.ui import Button, View
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class FeedbackView(View):
|
|
31
|
+
def __init__(self, step_id: str):
|
|
32
|
+
super().__init__(timeout=None)
|
|
33
|
+
self.step_id = step_id
|
|
34
|
+
|
|
35
|
+
@discord.ui.button(label="👎")
|
|
36
|
+
async def thumbs_down(self, interaction: discord.Interaction, button: Button):
|
|
37
|
+
if data_layer := get_data_layer():
|
|
38
|
+
try:
|
|
39
|
+
await data_layer.upsert_feedback(Feedback(forId=self.step_id, value=0))
|
|
40
|
+
except Exception as e:
|
|
41
|
+
logger.error(f"Error upserting feedback: {e}")
|
|
42
|
+
if interaction.message:
|
|
43
|
+
await interaction.message.edit(view=None)
|
|
44
|
+
await interaction.message.add_reaction("👎")
|
|
45
|
+
|
|
46
|
+
@discord.ui.button(label="👍")
|
|
47
|
+
async def thumbs_up(self, interaction: discord.Interaction, button: Button):
|
|
48
|
+
if data_layer := get_data_layer():
|
|
49
|
+
try:
|
|
50
|
+
await data_layer.upsert_feedback(Feedback(forId=self.step_id, value=1))
|
|
51
|
+
except Exception as e:
|
|
52
|
+
logger.error(f"Error upserting feedback: {e}")
|
|
53
|
+
if interaction.message:
|
|
54
|
+
await interaction.message.edit(view=None)
|
|
55
|
+
await interaction.message.add_reaction("👍")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class DiscordEmitter(BaseChainlitEmitter):
|
|
59
|
+
def __init__(
|
|
60
|
+
self, session: HTTPSession, channel: "MessageableChannel", enabled=False
|
|
61
|
+
):
|
|
62
|
+
super().__init__(session)
|
|
63
|
+
self.channel = channel
|
|
64
|
+
self.enabled = enabled
|
|
65
|
+
|
|
66
|
+
async def send_element(self, element_dict: ElementDict):
|
|
67
|
+
if not self.enabled or element_dict.get("display") != "inline":
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
persisted_file = self.session.files.get(element_dict.get("chainlitKey") or "")
|
|
71
|
+
file: Optional[Union[BytesIO, str]] = None
|
|
72
|
+
mime: Optional[str] = None
|
|
73
|
+
|
|
74
|
+
if persisted_file:
|
|
75
|
+
file = str(persisted_file["path"])
|
|
76
|
+
mime = element_dict.get("mime")
|
|
77
|
+
elif file_url := element_dict.get("url"):
|
|
78
|
+
async with httpx.AsyncClient() as client:
|
|
79
|
+
response = await client.get(file_url)
|
|
80
|
+
if response.status_code == 200:
|
|
81
|
+
file = BytesIO(response.content)
|
|
82
|
+
mime = filetype.guess_mime(file)
|
|
83
|
+
|
|
84
|
+
if not file:
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
element_name: str = element_dict.get("name", "Untitled")
|
|
88
|
+
|
|
89
|
+
if mime:
|
|
90
|
+
file_extension = mimetypes.guess_extension(mime)
|
|
91
|
+
if file_extension:
|
|
92
|
+
element_name += file_extension
|
|
93
|
+
|
|
94
|
+
file_obj = discord.File(file, filename=element_name)
|
|
95
|
+
await self.channel.send(file=file_obj)
|
|
96
|
+
|
|
97
|
+
async def send_step(self, step_dict: StepDict):
|
|
98
|
+
if not self.enabled:
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
is_chain_of_thought = bool(step_dict.get("parentId"))
|
|
102
|
+
is_empty_output = not step_dict.get("output")
|
|
103
|
+
|
|
104
|
+
if is_chain_of_thought or is_empty_output:
|
|
105
|
+
return
|
|
106
|
+
else:
|
|
107
|
+
enable_feedback = not step_dict.get("disableFeedback") and get_data_layer()
|
|
108
|
+
message = await self.channel.send(step_dict["output"])
|
|
109
|
+
|
|
110
|
+
if enable_feedback:
|
|
111
|
+
view = FeedbackView(step_dict.get("id", ""))
|
|
112
|
+
await message.edit(view=view)
|
|
113
|
+
|
|
114
|
+
async def update_step(self, step_dict: StepDict):
|
|
115
|
+
if not self.enabled:
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
await self.send_step(step_dict)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
intents = discord.Intents.default()
|
|
122
|
+
intents.message_content = True
|
|
123
|
+
|
|
124
|
+
client = discord.Client(intents=intents)
|
|
125
|
+
|
|
126
|
+
@trace
|
|
127
|
+
def init_discord_context(
|
|
128
|
+
session: HTTPSession,
|
|
129
|
+
channel: "MessageableChannel",
|
|
130
|
+
message: discord.Message,
|
|
131
|
+
) -> ChainlitContext:
|
|
132
|
+
emitter = DiscordEmitter(session=session, channel=channel)
|
|
133
|
+
context = ChainlitContext(session=session, emitter=emitter)
|
|
134
|
+
context_var.set(context)
|
|
135
|
+
user_session.set("discord_message", message)
|
|
136
|
+
user_session.set("discord_channel", channel)
|
|
137
|
+
return context
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
users_by_discord_id: Dict[int, Union[User, PersistedUser]] = {}
|
|
141
|
+
|
|
142
|
+
USER_PREFIX = "discord_"
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def get_user(discord_user: Union[discord.User, discord.Member]):
|
|
146
|
+
if discord_user.id in users_by_discord_id:
|
|
147
|
+
return users_by_discord_id[discord_user.id]
|
|
148
|
+
|
|
149
|
+
metadata = {
|
|
150
|
+
"name": discord_user.name,
|
|
151
|
+
"id": discord_user.id,
|
|
152
|
+
}
|
|
153
|
+
user = User(identifier=USER_PREFIX + str(discord_user.name), metadata=metadata)
|
|
154
|
+
|
|
155
|
+
users_by_discord_id[discord_user.id] = user
|
|
156
|
+
|
|
157
|
+
if data_layer := get_data_layer():
|
|
158
|
+
try:
|
|
159
|
+
persisted_user = await data_layer.create_user(user)
|
|
160
|
+
if persisted_user:
|
|
161
|
+
users_by_discord_id[discord_user.id] = persisted_user
|
|
162
|
+
except Exception as e:
|
|
163
|
+
logger.error(f"Error creating user: {e}")
|
|
164
|
+
|
|
165
|
+
return users_by_discord_id[discord_user.id]
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
async def download_discord_file(url: str):
|
|
169
|
+
async with httpx.AsyncClient() as client:
|
|
170
|
+
response = await client.get(url)
|
|
171
|
+
if response.status_code == 200:
|
|
172
|
+
return response.content
|
|
173
|
+
else:
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
async def download_discord_files(
|
|
178
|
+
session: HTTPSession, attachments: List[discord.Attachment]
|
|
179
|
+
):
|
|
180
|
+
download_coros = [
|
|
181
|
+
download_discord_file(attachment.url) for attachment in attachments
|
|
182
|
+
]
|
|
183
|
+
file_bytes_list = await asyncio.gather(*download_coros)
|
|
184
|
+
file_refs = []
|
|
185
|
+
for idx, file_bytes in enumerate(file_bytes_list):
|
|
186
|
+
if file_bytes:
|
|
187
|
+
name = attachments[idx].filename
|
|
188
|
+
mime_type = attachments[idx].content_type or "application/octet-stream"
|
|
189
|
+
file_ref = await session.persist_file(
|
|
190
|
+
name=name, mime=mime_type, content=file_bytes
|
|
191
|
+
)
|
|
192
|
+
file_refs.append(file_ref)
|
|
193
|
+
|
|
194
|
+
files_dicts = [
|
|
195
|
+
session.files[file["id"]] for file in file_refs if file["id"] in session.files
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
file_elements = [Element.from_dict(file_dict) for file_dict in files_dicts]
|
|
199
|
+
|
|
200
|
+
return file_elements
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def clean_content(message: discord.Message):
|
|
204
|
+
if not client.user:
|
|
205
|
+
return message.content
|
|
206
|
+
|
|
207
|
+
# Regex to find mentions of the bot
|
|
208
|
+
bot_mention = f"<@!?{client.user.id}>"
|
|
209
|
+
# Replace the bot's mention with nothing
|
|
210
|
+
return re.sub(bot_mention, "", message.content).strip()
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
async def process_discord_message(
|
|
214
|
+
message: discord.Message,
|
|
215
|
+
thread_name: str,
|
|
216
|
+
channel: "MessageableChannel",
|
|
217
|
+
bind_thread_to_user=False,
|
|
218
|
+
):
|
|
219
|
+
user = await get_user(message.author)
|
|
220
|
+
|
|
221
|
+
thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(channel.id)))
|
|
222
|
+
|
|
223
|
+
text = clean_content(message)
|
|
224
|
+
discord_files = message.attachments
|
|
225
|
+
|
|
226
|
+
session_id = str(uuid.uuid4())
|
|
227
|
+
session = HTTPSession(
|
|
228
|
+
id=session_id,
|
|
229
|
+
thread_id=thread_id,
|
|
230
|
+
user=user,
|
|
231
|
+
client_type="discord",
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
ctx = init_discord_context(
|
|
235
|
+
session=session,
|
|
236
|
+
channel=channel,
|
|
237
|
+
message=message,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
file_elements = await download_discord_files(session, discord_files)
|
|
241
|
+
|
|
242
|
+
msg = Message(
|
|
243
|
+
content=text,
|
|
244
|
+
elements=file_elements,
|
|
245
|
+
type="user_message",
|
|
246
|
+
author=user.metadata.get("name"),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
await msg.send()
|
|
250
|
+
|
|
251
|
+
ctx.emitter.enabled = True
|
|
252
|
+
|
|
253
|
+
if on_chat_start := config.code.on_chat_start:
|
|
254
|
+
await on_chat_start()
|
|
255
|
+
|
|
256
|
+
if on_message := config.code.on_message:
|
|
257
|
+
async with channel.typing():
|
|
258
|
+
await on_message(msg)
|
|
259
|
+
|
|
260
|
+
if on_chat_end := config.code.on_chat_end:
|
|
261
|
+
await on_chat_end()
|
|
262
|
+
|
|
263
|
+
if data_layer := get_data_layer():
|
|
264
|
+
user_id = None
|
|
265
|
+
if isinstance(user, PersistedUser):
|
|
266
|
+
user_id = user.id if bind_thread_to_user else None
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
await data_layer.update_thread(
|
|
270
|
+
thread_id=thread_id,
|
|
271
|
+
name=thread_name,
|
|
272
|
+
metadata=ctx.session.to_persistable(),
|
|
273
|
+
user_id=user_id,
|
|
274
|
+
)
|
|
275
|
+
except Exception as e:
|
|
276
|
+
logger.error(f"Error updating thread: {e}")
|
|
277
|
+
|
|
278
|
+
ctx.session.delete()
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@client.event
|
|
282
|
+
async def on_ready():
|
|
283
|
+
logger.info(f"Logged in as {client.user}")
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@client.event
|
|
287
|
+
async def on_message(message: discord.Message):
|
|
288
|
+
if not client.user or message.author == client.user:
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
is_dm = isinstance(message.channel, discord.DMChannel)
|
|
292
|
+
if not client.user.mentioned_in(message) and not is_dm:
|
|
293
|
+
return
|
|
294
|
+
|
|
295
|
+
thread_name: str = ""
|
|
296
|
+
bind_thread_to_user = False
|
|
297
|
+
channel = message.channel
|
|
298
|
+
|
|
299
|
+
if isinstance(message.channel, discord.Thread):
|
|
300
|
+
thread_name = f"{message.channel.name}"
|
|
301
|
+
elif isinstance(message.channel, discord.ForumChannel):
|
|
302
|
+
thread_name = f"{message.channel.name}"
|
|
303
|
+
elif isinstance(message.channel, discord.DMChannel):
|
|
304
|
+
thread_name = f"{message.author} Discord DM"
|
|
305
|
+
bind_thread_to_user = True
|
|
306
|
+
elif isinstance(message.channel, discord.GroupChannel):
|
|
307
|
+
thread_name = f"{message.channel.name}"
|
|
308
|
+
elif isinstance(message.channel, discord.TextChannel):
|
|
309
|
+
channel = await message.channel.create_thread(
|
|
310
|
+
name=clean_content(message), message=message
|
|
311
|
+
)
|
|
312
|
+
thread_name = f"{channel.name}"
|
|
313
|
+
else:
|
|
314
|
+
logger.warning(f"Unsupported channel type: {message.channel.type}")
|
|
315
|
+
return
|
|
316
|
+
|
|
317
|
+
await process_discord_message(
|
|
318
|
+
message=message,
|
|
319
|
+
thread_name=thread_name,
|
|
320
|
+
channel=channel,
|
|
321
|
+
bind_thread_to_user=bind_thread_to_user,
|
|
322
|
+
)
|
chainlit/element.py
CHANGED
|
@@ -5,6 +5,7 @@ from io import BytesIO
|
|
|
5
5
|
from typing import Any, ClassVar, List, Literal, Optional, TypedDict, TypeVar, Union
|
|
6
6
|
|
|
7
7
|
import filetype
|
|
8
|
+
import mimetypes
|
|
8
9
|
from chainlit.context import context
|
|
9
10
|
from chainlit.data import get_data_layer
|
|
10
11
|
from chainlit.logger import logger
|
|
@@ -165,14 +166,16 @@ class Element:
|
|
|
165
166
|
if self.type in mime_types
|
|
166
167
|
else filetype.guess_mime(self.path or self.content)
|
|
167
168
|
)
|
|
168
|
-
|
|
169
|
+
if not self.mime and self.url:
|
|
170
|
+
self.mime = mimetypes.guess_type(self.url)[0]
|
|
171
|
+
|
|
169
172
|
await self._create()
|
|
170
173
|
|
|
171
174
|
if not self.url and not self.chainlit_key:
|
|
172
175
|
raise ValueError("Must provide url or chainlit key to send element")
|
|
173
176
|
|
|
174
177
|
trace_event(f"send {self.__class__.__name__}")
|
|
175
|
-
await context.emitter.
|
|
178
|
+
await context.emitter.send_element(self.to_dict())
|
|
176
179
|
|
|
177
180
|
|
|
178
181
|
ElementBased = TypeVar("ElementBased", bound=Element)
|
chainlit/emitter.py
CHANGED
|
@@ -4,10 +4,10 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast
|
|
|
4
4
|
|
|
5
5
|
from chainlit.config import config
|
|
6
6
|
from chainlit.data import get_data_layer
|
|
7
|
-
from chainlit.element import Element, File
|
|
7
|
+
from chainlit.element import Element, ElementDict, File
|
|
8
8
|
from chainlit.logger import logger
|
|
9
9
|
from chainlit.message import Message
|
|
10
|
-
from chainlit.session import BaseSession, WebsocketSession
|
|
10
|
+
from chainlit.session import BaseSession, HTTPSession, WebsocketSession
|
|
11
11
|
from chainlit.step import StepDict
|
|
12
12
|
from chainlit.types import (
|
|
13
13
|
AskActionResponse,
|
|
@@ -29,6 +29,7 @@ class BaseChainlitEmitter:
|
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
session: BaseSession
|
|
32
|
+
enabled: bool = True
|
|
32
33
|
|
|
33
34
|
def __init__(self, session: BaseSession) -> None:
|
|
34
35
|
"""Initialize with the user session."""
|
|
@@ -46,6 +47,10 @@ class BaseChainlitEmitter:
|
|
|
46
47
|
"""Stub method to resume a thread."""
|
|
47
48
|
pass
|
|
48
49
|
|
|
50
|
+
async def send_element(self, element_dict: ElementDict):
|
|
51
|
+
"""Stub method to send an element to the UI."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
49
54
|
async def send_step(self, step_dict: StepDict):
|
|
50
55
|
"""Stub method to send a message to the UI."""
|
|
51
56
|
pass
|
|
@@ -151,6 +156,10 @@ class ChainlitEmitter(BaseChainlitEmitter):
|
|
|
151
156
|
"""Send a thread to the UI to resume it"""
|
|
152
157
|
return self.emit("resume_thread", thread_dict)
|
|
153
158
|
|
|
159
|
+
async def send_element(self, element_dict: ElementDict):
|
|
160
|
+
"""Stub method to send an element to the UI."""
|
|
161
|
+
await self.emit("element", element_dict)
|
|
162
|
+
|
|
154
163
|
def send_step(self, step_dict: StepDict):
|
|
155
164
|
"""Send a message to the UI."""
|
|
156
165
|
return self.emit("new_message", step_dict)
|