nvidia-nat 1.3.0a20250917__py3-none-any.whl → 1.3.0a20250923__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/register.py +3 -10
- nat/agent/reasoning_agent/reasoning_agent.py +3 -6
- nat/agent/register.py +0 -1
- nat/agent/rewoo_agent/agent.py +6 -1
- nat/agent/rewoo_agent/register.py +9 -10
- nat/agent/tool_calling_agent/register.py +3 -10
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/builder/context.py +28 -6
- nat/builder/function.py +165 -19
- nat/builder/workflow_builder.py +2 -0
- nat/cli/entrypoint.py +2 -9
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/{agent → control_flow}/router_agent/agent.py +3 -3
- nat/{agent → control_flow}/router_agent/register.py +8 -14
- nat/control_flow/sequential_executor.py +167 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/authentication.py +38 -0
- nat/front_ends/fastapi/dask_client_mixin.py +26 -4
- nat/front_ends/fastapi/fastapi_front_end_config.py +4 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +30 -7
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +5 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +108 -1
- nat/front_ends/mcp/tool_converter.py +3 -0
- nat/observability/mixin/type_introspection_mixin.py +19 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +5 -1
- nat/utils/log_levels.py +25 -0
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/METADATA +3 -1
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/RECORD +40 -31
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/entry_points.txt +1 -0
- /nat/{agent/router_agent → control_flow}/__init__.py +0 -0
- /nat/{agent → control_flow}/router_agent/prompt.py +0 -0
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/top_level.txt +0 -0
nat/data_models/agent.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
from pydantic import PositiveInt
|
|
18
|
+
|
|
19
|
+
from nat.data_models.component_ref import LLMRef
|
|
20
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AgentBaseConfig(FunctionBaseConfig):
|
|
24
|
+
"""Base configuration class for all NAT agents with common fields."""
|
|
25
|
+
|
|
26
|
+
workflow_alias: str | None = Field(
|
|
27
|
+
default=None,
|
|
28
|
+
description=("The alias of the workflow. Useful when the agent is configured as a workflow "
|
|
29
|
+
"and needs to expose a customized name as a tool."))
|
|
30
|
+
llm_name: LLMRef = Field(description="The LLM model to use with the agent.")
|
|
31
|
+
verbose: bool = Field(default=False, description="Set the verbosity of the agent's logging.")
|
|
32
|
+
description: str = Field(description="The description of this function's use.")
|
|
33
|
+
log_response_max_chars: PositiveInt = Field(
|
|
34
|
+
default=1000, description="Maximum number of characters to display in logs when logging responses.")
|
|
@@ -177,6 +177,26 @@ Credential = typing.Annotated[
|
|
|
177
177
|
]
|
|
178
178
|
|
|
179
179
|
|
|
180
|
+
class TokenValidationResult(BaseModel):
|
|
181
|
+
"""
|
|
182
|
+
Standard result for Bearer Token Validation.
|
|
183
|
+
"""
|
|
184
|
+
model_config = ConfigDict(extra="forbid")
|
|
185
|
+
|
|
186
|
+
client_id: str | None = Field(description="OAuth2 client identifier")
|
|
187
|
+
scopes: list[str] | None = Field(default=None, description="List of granted scopes (introspection only)")
|
|
188
|
+
expires_at: int | None = Field(default=None, description="Token expiration time (Unix timestamp)")
|
|
189
|
+
audience: list[str] | None = Field(default=None, description="Token audiences (aud claim)")
|
|
190
|
+
subject: str | None = Field(default=None, description="Token subject (sub claim)")
|
|
191
|
+
issuer: str | None = Field(default=None, description="Token issuer (iss claim)")
|
|
192
|
+
token_type: str = Field(description="Token type")
|
|
193
|
+
active: bool | None = Field(default=True, description="Token active status")
|
|
194
|
+
nbf: int | None = Field(default=None, description="Not before time (Unix timestamp)")
|
|
195
|
+
iat: int | None = Field(default=None, description="Issued at time (Unix timestamp)")
|
|
196
|
+
jti: str | None = Field(default=None, description="JWT ID")
|
|
197
|
+
username: str | None = Field(default=None, description="Username (introspection only)")
|
|
198
|
+
|
|
199
|
+
|
|
180
200
|
class AuthResult(BaseModel):
|
|
181
201
|
"""
|
|
182
202
|
Represents the result of an authentication process.
|
|
@@ -229,3 +249,21 @@ class AuthResult(BaseModel):
|
|
|
229
249
|
target_kwargs.setdefault(k, {}).update(v)
|
|
230
250
|
else:
|
|
231
251
|
target_kwargs[k] = v
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class AuthReason(str, Enum):
|
|
255
|
+
"""
|
|
256
|
+
Why the caller is asking for auth now.
|
|
257
|
+
"""
|
|
258
|
+
NORMAL = "normal"
|
|
259
|
+
RETRY_AFTER_401 = "retry_after_401"
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class AuthRequest(BaseModel):
|
|
263
|
+
"""
|
|
264
|
+
Authentication request payload for provider.authenticate(...).
|
|
265
|
+
"""
|
|
266
|
+
model_config = ConfigDict(extra="forbid")
|
|
267
|
+
|
|
268
|
+
reason: AuthReason = Field(default=AuthReason.NORMAL, description="Purpose of this auth attempt.")
|
|
269
|
+
www_authenticate: str | None = Field(default=None, description="Raw WWW-Authenticate header from a 401 response.")
|
|
@@ -16,7 +16,9 @@
|
|
|
16
16
|
import typing
|
|
17
17
|
from abc import ABC
|
|
18
18
|
from collections.abc import AsyncGenerator
|
|
19
|
+
from collections.abc import Generator
|
|
19
20
|
from contextlib import asynccontextmanager
|
|
21
|
+
from contextlib import contextmanager
|
|
20
22
|
|
|
21
23
|
if typing.TYPE_CHECKING:
|
|
22
24
|
from dask.distributed import Client
|
|
@@ -27,17 +29,37 @@ class DaskClientMixin(ABC):
|
|
|
27
29
|
@asynccontextmanager
|
|
28
30
|
async def client(self, address: str) -> AsyncGenerator["Client"]:
|
|
29
31
|
"""
|
|
30
|
-
Async context manager for obtaining a Dask client
|
|
32
|
+
Async context manager for obtaining a Dask client.
|
|
31
33
|
|
|
32
34
|
Yields
|
|
33
35
|
------
|
|
34
36
|
Client
|
|
35
|
-
An
|
|
37
|
+
An async Dask client connected to the scheduler. The client is automatically closed when exiting the
|
|
36
38
|
context manager.
|
|
37
39
|
"""
|
|
38
40
|
from dask.distributed import Client
|
|
39
41
|
client = await Client(address=address, asynchronous=True)
|
|
40
42
|
|
|
41
|
-
|
|
43
|
+
try:
|
|
44
|
+
yield client
|
|
45
|
+
finally:
|
|
46
|
+
await client.close()
|
|
42
47
|
|
|
43
|
-
|
|
48
|
+
@contextmanager
|
|
49
|
+
def blocking_client(self, address: str) -> Generator["Client"]:
|
|
50
|
+
"""
|
|
51
|
+
context manager for obtaining a blocking Dask client.
|
|
52
|
+
|
|
53
|
+
Yields
|
|
54
|
+
------
|
|
55
|
+
Client
|
|
56
|
+
A blocking Dask client connected to the scheduler. The client is automatically closed when exiting the
|
|
57
|
+
context manager.
|
|
58
|
+
"""
|
|
59
|
+
from dask.distributed import Client
|
|
60
|
+
client = Client(address=address)
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
yield client
|
|
64
|
+
finally:
|
|
65
|
+
client.close()
|
|
@@ -211,6 +211,10 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
|
|
|
211
211
|
"Maximum number of async jobs to run concurrently, this controls the number of dask workers created. "
|
|
212
212
|
"This parameter is only used when scheduler_address is `None` and a Dask local cluster is created."),
|
|
213
213
|
ge=1)
|
|
214
|
+
dask_log_level: str = Field(
|
|
215
|
+
default="WARNING",
|
|
216
|
+
description="Logging level for Dask.",
|
|
217
|
+
)
|
|
214
218
|
step_adaptor: StepAdaptorConfig = StepAdaptorConfig()
|
|
215
219
|
|
|
216
220
|
workflow: typing.Annotated[EndpointBase, Field(description="Endpoint for the default workflow.")] = EndpointBase(
|
|
@@ -27,6 +27,7 @@ from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontE
|
|
|
27
27
|
from nat.front_ends.fastapi.main import get_app
|
|
28
28
|
from nat.front_ends.fastapi.utils import get_class_name
|
|
29
29
|
from nat.utils.io.yaml_tools import yaml_dump
|
|
30
|
+
from nat.utils.log_levels import LOG_LEVELS
|
|
30
31
|
|
|
31
32
|
if (typing.TYPE_CHECKING):
|
|
32
33
|
from nat.data_models.config import Config
|
|
@@ -79,16 +80,23 @@ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]
|
|
|
79
80
|
except: # noqa: E722
|
|
80
81
|
logger.exception("Error during job cleanup")
|
|
81
82
|
|
|
82
|
-
async def _submit_cleanup_task(self, scheduler_address: str, db_url: str):
|
|
83
|
+
async def _submit_cleanup_task(self, scheduler_address: str, db_url: str, log_level: int = logging.INFO):
|
|
83
84
|
"""Submit a cleanup task to the cluster to remove the job after expiry."""
|
|
84
|
-
logger.
|
|
85
|
+
logger.debug("Submitting periodic cleanup task to Dask cluster at %s", scheduler_address)
|
|
85
86
|
async with self.client(self._scheduler_address) as client:
|
|
86
87
|
self._periodic_cleanup_future = client.submit(self._periodic_cleanup,
|
|
87
88
|
scheduler_address=self._scheduler_address,
|
|
88
89
|
db_url=db_url,
|
|
89
|
-
log_level=
|
|
90
|
+
log_level=log_level)
|
|
90
91
|
|
|
91
|
-
|
|
92
|
+
@staticmethod
|
|
93
|
+
def _setup_worker():
|
|
94
|
+
"""
|
|
95
|
+
Setup function to be run in each worker process. This moves each worker into it's own process group.
|
|
96
|
+
This fixes an issue where a Ctrl-C in the terminal sends a SIGINT to all workers, which then causes the
|
|
97
|
+
workers to exit before the main process can shutdown the cluster gracefully.
|
|
98
|
+
"""
|
|
99
|
+
os.setsid()
|
|
92
100
|
|
|
93
101
|
async def run(self):
|
|
94
102
|
|
|
@@ -102,15 +110,27 @@ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]
|
|
|
102
110
|
# 1. Dask is installed and scheduler_address is None, we create a LocalCluster
|
|
103
111
|
# 2. Dask is installed and scheduler_address is set, we use the existing cluster
|
|
104
112
|
# 3. Dask is not installed, we skip the cluster setup
|
|
113
|
+
dask_log_level = LOG_LEVELS.get(self.front_end_config.dask_log_level.upper(), logging.WARNING)
|
|
114
|
+
dask_logger = logging.getLogger("distributed")
|
|
115
|
+
dask_logger.setLevel(dask_log_level)
|
|
116
|
+
|
|
105
117
|
self._scheduler_address = self.front_end_config.scheduler_address
|
|
106
118
|
if self._scheduler_address is None:
|
|
107
119
|
try:
|
|
120
|
+
|
|
108
121
|
from dask.distributed import LocalCluster
|
|
109
122
|
|
|
110
|
-
self._cluster = LocalCluster(
|
|
123
|
+
self._cluster = LocalCluster(processes=True,
|
|
124
|
+
silence_logs=dask_log_level,
|
|
125
|
+
n_workers=self.front_end_config.max_running_async_jobs,
|
|
111
126
|
threads_per_worker=1)
|
|
112
127
|
|
|
113
128
|
self._scheduler_address = self._cluster.scheduler.address
|
|
129
|
+
|
|
130
|
+
with self.blocking_client(self._scheduler_address) as client:
|
|
131
|
+
# Client.run submits a function to be run on each worker
|
|
132
|
+
client.run(self._setup_worker)
|
|
133
|
+
|
|
114
134
|
logger.info("Created local Dask cluster with scheduler at %s", self._scheduler_address)
|
|
115
135
|
|
|
116
136
|
except ImportError:
|
|
@@ -128,7 +148,9 @@ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]
|
|
|
128
148
|
|
|
129
149
|
# If self.front_end_config.db_url is None, then we need to get the actual url from the engine
|
|
130
150
|
db_url = str(db_engine.url)
|
|
131
|
-
await self._submit_cleanup_task(scheduler_address=self._scheduler_address,
|
|
151
|
+
await self._submit_cleanup_task(scheduler_address=self._scheduler_address,
|
|
152
|
+
db_url=db_url,
|
|
153
|
+
log_level=dask_log_level)
|
|
132
154
|
|
|
133
155
|
# Set environment variabls such that the worker subprocesses will know how to connect to dask and to
|
|
134
156
|
# the database
|
|
@@ -216,8 +238,9 @@ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]
|
|
|
216
238
|
|
|
217
239
|
if self._cluster is not None:
|
|
218
240
|
# Only shut down the cluster if we created it
|
|
219
|
-
logger.
|
|
241
|
+
logger.debug("Closing Local Dask cluster.")
|
|
220
242
|
self._cluster.close()
|
|
243
|
+
|
|
221
244
|
try:
|
|
222
245
|
os.remove(config_file_name)
|
|
223
246
|
except OSError as e:
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""OAuth 2.0 Token Introspection verifier implementation for MCP servers."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from mcp.server.auth.provider import AccessToken
|
|
20
|
+
from mcp.server.auth.provider import TokenVerifier
|
|
21
|
+
|
|
22
|
+
from nat.authentication.credential_validator.bearer_token_validator import BearerTokenValidator
|
|
23
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class IntrospectionTokenVerifier(TokenVerifier):
|
|
29
|
+
"""Token verifier that delegates token verification to BearerTokenValidator."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, config: OAuth2ResourceServerConfig):
|
|
32
|
+
"""Create IntrospectionTokenVerifier from OAuth2ResourceServerConfig.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config: OAuth2ResourceServerConfig
|
|
36
|
+
"""
|
|
37
|
+
issuer = config.issuer_url
|
|
38
|
+
scopes = config.scopes or []
|
|
39
|
+
audience = config.audience
|
|
40
|
+
jwks_uri = config.jwks_uri
|
|
41
|
+
introspection_endpoint = config.introspection_endpoint
|
|
42
|
+
discovery_url = config.discovery_url
|
|
43
|
+
client_id = config.client_id
|
|
44
|
+
client_secret = config.client_secret
|
|
45
|
+
|
|
46
|
+
self._bearer_token_validator = BearerTokenValidator(
|
|
47
|
+
issuer=issuer,
|
|
48
|
+
audience=audience,
|
|
49
|
+
scopes=scopes,
|
|
50
|
+
jwks_uri=jwks_uri,
|
|
51
|
+
introspection_endpoint=introspection_endpoint,
|
|
52
|
+
discovery_url=discovery_url,
|
|
53
|
+
client_id=client_id,
|
|
54
|
+
client_secret=client_secret,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
async def verify_token(self, token: str) -> AccessToken | None:
|
|
58
|
+
"""Verify token by delegating to BearerTokenValidator.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
token: The Bearer token to verify
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
AccessToken | None: AccessToken if valid, None if invalid
|
|
65
|
+
"""
|
|
66
|
+
validation_result = await self._bearer_token_validator.verify(token)
|
|
67
|
+
|
|
68
|
+
if validation_result.active:
|
|
69
|
+
return AccessToken(token=token,
|
|
70
|
+
expires_at=validation_result.expires_at,
|
|
71
|
+
scopes=validation_result.scopes or [],
|
|
72
|
+
client_id=validation_result.client_id or "")
|
|
73
|
+
return None
|
|
@@ -17,13 +17,14 @@ from typing import Literal
|
|
|
17
17
|
|
|
18
18
|
from pydantic import Field
|
|
19
19
|
|
|
20
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
20
21
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
24
25
|
"""MCP front end configuration.
|
|
25
26
|
|
|
26
|
-
A simple MCP (
|
|
27
|
+
A simple MCP (Model Context Protocol) front end for NeMo Agent toolkit.
|
|
27
28
|
"""
|
|
28
29
|
|
|
29
30
|
name: str = Field(default="NeMo Agent Toolkit MCP",
|
|
@@ -39,3 +40,6 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
|
|
|
39
40
|
description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
|
|
40
41
|
runner_class: str | None = Field(
|
|
41
42
|
default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
|
|
43
|
+
|
|
44
|
+
server_auth: OAuth2ResourceServerConfig | None = Field(
|
|
45
|
+
default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
import typing
|
|
18
18
|
|
|
19
|
+
from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
|
|
19
20
|
from nat.builder.front_end import FrontEndBase
|
|
20
21
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
21
22
|
from nat.front_ends.mcp.mcp_front_end_config import MCPFrontEndConfig
|
|
@@ -55,25 +56,50 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
|
|
|
55
56
|
|
|
56
57
|
return worker_class(self.full_config)
|
|
57
58
|
|
|
59
|
+
async def _create_token_verifier(self, token_verifier_config: OAuth2ResourceServerConfig):
|
|
60
|
+
"""Create a token verifier based on configuration."""
|
|
61
|
+
from nat.front_ends.mcp.introspection_token_verifier import IntrospectionTokenVerifier
|
|
62
|
+
|
|
63
|
+
if not self.front_end_config.server_auth:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
return IntrospectionTokenVerifier(token_verifier_config)
|
|
67
|
+
|
|
58
68
|
async def run(self) -> None:
|
|
59
69
|
"""Run the MCP server."""
|
|
60
70
|
# Import FastMCP
|
|
61
71
|
from mcp.server.fastmcp import FastMCP
|
|
62
72
|
|
|
63
|
-
# Create
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
host=self.front_end_config.host,
|
|
67
|
-
port=self.front_end_config.port,
|
|
68
|
-
debug=self.front_end_config.debug,
|
|
69
|
-
log_level=self.front_end_config.log_level,
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
# Get the worker instance and set up routes
|
|
73
|
-
worker = self._get_worker_instance()
|
|
73
|
+
# Create auth settings and token verifier if auth is required
|
|
74
|
+
auth_settings = None
|
|
75
|
+
token_verifier = None
|
|
74
76
|
|
|
75
77
|
# Build the workflow and add routes using the worker
|
|
76
78
|
async with WorkflowBuilder.from_config(config=self.full_config) as builder:
|
|
79
|
+
|
|
80
|
+
if self.front_end_config.server_auth:
|
|
81
|
+
from mcp.server.auth.settings import AuthSettings
|
|
82
|
+
from pydantic import AnyHttpUrl
|
|
83
|
+
|
|
84
|
+
server_url = f"http://{self.front_end_config.host}:{self.front_end_config.port}"
|
|
85
|
+
|
|
86
|
+
auth_settings = AuthSettings(issuer_url=AnyHttpUrl(self.front_end_config.server_auth.issuer_url),
|
|
87
|
+
required_scopes=self.front_end_config.server_auth.scopes,
|
|
88
|
+
resource_server_url=AnyHttpUrl(server_url))
|
|
89
|
+
|
|
90
|
+
token_verifier = await self._create_token_verifier(self.front_end_config.server_auth)
|
|
91
|
+
|
|
92
|
+
# Create an MCP server with the configured parameters
|
|
93
|
+
mcp = FastMCP(name=self.front_end_config.name,
|
|
94
|
+
host=self.front_end_config.host,
|
|
95
|
+
port=self.front_end_config.port,
|
|
96
|
+
debug=self.front_end_config.debug,
|
|
97
|
+
auth=auth_settings,
|
|
98
|
+
token_verifier=token_verifier)
|
|
99
|
+
|
|
100
|
+
# Get the worker instance and set up routes
|
|
101
|
+
worker = self._get_worker_instance()
|
|
102
|
+
|
|
77
103
|
# Add routes through the worker (includes health endpoint and function registration)
|
|
78
104
|
await worker.add_routes(mcp, builder)
|
|
79
105
|
|
|
@@ -16,11 +16,15 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
from abc import ABC
|
|
18
18
|
from abc import abstractmethod
|
|
19
|
+
from collections.abc import Mapping
|
|
20
|
+
from typing import Any
|
|
19
21
|
|
|
20
22
|
from mcp.server.fastmcp import FastMCP
|
|
23
|
+
from starlette.exceptions import HTTPException
|
|
21
24
|
from starlette.requests import Request
|
|
22
25
|
|
|
23
26
|
from nat.builder.function import Function
|
|
27
|
+
from nat.builder.function_base import FunctionBase
|
|
24
28
|
from nat.builder.workflow import Workflow
|
|
25
29
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
26
30
|
from nat.data_models.config import Config
|
|
@@ -98,10 +102,110 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
98
102
|
for function_group in workflow.function_groups.values():
|
|
99
103
|
functions.update(function_group.get_accessible_functions())
|
|
100
104
|
|
|
101
|
-
|
|
105
|
+
if workflow.config.workflow.workflow_alias:
|
|
106
|
+
functions[workflow.config.workflow.workflow_alias] = workflow
|
|
107
|
+
else:
|
|
108
|
+
functions[workflow.config.workflow.type] = workflow
|
|
102
109
|
|
|
103
110
|
return functions
|
|
104
111
|
|
|
112
|
+
def _setup_debug_endpoints(self, mcp: FastMCP, functions: Mapping[str, FunctionBase]) -> None:
|
|
113
|
+
"""Set up HTTP debug endpoints for introspecting tools and schemas.
|
|
114
|
+
|
|
115
|
+
Exposes:
|
|
116
|
+
- GET /debug/tools/list: List tools. Optional query param `name` (one or more, repeatable or comma separated)
|
|
117
|
+
selects a subset and returns details for those tools.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
@mcp.custom_route("/debug/tools/list", methods=["GET"])
|
|
121
|
+
async def list_tools(request: Request):
|
|
122
|
+
"""HTTP list tools endpoint."""
|
|
123
|
+
|
|
124
|
+
from starlette.responses import JSONResponse
|
|
125
|
+
|
|
126
|
+
from nat.front_ends.mcp.tool_converter import get_function_description
|
|
127
|
+
|
|
128
|
+
# Query params
|
|
129
|
+
# Support repeated names and comma-separated lists
|
|
130
|
+
names_param_list = set(request.query_params.getlist("name"))
|
|
131
|
+
names: list[str] = []
|
|
132
|
+
for raw in names_param_list:
|
|
133
|
+
# if p.strip() is empty, it won't be included in the list!
|
|
134
|
+
parts = [p.strip() for p in raw.split(",") if p.strip()]
|
|
135
|
+
names.extend(parts)
|
|
136
|
+
detail_raw = request.query_params.get("detail")
|
|
137
|
+
|
|
138
|
+
def _parse_detail_param(detail_param: str | None, has_names: bool) -> bool:
|
|
139
|
+
if detail_param is None:
|
|
140
|
+
if has_names:
|
|
141
|
+
return True
|
|
142
|
+
return False
|
|
143
|
+
v = detail_param.strip().lower()
|
|
144
|
+
if v in ("0", "false", "no", "off"):
|
|
145
|
+
return False
|
|
146
|
+
if v in ("1", "true", "yes", "on"):
|
|
147
|
+
return True
|
|
148
|
+
# For invalid values, default based on whether names are present
|
|
149
|
+
return has_names
|
|
150
|
+
|
|
151
|
+
# Helper function to build the input schema info
|
|
152
|
+
def _build_schema_info(fn: FunctionBase) -> dict[str, Any] | None:
|
|
153
|
+
schema = getattr(fn, "input_schema", None)
|
|
154
|
+
if schema is None:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
# check if schema is a ChatRequest
|
|
158
|
+
schema_name = getattr(schema, "__name__", "")
|
|
159
|
+
schema_qualname = getattr(schema, "__qualname__", "")
|
|
160
|
+
if "ChatRequest" in schema_name or "ChatRequest" in schema_qualname:
|
|
161
|
+
# Simplified interface used by MCP wrapper for ChatRequest
|
|
162
|
+
return {
|
|
163
|
+
"type": "object",
|
|
164
|
+
"properties": {
|
|
165
|
+
"query": {
|
|
166
|
+
"type": "string", "description": "User query string"
|
|
167
|
+
}
|
|
168
|
+
},
|
|
169
|
+
"required": ["query"],
|
|
170
|
+
"title": "ChatRequestQuery",
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
# Pydantic models provide model_json_schema
|
|
174
|
+
if schema is not None and hasattr(schema, "model_json_schema"):
|
|
175
|
+
return schema.model_json_schema()
|
|
176
|
+
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
def _build_final_json(functions_to_include: Mapping[str, FunctionBase],
|
|
180
|
+
include_schemas: bool = False) -> dict[str, Any]:
|
|
181
|
+
tools = []
|
|
182
|
+
for name, fn in functions_to_include.items():
|
|
183
|
+
list_entry: dict[str, Any] = {
|
|
184
|
+
"name": name, "description": get_function_description(fn), "is_workflow": hasattr(fn, "run")
|
|
185
|
+
}
|
|
186
|
+
if include_schemas:
|
|
187
|
+
list_entry["schema"] = _build_schema_info(fn)
|
|
188
|
+
tools.append(list_entry)
|
|
189
|
+
|
|
190
|
+
return {
|
|
191
|
+
"count": len(tools),
|
|
192
|
+
"tools": tools,
|
|
193
|
+
"server_name": mcp.name,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
if names:
|
|
197
|
+
# Return selected tools
|
|
198
|
+
try:
|
|
199
|
+
functions_to_include = {n: functions[n] for n in names}
|
|
200
|
+
except KeyError as e:
|
|
201
|
+
raise HTTPException(status_code=404, detail=f"Tool \"{e.args[0]}\" not found.") from e
|
|
202
|
+
else:
|
|
203
|
+
functions_to_include = functions
|
|
204
|
+
|
|
205
|
+
# Default for listing all: detail defaults to False unless explicitly set true
|
|
206
|
+
return JSONResponse(
|
|
207
|
+
_build_final_json(functions_to_include, _parse_detail_param(detail_raw, has_names=bool(names))))
|
|
208
|
+
|
|
105
209
|
|
|
106
210
|
class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
107
211
|
"""Default MCP front end plugin worker implementation."""
|
|
@@ -142,3 +246,6 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
|
142
246
|
# Add a simple fallback function if no functions were found
|
|
143
247
|
if not functions:
|
|
144
248
|
raise RuntimeError("No functions found in workflow. Please check your configuration.")
|
|
249
|
+
|
|
250
|
+
# After registration, expose debug endpoints for tool/schema inspection
|
|
251
|
+
self._setup_debug_endpoints(mcp, functions)
|
|
@@ -229,6 +229,9 @@ def get_function_description(function: FunctionBase) -> str:
|
|
|
229
229
|
# Try to get anything that might be a description
|
|
230
230
|
elif hasattr(config, "topic") and config.topic:
|
|
231
231
|
function_description = config.topic
|
|
232
|
+
# Try to get description from the workflow config
|
|
233
|
+
elif hasattr(config, "workflow") and hasattr(config.workflow, "description") and config.workflow.description:
|
|
234
|
+
function_description = config.workflow.description
|
|
232
235
|
|
|
233
236
|
elif isinstance(function, Function):
|
|
234
237
|
function_description = function.description
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import inspect
|
|
17
17
|
import logging
|
|
18
|
+
import types
|
|
18
19
|
from functools import lru_cache
|
|
19
20
|
from typing import Any
|
|
20
21
|
from typing import TypeVar
|
|
@@ -475,3 +476,21 @@ class TypeIntrospectionMixin:
|
|
|
475
476
|
except ValidationError:
|
|
476
477
|
logger.warning("Item %s is not compatible with output type %s", item, self.output_type)
|
|
477
478
|
return False
|
|
479
|
+
|
|
480
|
+
@lru_cache
|
|
481
|
+
def extract_non_optional_type(self, type_obj: type | types.UnionType) -> Any:
|
|
482
|
+
"""Extract the non-None type from Optional[T] or Union[T, None] types.
|
|
483
|
+
|
|
484
|
+
This is useful when you need to pass a type to a system that doesn't
|
|
485
|
+
understand Optional types (like registries that expect concrete types).
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
type_obj (type | types.UnionType): The type to extract from (could be Optional[T] or Union[T, None])
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
Any: The actual type without None, or the original type if not a union with None
|
|
492
|
+
"""
|
|
493
|
+
decomposed = DecomposedType(type_obj) # type: ignore[arg-type]
|
|
494
|
+
if decomposed.is_optional:
|
|
495
|
+
return decomposed.get_optional_type().type
|
|
496
|
+
return type_obj
|
|
@@ -28,7 +28,6 @@ from nat.eval.evaluate import EvaluationRun
|
|
|
28
28
|
from nat.eval.evaluate import EvaluationRunConfig
|
|
29
29
|
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
30
30
|
from nat.profiler.parameter_optimization.parameter_selection import pick_trial
|
|
31
|
-
from nat.profiler.parameter_optimization.pareto_visualizer import create_pareto_visualization
|
|
32
31
|
from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
|
|
33
32
|
|
|
34
33
|
logger = logging.getLogger(__name__)
|
|
@@ -133,6 +132,7 @@ def optimize_parameters(
|
|
|
133
132
|
|
|
134
133
|
# Generate Pareto front visualizations
|
|
135
134
|
try:
|
|
135
|
+
from nat.profiler.parameter_optimization.pareto_visualizer import create_pareto_visualization
|
|
136
136
|
logger.info("Generating Pareto front visualizations...")
|
|
137
137
|
create_pareto_visualization(
|
|
138
138
|
data_source=study,
|
|
@@ -143,6 +143,10 @@ def optimize_parameters(
|
|
|
143
143
|
show_plots=False # Don't show plots in automated runs
|
|
144
144
|
)
|
|
145
145
|
logger.info("Pareto visualizations saved to: %s", out_dir / "plots")
|
|
146
|
+
except ImportError as ie:
|
|
147
|
+
logger.warning("Could not import visualization dependencies: %s. "
|
|
148
|
+
"Have you installed nvidia-nat-profiling?",
|
|
149
|
+
ie)
|
|
146
150
|
except Exception as e:
|
|
147
151
|
logger.warning("Failed to generate visualizations: %s", e)
|
|
148
152
|
|
nat/utils/log_levels.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
# Define log level choices
|
|
19
|
+
LOG_LEVELS = {
|
|
20
|
+
'DEBUG': logging.DEBUG,
|
|
21
|
+
'INFO': logging.INFO,
|
|
22
|
+
'WARNING': logging.WARNING,
|
|
23
|
+
'ERROR': logging.ERROR,
|
|
24
|
+
'CRITICAL': logging.CRITICAL
|
|
25
|
+
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nvidia-nat
|
|
3
|
-
Version: 1.3.
|
|
3
|
+
Version: 1.3.0a20250923
|
|
4
4
|
Summary: NVIDIA NeMo Agent toolkit
|
|
5
5
|
Author: NVIDIA Corporation
|
|
6
6
|
Maintainer: NVIDIA Corporation
|
|
@@ -225,6 +225,7 @@ Requires-Dist: httpx~=0.27
|
|
|
225
225
|
Requires-Dist: jinja2~=3.1
|
|
226
226
|
Requires-Dist: jsonpath-ng~=1.7
|
|
227
227
|
Requires-Dist: mcp~=1.13
|
|
228
|
+
Requires-Dist: nest-asyncio~=1.6
|
|
228
229
|
Requires-Dist: networkx~=3.4
|
|
229
230
|
Requires-Dist: numpy~=1.26; python_version < "3.12"
|
|
230
231
|
Requires-Dist: numpy~=2.3; python_version >= "3.12"
|
|
@@ -299,6 +300,7 @@ Requires-Dist: nat_redact_pii; extra == "examples"
|
|
|
299
300
|
Requires-Dist: nat_retail_sales_agent; extra == "examples"
|
|
300
301
|
Requires-Dist: nat_router_agent; extra == "examples"
|
|
301
302
|
Requires-Dist: nat_semantic_kernel_demo; extra == "examples"
|
|
303
|
+
Requires-Dist: nat_sequential_executor; extra == "examples"
|
|
302
304
|
Requires-Dist: nat_simple_auth; extra == "examples"
|
|
303
305
|
Requires-Dist: nat_simple_web_query; extra == "examples"
|
|
304
306
|
Requires-Dist: nat_simple_web_query_eval; extra == "examples"
|