agno 2.1.1__py3-none-any.whl → 2.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. agno/agent/agent.py +12 -0
  2. agno/db/base.py +8 -4
  3. agno/db/dynamo/dynamo.py +69 -17
  4. agno/db/firestore/firestore.py +65 -28
  5. agno/db/gcs_json/gcs_json_db.py +70 -17
  6. agno/db/in_memory/in_memory_db.py +85 -14
  7. agno/db/json/json_db.py +79 -15
  8. agno/db/mongo/mongo.py +27 -8
  9. agno/db/mysql/mysql.py +17 -3
  10. agno/db/postgres/postgres.py +21 -3
  11. agno/db/redis/redis.py +38 -11
  12. agno/db/singlestore/singlestore.py +14 -3
  13. agno/db/sqlite/sqlite.py +34 -46
  14. agno/knowledge/reader/field_labeled_csv_reader.py +294 -0
  15. agno/knowledge/reader/pdf_reader.py +28 -52
  16. agno/knowledge/reader/reader_factory.py +12 -0
  17. agno/memory/manager.py +12 -4
  18. agno/models/anthropic/claude.py +4 -1
  19. agno/models/aws/bedrock.py +52 -112
  20. agno/models/openrouter/openrouter.py +39 -1
  21. agno/models/vertexai/__init__.py +0 -0
  22. agno/models/vertexai/claude.py +74 -0
  23. agno/os/app.py +76 -32
  24. agno/os/interfaces/a2a/__init__.py +3 -0
  25. agno/os/interfaces/a2a/a2a.py +42 -0
  26. agno/os/interfaces/a2a/router.py +252 -0
  27. agno/os/interfaces/a2a/utils.py +924 -0
  28. agno/os/interfaces/agui/router.py +12 -0
  29. agno/os/mcp.py +3 -3
  30. agno/os/router.py +38 -8
  31. agno/os/routers/memory/memory.py +5 -3
  32. agno/os/routers/memory/schemas.py +1 -0
  33. agno/os/utils.py +37 -10
  34. agno/team/team.py +12 -0
  35. agno/tools/file.py +4 -2
  36. agno/tools/mcp.py +46 -1
  37. agno/utils/merge_dict.py +22 -1
  38. agno/utils/streamlit.py +1 -1
  39. agno/workflow/parallel.py +90 -14
  40. agno/workflow/step.py +30 -27
  41. agno/workflow/workflow.py +12 -6
  42. {agno-2.1.1.dist-info → agno-2.1.3.dist-info}/METADATA +16 -14
  43. {agno-2.1.1.dist-info → agno-2.1.3.dist-info}/RECORD +46 -39
  44. {agno-2.1.1.dist-info → agno-2.1.3.dist-info}/WHEEL +0 -0
  45. {agno-2.1.1.dist-info → agno-2.1.3.dist-info}/licenses/LICENSE +0 -0
  46. {agno-2.1.1.dist-info → agno-2.1.3.dist-info}/top_level.txt +0 -0
agno/agent/agent.py CHANGED
@@ -6420,6 +6420,12 @@ class Agent:
6420
6420
  log_warning("Reasoning error. Reasoning response is empty, continuing regular session...")
6421
6421
  break
6422
6422
 
6423
+ if isinstance(reasoning_agent_response.content, str):
6424
+ log_warning(
6425
+ "Reasoning error. Content is a string, not structured output. Continuing regular session..."
6426
+ )
6427
+ break
6428
+
6423
6429
  if (
6424
6430
  reasoning_agent_response.content.reasoning_steps is None
6425
6431
  or len(reasoning_agent_response.content.reasoning_steps) == 0
@@ -6649,6 +6655,12 @@ class Agent:
6649
6655
  log_warning("Reasoning error. Reasoning response is empty, continuing regular session...")
6650
6656
  break
6651
6657
 
6658
+ if isinstance(reasoning_agent_response.content, str):
6659
+ log_warning(
6660
+ "Reasoning error. Content is a string, not structured output. Continuing regular session..."
6661
+ )
6662
+ break
6663
+
6652
6664
  if reasoning_agent_response.content.reasoning_steps is None:
6653
6665
  log_warning("Reasoning error. Reasoning steps are empty, continuing regular session...")
6654
6666
  break
agno/db/base.py CHANGED
@@ -95,20 +95,23 @@ class BaseDb(ABC):
95
95
  raise NotImplementedError
96
96
 
97
97
  @abstractmethod
98
- def delete_user_memory(self, memory_id: str) -> None:
98
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None) -> None:
99
99
  raise NotImplementedError
