chainlit 1.0.401__py3-none-any.whl → 1.0.501__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

chainlit/data/__init__.py CHANGED
@@ -2,7 +2,7 @@ import functools
2
2
  import json
3
3
  import os
4
4
  from collections import deque
5
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
5
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast
6
6
 
7
7
  import aiofiles
8
8
  from chainlit.config import config
@@ -10,21 +10,15 @@ from chainlit.context import context
10
10
  from chainlit.logger import logger
11
11
  from chainlit.session import WebsocketSession
12
12
  from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter
13
- from chainlit.user import PersistedUser, User, UserDict
14
- from literalai import Attachment
15
- from literalai import Feedback as ClientFeedback
16
- from literalai import PageInfo, PaginatedResponse
17
- from literalai import Step as ClientStep
18
- from literalai.step import StepDict as ClientStepDict
19
- from literalai.thread import NumberListFilter, StringFilter, StringListFilter
20
- from literalai.thread import ThreadFilter as ClientThreadFilter
13
+ from chainlit.user import PersistedUser, User
14
+ from literalai import Attachment, PageInfo, PaginatedResponse, Score as LiteralScore, Step as LiteralStep
15
+ from literalai.filter import threads_filters as LiteralThreadsFilters
16
+ from literalai.step import StepDict as LiteralStepDict
21
17
 
22
18
  if TYPE_CHECKING:
23
19
  from chainlit.element import Element, ElementDict
24
20
  from chainlit.step import FeedbackDict, StepDict
25
21
 
26
- _data_layer = None
27
-
28
22
 
29
23
  def queue_until_user_message():
30
24
  def decorator(method):
@@ -59,6 +53,12 @@ class BaseDataLayer:
59
53
  async def create_user(self, user: "User") -> Optional["PersistedUser"]:
60
54
  pass
61
55
 
56
+ async def delete_feedback(
57
+ self,
58
+ feedback_id: str,
59
+ ) -> bool:
60
+ return True
61
+
62
62
  async def upsert_feedback(
63
63
  self,
64
64
  feedback: Feedback,
@@ -66,7 +66,7 @@ class BaseDataLayer:
66
66
  return ""
67
67
 
68
68
  @queue_until_user_message()
69
- async def create_element(self, element_dict: "ElementDict"):
69
+ async def create_element(self, element: "Element"):
70
70
  pass
71
71
 
72
72
  async def get_element(
@@ -100,7 +100,8 @@ class BaseDataLayer:
100
100
  self, pagination: "Pagination", filters: "ThreadFilter"
101
101
  ) -> "PaginatedResponse[ThreadDict]":
102
102
  return PaginatedResponse(
103
- data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None)
103
+ data=[],
104
+ pageInfo=PageInfo(hasNextPage=False, startCursor=None, endCursor=None),
104
105
  )
105
106
 
106
107
  async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
@@ -120,11 +121,14 @@ class BaseDataLayer:
120
121
  return True
121
122
 
122
123
 
123
- class ChainlitDataLayer:
124
+ _data_layer: Optional[BaseDataLayer] = None
125
+
126
+
127
+ class ChainlitDataLayer(BaseDataLayer):
124
128
  def __init__(self, api_key: str, server: Optional[str]):
125
- from literalai import LiteralClient
129
+ from literalai import AsyncLiteralClient
126
130
 
127
- self.client = LiteralClient(api_key=api_key, url=server)
131
+ self.client = AsyncLiteralClient(api_key=api_key, url=server)
128
132
  logger.info("Chainlit data layer initialized")
129
133
 
130
134
  def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict":
@@ -145,20 +149,19 @@ class ChainlitDataLayer:
145
149
  "threadId": attachment.thread_id,
146
150
  }
147
151
 
