nvidia-nat 1.3.0a20250917__py3-none-any.whl → 1.3.0a20250923__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. nat/agent/react_agent/register.py +3 -10
  2. nat/agent/reasoning_agent/reasoning_agent.py +3 -6
  3. nat/agent/register.py +0 -1
  4. nat/agent/rewoo_agent/agent.py +6 -1
  5. nat/agent/rewoo_agent/register.py +9 -10
  6. nat/agent/tool_calling_agent/register.py +3 -10
  7. nat/authentication/credential_validator/__init__.py +14 -0
  8. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  9. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  10. nat/builder/context.py +28 -6
  11. nat/builder/function.py +165 -19
  12. nat/builder/workflow_builder.py +2 -0
  13. nat/cli/entrypoint.py +2 -9
  14. nat/control_flow/register.py +20 -0
  15. nat/control_flow/router_agent/__init__.py +0 -0
  16. nat/{agent → control_flow}/router_agent/agent.py +3 -3
  17. nat/{agent → control_flow}/router_agent/register.py +8 -14
  18. nat/control_flow/sequential_executor.py +167 -0
  19. nat/data_models/agent.py +34 -0
  20. nat/data_models/authentication.py +38 -0
  21. nat/front_ends/fastapi/dask_client_mixin.py +26 -4
  22. nat/front_ends/fastapi/fastapi_front_end_config.py +4 -0
  23. nat/front_ends/fastapi/fastapi_front_end_plugin.py +30 -7
  24. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  25. nat/front_ends/mcp/mcp_front_end_config.py +5 -1
  26. nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
  27. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +108 -1
  28. nat/front_ends/mcp/tool_converter.py +3 -0
  29. nat/observability/mixin/type_introspection_mixin.py +19 -0
  30. nat/profiler/parameter_optimization/parameter_optimizer.py +5 -1
  31. nat/utils/log_levels.py +25 -0
  32. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/METADATA +3 -1
  33. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/RECORD +40 -31
  34. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/entry_points.txt +1 -0
  35. /nat/{agent/router_agent → control_flow}/__init__.py +0 -0
  36. /nat/{agent → control_flow}/router_agent/prompt.py +0 -0
  37. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/WHEEL +0 -0
  38. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  39. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/licenses/LICENSE.md +0 -0
  40. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,34 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import Field
17
+ from pydantic import PositiveInt
18
+
19
+ from nat.data_models.component_ref import LLMRef
20
+ from nat.data_models.function import FunctionBaseConfig
21
+
22
+
23
+ class AgentBaseConfig(FunctionBaseConfig):
24
+ """Base configuration class for all NAT agents with common fields."""
25
+
26
+ workflow_alias: str | None = Field(
27
+ default=None,
28
+ description=("The alias of the workflow. Useful when the agent is configured as a workflow "
29
+ "and needs to expose a customized name as a tool."))
30
+ llm_name: LLMRef = Field(description="The LLM model to use with the agent.")
31
+ verbose: bool = Field(default=False, description="Set the verbosity of the agent's logging.")
32
+ description: str = Field(description="The description of this function's use.")
33
+ log_response_max_chars: PositiveInt = Field(
34
+ default=1000, description="Maximum number of characters to display in logs when logging responses.")
@@ -177,6 +177,26 @@ Credential = typing.Annotated[
177
177
  ]
178
178
 
179
179
 
180
+ class TokenValidationResult(BaseModel):
181
+ """
182
+ Standard result for Bearer Token Validation.
183
+ """
184
+ model_config = ConfigDict(extra="forbid")
185
+
186
+ client_id: str | None = Field(description="OAuth2 client identifier")
187
+ scopes: list[str] | None = Field(default=None, description="List of granted scopes (introspection only)")
188
+ expires_at: int | None = Field(default=None, description="Token expiration time (Unix timestamp)")
189
+ audience: list[str] | None = Field(default=None, description="Token audiences (aud claim)")
190
+ subject: str | None = Field(default=None, description="Token subject (sub claim)")
191
+ issuer: str | None = Field(default=None, description="Token issuer (iss claim)")
192
+ token_type: str = Field(description="Token type")
193
+ active: bool | None = Field(default=True, description="Token active status")
194
+ nbf: int | None = Field(default=None, description="Not before time (Unix timestamp)")
195
+ iat: int | None = Field(default=None, description="Issued at time (Unix timestamp)")
196
+ jti: str | None = Field(default=None, description="JWT ID")
197
+ username: str | None = Field(default=None, description="Username (introspection only)")
198
+
199
+
180
200
  class AuthResult(BaseModel):