100
100
 
101
101
  @abstractmethod
102
- def delete_user_memories(self, memory_ids: List[str]) -> None:
102
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
103
103
  raise NotImplementedError
104
104
 
105
105
  @abstractmethod
106
- def get_all_memory_topics(self) -> List[str]:
106
+ def get_all_memory_topics(self, user_id: Optional[str] = None) -> List[str]:
107
107
  raise NotImplementedError
108
108
 
109
109
  @abstractmethod
110
110
  def get_user_memory(
111
- self, memory_id: str, deserialize: Optional[bool] = True
111
+ self,
112
+ memory_id: str,
113
+ deserialize: Optional[bool] = True,
114
+ user_id: Optional[str] = None,
112
115
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
113
116
  raise NotImplementedError
114
117
 
@@ -133,6 +136,7 @@ class BaseDb(ABC):
133
136
  self,
134
137
  limit: Optional[int] = None,
135
138
  page: Optional[int] = None,
139
+ user_id: Optional[str] = None,
136
140
  ) -> Tuple[List[Dict[str, Any]], int]:
137
141
  raise NotImplementedError
138
142
 
agno/db/dynamo/dynamo.py CHANGED
@@ -562,17 +562,31 @@ class DynamoDb(BaseDb):
562
562
 
563
563
  # --- User Memory ---
564
564
 
565
- def delete_user_memory(self, memory_id: str) -> None:
565
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None) -> None:
566
566
  """
567
567
  Delete a user memory from the database.
568
568
 
569
569
  Args:
570
570
  memory_id: The ID of the memory to delete.
571
+ user_id: The ID of the user (optional, for filtering).
571
572
 
572
573
  Raises:
573
574
  Exception: If any error occurs while deleting the user memory.
574
575
  """
575
576
  try:
