agno 2.0.11__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. agno/agent/agent.py +607 -176
  2. agno/db/in_memory/in_memory_db.py +42 -29
  3. agno/db/mongo/mongo.py +65 -66
  4. agno/db/postgres/postgres.py +6 -4
  5. agno/db/utils.py +50 -22
  6. agno/exceptions.py +62 -1
  7. agno/guardrails/__init__.py +6 -0
  8. agno/guardrails/base.py +19 -0
  9. agno/guardrails/openai.py +144 -0
  10. agno/guardrails/pii.py +94 -0
  11. agno/guardrails/prompt_injection.py +51 -0
  12. agno/knowledge/embedder/aws_bedrock.py +9 -4
  13. agno/knowledge/embedder/azure_openai.py +54 -0
  14. agno/knowledge/embedder/base.py +2 -0
  15. agno/knowledge/embedder/cohere.py +184 -5
  16. agno/knowledge/embedder/google.py +79 -1
  17. agno/knowledge/embedder/huggingface.py +9 -4
  18. agno/knowledge/embedder/jina.py +63 -0
  19. agno/knowledge/embedder/mistral.py +78 -11
  20. agno/knowledge/embedder/ollama.py +5 -0
  21. agno/knowledge/embedder/openai.py +18 -54
  22. agno/knowledge/embedder/voyageai.py +69 -16
  23. agno/knowledge/knowledge.py +11 -4
  24. agno/knowledge/reader/pdf_reader.py +4 -3
  25. agno/knowledge/reader/website_reader.py +3 -2
  26. agno/models/base.py +125 -32
  27. agno/models/cerebras/cerebras.py +1 -0
  28. agno/models/cerebras/cerebras_openai.py +1 -0
  29. agno/models/dashscope/dashscope.py +1 -0
  30. agno/models/google/gemini.py +27 -5
  31. agno/models/openai/chat.py +13 -4
  32. agno/models/openai/responses.py +1 -1
  33. agno/models/perplexity/perplexity.py +2 -3
  34. agno/models/requesty/__init__.py +5 -0
  35. agno/models/requesty/requesty.py +49 -0
  36. agno/models/vllm/vllm.py +1 -0
  37. agno/models/xai/xai.py +1 -0
  38. agno/os/app.py +98 -126
  39. agno/os/interfaces/__init__.py +1 -0
  40. agno/os/interfaces/agui/agui.py +21 -5
  41. agno/os/interfaces/base.py +4 -2
  42. agno/os/interfaces/slack/slack.py +13 -8
  43. agno/os/interfaces/whatsapp/router.py +2 -0
  44. agno/os/interfaces/whatsapp/whatsapp.py +12 -5
  45. agno/os/mcp.py +2 -2
  46. agno/os/middleware/__init__.py +7 -0
  47. agno/os/middleware/jwt.py +233 -0
  48. agno/os/router.py +182 -46
  49. agno/os/routers/home.py +2 -2
  50. agno/os/routers/memory/memory.py +23 -1
  51. agno/os/routers/memory/schemas.py +1 -1
  52. agno/os/routers/session/session.py +20 -3
  53. agno/os/utils.py +74 -8
  54. agno/run/agent.py +120 -77
  55. agno/run/base.py +2 -13
  56. agno/run/team.py +115 -72
  57. agno/run/workflow.py +5 -15
  58. agno/session/summary.py +9 -10
  59. agno/session/team.py +2 -1
  60. agno/team/team.py +721 -169
  61. agno/tools/firecrawl.py +4 -4
  62. agno/tools/function.py +42 -2
  63. agno/tools/knowledge.py +3 -3
  64. agno/tools/searxng.py +2 -2
  65. agno/tools/serper.py +2 -2
  66. agno/tools/spider.py +2 -2
  67. agno/tools/workflow.py +4 -5
  68. agno/utils/events.py +66 -1
  69. agno/utils/hooks.py +57 -0
  70. agno/utils/media.py +11 -9
  71. agno/utils/print_response/agent.py +43 -5
  72. agno/utils/print_response/team.py +48 -12
  73. agno/utils/serialize.py +32 -0
  74. agno/vectordb/cassandra/cassandra.py +44 -4
  75. agno/vectordb/chroma/chromadb.py +79 -8
  76. agno/vectordb/clickhouse/clickhousedb.py +43 -6
  77. agno/vectordb/couchbase/couchbase.py +76 -5
  78. agno/vectordb/lancedb/lance_db.py +38 -3
  79. agno/vectordb/milvus/milvus.py +76 -4
  80. agno/vectordb/mongodb/mongodb.py +76 -4
  81. agno/vectordb/pgvector/pgvector.py +50 -6
  82. agno/vectordb/pineconedb/pineconedb.py +39 -2
  83. agno/vectordb/qdrant/qdrant.py +76 -26
  84. agno/vectordb/singlestore/singlestore.py +77 -4
  85. agno/vectordb/upstashdb/upstashdb.py +42 -2
  86. agno/vectordb/weaviate/weaviate.py +39 -3
  87. agno/workflow/types.py +5 -6
  88. agno/workflow/workflow.py +58 -2
  89. {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/METADATA +4 -3
  90. {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/RECORD +93 -82
  91. {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/WHEEL +0 -0
  92. {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/licenses/LICENSE +0 -0
  93. {agno-2.0.11.dist-info → agno-2.1.1.dist-info}/top_level.txt +0 -0
agno/os/router.py CHANGED
@@ -16,6 +16,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
16
16
  from pydantic import BaseModel
17
17
 
18
18
  from agno.agent.agent import Agent
19
+ from agno.exceptions import InputCheckError, OutputCheckError
19
20
  from agno.media import Audio, Image, Video
20
21
  from agno.media import File as FileMedia
21
22
  from agno.os.auth import get_authentication_dependency, validate_websocket_token
@@ -74,7 +75,6 @@ async def _get_request_kwargs(request: Request, endpoint_func: Callable) -> Dict
74
75
  kwargs = {key: value for key, value in form_data.items() if key not in known_fields}
75
76
 
76
77
  # Handle JSON parameters. They are passed as strings and need to be deserialized.
77
-
78
78
  if session_state := kwargs.get("session_state"):
79
79
  try:
80
80
  session_state_dict = json.loads(session_state) # type: ignore
@@ -82,6 +82,7 @@ async def _get_request_kwargs(request: Request, endpoint_func: Callable) -> Dict
82
82
  except json.JSONDecodeError:
83
83
  kwargs.pop("session_state")
84
84
  log_warning(f"Invalid session_state parameter couldn't be loaded: {session_state}")
85
+
85
86
  if dependencies := kwargs.get("dependencies"):
86
87
  try:
87
88
  dependencies_dict = json.loads(dependencies) # type: ignore
@@ -245,7 +246,14 @@ async def agent_response_streamer(
245
246
  )
246
247
  async for run_response_chunk in run_response:
247
248
  yield format_sse_event(run_response_chunk) # type: ignore
248
-
249
+ except (InputCheckError, OutputCheckError) as e:
250
+ error_response = RunErrorEvent(
251
+ content=str(e),
252
+ error_type=e.type,
253
+ error_id=e.error_id,
254
+ additional_data=e.additional_data,
255
+ )
256
+ yield format_sse_event(error_response)
249
257
  except Exception as e:
250
258
  import traceback
251
259
 
@@ -274,6 +282,14 @@ async def agent_continue_response_streamer(
274
282
  )
275
283
  async for run_response_chunk in continue_response:
276
284
  yield format_sse_event(run_response_chunk) # type: ignore
285
+ except (InputCheckError, OutputCheckError) as e:
286
+ error_response = RunErrorEvent(
287
+ content=str(e),
288
+ error_type=e.type,
289
+ error_id=e.error_id,
290
+ additional_data=e.additional_data,
291
+ )
292
+ yield format_sse_event(error_response)
277
293
 
278
294
  except Exception as e:
279
295
  import traceback
@@ -281,6 +297,8 @@ async def agent_continue_response_streamer(
281
297
  traceback.print_exc(limit=3)
282
298
  error_response = RunErrorEvent(
283
299
  content=str(e),
300
+ error_type=e.type if hasattr(e, "type") else None,
301
+ error_id=e.error_id if hasattr(e, "error_id") else None,
284
302
  )
285
303
  yield format_sse_event(error_response)
286
304
  return
@@ -313,6 +331,14 @@ async def team_response_streamer(
313
331
  )
314
332
  async for run_response_chunk in run_response:
315
333
  yield format_sse_event(run_response_chunk) # type: ignore
334
+ except (InputCheckError, OutputCheckError) as e:
335
+ error_response = TeamRunErrorEvent(
336
+ content=str(e),
337
+ error_type=e.type,
338
+ error_id=e.error_id,
339
+ additional_data=e.additional_data,
340
+ )
341
+ yield format_sse_event(error_response)
316
342
 
317
343
  except Exception as e:
318
344
  import traceback
@@ -320,6 +346,8 @@ async def team_response_streamer(
320
346
  traceback.print_exc()
321
347
  error_response = TeamRunErrorEvent(
322
348
  content=str(e),
349
+ error_type=e.type if hasattr(e, "type") else None,
350
+ error_id=e.error_id if hasattr(e, "error_id") else None,
323
351
  )
324
352
  yield format_sse_event(error_response)
325
353
  return
@@ -366,9 +394,28 @@ async def handle_workflow_via_websocket(websocket: WebSocket, message: dict, os:
366
394
 
367
395
  await websocket_manager.register_workflow_websocket(workflow_run_output.run_id, websocket) # type: ignore
368
396
 
397
+ except (InputCheckError, OutputCheckError) as e:
398
+ await websocket.send_text(
399
+ json.dumps(
400
+ {
401
+ "event": "error",
402
+ "error": str(e),
403
+ "error_type": e.type,
404
+ "error_id": e.error_id,
405
+ "additional_data": e.additional_data,
406
+ }
407
+ )
408
+ )
369
409
  except Exception as e:
370
410
  logger.error(f"Error executing workflow via WebSocket: {e}")
371
- await websocket.send_text(json.dumps({"event": "error", "error": str(e)}))
411
+ error_payload = {
412
+ "event": "error",
413
+ "error": str(e),
414
+ "error_type": e.type if hasattr(e, "type") else None,
415
+ "error_id": e.error_id if hasattr(e, "error_id") else None,
416
+ }
417
+ error_payload = {k: v for k, v in error_payload.items() if v is not None}
418
+ await websocket.send_text(json.dumps(error_payload))
372
419
 
373
420
 
374
421
  async def workflow_response_streamer(
@@ -391,12 +438,23 @@ async def workflow_response_streamer(
391
438
  async for run_response_chunk in run_response:
392
439
  yield format_sse_event(run_response_chunk) # type: ignore
393
440
 
441
+ except (InputCheckError, OutputCheckError) as e:
442
+ error_response = WorkflowErrorEvent(
443
+ error=str(e),
444
+ error_type=e.type,
445
+ error_id=e.error_id,
446
+ additional_data=e.additional_data,
447
+ )
448
+ yield format_sse_event(error_response)
449
+
394
450
  except Exception as e:
395
451
  import traceback
396
452
 
397
453
  traceback.print_exc()
398
454
  error_response = WorkflowErrorEvent(
399
455
  error=str(e),
456
+ error_type=e.type if hasattr(e, "type") else None,
457
+ error_id=e.error_id if hasattr(e, "error_id") else None,
400
458
  )
401
459
  yield format_sse_event(error_response)
402
460
  return
@@ -464,7 +522,7 @@ def get_websocket_router(
464
522
  await websocket.send_text(json.dumps({"event": "error", "error": f"Unknown action: {action}"}))
465
523
 
466
524
  except Exception as e:
467
- if "1012" not in str(e):
525
+ if "1012" not in str(e) and "1001" not in str(e):
468
526
  logger.error(f"WebSocket error: {e}")
469
527
  finally:
470
528
  # Clean up the websocket connection
@@ -520,7 +578,7 @@ def get_base_router(
520
578
  "content": {
521
579
  "application/json": {
522
580
  "example": {
523
- "os_id": "demo",
581
+ "id": "demo",
524
582
  "description": "Example AgentOS configuration",
525
583
  "available_models": [],
526
584
  "databases": ["9c884dc4-9066-448c-9074-ef49ec7eb73c"],
@@ -582,7 +640,7 @@ def get_base_router(
582
640
  )
583
641
  async def config() -> ConfigResponse:
584
642
  return ConfigResponse(
585
- os_id=os.os_id or "Unnamed OS",
643
+ os_id=os.id or "Unnamed OS",
586
644
  description=os.description,
587
645
  available_models=os.config.available_models if os.config else [],
588
646
  databases=[db.id for db in os.dbs.values()],
@@ -596,7 +654,7 @@ def get_base_router(
596
654
  teams=[TeamSummaryResponse.from_team(team) for team in os.teams] if os.teams else [],
597
655
  workflows=[WorkflowSummaryResponse.from_workflow(w) for w in os.workflows] if os.workflows else [],
598
656
  interfaces=[
599
- InterfaceResponse(type=interface.type, version=interface.version, route=interface.router_prefix)
657
+ InterfaceResponse(type=interface.type, version=interface.version, route=interface.prefix)
600
658
  for interface in os.interfaces
601
659
  ],
602
660
  )
@@ -669,7 +727,7 @@ def get_base_router(
669
727
  "content": {
670
728
  "text/event-stream": {
671
729
  "examples": {
672
- "event_strea": {
730
+ "event_stream": {
673
731
  "summary": "Example event stream response",
674
732
  "value": 'event: RunStarted\ndata: {"content": "Hello!", "run_id": "123..."}\n\n',
675
733
  }
@@ -692,6 +750,25 @@ def get_base_router(
692
750
  ):
693
751
  kwargs = await _get_request_kwargs(request, create_agent_run)
694
752
 
753
+ if hasattr(request.state, "user_id"):
754
+ if user_id:
755
+ log_warning("User ID parameter passed in both request state and kwargs, using request state")
756
+ user_id = request.state.user_id
757
+ if hasattr(request.state, "session_id"):
758
+ if session_id:
759
+ log_warning("Session ID parameter passed in both request state and kwargs, using request state")
760
+ session_id = request.state.session_id
761
+ if hasattr(request.state, "session_state"):
762
+ session_state = request.state.session_state
763
+ if "session_state" in kwargs:
764
+ log_warning("Session state parameter passed in both request state and kwargs, using request state")
765
+ kwargs["session_state"] = session_state
766
+ if hasattr(request.state, "dependencies"):
767
+ dependencies = request.state.dependencies
768
+ if "dependencies" in kwargs:
769
+ log_warning("Dependencies parameter passed in both request state and kwargs, using request state")
770
+ kwargs["dependencies"] = dependencies
771
+
695
772
  agent = get_agent_by_id(agent_id, os.agents)
696
773
  if agent is None:
697
774
  raise HTTPException(status_code=404, detail="Agent not found")
@@ -774,21 +851,25 @@ def get_base_router(
774
851
  media_type="text/event-stream",
775
852
  )
776
853
  else:
777
- run_response = cast(
778
- RunOutput,
779
- await agent.arun(
780
- input=message,
781
- session_id=session_id,
782
- user_id=user_id,
783
- images=base64_images if base64_images else None,
784
- audio=base64_audios if base64_audios else None,
785
- videos=base64_videos if base64_videos else None,
786
- files=input_files if input_files else None,
787
- stream=False,
788
- **kwargs,
789
- ),
790
- )
791
- return run_response.to_dict()
854
+ try:
855
+ run_response = cast(
856
+ RunOutput,
857
+ await agent.arun(
858
+ input=message,
859
+ session_id=session_id,
860
+ user_id=user_id,
861
+ images=base64_images if base64_images else None,
862
+ audio=base64_audios if base64_audios else None,
863
+ videos=base64_videos if base64_videos else None,
864
+ files=input_files if input_files else None,
865
+ stream=False,
866
+ **kwargs,
867
+ ),
868
+ )
869
+ return run_response.to_dict()
870
+
871
+ except InputCheckError as e:
872
+ raise HTTPException(status_code=400, detail=str(e))
792
873
 
793
874
  @router.post(
794
875
  "/agents/{agent_id}/runs/{run_id}/cancel",
@@ -849,11 +930,17 @@ def get_base_router(
849
930
  async def continue_agent_run(
850
931
  agent_id: str,
851
932
  run_id: str,
933
+ request: Request,
852
934
  tools: str = Form(...), # JSON string of tools
853
935
  session_id: Optional[str] = Form(None),
854
936
  user_id: Optional[str] = Form(None),
855
937
  stream: bool = Form(True),
856
938
  ):
939
+ if hasattr(request.state, "user_id"):
940
+ user_id = request.state.user_id
941
+ if hasattr(request.state, "session_id"):
942
+ session_id = request.state.session_id
943
+
857
944
  # Parse the JSON string manually
858
945
  try:
859
946
  tools_data = json.loads(tools) if tools else None
@@ -891,17 +978,21 @@ def get_base_router(
891
978
  media_type="text/event-stream",
892
979
  )
893
980
  else:
894
- run_response_obj = cast(
895
- RunOutput,
896
- await agent.acontinue_run(
897
- run_id=run_id, # run_id from path
898
- updated_tools=updated_tools,
899
- session_id=session_id,
900
- user_id=user_id,
901
- stream=False,
902
- ),
903
- )
904
- return run_response_obj.to_dict()
981
+ try:
982
+ run_response_obj = cast(
983
+ RunOutput,
984
+ await agent.acontinue_run(
985
+ run_id=run_id, # run_id from path
986
+ updated_tools=updated_tools,
987
+ session_id=session_id,
988
+ user_id=user_id,
989
+ stream=False,
990
+ ),
991
+ )
992
+ return run_response_obj.to_dict()
993
+
994
+ except InputCheckError as e:
995
+ raise HTTPException(status_code=400, detail=str(e))
905
996
 
906
997
  @router.get(
907
998
  "/agents",
@@ -1041,6 +1132,25 @@ def get_base_router(
1041
1132
  ):
1042
1133
  kwargs = await _get_request_kwargs(request, create_team_run)
1043
1134
 
1135
+ if hasattr(request.state, "user_id"):
1136
+ if user_id:
1137
+ log_warning("User ID parameter passed in both request state and kwargs, using request state")
1138
+ user_id = request.state.user_id
1139
+ if hasattr(request.state, "session_id"):
1140
+ if session_id:
1141
+ log_warning("Session ID parameter passed in both request state and kwargs, using request state")
1142
+ session_id = request.state.session_id
1143
+ if hasattr(request.state, "session_state"):
1144
+ session_state = request.state.session_state
1145
+ if "session_state" in kwargs:
1146
+ log_warning("Session state parameter passed in both request state and kwargs, using request state")
1147
+ kwargs["session_state"] = session_state
1148
+ if hasattr(request.state, "dependencies"):
1149
+ dependencies = request.state.dependencies
1150
+ if "dependencies" in kwargs:
1151
+ log_warning("Dependencies parameter passed in both request state and kwargs, using request state")
1152
+ kwargs["dependencies"] = dependencies
1153
+
1044
1154
  logger.debug(f"Creating team run: {message=} {session_id=} {monitor=} {user_id=} {team_id=} {files=} {kwargs=}")
1045
1155
 
1046
1156
  team = get_team_by_id(team_id, os.teams)
@@ -1122,18 +1232,22 @@ def get_base_router(
1122
1232
  media_type="text/event-stream",
1123
1233
  )
1124
1234
  else:
1125
- run_response = await team.arun(
1126
- input=message,
1127
- session_id=session_id,
1128
- user_id=user_id,
1129
- images=base64_images if base64_images else None,
1130
- audio=base64_audios if base64_audios else None,
1131
- videos=base64_videos if base64_videos else None,
1132
- files=document_files if document_files else None,
1133
- stream=False,
1134
- **kwargs,
1135
- )
1136
- return run_response.to_dict()
1235
+ try:
1236
+ run_response = await team.arun(
1237
+ input=message,
1238
+ session_id=session_id,
1239
+ user_id=user_id,
1240
+ images=base64_images if base64_images else None,
1241
+ audio=base64_audios if base64_audios else None,
1242
+ videos=base64_videos if base64_videos else None,
1243
+ files=document_files if document_files else None,
1244
+ stream=False,
1245
+ **kwargs,
1246
+ )
1247
+ return run_response.to_dict()
1248
+
1249
+ except InputCheckError as e:
1250
+ raise HTTPException(status_code=400, detail=str(e))
1137
1251
 
1138
1252
  @router.post(
1139
1253
  "/teams/{team_id}/runs/{run_id}/cancel",
@@ -1464,6 +1578,25 @@ def get_base_router(
1464
1578
  ):
1465
1579
  kwargs = await _get_request_kwargs(request, create_workflow_run)
1466
1580
 
1581
+ if hasattr(request.state, "user_id"):
1582
+ if user_id:
1583
+ log_warning("User ID parameter passed in both request state and kwargs, using request state")
1584
+ user_id = request.state.user_id
1585
+ if hasattr(request.state, "session_id"):
1586
+ if session_id:
1587
+ log_warning("Session ID parameter passed in both request state and kwargs, using request state")
1588
+ session_id = request.state.session_id
1589
+ if hasattr(request.state, "session_state"):
1590
+ session_state = request.state.session_state
1591
+ if "session_state" in kwargs:
1592
+ log_warning("Session state parameter passed in both request state and kwargs, using request state")
1593
+ kwargs["session_state"] = session_state
1594
+ if hasattr(request.state, "dependencies"):
1595
+ dependencies = request.state.dependencies
1596
+ if "dependencies" in kwargs:
1597
+ log_warning("Dependencies parameter passed in both request state and kwargs, using request state")
1598
+ kwargs["dependencies"] = dependencies
1599
+
1467
1600
  # Retrieve the workflow by ID
1468
1601
  workflow = get_workflow_by_id(workflow_id, os.workflows)
1469
1602
  if workflow is None:
@@ -1497,6 +1630,9 @@ def get_base_router(
1497
1630
  **kwargs,
1498
1631
  )
1499
1632
  return run_response.to_dict()
1633
+
1634
+ except InputCheckError as e:
1635
+ raise HTTPException(status_code=400, detail=str(e))
1500
1636
  except Exception as e:
1501
1637
  # Handle unexpected runtime errors
1502
1638
  raise HTTPException(status_code=500, detail=f"Error running workflow: {str(e)}")
agno/os/routers/home.py CHANGED
@@ -30,7 +30,7 @@ def get_home_router(os: "AgentOS") -> APIRouter:
30
30
  "value": {
31
31
  "name": "AgentOS API",
32
32
  "description": "AI Agent Operating System API",
33
- "os_id": "demo-os",
33
+ "id": "demo-os",
34
34
  "version": "1.0.0",
35
35
  },
36
36
  }
@@ -45,7 +45,7 @@ def get_home_router(os: "AgentOS") -> APIRouter:
45
45
  return {
46
46
  "name": "AgentOS API",
47
47
  "description": os.description or "AI Agent Operating System API",
48
- "os_id": os.os_id or "agno-agentos",
48
+ "id": os.id or "agno-agentos",
49
49
  "version": os.version or "1.0.0",
50
50
  }
51
51
 
@@ -3,7 +3,7 @@ import math
3
3
  from typing import List, Optional
4
4
  from uuid import uuid4
5
5
 
6
- from fastapi import Depends, HTTPException, Path, Query
6
+ from fastapi import Depends, HTTPException, Path, Query, Request
7
7
  from fastapi.routing import APIRouter
8
8
 
9
9
  from agno.db.base import BaseDb
@@ -80,9 +80,17 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
80
80
  },
81
81
  )
82
82
  async def create_memory(
83
+ request: Request,
83
84
  payload: UserMemoryCreateSchema,
84
85
  db_id: Optional[str] = Query(default=None, description="Database ID to use for memory storage"),
85
86
  ) -> UserMemorySchema:
87
+ if hasattr(request.state, "user_id"):
88
+ user_id = request.state.user_id
89
+ payload.user_id = user_id
90
+
91
+ if payload.user_id is None:
92
+ raise HTTPException(status_code=400, detail="User ID is required")
93
+
86
94
  db = get_db(dbs, db_id)
87
95
  user_memory = db.upsert_user_memory(
88
96
  memory=UserMemory(
@@ -173,6 +181,7 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
173
181
  },
174
182
  )
175
183
  async def get_memories(
184
+ request: Request,
176
185
  user_id: Optional[str] = Query(default=None, description="Filter memories by user ID"),
177
186
  agent_id: Optional[str] = Query(default=None, description="Filter memories by agent ID"),
178
187
  team_id: Optional[str] = Query(default=None, description="Filter memories by team ID"),
@@ -185,6 +194,10 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
185
194
  db_id: Optional[str] = Query(default=None, description="Database ID to query memories from"),
186
195
  ) -> PaginatedResponse[UserMemorySchema]:
187
196
  db = get_db(dbs, db_id)
197
+
198
+ if hasattr(request.state, "user_id"):
199
+ user_id = request.state.user_id
200
+
188
201
  user_memories, total_count = db.get_user_memories(
189
202
  limit=limit,
190
203
  page=page,
@@ -316,11 +329,20 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
316
329
  },
317
330
  )
318
331
  async def update_memory(
332
+ request: Request,
319
333
  payload: UserMemoryCreateSchema,
320
334
  memory_id: str = Path(description="Memory ID to update"),
321
335
  db_id: Optional[str] = Query(default=None, description="Database ID to use for update"),
322
336
  ) -> UserMemorySchema:
323
337
  db = get_db(dbs, db_id)
338
+
339
+ if hasattr(request.state, "user_id"):
340
+ user_id = request.state.user_id
341
+ payload.user_id = user_id
342
+
343
+ if payload.user_id is None:
344
+ raise HTTPException(status_code=400, detail="User ID is required")
345
+
324
346
  user_memory = db.upsert_user_memory(
325
347
  memory=UserMemory(
326
348
  memory_id=memory_id,
@@ -36,7 +36,7 @@ class UserMemoryCreateSchema(BaseModel):
36
36
  """Define the payload expected for creating a new user memory"""
37
37
 
38
38
  memory: str
39
- user_id: str
39
+ user_id: Optional[str] = None
40
40
  topics: Optional[List[str]] = None
41
41
 
42
42
 
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from typing import List, Optional, Union
3
3
 
4
- from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query
4
+ from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
5
5
 
6
6
  from agno.db.base import BaseDb, SessionType
7
7
  from agno.os.auth import get_authentication_dependency
@@ -86,6 +86,7 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
86
86
  },
87
87
  )
88
88
  async def get_sessions(
89
+ request: Request,
89
90
  session_type: SessionType = Query(
90
91
  default=SessionType.AGENT,
91
92
  alias="type",
@@ -103,6 +104,10 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
103
104
  db_id: Optional[str] = Query(default=None, description="Database ID to query sessions from"),
104
105
  ) -> PaginatedResponse[SessionSchema]:
105
106
  db = get_db(dbs, db_id)
107
+
108
+ if hasattr(request.state, "user_id"):
109
+ user_id = request.state.user_id
110
+
106
111
  sessions, total_count = db.get_sessions(
107
112
  session_type=session_type,
108
113
  component_id=component_id,
@@ -213,14 +218,20 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
213
218
  },
214
219
  )
215
220
  async def get_session_by_id(
221
+ request: Request,
216
222
  session_id: str = Path(description="Session ID to retrieve"),
217
223
  session_type: SessionType = Query(
218
224
  default=SessionType.AGENT, description="Session type (agent, team, or workflow)", alias="type"
219
225
  ),
226
+ user_id: Optional[str] = Query(default=None, description="User ID to query session from"),
220
227
  db_id: Optional[str] = Query(default=None, description="Database ID to query session from"),
221
228
  ) -> Union[AgentSessionDetailSchema, TeamSessionDetailSchema, WorkflowSessionDetailSchema]:
222
229
  db = get_db(dbs, db_id)
223
- session = db.get_session(session_id=session_id, session_type=session_type)
230
+
231
+ if hasattr(request.state, "user_id"):
232
+ user_id = request.state.user_id
233
+
234
+ session = db.get_session(session_id=session_id, session_type=session_type, user_id=user_id)
224
235
  if not session:
225
236
  raise HTTPException(
226
237
  status_code=404, detail=f"{session_type.value.title()} Session with id '{session_id}' not found"
@@ -349,14 +360,20 @@ def attach_routes(router: APIRouter, dbs: dict[str, BaseDb]) -> APIRouter:
349
360
  },
350
361
  )
351
362
  async def get_session_runs(
363
+ request: Request,
352
364
  session_id: str = Path(description="Session ID to get runs from"),
353
365
  session_type: SessionType = Query(
354
366
  default=SessionType.AGENT, description="Session type (agent, team, or workflow)", alias="type"
355
367
  ),
368
+ user_id: Optional[str] = Query(default=None, description="User ID to query runs from"),
356
369
  db_id: Optional[str] = Query(default=None, description="Database ID to query runs from"),
357
370
  ) -> List[Union[RunSchema, TeamRunSchema, WorkflowRunSchema]]:
358
371
  db = get_db(dbs, db_id)
359
- session = db.get_session(session_id=session_id, session_type=session_type, deserialize=False)
372
+
373
+ if hasattr(request.state, "user_id"):
374
+ user_id = request.state.user_id
375
+
376
+ session = db.get_session(session_id=session_id, session_type=session_type, user_id=user_id, deserialize=False)
360
377
  if not session:
361
378
  raise HTTPException(status_code=404, detail=f"Session with ID {session_id} not found")
362
379