chainlit 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

@@ -0,0 +1,440 @@
1
+ from typing import Dict, Any, Optional
2
+ import uuid
3
+
4
+ import asyncio
5
+ import aiohttp
6
+ from python_graphql_client import GraphqlClient
7
+
8
+ from .base import BaseClient, PaginatedResponse, PageInfo
9
+
10
+ from chainlit.logger import logger
11
+ from chainlit.config import config
12
+
13
+
14
+ class CloudClient(BaseClient):
15
+ conversation_id: Optional[str] = None
16
+ lock: asyncio.Lock
17
+
18
+ def __init__(self, project_id: str, access_token: str):
19
+ self.lock = asyncio.Lock()
20
+ self.project_id = project_id
21
+ self.headers = {
22
+ "Authorization": access_token,
23
+ "content-type": "application/json",
24
+ }
25
+ graphql_endpoint = f"{config.chainlit_server}/api/graphql"
26
+ self.graphql_client = GraphqlClient(
27
+ endpoint=graphql_endpoint, headers=self.headers
28
+ )
29
+
30
+ def query(self, query: str, variables: Dict[str, Any] = {}) -> Dict[str, Any]:
31
+ """
32
+ Execute a GraphQL query.
33
+
34
+ :param query: The GraphQL query string.
35
+ :param variables: A dictionary of variables for the query.
36
+ :return: The response data as a dictionary.
37
+ """
38
+ return self.graphql_client.execute_async(query=query, variables=variables)
39
+
40
+ def check_for_errors(self, response: Dict[str, Any], raise_error: bool = False):
41
+ if "errors" in response:
42
+ if raise_error:
43
+ raise Exception(response["errors"][0])
44
+ logger.error(response["errors"][0])
45
+ return True
46
+ return False
47
+
48
+ def mutation(self, mutation: str, variables: Dict[str, Any] = {}) -> Dict[str, Any]:
49
+ """
50
+ Execute a GraphQL mutation.
51
+
52
+ :param mutation: The GraphQL mutation string.
53
+ :param variables: A dictionary of variables for the mutation.
54
+ :return: The response data as a dictionary.
55
+ """
56
+ return self.graphql_client.execute_async(query=mutation, variables=variables)
57
+
58
+ async def get_member_role(
59
+ self,
60
+ ):
61
+ data = {"projectId": self.project_id}
62
+ async with aiohttp.ClientSession() as session:
63
+ async with session.post(
64
+ f"{config.chainlit_server}/api/role",
65
+ json=data,
66
+ headers=self.headers,
67
+ ) as r:
68
+ if not r.ok:
69
+ reason = await r.text()
70
+ logger.error(f"Failed to get user role. {r.status}: {reason}")
71
+ return False
72
+ json = await r.json()
73
+ return json.get("role", "ANONYMOUS")
74
+
75
+ async def is_project_member(self) -> bool:
76
+ role = await self.get_member_role()
77
+ return role != "ANONYMOUS"
78
+
79
+ async def get_project_members(self):
80
+ query = """query ($projectId: String!) {
81
+ projectMembers(projectId: $projectId) {
82
+ edges {
83
+ cursor
84
+ node {
85
+ role
86
+ user {
87
+ email
88
+ name
89
+ }
90
+ }
91
+ }
92
+ }
93
+ }"""
94
+ variables = {"projectId": self.project_id}
95
+ res = await self.query(query, variables)
96
+ self.check_for_errors(res, raise_error=True)
97
+
98
+ members = []
99
+
100
+ for edge in res["data"]["projectMembers"]["edges"]:
101
+ node = edge["node"]
102
+ role = node["role"]
103
+ name = node["user"]["name"]
104
+ email = node["user"]["email"]
105
+ members.append({"role": role, "name": name, "email": email})
106
+
107
+ return members
108
+
109
+ async def create_conversation(self) -> int:
110
+ # If we run multiple send concurrently, we need to make sure we don't create multiple conversations.
111
+ async with self.lock:
112
+ if self.conversation_id:
113
+ return self.conversation_id
114
+
115
+ mutation = """
116
+ mutation ($projectId: String!, $sessionId: String) {
117
+ createConversation(projectId: $projectId, sessionId: $sessionId) {
118
+ id
119
+ }
120
+ }
121
+ """
122
+ variables = {"projectId": self.project_id}
123
+ res = await self.mutation(mutation, variables)
124
+
125
+ if self.check_for_errors(res):
126
+ logger.warning("Could not create conversation.")
127
+ return None
128
+
129
+ return int(res["data"]["createConversation"]["id"])
130
+
131
+ async def get_conversation_id(self):
132
+ self.conversation_id = await self.create_conversation()
133
+
134
+ return self.conversation_id
135
+
136
+ async def delete_conversation(self, conversation_id: int):
137
+ mutation = """mutation ($id: ID!) {
138
+ deleteConversation(id: $id) {
139
+ id
140
+ }
141
+ }"""
142
+ variables = {"id": conversation_id}
143
+ res = await self.mutation(mutation, variables)
144
+ self.check_for_errors(res, raise_error=True)
145
+
146
+ return True
147
+
148
+ async def get_conversation(self, conversation_id: int):
149
+ query = """query ($id: ID!) {
150
+ conversation(id: $id) {
151
+ id
152
+ createdAt
153
+ messages {
154
+ id
155
+ isError
156
+ indent
157
+ author
158
+ content
159
+ waitForAnswer
160
+ humanFeedback
161
+ language
162
+ prompt
163
+ llmSettings
164
+ authorIsUser
165
+ createdAt
166
+ }
167
+ elements {
168
+ id
169
+ conversationId
170
+ type
171
+ name
172
+ url
173
+ display
174
+ language
175
+ size
176
+ forIds
177
+ }
178
+ }
179
+ }"""
180
+ variables = {
181
+ "id": conversation_id,
182
+ }
183
+ res = await self.query(query, variables)
184
+ self.check_for_errors(res, raise_error=True)
185
+
186
+ return res["data"]["conversation"]
187
+
188
+ async def get_conversations(self, pagination, filter):
189
+ query = """query (
190
+ $first: Int
191
+ $projectId: String!
192
+ $cursor: String
193
+ $withFeedback: Int
194
+ $authorEmail: String
195
+ $search: String
196
+ ) {
197
+ conversations(
198
+ first: $first
199
+ cursor: $cursor
200
+ projectId: $projectId
201
+ withFeedback: $withFeedback
202
+ authorEmail: $authorEmail
203
+ search: $search
204
+ ) {
205
+ pageInfo {
206
+ endCursor
207
+ hasNextPage
208
+ }
209
+ edges {
210
+ cursor
211
+ node {
212
+ id
213
+ createdAt
214
+ elementCount
215
+ messageCount
216
+ author {
217
+ name
218
+ email
219
+ }
220
+ messages {
221
+ content
222
+ }
223
+ }
224
+ }
225
+ }
226
+ }"""
227
+
228
+ variables = {
229
+ "projectId": self.project_id,
230
+ "first": pagination.first,
231
+ "cursor": pagination.cursor,
232
+ "withFeedback": filter.feedback,
233
+ "authorEmail": filter.authorEmail,
234
+ "search": filter.search,
235
+ }
236
+ res = await self.query(query, variables)
237
+ self.check_for_errors(res, raise_error=True)
238
+
239
+ conversations = []
240
+
241
+ for edge in res["data"]["conversations"]["edges"]:
242
+ node = edge["node"]
243
+ conversations.append(node)
244
+
245
+ page_info = res["data"]["conversations"]["pageInfo"]
246
+
247
+ return PaginatedResponse(
248
+ pageInfo=PageInfo(
249
+ hasNextPage=page_info["hasNextPage"],
250
+ endCursor=page_info["endCursor"],
251
+ ),
252
+ data=conversations,
253
+ )
254
+
255
+ async def set_human_feedback(self, message_id, feedback):
256
+ mutation = """mutation ($messageId: ID!, $humanFeedback: Int!) {
257
+ setHumanFeedback(messageId: $messageId, humanFeedback: $humanFeedback) {
258
+ id
259
+ humanFeedback
260
+ }
261
+ }"""
262
+ variables = {"messageId": message_id, "humanFeedback": feedback}
263
+ res = await self.mutation(mutation, variables)
264
+ self.check_for_errors(res, raise_error=True)
265
+
266
+ return True
267
+
268
+ async def get_message(self):
269
+ raise NotImplementedError
270
+
271
+ async def create_message(self, variables: Dict[str, Any]) -> int:
272
+ c_id = await self.get_conversation_id()
273
+
274
+ if not c_id:
275
+ logger.warning("Missing conversation ID, could not persist the message.")
276
+ return None
277
+
278
+ variables["conversationId"] = c_id
279
+
280
+ mutation = """
281
+ mutation ($conversationId: ID!, $author: String!, $content: String!, $language: String, $prompt: String, $llmSettings: Json, $isError: Boolean, $indent: Int, $authorIsUser: Boolean, $waitForAnswer: Boolean, $createdAt: StringOrFloat) {
282
+ createMessage(conversationId: $conversationId, author: $author, content: $content, language: $language, prompt: $prompt, llmSettings: $llmSettings, isError: $isError, indent: $indent, authorIsUser: $authorIsUser, waitForAnswer: $waitForAnswer, createdAt: $createdAt) {
283
+ id
284
+ }
285
+ }
286
+ """
287
+ res = await self.mutation(mutation, variables)
288
+ if self.check_for_errors(res):
289
+ logger.warning("Could not create message.")
290
+ return None
291
+
292
+ return int(res["data"]["createMessage"]["id"])
293
+
294
+ async def update_message(self, message_id: int, variables: Dict[str, Any]) -> bool:
295
+ mutation = """
296
+ mutation ($messageId: ID!, $author: String!, $content: String!, $language: String, $prompt: String, $llmSettings: Json) {
297
+ updateMessage(messageId: $messageId, author: $author, content: $content, language: $language, prompt: $prompt, llmSettings: $llmSettings) {
298
+ id
299
+ }
300
+ }
301
+ """
302
+ variables["messageId"] = message_id
303
+ res = await self.mutation(mutation, variables)
304
+
305
+ if self.check_for_errors(res):
306
+ logger.warning("Could not update message.")
307
+ return False
308
+
309
+ return True
310
+
311
+ async def delete_message(self, message_id: int) -> bool:
312
+ mutation = """
313
+ mutation ($messageId: ID!) {
314
+ deleteMessage(messageId: $messageId) {
315
+ id
316
+ }
317
+ }
318
+ """
319
+ res = await self.mutation(mutation, {"messageId": message_id})
320
+
321
+ if self.check_for_errors(res):
322
+ logger.warning("Could not delete message.")
323
+ return False
324
+
325
+ return True
326
+
327
+ async def get_element(self, conversation_id, element_id):
328
+ query = """query (
329
+ $conversationId: ID!
330
+ $id: ID!
331
+ ) {
332
+ element(
333
+ conversationId: $conversationId,
334
+ id: $id
335
+ ) {
336
+ id
337
+ conversationId
338
+ type
339
+ name
340
+ url
341
+ display
342
+ language
343
+ size
344
+ forIds
345
+ }
346
+ }"""
347
+
348
+ variables = {
349
+ "conversationId": conversation_id,
350
+ "id": element_id,
351
+ }
352
+ res = await self.query(query, variables)
353
+ self.check_for_errors(res, raise_error=True)
354
+
355
+ return res["data"]["element"]
356
+
357
+ async def upsert_element(self, variables):
358
+ c_id = await self.get_conversation_id()
359
+
360
+ if not c_id:
361
+ logger.warning("Missing conversation ID, could not persist the element.")
362
+ return None
363
+
364
+ if "id" in variables:
365
+ mutation_name = "updateElement"
366
+ mutation = """
367
+ mutation ($conversationId: ID!, $id: ID!, $forIds: [String!]!) {
368
+ updateElement(conversationId: $conversationId, id: $id, forIds: $forIds) {
369
+ id,
370
+ }
371
+ }
372
+ """
373
+ variables["conversationId"] = c_id
374
+ res = await self.mutation(mutation, variables)
375
+ else:
376
+ mutation_name = "createElement"
377
+ mutation = """
378
+ mutation ($conversationId: ID!, $type: String!, $url: String!, $name: String!, $display: String!, $forIds: [String!]!, $size: String, $language: String) {
379
+ createElement(conversationId: $conversationId, type: $type, url: $url, name: $name, display: $display, size: $size, language: $language, forIds: $forIds) {
380
+ id,
381
+ type,
382
+ url,
383
+ name,
384
+ display,
385
+ size,
386
+ language,
387
+ forIds
388
+ }
389
+ }
390
+ """
391
+ variables["conversationId"] = c_id
392
+ res = await self.mutation(mutation, variables)
393
+
394
+ if self.check_for_errors(res):
395
+ logger.warning("Could not persist element.")
396
+ return None
397
+
398
+ return res["data"][mutation_name]
399
+
400
+ async def upload_element(self, content: bytes, mime: str) -> str:
401
+ id = f"{uuid.uuid4()}"
402
+ body = {"projectId": self.project_id, "fileName": id, "contentType": mime}
403
+
404
+ path = f"/api/upload/file"
405
+
406
+ async with aiohttp.ClientSession() as session:
407
+ async with session.post(
408
+ f"{config.chainlit_server}{path}",
409
+ json=body,
410
+ headers=self.headers,
411
+ ) as r:
412
+ if not r.ok:
413
+ reason = await r.text()
414
+ logger.error(f"Failed to upload file: {reason}")
415
+ return ""
416
+ json_res = await r.json()
417
+
418
+ upload_details = json_res["post"]
419
+ permanent_url = json_res["permanentUrl"]
420
+
421
+ form_data = aiohttp.FormData()
422
+
423
+ # Add fields to the form_data
424
+ for field_name, field_value in upload_details["fields"].items():
425
+ form_data.add_field(field_name, field_value)
426
+
427
+ # Add file to the form_data
428
+ form_data.add_field("file", content, content_type="multipart/form-data")
429
+ async with aiohttp.ClientSession() as session:
430
+ async with session.post(
431
+ upload_details["url"],
432
+ data=form_data,
433
+ ) as upload_response:
434
+ if not upload_response.ok:
435
+ reason = await upload_response.text()
436
+ logger.error(f"Failed to upload file: {reason}")
437
+ return ""
438
+
439
+ url = f'{upload_details["url"]}/{upload_details["fields"]["key"]}'
440
+ return permanent_url
@@ -0,0 +1,257 @@
1
+ from typing import Optional, Dict
2
+ import uuid
3
+ import json
4
+ import os
5
+
6
+ import asyncio
7
+ import aiofiles
8
+
9
+ from chainlit.client.base import PaginatedResponse, PageInfo
10
+
11
+ from .base import BaseClient
12
+
13
+ from chainlit.logger import logger
14
+ from chainlit.config import config
15
+ from chainlit.element import mime_to_ext
16
+
17
+
18
+ class LocalClient(BaseClient):
19
+ conversation_id: Optional[str] = None
20
+ lock: asyncio.Lock
21
+
22
+ def __init__(self):
23
+ self.lock = asyncio.Lock()
24
+
25
+ def before_write(self, variables: Dict):
26
+ if "llmSettings" in variables:
27
+ # Sqlite doesn't support json fields, so we need to serialize it.
28
+ variables["llmSettings"] = json.dumps(variables["llmSettings"])
29
+
30
+ if "forIds" in variables:
31
+ # Sqlite doesn't support list of primitives, so we need to serialize it.
32
+ variables["forIds"] = json.dumps(variables["forIds"])
33
+
34
+ if "tempId" in variables:
35
+ del variables["tempId"]
36
+
37
+ def after_read(self, variables: Dict):
38
+ if "llmSettings" in variables:
39
+ # Sqlite doesn't support json fields, so we need to parse it.
40
+ variables["llmSettings"] = json.loads(variables["llmSettings"])
41
+
42
+ async def is_project_member(self):
43
+ return True
44
+
45
+ async def get_member_role(self):
46
+ return "OWNER"
47
+
48
+ async def get_project_members(self):
49
+ return []
50
+
51
+ async def get_conversation_id(self):
52
+ self.conversation_id = await self.create_conversation()
53
+
54
+ return self.conversation_id
55
+
56
+ async def create_conversation(self):
57
+ from prisma.models import Conversation
58
+
59
+ # If we run multiple send concurrently, we need to make sure we don't create multiple conversations.
60
+ async with self.lock:
61
+ if self.conversation_id:
62
+ return self.conversation_id
63
+
64
+ res = await Conversation.prisma().create(data={})
65
+
66
+ return res.id
67
+
68
+ async def delete_conversation(self, conversation_id):
69
+ from prisma.models import Conversation
70
+
71
+ await Conversation.prisma().delete(where={"id": conversation_id})
72
+
73
+ return True
74
+
75
+ async def get_conversation(self, conversation_id: int):
76
+ from prisma.models import Conversation
77
+
78
+ c = await Conversation.prisma().find_unique_or_raise(
79
+ where={"id": conversation_id}, include={"messages": True, "elements": True}
80
+ )
81
+
82
+ for m in c.messages:
83
+ if m.llmSettings:
84
+ m.llmSettings = json.loads(m.llmSettings)
85
+
86
+ for e in c.elements:
87
+ if e.forIds:
88
+ e.forIds = json.loads(e.forIds)
89
+
90
+ return json.loads(c.json())
91
+
92
+ async def get_conversations(self, pagination, filter):
93
+ from prisma.models import Conversation
94
+
95
+ some_messages = {}
96
+
97
+ if filter.feedback is not None:
98
+ some_messages["humanFeedback"] = filter.feedback
99
+
100
+ if filter.search is not None:
101
+ some_messages["content"] = {"contains": filter.search or None}
102
+
103
+ if pagination.cursor:
104
+ cursor = {"id": pagination.cursor}
105
+ else:
106
+ cursor = None
107
+
108
+ conversations = await Conversation.prisma().find_many(
109
+ take=pagination.first,
110
+ skip=1 if pagination.cursor else None,
111
+ cursor=cursor,
112
+ include={
113
+ "messages": {
114
+ "take": 1,
115
+ "where": {
116
+ "authorIsUser": True,
117
+ },
118
+ "orderBy": [
119
+ {
120
+ "createdAt": "asc",
121
+ }
122
+ ],
123
+ }
124
+ },
125
+ where={"messages": {"some": some_messages}},
126
+ order={
127
+ "createdAt": "desc",
128
+ },
129
+ )
130
+
131
+ has_more = len(conversations) == pagination.first
132
+
133
+ if has_more:
134
+ end_cursor = conversations[-1].id
135
+ else:
136
+ end_cursor = None
137
+
138
+ conversations = [json.loads(c.json()) for c in conversations]
139
+
140
+ return PaginatedResponse(
141
+ pageInfo=PageInfo(hasNextPage=has_more, endCursor=end_cursor),
142
+ data=conversations,
143
+ )
144
+
145
+ async def create_message(self, variables):
146
+ from prisma.models import Message
147
+
148
+ c_id = await self.get_conversation_id()
149
+
150
+ if not c_id:
151
+ logger.warning("Missing conversation ID, could not persist the message.")
152
+ return None
153
+
154
+ variables = variables.copy()
155
+
156
+ variables["conversationId"] = c_id
157
+
158
+ self.before_write(variables)
159
+
160
+ res = await Message.prisma().create(data=variables)
161
+ return res.id
162
+
163
+ async def get_message(self, message_id):
164
+ from prisma.models import Message
165
+
166
+ res = await Message.prisma().find_first(where={"id": message_id})
167
+ res = res.dict()
168
+ self.after_read(res)
169
+ return res
170
+
171
+ async def update_message(self, message_id, variables):
172
+ from prisma.models import Message
173
+
174
+ variables = variables.copy()
175
+
176
+ self.before_write(variables)
177
+
178
+ await Message.prisma().update(data=variables, where={"id": message_id})
179
+
180
+ return True
181
+
182
+ async def delete_message(self, message_id):
183
+ from prisma.models import Message
184
+
185
+ await Message.prisma().delete(where={"id": message_id})
186
+
187
+ return True
188
+
189
+ async def upsert_element(
190
+ self,
191
+ variables,
192
+ ):
193
+ from prisma.models import Element
194
+
195
+ c_id = await self.get_conversation_id()
196
+
197
+ if not c_id:
198
+ logger.warning("Missing conversation ID, could not persist the element.")
199
+ return None
200
+
201
+ variables["conversationId"] = c_id
202
+
203
+ self.before_write(variables)
204
+
205
+ if "id" in variables:
206
+ res = await Element.prisma().update(
207
+ data=variables, where={"id": variables.get("id")}
208
+ )
209
+ else:
210
+ res = await Element.prisma().create(data=variables)
211
+
212
+ return res.dict()
213
+
214
+ async def get_element(
215
+ self,
216
+ conversation_id,
217
+ element_id,
218
+ ):
219
+ from prisma.models import Element
220
+
221
+ res = await Element.prisma().find_unique_or_raise(where={"id": element_id})
222
+ return json.loads(res.json())
223
+
224
+ async def upload_element(self, content: bytes, mime: str):
225
+ c_id = await self.get_conversation_id()
226
+
227
+ if not c_id:
228
+ logger.warning("Missing conversation ID, could not persist the message.")
229
+ return None
230
+
231
+ file_ext = mime_to_ext.get(mime, "bin")
232
+ file_name = f"{uuid.uuid4()}.{file_ext}"
233
+
234
+ sub_path = os.path.join(str(c_id), file_name)
235
+ full_path = os.path.join(config.project.local_fs_path, sub_path)
236
+
237
+ if not os.path.exists(os.path.dirname(full_path)):
238
+ os.makedirs(os.path.dirname(full_path))
239
+
240
+ async with aiofiles.open(full_path, "wb") as out:
241
+ await out.write(content)
242
+ await out.flush()
243
+
244
+ url = f"/files/{sub_path}"
245
+ return url
246
+
247
+ async def set_human_feedback(self, message_id, feedback):
248
+ from prisma.models import Message
249
+
250
+ await Message.prisma().update(
251
+ where={"id": message_id},
252
+ data={
253
+ "humanFeedback": feedback,
254
+ },
255
+ )
256
+
257
+ return True