PraisonAI 2.0.61__cp313-cp313-manylinux_2_39_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of PraisonAI might be problematic. Click here for more details.

Files changed (89) hide show
  1. praisonai/__init__.py +6 -0
  2. praisonai/__main__.py +10 -0
  3. praisonai/agents_generator.py +648 -0
  4. praisonai/api/call.py +292 -0
  5. praisonai/auto.py +238 -0
  6. praisonai/chainlit_ui.py +304 -0
  7. praisonai/cli.py +518 -0
  8. praisonai/deploy.py +138 -0
  9. praisonai/inbuilt_tools/__init__.py +24 -0
  10. praisonai/inbuilt_tools/autogen_tools.py +117 -0
  11. praisonai/inc/__init__.py +2 -0
  12. praisonai/inc/config.py +96 -0
  13. praisonai/inc/models.py +128 -0
  14. praisonai/public/android-chrome-192x192.png +0 -0
  15. praisonai/public/android-chrome-512x512.png +0 -0
  16. praisonai/public/apple-touch-icon.png +0 -0
  17. praisonai/public/fantasy.svg +3 -0
  18. praisonai/public/favicon-16x16.png +0 -0
  19. praisonai/public/favicon-32x32.png +0 -0
  20. praisonai/public/favicon.ico +0 -0
  21. praisonai/public/game.svg +3 -0
  22. praisonai/public/logo_dark.png +0 -0
  23. praisonai/public/logo_light.png +0 -0
  24. praisonai/public/movie.svg +3 -0
  25. praisonai/public/praison-ai-agents-architecture-dark.png +0 -0
  26. praisonai/public/praison-ai-agents-architecture.png +0 -0
  27. praisonai/public/thriller.svg +3 -0
  28. praisonai/setup/__init__.py +1 -0
  29. praisonai/setup/build.py +21 -0
  30. praisonai/setup/config.yaml +60 -0
  31. praisonai/setup/post_install.py +23 -0
  32. praisonai/setup/setup_conda_env.py +25 -0
  33. praisonai/setup/setup_conda_env.sh +72 -0
  34. praisonai/setup.py +16 -0
  35. praisonai/test.py +105 -0
  36. praisonai/train.py +276 -0
  37. praisonai/ui/README.md +21 -0
  38. praisonai/ui/agents.py +822 -0
  39. praisonai/ui/callbacks.py +57 -0
  40. praisonai/ui/chat.py +387 -0
  41. praisonai/ui/code.py +440 -0
  42. praisonai/ui/colab.py +474 -0
  43. praisonai/ui/colab_chainlit.py +81 -0
  44. praisonai/ui/components/aicoder.py +269 -0
  45. praisonai/ui/config/.chainlit/config.toml +120 -0
  46. praisonai/ui/config/.chainlit/translations/bn.json +231 -0
  47. praisonai/ui/config/.chainlit/translations/en-US.json +229 -0
  48. praisonai/ui/config/.chainlit/translations/gu.json +231 -0
  49. praisonai/ui/config/.chainlit/translations/he-IL.json +231 -0
  50. praisonai/ui/config/.chainlit/translations/hi.json +231 -0
  51. praisonai/ui/config/.chainlit/translations/kn.json +231 -0
  52. praisonai/ui/config/.chainlit/translations/ml.json +231 -0
  53. praisonai/ui/config/.chainlit/translations/mr.json +231 -0
  54. praisonai/ui/config/.chainlit/translations/ta.json +231 -0
  55. praisonai/ui/config/.chainlit/translations/te.json +231 -0
  56. praisonai/ui/config/.chainlit/translations/zh-CN.json +229 -0
  57. praisonai/ui/config/chainlit.md +1 -0
  58. praisonai/ui/config/translations/bn.json +231 -0
  59. praisonai/ui/config/translations/en-US.json +229 -0
  60. praisonai/ui/config/translations/gu.json +231 -0
  61. praisonai/ui/config/translations/he-IL.json +231 -0
  62. praisonai/ui/config/translations/hi.json +231 -0
  63. praisonai/ui/config/translations/kn.json +231 -0
  64. praisonai/ui/config/translations/ml.json +231 -0
  65. praisonai/ui/config/translations/mr.json +231 -0
  66. praisonai/ui/config/translations/ta.json +231 -0
  67. praisonai/ui/config/translations/te.json +231 -0
  68. praisonai/ui/config/translations/zh-CN.json +229 -0
  69. praisonai/ui/context.py +283 -0
  70. praisonai/ui/db.py +291 -0
  71. praisonai/ui/public/fantasy.svg +3 -0
  72. praisonai/ui/public/game.svg +3 -0
  73. praisonai/ui/public/logo_dark.png +0 -0
  74. praisonai/ui/public/logo_light.png +0 -0
  75. praisonai/ui/public/movie.svg +3 -0
  76. praisonai/ui/public/praison.css +3 -0
  77. praisonai/ui/public/thriller.svg +3 -0
  78. praisonai/ui/realtime.py +476 -0
  79. praisonai/ui/realtimeclient/__init__.py +653 -0
  80. praisonai/ui/realtimeclient/realtimedocs.txt +1484 -0
  81. praisonai/ui/realtimeclient/tools.py +236 -0
  82. praisonai/ui/sql_alchemy.py +707 -0
  83. praisonai/ui/tools.md +133 -0
  84. praisonai/version.py +1 -0
  85. praisonai-2.0.61.dist-info/LICENSE +20 -0
  86. praisonai-2.0.61.dist-info/METADATA +679 -0
  87. praisonai-2.0.61.dist-info/RECORD +89 -0
  88. praisonai-2.0.61.dist-info/WHEEL +4 -0
  89. praisonai-2.0.61.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,707 @@
