chainlit 1.0.504__py3-none-any.whl → 1.0.505__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

@@ -1,29 +1,45 @@
1
- import uuid
1
+ import json
2
2
  import ssl
3
+ import uuid
4
+ from dataclasses import asdict
3
5
  from datetime import datetime, timezone
4
- import json
5
- from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING, Any
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
7
+
6
8
  import aiofiles
7
9
  import aiohttp
8
- from dataclasses import asdict
9
- from sqlalchemy import text
10
- from sqlalchemy.exc import SQLAlchemyError
11
- from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine
12
- from sqlalchemy.orm import sessionmaker
13
10
  from chainlit.context import context
14
- from chainlit.logger import logger
15
11
  from chainlit.data import BaseDataLayer, BaseStorageClient, queue_until_user_message
16
- from chainlit.user import User, PersistedUser
17
- from chainlit.types import Feedback, FeedbackDict, Pagination, ThreadDict, ThreadFilter, PageInfo, PaginatedResponse
12
+ from chainlit.element import Avatar, ElementDict
13
+ from chainlit.logger import logger
18
14
  from chainlit.step import StepDict
19
- from chainlit.element import ElementDict, Avatar
15
+ from chainlit.types import (
16
+ Feedback,
17
+ FeedbackDict,
18
+ PageInfo,
19
+ PaginatedResponse,
20
+ Pagination,
21
+ ThreadDict,
22
+ ThreadFilter,
23
+ )
24
+ from chainlit.user import PersistedUser, User
25
+ from sqlalchemy import text
26
+ from sqlalchemy.exc import SQLAlchemyError
27
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
28
+ from sqlalchemy.orm import sessionmaker
20
29
 
21
30
  if TYPE_CHECKING:
22
31
  from chainlit.element import Element, ElementDict
23
32
  from chainlit.step import StepDict
24
33
 
34
+
25
35
  class SQLAlchemyDataLayer(BaseDataLayer):
26
- def __init__(self, conninfo: str, ssl_require: bool = False, storage_provider: Optional[BaseStorageClient] = None, user_thread_limit: Optional[int] = 1000):
36
+ def __init__(
37
+ self,
38
+ conninfo: str,
39
+ ssl_require: bool = False,
40
+ storage_provider: Optional[BaseStorageClient] = None,
41
+ user_thread_limit: Optional[int] = 1000,
42
+ ):
27
43
  self._conninfo = conninfo
28
44
  self.user_thread_limit = user_thread_limit
29
45
  ssl_args = {}
@@ -32,17 +48,24 @@ class SQLAlchemyDataLayer(BaseDataLayer):
32
48
  ssl_context = ssl.create_default_context()
33
49
  ssl_context.check_hostname = False
34
50
  ssl_context.verify_mode = ssl.CERT_NONE
35
- ssl_args['ssl'] = ssl_context
36
- self.engine: AsyncEngine = create_async_engine(self._conninfo, connect_args=ssl_args)
51
+ ssl_args["ssl"] = ssl_context
52
+ self.engine: AsyncEngine = create_async_engine(
53
+ self._conninfo, connect_args=ssl_args
54
+ )
37
55
  self.async_session = sessionmaker(bind=self.engine, expire_on_commit=False, class_=AsyncSession) # type: ignore
38
56
  if storage_provider:
39
- self.storage_provider = storage_provider
57
+ self.storage_provider: Optional[BaseStorageClient] = storage_provider
40
58
  logger.info("SQLAlchemyDataLayer storage client initialized")
41
59
  else:
42
- logger.warn("SQLAlchemyDataLayer storage client is not initialized and elements will not be persisted!")
60
+ self.storage_provider = None
61
+ logger.warn(
62
+ "SQLAlchemyDataLayer storage client is not initialized and elements will not be persisted!"
63
+ )
43
64
 
44
65
  ###### SQL Helpers ######
45
- async def execute_sql(self, query: str, parameters: dict) -> Union[List[Dict[str, Any]], int, None]:
66
+ async def execute_sql(
67
+ self, query: str, parameters: dict
68
+ ) -> Union[List[Dict[str, Any]], int, None]:
46
69
  parameterized_query = text(query)
47
70
  async with self.async_session() as session:
48
71
  try:
@@ -66,7 +89,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
66
89
 
67
90
  async def get_current_timestamp(self) -> str:
68
91
  return datetime.now().isoformat() + "Z"
69
-
92
+
70
93
  def clean_result(self, obj):
71
94
  """Recursively change UUID -> str and serialize dictionaries"""
72
95
  if isinstance(obj, dict):
@@ -76,7 +99,7 @@ class SQLAlchemyDataLayer(BaseDataLayer):
76
99
  elif isinstance(obj, uuid.UUID):
77
100
  return str(obj)
78
101
  return obj
79
-
102
+
80
103
  ###### User ######
81
104
  async def get_user(self, identifier: str) -> Optional[PersistedUser]:
82
105
  logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
@@ -87,26 +110,28 @@ class SQLAlchemyDataLayer(BaseDataLayer):
87
110
  user_data = result[0]
88
111
  return PersistedUser(**user_data)
89
112
  return None
90
-
113
+
91
114
  async def create_user(self, user: User) -> Optional[PersistedUser]:
92
115
  logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
93
- existing_user: Optional['PersistedUser'] = await self.get_user(user.identifier)
94
- user_dict: Dict[str, Any] = {
116
+ existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier)
117
+ user_dict: Dict[str, Any] = {
95
118
  "identifier": str(user.identifier),
96
- "metadata": json.dumps(user.metadata) or {}
97
- }
98
- if not existing_user: # create the user
119
+ "metadata": json.dumps(user.metadata) or {},
120
+ }
121
+ if not existing_user: # create the user
99
122
  logger.info("SQLAlchemy: create_user, creating the user")
100
- user_dict['id'] = str(uuid.uuid4())
101
- user_dict['createdAt'] = await self.get_current_timestamp()
123
+ user_dict["id"] = str(uuid.uuid4())
124
+ user_dict["createdAt"] = await self.get_current_timestamp()
102
125
  query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)"""
103
126
  await self.execute_sql(query=query, parameters=user_dict)
104
- else: # update the user
127
+ else: # update the user
105
128
  logger.info("SQLAlchemy: update user metadata")
106
129
  query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier"""
107
- await self.execute_sql(query=query, parameters=user_dict) # We want to update the metadata
130
+ await self.execute_sql(
131
+ query=query, parameters=user_dict
132
+ ) # We want to update the metadata
108
133
  return await self.get_user(user.identifier)
109
-
134
+
110
135
  ###### Threads ######
111
136
  async def get_thread_author(self, thread_id: str) -> str:
112
137
  logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
@@ -114,21 +139,30 @@ class SQLAlchemyDataLayer(BaseDataLayer):
114
139
  parameters = {"id": thread_id}
115
140
  result = await self.execute_sql(query=query, parameters=parameters)
116
141
  if isinstance(result, list) and result[0]:
117
- author_identifier = result[0].get('userIdentifier')
142
+ author_identifier = result[0].get("userIdentifier")
118
143
  if author_identifier is not None:
119
- print(f'Author found: {author_identifier}')
144
+ print(f"Author found: {author_identifier}")
120
145
  return author_identifier
121
- raise ValueError (f"Author not found for thread_id {thread_id}")
122
-
146
+ raise ValueError(f"Author not found for thread_id {thread_id}")
147
+
123
148
  async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
124
149
  logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
125
- user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(thread_id=thread_id)
150
+ user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(
151
+ thread_id=thread_id
152
+ )
126
153
  if user_threads:
127
154
  return user_threads[0]
128
155
  else:
129
156
  return None
130
157
 
131
- async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None):
158
+ async def update_thread(
159
+ self,
160
+ thread_id: str,
161
+ name: Optional[str] = None,
162
+ user_id: Optional[str] = None,
163
+ metadata: Optional[Dict] = None,
164
+ tags: Optional[List[str]] = None,
165
+ ):
132
166
  logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
133
167
  if context.session.user is not None:
134
168
  user_identifier = context.session.user.identifier
@@ -136,17 +170,25 @@ class SQLAlchemyDataLayer(BaseDataLayer):
136
170
  raise ValueError("User not found in session context")