148
- def feedback_to_feedback_dict(
149
- self, feedback: Optional[ClientFeedback]
152
+ def score_to_feedback_dict(
153
+ self, score: Optional[LiteralScore]
150
154
  ) -> "Optional[FeedbackDict]":
151
- if not feedback:
155
+ if not score:
152
156
  return None
153
157
  return {
154
- "id": feedback.id or "",
155
- "forId": feedback.step_id or "",
156
- "value": feedback.value or 0, # type: ignore
157
- "comment": feedback.comment,
158
- "strategy": "BINARY",
158
+ "id": score.id or "",
159
+ "forId": score.step_id or "",
160
+ "value": cast(Literal[0, 1], score.value),
161
+ "comment": score.comment,
159
162
  }
160
163
 
161
- def step_to_step_dict(self, step: ClientStep) -> "StepDict":
164
+ def step_to_step_dict(self, step: LiteralStep) -> "StepDict":
162
165
  metadata = step.metadata or {}
163
166
  input = (step.input or {}).get("content") or (
164
167
  json.dumps(step.input) if step.input and step.input != {} else ""
@@ -166,12 +169,26 @@ class ChainlitDataLayer:
166
169
  output = (step.output or {}).get("content") or (
167
170
  json.dumps(step.output) if step.output and step.output != {} else ""
168
171
  )
172
+
173
+ user_feedback = (
174
+ next(
175
+ (
176
+ s
177
+ for s in step.scores
178
+ if s.type == "HUMAN" and s.name == "user-feedback"
179
+ ),
180
+ None,
181
+ )
182
+ if step.scores
183
+ else None
184
+ )
185
+
169
186
  return {
170
187
  "createdAt": step.created_at,
171
188
  "id": step.id or "",
172
189
  "threadId": step.thread_id or "",
173
190
  "parentId": step.parent_id,
174
- "feedback": self.feedback_to_feedback_dict(step.feedback),
191
+ "feedback": self.score_to_feedback_dict(user_feedback),
175
192
  "start": step.start_time,
176
193
  "end": step.end_time,
177
194
  "type": step.type or "undefined",
@@ -185,7 +202,6 @@ class ChainlitDataLayer:
185
202
  "language": metadata.get("language"),
186
203
  "isError": metadata.get("isError", False),
187
204
  "waitForAnswer": metadata.get("waitForAnswer", False),
188
- "feedback": self.feedback_to_feedback_dict(step.feedback),
189
205
  }
190
206
 
191
207
  async def get_user(self, identifier: str) -> Optional[PersistedUser]:
@@ -214,26 +230,37 @@ class ChainlitDataLayer:
214
230
  createdAt=_user.created_at or "",
215
231
  )
216
232
 
233
+ async def delete_feedback(
234
+ self,
235
+ feedback_id: str,
236
+ ):
237
+ if feedback_id:
238
+ await self.client.api.delete_score(
239
+ id=feedback_id,
240
+ )
241
+ return True
242
+ return False
243
+
217
244
  async def upsert_feedback(
218
245
  self,
219
246
  feedback: Feedback,
220
247
  ):
221
248
  if feedback.id:
222
- await self.client.api.update_feedback(
249
+ await self.client.api.update_score(
223
250
  id=feedback.id,
224
251
  update_params={
225
252
  "comment": feedback.comment,
226
- "strategy": feedback.strategy,
227
253
  "value": feedback.value,
228
254
  },
229
255
  )
230
256
  return feedback.id
231
257
  else:
232
- created = await self.client.api.create_feedback(
258
+ created = await self.client.api.create_score(
233
259
  step_id=feedback.forId,
234
260
  value=feedback.value,
235
261
  comment=feedback.comment,
236
- strategy=feedback.strategy,
262
+ name="user-feedback",
263
+ type="HUMAN",
237
264
  )
238
265
  return created.id or ""
239
266
 
@@ -298,15 +325,18 @@ class ChainlitDataLayer:
298
325
 
299
326
  @queue_until_user_message()
300
327
  async def create_step(self, step_dict: "StepDict"):
301
- metadata = {
302
- "disableFeedback": step_dict.get("disableFeedback"),
303
- "isError": step_dict.get("isError"),
304
- "waitForAnswer": step_dict.get("waitForAnswer"),
305
- "language": step_dict.get("language"),
306
- "showInput": step_dict.get("showInput"),
307
- }
328
+ metadata = dict(
329
+ step_dict.get("metadata", {}),
330
+ **{
331
+ "disableFeedback": step_dict.get("disableFeedback"),
332
+ "isError": step_dict.get("isError"),
333
+ "waitForAnswer": step_dict.get("waitForAnswer"),
334
+ "language": step_dict.get("language"),
335
+ "showInput": step_dict.get("showInput"),
336
+ },
337
+ )
308
338
 
309
- step: ClientStepDict = {
339
+ step: LiteralStepDict = {
310
340
  "createdAt": step_dict.get("createdAt"),
311
341
  "startTime": step_dict.get("start"),
312
342
  "endTime": step_dict.get("end"),
@@ -316,6 +346,7 @@ class ChainlitDataLayer:
316
346
  "name": step_dict.get("name"),
317
347
  "threadId": step_dict.get("threadId"),
318
348
  "type": step_dict.get("type"),
349
+ "tags": step_dict.get("tags"),
319
350
  "metadata": metadata,
320
351
  }
321
352
  if step_dict.get("input"):
@@ -337,10 +368,11 @@ class ChainlitDataLayer:
337
368
  thread = await self.get_thread(thread_id)
338
369
  if not thread:
339
370
  return ""
340
- user = thread.get("user")
341
- if not user:
371
+ user_identifier = thread.get("userIdentifier")
372
+ if not user_identifier:
342
373
  return ""
343
- return user.get("identifier") or ""
374
+
375
+ return user_identifier
344
376
 
345
377
  async def delete_thread(self, thread_id: str):
346
378
  await self.client.api.delete_thread(id=thread_id)
@@ -348,22 +380,42 @@ class ChainlitDataLayer:
348
380
  async def list_threads(
349
381
  self, pagination: "Pagination", filters: "ThreadFilter"
350
382
  ) -> "PaginatedResponse[ThreadDict]":
351
- if not filters.userIdentifier:
352
- raise ValueError("userIdentifier is required")
383
+ if not filters.userId:
384
+ raise ValueError("userId is required")
385
+
386
+ literal_filters: LiteralThreadsFilters = [
387
+ {
388
+ "field": "participantId",
389
+ "operator": "eq",
390
+ "value": filters.userId,
391
+ }
392
+ ]
353
393
 
354
- client_filters = ClientThreadFilter(
355
- participantsIdentifier=StringListFilter(
356
- operator="in", value=[filters.userIdentifier]
357
- ),
358
- )
359
394
  if filters.search:
360
- client_filters.search = StringFilter(operator="ilike", value=filters.search)
361
- if filters.feedback:
362
- client_filters.feedbacksValue = NumberListFilter(
363
- operator="in", value=[filters.feedback]
395
+ literal_filters.append(
396
+ {
397
+ "field": "stepOutput",
398
+ "operator": "ilike",
399
+ "value": filters.search,
400
+ "path": "content",
401
+ }
364
402
  )
403
+
404
+ if filters.feedback is not None:
405
+ literal_filters.append(
406
+ {
407
+ "field": "scoreValue",
408
+ "operator": "eq",
409
+ "value": filters.feedback,
410
+ "path": "user-feedback",
411
+ }
412
+ )
413
+
365
414
  return await self.client.api.list_threads(
366
- first=pagination.first, after=pagination.cursor, filters=client_filters
415
+ first=pagination.first,
416
+ after=pagination.cursor,
417
+ filters=literal_filters,
418
+ order_by={"column": "createdAt", "direction": "DESC"},
367
419
  )
368
420
 
369
421
  async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
@@ -382,15 +434,6 @@ class ChainlitDataLayer:
382
434
  step.generation = None
383
435
  steps.append(self.step_to_step_dict(step))
384
436
 
385
- user = None # type: Optional["UserDict"]
386
-
387
- if thread.user:
388
- user = {
389
- "id": thread.user.id or "",
390
- "identifier": thread.user.identifier or "",
391
- "metadata": thread.user.metadata,
392
- }
393
-
394
437
  return {
395
438
  "createdAt": thread.created_at or "",
396
439
  "id": thread.id,
@@ -398,7 +441,8 @@ class ChainlitDataLayer:
398
441
  "steps": steps,
399
442
  "elements": elements,
400
443
  "metadata": thread.metadata,
401
- "user": user,
444
+ "userId": thread.participant_id,
445
+ "userIdentifier": thread.participant_identifier,
402
446
  "tags": thread.tags,
403
447
  }
404
448
 
@@ -411,7 +455,7 @@ class ChainlitDataLayer:
411
455
  tags: Optional[List[str]] = None,
412
456
  ):
413
457
  await self.client.api.upsert_thread(
414
- thread_id=thread_id,
458
+ id=thread_id,
415
459
  name=name,
416
460
  participant_id=user_id,
417
461
  metadata=metadata,
@@ -420,7 +464,8 @@ class ChainlitDataLayer:
420
464
 
421
465
 
422
466
  if api_key := os.environ.get("LITERAL_API_KEY"):
423
- server = os.environ.get("LITERAL_SERVER")
467
+ # support legacy LITERAL_SERVER variable as fallback
468
+ server = os.environ.get("LITERAL_API_URL", os.environ.get("LITERAL_SERVER"))
424
469
  _data_layer = ChainlitDataLayer(api_key=api_key, server=server)
425
470
 
426
471