nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
nat/front_ends/fastapi/main.py
CHANGED
|
@@ -13,19 +13,24 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import importlib
|
|
17
16
|
import logging
|
|
18
17
|
import os
|
|
18
|
+
import typing
|
|
19
19
|
|
|
20
20
|
from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorkerBase
|
|
21
|
+
from nat.front_ends.fastapi.utils import get_config_file_path
|
|
22
|
+
from nat.front_ends.fastapi.utils import import_class_from_string
|
|
21
23
|
from nat.runtime.loader import load_config
|
|
22
24
|
|
|
25
|
+
if typing.TYPE_CHECKING:
|
|
26
|
+
from fastapi import FastAPI
|
|
27
|
+
|
|
23
28
|
logger = logging.getLogger(__name__)
|
|
24
29
|
|
|
25
30
|
|
|
26
|
-
def get_app():
|
|
31
|
+
def get_app() -> "FastAPI":
|
|
27
32
|
|
|
28
|
-
config_file_path =
|
|
33
|
+
config_file_path = get_config_file_path()
|
|
29
34
|
front_end_worker_full_name = os.getenv("NAT_FRONT_END_WORKER")
|
|
30
35
|
|
|
31
36
|
if (not config_file_path):
|
|
@@ -36,28 +41,15 @@ def get_app():
|
|
|
36
41
|
|
|
37
42
|
# Try to import the front end worker class
|
|
38
43
|
try:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
front_end_worker_module_name = ".".join(front_end_worker_parts[:-1])
|
|
43
|
-
front_end_worker_class_name = front_end_worker_parts[-1]
|
|
44
|
-
|
|
45
|
-
front_end_worker_module = importlib.import_module(front_end_worker_module_name)
|
|
46
|
-
|
|
47
|
-
if not hasattr(front_end_worker_module, front_end_worker_class_name):
|
|
48
|
-
raise ValueError(f"Front end worker {front_end_worker_full_name} not found.")
|
|
49
|
-
|
|
50
|
-
front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = getattr(front_end_worker_module,
|
|
51
|
-
front_end_worker_class_name)
|
|
44
|
+
front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = import_class_from_string(
|
|
45
|
+
front_end_worker_full_name)
|
|
52
46
|
|
|
53
47
|
if (not issubclass(front_end_worker_class, FastApiFrontEndPluginWorkerBase)):
|
|
54
48
|
raise ValueError(
|
|
55
49
|
f"Front end worker {front_end_worker_full_name} is not a subclass of FastApiFrontEndPluginWorker.")
|
|
56
50
|
|
|
57
51
|
# Load the config
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
config = load_config(abs_config_file_path)
|
|
52
|
+
config = load_config(config_file_path)
|
|
61
53
|
|
|
62
54
|
# Create an instance of the front end worker class
|
|
63
55
|
front_end_worker = front_end_worker_class(config)
|
|
@@ -86,7 +86,7 @@ class WebSocketMessageHandler:
|
|
|
86
86
|
|
|
87
87
|
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
|
88
88
|
|
|
89
|
-
# TODO: Handle the exit
|
|
89
|
+
# TODO: Handle the exit
|
|
90
90
|
pass
|
|
91
91
|
|
|
92
92
|
async def run(self) -> None:
|
|
@@ -107,10 +107,8 @@ class WebSocketMessageHandler:
|
|
|
107
107
|
|
|
108
108
|
elif isinstance(
|
|
109
109
|
validated_message,
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
WebSocketSystemIntermediateStepMessage,
|
|
113
|
-
WebSocketSystemInteractionMessage)):
|
|
110
|
+
WebSocketSystemResponseTokenMessage | WebSocketSystemIntermediateStepMessage
|
|
111
|
+
| WebSocketSystemInteractionMessage):
|
|
114
112
|
# These messages are already handled by self.create_websocket_message(data_model=value, …)
|
|
115
113
|
# No further processing is needed here.
|
|
116
114
|
pass
|
|
@@ -119,11 +117,9 @@ class WebSocketMessageHandler:
|
|
|
119
117
|
user_content = await self.process_user_message_content(validated_message)
|
|
120
118
|
self._user_interaction_response.set_result(user_content)
|
|
121
119
|
except (asyncio.CancelledError, WebSocketDisconnect):
|
|
122
|
-
# TODO: Handle the disconnect
|
|
120
|
+
# TODO: Handle the disconnect
|
|
123
121
|
break
|
|
124
122
|
|
|
125
|
-
return None
|
|
126
|
-
|
|
127
123
|
async def process_user_message_content(
|
|
128
124
|
self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
|
|
129
125
|
"""
|
|
@@ -162,18 +158,19 @@ class WebSocketMessageHandler:
|
|
|
162
158
|
|
|
163
159
|
if isinstance(content, TextContent) and (self._running_workflow_task is None):
|
|
164
160
|
|
|
165
|
-
def _done_callback(task: asyncio.Task):
|
|
161
|
+
def _done_callback(task: asyncio.Task):
|
|
166
162
|
self._running_workflow_task = None
|
|
167
163
|
|
|
168
164
|
self._running_workflow_task = asyncio.create_task(
|
|
169
|
-
self._run_workflow(content.text,
|
|
170
|
-
self.
|
|
165
|
+
self._run_workflow(payload=content.text,
|
|
166
|
+
user_message_id=self._message_parent_id,
|
|
167
|
+
conversation_id=self._conversation_id,
|
|
171
168
|
result_type=self._schema_output_mapping[self._workflow_schema_type],
|
|
172
169
|
output_type=self._schema_output_mapping[
|
|
173
170
|
self._workflow_schema_type])).add_done_callback(_done_callback)
|
|
174
171
|
|
|
175
172
|
except ValueError as e:
|
|
176
|
-
logger.
|
|
173
|
+
logger.exception("User message content not found: %s", str(e))
|
|
177
174
|
await self.create_websocket_message(data_model=Error(code=ErrorTypes.INVALID_USER_MESSAGE_CONTENT,
|
|
178
175
|
message="User message content could not be found",
|
|
179
176
|
details=str(e)),
|
|
@@ -241,7 +238,7 @@ class WebSocketMessageHandler:
|
|
|
241
238
|
f"Message type could not be resolved by input data model: {data_model.model_dump_json()}")
|
|
242
239
|
|
|
243
240
|
except (ValidationError, TypeError, ValueError) as e:
|
|
244
|
-
logger.
|
|
241
|
+
logger.exception("A data vaidation error ocurred creating websocket message: %s", str(e))
|
|
245
242
|
message = await self._message_validator.create_system_response_token_message(
|
|
246
243
|
message_type=WebSocketMessageType.ERROR_MESSAGE,
|
|
247
244
|
conversation_id=self._conversation_id,
|
|
@@ -290,14 +287,16 @@ class WebSocketMessageHandler:
|
|
|
290
287
|
|
|
291
288
|
async def _run_workflow(self,
|
|
292
289
|
payload: typing.Any,
|
|
290
|
+
user_message_id: str | None = None,
|
|
293
291
|
conversation_id: str | None = None,
|
|
294
292
|
result_type: type | None = None,
|
|
295
293
|
output_type: type | None = None) -> None:
|
|
296
294
|
|
|
297
295
|
try:
|
|
298
296
|
async with self._session_manager.session(
|
|
297
|
+
user_message_id=user_message_id,
|
|
299
298
|
conversation_id=conversation_id,
|
|
300
|
-
|
|
299
|
+
http_connection=self._socket,
|
|
301
300
|
user_input_callback=self.human_interaction_callback,
|
|
302
301
|
user_authentication_callback=(self._flow_handler.authenticate
|
|
303
302
|
if self._flow_handler else None)) as session:
|
|
@@ -97,7 +97,7 @@ class MessageValidator:
|
|
|
97
97
|
return validated_message
|
|
98
98
|
|
|
99
99
|
except (ValidationError, TypeError, ValueError) as e:
|
|
100
|
-
logger.
|
|
100
|
+
logger.exception("A data validation error %s occurred for message: %s", str(e), str(message))
|
|
101
101
|
return await self.create_system_response_token_message(message_type=WebSocketMessageType.ERROR_MESSAGE,
|
|
102
102
|
content=Error(code=ErrorTypes.INVALID_MESSAGE,
|
|
103
103
|
message="Error validating message.",
|
|
@@ -119,7 +119,7 @@ class MessageValidator:
|
|
|
119
119
|
return schema
|
|
120
120
|
|
|
121
121
|
except (TypeError, ValueError) as e:
|
|
122
|
-
logger.
|
|
122
|
+
logger.exception("Error retrieving schema for message type '%s': %s", message_type, str(e))
|
|
123
123
|
return Error
|
|
124
124
|
|
|
125
125
|
async def convert_data_to_message_content(self, data_model: BaseModel) -> BaseModel:
|
|
@@ -139,7 +139,7 @@ class MessageValidator:
|
|
|
139
139
|
text_content: str = str(data_model.payload)
|
|
140
140
|
validated_message_content = SystemResponseContent(text=text_content)
|
|
141
141
|
|
|
142
|
-
elif (isinstance(data_model,
|
|
142
|
+
elif (isinstance(data_model, ChatResponse | ChatResponseChunk)):
|
|
143
143
|
validated_message_content = SystemResponseContent(text=data_model.choices[0].message.content)
|
|
144
144
|
|
|
145
145
|
elif (isinstance(data_model, ResponseIntermediateStep)):
|
|
@@ -156,7 +156,7 @@ class MessageValidator:
|
|
|
156
156
|
return validated_message_content
|
|
157
157
|
|
|
158
158
|
except ValueError as e:
|
|
159
|
-
logger.
|
|
159
|
+
logger.exception("Input data could not be converted to validated message content: %s", str(e))
|
|
160
160
|
return Error(code=ErrorTypes.INVALID_DATA_CONTENT, message="Input data not supported.", details=str(e))
|
|
161
161
|
|
|
162
162
|
async def convert_text_content_to_human_response(self, text_content: TextContent,
|
|
@@ -191,7 +191,7 @@ class MessageValidator:
|
|
|
191
191
|
return human_response
|
|
192
192
|
|
|
193
193
|
except ValueError as e:
|
|
194
|
-
logger.
|
|
194
|
+
logger.exception("Error human response content not found: %s", str(e))
|
|
195
195
|
return HumanResponseText(text=str(e))
|
|
196
196
|
|
|
197
197
|
async def resolve_message_type_by_data(self, data_model: BaseModel) -> str:
|
|
@@ -204,7 +204,7 @@ class MessageValidator:
|
|
|
204
204
|
|
|
205
205
|
validated_message_type: str = ""
|
|
206
206
|
try:
|
|
207
|
-
if (isinstance(data_model,
|
|
207
|
+
if (isinstance(data_model, ResponsePayloadOutput | ChatResponse | ChatResponseChunk)):
|
|
208
208
|
validated_message_type = WebSocketMessageType.RESPONSE_MESSAGE
|
|
209
209
|
|
|
210
210
|
elif (isinstance(data_model, ResponseIntermediateStep)):
|
|
@@ -218,9 +218,7 @@ class MessageValidator:
|
|
|
218
218
|
return validated_message_type
|
|
219
219
|
|
|
220
220
|
except ValueError as e:
|
|
221
|
-
logger.
|
|
222
|
-
str(e),
|
|
223
|
-
exc_info=True)
|
|
221
|
+
logger.exception("Error type not found converting data to validated websocket message content: %s", str(e))
|
|
224
222
|
return WebSocketMessageType.ERROR_MESSAGE
|
|
225
223
|
|
|
226
224
|
async def get_intermediate_step_parent_id(self, data_model: ResponseIntermediateStep) -> str:
|
|
@@ -232,7 +230,7 @@ class MessageValidator:
|
|
|
232
230
|
"""
|
|
233
231
|
return data_model.parent_id or "root"
|
|
234
232
|
|
|
235
|
-
async def create_system_response_token_message(
|
|
233
|
+
async def create_system_response_token_message(
|
|
236
234
|
self,
|
|
237
235
|
message_type: Literal[WebSocketMessageType.RESPONSE_MESSAGE,
|
|
238
236
|
WebSocketMessageType.ERROR_MESSAGE] = WebSocketMessageType.RESPONSE_MESSAGE,
|
|
@@ -243,7 +241,7 @@ class MessageValidator:
|
|
|
243
241
|
content: SystemResponseContent
|
|
244
242
|
| Error = SystemResponseContent(),
|
|
245
243
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
246
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
244
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
247
245
|
) -> WebSocketSystemResponseTokenMessage | None:
|
|
248
246
|
"""
|
|
249
247
|
Creates a system response token message with default values.
|
|
@@ -269,10 +267,10 @@ class MessageValidator:
|
|
|
269
267
|
timestamp=timestamp)
|
|
270
268
|
|
|
271
269
|
except Exception as e:
|
|
272
|
-
logger.
|
|
270
|
+
logger.exception("Error creating system response token message: %s", str(e))
|
|
273
271
|
return None
|
|
274
272
|
|
|
275
|
-
async def create_system_intermediate_step_message(
|
|
273
|
+
async def create_system_intermediate_step_message(
|
|
276
274
|
self,
|
|
277
275
|
message_type: Literal[WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE] = (
|
|
278
276
|
WebSocketMessageType.INTERMEDIATE_STEP_MESSAGE),
|
|
@@ -282,7 +280,7 @@ class MessageValidator:
|
|
|
282
280
|
conversation_id: str | None = None,
|
|
283
281
|
content: SystemIntermediateStepContent = SystemIntermediateStepContent(name="default", payload="default"),
|
|
284
282
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
285
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
283
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
286
284
|
) -> WebSocketSystemIntermediateStepMessage | None:
|
|
287
285
|
"""
|
|
288
286
|
Creates a system intermediate step message with default values.
|
|
@@ -308,10 +306,10 @@ class MessageValidator:
|
|
|
308
306
|
timestamp=timestamp)
|
|
309
307
|
|
|
310
308
|
except Exception as e:
|
|
311
|
-
logger.
|
|
309
|
+
logger.exception("Error creating system intermediate step message: %s", str(e))
|
|
312
310
|
return None
|
|
313
311
|
|
|
314
|
-
async def create_system_interaction_message(
|
|
312
|
+
async def create_system_interaction_message(
|
|
315
313
|
self,
|
|
316
314
|
*,
|
|
317
315
|
message_type: Literal[WebSocketMessageType.SYSTEM_INTERACTION_MESSAGE] = (
|
|
@@ -322,8 +320,8 @@ class MessageValidator:
|
|
|
322
320
|
conversation_id: str | None = None,
|
|
323
321
|
content: HumanPrompt,
|
|
324
322
|
status: WebSocketMessageStatus = WebSocketMessageStatus.IN_PROGRESS,
|
|
325
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
326
|
-
) -> WebSocketSystemInteractionMessage | None:
|
|
323
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
324
|
+
) -> WebSocketSystemInteractionMessage | None:
|
|
327
325
|
"""
|
|
328
326
|
Creates a system interaction message with default values.
|
|
329
327
|
|
|
@@ -348,5 +346,5 @@ class MessageValidator:
|
|
|
348
346
|
timestamp=timestamp)
|
|
349
347
|
|
|
350
348
|
except Exception as e:
|
|
351
|
-
logger.
|
|
349
|
+
logger.exception("Error creating system interaction message: %s", str(e))
|
|
352
350
|
return None
|
|
@@ -98,9 +98,9 @@ async def generate_streaming_response(payload: typing.Any,
|
|
|
98
98
|
yield item
|
|
99
99
|
else:
|
|
100
100
|
yield ResponsePayloadOutput(payload=item)
|
|
101
|
-
except Exception
|
|
101
|
+
except Exception:
|
|
102
102
|
# Handle exceptions here
|
|
103
|
-
raise
|
|
103
|
+
raise
|
|
104
104
|
finally:
|
|
105
105
|
await q.close()
|
|
106
106
|
|
|
@@ -165,9 +165,9 @@ async def generate_streaming_response_full(payload: typing.Any,
|
|
|
165
165
|
yield item
|
|
166
166
|
else:
|
|
167
167
|
yield ResponsePayloadOutput(payload=item)
|
|
168
|
-
except Exception
|
|
168
|
+
except Exception:
|
|
169
169
|
# Handle exceptions here
|
|
170
|
-
raise
|
|
170
|
+
raise
|
|
171
171
|
finally:
|
|
172
172
|
await q.close()
|
|
173
173
|
|
|
@@ -289,7 +289,7 @@ class StepAdaptor:
|
|
|
289
289
|
|
|
290
290
|
return event
|
|
291
291
|
|
|
292
|
-
def process(self, step: IntermediateStep) -> ResponseSerializable | None:
|
|
292
|
+
def process(self, step: IntermediateStep) -> ResponseSerializable | None:
|
|
293
293
|
|
|
294
294
|
# Track the chunk
|
|
295
295
|
self._history.append(step)
|
|
@@ -314,6 +314,6 @@ class StepAdaptor:
|
|
|
314
314
|
return self._handle_custom(payload, ancestry)
|
|
315
315
|
|
|
316
316
|
except Exception as e:
|
|
317
|
-
logger.
|
|
317
|
+
logger.exception("Error processing intermediate step: %s", e)
|
|
318
318
|
|
|
319
319
|
return None
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import importlib
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_config_file_path() -> str:
|
|
21
|
+
"""
|
|
22
|
+
Get the path to the NAT configuration file from the environment variable NAT_CONFIG_FILE.
|
|
23
|
+
Raises ValueError if the environment variable is not set.
|
|
24
|
+
"""
|
|
25
|
+
config_file_path = os.getenv("NAT_CONFIG_FILE")
|
|
26
|
+
if (not config_file_path):
|
|
27
|
+
raise ValueError("Config file not found in environment variable NAT_CONFIG_FILE.")
|
|
28
|
+
|
|
29
|
+
return os.path.abspath(config_file_path)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def import_class_from_string(class_full_name: str) -> type:
|
|
33
|
+
"""
|
|
34
|
+
Import a class from a string in the format 'module.submodule.ClassName'.
|
|
35
|
+
Raises ImportError if the class cannot be imported.
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
class_name_parts = class_full_name.split(".")
|
|
39
|
+
|
|
40
|
+
module_name = ".".join(class_name_parts[:-1])
|
|
41
|
+
class_name = class_name_parts[-1]
|
|
42
|
+
|
|
43
|
+
module = importlib.import_module(module_name)
|
|
44
|
+
|
|
45
|
+
if not hasattr(module, class_name):
|
|
46
|
+
raise ValueError(f"Class '{class_full_name}' not found.")
|
|
47
|
+
|
|
48
|
+
return getattr(module, class_name)
|
|
49
|
+
except (ImportError, AttributeError) as e:
|
|
50
|
+
raise ImportError(f"Could not import {class_full_name}.") from e
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_class_name(cls: type) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Get the full class name including the module.
|
|
56
|
+
"""
|
|
57
|
+
return f"{cls.__module__}.{cls.__qualname__}"
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""OAuth 2.0 Token Introspection verifier implementation for MCP servers."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from mcp.server.auth.provider import AccessToken
|
|
20
|
+
from mcp.server.auth.provider import TokenVerifier
|
|
21
|
+
|
|
22
|
+
from nat.authentication.credential_validator.bearer_token_validator import BearerTokenValidator
|
|
23
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class IntrospectionTokenVerifier(TokenVerifier):
|
|
29
|
+
"""Token verifier that delegates token verification to BearerTokenValidator."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, config: OAuth2ResourceServerConfig):
|
|
32
|
+
"""Create IntrospectionTokenVerifier from OAuth2ResourceServerConfig.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config: OAuth2ResourceServerConfig
|
|
36
|
+
"""
|
|
37
|
+
issuer = config.issuer_url
|
|
38
|
+
scopes = config.scopes or []
|
|
39
|
+
audience = config.audience
|
|
40
|
+
jwks_uri = config.jwks_uri
|
|
41
|
+
introspection_endpoint = config.introspection_endpoint
|
|
42
|
+
discovery_url = config.discovery_url
|
|
43
|
+
client_id = config.client_id
|
|
44
|
+
client_secret = config.client_secret
|
|
45
|
+
|
|
46
|
+
self._bearer_token_validator = BearerTokenValidator(
|
|
47
|
+
issuer=issuer,
|
|
48
|
+
audience=audience,
|
|
49
|
+
scopes=scopes,
|
|
50
|
+
jwks_uri=jwks_uri,
|
|
51
|
+
introspection_endpoint=introspection_endpoint,
|
|
52
|
+
discovery_url=discovery_url,
|
|
53
|
+
client_id=client_id,
|
|
54
|
+
client_secret=client_secret,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
async def verify_token(self, token: str) -> AccessToken | None:
|
|
58
|
+
"""Verify token by delegating to BearerTokenValidator.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
token: The Bearer token to verify
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
AccessToken | None: AccessToken if valid, None if invalid
|
|
65
|
+
"""
|
|
66
|
+
validation_result = await self._bearer_token_validator.verify(token)
|
|
67
|
+
|
|
68
|
+
if validation_result.active:
|
|
69
|
+
return AccessToken(token=token,
|
|
70
|
+
expires_at=validation_result.expires_at,
|
|
71
|
+
scopes=validation_result.scopes or [],
|
|
72
|
+
client_id=validation_result.client_id or "")
|
|
73
|
+
return None
|
|
@@ -13,15 +13,18 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
from typing import Literal
|
|
17
|
+
|
|
16
18
|
from pydantic import Field
|
|
17
19
|
|
|
20
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
18
21
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
19
22
|
|
|
20
23
|
|
|
21
24
|
class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
22
25
|
"""MCP front end configuration.
|
|
23
26
|
|
|
24
|
-
A simple MCP (
|
|
27
|
+
A simple MCP (Model Context Protocol) front end for NeMo Agent toolkit.
|
|
25
28
|
"""
|
|
26
29
|
|
|
27
30
|
name: str = Field(default="NeMo Agent Toolkit MCP",
|
|
@@ -32,5 +35,11 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
|
32
35
|
log_level: str = Field(default="INFO", description="Log level for the MCP server (default: INFO)")
|
|
33
36
|
tool_names: list[str] = Field(default_factory=list,
|
|
34
37
|
description="The list of tools MCP server will expose (default: all tools)")
|
|
38
|
+
transport: Literal["sse", "streamable-http"] = Field(
|
|
39
|
+
default="streamable-http",
|
|
40
|
+
description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
|
|
35
41
|
runner_class: str | None = Field(
|
|
36
42
|
default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
|
|
43
|
+
|
|
44
|
+
server_auth: OAuth2ResourceServerConfig | None = Field(
|
|
45
|
+
default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
import typing
|
|
18
18
|
|
|
19
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
19
20
|
from nat.builder.front_end import FrontEndBase
|
|
20
21
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
21
22
|
from nat.front_ends.mcp.mcp_front_end_config import MCPFrontEndConfig
|
|
@@ -55,27 +56,58 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
|
|
|
55
56
|
|
|
56
57
|
return worker_class(self.full_config)
|
|
57
58
|
|
|
59
|
+
async def _create_token_verifier(self, token_verifier_config: OAuth2ResourceServerConfig):
|
|
60
|
+
"""Create a token verifier based on configuration."""
|
|
61
|
+
from nat.front_ends.mcp.introspection_token_verifier import IntrospectionTokenVerifier
|
|
62
|
+
|
|
63
|
+
if not self.front_end_config.server_auth:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
return IntrospectionTokenVerifier(token_verifier_config)
|
|
67
|
+
|
|
58
68
|
async def run(self) -> None:
|
|
59
69
|
"""Run the MCP server."""
|
|
60
70
|
# Import FastMCP
|
|
61
71
|
from mcp.server.fastmcp import FastMCP
|
|
62
72
|
|
|
63
|
-
# Create
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
host=self.front_end_config.host,
|
|
67
|
-
port=self.front_end_config.port,
|
|
68
|
-
debug=self.front_end_config.debug,
|
|
69
|
-
log_level=self.front_end_config.log_level,
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
# Get the worker instance and set up routes
|
|
73
|
-
worker = self._get_worker_instance()
|
|
73
|
+
# Create auth settings and token verifier if auth is required
|
|
74
|
+
auth_settings = None
|
|
75
|
+
token_verifier = None
|
|
74
76
|
|
|
75
77
|
# Build the workflow and add routes using the worker
|
|
76
78
|
async with WorkflowBuilder.from_config(config=self.full_config) as builder:
|
|
79
|
+
|
|
80
|
+
if self.front_end_config.server_auth:
|
|
81
|
+
from mcp.server.auth.settings import AuthSettings
|
|
82
|
+
from pydantic import AnyHttpUrl
|
|
83
|
+
|
|
84
|
+
server_url = f"http://{self.front_end_config.host}:{self.front_end_config.port}"
|
|
85
|
+
|
|
86
|
+
auth_settings = AuthSettings(issuer_url=AnyHttpUrl(self.front_end_config.server_auth.issuer_url),
|
|
87
|
+
required_scopes=self.front_end_config.server_auth.scopes,
|
|
88
|
+
resource_server_url=AnyHttpUrl(server_url))
|
|
89
|
+
|
|
90
|
+
token_verifier = await self._create_token_verifier(self.front_end_config.server_auth)
|
|
91
|
+
|
|
92
|
+
# Create an MCP server with the configured parameters
|
|
93
|
+
mcp = FastMCP(name=self.front_end_config.name,
|
|
94
|
+
host=self.front_end_config.host,
|
|
95
|
+
port=self.front_end_config.port,
|
|
96
|
+
debug=self.front_end_config.debug,
|
|
97
|
+
auth=auth_settings,
|
|
98
|
+
token_verifier=token_verifier)
|
|
99
|
+
|
|
100
|
+
# Get the worker instance and set up routes
|
|
101
|
+
worker = self._get_worker_instance()
|
|
102
|
+
|
|
77
103
|
# Add routes through the worker (includes health endpoint and function registration)
|
|
78
104
|
await worker.add_routes(mcp, builder)
|
|
79
105
|
|
|
80
|
-
# Start the MCP server
|
|
81
|
-
|
|
106
|
+
# Start the MCP server with configurable transport
|
|
107
|
+
# streamable-http is the default, but users can choose sse if preferred
|
|
108
|
+
if self.front_end_config.transport == "sse":
|
|
109
|
+
logger.info("Starting MCP server with SSE endpoint at /sse")
|
|
110
|
+
await mcp.run_sse_async()
|
|
111
|
+
else: # streamable-http
|
|
112
|
+
logger.info("Starting MCP server with streamable-http endpoint at /mcp/")
|
|
113
|
+
await mcp.run_streamable_http_async()
|