agno 2.2.13__py3-none-any.whl → 2.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/__init__.py +6 -0
- agno/agent/agent.py +5252 -3145
- agno/agent/remote.py +525 -0
- agno/api/api.py +2 -0
- agno/client/__init__.py +3 -0
- agno/client/a2a/__init__.py +10 -0
- agno/client/a2a/client.py +554 -0
- agno/client/a2a/schemas.py +112 -0
- agno/client/a2a/utils.py +369 -0
- agno/client/os.py +2669 -0
- agno/compression/__init__.py +3 -0
- agno/compression/manager.py +247 -0
- agno/culture/manager.py +2 -2
- agno/db/base.py +927 -6
- agno/db/dynamo/dynamo.py +788 -2
- agno/db/dynamo/schemas.py +128 -0
- agno/db/dynamo/utils.py +26 -3
- agno/db/firestore/firestore.py +674 -50
- agno/db/firestore/schemas.py +41 -0
- agno/db/firestore/utils.py +25 -10
- agno/db/gcs_json/gcs_json_db.py +506 -3
- agno/db/gcs_json/utils.py +14 -2
- agno/db/in_memory/in_memory_db.py +203 -4
- agno/db/in_memory/utils.py +14 -2
- agno/db/json/json_db.py +498 -2
- agno/db/json/utils.py +14 -2
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/utils.py +19 -0
- agno/db/migrations/v1_to_v2.py +54 -16
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +977 -0
- agno/db/mongo/async_mongo.py +1013 -39
- agno/db/mongo/mongo.py +684 -4
- agno/db/mongo/schemas.py +48 -0
- agno/db/mongo/utils.py +17 -0
- agno/db/mysql/__init__.py +2 -1
- agno/db/mysql/async_mysql.py +2958 -0
- agno/db/mysql/mysql.py +722 -53
- agno/db/mysql/schemas.py +77 -11
- agno/db/mysql/utils.py +151 -8
- agno/db/postgres/async_postgres.py +1254 -137
- agno/db/postgres/postgres.py +2316 -93
- agno/db/postgres/schemas.py +153 -21
- agno/db/postgres/utils.py +22 -7
- agno/db/redis/redis.py +531 -3
- agno/db/redis/schemas.py +36 -0
- agno/db/redis/utils.py +31 -15
- agno/db/schemas/evals.py +1 -0
- agno/db/schemas/memory.py +20 -9
- agno/db/singlestore/schemas.py +70 -1
- agno/db/singlestore/singlestore.py +737 -74
- agno/db/singlestore/utils.py +13 -3
- agno/db/sqlite/async_sqlite.py +1069 -89
- agno/db/sqlite/schemas.py +133 -1
- agno/db/sqlite/sqlite.py +2203 -165
- agno/db/sqlite/utils.py +21 -11
- agno/db/surrealdb/models.py +25 -0
- agno/db/surrealdb/surrealdb.py +603 -1
- agno/db/utils.py +60 -0
- agno/eval/__init__.py +26 -3
- agno/eval/accuracy.py +25 -12
- agno/eval/agent_as_judge.py +871 -0
- agno/eval/base.py +29 -0
- agno/eval/performance.py +10 -4
- agno/eval/reliability.py +22 -13
- agno/eval/utils.py +2 -1
- agno/exceptions.py +42 -0
- agno/hooks/__init__.py +3 -0
- agno/hooks/decorator.py +164 -0
- agno/integrations/discord/client.py +13 -2
- agno/knowledge/__init__.py +4 -0
- agno/knowledge/chunking/code.py +90 -0
- agno/knowledge/chunking/document.py +65 -4
- agno/knowledge/chunking/fixed.py +4 -1
- agno/knowledge/chunking/markdown.py +102 -11
- agno/knowledge/chunking/recursive.py +2 -2
- agno/knowledge/chunking/semantic.py +130 -48
- agno/knowledge/chunking/strategy.py +18 -0
- agno/knowledge/embedder/azure_openai.py +0 -1
- agno/knowledge/embedder/google.py +1 -1
- agno/knowledge/embedder/mistral.py +1 -1
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/openai.py +16 -12
- agno/knowledge/filesystem.py +412 -0
- agno/knowledge/knowledge.py +4261 -1199
- agno/knowledge/protocol.py +134 -0
- agno/knowledge/reader/arxiv_reader.py +3 -2
- agno/knowledge/reader/base.py +9 -7
- agno/knowledge/reader/csv_reader.py +91 -42
- agno/knowledge/reader/docx_reader.py +9 -10
- agno/knowledge/reader/excel_reader.py +225 -0
- agno/knowledge/reader/field_labeled_csv_reader.py +38 -48
- agno/knowledge/reader/firecrawl_reader.py +3 -2
- agno/knowledge/reader/json_reader.py +16 -22
- agno/knowledge/reader/markdown_reader.py +15 -14
- agno/knowledge/reader/pdf_reader.py +33 -28
- agno/knowledge/reader/pptx_reader.py +9 -10
- agno/knowledge/reader/reader_factory.py +135 -1
- agno/knowledge/reader/s3_reader.py +8 -16
- agno/knowledge/reader/tavily_reader.py +3 -3
- agno/knowledge/reader/text_reader.py +15 -14
- agno/knowledge/reader/utils/__init__.py +17 -0
- agno/knowledge/reader/utils/spreadsheet.py +114 -0
- agno/knowledge/reader/web_search_reader.py +8 -65
- agno/knowledge/reader/website_reader.py +16 -13
- agno/knowledge/reader/wikipedia_reader.py +36 -3
- agno/knowledge/reader/youtube_reader.py +3 -2
- agno/knowledge/remote_content/__init__.py +33 -0
- agno/knowledge/remote_content/config.py +266 -0
- agno/knowledge/remote_content/remote_content.py +105 -17
- agno/knowledge/utils.py +76 -22
- agno/learn/__init__.py +71 -0
- agno/learn/config.py +463 -0
- agno/learn/curate.py +185 -0
- agno/learn/machine.py +725 -0
- agno/learn/schemas.py +1114 -0
- agno/learn/stores/__init__.py +38 -0
- agno/learn/stores/decision_log.py +1156 -0
- agno/learn/stores/entity_memory.py +3275 -0
- agno/learn/stores/learned_knowledge.py +1583 -0
- agno/learn/stores/protocol.py +117 -0
- agno/learn/stores/session_context.py +1217 -0
- agno/learn/stores/user_memory.py +1495 -0
- agno/learn/stores/user_profile.py +1220 -0
- agno/learn/utils.py +209 -0
- agno/media.py +22 -6
- agno/memory/__init__.py +14 -1
- agno/memory/manager.py +223 -8
- agno/memory/strategies/__init__.py +15 -0
- agno/memory/strategies/base.py +66 -0
- agno/memory/strategies/summarize.py +196 -0
- agno/memory/strategies/types.py +37 -0
- agno/models/aimlapi/aimlapi.py +17 -0
- agno/models/anthropic/claude.py +434 -59
- agno/models/aws/bedrock.py +121 -20
- agno/models/aws/claude.py +131 -274
- agno/models/azure/ai_foundry.py +10 -6
- agno/models/azure/openai_chat.py +33 -10
- agno/models/base.py +1162 -561
- agno/models/cerebras/cerebras.py +120 -24
- agno/models/cerebras/cerebras_openai.py +21 -2
- agno/models/cohere/chat.py +65 -6
- agno/models/cometapi/cometapi.py +18 -1
- agno/models/dashscope/dashscope.py +2 -3
- agno/models/deepinfra/deepinfra.py +18 -1
- agno/models/deepseek/deepseek.py +69 -3
- agno/models/fireworks/fireworks.py +18 -1
- agno/models/google/gemini.py +959 -89
- agno/models/google/utils.py +22 -0
- agno/models/groq/groq.py +48 -18
- agno/models/huggingface/huggingface.py +17 -6
- agno/models/ibm/watsonx.py +16 -6
- agno/models/internlm/internlm.py +18 -1
- agno/models/langdb/langdb.py +13 -1
- agno/models/litellm/chat.py +88 -9
- agno/models/litellm/litellm_openai.py +18 -1
- agno/models/message.py +24 -5
- agno/models/meta/llama.py +40 -13
- agno/models/meta/llama_openai.py +22 -21
- agno/models/metrics.py +12 -0
- agno/models/mistral/mistral.py +8 -4
- agno/models/n1n/__init__.py +3 -0
- agno/models/n1n/n1n.py +57 -0
- agno/models/nebius/nebius.py +6 -7
- agno/models/nvidia/nvidia.py +20 -3
- agno/models/ollama/__init__.py +2 -0
- agno/models/ollama/chat.py +17 -6
- agno/models/ollama/responses.py +100 -0
- agno/models/openai/__init__.py +2 -0
- agno/models/openai/chat.py +117 -26
- agno/models/openai/open_responses.py +46 -0
- agno/models/openai/responses.py +110 -32
- agno/models/openrouter/__init__.py +2 -0
- agno/models/openrouter/openrouter.py +67 -2
- agno/models/openrouter/responses.py +146 -0
- agno/models/perplexity/perplexity.py +19 -1
- agno/models/portkey/portkey.py +7 -6
- agno/models/requesty/requesty.py +19 -2
- agno/models/response.py +20 -2
- agno/models/sambanova/sambanova.py +20 -3
- agno/models/siliconflow/siliconflow.py +19 -2
- agno/models/together/together.py +20 -3
- agno/models/vercel/v0.py +20 -3
- agno/models/vertexai/claude.py +124 -4
- agno/models/vllm/vllm.py +19 -14
- agno/models/xai/xai.py +19 -2
- agno/os/app.py +467 -137
- agno/os/auth.py +253 -5
- agno/os/config.py +22 -0
- agno/os/interfaces/a2a/a2a.py +7 -6
- agno/os/interfaces/a2a/router.py +635 -26
- agno/os/interfaces/a2a/utils.py +32 -33
- agno/os/interfaces/agui/agui.py +5 -3
- agno/os/interfaces/agui/router.py +26 -16
- agno/os/interfaces/agui/utils.py +97 -57
- agno/os/interfaces/base.py +7 -7
- agno/os/interfaces/slack/router.py +16 -7
- agno/os/interfaces/slack/slack.py +7 -7
- agno/os/interfaces/whatsapp/router.py +35 -7
- agno/os/interfaces/whatsapp/security.py +3 -1
- agno/os/interfaces/whatsapp/whatsapp.py +11 -8
- agno/os/managers.py +326 -0
- agno/os/mcp.py +652 -79
- agno/os/middleware/__init__.py +4 -0
- agno/os/middleware/jwt.py +718 -115
- agno/os/middleware/trailing_slash.py +27 -0
- agno/os/router.py +105 -1558
- agno/os/routers/agents/__init__.py +3 -0
- agno/os/routers/agents/router.py +655 -0
- agno/os/routers/agents/schema.py +288 -0
- agno/os/routers/components/__init__.py +3 -0
- agno/os/routers/components/components.py +475 -0
- agno/os/routers/database.py +155 -0
- agno/os/routers/evals/evals.py +111 -18
- agno/os/routers/evals/schemas.py +38 -5
- agno/os/routers/evals/utils.py +80 -11
- agno/os/routers/health.py +3 -3
- agno/os/routers/knowledge/knowledge.py +284 -35
- agno/os/routers/knowledge/schemas.py +14 -2
- agno/os/routers/memory/memory.py +274 -11
- agno/os/routers/memory/schemas.py +44 -3
- agno/os/routers/metrics/metrics.py +30 -15
- agno/os/routers/metrics/schemas.py +10 -6
- agno/os/routers/registry/__init__.py +3 -0
- agno/os/routers/registry/registry.py +337 -0
- agno/os/routers/session/session.py +143 -14
- agno/os/routers/teams/__init__.py +3 -0
- agno/os/routers/teams/router.py +550 -0
- agno/os/routers/teams/schema.py +280 -0
- agno/os/routers/traces/__init__.py +3 -0
- agno/os/routers/traces/schemas.py +414 -0
- agno/os/routers/traces/traces.py +549 -0
- agno/os/routers/workflows/__init__.py +3 -0
- agno/os/routers/workflows/router.py +757 -0
- agno/os/routers/workflows/schema.py +139 -0
- agno/os/schema.py +157 -584
- agno/os/scopes.py +469 -0
- agno/os/settings.py +3 -0
- agno/os/utils.py +574 -185
- agno/reasoning/anthropic.py +85 -1
- agno/reasoning/azure_ai_foundry.py +93 -1
- agno/reasoning/deepseek.py +102 -2
- agno/reasoning/default.py +6 -7
- agno/reasoning/gemini.py +87 -3
- agno/reasoning/groq.py +109 -2
- agno/reasoning/helpers.py +6 -7
- agno/reasoning/manager.py +1238 -0
- agno/reasoning/ollama.py +93 -1
- agno/reasoning/openai.py +115 -1
- agno/reasoning/vertexai.py +85 -1
- agno/registry/__init__.py +3 -0
- agno/registry/registry.py +68 -0
- agno/remote/__init__.py +3 -0
- agno/remote/base.py +581 -0
- agno/run/__init__.py +2 -4
- agno/run/agent.py +134 -19
- agno/run/base.py +49 -1
- agno/run/cancel.py +65 -52
- agno/run/cancellation_management/__init__.py +9 -0
- agno/run/cancellation_management/base.py +78 -0
- agno/run/cancellation_management/in_memory_cancellation_manager.py +100 -0
- agno/run/cancellation_management/redis_cancellation_manager.py +236 -0
- agno/run/requirement.py +181 -0
- agno/run/team.py +111 -19
- agno/run/workflow.py +2 -1
- agno/session/agent.py +57 -92
- agno/session/summary.py +1 -1
- agno/session/team.py +62 -115
- agno/session/workflow.py +353 -57
- agno/skills/__init__.py +17 -0
- agno/skills/agent_skills.py +377 -0
- agno/skills/errors.py +32 -0
- agno/skills/loaders/__init__.py +4 -0
- agno/skills/loaders/base.py +27 -0
- agno/skills/loaders/local.py +216 -0
- agno/skills/skill.py +65 -0
- agno/skills/utils.py +107 -0
- agno/skills/validator.py +277 -0
- agno/table.py +10 -0
- agno/team/__init__.py +5 -1
- agno/team/remote.py +447 -0
- agno/team/team.py +3769 -2202
- agno/tools/brandfetch.py +27 -18
- agno/tools/browserbase.py +225 -16
- agno/tools/crawl4ai.py +3 -0
- agno/tools/duckduckgo.py +25 -71
- agno/tools/exa.py +0 -21
- agno/tools/file.py +14 -13
- agno/tools/file_generation.py +12 -6
- agno/tools/firecrawl.py +15 -7
- agno/tools/function.py +94 -113
- agno/tools/google_bigquery.py +11 -2
- agno/tools/google_drive.py +4 -3
- agno/tools/knowledge.py +9 -4
- agno/tools/mcp/mcp.py +301 -18
- agno/tools/mcp/multi_mcp.py +269 -14
- agno/tools/mem0.py +11 -10
- agno/tools/memory.py +47 -46
- agno/tools/mlx_transcribe.py +10 -7
- agno/tools/models/nebius.py +5 -5
- agno/tools/models_labs.py +20 -10
- agno/tools/nano_banana.py +151 -0
- agno/tools/parallel.py +0 -7
- agno/tools/postgres.py +76 -36
- agno/tools/python.py +14 -6
- agno/tools/reasoning.py +30 -23
- agno/tools/redshift.py +406 -0
- agno/tools/shopify.py +1519 -0
- agno/tools/spotify.py +919 -0
- agno/tools/tavily.py +4 -1
- agno/tools/toolkit.py +253 -18
- agno/tools/websearch.py +93 -0
- agno/tools/website.py +1 -1
- agno/tools/wikipedia.py +1 -1
- agno/tools/workflow.py +56 -48
- agno/tools/yfinance.py +12 -11
- agno/tracing/__init__.py +12 -0
- agno/tracing/exporter.py +161 -0
- agno/tracing/schemas.py +276 -0
- agno/tracing/setup.py +112 -0
- agno/utils/agent.py +251 -10
- agno/utils/cryptography.py +22 -0
- agno/utils/dttm.py +33 -0
- agno/utils/events.py +264 -7
- agno/utils/hooks.py +111 -3
- agno/utils/http.py +161 -2
- agno/utils/mcp.py +49 -8
- agno/utils/media.py +22 -1
- agno/utils/models/ai_foundry.py +9 -2
- agno/utils/models/claude.py +20 -5
- agno/utils/models/cohere.py +9 -2
- agno/utils/models/llama.py +9 -2
- agno/utils/models/mistral.py +4 -2
- agno/utils/os.py +0 -0
- agno/utils/print_response/agent.py +99 -16
- agno/utils/print_response/team.py +223 -24
- agno/utils/print_response/workflow.py +0 -2
- agno/utils/prompts.py +8 -6
- agno/utils/remote.py +23 -0
- agno/utils/response.py +1 -13
- agno/utils/string.py +91 -2
- agno/utils/team.py +62 -12
- agno/utils/tokens.py +657 -0
- agno/vectordb/base.py +15 -2
- agno/vectordb/cassandra/cassandra.py +1 -1
- agno/vectordb/chroma/__init__.py +2 -1
- agno/vectordb/chroma/chromadb.py +468 -23
- agno/vectordb/clickhouse/clickhousedb.py +1 -1
- agno/vectordb/couchbase/couchbase.py +6 -2
- agno/vectordb/lancedb/lance_db.py +7 -38
- agno/vectordb/lightrag/lightrag.py +7 -6
- agno/vectordb/milvus/milvus.py +118 -84
- agno/vectordb/mongodb/__init__.py +2 -1
- agno/vectordb/mongodb/mongodb.py +14 -31
- agno/vectordb/pgvector/pgvector.py +120 -66
- agno/vectordb/pineconedb/pineconedb.py +2 -19
- agno/vectordb/qdrant/__init__.py +2 -1
- agno/vectordb/qdrant/qdrant.py +33 -56
- agno/vectordb/redis/__init__.py +2 -1
- agno/vectordb/redis/redisdb.py +19 -31
- agno/vectordb/singlestore/singlestore.py +17 -9
- agno/vectordb/surrealdb/surrealdb.py +2 -38
- agno/vectordb/weaviate/__init__.py +2 -1
- agno/vectordb/weaviate/weaviate.py +7 -3
- agno/workflow/__init__.py +5 -1
- agno/workflow/agent.py +2 -2
- agno/workflow/condition.py +12 -10
- agno/workflow/loop.py +28 -9
- agno/workflow/parallel.py +21 -13
- agno/workflow/remote.py +362 -0
- agno/workflow/router.py +12 -9
- agno/workflow/step.py +261 -36
- agno/workflow/steps.py +12 -8
- agno/workflow/types.py +40 -77
- agno/workflow/workflow.py +939 -213
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/METADATA +134 -181
- agno-2.4.3.dist-info/RECORD +677 -0
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/WHEEL +1 -1
- agno/tools/googlesearch.py +0 -98
- agno/tools/memori.py +0 -339
- agno-2.2.13.dist-info/RECORD +0 -575
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/licenses/LICENSE +0 -0
- {agno-2.2.13.dist-info → agno-2.4.3.dist-info}/top_level.txt +0 -0
agno/os/middleware/jwt.py
CHANGED
|
@@ -1,14 +1,25 @@
|
|
|
1
|
+
"""JWT Middleware for AgentOS - JWT Authentication with optional RBAC."""
|
|
2
|
+
|
|
1
3
|
import fnmatch
|
|
4
|
+
import json
|
|
5
|
+
import re
|
|
2
6
|
from enum import Enum
|
|
3
7
|
from os import getenv
|
|
4
|
-
from typing import List, Optional
|
|
8
|
+
from typing import Any, Dict, Iterable, List, Optional, Union
|
|
5
9
|
|
|
6
10
|
import jwt
|
|
7
11
|
from fastapi import Request, Response
|
|
8
12
|
from fastapi.responses import JSONResponse
|
|
13
|
+
from jwt import PyJWK
|
|
9
14
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
10
15
|
|
|
11
|
-
from agno.
|
|
16
|
+
from agno.os.scopes import (
|
|
17
|
+
AgentOSScope,
|
|
18
|
+
get_accessible_resource_ids,
|
|
19
|
+
get_default_scope_mappings,
|
|
20
|
+
has_required_scopes,
|
|
21
|
+
)
|
|
22
|
+
from agno.utils.log import log_debug, log_warning
|
|
12
23
|
|
|
13
24
|
|
|
14
25
|
class TokenSource(str, Enum):
|
|
@@ -19,78 +30,575 @@ class TokenSource(str, Enum):
|
|
|
19
30
|
BOTH = "both" # Try header first, then cookie
|
|
20
31
|
|
|
21
32
|
|
|
33
|
+
class JWTValidator:
|
|
34
|
+
"""
|
|
35
|
+
JWT token validator that can be used standalone or within JWTMiddleware.
|
|
36
|
+
|
|
37
|
+
This class handles:
|
|
38
|
+
- Loading verification keys (static keys or JWKS files)
|
|
39
|
+
- Validating JWT signatures
|
|
40
|
+
- Extracting claims from tokens
|
|
41
|
+
|
|
42
|
+
It can be stored on app.state for use by WebSocket handlers or other
|
|
43
|
+
components that need JWT validation outside of the HTTP middleware chain.
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
# Create validator
|
|
47
|
+
validator = JWTValidator(
|
|
48
|
+
verification_keys=["your-public-key"],
|
|
49
|
+
algorithm="RS256",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# Validate a token
|
|
53
|
+
try:
|
|
54
|
+
payload = validator.validate(token)
|
|
55
|
+
user_id = payload.get("sub")
|
|
56
|
+
scopes = payload.get("scopes", [])
|
|
57
|
+
except jwt.InvalidTokenError as e:
|
|
58
|
+
print(f"Invalid token: {e}")
|
|
59
|
+
|
|
60
|
+
# Store on app.state for WebSocket access
|
|
61
|
+
app.state.jwt_validator = validator
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
verification_keys: Optional[List[str]] = None,
|
|
67
|
+
jwks_file: Optional[str] = None,
|
|
68
|
+
algorithm: str = "RS256",
|
|
69
|
+
validate: bool = True,
|
|
70
|
+
scopes_claim: str = "scopes",
|
|
71
|
+
user_id_claim: str = "sub",
|
|
72
|
+
session_id_claim: str = "session_id",
|
|
73
|
+
audience_claim: str = "aud",
|
|
74
|
+
leeway: int = 10,
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
Initialize the JWT validator.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
verification_keys: List of keys for verifying JWT signatures.
|
|
81
|
+
For asymmetric algorithms (RS256, ES256), these should be public keys.
|
|
82
|
+
For symmetric algorithms (HS256), these are shared secrets.
|
|
83
|
+
jwks_file: Path to a static JWKS (JSON Web Key Set) file containing public keys.
|
|
84
|
+
algorithm: JWT algorithm (default: RS256).
|
|
85
|
+
validate: Whether to validate the JWT token (default: True).
|
|
86
|
+
scopes_claim: JWT claim name for scopes (default: "scopes").
|
|
87
|
+
user_id_claim: JWT claim name for user ID (default: "sub").
|
|
88
|
+
session_id_claim: JWT claim name for session ID (default: "session_id").
|
|
89
|
+
audience_claim: JWT claim name for audience (default: "aud").
|
|
90
|
+
leeway: Seconds of leeway for clock skew tolerance (default: 10).
|
|
91
|
+
"""
|
|
92
|
+
self.algorithm = algorithm
|
|
93
|
+
self.validate = validate
|
|
94
|
+
self.scopes_claim = scopes_claim
|
|
95
|
+
self.user_id_claim = user_id_claim
|
|
96
|
+
self.session_id_claim = session_id_claim
|
|
97
|
+
self.audience_claim = audience_claim
|
|
98
|
+
self.leeway = leeway
|
|
99
|
+
|
|
100
|
+
# Build list of verification keys
|
|
101
|
+
self.verification_keys: List[str] = []
|
|
102
|
+
if verification_keys:
|
|
103
|
+
self.verification_keys.extend(verification_keys)
|
|
104
|
+
|
|
105
|
+
# Add key from environment variable if not already provided
|
|
106
|
+
env_key = getenv("JWT_VERIFICATION_KEY", "")
|
|
107
|
+
if env_key and env_key not in self.verification_keys:
|
|
108
|
+
self.verification_keys.append(env_key)
|
|
109
|
+
|
|
110
|
+
# JWKS configuration - load keys from JWKS file or environment variable
|
|
111
|
+
self.jwks_keys: Dict[str, PyJWK] = {} # kid -> PyJWK mapping
|
|
112
|
+
|
|
113
|
+
# Try jwks_file parameter first
|
|
114
|
+
if jwks_file:
|
|
115
|
+
self._load_jwks_file(jwks_file)
|
|
116
|
+
else:
|
|
117
|
+
# Try JWT_JWKS_FILE env var (path to file)
|
|
118
|
+
jwks_file_env = getenv("JWT_JWKS_FILE", "")
|
|
119
|
+
if jwks_file_env:
|
|
120
|
+
self._load_jwks_file(jwks_file_env)
|
|
121
|
+
|
|
122
|
+
# Validate that at least one key source is provided if validate=True
|
|
123
|
+
if self.validate and not self.verification_keys and not self.jwks_keys:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
"At least one JWT verification key or JWKS file is required when validate=True. "
|
|
126
|
+
"Set via verification_keys parameter, JWT_VERIFICATION_KEY environment variable, "
|
|
127
|
+
"jwks_file parameter or JWT_JWKS_FILE environment variable."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def _load_jwks_file(self, file_path: str) -> None:
|
|
131
|
+
"""
|
|
132
|
+
Load keys from a static JWKS file.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
file_path: Path to the JWKS JSON file
|
|
136
|
+
"""
|
|
137
|
+
try:
|
|
138
|
+
with open(file_path) as f:
|
|
139
|
+
jwks_data = json.load(f)
|
|
140
|
+
self._parse_jwks_data(jwks_data)
|
|
141
|
+
log_debug(f"Loaded {len(self.jwks_keys)} key(s) from JWKS file: {file_path}")
|
|
142
|
+
except FileNotFoundError:
|
|
143
|
+
raise ValueError(f"JWKS file not found: {file_path}")
|
|
144
|
+
except json.JSONDecodeError as e:
|
|
145
|
+
raise ValueError(f"Invalid JSON in JWKS file {file_path}: {e}")
|
|
146
|
+
|
|
147
|
+
def _parse_jwks_data(self, jwks_data: Dict[str, Any]) -> None:
|
|
148
|
+
"""
|
|
149
|
+
Parse JWKS data and populate self.jwks_keys.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
jwks_data: Parsed JWKS dictionary with "keys" array
|
|
153
|
+
"""
|
|
154
|
+
keys = jwks_data.get("keys", [])
|
|
155
|
+
if not keys:
|
|
156
|
+
log_warning("JWKS contains no keys")
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
for key_data in keys:
|
|
160
|
+
try:
|
|
161
|
+
kid = key_data.get("kid")
|
|
162
|
+
jwk = PyJWK.from_dict(key_data)
|
|
163
|
+
if kid:
|
|
164
|
+
self.jwks_keys[kid] = jwk
|
|
165
|
+
else:
|
|
166
|
+
# If no kid, use a default key (for single-key JWKS)
|
|
167
|
+
self.jwks_keys["_default"] = jwk
|
|
168
|
+
except Exception as e:
|
|
169
|
+
log_warning(f"Failed to parse JWKS key: {e}")
|
|
170
|
+
|
|
171
|
+
def validate_token(
|
|
172
|
+
self, token: str, expected_audience: Optional[Union[str, Iterable[str]]] = None
|
|
173
|
+
) -> Dict[str, Any]:
|
|
174
|
+
"""
|
|
175
|
+
Validate JWT token and extract claims.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
token: The JWT token to validate
|
|
179
|
+
expected_audience: The expected audience to verify (optional)
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Dictionary of claims if valid
|
|
183
|
+
|
|
184
|
+
Raises:
|
|
185
|
+
jwt.InvalidAudienceError: If audience claim doesn't match expected
|
|
186
|
+
jwt.ExpiredSignatureError: If token has expired
|
|
187
|
+
jwt.InvalidTokenError: If token is invalid
|
|
188
|
+
"""
|
|
189
|
+
decode_options: Dict[str, Any] = {}
|
|
190
|
+
decode_kwargs: Dict[str, Any] = {
|
|
191
|
+
"algorithms": [self.algorithm],
|
|
192
|
+
"leeway": self.leeway,
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
# Configure audience verification
|
|
196
|
+
# We'll decode without audience verification and if we need to verify the audience,
|
|
197
|
+
# we'll manually verify the audience to provide better error messages
|
|
198
|
+
decode_options["verify_aud"] = False
|
|
199
|
+
|
|
200
|
+
# If validation is disabled, decode without signature verification
|
|
201
|
+
if not self.validate:
|
|
202
|
+
decode_options["verify_signature"] = False
|
|
203
|
+
decode_kwargs["options"] = decode_options
|
|
204
|
+
return jwt.decode(token, **decode_kwargs)
|
|
205
|
+
|
|
206
|
+
if decode_options:
|
|
207
|
+
decode_kwargs["options"] = decode_options
|
|
208
|
+
|
|
209
|
+
last_exception: Optional[Exception] = None
|
|
210
|
+
payload: Optional[Dict[str, Any]] = None
|
|
211
|
+
|
|
212
|
+
# Try JWKS keys first if configured
|
|
213
|
+
if self.jwks_keys:
|
|
214
|
+
try:
|
|
215
|
+
# Get the kid from the token header to find the right key
|
|
216
|
+
unverified_header = jwt.get_unverified_header(token)
|
|
217
|
+
kid = unverified_header.get("kid")
|
|
218
|
+
|
|
219
|
+
jwk = None
|
|
220
|
+
if kid and kid in self.jwks_keys:
|
|
221
|
+
jwk = self.jwks_keys[kid]
|
|
222
|
+
elif "_default" in self.jwks_keys:
|
|
223
|
+
# Fall back to default key if no kid match
|
|
224
|
+
jwk = self.jwks_keys["_default"]
|
|
225
|
+
|
|
226
|
+
if jwk:
|
|
227
|
+
payload = jwt.decode(token, jwk.key, **decode_kwargs)
|
|
228
|
+
except jwt.ExpiredSignatureError:
|
|
229
|
+
raise
|
|
230
|
+
except jwt.InvalidTokenError as e:
|
|
231
|
+
if not self.verification_keys:
|
|
232
|
+
raise
|
|
233
|
+
last_exception = e
|
|
234
|
+
|
|
235
|
+
# Try each static verification key until one succeeds
|
|
236
|
+
if payload is None:
|
|
237
|
+
for key in self.verification_keys:
|
|
238
|
+
try:
|
|
239
|
+
payload = jwt.decode(token, key, **decode_kwargs)
|
|
240
|
+
break
|
|
241
|
+
except jwt.ExpiredSignatureError:
|
|
242
|
+
raise
|
|
243
|
+
except jwt.InvalidTokenError as e:
|
|
244
|
+
last_exception = e
|
|
245
|
+
continue
|
|
246
|
+
|
|
247
|
+
if payload is None:
|
|
248
|
+
if last_exception:
|
|
249
|
+
raise last_exception
|
|
250
|
+
raise jwt.InvalidTokenError("No verification keys configured")
|
|
251
|
+
|
|
252
|
+
# Manually verify audience if expected_audience was provided
|
|
253
|
+
if expected_audience:
|
|
254
|
+
token_audience = payload.get(self.audience_claim)
|
|
255
|
+
if token_audience is None:
|
|
256
|
+
raise jwt.InvalidTokenError(
|
|
257
|
+
f'Token is missing the "{self.audience_claim}" claim. '
|
|
258
|
+
f"Audience verification requires this claim to be present in the token."
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# Normalize expected_audience to a list
|
|
262
|
+
if isinstance(expected_audience, str):
|
|
263
|
+
expected_audiences = [expected_audience]
|
|
264
|
+
elif isinstance(expected_audience, Iterable):
|
|
265
|
+
expected_audiences = list(expected_audience)
|
|
266
|
+
else:
|
|
267
|
+
expected_audiences = []
|
|
268
|
+
|
|
269
|
+
# Normalize token_audience to a list
|
|
270
|
+
if isinstance(token_audience, str):
|
|
271
|
+
token_audiences = [token_audience]
|
|
272
|
+
elif isinstance(token_audience, list):
|
|
273
|
+
token_audiences = token_audience
|
|
274
|
+
else:
|
|
275
|
+
token_audiences = [token_audience] if token_audience else []
|
|
276
|
+
|
|
277
|
+
# Check if any token audience matches any expected audience
|
|
278
|
+
if not any(aud in expected_audiences for aud in token_audiences):
|
|
279
|
+
raise jwt.InvalidAudienceError(
|
|
280
|
+
f"Invalid audience. Expected one of: {expected_audiences}, got: {token_audiences}"
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
return payload
|
|
284
|
+
|
|
285
|
+
def extract_claims(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
286
|
+
"""
|
|
287
|
+
Extract standard claims from a JWT payload.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
payload: The decoded JWT payload
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
Dictionary with user_id, session_id, scopes, and audience
|
|
294
|
+
"""
|
|
295
|
+
scopes = payload.get(self.scopes_claim, [])
|
|
296
|
+
if isinstance(scopes, str):
|
|
297
|
+
scopes = [scopes]
|
|
298
|
+
elif not isinstance(scopes, list):
|
|
299
|
+
scopes = []
|
|
300
|
+
|
|
301
|
+
return {
|
|
302
|
+
"user_id": payload.get(self.user_id_claim),
|
|
303
|
+
"session_id": payload.get(self.session_id_claim),
|
|
304
|
+
"scopes": scopes,
|
|
305
|
+
"audience": payload.get(self.audience_claim),
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
|
|
22
309
|
class JWTMiddleware(BaseHTTPMiddleware):
|
|
23
310
|
"""
|
|
24
|
-
JWT Middleware
|
|
311
|
+
JWT Authentication Middleware with optional RBAC (Role-Based Access Control).
|
|
25
312
|
|
|
26
313
|
This middleware:
|
|
27
|
-
1. Extracts JWT token from Authorization header
|
|
314
|
+
1. Extracts JWT token from Authorization header or cookies
|
|
28
315
|
2. Decodes and validates the token
|
|
29
|
-
3.
|
|
316
|
+
3. Validates the `aud` (audience) claim matches the AgentOS ID (if configured)
|
|
317
|
+
4. Stores JWT claims (user_id, session_id, scopes) in request.state
|
|
318
|
+
5. Optionally checks if the request path requires specific scopes (if scope_mappings provided)
|
|
319
|
+
6. Validates that the authenticated user has the required scopes
|
|
320
|
+
7. Returns 401 for invalid tokens, 403 for insufficient scopes
|
|
321
|
+
|
|
322
|
+
RBAC is opt-in: Only enabled when authorization=True or scope_mappings are provided.
|
|
323
|
+
Without authorization enabled, the middleware only extracts and validates JWT tokens.
|
|
324
|
+
|
|
325
|
+
Audience Verification:
|
|
326
|
+
- The `aud` claim in JWT tokens should contain the AgentOS ID
|
|
327
|
+
- This is verified against the AgentOS instance ID from app.state.agent_os_id
|
|
328
|
+
- Tokens with mismatched audience will be rejected with 401
|
|
329
|
+
|
|
330
|
+
Scope Format (simplified):
|
|
331
|
+
- Global resource scopes: `resource:action` (e.g., "agents:read")
|
|
332
|
+
- Per-resource scopes: `resource:<resource-id>:action` (e.g., "agents:web-agent:run")
|
|
333
|
+
- Wildcards: `resource:*:action` (e.g., "agents:*:run")
|
|
334
|
+
- Admin scope: `admin` (grants all permissions)
|
|
30
335
|
|
|
31
336
|
Token Sources:
|
|
32
337
|
- "header": Extract from Authorization header (default)
|
|
33
338
|
- "cookie": Extract from HTTP cookie
|
|
34
339
|
- "both": Try header first, then cookie as fallback
|
|
35
340
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
341
|
+
Example:
|
|
342
|
+
from agno.os.middleware import JWTMiddleware
|
|
343
|
+
from agno.os.scopes import AgentOSScope
|
|
344
|
+
|
|
345
|
+
# Single verification key
|
|
346
|
+
app.add_middleware(
|
|
347
|
+
JWTMiddleware,
|
|
348
|
+
verification_keys=["your-public-key"],
|
|
349
|
+
authorization=True,
|
|
350
|
+
verify_audience=True, # Verify aud claim matches AgentOS ID
|
|
351
|
+
scope_mappings={
|
|
352
|
+
# Override default scope for this endpoint
|
|
353
|
+
"GET /agents": ["agents:read"],
|
|
354
|
+
# Add new endpoint mapping
|
|
355
|
+
"POST /custom/endpoint": ["agents:run"],
|
|
356
|
+
# Allow access without scopes
|
|
357
|
+
"GET /public/stats": [],
|
|
358
|
+
}
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Multiple verification keys (accept tokens from multiple issuers)
|
|
362
|
+
app.add_middleware(
|
|
363
|
+
JWTMiddleware,
|
|
364
|
+
verification_keys=[
|
|
365
|
+
"public-key-from-issuer-1",
|
|
366
|
+
"public-key-from-issuer-2",
|
|
367
|
+
],
|
|
368
|
+
authorization=True,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Using a static JWKS file
|
|
372
|
+
app.add_middleware(
|
|
373
|
+
JWTMiddleware,
|
|
374
|
+
jwks_file="/path/to/jwks.json",
|
|
375
|
+
authorization=True,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# No validation (extract claims only, useful for development)
|
|
379
|
+
app.add_middleware(
|
|
380
|
+
JWTMiddleware,
|
|
381
|
+
validate=False, # No verification key needed
|
|
382
|
+
)
|
|
43
383
|
"""
|
|
44
384
|
|
|
45
385
|
def __init__(
|
|
46
386
|
self,
|
|
47
387
|
app,
|
|
48
|
-
|
|
49
|
-
|
|
388
|
+
verification_keys: Optional[List[str]] = None,
|
|
389
|
+
jwks_file: Optional[str] = None,
|
|
390
|
+
secret_key: Optional[str] = None, # Deprecated: Use verification_keys instead
|
|
391
|
+
algorithm: str = "RS256",
|
|
392
|
+
validate: bool = True,
|
|
393
|
+
authorization: Optional[bool] = None,
|
|
50
394
|
token_source: TokenSource = TokenSource.HEADER,
|
|
51
395
|
token_header_key: str = "Authorization",
|
|
52
396
|
cookie_name: str = "access_token",
|
|
53
|
-
|
|
54
|
-
excluded_route_paths: Optional[List[str]] = None,
|
|
55
|
-
scopes_claim: Optional[str] = None,
|
|
397
|
+
scopes_claim: str = "scopes",
|
|
56
398
|
user_id_claim: str = "sub",
|
|
57
399
|
session_id_claim: str = "session_id",
|
|
400
|
+
audience_claim: str = "aud",
|
|
401
|
+
audience: Optional[Union[str, Iterable[str]]] = None,
|
|
402
|
+
verify_audience: bool = False,
|
|
58
403
|
dependencies_claims: Optional[List[str]] = None,
|
|
59
404
|
session_state_claims: Optional[List[str]] = None,
|
|
405
|
+
scope_mappings: Optional[Dict[str, List[str]]] = None,
|
|
406
|
+
excluded_route_paths: Optional[List[str]] = None,
|
|
407
|
+
admin_scope: Optional[str] = None,
|
|
60
408
|
):
|
|
61
409
|
"""
|
|
62
410
|
Initialize the JWT middleware.
|
|
63
411
|
|
|
64
412
|
Args:
|
|
65
413
|
app: The FastAPI app instance
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
414
|
+
verification_keys: List of keys for verifying JWT signatures.
|
|
415
|
+
For asymmetric algorithms (RS256, ES256), these should be public keys.
|
|
416
|
+
For symmetric algorithms (HS256), these are shared secrets.
|
|
417
|
+
Each key will be tried in order until one successfully validates the token.
|
|
418
|
+
Useful when accepting tokens signed by different private keys.
|
|
419
|
+
If not provided, will use JWT_VERIFICATION_KEY env var (as a single-item list).
|
|
420
|
+
jwks_file: Path to a static JWKS (JSON Web Key Set) file containing public keys.
|
|
421
|
+
The file should contain a JSON object with a "keys" array.
|
|
422
|
+
Keys are looked up by the "kid" (key ID) claim in the JWT header.
|
|
423
|
+
If not provided, will check JWT_JWKS_FILE env var for a file path,
|
|
424
|
+
or JWT_JWKS env var for inline JWKS JSON content.
|
|
425
|
+
secret_key: (deprecated) Use verification_keys instead. If provided, will be added to verification_keys.
|
|
426
|
+
algorithm: JWT algorithm (default: RS256). Common options: RS256 (asymmetric), HS256 (symmetric).
|
|
427
|
+
validate: Whether to validate the JWT signature (default: True). If False, tokens are decoded
|
|
428
|
+
without signature verification and no verification key is required. Useful when
|
|
429
|
+
JWT verification is handled upstream (API Gateway, etc.).
|
|
430
|
+
authorization: Whether to add authorization checks to the request (i.e. validation of scopes)
|
|
431
|
+
token_source: Where to extract JWT token from (header, cookie, or both)
|
|
432
|
+
token_header_key: Header key for Authorization (default: "Authorization")
|
|
433
|
+
cookie_name: Cookie name for JWT token (default: "access_token")
|
|
434
|
+
scopes_claim: JWT claim name for scopes (default: "scopes")
|
|
435
|
+
user_id_claim: JWT claim name for user ID (default: "sub")
|
|
436
|
+
session_id_claim: JWT claim name for session ID (default: "session_id")
|
|
437
|
+
audience_claim: JWT claim name for audience/OS ID (default: "aud")
|
|
438
|
+
audience: Optional expected audience claim to validate against the token's audience claim (default: AgentOS ID)
|
|
439
|
+
verify_audience: Whether to verify the token's audience claim matches the expected audience claim (default: False)
|
|
76
440
|
dependencies_claims: A list of claims to extract from the JWT token for dependencies
|
|
77
441
|
session_state_claims: A list of claims to extract from the JWT token for session state
|
|
442
|
+
scope_mappings: Optional dictionary mapping route patterns to required scopes.
|
|
443
|
+
If None, RBAC is disabled and only JWT extraction/validation happens.
|
|
444
|
+
If provided, mappings are ADDITIVE to default scope mappings (overrides on conflict).
|
|
445
|
+
Use empty list [] to explicitly allow access without scopes for a route.
|
|
446
|
+
Format: {"POST /agents/*/runs": ["agents:run"], "GET /public": []}
|
|
447
|
+
excluded_route_paths: List of route paths to exclude from JWT/RBAC checks
|
|
448
|
+
admin_scope: The scope that grants admin access (default: "agent_os:admin")
|
|
449
|
+
|
|
450
|
+
Note:
|
|
451
|
+
- At least one verification key or JWKS file must be provided if validate=True
|
|
452
|
+
- If validate=False, no verification key is needed (claims are extracted without verification)
|
|
453
|
+
- JWKS keys are tried first (by kid), then static verification_keys as fallback
|
|
454
|
+
- CORS allowed origins are read from app.state.cors_allowed_origins (set by AgentOS).
|
|
455
|
+
This allows error responses to include proper CORS headers.
|
|
78
456
|
"""
|
|
79
457
|
super().__init__(app)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
458
|
+
|
|
459
|
+
# Handle deprecated secret_key parameter
|
|
460
|
+
all_verification_keys = list(verification_keys) if verification_keys else []
|
|
461
|
+
if secret_key:
|
|
462
|
+
log_warning("secret_key is deprecated. Use verification_keys instead.")
|
|
463
|
+
if secret_key not in all_verification_keys:
|
|
464
|
+
all_verification_keys.append(secret_key)
|
|
465
|
+
|
|
466
|
+
# Create the JWT validator (handles key loading and token validation)
|
|
467
|
+
self.validator = JWTValidator(
|
|
468
|
+
verification_keys=all_verification_keys if all_verification_keys else None,
|
|
469
|
+
jwks_file=jwks_file,
|
|
470
|
+
algorithm=algorithm,
|
|
471
|
+
validate=validate,
|
|
472
|
+
scopes_claim=scopes_claim,
|
|
473
|
+
user_id_claim=user_id_claim,
|
|
474
|
+
session_id_claim=session_id_claim,
|
|
475
|
+
audience_claim=audience_claim,
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
# Store config for easy access
|
|
479
|
+
self.validate = validate
|
|
83
480
|
self.algorithm = algorithm
|
|
84
|
-
self.token_header_key = token_header_key
|
|
85
481
|
self.token_source = token_source
|
|
482
|
+
self.token_header_key = token_header_key
|
|
86
483
|
self.cookie_name = cookie_name
|
|
87
|
-
self.validate = validate
|
|
88
|
-
self.excluded_route_paths = excluded_route_paths
|
|
89
484
|
self.scopes_claim = scopes_claim
|
|
90
485
|
self.user_id_claim = user_id_claim
|
|
91
486
|
self.session_id_claim = session_id_claim
|
|
92
|
-
self.
|
|
93
|
-
self.
|
|
487
|
+
self.audience_claim = audience_claim
|
|
488
|
+
self.verify_audience = verify_audience
|
|
489
|
+
self.dependencies_claims: List[str] = dependencies_claims or []
|
|
490
|
+
self.session_state_claims: List[str] = session_state_claims or []
|
|
491
|
+
|
|
492
|
+
self.audience = audience
|
|
493
|
+
|
|
494
|
+
# RBAC configuration (opt-in via scope_mappings)
|
|
495
|
+
self.authorization = authorization
|
|
496
|
+
|
|
497
|
+
# If scope_mappings are provided, enable authorization
|
|
498
|
+
if scope_mappings is not None and self.authorization is None:
|
|
499
|
+
self.authorization = True
|
|
500
|
+
|
|
501
|
+
# Build final scope mappings (additive approach)
|
|
502
|
+
if self.authorization:
|
|
503
|
+
# Start with default scope mappings
|
|
504
|
+
self.scope_mappings = get_default_scope_mappings()
|
|
505
|
+
|
|
506
|
+
# Merge user-provided scope mappings (overrides defaults)
|
|
507
|
+
if scope_mappings is not None:
|
|
508
|
+
self.scope_mappings.update(scope_mappings)
|
|
509
|
+
else:
|
|
510
|
+
self.scope_mappings = scope_mappings or {}
|
|
511
|
+
|
|
512
|
+
self.excluded_route_paths = (
|
|
513
|
+
excluded_route_paths if excluded_route_paths is not None else self._get_default_excluded_routes()
|
|
514
|
+
)
|
|
515
|
+
self.admin_scope = admin_scope or AgentOSScope.ADMIN.value
|
|
516
|
+
|
|
517
|
+
def _get_default_excluded_routes(self) -> List[str]:
|
|
518
|
+
"""Get default routes that should be excluded from RBAC checks."""
|
|
519
|
+
return [
|
|
520
|
+
"/",
|
|
521
|
+
"/health",
|
|
522
|
+
"/docs",
|
|
523
|
+
"/redoc",
|
|
524
|
+
"/openapi.json",
|
|
525
|
+
"/docs/oauth2-redirect",
|
|
526
|
+
]
|
|
527
|
+
|
|
528
|
+
def _extract_resource_id_from_path(self, path: str, resource_type: str) -> Optional[str]:
|
|
529
|
+
"""
|
|
530
|
+
Extract resource ID from a path.
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
path: The request path
|
|
534
|
+
resource_type: Type of resource ("agents", "teams", "workflows")
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
The resource ID if found, None otherwise
|
|
538
|
+
|
|
539
|
+
Examples:
|
|
540
|
+
>>> _extract_resource_id_from_path("/agents/my-agent/runs", "agents")
|
|
541
|
+
"my-agent"
|
|
542
|
+
"""
|
|
543
|
+
# Pattern: /{resource_type}/{resource_id}/...
|
|
544
|
+
pattern = f"^/{resource_type}/([^/]+)"
|
|
545
|
+
match = re.search(pattern, path)
|
|
546
|
+
if match:
|
|
547
|
+
return match.group(1)
|
|
548
|
+
return None
|
|
549
|
+
|
|
550
|
+
def _is_route_excluded(self, path: str) -> bool:
|
|
551
|
+
"""Check if a route path matches any of the excluded patterns."""
|
|
552
|
+
if not self.excluded_route_paths:
|
|
553
|
+
return False
|
|
554
|
+
|
|
555
|
+
for excluded_path in self.excluded_route_paths:
|
|
556
|
+
# Support both exact matches and wildcard patterns
|
|
557
|
+
if fnmatch.fnmatch(path, excluded_path):
|
|
558
|
+
return True
|
|
559
|
+
# Also check without trailing slash
|
|
560
|
+
if fnmatch.fnmatch(path.rstrip("/"), excluded_path):
|
|
561
|
+
return True
|
|
562
|
+
|
|
563
|
+
return False
|
|
564
|
+
|
|
565
|
+
def _get_required_scopes(self, method: str, path: str) -> List[str]:
|
|
566
|
+
"""
|
|
567
|
+
Get required scopes for a given method and path.
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
method: HTTP method (GET, POST, etc.)
|
|
571
|
+
path: Request path
|
|
572
|
+
|
|
573
|
+
Returns:
|
|
574
|
+
List of required scopes. Empty list [] means no scopes required (allow access).
|
|
575
|
+
Routes not in scope_mappings also return [], allowing access.
|
|
576
|
+
"""
|
|
577
|
+
route_key = f"{method} {path}"
|
|
578
|
+
|
|
579
|
+
# First, try exact match
|
|
580
|
+
if route_key in self.scope_mappings:
|
|
581
|
+
return self.scope_mappings[route_key]
|
|
582
|
+
|
|
583
|
+
# Then try pattern matching
|
|
584
|
+
for pattern, scopes in self.scope_mappings.items():
|
|
585
|
+
pattern_method, pattern_path = pattern.split(" ", 1)
|
|
586
|
+
|
|
587
|
+
# Check if method matches
|
|
588
|
+
if pattern_method != method:
|
|
589
|
+
continue
|
|
590
|
+
|
|
591
|
+
# Convert pattern to fnmatch pattern (replace {param} with *)
|
|
592
|
+
# This handles both /agents/* and /agents/{agent_id} style patterns
|
|
593
|
+
normalized_pattern = pattern_path
|
|
594
|
+
if "{" in normalized_pattern:
|
|
595
|
+
# Replace {param} with * for pattern matching
|
|
596
|
+
normalized_pattern = re.sub(r"\{[^}]+\}", "*", normalized_pattern)
|
|
597
|
+
|
|
598
|
+
if fnmatch.fnmatch(path, normalized_pattern):
|
|
599
|
+
return scopes
|
|
600
|
+
|
|
601
|
+
return []
|
|
94
602
|
|
|
95
603
|
def _extract_token_from_header(self, request: Request) -> Optional[str]:
|
|
96
604
|
"""Extract JWT token from Authorization header."""
|
|
@@ -98,32 +606,17 @@ class JWTMiddleware(BaseHTTPMiddleware):
|
|
|
98
606
|
if not authorization:
|
|
99
607
|
return None
|
|
100
608
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
except ValueError:
|
|
106
|
-
return None
|
|
609
|
+
# Support both "Bearer <token>" and just "<token>"
|
|
610
|
+
if authorization.lower().startswith("bearer "):
|
|
611
|
+
return authorization[7:].strip()
|
|
612
|
+
return authorization.strip()
|
|
107
613
|
|
|
108
614
|
def _extract_token_from_cookie(self, request: Request) -> Optional[str]:
|
|
109
615
|
"""Extract JWT token from cookie."""
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
if self.token_source == TokenSource.HEADER:
|
|
115
|
-
return self._extract_token_from_header(request)
|
|
116
|
-
elif self.token_source == TokenSource.COOKIE:
|
|
117
|
-
return self._extract_token_from_cookie(request)
|
|
118
|
-
elif self.token_source == TokenSource.BOTH:
|
|
119
|
-
# Try header first, then cookie
|
|
120
|
-
token = self._extract_token_from_header(request)
|
|
121
|
-
if token is None:
|
|
122
|
-
token = self._extract_token_from_cookie(request)
|
|
123
|
-
return token
|
|
124
|
-
else:
|
|
125
|
-
log_debug(f"Unknown token source: {self.token_source}")
|
|
126
|
-
return None
|
|
616
|
+
cookie_value = request.cookies.get(self.cookie_name)
|
|
617
|
+
if cookie_value:
|
|
618
|
+
return cookie_value.strip()
|
|
619
|
+
return None
|
|
127
620
|
|
|
128
621
|
def _get_missing_token_error_message(self) -> str:
|
|
129
622
|
"""Get appropriate error message for missing token based on token source."""
|
|
@@ -136,98 +629,208 @@ class JWTMiddleware(BaseHTTPMiddleware):
|
|
|
136
629
|
else:
|
|
137
630
|
return "JWT token missing"
|
|
138
631
|
|
|
139
|
-
def
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
632
|
+
def _create_error_response(
|
|
633
|
+
self,
|
|
634
|
+
status_code: int,
|
|
635
|
+
detail: str,
|
|
636
|
+
origin: Optional[str] = None,
|
|
637
|
+
cors_allowed_origins: Optional[List[str]] = None,
|
|
638
|
+
) -> JSONResponse:
|
|
639
|
+
"""Create an error response with CORS headers."""
|
|
640
|
+
response = JSONResponse(status_code=status_code, content={"detail": detail})
|
|
641
|
+
|
|
642
|
+
# Add CORS headers to the error response
|
|
643
|
+
if origin and self._is_origin_allowed(origin, cors_allowed_origins):
|
|
644
|
+
response.headers["Access-Control-Allow-Origin"] = origin
|
|
645
|
+
response.headers["Access-Control-Allow-Credentials"] = "true"
|
|
646
|
+
response.headers["Access-Control-Allow-Methods"] = "*"
|
|
647
|
+
response.headers["Access-Control-Allow-Headers"] = "*"
|
|
648
|
+
response.headers["Access-Control-Expose-Headers"] = "*"
|
|
649
|
+
|
|
650
|
+
return response
|
|
651
|
+
|
|
652
|
+
def _is_origin_allowed(self, origin: str, cors_allowed_origins: Optional[List[str]] = None) -> bool:
|
|
653
|
+
"""Check if the origin is in the allowed origins list."""
|
|
654
|
+
if not cors_allowed_origins:
|
|
655
|
+
# If no allowed origins configured, allow all (fallback to default behavior)
|
|
656
|
+
return True
|
|
657
|
+
|
|
658
|
+
# Check if origin is in the allowed list
|
|
659
|
+
return origin in cors_allowed_origins
|
|
143
660
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
661
|
+
async def dispatch(self, request: Request, call_next) -> Response:
|
|
662
|
+
"""Process the request: extract JWT, validate, and check RBAC scopes."""
|
|
663
|
+
path = request.url.path
|
|
664
|
+
method = request.method
|
|
148
665
|
|
|
149
|
-
|
|
666
|
+
# Skip OPTIONS requests (CORS preflight)
|
|
667
|
+
if method == "OPTIONS":
|
|
668
|
+
return await call_next(request)
|
|
150
669
|
|
|
151
|
-
|
|
152
|
-
if self._is_route_excluded(
|
|
670
|
+
# Skip excluded routes
|
|
671
|
+
if self._is_route_excluded(path):
|
|
153
672
|
return await call_next(request)
|
|
154
673
|
|
|
155
|
-
#
|
|
156
|
-
|
|
674
|
+
# Get origin and CORS allowed origins for error responses
|
|
675
|
+
origin = request.headers.get("origin")
|
|
676
|
+
cors_allowed_origins = getattr(request.app.state, "cors_allowed_origins", None)
|
|
677
|
+
|
|
678
|
+
# Get agent_os_id from app state for audience verification
|
|
679
|
+
agent_os_id = getattr(request.app.state, "agent_os_id", None)
|
|
157
680
|
|
|
681
|
+
# Extract JWT token
|
|
682
|
+
token = self._extract_token(request)
|
|
158
683
|
if not token:
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
return JSONResponse(status_code=401, content={"detail": error_msg})
|
|
162
|
-
return await call_next(request)
|
|
684
|
+
error_msg = self._get_missing_token_error_message()
|
|
685
|
+
return self._create_error_response(401, error_msg, origin, cors_allowed_origins)
|
|
163
686
|
|
|
164
|
-
# Decode JWT token
|
|
165
687
|
try:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
688
|
+
# Validate token and extract claims (with audience verification if configured)
|
|
689
|
+
expected_audience = None
|
|
690
|
+
if self.verify_audience:
|
|
691
|
+
expected_audience = self.audience or agent_os_id
|
|
692
|
+
payload: Dict[str, Any] = self.validator.validate_token(token, expected_audience) # type: ignore
|
|
693
|
+
|
|
694
|
+
# Extract standard claims and store in request.state
|
|
695
|
+
user_id = payload.get(self.user_id_claim)
|
|
696
|
+
session_id = payload.get(self.session_id_claim)
|
|
697
|
+
scopes = payload.get(self.scopes_claim, [])
|
|
698
|
+
audience = payload.get(self.audience_claim)
|
|
699
|
+
|
|
700
|
+
# Ensure scopes is a list
|
|
701
|
+
if isinstance(scopes, str):
|
|
702
|
+
scopes = [scopes]
|
|
703
|
+
elif not isinstance(scopes, list):
|
|
704
|
+
scopes = []
|
|
705
|
+
|
|
706
|
+
# Store claims in request.state
|
|
707
|
+
request.state.authenticated = True
|
|
708
|
+
request.state.user_id = user_id
|
|
709
|
+
request.state.session_id = session_id
|
|
710
|
+
request.state.scopes = scopes
|
|
711
|
+
request.state.audience = audience
|
|
712
|
+
request.state.authorization_enabled = self.authorization or False
|
|
188
713
|
|
|
189
|
-
# Extract
|
|
714
|
+
# Extract dependencies claims
|
|
190
715
|
dependencies = {}
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
716
|
+
if self.dependencies_claims:
|
|
717
|
+
for claim in self.dependencies_claims:
|
|
718
|
+
if claim in payload:
|
|
719
|
+
dependencies[claim] = payload[claim]
|
|
194
720
|
|
|
195
721
|
if dependencies:
|
|
722
|
+
log_debug(f"Extracted dependencies: {dependencies}")
|
|
196
723
|
request.state.dependencies = dependencies
|
|
197
724
|
|
|
198
725
|
# Extract session state claims
|
|
199
726
|
session_state = {}
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
727
|
+
if self.session_state_claims:
|
|
728
|
+
for claim in self.session_state_claims:
|
|
729
|
+
if claim in payload:
|
|
730
|
+
session_state[claim] = payload[claim]
|
|
203
731
|
|
|
204
732
|
if session_state:
|
|
733
|
+
log_debug(f"Extracted session state: {session_state}")
|
|
205
734
|
request.state.session_state = session_state
|
|
206
735
|
|
|
207
|
-
|
|
208
|
-
|
|
736
|
+
# RBAC scope checking (only if enabled)
|
|
737
|
+
if self.authorization:
|
|
738
|
+
# Extract resource type and ID from path
|
|
739
|
+
resource_type = None
|
|
740
|
+
resource_id = None
|
|
741
|
+
|
|
742
|
+
if "/agents" in path:
|
|
743
|
+
resource_type = "agents"
|
|
744
|
+
elif "/teams" in path:
|
|
745
|
+
resource_type = "teams"
|
|
746
|
+
elif "/workflows" in path:
|
|
747
|
+
resource_type = "workflows"
|
|
748
|
+
|
|
749
|
+
if resource_type:
|
|
750
|
+
resource_id = self._extract_resource_id_from_path(path, resource_type)
|
|
751
|
+
|
|
752
|
+
required_scopes = self._get_required_scopes(method, path)
|
|
753
|
+
|
|
754
|
+
# Empty list [] means no scopes required (allow access)
|
|
755
|
+
if required_scopes:
|
|
756
|
+
# Use the scope validation system
|
|
757
|
+
has_access = has_required_scopes(
|
|
758
|
+
scopes,
|
|
759
|
+
required_scopes,
|
|
760
|
+
resource_type=resource_type,
|
|
761
|
+
resource_id=resource_id,
|
|
762
|
+
admin_scope=self.admin_scope,
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
# Special handling for listing endpoints (no resource_id)
|
|
766
|
+
if not has_access and not resource_id and resource_type:
|
|
767
|
+
# For listing endpoints, always allow access but store accessible IDs for filtering
|
|
768
|
+
# This allows endpoints to return filtered results (including empty list) instead of 403
|
|
769
|
+
accessible_ids = get_accessible_resource_ids(
|
|
770
|
+
scopes, resource_type, admin_scope=self.admin_scope
|
|
771
|
+
)
|
|
772
|
+
has_access = True # Always allow listing endpoints
|
|
773
|
+
request.state.accessible_resource_ids = accessible_ids
|
|
774
|
+
|
|
775
|
+
if accessible_ids:
|
|
776
|
+
log_debug(f"User has specific {resource_type} scopes. Accessible IDs: {accessible_ids}")
|
|
777
|
+
else:
|
|
778
|
+
log_debug(f"User has no {resource_type} scopes. Will return empty list.")
|
|
779
|
+
|
|
780
|
+
if not has_access:
|
|
781
|
+
log_warning(
|
|
782
|
+
f"Insufficient scopes for {method} {path}. Required: {required_scopes}, User has: {scopes}"
|
|
783
|
+
)
|
|
784
|
+
return self._create_error_response(
|
|
785
|
+
403, "Insufficient permissions", origin, cors_allowed_origins
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
log_debug(f"Scope check passed for {method} {path}. User scopes: {scopes}")
|
|
789
|
+
else:
|
|
790
|
+
log_debug(f"No scopes required for {method} {path}")
|
|
209
791
|
|
|
210
792
|
log_debug(f"JWT decoded successfully for user: {user_id}")
|
|
211
|
-
if dependencies:
|
|
212
|
-
log_debug(f"Extracted dependencies: {dependencies}")
|
|
213
|
-
if session_state:
|
|
214
|
-
log_debug(f"Extracted session state: {session_state}")
|
|
215
793
|
|
|
216
|
-
|
|
794
|
+
request.state.token = token
|
|
795
|
+
request.state.authenticated = True
|
|
796
|
+
|
|
797
|
+
except jwt.InvalidAudienceError:
|
|
798
|
+
log_warning(f"Invalid token audience - expected: {expected_audience}")
|
|
799
|
+
return self._create_error_response(
|
|
800
|
+
401, "Invalid token audience - token not valid for this AgentOS instance", origin, cors_allowed_origins
|
|
801
|
+
)
|
|
802
|
+
except jwt.ExpiredSignatureError as e:
|
|
217
803
|
if self.validate:
|
|
218
|
-
|
|
804
|
+
log_warning(f"Token has expired: {str(e)}")
|
|
805
|
+
return self._create_error_response(401, "Token has expired", origin, cors_allowed_origins)
|
|
219
806
|
request.state.authenticated = False
|
|
220
807
|
request.state.token = token
|
|
221
808
|
|
|
222
809
|
except jwt.InvalidTokenError as e:
|
|
223
810
|
if self.validate:
|
|
224
|
-
|
|
811
|
+
log_warning(f"Invalid token: {str(e)}")
|
|
812
|
+
return self._create_error_response(401, f"Invalid token: {str(e)}", origin, cors_allowed_origins)
|
|
225
813
|
request.state.authenticated = False
|
|
226
814
|
request.state.token = token
|
|
227
815
|
except Exception as e:
|
|
228
816
|
if self.validate:
|
|
229
|
-
|
|
817
|
+
log_warning(f"Error decoding token: {str(e)}")
|
|
818
|
+
return self._create_error_response(401, f"Error decoding token: {str(e)}", origin, cors_allowed_origins)
|
|
230
819
|
request.state.authenticated = False
|
|
231
820
|
request.state.token = token
|
|
232
821
|
|
|
233
822
|
return await call_next(request)
|
|
823
|
+
|
|
824
|
+
def _extract_token(self, request: Request) -> Optional[str]:
|
|
825
|
+
"""Extract JWT token based on configured source."""
|
|
826
|
+
if self.token_source == TokenSource.HEADER:
|
|
827
|
+
return self._extract_token_from_header(request)
|
|
828
|
+
elif self.token_source == TokenSource.COOKIE:
|
|
829
|
+
return self._extract_token_from_cookie(request)
|
|
830
|
+
elif self.token_source == TokenSource.BOTH:
|
|
831
|
+
# Try header first, then cookie
|
|
832
|
+
token = self._extract_token_from_header(request)
|
|
833
|
+
if token:
|
|
834
|
+
return token
|
|
835
|
+
return self._extract_token_from_cookie(request)
|
|
836
|
+
return None
|