agno 2.1.0__py3-none-any.whl → 2.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. agno/agent/agent.py +13 -1
  2. agno/db/base.py +8 -4
  3. agno/db/dynamo/dynamo.py +69 -17
  4. agno/db/firestore/firestore.py +68 -29
  5. agno/db/gcs_json/gcs_json_db.py +68 -17
  6. agno/db/in_memory/in_memory_db.py +83 -14
  7. agno/db/json/json_db.py +79 -15
  8. agno/db/mongo/mongo.py +92 -74
  9. agno/db/mysql/mysql.py +17 -3
  10. agno/db/postgres/postgres.py +21 -3
  11. agno/db/redis/redis.py +38 -11
  12. agno/db/singlestore/singlestore.py +14 -3
  13. agno/db/sqlite/sqlite.py +34 -46
  14. agno/db/utils.py +50 -22
  15. agno/knowledge/knowledge.py +6 -0
  16. agno/knowledge/reader/field_labeled_csv_reader.py +294 -0
  17. agno/knowledge/reader/pdf_reader.py +28 -52
  18. agno/knowledge/reader/reader_factory.py +12 -0
  19. agno/memory/manager.py +12 -4
  20. agno/models/anthropic/claude.py +4 -1
  21. agno/models/aws/bedrock.py +52 -112
  22. agno/models/openai/responses.py +1 -1
  23. agno/os/app.py +24 -30
  24. agno/os/interfaces/__init__.py +1 -0
  25. agno/os/interfaces/a2a/__init__.py +3 -0
  26. agno/os/interfaces/a2a/a2a.py +42 -0
  27. agno/os/interfaces/a2a/router.py +252 -0
  28. agno/os/interfaces/a2a/utils.py +924 -0
  29. agno/os/interfaces/agui/agui.py +21 -5
  30. agno/os/interfaces/agui/router.py +12 -0
  31. agno/os/interfaces/base.py +4 -2
  32. agno/os/interfaces/slack/slack.py +13 -8
  33. agno/os/interfaces/whatsapp/whatsapp.py +12 -5
  34. agno/os/mcp.py +1 -1
  35. agno/os/router.py +39 -9
  36. agno/os/routers/memory/memory.py +5 -3
  37. agno/os/routers/memory/schemas.py +1 -0
  38. agno/os/utils.py +36 -10
  39. agno/run/base.py +2 -13
  40. agno/team/team.py +13 -1
  41. agno/tools/mcp.py +46 -1
  42. agno/utils/merge_dict.py +22 -1
  43. agno/utils/serialize.py +32 -0
  44. agno/utils/streamlit.py +1 -1
  45. agno/workflow/parallel.py +90 -14
  46. agno/workflow/step.py +30 -27
  47. agno/workflow/types.py +4 -6
  48. agno/workflow/workflow.py +5 -3
  49. {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/METADATA +16 -14
  50. {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/RECORD +53 -47
  51. {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/WHEEL +0 -0
  52. {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/licenses/LICENSE +0 -0
  53. {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/top_level.txt +0 -0
agno/agent/agent.py CHANGED
@@ -5244,7 +5244,7 @@ class Agent:
5244
5244
 
5245
5245
  # 3.2.5 Add information about agentic filters if enabled
5246
5246
  if self.knowledge is not None and self.enable_agentic_knowledge_filters:
5247
- valid_filters = getattr(self.knowledge, "valid_metadata_filters", None)
5247
+ valid_filters = self.knowledge.get_valid_filters()
5248
5248
  if valid_filters:
5249
5249
  valid_filters_str = ", ".join(valid_filters)
5250
5250
  additional_information.append(
@@ -6420,6 +6420,12 @@ class Agent:
6420
6420
  log_warning("Reasoning error. Reasoning response is empty, continuing regular session...")
6421
6421
  break
6422
6422
 
6423
+ if isinstance(reasoning_agent_response.content, str):
6424
+ log_warning(
6425
+ "Reasoning error. Content is a string, not structured output. Continuing regular session..."
6426
+ )
6427
+ break
6428
+
6423
6429
  if (
6424
6430
  reasoning_agent_response.content.reasoning_steps is None
6425
6431
  or len(reasoning_agent_response.content.reasoning_steps) == 0
@@ -6649,6 +6655,12 @@ class Agent:
6649
6655
  log_warning("Reasoning error. Reasoning response is empty, continuing regular session...")
6650
6656
  break
6651
6657
 
6658
+ if isinstance(reasoning_agent_response.content, str):
6659
+ log_warning(
6660
+ "Reasoning error. Content is a string, not structured output. Continuing regular session..."
6661
+ )
6662
+ break
6663
+
6652
6664
  if reasoning_agent_response.content.reasoning_steps is None:
6653
6665
  log_warning("Reasoning error. Reasoning steps are empty, continuing regular session...")
6654
6666
  break
agno/db/base.py CHANGED
@@ -95,20 +95,23 @@ class BaseDb(ABC):
95
95
  raise NotImplementedError
96
96
 
97
97
  @abstractmethod
98
- def delete_user_memory(self, memory_id: str) -> None:
98
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None) -> None:
99
99
  raise NotImplementedError
100
100
 
101
101
  @abstractmethod
102
- def delete_user_memories(self, memory_ids: List[str]) -> None:
102
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
103
103
  raise NotImplementedError
104
104
 
105
105
  @abstractmethod
106
- def get_all_memory_topics(self) -> List[str]:
106
+ def get_all_memory_topics(self, user_id: Optional[str] = None) -> List[str]:
107
107
  raise NotImplementedError
108
108
 
109
109
  @abstractmethod
110
110
  def get_user_memory(
111
- self, memory_id: str, deserialize: Optional[bool] = True
111
+ self,
112
+ memory_id: str,
113
+ deserialize: Optional[bool] = True,
114
+ user_id: Optional[str] = None,
112
115
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
113
116
  raise NotImplementedError
114
117
 
@@ -133,6 +136,7 @@ class BaseDb(ABC):
133
136
  self,
134
137
  limit: Optional[int] = None,
135
138
  page: Optional[int] = None,
139
+ user_id: Optional[str] = None,
136
140
  ) -> Tuple[List[Dict[str, Any]], int]:
137
141
  raise NotImplementedError
138
142
 
agno/db/dynamo/dynamo.py CHANGED
@@ -562,17 +562,31 @@ class DynamoDb(BaseDb):
562
562
 
563
563
  # --- User Memory ---
564
564
 
565
- def delete_user_memory(self, memory_id: str) -> None:
565
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None) -> None:
566
566
  """
567
567
  Delete a user memory from the database.
568
568
 
569
569
  Args:
570
570
  memory_id: The ID of the memory to delete.
571
+ user_id: The ID of the user (optional, for filtering).
571
572
 
572
573
  Raises:
573
574
  Exception: If any error occurs while deleting the user memory.
574
575
  """
575
576
  try:
577
+ # If user_id is provided, verify the memory belongs to the user before deleting
578
+ if user_id:
579
+ response = self.client.get_item(
580
+ TableName=self.memory_table_name,
581
+ Key={"memory_id": {"S": memory_id}},
582
+ )
583
+ item = response.get("Item")
584
+ if item:
585
+ memory_data = deserialize_from_dynamodb_item(item)
586
+ if memory_data.get("user_id") != user_id:
587
+ log_debug(f"Memory {memory_id} does not belong to user {user_id}")
588
+ return
589
+
576
590
  self.client.delete_item(
577
591
  TableName=self.memory_table_name,
578
592
  Key={"memory_id": {"S": memory_id}},
@@ -583,18 +597,34 @@ class DynamoDb(BaseDb):
583
597
  log_error(f"Failed to delete user memory {memory_id}: {e}")
584
598
  raise e
585
599
 
586
- def delete_user_memories(self, memory_ids: List[str]) -> None:
600
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
587
601
  """
588
602
  Delete user memories from the database in batches.
589
603
 
590
604
  Args:
591
605
  memory_ids: List of memory IDs to delete
606
+ user_id: The ID of the user (optional, for filtering).
592
607
 
593
608
  Raises:
594
609
  Exception: If any error occurs while deleting the user memories.
595
610
  """
596
611
 
597
612
  try:
613
+ # If user_id is provided, filter memory_ids to only those belonging to the user
614
+ if user_id:
615
+ filtered_memory_ids = []
616
+ for memory_id in memory_ids:
617
+ response = self.client.get_item(
618
+ TableName=self.memory_table_name,
619
+ Key={"memory_id": {"S": memory_id}},
620
+ )
621
+ item = response.get("Item")
622
+ if item:
623
+ memory_data = deserialize_from_dynamodb_item(item)
624
+ if memory_data.get("user_id") == user_id:
625
+ filtered_memory_ids.append(memory_id)
626
+ memory_ids = filtered_memory_ids
627
+
598
628
  for i in range(0, len(memory_ids), DYNAMO_BATCH_SIZE_LIMIT):
599
629
  batch = memory_ids[i : i + DYNAMO_BATCH_SIZE_LIMIT]
600
630
 
@@ -611,6 +641,9 @@ class DynamoDb(BaseDb):
611
641
  def get_all_memory_topics(self) -> List[str]:
612
642
  """Get all memory topics from the database.
613
643
 
644
+ Args:
645
+ user_id: The ID of the user (optional, for filtering).
646
+
614
647
  Returns:
615
648
  List[str]: List of unique memory topics.
616
649
  """
@@ -619,13 +652,17 @@ class DynamoDb(BaseDb):
619
652
  if table_name is None:
620
653
  return []
621
654
 
622
- # Scan the entire table to get all memories
623
- response = self.client.scan(TableName=table_name)
655
+ # Build filter expression for user_id if provided
656
+ scan_kwargs = {"TableName": table_name}
657
+
658
+ # Scan the table to get memories
659
+ response = self.client.scan(**scan_kwargs)
624
660
  items = response.get("Items", [])
625
661
 
626
662
  # Handle pagination
627
663
  while "LastEvaluatedKey" in response:
628
- response = self.client.scan(TableName=table_name, ExclusiveStartKey=response["LastEvaluatedKey"])
664
+ scan_kwargs["ExclusiveStartKey"] = response["LastEvaluatedKey"]
665
+ response = self.client.scan(**scan_kwargs)
629
666
  items.extend(response.get("Items", []))
630
667
 
631
668
  # Extract topics from all memories
@@ -642,13 +679,18 @@ class DynamoDb(BaseDb):
642
679
  raise e
643
680
 
644
681
  def get_user_memory(
645
- self, memory_id: str, deserialize: Optional[bool] = True
682
+ self,
683
+ memory_id: str,
684
+ deserialize: Optional[bool] = True,
685
+ user_id: Optional[str] = None,
646
686
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
647
687
  """
648
688
  Get a user memory from the database as a UserMemory object.
649
689
 
650
690
  Args:
651
691
  memory_id: The ID of the memory to get.
692
+ deserialize: Whether to deserialize the memory.
693
+ user_id: The ID of the user (optional, for filtering).
652
694
 
653
695
  Returns:
654
696
  Optional[UserMemory]: The user memory data if found, None otherwise.
@@ -665,6 +707,11 @@ class DynamoDb(BaseDb):
665
707
  return None
666
708
 
667
709
  item = deserialize_from_dynamodb_item(item)
710
+
711
+ # Filter by user_id if provided
712
+ if user_id and item.get("user_id") != user_id:
713
+ return None
714
+
668
715
  if not deserialize:
669
716
  return item
670
717
 
@@ -804,6 +851,7 @@ class DynamoDb(BaseDb):
804
851
  Args:
805
852
  limit (Optional[int]): The maximum number of user stats to return.
806
853
  page (Optional[int]): The page number.
854
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
807
855
 
808
856
  Returns:
809
857
  Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
@@ -823,29 +871,33 @@ class DynamoDb(BaseDb):
823
871
  try:
824
872
  table_name = self._get_table("memories")
825
873
 
826
- response = self.client.scan(TableName=table_name)
874
+ # Build filter expression for user_id if provided
875
+ scan_kwargs = {"TableName": table_name}
876
+
877
+ response = self.client.scan(**scan_kwargs)
827
878
  items = response.get("Items", [])
828
879
 
829
880
  # Handle pagination
830
881
  while "LastEvaluatedKey" in response:
831
- response = self.client.scan(TableName=table_name, ExclusiveStartKey=response["LastEvaluatedKey"])
882
+ scan_kwargs["ExclusiveStartKey"] = response["LastEvaluatedKey"]
883
+ response = self.client.scan(**scan_kwargs)
832
884
  items.extend(response.get("Items", []))
833
885
 
834
886
  # Aggregate stats by user_id
835
887
  user_stats = {}
836
888
  for item in items:
837
889
  memory_data = deserialize_from_dynamodb_item(item)
838
- user_id = memory_data.get("user_id")
890
+ current_user_id = memory_data.get("user_id")
839
891
 
840
- if user_id:
841
- if user_id not in user_stats:
842
- user_stats[user_id] = {
843
- "user_id": user_id,
892
+ if current_user_id:
893
+ if current_user_id not in user_stats:
894
+ user_stats[current_user_id] = {
895
+ "user_id": current_user_id,
844
896
  "total_memories": 0,
845
897
  "last_memory_updated_at": None,
846
898
  }
847
899
 
848
- user_stats[user_id]["total_memories"] += 1
900
+ user_stats[current_user_id]["total_memories"] += 1
849
901
 
850
902
  updated_at = memory_data.get("updated_at")
851
903
  if updated_at:
@@ -853,10 +905,10 @@ class DynamoDb(BaseDb):
853
905
  updated_at_timestamp = int(updated_at_dt.timestamp())
854
906
 
855
907
  if updated_at_timestamp and (
856
- user_stats[user_id]["last_memory_updated_at"] is None
857
- or updated_at_timestamp > user_stats[user_id]["last_memory_updated_at"]
908
+ user_stats[current_user_id]["last_memory_updated_at"] is None
909
+ or updated_at_timestamp > user_stats[current_user_id]["last_memory_updated_at"]
858
910
  ):
859
- user_stats[user_id]["last_memory_updated_at"] = updated_at_timestamp
911
+ user_stats[current_user_id]["last_memory_updated_at"] = updated_at_timestamp
860
912
 
861
913
  # Convert to list and apply sorting
862
914
  stats_list = list(user_stats.values())
@@ -596,11 +596,12 @@ class FirestoreDb(BaseDb):
596
596
 
597
597
  # -- Memory methods --
598
598
 
599
- def delete_user_memory(self, memory_id: str):
599
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
600
600
  """Delete a user memory from the database.
601
601
 
602
602
  Args:
603
603
  memory_id (str): The ID of the memory to delete.
604
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
604
605
 
605
606
  Returns:
606
607
  bool: True if the memory was deleted, False otherwise.
@@ -610,28 +611,41 @@ class FirestoreDb(BaseDb):
610
611
  """
611
612
  try:
612
613
  collection_ref = self._get_collection(table_type="memories")
613
- docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
614
614
 
615
- deleted_count = 0
616
- for doc in docs:
617
- doc.reference.delete()
618
- deleted_count += 1
619
-
620
- success = deleted_count > 0
621
- if success:
622
- log_debug(f"Successfully deleted user memory id: {memory_id}")
615
+ # If user_id is provided, verify the memory belongs to the user before deleting
616
+ if user_id:
617
+ docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
618
+ for doc in docs:
619
+ data = doc.to_dict()
620
+ if data.get("user_id") != user_id:
621
+ log_debug(f"Memory {memory_id} does not belong to user {user_id}")
622
+ return
623
+ doc.reference.delete()
624
+ log_debug(f"Successfully deleted user memory id: {memory_id}")
625
+ return
623
626
  else:
624
- log_debug(f"No user memory found with id: {memory_id}")
627
+ docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
628
+ deleted_count = 0
629
+ for doc in docs:
630
+ doc.reference.delete()
631
+ deleted_count += 1
632
+
633
+ success = deleted_count > 0
634
+ if success:
635
+ log_debug(f"Successfully deleted user memory id: {memory_id}")
636
+ else:
637
+ log_debug(f"No user memory found with id: {memory_id}")
625
638
 
626
639
  except Exception as e:
627
640
  log_error(f"Error deleting user memory: {e}")
628
641
  raise e
629
642
 
630
- def delete_user_memories(self, memory_ids: List[str]) -> None:
643
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
631
644
  """Delete user memories from the database.
632
645
 
633
646
  Args:
634
647
  memory_ids (List[str]): The IDs of the memories to delete.
648
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
635
649
 
636
650
  Raises:
637
651
  Exception: If there is an error deleting the memories.
@@ -641,11 +655,21 @@ class FirestoreDb(BaseDb):
641
655
  batch = self.db_client.batch()
642
656
  deleted_count = 0
643
657
 
644
- for memory_id in memory_ids:
645
- docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
646
- for doc in docs:
647
- batch.delete(doc.reference)
648
- deleted_count += 1
658
+ # If user_id is provided, filter memory_ids to only those belonging to the user
659
+ if user_id:
660
+ for memory_id in memory_ids:
661
+ docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
662
+ for doc in docs:
663
+ data = doc.to_dict()
664
+ if data.get("user_id") == user_id:
665
+ batch.delete(doc.reference)
666
+ deleted_count += 1
667
+ else:
668
+ for memory_id in memory_ids:
669
+ docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
670
+ for doc in docs:
671
+ batch.delete(doc.reference)
672
+ deleted_count += 1
649
673
 
650
674
  batch.commit()
651
675
 
@@ -658,7 +682,9 @@ class FirestoreDb(BaseDb):
658
682
  log_error(f"Error deleting memories: {e}")
659
683
  raise e
660
684
 
661
- def get_all_memory_topics(self, create_collection_if_not_found: Optional[bool] = True) -> List[str]:
685
+ def get_all_memory_topics(
686
+ self, create_collection_if_not_found: Optional[bool] = True
687
+ ) -> List[str]:
662
688
  """Get all memory topics from the database.
663
689
 
664
690
  Returns:
@@ -687,12 +713,15 @@ class FirestoreDb(BaseDb):
687
713
  log_error(f"Exception getting all memory topics: {e}")
688
714
  raise e
689
715
 
690
- def get_user_memory(self, memory_id: str, deserialize: Optional[bool] = True) -> Optional[UserMemory]:
716
+ def get_user_memory(
717
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
718
+ ) -> Optional[UserMemory]:
691
719
  """Get a memory from the database.
692
720
 
693
721
  Args:
694
722
  memory_id (str): The ID of the memory to get.
695
723
  deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
724
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
696
725
 
697
726
  Returns:
698
727
  Optional[UserMemory]:
@@ -711,7 +740,14 @@ class FirestoreDb(BaseDb):
711
740
  result = doc.to_dict()
712
741
  break
713
742
 
714
- if result is None or not deserialize:
743
+ if result is None:
744
+ return None
745
+
746
+ # Filter by user_id if provided
747
+ if user_id and result.get("user_id") != user_id:
748
+ return None
749
+
750
+ if not deserialize:
715
751
  return result
716
752
 
717
753
  return UserMemory.from_dict(result)
@@ -818,23 +854,26 @@ class FirestoreDb(BaseDb):
818
854
  """
819
855
  try:
820
856
  collection_ref = self._get_collection(table_type="memories")
821
- docs = collection_ref.where(filter=FieldFilter("user_id", "!=", None)).stream()
857
+
858
+ query = collection_ref.where(filter=FieldFilter("user_id", "!=", None))
859
+
860
+ docs = query.stream()
822
861
 
823
862
  user_stats = {}
824
863
  for doc in docs:
825
864
  data = doc.to_dict()
826
- user_id = data.get("user_id")
827
- if user_id:
828
- if user_id not in user_stats:
829
- user_stats[user_id] = {
830
- "user_id": user_id,
865
+ current_user_id = data.get("user_id")
866
+ if current_user_id:
867
+ if current_user_id not in user_stats:
868
+ user_stats[current_user_id] = {
869
+ "user_id": current_user_id,
831
870
  "total_memories": 0,
832
871
  "last_memory_updated_at": 0,
833
872
  }
834
- user_stats[user_id]["total_memories"] += 1
873
+ user_stats[current_user_id]["total_memories"] += 1
835
874
  updated_at = data.get("updated_at", 0)
836
- if updated_at > user_stats[user_id]["last_memory_updated_at"]:
837
- user_stats[user_id]["last_memory_updated_at"] = updated_at
875
+ if updated_at > user_stats[current_user_id]["last_memory_updated_at"]:
876
+ user_stats[current_user_id]["last_memory_updated_at"] = updated_at
838
877
 
839
878
  # Convert to list and sort
840
879
  formatted_results = list(user_stats.values())
@@ -459,12 +459,22 @@ class GcsJsonDb(BaseDb):
459
459
  return False
460
460
 
461
461
  # -- Memory methods --
462
- def delete_user_memory(self, memory_id: str) -> None:
463
- """Delete a user memory from the GCS JSON file."""
462
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None) -> None:
463
+ """Delete a user memory from the GCS JSON file.
464
+
465
+ Args:
466
+ memory_id (str): The ID of the memory to delete.
467
+ user_id (Optional[str]): The ID of the user. If provided, verifies ownership before deletion.
468
+ """
464
469
  try:
465
470
  memories = self._read_json_file(self.memory_table_name)
466
471
  original_count = len(memories)
467
- memories = [m for m in memories if m.get("memory_id") != memory_id]
472
+
473
+ # Filter out the memory, with optional user_id verification
474
+ memories = [
475
+ m for m in memories
476
+ if not (m.get("memory_id") == memory_id and (user_id is None or m.get("user_id") == user_id))
477
+ ]
468
478
 
469
479
  if len(memories) < original_count:
470
480
  self._write_json_file(self.memory_table_name, memories)
@@ -477,11 +487,22 @@ class GcsJsonDb(BaseDb):
477
487
  log_warning(f"Error deleting user memory: {e}")
478
488
  raise e
479
489
 
480
- def delete_user_memories(self, memory_ids: List[str]) -> None:
481
- """Delete multiple user memories from the GCS JSON file."""
490
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
491
+ """Delete multiple user memories from the GCS JSON file.
492
+
493
+ Args:
494
+ memory_ids (List[str]): The IDs of the memories to delete.
495
+ user_id (Optional[str]): The ID of the user. If provided, verifies ownership before deletion.
496
+ """
482
497
  try:
483
498
  memories = self._read_json_file(self.memory_table_name)
484
- memories = [m for m in memories if m.get("memory_id") not in memory_ids]
499
+
500
+ # Filter out memories, with optional user_id verification
501
+ memories = [
502
+ m for m in memories
503
+ if not (m.get("memory_id") in memory_ids and (user_id is None or m.get("user_id") == user_id))
504
+ ]
505
+
485
506
  self._write_json_file(self.memory_table_name, memories)
486
507
  log_debug(f"Successfully deleted user memories with ids: {memory_ids}")
487
508
  except Exception as e:
@@ -489,7 +510,11 @@ class GcsJsonDb(BaseDb):
489
510
  raise e
490
511
 
491
512
  def get_all_memory_topics(self) -> List[str]:
492
- """Get all memory topics from the GCS JSON file."""
513
+ """Get all memory topics from the GCS JSON file.
514
+
515
+ Returns:
516
+ List[str]: List of unique memory topics.
517
+ """
493
518
  try:
494
519
  memories = self._read_json_file(self.memory_table_name)
495
520
  topics = set()
@@ -504,14 +529,27 @@ class GcsJsonDb(BaseDb):
504
529
  raise e
505
530
 
506
531
  def get_user_memory(
507
- self, memory_id: str, deserialize: Optional[bool] = True
532
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
508
533
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
509
- """Get a memory from the GCS JSON file."""
534
+ """Get a memory from the GCS JSON file.
535
+
536
+ Args:
537
+ memory_id (str): The ID of the memory to retrieve.
538
+ deserialize (Optional[bool]): Whether to deserialize to UserMemory object. Defaults to True.
539
+ user_id (Optional[str]): The ID of the user. If provided, verifies ownership before returning.
540
+
541
+ Returns:
542
+ Optional[Union[UserMemory, Dict[str, Any]]]: The memory if found and ownership matches, None otherwise.
543
+ """
510
544
  try:
511
545
  memories = self._read_json_file(self.memory_table_name)
512
546
 
513
547
  for memory_data in memories:
514
548
  if memory_data.get("memory_id") == memory_id:
549
+ # Verify user ownership if user_id is provided
550
+ if user_id is not None and memory_data.get("user_id") != user_id:
551
+ continue
552
+
515
553
  if not deserialize:
516
554
  return memory_data
517
555
 
@@ -583,20 +621,33 @@ class GcsJsonDb(BaseDb):
583
621
  def get_user_memory_stats(
584
622
  self, limit: Optional[int] = None, page: Optional[int] = None
585
623
  ) -> Tuple[List[Dict[str, Any]], int]:
586
- """Get user memory statistics."""
624
+ """Get user memory statistics.
625
+
626
+ Args:
627
+ limit (Optional[int]): Maximum number of results to return.
628
+ page (Optional[int]): Page number for pagination.
629
+
630
+ Returns:
631
+ Tuple[List[Dict[str, Any]], int]: List of user memory statistics and total count.
632
+ """
587
633
  try:
588
634
  memories = self._read_json_file(self.memory_table_name)
589
635
  user_stats = {}
590
636
 
591
637
  for memory in memories:
592
- user_id = memory.get("user_id")
593
- if user_id:
594
- if user_id not in user_stats:
595
- user_stats[user_id] = {"user_id": user_id, "total_memories": 0, "last_memory_updated_at": 0}
596
- user_stats[user_id]["total_memories"] += 1
638
+ memory_user_id = memory.get("user_id")
639
+
640
+ if memory_user_id:
641
+ if memory_user_id not in user_stats:
642
+ user_stats[memory_user_id] = {
643
+ "user_id": memory_user_id,
644
+ "total_memories": 0,
645
+ "last_memory_updated_at": 0
646
+ }
647
+ user_stats[memory_user_id]["total_memories"] += 1
597
648
  updated_at = memory.get("updated_at", 0)
598
- if updated_at > user_stats[user_id]["last_memory_updated_at"]:
599
- user_stats[user_id]["last_memory_updated_at"] = updated_at
649
+ if updated_at > user_stats[memory_user_id]["last_memory_updated_at"]:
650
+ user_stats[memory_user_id]["last_memory_updated_at"] = updated_at
600
651
 
601
652
  stats_list = list(user_stats.values())
602
653
  stats_list.sort(key=lambda x: x["last_memory_updated_at"], reverse=True)