137
171
  data = {
138
172
  "id": thread_id,
139
- "createdAt": await self.get_current_timestamp() if metadata is None else None,
140
- "name": name if name is not None else (metadata.get('name') if metadata and 'name' in metadata else None),
173
+ "createdAt": await self.get_current_timestamp()
174
+ if metadata is None
175
+ else None,
176
+ "name": name
177
+ if name is not None
178
+ else (metadata.get("name") if metadata and "name" in metadata else None),
141
179
  "userId": user_id,
142
180
  "userIdentifier": user_identifier,
143
181
  "tags": tags,
144
182
  "metadata": json.dumps(metadata) if metadata else None,
145
183
  }
146
- parameters = {key: value for key, value in data.items() if value is not None} # Remove keys with None values
147
- columns = ', '.join(f'"{key}"' for key in parameters.keys())
148
- values = ', '.join(f':{key}' for key in parameters.keys())
149
- updates = ', '.join(f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != 'id')
184
+ parameters = {
185
+ key: value for key, value in data.items() if value is not None
186
+ } # Remove keys with None values
187
+ columns = ", ".join(f'"{key}"' for key in parameters.keys())
188
+ values = ", ".join(f":{key}" for key in parameters.keys())
189
+ updates = ", ".join(
190
+ f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != "id"
191
+ )
150
192
  query = f"""
151
193
  INSERT INTO threads ({columns})
152
194
  VALUES ({values})
@@ -167,12 +209,18 @@ class SQLAlchemyDataLayer(BaseDataLayer):
167
209
  await self.execute_sql(query=elements_query, parameters=parameters)
168
210
  await self.execute_sql(query=steps_query, parameters=parameters)
169
211
  await self.execute_sql(query=thread_query, parameters=parameters)
170
-
171
- async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse:
172
- logger.info(f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}")
212
+
213
+ async def list_threads(
214
+ self, pagination: Pagination, filters: ThreadFilter
215
+ ) -> PaginatedResponse:
216
+ logger.info(
217
+ f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}"
218
+ )
173
219
  if not filters.userId:
174
220
  raise ValueError("userId is required")
175
- all_user_threads: List[ThreadDict] = await self.get_all_user_threads(user_id=filters.userId) or []
221
+ all_user_threads: List[ThreadDict] = (
222
+ await self.get_all_user_threads(user_id=filters.userId) or []
223
+ )
176
224
 
177
225
  search_keyword = filters.search.lower() if filters.search else None
178
226
  feedback_value = int(filters.feedback) if filters.feedback else None
@@ -183,47 +231,68 @@ class SQLAlchemyDataLayer(BaseDataLayer):
183
231
  feedback_match = True
184
232
  if search_keyword or feedback_value is not None:
185
233
  if search_keyword:
186
- keyword_match = any(search_keyword in step['output'].lower() for step in thread['steps'] if 'output' in step)
234
+ keyword_match = any(
235
+ search_keyword in step["output"].lower()
236
+ for step in thread["steps"]
237
+ if "output" in step
238
+ )
187
239
  if feedback_value is not None:
188
240
  feedback_match = False # Assume no match until found
189
- for step in thread['steps']:
190
- feedback = step.get('feedback')
191
- if feedback and feedback.get('value') == feedback_value:
241
+ for step in thread["steps"]:
242
+ feedback = step.get("feedback")
243
+ if feedback and feedback.get("value") == feedback_value:
192
244
  feedback_match = True
193
245
  break
194
246
  if keyword_match and feedback_match:
195
247
  filtered_threads.append(thread)
196
-
248
+
197
249
  start = 0
198
250
  if pagination.cursor:
199
251
  for i, thread in enumerate(filtered_threads):
200
- if thread['id'] == pagination.cursor: # Find the start index using pagination.cursor
252
+ if (
253
+ thread["id"] == pagination.cursor
254
+ ): # Find the start index using pagination.cursor
201
255
  start = i + 1
202
256
  break
203
257
  end = start + pagination.first
204
258
  paginated_threads = filtered_threads[start:end] or []
205
259
 
206
260
  has_next_page = len(filtered_threads) > end
207
- start_cursor = paginated_threads[0]['id'] if paginated_threads else None
208
- end_cursor = paginated_threads[-1]['id'] if paginated_threads else None
261
+ start_cursor = paginated_threads[0]["id"] if paginated_threads else None
262
+ end_cursor = paginated_threads[-1]["id"] if paginated_threads else None
209
263
 
210
264
  return PaginatedResponse(
211
- pageInfo=PageInfo(hasNextPage=has_next_page, startCursor=start_cursor, endCursor=end_cursor),
212
- data=paginated_threads
265
+ pageInfo=PageInfo(
266
+ hasNextPage=has_next_page,
267
+ startCursor=start_cursor,
268
+ endCursor=end_cursor,
269
+ ),
270
+ data=paginated_threads,
213
271
  )
214
-
272
+
215
273
  ###### Steps ######
216
274
  @queue_until_user_message()
217
- async def create_step(self, step_dict: 'StepDict'):
275
+ async def create_step(self, step_dict: "StepDict"):
218
276
  logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
219
- if not getattr(context.session.user, 'id', None):
277
+ if not getattr(context.session.user, "id", None):
220
278
  raise ValueError("No authenticated user in context")
221
- step_dict['showInput'] = str(step_dict.get('showInput', '')).lower() if 'showInput' in step_dict else None
222
- parameters = {key: value for key, value in step_dict.items() if value is not None and not (isinstance(value, dict) and not value)}
223
- parameters['metadata'] = json.dumps(step_dict.get('metadata', {}))
224
- columns = ', '.join(f'"{key}"' for key in parameters.keys())
225
- values = ', '.join(f':{key}' for key in parameters.keys())
226
- updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id')
279
+ step_dict["showInput"] = (
280
+ str(step_dict.get("showInput", "")).lower()
281
+ if "showInput" in step_dict
282
+ else None
283
+ )
284
+ parameters = {
285
+ key: value
286
+ for key, value in step_dict.items()
287
+ if value is not None and not (isinstance(value, dict) and not value)
288
+ }
289
+ parameters["metadata"] = json.dumps(step_dict.get("metadata", {}))
290
+ parameters["generation"] = json.dumps(step_dict.get("generation", {}))
291
+ columns = ", ".join(f'"{key}"' for key in parameters.keys())
292
+ values = ", ".join(f":{key}" for key in parameters.keys())
293
+ updates = ", ".join(
294
+ f'"{key}" = :{key}' for key in parameters.keys() if key != "id"
295
+ )
227
296
  query = f"""
