chainlit 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +41 -130
- chainlit/action.py +2 -4
- chainlit/cli/__init__.py +64 -9
- chainlit/cli/mock.py +1571 -7
- chainlit/client/base.py +152 -0
- chainlit/client/cloud.py +440 -0
- chainlit/client/local.py +257 -0
- chainlit/client/utils.py +23 -0
- chainlit/config.py +31 -5
- chainlit/context.py +29 -0
- chainlit/db/__init__.py +35 -0
- chainlit/db/prisma/schema.prisma +48 -0
- chainlit/element.py +54 -41
- chainlit/emitter.py +1 -30
- chainlit/frontend/dist/assets/{index-51a1a88f.js → index-37b5009c.js} +1 -1
- chainlit/frontend/dist/assets/index-51393291.js +523 -0
- chainlit/frontend/dist/index.html +1 -1
- chainlit/langflow/__init__.py +75 -0
- chainlit/lc/__init__.py +85 -0
- chainlit/lc/agent.py +9 -5
- chainlit/lc/callbacks.py +9 -24
- chainlit/llama_index/__init__.py +34 -0
- chainlit/llama_index/callbacks.py +99 -0
- chainlit/llama_index/run.py +34 -0
- chainlit/logger.py +7 -2
- chainlit/message.py +25 -19
- chainlit/server.py +149 -38
- chainlit/session.py +3 -3
- chainlit/sync.py +20 -27
- chainlit/types.py +26 -1
- chainlit/user_session.py +1 -1
- chainlit/utils.py +51 -0
- {chainlit-0.4.1.dist-info → chainlit-0.4.3.dist-info}/METADATA +7 -3
- chainlit-0.4.3.dist-info/RECORD +49 -0
- chainlit/client.py +0 -287
- chainlit/frontend/dist/assets/index-68c36c96.js +0 -707
- chainlit-0.4.1.dist-info/RECORD +0 -38
- {chainlit-0.4.1.dist-info → chainlit-0.4.3.dist-info}/WHEEL +0 -0
- {chainlit-0.4.1.dist-info → chainlit-0.4.3.dist-info}/entry_points.txt +0 -0
chainlit/client/local.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
from typing import Optional, Dict
|
|
2
|
+
import uuid
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import aiofiles
|
|
8
|
+
|
|
9
|
+
from chainlit.client.base import PaginatedResponse, PageInfo
|
|
10
|
+
|
|
11
|
+
from .base import BaseClient
|
|
12
|
+
|
|
13
|
+
from chainlit.logger import logger
|
|
14
|
+
from chainlit.config import config
|
|
15
|
+
from chainlit.element import mime_to_ext
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LocalClient(BaseClient):
|
|
19
|
+
conversation_id: Optional[str] = None
|
|
20
|
+
lock: asyncio.Lock
|
|
21
|
+
|
|
22
|
+
def __init__(self):
|
|
23
|
+
self.lock = asyncio.Lock()
|
|
24
|
+
|
|
25
|
+
def before_write(self, variables: Dict):
|
|
26
|
+
if "llmSettings" in variables:
|
|
27
|
+
# Sqlite doesn't support json fields, so we need to serialize it.
|
|
28
|
+
variables["llmSettings"] = json.dumps(variables["llmSettings"])
|
|
29
|
+
|
|
30
|
+
if "forIds" in variables:
|
|
31
|
+
# Sqlite doesn't support list of primitives, so we need to serialize it.
|
|
32
|
+
variables["forIds"] = json.dumps(variables["forIds"])
|
|
33
|
+
|
|
34
|
+
if "tempId" in variables:
|
|
35
|
+
del variables["tempId"]
|
|
36
|
+
|
|
37
|
+
def after_read(self, variables: Dict):
|
|
38
|
+
if "llmSettings" in variables:
|
|
39
|
+
# Sqlite doesn't support json fields, so we need to parse it.
|
|
40
|
+
variables["llmSettings"] = json.loads(variables["llmSettings"])
|
|
41
|
+
|
|
42
|
+
async def is_project_member(self):
|
|
43
|
+
return True
|
|
44
|
+
|
|
45
|
+
async def get_member_role(self):
|
|
46
|
+
return "OWNER"
|
|
47
|
+
|
|
48
|
+
async def get_project_members(self):
|
|
49
|
+
return []
|
|
50
|
+
|
|
51
|
+
async def get_conversation_id(self):
|
|
52
|
+
self.conversation_id = await self.create_conversation()
|
|
53
|
+
|
|
54
|
+
return self.conversation_id
|
|
55
|
+
|
|
56
|
+
async def create_conversation(self):
|
|
57
|
+
from prisma.models import Conversation
|
|
58
|
+
|
|
59
|
+
# If we run multiple send concurrently, we need to make sure we don't create multiple conversations.
|
|
60
|
+
async with self.lock:
|
|
61
|
+
if self.conversation_id:
|
|
62
|
+
return self.conversation_id
|
|
63
|
+
|
|
64
|
+
res = await Conversation.prisma().create(data={})
|
|
65
|
+
|
|
66
|
+
return res.id
|
|
67
|
+
|
|
68
|
+
async def delete_conversation(self, conversation_id):
|
|
69
|
+
from prisma.models import Conversation
|
|
70
|
+
|
|
71
|
+
await Conversation.prisma().delete(where={"id": conversation_id})
|
|
72
|
+
|
|
73
|
+
return True
|
|
74
|
+
|
|
75
|
+
async def get_conversation(self, conversation_id: int):
|
|
76
|
+
from prisma.models import Conversation
|
|
77
|
+
|
|
78
|
+
c = await Conversation.prisma().find_unique_or_raise(
|
|
79
|
+
where={"id": conversation_id}, include={"messages": True, "elements": True}
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
for m in c.messages:
|
|
83
|
+
if m.llmSettings:
|
|
84
|
+
m.llmSettings = json.loads(m.llmSettings)
|
|
85
|
+
|
|
86
|
+
for e in c.elements:
|
|
87
|
+
if e.forIds:
|
|
88
|
+
e.forIds = json.loads(e.forIds)
|
|
89
|
+
|
|
90
|
+
return json.loads(c.json())
|
|
91
|
+
|
|
92
|
+
async def get_conversations(self, pagination, filter):
|
|
93
|
+
from prisma.models import Conversation
|
|
94
|
+
|
|
95
|
+
some_messages = {}
|
|
96
|
+
|
|
97
|
+
if filter.feedback is not None:
|
|
98
|
+
some_messages["humanFeedback"] = filter.feedback
|
|
99
|
+
|
|
100
|
+
if filter.search is not None:
|
|
101
|
+
some_messages["content"] = {"contains": filter.search or None}
|
|
102
|
+
|
|
103
|
+
if pagination.cursor:
|
|
104
|
+
cursor = {"id": pagination.cursor}
|
|
105
|
+
else:
|
|
106
|
+
cursor = None
|
|
107
|
+
|
|
108
|
+
conversations = await Conversation.prisma().find_many(
|
|
109
|
+
take=pagination.first,
|
|
110
|
+
skip=1 if pagination.cursor else None,
|
|
111
|
+
cursor=cursor,
|
|
112
|
+
include={
|
|
113
|
+
"messages": {
|
|
114
|
+
"take": 1,
|
|
115
|
+
"where": {
|
|
116
|
+
"authorIsUser": True,
|
|
117
|
+
},
|
|
118
|
+
"orderBy": [
|
|
119
|
+
{
|
|
120
|
+
"createdAt": "asc",
|
|
121
|
+
}
|
|
122
|
+
],
|
|
123
|
+
}
|
|
124
|
+
},
|
|
125
|
+
where={"messages": {"some": some_messages}},
|
|
126
|
+
order={
|
|
127
|
+
"createdAt": "desc",
|
|
128
|
+
},
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
has_more = len(conversations) == pagination.first
|
|
132
|
+
|
|
133
|
+
if has_more:
|
|
134
|
+
end_cursor = conversations[-1].id
|
|
135
|
+
else:
|
|
136
|
+
end_cursor = None
|
|
137
|
+
|
|
138
|
+
conversations = [json.loads(c.json()) for c in conversations]
|
|
139
|
+
|
|
140
|
+
return PaginatedResponse(
|
|
141
|
+
pageInfo=PageInfo(hasNextPage=has_more, endCursor=end_cursor),
|
|
142
|
+
data=conversations,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
async def create_message(self, variables):
|
|
146
|
+
from prisma.models import Message
|
|
147
|
+
|
|
148
|
+
c_id = await self.get_conversation_id()
|
|
149
|
+
|
|
150
|
+
if not c_id:
|
|
151
|
+
logger.warning("Missing conversation ID, could not persist the message.")
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
variables = variables.copy()
|
|
155
|
+
|
|
156
|
+
variables["conversationId"] = c_id
|
|
157
|
+
|
|
158
|
+
self.before_write(variables)
|
|
159
|
+
|
|
160
|
+
res = await Message.prisma().create(data=variables)
|
|
161
|
+
return res.id
|
|
162
|
+
|
|
163
|
+
async def get_message(self, message_id):
|
|
164
|
+
from prisma.models import Message
|
|
165
|
+
|
|
166
|
+
res = await Message.prisma().find_first(where={"id": message_id})
|
|
167
|
+
res = res.dict()
|
|
168
|
+
self.after_read(res)
|
|
169
|
+
return res
|
|
170
|
+
|
|
171
|
+
async def update_message(self, message_id, variables):
|
|
172
|
+
from prisma.models import Message
|
|
173
|
+
|
|
174
|
+
variables = variables.copy()
|
|
175
|
+
|
|
176
|
+
self.before_write(variables)
|
|
177
|
+
|
|
178
|
+
await Message.prisma().update(data=variables, where={"id": message_id})
|
|
179
|
+
|
|
180
|
+
return True
|
|
181
|
+
|
|
182
|
+
async def delete_message(self, message_id):
|
|
183
|
+
from prisma.models import Message
|
|
184
|
+
|
|
185
|
+
await Message.prisma().delete(where={"id": message_id})
|
|
186
|
+
|
|
187
|
+
return True
|
|
188
|
+
|
|
189
|
+
async def upsert_element(
|
|
190
|
+
self,
|
|
191
|
+
variables,
|
|
192
|
+
):
|
|
193
|
+
from prisma.models import Element
|
|
194
|
+
|
|
195
|
+
c_id = await self.get_conversation_id()
|
|
196
|
+
|
|
197
|
+
if not c_id:
|
|
198
|
+
logger.warning("Missing conversation ID, could not persist the element.")
|
|
199
|
+
return None
|
|
200
|
+
|
|
201
|
+
variables["conversationId"] = c_id
|
|
202
|
+
|
|
203
|
+
self.before_write(variables)
|
|
204
|
+
|
|
205
|
+
if "id" in variables:
|
|
206
|
+
res = await Element.prisma().update(
|
|
207
|
+
data=variables, where={"id": variables.get("id")}
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
res = await Element.prisma().create(data=variables)
|
|
211
|
+
|
|
212
|
+
return res.dict()
|
|
213
|
+
|
|
214
|
+
async def get_element(
|
|
215
|
+
self,
|
|
216
|
+
conversation_id,
|
|
217
|
+
element_id,
|
|
218
|
+
):
|
|
219
|
+
from prisma.models import Element
|
|
220
|
+
|
|
221
|
+
res = await Element.prisma().find_unique_or_raise(where={"id": element_id})
|
|
222
|
+
return json.loads(res.json())
|
|
223
|
+
|
|
224
|
+
async def upload_element(self, content: bytes, mime: str):
|
|
225
|
+
c_id = await self.get_conversation_id()
|
|
226
|
+
|
|
227
|
+
if not c_id:
|
|
228
|
+
logger.warning("Missing conversation ID, could not persist the message.")
|
|
229
|
+
return None
|
|
230
|
+
|
|
231
|
+
file_ext = mime_to_ext.get(mime, "bin")
|
|
232
|
+
file_name = f"{uuid.uuid4()}.{file_ext}"
|
|
233
|
+
|
|
234
|
+
sub_path = os.path.join(str(c_id), file_name)
|
|
235
|
+
full_path = os.path.join(config.project.local_fs_path, sub_path)
|
|
236
|
+
|
|
237
|
+
if not os.path.exists(os.path.dirname(full_path)):
|
|
238
|
+
os.makedirs(os.path.dirname(full_path))
|
|
239
|
+
|
|
240
|
+
async with aiofiles.open(full_path, "wb") as out:
|
|
241
|
+
await out.write(content)
|
|
242
|
+
await out.flush()
|
|
243
|
+
|
|
244
|
+
url = f"/files/{sub_path}"
|
|
245
|
+
return url
|
|
246
|
+
|
|
247
|
+
async def set_human_feedback(self, message_id, feedback):
|
|
248
|
+
from prisma.models import Message
|
|
249
|
+
|
|
250
|
+
await Message.prisma().update(
|
|
251
|
+
where={"id": message_id},
|
|
252
|
+
data={
|
|
253
|
+
"humanFeedback": feedback,
|
|
254
|
+
},
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
return True
|
chainlit/client/utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from fastapi import HTTPException, Request
|
|
2
|
+
|
|
3
|
+
from chainlit.config import config
|
|
4
|
+
from chainlit.client.base import BaseClient
|
|
5
|
+
from chainlit.client.local import LocalClient
|
|
6
|
+
from chainlit.client.cloud import CloudClient
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
async def get_client(request: Request) -> BaseClient:
|
|
10
|
+
auth_header = request.headers.get("Authorization")
|
|
11
|
+
|
|
12
|
+
db = config.project.database
|
|
13
|
+
|
|
14
|
+
if db == "local":
|
|
15
|
+
client = LocalClient()
|
|
16
|
+
elif db == "cloud":
|
|
17
|
+
client = CloudClient(config.project.id, auth_header)
|
|
18
|
+
elif db == "custom":
|
|
19
|
+
client = await config.code.client_factory()
|
|
20
|
+
else:
|
|
21
|
+
raise HTTPException(status_code=500, detail="Invalid database type")
|
|
22
|
+
|
|
23
|
+
return client
|
chainlit/config.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional, Any, Callable, List, Dict, TYPE_CHECKING
|
|
1
|
+
from typing import Optional, Any, Callable, Union, Literal, List, Dict, TYPE_CHECKING
|
|
2
2
|
import os
|
|
3
3
|
import sys
|
|
4
4
|
import tomli
|
|
@@ -10,6 +10,7 @@ from chainlit.version import __version__
|
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
12
|
from chainlit.action import Action
|
|
13
|
+
from chainlit.client.base import BaseClient
|
|
13
14
|
|
|
14
15
|
PACKAGE_ROOT = os.path.dirname(__file__)
|
|
15
16
|
|
|
@@ -26,10 +27,15 @@ DEFAULT_CONFIG_STR = f"""[project]
|
|
|
26
27
|
public = true
|
|
27
28
|
|
|
28
29
|
# The project ID (found on https://cloud.chainlit.io).
|
|
29
|
-
#
|
|
30
|
-
# The project ID is required when public is set to false.
|
|
30
|
+
# The project ID is required when public is set to false or when using the cloud database.
|
|
31
31
|
#id = ""
|
|
32
32
|
|
|
33
|
+
# Uncomment if you want to persist the chats.
|
|
34
|
+
# local will create a database in your .chainlit directory (requires node.js installed).
|
|
35
|
+
# cloud will use the Chainlit cloud database.
|
|
36
|
+
# custom will load use your custom client.
|
|
37
|
+
# database = "local"
|
|
38
|
+
|
|
33
39
|
# Whether to enable telemetry (default: true). No personal data is collected.
|
|
34
40
|
enable_telemetry = true
|
|
35
41
|
|
|
@@ -102,15 +108,19 @@ class CodeSettings:
|
|
|
102
108
|
lc_postprocess: Optional[Callable[[Any], str]] = None
|
|
103
109
|
lc_factory: Optional[Callable[[], Any]] = None
|
|
104
110
|
lc_rename: Optional[Callable[[str], str]] = None
|
|
111
|
+
llama_index_factory: Optional[Callable[[], Any]] = None
|
|
112
|
+
langflow_schema: Union[Dict, str] = None
|
|
113
|
+
client_factory: Optional[Callable[[str], "BaseClient"]] = None
|
|
105
114
|
|
|
106
115
|
def validate(self):
|
|
107
116
|
requires_one_of = [
|
|
108
117
|
"lc_factory",
|
|
118
|
+
"llama_index_factory",
|
|
109
119
|
"on_message",
|
|
110
120
|
"on_chat_start",
|
|
111
121
|
]
|
|
112
122
|
|
|
113
|
-
mutually_exclusive = ["lc_factory"]
|
|
123
|
+
mutually_exclusive = ["lc_factory", "llama_index_factory"]
|
|
114
124
|
|
|
115
125
|
# Check if at least one of the required attributes is set
|
|
116
126
|
if not any(getattr(self, attr) for attr in requires_one_of):
|
|
@@ -136,12 +146,18 @@ class ProjectSettings:
|
|
|
136
146
|
id: Optional[str] = None
|
|
137
147
|
# Whether the app is available to anonymous users or only to team members.
|
|
138
148
|
public: bool = True
|
|
149
|
+
# Storage type
|
|
150
|
+
database: Optional[Literal["local", "cloud", "custom"]] = None
|
|
139
151
|
# Whether to enable telemetry. No personal data is collected.
|
|
140
152
|
enable_telemetry: bool = True
|
|
141
153
|
# List of environment variables to be provided by each user to use the app. If empty, no environment variables will be asked to the user.
|
|
142
154
|
user_env: List[str] = None
|
|
143
155
|
# Path to the local langchain cache database
|
|
144
156
|
lc_cache_path: str = None
|
|
157
|
+
# Path to the local chat db
|
|
158
|
+
local_db_path: str = None
|
|
159
|
+
# Path to the local file system
|
|
160
|
+
local_fs_path: str = None
|
|
145
161
|
|
|
146
162
|
|
|
147
163
|
@dataclass()
|
|
@@ -203,10 +219,20 @@ def load_settings():
|
|
|
203
219
|
)
|
|
204
220
|
|
|
205
221
|
lc_cache_path = os.path.join(config_dir, ".langchain.db")
|
|
222
|
+
local_db_path = os.path.join(config_dir, "chat.db")
|
|
223
|
+
local_fs_path = os.path.join(config_dir, "chat_files")
|
|
224
|
+
|
|
225
|
+
os.environ[
|
|
226
|
+
"LOCAL_DB_PATH"
|
|
227
|
+
] = f"file:{local_db_path}?socket_timeout=10&connection_limit=1"
|
|
206
228
|
|
|
207
229
|
project_settings = ProjectSettings(
|
|
208
|
-
lc_cache_path=lc_cache_path,
|
|
230
|
+
lc_cache_path=lc_cache_path,
|
|
231
|
+
local_db_path=local_db_path,
|
|
232
|
+
local_fs_path=local_fs_path,
|
|
233
|
+
**project_config,
|
|
209
234
|
)
|
|
235
|
+
|
|
210
236
|
ui_settings = UISettings(**ui_settings)
|
|
211
237
|
|
|
212
238
|
if not project_settings.public and not project_settings.id:
|
chainlit/context.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import contextvars
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
from asyncio import AbstractEventLoop
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from chainlit.emitter import ChainlitEmitter
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ChainlitContextException(Exception):
|
|
10
|
+
def __init__(self, msg="Chainlit context not found", *args, **kwargs):
|
|
11
|
+
super().__init__(msg, *args, **kwargs)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
emitter_var = contextvars.ContextVar("emitter")
|
|
15
|
+
loop_var = contextvars.ContextVar("loop")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_emitter() -> "ChainlitEmitter":
|
|
19
|
+
try:
|
|
20
|
+
return emitter_var.get()
|
|
21
|
+
except LookupError:
|
|
22
|
+
raise ChainlitContextException()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_loop() -> AbstractEventLoop:
|
|
26
|
+
try:
|
|
27
|
+
return loop_var.get()
|
|
28
|
+
except LookupError:
|
|
29
|
+
raise ChainlitContextException()
|
chainlit/db/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from chainlit.logger import logger
|
|
3
|
+
from chainlit.config import config, PACKAGE_ROOT
|
|
4
|
+
|
|
5
|
+
SCHEMA_PATH = os.path.join(PACKAGE_ROOT, "db/prisma/schema.prisma")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def db_push():
|
|
9
|
+
from prisma.cli.prisma import run
|
|
10
|
+
import prisma
|
|
11
|
+
from importlib import reload
|
|
12
|
+
|
|
13
|
+
args = ["db", "push", f"--schema={SCHEMA_PATH}"]
|
|
14
|
+
env = {"LOCAL_DB_PATH": os.environ.get("LOCAL_DB_PATH")}
|
|
15
|
+
run(args, env=env)
|
|
16
|
+
|
|
17
|
+
# Without this the client will fail to initialize the first time.
|
|
18
|
+
reload(prisma)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def init_local_db():
|
|
22
|
+
use_local_db = config.project.database == "local"
|
|
23
|
+
if use_local_db:
|
|
24
|
+
if not os.path.exists(config.project.local_db_path):
|
|
25
|
+
db_push()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def migrate_local_db():
|
|
29
|
+
use_local_db = config.project.database == "local"
|
|
30
|
+
if use_local_db:
|
|
31
|
+
if os.path.exists(config.project.local_db_path):
|
|
32
|
+
db_push()
|
|
33
|
+
logger.info(f"Local db migrated")
|
|
34
|
+
else:
|
|
35
|
+
logger.info(f"Local db does not exist, skipping migration")
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
datasource db {
|
|
2
|
+
provider = "sqlite"
|
|
3
|
+
url = env("LOCAL_DB_PATH")
|
|
4
|
+
}
|
|
5
|
+
|
|
6
|
+
generator client {
|
|
7
|
+
provider = "prisma-client-py"
|
|
8
|
+
recursive_type_depth = 5
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
model Conversation {
|
|
12
|
+
id Int @id @default(autoincrement())
|
|
13
|
+
createdAt DateTime @default(now())
|
|
14
|
+
messages Message[]
|
|
15
|
+
elements Element[]
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
model Element {
|
|
19
|
+
id Int @id @default(autoincrement())
|
|
20
|
+
createdAt DateTime @default(now())
|
|
21
|
+
conversationId Int
|
|
22
|
+
conversation Conversation @relation(fields: [conversationId], references: [id], onDelete: Cascade)
|
|
23
|
+
type String
|
|
24
|
+
url String
|
|
25
|
+
name String
|
|
26
|
+
display String
|
|
27
|
+
size String?
|
|
28
|
+
language String?
|
|
29
|
+
forIds String?
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
model Message {
|
|
33
|
+
id Int @id @default(autoincrement())
|
|
34
|
+
createdAt DateTime @default(now())
|
|
35
|
+
conversationId Int
|
|
36
|
+
conversation Conversation @relation(fields: [conversationId], references: [id], onDelete: Cascade)
|
|
37
|
+
authorIsUser Boolean @default(false)
|
|
38
|
+
isError Boolean @default(false)
|
|
39
|
+
waitForAnswer Boolean @default(false)
|
|
40
|
+
indent Int @default(0)
|
|
41
|
+
author String
|
|
42
|
+
content String
|
|
43
|
+
humanFeedback Int @default(0)
|
|
44
|
+
language String?
|
|
45
|
+
prompt String?
|
|
46
|
+
// Sqlite does not support JSON
|
|
47
|
+
llmSettings String?
|
|
48
|
+
}
|
chainlit/element.py
CHANGED
|
@@ -1,22 +1,27 @@
|
|
|
1
1
|
from pydantic.dataclasses import dataclass
|
|
2
|
-
from
|
|
3
|
-
from typing import Dict, Union, Any
|
|
2
|
+
from typing import Dict, List, Union, Any
|
|
4
3
|
import uuid
|
|
5
4
|
import aiofiles
|
|
6
5
|
from io import BytesIO
|
|
7
6
|
|
|
8
|
-
from chainlit.
|
|
7
|
+
from chainlit.context import get_emitter
|
|
8
|
+
from chainlit.client.base import BaseClient
|
|
9
9
|
from chainlit.telemetry import trace_event
|
|
10
10
|
from chainlit.types import ElementType, ElementDisplay, ElementSize
|
|
11
11
|
|
|
12
12
|
type_to_mime = {
|
|
13
|
-
"image": "
|
|
13
|
+
"image": "image/png",
|
|
14
14
|
"text": "text/plain",
|
|
15
15
|
"pdf": "application/pdf",
|
|
16
16
|
}
|
|
17
17
|
|
|
18
|
+
mime_to_ext = {
|
|
19
|
+
"image/png": "png",
|
|
20
|
+
"text/plain": "txt",
|
|
21
|
+
"application/pdf": "pdf",
|
|
22
|
+
}
|
|
23
|
+
|
|
18
24
|
|
|
19
|
-
@dataclass_json
|
|
20
25
|
@dataclass
|
|
21
26
|
class Element:
|
|
22
27
|
# Name of the element, this will be used to reference the element in the UI.
|
|
@@ -34,20 +39,35 @@ class Element:
|
|
|
34
39
|
# The ID of the element. This is set automatically when the element is sent to the UI if cloud is enabled.
|
|
35
40
|
id: int = None
|
|
36
41
|
# The ID of the element if cloud is disabled.
|
|
37
|
-
|
|
42
|
+
temp_id: str = None
|
|
38
43
|
# The ID of the message this element is associated with.
|
|
39
|
-
|
|
44
|
+
for_ids: List[str] = None
|
|
40
45
|
|
|
41
46
|
def __post_init__(self) -> None:
|
|
42
47
|
trace_event(f"init {self.__class__.__name__}")
|
|
43
48
|
self.emitter = get_emitter()
|
|
44
|
-
|
|
45
|
-
|
|
49
|
+
self.for_ids = []
|
|
50
|
+
self.temp_id = str(uuid.uuid4())
|
|
46
51
|
|
|
47
52
|
if not self.url and not self.path and not self.content:
|
|
48
53
|
raise ValueError("Must provide url, path or content to instantiate element")
|
|
49
54
|
|
|
50
|
-
|
|
55
|
+
def to_dict(self) -> Dict:
|
|
56
|
+
_dict = {
|
|
57
|
+
"tempId": self.temp_id,
|
|
58
|
+
"type": self.type,
|
|
59
|
+
"url": self.url,
|
|
60
|
+
"name": self.name,
|
|
61
|
+
"display": self.display,
|
|
62
|
+
"size": getattr(self, "size", None),
|
|
63
|
+
"language": getattr(self, "language", None),
|
|
64
|
+
"forIds": getattr(self, "for_ids", None),
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
if self.id:
|
|
68
|
+
_dict["id"] = self.id
|
|
69
|
+
|
|
70
|
+
return _dict
|
|
51
71
|
|
|
52
72
|
async def preprocess_content(self):
|
|
53
73
|
pass
|
|
@@ -56,29 +76,15 @@ class Element:
|
|
|
56
76
|
if self.path:
|
|
57
77
|
async with aiofiles.open(self.path, "rb") as f:
|
|
58
78
|
self.content = await f.read()
|
|
59
|
-
await self.preprocess_content()
|
|
60
|
-
elif self.content:
|
|
61
|
-
await self.preprocess_content()
|
|
62
79
|
else:
|
|
63
80
|
raise ValueError("Must provide path or content to load element")
|
|
64
81
|
|
|
65
|
-
async def persist(self, client: BaseClient
|
|
66
|
-
if not self.url and self.content:
|
|
82
|
+
async def persist(self, client: BaseClient):
|
|
83
|
+
if not self.url and self.content and not self.id:
|
|
67
84
|
self.url = await client.upload_element(
|
|
68
85
|
content=self.content, mime=type_to_mime[self.type]
|
|
69
86
|
)
|
|
70
|
-
|
|
71
|
-
size = getattr(self, "size", None)
|
|
72
|
-
language = getattr(self, "language", None)
|
|
73
|
-
element = await client.create_element(
|
|
74
|
-
name=self.name,
|
|
75
|
-
url=self.url,
|
|
76
|
-
type=self.type,
|
|
77
|
-
display=self.display,
|
|
78
|
-
size=size,
|
|
79
|
-
language=language,
|
|
80
|
-
for_id=for_id,
|
|
81
|
-
)
|
|
87
|
+
element = await client.upsert_element(self.to_dict())
|
|
82
88
|
return element
|
|
83
89
|
|
|
84
90
|
async def before_emit(self, element: Dict) -> Dict:
|
|
@@ -90,22 +96,34 @@ class Element:
|
|
|
90
96
|
if not self.content and not self.url and self.path:
|
|
91
97
|
await self.load()
|
|
92
98
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
self.
|
|
99
|
+
await self.preprocess_content()
|
|
100
|
+
|
|
101
|
+
if for_id:
|
|
102
|
+
self.for_ids.append(for_id)
|
|
103
|
+
|
|
104
|
+
# We have a client, persist the element
|
|
105
|
+
if self.emitter.client:
|
|
106
|
+
element = await self.persist(self.emitter.client)
|
|
107
|
+
self.id = element and element.get("id")
|
|
97
108
|
|
|
98
109
|
elif not self.url and not self.content:
|
|
99
110
|
raise ValueError("Must provide url or content to send element")
|
|
100
111
|
|
|
101
112
|
element = self.to_dict()
|
|
102
|
-
|
|
103
|
-
|
|
113
|
+
|
|
114
|
+
element["content"] = self.content
|
|
104
115
|
|
|
105
116
|
if self.emitter.emit and element:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
117
|
+
if len(self.for_ids) > 1:
|
|
118
|
+
trace_event(f"update {self.__class__.__name__}")
|
|
119
|
+
await self.emitter.emit(
|
|
120
|
+
"update_element",
|
|
121
|
+
{"id": self.id or self.temp_id, "forIds": self.for_ids},
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
trace_event(f"send {self.__class__.__name__}")
|
|
125
|
+
element = await self.before_emit(element)
|
|
126
|
+
await self.emitter.emit("element", element)
|
|
109
127
|
|
|
110
128
|
|
|
111
129
|
@dataclass
|
|
@@ -187,8 +205,3 @@ class Pyplot(Element):
|
|
|
187
205
|
self.content = image.getvalue()
|
|
188
206
|
|
|
189
207
|
super().__post_init__()
|
|
190
|
-
|
|
191
|
-
async def before_emit(self, element: Dict) -> Dict:
|
|
192
|
-
# Prevent the figure from being serialized
|
|
193
|
-
del element["figure"]
|
|
194
|
-
return element
|
chainlit/emitter.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
from typing import Union, Dict
|
|
2
2
|
from chainlit.session import Session
|
|
3
3
|
from chainlit.types import AskSpec
|
|
4
|
-
from chainlit.client import BaseClient
|
|
4
|
+
from chainlit.client.base import BaseClient
|
|
5
5
|
from socketio.exceptions import TimeoutError
|
|
6
|
-
import inspect
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
class ChainlitEmitter:
|
|
@@ -129,31 +128,3 @@ class ChainlitEmitter:
|
|
|
129
128
|
def send_token(self, id: Union[str, int], token: str):
|
|
130
129
|
"""Send a message token to the UI."""
|
|
131
130
|
return self.emit("stream_token", {"id": id, "token": token})
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
def get_emitter() -> Union[ChainlitEmitter, None]:
|
|
135
|
-
"""
|
|
136
|
-
Get the Chainlit Emitter instance from the current call stack.
|
|
137
|
-
This unusual approach is necessary because:
|
|
138
|
-
- we need to get the right Emitter instance with the right websocket connection
|
|
139
|
-
- to preserve a lean developer experience, we do not pass the Emitter instance to every function call
|
|
140
|
-
|
|
141
|
-
What happens is that we set __chainlit_emitter__ in the local variables when we receive a websocket message.
|
|
142
|
-
Then we can retrieve it from the call stack when we need it, even if the developer's code has no idea about it.
|
|
143
|
-
"""
|
|
144
|
-
attr = "__chainlit_emitter__"
|
|
145
|
-
candidates = [i[0].f_locals.get(attr) for i in inspect.stack()]
|
|
146
|
-
emitter = None
|
|
147
|
-
for candidate in candidates:
|
|
148
|
-
if candidate:
|
|
149
|
-
emitter = candidate
|
|
150
|
-
break
|
|
151
|
-
|
|
152
|
-
return emitter
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
def get_emit_fn():
|
|
156
|
-
emitter = get_emitter()
|
|
157
|
-
if emitter:
|
|
158
|
-
return emitter.emit
|
|
159
|
-
return None
|