agno 2.1.1__py3-none-any.whl → 2.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. agno/agent/agent.py +12 -0
  2. agno/db/base.py +8 -4
  3. agno/db/dynamo/dynamo.py +69 -17
  4. agno/db/firestore/firestore.py +65 -28
  5. agno/db/gcs_json/gcs_json_db.py +70 -17
  6. agno/db/in_memory/in_memory_db.py +85 -14
  7. agno/db/json/json_db.py +79 -15
  8. agno/db/mongo/mongo.py +27 -8
  9. agno/db/mysql/mysql.py +17 -3
  10. agno/db/postgres/postgres.py +21 -3
  11. agno/db/redis/redis.py +38 -11
  12. agno/db/singlestore/singlestore.py +14 -3
  13. agno/db/sqlite/sqlite.py +34 -46
  14. agno/knowledge/reader/field_labeled_csv_reader.py +294 -0
  15. agno/knowledge/reader/pdf_reader.py +28 -52
  16. agno/knowledge/reader/reader_factory.py +12 -0
  17. agno/memory/manager.py +12 -4
  18. agno/models/anthropic/claude.py +4 -1
  19. agno/models/aws/bedrock.py +52 -112
  20. agno/models/openrouter/openrouter.py +39 -1
  21. agno/models/vertexai/__init__.py +0 -0
  22. agno/models/vertexai/claude.py +74 -0
  23. agno/os/app.py +76 -32
  24. agno/os/interfaces/a2a/__init__.py +3 -0
  25. agno/os/interfaces/a2a/a2a.py +42 -0
  26. agno/os/interfaces/a2a/router.py +252 -0
  27. agno/os/interfaces/a2a/utils.py +924 -0
  28. agno/os/interfaces/agui/router.py +12 -0
  29. agno/os/mcp.py +3 -3
  30. agno/os/router.py +38 -8
  31. agno/os/routers/memory/memory.py +5 -3
  32. agno/os/routers/memory/schemas.py +1 -0
  33. agno/os/utils.py +37 -10
  34. agno/team/team.py +12 -0
  35. agno/tools/file.py +4 -2
  36. agno/tools/mcp.py +46 -1
  37. agno/utils/merge_dict.py +22 -1
  38. agno/utils/streamlit.py +1 -1
  39. agno/workflow/parallel.py +90 -14
  40. agno/workflow/step.py +30 -27
  41. agno/workflow/workflow.py +12 -6
  42. {agno-2.1.1.dist-info → agno-2.1.3.dist-info}/METADATA +16 -14
  43. {agno-2.1.1.dist-info → agno-2.1.3.dist-info}/RECORD +46 -39
  44. {agno-2.1.1.dist-info → agno-2.1.3.dist-info}/WHEEL +0 -0
  45. {agno-2.1.1.dist-info → agno-2.1.3.dist-info}/licenses/LICENSE +0 -0
  46. {agno-2.1.1.dist-info → agno-2.1.3.dist-info}/top_level.txt +0 -0
@@ -343,10 +343,26 @@ class InMemoryDb(BaseDb):
343
343
  return []
344
344
 
345
345
  # -- Memory methods --
346
- def delete_user_memory(self, memory_id: str):
346
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
347
+ """Delete a user memory from in-memory storage.
348
+
349
+ Args:
350
+ memory_id (str): The ID of the memory to delete.
351
+ user_id (Optional[str]): The ID of the user. If provided, verifies the memory belongs to this user before deletion.
352
+
353
+ Raises:
354
+ Exception: If an error occurs during deletion.
355
+ """
347
356
  try:
348
357
  original_count = len(self._memories)
349
- self._memories = [m for m in self._memories if m.get("memory_id") != memory_id]
358
+
359
+ # If user_id is provided, verify ownership before deleting
360
+ if user_id is not None:
361
+ self._memories = [
362
+ m for m in self._memories if not (m.get("memory_id") == memory_id and m.get("user_id") == user_id)
363
+ ]
364
+ else:
365
+ self._memories = [m for m in self._memories if m.get("memory_id") != memory_id]
350
366
 