228
297
  INSERT INTO steps ({columns})
229
298
  VALUES ({values})
@@ -231,9 +300,9 @@ class SQLAlchemyDataLayer(BaseDataLayer):
231
300
  SET {updates};
232
301
  """
233
302
  await self.execute_sql(query=query, parameters=parameters)
234
-
303
+
235
304
  @queue_until_user_message()
236
- async def update_step(self, step_dict: 'StepDict'):
305
+ async def update_step(self, step_dict: "StepDict"):
237
306
  logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
238
307
  await self.create_step(step_dict)
239
308
 
@@ -248,17 +317,21 @@ class SQLAlchemyDataLayer(BaseDataLayer):
248
317
  await self.execute_sql(query=feedbacks_query, parameters=parameters)
249
318
  await self.execute_sql(query=elements_query, parameters=parameters)
250
319
  await self.execute_sql(query=steps_query, parameters=parameters)
251
-
320
+
252
321
  ###### Feedback ######
253
322
  async def upsert_feedback(self, feedback: Feedback) -> str:
254
323
  logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
255
324
  feedback.id = feedback.id or str(uuid.uuid4())
256
325
  feedback_dict = asdict(feedback)
257
- parameters = {key: value for key, value in feedback_dict.items() if value is not None}
326
+ parameters = {
327
+ key: value for key, value in feedback_dict.items() if value is not None
328
+ }
258
329
 
259
- columns = ', '.join(f'"{key}"' for key in parameters.keys())
260
- values = ', '.join(f':{key}' for key in parameters.keys())
261
- updates = ', '.join(f'"{key}" = :{key}' for key in parameters.keys() if key != 'id')
330
+ columns = ", ".join(f'"{key}"' for key in parameters.keys())
331
+ values = ", ".join(f":{key}" for key in parameters.keys())
332
+ updates = ", ".join(
333
+ f'"{key}" = :{key}' for key in parameters.keys() if key != "id"
334
+ )
262
335
  query = f"""
263
336
  INSERT INTO feedbacks ({columns})
