agno 2.0.11__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. agno/agent/agent.py +606 -175
  2. agno/db/in_memory/in_memory_db.py +42 -29
  3. agno/db/postgres/postgres.py +6 -4
  4. agno/exceptions.py +62 -1
  5. agno/guardrails/__init__.py +6 -0
  6. agno/guardrails/base.py +19 -0
  7. agno/guardrails/openai.py +144 -0
  8. agno/guardrails/pii.py +94 -0
  9. agno/guardrails/prompt_injection.py +51 -0
  10. agno/knowledge/embedder/aws_bedrock.py +9 -4
  11. agno/knowledge/embedder/azure_openai.py +54 -0
  12. agno/knowledge/embedder/base.py +2 -0
  13. agno/knowledge/embedder/cohere.py +184 -5
  14. agno/knowledge/embedder/google.py +79 -1
  15. agno/knowledge/embedder/huggingface.py +9 -4
  16. agno/knowledge/embedder/jina.py +63 -0
  17. agno/knowledge/embedder/mistral.py +78 -11
  18. agno/knowledge/embedder/ollama.py +5 -0
  19. agno/knowledge/embedder/openai.py +18 -54
  20. agno/knowledge/embedder/voyageai.py +69 -16
  21. agno/knowledge/knowledge.py +5 -4
  22. agno/knowledge/reader/pdf_reader.py +4 -3
  23. agno/knowledge/reader/website_reader.py +3 -2
  24. agno/models/base.py +125 -32
  25. agno/models/cerebras/cerebras.py +1 -0
  26. agno/models/cerebras/cerebras_openai.py +1 -0
  27. agno/models/dashscope/dashscope.py +1 -0
  28. agno/models/google/gemini.py +27 -5
  29. agno/models/openai/chat.py +13 -4
  30. agno/models/perplexity/perplexity.py +2 -3
  31. agno/models/requesty/__init__.py +5 -0
  32. agno/models/requesty/requesty.py +49 -0
  33. agno/models/vllm/vllm.py +1 -0
  34. agno/models/xai/xai.py +1 -0
  35. agno/os/app.py +98 -126
  36. agno/os/interfaces/whatsapp/router.py +2 -0
  37. agno/os/mcp.py +1 -1
  38. agno/os/middleware/__init__.py +7 -0
  39. agno/os/middleware/jwt.py +233 -0
  40. agno/os/router.py +181 -45
  41. agno/os/routers/home.py +2 -2
  42. agno/os/routers/memory/memory.py +23 -1
  43. agno/os/routers/memory/schemas.py +1 -1
  44. agno/os/routers/session/session.py +20 -3
  45. agno/os/utils.py +74 -8
  46. agno/run/agent.py +120 -77
  47. agno/run/team.py +115 -72
  48. agno/run/workflow.py +5 -15
  49. agno/session/summary.py +9 -10
  50. agno/session/team.py +2 -1
  51. agno/team/team.py +720 -168
  52. agno/tools/firecrawl.py +4 -4
  53. agno/tools/function.py +42 -2
  54. agno/tools/knowledge.py +3 -3
  55. agno/tools/searxng.py +2 -2
  56. agno/tools/serper.py +2 -2
  57. agno/tools/spider.py +2 -2
  58. agno/tools/workflow.py +4 -5
  59. agno/utils/events.py +66 -1
  60. agno/utils/hooks.py +57 -0
  61. agno/utils/media.py +11 -9
  62. agno/utils/print_response/agent.py +43 -5
  63. agno/utils/print_response/team.py +48 -12
  64. agno/vectordb/cassandra/cassandra.py +44 -4
  65. agno/vectordb/chroma/chromadb.py +79 -8
  66. agno/vectordb/clickhouse/clickhousedb.py +43 -6
  67. agno/vectordb/couchbase/couchbase.py +76 -5
  68. agno/vectordb/lancedb/lance_db.py +38 -3
  69. agno/vectordb/milvus/milvus.py +76 -4
  70. agno/vectordb/mongodb/mongodb.py +76 -4
  71. agno/vectordb/pgvector/pgvector.py +50 -6
  72. agno/vectordb/pineconedb/pineconedb.py +39 -2
  73. agno/vectordb/qdrant/qdrant.py +76 -26
  74. agno/vectordb/singlestore/singlestore.py +77 -4
  75. agno/vectordb/upstashdb/upstashdb.py +42 -2
  76. agno/vectordb/weaviate/weaviate.py +39 -3
  77. agno/workflow/types.py +1 -0
  78. agno/workflow/workflow.py +58 -2
  79. {agno-2.0.11.dist-info → agno-2.1.0.dist-info}/METADATA +4 -3
  80. {agno-2.0.11.dist-info → agno-2.1.0.dist-info}/RECORD +83 -73
  81. {agno-2.0.11.dist-info → agno-2.1.0.dist-info}/WHEEL +0 -0
  82. {agno-2.0.11.dist-info → agno-2.1.0.dist-info}/licenses/LICENSE +0 -0
  83. {agno-2.0.11.dist-info → agno-2.1.0.dist-info}/top_level.txt +0 -0
