agno 2.1.0__py3-none-any.whl → 2.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +13 -1
- agno/db/base.py +8 -4
- agno/db/dynamo/dynamo.py +69 -17
- agno/db/firestore/firestore.py +68 -29
- agno/db/gcs_json/gcs_json_db.py +68 -17
- agno/db/in_memory/in_memory_db.py +83 -14
- agno/db/json/json_db.py +79 -15
- agno/db/mongo/mongo.py +92 -74
- agno/db/mysql/mysql.py +17 -3
- agno/db/postgres/postgres.py +21 -3
- agno/db/redis/redis.py +38 -11
- agno/db/singlestore/singlestore.py +14 -3
- agno/db/sqlite/sqlite.py +34 -46
- agno/db/utils.py +50 -22
- agno/knowledge/knowledge.py +6 -0
- agno/knowledge/reader/field_labeled_csv_reader.py +294 -0
- agno/knowledge/reader/pdf_reader.py +28 -52
- agno/knowledge/reader/reader_factory.py +12 -0
- agno/memory/manager.py +12 -4
- agno/models/anthropic/claude.py +4 -1
- agno/models/aws/bedrock.py +52 -112
- agno/models/openai/responses.py +1 -1
- agno/os/app.py +24 -30
- agno/os/interfaces/__init__.py +1 -0
- agno/os/interfaces/a2a/__init__.py +3 -0
- agno/os/interfaces/a2a/a2a.py +42 -0
- agno/os/interfaces/a2a/router.py +252 -0
- agno/os/interfaces/a2a/utils.py +924 -0
- agno/os/interfaces/agui/agui.py +21 -5
- agno/os/interfaces/agui/router.py +12 -0
- agno/os/interfaces/base.py +4 -2
- agno/os/interfaces/slack/slack.py +13 -8
- agno/os/interfaces/whatsapp/whatsapp.py +12 -5
- agno/os/mcp.py +1 -1
- agno/os/router.py +39 -9
- agno/os/routers/memory/memory.py +5 -3
- agno/os/routers/memory/schemas.py +1 -0
- agno/os/utils.py +36 -10
- agno/run/base.py +2 -13
- agno/team/team.py +13 -1
- agno/tools/mcp.py +46 -1
- agno/utils/merge_dict.py +22 -1
- agno/utils/serialize.py +32 -0
- agno/utils/streamlit.py +1 -1
- agno/workflow/parallel.py +90 -14
- agno/workflow/step.py +30 -27
- agno/workflow/types.py +4 -6
- agno/workflow/workflow.py +5 -3
- {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/METADATA +16 -14
- {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/RECORD +53 -47
- {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/WHEEL +0 -0
- {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/licenses/LICENSE +0 -0
- {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/top_level.txt +0 -0
|
@@ -343,10 +343,27 @@ class InMemoryDb(BaseDb):
|
|
|
343
343
|
return []
|
|
344
344
|
|
|
345
345
|
# -- Memory methods --
|
|
346
|
-
def delete_user_memory(self, memory_id: str):
|
|
346
|
+
def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
|
|
347
|
+
"""Delete a user memory from in-memory storage.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
memory_id (str): The ID of the memory to delete.
|
|
351
|
+
user_id (Optional[str]): The ID of the user. If provided, verifies the memory belongs to this user before deletion.
|
|
352
|
+
|
|
353
|
+
Raises:
|
|
354
|
+
Exception: If an error occurs during deletion.
|
|
355
|
+
"""
|
|
347
356
|
try:
|
|
348
357
|
original_count = len(self._memories)
|
|
349
|
-
|
|
358
|
+
|
|
359
|
+
# If user_id is provided, verify ownership before deleting
|
|
360
|
+
if user_id is not None:
|
|
361
|
+
self._memories = [
|
|
362
|
+
m for m in self._memories
|
|
363
|
+
if not (m.get("memory_id") == memory_id and m.get("user_id") == user_id)
|
|
364
|
+
]
|
|
365
|
+
else:
|
|
366
|
+
self._memories = [m for m in self._memories if m.get("memory_id") != memory_id]
|
|
350
367
|
|
|
351
368
|
if len(self._memories) < original_count:
|
|
352
369
|
log_debug(f"Successfully deleted user memory id: {memory_id}")
|
|
@@ -357,10 +374,25 @@ class InMemoryDb(BaseDb):
|
|
|
357
374
|
log_error(f"Error deleting memory: {e}")
|
|
358
375
|
raise e
|
|
359
376
|
|
|
360
|
-
def delete_user_memories(self, memory_ids: List[str]) -> None:
|
|
361
|
-
"""Delete multiple user memories from in-memory storage.
|
|
377
|
+
def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
|
|
378
|
+
"""Delete multiple user memories from in-memory storage.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
memory_ids (List[str]): The IDs of the memories to delete.
|
|
382
|
+
user_id (Optional[str]): The ID of the user. If provided, only deletes memories belonging to this user.
|
|
383
|
+
|
|
384
|
+
Raises:
|
|
385
|
+
Exception: If an error occurs during deletion.
|
|
386
|
+
"""
|
|
362
387
|
try:
|
|
363
|
-
|
|
388
|
+
# If user_id is provided, verify ownership before deleting
|
|
389
|
+
if user_id is not None:
|
|
390
|
+
self._memories = [
|
|
391
|
+
m for m in self._memories
|
|
392
|
+
if not (m.get("memory_id") in memory_ids and m.get("user_id") == user_id)
|
|
393
|
+
]
|
|
394
|
+
else:
|
|
395
|
+
self._memories = [m for m in self._memories if m.get("memory_id") not in memory_ids]
|
|
364
396
|
log_debug(f"Successfully deleted {len(memory_ids)} user memories")
|
|
365
397
|
|
|
366
398
|
except Exception as e:
|
|
@@ -368,6 +400,14 @@ class InMemoryDb(BaseDb):
|
|
|
368
400
|
raise e
|
|
369
401
|
|
|
370
402
|
def get_all_memory_topics(self) -> List[str]:
|
|
403
|
+
"""Get all memory topics from in-memory storage.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
List[str]: List of unique topics.
|
|
407
|
+
|
|
408
|
+
Raises:
|
|
409
|
+
Exception: If an error occurs while reading topics.
|
|
410
|
+
"""
|
|
371
411
|
try:
|
|
372
412
|
topics = set()
|
|
373
413
|
for memory in self._memories:
|
|
@@ -381,11 +421,28 @@ class InMemoryDb(BaseDb):
|
|
|
381
421
|
raise e
|
|
382
422
|
|
|
383
423
|
def get_user_memory(
|
|
384
|
-
self, memory_id: str, deserialize: Optional[bool] = True
|
|
424
|
+
self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
|
|
385
425
|
) -> Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
426
|
+
"""Get a user memory from in-memory storage.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
memory_id (str): The ID of the memory to retrieve.
|
|
430
|
+
deserialize (Optional[bool]): Whether to deserialize the memory. Defaults to True.
|
|
431
|
+
user_id (Optional[str]): The ID of the user. If provided, only returns the memory if it belongs to this user.
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
Optional[Union[UserMemory, Dict[str, Any]]]: The memory object or dictionary, or None if not found.
|
|
435
|
+
|
|
436
|
+
Raises:
|
|
437
|
+
Exception: If an error occurs while reading the memory.
|
|
438
|
+
"""
|
|
386
439
|
try:
|
|
387
440
|
for memory_data in self._memories:
|
|
388
441
|
if memory_data.get("memory_id") == memory_id:
|
|
442
|
+
# Filter by user_id if provided
|
|
443
|
+
if user_id is not None and memory_data.get("user_id") != user_id:
|
|
444
|
+
continue
|
|
445
|
+
|
|
389
446
|
memory_data_copy = deepcopy(memory_data)
|
|
390
447
|
if not deserialize:
|
|
391
448
|
return memory_data_copy
|
|
@@ -455,19 +512,31 @@ class InMemoryDb(BaseDb):
|
|
|
455
512
|
def get_user_memory_stats(
|
|
456
513
|
self, limit: Optional[int] = None, page: Optional[int] = None
|
|
457
514
|
) -> Tuple[List[Dict[str, Any]], int]:
|
|
458
|
-
"""Get user memory statistics.
|
|
515
|
+
"""Get user memory statistics.
|
|
516
|
+
|
|
517
|
+
Args:
|
|
518
|
+
limit (Optional[int]): Maximum number of stats to return.
|
|
519
|
+
page (Optional[int]): Page number for pagination.
|
|
520
|
+
|
|
521
|
+
Returns:
|
|
522
|
+
Tuple[List[Dict[str, Any]], int]: List of user memory statistics and total count.
|
|
523
|
+
|
|
524
|
+
Raises:
|
|
525
|
+
Exception: If an error occurs while getting stats.
|
|
526
|
+
"""
|
|
459
527
|
try:
|
|
460
528
|
user_stats = {}
|
|
461
529
|
|
|
462
530
|
for memory in self._memories:
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
531
|
+
memory_user_id = memory.get("user_id")
|
|
532
|
+
|
|
533
|
+
if memory_user_id:
|
|
534
|
+
if memory_user_id not in user_stats:
|
|
535
|
+
user_stats[memory_user_id] = {"user_id": memory_user_id, "total_memories": 0, "last_memory_updated_at": 0}
|
|
536
|
+
user_stats[memory_user_id]["total_memories"] += 1
|
|
468
537
|
updated_at = memory.get("updated_at", 0)
|
|
469
|
-
if updated_at > user_stats[
|
|
470
|
-
user_stats[
|
|
538
|
+
if updated_at > user_stats[memory_user_id]["last_memory_updated_at"]:
|
|
539
|
+
user_stats[memory_user_id]["last_memory_updated_at"] = updated_at
|
|
471
540
|
|
|
472
541
|
stats_list = list(user_stats.values())
|
|
473
542
|
stats_list.sort(key=lambda x: x["last_memory_updated_at"], reverse=True)
|
agno/db/json/json_db.py
CHANGED
|
@@ -445,11 +445,29 @@ class JsonDb(BaseDb):
|
|
|
445
445
|
return False
|
|
446
446
|
|
|
447
447
|
# -- Memory methods --
|
|
448
|
-
def delete_user_memory(self, memory_id: str):
|
|
449
|
-
"""Delete a user memory from the JSON file.
|
|
448
|
+
def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
|
|
449
|
+
"""Delete a user memory from the JSON file.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
memory_id (str): The ID of the memory to delete.
|
|
453
|
+
user_id (Optional[str]): The ID of the user (optional, for filtering).
|
|
454
|
+
"""
|
|
450
455
|
try:
|
|
451
456
|
memories = self._read_json_file(self.memory_table_name)
|
|
452
457
|
original_count = len(memories)
|
|
458
|
+
|
|
459
|
+
# If user_id is provided, verify the memory belongs to the user before deleting
|
|
460
|
+
if user_id:
|
|
461
|
+
memory_to_delete = None
|
|
462
|
+
for m in memories:
|
|
463
|
+
if m.get("memory_id") == memory_id:
|
|
464
|
+
memory_to_delete = m
|
|
465
|
+
break
|
|
466
|
+
|
|
467
|
+
if memory_to_delete and memory_to_delete.get("user_id") != user_id:
|
|
468
|
+
log_debug(f"Memory {memory_id} does not belong to user {user_id}")
|
|
469
|
+
return
|
|
470
|
+
|
|
453
471
|
memories = [m for m in memories if m.get("memory_id") != memory_id]
|
|
454
472
|
|
|
455
473
|
if len(memories) < original_count:
|
|
@@ -462,10 +480,24 @@ class JsonDb(BaseDb):
|
|
|
462
480
|
log_error(f"Error deleting memory: {e}")
|
|
463
481
|
raise e
|
|
464
482
|
|
|
465
|
-
def delete_user_memories(self, memory_ids: List[str]) -> None:
|
|
466
|
-
"""Delete multiple user memories from the JSON file.
|
|
483
|
+
def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
|
|
484
|
+
"""Delete multiple user memories from the JSON file.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
memory_ids (List[str]): List of memory IDs to delete.
|
|
488
|
+
user_id (Optional[str]): The ID of the user (optional, for filtering).
|
|
489
|
+
"""
|
|
467
490
|
try:
|
|
468
491
|
memories = self._read_json_file(self.memory_table_name)
|
|
492
|
+
|
|
493
|
+
# If user_id is provided, filter memory_ids to only those belonging to the user
|
|
494
|
+
if user_id:
|
|
495
|
+
filtered_memory_ids: List[str] = []
|
|
496
|
+
for memory in memories:
|
|
497
|
+
if memory.get("memory_id") in memory_ids and memory.get("user_id") == user_id:
|
|
498
|
+
filtered_memory_ids.append(memory.get("memory_id")) # type: ignore
|
|
499
|
+
memory_ids = filtered_memory_ids
|
|
500
|
+
|
|
469
501
|
memories = [m for m in memories if m.get("memory_id") not in memory_ids]
|
|
470
502
|
self._write_json_file(self.memory_table_name, memories)
|
|
471
503
|
|
|
@@ -476,7 +508,11 @@ class JsonDb(BaseDb):
|
|
|
476
508
|
raise e
|
|
477
509
|
|
|
478
510
|
def get_all_memory_topics(self) -> List[str]:
|
|
479
|
-
"""Get all memory topics from the JSON file.
|
|
511
|
+
"""Get all memory topics from the JSON file.
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
List[str]: List of unique memory topics.
|
|
515
|
+
"""
|
|
480
516
|
try:
|
|
481
517
|
memories = self._read_json_file(self.memory_table_name)
|
|
482
518
|
|
|
@@ -492,14 +528,30 @@ class JsonDb(BaseDb):
|
|
|
492
528
|
raise e
|
|
493
529
|
|
|
494
530
|
def get_user_memory(
|
|
495
|
-
self,
|
|
531
|
+
self,
|
|
532
|
+
memory_id: str,
|
|
533
|
+
deserialize: Optional[bool] = True,
|
|
534
|
+
user_id: Optional[str] = None,
|
|
496
535
|
) -> Optional[Union[UserMemory, Dict[str, Any]]]:
|
|
497
|
-
"""Get a memory from the JSON file.
|
|
536
|
+
"""Get a memory from the JSON file.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
memory_id (str): The ID of the memory to get.
|
|
540
|
+
deserialize (Optional[bool]): Whether to deserialize the memory.
|
|
541
|
+
user_id (Optional[str]): The ID of the user (optional, for filtering).
|
|
542
|
+
|
|
543
|
+
Returns:
|
|
544
|
+
Optional[Union[UserMemory, Dict[str, Any]]]: The user memory data if found, None otherwise.
|
|
545
|
+
"""
|
|
498
546
|
try:
|
|
499
547
|
memories = self._read_json_file(self.memory_table_name)
|
|
500
548
|
|
|
501
549
|
for memory_data in memories:
|
|
502
550
|
if memory_data.get("memory_id") == memory_id:
|
|
551
|
+
# Filter by user_id if provided
|
|
552
|
+
if user_id and memory_data.get("user_id") != user_id:
|
|
553
|
+
return None
|
|
554
|
+
|
|
503
555
|
if not deserialize:
|
|
504
556
|
return memory_data
|
|
505
557
|
return UserMemory.from_dict(memory_data)
|
|
@@ -571,20 +623,32 @@ class JsonDb(BaseDb):
|
|
|
571
623
|
def get_user_memory_stats(
|
|
572
624
|
self, limit: Optional[int] = None, page: Optional[int] = None
|
|
573
625
|
) -> Tuple[List[Dict[str, Any]], int]:
|
|
574
|
-
"""Get user memory statistics.
|
|
626
|
+
"""Get user memory statistics.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
limit (Optional[int]): The maximum number of user stats to return.
|
|
630
|
+
page (Optional[int]): The page number.
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
|
|
634
|
+
"""
|
|
575
635
|
try:
|
|
576
636
|
memories = self._read_json_file(self.memory_table_name)
|
|
577
637
|
user_stats = {}
|
|
578
638
|
|
|
579
639
|
for memory in memories:
|
|
580
|
-
|
|
581
|
-
if
|
|
582
|
-
if
|
|
583
|
-
user_stats[
|
|
584
|
-
|
|
640
|
+
memory_user_id = memory.get("user_id")
|
|
641
|
+
if memory_user_id:
|
|
642
|
+
if memory_user_id not in user_stats:
|
|
643
|
+
user_stats[memory_user_id] = {
|
|
644
|
+
"user_id": memory_user_id,
|
|
645
|
+
"total_memories": 0,
|
|
646
|
+
"last_memory_updated_at": 0,
|
|
647
|
+
}
|
|
648
|
+
user_stats[memory_user_id]["total_memories"] += 1
|
|
585
649
|
updated_at = memory.get("updated_at", 0)
|
|
586
|
-
if updated_at > user_stats[
|
|
587
|
-
user_stats[
|
|
650
|
+
if updated_at > user_stats[memory_user_id]["last_memory_updated_at"]:
|
|
651
|
+
user_stats[memory_user_id]["last_memory_updated_at"] = updated_at
|
|
588
652
|
|
|
589
653
|
stats_list = list(user_stats.values())
|
|
590
654
|
stats_list.sort(key=lambda x: x["last_memory_updated_at"], reverse=True)
|
agno/db/mongo/mongo.py
CHANGED
|
@@ -16,7 +16,7 @@ from agno.db.mongo.utils import (
|
|
|
16
16
|
from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
|
|
17
17
|
from agno.db.schemas.knowledge import KnowledgeRow
|
|
18
18
|
from agno.db.schemas.memory import UserMemory
|
|
19
|
-
from agno.db.utils import deserialize_session_json_fields
|
|
19
|
+
from agno.db.utils import deserialize_session_json_fields
|
|
20
20
|
from agno.session import AgentSession, Session, TeamSession, WorkflowSession
|
|
21
21
|
from agno.utils.log import log_debug, log_error, log_info
|
|
22
22
|
from agno.utils.string import generate_id
|
|
@@ -282,7 +282,6 @@ class MongoDb(BaseDb):
|
|
|
282
282
|
return None
|
|
283
283
|
|
|
284
284
|
session = deserialize_session_json_fields(result)
|
|
285
|
-
|
|
286
285
|
if not deserialize:
|
|
287
286
|
return session
|
|
288
287
|
|
|
@@ -385,7 +384,6 @@ class MongoDb(BaseDb):
|
|
|
385
384
|
records = list(cursor)
|
|
386
385
|
if records is None:
|
|
387
386
|
return [] if deserialize else ([], 0)
|
|
388
|
-
|
|
389
387
|
sessions_raw = [deserialize_session_json_fields(record) for record in records]
|
|
390
388
|
|
|
391
389
|
if not deserialize:
|
|
@@ -489,25 +487,25 @@ class MongoDb(BaseDb):
|
|
|
489
487
|
if collection is None:
|
|
490
488
|
return None
|
|
491
489
|
|
|
492
|
-
|
|
490
|
+
session_dict = session.to_dict()
|
|
493
491
|
|
|
494
492
|
if isinstance(session, AgentSession):
|
|
495
493
|
record = {
|
|
496
|
-
"session_id":
|
|
494
|
+
"session_id": session_dict.get("session_id"),
|
|
497
495
|
"session_type": SessionType.AGENT.value,
|
|
498
|
-
"agent_id":
|
|
499
|
-
"user_id":
|
|
500
|
-
"runs":
|
|
501
|
-
"agent_data":
|
|
502
|
-
"session_data":
|
|
503
|
-
"summary":
|
|
504
|
-
"metadata":
|
|
505
|
-
"created_at":
|
|
496
|
+
"agent_id": session_dict.get("agent_id"),
|
|
497
|
+
"user_id": session_dict.get("user_id"),
|
|
498
|
+
"runs": session_dict.get("runs"),
|
|
499
|
+
"agent_data": session_dict.get("agent_data"),
|
|
500
|
+
"session_data": session_dict.get("session_data"),
|
|
501
|
+
"summary": session_dict.get("summary"),
|
|
502
|
+
"metadata": session_dict.get("metadata"),
|
|
503
|
+
"created_at": session_dict.get("created_at"),
|
|
506
504
|
"updated_at": int(time.time()),
|
|
507
505
|
}
|
|
508
506
|
|
|
509
507
|
result = collection.find_one_and_replace(
|
|
510
|
-
filter={"session_id":
|
|
508
|
+
filter={"session_id": session_dict.get("session_id")},
|
|
511
509
|
replacement=record,
|
|
512
510
|
upsert=True,
|
|
513
511
|
return_document=ReturnDocument.AFTER,
|
|
@@ -515,7 +513,7 @@ class MongoDb(BaseDb):
|
|
|
515
513
|
if not result:
|
|
516
514
|
return None
|
|
517
515
|
|
|
518
|
-
session =
|
|
516
|
+
session = result # type: ignore
|
|
519
517
|
|
|
520
518
|
if not deserialize:
|
|
521
519
|
return session
|
|
@@ -524,21 +522,21 @@ class MongoDb(BaseDb):
|
|
|
524
522
|
|
|
525
523
|
elif isinstance(session, TeamSession):
|
|
526
524
|
record = {
|
|
527
|
-
"session_id":
|
|
525
|
+
"session_id": session_dict.get("session_id"),
|
|
528
526
|
"session_type": SessionType.TEAM.value,
|
|
529
|
-
"team_id":
|
|
530
|
-
"user_id":
|
|
531
|
-
"runs":
|
|
532
|
-
"team_data":
|
|
533
|
-
"session_data":
|
|
534
|
-
"summary":
|
|
535
|
-
"metadata":
|
|
536
|
-
"created_at":
|
|
527
|
+
"team_id": session_dict.get("team_id"),
|
|
528
|
+
"user_id": session_dict.get("user_id"),
|
|
529
|
+
"runs": session_dict.get("runs"),
|
|
530
|
+
"team_data": session_dict.get("team_data"),
|
|
531
|
+
"session_data": session_dict.get("session_data"),
|
|
532
|
+
"summary": session_dict.get("summary"),
|
|
533
|
+
"metadata": session_dict.get("metadata"),
|
|
534
|
+
"created_at": session_dict.get("created_at"),
|
|
537
535
|
"updated_at": int(time.time()),
|
|
538
536
|
}
|
|
539
537
|
|
|
540
538
|
result = collection.find_one_and_replace(
|
|
541
|
-
filter={"session_id":
|
|
539
|
+
filter={"session_id": session_dict.get("session_id")},
|
|
542
540
|
replacement=record,
|
|
543
541
|
upsert=True,
|
|
544
542
|
return_document=ReturnDocument.AFTER,
|
|
@@ -546,7 +544,8 @@ class MongoDb(BaseDb):
|
|
|
546
544
|
if not result:
|
|
547
545
|
return None
|
|
548
546
|
|
|
549
|
-
|
|
547
|
+
# MongoDB stores native objects, no deserialization needed for document fields
|
|
548
|
+
session = result # type: ignore
|
|
550
549
|
|
|
551
550
|
if not deserialize:
|
|
552
551
|
return session
|
|
@@ -555,21 +554,21 @@ class MongoDb(BaseDb):
|
|
|
555
554
|
|
|
556
555
|
else:
|
|
557
556
|
record = {
|
|
558
|
-
"session_id":
|
|
557
|
+
"session_id": session_dict.get("session_id"),
|
|
559
558
|
"session_type": SessionType.WORKFLOW.value,
|
|
560
|
-
"workflow_id":
|
|
561
|
-
"user_id":
|
|
562
|
-
"runs":
|
|
563
|
-
"workflow_data":
|
|
564
|
-
"session_data":
|
|
565
|
-
"summary":
|
|
566
|
-
"metadata":
|
|
567
|
-
"created_at":
|
|
559
|
+
"workflow_id": session_dict.get("workflow_id"),
|
|
560
|
+
"user_id": session_dict.get("user_id"),
|
|
561
|
+
"runs": session_dict.get("runs"),
|
|
562
|
+
"workflow_data": session_dict.get("workflow_data"),
|
|
563
|
+
"session_data": session_dict.get("session_data"),
|
|
564
|
+
"summary": session_dict.get("summary"),
|
|
565
|
+
"metadata": session_dict.get("metadata"),
|
|
566
|
+
"created_at": session_dict.get("created_at"),
|
|
568
567
|
"updated_at": int(time.time()),
|
|
569
568
|
}
|
|
570
569
|
|
|
571
570
|
result = collection.find_one_and_replace(
|
|
572
|
-
filter={"session_id":
|
|
571
|
+
filter={"session_id": session_dict.get("session_id")},
|
|
573
572
|
replacement=record,
|
|
574
573
|
upsert=True,
|
|
575
574
|
return_document=ReturnDocument.AFTER,
|
|
@@ -577,7 +576,7 @@ class MongoDb(BaseDb):
|
|
|
577
576
|
if not result:
|
|
578
577
|
return None
|
|
579
578
|
|
|
580
|
-
session =
|
|
579
|
+
session = result # type: ignore
|
|
581
580
|
|
|
582
581
|
if not deserialize:
|
|
583
582
|
return session
|
|
@@ -628,48 +627,48 @@ class MongoDb(BaseDb):
|
|
|
628
627
|
if session is None:
|
|
629
628
|
continue
|
|
630
629
|
|
|
631
|
-
|
|
630
|
+
session_dict = session.to_dict()
|
|
632
631
|
|
|
633
632
|
if isinstance(session, AgentSession):
|
|
634
633
|
record = {
|
|
635
|
-
"session_id":
|
|
634
|
+
"session_id": session_dict.get("session_id"),
|
|
636
635
|
"session_type": SessionType.AGENT.value,
|
|
637
|
-
"agent_id":
|
|
638
|
-
"user_id":
|
|
639
|
-
"runs":
|
|
640
|
-
"agent_data":
|
|
641
|
-
"session_data":
|
|
642
|
-
"summary":
|
|
643
|
-
"metadata":
|
|
644
|
-
"created_at":
|
|
636
|
+
"agent_id": session_dict.get("agent_id"),
|
|
637
|
+
"user_id": session_dict.get("user_id"),
|
|
638
|
+
"runs": session_dict.get("runs"),
|
|
639
|
+
"agent_data": session_dict.get("agent_data"),
|
|
640
|
+
"session_data": session_dict.get("session_data"),
|
|
641
|
+
"summary": session_dict.get("summary"),
|
|
642
|
+
"metadata": session_dict.get("metadata"),
|
|
643
|
+
"created_at": session_dict.get("created_at"),
|
|
645
644
|
"updated_at": int(time.time()),
|
|
646
645
|
}
|
|
647
646
|
elif isinstance(session, TeamSession):
|
|
648
647
|
record = {
|
|
649
|
-
"session_id":
|
|
648
|
+
"session_id": session_dict.get("session_id"),
|
|
650
649
|
"session_type": SessionType.TEAM.value,
|
|
651
|
-
"team_id":
|
|
652
|
-
"user_id":
|
|
653
|
-
"runs":
|
|
654
|
-
"team_data":
|
|
655
|
-
"session_data":
|
|
656
|
-
"summary":
|
|
657
|
-
"metadata":
|
|
658
|
-
"created_at":
|
|
650
|
+
"team_id": session_dict.get("team_id"),
|
|
651
|
+
"user_id": session_dict.get("user_id"),
|
|
652
|
+
"runs": session_dict.get("runs"),
|
|
653
|
+
"team_data": session_dict.get("team_data"),
|
|
654
|
+
"session_data": session_dict.get("session_data"),
|
|
655
|
+
"summary": session_dict.get("summary"),
|
|
656
|
+
"metadata": session_dict.get("metadata"),
|
|
657
|
+
"created_at": session_dict.get("created_at"),
|
|
659
658
|
"updated_at": int(time.time()),
|
|
660
659
|
}
|
|
661
660
|
elif isinstance(session, WorkflowSession):
|
|
662
661
|
record = {
|
|
663
|
-
"session_id":
|
|
662
|
+
"session_id": session_dict.get("session_id"),
|
|
664
663
|
"session_type": SessionType.WORKFLOW.value,
|
|
665
|
-
"workflow_id":
|
|
666
|
-
"user_id":
|
|
667
|
-
"runs":
|
|
668
|
-
"workflow_data":
|
|
669
|
-
"session_data":
|
|
670
|
-
"summary":
|
|
671
|
-
"metadata":
|
|
672
|
-
"created_at":
|
|
664
|
+
"workflow_id": session_dict.get("workflow_id"),
|
|
665
|
+
"user_id": session_dict.get("user_id"),
|
|
666
|
+
"runs": session_dict.get("runs"),
|
|
667
|
+
"workflow_data": session_dict.get("workflow_data"),
|
|
668
|
+
"session_data": session_dict.get("session_data"),
|
|
669
|
+
"summary": session_dict.get("summary"),
|
|
670
|
+
"metadata": session_dict.get("metadata"),
|
|
671
|
+
"created_at": session_dict.get("created_at"),
|
|
673
672
|
"updated_at": int(time.time()),
|
|
674
673
|
}
|
|
675
674
|
else:
|
|
@@ -688,7 +687,7 @@ class MongoDb(BaseDb):
|
|
|
688
687
|
cursor = collection.find({"session_id": {"$in": session_ids}})
|
|
689
688
|
|
|
690
689
|
for doc in cursor:
|
|
691
|
-
session_dict =
|
|
690
|
+
session_dict = doc
|
|
692
691
|
|
|
693
692
|
if deserialize:
|
|
694
693
|
session_type = doc.get("session_type")
|
|
@@ -728,11 +727,12 @@ class MongoDb(BaseDb):
|
|
|
728
727
|
|
|
729
728
|
# -- Memory methods --
|
|
730
729
|
|
|
731
|
-
def delete_user_memory(self, memory_id: str):
|
|
730
|
+
def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
|
|
732
731
|
"""Delete a user memory from the database.
|
|
733
732
|
|
|
734
733
|
Args:
|
|
735
734
|
memory_id (str): The ID of the memory to delete.
|
|
735
|
+
user_id (Optional[str]): The ID of the user to verify ownership. If provided, only delete if the memory belongs to this user.
|
|
736
736
|
|
|
737
737
|
Returns:
|
|
738
738
|
bool: True if the memory was deleted, False otherwise.
|
|
@@ -745,7 +745,11 @@ class MongoDb(BaseDb):
|
|
|
745
745
|
if collection is None:
|
|
746
746
|
return
|
|
747
747
|
|
|
748
|
-
|
|
748
|
+
query = {"memory_id": memory_id}
|
|
749
|
+
if user_id is not None:
|
|
750
|
+
query["user_id"] = user_id
|
|
751
|
+
|
|
752
|
+
result = collection.delete_one(query)
|
|
749
753
|
|
|
750
754
|
success = result.deleted_count > 0
|
|
751
755
|
if success:
|
|
@@ -757,11 +761,12 @@ class MongoDb(BaseDb):
|
|
|
757
761
|
log_error(f"Error deleting memory: {e}")
|
|
758
762
|
raise e
|
|
759
763
|
|
|
760
|
-
def delete_user_memories(self, memory_ids: List[str]) -> None:
|
|
764
|
+
def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
|
|
761
765
|
"""Delete user memories from the database.
|
|
762
766
|
|
|
763
767
|
Args:
|
|
764
768
|
memory_ids (List[str]): The IDs of the memories to delete.
|
|
769
|
+
user_id (Optional[str]): The ID of the user to verify ownership. If provided, only delete memories that belong to this user.
|
|
765
770
|
|
|
766
771
|
Raises:
|
|
767
772
|
Exception: If there is an error deleting the memories.
|
|
@@ -771,7 +776,11 @@ class MongoDb(BaseDb):
|
|
|
771
776
|
if collection is None:
|
|
772
777
|
return
|
|
773
778
|
|
|
774
|
-
|
|
779
|
+
query: Dict[str, Any] = {"memory_id": {"$in": memory_ids}}
|
|
780
|
+
if user_id is not None:
|
|
781
|
+
query["user_id"] = user_id
|
|
782
|
+
|
|
783
|
+
result = collection.delete_many(query)
|
|
775
784
|
|
|
776
785
|
if result.deleted_count == 0:
|
|
777
786
|
log_debug(f"No memories found with ids: {memory_ids}")
|
|
@@ -794,19 +803,22 @@ class MongoDb(BaseDb):
|
|
|
794
803
|
if collection is None:
|
|
795
804
|
return []
|
|
796
805
|
|
|
797
|
-
topics = collection.distinct("topics")
|
|
806
|
+
topics = collection.distinct("topics", {})
|
|
798
807
|
return [topic for topic in topics if topic]
|
|
799
808
|
|
|
800
809
|
except Exception as e:
|
|
801
810
|
log_error(f"Exception reading from collection: {e}")
|
|
802
811
|
raise e
|
|
803
812
|
|
|
804
|
-
def get_user_memory(
|
|
813
|
+
def get_user_memory(
|
|
814
|
+
self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
|
|
815
|
+
) -> Optional[UserMemory]:
|
|
805
816
|
"""Get a memory from the database.
|
|
806
817
|
|
|
807
818
|
Args:
|
|
808
819
|
memory_id (str): The ID of the memory to get.
|
|
809
820
|
deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
|
|
821
|
+
user_id (Optional[str]): The ID of the user to verify ownership. If provided, only return the memory if it belongs to this user.
|
|
810
822
|
|
|
811
823
|
Returns:
|
|
812
824
|
Optional[UserMemory]:
|
|
@@ -821,7 +833,11 @@ class MongoDb(BaseDb):
|
|
|
821
833
|
if collection is None:
|
|
822
834
|
return None
|
|
823
835
|
|
|
824
|
-
|
|
836
|
+
query = {"memory_id": memory_id}
|
|
837
|
+
if user_id is not None:
|
|
838
|
+
query["user_id"] = user_id
|
|
839
|
+
|
|
840
|
+
result = collection.find_one(query)
|
|
825
841
|
if result is None or not deserialize:
|
|
826
842
|
return result
|
|
827
843
|
|
|
@@ -934,8 +950,10 @@ class MongoDb(BaseDb):
|
|
|
934
950
|
if collection is None:
|
|
935
951
|
return [], 0
|
|
936
952
|
|
|
953
|
+
match_stage = {"user_id": {"$ne": None}}
|
|
954
|
+
|
|
937
955
|
pipeline = [
|
|
938
|
-
{"$match":
|
|
956
|
+
{"$match": match_stage},
|
|
939
957
|
{
|
|
940
958
|
"$group": {
|
|
941
959
|
"_id": "$user_id",
|