264
337
  VALUES ({values})
@@ -267,24 +340,26 @@ class SQLAlchemyDataLayer(BaseDataLayer):
267
340
  """
268
341
  await self.execute_sql(query=query, parameters=parameters)
269
342
  return feedback.id
270
-
343
+
271
344
  async def delete_feedback(self, feedback_id: str) -> bool:
272
345
  logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
273
346
  query = """DELETE FROM feedbacks WHERE "id" = :feedback_id"""
274
347
  parameters = {"feedback_id": feedback_id}
275
348
  await self.execute_sql(query=query, parameters=parameters)
276
349
  return True
277
-
350
+
278
351
  ###### Elements ######
279
352
  @queue_until_user_message()
280
- async def create_element(self, element: 'Element'):
353
+ async def create_element(self, element: "Element"):
281
354
  logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
282
- if not getattr(context.session.user, 'id', None):
355
+ if not getattr(context.session.user, "id", None):
283
356
  raise ValueError("No authenticated user in context")
284
- if isinstance(element, Avatar): # Skip creating elements of type avatar
357
+ if isinstance(element, Avatar): # Skip creating elements of type avatar
285
358
  return
286
359
  if not self.storage_provider:
287
- logger.warn(f"SQLAlchemy: create_element error. No blob_storage_client is configured!")
360
+ logger.warn(
361
+ f"SQLAlchemy: create_element error. No blob_storage_client is configured!"
362
+ )
288
363
  return
289
364
  if not element.for_id:
290
365
  return
@@ -310,24 +385,30 @@ class SQLAlchemyDataLayer(BaseDataLayer):
310
385
 
311
386
  context_user = context.session.user
312
387
 
313
- user_folder = getattr(context_user, 'id', 'unknown')
314
- file_object_key = f"{user_folder}/{element.id}" + (f"/{element.name}" if element.name else "")
388
+ user_folder = getattr(context_user, "id", "unknown")
389
+ file_object_key = f"{user_folder}/{element.id}" + (
390
+ f"/{element.name}" if element.name else ""
391
+ )
315
392
 
316
393
  if not element.mime:
317
394
  element.mime = "application/octet-stream"
318
395
 
319
- uploaded_file = await self.storage_provider.upload_file(object_key=file_object_key, data=content, mime=element.mime, overwrite=True)
396
+ uploaded_file = await self.storage_provider.upload_file(
397
+ object_key=file_object_key, data=content, mime=element.mime, overwrite=True
398
+ )
320
399
  if not uploaded_file:
321
- raise ValueError("SQLAlchemy Error: create_element, Failed to persist data in storage_provider")
400
+ raise ValueError(
401
+ "SQLAlchemy Error: create_element, Failed to persist data in storage_provider"
402
+ )
322
403
 
323
404
  element_dict: ElementDict = element.to_dict()
324
405
 
325
- element_dict['url'] = uploaded_file.get('url')
326
- element_dict['objectKey'] = uploaded_file.get('object_key')
406
+ element_dict["url"] = uploaded_file.get("url")
407
+ element_dict["objectKey"] = uploaded_file.get("object_key")
327
408
  element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None}
328
409
 
329
- columns = ', '.join(f'"{column}"' for column in element_dict_cleaned.keys())
330
- placeholders = ', '.join(f':{column}' for column in element_dict_cleaned.keys())
410
+ columns = ", ".join(f'"{column}"' for column in element_dict_cleaned.keys())
411
+ placeholders = ", ".join(f":{column}" for column in element_dict_cleaned.keys())
331
412
  query = f"INSERT INTO elements ({columns}) VALUES ({placeholders})"
332
413
  await self.execute_sql(query=query, parameters=element_dict_cleaned)
333
414
 
@@ -339,9 +420,11 @@ class SQLAlchemyDataLayer(BaseDataLayer):
339
420
  await self.execute_sql(query=query, parameters=parameters)
340
421
 
341
422
  async def delete_user_session(self, id: str) -> bool:
342
- return False # Not sure why documentation wants this
423
+ return False # Not sure why documentation wants this
343
424
 
344
- async def get_all_user_threads(self, user_id: Optional[str] = None, thread_id: Optional[str] = None) -> Optional[List[ThreadDict]]:
425
+ async def get_all_user_threads(
426
+ self, user_id: Optional[str] = None, thread_id: Optional[str] = None
427
+ ) -> Optional[List[ThreadDict]]:
345
428
  """Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided."""
346
429
  logger.info(f"SQLAlchemy: get_all_user_threads")
347
430
  user_threads_query = """