181
201
  """
182
202
  Represents the result of an authentication process.
@@ -229,3 +249,21 @@ class AuthResult(BaseModel):
229
249
  target_kwargs.setdefault(k, {}).update(v)
230
250
  else:
231
251
  target_kwargs[k] = v
252
+
253
+
254
+ class AuthReason(str, Enum):
255
+ """
256
+ Why the caller is asking for auth now.
257
+ """
258
+ NORMAL = "normal"
259
+ RETRY_AFTER_401 = "retry_after_401"
260
+
261
+
262
+ class AuthRequest(BaseModel):
263
+ """
264
+ Authentication request payload for provider.authenticate(...).
265
+ """
266
+ model_config = ConfigDict(extra="forbid")
267
+
268
+ reason: AuthReason = Field(default=AuthReason.NORMAL, description="Purpose of this auth attempt.")
269
+ www_authenticate: str | None = Field(default=None, description="Raw WWW-Authenticate header from a 401 response.")
@@ -16,7 +16,9 @@
16
16
  import typing
17
17
  from abc import ABC
18
18
  from collections.abc import AsyncGenerator
19
+ from collections.abc import Generator
19
20
  from contextlib import asynccontextmanager
21
+ from contextlib import contextmanager
20
22
 
21
23
  if typing.TYPE_CHECKING:
22
24
  from dask.distributed import Client
@@ -27,17 +29,37 @@ class DaskClientMixin(ABC):
27
29
  @asynccontextmanager
28
30
  async def client(self, address: str) -> AsyncGenerator["Client"]:
29
31
  """
30
- Async context manager for obtaining a Dask client connection.
32
+ Async context manager for obtaining a Dask client.
31
33
 
32
34
  Yields
33
35
  ------
34
36
  Client
35
- An active Dask client connected to the scheduler. The client is automatically closed when exiting the
37
+ An async Dask client connected to the scheduler. The client is automatically closed when exiting the
36
38
  context manager.
