nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,557 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import time
|
|
19
|
+
from typing import Any
|
|
20
|
+
from urllib.parse import urlparse
|
|
21
|
+
|
|
22
|
+
import httpx
|
|
23
|
+
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
24
|
+
from authlib.jose import JsonWebKey
|
|
25
|
+
from authlib.jose import KeySet
|
|
26
|
+
from authlib.jose import jwt
|
|
27
|
+
|
|
28
|
+
from nat.data_models.authentication import TokenValidationResult
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BearerTokenValidator:
|
|
34
|
+
"""Bearer token validator supporting JWT and opaque tokens.
|
|
35
|
+
|
|
36
|
+
Implements RFC 7519 (JWT) and RFC 7662 (Token Introspection) standards.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
introspection_endpoint: str | None = None,
|
|
42
|
+
issuer: str | None = None,
|
|
43
|
+
audience: str | None = None,
|
|
44
|
+
jwks_uri: str | None = None,
|
|
45
|
+
client_id: str | None = None,
|
|
46
|
+
client_secret: str | None = None,
|
|
47
|
+
scopes: list[str] | None = None,
|
|
48
|
+
timeout: float = 10.0,
|
|
49
|
+
leeway: int = 60,
|
|
50
|
+
discovery_url: str | None = None,
|
|
51
|
+
):
|
|
52
|
+
"""
|
|
53
|
+
Args:
|
|
54
|
+
introspection_endpoint: OAuth 2.0 introspection URL (required to validate opaque tokens).
|
|
55
|
+
issuer: Expected token issuer (`iss`); recommended for policy, not required for JWT signature validity.
|
|
56
|
+
audience: Expected token audience (`aud`); recommended for policy, not required for JWT signature validity.
|
|
57
|
+
jwks_uri: JWKS URL with public keys to verify asymmetric JWTs; optional if using discovery.
|
|
58
|
+
client_id: OAuth 2.0 client ID for authenticating to the introspection endpoint.
|
|
59
|
+
client_secret: OAuth 2.0 client secret for authenticating to the introspection endpoint.
|
|
60
|
+
scopes: Optional authorization scopes to check after validation; not required for token validity.
|
|
61
|
+
timeout: HTTP request timeout for discovery/JWKS/introspection (default: 10.0s).
|
|
62
|
+
leeway: Clock-skew allowance for `exp`/`nbf`/`iat` checks (default: 60s).
|
|
63
|
+
discovery_url: OIDC/OAuth metadata URL to auto-discover `jwks_uri` and `introspection_endpoint`.
|
|
64
|
+
"""
|
|
65
|
+
# Configuration parameters
|
|
66
|
+
self.introspection_endpoint = introspection_endpoint
|
|
67
|
+
self.issuer = issuer
|
|
68
|
+
self.audience = audience
|
|
69
|
+
self.jwks_uri = jwks_uri
|
|
70
|
+
self.client_id = client_id
|
|
71
|
+
self.client_secret = client_secret
|
|
72
|
+
self.scopes = scopes
|
|
73
|
+
self.timeout = timeout
|
|
74
|
+
self.leeway = leeway
|
|
75
|
+
self.discovery_url = discovery_url
|
|
76
|
+
|
|
77
|
+
# Validate configuration
|
|
78
|
+
self._validate_configuration()
|
|
79
|
+
|
|
80
|
+
# HTTPS validation for configured URLs
|
|
81
|
+
if self.discovery_url:
|
|
82
|
+
self._require_https(self.discovery_url, "discovery_url")
|
|
83
|
+
if self.jwks_uri:
|
|
84
|
+
self._require_https(self.jwks_uri, "jwks_uri")
|
|
85
|
+
if self.introspection_endpoint:
|
|
86
|
+
self._require_https(self.introspection_endpoint, "introspection_endpoint")
|
|
87
|
+
|
|
88
|
+
# Caches for performance with TTL
|
|
89
|
+
# JWKS cache: uri -> {keyset, cache_expires_at}
|
|
90
|
+
self._jwks_cache: dict[str, dict[str, Any]] = {}
|
|
91
|
+
# OIDC config cache: url -> {config, cache_expires_at}
|
|
92
|
+
self._oidc_config_cache: dict[str, dict[str, Any]] = {}
|
|
93
|
+
# Positive introspection result cache: token_prefix -> {result, cache_expires_at}
|
|
94
|
+
self._introspection_cache: dict[str, dict[str, Any]] = {}
|
|
95
|
+
|
|
96
|
+
# Cache TTL settings
|
|
97
|
+
self._jwks_cache_ttl = 900 # 15 minutes
|
|
98
|
+
self._discovery_cache_ttl = 900 # 15 minutes
|
|
99
|
+
|
|
100
|
+
def _validate_configuration(self) -> None:
|
|
101
|
+
"""Validate that at least one token verification method is configured."""
|
|
102
|
+
|
|
103
|
+
jwt_possible = self.jwks_uri or self.discovery_url or self.issuer
|
|
104
|
+
introspection_possible = self.introspection_endpoint and self.client_id and self.client_secret
|
|
105
|
+
|
|
106
|
+
if not jwt_possible and not introspection_possible:
|
|
107
|
+
raise ValueError("No valid token verification method configured. "
|
|
108
|
+
"Either provide JWT verification (jwks_uri, discovery_url, or issuer for derived JWKS) "
|
|
109
|
+
"or introspection (introspection_endpoint with client_id and client_secret)")
|
|
110
|
+
|
|
111
|
+
async def verify(self, token: str) -> TokenValidationResult:
|
|
112
|
+
"""Validate bearer token per RFC 7519 (JWT) and RFC 7662 (Introspection).
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
token: Bearer token to validate
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
TokenValidationResult
|
|
119
|
+
"""
|
|
120
|
+
if not token or not isinstance(token, str):
|
|
121
|
+
return TokenValidationResult(client_id="", token_type="bearer", active=False)
|
|
122
|
+
|
|
123
|
+
if token.startswith("Bearer "):
|
|
124
|
+
token = token[7:]
|
|
125
|
+
|
|
126
|
+
if not token:
|
|
127
|
+
return TokenValidationResult(client_id="", token_type="bearer", active=False)
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
if token.count(".") == 2:
|
|
131
|
+
return await self._verify_jwt_token(token)
|
|
132
|
+
elif (self.introspection_endpoint and self.client_id and self.client_secret):
|
|
133
|
+
return await self._verify_opaque_token(token)
|
|
134
|
+
else:
|
|
135
|
+
return TokenValidationResult(client_id="", token_type="bearer", active=False)
|
|
136
|
+
except Exception:
|
|
137
|
+
return TokenValidationResult(client_id="", token_type="bearer", active=False)
|
|
138
|
+
|
|
139
|
+
def _is_jwt_token(self, token: str) -> bool:
|
|
140
|
+
"""Check if token has JWT structure."""
|
|
141
|
+
return token.count(".") == 2
|
|
142
|
+
|
|
143
|
+
async def _verify_jwt_token(self, token: str) -> TokenValidationResult:
|
|
144
|
+
"""Verify JWT token.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
token: JWT token to verify
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
TokenValidationResult
|
|
151
|
+
"""
|
|
152
|
+
jwks_uri = await self._resolve_jwks_uri()
|
|
153
|
+
keyset = await self._fetch_jwks(jwks_uri)
|
|
154
|
+
|
|
155
|
+
claims = jwt.decode(
|
|
156
|
+
token,
|
|
157
|
+
keyset,
|
|
158
|
+
claims_options={
|
|
159
|
+
"exp": {
|
|
160
|
+
"essential": True, "leeway": self.leeway
|
|
161
|
+
},
|
|
162
|
+
"nbf": {
|
|
163
|
+
"essential": False, "leeway": self.leeway
|
|
164
|
+
},
|
|
165
|
+
"iat": {
|
|
166
|
+
"essential": False, "leeway": self.leeway
|
|
167
|
+
},
|
|
168
|
+
},
|
|
169
|
+
)
|
|
170
|
+
claims.validate(leeway=self.leeway)
|
|
171
|
+
|
|
172
|
+
issuer = claims.get("iss")
|
|
173
|
+
subject = claims.get("sub")
|
|
174
|
+
audience = self._extract_audience_from_claims(claims)
|
|
175
|
+
scopes = claims.get("scope") or claims.get("scp")
|
|
176
|
+
scopes = (scopes.split() if isinstance(scopes, str) else scopes) or None
|
|
177
|
+
|
|
178
|
+
self._check_jwt_policies(issuer, audience, scopes)
|
|
179
|
+
|
|
180
|
+
return TokenValidationResult(
|
|
181
|
+
client_id=claims.get("azp") or claims.get("client_id") or subject,
|
|
182
|
+
expires_at=claims.get("exp"),
|
|
183
|
+
audience=audience,
|
|
184
|
+
subject=subject,
|
|
185
|
+
issuer=issuer,
|
|
186
|
+
token_type="at+jwt",
|
|
187
|
+
nbf=claims.get("nbf"),
|
|
188
|
+
iat=claims.get("iat"),
|
|
189
|
+
jti=claims.get("jti"),
|
|
190
|
+
scopes=scopes,
|
|
191
|
+
active=True,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
async def _verify_opaque_token(self, token: str) -> TokenValidationResult:
|
|
195
|
+
"""Verify opaque token via RFC 7662 introspection.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
token: Opaque token to verify
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
TokenValidationResult
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
cache_key = token[:10] if len(token) >= 10 else token
|
|
205
|
+
|
|
206
|
+
# Check cache first
|
|
207
|
+
cache_entry = self._introspection_cache.get(cache_key)
|
|
208
|
+
if cache_entry:
|
|
209
|
+
cached_result = cache_entry["result"]
|
|
210
|
+
cache_expires_at = cache_entry["cache_expires_at"]
|
|
211
|
+
now = int(time.time())
|
|
212
|
+
|
|
213
|
+
# Use cached result if not expired
|
|
214
|
+
if now < cache_expires_at:
|
|
215
|
+
return cached_result
|
|
216
|
+
else:
|
|
217
|
+
del self._introspection_cache[cache_key]
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
async with AsyncOAuth2Client(
|
|
221
|
+
client_id=self.client_id,
|
|
222
|
+
client_secret=self.client_secret,
|
|
223
|
+
timeout=httpx.Timeout(self.timeout),
|
|
224
|
+
) as oauth_client:
|
|
225
|
+
introspection_response = await oauth_client.introspect_token(
|
|
226
|
+
self.introspection_endpoint,
|
|
227
|
+
token,
|
|
228
|
+
token_type_hint="access_token",
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Check if token is active
|
|
232
|
+
if not introspection_response.get("active", False):
|
|
233
|
+
raise ValueError("Token is inactive")
|
|
234
|
+
|
|
235
|
+
# Extract claims
|
|
236
|
+
client_id = introspection_response.get("client_id")
|
|
237
|
+
username = introspection_response.get("username")
|
|
238
|
+
token_type = introspection_response.get("token_type", "opaque")
|
|
239
|
+
expires_at = introspection_response.get("exp")
|
|
240
|
+
not_before = introspection_response.get("nbf")
|
|
241
|
+
issued_at = introspection_response.get("iat")
|
|
242
|
+
subject = introspection_response.get("sub")
|
|
243
|
+
audience = self._extract_audience_from_introspection(introspection_response)
|
|
244
|
+
issuer = introspection_response.get("iss")
|
|
245
|
+
jwt_id = introspection_response.get("jti")
|
|
246
|
+
|
|
247
|
+
# Parse scopes
|
|
248
|
+
scope_value = introspection_response.get("scope")
|
|
249
|
+
scopes = None
|
|
250
|
+
if scope_value and isinstance(scope_value, str):
|
|
251
|
+
scopes = scope_value.split()
|
|
252
|
+
elif isinstance(scope_value, list):
|
|
253
|
+
scopes = scope_value
|
|
254
|
+
|
|
255
|
+
# Check expiration and not-before with leeway
|
|
256
|
+
if self._is_expired(expires_at):
|
|
257
|
+
raise ValueError("Token is expired")
|
|
258
|
+
|
|
259
|
+
# Check not-before claim with leeway
|
|
260
|
+
if not_before and self._is_not_yet_valid(not_before):
|
|
261
|
+
raise ValueError("Token is not yet valid")
|
|
262
|
+
|
|
263
|
+
# Apply opaque token policy checks
|
|
264
|
+
self._check_opaque_policies(issuer, audience, scopes)
|
|
265
|
+
|
|
266
|
+
result = TokenValidationResult(
|
|
267
|
+
client_id=client_id,
|
|
268
|
+
username=username,
|
|
269
|
+
token_type=token_type,
|
|
270
|
+
expires_at=expires_at,
|
|
271
|
+
audience=audience,
|
|
272
|
+
subject=subject,
|
|
273
|
+
issuer=issuer,
|
|
274
|
+
jti=jwt_id,
|
|
275
|
+
scopes=scopes,
|
|
276
|
+
active=True,
|
|
277
|
+
nbf=not_before,
|
|
278
|
+
iat=issued_at,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Cache positive result with TTL based on token expiration
|
|
282
|
+
if expires_at:
|
|
283
|
+
cache_expires_at = min(expires_at, int(time.time()) + 300) # Max 5 minutes
|
|
284
|
+
self._introspection_cache[cache_key] = {"result": result, "cache_expires_at": cache_expires_at}
|
|
285
|
+
|
|
286
|
+
return result
|
|
287
|
+
|
|
288
|
+
except (ValueError, TypeError, KeyError, httpx.HTTPError) as e:
|
|
289
|
+
raise ValueError(f"Introspection failed: {e}") from e
|
|
290
|
+
|
|
291
|
+
async def _resolve_jwks_uri(self) -> str:
|
|
292
|
+
"""Resolve JWKS URI using configuration priority: jwks_uri → discovery → issuer.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
JWKS URI string
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
if self.jwks_uri:
|
|
299
|
+
return self.jwks_uri
|
|
300
|
+
|
|
301
|
+
if self.discovery_url:
|
|
302
|
+
try:
|
|
303
|
+
config = await self._get_oidc_configuration(self.discovery_url)
|
|
304
|
+
jwks = config.get("jwks_uri")
|
|
305
|
+
if isinstance(jwks, str) and jwks:
|
|
306
|
+
self._require_https(jwks, "jwks_uri")
|
|
307
|
+
return jwks
|
|
308
|
+
except Exception as e:
|
|
309
|
+
raise ValueError(f"Failed to get JWKS URI from discovery: {e}") from e
|
|
310
|
+
|
|
311
|
+
if self.issuer:
|
|
312
|
+
jwks = f"{self.issuer.rstrip('/')}/.well-known/jwks.json"
|
|
313
|
+
self._require_https(jwks, "jwks_uri")
|
|
314
|
+
return jwks
|
|
315
|
+
|
|
316
|
+
raise ValueError("No JWKS URI available - no jwks_uri, discovery_url, or issuer configured")
|
|
317
|
+
|
|
318
|
+
async def _get_oidc_configuration(self, discovery_url: str) -> dict[str, Any]:
|
|
319
|
+
"""Get OIDC configuration.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
discovery_url: OIDC discovery URL
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
OIDC configuration dict
|
|
326
|
+
"""
|
|
327
|
+
|
|
328
|
+
# Check cache first
|
|
329
|
+
cache_entry = self._oidc_config_cache.get(discovery_url)
|
|
330
|
+
if cache_entry:
|
|
331
|
+
config = cache_entry["config"]
|
|
332
|
+
cache_expires_at = cache_entry["cache_expires_at"]
|
|
333
|
+
now = int(time.time())
|
|
334
|
+
|
|
335
|
+
if now < cache_expires_at:
|
|
336
|
+
return config
|
|
337
|
+
else:
|
|
338
|
+
# Remove expired entry
|
|
339
|
+
del self._oidc_config_cache[discovery_url]
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
343
|
+
response = await client.get(discovery_url)
|
|
344
|
+
response.raise_for_status()
|
|
345
|
+
config = response.json()
|
|
346
|
+
|
|
347
|
+
if not isinstance(config, dict):
|
|
348
|
+
logger.warning("OIDC discovery returned non-dict; not caching")
|
|
349
|
+
return config
|
|
350
|
+
|
|
351
|
+
jwks_uri = config.get("jwks_uri")
|
|
352
|
+
if jwks_uri is not None and not isinstance(jwks_uri, str):
|
|
353
|
+
logger.warning("OIDC discovery jwks_uri is not a string; not caching")
|
|
354
|
+
return config
|
|
355
|
+
|
|
356
|
+
# Cache with TTL
|
|
357
|
+
cache_expires_at = int(time.time()) + self._discovery_cache_ttl
|
|
358
|
+
self._oidc_config_cache[discovery_url] = {"config": config, "cache_expires_at": cache_expires_at}
|
|
359
|
+
return config
|
|
360
|
+
|
|
361
|
+
except httpx.HTTPError as e:
|
|
362
|
+
raise ValueError(f"OIDC discovery failed: {e}") from e
|
|
363
|
+
except json.JSONDecodeError as e:
|
|
364
|
+
raise ValueError(f"Invalid OIDC discovery response: {e}") from e
|
|
365
|
+
|
|
366
|
+
async def _fetch_jwks(self, jwks_uri: str) -> KeySet:
|
|
367
|
+
"""Fetch JWKS from URI.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
jwks_uri: JWKS endpoint URI
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
KeySet for token verification
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
# Check cache first
|
|
377
|
+
cache_entry = self._jwks_cache.get(jwks_uri)
|
|
378
|
+
if cache_entry:
|
|
379
|
+
keyset = cache_entry["keyset"]
|
|
380
|
+
cache_expires_at = cache_entry["cache_expires_at"]
|
|
381
|
+
now = int(time.time())
|
|
382
|
+
|
|
383
|
+
if now < cache_expires_at:
|
|
384
|
+
return keyset
|
|
385
|
+
else:
|
|
386
|
+
# Remove expired entry
|
|
387
|
+
del self._jwks_cache[jwks_uri]
|
|
388
|
+
|
|
389
|
+
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
390
|
+
response = await client.get(jwks_uri)
|
|
391
|
+
response.raise_for_status()
|
|
392
|
+
jwks_data = response.json()
|
|
393
|
+
|
|
394
|
+
keys = jwks_data.get("keys", [])
|
|
395
|
+
if not keys:
|
|
396
|
+
raise ValueError("JWKS contains no keys")
|
|
397
|
+
|
|
398
|
+
keyset = KeySet([JsonWebKey.import_key(k) for k in keys if isinstance(k, dict)])
|
|
399
|
+
if not keyset:
|
|
400
|
+
raise ValueError("JWKS contains no valid keys")
|
|
401
|
+
|
|
402
|
+
# Cache keyset with TTL
|
|
403
|
+
cache_expires_at = int(time.time()) + self._jwks_cache_ttl
|
|
404
|
+
self._jwks_cache[jwks_uri] = {"keyset": keyset, "cache_expires_at": cache_expires_at}
|
|
405
|
+
return keyset
|
|
406
|
+
|
|
407
|
+
def _extract_audience_from_claims(self, claims: dict[str, Any]) -> list[str] | None:
|
|
408
|
+
"""Extract audience from JWT claims.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
claims: JWT claims dict
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
List of audience values
|
|
415
|
+
"""
|
|
416
|
+
|
|
417
|
+
audience = claims.get("aud")
|
|
418
|
+
if isinstance(audience, str):
|
|
419
|
+
return [audience]
|
|
420
|
+
elif isinstance(audience, list):
|
|
421
|
+
filtered = [aud for aud in audience if isinstance(aud, str)]
|
|
422
|
+
return filtered if filtered else None
|
|
423
|
+
return None
|
|
424
|
+
|
|
425
|
+
def _extract_audience_from_introspection(self, response: dict[str, Any]) -> list[str] | None:
|
|
426
|
+
"""Extract audience from introspection response.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
response: Introspection response dict
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
List of audience values
|
|
433
|
+
"""
|
|
434
|
+
|
|
435
|
+
audience = response.get("aud")
|
|
436
|
+
if isinstance(audience, str):
|
|
437
|
+
return [audience]
|
|
438
|
+
elif isinstance(audience, list):
|
|
439
|
+
filtered = [aud for aud in audience if isinstance(aud, str)]
|
|
440
|
+
return filtered if filtered else None
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
def _require_https(self, url: str, url_description: str) -> None:
|
|
444
|
+
"""Enforce HTTPS requirement.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
url: URL to validate
|
|
448
|
+
url_description: Description for error messages
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
if url.startswith("https://"):
|
|
452
|
+
return
|
|
453
|
+
parsed_url = urlparse(url)
|
|
454
|
+
if parsed_url.hostname in ("localhost", "127.0.0.1", "::1"):
|
|
455
|
+
return
|
|
456
|
+
raise ValueError(f"{url_description} must use HTTPS: {url}")
|
|
457
|
+
|
|
458
|
+
def _check_jwt_policies(self,
|
|
459
|
+
issuer_claim: str | None,
|
|
460
|
+
audience_claim: list[str] | None,
|
|
461
|
+
token_scopes: list[str] | None) -> None:
|
|
462
|
+
"""Check JWT token against configured policies.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
issuer_claim: Issuer from JWT token
|
|
466
|
+
audience_claim: Audience list from JWT token
|
|
467
|
+
token_scopes: Scopes from JWT token
|
|
468
|
+
"""
|
|
469
|
+
# Check issuer policy
|
|
470
|
+
if self.issuer and issuer_claim != self.issuer:
|
|
471
|
+
raise ValueError(f"JWT issuer '{issuer_claim}' does not match expected issuer '{self.issuer}'")
|
|
472
|
+
|
|
473
|
+
# Check audience policy
|
|
474
|
+
if self.audience:
|
|
475
|
+
if not audience_claim or self.audience not in audience_claim:
|
|
476
|
+
raise ValueError(f"JWT audience {audience_claim} does not contain required audience '{self.audience}'")
|
|
477
|
+
|
|
478
|
+
# Check scope policy
|
|
479
|
+
if self.scopes:
|
|
480
|
+
if not token_scopes:
|
|
481
|
+
raise ValueError(f"JWT has no scopes but required scopes: {self.scopes}")
|
|
482
|
+
|
|
483
|
+
token_scope_set = set(token_scopes)
|
|
484
|
+
required_scope_set = set(self.scopes)
|
|
485
|
+
|
|
486
|
+
if not required_scope_set.issubset(token_scope_set):
|
|
487
|
+
missing_scopes = required_scope_set - token_scope_set
|
|
488
|
+
raise ValueError(
|
|
489
|
+
f"JWT missing required scopes: {sorted(missing_scopes)} (has: {sorted(token_scope_set)})")
|
|
490
|
+
|
|
491
|
+
def _check_opaque_policies(self,
|
|
492
|
+
issuer_claim: str | None,
|
|
493
|
+
audience_claim: list[str] | None,
|
|
494
|
+
token_scopes: list[str] | None) -> None:
|
|
495
|
+
"""Check opaque token against configured policies.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
issuer_claim: Issuer from introspection response
|
|
499
|
+
audience_claim: Audience list from introspection response
|
|
500
|
+
token_scopes: Scopes from introspection response
|
|
501
|
+
"""
|
|
502
|
+
# Check issuer policy
|
|
503
|
+
if self.issuer and issuer_claim != self.issuer:
|
|
504
|
+
raise ValueError(f"Opaque token issuer '{issuer_claim}' does not match expected issuer '{self.issuer}'")
|
|
505
|
+
|
|
506
|
+
# Check audience policy
|
|
507
|
+
if self.audience:
|
|
508
|
+
if not audience_claim or self.audience not in audience_claim:
|
|
509
|
+
raise ValueError(
|
|
510
|
+
f"Opaque token audience {audience_claim} does not contain required audience '{self.audience}'")
|
|
511
|
+
|
|
512
|
+
# Check scope policy
|
|
513
|
+
if self.scopes:
|
|
514
|
+
if not token_scopes:
|
|
515
|
+
raise ValueError(f"Opaque token has no scopes but required scopes: {self.scopes}")
|
|
516
|
+
|
|
517
|
+
token_scope_set = set(token_scopes)
|
|
518
|
+
required_scope_set = set(self.scopes)
|
|
519
|
+
|
|
520
|
+
if not required_scope_set.issubset(token_scope_set):
|
|
521
|
+
missing_scopes = required_scope_set - token_scope_set
|
|
522
|
+
raise ValueError(
|
|
523
|
+
f"Opaque token missing required scopes: {sorted(missing_scopes)} (has: {sorted(token_scope_set)})")
|
|
524
|
+
|
|
525
|
+
def _is_expired(self, exp: int | None, leeway: int | None = None) -> bool:
|
|
526
|
+
"""Check if timestamp is expired considering leeway.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
exp: Expiration timestamp
|
|
530
|
+
leeway: Clock skew allowance
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
True if expired
|
|
534
|
+
"""
|
|
535
|
+
|
|
536
|
+
if exp is None:
|
|
537
|
+
return False
|
|
538
|
+
leeway = leeway or self.leeway
|
|
539
|
+
now = int(time.time())
|
|
540
|
+
return now > (exp + leeway)
|
|
541
|
+
|
|
542
|
+
def _is_not_yet_valid(self, nbf: int | None, leeway: int | None = None) -> bool:
|
|
543
|
+
"""Check if timestamp is not yet valid considering leeway.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
nbf: Not-before timestamp
|
|
547
|
+
leeway: Clock skew allowance
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
True if not yet valid
|
|
551
|
+
"""
|
|
552
|
+
|
|
553
|
+
if nbf is None:
|
|
554
|
+
return False
|
|
555
|
+
leeway = leeway or self.leeway
|
|
556
|
+
now = int(time.time())
|
|
557
|
+
return now < (nbf - leeway)
|
|
@@ -38,7 +38,7 @@ class HTTPBasicAuthProvider(AuthProviderBase):
|
|
|
38
38
|
|
|
39
39
|
self._authenticated_tokens: dict[str, AuthResult] = {}
|
|
40
40
|
|
|
41
|
-
async def authenticate(self, user_id: str | None = None) -> AuthResult:
|
|
41
|
+
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
|
42
42
|
"""
|
|
43
43
|
Performs simple HTTP Authentication using the provided user ID.
|
|
44
44
|
"""
|
nat/authentication/interfaces.py
CHANGED
|
@@ -54,7 +54,7 @@ class AuthProviderBase(typing.Generic[AuthProviderBaseConfigT], ABC):
|
|
|
54
54
|
return self._config
|
|
55
55
|
|
|
56
56
|
@abstractmethod
|
|
57
|
-
async def authenticate(self, user_id: str | None = None) -> AuthResult:
|
|
57
|
+
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
|
58
58
|
"""
|
|
59
59
|
Perform the authentication process for the client.
|
|
60
60
|
|
|
@@ -62,6 +62,9 @@ class AuthProviderBase(typing.Generic[AuthProviderBaseConfigT], ABC):
|
|
|
62
62
|
target API service, which may include obtaining tokens, refreshing credentials,
|
|
63
63
|
or completing multi-step authentication flows.
|
|
64
64
|
|
|
65
|
+
Args:
|
|
66
|
+
user_id: Optional user identifier for authentication
|
|
67
|
+
kwargs: Additional authentication parameters for example: http response (typically from a 401)
|
|
65
68
|
Raises:
|
|
66
69
|
NotImplementedError: Must be implemented by subclasses.
|
|
67
70
|
"""
|
|
@@ -71,7 +74,7 @@ class AuthProviderBase(typing.Generic[AuthProviderBaseConfigT], ABC):
|
|
|
71
74
|
|
|
72
75
|
class FlowHandlerBase(ABC):
|
|
73
76
|
"""
|
|
74
|
-
Handles front-end
|
|
77
|
+
Handles front-end specific flows for authentication clients.
|
|
75
78
|
|
|
76
79
|
Each front end will define a FlowHandler that will implement the authenticate method.
|
|
77
80
|
|