chainlit 1.0.401__py3-none-any.whl → 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (112) hide show
  1. chainlit/__init__.py +98 -279
  2. chainlit/_utils.py +8 -0
  3. chainlit/action.py +12 -10
  4. chainlit/{auth.py → auth/__init__.py} +28 -36
  5. chainlit/auth/cookie.py +122 -0
  6. chainlit/auth/jwt.py +39 -0
  7. chainlit/cache.py +4 -6
  8. chainlit/callbacks.py +362 -0
  9. chainlit/chat_context.py +64 -0
  10. chainlit/chat_settings.py +3 -1
  11. chainlit/cli/__init__.py +77 -8
  12. chainlit/config.py +181 -101
  13. chainlit/context.py +42 -13
  14. chainlit/copilot/dist/index.js +8750 -903
  15. chainlit/data/__init__.py +101 -416
  16. chainlit/data/acl.py +6 -2
  17. chainlit/data/base.py +107 -0
  18. chainlit/data/chainlit_data_layer.py +608 -0
  19. chainlit/data/dynamodb.py +590 -0
  20. chainlit/data/literalai.py +500 -0
  21. chainlit/data/sql_alchemy.py +721 -0
  22. chainlit/data/storage_clients/__init__.py +0 -0
  23. chainlit/data/storage_clients/azure.py +81 -0
  24. chainlit/data/storage_clients/azure_blob.py +89 -0
  25. chainlit/data/storage_clients/base.py +26 -0
  26. chainlit/data/storage_clients/gcs.py +88 -0
  27. chainlit/data/storage_clients/s3.py +75 -0
  28. chainlit/data/utils.py +29 -0
  29. chainlit/discord/__init__.py +6 -0
  30. chainlit/discord/app.py +354 -0
  31. chainlit/element.py +91 -33
  32. chainlit/emitter.py +80 -29
  33. chainlit/frontend/dist/assets/DailyMotion-C_XC7xJI.js +1 -0
  34. chainlit/frontend/dist/assets/Dataframe-Cs4l4hA1.js +22 -0
  35. chainlit/frontend/dist/assets/Facebook-CUeCH7hk.js +1 -0
  36. chainlit/frontend/dist/assets/FilePlayer-CB-fYkx8.js +1 -0
  37. chainlit/frontend/dist/assets/Kaltura-YX6qaq72.js +1 -0
  38. chainlit/frontend/dist/assets/Mixcloud-DGV0ldjP.js +1 -0
  39. chainlit/frontend/dist/assets/Mux-CmRss5oc.js +1 -0
  40. chainlit/frontend/dist/assets/Preview-DBVJn7-H.js +1 -0
  41. chainlit/frontend/dist/assets/SoundCloud-qLUb18oY.js +1 -0
  42. chainlit/frontend/dist/assets/Streamable-BvYP7bFp.js +1 -0
  43. chainlit/frontend/dist/assets/Twitch-CTHt-sGZ.js +1 -0
  44. chainlit/frontend/dist/assets/Vidyard-B-0mCJbm.js +1 -0
  45. chainlit/frontend/dist/assets/Vimeo-Dnp7ri8q.js +1 -0
  46. chainlit/frontend/dist/assets/Wistia-DW0x_UBn.js +1 -0
  47. chainlit/frontend/dist/assets/YouTube--98FipvA.js +1 -0
  48. chainlit/frontend/dist/assets/index-D71nZ46o.js +8665 -0
  49. chainlit/frontend/dist/assets/index-g8LTJwwr.css +1 -0
  50. chainlit/frontend/dist/assets/react-plotly-Cn_BQTQw.js +3484 -0
  51. chainlit/frontend/dist/index.html +2 -4
  52. chainlit/haystack/callbacks.py +4 -7
  53. chainlit/input_widget.py +8 -4
  54. chainlit/langchain/callbacks.py +103 -68
  55. chainlit/langflow/__init__.py +1 -0
  56. chainlit/llama_index/callbacks.py +65 -40
  57. chainlit/markdown.py +22 -6
  58. chainlit/message.py +54 -56
  59. chainlit/mistralai/__init__.py +50 -0
  60. chainlit/oauth_providers.py +266 -8
  61. chainlit/openai/__init__.py +10 -18
  62. chainlit/secret.py +1 -1
  63. chainlit/server.py +789 -228
  64. chainlit/session.py +108 -90
  65. chainlit/slack/__init__.py +6 -0
  66. chainlit/slack/app.py +397 -0
  67. chainlit/socket.py +199 -116
  68. chainlit/step.py +141 -89
  69. chainlit/sync.py +2 -1
  70. chainlit/teams/__init__.py +6 -0
  71. chainlit/teams/app.py +338 -0
  72. chainlit/translations/bn.json +235 -0
  73. chainlit/translations/en-US.json +83 -4
  74. chainlit/translations/gu.json +235 -0
  75. chainlit/translations/he-IL.json +235 -0
  76. chainlit/translations/hi.json +235 -0
  77. chainlit/translations/kn.json +235 -0
  78. chainlit/translations/ml.json +235 -0
  79. chainlit/translations/mr.json +235 -0
  80. chainlit/translations/nl-NL.json +233 -0
  81. chainlit/translations/ta.json +235 -0
  82. chainlit/translations/te.json +235 -0
  83. chainlit/translations/zh-CN.json +233 -0
  84. chainlit/translations.py +60 -0
  85. chainlit/types.py +133 -28
  86. chainlit/user.py +14 -3
  87. chainlit/user_session.py +6 -3
  88. chainlit/utils.py +52 -5
  89. chainlit/version.py +3 -2
  90. {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/METADATA +48 -50
  91. chainlit-2.0.3.dist-info/RECORD +106 -0
  92. chainlit/cli/utils.py +0 -24
  93. chainlit/frontend/dist/assets/index-9711593e.js +0 -723
  94. chainlit/frontend/dist/assets/index-d088547c.css +0 -1
  95. chainlit/frontend/dist/assets/react-plotly-d8762cc2.js +0 -3602
  96. chainlit/playground/__init__.py +0 -2
  97. chainlit/playground/config.py +0 -40
  98. chainlit/playground/provider.py +0 -108
  99. chainlit/playground/providers/__init__.py +0 -13
  100. chainlit/playground/providers/anthropic.py +0 -118
  101. chainlit/playground/providers/huggingface.py +0 -75
  102. chainlit/playground/providers/langchain.py +0 -89
  103. chainlit/playground/providers/openai.py +0 -408
  104. chainlit/playground/providers/vertexai.py +0 -171
  105. chainlit/translations/pt-BR.json +0 -155
  106. chainlit-1.0.401.dist-info/RECORD +0 -66
  107. /chainlit/copilot/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  108. /chainlit/copilot/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  109. /chainlit/frontend/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  110. /chainlit/frontend/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  111. {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/WHEEL +0 -0
  112. {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/entry_points.txt +0 -0
chainlit/data/__init__.py CHANGED
@@ -1,428 +1,113 @@
1
- import functools
2
- import json
3
1
  import os
4
- from collections import deque
5
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
2
+ import warnings
3
+ from typing import Optional
6
4
 
7
- import aiofiles
8
- from chainlit.config import config
9
- from chainlit.context import context
10
- from chainlit.logger import logger
11
- from chainlit.session import WebsocketSession
12
- from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter
13
- from chainlit.user import PersistedUser, User, UserDict
14
- from literalai import Attachment
15
- from literalai import Feedback as ClientFeedback
16
- from literalai import PageInfo, PaginatedResponse
17
- from literalai import Step as ClientStep
18
- from literalai.step import StepDict as ClientStepDict
19
- from literalai.thread import NumberListFilter, StringFilter, StringListFilter
20
- from literalai.thread import ThreadFilter as ClientThreadFilter
5
+ from .base import BaseDataLayer
6
+ from .utils import (
7
+ queue_until_user_message as queue_until_user_message, # TODO: Consider deprecating re-export.; Redundant alias tells type checkers to STFU.
8
+ )
21
9
 
22
- if TYPE_CHECKING:
23
- from chainlit.element import Element, ElementDict
24
- from chainlit.step import FeedbackDict, StepDict
10
+ _data_layer: Optional[BaseDataLayer] = None
11
+ _data_layer_initialized = False
25
12
 
26
- _data_layer = None
27
13
 
14
+ def get_data_layer():
15
+ global _data_layer, _data_layer_initialized
28
16
 
29
- def queue_until_user_message():
30
- def decorator(method):
31
- @functools.wraps(method)
32
- async def wrapper(self, *args, **kwargs):
33
- if (
34
- isinstance(context.session, WebsocketSession)
35
- and not context.session.has_first_interaction
36
- ):
37
- # Queue the method invocation waiting for the first user message
38
- queues = context.session.thread_queues
39
- method_name = method.__name__
40
- if method_name not in queues:
41
- queues[method_name] = deque()
42
- queues[method_name].append((method, self, args, kwargs))
43
-
44
- else:
45
- # Otherwise, Execute the method immediately
46
- return await method(self, *args, **kwargs)
47
-
48
- return wrapper
49
-
50
- return decorator
51
-
52
-
53
- class BaseDataLayer:
54
- """Base class for data persistence."""
55
-
56
- async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
57
- return None
58
-
59
- async def create_user(self, user: "User") -> Optional["PersistedUser"]:
60
- pass
61
-
62
- async def upsert_feedback(
63
- self,
64
- feedback: Feedback,
65
- ) -> str:
66
- return ""
67
-
68
- @queue_until_user_message()
69
- async def create_element(self, element_dict: "ElementDict"):
70
- pass
71
-
72
- async def get_element(
73
- self, thread_id: str, element_id: str
74
- ) -> Optional["ElementDict"]:
75
- pass
76
-
77
- @queue_until_user_message()
78
- async def delete_element(self, element_id: str):
79
- pass
80
-
81
- @queue_until_user_message()
82
- async def create_step(self, step_dict: "StepDict"):
83
- pass
84
-
85
- @queue_until_user_message()
86
- async def update_step(self, step_dict: "StepDict"):
87
- pass
88
-
89
- @queue_until_user_message()
90
- async def delete_step(self, step_id: str):
91
- pass
92
-
93
- async def get_thread_author(self, thread_id: str) -> str:
94
- return ""
95
-
96
- async def delete_thread(self, thread_id: str):
97
- pass
98
-
99
- async def list_threads(
100
- self, pagination: "Pagination", filters: "ThreadFilter"
101
- ) -> "PaginatedResponse[ThreadDict]":
102
- return PaginatedResponse(
103
- data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None)
104
- )
105
-
106
- async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
107
- return None
108
-
109
- async def update_thread(
110
- self,
111
- thread_id: str,
112
- name: Optional[str] = None,
113
- user_id: Optional[str] = None,
114
- metadata: Optional[Dict] = None,
115
- tags: Optional[List[str]] = None,
116
- ):
117
- pass
118
-
119
- async def delete_user_session(self, id: str) -> bool:
120
- return True
121
-
122
-
123
- class ChainlitDataLayer:
124
- def __init__(self, api_key: str, server: Optional[str]):
125
- from literalai import LiteralClient
126
-
127
- self.client = LiteralClient(api_key=api_key, url=server)
128
- logger.info("Chainlit data layer initialized")
129
-
130
- def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict":
131
- metadata = attachment.metadata or {}
132
- return {
133
- "chainlitKey": None,
134
- "display": metadata.get("display", "side"),
135
- "language": metadata.get("language"),
136
- "page": metadata.get("page"),
137
- "size": metadata.get("size"),
138
- "type": metadata.get("type", "file"),
139
- "forId": attachment.step_id,
140
- "id": attachment.id or "",
141
- "mime": attachment.mime,
142
- "name": attachment.name or "",
143
- "objectKey": attachment.object_key,
144
- "url": attachment.url,
145
- "threadId": attachment.thread_id,
146
- }
147
-
148
- def feedback_to_feedback_dict(
149
- self, feedback: Optional[ClientFeedback]
150
- ) -> "Optional[FeedbackDict]":
151
- if not feedback:
152
- return None
153
- return {
154
- "id": feedback.id or "",
155
- "forId": feedback.step_id or "",
156
- "value": feedback.value or 0, # type: ignore
157
- "comment": feedback.comment,
158
- "strategy": "BINARY",
159
- }
160
-
161
- def step_to_step_dict(self, step: ClientStep) -> "StepDict":
162
- metadata = step.metadata or {}
163
- input = (step.input or {}).get("content") or (
164
- json.dumps(step.input) if step.input and step.input != {} else ""
165
- )
166
- output = (step.output or {}).get("content") or (
167
- json.dumps(step.output) if step.output and step.output != {} else ""
168
- )
169
- return {
170
- "createdAt": step.created_at,
171
- "id": step.id or "",
172
- "threadId": step.thread_id or "",
173
- "parentId": step.parent_id,
174
- "feedback": self.feedback_to_feedback_dict(step.feedback),
175
- "start": step.start_time,
176
- "end": step.end_time,
177
- "type": step.type or "undefined",
178
- "name": step.name or "",
179
- "generation": step.generation.to_dict() if step.generation else None,
180
- "input": input,
181
- "output": output,
182
- "showInput": metadata.get("showInput", False),
183
- "disableFeedback": metadata.get("disableFeedback", False),
184
- "indent": metadata.get("indent"),
185
- "language": metadata.get("language"),
186
- "isError": metadata.get("isError", False),
187
- "waitForAnswer": metadata.get("waitForAnswer", False),
188
- "feedback": self.feedback_to_feedback_dict(step.feedback),
189
- }
190
-
191
- async def get_user(self, identifier: str) -> Optional[PersistedUser]:
192
- user = await self.client.api.get_user(identifier=identifier)
193
- if not user:
194
- return None
195
- return PersistedUser(
196
- id=user.id or "",
197
- identifier=user.identifier or "",
198
- metadata=user.metadata,
199
- createdAt=user.created_at or "",
200
- )
17
+ if not _data_layer_initialized:
18
+ if _data_layer:
19
+ # Data layer manually set, warn user that this is deprecated.
201
20
 
202
- async def create_user(self, user: User) -> Optional[PersistedUser]:
203
- _user = await self.client.api.get_user(identifier=user.identifier)
204
- if not _user:
205
- _user = await self.client.api.create_user(
206
- identifier=user.identifier, metadata=user.metadata
21
+ warnings.warn(
22
+ "Setting data layer manually is deprecated. Use @data_layer instead.",
23
+ DeprecationWarning,
207
24
  )
208
- elif _user.id:
209
- await self.client.api.update_user(id=_user.id, metadata=user.metadata)
210
- return PersistedUser(
211
- id=_user.id or "",
212
- identifier=_user.identifier or "",
213
- metadata=_user.metadata,
214
- createdAt=_user.created_at or "",
215
- )
216
25
 
217
- async def upsert_feedback(
218
- self,
219
- feedback: Feedback,
220
- ):
221
- if feedback.id:
222
- await self.client.api.update_feedback(
223
- id=feedback.id,
224
- update_params={
225
- "comment": feedback.comment,
226
- "strategy": feedback.strategy,
227
- "value": feedback.value,
228
- },
229
- )
230
- return feedback.id
231
26
  else:
232
- created = await self.client.api.create_feedback(
233
- step_id=feedback.forId,
234
- value=feedback.value,
235
- comment=feedback.comment,
236
- strategy=feedback.strategy,
237
- )
238
- return created.id or ""
239
-
240
- @queue_until_user_message()
241
- async def create_element(self, element: "Element"):
242
- metadata = {
243
- "size": element.size,
244
- "language": element.language,
245
- "display": element.display,
246
- "type": element.type,
247
- "page": getattr(element, "page", None),
248
- }
249
-
250
- if not element.for_id:
251
- return
252
-
253
- object_key = None
254
-
255
- if not element.url:
256
- if element.path:
257
- async with aiofiles.open(element.path, "rb") as f:
258
- content = await f.read() # type: Union[bytes, str]
259
- elif element.content:
260
- content = element.content
261
- else:
262
- raise ValueError("Either path or content must be provided")
263
- uploaded = await self.client.api.upload_file(
264
- content=content, mime=element.mime, thread_id=element.thread_id
265
- )
266
- object_key = uploaded["object_key"]
267
-
268
- await self.client.api.send_steps(
269
- [
270
- {
271
- "id": element.for_id,
272
- "threadId": element.thread_id,
273
- "attachments": [
274
- {
275
- "id": element.id,
276
- "name": element.name,
277
- "metadata": metadata,
278
- "mime": element.mime,
279
- "url": element.url,
280
- "objectKey": object_key,
281
- }
282
- ],
283
- }
284
- ]
285
- )
286
-
287
- async def get_element(
288
- self, thread_id: str, element_id: str
289
- ) -> Optional["ElementDict"]:
290
- attachment = await self.client.api.get_attachment(id=element_id)
291
- if not attachment:
292
- return None
293
- return self.attachment_to_element_dict(attachment)
294
-
295
- @queue_until_user_message()
296
- async def delete_element(self, element_id: str):
297
- await self.client.api.delete_attachment(id=element_id)
298
-
299
- @queue_until_user_message()
300
- async def create_step(self, step_dict: "StepDict"):
301
- metadata = {
302
- "disableFeedback": step_dict.get("disableFeedback"),
303
- "isError": step_dict.get("isError"),
304
- "waitForAnswer": step_dict.get("waitForAnswer"),
305
- "language": step_dict.get("language"),
306
- "showInput": step_dict.get("showInput"),
307
- }
308
-
309
- step: ClientStepDict = {
310
- "createdAt": step_dict.get("createdAt"),
311
- "startTime": step_dict.get("start"),
312
- "endTime": step_dict.get("end"),
313
- "generation": step_dict.get("generation"),
314
- "id": step_dict.get("id"),
315
- "parentId": step_dict.get("parentId"),
316
- "name": step_dict.get("name"),
317
- "threadId": step_dict.get("threadId"),
318
- "type": step_dict.get("type"),
319
- "metadata": metadata,
320
- }
321
- if step_dict.get("input"):
322
- step["input"] = {"content": step_dict.get("input")}
323
- if step_dict.get("output"):
324
- step["output"] = {"content": step_dict.get("output")}
325
-
326
- await self.client.api.send_steps([step])
327
-
328
- @queue_until_user_message()
329
- async def update_step(self, step_dict: "StepDict"):
330
- await self.create_step(step_dict)
331
-
332
- @queue_until_user_message()
333
- async def delete_step(self, step_id: str):
334
- await self.client.api.delete_step(id=step_id)
27
+ from chainlit.config import config
28
+
29
+ if config.code.data_layer:
30
+ # When @data_layer is configured, call it to get data layer.
31
+ _data_layer = config.code.data_layer()
32
+ elif database_url := os.environ.get("DATABASE_URL"):
33
+ # Default to Chainlit data layer if DATABASE_URL specified.
34
+ from .chainlit_data_layer import ChainlitDataLayer
35
+
36
+ if os.environ.get("LITERAL_API_KEY"):
37
+ warnings.warn(
38
+ "Both LITERAL_API_KEY and DATABASE_URL specified. Ignoring Literal AI data layer and relying on data layer pointing to DATABASE_URL."
39
+ )
40
+ bucket_name = os.environ.get("BUCKET_NAME")
41
+
42
+ # AWS S3
43
+ aws_region = os.getenv("APP_AWS_REGION")
44
+ aws_access_key = os.getenv("APP_AWS_ACCESS_KEY")
45
+ aws_secret_key = os.getenv("APP_AWS_SECRET_KEY")
46
+ dev_aws_endpoint = os.getenv("DEV_AWS_ENDPOINT")
47
+ is_using_s3 = bool(aws_access_key and aws_secret_key and aws_region)
48
+
49
+ # Google Cloud Storage
50
+ gcs_project_id = os.getenv("APP_GCS_PROJECT_ID")
51
+ gcs_client_email = os.getenv("APP_GCS_CLIENT_EMAIL")
52
+ gcs_private_key = os.getenv("APP_GCS_PRIVATE_KEY")
53
+ is_using_gcs = bool(
54
+ gcs_project_id and gcs_client_email and gcs_private_key
55
+ )
56
+
57
+ # Azure Storage
58
+ azure_storage_account = os.getenv("APP_AZURE_STORAGE_ACCOUNT")
59
+ azure_storage_key = os.getenv("APP_AZURE_STORAGE_ACCESS_KEY")
60
+ is_using_azure = bool(azure_storage_account and azure_storage_key)
61
+
62
+ storage_client = None
63
+
64
+ if sum([is_using_s3, is_using_gcs, is_using_azure]) > 1:
65
+ warnings.warn(
66
+ "Multiple storage configurations detected. Please use only one."
67
+ )
68
+ elif is_using_s3:
69
+ from chainlit.data.storage_clients.s3 import S3StorageClient
70
+
71
+ storage_client = S3StorageClient(
72
+ bucket=bucket_name,
73
+ region_name=aws_region,
74
+ aws_access_key_id=aws_access_key,
75
+ aws_secret_access_key=aws_secret_key,
76
+ endpoint_url=dev_aws_endpoint,
77
+ )
78
+ elif is_using_gcs:
79
+ from chainlit.data.storage_clients.gcs import GCSStorageClient
80
+
81
+ storage_client = GCSStorageClient(
82
+ project_id=gcs_project_id,
83
+ client_email=gcs_client_email,
84
+ private_key=gcs_private_key,
85
+ bucket_name=bucket_name,
86
+ )
87
+ elif is_using_azure:
88
+ from chainlit.data.storage_clients.azure_blob import (
89
+ AzureBlobStorageClient,
90
+ )
91
+
92
+ storage_client = AzureBlobStorageClient(
93
+ container_name=bucket_name,
94
+ storage_account=azure_storage_account,
95
+ storage_key=azure_storage_key,
96
+ )
97
+
98
+ _data_layer = ChainlitDataLayer(
99
+ database_url=database_url, storage_client=storage_client
100
+ )
101
+ elif api_key := os.environ.get("LITERAL_API_KEY"):
102
+ # When LITERAL_API_KEY is defined, use Literal AI data layer
103
+ from .literalai import LiteralDataLayer
104
+
105
+ # support legacy LITERAL_SERVER variable as fallback
106
+ server = os.environ.get("LITERAL_API_URL") or os.environ.get(
107
+ "LITERAL_SERVER"
108
+ )
109
+ _data_layer = LiteralDataLayer(api_key=api_key, server=server)
110
+
111
+ _data_layer_initialized = True
335
112
 
336
- async def get_thread_author(self, thread_id: str) -> str:
337
- thread = await self.get_thread(thread_id)
338
- if not thread:
339
- return ""
340
- user = thread.get("user")
341
- if not user:
342
- return ""
343
- return user.get("identifier") or ""
344
-
345
- async def delete_thread(self, thread_id: str):
346
- await self.client.api.delete_thread(id=thread_id)
347
-
348
- async def list_threads(
349
- self, pagination: "Pagination", filters: "ThreadFilter"
350
- ) -> "PaginatedResponse[ThreadDict]":
351
- if not filters.userIdentifier:
352
- raise ValueError("userIdentifier is required")
353
-
354
- client_filters = ClientThreadFilter(
355
- participantsIdentifier=StringListFilter(
356
- operator="in", value=[filters.userIdentifier]
357
- ),
358
- )
359
- if filters.search:
360
- client_filters.search = StringFilter(operator="ilike", value=filters.search)
361
- if filters.feedback:
362
- client_filters.feedbacksValue = NumberListFilter(
363
- operator="in", value=[filters.feedback]
364
- )
365
- return await self.client.api.list_threads(
366
- first=pagination.first, after=pagination.cursor, filters=client_filters
367
- )
368
-
369
- async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
370
- thread = await self.client.api.get_thread(id=thread_id)
371
- if not thread:
372
- return None
373
- elements = [] # List[ElementDict]
374
- steps = [] # List[StepDict]
375
- if thread.steps:
376
- for step in thread.steps:
377
- if config.ui.hide_cot and step.parent_id:
378
- continue
379
- for attachment in step.attachments:
380
- elements.append(self.attachment_to_element_dict(attachment))
381
- if not config.features.prompt_playground and step.generation:
382
- step.generation = None
383
- steps.append(self.step_to_step_dict(step))
384
-
385
- user = None # type: Optional["UserDict"]
386
-
387
- if thread.user:
388
- user = {
389
- "id": thread.user.id or "",
390
- "identifier": thread.user.identifier or "",
391
- "metadata": thread.user.metadata,
392
- }
393
-
394
- return {
395
- "createdAt": thread.created_at or "",
396
- "id": thread.id,
397
- "name": thread.name or None,
398
- "steps": steps,
399
- "elements": elements,
400
- "metadata": thread.metadata,
401
- "user": user,
402
- "tags": thread.tags,
403
- }
404
-
405
- async def update_thread(
406
- self,
407
- thread_id: str,
408
- name: Optional[str] = None,
409
- user_id: Optional[str] = None,
410
- metadata: Optional[Dict] = None,
411
- tags: Optional[List[str]] = None,
412
- ):
413
- await self.client.api.upsert_thread(
414
- thread_id=thread_id,
415
- name=name,
416
- participant_id=user_id,
417
- metadata=metadata,
418
- tags=tags,
419
- )
420
-
421
-
422
- if api_key := os.environ.get("LITERAL_API_KEY"):
423
- server = os.environ.get("LITERAL_SERVER")
424
- _data_layer = ChainlitDataLayer(api_key=api_key, server=server)
425
-
426
-
427
- def get_data_layer():
428
113
  return _data_layer
chainlit/data/acl.py CHANGED
@@ -1,14 +1,18 @@
1
- from chainlit.data import get_data_layer
2
1
  from fastapi import HTTPException
3
2
 
3
+ from chainlit.data import get_data_layer
4
+
4
5
 
5
6
  async def is_thread_author(username: str, thread_id: str):
6
7
  data_layer = get_data_layer()
7
8
  if not data_layer:
8
- raise HTTPException(status_code=401, detail="Unauthorized")
9
+ raise HTTPException(status_code=400, detail="Data layer not initialized")
9
10
 
10
11
  thread_author = await data_layer.get_thread_author(thread_id)
11
12
 
13
+ if not thread_author:
14
+ raise HTTPException(status_code=404, detail="Thread not found")
15
+
12
16
  if thread_author != username:
13
17
  raise HTTPException(status_code=401, detail="Unauthorized")
14
18
  else:
chainlit/data/base.py ADDED
@@ -0,0 +1,107 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import TYPE_CHECKING, Dict, List, Optional
3
+
4
+ from chainlit.types import (
5
+ Feedback,
6
+ PaginatedResponse,
7
+ Pagination,
8
+ ThreadDict,
9
+ ThreadFilter,
10
+ )
11
+
12
+ from .utils import queue_until_user_message
13
+
14
+ if TYPE_CHECKING:
15
+ from chainlit.element import Element, ElementDict
16
+ from chainlit.step import StepDict
17
+ from chainlit.user import PersistedUser, User
18
+
19
+
20
+ class BaseDataLayer(ABC):
21
+ """Base class for data persistence."""
22
+
23
+ @abstractmethod
24
+ async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
25
+ pass
26
+
27
+ @abstractmethod
28
+ async def create_user(self, user: "User") -> Optional["PersistedUser"]:
29
+ pass
30
+
31
+ @abstractmethod
32
+ async def delete_feedback(
33
+ self,
34
+ feedback_id: str,
35
+ ) -> bool:
36
+ pass
37
+
38
+ @abstractmethod
39
+ async def upsert_feedback(
40
+ self,
41
+ feedback: Feedback,
42
+ ) -> str:
43
+ pass
44
+
45
+ @queue_until_user_message()
46
+ @abstractmethod
47
+ async def create_element(self, element: "Element"):
48
+ pass
49
+
50
+ @abstractmethod
51
+ async def get_element(
52
+ self, thread_id: str, element_id: str
53
+ ) -> Optional["ElementDict"]:
54
+ pass
55
+
56
+ @queue_until_user_message()
57
+ @abstractmethod
58
+ async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
59
+ pass
60
+
61
+ @queue_until_user_message()
62
+ @abstractmethod
63
+ async def create_step(self, step_dict: "StepDict"):
64
+ pass
65
+
66
+ @queue_until_user_message()
67
+ @abstractmethod
68
+ async def update_step(self, step_dict: "StepDict"):
69
+ pass
70
+
71
+ @queue_until_user_message()
72
+ @abstractmethod
73
+ async def delete_step(self, step_id: str):
74
+ pass
75
+
76
+ @abstractmethod
77
+ async def get_thread_author(self, thread_id: str) -> str:
78
+ return ""
79
+
80
+ @abstractmethod
81
+ async def delete_thread(self, thread_id: str):
82
+ pass
83
+
84
+ @abstractmethod
85
+ async def list_threads(
86
+ self, pagination: "Pagination", filters: "ThreadFilter"
87
+ ) -> "PaginatedResponse[ThreadDict]":
88
+ pass
89
+
90
+ @abstractmethod
91
+ async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
92
+ pass
93
+
94
+ @abstractmethod
95
+ async def update_thread(
96
+ self,
97
+ thread_id: str,
98
+ name: Optional[str] = None,
99
+ user_id: Optional[str] = None,
100
+ metadata: Optional[Dict] = None,
101
+ tags: Optional[List[str]] = None,
102
+ ):
103
+ pass
104
+
105
+ @abstractmethod
106
+ async def build_debug_url(self) -> str:
107
+ pass