agno/os/app.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from contextlib import asynccontextmanager
2
2
  from functools import partial
3
3
  from os import getenv
4
- from typing import Any, Dict, List, Optional, Set, Union
4
+ from typing import Any, Dict, List, Literal, Optional, Union
5
5
  from uuid import uuid4
6
6
 
7
7
  from fastapi import APIRouter, FastAPI, HTTPException
@@ -40,6 +40,8 @@ from agno.os.settings import AgnoAPISettings
40
40
  from agno.os.utils import (
41
41
  collect_mcp_tools_from_team,
42
42
  collect_mcp_tools_from_workflow,
43
+ find_conflicting_routes,
44
+ load_yaml_config,
43
45
  update_cors_middleware,
44
46
  )
45
47
  from agno.team.team import Team
@@ -49,7 +51,7 @@ from agno.workflow.workflow import Workflow
49
51
 
50
52
 
51
53
  @asynccontextmanager
52
- async def mcp_lifespan(app, mcp_tools):
54
+ async def mcp_lifespan(_, mcp_tools):
53
55
  """Manage MCP connection lifecycle inside a FastAPI app"""
54
56
  # Startup logic: connect to all contextual MCP servers
55
57
  for tool in mcp_tools:
@@ -65,7 +67,8 @@ async def mcp_lifespan(app, mcp_tools):
65
67
  class AgentOS:
66
68
  def __init__(
67
69
  self,
68
- os_id: Optional[str] = None,
70
+ id: Optional[str] = None,
71
+ os_id: Optional[str] = None, # Deprecated
69
72
  name: Optional[str] = None,
70
73
  description: Optional[str] = None,
71
74
  version: Optional[str] = None,
@@ -75,16 +78,19 @@ class AgentOS:
75
78
  interfaces: Optional[List[BaseInterface]] = None,
76
79
  config: Optional[Union[str, AgentOSConfig]] = None,
77
80
  settings: Optional[AgnoAPISettings] = None,
78
- fastapi_app: Optional[FastAPI] = None,
79
81
  lifespan: Optional[Any] = None,
80
- enable_mcp: bool = False,
81
- replace_routes: bool = True,
82
+ enable_mcp: bool = False, # Deprecated
83
+ enable_mcp_server: bool = False,
84
+ fastapi_app: Optional[FastAPI] = None, # Deprecated
85
+ base_app: Optional[FastAPI] = None,
86
+ replace_routes: Optional[bool] = None, # Deprecated
87
+ on_route_conflict: Literal["preserve_agentos", "preserve_base_app", "error"] = "preserve_agentos",
82
88
  telemetry: bool = True,
83
89
  ):
84
90
  """Initialize AgentOS.
85
91
 
86
92
  Args:
87
- os_id: Unique identifier for this AgentOS instance
93
+ id: Unique identifier for this AgentOS instance
88
94
  name: Name of the AgentOS instance
89
95
  description: Description of the AgentOS instance
90
96
  version: Version of the AgentOS instance
@@ -94,18 +100,16 @@ class AgentOS:
94
100
  interfaces: List of interfaces to include in the OS
95
101
  config: Configuration file path or AgentOSConfig instance
96
102
  settings: API settings for the OS
97
- fastapi_app: Optional custom FastAPI app to use instead of creating a new one
98
103
  lifespan: Optional lifespan context manager for the FastAPI app
99
- enable_mcp: Whether to enable MCP (Model Context Protocol)
100
- replace_routes: If False and using a custom fastapi_app, skip AgentOS routes that
101
- conflict with existing routes, preferring the user's custom routes.
102
- If True (default), AgentOS routes will override conflicting custom routes.
104
+ enable_mcp_server: Whether to enable MCP (Model Context Protocol)
105
+ base_app: Optional base FastAPI app to use for the AgentOS. All routes and middleware will be added to this app.
106
+ on_route_conflict: What to do when a route conflict is detected in case a custom base_app is provided.
103
107
  telemetry: Whether to enable telemetry
104
108
  """
