chainlit 1.1.201__py3-none-any.whl → 1.1.300__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +22 -4
- chainlit/cli/__init__.py +53 -6
- chainlit/config.py +25 -18
- chainlit/context.py +9 -0
- chainlit/copilot/dist/index.js +440 -407
- chainlit/data/__init__.py +20 -5
- chainlit/data/dynamodb.py +586 -0
- chainlit/data/sql_alchemy.py +48 -28
- chainlit/discord/app.py +4 -2
- chainlit/element.py +41 -20
- chainlit/emitter.py +8 -7
- chainlit/frontend/dist/assets/DailyMotion-578b63e6.js +1 -0
- chainlit/frontend/dist/assets/Facebook-b825e5bb.js +1 -0
- chainlit/frontend/dist/assets/FilePlayer-bcba3b4e.js +1 -0
- chainlit/frontend/dist/assets/Kaltura-fc1c9497.js +1 -0
- chainlit/frontend/dist/assets/Mixcloud-4cfb2724.js +1 -0
- chainlit/frontend/dist/assets/Mux-aa92055c.js +1 -0
- chainlit/frontend/dist/assets/Preview-9f55905a.js +1 -0
- chainlit/frontend/dist/assets/SoundCloud-f991fe03.js +1 -0
- chainlit/frontend/dist/assets/Streamable-53128f49.js +1 -0
- chainlit/frontend/dist/assets/Twitch-fce8b9f5.js +1 -0
- chainlit/frontend/dist/assets/Vidyard-e35c6102.js +1 -0
- chainlit/frontend/dist/assets/Vimeo-fff35f8e.js +1 -0
- chainlit/frontend/dist/assets/Wistia-ec07dc64.js +1 -0
- chainlit/frontend/dist/assets/YouTube-ad068e2a.js +1 -0
- chainlit/frontend/dist/assets/index-aaf974a9.css +1 -0
- chainlit/frontend/dist/assets/index-d40d41cc.js +727 -0
- chainlit/frontend/dist/assets/{react-plotly-1ca97c0e.js → react-plotly-b2c6442b.js} +1 -1
- chainlit/frontend/dist/index.html +2 -3
- chainlit/langchain/callbacks.py +4 -2
- chainlit/llama_index/callbacks.py +2 -2
- chainlit/message.py +30 -25
- chainlit/oauth_providers.py +118 -0
- chainlit/server.py +208 -83
- chainlit/slack/app.py +2 -3
- chainlit/socket.py +27 -23
- chainlit/step.py +44 -30
- chainlit/teams/__init__.py +6 -0
- chainlit/teams/app.py +332 -0
- chainlit/translations/en-US.json +2 -4
- chainlit/types.py +17 -17
- chainlit/user.py +9 -1
- chainlit/utils.py +47 -3
- {chainlit-1.1.201.dist-info → chainlit-1.1.300.dist-info}/METADATA +22 -14
- chainlit-1.1.300.dist-info/RECORD +79 -0
- chainlit/cli/utils.py +0 -24
- chainlit/frontend/dist/assets/index-bf0451c6.js +0 -698
- chainlit/frontend/dist/assets/index-d088547c.css +0 -1
- chainlit/playground/__init__.py +0 -2
- chainlit/playground/config.py +0 -36
- chainlit/playground/provider.py +0 -108
- chainlit/playground/providers/__init__.py +0 -11
- chainlit/playground/providers/anthropic.py +0 -118
- chainlit/playground/providers/huggingface.py +0 -75
- chainlit/playground/providers/langchain.py +0 -89
- chainlit/playground/providers/openai.py +0 -386
- chainlit/playground/providers/vertexai.py +0 -171
- chainlit-1.1.201.dist-info/RECORD +0 -72
- {chainlit-1.1.201.dist-info → chainlit-1.1.300.dist-info}/WHEEL +0 -0
- {chainlit-1.1.201.dist-info → chainlit-1.1.300.dist-info}/entry_points.txt +0 -0
chainlit/data/__init__.py
CHANGED
|
@@ -94,7 +94,7 @@ class BaseDataLayer:
|
|
|
94
94
|
pass
|
|
95
95
|
|
|
96
96
|
@queue_until_user_message()
|
|
97
|
-
async def delete_element(self, element_id: str):
|
|
97
|
+
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
|
|
98
98
|
pass
|
|
99
99
|
|
|
100
100
|
@queue_until_user_message()
|
|
@@ -139,6 +139,9 @@ class BaseDataLayer:
|
|
|
139
139
|
async def delete_user_session(self, id: str) -> bool:
|
|
140
140
|
return True
|
|
141
141
|
|
|
142
|
+
async def build_debug_url(self) -> str:
|
|
143
|
+
return ""
|
|
144
|
+
|
|
142
145
|
|
|
143
146
|
_data_layer: Optional[BaseDataLayer] = None
|
|
144
147
|
|
|
@@ -157,6 +160,7 @@ class ChainlitDataLayer(BaseDataLayer):
|
|
|
157
160
|
"display": metadata.get("display", "side"),
|
|
158
161
|
"language": metadata.get("language"),
|
|
159
162
|
"autoPlay": metadata.get("autoPlay", None),
|
|
163
|
+
"playerConfig": metadata.get("playerConfig", None),
|
|
160
164
|
"page": metadata.get("page"),
|
|
161
165
|
"size": metadata.get("size"),
|
|
162
166
|
"type": metadata.get("type", "file"),
|
|
@@ -224,6 +228,14 @@ class ChainlitDataLayer(BaseDataLayer):
|
|
|
224
228
|
"waitForAnswer": metadata.get("waitForAnswer", False),
|
|
225
229
|
}
|
|
226
230
|
|
|
231
|
+
async def build_debug_url(self) -> str:
|
|
232
|
+
try:
|
|
233
|
+
project_id = await self.client.api.get_my_project_id()
|
|
234
|
+
return f"{self.client.api.url}/projects/{project_id}/threads?threadId=[thread_id]¤tStepId=[step_id]"
|
|
235
|
+
except Exception as e:
|
|
236
|
+
logger.error(f"Error building debug url: {e}")
|
|
237
|
+
return ""
|
|
238
|
+
|
|
227
239
|
async def get_user(self, identifier: str) -> Optional[PersistedUser]:
|
|
228
240
|
user = await self.client.api.get_user(identifier=identifier)
|
|
229
241
|
if not user:
|
|
@@ -340,7 +352,7 @@ class ChainlitDataLayer(BaseDataLayer):
|
|
|
340
352
|
return self.attachment_to_element_dict(attachment)
|
|
341
353
|
|
|
342
354
|
@queue_until_user_message()
|
|
343
|
-
async def delete_element(self, element_id: str):
|
|
355
|
+
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
|
|
344
356
|
await self.client.api.delete_attachment(id=element_id)
|
|
345
357
|
|
|
346
358
|
@queue_until_user_message()
|
|
@@ -455,12 +467,15 @@ class ChainlitDataLayer(BaseDataLayer):
|
|
|
455
467
|
steps = [] # List[StepDict]
|
|
456
468
|
if thread.steps:
|
|
457
469
|
for step in thread.steps:
|
|
458
|
-
if
|
|
470
|
+
if step.type == "system_message":
|
|
471
|
+
continue
|
|
472
|
+
if config.ui.hide_cot and step.type not in [
|
|
473
|
+
"user_message",
|
|
474
|
+
"assistant_message",
|
|
475
|
+
]:
|
|
459
476
|
continue
|
|
460
477
|
for attachment in step.attachments:
|
|
461
478
|
elements.append(self.attachment_to_element_dict(attachment))
|
|
462
|
-
if not config.features.prompt_playground and step.generation:
|
|
463
|
-
step.generation = None
|
|
464
479
|
steps.append(self.step_to_step_dict(step))
|
|
465
480
|
|
|
466
481
|
return {
|
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
from dataclasses import asdict
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
9
|
+
|
|
10
|
+
import aiofiles
|
|
11
|
+
import aiohttp
|
|
12
|
+
import boto3 # type: ignore
|
|
13
|
+
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
|
|
14
|
+
from chainlit.context import context
|
|
15
|
+
from chainlit.data import BaseDataLayer, BaseStorageClient, queue_until_user_message
|
|
16
|
+
from chainlit.element import ElementDict
|
|
17
|
+
from chainlit.logger import logger
|
|
18
|
+
from chainlit.step import StepDict
|
|
19
|
+
from chainlit.types import (
|
|
20
|
+
Feedback,
|
|
21
|
+
PageInfo,
|
|
22
|
+
PaginatedResponse,
|
|
23
|
+
Pagination,
|
|
24
|
+
ThreadDict,
|
|
25
|
+
ThreadFilter,
|
|
26
|
+
)
|
|
27
|
+
from chainlit.user import PersistedUser, User
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from chainlit.element import Element
|
|
31
|
+
from mypy_boto3_dynamodb import DynamoDBClient
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
_logger = logger.getChild("DynamoDB")
|
|
35
|
+
_logger.setLevel(logging.WARNING)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DynamoDBDataLayer(BaseDataLayer):
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
table_name: str,
|
|
43
|
+
client: Optional["DynamoDBClient"] = None,
|
|
44
|
+
storage_provider: Optional[BaseStorageClient] = None,
|
|
45
|
+
user_thread_limit: int = 10,
|
|
46
|
+
):
|
|
47
|
+
if client:
|
|
48
|
+
self.client = client
|
|
49
|
+
else:
|
|
50
|
+
region_name = os.environ.get("AWS_REGION", "us-east-1")
|
|
51
|
+
self.client = boto3.client("dynamodb", region_name=region_name) # type: ignore
|
|
52
|
+
|
|
53
|
+
self.table_name = table_name
|
|
54
|
+
self.storage_provider = storage_provider
|
|
55
|
+
self.user_thread_limit = user_thread_limit
|
|
56
|
+
|
|
57
|
+
self._type_deserializer = TypeDeserializer()
|
|
58
|
+
self._type_serializer = TypeSerializer()
|
|
59
|
+
|
|
60
|
+
def _get_current_timestamp(self) -> str:
|
|
61
|
+
return datetime.now().isoformat() + "Z"
|
|
62
|
+
|
|
63
|
+
def _serialize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
|
64
|
+
return {
|
|
65
|
+
key: self._type_serializer.serialize(value) for key, value in item.items()
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
def _deserialize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
|
69
|
+
return {
|
|
70
|
+
key: self._type_deserializer.deserialize(value)
|
|
71
|
+
for key, value in item.items()
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
def _update_item(self, key: Dict[str, Any], updates: Dict[str, Any]):
|
|
75
|
+
update_expr: List[str] = []
|
|
76
|
+
expression_attribute_names = {}
|
|
77
|
+
expression_attribute_values = {}
|
|
78
|
+
|
|
79
|
+
for index, (attr, value) in enumerate(updates.items()):
|
|
80
|
+
if not value:
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
k, v = f"#{index}", f":{index}"
|
|
84
|
+
update_expr.append(f"{k} = {v}")
|
|
85
|
+
expression_attribute_names[k] = attr
|
|
86
|
+
expression_attribute_values[v] = value
|
|
87
|
+
|
|
88
|
+
self.client.update_item(
|
|
89
|
+
TableName=self.table_name,
|
|
90
|
+
Key=self._serialize_item(key),
|
|
91
|
+
UpdateExpression="SET " + ", ".join(update_expr),
|
|
92
|
+
ExpressionAttributeNames=expression_attribute_names,
|
|
93
|
+
ExpressionAttributeValues=self._serialize_item(expression_attribute_values),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
|
|
97
|
+
_logger.info("DynamoDB: get_user identifier=%s", identifier)
|
|
98
|
+
|
|
99
|
+
response = self.client.get_item(
|
|
100
|
+
TableName=self.table_name,
|
|
101
|
+
Key={
|
|
102
|
+
"PK": {"S": f"USER#{identifier}"},
|
|
103
|
+
"SK": {"S": "USER"},
|
|
104
|
+
},
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if "Item" not in response:
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
user = self._deserialize_item(response["Item"])
|
|
111
|
+
|
|
112
|
+
return PersistedUser(
|
|
113
|
+
id=user["id"],
|
|
114
|
+
identifier=user["identifier"],
|
|
115
|
+
createdAt=user["createdAt"],
|
|
116
|
+
metadata=user["metadata"],
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
async def create_user(self, user: "User") -> Optional["PersistedUser"]:
|
|
120
|
+
_logger.info("DynamoDB: create_user user.identifier=%s", user.identifier)
|
|
121
|
+
|
|
122
|
+
ts = self._get_current_timestamp()
|
|
123
|
+
metadata: Dict[Any, Any] = user.metadata # type: ignore
|
|
124
|
+
|
|
125
|
+
item = {
|
|
126
|
+
"PK": f"USER#{user.identifier}",
|
|
127
|
+
"SK": "USER",
|
|
128
|
+
"id": user.identifier,
|
|
129
|
+
"identifier": user.identifier,
|
|
130
|
+
"metadata": metadata,
|
|
131
|
+
"createdAt": ts,
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
self.client.put_item(
|
|
135
|
+
TableName=self.table_name,
|
|
136
|
+
Item=self._serialize_item(item),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return PersistedUser(
|
|
140
|
+
id=user.identifier,
|
|
141
|
+
identifier=user.identifier,
|
|
142
|
+
createdAt=ts,
|
|
143
|
+
metadata=metadata,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
async def delete_feedback(self, feedback_id: str) -> bool:
|
|
147
|
+
_logger.info("DynamoDB: delete_feedback feedback_id=%s", feedback_id)
|
|
148
|
+
|
|
149
|
+
# feedback id = THREAD#{thread_id}::STEP#{step_id}
|
|
150
|
+
thread_id, step_id = feedback_id.split("::")
|
|
151
|
+
thread_id = thread_id.strip("THREAD#")
|
|
152
|
+
step_id = step_id.strip("STEP#")
|
|
153
|
+
|
|
154
|
+
self.client.update_item(
|
|
155
|
+
TableName=self.table_name,
|
|
156
|
+
Key={
|
|
157
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
158
|
+
"SK": {"S": f"STEP#{step_id}"},
|
|
159
|
+
},
|
|
160
|
+
UpdateExpression="REMOVE #feedback",
|
|
161
|
+
ExpressionAttributeNames={"#feedback": "feedback"},
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
async def upsert_feedback(self, feedback: Feedback) -> str:
|
|
167
|
+
_logger.info(
|
|
168
|
+
"DynamoDB: upsert_feedback thread=%s step=%s value=%s",
|
|
169
|
+
feedback.threadId,
|
|
170
|
+
feedback.forId,
|
|
171
|
+
feedback.value,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if not feedback.forId:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
"DynamoDB datalayer expects value for feedback.threadId got None"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
feedback.id = f"THREAD#{feedback.threadId}::STEP#{feedback.forId}"
|
|
180
|
+
searialized_feedback = self._type_serializer.serialize(asdict(feedback))
|
|
181
|
+
|
|
182
|
+
self.client.update_item(
|
|
183
|
+
TableName=self.table_name,
|
|
184
|
+
Key={
|
|
185
|
+
"PK": {"S": f"THREAD#{feedback.threadId}"},
|
|
186
|
+
"SK": {"S": f"STEP#{feedback.forId}"},
|
|
187
|
+
},
|
|
188
|
+
UpdateExpression="SET #feedback = :feedback",
|
|
189
|
+
ExpressionAttributeNames={"#feedback": "feedback"},
|
|
190
|
+
ExpressionAttributeValues={":feedback": searialized_feedback},
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
return feedback.id
|
|
194
|
+
|
|
195
|
+
@queue_until_user_message()
|
|
196
|
+
async def create_element(self, element: "Element"):
|
|
197
|
+
_logger.info(
|
|
198
|
+
"DynamoDB: create_element thread=%s step=%s type=%s",
|
|
199
|
+
element.thread_id,
|
|
200
|
+
element.for_id,
|
|
201
|
+
element.type,
|
|
202
|
+
)
|
|
203
|
+
_logger.debug("DynamoDB: create_element: %s", element.to_dict())
|
|
204
|
+
|
|
205
|
+
if not element.for_id:
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
if not self.storage_provider:
|
|
209
|
+
_logger.warning(
|
|
210
|
+
"DynamoDB: create_element error. No storage_provider is configured!"
|
|
211
|
+
)
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
content: Optional[Union[bytes, str]] = None
|
|
215
|
+
|
|
216
|
+
if element.content:
|
|
217
|
+
content = element.content
|
|
218
|
+
|
|
219
|
+
elif element.path:
|
|
220
|
+
_logger.debug("DynamoDB: create_element reading file %s", element.path)
|
|
221
|
+
async with aiofiles.open(element.path, "rb") as f:
|
|
222
|
+
content = await f.read()
|
|
223
|
+
|
|
224
|
+
elif element.url:
|
|
225
|
+
_logger.debug("DynamoDB: create_element http %s", element.url)
|
|
226
|
+
async with aiohttp.ClientSession() as session:
|
|
227
|
+
async with session.get(element.url) as response:
|
|
228
|
+
if response.status == 200:
|
|
229
|
+
content = await response.read()
|
|
230
|
+
else:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
f"Failed to read content from {element.url} status {response.status}",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
else:
|
|
236
|
+
raise ValueError("Element url, path or content must be provided")
|
|
237
|
+
|
|
238
|
+
if content is None:
|
|
239
|
+
raise ValueError("Content is None, cannot upload file")
|
|
240
|
+
|
|
241
|
+
if not element.mime:
|
|
242
|
+
element.mime = "application/octet-stream"
|
|
243
|
+
|
|
244
|
+
context_user = context.session.user
|
|
245
|
+
user_folder = getattr(context_user, "id", "unknown")
|
|
246
|
+
file_object_key = f"{user_folder}/{element.thread_id}/{element.id}"
|
|
247
|
+
|
|
248
|
+
uploaded_file = await self.storage_provider.upload_file(
|
|
249
|
+
object_key=file_object_key,
|
|
250
|
+
data=content,
|
|
251
|
+
mime=element.mime,
|
|
252
|
+
overwrite=True,
|
|
253
|
+
)
|
|
254
|
+
if not uploaded_file:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
"DynamoDB Error: create_element, Failed to persist data in storage_provider",
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
element_dict: Dict[str, Any] = element.to_dict() # type: ignore
|
|
260
|
+
element_dict.update(
|
|
261
|
+
{
|
|
262
|
+
"PK": f"THREAD#{element.thread_id}",
|
|
263
|
+
"SK": f"ELEMENT#{element.id}",
|
|
264
|
+
"url": uploaded_file.get("url"),
|
|
265
|
+
"objectKey": uploaded_file.get("object_key"),
|
|
266
|
+
}
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
self.client.put_item(
|
|
270
|
+
TableName=self.table_name,
|
|
271
|
+
Item=self._serialize_item(element_dict),
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
async def get_element(
|
|
275
|
+
self, thread_id: str, element_id: str
|
|
276
|
+
) -> Optional["ElementDict"]:
|
|
277
|
+
_logger.info(
|
|
278
|
+
"DynamoDB: get_element thread=%s element=%s", thread_id, element_id
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
response = self.client.get_item(
|
|
282
|
+
TableName=self.table_name,
|
|
283
|
+
Key={
|
|
284
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
285
|
+
"SK": {"S": f"ELEMENT#{element_id}"},
|
|
286
|
+
},
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
if "Item" not in response:
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
return self._deserialize_item(response["Item"]) # type: ignore
|
|
293
|
+
|
|
294
|
+
@queue_until_user_message()
|
|
295
|
+
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
|
|
296
|
+
thread_id = context.session.thread_id
|
|
297
|
+
_logger.info(
|
|
298
|
+
"DynamoDB: delete_element thread=%s element=%s", thread_id, element_id
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
self.client.delete_item(
|
|
302
|
+
TableName=self.table_name,
|
|
303
|
+
Key={
|
|
304
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
305
|
+
"SK": {"S": f"ELEMENT#{element_id}"},
|
|
306
|
+
},
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
@queue_until_user_message()
|
|
310
|
+
async def create_step(self, step_dict: "StepDict"):
|
|
311
|
+
_logger.info(
|
|
312
|
+
"DynamoDB: create_step thread=%s step=%s",
|
|
313
|
+
step_dict.get("threadId"),
|
|
314
|
+
step_dict.get("id"),
|
|
315
|
+
)
|
|
316
|
+
_logger.debug("DynamoDB: create_step: %s", step_dict)
|
|
317
|
+
|
|
318
|
+
item = dict(step_dict)
|
|
319
|
+
item.update(
|
|
320
|
+
{
|
|
321
|
+
# ignore type, dynamo needs these so we want to fail if not set
|
|
322
|
+
"PK": f"THREAD#{step_dict['threadId']}", # type: ignore
|
|
323
|
+
"SK": f"STEP#{step_dict['id']}", # type: ignore
|
|
324
|
+
}
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
self.client.put_item(
|
|
328
|
+
TableName=self.table_name,
|
|
329
|
+
Item=self._serialize_item(item),
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
@queue_until_user_message()
|
|
333
|
+
async def update_step(self, step_dict: "StepDict"):
|
|
334
|
+
_logger.info(
|
|
335
|
+
"DynamoDB: update_step thread=%s step=%s",
|
|
336
|
+
step_dict.get("threadId"),
|
|
337
|
+
step_dict.get("id"),
|
|
338
|
+
)
|
|
339
|
+
_logger.debug("DynamoDB: update_step: %s", step_dict)
|
|
340
|
+
|
|
341
|
+
self._update_item(
|
|
342
|
+
key={
|
|
343
|
+
# ignore type, dynamo needs these so we want to fail if not set
|
|
344
|
+
"PK": f"THREAD#{step_dict['threadId']}", # type: ignore
|
|
345
|
+
"SK": f"STEP#{step_dict['id']}", # type: ignore
|
|
346
|
+
},
|
|
347
|
+
updates=step_dict, # type: ignore
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
@queue_until_user_message()
|
|
351
|
+
async def delete_step(self, step_id: str):
|
|
352
|
+
thread_id = context.session.thread_id
|
|
353
|
+
_logger.info("DynamoDB: delete_feedback thread=%s step=%s", thread_id, step_id)
|
|
354
|
+
|
|
355
|
+
self.client.delete_item(
|
|
356
|
+
TableName=self.table_name,
|
|
357
|
+
Key={
|
|
358
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
359
|
+
"SK": {"S": f"STEP#{step_id}"},
|
|
360
|
+
},
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
async def get_thread_author(self, thread_id: str) -> str:
|
|
364
|
+
_logger.info("DynamoDB: get_thread_author thread=%s", thread_id)
|
|
365
|
+
|
|
366
|
+
response = self.client.get_item(
|
|
367
|
+
TableName=self.table_name,
|
|
368
|
+
Key={
|
|
369
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
370
|
+
"SK": {"S": "THREAD"},
|
|
371
|
+
},
|
|
372
|
+
ProjectionExpression="userId",
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
if "Item" not in response:
|
|
376
|
+
raise ValueError(f"Author not found for thread_id {thread_id}")
|
|
377
|
+
|
|
378
|
+
item = self._deserialize_item(response["Item"])
|
|
379
|
+
return item["userId"]
|
|
380
|
+
|
|
381
|
+
async def delete_thread(self, thread_id: str):
|
|
382
|
+
_logger.info("DynamoDB: delete_thread thread=%s", thread_id)
|
|
383
|
+
|
|
384
|
+
thread = await self.get_thread(thread_id)
|
|
385
|
+
if not thread:
|
|
386
|
+
return
|
|
387
|
+
|
|
388
|
+
items: List[Any] = thread["steps"]
|
|
389
|
+
if thread["elements"]:
|
|
390
|
+
items.extend(thread["elements"])
|
|
391
|
+
|
|
392
|
+
delete_requests = []
|
|
393
|
+
for item in items:
|
|
394
|
+
key = self._serialize_item({"PK": item["PK"], "SK": item["SK"]})
|
|
395
|
+
req = {"DeleteRequest": {"Key": key}}
|
|
396
|
+
delete_requests.append(req)
|
|
397
|
+
|
|
398
|
+
BATCH_ITEM_SIZE = 25 # pylint: disable=invalid-name
|
|
399
|
+
for i in range(0, len(delete_requests), BATCH_ITEM_SIZE):
|
|
400
|
+
chunk = delete_requests[i : i + BATCH_ITEM_SIZE] # noqa: E203
|
|
401
|
+
response = self.client.batch_write_item(
|
|
402
|
+
RequestItems={
|
|
403
|
+
self.table_name: chunk, # type: ignore
|
|
404
|
+
}
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
backoff_time = 1
|
|
408
|
+
while "UnprocessedItems" in response and response["UnprocessedItems"]:
|
|
409
|
+
backoff_time *= 2
|
|
410
|
+
# Cap the backoff time at 32 seconds & add jitter
|
|
411
|
+
delay = min(backoff_time, 32) + random.uniform(0, 1)
|
|
412
|
+
await asyncio.sleep(delay)
|
|
413
|
+
|
|
414
|
+
response = self.client.batch_write_item(
|
|
415
|
+
RequestItems=response["UnprocessedItems"]
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
self.client.delete_item(
|
|
419
|
+
TableName=self.table_name,
|
|
420
|
+
Key={
|
|
421
|
+
"PK": {"S": f"THREAD#{thread_id}"},
|
|
422
|
+
"SK": {"S": "THREAD"},
|
|
423
|
+
},
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
async def list_threads(
|
|
427
|
+
self, pagination: "Pagination", filters: "ThreadFilter"
|
|
428
|
+
) -> "PaginatedResponse[ThreadDict]":
|
|
429
|
+
_logger.info("DynamoDB: list_threads filters.userId=%s", filters.userId)
|
|
430
|
+
|
|
431
|
+
if filters.feedback:
|
|
432
|
+
_logger.warning("DynamoDB: filters on feedback not supported")
|
|
433
|
+
|
|
434
|
+
paginated_response: PaginatedResponse[ThreadDict] = PaginatedResponse(
|
|
435
|
+
data=[],
|
|
436
|
+
pageInfo=PageInfo(
|
|
437
|
+
hasNextPage=False, startCursor=pagination.cursor, endCursor=None
|
|
438
|
+
),
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
query_args: Dict[str, Any] = {
|
|
442
|
+
"TableName": self.table_name,
|
|
443
|
+
"IndexName": "UserThread",
|
|
444
|
+
"ScanIndexForward": False,
|
|
445
|
+
"Limit": self.user_thread_limit,
|
|
446
|
+
"KeyConditionExpression": "#UserThreadPK = :pk",
|
|
447
|
+
"ExpressionAttributeNames": {
|
|
448
|
+
"#UserThreadPK": "UserThreadPK",
|
|
449
|
+
},
|
|
450
|
+
"ExpressionAttributeValues": {
|
|
451
|
+
":pk": {"S": f"USER#{filters.userId}"},
|
|
452
|
+
},
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
if pagination.cursor:
|
|
456
|
+
query_args["ExclusiveStartKey"] = json.loads(pagination.cursor)
|
|
457
|
+
|
|
458
|
+
if filters.search:
|
|
459
|
+
query_args["FilterExpression"] = "contains(#name, :search)"
|
|
460
|
+
query_args["ExpressionAttributeNames"]["#name"] = "name"
|
|
461
|
+
query_args["ExpressionAttributeValues"][":search"] = {"S": filters.search}
|
|
462
|
+
|
|
463
|
+
response = self.client.query(**query_args) # type: ignore
|
|
464
|
+
|
|
465
|
+
if "LastEvaluatedKey" in response:
|
|
466
|
+
paginated_response.pageInfo.hasNextPage = True
|
|
467
|
+
paginated_response.pageInfo.endCursor = json.dumps(
|
|
468
|
+
response["LastEvaluatedKey"]
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
for item in response["Items"]:
|
|
472
|
+
deserialized_item: Dict[str, Any] = self._deserialize_item(item)
|
|
473
|
+
thread = ThreadDict( # type: ignore
|
|
474
|
+
id=deserialized_item["PK"].strip("THREAD#"),
|
|
475
|
+
createdAt=deserialized_item["UserThreadSK"].strip("TS#"),
|
|
476
|
+
name=deserialized_item["name"],
|
|
477
|
+
)
|
|
478
|
+
paginated_response.data.append(thread)
|
|
479
|
+
|
|
480
|
+
return paginated_response
|
|
481
|
+
|
|
482
|
+
async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
|
|
483
|
+
_logger.info("DynamoDB: get_thread thread=%s", thread_id)
|
|
484
|
+
|
|
485
|
+
# Get all thread records
|
|
486
|
+
thread_items: List[Any] = []
|
|
487
|
+
|
|
488
|
+
cursor: Dict[str, Any] = {}
|
|
489
|
+
while True:
|
|
490
|
+
response = self.client.query(
|
|
491
|
+
TableName=self.table_name,
|
|
492
|
+
KeyConditionExpression="#pk = :pk",
|
|
493
|
+
ExpressionAttributeNames={"#pk": "PK"},
|
|
494
|
+
ExpressionAttributeValues={":pk": {"S": f"THREAD#{thread_id}"}},
|
|
495
|
+
**cursor,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
deserialized_items = map(self._deserialize_item, response["Items"])
|
|
499
|
+
thread_items.extend(deserialized_items)
|
|
500
|
+
|
|
501
|
+
if "LastEvaluatedKey" not in response:
|
|
502
|
+
break
|
|
503
|
+
cursor["ExclusiveStartKey"] = response["LastEvaluatedKey"]
|
|
504
|
+
|
|
505
|
+
if len(thread_items) == 0:
|
|
506
|
+
return None
|
|
507
|
+
|
|
508
|
+
# process accordingly
|
|
509
|
+
thread_dict: Optional[ThreadDict] = None
|
|
510
|
+
steps = []
|
|
511
|
+
elements = []
|
|
512
|
+
|
|
513
|
+
for item in thread_items:
|
|
514
|
+
if item["SK"] == "THREAD":
|
|
515
|
+
thread_dict = item
|
|
516
|
+
|
|
517
|
+
elif item["SK"].startswith("ELEMENT"):
|
|
518
|
+
elements.append(item)
|
|
519
|
+
|
|
520
|
+
elif item["SK"].startswith("STEP"):
|
|
521
|
+
if "feedback" in item: # Decimal is not json serializable
|
|
522
|
+
item["feedback"]["value"] = int(item["feedback"]["value"])
|
|
523
|
+
steps.append(item)
|
|
524
|
+
|
|
525
|
+
if not thread_dict:
|
|
526
|
+
if len(thread_items) > 0:
|
|
527
|
+
_logger.warning(
|
|
528
|
+
"DynamoDB: found orphaned items for thread=%s", thread_id
|
|
529
|
+
)
|
|
530
|
+
return None
|
|
531
|
+
|
|
532
|
+
steps.sort(key=lambda i: i["createdAt"])
|
|
533
|
+
thread_dict.update(
|
|
534
|
+
{
|
|
535
|
+
"steps": steps,
|
|
536
|
+
"elements": elements,
|
|
537
|
+
}
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
return thread_dict
|
|
541
|
+
|
|
542
|
+
async def update_thread(
|
|
543
|
+
self,
|
|
544
|
+
thread_id: str,
|
|
545
|
+
name: Optional[str] = None,
|
|
546
|
+
user_id: Optional[str] = None,
|
|
547
|
+
metadata: Optional[Dict] = None,
|
|
548
|
+
tags: Optional[List[str]] = None,
|
|
549
|
+
):
|
|
550
|
+
_logger.info("DynamoDB: update_thread thread=%s userId=%s", thread_id, user_id)
|
|
551
|
+
_logger.debug(
|
|
552
|
+
"DynamoDB: update_thread name=%s tags=%s metadata=%s", name, tags, metadata
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
ts = self._get_current_timestamp()
|
|
556
|
+
|
|
557
|
+
item = {
|
|
558
|
+
# GSI: UserThread
|
|
559
|
+
"UserThreadSK": f"TS#{ts}",
|
|
560
|
+
#
|
|
561
|
+
"id": thread_id,
|
|
562
|
+
"createdAt": ts,
|
|
563
|
+
"name": name,
|
|
564
|
+
"userId": user_id,
|
|
565
|
+
"userIdentifier": user_id,
|
|
566
|
+
"tags": tags,
|
|
567
|
+
"metadata": metadata,
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
if user_id:
|
|
571
|
+
# user_id may be None on subsequent calls, don't update UserThreadPK to "USER#{None}"
|
|
572
|
+
item["UserThreadPK"] = f"USER#{user_id}"
|
|
573
|
+
|
|
574
|
+
self._update_item(
|
|
575
|
+
key={
|
|
576
|
+
"PK": f"THREAD#{thread_id}",
|
|
577
|
+
"SK": "THREAD",
|
|
578
|
+
},
|
|
579
|
+
updates=item,
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
async def delete_user_session(self, id: str) -> bool:
|
|
583
|
+
return True # Not sure why documentation wants this
|
|
584
|
+
|
|
585
|
+
async def build_debug_url(self) -> str:
|
|
586
|
+
return ""
|