agno 2.1.0__py3-none-any.whl → 2.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. agno/agent/agent.py +13 -1
  2. agno/db/base.py +8 -4
  3. agno/db/dynamo/dynamo.py +69 -17
  4. agno/db/firestore/firestore.py +68 -29
  5. agno/db/gcs_json/gcs_json_db.py +68 -17
  6. agno/db/in_memory/in_memory_db.py +83 -14
  7. agno/db/json/json_db.py +79 -15
  8. agno/db/mongo/mongo.py +92 -74
  9. agno/db/mysql/mysql.py +17 -3
  10. agno/db/postgres/postgres.py +21 -3
  11. agno/db/redis/redis.py +38 -11
  12. agno/db/singlestore/singlestore.py +14 -3
  13. agno/db/sqlite/sqlite.py +34 -46
  14. agno/db/utils.py +50 -22
  15. agno/knowledge/knowledge.py +6 -0
  16. agno/knowledge/reader/field_labeled_csv_reader.py +294 -0
  17. agno/knowledge/reader/pdf_reader.py +28 -52
  18. agno/knowledge/reader/reader_factory.py +12 -0
  19. agno/memory/manager.py +12 -4
  20. agno/models/anthropic/claude.py +4 -1
  21. agno/models/aws/bedrock.py +52 -112
  22. agno/models/openai/responses.py +1 -1
  23. agno/os/app.py +24 -30
  24. agno/os/interfaces/__init__.py +1 -0
  25. agno/os/interfaces/a2a/__init__.py +3 -0
  26. agno/os/interfaces/a2a/a2a.py +42 -0
  27. agno/os/interfaces/a2a/router.py +252 -0
  28. agno/os/interfaces/a2a/utils.py +924 -0
  29. agno/os/interfaces/agui/agui.py +21 -5
  30. agno/os/interfaces/agui/router.py +12 -0
  31. agno/os/interfaces/base.py +4 -2
  32. agno/os/interfaces/slack/slack.py +13 -8
  33. agno/os/interfaces/whatsapp/whatsapp.py +12 -5
  34. agno/os/mcp.py +1 -1
  35. agno/os/router.py +39 -9
  36. agno/os/routers/memory/memory.py +5 -3
  37. agno/os/routers/memory/schemas.py +1 -0
  38. agno/os/utils.py +36 -10
  39. agno/run/base.py +2 -13
  40. agno/team/team.py +13 -1
  41. agno/tools/mcp.py +46 -1
  42. agno/utils/merge_dict.py +22 -1
  43. agno/utils/serialize.py +32 -0
  44. agno/utils/streamlit.py +1 -1
  45. agno/workflow/parallel.py +90 -14
  46. agno/workflow/step.py +30 -27
  47. agno/workflow/types.py +4 -6
  48. agno/workflow/workflow.py +5 -3
  49. {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/METADATA +16 -14
  50. {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/RECORD +53 -47
  51. {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/WHEEL +0 -0
  52. {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/licenses/LICENSE +0 -0
  53. {agno-2.1.0.dist-info → agno-2.1.2.dist-info}/top_level.txt +0 -0
@@ -343,10 +343,27 @@ class InMemoryDb(BaseDb):
343
343
  return []
344
344
 
345
345
  # -- Memory methods --
346
- def delete_user_memory(self, memory_id: str):
346
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
347
+ """Delete a user memory from in-memory storage.
348
+
349
+ Args:
350
+ memory_id (str): The ID of the memory to delete.
351
+ user_id (Optional[str]): The ID of the user. If provided, verifies the memory belongs to this user before deletion.
352
+
353
+ Raises:
354
+ Exception: If an error occurs during deletion.
355
+ """
347
356
  try:
348
357
  original_count = len(self._memories)
349
- self._memories = [m for m in self._memories if m.get("memory_id") != memory_id]
358
+
359
+ # If user_id is provided, verify ownership before deleting
360
+ if user_id is not None:
361
+ self._memories = [
362
+ m for m in self._memories
363
+ if not (m.get("memory_id") == memory_id and m.get("user_id") == user_id)
364
+ ]
365
+ else:
366
+ self._memories = [m for m in self._memories if m.get("memory_id") != memory_id]
350
367
 
351
368
  if len(self._memories) < original_count:
352
369
  log_debug(f"Successfully deleted user memory id: {memory_id}")
@@ -357,10 +374,25 @@ class InMemoryDb(BaseDb):
357
374
  log_error(f"Error deleting memory: {e}")
358
375
  raise e
359
376
 
360
- def delete_user_memories(self, memory_ids: List[str]) -> None:
361
- """Delete multiple user memories from in-memory storage."""
377
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
378
+ """Delete multiple user memories from in-memory storage.
379
+
380
+ Args:
381
+ memory_ids (List[str]): The IDs of the memories to delete.
382
+ user_id (Optional[str]): The ID of the user. If provided, only deletes memories belonging to this user.
383
+
384
+ Raises:
385
+ Exception: If an error occurs during deletion.
386
+ """
362
387
  try:
363
- self._memories = [m for m in self._memories if m.get("memory_id") not in memory_ids]
388
+ # If user_id is provided, verify ownership before deleting
389
+ if user_id is not None:
390
+ self._memories = [
391
+ m for m in self._memories
392
+ if not (m.get("memory_id") in memory_ids and m.get("user_id") == user_id)
393
+ ]
394
+ else:
395
+ self._memories = [m for m in self._memories if m.get("memory_id") not in memory_ids]
364
396
  log_debug(f"Successfully deleted {len(memory_ids)} user memories")
365
397
 
366
398
  except Exception as e:
@@ -368,6 +400,14 @@ class InMemoryDb(BaseDb):
368
400
  raise e
369
401
 
370
402
  def get_all_memory_topics(self) -> List[str]:
403
+ """Get all memory topics from in-memory storage.
404
+
405
+ Returns:
406
+ List[str]: List of unique topics.
407
+
408
+ Raises:
409
+ Exception: If an error occurs while reading topics.
410
+ """
371
411
  try:
372
412
  topics = set()
373
413
  for memory in self._memories:
@@ -381,11 +421,28 @@ class InMemoryDb(BaseDb):
381
421
  raise e
382
422
 
383
423
  def get_user_memory(
384
- self, memory_id: str, deserialize: Optional[bool] = True
424
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
385
425
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
426
+ """Get a user memory from in-memory storage.
427
+
428
+ Args:
429
+ memory_id (str): The ID of the memory to retrieve.
430
+ deserialize (Optional[bool]): Whether to deserialize the memory. Defaults to True.
431
+ user_id (Optional[str]): The ID of the user. If provided, only returns the memory if it belongs to this user.
432
+
433
+ Returns:
434
+ Optional[Union[UserMemory, Dict[str, Any]]]: The memory object or dictionary, or None if not found.
435
+
436
+ Raises:
437
+ Exception: If an error occurs while reading the memory.
438
+ """
386
439
  try:
387
440
  for memory_data in self._memories:
388
441
  if memory_data.get("memory_id") == memory_id:
442
+ # Filter by user_id if provided
443
+ if user_id is not None and memory_data.get("user_id") != user_id:
444
+ continue
445
+
389
446
  memory_data_copy = deepcopy(memory_data)
390
447
  if not deserialize:
391
448
  return memory_data_copy
@@ -455,19 +512,31 @@ class InMemoryDb(BaseDb):
455
512
  def get_user_memory_stats(
456
513
  self, limit: Optional[int] = None, page: Optional[int] = None
457
514
  ) -> Tuple[List[Dict[str, Any]], int]:
458
- """Get user memory statistics."""
515
+ """Get user memory statistics.
516
+
517
+ Args:
518
+ limit (Optional[int]): Maximum number of stats to return.
519
+ page (Optional[int]): Page number for pagination.
520
+
521
+ Returns:
522
+ Tuple[List[Dict[str, Any]], int]: List of user memory statistics and total count.
523
+
524
+ Raises:
525
+ Exception: If an error occurs while getting stats.
526
+ """
459
527
  try:
460
528
  user_stats = {}
461
529
 
462
530
  for memory in self._memories:
463
- user_id = memory.get("user_id")
464
- if user_id:
465
- if user_id not in user_stats:
466
- user_stats[user_id] = {"user_id": user_id, "total_memories": 0, "last_memory_updated_at": 0}
467
- user_stats[user_id]["total_memories"] += 1
531
+ memory_user_id = memory.get("user_id")
532
+
533
+ if memory_user_id:
534
+ if memory_user_id not in user_stats:
535
+ user_stats[memory_user_id] = {"user_id": memory_user_id, "total_memories": 0, "last_memory_updated_at": 0}
536
+ user_stats[memory_user_id]["total_memories"] += 1
468
537
  updated_at = memory.get("updated_at", 0)
469
- if updated_at > user_stats[user_id]["last_memory_updated_at"]:
470
- user_stats[user_id]["last_memory_updated_at"] = updated_at
538
+ if updated_at > user_stats[memory_user_id]["last_memory_updated_at"]:
539
+ user_stats[memory_user_id]["last_memory_updated_at"] = updated_at
471
540
 
472
541
  stats_list = list(user_stats.values())
473
542
  stats_list.sort(key=lambda x: x["last_memory_updated_at"], reverse=True)
agno/db/json/json_db.py CHANGED
@@ -445,11 +445,29 @@ class JsonDb(BaseDb):
445
445
  return False
446
446
 
447
447
  # -- Memory methods --
448
- def delete_user_memory(self, memory_id: str):
449
- """Delete a user memory from the JSON file."""
448
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
449
+ """Delete a user memory from the JSON file.
450
+
451
+ Args:
452
+ memory_id (str): The ID of the memory to delete.
453
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
454
+ """
450
455
  try:
451
456
  memories = self._read_json_file(self.memory_table_name)
452
457
  original_count = len(memories)
458
+
459
+ # If user_id is provided, verify the memory belongs to the user before deleting
460
+ if user_id:
461
+ memory_to_delete = None
462
+ for m in memories:
463
+ if m.get("memory_id") == memory_id:
464
+ memory_to_delete = m
465
+ break
466
+
467
+ if memory_to_delete and memory_to_delete.get("user_id") != user_id:
468
+ log_debug(f"Memory {memory_id} does not belong to user {user_id}")
469
+ return
470
+
453
471
  memories = [m for m in memories if m.get("memory_id") != memory_id]
454
472
 
455
473
  if len(memories) < original_count:
@@ -462,10 +480,24 @@ class JsonDb(BaseDb):
462
480
  log_error(f"Error deleting memory: {e}")
463
481
  raise e
464
482
 
465
- def delete_user_memories(self, memory_ids: List[str]) -> None:
466
- """Delete multiple user memories from the JSON file."""
483
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
484
+ """Delete multiple user memories from the JSON file.
485
+
486
+ Args:
487
+ memory_ids (List[str]): List of memory IDs to delete.
488
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
489
+ """
467
490
  try:
468
491
  memories = self._read_json_file(self.memory_table_name)
492
+
493
+ # If user_id is provided, filter memory_ids to only those belonging to the user
494
+ if user_id:
495
+ filtered_memory_ids: List[str] = []
496
+ for memory in memories:
497
+ if memory.get("memory_id") in memory_ids and memory.get("user_id") == user_id:
498
+ filtered_memory_ids.append(memory.get("memory_id")) # type: ignore
499
+ memory_ids = filtered_memory_ids
500
+
469
501
  memories = [m for m in memories if m.get("memory_id") not in memory_ids]
470
502
  self._write_json_file(self.memory_table_name, memories)
471
503
 
@@ -476,7 +508,11 @@ class JsonDb(BaseDb):
476
508
  raise e
477
509
 
478
510
  def get_all_memory_topics(self) -> List[str]:
479
- """Get all memory topics from the JSON file."""
511
+ """Get all memory topics from the JSON file.
512
+
513
+ Returns:
514
+ List[str]: List of unique memory topics.
515
+ """
480
516
  try:
481
517
  memories = self._read_json_file(self.memory_table_name)
482
518
 
@@ -492,14 +528,30 @@ class JsonDb(BaseDb):
492
528
  raise e
493
529
 
494
530
  def get_user_memory(
495
- self, memory_id: str, deserialize: Optional[bool] = True
531
+ self,
532
+ memory_id: str,
533
+ deserialize: Optional[bool] = True,
534
+ user_id: Optional[str] = None,
496
535
  ) -> Optional[Union[UserMemory, Dict[str, Any]]]:
497
- """Get a memory from the JSON file."""
536
+ """Get a memory from the JSON file.
537
+
538
+ Args:
539
+ memory_id (str): The ID of the memory to get.
540
+ deserialize (Optional[bool]): Whether to deserialize the memory.
541
+ user_id (Optional[str]): The ID of the user (optional, for filtering).
542
+
543
+ Returns:
544
+ Optional[Union[UserMemory, Dict[str, Any]]]: The user memory data if found, None otherwise.
545
+ """
498
546
  try:
499
547
  memories = self._read_json_file(self.memory_table_name)
500
548
 
501
549
  for memory_data in memories:
502
550
  if memory_data.get("memory_id") == memory_id:
551
+ # Filter by user_id if provided
552
+ if user_id and memory_data.get("user_id") != user_id:
553
+ return None
554
+
503
555
  if not deserialize:
504
556
  return memory_data
505
557
  return UserMemory.from_dict(memory_data)
@@ -571,20 +623,32 @@ class JsonDb(BaseDb):
571
623
  def get_user_memory_stats(
572
624
  self, limit: Optional[int] = None, page: Optional[int] = None
573
625
  ) -> Tuple[List[Dict[str, Any]], int]:
574
- """Get user memory statistics."""
626
+ """Get user memory statistics.
627
+
628
+ Args:
629
+ limit (Optional[int]): The maximum number of user stats to return.
630
+ page (Optional[int]): The page number.
631
+
632
+ Returns:
633
+ Tuple[List[Dict[str, Any]], int]: A list of dictionaries containing user stats and total count.
634
+ """
575
635
  try:
576
636
  memories = self._read_json_file(self.memory_table_name)
577
637
  user_stats = {}
578
638
 
579
639
  for memory in memories:
580
- user_id = memory.get("user_id")
581
- if user_id:
582
- if user_id not in user_stats:
583
- user_stats[user_id] = {"user_id": user_id, "total_memories": 0, "last_memory_updated_at": 0}
584
- user_stats[user_id]["total_memories"] += 1
640
+ memory_user_id = memory.get("user_id")
641
+ if memory_user_id:
642
+ if memory_user_id not in user_stats:
643
+ user_stats[memory_user_id] = {
644
+ "user_id": memory_user_id,
645
+ "total_memories": 0,
646
+ "last_memory_updated_at": 0,
647
+ }
648
+ user_stats[memory_user_id]["total_memories"] += 1
585
649
  updated_at = memory.get("updated_at", 0)
586
- if updated_at > user_stats[user_id]["last_memory_updated_at"]:
587
- user_stats[user_id]["last_memory_updated_at"] = updated_at
650
+ if updated_at > user_stats[memory_user_id]["last_memory_updated_at"]:
651
+ user_stats[memory_user_id]["last_memory_updated_at"] = updated_at
588
652
 
589
653
  stats_list = list(user_stats.values())
590
654
  stats_list.sort(key=lambda x: x["last_memory_updated_at"], reverse=True)
agno/db/mongo/mongo.py CHANGED
@@ -16,7 +16,7 @@ from agno.db.mongo.utils import (
16
16
  from agno.db.schemas.evals import EvalFilterType, EvalRunRecord, EvalType
17
17
  from agno.db.schemas.knowledge import KnowledgeRow
18
18
  from agno.db.schemas.memory import UserMemory
19
- from agno.db.utils import deserialize_session_json_fields, serialize_session_json_fields
19
+ from agno.db.utils import deserialize_session_json_fields
20
20
  from agno.session import AgentSession, Session, TeamSession, WorkflowSession
21
21
  from agno.utils.log import log_debug, log_error, log_info
22
22
  from agno.utils.string import generate_id
@@ -282,7 +282,6 @@ class MongoDb(BaseDb):
282
282
  return None
283
283
 
284
284
  session = deserialize_session_json_fields(result)
285
-
286
285
  if not deserialize:
287
286
  return session
288
287
 
@@ -385,7 +384,6 @@ class MongoDb(BaseDb):
385
384
  records = list(cursor)
386
385
  if records is None:
387
386
  return [] if deserialize else ([], 0)
388
-
389
387
  sessions_raw = [deserialize_session_json_fields(record) for record in records]
390
388
 
391
389
  if not deserialize:
@@ -489,25 +487,25 @@ class MongoDb(BaseDb):
489
487
  if collection is None:
490
488
  return None
491
489
 
492
- serialized_session_dict = serialize_session_json_fields(session.to_dict())
490
+ session_dict = session.to_dict()
493
491
 
494
492
  if isinstance(session, AgentSession):
495
493
  record = {
496
- "session_id": serialized_session_dict.get("session_id"),
494
+ "session_id": session_dict.get("session_id"),
497
495
  "session_type": SessionType.AGENT.value,
498
- "agent_id": serialized_session_dict.get("agent_id"),
499
- "user_id": serialized_session_dict.get("user_id"),
500
- "runs": serialized_session_dict.get("runs"),
501
- "agent_data": serialized_session_dict.get("agent_data"),
502
- "session_data": serialized_session_dict.get("session_data"),
503
- "summary": serialized_session_dict.get("summary"),
504
- "metadata": serialized_session_dict.get("metadata"),
505
- "created_at": serialized_session_dict.get("created_at"),
496
+ "agent_id": session_dict.get("agent_id"),
497
+ "user_id": session_dict.get("user_id"),
498
+ "runs": session_dict.get("runs"),
499
+ "agent_data": session_dict.get("agent_data"),
500
+ "session_data": session_dict.get("session_data"),
501
+ "summary": session_dict.get("summary"),
502
+ "metadata": session_dict.get("metadata"),
503
+ "created_at": session_dict.get("created_at"),
506
504
  "updated_at": int(time.time()),
507
505
  }
508
506
 
509
507
  result = collection.find_one_and_replace(
510
- filter={"session_id": serialized_session_dict.get("session_id")},
508
+ filter={"session_id": session_dict.get("session_id")},
511
509
  replacement=record,
512
510
  upsert=True,
513
511
  return_document=ReturnDocument.AFTER,
@@ -515,7 +513,7 @@ class MongoDb(BaseDb):
515
513
  if not result:
516
514
  return None
517
515
 
518
- session = deserialize_session_json_fields(result) # type: ignore
516
+ session = result # type: ignore
519
517
 
520
518
  if not deserialize:
521
519
  return session
@@ -524,21 +522,21 @@ class MongoDb(BaseDb):
524
522
 
525
523
  elif isinstance(session, TeamSession):
526
524
  record = {
527
- "session_id": serialized_session_dict.get("session_id"),
525
+ "session_id": session_dict.get("session_id"),
528
526
  "session_type": SessionType.TEAM.value,
529
- "team_id": serialized_session_dict.get("team_id"),
530
- "user_id": serialized_session_dict.get("user_id"),
531
- "runs": serialized_session_dict.get("runs"),
532
- "team_data": serialized_session_dict.get("team_data"),
533
- "session_data": serialized_session_dict.get("session_data"),
534
- "summary": serialized_session_dict.get("summary"),
535
- "metadata": serialized_session_dict.get("metadata"),
536
- "created_at": serialized_session_dict.get("created_at"),
527
+ "team_id": session_dict.get("team_id"),
528
+ "user_id": session_dict.get("user_id"),
529
+ "runs": session_dict.get("runs"),
530
+ "team_data": session_dict.get("team_data"),
531
+ "session_data": session_dict.get("session_data"),
532
+ "summary": session_dict.get("summary"),
533
+ "metadata": session_dict.get("metadata"),
534
+ "created_at": session_dict.get("created_at"),
537
535
  "updated_at": int(time.time()),
538
536
  }
539
537
 
540
538
  result = collection.find_one_and_replace(
541
- filter={"session_id": serialized_session_dict.get("session_id")},
539
+ filter={"session_id": session_dict.get("session_id")},
542
540
  replacement=record,
543
541
  upsert=True,
544
542
  return_document=ReturnDocument.AFTER,
@@ -546,7 +544,8 @@ class MongoDb(BaseDb):
546
544
  if not result:
547
545
  return None
548
546
 
549
- session = deserialize_session_json_fields(result) # type: ignore
547
+ # MongoDB stores native objects, no deserialization needed for document fields
548
+ session = result # type: ignore
550
549
 
551
550
  if not deserialize:
552
551
  return session
@@ -555,21 +554,21 @@ class MongoDb(BaseDb):
555
554
 
556
555
  else:
557
556
  record = {
558
- "session_id": serialized_session_dict.get("session_id"),
557
+ "session_id": session_dict.get("session_id"),
559
558
  "session_type": SessionType.WORKFLOW.value,
560
- "workflow_id": serialized_session_dict.get("workflow_id"),
561
- "user_id": serialized_session_dict.get("user_id"),
562
- "runs": serialized_session_dict.get("runs"),
563
- "workflow_data": serialized_session_dict.get("workflow_data"),
564
- "session_data": serialized_session_dict.get("session_data"),
565
- "summary": serialized_session_dict.get("summary"),
566
- "metadata": serialized_session_dict.get("metadata"),
567
- "created_at": serialized_session_dict.get("created_at"),
559
+ "workflow_id": session_dict.get("workflow_id"),
560
+ "user_id": session_dict.get("user_id"),
561
+ "runs": session_dict.get("runs"),
562
+ "workflow_data": session_dict.get("workflow_data"),
563
+ "session_data": session_dict.get("session_data"),
564
+ "summary": session_dict.get("summary"),
565
+ "metadata": session_dict.get("metadata"),
566
+ "created_at": session_dict.get("created_at"),
568
567
  "updated_at": int(time.time()),
569
568
  }
570
569
 
571
570
  result = collection.find_one_and_replace(
572
- filter={"session_id": serialized_session_dict.get("session_id")},
571
+ filter={"session_id": session_dict.get("session_id")},
573
572
  replacement=record,
574
573
  upsert=True,
575
574
  return_document=ReturnDocument.AFTER,
@@ -577,7 +576,7 @@ class MongoDb(BaseDb):
577
576
  if not result:
578
577
  return None
579
578
 
580
- session = deserialize_session_json_fields(result) # type: ignore
579
+ session = result # type: ignore
581
580
 
582
581
  if not deserialize:
583
582
  return session
@@ -628,48 +627,48 @@ class MongoDb(BaseDb):
628
627
  if session is None:
629
628
  continue
630
629
 
631
- serialized_session_dict = serialize_session_json_fields(session.to_dict())
630
+ session_dict = session.to_dict()
632
631
 
633
632
  if isinstance(session, AgentSession):
634
633
  record = {
635
- "session_id": serialized_session_dict.get("session_id"),
634
+ "session_id": session_dict.get("session_id"),
636
635
  "session_type": SessionType.AGENT.value,
637
- "agent_id": serialized_session_dict.get("agent_id"),
638
- "user_id": serialized_session_dict.get("user_id"),
639
- "runs": serialized_session_dict.get("runs"),
640
- "agent_data": serialized_session_dict.get("agent_data"),
641
- "session_data": serialized_session_dict.get("session_data"),
642
- "summary": serialized_session_dict.get("summary"),
643
- "metadata": serialized_session_dict.get("metadata"),
644
- "created_at": serialized_session_dict.get("created_at"),
636
+ "agent_id": session_dict.get("agent_id"),
637
+ "user_id": session_dict.get("user_id"),
638
+ "runs": session_dict.get("runs"),
639
+ "agent_data": session_dict.get("agent_data"),
640
+ "session_data": session_dict.get("session_data"),
641
+ "summary": session_dict.get("summary"),
642
+ "metadata": session_dict.get("metadata"),
643
+ "created_at": session_dict.get("created_at"),
645
644
  "updated_at": int(time.time()),
646
645
  }
647
646
  elif isinstance(session, TeamSession):
648
647
  record = {
649
- "session_id": serialized_session_dict.get("session_id"),
648
+ "session_id": session_dict.get("session_id"),
650
649
  "session_type": SessionType.TEAM.value,
651
- "team_id": serialized_session_dict.get("team_id"),
652
- "user_id": serialized_session_dict.get("user_id"),
653
- "runs": serialized_session_dict.get("runs"),
654
- "team_data": serialized_session_dict.get("team_data"),
655
- "session_data": serialized_session_dict.get("session_data"),
656
- "summary": serialized_session_dict.get("summary"),
657
- "metadata": serialized_session_dict.get("metadata"),
658
- "created_at": serialized_session_dict.get("created_at"),
650
+ "team_id": session_dict.get("team_id"),
651
+ "user_id": session_dict.get("user_id"),
652
+ "runs": session_dict.get("runs"),
653
+ "team_data": session_dict.get("team_data"),
654
+ "session_data": session_dict.get("session_data"),
655
+ "summary": session_dict.get("summary"),
656
+ "metadata": session_dict.get("metadata"),
657
+ "created_at": session_dict.get("created_at"),
659
658
  "updated_at": int(time.time()),
660
659
  }
661
660
  elif isinstance(session, WorkflowSession):
662
661
  record = {
663
- "session_id": serialized_session_dict.get("session_id"),
662
+ "session_id": session_dict.get("session_id"),
664
663
  "session_type": SessionType.WORKFLOW.value,
665
- "workflow_id": serialized_session_dict.get("workflow_id"),
666
- "user_id": serialized_session_dict.get("user_id"),
667
- "runs": serialized_session_dict.get("runs"),
668
- "workflow_data": serialized_session_dict.get("workflow_data"),
669
- "session_data": serialized_session_dict.get("session_data"),
670
- "summary": serialized_session_dict.get("summary"),
671
- "metadata": serialized_session_dict.get("metadata"),
672
- "created_at": serialized_session_dict.get("created_at"),
664
+ "workflow_id": session_dict.get("workflow_id"),
665
+ "user_id": session_dict.get("user_id"),
666
+ "runs": session_dict.get("runs"),
667
+ "workflow_data": session_dict.get("workflow_data"),
668
+ "session_data": session_dict.get("session_data"),
669
+ "summary": session_dict.get("summary"),
670
+ "metadata": session_dict.get("metadata"),
671
+ "created_at": session_dict.get("created_at"),
673
672
  "updated_at": int(time.time()),
674
673
  }
675
674
  else:
@@ -688,7 +687,7 @@ class MongoDb(BaseDb):
688
687
  cursor = collection.find({"session_id": {"$in": session_ids}})
689
688
 
690
689
  for doc in cursor:
691
- session_dict = deserialize_session_json_fields(doc)
690
+ session_dict = doc
692
691
 
693
692
  if deserialize:
694
693
  session_type = doc.get("session_type")
@@ -728,11 +727,12 @@ class MongoDb(BaseDb):
728
727
 
729
728
  # -- Memory methods --
730
729
 
731
- def delete_user_memory(self, memory_id: str):
730
+ def delete_user_memory(self, memory_id: str, user_id: Optional[str] = None):
732
731
  """Delete a user memory from the database.
733
732
 
734
733
  Args:
735
734
  memory_id (str): The ID of the memory to delete.
735
+ user_id (Optional[str]): The ID of the user to verify ownership. If provided, only delete if the memory belongs to this user.
736
736
 
737
737
  Returns:
738
738
  bool: True if the memory was deleted, False otherwise.
@@ -745,7 +745,11 @@ class MongoDb(BaseDb):
745
745
  if collection is None:
746
746
  return
747
747
 
748
- result = collection.delete_one({"memory_id": memory_id})
748
+ query = {"memory_id": memory_id}
749
+ if user_id is not None:
750
+ query["user_id"] = user_id
751
+
752
+ result = collection.delete_one(query)
749
753
 
750
754
  success = result.deleted_count > 0
751
755
  if success:
@@ -757,11 +761,12 @@ class MongoDb(BaseDb):
757
761
  log_error(f"Error deleting memory: {e}")
758
762
  raise e
759
763
 
760
- def delete_user_memories(self, memory_ids: List[str]) -> None:
764
+ def delete_user_memories(self, memory_ids: List[str], user_id: Optional[str] = None) -> None:
761
765
  """Delete user memories from the database.
762
766
 
763
767
  Args:
764
768
  memory_ids (List[str]): The IDs of the memories to delete.
769
+ user_id (Optional[str]): The ID of the user to verify ownership. If provided, only delete memories that belong to this user.
765
770
 
766
771
  Raises:
767
772
  Exception: If there is an error deleting the memories.
@@ -771,7 +776,11 @@ class MongoDb(BaseDb):
771
776
  if collection is None:
772
777
  return
773
778
 
774
- result = collection.delete_many({"memory_id": {"$in": memory_ids}})
779
+ query: Dict[str, Any] = {"memory_id": {"$in": memory_ids}}
780
+ if user_id is not None:
781
+ query["user_id"] = user_id
782
+
783
+ result = collection.delete_many(query)
775
784
 
776
785
  if result.deleted_count == 0:
777
786
  log_debug(f"No memories found with ids: {memory_ids}")
@@ -794,19 +803,22 @@ class MongoDb(BaseDb):
794
803
  if collection is None:
795
804
  return []
796
805
 
797
- topics = collection.distinct("topics")
806
+ topics = collection.distinct("topics", {})
798
807
  return [topic for topic in topics if topic]
799
808
 
800
809
  except Exception as e:
801
810
  log_error(f"Exception reading from collection: {e}")
802
811
  raise e
803
812
 
804
- def get_user_memory(self, memory_id: str, deserialize: Optional[bool] = True) -> Optional[UserMemory]:
813
+ def get_user_memory(
814
+ self, memory_id: str, deserialize: Optional[bool] = True, user_id: Optional[str] = None
815
+ ) -> Optional[UserMemory]:
805
816
  """Get a memory from the database.
806
817
 
807
818
  Args:
808
819
  memory_id (str): The ID of the memory to get.
809
820
  deserialize (Optional[bool]): Whether to serialize the memory. Defaults to True.
821
+ user_id (Optional[str]): The ID of the user to verify ownership. If provided, only return the memory if it belongs to this user.
810
822
 
811
823
  Returns:
812
824
  Optional[UserMemory]:
@@ -821,7 +833,11 @@ class MongoDb(BaseDb):
821
833
  if collection is None:
822
834
  return None
823
835
 
824
- result = collection.find_one({"memory_id": memory_id})
836
+ query = {"memory_id": memory_id}
837
+ if user_id is not None:
838
+ query["user_id"] = user_id
839
+
840
+ result = collection.find_one(query)
825
841
  if result is None or not deserialize:
826
842
  return result
827
843
 
@@ -934,8 +950,10 @@ class MongoDb(BaseDb):
934
950
  if collection is None:
935
951
  return [], 0
936
952
 
953
+ match_stage = {"user_id": {"$ne": None}}
954
+
937
955
  pipeline = [
938
- {"$match": {"user_id": {"$ne": None}}},
956
+ {"$match": match_stage},
939
957
  {
940
958
  "$group": {
941
959
  "_id": "$user_id",