chainlit 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

@@ -0,0 +1,152 @@
1
+ from typing import (
2
+ Dict,
3
+ Any,
4
+ List,
5
+ TypedDict,
6
+ Optional,
7
+ Union,
8
+ Literal,
9
+ TypeVar,
10
+ Generic,
11
+ )
12
+
13
+ from abc import ABC, abstractmethod
14
+ from pydantic.dataclasses import dataclass
15
+ from dataclasses_json import dataclass_json
16
+
17
+ from chainlit.types import (
18
+ Pagination,
19
+ ConversationFilter,
20
+ ElementType,
21
+ ElementSize,
22
+ ElementDisplay,
23
+ )
24
+
25
+
26
+ class MessageDict(TypedDict):
27
+ conversationId: Optional[str]
28
+ id: Optional[int]
29
+ tempId: Optional[str]
30
+ createdAt: Optional[int]
31
+ content: str
32
+ author: str
33
+ prompt: Optional[str]
34
+ llmSettings: Dict
35
+ language: Optional[str]
36
+ indent: Optional[int]
37
+ authorIsUser: Optional[bool]
38
+ waitForAnswer: Optional[bool]
39
+ isError: Optional[bool]
40
+ humanFeedback: Optional[int]
41
+
42
+
43
+ class UserDict(TypedDict):
44
+ name: str
45
+ email: str
46
+ role: str
47
+
48
+
49
+ class ElementDict(TypedDict):
50
+ id: Optional[int]
51
+ conversationId: Optional[int]
52
+ type: ElementType
53
+ url: str
54
+ name: str
55
+ display: ElementDisplay
56
+ size: ElementSize
57
+ language: str
58
+ forIds: Optional[List[Union[str, int]]]
59
+
60
+
61
+ class ConversationDict(TypedDict):
62
+ id: Optional[int]
63
+ createdAt: Optional[int]
64
+ elementCount: Optional[int]
65
+ messageCount: Optional[int]
66
+ author: Optional[UserDict]
67
+ messages: List[MessageDict]
68
+ elements: Optional[List[ElementDict]]
69
+
70
+
71
+ @dataclass
72
+ class PageInfo:
73
+ hasNextPage: bool
74
+ endCursor: Any
75
+
76
+
77
+ T = TypeVar("T")
78
+
79
+
80
+ @dataclass_json
81
+ @dataclass
82
+ class PaginatedResponse(Generic[T]):
83
+ pageInfo: PageInfo
84
+ data: List[T]
85
+
86
+
87
+ class BaseClient(ABC):
88
+ project_id: str
89
+
90
+ @abstractmethod
91
+ async def is_project_member(self, access_token: str) -> bool:
92
+ pass
93
+
94
+ @abstractmethod
95
+ async def get_member_role(self, access_token: str) -> str:
96
+ pass
97
+
98
+ @abstractmethod
99
+ async def get_project_members(self) -> List[UserDict]:
100
+ pass
101
+
102
+ @abstractmethod
103
+ async def create_conversation(self) -> int:
104
+ pass
105
+
106
+ @abstractmethod
107
+ async def delete_conversation(self, conversation_id: int) -> bool:
108
+ pass
109
+
110
+ @abstractmethod
111
+ async def get_conversation(self, conversation_id: int) -> ConversationDict:
112
+ pass
113
+
114
+ @abstractmethod
115
+ async def get_conversations(
116
+ self, pagination: "Pagination", filter: "ConversationFilter"
117
+ ) -> PaginatedResponse[ConversationDict]:
118
+ pass
119
+
120
+ @abstractmethod
121
+ async def get_message(self, conversation_id: str, message_id: str) -> Dict:
122
+ pass
123
+
124
+ @abstractmethod
125
+ async def create_message(self, variables: MessageDict) -> int:
126
+ pass
127
+
128
+ @abstractmethod
129
+ async def update_message(self, message_id: int, variables: MessageDict) -> bool:
130
+ pass
131
+
132
+ @abstractmethod
133
+ async def delete_message(self, message_id: int) -> bool:
134
+ pass
135
+
136
+ @abstractmethod
137
+ async def upload_element(self, content: bytes, mime: str) -> str:
138
+ pass
139
+
140
+ @abstractmethod
141
+ async def upsert_element(self, variables: ElementDict) -> ElementDict:
142
+ pass
143
+
144
+ @abstractmethod
145
+ async def get_element(self, conversation_id: int, element_id: int) -> ElementDict:
146
+ pass
147
+
148
+ @abstractmethod
149
+ async def set_human_feedback(
150
+ self, message_id: int, feedback: Literal[-1, 0, 1]
151
+ ) -> bool:
152
+ pass
@@ -0,0 +1,440 @@
1
+ from typing import Dict, Any, Optional
2
+ import uuid
3
+
4
+ import asyncio
5
+ import aiohttp
6
+ from python_graphql_client import GraphqlClient
7
+
8
+ from .base import BaseClient, PaginatedResponse, PageInfo
9
+
10
+ from chainlit.logger import logger
11
+ from chainlit.config import config
12
+
13
+
14
+ class CloudClient(BaseClient):
15
+ conversation_id: Optional[str] = None
16
+ lock: asyncio.Lock
17
+
18
+ def __init__(self, project_id: str, access_token: str):
19
+ self.lock = asyncio.Lock()
20
+ self.project_id = project_id
21
+ self.headers = {
22
+ "Authorization": access_token,
23
+ "content-type": "application/json",
24
+ }
25
+ graphql_endpoint = f"{config.chainlit_server}/api/graphql"
26
+ self.graphql_client = GraphqlClient(
27
+ endpoint=graphql_endpoint, headers=self.headers
28
+ )
29
+
30
+ def query(self, query: str, variables: Dict[str, Any] = {}) -> Dict[str, Any]:
31
+ """
32
+ Execute a GraphQL query.
33
+
34
+ :param query: The GraphQL query string.
35
+ :param variables: A dictionary of variables for the query.
36
+ :return: The response data as a dictionary.
37
+ """
38
+ return self.graphql_client.execute_async(query=query, variables=variables)
39
+
40
+ def check_for_errors(self, response: Dict[str, Any], raise_error: bool = False):
41
+ if "errors" in response:
42
+ if raise_error:
43
+ raise Exception(response["errors"][0])
44
+ logger.error(response["errors"][0])
45
+ return True
46
+ return False
47
+
48
+ def mutation(self, mutation: str, variables: Dict[str, Any] = {}) -> Dict[str, Any]:
49
+ """
50
+ Execute a GraphQL mutation.
51
+
52
+ :param mutation: The GraphQL mutation string.
53
+ :param variables: A dictionary of variables for the mutation.
54
+ :return: The response data as a dictionary.
55
+ """
56
+ return self.graphql_client.execute_async(query=mutation, variables=variables)
57
+
58
+ async def get_member_role(
59
+ self,
60
+ ):
61
+ data = {"projectId": self.project_id}
62
+ async with aiohttp.ClientSession() as session:
63
+ async with session.post(
64
+ f"{config.chainlit_server}/api/role",
65
+ json=data,
66
+ headers=self.headers,
67
+ ) as r:
68
+ if not r.ok:
69
+ reason = await r.text()
70
+ logger.error(f"Failed to get user role. {r.status}: {reason}")
71
+ return False
72
+ json = await r.json()
73
+ return json.get("role", "ANONYMOUS")
74
+
75
+ async def is_project_member(self) -> bool:
76
+ role = await self.get_member_role()
77
+ return role != "ANONYMOUS"
78
+
79
+ async def get_project_members(self):
80
+ query = """query ($projectId: String!) {
81
+ projectMembers(projectId: $projectId) {
82
+ edges {
83
+ cursor
84
+ node {
85
+ role
86
+ user {
87
+ email
88
+ name
89
+ }
90
+ }
91
+ }
92
+ }
93
+ }"""
94
+ variables = {"projectId": self.project_id}
95
+ res = await self.query(query, variables)
96
+ self.check_for_errors(res, raise_error=True)
97
+
98
+ members = []
99
+
100
+ for edge in res["data"]["projectMembers"]["edges"]:
101
+ node = edge["node"]
102
+ role = node["role"]
103
+ name = node["user"]["name"]
104
+ email = node["user"]["email"]
105
+ members.append({"role": role, "name": name, "email": email})
106
+
107
+ return members
108
+
109
+ async def create_conversation(self) -> int:
110
+ # If we run multiple send concurrently, we need to make sure we don't create multiple conversations.
111
+ async with self.lock:
112
+ if self.conversation_id:
113
+ return self.conversation_id
114
+
115
+ mutation = """
116
+ mutation ($projectId: String!, $sessionId: String) {
117
+ createConversation(projectId: $projectId, sessionId: $sessionId) {
118
+ id
119
+ }
120
+ }
121
+ """
122
+ variables = {"projectId": self.project_id}
123
+ res = await self.mutation(mutation, variables)
124
+
125
+ if self.check_for_errors(res):
126
+ logger.warning("Could not create conversation.")
127
+ return None
128
+
129
+ return int(res["data"]["createConversation"]["id"])
130
+
131
+ async def get_conversation_id(self):
132
+ self.conversation_id = await self.create_conversation()
133
+
134
+ return self.conversation_id
135
+
136
+ async def delete_conversation(self, conversation_id: int):
137
+ mutation = """mutation ($id: ID!) {
138
+ deleteConversation(id: $id) {
139
+ id
140
+ }
141
+ }"""
142
+ variables = {"id": conversation_id}
143
+ res = await self.mutation(mutation, variables)
144
+ self.check_for_errors(res, raise_error=True)
145
+
146
+ return True
147
+
148
+ async def get_conversation(self, conversation_id: int):
149
+ query = """query ($id: ID!) {
150
+ conversation(id: $id) {
151
+ id
152
+ createdAt
153
+ messages {
154
+ id
155
+ isError
156
+ indent
157
+ author
158
+ content
159
+ waitForAnswer
160
+ humanFeedback
161
+ language
162
+ prompt
163
+ llmSettings
164
+ authorIsUser
165
+ createdAt
166
+ }
167
+ elements {
168
+ id
169
+ conversationId
170
+ type
171
+ name
172
+ url
173
+ display
174
+ language
175
+ size
176
+ forIds
177
+ }
178
+ }
179
+ }"""
180
+ variables = {
181
+ "id": conversation_id,
182
+ }
183
+ res = await self.query(query, variables)
184
+ self.check_for_errors(res, raise_error=True)
185
+
186
+ return res["data"]["conversation"]
187
+
188
+ async def get_conversations(self, pagination, filter):
189
+ query = """query (
190
+ $first: Int
191
+ $projectId: String!
192
+ $cursor: String
193
+ $withFeedback: Int
194
+ $authorEmail: String
195
+ $search: String
196
+ ) {
197
+ conversations(
198
+ first: $first
199
+ cursor: $cursor
200
+ projectId: $projectId
201
+ withFeedback: $withFeedback
202
+ authorEmail: $authorEmail
203
+ search: $search
204
+ ) {
205
+ pageInfo {
206
+ endCursor
207
+ hasNextPage
208
+ }
209
+ edges {
210
+ cursor
211
+ node {
212
+ id
213
+ createdAt
214
+ elementCount
215
+ messageCount
216
+ author {
217
+ name
218
+ email
219
+ }
220
+ messages {
221
+ content
222
+ }
223
+ }
224
+ }
225
+ }
226
+ }"""
227
+
228
+ variables = {
229
+ "projectId": self.project_id,
230
+ "first": pagination.first,
231
+ "cursor": pagination.cursor,
232
+ "withFeedback": filter.feedback,
233
+ "authorEmail": filter.authorEmail,
234
+ "search": filter.search,
235
+ }
236
+ res = await self.query(query, variables)
237
+ self.check_for_errors(res, raise_error=True)
238
+
239
+ conversations = []
240
+
241
+ for edge in res["data"]["conversations"]["edges"]:
242
+ node = edge["node"]
243
+ conversations.append(node)
244
+
245
+ page_info = res["data"]["conversations"]["pageInfo"]
246
+
247
+ return PaginatedResponse(
248
+ pageInfo=PageInfo(
249
+ hasNextPage=page_info["hasNextPage"],
250
+ endCursor=page_info["endCursor"],
251
+ ),
252
+ data=conversations,
253
+ )
254
+
255
+ async def set_human_feedback(self, message_id, feedback):
256
+ mutation = """mutation ($messageId: ID!, $humanFeedback: Int!) {
257
+ setHumanFeedback(messageId: $messageId, humanFeedback: $humanFeedback) {
258
+ id
259
+ humanFeedback
260
+ }
261
+ }"""
262
+ variables = {"messageId": message_id, "humanFeedback": feedback}
263
+ res = await self.mutation(mutation, variables)
264
+ self.check_for_errors(res, raise_error=True)
265
+
266
+ return True
267
+
268
+ async def get_message(self):
269
+ raise NotImplementedError
270
+
271
+ async def create_message(self, variables: Dict[str, Any]) -> int:
272
+ c_id = await self.get_conversation_id()
273
+
274
+ if not c_id:
275
+ logger.warning("Missing conversation ID, could not persist the message.")
276
+ return None
277
+
278
+ variables["conversationId"] = c_id
279
+
280
+ mutation = """
281
+ mutation ($conversationId: ID!, $author: String!, $content: String!, $language: String, $prompt: String, $llmSettings: Json, $isError: Boolean, $indent: Int, $authorIsUser: Boolean, $waitForAnswer: Boolean, $createdAt: StringOrFloat) {
282
+ createMessage(conversationId: $conversationId, author: $author, content: $content, language: $language, prompt: $prompt, llmSettings: $llmSettings, isError: $isError, indent: $indent, authorIsUser: $authorIsUser, waitForAnswer: $waitForAnswer, createdAt: $createdAt) {
283
+ id
284
+ }
285
+ }
286
+ """
287
+ res = await self.mutation(mutation, variables)
288
+ if self.check_for_errors(res):
289
+ logger.warning("Could not create message.")
290
+ return None
291
+
292
+ return int(res["data"]["createMessage"]["id"])
293
+
294
+ async def update_message(self, message_id: int, variables: Dict[str, Any]) -> bool:
295
+ mutation = """
296
+ mutation ($messageId: ID!, $author: String!, $content: String!, $language: String, $prompt: String, $llmSettings: Json) {
297
+ updateMessage(messageId: $messageId, author: $author, content: $content, language: $language, prompt: $prompt, llmSettings: $llmSettings) {
298
+ id
299
+ }
300
+ }
301
+ """
302
+ variables["messageId"] = message_id
303
+ res = await self.mutation(mutation, variables)
304
+
305
+ if self.check_for_errors(res):
306
+ logger.warning("Could not update message.")
307
+ return False
308
+
309
+ return True
310
+
311
+ async def delete_message(self, message_id: int) -> bool:
312
+ mutation = """
313
+ mutation ($messageId: ID!) {
314
+ deleteMessage(messageId: $messageId) {
315
+ id
316
+ }
317
+ }
318
+ """
319
+ res = await self.mutation(mutation, {"messageId": message_id})
320
+
321
+ if self.check_for_errors(res):
322
+ logger.warning("Could not delete message.")
323
+ return False
324
+
325
+ return True
326
+
327
+ async def get_element(self, conversation_id, element_id):
328
+ query = """query (
329
+ $conversationId: ID!
330
+ $id: ID!
331
+ ) {
332
+ element(
333
+ conversationId: $conversationId,
334
+ id: $id
335
+ ) {
336
+ id
337
+ conversationId
338
+ type
339
+ name
340
+ url
341
+ display
342
+ language
343
+ size
344
+ forIds
345
+ }
346
+ }"""
347
+
348
+ variables = {
349
+ "conversationId": conversation_id,
350
+ "id": element_id,
351
+ }
352
+ res = await self.query(query, variables)
353
+ self.check_for_errors(res, raise_error=True)
354
+
355
+ return res["data"]["element"]
356
+
357
+ async def upsert_element(self, variables):
358
+ c_id = await self.get_conversation_id()
359
+
360
+ if not c_id:
361
+ logger.warning("Missing conversation ID, could not persist the element.")
362
+ return None
363
+
364
+ if "id" in variables:
365
+ mutation_name = "updateElement"
366
+ mutation = """
367
+ mutation ($conversationId: ID!, $id: ID!, $forIds: [String!]!) {
368
+ updateElement(conversationId: $conversationId, id: $id, forIds: $forIds) {
369
+ id,
370
+ }
371
+ }
372
+ """
373
+ variables["conversationId"] = c_id
374
+ res = await self.mutation(mutation, variables)
375
+ else:
376
+ mutation_name = "createElement"
377
+ mutation = """
378
+ mutation ($conversationId: ID!, $type: String!, $url: String!, $name: String!, $display: String!, $forIds: [String!]!, $size: String, $language: String) {
379
+ createElement(conversationId: $conversationId, type: $type, url: $url, name: $name, display: $display, size: $size, language: $language, forIds: $forIds) {
380
+ id,
381
+ type,
382
+ url,
383
+ name,
384
+ display,
385
+ size,
386
+ language,
387
+ forIds
388
+ }
389
+ }
390
+ """
391
+ variables["conversationId"] = c_id
392
+ res = await self.mutation(mutation, variables)
393
+
394
+ if self.check_for_errors(res):
395
+ logger.warning("Could not persist element.")
396
+ return None
397
+
398
+ return res["data"][mutation_name]
399
+
400
+ async def upload_element(self, content: bytes, mime: str) -> str:
401
+ id = f"{uuid.uuid4()}"
402
+ body = {"projectId": self.project_id, "fileName": id, "contentType": mime}
403
+
404
+ path = f"/api/upload/file"
405
+
406
+ async with aiohttp.ClientSession() as session:
407
+ async with session.post(
408
+ f"{config.chainlit_server}{path}",
409
+ json=body,
410
+ headers=self.headers,
411
+ ) as r:
412
+ if not r.ok:
413
+ reason = await r.text()
414
+ logger.error(f"Failed to upload file: {reason}")
415
+ return ""
416
+ json_res = await r.json()
417
+
418
+ upload_details = json_res["post"]
419
+ permanent_url = json_res["permanentUrl"]
420
+
421
+ form_data = aiohttp.FormData()
422
+
423
+ # Add fields to the form_data
424
+ for field_name, field_value in upload_details["fields"].items():
425
+ form_data.add_field(field_name, field_value)
426
+
427
+ # Add file to the form_data
428
+ form_data.add_field("file", content, content_type="multipart/form-data")
429
+ async with aiohttp.ClientSession() as session:
430
+ async with session.post(
431
+ upload_details["url"],
432
+ data=form_data,
433
+ ) as upload_response:
434
+ if not upload_response.ok:
435
+ reason = await upload_response.text()
436
+ logger.error(f"Failed to upload file: {reason}")
437
+ return ""
438
+
439
+ url = f'{upload_details["url"]}/{upload_details["fields"]["key"]}'
440
+ return permanent_url