105
109
  if not agents and not workflows and not teams:
106
110
  raise ValueError("Either agents, teams or workflows must be provided.")
107
111
 
108
- self.config = self._load_yaml_config(config) if isinstance(config, str) else config
112
+ self.config = load_yaml_config(config) if isinstance(config, str) else config
109
113
 
110
114
  self.agents: Optional[List[Agent]] = agents
111
115
  self.workflows: Optional[List[Workflow]] = workflows
@@ -115,27 +119,42 @@ class AgentOS:
115
119
  self.settings: AgnoAPISettings = settings or AgnoAPISettings()
116
120
 
117
121
  self._app_set = False
118
- self.fastapi_app: Optional[FastAPI] = None
119
- if fastapi_app:
120
- self.fastapi_app = fastapi_app
122
+
123
+ if base_app:
124
+ self.base_app: Optional[FastAPI] = base_app
125
+ self._app_set = True
126
+ self.on_route_conflict = on_route_conflict
127
+ elif fastapi_app:
128
+ self.base_app = fastapi_app
121
129
  self._app_set = True
130
+ if replace_routes is not None:
131
+ self.on_route_conflict = "preserve_agentos" if replace_routes else "preserve_base_app"
132
+ else:
133
+ self.on_route_conflict = on_route_conflict
134
+ else:
135
+ self.base_app = None
136
+ self._app_set = False
137
+ self.on_route_conflict = on_route_conflict
122
138
 
123
139
  self.interfaces = interfaces or []
124
140
 
125
- self.os_id = os_id
126
141
  self.name = name
142
+
143
+ self.id = id or os_id
144
+ if not self.id:
145
+ self.id = generate_id(self.name) if self.name else str(uuid4())
146
+
127
147
  self.version = version
128
148
  self.description = description
129
149
 
130
- self.replace_routes = replace_routes
131
-
132
150
  self.telemetry = telemetry
133
151
 
134
- self.enable_mcp = enable_mcp
152
+ self.enable_mcp_server = enable_mcp or enable_mcp_server
135
153
  self.lifespan = lifespan
136
154
 
137
155
  # List of all MCP tools used inside the AgentOS
138
156
  self.mcp_tools: List[Any] = []
157
+ self._mcp_app: Optional[Any] = None
139
158
 
140
159
  if self.agents:
141
160
  for agent in self.agents:
@@ -177,13 +196,10 @@ class AgentOS:
177
196
  if not workflow.id:
178
197
  workflow.id = generate_id_from_name(workflow.name)
179
198
 
180
- if not self.os_id:
181
- self.os_id = generate_id(self.name) if self.name else str(uuid4())
182
-
183
199
  if self.telemetry:
184
200
  from agno.api.os import OSLaunch, log_os_telemetry
185
201
 
186
- log_os_telemetry(launch=OSLaunch(os_id=self.os_id, data=self._get_telemetry_data()))
202
+ log_os_telemetry(launch=OSLaunch(os_id=self.id, data=self._get_telemetry_data()))
187
203
 
188
204
  def _make_app(self, lifespan: Optional[Any] = None) -> FastAPI:
189
205
  # Adjust the FastAPI app lifespan to handle MCP connections if relevant
@@ -215,39 +231,41 @@ class AgentOS:
215
231
  )
216
232
 
217
233
  def get_app(self) -> FastAPI:
218
- if not self.fastapi_app:
219
- if self.enable_mcp:
234
+ if self.base_app:
235
+ fastapi_app = self.base_app
236
+ else:
237
+ if self.enable_mcp_server:
220
238
  from contextlib import asynccontextmanager
221
239
 
222
240
  from agno.os.mcp import get_mcp_server
223
241
 
224
- self.mcp_app = get_mcp_server(self)
242
+ self._mcp_app = get_mcp_server(self)
225
243
 
226
- final_lifespan = self.mcp_app.lifespan
244
+ final_lifespan = self._mcp_app.lifespan # type: ignore
227
245
  if self.lifespan is not None:
228
246
  # Combine both lifespans
229
247
  @asynccontextmanager
