chainlit 1.1.202__py3-none-any.whl → 1.1.300__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +22 -4
- chainlit/cli/__init__.py +53 -6
- chainlit/config.py +25 -18
- chainlit/context.py +9 -0
- chainlit/copilot/dist/index.js +443 -410
- chainlit/data/__init__.py +19 -5
- chainlit/data/dynamodb.py +586 -0
- chainlit/data/sql_alchemy.py +47 -28
- chainlit/discord/app.py +4 -2
- chainlit/element.py +36 -20
- chainlit/emitter.py +8 -7
- chainlit/frontend/dist/assets/{DailyMotion-53376209.js → DailyMotion-578b63e6.js} +1 -1
- chainlit/frontend/dist/assets/{Facebook-aee41f5b.js → Facebook-b825e5bb.js} +1 -1
- chainlit/frontend/dist/assets/{FilePlayer-b2cdb30f.js → FilePlayer-bcba3b4e.js} +1 -1
- chainlit/frontend/dist/assets/{Kaltura-51db0377.js → Kaltura-fc1c9497.js} +1 -1
- chainlit/frontend/dist/assets/{Mixcloud-cb900886.js → Mixcloud-4cfb2724.js} +1 -1
- chainlit/frontend/dist/assets/{Mux-79ac59e6.js → Mux-aa92055c.js} +1 -1
- chainlit/frontend/dist/assets/{Preview-cfe7584c.js → Preview-9f55905a.js} +1 -1
- chainlit/frontend/dist/assets/{SoundCloud-a985707c.js → SoundCloud-f991fe03.js} +1 -1
- chainlit/frontend/dist/assets/{Streamable-3d89aab5.js → Streamable-53128f49.js} +1 -1
- chainlit/frontend/dist/assets/{Twitch-bf016588.js → Twitch-fce8b9f5.js} +1 -1
- chainlit/frontend/dist/assets/{Vidyard-1891ecd7.js → Vidyard-e35c6102.js} +1 -1
- chainlit/frontend/dist/assets/{Vimeo-0645662c.js → Vimeo-fff35f8e.js} +1 -1
- chainlit/frontend/dist/assets/{Wistia-3b449fe2.js → Wistia-ec07dc64.js} +1 -1
- chainlit/frontend/dist/assets/{YouTube-5ea2381e.js → YouTube-ad068e2a.js} +1 -1
- chainlit/frontend/dist/assets/index-aaf974a9.css +1 -0
- chainlit/frontend/dist/assets/index-d40d41cc.js +727 -0
- chainlit/frontend/dist/assets/{react-plotly-2ff19c9f.js → react-plotly-b2c6442b.js} +1 -1
- chainlit/frontend/dist/index.html +2 -3
- chainlit/langchain/callbacks.py +4 -2
- chainlit/llama_index/callbacks.py +2 -2
- chainlit/message.py +30 -25
- chainlit/oauth_providers.py +118 -0
- chainlit/server.py +208 -83
- chainlit/slack/app.py +2 -3
- chainlit/socket.py +27 -23
- chainlit/step.py +44 -30
- chainlit/teams/__init__.py +6 -0
- chainlit/teams/app.py +332 -0
- chainlit/translations/en-US.json +2 -4
- chainlit/types.py +17 -17
- chainlit/user.py +9 -1
- chainlit/utils.py +47 -3
- {chainlit-1.1.202.dist-info → chainlit-1.1.300.dist-info}/METADATA +22 -14
- chainlit-1.1.300.dist-info/RECORD +79 -0
- chainlit/cli/utils.py +0 -24
- chainlit/frontend/dist/assets/index-a0c5a67e.js +0 -698
- chainlit/frontend/dist/assets/index-d088547c.css +0 -1
- chainlit/playground/__init__.py +0 -2
- chainlit/playground/config.py +0 -36
- chainlit/playground/provider.py +0 -108
- chainlit/playground/providers/__init__.py +0 -11
- chainlit/playground/providers/anthropic.py +0 -118
- chainlit/playground/providers/huggingface.py +0 -75
- chainlit/playground/providers/langchain.py +0 -89
- chainlit/playground/providers/openai.py +0 -386
- chainlit/playground/providers/vertexai.py +0 -171
- chainlit-1.1.202.dist-info/RECORD +0 -86
- {chainlit-1.1.202.dist-info → chainlit-1.1.300.dist-info}/WHEEL +0 -0
- {chainlit-1.1.202.dist-info → chainlit-1.1.300.dist-info}/entry_points.txt +0 -0
chainlit/data/__init__.py
CHANGED
|
@@ -94,7 +94,7 @@ class BaseDataLayer:
|
|
|
94
94
|
pass
|
|
95
95
|
|
|
96
96
|
@queue_until_user_message()
|
|
97
|
-
async def delete_element(self, element_id: str):
|
|
97
|
+
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
|
|
98
98
|
pass
|
|
99
99
|
|
|
100
100
|
@queue_until_user_message()
|
|
@@ -139,6 +139,9 @@ class BaseDataLayer:
|
|
|
139
139
|
async def delete_user_session(self, id: str) -> bool:
|
|
140
140
|
return True
|
|
141
141
|
|
|
142
|
+
async def build_debug_url(self) -> str:
|
|
143
|
+
return ""
|
|
144
|
+
|
|
142
145
|
|
|
143
146
|
_data_layer: Optional[BaseDataLayer] = None
|
|
144
147
|
|
|
@@ -225,6 +228,14 @@ class ChainlitDataLayer(BaseDataLayer):
|
|
|
225
228
|
"waitForAnswer": metadata.get("waitForAnswer", False),
|
|
226
229
|
}
|
|
227
230
|
|
|
231
|
+
async def build_debug_url(self) -> str:
|
|
232
|
+
try:
|
|
233
|
+
project_id = await self.client.api.get_my_project_id()
|
|
234
|
+
return f"{self.client.api.url}/projects/{project_id}/threads?threadId=[thread_id]¤tStepId=[step_id]"
|
|
235
|
+
except Exception as e:
|
|
236
|
+
logger.error(f"Error building debug url: {e}")
|
|
237
|
+
return ""
|
|
238
|
+
|
|
228
239
|
async def get_user(self, identifier: str) -> Optional[PersistedUser]:
|
|
229
240
|
user = await self.client.api.get_user(identifier=identifier)
|
|
230
241
|
if not user:
|
|
@@ -341,7 +352,7 @@ class ChainlitDataLayer(BaseDataLayer):
|
|
|
341
352
|
return self.attachment_to_element_dict(attachment)
|
|
342
353
|
|
|
343
354
|
@queue_until_user_message()
|
|
344
|
-
async def delete_element(self, element_id: str):
|
|
355
|
+
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
|
|
345
356
|
await self.client.api.delete_attachment(id=element_id)
|
|
346
357
|
|
|
347
358
|
@queue_until_user_message()
|
|
@@ -456,12 +467,15 @@ class ChainlitDataLayer(BaseDataLayer):
|
|
|
456
467
|
steps = [] # List[StepDict]
|
|
457
468
|
if thread.steps:
|
|
458
469
|
for step in thread.steps:
|
|
459
|
-
if
|
|
470
|
+
if step.type == "system_message":
|
|
471
|
+
continue
|
|
472
|
+
if config.ui.hide_cot and step.type not in [
|
|
473
|
+
"user_message",
|
|
474
|
+
"assistant_message",
|
|
475
|
+
]:
|
|
460
476
|
continue
|
|
461
477
|
for attachment in step.attachments:
|
|
462
478
|
elements.append(self.attachment_to_element_dict(attachment))
|
|
463
|
-
if not config.features.prompt_playground and step.generation:
|
|
464
|
-
step.generation = None
|
|
465
479
|
steps.append(self.step_to_step_dict(step))
|
|
466
480
|
|
|
467
481
|
return {
|
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
from dataclasses import asdict
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
9
|
+
|
|
10
|
+
import aiofiles
|
|
11
|
+
import aiohttp
|
|
12
|
+
import boto3 # type: ignore
|
|
13
|
+
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
|
|
14
|
+
from chainlit.context import context
|
|
15
|
+
from chainlit.data import BaseDataLayer, BaseStorageClient, queue_until_user_message
|
|
16
|
+
from chainlit.element import ElementDict
|
|
17
|
+
from chainlit.logger import logger
|
|
18
|
+
from chainlit.step import StepDict
|
|
19
|
+
from chainlit.types import (
|
|
20
|
+
Feedback,
|
|
21
|
+
PageInfo,
|
|
22
|
+
PaginatedResponse,
|
|
23
|
+
Pagination,
|
|
24
|
+
ThreadDict,
|
|
25
|
+
ThreadFilter,
|
|
26
|
+
)
|
|
27
|
+
from chainlit.user import PersistedUser, User
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from chainlit.element import Element
|
|
31
|
+
from mypy_boto3_dynamodb import DynamoDBClient
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
_logger = logger.getChild("DynamoDB")
|
|
35
|
+
_logger.setLevel(logging.WARNING)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DynamoDBDataLayer(BaseDataLayer):
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
table_name: str,
|
|
43
|
+
client: Optional["DynamoDBClient"] = None,
|
|
44
|
+
storage_provider: Optional[BaseStorageClient] = None,
|
|
45
|
+
user_thread_limit: int = 10,
|
|
46
|
+
):
|
|
47
|
+
if client:
|
|
48
|
+
self.client = client
|
|
49
|
+
else:
|
|
50
|
+
region_name = os.environ.get("AWS_REGION", "us-east-1")
|
|
51
|
+
self.client = boto3.client("dynamodb", region_name=region_name) # type: ignore
|
|
52
|
+
|
|
53
|
+
self.table_name = table_name
|
|
54
|
+
self.storage_provider = storage_provider
|
|
55
|
+
self.user_thread_limit = user_thread_limit
|
|
56
|
+
|
|
57
|
+
self._type_deserializer = TypeDeserializer()
|
|
58
|
+
self._type_serializer = TypeSerializer()
|
|
59
|
+
|
|
60
|
+
def _get_current_timestamp(self) -> str:
|
|
61
|
+
return datetime.now().isoformat() + "Z"
|
|
62
|
+
|
|
63
|
+
def _serialize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
|
64
|
+
return {
|
|
65
|
+
key: self._type_serializer.serialize(value) for key, value in item.items()
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
def _deserialize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
|
69
|
+
return {
|
|
70
|
+
key: self._type_deserializer.deserialize(value)
|
|
71
|
+
for key, value in item.items()
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
def _update_item(self, key: Dict[str, Any], updates: Dict[str, Any]):
|
|
75
|
+
update_expr: List[str] = []
|
|
76
|
+
expression_attribute_names = {}
|
|
77
|
+
expression_attribute_values = {}
|
|
78
|
+
|
|
79
|
+
for index, (attr, value) in enumerate(updates.items()):
|
|
80
|
+
if not value:
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
k, v = f"#{index}", f":{index}"
|
|
84
|
+
update_expr.append(f"{k} = {v}")
|
|
85
|
+
expression_attribute_names[k] = attr
|
|
86
|
+
expression_attribute_values[v] = value
|
|
87
|
+
|
|
88
|
+
self.client.update_item(
|
|
89
|
+
TableName=self.table_name,
|
|
90
|
+
Key=self._serialize_item(key),
|
|
91
|
+
UpdateExpression="SET " + ", ".join(update_expr),
|
|
92
|
+
ExpressionAttributeNames=expression_attribute_names,
|
|
93
|
+
ExpressionAttributeValues=self._serialize_item(expression_attribute_values),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
|
|
97
|
+
_logger.info("DynamoDB: get_user identifier=%s", identifier)
|
|
98
|
+
|
|
99
|
+
response = self.client.get_item(
|
|
100
|
+
TableName=self.table_name,
|
|
101
|
+
Key={
|
|
102
|
+
"PK": {"S": f"USER#{identifier}"},
|
|
103
|
+
"SK": {"S": "USER"},
|
|
104
|
+
},
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if "Item" not in response:
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
user = self._deserialize_item(response["Item"])
|
|
111
|
+
|
|
112
|
+
return PersistedUser(
|
|
113
|
+
id=user["id"],
|
|
114
|
+
identifier=user["identifier"],
|
|
115
|
+
createdAt=user["createdAt"],
|
|
116
|
+
metadata=user["metadata"],
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
async def create_user(self, user: "User") -> Optional["PersistedUser"]:
|
|
120
|
+
_logger.info("DynamoDB: create_user user.identifier=%s", user.identifier)
|
|
121
|
+
|
|
122
|
+
ts = self._get_current_timestamp()
|
|
123
|
+
metadata: Dict[Any, Any] = user.metadata # type: ignore
|
|
124
|
+
|
|
125
|
+
item = {
|
|
126
|
+
"PK": f"USER#{user.identifier}",
|
|
127
|
+
"SK": "USER",
|
|
128
|
+
"id": user.identifier,
|
|
129
|
+
"identifier": user.identifier,
|
|
130
|
+
"metadata": metadata,
|
|
131
|
+
"createdAt": ts,
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
self.client.put_item(
|
|
135
|
+
TableName=self.table_name,
|
|
136
|
+
Item=self._serialize_item(item),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return PersistedUser(
|
|
140
|
+
id=user.identifier,
|
|
141
|
+
identifier=user.identifier,
|
|
142
|
+
createdAt=ts,
|
|
143
|
+
metadata=metadata,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
async def delete_feedback(self, feedback_id: str) -> bool:
|
|
147
|
+
_logger.info("DynamoDB: delete_feedback feedback_id=%s", feedback_id)
|
|
148
|
+
|
|
149
|
+
# feedback id = THREAD#{thread_id}::STEP#{step_id}
|
|
150
|
+
thread_id, step_id = feedback_id.split("::")
|
|
151
|
+
thread_id = thread_id.strip("THREAD#")
|
|
152
|
+
step_id = step_id.strip("STEP#")
|
|
153
|
+
|
|
154
|
+
self.client.update_item(
|
|
155
|
+
TableName=self.table_name,
|
|
156
|
+
Key={
|
|
157
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
158
|
+
"SK": {"S": f"STEP#{step_id}"},
|
|
159
|
+
},
|
|
160
|
+
UpdateExpression="REMOVE #feedback",
|
|
161
|
+
ExpressionAttributeNames={"#feedback": "feedback"},
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
async def upsert_feedback(self, feedback: Feedback) -> str:
|
|
167
|
+
_logger.info(
|
|
168
|
+
"DynamoDB: upsert_feedback thread=%s step=%s value=%s",
|
|
169
|
+
feedback.threadId,
|
|
170
|
+
feedback.forId,
|
|
171
|
+
feedback.value,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if not feedback.forId:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
"DynamoDB datalayer expects value for feedback.threadId got None"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
feedback.id = f"THREAD#{feedback.threadId}::STEP#{feedback.forId}"
|
|
180
|
+
searialized_feedback = self._type_serializer.serialize(asdict(feedback))
|
|
181
|
+
|
|
182
|
+
self.client.update_item(
|
|
183
|
+
TableName=self.table_name,
|
|
184
|
+
Key={
|
|
185
|
+
"PK": {"S": f"THREAD#{feedback.threadId}"},
|
|
186
|
+
"SK": {"S": f"STEP#{feedback.forId}"},
|
|
187
|
+
},
|
|
188
|
+
UpdateExpression="SET #feedback = :feedback",
|
|
189
|
+
ExpressionAttributeNames={"#feedback": "feedback"},
|
|
190
|
+
ExpressionAttributeValues={":feedback": searialized_feedback},
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
return feedback.id
|
|
194
|
+
|
|
195
|
+
@queue_until_user_message()
|
|
196
|
+
async def create_element(self, element: "Element"):
|
|
197
|
+
_logger.info(
|
|
198
|
+
"DynamoDB: create_element thread=%s step=%s type=%s",
|
|
199
|
+
element.thread_id,
|
|
200
|
+
element.for_id,
|
|
201
|
+
element.type,
|
|
202
|
+
)
|
|
203
|
+
_logger.debug("DynamoDB: create_element: %s", element.to_dict())
|
|
204
|
+
|
|
205
|
+
if not element.for_id:
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
if not self.storage_provider:
|
|
209
|
+
_logger.warning(
|
|
210
|
+
"DynamoDB: create_element error. No storage_provider is configured!"
|
|
211
|
+
)
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
content: Optional[Union[bytes, str]] = None
|
|
215
|
+
|
|
216
|
+
if element.content:
|
|
217
|
+
content = element.content
|
|
218
|
+
|
|
219
|
+
elif element.path:
|
|
220
|
+
_logger.debug("DynamoDB: create_element reading file %s", element.path)
|
|
221
|
+
async with aiofiles.open(element.path, "rb") as f:
|
|
222
|
+
content = await f.read()
|
|
223
|
+
|
|
224
|
+
elif element.url:
|
|
225
|
+
_logger.debug("DynamoDB: create_element http %s", element.url)
|
|
226
|
+
async with aiohttp.ClientSession() as session:
|
|
227
|
+
async with session.get(element.url) as response:
|
|
228
|
+
if response.status == 200:
|
|
229
|
+
content = await response.read()
|
|
230
|
+
else:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
f"Failed to read content from {element.url} status {response.status}",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
else:
|
|
236
|
+
raise ValueError("Element url, path or content must be provided")
|
|
237
|
+
|
|
238
|
+
if content is None:
|
|
239
|
+
raise ValueError("Content is None, cannot upload file")
|
|
240
|
+
|
|
241
|
+
if not element.mime:
|
|
242
|
+
element.mime = "application/octet-stream"
|
|
243
|
+
|
|
244
|
+
context_user = context.session.user
|
|
245
|
+
user_folder = getattr(context_user, "id", "unknown")
|
|
246
|
+
file_object_key = f"{user_folder}/{element.thread_id}/{element.id}"
|
|
247
|
+
|
|
248
|
+
uploaded_file = await self.storage_provider.upload_file(
|
|
249
|
+
object_key=file_object_key,
|
|
250
|
+
data=content,
|
|
251
|
+
mime=element.mime,
|
|
252
|
+
overwrite=True,
|
|
253
|
+
)
|
|
254
|
+
if not uploaded_file:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
"DynamoDB Error: create_element, Failed to persist data in storage_provider",
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
element_dict: Dict[str, Any] = element.to_dict() # type: ignore
|
|
260
|
+
element_dict.update(
|
|
261
|
+
{
|
|
262
|
+
"PK": f"THREAD#{element.thread_id}",
|
|
263
|
+
"SK": f"ELEMENT#{element.id}",
|
|
264
|
+
"url": uploaded_file.get("url"),
|
|
265
|
+
"objectKey": uploaded_file.get("object_key"),
|
|
266
|
+
}
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
self.client.put_item(
|
|
270
|
+
TableName=self.table_name,
|
|
271
|
+
Item=self._serialize_item(element_dict),
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
async def get_element(
|
|
275
|
+
self, thread_id: str, element_id: str
|
|
276
|
+
) -> Optional["ElementDict"]:
|
|
277
|
+
_logger.info(
|
|
278
|
+
"DynamoDB: get_element thread=%s element=%s", thread_id, element_id
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
response = self.client.get_item(
|
|
282
|
+
TableName=self.table_name,
|
|
283
|
+
Key={
|
|
284
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
285
|
+
"SK": {"S": f"ELEMENT#{element_id}"},
|
|
286
|
+
},
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
if "Item" not in response:
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
return self._deserialize_item(response["Item"]) # type: ignore
|
|
293
|
+
|
|
294
|
+
@queue_until_user_message()
|
|
295
|
+
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
|
|
296
|
+
thread_id = context.session.thread_id
|
|
297
|
+
_logger.info(
|
|
298
|
+
"DynamoDB: delete_element thread=%s element=%s", thread_id, element_id
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
self.client.delete_item(
|
|
302
|
+
TableName=self.table_name,
|
|
303
|
+
Key={
|
|
304
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
305
|
+
"SK": {"S": f"ELEMENT#{element_id}"},
|
|
306
|
+
},
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
@queue_until_user_message()
|
|
310
|
+
async def create_step(self, step_dict: "StepDict"):
|
|
311
|
+
_logger.info(
|
|
312
|
+
"DynamoDB: create_step thread=%s step=%s",
|
|
313
|
+
step_dict.get("threadId"),
|
|
314
|
+
step_dict.get("id"),
|
|
315
|
+
)
|
|
316
|
+
_logger.debug("DynamoDB: create_step: %s", step_dict)
|
|
317
|
+
|
|
318
|
+
item = dict(step_dict)
|
|
319
|
+
item.update(
|
|
320
|
+
{
|
|
321
|
+
# ignore type, dynamo needs these so we want to fail if not set
|
|
322
|
+
"PK": f"THREAD#{step_dict['threadId']}", # type: ignore
|
|
323
|
+
"SK": f"STEP#{step_dict['id']}", # type: ignore
|
|
324
|
+
}
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
self.client.put_item(
|
|
328
|
+
TableName=self.table_name,
|
|
329
|
+
Item=self._serialize_item(item),
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
@queue_until_user_message()
|
|
333
|
+
async def update_step(self, step_dict: "StepDict"):
|
|
334
|
+
_logger.info(
|
|
335
|
+
"DynamoDB: update_step thread=%s step=%s",
|
|
336
|
+
step_dict.get("threadId"),
|
|
337
|
+
step_dict.get("id"),
|
|
338
|
+
)
|
|
339
|
+
_logger.debug("DynamoDB: update_step: %s", step_dict)
|
|
340
|
+
|
|
341
|
+
self._update_item(
|
|
342
|
+
key={
|
|
343
|
+
# ignore type, dynamo needs these so we want to fail if not set
|
|
344
|
+
"PK": f"THREAD#{step_dict['threadId']}", # type: ignore
|
|
345
|
+
"SK": f"STEP#{step_dict['id']}", # type: ignore
|
|
346
|
+
},
|
|
347
|
+
updates=step_dict, # type: ignore
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
@queue_until_user_message()
|
|
351
|
+
async def delete_step(self, step_id: str):
|
|
352
|
+
thread_id = context.session.thread_id
|
|
353
|
+
_logger.info("DynamoDB: delete_feedback thread=%s step=%s", thread_id, step_id)
|
|
354
|
+
|
|
355
|
+
self.client.delete_item(
|
|
356
|
+
TableName=self.table_name,
|
|
357
|
+
Key={
|
|
358
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
359
|
+
"SK": {"S": f"STEP#{step_id}"},
|
|
360
|
+
},
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
async def get_thread_author(self, thread_id: str) -> str:
|
|
364
|
+
_logger.info("DynamoDB: get_thread_author thread=%s", thread_id)
|
|
365
|
+
|
|
366
|
+
response = self.client.get_item(
|
|
367
|
+
TableName=self.table_name,
|
|
368
|
+
Key={
|
|
369
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
370
|
+
"SK": {"S": "THREAD"},
|
|
371
|
+
},
|
|
372
|
+
ProjectionExpression="userId",
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
if "Item" not in response:
|
|
376
|
+
raise ValueError(f"Author not found for thread_id {thread_id}")
|
|
377
|
+
|
|
378
|
+
item = self._deserialize_item(response["Item"])
|
|
379
|
+
return item["userId"]
|
|
380
|
+
|
|
381
|
+
async def delete_thread(self, thread_id: str):
|
|
382
|
+
_logger.info("DynamoDB: delete_thread thread=%s", thread_id)
|
|
383
|
+
|
|
384
|
+
thread = await self.get_thread(thread_id)
|
|
385
|
+
if not thread:
|
|
386
|
+
return
|
|
387
|
+
|
|
388
|
+
items: List[Any] = thread["steps"]
|
|
389
|
+
if thread["elements"]:
|
|
390
|
+
items.extend(thread["elements"])
|
|
391
|
+
|
|
392
|
+
delete_requests = []
|
|
393
|
+
for item in items:
|
|
394
|
+
key = self._serialize_item({"PK": item["PK"], "SK": item["SK"]})
|
|
395
|
+
req = {"DeleteRequest": {"Key": key}}
|
|
396
|
+
delete_requests.append(req)
|
|
397
|
+
|
|
398
|
+
BATCH_ITEM_SIZE = 25 # pylint: disable=invalid-name
|
|
399
|
+
for i in range(0, len(delete_requests), BATCH_ITEM_SIZE):
|
|
400
|
+
chunk = delete_requests[i : i + BATCH_ITEM_SIZE] # noqa: E203
|
|
401
|
+
response = self.client.batch_write_item(
|
|
402
|
+
RequestItems={
|
|
403
|
+
self.table_name: chunk, # type: ignore
|
|
404
|
+
}
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
backoff_time = 1
|
|
408
|
+
while "UnprocessedItems" in response and response["UnprocessedItems"]:
|
|
409
|
+
backoff_time *= 2
|
|
410
|
+
# Cap the backoff time at 32 seconds & add jitter
|
|
411
|
+
delay = min(backoff_time, 32) + random.uniform(0, 1)
|
|
412
|
+
await asyncio.sleep(delay)
|
|
413
|
+
|
|
414
|
+
response = self.client.batch_write_item(
|
|
415
|
+
RequestItems=response["UnprocessedItems"]
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
self.client.delete_item(
|
|
419
|
+
TableName=self.table_name,
|
|
420
|
+
Key={
|
|
421
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
422
|
+
"SK": {"S": "THREAD"},
|
|
423
|
+
},
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
async def list_threads(
|
|
427
|
+
self, pagination: "Pagination", filters: "ThreadFilter"
|
|
428
|
+
) -> "PaginatedResponse[ThreadDict]":
|
|
429
|
+
_logger.info("DynamoDB: list_threads filters.userId=%s", filters.userId)
|
|
430
|
+
|
|
431
|
+
if filters.feedback:
|
|
432
|
+
_logger.warning("DynamoDB: filters on feedback not supported")
|
|
433
|
+
|
|
434
|
+
paginated_response: PaginatedResponse[ThreadDict] = PaginatedResponse(
|
|
435
|
+
data=[],
|
|
436
|
+
pageInfo=PageInfo(
|
|
437
|
+
hasNextPage=False, startCursor=pagination.cursor, endCursor=None
|
|
438
|
+
),
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
query_args: Dict[str, Any] = {
|
|
442
|
+
"TableName": self.table_name,
|
|
443
|
+
"IndexName": "UserThread",
|
|
444
|
+
"ScanIndexForward": False,
|
|
445
|
+
"Limit": self.user_thread_limit,
|
|
446
|
+
"KeyConditionExpression": "#UserThreadPK = :pk",
|
|
447
|
+
"ExpressionAttributeNames": {
|
|
448
|
+
"#UserThreadPK": "UserThreadPK",
|
|
449
|
+
},
|
|
450
|
+
"ExpressionAttributeValues": {
|
|
451
|
+
":pk": {"S": f"USER#{filters.userId}"},
|
|
452
|
+
},
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
if pagination.cursor:
|
|
456
|
+
query_args["ExclusiveStartKey"] = json.loads(pagination.cursor)
|
|
457
|
+
|
|
458
|
+
if filters.search:
|
|
459
|
+
query_args["FilterExpression"] = "contains(#name, :search)"
|
|
460
|
+
query_args["ExpressionAttributeNames"]["#name"] = "name"
|
|
461
|
+
query_args["ExpressionAttributeValues"][":search"] = {"S": filters.search}
|
|
462
|
+
|
|
463
|
+
response = self.client.query(**query_args) # type: ignore
|
|
464
|
+
|
|
465
|
+
if "LastEvaluatedKey" in response:
|
|
466
|
+
paginated_response.pageInfo.hasNextPage = True
|
|
467
|
+
paginated_response.pageInfo.endCursor = json.dumps(
|
|
468
|
+
response["LastEvaluatedKey"]
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
for item in response["Items"]:
|
|
472
|
+
deserialized_item: Dict[str, Any] = self._deserialize_item(item)
|
|
473
|
+
thread = ThreadDict( # type: ignore
|
|
474
|
+
id=deserialized_item["PK"].strip("THREAD#"),
|
|
475
|
+
createdAt=deserialized_item["UserThreadSK"].strip("TS#"),
|
|
476
|
+
name=deserialized_item["name"],
|
|
477
|
+
)
|
|
478
|
+
paginated_response.data.append(thread)
|
|
479
|
+
|
|
480
|
+
return paginated_response
|
|
481
|
+
|
|
482
|
+
async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
|
|
483
|
+
_logger.info("DynamoDB: get_thread thread=%s", thread_id)
|
|
484
|
+
|
|
485
|
+
# Get all thread records
|
|
486
|
+
thread_items: List[Any] = []
|
|
487
|
+
|
|
488
|
+
cursor: Dict[str, Any] = {}
|
|
489
|
+
while True:
|
|
490
|
+
response = self.client.query(
|
|
491
|
+
TableName=self.table_name,
|
|
492
|
+
KeyConditionExpression="#pk = :pk",
|
|
493
|
+
ExpressionAttributeNames={"#pk": "PK"},
|
|
494
|
+
ExpressionAttributeValues={":pk": {"S": f"THREAD#{thread_id}"}},
|
|
495
|
+
**cursor,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
deserialized_items = map(self._deserialize_item, response["Items"])
|
|
499
|
+
thread_items.extend(deserialized_items)
|
|
500
|
+
|
|
501
|
+
if "LastEvaluatedKey" not in response:
|
|
502
|
+
break
|
|
503
|
+
cursor["ExclusiveStartKey"] = response["LastEvaluatedKey"]
|
|
504
|
+
|
|
505
|
+
if len(thread_items) == 0:
|
|
506
|
+
return None
|
|
507
|
+
|
|
508
|
+
# process accordingly
|
|
509
|
+
thread_dict: Optional[ThreadDict] = None
|
|
510
|
+
steps = []
|
|
511
|
+
elements = []
|
|
512
|
+
|
|
513
|
+
for item in thread_items:
|
|
514
|
+
if item["SK"] == "THREAD":
|
|
515
|
+
thread_dict = item
|
|
516
|
+
|
|
517
|
+
elif item["SK"].startswith("ELEMENT"):
|
|
518
|
+
elements.append(item)
|
|
519
|
+
|
|
520
|
+
elif item["SK"].startswith("STEP"):
|
|
521
|
+
if "feedback" in item: # Decimal is not json serializable
|
|
522
|
+
item["feedback"]["value"] = int(item["feedback"]["value"])
|
|
523
|
+
steps.append(item)
|
|
524
|
+
|
|
525
|
+
if not thread_dict:
|
|
526
|
+
if len(thread_items) > 0:
|
|
527
|
+
_logger.warning(
|
|
528
|
+
"DynamoDB: found orphaned items for thread=%s", thread_id
|
|
529
|
+
)
|
|
530
|
+
return None
|
|
531
|
+
|
|
532
|
+
steps.sort(key=lambda i: i["createdAt"])
|
|
533
|
+
thread_dict.update(
|
|
534
|
+
{
|
|
535
|
+
"steps": steps,
|
|
536
|
+
"elements": elements,
|
|
537
|
+
}
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
return thread_dict
|
|
541
|
+
|
|
542
|
+
async def update_thread(
|
|
543
|
+
self,
|
|
544
|
+
thread_id: str,
|
|
545
|
+
name: Optional[str] = None,
|
|
546
|
+
user_id: Optional[str] = None,
|
|
547
|
+
metadata: Optional[Dict] = None,
|
|
548
|
+
tags: Optional[List[str]] = None,
|
|
549
|
+
):
|
|
550
|
+
_logger.info("DynamoDB: update_thread thread=%s userId=%s", thread_id, user_id)
|
|
551
|
+
_logger.debug(
|
|
552
|
+
"DynamoDB: update_thread name=%s tags=%s metadata=%s", name, tags, metadata
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
ts = self._get_current_timestamp()
|
|
556
|
+
|
|
557
|
+
item = {
|
|
558
|
+
# GSI: UserThread
|
|
559
|
+
"UserThreadSK": f"TS#{ts}",
|
|
560
|
+
#
|
|
561
|
+
"id": thread_id,
|
|
562
|
+
"createdAt": ts,
|
|
563
|
+
"name": name,
|
|
564
|
+
"userId": user_id,
|
|
565
|
+
"userIdentifier": user_id,
|
|
566
|
+
"tags": tags,
|
|
567
|
+
"metadata": metadata,
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
if user_id:
|
|
571
|
+
# user_id may be None on subsequent calls, don't update UserThreadPK to "USER#{None}"
|
|
572
|
+
item["UserThreadPK"] = f"USER#{user_id}"
|
|
573
|
+
|
|
574
|
+
self._update_item(
|
|
575
|
+
key={
|
|
576
|
+
"PK": f"THREAD#{thread_id}",
|
|
577
|
+
"SK": "THREAD",
|
|
578
|
+
},
|
|
579
|
+
updates=item,
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
async def delete_user_session(self, id: str) -> bool:
|
|
583
|
+
return True # Not sure why documentation wants this
|
|
584
|
+
|
|
585
|
+
async def build_debug_url(self) -> str:
|
|
586
|
+
return ""
|