37
39
  """
38
40
  from dask.distributed import Client
39
41
  client = await Client(address=address, asynchronous=True)
40
42
 
41
- yield client
43
+ try:
44
+ yield client
45
+ finally:
46
+ await client.close()
42
47
 
43
- await client.close()
48
+ @contextmanager
49
+ def blocking_client(self, address: str) -> Generator["Client"]:
50
+ """
51
+ context manager for obtaining a blocking Dask client.
52
+
53
+ Yields
54
+ ------
55
+ Client
56
+ A blocking Dask client connected to the scheduler. The client is automatically closed when exiting the
57
+ context manager.
58
+ """
59
+ from dask.distributed import Client
60
+ client = Client(address=address)
61
+
62
+ try:
63
+ yield client
64
+ finally:
65
+ client.close()
@@ -211,6 +211,10 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
211
211
  "Maximum number of async jobs to run concurrently, this controls the number of dask workers created. "
212
212
  "This parameter is only used when scheduler_address is `None` and a Dask local cluster is created."),
213
213
  ge=1)
214
+ dask_log_level: str = Field(
215
+ default="WARNING",
216
+ description="Logging level for Dask.",
217
+ )
214
218
  step_adaptor: StepAdaptorConfig = StepAdaptorConfig()
215
219
 
216
220
  workflow: typing.Annotated[EndpointBase, Field(description="Endpoint for the default workflow.")] = EndpointBase(
@@ -27,6 +27,7 @@ from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontE
27
27
  from nat.front_ends.fastapi.main import get_app
28
28
  from nat.front_ends.fastapi.utils import get_class_name
29
29
  from nat.utils.io.yaml_tools import yaml_dump
30
+ from nat.utils.log_levels import LOG_LEVELS
30
31
 
31
32
  if (typing.TYPE_CHECKING):
32
33
  from nat.data_models.config import Config
@@ -79,16 +80,23 @@ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]
79
80
  except: # noqa: E722
80
81
  logger.exception("Error during job cleanup")
81
82
 
82
- async def _submit_cleanup_task(self, scheduler_address: str, db_url: str):
83
+ async def _submit_cleanup_task(self, scheduler_address: str, db_url: str, log_level: int = logging.INFO):
83
84
  """Submit a cleanup task to the cluster to remove the job after expiry."""
84
- logger.info("Submitting periodic cleanup task to Dask cluster at %s", scheduler_address)
85
+ logger.debug("Submitting periodic cleanup task to Dask cluster at %s", scheduler_address)
85
86
  async with self.client(self._scheduler_address) as client:
86
87
  self._periodic_cleanup_future = client.submit(self._periodic_cleanup,
87
88
  scheduler_address=self._scheduler_address,
88
89
  db_url=db_url,
89
- log_level=logger.getEffectiveLevel())
90
+ log_level=log_level)
90
91
 
91
- logger.info("Submitted periodic cleanup task to Dask cluster at %s", scheduler_address)
92
+ @staticmethod
93
+ def _setup_worker():
94
+ """
95
+ Setup function to be run in each worker process. This moves each worker into it's own process group.
96
+ This fixes an issue where a Ctrl-C in the terminal sends a SIGINT to all workers, which then causes the
97
+ workers to exit before the main process can shutdown the cluster gracefully.
98
+ """
99
+ os.setsid()
92
100
 
93
101
  async def run(self):
94
102
 
@@ -102,15 +110,27 @@ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]
102
110
  # 1. Dask is installed and scheduler_address is None, we create a LocalCluster
103
111
  # 2. Dask is installed and scheduler_address is set, we use the existing cluster
104
112
  # 3. Dask is not installed, we skip the cluster setup
113
+ dask_log_level = LOG_LEVELS.get(self.front_end_config.dask_log_level.upper(), logging.WARNING)
114
+ dask_logger = logging.getLogger("distributed")
115
+ dask_logger.setLevel(dask_log_level)
116
+
105
117
  self._scheduler_address = self.front_end_config.scheduler_address
106
118
  if self._scheduler_address is None:
107
119
  try:
120
+
108
121
  from dask.distributed import LocalCluster
109
122
 
110
- self._cluster = LocalCluster(n_workers=self.front_end_config.max_running_async_jobs,
123
+ self._cluster = LocalCluster(processes=True,
124
+ silence_logs=dask_log_level,
125
+ n_workers=self.front_end_config.max_running_async_jobs,
111
126
  threads_per_worker=1)
112
127
 
113
128
  self._scheduler_address = self._cluster.scheduler.address
129
+
130
+ with self.blocking_client(self._scheduler_address) as client:
131
+ # Client.run submits a function to be run on each worker
132
+ client.run(self._setup_worker)
133
+
114
134
  logger.info("Created local Dask cluster with scheduler at %s", self._scheduler_address)
115
135
 
116
136
  except ImportError:
@@ -128,7 +148,9 @@ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]
128
148
 
129
149
  # If self.front_end_config.db_url is None, then we need to get the actual url from the engine
130
150
  db_url = str(db_engine.url)
131
- await self._submit_cleanup_task(scheduler_address=self._scheduler_address, db_url=db_url)
151
+ await self._submit_cleanup_task(scheduler_address=self._scheduler_address,
152
+ db_url=db_url,
153
+ log_level=dask_log_level)
132
154
 
133
155
  # Set environment variabls such that the worker subprocesses will know how to connect to dask and to
134
156
  # the database
@@ -216,8 +238,9 @@ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]
216
238
 
217
239
  if self._cluster is not None:
218
240
  # Only shut down the cluster if we created it
219
- logger.info("Closing Local Dask cluster.")
241
+ logger.debug("Closing Local Dask cluster.")
220
242
  self._cluster.close()
243
+
221
244
  try:
222
245
  os.remove(config_file_name)
223
246
  except OSError as e:
@@ -0,0 +1,73 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """OAuth 2.0 Token Introspection verifier implementation for MCP servers."""
16
+
17
+ import logging
18
+
19
+ from mcp.server.auth.provider import AccessToken
20
+ from mcp.server.auth.provider import TokenVerifier
21
+
22
+ from nat.authentication.credential_validator.bearer_token_validator import BearerTokenValidator
23
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class IntrospectionTokenVerifier(TokenVerifier):
29
+ """Token verifier that delegates token verification to BearerTokenValidator."""
30
+
31
+ def __init__(self, config: OAuth2ResourceServerConfig):
32
+ """Create IntrospectionTokenVerifier from OAuth2ResourceServerConfig.
33
+
34
+ Args:
35
+ config: OAuth2ResourceServerConfig
36
+ """
37
+ issuer = config.issuer_url
38
+ scopes = config.scopes or []
39
+ audience = config.audience
40
+ jwks_uri = config.jwks_uri
41
+ introspection_endpoint = config.introspection_endpoint
42
+ discovery_url = config.discovery_url
43
+ client_id = config.client_id
44
+ client_secret = config.client_secret
45
+
46
+ self._bearer_token_validator = BearerTokenValidator(
47
+ issuer=issuer,
48
+ audience=audience,
49
+ scopes=scopes,
50
+ jwks_uri=jwks_uri,
51
+ introspection_endpoint=introspection_endpoint,
52
+ discovery_url=discovery_url,
53
+ client_id=client_id,
54
+ client_secret=client_secret,
55
+ )
56
+
57
+ async def verify_token(self, token: str) -> AccessToken | None:
58
+ """Verify token by delegating to BearerTokenValidator.
59
+
60
+ Args:
61
+ token: The Bearer token to verify
62
+
63
+ Returns:
64
+ AccessToken | None: AccessToken if valid, None if invalid
65
+ """
66
+ validation_result = await self._bearer_token_validator.verify(token)
67
+
68
+ if validation_result.active:
69
+ return AccessToken(token=token,
70
+ expires_at=validation_result.expires_at,
71
+ scopes=validation_result.scopes or [],
72
+ client_id=validation_result.client_id or "")
73
+ return None
@@ -17,13 +17,14 @@ from typing import Literal
17
17
 
18
18
  from pydantic import Field
19
19
 
20
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
20
21
  from nat.data_models.front_end import FrontEndBaseConfig
21
22
 
22
23
 
23
24
  class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
24
25
  """MCP front end configuration.