351
367
  if len(self._memories) < original_count:
352
368
  log_debug(f"Successfully deleted user memory id: {memory_id}")
@@ -357,10 +373,24 @@ class InMemoryDb(BaseDb):
357
373
  log_error(f"Error deleting memory: {e}")
358
374
  raise e
359
375
 
360
- def delete_user_memories(self, memory_ids: List[str]) -> None:
361
- """Delete multiple user memories from in-memory storage."""
376
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
377
+ """Delete multiple user memories from in-memory storage.
378
+
379
+ Args:
380
+ memory_ids (List[str]): The IDs of the memories to delete.
381
+ user_id (Optional[str]): The ID of the user. If provided, only deletes memories belonging to this user.
382
+
383
+ Raises:
384
+ Exception: If an error occurs during deletion.
385
+ """
362
386
  try:
363
- self._memories = [m for m in self._memories if m.get("memory_id") not in memory_ids]
387
+ # If user_id is provided, verify ownership before deleting
388
+ if user_id is not None:
389
+ self._memories = [
390
+ m for m in self._memories if not (m.get("memory_id") in memory_ids and m.get("user_id") == user_id)
391
+ ]
392
+ else:
393
+ self._memories = [m for m in self._memories if m.get("memory_id") not in memory_ids]
364
394
  log_debug(f"Successfully deleted {len(memory_ids)} user memories")
365
395
 
366
396
  except Exception as e:
@@ -368,6 +398,14 @@ class InMemoryDb(BaseDb):
368
398
  raise e
369
399
 
370
400
  def get_all_memory_topics(self) -> List[str]:
401
+ """Get all memory topics from in-memory storage.
402
+
403
+ Returns:
404
+ List[str]: List of unique topics.
405
+
406
+ Raises:
407
+ Exception: If an error occurs while reading topics.
408
+ """
371
409
  try:
372
410
  topics = set()
373
411
  for memory in self._memories:
@@ -381,11 +419,28 @@ class InMemoryDb(BaseDb):
381
419
  raise e
382
420
 
