nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250922__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +1 -1
- nat/agent/react_agent/register.py +17 -14
- nat/agent/reasoning_agent/reasoning_agent.py +9 -7
- nat/agent/register.py +1 -0
- nat/agent/rewoo_agent/agent.py +9 -2
- nat/agent/rewoo_agent/register.py +16 -12
- nat/agent/tool_calling_agent/agent.py +69 -7
- nat/agent/tool_calling_agent/register.py +14 -13
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/builder/builder.py +27 -4
- nat/builder/component_utils.py +7 -3
- nat/builder/context.py +28 -6
- nat/builder/function.py +313 -0
- nat/builder/function_info.py +1 -1
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +215 -16
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -7
- nat/cli/entrypoint.py +4 -9
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +71 -0
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +167 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/authentication.py +38 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +40 -16
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/temperature_mixin.py +4 -3
- nat/data_models/top_p_mixin.py +4 -3
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/eval/config.py +1 -1
- nat/eval/evaluate.py +5 -1
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +18 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +134 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +5 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +111 -3
- nat/front_ends/mcp/tool_converter.py +3 -0
- nat/llm/aws_bedrock_llm.py +14 -3
- nat/llm/nim_llm.py +14 -3
- nat/llm/openai_llm.py +8 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/profiler/decorators/framework_wrapper.py +9 -6
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +108 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/tool/chat_completion.py +4 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/register.py +2 -7
- nat/utils/callable_utils.py +70 -0
- nat/utils/exception_handlers/automatic_retries.py +103 -48
- nat/utils/log_levels.py +25 -0
- nat/utils/type_utils.py +4 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/METADATA +10 -1
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/RECORD +105 -76
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/entry_points.txt +1 -0
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/top_level.txt +0 -0
nat/front_ends/fastapi/main.py
CHANGED
|
@@ -13,19 +13,24 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import importlib
|
|
17
16
|
import logging
|
|
18
17
|
import os
|
|
18
|
+
import typing
|
|
19
19
|
|
|
20
20
|
from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorkerBase
|
|
21
|
+
from nat.front_ends.fastapi.utils import get_config_file_path
|
|
22
|
+
from nat.front_ends.fastapi.utils import import_class_from_string
|
|
21
23
|
from nat.runtime.loader import load_config
|
|
22
24
|
|
|
25
|
+
if typing.TYPE_CHECKING:
|
|
26
|
+
from fastapi import FastAPI
|
|
27
|
+
|
|
23
28
|
logger = logging.getLogger(__name__)
|
|
24
29
|
|
|
25
30
|
|
|
26
|
-
def get_app():
|
|
31
|
+
def get_app() -> "FastAPI":
|
|
27
32
|
|
|
28
|
-
config_file_path =
|
|
33
|
+
config_file_path = get_config_file_path()
|
|
29
34
|
front_end_worker_full_name = os.getenv("NAT_FRONT_END_WORKER")
|
|
30
35
|
|
|
31
36
|
if (not config_file_path):
|
|
@@ -36,28 +41,15 @@ def get_app():
|
|
|
36
41
|
|
|
37
42
|
# Try to import the front end worker class
|
|
38
43
|
try:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
front_end_worker_module_name = ".".join(front_end_worker_parts[:-1])
|
|
43
|
-
front_end_worker_class_name = front_end_worker_parts[-1]
|
|
44
|
-
|
|
45
|
-
front_end_worker_module = importlib.import_module(front_end_worker_module_name)
|
|
46
|
-
|
|
47
|
-
if not hasattr(front_end_worker_module, front_end_worker_class_name):
|
|
48
|
-
raise ValueError(f"Front end worker {front_end_worker_full_name} not found.")
|
|
49
|
-
|
|
50
|
-
front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = getattr(front_end_worker_module,
|
|
51
|
-
front_end_worker_class_name)
|
|
44
|
+
front_end_worker_class: type[FastApiFrontEndPluginWorkerBase] = import_class_from_string(
|
|
45
|
+
front_end_worker_full_name)
|
|
52
46
|
|
|
53
47
|
if (not issubclass(front_end_worker_class, FastApiFrontEndPluginWorkerBase)):
|
|
54
48
|
raise ValueError(
|
|
55
49
|
f"Front end worker {front_end_worker_full_name} is not a subclass of FastApiFrontEndPluginWorker.")
|
|
56
50
|
|
|
57
51
|
# Load the config
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
config = load_config(abs_config_file_path)
|
|
52
|
+
config = load_config(config_file_path)
|
|
61
53
|
|
|
62
54
|
# Create an instance of the front end worker class
|
|
63
55
|
front_end_worker = front_end_worker_class(config)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import importlib
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_config_file_path() -> str:
|
|
21
|
+
"""
|
|
22
|
+
Get the path to the NAT configuration file from the environment variable NAT_CONFIG_FILE.
|
|
23
|
+
Raises ValueError if the environment variable is not set.
|
|
24
|
+
"""
|
|
25
|
+
config_file_path = os.getenv("NAT_CONFIG_FILE")
|
|
26
|
+
if (not config_file_path):
|
|
27
|
+
raise ValueError("Config file not found in environment variable NAT_CONFIG_FILE.")
|
|
28
|
+
|
|
29
|
+
return os.path.abspath(config_file_path)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def import_class_from_string(class_full_name: str) -> type:
|
|
33
|
+
"""
|
|
34
|
+
Import a class from a string in the format 'module.submodule.ClassName'.
|
|
35
|
+
Raises ImportError if the class cannot be imported.
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
class_name_parts = class_full_name.split(".")
|
|
39
|
+
|
|
40
|
+
module_name = ".".join(class_name_parts[:-1])
|
|
41
|
+
class_name = class_name_parts[-1]
|
|
42
|
+
|
|
43
|
+
module = importlib.import_module(module_name)
|
|
44
|
+
|
|
45
|
+
if not hasattr(module, class_name):
|
|
46
|
+
raise ValueError(f"Class '{class_full_name}' not found.")
|
|
47
|
+
|
|
48
|
+
return getattr(module, class_name)
|
|
49
|
+
except (ImportError, AttributeError) as e:
|
|
50
|
+
raise ImportError(f"Could not import {class_full_name}.") from e
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_class_name(cls: type) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Get the full class name including the module.
|
|
56
|
+
"""
|
|
57
|
+
return f"{cls.__module__}.{cls.__qualname__}"
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""OAuth 2.0 Token Introspection verifier implementation for MCP servers."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from mcp.server.auth.provider import AccessToken
|
|
20
|
+
from mcp.server.auth.provider import TokenVerifier
|
|
21
|
+
|
|
22
|
+
from nat.authentication.credential_validator.bearer_token_validator import BearerTokenValidator
|
|
23
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class IntrospectionTokenVerifier(TokenVerifier):
|
|
29
|
+
"""Token verifier that delegates token verification to BearerTokenValidator."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, config: OAuth2ResourceServerConfig):
|
|
32
|
+
"""Create IntrospectionTokenVerifier from OAuth2ResourceServerConfig.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config: OAuth2ResourceServerConfig
|
|
36
|
+
"""
|
|
37
|
+
issuer = config.issuer_url
|
|
38
|
+
scopes = config.scopes or []
|
|
39
|
+
audience = config.audience
|
|
40
|
+
jwks_uri = config.jwks_uri
|
|
41
|
+
introspection_endpoint = config.introspection_endpoint
|
|
42
|
+
discovery_url = config.discovery_url
|
|
43
|
+
client_id = config.client_id
|
|
44
|
+
client_secret = config.client_secret
|
|
45
|
+
|
|
46
|
+
self._bearer_token_validator = BearerTokenValidator(
|
|
47
|
+
issuer=issuer,
|
|
48
|
+
audience=audience,
|
|
49
|
+
scopes=scopes,
|
|
50
|
+
jwks_uri=jwks_uri,
|
|
51
|
+
introspection_endpoint=introspection_endpoint,
|
|
52
|
+
discovery_url=discovery_url,
|
|
53
|
+
client_id=client_id,
|
|
54
|
+
client_secret=client_secret,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
async def verify_token(self, token: str) -> AccessToken | None:
|
|
58
|
+
"""Verify token by delegating to BearerTokenValidator.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
token: The Bearer token to verify
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
AccessToken | None: AccessToken if valid, None if invalid
|
|
65
|
+
"""
|
|
66
|
+
validation_result = await self._bearer_token_validator.verify(token)
|
|
67
|
+
|
|
68
|
+
if validation_result.active:
|
|
69
|
+
return AccessToken(token=token,
|
|
70
|
+
expires_at=validation_result.expires_at,
|
|
71
|
+
scopes=validation_result.scopes or [],
|
|
72
|
+
client_id=validation_result.client_id or "")
|
|
73
|
+
return None
|
|
@@ -17,13 +17,14 @@ from typing import Literal
|
|
|
17
17
|
|
|
18
18
|
from pydantic import Field
|
|
19
19
|
|
|
20
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
20
21
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
24
25
|
"""MCP front end configuration.
|
|
25
26
|
|
|
26
|
-
A simple MCP (
|
|
27
|
+
A simple MCP (Model Context Protocol) front end for NeMo Agent toolkit.
|
|
27
28
|
"""
|
|
28
29
|
|
|
29
30
|
name: str = Field(default="NeMo Agent Toolkit MCP",
|
|
@@ -39,3 +40,6 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
|
39
40
|
description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
|
|
40
41
|
runner_class: str | None = Field(
|
|
41
42
|
default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
|
|
43
|
+
|
|
44
|
+
server_auth: OAuth2ResourceServerConfig | None = Field(
|
|
45
|
+
default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
import typing
|
|
18
18
|
|
|
19
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
19
20
|
from nat.builder.front_end import FrontEndBase
|
|
20
21
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
21
22
|
from nat.front_ends.mcp.mcp_front_end_config import MCPFrontEndConfig
|
|
@@ -55,25 +56,50 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
|
|
|
55
56
|
|
|
56
57
|
return worker_class(self.full_config)
|
|
57
58
|
|
|
59
|
+
async def _create_token_verifier(self, token_verifier_config: OAuth2ResourceServerConfig):
|
|
60
|
+
"""Create a token verifier based on configuration."""
|
|
61
|
+
from nat.front_ends.mcp.introspection_token_verifier import IntrospectionTokenVerifier
|
|
62
|
+
|
|
63
|
+
if not self.front_end_config.server_auth:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
return IntrospectionTokenVerifier(token_verifier_config)
|
|
67
|
+
|
|
58
68
|
async def run(self) -> None:
|
|
59
69
|
"""Run the MCP server."""
|
|
60
70
|
# Import FastMCP
|
|
61
71
|
from mcp.server.fastmcp import FastMCP
|
|
62
72
|
|
|
63
|
-
# Create
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
host=self.front_end_config.host,
|
|
67
|
-
port=self.front_end_config.port,
|
|
68
|
-
debug=self.front_end_config.debug,
|
|
69
|
-
log_level=self.front_end_config.log_level,
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
# Get the worker instance and set up routes
|
|
73
|
-
worker = self._get_worker_instance()
|
|
73
|
+
# Create auth settings and token verifier if auth is required
|
|
74
|
+
auth_settings = None
|
|
75
|
+
token_verifier = None
|
|
74
76
|
|
|
75
77
|
# Build the workflow and add routes using the worker
|
|
76
78
|
async with WorkflowBuilder.from_config(config=self.full_config) as builder:
|
|
79
|
+
|
|
80
|
+
if self.front_end_config.server_auth:
|
|
81
|
+
from mcp.server.auth.settings import AuthSettings
|
|
82
|
+
from pydantic import AnyHttpUrl
|
|
83
|
+
|
|
84
|
+
server_url = f"http://{self.front_end_config.host}:{self.front_end_config.port}"
|
|
85
|
+
|
|
86
|
+
auth_settings = AuthSettings(issuer_url=AnyHttpUrl(self.front_end_config.server_auth.issuer_url),
|
|
87
|
+
required_scopes=self.front_end_config.server_auth.scopes,
|
|
88
|
+
resource_server_url=AnyHttpUrl(server_url))
|
|
89
|
+
|
|
90
|
+
token_verifier = await self._create_token_verifier(self.front_end_config.server_auth)
|
|
91
|
+
|
|
92
|
+
# Create an MCP server with the configured parameters
|
|
93
|
+
mcp = FastMCP(name=self.front_end_config.name,
|
|
94
|
+
host=self.front_end_config.host,
|
|
95
|
+
port=self.front_end_config.port,
|
|
96
|
+
debug=self.front_end_config.debug,
|
|
97
|
+
auth=auth_settings,
|
|
98
|
+
token_verifier=token_verifier)
|
|
99
|
+
|
|
100
|
+
# Get the worker instance and set up routes
|
|
101
|
+
worker = self._get_worker_instance()
|
|
102
|
+
|
|
77
103
|
# Add routes through the worker (includes health endpoint and function registration)
|
|
78
104
|
await worker.add_routes(mcp, builder)
|
|
79
105
|
|
|
@@ -16,11 +16,15 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
from abc import ABC
|
|
18
18
|
from abc import abstractmethod
|
|
19
|
+
from collections.abc import Mapping
|
|
20
|
+
from typing import Any
|
|
19
21
|
|
|
20
22
|
from mcp.server.fastmcp import FastMCP
|
|
23
|
+
from starlette.exceptions import HTTPException
|
|
21
24
|
from starlette.requests import Request
|
|
22
25
|
|
|
23
26
|
from nat.builder.function import Function
|
|
27
|
+
from nat.builder.function_base import FunctionBase
|
|
24
28
|
from nat.builder.workflow import Workflow
|
|
25
29
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
26
30
|
from nat.data_models.config import Config
|
|
@@ -94,13 +98,114 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
94
98
|
functions: dict[str, Function] = {}
|
|
95
99
|
|
|
96
100
|
# Extract all functions from the workflow
|
|
97
|
-
|
|
98
|
-
|
|
101
|
+
functions.update(workflow.functions)
|
|
102
|
+
for function_group in workflow.function_groups.values():
|
|
103
|
+
functions.update(function_group.get_accessible_functions())
|
|
99
104
|
|
|
100
|
-
|
|
105
|
+
if workflow.config.workflow.workflow_alias:
|
|
106
|
+
functions[workflow.config.workflow.workflow_alias] = workflow
|
|
107
|
+
else:
|
|
108
|
+
functions[workflow.config.workflow.type] = workflow
|
|
101
109
|
|
|
102
110
|
return functions
|
|
103
111
|
|
|
112
|
+
def _setup_debug_endpoints(self, mcp: FastMCP, functions: Mapping[str, FunctionBase]) -> None:
|
|
113
|
+
"""Set up HTTP debug endpoints for introspecting tools and schemas.
|
|
114
|
+
|
|
115
|
+
Exposes:
|
|
116
|
+
- GET /debug/tools/list: List tools. Optional query param `name` (one or more, repeatable or comma separated)
|
|
117
|
+
selects a subset and returns details for those tools.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
@mcp.custom_route("/debug/tools/list", methods=["GET"])
|
|
121
|
+
async def list_tools(request: Request):
|
|
122
|
+
"""HTTP list tools endpoint."""
|
|
123
|
+
|
|
124
|
+
from starlette.responses import JSONResponse
|
|
125
|
+
|
|
126
|
+
from nat.front_ends.mcp.tool_converter import get_function_description
|
|
127
|
+
|
|
128
|
+
# Query params
|
|
129
|
+
# Support repeated names and comma-separated lists
|
|
130
|
+
names_param_list = set(request.query_params.getlist("name"))
|
|
131
|
+
names: list[str] = []
|
|
132
|
+
for raw in names_param_list:
|
|
133
|
+
# if p.strip() is empty, it won't be included in the list!
|
|
134
|
+
parts = [p.strip() for p in raw.split(",") if p.strip()]
|
|
135
|
+
names.extend(parts)
|
|
136
|
+
detail_raw = request.query_params.get("detail")
|
|
137
|
+
|
|
138
|
+
def _parse_detail_param(detail_param: str | None, has_names: bool) -> bool:
|
|
139
|
+
if detail_param is None:
|
|
140
|
+
if has_names:
|
|
141
|
+
return True
|
|
142
|
+
return False
|
|
143
|
+
v = detail_param.strip().lower()
|
|
144
|
+
if v in ("0", "false", "no", "off"):
|
|
145
|
+
return False
|
|
146
|
+
if v in ("1", "true", "yes", "on"):
|
|
147
|
+
return True
|
|
148
|
+
# For invalid values, default based on whether names are present
|
|
149
|
+
return has_names
|
|
150
|
+
|
|
151
|
+
# Helper function to build the input schema info
|
|
152
|
+
def _build_schema_info(fn: FunctionBase) -> dict[str, Any] | None:
|
|
153
|
+
schema = getattr(fn, "input_schema", None)
|
|
154
|
+
if schema is None:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
# check if schema is a ChatRequest
|
|
158
|
+
schema_name = getattr(schema, "__name__", "")
|
|
159
|
+
schema_qualname = getattr(schema, "__qualname__", "")
|
|
160
|
+
if "ChatRequest" in schema_name or "ChatRequest" in schema_qualname:
|
|
161
|
+
# Simplified interface used by MCP wrapper for ChatRequest
|
|
162
|
+
return {
|
|
163
|
+
"type": "object",
|
|
164
|
+
"properties": {
|
|
165
|
+
"query": {
|
|
166
|
+
"type": "string", "description": "User query string"
|
|
167
|
+
}
|
|
168
|
+
},
|
|
169
|
+
"required": ["query"],
|
|
170
|
+
"title": "ChatRequestQuery",
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
# Pydantic models provide model_json_schema
|
|
174
|
+
if schema is not None and hasattr(schema, "model_json_schema"):
|
|
175
|
+
return schema.model_json_schema()
|
|
176
|
+
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
def _build_final_json(functions_to_include: Mapping[str, FunctionBase],
|
|
180
|
+
include_schemas: bool = False) -> dict[str, Any]:
|
|
181
|
+
tools = []
|
|
182
|
+
for name, fn in functions_to_include.items():
|
|
183
|
+
list_entry: dict[str, Any] = {
|
|
184
|
+
"name": name, "description": get_function_description(fn), "is_workflow": hasattr(fn, "run")
|
|
185
|
+
}
|
|
186
|
+
if include_schemas:
|
|
187
|
+
list_entry["schema"] = _build_schema_info(fn)
|
|
188
|
+
tools.append(list_entry)
|
|
189
|
+
|
|
190
|
+
return {
|
|
191
|
+
"count": len(tools),
|
|
192
|
+
"tools": tools,
|
|
193
|
+
"server_name": mcp.name,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
if names:
|
|
197
|
+
# Return selected tools
|
|
198
|
+
try:
|
|
199
|
+
functions_to_include = {n: functions[n] for n in names}
|
|
200
|
+
except KeyError as e:
|
|
201
|
+
raise HTTPException(status_code=404, detail=f"Tool \"{e.args[0]}\" not found.") from e
|
|
202
|
+
else:
|
|
203
|
+
functions_to_include = functions
|
|
204
|
+
|
|
205
|
+
# Default for listing all: detail defaults to False unless explicitly set true
|
|
206
|
+
return JSONResponse(
|
|
207
|
+
_build_final_json(functions_to_include, _parse_detail_param(detail_raw, has_names=bool(names))))
|
|
208
|
+
|
|
104
209
|
|
|
105
210
|
class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
106
211
|
"""Default MCP front end plugin worker implementation."""
|
|
@@ -141,3 +246,6 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
|
141
246
|
# Add a simple fallback function if no functions were found
|
|
142
247
|
if not functions:
|
|
143
248
|
raise RuntimeError("No functions found in workflow. Please check your configuration.")
|
|
249
|
+
|
|
250
|
+
# After registration, expose debug endpoints for tool/schema inspection
|
|
251
|
+
self._setup_debug_endpoints(mcp, functions)
|
|
@@ -229,6 +229,9 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
229
229
|
# Try to get anything that might be a description
|
|
230
230
|
elif hasattr(config, "topic") and config.topic:
|
|
231
231
|
function_description = config.topic
|
|
232
|
+
# Try to get description from the workflow config
|
|
233
|
+
elif hasattr(config, "workflow") and hasattr(config.workflow, "description") and config.workflow.description:
|
|
234
|
+
function_description = config.workflow.description
|
|
232
235
|
|
|
233
236
|
elif isinstance(function, Function):
|
|
234
237
|
function_description = function.description
|
nat/llm/aws_bedrock_llm.py
CHANGED
|
@@ -21,22 +21,33 @@ from nat.builder.builder import Builder
|
|
|
21
21
|
from nat.builder.llm import LLMProviderInfo
|
|
22
22
|
from nat.cli.register_workflow import register_llm_provider
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
|
+
from nat.data_models.optimizable import OptimizableField
|
|
25
|
+
from nat.data_models.optimizable import OptimizableMixin
|
|
26
|
+
from nat.data_models.optimizable import SearchSpace
|
|
24
27
|
from nat.data_models.retry_mixin import RetryMixin
|
|
25
28
|
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
26
29
|
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
27
30
|
from nat.data_models.top_p_mixin import TopPMixin
|
|
28
31
|
|
|
29
32
|
|
|
30
|
-
class AWSBedrockModelConfig(LLMBaseConfig,
|
|
33
|
+
class AWSBedrockModelConfig(LLMBaseConfig,
|
|
34
|
+
RetryMixin,
|
|
35
|
+
OptimizableMixin,
|
|
36
|
+
TemperatureMixin,
|
|
37
|
+
TopPMixin,
|
|
38
|
+
ThinkingMixin,
|
|
39
|
+
name="aws_bedrock"):
|
|
31
40
|
"""An AWS Bedrock llm provider to be used with an LLM client."""
|
|
32
41
|
|
|
33
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
42
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
34
43
|
|
|
35
44
|
# Completion parameters
|
|
36
45
|
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
37
46
|
serialization_alias="model",
|
|
38
47
|
description="The model name for the hosted AWS Bedrock.")
|
|
39
|
-
max_tokens: int
|
|
48
|
+
max_tokens: int = OptimizableField(default=300,
|
|
49
|
+
description="Maximum number of tokens to generate.",
|
|
50
|
+
space=SearchSpace(high=2176, low=128, step=512))
|
|
40
51
|
context_size: int | None = Field(
|
|
41
52
|
default=1024,
|
|
42
53
|
gt=0,
|
nat/llm/nim_llm.py
CHANGED
|
@@ -22,23 +22,34 @@ from nat.builder.builder import Builder
|
|
|
22
22
|
from nat.builder.llm import LLMProviderInfo
|
|
23
23
|
from nat.cli.register_workflow import register_llm_provider
|
|
24
24
|
from nat.data_models.llm import LLMBaseConfig
|
|
25
|
+
from nat.data_models.optimizable import OptimizableField
|
|
26
|
+
from nat.data_models.optimizable import OptimizableMixin
|
|
27
|
+
from nat.data_models.optimizable import SearchSpace
|
|
25
28
|
from nat.data_models.retry_mixin import RetryMixin
|
|
26
29
|
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
27
30
|
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
28
31
|
from nat.data_models.top_p_mixin import TopPMixin
|
|
29
32
|
|
|
30
33
|
|
|
31
|
-
class NIMModelConfig(LLMBaseConfig,
|
|
34
|
+
class NIMModelConfig(LLMBaseConfig,
|
|
35
|
+
RetryMixin,
|
|
36
|
+
OptimizableMixin,
|
|
37
|
+
TemperatureMixin,
|
|
38
|
+
TopPMixin,
|
|
39
|
+
ThinkingMixin,
|
|
40
|
+
name="nim"):
|
|
32
41
|
"""An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client."""
|
|
33
42
|
|
|
34
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
43
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
35
44
|
|
|
36
45
|
api_key: str | None = Field(default=None, description="NVIDIA API key to interact with hosted NIM.")
|
|
37
46
|
base_url: str | None = Field(default=None, description="Base url to the hosted NIM.")
|
|
38
47
|
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
39
48
|
serialization_alias="model",
|
|
40
49
|
description="The model name for the hosted NIM.")
|
|
41
|
-
max_tokens: PositiveInt =
|
|
50
|
+
max_tokens: PositiveInt = OptimizableField(default=300,
|
|
51
|
+
description="Maximum number of tokens to generate.",
|
|
52
|
+
space=SearchSpace(high=2176, low=128, step=512))
|
|
42
53
|
|
|
43
54
|
|
|
44
55
|
@register_llm_provider(config_type=NIMModelConfig)
|
nat/llm/openai_llm.py
CHANGED
|
@@ -21,13 +21,20 @@ from nat.builder.builder import Builder
|
|
|
21
21
|
from nat.builder.llm import LLMProviderInfo
|
|
22
22
|
from nat.cli.register_workflow import register_llm_provider
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
|
+
from nat.data_models.optimizable import OptimizableMixin
|
|
24
25
|
from nat.data_models.retry_mixin import RetryMixin
|
|
25
26
|
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
26
27
|
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
27
28
|
from nat.data_models.top_p_mixin import TopPMixin
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
class OpenAIModelConfig(LLMBaseConfig,
|
|
31
|
+
class OpenAIModelConfig(LLMBaseConfig,
|
|
32
|
+
RetryMixin,
|
|
33
|
+
OptimizableMixin,
|
|
34
|
+
TemperatureMixin,
|
|
35
|
+
TopPMixin,
|
|
36
|
+
ThinkingMixin,
|
|
37
|
+
name="openai"):
|
|
31
38
|
"""An OpenAI LLM provider to be used with an LLM client."""
|
|
32
39
|
|
|
33
40
|
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|