@@ -358,14 +441,25 @@ class SQLAlchemyDataLayer(BaseDataLayer):
358
441
  ORDER BY "createdAt" DESC
359
442
  LIMIT :limit
360
443
  """
361
- user_threads = await self.execute_sql(query=user_threads_query, parameters={"user_id": user_id, "limit": self.user_thread_limit, "thread_id": thread_id})
444
+ user_threads = await self.execute_sql(
445
+ query=user_threads_query,
446
+ parameters={
447
+ "user_id": user_id,
448
+ "limit": self.user_thread_limit,
449
+ "thread_id": thread_id,
450
+ },
451
+ )
362
452
  if not isinstance(user_threads, list):
363
453
  return None
364
454
  if not user_threads:
365
455
  return []
366
456
  else:
367
- thread_ids = "('" + "','".join(map(str, [thread['thread_id'] for thread in user_threads])) + "')"
368
-
457
+ thread_ids = (
458
+ "('"
459
+ + "','".join(map(str, [thread["thread_id"] for thread in user_threads]))
460
+ + "')"
461
+ )
462
+
369
463
  steps_feedbacks_query = f"""
370
464
  SELECT
371
465
  s."id" AS step_id,
@@ -394,8 +488,10 @@ class SQLAlchemyDataLayer(BaseDataLayer):
394
488
  WHERE s."threadId" IN {thread_ids}
395
489
  ORDER BY s."createdAt" ASC
396
490
  """
397
- steps_feedbacks = await self.execute_sql(query=steps_feedbacks_query, parameters={})
398
-
491
+ steps_feedbacks = await self.execute_sql(
492
+ query=steps_feedbacks_query, parameters={}
493
+ )
494
+
399
495
  elements_query = f"""
400
496
  SELECT
401
497
  e."id" AS element_id,
@@ -418,77 +514,83 @@ class SQLAlchemyDataLayer(BaseDataLayer):
418
514
 
419
515
  thread_dicts = {}
420
516
  for thread in user_threads:
421
- thread_id = thread['thread_id']
517
+ thread_id = thread["thread_id"]
422
518
  if thread_id is not None:
423
519
  thread_dicts[thread_id] = ThreadDict(
424
520
  id=thread_id,
425
- createdAt=thread['thread_createdat'],
426
- name=thread['thread_name'],
427
- userId=thread['user_id'],
428
- userIdentifier=thread['user_identifier'],
429
- tags=thread['thread_tags'],
430
- metadata=thread['thread_metadata'],
521
+ createdAt=thread["thread_createdat"],
522
+ name=thread["thread_name"],
523
+ userId=thread["user_id"],
524
+ userIdentifier=thread["user_identifier"],
525
+ tags=thread["thread_tags"],
526
+ metadata=thread["thread_metadata"],
431
527
  steps=[],
432
- elements=[]
528
+ elements=[],
433
529
  )
434
530
  # Process steps_feedbacks to populate the steps in the corresponding ThreadDict
435
531
  if isinstance(steps_feedbacks, list):
436
532
  for step_feedback in steps_feedbacks:
437
- thread_id = step_feedback['step_threadid']
533
+ thread_id = step_feedback["step_threadid"]
438
534
  if thread_id is not None:
439
535
  feedback = None
440
- if step_feedback['feedback_value'] is not None:
536
+ if step_feedback["feedback_value"] is not None:
441
537
  feedback = FeedbackDict(
442
- forId=step_feedback['step_id'],
443
- id=step_feedback.get('feedback_id'),
444
- value=step_feedback['feedback_value'],
445
- comment=step_feedback.get('feedback_comment')
538
+ forId=step_feedback["step_id"],
539
+ id=step_feedback.get("feedback_id"),
540
+ value=step_feedback["feedback_value"],
541
+ comment=step_feedback.get("feedback_comment"),
446
542
  )