1
+ import json
2
+ import ssl
3
+ import uuid
4
+ from dataclasses import asdict
5
+ from datetime import datetime
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
7
+ import os
8
+
9
+ import aiofiles
10
+ import aiohttp
11
+
12
+ from chainlit.data.base import BaseDataLayer
13
+ from chainlit.data.storage_clients.base import EXPIRY_TIME, BaseStorageClient
14
+ from chainlit.data.utils import queue_until_user_message
15
+ from chainlit.element import ElementDict
16
+ from chainlit.logger import logger
17
+ from chainlit.step import StepDict
18
+ from chainlit.types import (
19
+ Feedback,
20
+ FeedbackDict,
21
+ PageInfo,
22
+ PaginatedResponse,
23
+ Pagination,
24
+ ThreadDict,
25
+ ThreadFilter,
26
+ )
27
+ from chainlit.user import PersistedUser, User
28
+ from sqlalchemy import text
29
+ from sqlalchemy.exc import SQLAlchemyError
30
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
31
+ from sqlalchemy.orm import sessionmaker
32
+
33
+ if TYPE_CHECKING:
34
+ from chainlit.element import Element
35
+ from chainlit.step import StepDict
36
+
37
+ DATABASE_URL = os.getenv("DATABASE_URL")
38
+ SUPABASE_DATABASE_URL = os.getenv("SUPABASE_DATABASE_URL")
39
+ if SUPABASE_DATABASE_URL:
40
+ # If a Supabase database URL is provided, use it.
41
+ DATABASE_URL = SUPABASE_DATABASE_URL
42
+
43
+ class SQLAlchemyDataLayer(BaseDataLayer):
44
+ def __init__(
45
+ self,
46
+ conninfo: str,
47
+ ssl_require: bool = False,
48
+ storage_provider: Optional[BaseStorageClient] = None,
49
+ user_thread_limit: Optional[int] = 1000,
50
+ show_logger: Optional[bool] = False,
51
+ ):
52
+ self._conninfo = conninfo
53
+ self.user_thread_limit = user_thread_limit
54
+ self.show_logger = show_logger
55
+ ssl_args = {}
56
+ if ssl_require:
57
+ ssl_context = ssl.create_default_context()
58
+ ssl_context.check_hostname = False
59
+ ssl_context.verify_mode = ssl.CERT_NONE
60
+ ssl_args["ssl"] = ssl_context
61
+ self.engine: AsyncEngine = create_async_engine(
62
+ self._conninfo, connect_args=ssl_args
63
+ )
64
+ self.async_session = sessionmaker(
65
+ bind=self.engine, expire_on_commit=False, class_=AsyncSession
66
+ )
67
+ if storage_provider:
68
+ self.storage_provider: Optional[BaseStorageClient] = storage_provider
69
+ if self.show_logger:
70
+ logger.info("SQLAlchemyDataLayer storage client initialized")
71
+ else:
72
+ self.storage_provider = None
73
+ logger.warn(
74
+ "SQLAlchemyDataLayer storage client is not initialized and elements will not be persisted!"
75
+ )
76
+
77
+ async def build_debug_url(self) -> str:
78
+ return ""
79
+
80
+ ###### SQL Helpers ######
81
+ async def execute_sql(
82
+ self, query: str, parameters: dict
83
+ ) -> Union[List[Dict[str, Any]], int, None]:
84
+ parameterized_query = text(query)
85
+ async with self.async_session() as session:
86
+ try:
87
+ await session.begin()
88
+ result = await session.execute(parameterized_query, parameters)
89
+ await session.commit()
90
+ if result.returns_rows:
91
+ json_result = [dict(row._mapping) for row in result.fetchall()]
92
+ clean_json_result = self.clean_result(json_result)
93
+ assert isinstance(clean_json_result, list) or isinstance(
94
+ clean_json_result, int
95
+ )
96
+ return clean_json_result
97
+ else:
98
+ return result.rowcount
99
+ except SQLAlchemyError as e:
100
+ await session.rollback()
101
+ logger.warn(f"An error occurred: {e}")
102
+ return None
103
+ except Exception as e:
104
+ await session.rollback()
105
+ logger.warn(f"An unexpected error occurred: {e}")
106
+ return None
107
+
108
+ async def get_current_timestamp(self) -> str:
109
+ return datetime.now().isoformat() + "Z"
110
+
111
+ def clean_result(self, obj):
112
+ if isinstance(obj, dict):
113
+ return {k: self.clean_result(v) for k, v in obj.items()}
114
+ elif isinstance(obj, list):
115
+ return [self.clean_result(item) for item in obj]
116
+ elif isinstance(obj, uuid.UUID):
117
+ return str(obj)
118
+ return obj
119
+
120
+ ###### User ######
121
+ async def get_user(self, identifier: str) -> Optional[PersistedUser]:
122
+ if self.show_logger:
123
+ logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
124
+ query = 'SELECT * FROM users WHERE "identifier" = :identifier'
125
+ parameters = {"identifier": identifier}
126
+ result = await self.execute_sql(query=query, parameters=parameters)
127
+ if result and isinstance(result, list):
128
+ user_data = result[0]
129
+
130
+ meta = user_data.get("meta", "{}")
131
+ if isinstance(meta, str):
132
+ meta = json.loads(meta)
133
+
134
+ return PersistedUser(
135
+ id=user_data["id"],
136
+ identifier=user_data["identifier"],
137
+ createdAt=user_data["createdAt"],
138
+ metadata=meta,
139
+ )
140
+ return None
141
+
142
+ async def _get_user_identifer_by_id(self, user_id: str) -> str:
143
+ if self.show_logger:
144
+ logger.info(f"SQLAlchemy: _get_user_identifer_by_id, user_id={user_id}")
145
+ query = 'SELECT "identifier" FROM users WHERE "id" = :user_id'
146
+ parameters = {"user_id": user_id}
147
+ result = await self.execute_sql(query=query, parameters=parameters)
148
+ assert result
149
+ assert isinstance(result, list)
150
+ return result[0]["identifier"]
151
+
152
+ async def _get_user_id_by_thread(self, thread_id: str) -> Optional[str]:
153
+ if self.show_logger:
154
+ logger.info(f"SQLAlchemy: _get_user_id_by_thread, thread_id={thread_id}")
155
+ query = 'SELECT "userId" FROM threads WHERE "id" = :thread_id'
156
+ parameters = {"thread_id": thread_id}
157
+ result = await self.execute_sql(query=query, parameters=parameters)
158
+ if result and isinstance(result, list):
159
+ return result[0]["userId"]
160
+ return None
161
+
162
+ async def create_user(self, user: User) -> Optional[PersistedUser]:
163
+ if self.show_logger:
164
+ logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
165
+ existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier)
166
+ user_dict: Dict[str, Any] = {
167
+ "identifier": str(user.identifier),
168
+ "meta": json.dumps(user.metadata) or "{}",
169
+ }
170
+ if not existing_user:
171
+ user_dict["id"] = str(uuid.uuid4())
172
+ user_dict["createdAt"] = await self.get_current_timestamp()
173
+ query = 'INSERT INTO users ("id", "identifier", "createdAt", "meta") VALUES (:id, :identifier, :createdAt, :meta)'
174
+ await self.execute_sql(query=query, parameters=user_dict)
175
+ else:
176
+ query = 'UPDATE users SET "meta" = :meta WHERE "identifier" = :identifier'
177
+ await self.execute_sql(query=query, parameters=user_dict)
178
+ return await self.get_user(user.identifier)
179
+
180
+ ###### Threads ######
181
+ async def get_thread_author(self, thread_id: str) -> str:
182
+ if self.show_logger:
183
+ logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
184
+ query = 'SELECT "userIdentifier" FROM threads WHERE "id" = :id'
185
+ parameters = {"id": thread_id}
186
+ result = await self.execute_sql(query=query, parameters=parameters)
187
+ if isinstance(result, list) and result:
188
+ author_identifier = result[0].get("userIdentifier")
189
+ if author_identifier is not None:
190
+ return author_identifier
191
+ raise ValueError(f"Author not found for thread_id {thread_id}")
192
+
193
+ async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
194
+ if self.show_logger:
195
+ logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
196
+ user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(
197
+ thread_id=thread_id
198
+ )
199
+ if user_threads:
200
+ return user_threads[0]
201
+ else:
202
+ return None
203
+
204
+ async def update_thread(
205
+ self,
206
+ thread_id: str,
207
+ name: Optional[str] = None,
208
+ user_id: Optional[str] = None,
209
+ metadata: Optional[Dict] = None,
210
+ tags: Optional[List[str]] = None,
211
+ ):
212
+ if self.show_logger:
213
+ logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
214
+
215
+ user_identifier = None
216
+ if user_id:
217
+ user_identifier = await self._get_user_identifer_by_id(user_id)
218
+
219
+ data = {
220
+ "id": thread_id,
221
+ "createdAt": (
222
+ await self.get_current_timestamp() if metadata is None else None
223
+ ),
224
+ "name": (
225
+ name
226
+ if name is not None
227
+ else (metadata.get("name") if metadata and "name" in metadata else None)
228
+ ),
229
+ "userId": user_id,
230
+ "userIdentifier": user_identifier,
231
+ "tags": json.dumps(tags) if tags else None,
232
+ "meta": json.dumps(metadata) if metadata else None,
233
+ }
234
+ parameters = {key: value for key, value in data.items() if value is not None}
235
+ columns = ", ".join(f'"{key}"' for key in parameters.keys())
236
+ values = ", ".join(f':{key}' for key in parameters.keys())
237
+ updates = ", ".join(
238
+ f'"{key}" = EXCLUDED."{key}"' for key in parameters.keys() if key != "id"
239
+ )
240
+ query = f"""
241
+ INSERT INTO threads ({columns})
242
+ VALUES ({values})
243
+ ON CONFLICT ("id") DO UPDATE
244
+ SET {updates};
245
+ """
246
+ await self.execute_sql(query=query, parameters=parameters)
247
+
248
+ async def delete_thread(self, thread_id: str):
249
+ if self.show_logger:
250
+ logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
251
+ feedbacks_query = 'DELETE FROM feedbacks WHERE "forId" IN (SELECT "id" FROM steps WHERE "threadId" = :id)'
252
+ elements_query = 'DELETE FROM elements WHERE "threadId" = :id'
253
+ steps_query = 'DELETE FROM steps WHERE "threadId" = :id'
254
+ thread_query = 'DELETE FROM threads WHERE "id" = :id'
255
+ parameters = {"id": thread_id}
256
+ await self.execute_sql(query=feedbacks_query, parameters=parameters)
257
+ await self.execute_sql(query=elements_query, parameters=parameters)
258
+ await self.execute_sql(query=steps_query, parameters=parameters)
259
+ await self.execute_sql(query=thread_query, parameters=parameters)
260
+
261
+ async def list_threads(
262
+ self, pagination: Pagination, filters: ThreadFilter
263
+ ) -> PaginatedResponse:
264
+ if self.show_logger:
265
+ logger.info(
266
+ f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}"
267
+ )
268
+ if not filters.userId:
269
+ raise ValueError("userId is required")
270
+ all_user_threads: Optional[List[ThreadDict]] = (
271
+ await self.get_all_user_threads(user_id=filters.userId) or []
272
+ )
273
+
274
+ search_keyword = filters.search.lower() if filters.search else None
275
+ feedback_value = int(filters.feedback) if filters.feedback else None
276
+
277
+ filtered_threads = []
278
+ for thread in all_user_threads:
279
+ keyword_match = True
280
+ feedback_match = True
281
+ if search_keyword or feedback_value is not None:
282
+ if search_keyword:
283
+ keyword_match = any(
284
+ search_keyword in step["output"].lower()
285
+ for step in thread["steps"]
286
+ if "output" in step and isinstance(step["output"], str)
287
+ )
288
+ if feedback_value is not None:
289
+ feedback_match = False
290
+ for step in thread["steps"]:
291
+ feedback = step.get("feedback")
292
+ if feedback and feedback.get("value") == feedback_value:
293
+ feedback_match = True
294
+ break
295
+ if keyword_match and feedback_match:
296
+ filtered_threads.append(thread)
297
+
298
+ start = 0
299
+ if pagination.cursor:
300
+ for i, thr in enumerate(filtered_threads):
301
+ if thr["id"] == pagination.cursor:
302
+ start = i + 1
303
+ break
304
+ end = start + pagination.first
305
+ paginated_threads = filtered_threads[start:end] or []
306
+
307
+ has_next_page = len(filtered_threads) > end
308
+ start_cursor = paginated_threads[0]["id"] if paginated_threads else None
309
+ end_cursor = paginated_threads[-1]["id"] if paginated_threads else None
310
+
311
+ return PaginatedResponse(
312
+ pageInfo=PageInfo(
313
+ hasNextPage=has_next_page,
314
+ startCursor=start_cursor,
315
+ endCursor=end_cursor,
316
+ ),
317
+ data=paginated_threads,
318
+ )
319
+
320
+ ###### Steps ######
321
+ @queue_until_user_message()
322
+ async def create_step(self, step_dict: "StepDict"):
323
+ if self.show_logger:
324
+ logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
325
+
326
+ step_dict["showInput"] = (
327
+ str(step_dict.get("showInput", "")).lower()
328
+ if "showInput" in step_dict
329
+ else None
330
+ )
331
+
332
+ tags = step_dict.get("tags")
333
+ if not tags:
334
+ tags = []
335
+ meta = json.dumps(step_dict.get("metadata", {}))
336
+ generation = json.dumps(step_dict.get("generation", {}))
337
+ parameters = {
338
+ "id": step_dict.get("id"),
339
+ "name": step_dict.get("name"),
340
+ "type": step_dict.get("type"),
341
+ "threadId": step_dict.get("threadId"),
342
+ "parentId": step_dict.get("parentId"),
343
+ "disableFeedback": step_dict.get("disableFeedback", False),
344
+ "streaming": step_dict.get("streaming", False),
345
+ "waitForAnswer": step_dict.get("waitForAnswer", False),
346
+ "isError": step_dict.get("isError", False),
347
+ "meta": meta,
348
+ "tags": json.dumps(tags),
349
+ "input": step_dict.get("input"),
350
+ "output": step_dict.get("output"),
351
+ "createdAt": step_dict.get("createdAt"),
352
+ "startTime": step_dict.get("start"),
353
+ "endTime": step_dict.get("end"),
354
+ "generation": generation,
355
+ "showInput": step_dict.get("showInput"),
356
+ "language": step_dict.get("language"),
357
+ "indent": step_dict.get("indent"),
358
+ }
359
+ parameters = {k: v for k, v in parameters.items() if v is not None}
360
+ columns = ", ".join(f'"{key}"' for key in parameters.keys())
361
+ values = ", ".join(f':{key}' for key in parameters.keys())
362
+ updates = ", ".join(
363
+ f'"{key}" = :{key}' for key in parameters.keys() if key != "id"
364
+ )
365
+ query = f"""
366
+ INSERT INTO steps ({columns})
367
+ VALUES ({values})
368
+ ON CONFLICT ("id") DO UPDATE
369
+ SET {updates};
370
+ """
371
+ await self.execute_sql(query=query, parameters=parameters)
372
+
373
+ @queue_until_user_message()
374
+ async def update_step(self, step_dict: "StepDict"):
375
+ if self.show_logger:
376
+ logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
377
+ await self.create_step(step_dict)
378
+
379
+ @queue_until_user_message()
380
+ async def delete_step(self, step_id: str):
381
+ if self.show_logger:
382
+ logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
383
+ feedbacks_query = 'DELETE FROM feedbacks WHERE "forId" = :id'
384
+ elements_query = 'DELETE FROM elements WHERE "forId" = :id'
385
+ steps_query = 'DELETE FROM steps WHERE "id" = :id'
386
+ parameters = {"id": step_id}
387
+ await self.execute_sql(query=feedbacks_query, parameters=parameters)
388
+ await self.execute_sql(query=elements_query, parameters=parameters)
389
+ await self.execute_sql(query=steps_query, parameters=parameters)
390
+
391
+ ###### Feedback ######
392
+ async def upsert_feedback(self, feedback: Feedback) -> str:
393
+ if self.show_logger:
394
+ logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
395
+ feedback.id = feedback.id or str(uuid.uuid4())
396
+ feedback_dict = asdict(feedback)
397
+ parameters = {k: v for k, v in feedback_dict.items() if v is not None}
398
+ columns = ", ".join(f'"{key}"' for key in parameters.keys())
399
+ values = ", ".join(f':{key}' for key in parameters.keys())
400
+ updates = ", ".join(
401
+ f'"{key}" = :{key}' for key in parameters.keys() if key != "id"
402
+ )
403
+ query = f"""
404
+ INSERT INTO feedbacks ({columns})
405
+ VALUES ({values})
406
+ ON CONFLICT ("id") DO UPDATE
407
+ SET {updates};
408
+ """
409
+ await self.execute_sql(query=query, parameters=parameters)
410
+ return feedback.id
411
+
412
+ async def delete_feedback(self, feedback_id: str) -> bool:
413
+ if self.show_logger:
414
+ logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
415
+ query = 'DELETE FROM feedbacks WHERE "id" = :feedback_id'
416
+ parameters = {"feedback_id": feedback_id}
417
+ await self.execute_sql(query=query, parameters=parameters)
418
+ return True
419
+
420
+ ###### Elements ######
421
+ async def get_element(
422
+ self, thread_id: str, element_id: str
423
+ ) -> Optional["ElementDict"]:
424
+ if self.show_logger:
425
+ logger.info(
426
+ f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}"
427
+ )
428
+ query = 'SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id'
429
+ parameters = {"thread_id": thread_id, "element_id": element_id}
430
+ element = await self.execute_sql(query=query, parameters=parameters)
431
+ if isinstance(element, list) and element:
432
+ element_dict = element[0]
433
+ return ElementDict(
434
+ id=element_dict["id"],
435
+ threadId=element_dict.get("threadId"),
436
+ type=element_dict.get("type"),
437
+ chainlitKey=element_dict.get("chainlitKey"),
438
+ url=element_dict.get("url"),
439
+ objectKey=element_dict.get("objectKey"),
440
+ name=element_dict["name"],
441
+ display=element_dict["display"],
442
+ size=element_dict.get("size"),
443
+ language=element_dict.get("language"),
444
+ page=element_dict.get("page"),
445
+ autoPlay=element_dict.get("autoPlay"),
446
+ playerConfig=element_dict.get("playerConfig"),
447
+ forId=element_dict.get("forId"),
448
+ mime=element_dict.get("mime"),
449
+ )
450
+ else:
451
+ return None
452
+
453
+ @queue_until_user_message()
454
+ async def create_element(self, element: "Element"):
455
+ if self.show_logger:
456
+ logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
457
+
458
+ if not self.storage_provider:
459
+ logger.warn("SQLAlchemy: create_element error. No storage client!")
460
+ return
461
+ if not element.for_id:
462
+ return
463
+
464
+ content: Optional[Union[bytes, str]] = None
465
+
466
+ if element.path:
467
+ async with aiofiles.open(element.path, "rb") as f:
468
+ content = await f.read()
469
+ elif element.url:
470
+ async with aiohttp.ClientSession() as session:
471
+ async with session.get(element.url) as response:
472
+ if response.status == 200:
473
+ content = await response.read()
474
+ else:
475
+ content = None
476
+ elif element.content:
477
+ content = element.content
478
+ else:
479
+ raise ValueError("Element url, path or content must be provided")
480
+ if content is None:
481
+ raise ValueError("Content is None, cannot upload file")
482
+
483
+ user_id: str = await self._get_user_id_by_thread(element.thread_id) or "unknown"
484
+ file_object_key = f"{user_id}/{element.id}" + (
485
+ f"/{element.name}" if element.name else ""
486
+ )
487
+
488
+ if not element.mime:
489
+ element.mime = "application/octet-stream"
490
+
491
+ uploaded_file = await self.storage_provider.upload_file(
492
+ object_key=file_object_key, data=content, mime=element.mime, overwrite=True
493
+ )
494
+ if not uploaded_file:
495
+ raise ValueError("Failed to persist data in storage_provider")
496
+
497
+ element_dict: ElementDict = element.to_dict()
498
+ element_dict["url"] = uploaded_file.get("url")
499
+ element_dict["objectKey"] = uploaded_file.get("object_key")
500
+ element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None}
501
+
502
+ columns = ", ".join(f'"{column}"' for column in element_dict_cleaned.keys())
503
+ placeholders = ", ".join(f':{column}' for column in element_dict_cleaned.keys())
504
+ query = f"INSERT INTO elements ({columns}) VALUES ({placeholders})"
505
+ await self.execute_sql(query=query, parameters=element_dict_cleaned)
506
+
507
+ @queue_until_user_message()
508
+ async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
509
+ if self.show_logger:
510
+ logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
511
+ query = 'DELETE FROM elements WHERE "id" = :id'
512
+ parameters = {"id": element_id}
513
+ await self.execute_sql(query=query, parameters=parameters)
514
+
515
+ async def get_all_user_threads(
516
+ self, user_id: Optional[str] = None, thread_id: Optional[str] = None
517
+ ) -> Optional[List[ThreadDict]]:
518
+ if self.show_logger:
519
+ logger.info("SQLAlchemy: get_all_user_threads")
520
+ user_threads_query = """
521
+ SELECT
522
+ "id" AS thread_id,
523
+ "createdAt" AS thread_createdat,
524
+ "name" AS thread_name,
525
+ "userId" AS user_id,
526
+ "userIdentifier" AS user_identifier,
527
+ "tags" AS thread_tags,
528
+ "meta" AS thread_meta
529
+ FROM threads
530
+ WHERE ("userId" = :user_id OR :user_id IS NULL)
531
+ AND ("id" = :thread_id OR :thread_id IS NULL)
532
+ ORDER BY "createdAt" DESC
533
+ LIMIT :limit
534
+ """
535
+ params = {
536
+ "user_id": user_id,
537
+ "thread_id": thread_id,
538
+ "limit": self.user_thread_limit,
539
+ }
540
+ user_threads = await self.execute_sql(
541
+ query=user_threads_query,
542
+ parameters=params,
543
+ )
544
+ if not isinstance(user_threads, list):
545
+ return None
546
+ if not user_threads:
547
+ return []
548
+ else:
549
+ thread_ids = (
550
+ "('" + "','".join([t["thread_id"] for t in user_threads]) + "')"
551
+ )
552
+
553
+ steps_feedbacks_query = f"""
554
+ SELECT
555
+ s."id" AS step_id,
556
+ s."name" AS step_name,
557
+ s."type" AS step_type,
558
+ s."threadId" AS step_threadid,
559
+ s."parentId" AS step_parentid,
560
+ s."streaming" AS step_streaming,
561
+ s."waitForAnswer" AS step_waitforanswer,
562
+ s."isError" AS step_iserror,
563
+ s."meta" AS step_meta,
564
+ s."tags" AS step_tags,
565
+ s."input" AS step_input,
566
+ s."output" AS step_output,
567
+ s."createdAt" AS step_createdat,
568
+ s."startTime" AS step_start,
569
+ s."endTime" AS step_end,
570
+ s."generation" AS step_generation,
571
+ s."showInput" AS step_showinput,
572
+ s."language" AS step_language,
573
+ s."indent" AS step_indent,
574
+ f."value" AS feedback_value,
575
+ f."comment" AS feedback_comment,
576
+ f."id" AS feedback_id
577
+ FROM steps s LEFT JOIN feedbacks f ON s."id" = f."forId"
578
+ WHERE s."threadId" IN {thread_ids}
579
+ ORDER BY s."createdAt" ASC
580
+ """
581
+ steps_feedbacks = await self.execute_sql(
582
+ query=steps_feedbacks_query, parameters={}
583
+ )
584
+
585
+ elements_query = f"""
586
+ SELECT
587
+ e."id" AS element_id,
588
+ e."threadId" as element_threadid,
589
+ e."type" AS element_type,
590
+ e."chainlitKey" AS element_chainlitkey,
591
+ e."url" AS element_url,
592
+ e."objectKey" as element_objectkey,
593
+ e."name" AS element_name,
594
+ e."display" AS element_display,
595
+ e."size" AS element_size,
596
+ e."language" AS element_language,
597
+ e."page" AS element_page,
598
+ e."forId" AS element_forid,
599
+ e."mime" AS element_mime
600
+ FROM elements e
601
+ WHERE e."threadId" IN {thread_ids}
602
+ """
603
+ elements = await self.execute_sql(query=elements_query, parameters={})
604
+
605
+ thread_dicts = {}
606
+ for thread in user_threads:
607
+ t_id = thread["thread_id"]
608
+ meta = thread["thread_meta"]
609
+ if isinstance(meta, str):
610
+ try:
611
+ meta = json.loads(meta)
612
+ except:
613
+ meta = {}
614
+ tags = thread["thread_tags"]
615
+ if isinstance(tags, str):
616
+ try:
617
+ tags = json.loads(tags)
618
+ except:
619
+ tags = []
620
+ thread_dicts[t_id] = ThreadDict(
621
+ id=t_id,
622
+ createdAt=thread["thread_createdat"],
623
+ name=thread["thread_name"],
624
+ userId=thread["user_id"],
625
+ userIdentifier=thread["user_identifier"],
626
+ tags=tags,
627
+ metadata=meta,
628
+ steps=[],
629
+ elements=[],
630
+ )
631
+
632
+ if isinstance(steps_feedbacks, list):
633
+ for step_feedback in steps_feedbacks:
634
+ t_id = step_feedback["step_threadid"]
635
+ if t_id in thread_dicts:
636
+ meta = step_feedback["step_meta"]
637
+ if isinstance(meta, str):
638
+ try:
639
+ meta = json.loads(meta)
640
+ except:
641
+ meta = {}
642
+ tags = step_feedback["step_tags"]
643
+ if isinstance(tags, str):
644
+ try:
645
+ tags = json.loads(tags)
646
+ except:
647
+ tags = []
648
+ feedback = None
649
+ if step_feedback["feedback_value"] is not None:
650
+ feedback = FeedbackDict(
651
+ forId=step_feedback["step_id"],
652
+ id=step_feedback.get("feedback_id"),
653
+ value=step_feedback["feedback_value"],
654
+ comment=step_feedback.get("feedback_comment"),
655
+ )
656
+ input_val = step_feedback.get("step_input", "")
657
+ show_input = step_feedback.get("step_showinput", "false")
658
+ if show_input == "false":
659
+ input_val = ""
660
+ step_dict = StepDict(
661
+ id=step_feedback["step_id"],
662
+ name=step_feedback["step_name"],
663
+ type=step_feedback["step_type"],
664
+ threadId=t_id,
665
+ parentId=step_feedback.get("step_parentid"),
666
+ streaming=step_feedback.get("step_streaming", False),
667
+ waitForAnswer=step_feedback.get("step_waitforanswer"),
668
+ isError=step_feedback.get("step_iserror"),
669
+ metadata=meta,
670
+ tags=tags,
671
+ input=input_val,
672
+ output=step_feedback.get("step_output", ""),
673
+ createdAt=step_feedback.get("step_createdat"),
674
+ start=step_feedback.get("step_start"),
675
+ end=step_feedback.get("step_end"),
676
+ generation=step_feedback.get("step_generation"),
677
+ showInput=step_feedback.get("step_showinput"),
678
+ language=step_feedback.get("step_language"),
679
+ indent=step_feedback.get("step_indent"),
680
+ feedback=feedback,
681
+ )
682
+ thread_dicts[t_id]["steps"].append(step_dict)
683
+
684
+ if isinstance(elements, list):
685
+ for element in elements:
686
+ t_id = element["element_threadid"]
687
+ if t_id in thread_dicts:
688
+ element_dict = ElementDict(
689
+ id=element["element_id"],
690
+ threadId=t_id,
691
+ type=element["element_type"],
692
+ chainlitKey=element.get("element_chainlitkey"),
693
+ url=element.get("element_url"),
694
+ objectKey=element.get("element_objectkey"),
695
+ name=element["element_name"],
696
+ display=element["element_display"],
697
+ size=element.get("element_size"),
698
+ language=element.get("element_language"),
699
+ autoPlay=element.get("element_autoPlay"),
700
+ playerConfig=element.get("element_playerconfig"),
701
+ page=element.get("element_page"),
702
+ forId=element.get("element_forid"),
703
+ mime=element.get("element_mime"),
704
+ )
705
+ thread_dicts[t_id]["elements"].append(element_dict)
706
+
707
+ return list(thread_dicts.values())