230
248
  async def combined_lifespan(app: FastAPI):
231
249
  # Run both lifespans
232
250
  async with self.lifespan(app): # type: ignore
233
- async with self.mcp_app.lifespan(app): # type: ignore
251
+ async with self._mcp_app.lifespan(app): # type: ignore
234
252
  yield
235
253
 
236
254
  final_lifespan = combined_lifespan # type: ignore
237
255
 
238
- self.fastapi_app = self._make_app(lifespan=final_lifespan)
256
+ fastapi_app = self._make_app(lifespan=final_lifespan)
239
257
  else:
240
- self.fastapi_app = self._make_app(lifespan=self.lifespan)
258
+ fastapi_app = self._make_app(lifespan=self.lifespan)
241
259
 
242
- # Add routes with conflict detection
243
- self._add_router(get_base_router(self, settings=self.settings))
244
- self._add_router(get_websocket_router(self, settings=self.settings))
245
- self._add_router(get_health_router())
246
- self._add_router(get_home_router(self))
260
+ # Add routes
261
+ self._add_router(fastapi_app, get_base_router(self, settings=self.settings))
262
+ self._add_router(fastapi_app, get_websocket_router(self, settings=self.settings))
263
+ self._add_router(fastapi_app, get_health_router())
264
+ self._add_router(fastapi_app, get_home_router(self))
247
265
 
248
266
  for interface in self.interfaces:
249
267
  interface_router = interface.get_router()
250
- self._add_router(interface_router)
268
+ self._add_router(fastapi_app, interface_router)
251
269
 
252
270
  self._auto_discover_databases()
253
271
  self._auto_discover_knowledge_instances()
@@ -261,17 +279,19 @@ class AgentOS:
261
279
  ]
262
280
 
263
281
  for router in routers:
264
- self._add_router(router)
282
+ self._add_router(fastapi_app, router)
265
283
 
266
284
  # Mount MCP if needed
267
- if self.enable_mcp and self.mcp_app:
268
- self.fastapi_app.mount("/", self.mcp_app)
285
+ if self.enable_mcp_server and self._mcp_app:
286
+ fastapi_app.mount("/", self._mcp_app)
287
+ else:
288
+ # Add the home router
289
+ self._add_router(fastapi_app, get_home_router(self))
269
290
 
270
- # Add middleware (only if app is not set)
271
291
  if not self._app_set:
272
292
 
273
- @self.fastapi_app.exception_handler(HTTPException)
274
- async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
293
+ @fastapi_app.exception_handler(HTTPException)
294
+ async def http_exception_handler(_, exc: HTTPException) -> JSONResponse:
275
295
  return JSONResponse(
276
296
  status_code=exc.status_code,
277
297
  content={"detail": str(exc.detail)},
@@ -286,12 +306,12 @@ class AgentOS:
286
306
  content={"detail": str(e)},
287
307
  )
288
308
 
289
- self.fastapi_app.middleware("http")(general_exception_handler)
309
+ fastapi_app.middleware("http")(general_exception_handler)
290
310
 
291
311
  # Update CORS middleware
292
- update_cors_middleware(self.fastapi_app, self.settings.cors_origin_list) # type: ignore
312
+ update_cors_middleware(fastapi_app, self.settings.cors_origin_list) # type: ignore
293
313
 
294
- return self.fastapi_app
314
+ return fastapi_app
295
315
 
296
316
  def get_routes(self) -> List[Any]:
297
317
  """Retrieve all routes from the FastAPI app.
@@ -303,55 +323,37 @@ class AgentOS:
303
323
 
304
324
  return app.routes
305
325
 
306
- def _get_existing_route_paths(self) -> Dict[str, List[str]]:
307
- """Get all existing route paths and methods from the FastAPI app.
308
-
309
- Returns:
310
- Dict[str, List[str]]: Dictionary mapping paths to list of HTTP methods
311
- """
312
- if not self.fastapi_app:
313
- return {}
314
-
315
- existing_paths: Dict[str, Any] = {}
316
- for route in self.fastapi_app.routes:
317
- if isinstance(route, APIRoute):
318
- path = route.path
319
- methods = list(route.methods) if route.methods else []
320
- if path in existing_paths:
321
- existing_paths[path].extend(methods)
322
- else:
323
- existing_paths[path] = methods
324
- return existing_paths
325
-
326
- def _add_router(self, router: APIRouter) -> None:
326
+ def _add_router(self, fastapi_app: FastAPI, router: APIRouter) -> None:
327
327
  """Add a router to the FastAPI app, avoiding route conflicts.