447
543
  step_dict = StepDict(
448
- id=step_feedback['step_id'],
449
- name=step_feedback['step_name'],
450
- type=step_feedback['step_type'],
544
+ id=step_feedback["step_id"],
545
+ name=step_feedback["step_name"],
546
+ type=step_feedback["step_type"],
451
547
  threadId=thread_id,
452
- parentId=step_feedback.get('step_parentid'),
453
- disableFeedback=step_feedback.get('step_disablefeedback', False),
454
- streaming=step_feedback.get('step_streaming', False),
455
- waitForAnswer=step_feedback.get('step_waitforanswer'),
456
- isError=step_feedback.get('step_iserror'),
457
- metadata=step_feedback['step_metadata'] if step_feedback.get('step_metadata') is not None else {},
458
- tags=step_feedback.get('step_tags'),
459
- input=step_feedback.get('step_input', '') if step_feedback['step_showinput'] else '',
460
- output=step_feedback.get('step_output', ''),
461
- createdAt=step_feedback.get('step_createdat'),
462
- start=step_feedback.get('step_start'),
463
- end=step_feedback.get('step_end'),
464
- generation=step_feedback.get('step_generation'),
465
- showInput=step_feedback.get('step_showinput'),
466
- language=step_feedback.get('step_language'),
467
- indent=step_feedback.get('step_indent'),
468
- feedback=feedback
548
+ parentId=step_feedback.get("step_parentid"),
549
+ disableFeedback=step_feedback.get(
550
+ "step_disablefeedback", False
551
+ ),
552
+ streaming=step_feedback.get("step_streaming", False),
553
+ waitForAnswer=step_feedback.get("step_waitforanswer"),
554
+ isError=step_feedback.get("step_iserror"),
555
+ metadata=step_feedback["step_metadata"]
556
+ if step_feedback.get("step_metadata") is not None
557
+ else {},
558
+ tags=step_feedback.get("step_tags"),
559
+ input=step_feedback.get("step_input", "")
560
+ if step_feedback["step_showinput"]
561
+ else "",
562
+ output=step_feedback.get("step_output", ""),
563
+ createdAt=step_feedback.get("step_createdat"),
564
+ start=step_feedback.get("step_start"),
565
+ end=step_feedback.get("step_end"),
566
+ generation=step_feedback.get("step_generation"),
567
+ showInput=step_feedback.get("step_showinput"),
568
+ language=step_feedback.get("step_language"),
569
+ indent=step_feedback.get("step_indent"),
570
+ feedback=feedback,
469
571
  )
470
572
  # Append the step to the steps list of the corresponding ThreadDict
471
- thread_dicts[thread_id]['steps'].append(step_dict)
573
+ thread_dicts[thread_id]["steps"].append(step_dict)
472
574
 
473
575
  if isinstance(elements, list):
474
576
  for element in elements:
475
- thread_id = element['element_threadid']
577
+ thread_id = element["element_threadid"]
476
578
  if thread_id is not None:
477
579
  element_dict = ElementDict(
478
- id=element['element_id'],
580
+ id=element["element_id"],
479
581
  threadId=thread_id,
480
- type=element['element_type'],
481
- chainlitKey=element.get('element_chainlitkey'),
482
- url=element.get('element_url'),
483
- objectKey=element.get('element_objectkey'),
484
- name=element['element_name'],
485
- display=element['element_display'],
486
- size=element.get('element_size'),
487
- language=element.get('element_language'),
488
- page=element.get('element_page'),
489
- forId=element.get('element_forid'),
490
- mime=element.get('element_mime'),
582
+ type=element["element_type"],
583
+ chainlitKey=element.get("element_chainlitkey"),
584
+ url=element.get("element_url"),
585
+ objectKey=element.get("element_objectkey"),
586
+ name=element["element_name"],
587
+ display=element["element_display"],
588
+ size=element.get("element_size"),
589
+ language=element.get("element_language"),
590
+ page=element.get("element_page"),
591
+ forId=element.get("element_forid"),
592
+ mime=element.get("element_mime"),
491
593
  )
492
- thread_dicts[thread_id]['elements'].append(element_dict) # type: ignore
594
+ thread_dicts[thread_id]["elements"].append(element_dict) # type: ignore
493
595
 
494
- return list(thread_dicts.values())
596
+ return list(thread_dicts.values())