chainlit 1.0.503__py3-none-any.whl → 1.0.505__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/config.py +7 -3
- chainlit/copilot/dist/index.js +262 -262
- chainlit/data/__init__.py +37 -9
- chainlit/data/sql_alchemy.py +246 -144
- chainlit/emitter.py +19 -9
- chainlit/frontend/dist/assets/{index-a8e1b559.js → index-d200e7ad.js} +119 -119
- chainlit/frontend/dist/assets/{react-plotly-b225b63c.js → react-plotly-10f4012e.js} +1 -1
- chainlit/frontend/dist/index.html +1 -1
- chainlit/llama_index/callbacks.py +2 -3
- chainlit/openai/__init__.py +5 -9
- chainlit/server.py +6 -4
- chainlit/session.py +4 -0
- chainlit/socket.py +5 -1
- chainlit/types.py +19 -1
- chainlit/user_session.py +1 -0
- {chainlit-1.0.503.dist-info → chainlit-1.0.505.dist-info}/METADATA +2 -2
- {chainlit-1.0.503.dist-info → chainlit-1.0.505.dist-info}/RECORD +19 -19
- {chainlit-1.0.503.dist-info → chainlit-1.0.505.dist-info}/WHEEL +0 -0
- {chainlit-1.0.503.dist-info → chainlit-1.0.505.dist-info}/entry_points.txt +0 -0
chainlit/data/sql_alchemy.py
CHANGED
|
@@ -1,29 +1,45 @@
|
|
|
1
|
-
import
|
|
1
|
+
import json
|
|
2
2
|
import ssl
|
|
3
|
+
import uuid
|
|
4
|
+
from dataclasses import asdict
|
|
3
5
|
from datetime import datetime, timezone
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
7
|
+
|
|
6
8
|
import aiofiles
|
|
7
9
|
import aiohttp
|
|
8
|
-
from dataclasses import asdict
|
|
9
|
-
from sqlalchemy import text
|
|
10
|
-
from sqlalchemy.exc import SQLAlchemyError
|
|
11
|
-
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine
|
|
12
|
-
from sqlalchemy.orm import sessionmaker
|
|
13
10
|
from chainlit.context import context
|
|
14
|
-
from chainlit.logger import logger
|
|
15
11
|
from chainlit.data import BaseDataLayer, BaseStorageClient, queue_until_user_message
|
|
16
|
-
from chainlit.
|
|
17
|
-
from chainlit.
|
|
12
|
+
from chainlit.element import Avatar, ElementDict
|
|
13
|
+
from chainlit.logger import logger
|
|
18
14
|
from chainlit.step import StepDict
|
|
19
|
-
from chainlit.
|
|
15
|
+
from chainlit.types import (
|
|
16
|
+
Feedback,
|
|
17
|
+
FeedbackDict,
|
|
18
|
+
PageInfo,
|
|
19
|
+
PaginatedResponse,
|
|
20
|
+
Pagination,
|
|
21
|
+
ThreadDict,
|
|
22
|
+
ThreadFilter,
|
|
23
|
+
)
|
|
24
|
+
from chainlit.user import PersistedUser, User
|
|
25
|
+
from sqlalchemy import text
|
|
26
|
+
from sqlalchemy.exc import SQLAlchemyError
|
|
27
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
|
28
|
+
from sqlalchemy.orm import sessionmaker
|
|
20
29
|
|
|
21
30
|
if TYPE_CHECKING:
|
|
22
31
|
from chainlit.element import Element, ElementDict
|
|
23
32
|
from chainlit.step import StepDict
|
|
24
33
|
|
|
34
|
+
|
|
25
35
|
class SQLAlchemyDataLayer(BaseDataLayer):
|
|
26
|
-
def __init__(
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
conninfo: str,
|
|
39
|
+
ssl_require: bool = False,
|
|
40
|
+
storage_provider: Optional[BaseStorageClient] = None,
|
|
41
|
+
user_thread_limit: Optional[int] = 1000,
|
|
42
|
+
):
|
|
27
43
|
self._conninfo = conninfo
|
|
28
44
|
self.user_thread_limit = user_thread_limit
|
|
29
45
|
ssl_args = {}
|
|
@@ -32,17 +48,24 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
32
48
|
ssl_context = ssl.create_default_context()
|
|
33
49
|
ssl_context.check_hostname = False
|
|
34
50
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
35
|
-
ssl_args[
|
|
36
|
-
self.engine: AsyncEngine = create_async_engine(
|
|
51
|
+
ssl_args["ssl"] = ssl_context
|
|
52
|
+
self.engine: AsyncEngine = create_async_engine(
|
|
53
|
+
self._conninfo, connect_args=ssl_args
|
|
54
|
+
)
|
|
37
55
|
self.async_session = sessionmaker(bind=self.engine, expire_on_commit=False, class_=AsyncSession) # type: ignore
|
|
38
56
|
if storage_provider:
|
|
39
|
-
self.storage_provider = storage_provider
|
|
57
|
+
self.storage_provider: Optional[BaseStorageClient] = storage_provider
|
|
40
58
|
logger.info("SQLAlchemyDataLayer storage client initialized")
|
|
41
59
|
else:
|
|
42
|
-
|
|
60
|
+
self.storage_provider = None
|
|
61
|
+
logger.warn(
|
|
62
|
+
"SQLAlchemyDataLayer storage client is not initialized and elements will not be persisted!"
|
|
63
|
+
)
|
|
43
64
|
|
|
44
65
|
###### SQL Helpers ######
|
|
45
|
-
async def execute_sql(
|
|
66
|
+
async def execute_sql(
|
|
67
|
+
self, query: str, parameters: dict
|
|
68
|
+
) -> Union[List[Dict[str, Any]], int, None]:
|
|
46
69
|
parameterized_query = text(query)
|
|
47
70
|
async with self.async_session() as session:
|
|
48
71
|
try:
|
|
@@ -66,7 +89,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
66
89
|
|
|
67
90
|
async def get_current_timestamp(self) -> str:
|
|
68
91
|
return datetime.now().isoformat() + "Z"
|
|
69
|
-
|
|
92
|
+
|
|
70
93
|
def clean_result(self, obj):
|
|
71
94
|
"""Recursively change UUID -> str and serialize dictionaries"""
|
|
72
95
|
if isinstance(obj, dict):
|
|
@@ -76,7 +99,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
76
99
|
elif isinstance(obj, uuid.UUID):
|
|
77
100
|
return str(obj)
|
|
78
101
|
return obj
|
|
79
|
-
|
|
102
|
+
|
|
80
103
|
###### User ######
|
|
81
104
|
async def get_user(self, identifier: str) -> Optional[PersistedUser]:
|
|
82
105
|
logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
|
|
@@ -87,26 +110,28 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
87
110
|
user_data = result[0]
|
|
88
111
|
return PersistedUser(**user_data)
|
|
89
112
|
return None
|
|
90
|
-
|
|
113
|
+
|
|
91
114
|
async def create_user(self, user: User) -> Optional[PersistedUser]:
|
|
92
115
|
logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
|
|
93
|
-
existing_user: Optional[
|
|
94
|
-
user_dict: Dict[str, Any]
|
|
116
|
+
existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier)
|
|
117
|
+
user_dict: Dict[str, Any] = {
|
|
95
118
|
"identifier": str(user.identifier),
|
|
96
|
-
"metadata": json.dumps(user.metadata) or {}
|
|
97
|
-
|
|
98
|
-
if not existing_user:
|
|
119
|
+
"metadata": json.dumps(user.metadata) or {},
|
|
120
|
+
}
|
|
121
|
+
if not existing_user: # create the user
|
|
99
122
|
logger.info("SQLAlchemy: create_user, creating the user")
|
|
100
|
-
user_dict[
|
|
101
|
-
user_dict[
|
|
123
|
+
user_dict["id"] = str(uuid.uuid4())
|
|
124
|
+
user_dict["createdAt"] = await self.get_current_timestamp()
|
|
102
125
|
query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)"""
|
|
103
126
|
await self.execute_sql(query=query, parameters=user_dict)
|
|
104
|
-
else:
|
|
127
|
+
else: # update the user
|
|
105
128
|
logger.info("SQLAlchemy: update user metadata")
|
|
106
129
|
query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier"""
|
|
107
|
-
await self.execute_sql(
|
|
130
|
+
await self.execute_sql(
|
|
131
|
+
query=query, parameters=user_dict
|
|
132
|
+
) # We want to update the metadata
|
|
108
133
|
return await self.get_user(user.identifier)
|
|
109
|
-
|
|
134
|
+
|
|
110
135
|
###### Threads ######
|
|
111
136
|
async def get_thread_author(self, thread_id: str) -> str:
|
|
112
137
|
logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
|
|
@@ -114,21 +139,30 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
114
139
|
parameters = {"id": thread_id}
|
|
115
140
|
result = await self.execute_sql(query=query, parameters=parameters)
|
|
116
141
|
if isinstance(result, list) and result[0]:
|
|
117
|
-
author_identifier = result[0].get(
|
|
142
|
+
author_identifier = result[0].get("userIdentifier")
|
|
118
143
|
if author_identifier is not None:
|
|
119
|
-
print(f
|
|
144
|
+
print(f"Author found: {author_identifier}")
|
|
120
145
|
return author_identifier
|
|
121
|
-
raise ValueError
|
|
122
|
-
|
|
146
|
+
raise ValueError(f"Author not found for thread_id {thread_id}")
|
|
147
|
+
|
|
123
148
|
async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
|
|
124
149
|
logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
|
|
125
|
-
user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(
|
|
150
|
+
user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(
|
|
151
|
+
thread_id=thread_id
|
|
152
|
+
)
|
|
126
153
|
if user_threads:
|
|
127
154
|
return user_threads[0]
|
|
128
155
|
else:
|
|
129
156
|
return None
|
|
130
157
|
|
|
131
|
-
async def update_thread(
|
|
158
|
+
async def update_thread(
|
|
159
|
+
self,
|
|
160
|
+
thread_id: str,
|
|
161
|
+
name: Optional[str] = None,
|
|
162
|
+
user_id: Optional[str] = None,
|
|
163
|
+
metadata: Optional[Dict] = None,
|
|
164
|
+
tags: Optional[List[str]] = None,
|
|
165
|
+
):
|
|
132
166
|
logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
|
|
133
167
|
if context.session.user is not None:
|
|
134
168
|
user_identifier = context.session.user.identifier
|
|
@@ -136,17 +170,25 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
136
170
|
raise ValueError("User not found in session context")
|
|
137
171
|
data = {
|
|
138
172
|
"id": thread_id,
|
|
139
|
-
"createdAt": await self.get_current_timestamp()
|
|
140
|
-
|
|
173
|
+
"createdAt": await self.get_current_timestamp()
|
|
174
|
+
if metadata is None
|
|
175
|
+
else None,
|
|
176
|
+
"name": name
|
|
177
|
+
if name is not None
|
|
178
|
+
else (metadata.get("name") if metadata and "name" in metadata else None),
|
|
141
179
|
"userId": user_id,
|
|
142
180
|
"userIdentifier": user_identifier,
|
|
143
181
|
"tags": tags,
|
|
144
182
|
"metadata": json.dumps(metadata) if metadata else None,
|
|
145
183
|
}
|
|
146
|
-
parameters = {
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
184
|
+
parameters = {
|
|
185
|
+
key: value for key, value in data.items() if value is not None
|
|
186
|
+
} # Remove keys with None values
|
|
187
|
+
columns = ", ".join(f'"{key}"' for key in parameters.keys())
|
|
188
|
+
values = ", ".join(f":{key}" for key in parameters.keys())
|
|
189
|
+
updates = ", ".join(
|
|
190
|
+
f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != "id"
|
|
191
|
+
)
|
|
150
192
|
query = f"""
|
|
151
193
|
INSERT INTO threads ({columns})
|
|
152
194
|
VALUES ({values})
|
|
@@ -167,12 +209,18 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
167
209
|
await self.execute_sql(query=elements_query, parameters=parameters)
|
|
168
210
|
await self.execute_sql(query=steps_query, parameters=parameters)
|
|
169
211
|
await self.execute_sql(query=thread_query, parameters=parameters)
|
|
170
|
-
|
|
171
|
-
async def list_threads(
|
|
172
|
-
|
|
212
|
+
|
|
213
|
+
async def list_threads(
|
|
214
|
+
self, pagination: Pagination, filters: ThreadFilter
|
|
215
|
+
) -> PaginatedResponse:
|
|
216
|
+
logger.info(
|
|
217
|
+
f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}"
|
|
218
|
+
)
|
|
173
219
|
if not filters.userId:
|
|
174
220
|
raise ValueError("userId is required")
|
|
175
|
-
all_user_threads: List[ThreadDict] =
|
|
221
|
+
all_user_threads: List[ThreadDict] = (
|
|
222
|
+
await self.get_all_user_threads(user_id=filters.userId) or []
|
|
223
|
+
)
|
|
176
224
|
|
|
177
225
|
search_keyword = filters.search.lower() if filters.search else None
|
|
178
226
|
feedback_value = int(filters.feedback) if filters.feedback else None
|
|
@@ -183,47 +231,68 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
183
231
|
feedback_match = True
|
|
184
232
|
if search_keyword or feedback_value is not None:
|
|
185
233
|
if search_keyword:
|
|
186
|
-
keyword_match = any(
|
|
234
|
+
keyword_match = any(
|
|
235
|
+
search_keyword in step["output"].lower()
|
|
236
|
+
for step in thread["steps"]
|
|
237
|
+
if "output" in step
|
|
238
|
+
)
|
|
187
239
|
if feedback_value is not None:
|
|
188
240
|
feedback_match = False # Assume no match until found
|
|
189
|
-
for step in thread[
|
|
190
|
-
feedback = step.get(
|
|
191
|
-
if feedback and feedback.get(
|
|
241
|
+
for step in thread["steps"]:
|
|
242
|
+
feedback = step.get("feedback")
|
|
243
|
+
if feedback and feedback.get("value") == feedback_value:
|
|
192
244
|
feedback_match = True
|
|
193
245
|
break
|
|
194
246
|
if keyword_match and feedback_match:
|
|
195
247
|
filtered_threads.append(thread)
|
|
196
|
-
|
|
248
|
+
|
|
197
249
|
start = 0
|
|
198
250
|
if pagination.cursor:
|
|
199
251
|
for i, thread in enumerate(filtered_threads):
|
|
200
|
-
if
|
|
252
|
+
if (
|
|
253
|
+
thread["id"] == pagination.cursor
|
|
254
|
+
): # Find the start index using pagination.cursor
|
|
201
255
|
start = i + 1
|
|
202
256
|
break
|
|
203
257
|
end = start + pagination.first
|
|
204
258
|
paginated_threads = filtered_threads[start:end] or []
|
|
205
259
|
|
|
206
260
|
has_next_page = len(filtered_threads) > end
|
|
207
|
-
start_cursor = paginated_threads[0][
|
|
208
|
-
end_cursor = paginated_threads[-1][
|
|
261
|
+
start_cursor = paginated_threads[0]["id"] if paginated_threads else None
|
|
262
|
+
end_cursor = paginated_threads[-1]["id"] if paginated_threads else None
|
|
209
263
|
|
|
210
264
|
return PaginatedResponse(
|
|
211
|
-
pageInfo=PageInfo(
|
|
212
|
-
|
|
265
|
+
pageInfo=PageInfo(
|
|
266
|
+
hasNextPage=has_next_page,
|
|
267
|
+
startCursor=start_cursor,
|
|
268
|
+
endCursor=end_cursor,
|
|
269
|
+
),
|
|
270
|
+
data=paginated_threads,
|
|
213
271
|
)
|
|
214
|
-
|
|
272
|
+
|
|
215
273
|
###### Steps ######
|
|
216
274
|
@queue_until_user_message()
|
|
217
|
-
async def create_step(self, step_dict:
|
|
275
|
+
async def create_step(self, step_dict: "StepDict"):
|
|
218
276
|
logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
|
|
219
|
-
if not getattr(context.session.user,
|
|
277
|
+
if not getattr(context.session.user, "id", None):
|
|
220
278
|
raise ValueError("No authenticated user in context")
|
|
221
|
-
step_dict[
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
279
|
+
step_dict["showInput"] = (
|
|
280
|
+
str(step_dict.get("showInput", "")).lower()
|
|
281
|
+
if "showInput" in step_dict
|
|
282
|
+
else None
|
|
283
|
+
)
|
|
284
|
+
parameters = {
|
|
285
|
+
key: value
|
|
286
|
+
for key, value in step_dict.items()
|
|
287
|
+
if value is not None and not (isinstance(value, dict) and not value)
|
|
288
|
+
}
|
|
289
|
+
parameters["metadata"] = json.dumps(step_dict.get("metadata", {}))
|
|
290
|
+
parameters["generation"] = json.dumps(step_dict.get("generation", {}))
|
|
291
|
+
columns = ", ".join(f'"{key}"' for key in parameters.keys())
|
|
292
|
+
values = ", ".join(f":{key}" for key in parameters.keys())
|
|
293
|
+
updates = ", ".join(
|
|
294
|
+
f'"{key}" = :{key}' for key in parameters.keys() if key != "id"
|
|
295
|
+
)
|
|
227
296
|
query = f"""
|
|
228
297
|
INSERT INTO steps ({columns})
|
|
229
298
|
VALUES ({values})
|
|
@@ -231,9 +300,9 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
231
300
|
SET {updates};
|
|
232
301
|
"""
|
|
233
302
|
await self.execute_sql(query=query, parameters=parameters)
|
|
234
|
-
|
|
303
|
+
|
|
235
304
|
@queue_until_user_message()
|
|
236
|
-
async def update_step(self, step_dict:
|
|
305
|
+
async def update_step(self, step_dict: "StepDict"):
|
|
237
306
|
logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
|
|
238
307
|
await self.create_step(step_dict)
|
|
239
308
|
|
|
@@ -248,17 +317,21 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
248
317
|
await self.execute_sql(query=feedbacks_query, parameters=parameters)
|
|
249
318
|
await self.execute_sql(query=elements_query, parameters=parameters)
|
|
250
319
|
await self.execute_sql(query=steps_query, parameters=parameters)
|
|
251
|
-
|
|
320
|
+
|
|
252
321
|
###### Feedback ######
|
|
253
322
|
async def upsert_feedback(self, feedback: Feedback) -> str:
|
|
254
323
|
logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
|
|
255
324
|
feedback.id = feedback.id or str(uuid.uuid4())
|
|
256
325
|
feedback_dict = asdict(feedback)
|
|
257
|
-
parameters = {
|
|
326
|
+
parameters = {
|
|
327
|
+
key: value for key, value in feedback_dict.items() if value is not None
|
|
328
|
+
}
|
|
258
329
|
|
|
259
|
-
columns =
|
|
260
|
-
values =
|
|
261
|
-
updates =
|
|
330
|
+
columns = ", ".join(f'"{key}"' for key in parameters.keys())
|
|
331
|
+
values = ", ".join(f":{key}" for key in parameters.keys())
|
|
332
|
+
updates = ", ".join(
|
|
333
|
+
f'"{key}" = :{key}' for key in parameters.keys() if key != "id"
|
|
334
|
+
)
|
|
262
335
|
query = f"""
|
|
263
336
|
INSERT INTO feedbacks ({columns})
|
|
264
337
|
VALUES ({values})
|
|
@@ -267,24 +340,26 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
267
340
|
"""
|
|
268
341
|
await self.execute_sql(query=query, parameters=parameters)
|
|
269
342
|
return feedback.id
|
|
270
|
-
|
|
343
|
+
|
|
271
344
|
async def delete_feedback(self, feedback_id: str) -> bool:
|
|
272
345
|
logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
|
|
273
346
|
query = """DELETE FROM feedbacks WHERE "id" = :feedback_id"""
|
|
274
347
|
parameters = {"feedback_id": feedback_id}
|
|
275
348
|
await self.execute_sql(query=query, parameters=parameters)
|
|
276
349
|
return True
|
|
277
|
-
|
|
350
|
+
|
|
278
351
|
###### Elements ######
|
|
279
352
|
@queue_until_user_message()
|
|
280
|
-
async def create_element(self, element:
|
|
353
|
+
async def create_element(self, element: "Element"):
|
|
281
354
|
logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
|
|
282
|
-
if not getattr(context.session.user,
|
|
355
|
+
if not getattr(context.session.user, "id", None):
|
|
283
356
|
raise ValueError("No authenticated user in context")
|
|
284
|
-
if isinstance(element, Avatar):
|
|
357
|
+
if isinstance(element, Avatar): # Skip creating elements of type avatar
|
|
285
358
|
return
|
|
286
359
|
if not self.storage_provider:
|
|
287
|
-
logger.warn(
|
|
360
|
+
logger.warn(
|
|
361
|
+
f"SQLAlchemy: create_element error. No blob_storage_client is configured!"
|
|
362
|
+
)
|
|
288
363
|
return
|
|
289
364
|
if not element.for_id:
|
|
290
365
|
return
|
|
@@ -310,24 +385,30 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
310
385
|
|
|
311
386
|
context_user = context.session.user
|
|
312
387
|
|
|
313
|
-
user_folder = getattr(context_user,
|
|
314
|
-
file_object_key = f"{user_folder}/{element.id}" + (
|
|
388
|
+
user_folder = getattr(context_user, "id", "unknown")
|
|
389
|
+
file_object_key = f"{user_folder}/{element.id}" + (
|
|
390
|
+
f"/{element.name}" if element.name else ""
|
|
391
|
+
)
|
|
315
392
|
|
|
316
393
|
if not element.mime:
|
|
317
394
|
element.mime = "application/octet-stream"
|
|
318
395
|
|
|
319
|
-
uploaded_file = await self.storage_provider.upload_file(
|
|
396
|
+
uploaded_file = await self.storage_provider.upload_file(
|
|
397
|
+
object_key=file_object_key, data=content, mime=element.mime, overwrite=True
|
|
398
|
+
)
|
|
320
399
|
if not uploaded_file:
|
|
321
|
-
raise ValueError(
|
|
400
|
+
raise ValueError(
|
|
401
|
+
"SQLAlchemy Error: create_element, Failed to persist data in storage_provider"
|
|
402
|
+
)
|
|
322
403
|
|
|
323
404
|
element_dict: ElementDict = element.to_dict()
|
|
324
405
|
|
|
325
|
-
element_dict[
|
|
326
|
-
element_dict[
|
|
406
|
+
element_dict["url"] = uploaded_file.get("url")
|
|
407
|
+
element_dict["objectKey"] = uploaded_file.get("object_key")
|
|
327
408
|
element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None}
|
|
328
409
|
|
|
329
|
-
columns =
|
|
330
|
-
placeholders =
|
|
410
|
+
columns = ", ".join(f'"{column}"' for column in element_dict_cleaned.keys())
|
|
411
|
+
placeholders = ", ".join(f":{column}" for column in element_dict_cleaned.keys())
|
|
331
412
|
query = f"INSERT INTO elements ({columns}) VALUES ({placeholders})"
|
|
332
413
|
await self.execute_sql(query=query, parameters=element_dict_cleaned)
|
|
333
414
|
|
|
@@ -339,9 +420,11 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
339
420
|
await self.execute_sql(query=query, parameters=parameters)
|
|
340
421
|
|
|
341
422
|
async def delete_user_session(self, id: str) -> bool:
|
|
342
|
-
return False
|
|
423
|
+
return False # Not sure why documentation wants this
|
|
343
424
|
|
|
344
|
-
async def get_all_user_threads(
|
|
425
|
+
async def get_all_user_threads(
|
|
426
|
+
self, user_id: Optional[str] = None, thread_id: Optional[str] = None
|
|
427
|
+
) -> Optional[List[ThreadDict]]:
|
|
345
428
|
"""Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided."""
|
|
346
429
|
logger.info(f"SQLAlchemy: get_all_user_threads")
|
|
347
430
|
user_threads_query = """
|
|
@@ -358,14 +441,25 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
358
441
|
ORDER BY "createdAt" DESC
|
|
359
442
|
LIMIT :limit
|
|
360
443
|
"""
|
|
361
|
-
user_threads = await self.execute_sql(
|
|
444
|
+
user_threads = await self.execute_sql(
|
|
445
|
+
query=user_threads_query,
|
|
446
|
+
parameters={
|
|
447
|
+
"user_id": user_id,
|
|
448
|
+
"limit": self.user_thread_limit,
|
|
449
|
+
"thread_id": thread_id,
|
|
450
|
+
},
|
|
451
|
+
)
|
|
362
452
|
if not isinstance(user_threads, list):
|
|
363
453
|
return None
|
|
364
454
|
if not user_threads:
|
|
365
455
|
return []
|
|
366
456
|
else:
|
|
367
|
-
thread_ids =
|
|
368
|
-
|
|
457
|
+
thread_ids = (
|
|
458
|
+
"('"
|
|
459
|
+
+ "','".join(map(str, [thread["thread_id"] for thread in user_threads]))
|
|
460
|
+
+ "')"
|
|
461
|
+
)
|
|
462
|
+
|
|
369
463
|
steps_feedbacks_query = f"""
|
|
370
464
|
SELECT
|
|
371
465
|
s."id" AS step_id,
|
|
@@ -394,8 +488,10 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
394
488
|
WHERE s."threadId" IN {thread_ids}
|
|
395
489
|
ORDER BY s."createdAt" ASC
|
|
396
490
|
"""
|
|
397
|
-
steps_feedbacks = await self.execute_sql(
|
|
398
|
-
|
|
491
|
+
steps_feedbacks = await self.execute_sql(
|
|
492
|
+
query=steps_feedbacks_query, parameters={}
|
|
493
|
+
)
|
|
494
|
+
|
|
399
495
|
elements_query = f"""
|
|
400
496
|
SELECT
|
|
401
497
|
e."id" AS element_id,
|
|
@@ -418,77 +514,83 @@ class SQLAlchemyDataLayer(BaseDataLayer):
|
|
|
418
514
|
|
|
419
515
|
thread_dicts = {}
|
|
420
516
|
for thread in user_threads:
|
|
421
|
-
thread_id = thread[
|
|
517
|
+
thread_id = thread["thread_id"]
|
|
422
518
|
if thread_id is not None:
|
|
423
519
|
thread_dicts[thread_id] = ThreadDict(
|
|
424
520
|
id=thread_id,
|
|
425
|
-
createdAt=thread[
|
|
426
|
-
name=thread[
|
|
427
|
-
userId=thread[
|
|
428
|
-
userIdentifier=thread[
|
|
429
|
-
tags=thread[
|
|
430
|
-
metadata=thread[
|
|
521
|
+
createdAt=thread["thread_createdat"],
|
|
522
|
+
name=thread["thread_name"],
|
|
523
|
+
userId=thread["user_id"],
|
|
524
|
+
userIdentifier=thread["user_identifier"],
|
|
525
|
+
tags=thread["thread_tags"],
|
|
526
|
+
metadata=thread["thread_metadata"],
|
|
431
527
|
steps=[],
|
|
432
|
-
elements=[]
|
|
528
|
+
elements=[],
|
|
433
529
|
)
|
|
434
530
|
# Process steps_feedbacks to populate the steps in the corresponding ThreadDict
|
|
435
531
|
if isinstance(steps_feedbacks, list):
|
|
436
532
|
for step_feedback in steps_feedbacks:
|
|
437
|
-
thread_id = step_feedback[
|
|
533
|
+
thread_id = step_feedback["step_threadid"]
|
|
438
534
|
if thread_id is not None:
|
|
439
535
|
feedback = None
|
|
440
|
-
if step_feedback[
|
|
536
|
+
if step_feedback["feedback_value"] is not None:
|
|
441
537
|
feedback = FeedbackDict(
|
|
442
|
-
forId=step_feedback[
|
|
443
|
-
id=step_feedback.get(
|
|
444
|
-
value=step_feedback[
|
|
445
|
-
comment=step_feedback.get(
|
|
538
|
+
forId=step_feedback["step_id"],
|
|
539
|
+
id=step_feedback.get("feedback_id"),
|
|
540
|
+
value=step_feedback["feedback_value"],
|
|
541
|
+
comment=step_feedback.get("feedback_comment"),
|
|
446
542
|
)
|
|
447
543
|
step_dict = StepDict(
|
|
448
|
-
id=step_feedback[
|
|
449
|
-
name=step_feedback[
|
|
450
|
-
type=step_feedback[
|
|
544
|
+
id=step_feedback["step_id"],
|
|
545
|
+
name=step_feedback["step_name"],
|
|
546
|
+
type=step_feedback["step_type"],
|
|
451
547
|
threadId=thread_id,
|
|
452
|
-
parentId=step_feedback.get(
|
|
453
|
-
disableFeedback=step_feedback.get(
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
548
|
+
parentId=step_feedback.get("step_parentid"),
|
|
549
|
+
disableFeedback=step_feedback.get(
|
|
550
|
+
"step_disablefeedback", False
|
|
551
|
+
),
|
|
552
|
+
streaming=step_feedback.get("step_streaming", False),
|
|
553
|
+
waitForAnswer=step_feedback.get("step_waitforanswer"),
|
|
554
|
+
isError=step_feedback.get("step_iserror"),
|
|
555
|
+
metadata=step_feedback["step_metadata"]
|
|
556
|
+
if step_feedback.get("step_metadata") is not None
|
|
557
|
+
else {},
|
|
558
|
+
tags=step_feedback.get("step_tags"),
|
|
559
|
+
input=step_feedback.get("step_input", "")
|
|
560
|
+
if step_feedback["step_showinput"]
|
|
561
|
+
else "",
|
|
562
|
+
output=step_feedback.get("step_output", ""),
|
|
563
|
+
createdAt=step_feedback.get("step_createdat"),
|
|
564
|
+
start=step_feedback.get("step_start"),
|
|
565
|
+
end=step_feedback.get("step_end"),
|
|
566
|
+
generation=step_feedback.get("step_generation"),
|
|
567
|
+
showInput=step_feedback.get("step_showinput"),
|
|
568
|
+
language=step_feedback.get("step_language"),
|
|
569
|
+
indent=step_feedback.get("step_indent"),
|
|
570
|
+
feedback=feedback,
|
|
469
571
|
)
|
|
470
572
|
# Append the step to the steps list of the corresponding ThreadDict
|
|
471
|
-
thread_dicts[thread_id][
|
|
573
|
+
thread_dicts[thread_id]["steps"].append(step_dict)
|
|
472
574
|
|
|
473
575
|
if isinstance(elements, list):
|
|
474
576
|
for element in elements:
|
|
475
|
-
thread_id = element[
|
|
577
|
+
thread_id = element["element_threadid"]
|
|
476
578
|
if thread_id is not None:
|
|
477
579
|
element_dict = ElementDict(
|
|
478
|
-
id=element[
|
|
580
|
+
id=element["element_id"],
|
|
479
581
|
threadId=thread_id,
|
|
480
|
-
type=element[
|
|
481
|
-
chainlitKey=element.get(
|
|
482
|
-
url=element.get(
|
|
483
|
-
objectKey=element.get(
|
|
484
|
-
name=element[
|
|
485
|
-
display=element[
|
|
486
|
-
size=element.get(
|
|
487
|
-
language=element.get(
|
|
488
|
-
page=element.get(
|
|
489
|
-
forId=element.get(
|
|
490
|
-
mime=element.get(
|
|
582
|
+
type=element["element_type"],
|
|
583
|
+
chainlitKey=element.get("element_chainlitkey"),
|
|
584
|
+
url=element.get("element_url"),
|
|
585
|
+
objectKey=element.get("element_objectkey"),
|
|
586
|
+
name=element["element_name"],
|
|
587
|
+
display=element["element_display"],
|
|
588
|
+
size=element.get("element_size"),
|
|
589
|
+
language=element.get("element_language"),
|
|
590
|
+
page=element.get("element_page"),
|
|
591
|
+
forId=element.get("element_forid"),
|
|
592
|
+
mime=element.get("element_mime"),
|
|
491
593
|
)
|
|
492
|
-
thread_dicts[thread_id][
|
|
594
|
+
thread_dicts[thread_id]["elements"].append(element_dict) # type: ignore
|
|
493
595
|
|
|
494
|
-
return list(thread_dicts.values())
|
|
596
|
+
return list(thread_dicts.values())
|