25
26
 
26
- A simple MCP (Modular Communication Protocol) front end for NeMo Agent toolkit.
27
+ A simple MCP (Model Context Protocol) front end for NeMo Agent toolkit.
27
28
  """
28
29
 
29
30
  name: str = Field(default="NeMo Agent Toolkit MCP",
@@ -39,3 +40,6 @@ class MCPFrontEndConfig(FrontEndBaseConfig, name="mcp"):
39
40
  description="Transport type for the MCP server (default: streamable-http, backwards compatible with sse)")
40
41
  runner_class: str | None = Field(
41
42
  default=None, description="Custom worker class for handling MCP routes (default: built-in worker)")
43
+
44
+ server_auth: OAuth2ResourceServerConfig | None = Field(
45
+ default=None, description=("OAuth 2.0 Resource Server configuration for token verification."))
@@ -16,6 +16,7 @@
16
16
  import logging
17
17
  import typing
18
18
 
19
+ from nat.authentication.oauth2.oauth2_resource_server_config import OAuth2ResourceServerConfig
19
20
  from nat.builder.front_end import FrontEndBase
20
21
  from nat.builder.workflow_builder import WorkflowBuilder
21
22
  from nat.front_ends.mcp.mcp_front_end_config import MCPFrontEndConfig
@@ -55,25 +56,50 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
55
56
 
56
57
  return worker_class(self.full_config)
57
58
 
59
+ async def _create_token_verifier(self, token_verifier_config: OAuth2ResourceServerConfig):
60
+ """Create a token verifier based on configuration."""
61
+ from nat.front_ends.mcp.introspection_token_verifier import IntrospectionTokenVerifier
62
+
63
+ if not self.front_end_config.server_auth:
64
+ return None
65
+
66
+ return IntrospectionTokenVerifier(token_verifier_config)
67
+
58
68
  async def run(self) -> None:
59
69
  """Run the MCP server."""
60
70
  # Import FastMCP
61
71
  from mcp.server.fastmcp import FastMCP
62
72
 
63
- # Create an MCP server with the configured parameters
64
- mcp = FastMCP(
65
- self.front_end_config.name,
66
- host=self.front_end_config.host,
67
- port=self.front_end_config.port,
68
- debug=self.front_end_config.debug,
69
- log_level=self.front_end_config.log_level,
70
- )
71
-
72
- # Get the worker instance and set up routes
73
- worker = self._get_worker_instance()
73
+ # Create auth settings and token verifier if auth is required
74
+ auth_settings = None
75
+ token_verifier = None
74
76
 
75
77
  # Build the workflow and add routes using the worker
76
78
  async with WorkflowBuilder.from_config(config=self.full_config) as builder:
79
+
80
+ if self.front_end_config.server_auth:
81
+ from mcp.server.auth.settings import AuthSettings
82
+ from pydantic import AnyHttpUrl
83
+
84
+ server_url = f"http://{self.front_end_config.host}:{self.front_end_config.port}"
85
+
86
+ auth_settings = AuthSettings(issuer_url=AnyHttpUrl(self.front_end_config.server_auth.issuer_url),
87
+ required_scopes=self.front_end_config.server_auth.scopes,
88
+ resource_server_url=AnyHttpUrl(server_url))
89
+
90
+ token_verifier = await self._create_token_verifier(self.front_end_config.server_auth)
91
+
92
+ # Create an MCP server with the configured parameters
93
+ mcp = FastMCP(name=self.front_end_config.name,
94
+ host=self.front_end_config.host,
95
+ port=self.front_end_config.port,
96
+ debug=self.front_end_config.debug,
97
+ auth=auth_settings,
98
+ token_verifier=token_verifier)
99
+
100
+ # Get the worker instance and set up routes
101
+ worker = self._get_worker_instance()
102
+
77
103
  # Add routes through the worker (includes health endpoint and function registration)
78
104
  await worker.add_routes(mcp, builder)
79
105
 
@@ -16,11 +16,15 @@
16
16
  import logging
17
17
  from abc import ABC
18
18
  from abc import abstractmethod
19
+ from collections.abc import Mapping
20
+ from typing import Any
19
21
 
20
22
  from mcp.server.fastmcp import FastMCP
23
+ from starlette.exceptions import HTTPException
21
24
  from starlette.requests import Request
22
25
 
23
26
  from nat.builder.function import Function
27
+ from nat.builder.function_base import FunctionBase
24
28
  from nat.builder.workflow import Workflow
25
29
  from nat.builder.workflow_builder import WorkflowBuilder
26
30
  from nat.data_models.config import Config
@@ -98,10 +102,110 @@ class MCPFrontEndPluginWorkerBase(ABC):
98
102
  for function_group in workflow.function_groups.values():
99
103
  functions.update(function_group.get_accessible_functions())
100
104
 
101
- functions[workflow.config.workflow.type] = workflow
105
+ if workflow.config.workflow.workflow_alias:
106
+ functions[workflow.config.workflow.workflow_alias] = workflow
107
+ else:
108
+ functions[workflow.config.workflow.type] = workflow
102
109
 
103
110
  return functions
104
111
 
112
+ def _setup_debug_endpoints(self, mcp: FastMCP, functions: Mapping[str, FunctionBase]) -> None:
113
+ """Set up HTTP debug endpoints for introspecting tools and schemas.
114
+
115
+ Exposes:
116
+ - GET /debug/tools/list: List tools. Optional query param `name` (one or more, repeatable or comma separated)
117
+ selects a subset and returns details for those tools.
118
+ """
119
+
120
+ @mcp.custom_route("/debug/tools/list", methods=["GET"])
121
+ async def list_tools(request: Request):
122
+ """HTTP list tools endpoint."""
123
+
124
+ from starlette.responses import JSONResponse
125
+
126
+ from nat.front_ends.mcp.tool_converter import get_function_description
127
+
128
+ # Query params
129
+ # Support repeated names and comma-separated lists
130
+ names_param_list = set(request.query_params.getlist("name"))
131
+ names: list[str] = []
132
+ for raw in names_param_list:
133
+ # if p.strip() is empty, it won't be included in the list!
134
+ parts = [p.strip() for p in raw.split(",") if p.strip()]
135
+ names.extend(parts)
136
+ detail_raw = request.query_params.get("detail")
137
+
138
+ def _parse_detail_param(detail_param: str | None, has_names: bool) -> bool:
139
+ if detail_param is None:
140
+ if has_names:
141
+ return True
142
+ return False
143
+ v = detail_param.strip().lower()
144
+ if v in ("0", "false", "no", "off"):
145
+ return False
146
+ if v in ("1", "true", "yes", "on"):
147
+ return True
148
+ # For invalid values, default based on whether names are present
149
+ return has_names
150
+
151
+ # Helper function to build the input schema info
152
+ def _build_schema_info(fn: FunctionBase) -> dict[str, Any] | None:
153
+ schema = getattr(fn, "input_schema", None)
154
+ if schema is None:
155
+ return None
156
+
157
+ # check if schema is a ChatRequest
158
+ schema_name = getattr(schema, "__name__", "")
159
+ schema_qualname = getattr(schema, "__qualname__", "")
160
+ if "ChatRequest" in schema_name or "ChatRequest" in schema_qualname:
161
+ # Simplified interface used by MCP wrapper for ChatRequest
162
+ return {
163
+ "type": "object",
164
+ "properties": {
165
+ "query": {
166
+ "type": "string", "description": "User query string"
167
+ }
168
+ },
169
+ "required": ["query"],
170
+ "title": "ChatRequestQuery",
171
+ }
172
+
173
+ # Pydantic models provide model_json_schema
174
+ if schema is not None and hasattr(schema, "model_json_schema"):
175
+ return schema.model_json_schema()
176
+
177
+ return None
178
+
179
+ def _build_final_json(functions_to_include: Mapping[str, FunctionBase],
180
+ include_schemas: bool = False) -> dict[str, Any]:
181
+ tools = []
182
+ for name, fn in functions_to_include.items():
183
+ list_entry: dict[str, Any] = {
184
+ "name": name, "description": get_function_description(fn), "is_workflow": hasattr(fn, "run")
185
+ }
186
+ if include_schemas:
187
+ list_entry["schema"] = _build_schema_info(fn)
188
+ tools.append(list_entry)
189
+
190
+ return {
191
+ "count": len(tools),
192
+ "tools": tools,
193
+ "server_name": mcp.name,
194
+ }
195
+
196
+ if names:
197
+ # Return selected tools
198
+ try:
199
+ functions_to_include = {n: functions[n] for n in names}
200
+ except KeyError as e:
201
+ raise HTTPException(status_code=404, detail=f"Tool \"{e.args[0]}\" not found.") from e
202
+ else:
203
+ functions_to_include = functions
204
+
205
+ # Default for listing all: detail defaults to False unless explicitly set true
206
+ return JSONResponse(
207
+ _build_final_json(functions_to_include, _parse_detail_param(detail_raw, has_names=bool(names))))
208
+
105
209
 
106
210
  class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
107
211
  """Default MCP front end plugin worker implementation."""
@@ -142,3 +246,6 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
142
246
  # Add a simple fallback function if no functions were found
143
247
  if not functions:
144
248
  raise RuntimeError("No functions found in workflow. Please check your configuration.")
249
+
250
+ # After registration, expose debug endpoints for tool/schema inspection
251
+ self._setup_debug_endpoints(mcp, functions)
@@ -229,6 +229,9 @@ def get_function_description(function: FunctionBase) -> str:
229
229
  # Try to get anything that might be a description
230
230
  elif hasattr(config, "topic") and config.topic:
231
231
  function_description = config.topic
232
+ # Try to get description from the workflow config
233
+ elif hasattr(config, "workflow") and hasattr(config.workflow, "description") and config.workflow.description:
234
+ function_description = config.workflow.description
232
235
 
233
236
  elif isinstance(function, Function):
234
237
  function_description = function.description
@@ -15,6 +15,7 @@
15
15
 
16
16
  import inspect
17
17
  import logging
18
+ import types
18
19
  from functools import lru_cache
19
20
  from typing import Any
20
21
  from typing import TypeVar
@@ -475,3 +476,21 @@ class TypeIntrospectionMixin:
475
476
  except ValidationError:
476
477
  logger.warning("Item %s is not compatible with output type %s", item, self.output_type)
477
478
  return False
479
+
480
+ @lru_cache
481
+ def extract_non_optional_type(self, type_obj: type | types.UnionType) -> Any:
482
+ """Extract the non-None type from Optional[T] or Union[T, None] types.
483
+
484
+ This is useful when you need to pass a type to a system that doesn't
485
+ understand Optional types (like registries that expect concrete types).
486
+
487
+ Args:
488
+ type_obj (type | types.UnionType): The type to extract from (could be Optional[T] or Union[T, None])
489
+
490
+ Returns:
491
+ Any: The actual type without None, or the original type if not a union with None
492
+ """
493
+ decomposed = DecomposedType(type_obj) # type: ignore[arg-type]
494
+ if decomposed.is_optional:
495
+ return decomposed.get_optional_type().type
496
+ return type_obj
@@ -28,7 +28,6 @@ from nat.eval.evaluate import EvaluationRun
28
28
  from nat.eval.evaluate import EvaluationRunConfig
29
29
  from nat.experimental.decorators.experimental_warning_decorator import experimental
30
30
  from nat.profiler.parameter_optimization.parameter_selection import pick_trial
31
- from nat.profiler.parameter_optimization.pareto_visualizer import create_pareto_visualization
32
31
  from nat.profiler.parameter_optimization.update_helpers import apply_suggestions
33
32
 
34
33
  logger = logging.getLogger(__name__)
@@ -133,6 +132,7 @@ def optimize_parameters(
133
132
 
134
133
  # Generate Pareto front visualizations
135
134
  try:
135
+ from nat.profiler.parameter_optimization.pareto_visualizer import create_pareto_visualization
136
136
  logger.info("Generating Pareto front visualizations...")
137
137
  create_pareto_visualization(
138
138
  data_source=study,
@@ -143,6 +143,10 @@ def optimize_parameters(
143
143
  show_plots=False # Don't show plots in automated runs
144
144
  )
145
145
  logger.info("Pareto visualizations saved to: %s", out_dir / "plots")
146
+ except ImportError as ie:
147
+ logger.warning("Could not import visualization dependencies: %s. "
148
+ "Have you installed nvidia-nat-profiling?",
149
+ ie)
146
150
  except Exception as e:
147
151
  logger.warning("Failed to generate visualizations: %s", e)
148
152
 
@@ -0,0 +1,25 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ # Define log level choices
19
+ LOG_LEVELS = {
20
+ 'DEBUG': logging.DEBUG,
21
+ 'INFO': logging.INFO,
22
+ 'WARNING': logging.WARNING,
23
+ 'ERROR': logging.ERROR,
24
+ 'CRITICAL': logging.CRITICAL
25
+ }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat
3
- Version: 1.3.0a20250917
3
+ Version: 1.3.0a20250923
4
4
  Summary: NVIDIA NeMo Agent toolkit
5
5
  Author: NVIDIA Corporation
6
6
  Maintainer: NVIDIA Corporation
@@ -225,6 +225,7 @@ Requires-Dist: httpx~=0.27
225
225
  Requires-Dist: jinja2~=3.1
226
226
  Requires-Dist: jsonpath-ng~=1.7
227
227
  Requires-Dist: mcp~=1.13
228
+ Requires-Dist: nest-asyncio~=1.6
228
229
  Requires-Dist: networkx~=3.4
229
230
  Requires-Dist: numpy~=1.26; python_version < "3.12"
230
231
  Requires-Dist: numpy~=2.3; python_version >= "3.12"
@@ -299,6 +300,7 @@ Requires-Dist: nat_redact_pii; extra == "examples"
299
300
  Requires-Dist: nat_retail_sales_agent; extra == "examples"
300
301
  Requires-Dist: nat_router_agent; extra == "examples"
301
302
  Requires-Dist: nat_semantic_kernel_demo; extra == "examples"
303
+ Requires-Dist: nat_sequential_executor; extra == "examples"
302
304
  Requires-Dist: nat_simple_auth; extra == "examples"
303
305
  Requires-Dist: nat_simple_web_query; extra == "examples"
304
306
  Requires-Dist: nat_simple_web_query_eval; extra == "examples"