nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +13 -8
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +6 -5
- nat/agent/react_agent/register.py +49 -39
- nat/agent/reasoning_agent/reasoning_agent.py +17 -15
- nat/agent/register.py +2 -0
- nat/agent/responses_api_agent/__init__.py +14 -0
- nat/agent/responses_api_agent/register.py +126 -0
- nat/agent/rewoo_agent/agent.py +304 -117
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +51 -38
- nat/agent/tool_calling_agent/agent.py +75 -17
- nat/agent/tool_calling_agent/register.py +46 -23
- nat/authentication/api_key/api_key_auth_provider.py +6 -11
- nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
- nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
- nat/builder/builder.py +55 -23
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +54 -15
- nat/builder/eval_builder.py +14 -9
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +370 -0
- nat/builder/function_info.py +1 -1
- nat/builder/intermediate_step_manager.py +38 -2
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +306 -54
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/start.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/register.py.j2 +2 -2
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +60 -18
- nat/cli/entrypoint.py +15 -11
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +72 -1
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +199 -69
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +47 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +4 -3
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/intermediate_step.py +9 -1
- nat/data_models/llm.py +15 -1
- nat/data_models/openai_mcp.py +46 -0
- nat/data_models/optimizable.py +208 -0
- nat/data_models/optimizer.py +161 -0
- nat/data_models/span.py +41 -3
- nat/data_models/thinking_mixin.py +2 -2
- nat/embedder/azure_openai_embedder.py +2 -1
- nat/embedder/nim_embedder.py +3 -2
- nat/embedder/openai_embedder.py +3 -2
- nat/eval/config.py +1 -1
- nat/eval/dataset_handler/dataset_downloader.py +3 -2
- nat/eval/dataset_handler/dataset_filter.py +34 -2
- nat/eval/evaluate.py +10 -3
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +7 -4
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
- nat/eval/usage_stats.py +2 -0
- nat/eval/utils/output_uploader.py +3 -2
- nat/eval/utils/weave_eval.py +17 -3
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +2 -2
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +19 -7
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +69 -44
- nat/front_ends/fastapi/message_validator.py +8 -7
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +71 -3
- nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/tool_converter.py +78 -25
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +21 -8
- nat/llm/azure_openai_llm.py +14 -5
- nat/llm/litellm_llm.py +80 -0
- nat/llm/nim_llm.py +23 -9
- nat/llm/openai_llm.py +19 -7
- nat/llm/register.py +4 -0
- nat/llm/utils/thinking.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/exporter/span_exporter.py +43 -15
- nat/observability/exporter_manager.py +2 -2
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +1 -1
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/observability/register.py +16 -0
- nat/profiler/callbacks/langchain_callback_handler.py +32 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
- nat/profiler/callbacks/token_usage_base_model.py +2 -0
- nat/profiler/decorators/framework_wrapper.py +61 -9
- nat/profiler/decorators/function_tracking.py +35 -3
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/registry_handlers/pypi/register_pypi.py +5 -3
- nat/registry_handlers/rest/register_rest.py +5 -3
- nat/retriever/milvus/retriever.py +1 -1
- nat/retriever/nemo_retriever/register.py +2 -1
- nat/runtime/loader.py +1 -1
- nat/runtime/runner.py +111 -6
- nat/runtime/session.py +49 -3
- nat/settings/global_settings.py +2 -2
- nat/tool/chat_completion.py +4 -1
- nat/tool/code_execution/code_sandbox.py +3 -6
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
- nat/tool/datetime_tools.py +1 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/tool/register.py +2 -7
- nat/tool/server_tools.py +15 -2
- nat/utils/__init__.py +76 -0
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +278 -72
- nat/utils/io/yaml_tools.py +73 -3
- nat/utils/log_levels.py +25 -0
- nat/utils/responses_api.py +26 -0
- nat/utils/string_utils.py +16 -0
- nat/utils/type_converter.py +12 -3
- nat/utils/type_utils.py +6 -2
- nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -461
- nat/data_models/temperature_mixin.py +0 -43
- nat/data_models/top_p_mixin.py +0 -43
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
nat/front_ends/fastapi/main.py
CHANGED
|
@@ -13,19 +13,24 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import importlib
|
|
17
16
|
import logging
|
|
18
17
|
import os
|
|
18
|
+
import typing
|
|
19
19
|
|
|
20
20
|
from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorkerBase
|
|
21
|
+
from nat.front_ends.fastapi.utils import get_config_file_path
|
|
22
|
+
from nat.front_ends.fastapi.utils import import_class_from_string
|
|
21
23
|
from nat.runtime.loader import load_config
|
|
22
24
|
|
|
25
|
+
if typing.TYPE_CHECKING:
|
|
26
|
+
from fastapi import FastAPI
|
|
27
|
+
|
|
23
28
|
logger = logging.getLogger(__name__)
|
|
24
29
|
|
|
25
30
|
|
|
26
|
-
def get_app():
|
|
31
|
+
def get_app() -> "FastAPI":
|
|
27
32
|
|
|
28
|
-
config_file_path =
|
|
33
|
+
config_file_path = get_config_file_path()
|
|
29
34
|
front_end_worker_full_name = os.getenv("NAT_FRONT_END_WORKER")
|
|
30
35
|
|
|
31
36
|
if (not config_file_path):
|
|
@@ -36,28 +41,15 @@ def get_app():
|
|
|
36
41
|
|
|
37
42
|
# Try to import the front end worker class
|
|
38
43
|
try:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
front_end_worker_module_name = ".".join(front_end_worker_parts[:-1])
|
|
43
|
-
front_end_worker_class_name = front_end_worker_parts[-1]
|
|
44
|
-
|
|
45
|
-
front_end_worker_module = importlib.import_module(front_end_worker_module_name)
|
|
46
|
-
|
|
47
|
-
if not hasattr(front_end_worker_module, front_end_worker_class_name):
|
|
48
|
-
raise ValueError(f"Front end worker {front_end_worker_full_name} not found.")
|
|
49
|
-
|
|
50
|
-
front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = getattr(front_end_worker_module,
|
|
51
|
-
front_end_worker_class_name)
|
|
44
|
+
front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = import_class_from_string(
|
|
45
|
+
front_end_worker_full_name)
|
|
52
46
|
|
|
53
47
|
if (not issubclass(front_end_worker_class, FastApiFrontEndPluginWorkerBase)):
|
|
54
48
|
raise ValueError(
|
|
55
49
|
f"Front end worker {front_end_worker_full_name} is not a subclass of FastApiFrontEndPluginWorker.")
|
|
56
50
|
|
|
57
51
|
# Load the config
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
config = load_config(abs_config_file_path)
|
|
52
|
+
config = load_config(config_file_path)
|
|
61
53
|
|
|
62
54
|
# Create an instance of the front end worker class
|
|
63
55
|
front_end_worker = front_end_worker_class(config)
|
|
@@ -25,6 +25,7 @@ from pydantic import ValidationError
|
|
|
25
25
|
from starlette.websockets import WebSocketDisconnect
|
|
26
26
|
|
|
27
27
|
from nat.authentication.interfaces import FlowHandlerBase
|
|
28
|
+
from nat.data_models.api_server import ChatRequest
|
|
28
29
|
from nat.data_models.api_server import ChatResponse
|
|
29
30
|
from nat.data_models.api_server import ChatResponseChunk
|
|
30
31
|
from nat.data_models.api_server import Error
|
|
@@ -33,6 +34,8 @@ from nat.data_models.api_server import ResponsePayloadOutput
|
|
|
33
34
|
from nat.data_models.api_server import ResponseSerializable
|
|
34
35
|
from nat.data_models.api_server import SystemResponseContent
|
|
35
36
|
from nat.data_models.api_server import TextContent
|
|
37
|
+
from nat.data_models.api_server import UserMessageContentRoleType
|
|
38
|
+
from nat.data_models.api_server import UserMessages
|
|
36
39
|
from nat.data_models.api_server import WebSocketMessageStatus
|
|
37
40
|
from nat.data_models.api_server import WebSocketMessageType
|
|
38
41
|
from nat.data_models.api_server import WebSocketSystemInteractionMessage
|
|
@@ -64,12 +67,12 @@ class WebSocketMessageHandler:
|
|
|
64
67
|
self._running_workflow_task: asyncio.Task | None = None
|
|
65
68
|
self._message_parent_id: str = "default_id"
|
|
66
69
|
self._conversation_id: str | None = None
|
|
67
|
-
self._workflow_schema_type: str = None
|
|
68
|
-
self._user_interaction_response: asyncio.Future[
|
|
70
|
+
self._workflow_schema_type: str | None = None
|
|
71
|
+
self._user_interaction_response: asyncio.Future[TextContent] | None = None
|
|
69
72
|
|
|
70
73
|
self._flow_handler: FlowHandlerBase | None = None
|
|
71
74
|
|
|
72
|
-
self._schema_output_mapping: dict[str, type[BaseModel] | None] = {
|
|
75
|
+
self._schema_output_mapping: dict[str, type[BaseModel] | type[None]] = {
|
|
73
76
|
WorkflowSchemaType.GENERATE: self._session_manager.workflow.single_output_schema,
|
|
74
77
|
WorkflowSchemaType.CHAT: ChatResponse,
|
|
75
78
|
WorkflowSchemaType.CHAT_STREAM: ChatResponseChunk,
|
|
@@ -105,45 +108,67 @@ class WebSocketMessageHandler:
|
|
|
105
108
|
if (isinstance(validated_message, WebSocketUserMessage)):
|
|
106
109
|
await self.process_workflow_request(validated_message)
|
|
107
110
|
|
|
108
|
-
elif isinstance(
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
111
|
+
elif isinstance(
|
|
112
|
+
validated_message,
|
|
113
|
+
WebSocketSystemResponseTokenMessage | WebSocketSystemIntermediateStepMessage
|
|
114
|
+
| WebSocketSystemInteractionMessage):
|
|
112
115
|
# These messages are already handled by self.create_websocket_message(data_model=value, …)
|
|
113
116
|
# No further processing is needed here.
|
|
114
117
|
pass
|
|
115
118
|
|
|
116
119
|
elif (isinstance(validated_message, WebSocketUserInteractionResponseMessage)):
|
|
117
|
-
user_content = await self.
|
|
120
|
+
user_content = await self._process_websocket_user_interaction_response_message(validated_message)
|
|
121
|
+
assert self._user_interaction_response is not None
|
|
118
122
|
self._user_interaction_response.set_result(user_content)
|
|
119
123
|
except (asyncio.CancelledError, WebSocketDisconnect):
|
|
120
124
|
# TODO: Handle the disconnect
|
|
121
125
|
break
|
|
122
126
|
|
|
123
|
-
|
|
124
|
-
self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
|
|
127
|
+
def _extract_last_user_message_content(self, messages: list[UserMessages]) -> TextContent:
|
|
125
128
|
"""
|
|
126
|
-
|
|
129
|
+
Extracts the last user's TextContent from a list of messages.
|
|
127
130
|
|
|
128
|
-
:
|
|
129
|
-
|
|
130
|
-
"""
|
|
131
|
+
Args:
|
|
132
|
+
messages: List of UserMessages.
|
|
131
133
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
+
Returns:
|
|
135
|
+
TextContent object from the last user message.
|
|
134
136
|
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If no user text content is found.
|
|
139
|
+
"""
|
|
140
|
+
for user_message in messages[::-1]:
|
|
141
|
+
if user_message.role == UserMessageContentRoleType.USER:
|
|
135
142
|
for attachment in user_message.content:
|
|
136
|
-
|
|
137
143
|
if isinstance(attachment, TextContent):
|
|
138
144
|
return attachment
|
|
145
|
+
raise ValueError("No user text content found in messages.")
|
|
146
|
+
|
|
147
|
+
async def _process_websocket_user_interaction_response_message(
|
|
148
|
+
self, user_content: WebSocketUserInteractionResponseMessage) -> TextContent:
|
|
149
|
+
"""
|
|
150
|
+
Processes a WebSocketUserInteractionResponseMessage.
|
|
151
|
+
"""
|
|
152
|
+
return self._extract_last_user_message_content(user_content.content.messages)
|
|
139
153
|
|
|
140
|
-
|
|
154
|
+
async def _process_websocket_user_message(self, user_content: WebSocketUserMessage) -> ChatRequest | str:
|
|
155
|
+
"""
|
|
156
|
+
Processes a WebSocketUserMessage based on schema type.
|
|
157
|
+
"""
|
|
158
|
+
if self._workflow_schema_type in [WorkflowSchemaType.CHAT, WorkflowSchemaType.CHAT_STREAM]:
|
|
159
|
+
return ChatRequest(**user_content.content.model_dump(include={"messages"}))
|
|
160
|
+
|
|
161
|
+
elif self._workflow_schema_type in [WorkflowSchemaType.GENERATE, WorkflowSchemaType.GENERATE_STREAM]:
|
|
162
|
+
return self._extract_last_user_message_content(user_content.content.messages).text
|
|
163
|
+
|
|
164
|
+
raise ValueError("Unsupported workflow schema type for WebSocketUserMessage")
|
|
141
165
|
|
|
142
166
|
async def process_workflow_request(self, user_message_as_validated_type: WebSocketUserMessage) -> None:
|
|
143
167
|
"""
|
|
144
168
|
Process user messages and routes them appropriately.
|
|
145
169
|
|
|
146
|
-
:
|
|
170
|
+
Args:
|
|
171
|
+
user_message_as_validated_type (WebSocketUserMessage): The validated user message to process.
|
|
147
172
|
"""
|
|
148
173
|
|
|
149
174
|
try:
|
|
@@ -151,18 +176,15 @@ class WebSocketMessageHandler:
|
|
|
151
176
|
self._workflow_schema_type = user_message_as_validated_type.schema_type
|
|
152
177
|
self._conversation_id = user_message_as_validated_type.conversation_id
|
|
153
178
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
if content is None:
|
|
157
|
-
raise ValueError(f"User message content could not be found: {user_message_as_validated_type}")
|
|
179
|
+
message_content: typing.Any = await self._process_websocket_user_message(user_message_as_validated_type)
|
|
158
180
|
|
|
159
|
-
if
|
|
181
|
+
if (self._running_workflow_task is None):
|
|
160
182
|
|
|
161
|
-
def _done_callback(
|
|
183
|
+
def _done_callback(_task: asyncio.Task):
|
|
162
184
|
self._running_workflow_task = None
|
|
163
185
|
|
|
164
186
|
self._running_workflow_task = asyncio.create_task(
|
|
165
|
-
self._run_workflow(payload=
|
|
187
|
+
self._run_workflow(payload=message_content,
|
|
166
188
|
user_message_id=self._message_parent_id,
|
|
167
189
|
conversation_id=self._conversation_id,
|
|
168
190
|
result_type=self._schema_output_mapping[self._workflow_schema_type],
|
|
@@ -180,13 +202,14 @@ class WebSocketMessageHandler:
|
|
|
180
202
|
async def create_websocket_message(self,
|
|
181
203
|
data_model: BaseModel,
|
|
182
204
|
message_type: str | None = None,
|
|
183
|
-
status:
|
|
205
|
+
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS) -> None:
|
|
184
206
|
"""
|
|
185
207
|
Creates a websocket message that will be ready for routing based on message type or data model.
|
|
186
208
|
|
|
187
|
-
:
|
|
188
|
-
|
|
189
|
-
|
|
209
|
+
Args:
|
|
210
|
+
data_model (BaseModel): Message content model.
|
|
211
|
+
message_type (str | None): Message content model.
|
|
212
|
+
status (WebSocketMessageStatus): Message content model.
|
|
190
213
|
"""
|
|
191
214
|
try:
|
|
192
215
|
message: BaseModel | None = None
|
|
@@ -196,8 +219,8 @@ class WebSocketMessageHandler:
|
|
|
196
219
|
|
|
197
220
|
message_schema: type[BaseModel] = await self._message_validator.get_message_schema_by_type(message_type)
|
|
198
221
|
|
|
199
|
-
if 'id'
|
|
200
|
-
message_id: str = data_model
|
|
222
|
+
if hasattr(data_model, 'id'):
|
|
223
|
+
message_id: str = str(getattr(data_model, 'id'))
|
|
201
224
|
else:
|
|
202
225
|
message_id = str(uuid.uuid4())
|
|
203
226
|
|
|
@@ -253,12 +276,15 @@ class WebSocketMessageHandler:
|
|
|
253
276
|
Registered human interaction callback that processes human interactions and returns
|
|
254
277
|
responses from websocket connection.
|
|
255
278
|
|
|
256
|
-
:
|
|
257
|
-
|
|
279
|
+
Args:
|
|
280
|
+
prompt: Incoming interaction content data model.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
A Text Content Base Pydantic model.
|
|
258
284
|
"""
|
|
259
285
|
|
|
260
286
|
# First create a future from the loop for the human response
|
|
261
|
-
human_response_future: asyncio.Future[
|
|
287
|
+
human_response_future: asyncio.Future[TextContent] = asyncio.get_running_loop().create_future()
|
|
262
288
|
|
|
263
289
|
# Then add the future to the outstanding human prompts dictionary
|
|
264
290
|
self._user_interaction_response = human_response_future
|
|
@@ -274,10 +300,10 @@ class WebSocketMessageHandler:
|
|
|
274
300
|
return HumanResponseNotification()
|
|
275
301
|
|
|
276
302
|
# Wait for the human response future to complete
|
|
277
|
-
|
|
303
|
+
text_content: TextContent = await human_response_future
|
|
278
304
|
|
|
279
305
|
interaction_response: HumanResponse = await self._message_validator.convert_text_content_to_human_response(
|
|
280
|
-
|
|
306
|
+
text_content, prompt.content)
|
|
281
307
|
|
|
282
308
|
return interaction_response
|
|
283
309
|
|
|
@@ -293,13 +319,12 @@ class WebSocketMessageHandler:
|
|
|
293
319
|
output_type: type | None = None) -> None:
|
|
294
320
|
|
|
295
321
|
try:
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
if self._flow_handler else None)) as session:
|
|
322
|
+
auth_callback = self._flow_handler.authenticate if self._flow_handler else None
|
|
323
|
+
async with self._session_manager.session(user_message_id=user_message_id,
|
|
324
|
+
conversation_id=conversation_id,
|
|
325
|
+
http_connection=self._socket,
|
|
326
|
+
user_input_callback=self.human_interaction_callback,
|
|
327
|
+
user_authentication_callback=auth_callback) as session:
|
|
303
328
|
|
|
304
329
|
async for value in generate_streaming_response(payload,
|
|
305
330
|
session_manager=session,
|
|
@@ -139,8 +139,10 @@ class MessageValidator:
|
|
|
139
139
|
text_content: str = str(data_model.payload)
|
|
140
140
|
validated_message_content = SystemResponseContent(text=text_content)
|
|
141
141
|
|
|
142
|
-
elif
|
|
142
|
+
elif isinstance(data_model, ChatResponse):
|
|
143
143
|
validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
|
|
144
|
+
elif isinstance(data_model, ChatResponseChunk):
|
|
145
|
+
validated_message_content = SystemResponseContent(text=data_model.choices[0].delta.content)
|
|
144
146
|
|
|
145
147
|
elif (isinstance(data_model, ResponseIntermediateStep)):
|
|
146
148
|
validated_message_content = SystemIntermediateStepContent(name=data_model.name,
|
|
@@ -204,7 +206,7 @@ class MessageValidator:
|
|
|
204
206
|
|
|
205
207
|
validated_message_type: str = ""
|
|
206
208
|
try:
|
|
207
|
-
if (isinstance(data_model,
|
|
209
|
+
if (isinstance(data_model, ResponsePayloadOutput | ChatResponse | ChatResponseChunk)):
|
|
208
210
|
validated_message_type = WebSocketMessageType.RESPONSE_MESSAGE
|
|
209
211
|
|
|
210
212
|
elif (isinstance(data_model, ResponseIntermediateStep)):
|
|
@@ -238,10 +240,9 @@ class MessageValidator:
|
|
|
238
240
|
thread_id: str = "default",
|
|
239
241
|
parent_id: str = "default",
|
|
240
242
|
conversation_id: str | None = None,
|
|
241
|
-
content: SystemResponseContent
|
|
242
|
-
| Error = SystemResponseContent(),
|
|
243
|
+
content: SystemResponseContent | Error = SystemResponseContent(),
|
|
243
244
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
244
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
245
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
245
246
|
) -> WebSocketSystemResponseTokenMessage | None:
|
|
246
247
|
"""
|
|
247
248
|
Creates a system response token message with default values.
|
|
@@ -280,7 +281,7 @@ class MessageValidator:
|
|
|
280
281
|
conversation_id: str | None = None,
|
|
281
282
|
content: SystemIntermediateStepContent = SystemIntermediateStepContent(name="default", payload="default"),
|
|
282
283
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
283
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
284
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
284
285
|
) -> WebSocketSystemIntermediateStepMessage | None:
|
|
285
286
|
"""
|
|
286
287
|
Creates a system intermediate step message with default values.
|
|
@@ -320,7 +321,7 @@ class MessageValidator:
|
|
|
320
321
|
conversation_id: str | None = None,
|
|
321
322
|
content: HumanPrompt,
|
|
322
323
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
323
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
324
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
324
325
|
) -> WebSocketSystemInteractionMessage | None:
|
|
325
326
|
"""
|
|
326
327
|
Creates a system interaction message with default values.
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import importlib
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_config_file_path() -> str:
|
|
21
|
+
"""
|
|
22
|
+
Get the path to the NAT configuration file from the environment variable NAT_CONFIG_FILE.
|
|
23
|
+
Raises ValueError if the environment variable is not set.
|
|
24
|
+
"""
|
|
25
|
+
config_file_path = os.getenv("NAT_CONFIG_FILE")
|
|
26
|
+
if (not config_file_path):
|
|
27
|
+
raise ValueError("Config file not found in environment variable NAT_CONFIG_FILE.")
|
|
28
|
+
|
|
29
|
+
return os.path.abspath(config_file_path)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def import_class_from_string(class_full_name: str) -> type:
|
|
33
|
+
"""
|
|
34
|
+
Import a class from a string in the format 'module.submodule.ClassName'.
|
|
35
|
+
Raises ImportError if the class cannot be imported.
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
class_name_parts = class_full_name.split(".")
|
|
39
|
+
|
|
40
|
+
module_name = ".".join(class_name_parts[:-1])
|
|
41
|
+
class_name = class_name_parts[-1]
|
|
42
|
+
|
|
43
|
+
module = importlib.import_module(module_name)
|
|
44
|
+
|
|
45
|
+
if not hasattr(module, class_name):
|
|
46
|
+
raise ValueError(f"Class '{class_full_name}' not found.")
|
|
47
|
+
|
|
48
|
+
return getattr(module, class_name)
|
|
49
|
+
except (ImportError, AttributeError) as e:
|
|
50
|
+
raise ImportError(f"Could not import {class_full_name}.") from e
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_class_name(cls: type) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Get the full class name including the module.
|
|
56
|
+
"""
|
|
57
|
+
return f"{cls.__module__}.{cls.__qualname__}"
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""OAuth 2.0 Token Introspection verifier implementation for MCP servers."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from mcp.server.auth.provider import AccessToken
|
|
20
|
+
from mcp.server.auth.provider import TokenVerifier
|
|
21
|
+
|
|
22
|
+
from nat.authentication.credential_validator.bearer_token_validator import BearerTokenValidator
|
|
23
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class IntrospectionTokenVerifier(TokenVerifier):
|
|
29
|
+
"""Token verifier that delegates token verification to BearerTokenValidator."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, config: OAuth2ResourceServerConfig):
|
|
32
|
+
"""Create IntrospectionTokenVerifier from OAuth2ResourceServerConfig.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config: OAuth2ResourceServerConfig
|
|
36
|
+
"""
|
|
37
|
+
issuer = config.issuer_url
|
|
38
|
+
scopes = config.scopes or []
|
|
39
|
+
audience = config.audience
|
|
40
|
+
jwks_uri = config.jwks_uri
|
|
41
|
+
introspection_endpoint = config.introspection_endpoint
|
|
42
|
+
discovery_url = config.discovery_url
|
|
43
|
+
client_id = config.client_id
|
|
44
|
+
client_secret = config.client_secret
|
|
45
|
+
|
|
46
|
+
self._bearer_token_validator = BearerTokenValidator(
|
|
47
|
+
issuer=issuer,
|
|
48
|
+
audience=audience,
|
|
49
|
+
scopes=scopes,
|
|
50
|
+
jwks_uri=jwks_uri,
|
|
51
|
+
introspection_endpoint=introspection_endpoint,
|
|
52
|
+
discovery_url=discovery_url,
|
|
53
|
+
client_id=client_id,
|
|
54
|
+
client_secret=client_secret,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
async def verify_token(self, token: str) -> AccessToken | None:
|
|
58
|
+
"""Verify token by delegating to BearerTokenValidator.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
token: The Bearer token to verify
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
AccessToken | None: AccessToken if valid, None if invalid
|
|
65
|
+
"""
|
|
66
|
+
validation_result = await self._bearer_token_validator.verify(token)
|
|
67
|
+
|
|
68
|
+
if validation_result.active:
|
|
69
|
+
return AccessToken(token=token,
|
|
70
|
+
expires_at=validation_result.expires_at,
|
|
71
|
+
scopes=validation_result.scopes or [],
|
|
72
|
+
client_id=validation_result.client_id or "")
|
|
73
|
+
return None
|
|
@@ -13,17 +13,23 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import logging
|
|
16
17
|
from typing import Literal
|
|
17
18
|
|
|
18
19
|
from pydantic import Field
|
|
20
|
+
from pydantic import field_validator
|
|
21
|
+
from pydantic import model_validator
|
|
19
22
|
|
|
23
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
20
24
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
21
25
|
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
22
28
|
|
|
23
29
|
class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
24
30
|
"""MCP front end configuration.
|
|
25
31
|
|
|
26
|
-
A simple MCP (
|
|
32
|
+
A simple MCP (Model Context Protocol) front end for NeMo Agent toolkit.
|
|
27
33
|
"""
|
|
28
34
|
|
|
29
35
|
name: str = Field(default="NeMo Agent Toolkit MCP",
|
|
@@ -32,10 +38,72 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
|
32
38
|
port: int = Field(default=9901, description="Port to bind the server to (default: 9901)", ge=0, le=65535)
|
|
33
39
|
debug: bool = Field(default=False, description="Enable debug mode (default: False)")
|
|
34
40
|
log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
|
|
35
|
-
tool_names: list[str] = Field(
|
|
36
|
-
|
|
41
|
+
tool_names: list[str] = Field(
|
|
42
|
+
default_factory=list,
|
|
43
|
+
description="The list of tools MCP server will expose (default: all tools)."
|
|
44
|
+
"Tool names can be functions or function groups",
|
|
45
|
+
)
|
|
37
46
|
transport: Literal["sse", "streamable-http"] = Field(
|
|
38
47
|
default="streamable-http",
|
|
39
48
|
description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
|
|
40
49
|
runner_class: str | None = Field(
|
|
41
50
|
default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
|
|
51
|
+
base_path: str | None = Field(default=None,
|
|
52
|
+
description="Base path to mount the MCP server at (e.g., '/api/v1'). "
|
|
53
|
+
"If specified, the server will be accessible at http://host:port{base_path}/mcp. "
|
|
54
|
+
"If None, server runs at root path /mcp.")
|
|
55
|
+
|
|
56
|
+
server_auth: OAuth2ResourceServerConfig | None = Field(
|
|
57
|
+
default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
|
|
58
|
+
|
|
59
|
+
@field_validator('base_path')
|
|
60
|
+
@classmethod
|
|
61
|
+
def validate_base_path(cls, v: str | None) -> str | None:
|
|
62
|
+
"""Validate that base_path starts with '/' and doesn't end with '/'."""
|
|
63
|
+
if v is not None:
|
|
64
|
+
if not v.startswith('/'):
|
|
65
|
+
raise ValueError("base_path must start with '/'")
|
|
66
|
+
if v.endswith('/'):
|
|
67
|
+
raise ValueError("base_path must not end with '/'")
|
|
68
|
+
return v
|
|
69
|
+
|
|
70
|
+
# Memory profiling configuration
|
|
71
|
+
enable_memory_profiling: bool = Field(default=False,
|
|
72
|
+
description="Enable memory profiling and diagnostics (default: False)")
|
|
73
|
+
memory_profile_interval: int = Field(default=50,
|
|
74
|
+
description="Log memory stats every N requests (default: 50)",
|
|
75
|
+
ge=1)
|
|
76
|
+
memory_profile_top_n: int = Field(default=10,
|
|
77
|
+
description="Number of top memory allocations to log (default: 10)",
|
|
78
|
+
ge=1,
|
|
79
|
+
le=50)
|
|
80
|
+
memory_profile_log_level: str = Field(default="DEBUG",
|
|
81
|
+
description="Log level for memory profiling output (default: DEBUG)")
|
|
82
|
+
|
|
83
|
+
@model_validator(mode="after")
|
|
84
|
+
def validate_security_configuration(self):
|
|
85
|
+
"""Validate security configuration to prevent accidental misconfigurations."""
|
|
86
|
+
# Check if server is bound to a non-localhost interface without authentication
|
|
87
|
+
localhost_hosts = {"localhost", "127.0.0.1", "::1"}
|
|
88
|
+
if self.host not in localhost_hosts and self.server_auth is None:
|
|
89
|
+
logger.warning(
|
|
90
|
+
"MCP server is configured to bind to '%s' without authentication. "
|
|
91
|
+
"This may expose your server to unauthorized access. "
|
|
92
|
+
"Consider either: (1) binding to localhost for local-only access, "
|
|
93
|
+
"or (2) configuring server_auth for production deployments on public interfaces.",
|
|
94
|
+
self.host)
|
|
95
|
+
|
|
96
|
+
# Check if SSE transport is used (which doesn't support authentication)
|
|
97
|
+
if self.transport == "sse":
|
|
98
|
+
if self.server_auth is not None:
|
|
99
|
+
logger.warning("SSE transport does not support authentication. "
|
|
100
|
+
"The configured server_auth will be ignored. "
|
|
101
|
+
"For production use with authentication, use 'streamable-http' transport instead.")
|
|
102
|
+
elif self.host not in localhost_hosts:
|
|
103
|
+
logger.warning(
|
|
104
|
+
"SSE transport does not support authentication and is bound to '%s'. "
|
|
105
|
+
"This configuration is not recommended for production use. "
|
|
106
|
+
"For production deployments, use 'streamable-http' transport with server_auth configured.",
|
|
107
|
+
self.host)
|
|
108
|
+
|
|
109
|
+
return self
|