577
+ # If user_id is provided, verify the memory belongs to the user before deleting
578
+ if user_id:
579
+ response = self.client.get_item(
580
+ TableName=self.memory_table_name,
581
+ Key={"memory_id": {"S": memory_id}},
582
+ )
583
+ item = response.get("Item")
584
+ if item:
585
+ memory_data = deserialize_from_dynamodb_item(item)
586
+ if memory_data.get("user_id") != user_id:
587
+ log_debug(f"Memory {memory_id} does not belong to user {user_id}")
588
+ return
589
+
576
590
  self.client.delete_item(
577
591
  TableName=self.memory_table_name,
578
592
  Key={"memory_id": {"S": memory_id}},
@@ -583,18 +597,34 @@ class DynamoDb(BaseDb):
583
597
  log_error(f"Failed to delete user memory {memory_id}: {e}")
584
598
  raise e
585
599
 
586
- def delete_user_memories(self, memory_ids: List[str]) -> None:
600
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
587
601
  """
588
602
  Delete user memories from the database in batches.
589
603
 
590
604
  Args:
591
605
  memory_ids: List of memory IDs to delete
606
+ user_id: The ID of the user (optional, for filtering).
592
607
 
593
608
  Raises:
594
609
  Exception: If any error occurs while deleting the user memories.
595
610
  """
596
611
 
597
612
  try:
613
+ # If user_id is provided, filter memory_ids to only those belonging to the user
614
+ if user_id:
615
+ filtered_memory_ids = []
616
+ for memory_id in memory_ids:
617
+ response = self.client.get_item(
618
+ TableName=self.memory_table_name,
619
+ Key={"memory_id": {"S": memory_id}},
620
+ )
621
+ item = response.get("Item")
622
+ if item:
623
+ memory_data = deserialize_from_dynamodb_item(item)
624
+ if memory_data.get("user_id") == user_id:
625
+ filtered_memory_ids.append(memory_id)
626
+ memory_ids = filtered_memory_ids
627
+
598
628
  for i in range(0, len(memory_ids), DYNAMO_BATCH_SIZE_LIMIT):
599
629
  batch = memory_ids[i : i + DYNAMO_BATCH_SIZE_LIMIT]
600
630
 
@@ -611,6 +641,9 @@ class DynamoDb(BaseDb):
611
641
  def get_all_memory_topics(self) -> List[str]:
612
642
  """Get all memory topics from the database.
613
643
 
644
+ Args:
645
+ user_id: The ID of the user (optional, for filtering).
646
+
614
647
  Returns:
615
648
  List[str]: List of unique memory topics.
616
649
  """
@@ -619,13 +652,17 @@ class DynamoDb(BaseDb):
619
652
  if table_name is None:
620
653
  return []
621
654
 
622
- # Scan the entire table to get all memories
623
- response = self.client.scan(TableName=table_name)
655
+ # Build filter expression for user_id if provided
656
+ scan_kwargs = {"TableName": table_name}
657
+
658
+ # Scan the table to get memories
659
+ response = self.client.scan(**scan_kwargs)
624
660
  items = response.get("Items", [])
625
661
 
626
662
  # Handle pagination
627
663
  while "LastEvaluatedKey" in response:
628
- response = self.client.scan(TableName=table_name, ExclusiveStartKey=response["LastEvaluatedKey"])
664
+ scan_kwargs["ExclusiveStartKey"] = response["LastEvaluatedKey"]
665
+ response = self.client.scan(**scan_kwargs)
629
666
  items.extend(response.get("Items", []))
630
667
 
631
668
  # Extract topics from all memories
@@ -642,13 +679,18 @@ class DynamoDb(BaseDb):
642
679
  raise e
643
680
 
644
681
  def get_user_memory(
645
- self, memory_id: str, deserialize: Optional[bool] = True
682
+ self,
683
+ memory_id: str,
684
+ deserialize: Optional[bool] = True,
685
+ user_id: Optional[str] = None,
646
686
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
647
687
  """
648
688
  Get a user memory from the database as a UserMemory object.
649
689
 
650
690
  Args:
651
691
  memory_id: The ID of the memory to get.
692
+ deserialize: Whether to deserialize the memory.
693
+ user_id: The ID of the user (optional, for filtering).
652
694
 
653
695
  Returns:
654
696
  Optional[UserMemory]: The user memory data if found, None otherwise.
@@ -665,6 +707,11 @@ class DynamoDb(BaseDb):
665
707
  return None
666
708
 
667
709
  item = deserialize_from_dynamodb_item(item)
710
+
711
+ # Filter by user_id if provided
712
+ if user_id and item.get("user_id") != user_id:
713
+ return None
714
+
668
715
  if not deserialize:
669
716
  return item
670
717
 
@@ -804,6 +851,7 @@ class DynamoDb(BaseDb):
804
851
  Args:
805
852
  limit (Optional[int]): The maximum number of user stats to return.
806
853
  page (Optional[int]): The page number.
854
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
807
855
 
808
856
  Returns:
809
857
  Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
@@ -823,29 +871,33 @@ class DynamoDb(BaseDb):
823
871
  try:
824
872
  table_name = self._get_table("memories")
825
873
 
826
- response = self.client.scan(TableName=table_name)
874
+ # Build filter expression for user_id if provided
875
+ scan_kwargs = {"TableName": table_name}
876
+
877
+ response = self.client.scan(**scan_kwargs)
827
878
  items = response.get("Items", [])
828
879
 
829
880
  # Handle pagination
830
881
  while "LastEvaluatedKey" in response:
831
- response = self.client.scan(TableName=table_name, ExclusiveStartKey=response["LastEvaluatedKey"])
882
+ scan_kwargs["ExclusiveStartKey"] = response["LastEvaluatedKey"]
883
+ response = self.client.scan(**scan_kwargs)
832
884
  items.extend(response.get("Items", []))
833
885
 
834
886
  # Aggregate stats by user_id
835
887
  user_stats = {}
836
888
  for item in items:
837
889
  memory_data = deserialize_from_dynamodb_item(item)
838
- user_id = memory_data.get("user_id")
890
+ current_user_id = memory_data.get("user_id")
839
891
 
840
- if user_id:
841
- if user_id not in user_stats:
842
- user_stats[user_id] = {
843
- "user_id": user_id,
892
+ if current_user_id:
893
+ if current_user_id not in user_stats:
894
+ user_stats[current_user_id] = {
895
+ "user_id": current_user_id,
844
896
  "total_memories": 0,
845
897
  "last_memory_updated_at": None,
846
898
  }
847
899
 
848
- user_stats[user_id]["total_memories"] += 1
900
+ user_stats[current_user_id]["total_memories"] += 1
849
901
 
850
902
  updated_at = memory_data.get("updated_at")
851
903
  if updated_at:
@@ -853,10 +905,10 @@ class DynamoDb(BaseDb):
853
905
  updated_at_timestamp = int(updated_at_dt.timestamp())
854
906
 
855
907
  if updated_at_timestamp and (
856
- user_stats[user_id]["last_memory_updated_at"] is None
857
- or updated_at_timestamp > user_stats[user_id]["last_memory_updated_at"]
908
+ user_stats[current_user_id]["last_memory_updated_at"] is None
909
+ or updated_at_timestamp > user_stats[current_user_id]["last_memory_updated_at"]
858
910
  ):
859
- user_stats[user_id]["last_memory_updated_at"] = updated_at_timestamp
911
+ user_stats[current_user_id]["last_memory_updated_at"] = updated_at_timestamp
860
912
 
861
913
  # Convert to list and apply sorting
862
914
  stats_list = list(user_stats.values())
@@ -596,11 +596,12 @@ class FirestoreDb(BaseDb):
596
596
 
597
597
  # -- Memory methods --
598
598
 
599
- def delete_user_memory(self, memory_id: str):
599
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
600
600
  """Delete a user memory from the database.
601
601
 
602
602
  Args:
603
603
  memory_id (str): The ID of the memory to delete.
604
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
604
605
 
605
606
  Returns:
606
607
  bool: True if the memory was deleted, False otherwise.
@@ -610,28 +611,41 @@ class FirestoreDb(BaseDb):
610
611
  """
611
612
  try:
612
613
  collection_ref = self._get_collection(table_type="memories")
613
- docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
614
614
 
615
- deleted_count = 0
616
- for doc in docs:
617
- doc.reference.delete()
618
- deleted_count += 1
619
-
620
- success = deleted_count > 0
621
- if success:
622
- log_debug(f"Successfully deleted user memory id: {memory_id}")
615
+ # If user_id is provided, verify the memory belongs to the user before deleting
616
+ if user_id:
617
+ docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
618
+ for doc in docs:
619
+ data = doc.to_dict()
620
+ if data.get("user_id") != user_id:
621
+ log_debug(f"Memory {memory_id} does not belong to user {user_id}")
622
+ return
623
+ doc.reference.delete()
624
+ log_debug(f"Successfully deleted user memory id: {memory_id}")
625
+ return
623
626
  else:
624
- log_debug(f"No user memory found with id: {memory_id}")
627
+ docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
628
+ deleted_count = 0
629
+ for doc in docs:
630
+ doc.reference.delete()
631
+ deleted_count += 1
632
+
633
+ success = deleted_count > 0
634
+ if success:
635
+ log_debug(f"Successfully deleted user memory id: {memory_id}")
636
+ else:
637
+ log_debug(f"No user memory found with id: {memory_id}")
625
638
 
626
639
  except Exception as e:
627
640
  log_error(f"Error deleting user memory: {e}")
628
641
  raise e
629
642
 
630
- def delete_user_memories(self, memory_ids: List[str]) -> None:
643
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
631
644
  """Delete user memories from the database.
632
645
 
633
646
  Args:
634
647
  memory_ids (List[str]): The IDs of the memories to delete.
648
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
635
649
 
636
650
  Raises:
637
651
  Exception: If there is an error deleting the memories.
@@ -641,11 +655,21 @@ class FirestoreDb(BaseDb):
641
655
  batch = self.db_client.batch()
642
656
  deleted_count = 0
643
657
 
644
- for memory_id in memory_ids:
645
- docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
646
- for doc in docs:
647
- batch.delete(doc.reference)
648
- deleted_count += 1
658
+ # If user_id is provided, filter memory_ids to only those belonging to the user
659
+ if user_id:
660
+ for memory_id in memory_ids:
661
+ docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
662
+ for doc in docs:
663
+ data = doc.to_dict()
664
+ if data.get("user_id") == user_id:
665
+ batch.delete(doc.reference)
666
+ deleted_count += 1
667
+ else:
668
+ for memory_id in memory_ids:
669
+ docs = collection_ref.where(filter=FieldFilter("memory_id", "==", memory_id)).stream()
670
+ for doc in docs:
671
+ batch.delete(doc.reference)
672
+ deleted_count += 1
649
673
 
650
674
  batch.commit()
651
675
 
@@ -687,12 +711,15 @@ class FirestoreDb(BaseDb):
687
711
  log_error(f"Exception getting all memory topics: {e}")
688
712
  raise e
689
713
 
690
- def get_user_memory(self, memory_id: str, deserialize: Optional[bool] = True) -> Optional[UserMemory]:
714
+ def get_user_memory(
715
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
716
+ ) -> Optional[UserMemory]:
691
717
  """Get a memory from the database.
692
718
 
693
719
  Args:
694
720
  memory_id (str): The ID of the memory to get.
695
721
  deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
722
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
696
723
 
697
724
  Returns:
698
725
  Optional[UserMemory]:
@@ -711,7 +738,14 @@ class FirestoreDb(BaseDb):
711
738
  result = doc.to_dict()
712
739
  break
713
740
 
714
- if result is None or not deserialize:
741
+ if result is None:
742
+ return None
743
+
744
+ # Filter by user_id if provided
745
+ if user_id and result.get("user_id") != user_id:
746
+ return None
747
+
748
+ if not deserialize:
715
749
  return result
716
750
 
717
751
  return UserMemory.from_dict(result)
@@ -818,23 +852,26 @@ class FirestoreDb(BaseDb):
818
852
  """
819
853
  try:
820
854
  collection_ref = self._get_collection(table_type="memories")
821
- docs = collection_ref.where(filter=FieldFilter("user_id", "!=", None)).stream()
855
+
856
+ query = collection_ref.where(filter=FieldFilter("user_id", "!=", None))
857
+
858
+ docs = query.stream()
822
859
 
823
860
  user_stats = {}
824
861
  for doc in docs:
825
862
  data = doc.to_dict()
826
- user_id = data.get("user_id")
827
- if user_id:
828
- if user_id not in user_stats:
829
- user_stats[user_id] = {
830
- "user_id": user_id,
863
+ current_user_id = data.get("user_id")
864
+ if current_user_id:
865
+ if current_user_id not in user_stats:
866
+ user_stats[current_user_id] = {
867
+ "user_id": current_user_id,
831
868
  "total_memories": 0,
832
869
  "last_memory_updated_at": 0,
833
870
  }
834
- user_stats[user_id]["total_memories"] += 1
871
+ user_stats[current_user_id]["total_memories"] += 1
835
872
  updated_at = data.get("updated_at", 0)
836
- if updated_at > user_stats[user_id]["last_memory_updated_at"]:
837
- user_stats[user_id]["last_memory_updated_at"] = updated_at
873
+ if updated_at > user_stats[current_user_id]["last_memory_updated_at"]:
874
+ user_stats[current_user_id]["last_memory_updated_at"] = updated_at
838
875
 
839
876
  # Convert to list and sort
840
877
  formatted_results = list(user_stats.values())
@@ -459,12 +459,23 @@ class GcsJsonDb(BaseDb):
459
459
  return False
460
460
 
461
461
  # -- Memory methods --
462
- def delete_user_memory(self, memory_id: str) -> None:
463
- """Delete a user memory from the GCS JSON file."""
462
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None) -> None:
463
+ """Delete a user memory from the GCS JSON file.
464
+
465
+ Args:
466
+ memory_id (str): The ID of the memory to delete.
467
+ user_id (Optional[str]): The ID of the user. If provided, verifies ownership before deletion.
468
+ """
464
469
  try:
465
470
  memories = self._read_json_file(self.memory_table_name)
466
471
  original_count = len(memories)
467
- memories = [m for m in memories if m.get("memory_id") != memory_id]
472
+
473
+ # Filter out the memory, with optional user_id verification
474
+ memories = [
475
+ m
476
+ for m in memories
477
+ if not (m.get("memory_id") == memory_id and (user_id is None or m.get("user_id") == user_id))
478
+ ]
468
479
 
469
480
  if len(memories) < original_count:
470
481
  self._write_json_file(self.memory_table_name, memories)
@@ -477,11 +488,23 @@ class GcsJsonDb(BaseDb):
477
488
  log_warning(f"Error deleting user memory: {e}")
478
489
  raise e
479
490
 
480
- def delete_user_memories(self, memory_ids: List[str]) -> None:
481
- """Delete multiple user memories from the GCS JSON file."""
491
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
492
+ """Delete multiple user memories from the GCS JSON file.
493
+
494
+ Args:
495
+ memory_ids (List[str]): The IDs of the memories to delete.
496
+ user_id (Optional[str]): The ID of the user. If provided, verifies ownership before deletion.
497
+ """
482
498
  try:
483
499
  memories = self._read_json_file(self.memory_table_name)
484
- memories = [m for m in memories if m.get("memory_id") not in memory_ids]
500
+
501
+ # Filter out memories, with optional user_id verification
502
+ memories = [
503
+ m
504
+ for m in memories
505
+ if not (m.get("memory_id") in memory_ids and (user_id is None or m.get("user_id") == user_id))
506
+ ]
507
+
485
508
  self._write_json_file(self.memory_table_name, memories)
486
509
  log_debug(f"Successfully deleted user memories with ids: {memory_ids}")
487
510
  except Exception as e:
@@ -489,7 +512,11 @@ class GcsJsonDb(BaseDb):
489
512
  raise e
490
513
 
491
514
  def get_all_memory_topics(self) -> List[str]:
492
- """Get all memory topics from the GCS JSON file."""
515
+ """Get all memory topics from the GCS JSON file.
516
+
517
+ Returns:
518
+ List[str]: List of unique memory topics.
519
+ """
493
520
  try:
494
521
  memories = self._read_json_file(self.memory_table_name)
495
522
  topics = set()
@@ -504,14 +531,27 @@ class GcsJsonDb(BaseDb):
504
531
  raise e
505
532
 
506
533
  def get_user_memory(
507
- self, memory_id: str, deserialize: Optional[bool] = True
534
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
508
535
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
509
- """Get a memory from the GCS JSON file."""
536
+ """Get a memory from the GCS JSON file.
537
+
538
+ Args:
539
+ memory_id (str): The ID of the memory to retrieve.
540
+ deserialize (Optional[bool]): Whether to deserialize to UserMemory object. Defaults to True.
541
+ user_id (Optional[str]): The ID of the user. If provided, verifies ownership before returning.
542
+
543
+ Returns:
544
+ Optional[Union[UserMemory, Dict[str, Any]]]: The memory if found and ownership matches, None otherwise.
545
+ """
510
546
  try:
511
547
  memories = self._read_json_file(self.memory_table_name)
512
548
 
513
549
  for memory_data in memories:
514
550
  if memory_data.get("memory_id") == memory_id:
551
+ # Verify user ownership if user_id is provided
552
+ if user_id is not None and memory_data.get("user_id") != user_id:
553
+ continue
554
+
515
555
  if not deserialize:
516
556
  return memory_data
517
557
 
@@ -583,20 +623,33 @@ class GcsJsonDb(BaseDb):
583
623
  def get_user_memory_stats(
584
624
  self, limit: Optional[int] = None, page: Optional[int] = None
585
625
  ) -> Tuple[List[Dict[str, Any]], int]:
586
- """Get user memory statistics."""
626
+ """Get user memory statistics.
627
+
628
+ Args:
629
+ limit (Optional[int]): Maximum number of results to return.
630
+ page (Optional[int]): Page number for pagination.
631
+
632
+ Returns:
633
+ Tuple[List[Dict[str, Any]], int]: List of user memory statistics and total count.
634
+ """
587
635
  try:
588
636
  memories = self._read_json_file(self.memory_table_name)
589
637
  user_stats = {}
590
638
 
591
639
  for memory in memories:
592
- user_id = memory.get("user_id")
593
- if user_id:
594
- if user_id not in user_stats:
595
- user_stats[user_id] = {"user_id": user_id, "total_memories": 0, "last_memory_updated_at": 0}
596
- user_stats[user_id]["total_memories"] += 1
640
+ memory_user_id = memory.get("user_id")
641
+
642
+ if memory_user_id:
643
+ if memory_user_id not in user_stats:
644
+ user_stats[memory_user_id] = {
645
+ "user_id": memory_user_id,
646
+ "total_memories": 0,
647
+ "last_memory_updated_at": 0,
648
+ }
649
+ user_stats[memory_user_id]["total_memories"] += 1
597
650
  updated_at = memory.get("updated_at", 0)
598
- if updated_at > user_stats[user_id]["last_memory_updated_at"]:
599
- user_stats[user_id]["last_memory_updated_at"] = updated_at
651
+ if updated_at > user_stats[memory_user_id]["last_memory_updated_at"]:
652
+ user_stats[memory_user_id]["last_memory_updated_at"] = updated_at
600
653
 
601
654
  stats_list = list(user_stats.values())
602
655
  stats_list.sort(key=lambda x: x["last_memory_updated_at"], reverse=True)