383
421
  def get_user_memory(
384
- self, memory_id: str, deserialize: Optional[bool] = True
422
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
385
423
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
424
+ """Get a user memory from in-memory storage.
425
+
426
+ Args:
427
+ memory_id (str): The ID of the memory to retrieve.
428
+ deserialize (Optional[bool]): Whether to deserialize the memory. Defaults to True.
429
+ user_id (Optional[str]): The ID of the user. If provided, only returns the memory if it belongs to this user.
430
+
431
+ Returns:
432
+ Optional[Union[UserMemory, Dict[str, Any]]]: The memory object or dictionary, or None if not found.
433
+
434
+ Raises:
435
+ Exception: If an error occurs while reading the memory.
436
+ """
386
437
  try:
387
438
  for memory_data in self._memories:
388
439
  if memory_data.get("memory_id") == memory_id:
440
+ # Filter by user_id if provided
441
+ if user_id is not None and memory_data.get("user_id") != user_id:
442
+ continue
443
+
389
444
  memory_data_copy = deepcopy(memory_data)
390
445
  if not deserialize:
391
446
  return memory_data_copy
@@ -455,19 +510,35 @@ class InMemoryDb(BaseDb):
455
510
  def get_user_memory_stats(
456
511
  self, limit: Optional[int] = None, page: Optional[int] = None
457
512
  ) -> Tuple[List[Dict[str, Any]], int]:
458
- """Get user memory statistics."""
513
+ """Get user memory statistics.
514
+
515
+ Args:
516
+ limit (Optional[int]): Maximum number of stats to return.
517
+ page (Optional[int]): Page number for pagination.
518
+
519
+ Returns:
520
+ Tuple[List[Dict[str, Any]], int]: List of user memory statistics and total count.
521
+
522
+ Raises:
523
+ Exception: If an error occurs while getting stats.
524
+ """
459
525
  try:
460
526
  user_stats = {}
461
527
 
462
528
  for memory in self._memories:
463
- user_id = memory.get("user_id")
464
- if user_id:
465
- if user_id not in user_stats:
466
- user_stats[user_id] = {"user_id": user_id, "total_memories": 0, "last_memory_updated_at": 0}
467
- user_stats[user_id]["total_memories"] += 1
529
+ memory_user_id = memory.get("user_id")
530
+
531
+ if memory_user_id:
532
+ if memory_user_id not in user_stats:
533
+ user_stats[memory_user_id] = {
534
+ "user_id": memory_user_id,
535
+ "total_memories": 0,
536
+ "last_memory_updated_at": 0,
537
+ }
538
+ user_stats[memory_user_id]["total_memories"] += 1
468
539
  updated_at = memory.get("updated_at", 0)
469
- if updated_at > user_stats[user_id]["last_memory_updated_at"]:
470
- user_stats[user_id]["last_memory_updated_at"] = updated_at
540
+ if updated_at > user_stats[memory_user_id]["last_memory_updated_at"]:
541
+ user_stats[memory_user_id]["last_memory_updated_at"] = updated_at
471
542
 
472
543
  stats_list = list(user_stats.values())
473
544
  stats_list.sort(key=lambda x: x["last_memory_updated_at"], reverse=True)
agno/db/json/json_db.py CHANGED
@@ -445,11 +445,29 @@ class JsonDb(BaseDb):
445
445
  return False
446
446
 
447
447
  # -- Memory methods --
448
- def delete_user_memory(self, memory_id: str):
449
- """Delete a user memory from the JSON file."""
448
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
449
+ """Delete a user memory from the JSON file.
450
+
451
+ Args:
452
+ memory_id (str): The ID of the memory to delete.
453
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
454
+ """
450
455
  try:
451
456
  memories = self._read_json_file(self.memory_table_name)
452
457
  original_count = len(memories)
458
+
459
+ # If user_id is provided, verify the memory belongs to the user before deleting
460
+ if user_id:
461
+ memory_to_delete = None
462
+ for m in memories:
463
+ if m.get("memory_id") == memory_id:
464
+ memory_to_delete = m
465
+ break
466
+
467
+ if memory_to_delete and memory_to_delete.get("user_id") != user_id:
468
+ log_debug(f"Memory {memory_id} does not belong to user {user_id}")
469
+ return
470
+
453
471
  memories = [m for m in memories if m.get("memory_id") != memory_id]
454
472
 
455
473
  if len(memories) < original_count:
@@ -462,10 +480,24 @@ class JsonDb(BaseDb):
462
480
  log_error(f"Error deleting memory: {e}")
463
481
  raise e
464
482
 
465
- def delete_user_memories(self, memory_ids: List[str]) -> None:
466
- """Delete multiple user memories from the JSON file."""
483
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
484
+ """Delete multiple user memories from the JSON file.
485
+
486
+ Args:
487
+ memory_ids (List[str]): List of memory IDs to delete.
488
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
489
+ """
467
490
  try:
468
491
  memories = self._read_json_file(self.memory_table_name)
492
+
493
+ # If user_id is provided, filter memory_ids to only those belonging to the user
494
+ if user_id:
495
+ filtered_memory_ids: List[str] = []
496
+ for memory in memories:
497
+ if memory.get("memory_id") in memory_ids and memory.get("user_id") == user_id:
498
+ filtered_memory_ids.append(memory.get("memory_id")) # type: ignore
499
+ memory_ids = filtered_memory_ids
500
+
469
501
  memories = [m for m in memories if m.get("memory_id") not in memory_ids]
470
502
  self._write_json_file(self.memory_table_name, memories)
471
503
 
@@ -476,7 +508,11 @@ class JsonDb(BaseDb):
476
508
  raise e
477
509
 
478
510
  def get_all_memory_topics(self) -> List[str]:
479
- """Get all memory topics from the JSON file."""
511
+ """Get all memory topics from the JSON file.
512
+
513
+ Returns:
514
+ List[str]: List of unique memory topics.
515
+ """
480
516
  try:
481
517
  memories = self._read_json_file(self.memory_table_name)
482
518
 
@@ -492,14 +528,30 @@ class JsonDb(BaseDb):
492
528
  raise e
493
529
 
494
530
  def get_user_memory(
495
- self, memory_id: str, deserialize: Optional[bool] = True
531
+ self,
532
+ memory_id: str,
533
+ deserialize: Optional[bool] = True,
534
+ user_id: Optional[str] = None,
496
535
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
497
- """Get a memory from the JSON file."""
536
+ """Get a memory from the JSON file.
537
+
538
+ Args:
539
+ memory_id (str): The ID of the memory to get.
540
+ deserialize (Optional[bool]): Whether to deserialize the memory.
541
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
542
+
543
+ Returns:
544
+ Optional[Union[UserMemory, Dict[str, Any]]]: The user memory data if found, None otherwise.
545
+ """
498
546
  try:
499
547
  memories = self._read_json_file(self.memory_table_name)
500
548
 
501
549
  for memory_data in memories:
502
550
  if memory_data.get("memory_id") == memory_id:
551
+ # Filter by user_id if provided
552
+ if user_id and memory_data.get("user_id") != user_id:
553
+ return None
554
+
503
555
  if not deserialize:
504
556
  return memory_data
505
557
  return UserMemory.from_dict(memory_data)
@@ -571,20 +623,32 @@ class JsonDb(BaseDb):
571
623
  def get_user_memory_stats(
572
624
  self, limit: Optional[int] = None, page: Optional[int] = None
573
625
  ) -> Tuple[List[Dict[str, Any]], int]:
574
- """Get user memory statistics."""
626
+ """Get user memory statistics.
627
+
628
+ Args:
629
+ limit (Optional[int]): The maximum number of user stats to return.
630
+ page (Optional[int]): The page number.
631
+
632
+ Returns:
633
+ Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
634
+ """
575
635
  try:
576
636
  memories = self._read_json_file(self.memory_table_name)
577
637
  user_stats = {}
578
638
 
579
639
  for memory in memories:
580
- user_id = memory.get("user_id")
581
- if user_id:
582
- if user_id not in user_stats:
583
- user_stats[user_id] = {"user_id": user_id, "total_memories": 0, "last_memory_updated_at": 0}
584
- user_stats[user_id]["total_memories"] += 1
640
+ memory_user_id = memory.get("user_id")
641
+ if memory_user_id:
642
+ if memory_user_id not in user_stats:
643
+ user_stats[memory_user_id] = {
644
+ "user_id": memory_user_id,
645
+ "total_memories": 0,
646
+ "last_memory_updated_at": 0,
647
+ }
648
+ user_stats[memory_user_id]["total_memories"] += 1
585
649
  updated_at = memory.get("updated_at", 0)
586
- if updated_at > user_stats[user_id]["last_memory_updated_at"]:
587
- user_stats[user_id]["last_memory_updated_at"] = updated_at
650
+ if updated_at > user_stats[memory_user_id]["last_memory_updated_at"]:
651
+ user_stats[memory_user_id]["last_memory_updated_at"] = updated_at
588
652
 
589
653
  stats_list = list(user_stats.values())
590
654
  stats_list.sort(key=lambda x: x["last_memory_updated_at"], reverse=True)
agno/db/mongo/mongo.py CHANGED
@@ -727,11 +727,12 @@ class MongoDb(BaseDb):
727
727
 
728
728
  # -- Memory methods --
729
729
 
730
- def delete_user_memory(self, memory_id: str):
730
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
731
731
  """Delete a user memory from the database.
732
732
 
733
733
  Args:
734
734
  memory_id (str): The ID of the memory to delete.
735
+ user_id (Optional[str]): The ID of the user to verify ownership. If provided, only delete if the memory belongs to this user.
735
736
 
736
737
  Returns:
737
738
  bool: True if the memory was deleted, False otherwise.
@@ -744,7 +745,11 @@ class MongoDb(BaseDb):
744
745
  if collection is None:
745
746
  return
746
747
 
747
- result = collection.delete_one({"memory_id": memory_id})
748
+ query = {"memory_id": memory_id}
749
+ if user_id is not None:
750
+ query["user_id"] = user_id
751
+
752
+ result = collection.delete_one(query)
748
753
 
749
754
  success = result.deleted_count > 0
750
755
  if success:
@@ -756,11 +761,12 @@ class MongoDb(BaseDb):
756
761
  log_error(f"Error deleting memory: {e}")
757
762
  raise e
758
763
 
759
- def delete_user_memories(self, memory_ids: List[str]) -> None:
764
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
760
765
  """Delete user memories from the database.
761
766
 
762
767
  Args:
763
768
  memory_ids (List[str]): The IDs of the memories to delete.
769
+ user_id (Optional[str]): The ID of the user to verify ownership. If provided, only delete memories that belong to this user.
764
770
 
765
771
  Raises:
766
772
  Exception: If there is an error deleting the memories.
@@ -770,7 +776,11 @@ class MongoDb(BaseDb):
770
776
  if collection is None:
771
777
  return
772
778
 
773
- result = collection.delete_many({"memory_id": {"$in": memory_ids}})
779
+ query: Dict[str, Any] = {"memory_id": {"$in": memory_ids}}
780
+ if user_id is not None:
781
+ query["user_id"] = user_id
782
+
783
+ result = collection.delete_many(query)
774
784
 
775
785
  if result.deleted_count == 0:
776
786
  log_debug(f"No memories found with ids: {memory_ids}")
@@ -793,19 +803,22 @@ class MongoDb(BaseDb):
793
803
  if collection is None:
794
804
  return []
795
805
 
796
- topics = collection.distinct("topics")
806
+ topics = collection.distinct("topics", {})
797
807
  return [topic for topic in topics if topic]
798
808
 
799
809
  except Exception as e:
800
810
  log_error(f"Exception reading from collection: {e}")
801
811
  raise e
802
812
 
803
- def get_user_memory(self, memory_id: str, deserialize: Optional[bool] = True) -> Optional[UserMemory]:
813
+ def get_user_memory(
814
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
815
+ ) -> Optional[UserMemory]:
804
816
  """Get a memory from the database.
805
817
 
806
818
  Args:
807
819
  memory_id (str): The ID of the memory to get.
808
820
  deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
821
+ user_id (Optional[str]): The ID of the user to verify ownership. If provided, only return the memory if it belongs to this user.
809
822
 
810
823
  Returns:
811
824
  Optional[UserMemory]:
@@ -820,7 +833,11 @@ class MongoDb(BaseDb):
820
833
  if collection is None:
821
834
  return None
822
835
 
823
- result = collection.find_one({"memory_id": memory_id})
836
+ query = {"memory_id": memory_id}
837
+ if user_id is not None:
838
+ query["user_id"] = user_id
839
+
840
+ result = collection.find_one(query)
824
841
  if result is None or not deserialize:
825
842
  return result
826
843
 
@@ -933,8 +950,10 @@ class MongoDb(BaseDb):
933
950
  if collection is None:
934
951
  return [], 0
935
952
 
953
+ match_stage = {"user_id": {"$ne": None}}
954
+
936
955
  pipeline = [
937
- {"$match": {"user_id": {"$ne": None}}},
956
+ {"$match": match_stage},
938
957
  {
939
958
  "$group": {
940
959
  "_id": "$user_id",
agno/db/mysql/mysql.py CHANGED
@@ -917,9 +917,13 @@ class MySQLDb(BaseDb):
917
917
  ]
918
918
 
919
919
  # -- Memory methods --
920
- def delete_user_memory(self, memory_id: str):
920
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
921
921
  """Delete a user memory from the database.
922
922
 
923
+ Args:
924
+ memory_id (str): The ID of the memory to delete.
925
+ user_id (Optional[str]): The user ID to filter by. Defaults to None.
926
+
923
927
  Returns:
924
928
  bool: True if deletion was successful, False otherwise.
925
929
 
@@ -933,6 +937,8 @@ class MySQLDb(BaseDb):
933
937
 
934
938
  with self.Session() as sess, sess.begin():
935
939
  delete_stmt = table.delete().where(table.c.memory_id == memory_id)
940
+ if user_id is not None:
941
+ delete_stmt = delete_stmt.where(table.c.user_id == user_id)
936
942
  result = sess.execute(delete_stmt)
937
943
 
938
944
  success = result.rowcount > 0
@@ -944,11 +950,12 @@ class MySQLDb(BaseDb):
944
950
  except Exception as e:
945
951
  log_error(f"Error deleting user memory: {e}")
946
952
 
947
- def delete_user_memories(self, memory_ids: List[str]) -> None:
953
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
948
954
  """Delete user memories from the database.
949
955
 
950
956
  Args:
951
957
  memory_ids (List[str]): The IDs of the memories to delete.
958
+ user_id (Optional[str]): The user ID to filter by. Defaults to None.
952
959
 
953
960
  Raises:
954
961
  Exception: If an error occurs during deletion.
@@ -960,6 +967,8 @@ class MySQLDb(BaseDb):
960
967
 
961
968
  with self.Session() as sess, sess.begin():
962
969
  delete_stmt = table.delete().where(table.c.memory_id.in_(memory_ids))
970
+ if user_id is not None:
971
+ delete_stmt = delete_stmt.where(table.c.user_id == user_id)
963
972
  result = sess.execute(delete_stmt)
964
973
  if result.rowcount == 0:
965
974
  log_debug(f"No user memories found with ids: {memory_ids}")
@@ -1002,12 +1011,15 @@ class MySQLDb(BaseDb):
1002
1011
  log_error(f"Exception reading from memory table: {e}")
1003
1012
  raise e
1004
1013
 
1005
- def get_user_memory(self, memory_id: str, deserialize: Optional[bool] = True) -> Optional[UserMemory]:
1014
+ def get_user_memory(
1015
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
1016
+ ) -> Optional[UserMemory]:
1006
1017
  """Get a memory from the database.
1007
1018
 
1008
1019
  Args:
1009
1020
  memory_id (str): The ID of the memory to get.
1010
1021
  deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
1022
+ user_id (Optional[str]): The user ID to filter by. Defaults to None.
1011
1023
 
1012
1024
  Returns:
1013
1025
  Union[UserMemory, Dict[str, Any], None]:
@@ -1024,6 +1036,8 @@ class MySQLDb(BaseDb):
1024
1036
 
1025
1037
  with self.Session() as sess, sess.begin():
1026
1038
  stmt = select(table).where(table.c.memory_id == memory_id)
1039
+ if user_id is not None:
1040
+ stmt = stmt.where(table.c.user_id == user_id)
1027
1041
 
1028
1042
  result = sess.execute(stmt).fetchone()
1029
1043
  if not result:
@@ -870,9 +870,13 @@ class PostgresDb(BaseDb):
870
870
  return []
871
871
 
872
872
  # -- Memory methods --
873
- def delete_user_memory(self, memory_id: str):
873
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
874
874
  """Delete a user memory from the database.
875
875
 
876
+ Args:
877
+ memory_id (str): The ID of the memory to delete.
878
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
879
+
876
880
  Returns:
877
881
  bool: True if deletion was successful, False otherwise.
878
882
 
@@ -886,6 +890,10 @@ class PostgresDb(BaseDb):
886
890
 
887
891
  with self.Session() as sess, sess.begin():
888
892
  delete_stmt = table.delete().where(table.c.memory_id == memory_id)
893
+
894
+ if user_id is not None:
895
+ delete_stmt = delete_stmt.where(table.c.user_id == user_id)
896
+
889
897
  result = sess.execute(delete_stmt)
890
898
 
891
899
  success = result.rowcount > 0
@@ -898,11 +906,12 @@ class PostgresDb(BaseDb):
898
906
  log_error(f"Error deleting user memory: {e}")
899
907
  raise e
900
908
 
901
- def delete_user_memories(self, memory_ids: List[str]) -> None:
909
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
902
910
  """Delete user memories from the database.
903
911
 
904
912
  Args:
905
913
  memory_ids (List[str]): The IDs of the memories to delete.
914
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
906
915
 
907
916
  Raises:
908
917
  Exception: If an error occurs during deletion.
@@ -914,6 +923,10 @@ class PostgresDb(BaseDb):
914
923
 
915
924
  with self.Session() as sess, sess.begin():
916
925
  delete_stmt = table.delete().where(table.c.memory_id.in_(memory_ids))
926
+
927
+ if user_id is not None:
928
+ delete_stmt = delete_stmt.where(table.c.user_id == user_id)
929
+
917
930
  result = sess.execute(delete_stmt)
918
931
 
919
932
  if result.rowcount == 0:
@@ -938,6 +951,7 @@ class PostgresDb(BaseDb):
938
951
 
939
952
  with self.Session() as sess, sess.begin():
940
953
  stmt = select(func.json_array_elements_text(table.c.topics))
954
+
941
955
  result = sess.execute(stmt).fetchall()
942
956
 
943
957
  return list(set([record[0] for record in result]))
@@ -947,13 +961,14 @@ class PostgresDb(BaseDb):
947
961
  return []
948
962
 
949
963
  def get_user_memory(
950
- self, memory_id: str, deserialize: Optional[bool] = True
964
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
951
965
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
952
966
  """Get a memory from the database.
953
967
 
954
968
  Args:
955
969
  memory_id (str): The ID of the memory to get.
956
970
  deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
971
+ user_id (Optional[str]): The ID of the user to filter by. Defaults to None.
957
972
 
958
973
  Returns:
959
974
  Union[UserMemory, Dict[str, Any], None]:
@@ -971,6 +986,9 @@ class PostgresDb(BaseDb):
971
986
  with self.Session() as sess, sess.begin():
972
987
  stmt = select(table).where(table.c.memory_id == memory_id)
973
988
 
989
+ if user_id is not None:
990
+ stmt = stmt.where(table.c.user_id == user_id)
991
+
974
992
  result = sess.execute(stmt).fetchone()
975
993
  if not result:
976
994
  return None
agno/db/redis/redis.py CHANGED
@@ -627,11 +627,12 @@ class RedisDb(BaseDb):
627
627
 
628
628
  # -- Memory methods --
629
629
 
630
- def delete_user_memory(self, memory_id: str):
630
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
631
631
  """Delete a user memory from Redis.
632
632
 
633
633
  Args:
634
634
  memory_id (str): The ID of the memory to delete.
635
+ user_id (Optional[str]): The ID of the user. If provided, verifies the memory belongs to this user before deleting.
635
636
 
636
637
  Returns:
637
638
  bool: True if the memory was deleted, False otherwise.
@@ -640,6 +641,16 @@ class RedisDb(BaseDb):
640
641
  Exception: If any error occurs while deleting the memory.
641
642
  """
642
643
  try:
644
+ # If user_id is provided, verify ownership before deleting
645
+ if user_id is not None:
646
+ memory = self._get_record("memories", memory_id)
647
+ if memory is None:
648
+ log_debug(f"No user memory found with id: {memory_id}")
649
+ return
650
+ if memory.get("user_id") != user_id:
651
+ log_debug(f"Memory {memory_id} does not belong to user {user_id}")
652
+ return
653
+
643
654
  if self._delete_record(
644
655
  "memories", memory_id, index_fields=["user_id", "agent_id", "team_id", "workflow_id"]
645
656
  ):
@@ -651,15 +662,25 @@ class RedisDb(BaseDb):
651
662
  log_error(f"Error deleting user memory: {e}")
652
663
  raise e
653
664
 
654
- def delete_user_memories(self, memory_ids: List[str]) -> None:
665
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
655
666
  """Delete user memories from Redis.
656
667
 
657
668
  Args:
658
669
  memory_ids (List[str]): The IDs of the memories to delete.
670
+ user_id (Optional[str]): The ID of the user. If provided, only deletes memories belonging to this user.
659
671
  """
660
672
  try:
661
673
  # TODO: cant we optimize this?
662
674
  for memory_id in memory_ids:
675
+ # If user_id is provided, verify ownership before deleting
676
+ if user_id is not None:
677
+ memory = self._get_record("memories", memory_id)
678
+ if memory is None:
679
+ continue
680
+ if memory.get("user_id") != user_id:
681
+ log_debug(f"Memory {memory_id} does not belong to user {user_id}, skipping deletion")
682
+ continue
683
+
663
684
  self._delete_record(
664
685
  "memories",
665
686
  memory_id,
@@ -692,12 +713,14 @@ class RedisDb(BaseDb):
692
713
  raise e
693
714
 
694
715
  def get_user_memory(
695
- self, memory_id: str, deserialize: Optional[bool] = True
716
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
696
717
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
697
718
  """Get a memory from Redis.
698
719
 
699
720
  Args:
700
721
  memory_id (str): The ID of the memory to get.
722
+ deserialize (Optional[bool]): Whether to deserialize the memory. Defaults to True.
723
+ user_id (Optional[str]): The ID of the user. If provided, only returns the memory if it belongs to this user.
701
724
 
702
725
  Returns:
703
726
  Optional[UserMemory]: The memory data if found, None otherwise.
@@ -707,6 +730,10 @@ class RedisDb(BaseDb):
707
730
  if memory_raw is None:
708
731
  return None
709
732
 
733
+ # Filter by user_id if provided
734
+ if user_id is not None and memory_raw.get("user_id") != user_id:
735
+ return None
736
+
710
737
  if not deserialize:
711
738
  return memory_raw
712
739
 
@@ -812,21 +839,21 @@ class RedisDb(BaseDb):
812
839
  # Group by user_id
813
840
  user_stats = {}
814
841
  for memory in all_memories:
815
- user_id = memory.get("user_id")
816
- if user_id is None:
842
+ memory_user_id = memory.get("user_id")
843
+ if memory_user_id is None:
817
844
  continue
818
845
 
819
- if user_id not in user_stats:
820
- user_stats[user_id] = {
821
- "user_id": user_id,
846
+ if memory_user_id not in user_stats:
847
+ user_stats[memory_user_id] = {
848
+ "user_id": memory_user_id,
822
849
  "total_memories": 0,
823
850
  "last_memory_updated_at": 0,
824
851
  }
825
852
 
826
- user_stats[user_id]["total_memories"] += 1
853
+ user_stats[memory_user_id]["total_memories"] += 1
827
854
  updated_at = memory.get("updated_at", 0)
828
- if updated_at > user_stats[user_id]["last_memory_updated_at"]:
829
- user_stats[user_id]["last_memory_updated_at"] = updated_at
855
+ if updated_at > user_stats[memory_user_id]["last_memory_updated_at"]:
856
+ user_stats[memory_user_id]["last_memory_updated_at"] = updated_at
830
857
 
831
858
  stats_list = list(user_stats.values())
832
859