chainlit 1.0.401__py3-none-any.whl → 2.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +98 -279
- chainlit/_utils.py +8 -0
- chainlit/action.py +12 -10
- chainlit/{auth.py → auth/__init__.py} +28 -36
- chainlit/auth/cookie.py +123 -0
- chainlit/auth/jwt.py +39 -0
- chainlit/cache.py +4 -6
- chainlit/callbacks.py +362 -0
- chainlit/chat_context.py +64 -0
- chainlit/chat_settings.py +3 -1
- chainlit/cli/__init__.py +77 -8
- chainlit/config.py +191 -102
- chainlit/context.py +42 -13
- chainlit/copilot/dist/index.js +8750 -903
- chainlit/data/__init__.py +101 -416
- chainlit/data/acl.py +6 -2
- chainlit/data/base.py +107 -0
- chainlit/data/chainlit_data_layer.py +614 -0
- chainlit/data/dynamodb.py +590 -0
- chainlit/data/literalai.py +500 -0
- chainlit/data/sql_alchemy.py +721 -0
- chainlit/data/storage_clients/__init__.py +0 -0
- chainlit/data/storage_clients/azure.py +81 -0
- chainlit/data/storage_clients/azure_blob.py +89 -0
- chainlit/data/storage_clients/base.py +26 -0
- chainlit/data/storage_clients/gcs.py +88 -0
- chainlit/data/storage_clients/s3.py +75 -0
- chainlit/data/utils.py +29 -0
- chainlit/discord/__init__.py +6 -0
- chainlit/discord/app.py +354 -0
- chainlit/element.py +91 -33
- chainlit/emitter.py +81 -29
- chainlit/frontend/dist/assets/DailyMotion-Ce9dQoqZ.js +1 -0
- chainlit/frontend/dist/assets/Dataframe-C1XonMcV.js +22 -0
- chainlit/frontend/dist/assets/Facebook-DVVt6lrr.js +1 -0
- chainlit/frontend/dist/assets/FilePlayer-c7stW4vz.js +1 -0
- chainlit/frontend/dist/assets/Kaltura-BmMmgorA.js +1 -0
- chainlit/frontend/dist/assets/Mixcloud-Cw8hDmiO.js +1 -0
- chainlit/frontend/dist/assets/Mux-DiRZfeUf.js +1 -0
- chainlit/frontend/dist/assets/Preview-6Jt2mRHx.js +1 -0
- chainlit/frontend/dist/assets/SoundCloud-DKwcT58_.js +1 -0
- chainlit/frontend/dist/assets/Streamable-BVdxrEeX.js +1 -0
- chainlit/frontend/dist/assets/Twitch-DFqZR7Gu.js +1 -0
- chainlit/frontend/dist/assets/Vidyard-0BQAAtVk.js +1 -0
- chainlit/frontend/dist/assets/Vimeo-CRFSH0Vu.js +1 -0
- chainlit/frontend/dist/assets/Wistia-CKrmdQaG.js +1 -0
- chainlit/frontend/dist/assets/YouTube-CQpL-rvU.js +1 -0
- chainlit/frontend/dist/assets/index-DQmLRKyv.css +1 -0
- chainlit/frontend/dist/assets/index-QdmxtIMQ.js +8665 -0
- chainlit/frontend/dist/assets/react-plotly-B9hvVpUG.js +3484 -0
- chainlit/frontend/dist/index.html +2 -4
- chainlit/haystack/callbacks.py +4 -7
- chainlit/input_widget.py +8 -4
- chainlit/langchain/callbacks.py +103 -68
- chainlit/langflow/__init__.py +1 -0
- chainlit/llama_index/callbacks.py +65 -40
- chainlit/markdown.py +22 -6
- chainlit/message.py +54 -56
- chainlit/mistralai/__init__.py +50 -0
- chainlit/oauth_providers.py +266 -8
- chainlit/openai/__init__.py +10 -18
- chainlit/secret.py +1 -1
- chainlit/server.py +789 -228
- chainlit/session.py +108 -90
- chainlit/slack/__init__.py +6 -0
- chainlit/slack/app.py +397 -0
- chainlit/socket.py +199 -116
- chainlit/step.py +141 -89
- chainlit/sync.py +2 -1
- chainlit/teams/__init__.py +6 -0
- chainlit/teams/app.py +338 -0
- chainlit/translations/bn.json +244 -0
- chainlit/translations/en-US.json +122 -8
- chainlit/translations/gu.json +244 -0
- chainlit/translations/he-IL.json +244 -0
- chainlit/translations/hi.json +244 -0
- chainlit/translations/ja.json +242 -0
- chainlit/translations/kn.json +244 -0
- chainlit/translations/ml.json +244 -0
- chainlit/translations/mr.json +244 -0
- chainlit/translations/nl-NL.json +242 -0
- chainlit/translations/ta.json +244 -0
- chainlit/translations/te.json +244 -0
- chainlit/translations/zh-CN.json +243 -0
- chainlit/translations.py +60 -0
- chainlit/types.py +133 -28
- chainlit/user.py +14 -3
- chainlit/user_session.py +6 -3
- chainlit/utils.py +52 -5
- chainlit/version.py +3 -2
- {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/METADATA +48 -50
- chainlit-2.0.4.dist-info/RECORD +107 -0
- chainlit/cli/utils.py +0 -24
- chainlit/frontend/dist/assets/index-9711593e.js +0 -723
- chainlit/frontend/dist/assets/index-d088547c.css +0 -1
- chainlit/frontend/dist/assets/react-plotly-d8762cc2.js +0 -3602
- chainlit/playground/__init__.py +0 -2
- chainlit/playground/config.py +0 -40
- chainlit/playground/provider.py +0 -108
- chainlit/playground/providers/__init__.py +0 -13
- chainlit/playground/providers/anthropic.py +0 -118
- chainlit/playground/providers/huggingface.py +0 -75
- chainlit/playground/providers/langchain.py +0 -89
- chainlit/playground/providers/openai.py +0 -408
- chainlit/playground/providers/vertexai.py +0 -171
- chainlit/translations/pt-BR.json +0 -155
- chainlit-1.0.401.dist-info/RECORD +0 -66
- /chainlit/copilot/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
- /chainlit/copilot/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
- {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/WHEEL +0 -0
- {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/entry_points.txt +0 -0
chainlit/data/__init__.py
CHANGED
|
@@ -1,428 +1,113 @@
|
|
|
1
|
-
import functools
|
|
2
|
-
import json
|
|
3
1
|
import os
|
|
4
|
-
|
|
5
|
-
from typing import
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Optional
|
|
6
4
|
|
|
7
|
-
import
|
|
8
|
-
from
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
from chainlit.session import WebsocketSession
|
|
12
|
-
from chainlit.types import Feedback, Pagination, ThreadDict, ThreadFilter
|
|
13
|
-
from chainlit.user import PersistedUser, User, UserDict
|
|
14
|
-
from literalai import Attachment
|
|
15
|
-
from literalai import Feedback as ClientFeedback
|
|
16
|
-
from literalai import PageInfo, PaginatedResponse
|
|
17
|
-
from literalai import Step as ClientStep
|
|
18
|
-
from literalai.step import StepDict as ClientStepDict
|
|
19
|
-
from literalai.thread import NumberListFilter, StringFilter, StringListFilter
|
|
20
|
-
from literalai.thread import ThreadFilter as ClientThreadFilter
|
|
5
|
+
from .base import BaseDataLayer
|
|
6
|
+
from .utils import (
|
|
7
|
+
queue_until_user_message as queue_until_user_message, # TODO: Consider deprecating re-export.; Redundant alias tells type checkers to STFU.
|
|
8
|
+
)
|
|
21
9
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
from chainlit.step import FeedbackDict, StepDict
|
|
10
|
+
_data_layer: Optional[BaseDataLayer] = None
|
|
11
|
+
_data_layer_initialized = False
|
|
25
12
|
|
|
26
|
-
_data_layer = None
|
|
27
13
|
|
|
14
|
+
def get_data_layer():
|
|
15
|
+
global _data_layer, _data_layer_initialized
|
|
28
16
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
async def wrapper(self, *args, **kwargs):
|
|
33
|
-
if (
|
|
34
|
-
isinstance(context.session, WebsocketSession)
|
|
35
|
-
and not context.session.has_first_interaction
|
|
36
|
-
):
|
|
37
|
-
# Queue the method invocation waiting for the first user message
|
|
38
|
-
queues = context.session.thread_queues
|
|
39
|
-
method_name = method.__name__
|
|
40
|
-
if method_name not in queues:
|
|
41
|
-
queues[method_name] = deque()
|
|
42
|
-
queues[method_name].append((method, self, args, kwargs))
|
|
43
|
-
|
|
44
|
-
else:
|
|
45
|
-
# Otherwise, Execute the method immediately
|
|
46
|
-
return await method(self, *args, **kwargs)
|
|
47
|
-
|
|
48
|
-
return wrapper
|
|
49
|
-
|
|
50
|
-
return decorator
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class BaseDataLayer:
|
|
54
|
-
"""Base class for data persistence."""
|
|
55
|
-
|
|
56
|
-
async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
|
|
57
|
-
return None
|
|
58
|
-
|
|
59
|
-
async def create_user(self, user: "User") -> Optional["PersistedUser"]:
|
|
60
|
-
pass
|
|
61
|
-
|
|
62
|
-
async def upsert_feedback(
|
|
63
|
-
self,
|
|
64
|
-
feedback: Feedback,
|
|
65
|
-
) -> str:
|
|
66
|
-
return ""
|
|
67
|
-
|
|
68
|
-
@queue_until_user_message()
|
|
69
|
-
async def create_element(self, element_dict: "ElementDict"):
|
|
70
|
-
pass
|
|
71
|
-
|
|
72
|
-
async def get_element(
|
|
73
|
-
self, thread_id: str, element_id: str
|
|
74
|
-
) -> Optional["ElementDict"]:
|
|
75
|
-
pass
|
|
76
|
-
|
|
77
|
-
@queue_until_user_message()
|
|
78
|
-
async def delete_element(self, element_id: str):
|
|
79
|
-
pass
|
|
80
|
-
|
|
81
|
-
@queue_until_user_message()
|
|
82
|
-
async def create_step(self, step_dict: "StepDict"):
|
|
83
|
-
pass
|
|
84
|
-
|
|
85
|
-
@queue_until_user_message()
|
|
86
|
-
async def update_step(self, step_dict: "StepDict"):
|
|
87
|
-
pass
|
|
88
|
-
|
|
89
|
-
@queue_until_user_message()
|
|
90
|
-
async def delete_step(self, step_id: str):
|
|
91
|
-
pass
|
|
92
|
-
|
|
93
|
-
async def get_thread_author(self, thread_id: str) -> str:
|
|
94
|
-
return ""
|
|
95
|
-
|
|
96
|
-
async def delete_thread(self, thread_id: str):
|
|
97
|
-
pass
|
|
98
|
-
|
|
99
|
-
async def list_threads(
|
|
100
|
-
self, pagination: "Pagination", filters: "ThreadFilter"
|
|
101
|
-
) -> "PaginatedResponse[ThreadDict]":
|
|
102
|
-
return PaginatedResponse(
|
|
103
|
-
data=[], pageInfo=PageInfo(hasNextPage=False, endCursor=None)
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
|
|
107
|
-
return None
|
|
108
|
-
|
|
109
|
-
async def update_thread(
|
|
110
|
-
self,
|
|
111
|
-
thread_id: str,
|
|
112
|
-
name: Optional[str] = None,
|
|
113
|
-
user_id: Optional[str] = None,
|
|
114
|
-
metadata: Optional[Dict] = None,
|
|
115
|
-
tags: Optional[List[str]] = None,
|
|
116
|
-
):
|
|
117
|
-
pass
|
|
118
|
-
|
|
119
|
-
async def delete_user_session(self, id: str) -> bool:
|
|
120
|
-
return True
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
class ChainlitDataLayer:
|
|
124
|
-
def __init__(self, api_key: str, server: Optional[str]):
|
|
125
|
-
from literalai import LiteralClient
|
|
126
|
-
|
|
127
|
-
self.client = LiteralClient(api_key=api_key, url=server)
|
|
128
|
-
logger.info("Chainlit data layer initialized")
|
|
129
|
-
|
|
130
|
-
def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict":
|
|
131
|
-
metadata = attachment.metadata or {}
|
|
132
|
-
return {
|
|
133
|
-
"chainlitKey": None,
|
|
134
|
-
"display": metadata.get("display", "side"),
|
|
135
|
-
"language": metadata.get("language"),
|
|
136
|
-
"page": metadata.get("page"),
|
|
137
|
-
"size": metadata.get("size"),
|
|
138
|
-
"type": metadata.get("type", "file"),
|
|
139
|
-
"forId": attachment.step_id,
|
|
140
|
-
"id": attachment.id or "",
|
|
141
|
-
"mime": attachment.mime,
|
|
142
|
-
"name": attachment.name or "",
|
|
143
|
-
"objectKey": attachment.object_key,
|
|
144
|
-
"url": attachment.url,
|
|
145
|
-
"threadId": attachment.thread_id,
|
|
146
|
-
}
|
|
147
|
-
|
|
148
|
-
def feedback_to_feedback_dict(
|
|
149
|
-
self, feedback: Optional[ClientFeedback]
|
|
150
|
-
) -> "Optional[FeedbackDict]":
|
|
151
|
-
if not feedback:
|
|
152
|
-
return None
|
|
153
|
-
return {
|
|
154
|
-
"id": feedback.id or "",
|
|
155
|
-
"forId": feedback.step_id or "",
|
|
156
|
-
"value": feedback.value or 0, # type: ignore
|
|
157
|
-
"comment": feedback.comment,
|
|
158
|
-
"strategy": "BINARY",
|
|
159
|
-
}
|
|
160
|
-
|
|
161
|
-
def step_to_step_dict(self, step: ClientStep) -> "StepDict":
|
|
162
|
-
metadata = step.metadata or {}
|
|
163
|
-
input = (step.input or {}).get("content") or (
|
|
164
|
-
json.dumps(step.input) if step.input and step.input != {} else ""
|
|
165
|
-
)
|
|
166
|
-
output = (step.output or {}).get("content") or (
|
|
167
|
-
json.dumps(step.output) if step.output and step.output != {} else ""
|
|
168
|
-
)
|
|
169
|
-
return {
|
|
170
|
-
"createdAt": step.created_at,
|
|
171
|
-
"id": step.id or "",
|
|
172
|
-
"threadId": step.thread_id or "",
|
|
173
|
-
"parentId": step.parent_id,
|
|
174
|
-
"feedback": self.feedback_to_feedback_dict(step.feedback),
|
|
175
|
-
"start": step.start_time,
|
|
176
|
-
"end": step.end_time,
|
|
177
|
-
"type": step.type or "undefined",
|
|
178
|
-
"name": step.name or "",
|
|
179
|
-
"generation": step.generation.to_dict() if step.generation else None,
|
|
180
|
-
"input": input,
|
|
181
|
-
"output": output,
|
|
182
|
-
"showInput": metadata.get("showInput", False),
|
|
183
|
-
"disableFeedback": metadata.get("disableFeedback", False),
|
|
184
|
-
"indent": metadata.get("indent"),
|
|
185
|
-
"language": metadata.get("language"),
|
|
186
|
-
"isError": metadata.get("isError", False),
|
|
187
|
-
"waitForAnswer": metadata.get("waitForAnswer", False),
|
|
188
|
-
"feedback": self.feedback_to_feedback_dict(step.feedback),
|
|
189
|
-
}
|
|
190
|
-
|
|
191
|
-
async def get_user(self, identifier: str) -> Optional[PersistedUser]:
|
|
192
|
-
user = await self.client.api.get_user(identifier=identifier)
|
|
193
|
-
if not user:
|
|
194
|
-
return None
|
|
195
|
-
return PersistedUser(
|
|
196
|
-
id=user.id or "",
|
|
197
|
-
identifier=user.identifier or "",
|
|
198
|
-
metadata=user.metadata,
|
|
199
|
-
createdAt=user.created_at or "",
|
|
200
|
-
)
|
|
17
|
+
if not _data_layer_initialized:
|
|
18
|
+
if _data_layer:
|
|
19
|
+
# Data layer manually set, warn user that this is deprecated.
|
|
201
20
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
_user = await self.client.api.create_user(
|
|
206
|
-
identifier=user.identifier, metadata=user.metadata
|
|
21
|
+
warnings.warn(
|
|
22
|
+
"Setting data layer manually is deprecated. Use @data_layer instead.",
|
|
23
|
+
DeprecationWarning,
|
|
207
24
|
)
|
|
208
|
-
elif _user.id:
|
|
209
|
-
await self.client.api.update_user(id=_user.id, metadata=user.metadata)
|
|
210
|
-
return PersistedUser(
|
|
211
|
-
id=_user.id or "",
|
|
212
|
-
identifier=_user.identifier or "",
|
|
213
|
-
metadata=_user.metadata,
|
|
214
|
-
createdAt=_user.created_at or "",
|
|
215
|
-
)
|
|
216
25
|
|
|
217
|
-
async def upsert_feedback(
|
|
218
|
-
self,
|
|
219
|
-
feedback: Feedback,
|
|
220
|
-
):
|
|
221
|
-
if feedback.id:
|
|
222
|
-
await self.client.api.update_feedback(
|
|
223
|
-
id=feedback.id,
|
|
224
|
-
update_params={
|
|
225
|
-
"comment": feedback.comment,
|
|
226
|
-
"strategy": feedback.strategy,
|
|
227
|
-
"value": feedback.value,
|
|
228
|
-
},
|
|
229
|
-
)
|
|
230
|
-
return feedback.id
|
|
231
26
|
else:
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
)
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
"threadId": step_dict.get("threadId"),
|
|
318
|
-
"type": step_dict.get("type"),
|
|
319
|
-
"metadata": metadata,
|
|
320
|
-
}
|
|
321
|
-
if step_dict.get("input"):
|
|
322
|
-
step["input"] = {"content": step_dict.get("input")}
|
|
323
|
-
if step_dict.get("output"):
|
|
324
|
-
step["output"] = {"content": step_dict.get("output")}
|
|
325
|
-
|
|
326
|
-
await self.client.api.send_steps([step])
|
|
327
|
-
|
|
328
|
-
@queue_until_user_message()
|
|
329
|
-
async def update_step(self, step_dict: "StepDict"):
|
|
330
|
-
await self.create_step(step_dict)
|
|
331
|
-
|
|
332
|
-
@queue_until_user_message()
|
|
333
|
-
async def delete_step(self, step_id: str):
|
|
334
|
-
await self.client.api.delete_step(id=step_id)
|
|
27
|
+
from chainlit.config import config
|
|
28
|
+
|
|
29
|
+
if config.code.data_layer:
|
|
30
|
+
# When @data_layer is configured, call it to get data layer.
|
|
31
|
+
_data_layer = config.code.data_layer()
|
|
32
|
+
elif database_url := os.environ.get("DATABASE_URL"):
|
|
33
|
+
# Default to Chainlit data layer if DATABASE_URL specified.
|
|
34
|
+
from .chainlit_data_layer import ChainlitDataLayer
|
|
35
|
+
|
|
36
|
+
if os.environ.get("LITERAL_API_KEY"):
|
|
37
|
+
warnings.warn(
|
|
38
|
+
"Both LITERAL_API_KEY and DATABASE_URL specified. Ignoring Literal AI data layer and relying on data layer pointing to DATABASE_URL."
|
|
39
|
+
)
|
|
40
|
+
bucket_name = os.environ.get("BUCKET_NAME")
|
|
41
|
+
|
|
42
|
+
# AWS S3
|
|
43
|
+
aws_region = os.getenv("APP_AWS_REGION")
|
|
44
|
+
aws_access_key = os.getenv("APP_AWS_ACCESS_KEY")
|
|
45
|
+
aws_secret_key = os.getenv("APP_AWS_SECRET_KEY")
|
|
46
|
+
dev_aws_endpoint = os.getenv("DEV_AWS_ENDPOINT")
|
|
47
|
+
is_using_s3 = bool(aws_access_key and aws_secret_key and aws_region)
|
|
48
|
+
|
|
49
|
+
# Google Cloud Storage
|
|
50
|
+
gcs_project_id = os.getenv("APP_GCS_PROJECT_ID")
|
|
51
|
+
gcs_client_email = os.getenv("APP_GCS_CLIENT_EMAIL")
|
|
52
|
+
gcs_private_key = os.getenv("APP_GCS_PRIVATE_KEY")
|
|
53
|
+
is_using_gcs = bool(
|
|
54
|
+
gcs_project_id and gcs_client_email and gcs_private_key
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Azure Storage
|
|
58
|
+
azure_storage_account = os.getenv("APP_AZURE_STORAGE_ACCOUNT")
|
|
59
|
+
azure_storage_key = os.getenv("APP_AZURE_STORAGE_ACCESS_KEY")
|
|
60
|
+
is_using_azure = bool(azure_storage_account and azure_storage_key)
|
|
61
|
+
|
|
62
|
+
storage_client = None
|
|
63
|
+
|
|
64
|
+
if sum([is_using_s3, is_using_gcs, is_using_azure]) > 1:
|
|
65
|
+
warnings.warn(
|
|
66
|
+
"Multiple storage configurations detected. Please use only one."
|
|
67
|
+
)
|
|
68
|
+
elif is_using_s3:
|
|
69
|
+
from chainlit.data.storage_clients.s3 import S3StorageClient
|
|
70
|
+
|
|
71
|
+
storage_client = S3StorageClient(
|
|
72
|
+
bucket=bucket_name,
|
|
73
|
+
region_name=aws_region,
|
|
74
|
+
aws_access_key_id=aws_access_key,
|
|
75
|
+
aws_secret_access_key=aws_secret_key,
|
|
76
|
+
endpoint_url=dev_aws_endpoint,
|
|
77
|
+
)
|
|
78
|
+
elif is_using_gcs:
|
|
79
|
+
from chainlit.data.storage_clients.gcs import GCSStorageClient
|
|
80
|
+
|
|
81
|
+
storage_client = GCSStorageClient(
|
|
82
|
+
project_id=gcs_project_id,
|
|
83
|
+
client_email=gcs_client_email,
|
|
84
|
+
private_key=gcs_private_key,
|
|
85
|
+
bucket_name=bucket_name,
|
|
86
|
+
)
|
|
87
|
+
elif is_using_azure:
|
|
88
|
+
from chainlit.data.storage_clients.azure_blob import (
|
|
89
|
+
AzureBlobStorageClient,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
storage_client = AzureBlobStorageClient(
|
|
93
|
+
container_name=bucket_name,
|
|
94
|
+
storage_account=azure_storage_account,
|
|
95
|
+
storage_key=azure_storage_key,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
_data_layer = ChainlitDataLayer(
|
|
99
|
+
database_url=database_url, storage_client=storage_client
|
|
100
|
+
)
|
|
101
|
+
elif api_key := os.environ.get("LITERAL_API_KEY"):
|
|
102
|
+
# When LITERAL_API_KEY is defined, use Literal AI data layer
|
|
103
|
+
from .literalai import LiteralDataLayer
|
|
104
|
+
|
|
105
|
+
# support legacy LITERAL_SERVER variable as fallback
|
|
106
|
+
server = os.environ.get("LITERAL_API_URL") or os.environ.get(
|
|
107
|
+
"LITERAL_SERVER"
|
|
108
|
+
)
|
|
109
|
+
_data_layer = LiteralDataLayer(api_key=api_key, server=server)
|
|
110
|
+
|
|
111
|
+
_data_layer_initialized = True
|
|
335
112
|
|
|
336
|
-
async def get_thread_author(self, thread_id: str) -> str:
|
|
337
|
-
thread = await self.get_thread(thread_id)
|
|
338
|
-
if not thread:
|
|
339
|
-
return ""
|
|
340
|
-
user = thread.get("user")
|
|
341
|
-
if not user:
|
|
342
|
-
return ""
|
|
343
|
-
return user.get("identifier") or ""
|
|
344
|
-
|
|
345
|
-
async def delete_thread(self, thread_id: str):
|
|
346
|
-
await self.client.api.delete_thread(id=thread_id)
|
|
347
|
-
|
|
348
|
-
async def list_threads(
|
|
349
|
-
self, pagination: "Pagination", filters: "ThreadFilter"
|
|
350
|
-
) -> "PaginatedResponse[ThreadDict]":
|
|
351
|
-
if not filters.userIdentifier:
|
|
352
|
-
raise ValueError("userIdentifier is required")
|
|
353
|
-
|
|
354
|
-
client_filters = ClientThreadFilter(
|
|
355
|
-
participantsIdentifier=StringListFilter(
|
|
356
|
-
operator="in", value=[filters.userIdentifier]
|
|
357
|
-
),
|
|
358
|
-
)
|
|
359
|
-
if filters.search:
|
|
360
|
-
client_filters.search = StringFilter(operator="ilike", value=filters.search)
|
|
361
|
-
if filters.feedback:
|
|
362
|
-
client_filters.feedbacksValue = NumberListFilter(
|
|
363
|
-
operator="in", value=[filters.feedback]
|
|
364
|
-
)
|
|
365
|
-
return await self.client.api.list_threads(
|
|
366
|
-
first=pagination.first, after=pagination.cursor, filters=client_filters
|
|
367
|
-
)
|
|
368
|
-
|
|
369
|
-
async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
|
|
370
|
-
thread = await self.client.api.get_thread(id=thread_id)
|
|
371
|
-
if not thread:
|
|
372
|
-
return None
|
|
373
|
-
elements = [] # List[ElementDict]
|
|
374
|
-
steps = [] # List[StepDict]
|
|
375
|
-
if thread.steps:
|
|
376
|
-
for step in thread.steps:
|
|
377
|
-
if config.ui.hide_cot and step.parent_id:
|
|
378
|
-
continue
|
|
379
|
-
for attachment in step.attachments:
|
|
380
|
-
elements.append(self.attachment_to_element_dict(attachment))
|
|
381
|
-
if not config.features.prompt_playground and step.generation:
|
|
382
|
-
step.generation = None
|
|
383
|
-
steps.append(self.step_to_step_dict(step))
|
|
384
|
-
|
|
385
|
-
user = None # type: Optional["UserDict"]
|
|
386
|
-
|
|
387
|
-
if thread.user:
|
|
388
|
-
user = {
|
|
389
|
-
"id": thread.user.id or "",
|
|
390
|
-
"identifier": thread.user.identifier or "",
|
|
391
|
-
"metadata": thread.user.metadata,
|
|
392
|
-
}
|
|
393
|
-
|
|
394
|
-
return {
|
|
395
|
-
"createdAt": thread.created_at or "",
|
|
396
|
-
"id": thread.id,
|
|
397
|
-
"name": thread.name or None,
|
|
398
|
-
"steps": steps,
|
|
399
|
-
"elements": elements,
|
|
400
|
-
"metadata": thread.metadata,
|
|
401
|
-
"user": user,
|
|
402
|
-
"tags": thread.tags,
|
|
403
|
-
}
|
|
404
|
-
|
|
405
|
-
async def update_thread(
|
|
406
|
-
self,
|
|
407
|
-
thread_id: str,
|
|
408
|
-
name: Optional[str] = None,
|
|
409
|
-
user_id: Optional[str] = None,
|
|
410
|
-
metadata: Optional[Dict] = None,
|
|
411
|
-
tags: Optional[List[str]] = None,
|
|
412
|
-
):
|
|
413
|
-
await self.client.api.upsert_thread(
|
|
414
|
-
thread_id=thread_id,
|
|
415
|
-
name=name,
|
|
416
|
-
participant_id=user_id,
|
|
417
|
-
metadata=metadata,
|
|
418
|
-
tags=tags,
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
if api_key := os.environ.get("LITERAL_API_KEY"):
|
|
423
|
-
server = os.environ.get("LITERAL_SERVER")
|
|
424
|
-
_data_layer = ChainlitDataLayer(api_key=api_key, server=server)
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
def get_data_layer():
|
|
428
113
|
return _data_layer
|
chainlit/data/acl.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
|
-
from chainlit.data import get_data_layer
|
|
2
1
|
from fastapi import HTTPException
|
|
3
2
|
|
|
3
|
+
from chainlit.data import get_data_layer
|
|
4
|
+
|
|
4
5
|
|
|
5
6
|
async def is_thread_author(username: str, thread_id: str):
|
|
6
7
|
data_layer = get_data_layer()
|
|
7
8
|
if not data_layer:
|
|
8
|
-
raise HTTPException(status_code=
|
|
9
|
+
raise HTTPException(status_code=400, detail="Data layer not initialized")
|
|
9
10
|
|
|
10
11
|
thread_author = await data_layer.get_thread_author(thread_id)
|
|
11
12
|
|
|
13
|
+
if not thread_author:
|
|
14
|
+
raise HTTPException(status_code=404, detail="Thread not found")
|
|
15
|
+
|
|
12
16
|
if thread_author != username:
|
|
13
17
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
14
18
|
else:
|
chainlit/data/base.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from chainlit.types import (
|
|
5
|
+
Feedback,
|
|
6
|
+
PaginatedResponse,
|
|
7
|
+
Pagination,
|
|
8
|
+
ThreadDict,
|
|
9
|
+
ThreadFilter,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from .utils import queue_until_user_message
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from chainlit.element import Element, ElementDict
|
|
16
|
+
from chainlit.step import StepDict
|
|
17
|
+
from chainlit.user import PersistedUser, User
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BaseDataLayer(ABC):
|
|
21
|
+
"""Base class for data persistence."""
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
async def create_user(self, user: "User") -> Optional["PersistedUser"]:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
async def delete_feedback(
|
|
33
|
+
self,
|
|
34
|
+
feedback_id: str,
|
|
35
|
+
) -> bool:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
async def upsert_feedback(
|
|
40
|
+
self,
|
|
41
|
+
feedback: Feedback,
|
|
42
|
+
) -> str:
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
@queue_until_user_message()
|
|
46
|
+
@abstractmethod
|
|
47
|
+
async def create_element(self, element: "Element"):
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
async def get_element(
|
|
52
|
+
self, thread_id: str, element_id: str
|
|
53
|
+
) -> Optional["ElementDict"]:
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@queue_until_user_message()
|
|
57
|
+
@abstractmethod
|
|
58
|
+
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
@queue_until_user_message()
|
|
62
|
+
@abstractmethod
|
|
63
|
+
async def create_step(self, step_dict: "StepDict"):
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
@queue_until_user_message()
|
|
67
|
+
@abstractmethod
|
|
68
|
+
async def update_step(self, step_dict: "StepDict"):
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
@queue_until_user_message()
|
|
72
|
+
@abstractmethod
|
|
73
|
+
async def delete_step(self, step_id: str):
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
async def get_thread_author(self, thread_id: str) -> str:
|
|
78
|
+
return ""
|
|
79
|
+
|
|
80
|
+
@abstractmethod
|
|
81
|
+
async def delete_thread(self, thread_id: str):
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
async def list_threads(
|
|
86
|
+
self, pagination: "Pagination", filters: "ThreadFilter"
|
|
87
|
+
) -> "PaginatedResponse[ThreadDict]":
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
async def update_thread(
|
|
96
|
+
self,
|
|
97
|
+
thread_id: str,
|
|
98
|
+
name: Optional[str] = None,
|
|
99
|
+
user_id: Optional[str] = None,
|
|
100
|
+
metadata: Optional[Dict] = None,
|
|
101
|
+
tags: Optional[List[str]] = None,
|
|
102
|
+
):
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
@abstractmethod
|
|
106
|
+
async def build_debug_url(self) -> str:
|
|
107
|
+
pass
|