chainlit 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +30 -7
- chainlit/action.py +2 -4
- chainlit/cache.py +24 -1
- chainlit/cli/__init__.py +64 -21
- chainlit/client/base.py +152 -0
- chainlit/client/cloud.py +440 -0
- chainlit/client/local.py +257 -0
- chainlit/client/utils.py +23 -0
- chainlit/config.py +92 -29
- chainlit/context.py +29 -0
- chainlit/db/__init__.py +35 -0
- chainlit/db/prisma/schema.prisma +48 -0
- chainlit/element.py +54 -41
- chainlit/emitter.py +1 -30
- chainlit/frontend/dist/assets/index-995e21ad.js +11 -0
- chainlit/frontend/dist/assets/index-f93cc942.css +1 -0
- chainlit/frontend/dist/assets/index-fb1e167a.js +523 -0
- chainlit/frontend/dist/index.html +2 -2
- chainlit/lc/agent.py +1 -0
- chainlit/lc/callbacks.py +6 -21
- chainlit/logger.py +7 -2
- chainlit/message.py +22 -16
- chainlit/server.py +169 -59
- chainlit/session.py +1 -3
- chainlit/sync.py +16 -28
- chainlit/types.py +26 -1
- chainlit/user_session.py +1 -1
- {chainlit-0.4.0.dist-info → chainlit-0.4.2.dist-info}/METADATA +8 -3
- chainlit-0.4.2.dist-info/RECORD +44 -0
- chainlit/client.py +0 -287
- chainlit/frontend/dist/assets/index-0cc9e355.css +0 -1
- chainlit/frontend/dist/assets/index-9e4bccd1.js +0 -717
- chainlit-0.4.0.dist-info/RECORD +0 -37
- {chainlit-0.4.0.dist-info → chainlit-0.4.2.dist-info}/WHEEL +0 -0
- {chainlit-0.4.0.dist-info → chainlit-0.4.2.dist-info}/entry_points.txt +0 -0
chainlit/client/cloud.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
from typing import Dict, Any, Optional
|
|
2
|
+
import uuid
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
import aiohttp
|
|
6
|
+
from python_graphql_client import GraphqlClient
|
|
7
|
+
|
|
8
|
+
from .base import BaseClient, PaginatedResponse, PageInfo
|
|
9
|
+
|
|
10
|
+
from chainlit.logger import logger
|
|
11
|
+
from chainlit.config import config
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CloudClient(BaseClient):
|
|
15
|
+
conversation_id: Optional[str] = None
|
|
16
|
+
lock: asyncio.Lock
|
|
17
|
+
|
|
18
|
+
def __init__(self, project_id: str, access_token: str):
|
|
19
|
+
self.lock = asyncio.Lock()
|
|
20
|
+
self.project_id = project_id
|
|
21
|
+
self.headers = {
|
|
22
|
+
"Authorization": access_token,
|
|
23
|
+
"content-type": "application/json",
|
|
24
|
+
}
|
|
25
|
+
graphql_endpoint = f"{config.chainlit_server}/api/graphql"
|
|
26
|
+
self.graphql_client = GraphqlClient(
|
|
27
|
+
endpoint=graphql_endpoint, headers=self.headers
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def query(self, query: str, variables: Dict[str, Any] = {}) -> Dict[str, Any]:
|
|
31
|
+
"""
|
|
32
|
+
Execute a GraphQL query.
|
|
33
|
+
|
|
34
|
+
:param query: The GraphQL query string.
|
|
35
|
+
:param variables: A dictionary of variables for the query.
|
|
36
|
+
:return: The response data as a dictionary.
|
|
37
|
+
"""
|
|
38
|
+
return self.graphql_client.execute_async(query=query, variables=variables)
|
|
39
|
+
|
|
40
|
+
def check_for_errors(self, response: Dict[str, Any], raise_error: bool = False):
|
|
41
|
+
if "errors" in response:
|
|
42
|
+
if raise_error:
|
|
43
|
+
raise Exception(response["errors"][0])
|
|
44
|
+
logger.error(response["errors"][0])
|
|
45
|
+
return True
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
def mutation(self, mutation: str, variables: Dict[str, Any] = {}) -> Dict[str, Any]:
|
|
49
|
+
"""
|
|
50
|
+
Execute a GraphQL mutation.
|
|
51
|
+
|
|
52
|
+
:param mutation: The GraphQL mutation string.
|
|
53
|
+
:param variables: A dictionary of variables for the mutation.
|
|
54
|
+
:return: The response data as a dictionary.
|
|
55
|
+
"""
|
|
56
|
+
return self.graphql_client.execute_async(query=mutation, variables=variables)
|
|
57
|
+
|
|
58
|
+
async def get_member_role(
|
|
59
|
+
self,
|
|
60
|
+
):
|
|
61
|
+
data = {"projectId": self.project_id}
|
|
62
|
+
async with aiohttp.ClientSession() as session:
|
|
63
|
+
async with session.post(
|
|
64
|
+
f"{config.chainlit_server}/api/role",
|
|
65
|
+
json=data,
|
|
66
|
+
headers=self.headers,
|
|
67
|
+
) as r:
|
|
68
|
+
if not r.ok:
|
|
69
|
+
reason = await r.text()
|
|
70
|
+
logger.error(f"Failed to get user role. {r.status}: {reason}")
|
|
71
|
+
return False
|
|
72
|
+
json = await r.json()
|
|
73
|
+
return json.get("role", "ANONYMOUS")
|
|
74
|
+
|
|
75
|
+
async def is_project_member(self) -> bool:
|
|
76
|
+
role = await self.get_member_role()
|
|
77
|
+
return role != "ANONYMOUS"
|
|
78
|
+
|
|
79
|
+
async def get_project_members(self):
|
|
80
|
+
query = """query ($projectId: String!) {
|
|
81
|
+
projectMembers(projectId: $projectId) {
|
|
82
|
+
edges {
|
|
83
|
+
cursor
|
|
84
|
+
node {
|
|
85
|
+
role
|
|
86
|
+
user {
|
|
87
|
+
email
|
|
88
|
+
name
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
}"""
|
|
94
|
+
variables = {"projectId": self.project_id}
|
|
95
|
+
res = await self.query(query, variables)
|
|
96
|
+
self.check_for_errors(res, raise_error=True)
|
|
97
|
+
|
|
98
|
+
members = []
|
|
99
|
+
|
|
100
|
+
for edge in res["data"]["projectMembers"]["edges"]:
|
|
101
|
+
node = edge["node"]
|
|
102
|
+
role = node["role"]
|
|
103
|
+
name = node["user"]["name"]
|
|
104
|
+
email = node["user"]["email"]
|
|
105
|
+
members.append({"role": role, "name": name, "email": email})
|
|
106
|
+
|
|
107
|
+
return members
|
|
108
|
+
|
|
109
|
+
async def create_conversation(self) -> int:
|
|
110
|
+
# If we run multiple send concurrently, we need to make sure we don't create multiple conversations.
|
|
111
|
+
async with self.lock:
|
|
112
|
+
if self.conversation_id:
|
|
113
|
+
return self.conversation_id
|
|
114
|
+
|
|
115
|
+
mutation = """
|
|
116
|
+
mutation ($projectId: String!, $sessionId: String) {
|
|
117
|
+
createConversation(projectId: $projectId, sessionId: $sessionId) {
|
|
118
|
+
id
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
"""
|
|
122
|
+
variables = {"projectId": self.project_id}
|
|
123
|
+
res = await self.mutation(mutation, variables)
|
|
124
|
+
|
|
125
|
+
if self.check_for_errors(res):
|
|
126
|
+
logger.warning("Could not create conversation.")
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
return int(res["data"]["createConversation"]["id"])
|
|
130
|
+
|
|
131
|
+
async def get_conversation_id(self):
|
|
132
|
+
self.conversation_id = await self.create_conversation()
|
|
133
|
+
|
|
134
|
+
return self.conversation_id
|
|
135
|
+
|
|
136
|
+
async def delete_conversation(self, conversation_id: int):
|
|
137
|
+
mutation = """mutation ($id: ID!) {
|
|
138
|
+
deleteConversation(id: $id) {
|
|
139
|
+
id
|
|
140
|
+
}
|
|
141
|
+
}"""
|
|
142
|
+
variables = {"id": conversation_id}
|
|
143
|
+
res = await self.mutation(mutation, variables)
|
|
144
|
+
self.check_for_errors(res, raise_error=True)
|
|
145
|
+
|
|
146
|
+
return True
|
|
147
|
+
|
|
148
|
+
async def get_conversation(self, conversation_id: int):
|
|
149
|
+
query = """query ($id: ID!) {
|
|
150
|
+
conversation(id: $id) {
|
|
151
|
+
id
|
|
152
|
+
createdAt
|
|
153
|
+
messages {
|
|
154
|
+
id
|
|
155
|
+
isError
|
|
156
|
+
indent
|
|
157
|
+
author
|
|
158
|
+
content
|
|
159
|
+
waitForAnswer
|
|
160
|
+
humanFeedback
|
|
161
|
+
language
|
|
162
|
+
prompt
|
|
163
|
+
llmSettings
|
|
164
|
+
authorIsUser
|
|
165
|
+
createdAt
|
|
166
|
+
}
|
|
167
|
+
elements {
|
|
168
|
+
id
|
|
169
|
+
conversationId
|
|
170
|
+
type
|
|
171
|
+
name
|
|
172
|
+
url
|
|
173
|
+
display
|
|
174
|
+
language
|
|
175
|
+
size
|
|
176
|
+
forIds
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
}"""
|
|
180
|
+
variables = {
|
|
181
|
+
"id": conversation_id,
|
|
182
|
+
}
|
|
183
|
+
res = await self.query(query, variables)
|
|
184
|
+
self.check_for_errors(res, raise_error=True)
|
|
185
|
+
|
|
186
|
+
return res["data"]["conversation"]
|
|
187
|
+
|
|
188
|
+
async def get_conversations(self, pagination, filter):
|
|
189
|
+
query = """query (
|
|
190
|
+
$first: Int
|
|
191
|
+
$projectId: String!
|
|
192
|
+
$cursor: String
|
|
193
|
+
$withFeedback: Int
|
|
194
|
+
$authorEmail: String
|
|
195
|
+
$search: String
|
|
196
|
+
) {
|
|
197
|
+
conversations(
|
|
198
|
+
first: $first
|
|
199
|
+
cursor: $cursor
|
|
200
|
+
projectId: $projectId
|
|
201
|
+
withFeedback: $withFeedback
|
|
202
|
+
authorEmail: $authorEmail
|
|
203
|
+
search: $search
|
|
204
|
+
) {
|
|
205
|
+
pageInfo {
|
|
206
|
+
endCursor
|
|
207
|
+
hasNextPage
|
|
208
|
+
}
|
|
209
|
+
edges {
|
|
210
|
+
cursor
|
|
211
|
+
node {
|
|
212
|
+
id
|
|
213
|
+
createdAt
|
|
214
|
+
elementCount
|
|
215
|
+
messageCount
|
|
216
|
+
author {
|
|
217
|
+
name
|
|
218
|
+
email
|
|
219
|
+
}
|
|
220
|
+
messages {
|
|
221
|
+
content
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
}"""
|
|
227
|
+
|
|
228
|
+
variables = {
|
|
229
|
+
"projectId": self.project_id,
|
|
230
|
+
"first": pagination.first,
|
|
231
|
+
"cursor": pagination.cursor,
|
|
232
|
+
"withFeedback": filter.feedback,
|
|
233
|
+
"authorEmail": filter.authorEmail,
|
|
234
|
+
"search": filter.search,
|
|
235
|
+
}
|
|
236
|
+
res = await self.query(query, variables)
|
|
237
|
+
self.check_for_errors(res, raise_error=True)
|
|
238
|
+
|
|
239
|
+
conversations = []
|
|
240
|
+
|
|
241
|
+
for edge in res["data"]["conversations"]["edges"]:
|
|
242
|
+
node = edge["node"]
|
|
243
|
+
conversations.append(node)
|
|
244
|
+
|
|
245
|
+
page_info = res["data"]["conversations"]["pageInfo"]
|
|
246
|
+
|
|
247
|
+
return PaginatedResponse(
|
|
248
|
+
pageInfo=PageInfo(
|
|
249
|
+
hasNextPage=page_info["hasNextPage"],
|
|
250
|
+
endCursor=page_info["endCursor"],
|
|
251
|
+
),
|
|
252
|
+
data=conversations,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
async def set_human_feedback(self, message_id, feedback):
|
|
256
|
+
mutation = """mutation ($messageId: ID!, $humanFeedback: Int!) {
|
|
257
|
+
setHumanFeedback(messageId: $messageId, humanFeedback: $humanFeedback) {
|
|
258
|
+
id
|
|
259
|
+
humanFeedback
|
|
260
|
+
}
|
|
261
|
+
}"""
|
|
262
|
+
variables = {"messageId": message_id, "humanFeedback": feedback}
|
|
263
|
+
res = await self.mutation(mutation, variables)
|
|
264
|
+
self.check_for_errors(res, raise_error=True)
|
|
265
|
+
|
|
266
|
+
return True
|
|
267
|
+
|
|
268
|
+
async def get_message(self):
|
|
269
|
+
raise NotImplementedError
|
|
270
|
+
|
|
271
|
+
async def create_message(self, variables: Dict[str, Any]) -> int:
|
|
272
|
+
c_id = await self.get_conversation_id()
|
|
273
|
+
|
|
274
|
+
if not c_id:
|
|
275
|
+
logger.warning("Missing conversation ID, could not persist the message.")
|
|
276
|
+
return None
|
|
277
|
+
|
|
278
|
+
variables["conversationId"] = c_id
|
|
279
|
+
|
|
280
|
+
mutation = """
|
|
281
|
+
mutation ($conversationId: ID!, $author: String!, $content: String!, $language: String, $prompt: String, $llmSettings: Json, $isError: Boolean, $indent: Int, $authorIsUser: Boolean, $waitForAnswer: Boolean, $createdAt: StringOrFloat) {
|
|
282
|
+
createMessage(conversationId: $conversationId, author: $author, content: $content, language: $language, prompt: $prompt, llmSettings: $llmSettings, isError: $isError, indent: $indent, authorIsUser: $authorIsUser, waitForAnswer: $waitForAnswer, createdAt: $createdAt) {
|
|
283
|
+
id
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
"""
|
|
287
|
+
res = await self.mutation(mutation, variables)
|
|
288
|
+
if self.check_for_errors(res):
|
|
289
|
+
logger.warning("Could not create message.")
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
return int(res["data"]["createMessage"]["id"])
|
|
293
|
+
|
|
294
|
+
async def update_message(self, message_id: int, variables: Dict[str, Any]) -> bool:
|
|
295
|
+
mutation = """
|
|
296
|
+
mutation ($messageId: ID!, $author: String!, $content: String!, $language: String, $prompt: String, $llmSettings: Json) {
|
|
297
|
+
updateMessage(messageId: $messageId, author: $author, content: $content, language: $language, prompt: $prompt, llmSettings: $llmSettings) {
|
|
298
|
+
id
|
|
299
|
+
}
|
|
300
|
+
}
|
|
301
|
+
"""
|
|
302
|
+
variables["messageId"] = message_id
|
|
303
|
+
res = await self.mutation(mutation, variables)
|
|
304
|
+
|
|
305
|
+
if self.check_for_errors(res):
|
|
306
|
+
logger.warning("Could not update message.")
|
|
307
|
+
return False
|
|
308
|
+
|
|
309
|
+
return True
|
|
310
|
+
|
|
311
|
+
async def delete_message(self, message_id: int) -> bool:
|
|
312
|
+
mutation = """
|
|
313
|
+
mutation ($messageId: ID!) {
|
|
314
|
+
deleteMessage(messageId: $messageId) {
|
|
315
|
+
id
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
"""
|
|
319
|
+
res = await self.mutation(mutation, {"messageId": message_id})
|
|
320
|
+
|
|
321
|
+
if self.check_for_errors(res):
|
|
322
|
+
logger.warning("Could not delete message.")
|
|
323
|
+
return False
|
|
324
|
+
|
|
325
|
+
return True
|
|
326
|
+
|
|
327
|
+
async def get_element(self, conversation_id, element_id):
|
|
328
|
+
query = """query (
|
|
329
|
+
$conversationId: ID!
|
|
330
|
+
$id: ID!
|
|
331
|
+
) {
|
|
332
|
+
element(
|
|
333
|
+
conversationId: $conversationId,
|
|
334
|
+
id: $id
|
|
335
|
+
) {
|
|
336
|
+
id
|
|
337
|
+
conversationId
|
|
338
|
+
type
|
|
339
|
+
name
|
|
340
|
+
url
|
|
341
|
+
display
|
|
342
|
+
language
|
|
343
|
+
size
|
|
344
|
+
forIds
|
|
345
|
+
}
|
|
346
|
+
}"""
|
|
347
|
+
|
|
348
|
+
variables = {
|
|
349
|
+
"conversationId": conversation_id,
|
|
350
|
+
"id": element_id,
|
|
351
|
+
}
|
|
352
|
+
res = await self.query(query, variables)
|
|
353
|
+
self.check_for_errors(res, raise_error=True)
|
|
354
|
+
|
|
355
|
+
return res["data"]["element"]
|
|
356
|
+
|
|
357
|
+
async def upsert_element(self, variables):
|
|
358
|
+
c_id = await self.get_conversation_id()
|
|
359
|
+
|
|
360
|
+
if not c_id:
|
|
361
|
+
logger.warning("Missing conversation ID, could not persist the element.")
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
if "id" in variables:
|
|
365
|
+
mutation_name = "updateElement"
|
|
366
|
+
mutation = """
|
|
367
|
+
mutation ($conversationId: ID!, $id: ID!, $forIds: [String!]!) {
|
|
368
|
+
updateElement(conversationId: $conversationId, id: $id, forIds: $forIds) {
|
|
369
|
+
id,
|
|
370
|
+
}
|
|
371
|
+
}
|
|
372
|
+
"""
|
|
373
|
+
variables["conversationId"] = c_id
|
|
374
|
+
res = await self.mutation(mutation, variables)
|
|
375
|
+
else:
|
|
376
|
+
mutation_name = "createElement"
|
|
377
|
+
mutation = """
|
|
378
|
+
mutation ($conversationId: ID!, $type: String!, $url: String!, $name: String!, $display: String!, $forIds: [String!]!, $size: String, $language: String) {
|
|
379
|
+
createElement(conversationId: $conversationId, type: $type, url: $url, name: $name, display: $display, size: $size, language: $language, forIds: $forIds) {
|
|
380
|
+
id,
|
|
381
|
+
type,
|
|
382
|
+
url,
|
|
383
|
+
name,
|
|
384
|
+
display,
|
|
385
|
+
size,
|
|
386
|
+
language,
|
|
387
|
+
forIds
|
|
388
|
+
}
|
|
389
|
+
}
|
|
390
|
+
"""
|
|
391
|
+
variables["conversationId"] = c_id
|
|
392
|
+
res = await self.mutation(mutation, variables)
|
|
393
|
+
|
|
394
|
+
if self.check_for_errors(res):
|
|
395
|
+
logger.warning("Could not persist element.")
|
|
396
|
+
return None
|
|
397
|
+
|
|
398
|
+
return res["data"][mutation_name]
|
|
399
|
+
|
|
400
|
+
async def upload_element(self, content: bytes, mime: str) -> str:
|
|
401
|
+
id = f"{uuid.uuid4()}"
|
|
402
|
+
body = {"projectId": self.project_id, "fileName": id, "contentType": mime}
|
|
403
|
+
|
|
404
|
+
path = f"/api/upload/file"
|
|
405
|
+
|
|
406
|
+
async with aiohttp.ClientSession() as session:
|
|
407
|
+
async with session.post(
|
|
408
|
+
f"{config.chainlit_server}{path}",
|
|
409
|
+
json=body,
|
|
410
|
+
headers=self.headers,
|
|
411
|
+
) as r:
|
|
412
|
+
if not r.ok:
|
|
413
|
+
reason = await r.text()
|
|
414
|
+
logger.error(f"Failed to upload file: {reason}")
|
|
415
|
+
return ""
|
|
416
|
+
json_res = await r.json()
|
|
417
|
+
|
|
418
|
+
upload_details = json_res["post"]
|
|
419
|
+
permanent_url = json_res["permanentUrl"]
|
|
420
|
+
|
|
421
|
+
form_data = aiohttp.FormData()
|
|
422
|
+
|
|
423
|
+
# Add fields to the form_data
|
|
424
|
+
for field_name, field_value in upload_details["fields"].items():
|
|
425
|
+
form_data.add_field(field_name, field_value)
|
|
426
|
+
|
|
427
|
+
# Add file to the form_data
|
|
428
|
+
form_data.add_field("file", content, content_type="multipart/form-data")
|
|
429
|
+
async with aiohttp.ClientSession() as session:
|
|
430
|
+
async with session.post(
|
|
431
|
+
upload_details["url"],
|
|
432
|
+
data=form_data,
|
|
433
|
+
) as upload_response:
|
|
434
|
+
if not upload_response.ok:
|
|
435
|
+
reason = await upload_response.text()
|
|
436
|
+
logger.error(f"Failed to upload file: {reason}")
|
|
437
|
+
return ""
|
|
438
|
+
|
|
439
|
+
url = f'{upload_details["url"]}/{upload_details["fields"]["key"]}'
|
|
440
|
+
return permanent_url
|
chainlit/client/local.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
from typing import Optional, Dict
|
|
2
|
+
import uuid
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import aiofiles
|
|
8
|
+
|
|
9
|
+
from chainlit.client.base import PaginatedResponse, PageInfo
|
|
10
|
+
|
|
11
|
+
from .base import BaseClient
|
|
12
|
+
|
|
13
|
+
from chainlit.logger import logger
|
|
14
|
+
from chainlit.config import config
|
|
15
|
+
from chainlit.element import mime_to_ext
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LocalClient(BaseClient):
|
|
19
|
+
conversation_id: Optional[str] = None
|
|
20
|
+
lock: asyncio.Lock
|
|
21
|
+
|
|
22
|
+
def __init__(self):
|
|
23
|
+
self.lock = asyncio.Lock()
|
|
24
|
+
|
|
25
|
+
def before_write(self, variables: Dict):
|
|
26
|
+
if "llmSettings" in variables:
|
|
27
|
+
# Sqlite doesn't support json fields, so we need to serialize it.
|
|
28
|
+
variables["llmSettings"] = json.dumps(variables["llmSettings"])
|
|
29
|
+
|
|
30
|
+
if "forIds" in variables:
|
|
31
|
+
# Sqlite doesn't support list of primitives, so we need to serialize it.
|
|
32
|
+
variables["forIds"] = json.dumps(variables["forIds"])
|
|
33
|
+
|
|
34
|
+
if "tempId" in variables:
|
|
35
|
+
del variables["tempId"]
|
|
36
|
+
|
|
37
|
+
def after_read(self, variables: Dict):
|
|
38
|
+
if "llmSettings" in variables:
|
|
39
|
+
# Sqlite doesn't support json fields, so we need to parse it.
|
|
40
|
+
variables["llmSettings"] = json.loads(variables["llmSettings"])
|
|
41
|
+
|
|
42
|
+
async def is_project_member(self):
|
|
43
|
+
return True
|
|
44
|
+
|
|
45
|
+
async def get_member_role(self):
|
|
46
|
+
return "OWNER"
|
|
47
|
+
|
|
48
|
+
async def get_project_members(self):
|
|
49
|
+
return []
|
|
50
|
+
|
|
51
|
+
async def get_conversation_id(self):
|
|
52
|
+
self.conversation_id = await self.create_conversation()
|
|
53
|
+
|
|
54
|
+
return self.conversation_id
|
|
55
|
+
|
|
56
|
+
async def create_conversation(self):
|
|
57
|
+
from prisma.models import Conversation
|
|
58
|
+
|
|
59
|
+
# If we run multiple send concurrently, we need to make sure we don't create multiple conversations.
|
|
60
|
+
async with self.lock:
|
|
61
|
+
if self.conversation_id:
|
|
62
|
+
return self.conversation_id
|
|
63
|
+
|
|
64
|
+
res = await Conversation.prisma().create(data={})
|
|
65
|
+
|
|
66
|
+
return res.id
|
|
67
|
+
|
|
68
|
+
async def delete_conversation(self, conversation_id):
|
|
69
|
+
from prisma.models import Conversation
|
|
70
|
+
|
|
71
|
+
await Conversation.prisma().delete(where={"id": conversation_id})
|
|
72
|
+
|
|
73
|
+
return True
|
|
74
|
+
|
|
75
|
+
async def get_conversation(self, conversation_id: int):
|
|
76
|
+
from prisma.models import Conversation
|
|
77
|
+
|
|
78
|
+
c = await Conversation.prisma().find_unique_or_raise(
|
|
79
|
+
where={"id": conversation_id}, include={"messages": True, "elements": True}
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
for m in c.messages:
|
|
83
|
+
if m.llmSettings:
|
|
84
|
+
m.llmSettings = json.loads(m.llmSettings)
|
|
85
|
+
|
|
86
|
+
for e in c.elements:
|
|
87
|
+
if e.forIds:
|
|
88
|
+
e.forIds = json.loads(e.forIds)
|
|
89
|
+
|
|
90
|
+
return json.loads(c.json())
|
|
91
|
+
|
|
92
|
+
async def get_conversations(self, pagination, filter):
|
|
93
|
+
from prisma.models import Conversation
|
|
94
|
+
|
|
95
|
+
some_messages = {}
|
|
96
|
+
|
|
97
|
+
if filter.feedback is not None:
|
|
98
|
+
some_messages["humanFeedback"] = filter.feedback
|
|
99
|
+
|
|
100
|
+
if filter.search is not None:
|
|
101
|
+
some_messages["content"] = {"contains": filter.search or None}
|
|
102
|
+
|
|
103
|
+
if pagination.cursor:
|
|
104
|
+
cursor = {"id": pagination.cursor}
|
|
105
|
+
else:
|
|
106
|
+
cursor = None
|
|
107
|
+
|
|
108
|
+
conversations = await Conversation.prisma().find_many(
|
|
109
|
+
take=pagination.first,
|
|
110
|
+
skip=1 if pagination.cursor else None,
|
|
111
|
+
cursor=cursor,
|
|
112
|
+
include={
|
|
113
|
+
"messages": {
|
|
114
|
+
"take": 1,
|
|
115
|
+
"where": {
|
|
116
|
+
"authorIsUser": True,
|
|
117
|
+
},
|
|
118
|
+
"orderBy": [
|
|
119
|
+
{
|
|
120
|
+
"createdAt": "asc",
|
|
121
|
+
}
|
|
122
|
+
],
|
|
123
|
+
}
|
|
124
|
+
},
|
|
125
|
+
where={"messages": {"some": some_messages}},
|
|
126
|
+
order={
|
|
127
|
+
"createdAt": "desc",
|
|
128
|
+
},
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
has_more = len(conversations) == pagination.first
|
|
132
|
+
|
|
133
|
+
if has_more:
|
|
134
|
+
end_cursor = conversations[-1].id
|
|
135
|
+
else:
|
|
136
|
+
end_cursor = None
|
|
137
|
+
|
|
138
|
+
conversations = [json.loads(c.json()) for c in conversations]
|
|
139
|
+
|
|
140
|
+
return PaginatedResponse(
|
|
141
|
+
pageInfo=PageInfo(hasNextPage=has_more, endCursor=end_cursor),
|
|
142
|
+
data=conversations,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
async def create_message(self, variables):
|
|
146
|
+
from prisma.models import Message
|
|
147
|
+
|
|
148
|
+
c_id = await self.get_conversation_id()
|
|
149
|
+
|
|
150
|
+
if not c_id:
|
|
151
|
+
logger.warning("Missing conversation ID, could not persist the message.")
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
variables = variables.copy()
|
|
155
|
+
|
|
156
|
+
variables["conversationId"] = c_id
|
|
157
|
+
|
|
158
|
+
self.before_write(variables)
|
|
159
|
+
|
|
160
|
+
res = await Message.prisma().create(data=variables)
|
|
161
|
+
return res.id
|
|
162
|
+
|
|
163
|
+
async def get_message(self, message_id):
|
|
164
|
+
from prisma.models import Message
|
|
165
|
+
|
|
166
|
+
res = await Message.prisma().find_first(where={"id": message_id})
|
|
167
|
+
res = res.dict()
|
|
168
|
+
self.after_read(res)
|
|
169
|
+
return res
|
|
170
|
+
|
|
171
|
+
async def update_message(self, message_id, variables):
|
|
172
|
+
from prisma.models import Message
|
|
173
|
+
|
|
174
|
+
variables = variables.copy()
|
|
175
|
+
|
|
176
|
+
self.before_write(variables)
|
|
177
|
+
|
|
178
|
+
await Message.prisma().update(data=variables, where={"id": message_id})
|
|
179
|
+
|
|
180
|
+
return True
|
|
181
|
+
|
|
182
|
+
async def delete_message(self, message_id):
|
|
183
|
+
from prisma.models import Message
|
|
184
|
+
|
|
185
|
+
await Message.prisma().delete(where={"id": message_id})
|
|
186
|
+
|
|
187
|
+
return True
|
|
188
|
+
|
|
189
|
+
async def upsert_element(
|
|
190
|
+
self,
|
|
191
|
+
variables,
|
|
192
|
+
):
|
|
193
|
+
from prisma.models import Element
|
|
194
|
+
|
|
195
|
+
c_id = await self.get_conversation_id()
|
|
196
|
+
|
|
197
|
+
if not c_id:
|
|
198
|
+
logger.warning("Missing conversation ID, could not persist the element.")
|
|
199
|
+
return None
|
|
200
|
+
|
|
201
|
+
variables["conversationId"] = c_id
|
|
202
|
+
|
|
203
|
+
self.before_write(variables)
|
|
204
|
+
|
|
205
|
+
if "id" in variables:
|
|
206
|
+
res = await Element.prisma().update(
|
|
207
|
+
data=variables, where={"id": variables.get("id")}
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
res = await Element.prisma().create(data=variables)
|
|
211
|
+
|
|
212
|
+
return res.dict()
|
|
213
|
+
|
|
214
|
+
async def get_element(
|
|
215
|
+
self,
|
|
216
|
+
conversation_id,
|
|
217
|
+
element_id,
|
|
218
|
+
):
|
|
219
|
+
from prisma.models import Element
|
|
220
|
+
|
|
221
|
+
res = await Element.prisma().find_unique_or_raise(where={"id": element_id})
|
|
222
|
+
return json.loads(res.json())
|
|
223
|
+
|
|
224
|
+
async def upload_element(self, content: bytes, mime: str):
|
|
225
|
+
c_id = await self.get_conversation_id()
|
|
226
|
+
|
|
227
|
+
if not c_id:
|
|
228
|
+
logger.warning("Missing conversation ID, could not persist the message.")
|
|
229
|
+
return None
|
|
230
|
+
|
|
231
|
+
file_ext = mime_to_ext.get(mime, "bin")
|
|
232
|
+
file_name = f"{uuid.uuid4()}.{file_ext}"
|
|
233
|
+
|
|
234
|
+
sub_path = os.path.join(str(c_id), file_name)
|
|
235
|
+
full_path = os.path.join(config.project.local_fs_path, sub_path)
|
|
236
|
+
|
|
237
|
+
if not os.path.exists(os.path.dirname(full_path)):
|
|
238
|
+
os.makedirs(os.path.dirname(full_path))
|
|
239
|
+
|
|
240
|
+
async with aiofiles.open(full_path, "wb") as out:
|
|
241
|
+
await out.write(content)
|
|
242
|
+
await out.flush()
|
|
243
|
+
|
|
244
|
+
url = f"/files/{sub_path}"
|
|
245
|
+
return url
|
|
246
|
+
|
|
247
|
+
async def set_human_feedback(self, message_id, feedback):
|
|
248
|
+
from prisma.models import Message
|
|
249
|
+
|
|
250
|
+
await Message.prisma().update(
|
|
251
|
+
where={"id": message_id},
|
|
252
|
+
data={
|
|
253
|
+
"humanFeedback": feedback,
|
|
254
|
+
},
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
return True
|