328
328
 
329
329
  Args:
330
330
  router: The APIRouter to add
331
331
  """
332
- if not self.fastapi_app:
333
- return
334
332
 
335
- # Get existing routes
336
- existing_paths = self._get_existing_route_paths()
333
+ conflicts = find_conflicting_routes(fastapi_app, router)
334
+ conflicting_routes = [conflict["route"] for conflict in conflicts]
337
335
 
338
- # Check for conflicts
339
- conflicts = []
340
- conflicting_routes = []
336
+ if conflicts and self._app_set:
337
+ if self.on_route_conflict == "preserve_base_app":
338
+ # Skip conflicting AgentOS routes, prefer user's existing routes
339
+ for conflict in conflicts:
340
+ methods_str = ", ".join(conflict["methods"]) # type: ignore
341
+ logger.debug(
342
+ f"Skipping conflicting AgentOS route: {methods_str} {conflict['path']} - "
343
+ f"Using existing custom route instead"
344
+ )
341
345
 
342
- for route in router.routes:
343
- if isinstance(route, APIRoute):
344
- full_path = route.path
345
- route_methods = list(route.methods) if route.methods else []
346
+ # Create a new router without the conflicting routes
347
+ filtered_router = APIRouter()
348
+ for route in router.routes:
349
+ if route not in conflicting_routes:
350
+ filtered_router.routes.append(route)
346
351
 
347
- if full_path in existing_paths:
348
- conflicting_methods: Set[str] = set(route_methods) & set(existing_paths[full_path])
349
- if conflicting_methods:
350
- conflicts.append({"path": full_path, "methods": list(conflicting_methods), "route": route})
351
- conflicting_routes.append(route)
352
+ # Use the filtered router if it has any routes left
353
+ if filtered_router.routes:
354
+ fastapi_app.include_router(filtered_router)
352
355
 
353
- if conflicts and self._app_set:
354
- if self.replace_routes:
356
+ elif self.on_route_conflict == "preserve_agentos":
355
357
  # Log warnings but still add all routes (AgentOS routes will override)
356
358
  for conflict in conflicts:
357
359
  methods_str = ", ".join(conflict["methods"]) # type: ignore
@@ -361,35 +363,21 @@ class AgentOS:
361
363
  )
362
364
 
363
365
  # Remove conflicting routes
364
- for route in self.fastapi_app.routes:
366
+ for route in fastapi_app.routes:
365
367
  for conflict in conflicts:
366
368
  if isinstance(route, APIRoute):
367
369
  if route.path == conflict["path"] and list(route.methods) == list(conflict["methods"]): # type: ignore
368
- self.fastapi_app.routes.pop(self.fastapi_app.routes.index(route))
370
+ fastapi_app.routes.pop(fastapi_app.routes.index(route))
369
371
 
370
- self.fastapi_app.include_router(router)
372
+ fastapi_app.include_router(router)
371
373
 
372
- else:
373
- # Skip conflicting AgentOS routes, prefer user's existing routes
374
- for conflict in conflicts:
375
- methods_str = ", ".join(conflict["methods"]) # type: ignore
376
- logger.debug(
377
- f"Skipping conflicting AgentOS route: {methods_str} {conflict['path']} - "
378
- f"Using existing custom route instead"
379
- )
374
+ elif self.on_route_conflict == "error":
375
+ conflicting_paths = [conflict["path"] for conflict in conflicts]
376
+ raise ValueError(f"Route conflict detected: {conflicting_paths}")
380
377
 
381
- # Create a new router without the conflicting routes
382
- filtered_router = APIRouter()
383
- for route in router.routes:
384
- if route not in conflicting_routes:
385
- filtered_router.routes.append(route)
386
-
387
- # Use the filtered router if it has any routes left
388
- if filtered_router.routes:
389
- self.fastapi_app.include_router(filtered_router)
390
378
  else:
391
379
  # No conflicts, add router normally
392
- self.fastapi_app.include_router(router)
380
+ fastapi_app.include_router(router)
393
381
 
394
382
  def _get_telemetry_data(self) -> Dict[str, Any]:
395
383
  """Get the telemetry data for the OS"""
@@ -400,21 +388,6 @@ class AgentOS:
400
388
  "interfaces": [interface.type for interface in self.interfaces] if self.interfaces else None,
401
389
  }
402
390
 
403
- def _load_yaml_config(self, config_file_path: str) -> AgentOSConfig:
404
- """Load a YAML config file and return the configuration as an AgentOSConfig instance."""
405
- from pathlib import Path
406
-
407
- import yaml
408
-
409
- # Validate that the path points to a YAML file
410
- path = Path(config_file_path)
411
- if path.suffix.lower() not in [".yaml", ".yml"]:
412
- raise ValueError(f"Config file must have a .yaml or .yml extension, got: {config_file_path}")
413
-
414
- # Load the YAML file
415
- with open(config_file_path, "r") as f:
416
- return AgentOSConfig.model_validate(yaml.safe_load(f))
417
-
418
391
  def _auto_discover_databases(self) -> None:
419
392
  """Auto-discover the databases used by all contextual agents, teams and workflows."""
420
393
  from agno.db.base import BaseDb
@@ -641,11 +614,10 @@ class AgentOS:
641
614
  from rich.align import Align
642
615
  from rich.console import Console, Group
643
616
 
644
- panel_group = []
645
- panel_group.append(Align.center(f"[bold cyan]{public_endpoint}[/bold cyan]"))
646
- panel_group.append(
647
- Align.center(f"\n\n[bold dark_orange]OS running on:[/bold dark_orange] http://{host}:{port}")
648
- )
617
+ panel_group = [
618
+ Align.center(f"[bold cyan]{public_endpoint}[/bold cyan]"),
619
+ Align.center(f"\n\n[bold dark_orange]OS running on:[/bold dark_orange] http://{host}:{port}"),
620
+ ]
649
621
  if bool(self.settings.os_security_key):
650
622
  panel_group.append(Align.center("\n\n[bold chartreuse3]:lock: Security Enabled[/bold chartreuse3]"))
651
623
 
@@ -123,6 +123,7 @@ def attach_routes(router: APIRouter, agent: Optional[Agent] = None, team: Option
123
123
  response = await agent.arun(
124
124
  message_text,
125
125
  user_id=phone_number,
126
+ session_id=f"wa:{phone_number}",
126
127
  images=[Image(content=await get_media_async(message_image))] if message_image else None,
127
128
  files=[File(content=await get_media_async(message_doc))] if message_doc else None,
128
129
  videos=[Video(content=await get_media_async(message_video))] if message_video else None,
@@ -132,6 +133,7 @@ def attach_routes(router: APIRouter, agent: Optional[Agent] = None, team: Option
132
133
  response = await team.arun( # type: ignore
133
134
  message_text,
134
135
  user_id=phone_number,
136
+ session_id=f"wa:{phone_number}",
135
137
  files=[File(content=await get_media_async(message_doc))] if message_doc else None,
136
138
  images=[Image(content=await get_media_async(message_image))] if message_image else None,
137
139
  videos=[Video(content=await get_media_async(message_video))] if message_video else None,
agno/os/mcp.py CHANGED
@@ -54,7 +54,7 @@ def get_mcp_server(
54
54
  ) # type: ignore
55
55
  async def config() -> ConfigResponse:
56
56
  return ConfigResponse(
57
- os_id=os.os_id or "AgentOS",
57
+ os_id=os.id or "AgentOS",
58
58
  description=os.description,
59
59
  available_models=os.config.available_models if os.config else [],
60
60
  databases=[db.id for db in os.dbs.values()],
@@ -0,0 +1,7 @@
1
+ from agno.os.middleware.jwt import (
2
+ JWTMiddleware,
3
+ )
4
+
5
+ __all__ = [
6
+ "JWTMiddleware",
7
+ ]
@@ -0,0 +1,233 @@
1
+ import fnmatch
2
+ from enum import Enum
3
+ from os import getenv
4
+ from typing import List, Optional
5
+
6
+ import jwt
7
+ from fastapi import Request, Response
8
+ from fastapi.responses import JSONResponse
9
+ from starlette.middleware.base import BaseHTTPMiddleware
10
+
11
+ from agno.utils.log import log_debug
12
+
13
+
14
+ class TokenSource(str, Enum):
15
+ """Enum for JWT token source options."""
16
+
17
+ HEADER = "header"
18
+ COOKIE = "cookie"
19
+ BOTH = "both" # Try header first, then cookie
20
+
21
+
22
+ class JWTMiddleware(BaseHTTPMiddleware):
23
+ """
24
+ JWT Middleware for validating tokens and storing JWT claims in request state.
25
+
26
+ This middleware:
27
+ 1. Extracts JWT token from Authorization header, cookies, or both
28
+ 2. Decodes and validates the token
29
+ 3. Stores JWT claims in request.state for easy access in endpoints
30
+
31
+ Token Sources:
32
+ - "header": Extract from Authorization header (default)
33
+ - "cookie": Extract from HTTP cookie
34
+ - "both": Try header first, then cookie as fallback
35
+
36
+ Claims are stored as:
37
+ - request.state.user_id: User ID from configured claim
38
+ - request.state.session_id: Session ID from configured claim
39
+ - request.state.dependencies: Dictionary of dependency claims
40
+ - request.state.session_state: Dictionary of session state claims
41
+ - request.state.authenticated: Boolean authentication status
42
+
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ app,
48
+ secret_key: Optional[str] = None,
49
+ algorithm: str = "HS256",
50
+ token_source: TokenSource = TokenSource.HEADER,
51
+ token_header_key: str = "Authorization",
52
+ cookie_name: str = "access_token",
53
+ validate: bool = True,
54
+ excluded_route_paths: Optional[List[str]] = None,
55
+ scopes_claim: Optional[str] = None,
56
+ user_id_claim: str = "sub",
57
+ session_id_claim: str = "session_id",
58
+ dependencies_claims: Optional[List[str]] = None,
59
+ session_state_claims: Optional[List[str]] = None,
60
+ ):
61
+ """
62
+ Initialize the JWT middleware.
63
+
64
+ Args:
65
+ app: The FastAPI app instance
66
+ secret_key: The secret key to use for JWT validation (optional, will use JWT_SECRET_KEY environment variable if not provided)
67
+ algorithm: The algorithm to use for JWT validation
68
+ token_header_key: The key to use for the Authorization header (only used when token_source is header)
69
+ token_source: Where to extract the JWT token from (header, cookie, or both)
70
+ cookie_name: The name of the cookie containing the JWT token (only used when token_source is cookie/both)
71
+ validate: Whether to validate the JWT token
72
+ excluded_route_paths: A list of route paths to exclude from JWT validation
73
+ scopes_claim: The claim to use for scopes extraction
74
+ user_id_claim: The claim to use for user ID extraction
75
+ session_id_claim: The claim to use for session ID extraction
76
+ dependencies_claims: A list of claims to extract from the JWT token for dependencies
77
+ session_state_claims: A list of claims to extract from the JWT token for session state
78
+ """
79
+ super().__init__(app)
80
+ self.secret_key = secret_key or getenv("JWT_SECRET_KEY")
81
+ if not self.secret_key:
82
+ raise ValueError("Secret key is required")
83
+ self.algorithm = algorithm
84
+ self.token_header_key = token_header_key
85
+ self.token_source = token_source
86
+ self.cookie_name = cookie_name
87
+ self.validate = validate
88
+ self.excluded_route_paths = excluded_route_paths
89
+ self.scopes_claim = scopes_claim
90
+ self.user_id_claim = user_id_claim
91
+ self.session_id_claim = session_id_claim
92
+ self.dependencies_claims = dependencies_claims or []
93
+ self.session_state_claims = session_state_claims or []
94
+
95
+ def _extract_token_from_header(self, request: Request) -> Optional[str]:
96
+ """Extract JWT token from Authorization header."""
97
+ authorization = request.headers.get(self.token_header_key, "")
98
+ if not authorization:
99
+ return None
100
+
101
+ try:
102
+ # Remove the "Bearer " prefix (if present)
103
+ _, token = authorization.split(" ", 1)
104
+ return token
105
+ except ValueError:
106
+ return None
107
+
108
+ def _extract_token_from_cookie(self, request: Request) -> Optional[str]:
109
+ """Extract JWT token from cookie."""
110
+ return request.cookies.get(self.cookie_name)
111
+
112
+ def _extract_token(self, request: Request) -> Optional[str]:
113
+ """Extract JWT token based on configured token source."""
114
+ if self.token_source == TokenSource.HEADER:
115
+ return self._extract_token_from_header(request)
116
+ elif self.token_source == TokenSource.COOKIE:
117
+ return self._extract_token_from_cookie(request)
118
+ elif self.token_source == TokenSource.BOTH:
119
+ # Try header first, then cookie
120
+ token = self._extract_token_from_header(request)
121
+ if token is None:
122
+ token = self._extract_token_from_cookie(request)
123
+ return token
124
+ else:
125
+ log_debug(f"Unknown token source: {self.token_source}")
126
+ return None
127
+
128
+ def _get_missing_token_error_message(self) -> str:
129
+ """Get appropriate error message for missing token based on token source."""
130
+ if self.token_source == TokenSource.HEADER:
131
+ return "Authorization header missing"
132
+ elif self.token_source == TokenSource.COOKIE:
133
+ return f"JWT cookie '{self.cookie_name}' missing"
134
+ elif self.token_source == TokenSource.BOTH:
135
+ return f"JWT token missing from both Authorization header and '{self.cookie_name}' cookie"
136
+ else:
137
+ return "JWT token missing"
138
+
139
+ def _is_route_excluded(self, path: str) -> bool:
140
+ """Check if a route path matches any of the excluded patterns."""
141
+ if not self.excluded_route_paths:
142
+ return False
143
+
144
+ for excluded_path in self.excluded_route_paths:
145
+ # Support both exact matches and wildcard patterns
146
+ if fnmatch.fnmatch(path, excluded_path):
147
+ return True
148
+
149
+ return False
150
+
151
+ async def dispatch(self, request: Request, call_next) -> Response:
152
+ if self._is_route_excluded(request.url.path):
153
+ return await call_next(request)
154
+
155
+ # Extract JWT token from configured source (header, cookie, or both)
156
+ token = self._extract_token(request)
157
+
158
+ if not token:
159
+ if self.validate:
160
+ error_msg = self._get_missing_token_error_message()
161
+ return JSONResponse(status_code=401, content={"detail": error_msg})
162
+ return await call_next(request)
163
+
164
+ # Decode JWT token
165
+ try:
166
+ payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) # type: ignore
167
+
168
+ # Extract scopes claims
169
+ scopes = []
170
+ if self.scopes_claim in payload:
171
+ extracted_scopes = payload[self.scopes_claim]
172
+ if isinstance(extracted_scopes, str):
173
+ scopes = extracted_scopes.split(" ")
174
+ else:
175
+ scopes = extracted_scopes
176
+ if scopes:
177
+ request.state.scopes = scopes
178
+
179
+ # Extract user information
180
+ if self.user_id_claim in payload:
181
+ user_id = payload[self.user_id_claim]
182
+ request.state.user_id = user_id
183
+ if self.session_id_claim in payload:
184
+ session_id = payload[self.session_id_claim]
185
+ request.state.session_id = session_id
186
+ else:
187
+ session_id = None
188
+
189
+ # Extract dependency claims
190
+ dependencies = {}
191
+ for claim in self.dependencies_claims:
192
+ if claim in payload:
193
+ dependencies[claim] = payload[claim]
194
+
195
+ if dependencies:
196
+ request.state.dependencies = dependencies
197
+
198
+ # Extract session state claims
199
+ session_state = {}
200
+ for claim in self.session_state_claims:
201
+ if claim in payload:
202
+ session_state[claim] = payload[claim]
203
+
204
+ if session_state:
205
+ request.state.session_state = session_state
206
+
207
+ request.state.token = token
208
+ request.state.authenticated = True
209
+
210
+ log_debug(f"JWT decoded successfully for user: {user_id}")
211
+ if dependencies:
212
+ log_debug(f"Extracted dependencies: {dependencies}")
213
+ if session_state:
214
+ log_debug(f"Extracted session state: {session_state}")
215
+
216
+ except jwt.ExpiredSignatureError:
217
+ if self.validate:
218
+ return JSONResponse(status_code=401, content={"detail": "Token has expired"})
219
+ request.state.authenticated = False
220
+ request.state.token = token
221
+
222
+ except jwt.InvalidTokenError as e:
223
+ if self.validate:
224
+ return JSONResponse(status_code=401, content={"detail": f"Invalid token: {str(e)}"})
225
+ request.state.authenticated = False
226
+ request.state.token = token
227
+ except Exception as e:
228
+ if self.validate:
229
+ return JSONResponse(status_code=401, content={"detail": f"Error decoding token: {str(e)}"})
230
+ request.state.authenticated = False
231
+ request.state.token = token
232
+
233
+ return await call_next(request)