nvidia-nat 1.2.1rc1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +27 -18
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +81 -50
- nat/agent/react_agent/register.py +59 -40
- nat/agent/reasoning_agent/reasoning_agent.py +17 -15
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +327 -149
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +64 -46
- nat/agent/tool_calling_agent/agent.py +152 -29
- nat/agent/tool_calling_agent/register.py +61 -38
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +10 -6
- nat/builder/context.py +70 -18
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/intermediate_step_manager.py +6 -2
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +327 -79
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +5 -2
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +105 -19
- nat/cli/entrypoint.py +17 -11
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +79 -10
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +196 -67
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +42 -18
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +2 -3
- nat/embedder/register.py +1 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +9 -6
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +19 -7
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +455 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +74 -50
- nat/front_ends/fastapi/message_validator.py +20 -21
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +47 -3
- nat/front_ends/mcp/mcp_front_end_plugin.py +48 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +120 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +57 -0
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +5 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +22 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +14 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +164 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +395 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +105 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +12 -3
- nat/utils/type_utils.py +9 -5
- nvidia_nat-1.3.0.dist-info/METADATA +195 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/RECORD +244 -200
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- nvidia_nat-1.2.1rc1.dist-info/METADATA +0 -365
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/top_level.txt +0 -0
nat/front_ends/fastapi/main.py
CHANGED
|
@@ -13,19 +13,24 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import importlib
|
|
17
16
|
import logging
|
|
18
17
|
import os
|
|
18
|
+
import typing
|
|
19
19
|
|
|
20
20
|
from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorkerBase
|
|
21
|
+
from nat.front_ends.fastapi.utils import get_config_file_path
|
|
22
|
+
from nat.front_ends.fastapi.utils import import_class_from_string
|
|
21
23
|
from nat.runtime.loader import load_config
|
|
22
24
|
|
|
25
|
+
if typing.TYPE_CHECKING:
|
|
26
|
+
from fastapi import FastAPI
|
|
27
|
+
|
|
23
28
|
logger = logging.getLogger(__name__)
|
|
24
29
|
|
|
25
30
|
|
|
26
|
-
def get_app():
|
|
31
|
+
def get_app() -> "FastAPI":
|
|
27
32
|
|
|
28
|
-
config_file_path =
|
|
33
|
+
config_file_path = get_config_file_path()
|
|
29
34
|
front_end_worker_full_name = os.getenv("NAT_FRONT_END_WORKER")
|
|
30
35
|
|
|
31
36
|
if (not config_file_path):
|
|
@@ -36,28 +41,15 @@ def get_app():
|
|
|
36
41
|
|
|
37
42
|
# Try to import the front end worker class
|
|
38
43
|
try:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
front_end_worker_module_name = ".".join(front_end_worker_parts[:-1])
|
|
43
|
-
front_end_worker_class_name = front_end_worker_parts[-1]
|
|
44
|
-
|
|
45
|
-
front_end_worker_module = importlib.import_module(front_end_worker_module_name)
|
|
46
|
-
|
|
47
|
-
if not hasattr(front_end_worker_module, front_end_worker_class_name):
|
|
48
|
-
raise ValueError(f"Front end worker {front_end_worker_full_name} not found.")
|
|
49
|
-
|
|
50
|
-
front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = getattr(front_end_worker_module,
|
|
51
|
-
front_end_worker_class_name)
|
|
44
|
+
front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = import_class_from_string(
|
|
45
|
+
front_end_worker_full_name)
|
|
52
46
|
|
|
53
47
|
if (not issubclass(front_end_worker_class, FastApiFrontEndPluginWorkerBase)):
|
|
54
48
|
raise ValueError(
|
|
55
49
|
f"Front end worker {front_end_worker_full_name} is not a subclass of FastApiFrontEndPluginWorker.")
|
|
56
50
|
|
|
57
51
|
# Load the config
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
config = load_config(abs_config_file_path)
|
|
52
|
+
config = load_config(config_file_path)
|
|
61
53
|
|
|
62
54
|
# Create an instance of the front end worker class
|
|
63
55
|
front_end_worker = front_end_worker_class(config)
|
|
@@ -25,6 +25,7 @@ from pydantic import ValidationError
|
|
|
25
25
|
from starlette.websockets import WebSocketDisconnect
|
|
26
26
|
|
|
27
27
|
from nat.authentication.interfaces import FlowHandlerBase
|
|
28
|
+
from nat.data_models.api_server import ChatRequest
|
|
28
29
|
from nat.data_models.api_server import ChatResponse
|
|
29
30
|
from nat.data_models.api_server import ChatResponseChunk
|
|
30
31
|
from nat.data_models.api_server import Error
|
|
@@ -33,6 +34,8 @@ from nat.data_models.api_server import ResponsePayloadOutput
|
|
|
33
34
|
from nat.data_models.api_server import ResponseSerializable
|
|
34
35
|
from nat.data_models.api_server import SystemResponseContent
|
|
35
36
|
from nat.data_models.api_server import TextContent
|
|
37
|
+
from nat.data_models.api_server import UserMessageContentRoleType
|
|
38
|
+
from nat.data_models.api_server import UserMessages
|
|
36
39
|
from nat.data_models.api_server import WebSocketMessageStatus
|
|
37
40
|
from nat.data_models.api_server import WebSocketMessageType
|
|
38
41
|
from nat.data_models.api_server import WebSocketSystemInteractionMessage
|
|
@@ -64,12 +67,12 @@ class WebSocketMessageHandler:
|
|
|
64
67
|
self._running_workflow_task: asyncio.Task | None = None
|
|
65
68
|
self._message_parent_id: str = "default_id"
|
|
66
69
|
self._conversation_id: str | None = None
|
|
67
|
-
self._workflow_schema_type: str = None
|
|
68
|
-
self._user_interaction_response: asyncio.Future[
|
|
70
|
+
self._workflow_schema_type: str | None = None
|
|
71
|
+
self._user_interaction_response: asyncio.Future[TextContent] | None = None
|
|
69
72
|
|
|
70
73
|
self._flow_handler: FlowHandlerBase | None = None
|
|
71
74
|
|
|
72
|
-
self._schema_output_mapping: dict[str, type[BaseModel] | None] = {
|
|
75
|
+
self._schema_output_mapping: dict[str, type[BaseModel] | type[None]] = {
|
|
73
76
|
WorkflowSchemaType.GENERATE: self._session_manager.workflow.single_output_schema,
|
|
74
77
|
WorkflowSchemaType.CHAT: ChatResponse,
|
|
75
78
|
WorkflowSchemaType.CHAT_STREAM: ChatResponseChunk,
|
|
@@ -86,7 +89,7 @@ class WebSocketMessageHandler:
|
|
|
86
89
|
|
|
87
90
|
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
|
88
91
|
|
|
89
|
-
# TODO: Handle the exit
|
|
92
|
+
# TODO: Handle the exit
|
|
90
93
|
pass
|
|
91
94
|
|
|
92
95
|
async def run(self) -> None:
|
|
@@ -107,47 +110,65 @@ class WebSocketMessageHandler:
|
|
|
107
110
|
|
|
108
111
|
elif isinstance(
|
|
109
112
|
validated_message,
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
WebSocketSystemIntermediateStepMessage,
|
|
113
|
-
WebSocketSystemInteractionMessage)):
|
|
113
|
+
WebSocketSystemResponseTokenMessage | WebSocketSystemIntermediateStepMessage
|
|
114
|
+
| WebSocketSystemInteractionMessage):
|
|
114
115
|
# These messages are already handled by self.create_websocket_message(data_model=value, …)
|
|
115
116
|
# No further processing is needed here.
|
|
116
117
|
pass
|
|
117
118
|
|
|
118
119
|
elif (isinstance(validated_message, WebSocketUserInteractionResponseMessage)):
|
|
119
|
-
user_content = await self.
|
|
120
|
+
user_content = await self._process_websocket_user_interaction_response_message(validated_message)
|
|
121
|
+
assert self._user_interaction_response is not None
|
|
120
122
|
self._user_interaction_response.set_result(user_content)
|
|
121
123
|
except (asyncio.CancelledError, WebSocketDisconnect):
|
|
122
|
-
# TODO: Handle the disconnect
|
|
124
|
+
# TODO: Handle the disconnect
|
|
123
125
|
break
|
|
124
126
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
async def process_user_message_content(
|
|
128
|
-
self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
|
|
127
|
+
def _extract_last_user_message_content(self, messages: list[UserMessages]) -> TextContent:
|
|
129
128
|
"""
|
|
130
|
-
|
|
129
|
+
Extracts the last user's TextContent from a list of messages.
|
|
131
130
|
|
|
132
|
-
:
|
|
133
|
-
|
|
134
|
-
"""
|
|
131
|
+
Args:
|
|
132
|
+
messages: List of UserMessages.
|
|
135
133
|
|
|
136
|
-
|
|
137
|
-
|
|
134
|
+
Returns:
|
|
135
|
+
TextContent object from the last user message.
|
|
138
136
|
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If no user text content is found.
|
|
139
|
+
"""
|
|
140
|
+
for user_message in messages[::-1]:
|
|
141
|
+
if user_message.role == UserMessageContentRoleType.USER:
|
|
139
142
|
for attachment in user_message.content:
|
|
140
|
-
|
|
141
143
|
if isinstance(attachment, TextContent):
|
|
142
144
|
return attachment
|
|
145
|
+
raise ValueError("No user text content found in messages.")
|
|
143
146
|
|
|
144
|
-
|
|
147
|
+
async def _process_websocket_user_interaction_response_message(
|
|
148
|
+
self, user_content: WebSocketUserInteractionResponseMessage) -> TextContent:
|
|
149
|
+
"""
|
|
150
|
+
Processes a WebSocketUserInteractionResponseMessage.
|
|
151
|
+
"""
|
|
152
|
+
return self._extract_last_user_message_content(user_content.content.messages)
|
|
153
|
+
|
|
154
|
+
async def _process_websocket_user_message(self, user_content: WebSocketUserMessage) -> ChatRequest | str:
|
|
155
|
+
"""
|
|
156
|
+
Processes a WebSocketUserMessage based on schema type.
|
|
157
|
+
"""
|
|
158
|
+
if self._workflow_schema_type in [WorkflowSchemaType.CHAT, WorkflowSchemaType.CHAT_STREAM]:
|
|
159
|
+
return ChatRequest(**user_content.content.model_dump(include={"messages"}))
|
|
160
|
+
|
|
161
|
+
elif self._workflow_schema_type in [WorkflowSchemaType.GENERATE, WorkflowSchemaType.GENERATE_STREAM]:
|
|
162
|
+
return self._extract_last_user_message_content(user_content.content.messages).text
|
|
163
|
+
|
|
164
|
+
raise ValueError("Unsupported workflow schema type for WebSocketUserMessage")
|
|
145
165
|
|
|
146
166
|
async def process_workflow_request(self, user_message_as_validated_type: WebSocketUserMessage) -> None:
|
|
147
167
|
"""
|
|
148
168
|
Process user messages and routes them appropriately.
|
|
149
169
|
|
|
150
|
-
:
|
|
170
|
+
Args:
|
|
171
|
+
user_message_as_validated_type (WebSocketUserMessage): The validated user message to process.
|
|
151
172
|
"""
|
|
152
173
|
|
|
153
174
|
try:
|
|
@@ -155,25 +176,23 @@ class WebSocketMessageHandler:
|
|
|
155
176
|
self._workflow_schema_type = user_message_as_validated_type.schema_type
|
|
156
177
|
self._conversation_id = user_message_as_validated_type.conversation_id
|
|
157
178
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
if content is None:
|
|
161
|
-
raise ValueError(f"User message content could not be found: {user_message_as_validated_type}")
|
|
179
|
+
message_content: typing.Any = await self._process_websocket_user_message(user_message_as_validated_type)
|
|
162
180
|
|
|
163
|
-
if
|
|
181
|
+
if (self._running_workflow_task is None):
|
|
164
182
|
|
|
165
|
-
def _done_callback(
|
|
183
|
+
def _done_callback(_task: asyncio.Task):
|
|
166
184
|
self._running_workflow_task = None
|
|
167
185
|
|
|
168
186
|
self._running_workflow_task = asyncio.create_task(
|
|
169
|
-
self._run_workflow(
|
|
170
|
-
self.
|
|
187
|
+
self._run_workflow(payload=message_content,
|
|
188
|
+
user_message_id=self._message_parent_id,
|
|
189
|
+
conversation_id=self._conversation_id,
|
|
171
190
|
result_type=self._schema_output_mapping[self._workflow_schema_type],
|
|
172
191
|
output_type=self._schema_output_mapping[
|
|
173
192
|
self._workflow_schema_type])).add_done_callback(_done_callback)
|
|
174
193
|
|
|
175
194
|
except ValueError as e:
|
|
176
|
-
logger.
|
|
195
|
+
logger.exception("User message content not found: %s", str(e))
|
|
177
196
|
await self.create_websocket_message(data_model=Error(code=ErrorTypes.INVALID_USER_MESSAGE_CONTENT,
|
|
178
197
|
message="User message content could not be found",
|
|
179
198
|
details=str(e)),
|
|
@@ -183,13 +202,14 @@ class WebSocketMessageHandler:
|
|
|
183
202
|
async def create_websocket_message(self,
|
|
184
203
|
data_model: BaseModel,
|
|
185
204
|
message_type: str | None = None,
|
|
186
|
-
status:
|
|
205
|
+
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS) -> None:
|
|
187
206
|
"""
|
|
188
207
|
Creates a websocket message that will be ready for routing based on message type or data model.
|
|
189
208
|
|
|
190
|
-
:
|
|
191
|
-
|
|
192
|
-
|
|
209
|
+
Args:
|
|
210
|
+
data_model (BaseModel): Message content model.
|
|
211
|
+
message_type (str | None): Message content model.
|
|
212
|
+
status (WebSocketMessageStatus): Message content model.
|
|
193
213
|
"""
|
|
194
214
|
try:
|
|
195
215
|
message: BaseModel | None = None
|
|
@@ -199,8 +219,8 @@ class WebSocketMessageHandler:
|
|
|
199
219
|
|
|
200
220
|
message_schema: type[BaseModel] = await self._message_validator.get_message_schema_by_type(message_type)
|
|
201
221
|
|
|
202
|
-
if 'id'
|
|
203
|
-
message_id: str = data_model
|
|
222
|
+
if hasattr(data_model, 'id'):
|
|
223
|
+
message_id: str = str(getattr(data_model, 'id'))
|
|
204
224
|
else:
|
|
205
225
|
message_id = str(uuid.uuid4())
|
|
206
226
|
|
|
@@ -241,7 +261,7 @@ class WebSocketMessageHandler:
|
|
|
241
261
|
f"Message type could not be resolved by input data model: {data_model.model_dump_json()}")
|
|
242
262
|
|
|
243
263
|
except (ValidationError, TypeError, ValueError) as e:
|
|
244
|
-
logger.
|
|
264
|
+
logger.exception("A data vaidation error ocurred creating websocket message: %s", str(e))
|
|
245
265
|
message = await self._message_validator.create_system_response_token_message(
|
|
246
266
|
message_type=WebSocketMessageType.ERROR_MESSAGE,
|
|
247
267
|
conversation_id=self._conversation_id,
|
|
@@ -256,12 +276,15 @@ class WebSocketMessageHandler:
|
|
|
256
276
|
Registered human interaction callback that processes human interactions and returns
|
|
257
277
|
responses from websocket connection.
|
|
258
278
|
|
|
259
|
-
:
|
|
260
|
-
|
|
279
|
+
Args:
|
|
280
|
+
prompt: Incoming interaction content data model.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
A Text Content Base Pydantic model.
|
|
261
284
|
"""
|
|
262
285
|
|
|
263
286
|
# First create a future from the loop for the human response
|
|
264
|
-
human_response_future: asyncio.Future[
|
|
287
|
+
human_response_future: asyncio.Future[TextContent] = asyncio.get_running_loop().create_future()
|
|
265
288
|
|
|
266
289
|
# Then add the future to the outstanding human prompts dictionary
|
|
267
290
|
self._user_interaction_response = human_response_future
|
|
@@ -277,10 +300,10 @@ class WebSocketMessageHandler:
|
|
|
277
300
|
return HumanResponseNotification()
|
|
278
301
|
|
|
279
302
|
# Wait for the human response future to complete
|
|
280
|
-
|
|
303
|
+
text_content: TextContent = await human_response_future
|
|
281
304
|
|
|
282
305
|
interaction_response: HumanResponse = await self._message_validator.convert_text_content_to_human_response(
|
|
283
|
-
|
|
306
|
+
text_content, prompt.content)
|
|
284
307
|
|
|
285
308
|
return interaction_response
|
|
286
309
|
|
|
@@ -290,17 +313,18 @@ class WebSocketMessageHandler:
|
|
|
290
313
|
|
|
291
314
|
async def _run_workflow(self,
|
|
292
315
|
payload: typing.Any,
|
|
316
|
+
user_message_id: str | None = None,
|
|
293
317
|
conversation_id: str | None = None,
|
|
294
318
|
result_type: type | None = None,
|
|
295
319
|
output_type: type | None = None) -> None:
|
|
296
320
|
|
|
297
321
|
try:
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
322
|
+
auth_callback = self._flow_handler.authenticate if self._flow_handler else None
|
|
323
|
+
async with self._session_manager.session(user_message_id=user_message_id,
|
|
324
|
+
conversation_id=conversation_id,
|
|
325
|
+
http_connection=self._socket,
|
|
326
|
+
user_input_callback=self.human_interaction_callback,
|
|
327
|
+
user_authentication_callback=auth_callback) as session:
|
|
304
328
|
|
|
305
329
|
async for value in generate_streaming_response(payload,
|
|
306
330
|
session_manager=session,
|
|
@@ -97,7 +97,7 @@ class MessageValidator:
|
|
|
97
97
|
return validated_message
|
|
98
98
|
|
|
99
99
|
except (ValidationError, TypeError, ValueError) as e:
|
|
100
|
-
logger.
|
|
100
|
+
logger.exception("A data validation error %s occurred for message: %s", str(e), str(message))
|
|
101
101
|
return await self.create_system_response_token_message(message_type=WebSocketMessageType.ERROR_MESSAGE,
|
|
102
102
|
content=Error(code=ErrorTypes.INVALID_MESSAGE,
|
|
103
103
|
message="Error validating message.",
|
|
@@ -119,7 +119,7 @@ class MessageValidator:
|
|
|
119
119
|
return schema
|
|
120
120
|
|
|
121
121
|
except (TypeError, ValueError) as e:
|
|
122
|
-
logger.
|
|
122
|
+
logger.exception("Error retrieving schema for message type '%s': %s", message_type, str(e))
|
|
123
123
|
return Error
|
|
124
124
|
|
|
125
125
|
async def convert_data_to_message_content(self, data_model: BaseModel) -> BaseModel:
|
|
@@ -139,8 +139,10 @@ class MessageValidator:
|
|
|
139
139
|
text_content: str = str(data_model.payload)
|
|
140
140
|
validated_message_content = SystemResponseContent(text=text_content)
|
|
141
141
|
|
|
142
|
-
elif
|
|
142
|
+
elif isinstance(data_model, ChatResponse):
|
|
143
143
|
validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
|
|
144
|
+
elif isinstance(data_model, ChatResponseChunk):
|
|
145
|
+
validated_message_content = SystemResponseContent(text=data_model.choices[0].delta.content)
|
|
144
146
|
|
|
145
147
|
elif (isinstance(data_model, ResponseIntermediateStep)):
|
|
146
148
|
validated_message_content = SystemIntermediateStepContent(name=data_model.name,
|
|
@@ -156,7 +158,7 @@ class MessageValidator:
|
|
|
156
158
|
return validated_message_content
|
|
157
159
|
|
|
158
160
|
except ValueError as e:
|
|
159
|
-
logger.
|
|
161
|
+
logger.exception("Input data could not be converted to validated message content: %s", str(e))
|
|
160
162
|
return Error(code=ErrorTypes.INVALID_DATA_CONTENT, message="Input data not supported.", details=str(e))
|
|
161
163
|
|
|
162
164
|
async def convert_text_content_to_human_response(self, text_content: TextContent,
|
|
@@ -191,7 +193,7 @@ class MessageValidator:
|
|
|
191
193
|
return human_response
|
|
192
194
|
|
|
193
195
|
except ValueError as e:
|
|
194
|
-
logger.
|
|
196
|
+
logger.exception("Error human response content not found: %s", str(e))
|
|
195
197
|
return HumanResponseText(text=str(e))
|
|
196
198
|
|
|
197
199
|
async def resolve_message_type_by_data(self, data_model: BaseModel) -> str:
|
|
@@ -204,7 +206,7 @@ class MessageValidator:
|
|
|
204
206
|
|
|
205
207
|
validated_message_type: str = ""
|
|
206
208
|
try:
|
|
207
|
-
if (isinstance(data_model,
|
|
209
|
+
if (isinstance(data_model, ResponsePayloadOutput | ChatResponse | ChatResponseChunk)):
|
|
208
210
|
validated_message_type = WebSocketMessageType.RESPONSE_MESSAGE
|
|
209
211
|
|
|
210
212
|
elif (isinstance(data_model, ResponseIntermediateStep)):
|
|
@@ -218,9 +220,7 @@ class MessageValidator:
|
|
|
218
220
|
return validated_message_type
|
|
219
221
|
|
|
220
222
|
except ValueError as e:
|
|
221
|
-
logger.
|
|
222
|
-
str(e),
|
|
223
|
-
exc_info=True)
|
|
223
|
+
logger.exception("Error type not found converting data to validated websocket message content: %s", str(e))
|
|
224
224
|
return WebSocketMessageType.ERROR_MESSAGE
|
|
225
225
|
|
|
226
226
|
async def get_intermediate_step_parent_id(self, data_model: ResponseIntermediateStep) -> str:
|
|
@@ -232,7 +232,7 @@ class MessageValidator:
|
|
|
232
232
|
"""
|
|
233
233
|
return data_model.parent_id or "root"
|
|
234
234
|
|
|
235
|
-
async def create_system_response_token_message(
|
|
235
|
+
async def create_system_response_token_message(
|
|
236
236
|
self,
|
|
237
237
|
message_type: Literal[WebSocketMessageType.RESPONSE_MESSAGE,
|
|
238
238
|
WebSocketMessageType.ERROR_MESSAGE] = WebSocketMessageType.RESPONSE_MESSAGE,
|
|
@@ -240,10 +240,9 @@ class MessageValidator:
|
|
|
240
240
|
thread_id: str = "default",
|
|
241
241
|
parent_id: str = "default",
|
|
242
242
|
conversation_id: str | None = None,
|
|
243
|
-
content: SystemResponseContent
|
|
244
|
-
| Error = SystemResponseContent(),
|
|
243
|
+
content: SystemResponseContent | Error = SystemResponseContent(),
|
|
245
244
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
246
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
245
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
247
246
|
) -> WebSocketSystemResponseTokenMessage | None:
|
|
248
247
|
"""
|
|
249
248
|
Creates a system response token message with default values.
|
|
@@ -269,10 +268,10 @@ class MessageValidator:
|
|
|
269
268
|
timestamp=timestamp)
|
|
270
269
|
|
|
271
270
|
except Exception as e:
|
|
272
|
-
logger.
|
|
271
|
+
logger.exception("Error creating system response token message: %s", str(e))
|
|
273
272
|
return None
|
|
274
273
|
|
|
275
|
-
async def create_system_intermediate_step_message(
|
|
274
|
+
async def create_system_intermediate_step_message(
|
|
276
275
|
self,
|
|
277
276
|
message_type: Literal[WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE] = (
|
|
278
277
|
WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE),
|
|
@@ -282,7 +281,7 @@ class MessageValidator:
|
|
|
282
281
|
conversation_id: str | None = None,
|
|
283
282
|
content: SystemIntermediateStepContent = SystemIntermediateStepContent(name="default", payload="default"),
|
|
284
283
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
285
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
284
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
286
285
|
) -> WebSocketSystemIntermediateStepMessage | None:
|
|
287
286
|
"""
|
|
288
287
|
Creates a system intermediate step message with default values.
|
|
@@ -308,10 +307,10 @@ class MessageValidator:
|
|
|
308
307
|
timestamp=timestamp)
|
|
309
308
|
|
|
310
309
|
except Exception as e:
|
|
311
|
-
logger.
|
|
310
|
+
logger.exception("Error creating system intermediate step message: %s", str(e))
|
|
312
311
|
return None
|
|
313
312
|
|
|
314
|
-
async def create_system_interaction_message(
|
|
313
|
+
async def create_system_interaction_message(
|
|
315
314
|
self,
|
|
316
315
|
*,
|
|
317
316
|
message_type: Literal[WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE] = (
|
|
@@ -322,8 +321,8 @@ class MessageValidator:
|
|
|
322
321
|
conversation_id: str | None = None,
|
|
323
322
|
content: HumanPrompt,
|
|
324
323
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
325
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
326
|
-
) -> WebSocketSystemInteractionMessage | None:
|
|
324
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
325
|
+
) -> WebSocketSystemInteractionMessage | None:
|
|
327
326
|
"""
|
|
328
327
|
Creates a system interaction message with default values.
|
|
329
328
|
|
|
@@ -348,5 +347,5 @@ class MessageValidator:
|
|
|
348
347
|
timestamp=timestamp)
|
|
349
348
|
|
|
350
349
|
except Exception as e:
|
|
351
|
-
logger.
|
|
350
|
+
logger.exception("Error creating system interaction message: %s", str(e))
|
|
352
351
|
return None
|
|
@@ -98,9 +98,9 @@ async def generate_streaming_response(payload: typing.Any,
|
|
|
98
98
|
yield item
|
|
99
99
|
else:
|
|
100
100
|
yield ResponsePayloadOutput(payload=item)
|
|
101
|
-
except Exception
|
|
101
|
+
except Exception:
|
|
102
102
|
# Handle exceptions here
|
|
103
|
-
raise
|
|
103
|
+
raise
|
|
104
104
|
finally:
|
|
105
105
|
await q.close()
|
|
106
106
|
|
|
@@ -165,9 +165,9 @@ async def generate_streaming_response_full(payload: typing.Any,
|
|
|
165
165
|
yield item
|
|
166
166
|
else:
|
|
167
167
|
yield ResponsePayloadOutput(payload=item)
|
|
168
|
-
except Exception
|
|
168
|
+
except Exception:
|
|
169
169
|
# Handle exceptions here
|
|
170
|
-
raise
|
|
170
|
+
raise
|
|
171
171
|
finally:
|
|
172
172
|
await q.close()
|
|
173
173
|
|
|
@@ -289,7 +289,7 @@ class StepAdaptor:
|
|
|
289
289
|
|
|
290
290
|
return event
|
|
291
291
|
|
|
292
|
-
def process(self, step: IntermediateStep) -> ResponseSerializable | None:
|
|
292
|
+
def process(self, step: IntermediateStep) -> ResponseSerializable | None:
|
|
293
293
|
|
|
294
294
|
# Track the chunk
|
|
295
295
|
self._history.append(step)
|
|
@@ -314,6 +314,6 @@ class StepAdaptor:
|
|
|
314
314
|
return self._handle_custom(payload, ancestry)
|
|
315
315
|
|
|
316
316
|
except Exception as e:
|
|
317
|
-
logger.
|
|
317
|
+
logger.exception("Error processing intermediate step: %s", e)
|
|
318
318
|
|
|
319
319
|
return None
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import importlib
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_config_file_path() -> str:
|
|
21
|
+
"""
|
|
22
|
+
Get the path to the NAT configuration file from the environment variable NAT_CONFIG_FILE.
|
|
23
|
+
Raises ValueError if the environment variable is not set.
|
|
24
|
+
"""
|
|
25
|
+
config_file_path = os.getenv("NAT_CONFIG_FILE")
|
|
26
|
+
if (not config_file_path):
|
|
27
|
+
raise ValueError("Config file not found in environment variable NAT_CONFIG_FILE.")
|
|
28
|
+
|
|
29
|
+
return os.path.abspath(config_file_path)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def import_class_from_string(class_full_name: str) -> type:
|
|
33
|
+
"""
|
|
34
|
+
Import a class from a string in the format 'module.submodule.ClassName'.
|
|
35
|
+
Raises ImportError if the class cannot be imported.
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
class_name_parts = class_full_name.split(".")
|
|
39
|
+
|
|
40
|
+
module_name = ".".join(class_name_parts[:-1])
|
|
41
|
+
class_name = class_name_parts[-1]
|
|
42
|
+
|
|
43
|
+
module = importlib.import_module(module_name)
|
|
44
|
+
|
|
45
|
+
if not hasattr(module, class_name):
|
|
46
|
+
raise ValueError(f"Class '{class_full_name}' not found.")
|
|
47
|
+
|
|
48
|
+
return getattr(module, class_name)
|
|
49
|
+
except (ImportError, AttributeError) as e:
|
|
50
|
+
raise ImportError(f"Could not import {class_full_name}.") from e
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_class_name(cls: type) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Get the full class name including the module.
|
|
56
|
+
"""
|
|
57
|
+
return f"{cls.__module__}.{cls.__qualname__}"
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""OAuth 2.0 Token Introspection verifier implementation for MCP servers."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from mcp.server.auth.provider import AccessToken
|
|
20
|
+
from mcp.server.auth.provider import TokenVerifier
|
|
21
|
+
|
|
22
|
+
from nat.authentication.credential_validator.bearer_token_validator import BearerTokenValidator
|
|
23
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class IntrospectionTokenVerifier(TokenVerifier):
|
|
29
|
+
"""Token verifier that delegates token verification to BearerTokenValidator."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, config: OAuth2ResourceServerConfig):
|
|
32
|
+
"""Create IntrospectionTokenVerifier from OAuth2ResourceServerConfig.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config: OAuth2ResourceServerConfig
|
|
36
|
+
"""
|
|
37
|
+
issuer = config.issuer_url
|
|
38
|
+
scopes = config.scopes or []
|
|
39
|
+
audience = config.audience
|
|
40
|
+
jwks_uri = config.jwks_uri
|
|
41
|
+
introspection_endpoint = config.introspection_endpoint
|
|
42
|
+
discovery_url = config.discovery_url
|
|
43
|
+
client_id = config.client_id
|
|
44
|
+
client_secret = config.client_secret
|
|
45
|
+
|
|
46
|
+
self._bearer_token_validator = BearerTokenValidator(
|
|
47
|
+
issuer=issuer,
|
|
48
|
+
audience=audience,
|
|
49
|
+
scopes=scopes,
|
|
50
|
+
jwks_uri=jwks_uri,
|
|
51
|
+
introspection_endpoint=introspection_endpoint,
|
|
52
|
+
discovery_url=discovery_url,
|
|
53
|
+
client_id=client_id,
|
|
54
|
+
client_secret=client_secret,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
async def verify_token(self, token: str) -> AccessToken | None:
|
|
58
|
+
"""Verify token by delegating to BearerTokenValidator.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
token: The Bearer token to verify
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
AccessToken | None: AccessToken if valid, None if invalid
|
|
65
|
+
"""
|
|
66
|
+
validation_result = await self._bearer_token_validator.verify(token)
|
|
67
|
+
|
|
68
|
+
if validation_result.active:
|
|
69
|
+
return AccessToken(token=token,
|
|
70
|
+
expires_at=validation_result.expires_at,
|
|
71
|
+
scopes=validation_result.scopes or [],
|
|
72
|
+
client_id=validation_result.client_id or "")
|
|
73
|
+
return None
|