nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +50 -22
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +54 -27
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +68 -17
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  53. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  54. nat/cli/commands/workflow/workflow_commands.py +62 -22
  55. nat/cli/entrypoint.py +8 -10
  56. nat/cli/main.py +3 -0
  57. nat/cli/register_workflow.py +38 -4
  58. nat/cli/type_registry.py +75 -6
  59. nat/control_flow/__init__.py +0 -0
  60. nat/control_flow/register.py +20 -0
  61. nat/control_flow/router_agent/__init__.py +0 -0
  62. nat/control_flow/router_agent/agent.py +329 -0
  63. nat/control_flow/router_agent/prompt.py +48 -0
  64. nat/control_flow/router_agent/register.py +91 -0
  65. nat/control_flow/sequential_executor.py +166 -0
  66. nat/data_models/agent.py +34 -0
  67. nat/data_models/api_server.py +74 -66
  68. nat/data_models/authentication.py +23 -9
  69. nat/data_models/common.py +1 -1
  70. nat/data_models/component.py +2 -0
  71. nat/data_models/component_ref.py +11 -0
  72. nat/data_models/config.py +41 -17
  73. nat/data_models/dataset_handler.py +1 -1
  74. nat/data_models/discovery_metadata.py +4 -4
  75. nat/data_models/evaluate.py +4 -1
  76. nat/data_models/function.py +34 -0
  77. nat/data_models/function_dependencies.py +14 -6
  78. nat/data_models/gated_field_mixin.py +242 -0
  79. nat/data_models/intermediate_step.py +3 -3
  80. nat/data_models/optimizable.py +119 -0
  81. nat/data_models/optimizer.py +149 -0
  82. nat/data_models/span.py +41 -3
  83. nat/data_models/swe_bench_model.py +1 -1
  84. nat/data_models/temperature_mixin.py +44 -0
  85. nat/data_models/thinking_mixin.py +86 -0
  86. nat/data_models/top_p_mixin.py +44 -0
  87. nat/embedder/nim_embedder.py +1 -1
  88. nat/embedder/openai_embedder.py +1 -1
  89. nat/embedder/register.py +0 -1
  90. nat/eval/config.py +3 -1
  91. nat/eval/dataset_handler/dataset_handler.py +71 -7
  92. nat/eval/evaluate.py +86 -31
  93. nat/eval/evaluator/base_evaluator.py +1 -1
  94. nat/eval/evaluator/evaluator_model.py +13 -0
  95. nat/eval/intermediate_step_adapter.py +1 -1
  96. nat/eval/rag_evaluator/evaluate.py +2 -2
  97. nat/eval/rag_evaluator/register.py +3 -3
  98. nat/eval/register.py +4 -1
  99. nat/eval/remote_workflow.py +3 -3
  100. nat/eval/runtime_evaluator/__init__.py +14 -0
  101. nat/eval/runtime_evaluator/evaluate.py +123 -0
  102. nat/eval/runtime_evaluator/register.py +100 -0
  103. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  104. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  105. nat/eval/trajectory_evaluator/register.py +1 -1
  106. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  107. nat/eval/utils/eval_trace_ctx.py +89 -0
  108. nat/eval/utils/weave_eval.py +18 -9
  109. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  110. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  111. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  112. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  113. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  114. nat/experimental/test_time_compute/register.py +0 -1
  115. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  116. nat/front_ends/console/authentication_flow_handler.py +82 -30
  117. nat/front_ends/console/console_front_end_plugin.py +8 -5
  118. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  119. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  120. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  121. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  122. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  123. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
  124. nat/front_ends/fastapi/job_store.py +518 -99
  125. nat/front_ends/fastapi/main.py +11 -19
  126. nat/front_ends/fastapi/message_handler.py +13 -14
  127. nat/front_ends/fastapi/message_validator.py +19 -19
  128. nat/front_ends/fastapi/response_helpers.py +4 -4
  129. nat/front_ends/fastapi/step_adaptor.py +2 -2
  130. nat/front_ends/fastapi/utils.py +57 -0
  131. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  132. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  133. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  134. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  135. nat/front_ends/mcp/tool_converter.py +44 -14
  136. nat/front_ends/register.py +0 -1
  137. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  138. nat/llm/aws_bedrock_llm.py +24 -12
  139. nat/llm/azure_openai_llm.py +13 -6
  140. nat/llm/litellm_llm.py +69 -0
  141. nat/llm/nim_llm.py +20 -8
  142. nat/llm/openai_llm.py +14 -6
  143. nat/llm/register.py +4 -1
  144. nat/llm/utils/env_config_value.py +2 -3
  145. nat/llm/utils/thinking.py +215 -0
  146. nat/meta/pypi.md +9 -9
  147. nat/object_store/register.py +0 -1
  148. nat/observability/exporter/base_exporter.py +3 -3
  149. nat/observability/exporter/file_exporter.py +1 -1
  150. nat/observability/exporter/processing_exporter.py +309 -81
  151. nat/observability/exporter/span_exporter.py +35 -15
  152. nat/observability/exporter_manager.py +7 -7
  153. nat/observability/mixin/file_mixin.py +7 -7
  154. nat/observability/mixin/redaction_config_mixin.py +42 -0
  155. nat/observability/mixin/tagging_config_mixin.py +62 -0
  156. nat/observability/mixin/type_introspection_mixin.py +420 -107
  157. nat/observability/processor/batching_processor.py +5 -7
  158. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  159. nat/observability/processor/processor.py +3 -0
  160. nat/observability/processor/processor_factory.py +70 -0
  161. nat/observability/processor/redaction/__init__.py +24 -0
  162. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  163. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  164. nat/observability/processor/redaction/redaction_processor.py +177 -0
  165. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  166. nat/observability/processor/span_tagging_processor.py +68 -0
  167. nat/observability/register.py +6 -4
  168. nat/profiler/calc/calc_runner.py +3 -4
  169. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  170. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  171. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  172. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  173. nat/profiler/data_frame_row.py +1 -1
  174. nat/profiler/decorators/framework_wrapper.py +62 -13
  175. nat/profiler/decorators/function_tracking.py +160 -3
  176. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  177. nat/profiler/forecasting/models/linear_model.py +1 -1
  178. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  179. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  180. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  181. nat/profiler/inference_optimization/data_models.py +3 -3
  182. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  183. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  184. nat/profiler/parameter_optimization/__init__.py +0 -0
  185. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  186. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  187. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  188. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  189. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  190. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  191. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  192. nat/profiler/profile_runner.py +14 -9
  193. nat/profiler/utils.py +4 -2
  194. nat/registry_handlers/local/local_handler.py +2 -2
  195. nat/registry_handlers/package_utils.py +1 -2
  196. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  197. nat/registry_handlers/register.py +3 -4
  198. nat/registry_handlers/rest/rest_handler.py +12 -13
  199. nat/retriever/milvus/retriever.py +2 -2
  200. nat/retriever/nemo_retriever/retriever.py +1 -1
  201. nat/retriever/register.py +0 -1
  202. nat/runtime/loader.py +2 -2
  203. nat/runtime/runner.py +106 -8
  204. nat/runtime/session.py +69 -8
  205. nat/settings/global_settings.py +16 -5
  206. nat/tool/chat_completion.py +5 -2
  207. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  208. nat/tool/datetime_tools.py +49 -9
  209. nat/tool/document_search.py +2 -2
  210. nat/tool/github_tools.py +450 -0
  211. nat/tool/memory_tools/get_memory_tool.py +1 -1
  212. nat/tool/nvidia_rag.py +1 -1
  213. nat/tool/register.py +2 -9
  214. nat/tool/retriever.py +3 -2
  215. nat/utils/callable_utils.py +70 -0
  216. nat/utils/data_models/schema_validator.py +3 -3
  217. nat/utils/decorators.py +210 -0
  218. nat/utils/exception_handlers/automatic_retries.py +104 -51
  219. nat/utils/exception_handlers/schemas.py +1 -1
  220. nat/utils/io/yaml_tools.py +2 -2
  221. nat/utils/log_levels.py +25 -0
  222. nat/utils/reactive/base/observable_base.py +2 -2
  223. nat/utils/reactive/base/observer_base.py +1 -1
  224. nat/utils/reactive/observable.py +2 -2
  225. nat/utils/reactive/observer.py +4 -4
  226. nat/utils/reactive/subscription.py +1 -1
  227. nat/utils/settings/global_settings.py +6 -8
  228. nat/utils/type_converter.py +4 -3
  229. nat/utils/type_utils.py +9 -5
  230. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
  231. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
  232. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
  233. nat/cli/commands/info/list_mcp.py +0 -304
  234. nat/tool/github_tools/create_github_commit.py +0 -133
  235. nat/tool/github_tools/create_github_issue.py +0 -87
  236. nat/tool/github_tools/create_github_pr.py +0 -106
  237. nat/tool/github_tools/get_github_file.py +0 -106
  238. nat/tool/github_tools/get_github_issue.py +0 -166
  239. nat/tool/github_tools/get_github_pr.py +0 -256
  240. nat/tool/github_tools/update_github_issue.py +0 -100
  241. nat/tool/mcp/exceptions.py +0 -142
  242. nat/tool/mcp/mcp_client.py +0 -255
  243. nat/tool/mcp/mcp_tool.py +0 -96
  244. nat/utils/exception_handlers/mcp.py +0 -211
  245. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  246. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  247. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
  248. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  249. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  250. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -1,142 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from enum import Enum
17
-
18
-
19
- class MCPErrorCategory(str, Enum):
20
- """Categories of MCP errors for structured handling."""
21
- CONNECTION = "connection"
22
- TIMEOUT = "timeout"
23
- SSL = "ssl"
24
- AUTHENTICATION = "authentication"
25
- TOOL_NOT_FOUND = "tool_not_found"
26
- PROTOCOL = "protocol"
27
- UNKNOWN = "unknown"
28
-
29
-
30
- class MCPError(Exception):
31
- """Base exception for MCP-related errors."""
32
-
33
- def __init__(self,
34
- message: str,
35
- url: str,
36
- category: MCPErrorCategory = MCPErrorCategory.UNKNOWN,
37
- suggestions: list[str] | None = None,
38
- original_exception: Exception | None = None):
39
- super().__init__(message)
40
- self.url = url
41
- self.category = category
42
- self.suggestions = suggestions or []
43
- self.original_exception = original_exception
44
-
45
-
46
- class MCPConnectionError(MCPError):
47
- """Exception for MCP connection failures."""
48
-
49
- def __init__(self, url: str, original_exception: Exception | None = None):
50
- super().__init__(f"Unable to connect to MCP server at {url}",
51
- url=url,
52
- category=MCPErrorCategory.CONNECTION,
53
- suggestions=[
54
- "Please ensure the MCP server is running and accessible",
55
- "Check if the URL and port are correct"
56
- ],
57
- original_exception=original_exception)
58
-
59
-
60
- class MCPTimeoutError(MCPError):
61
- """Exception for MCP timeout errors."""
62
-
63
- def __init__(self, url: str, original_exception: Exception | None = None):
64
- super().__init__(f"Connection timed out to MCP server at {url}",
65
- url=url,
66
- category=MCPErrorCategory.TIMEOUT,
67
- suggestions=[
68
- "The server may be overloaded or network is slow",
69
- "Try again in a moment or check network connectivity"
70
- ],
71
- original_exception=original_exception)
72
-
73
-
74
- class MCPSSLError(MCPError):
75
- """Exception for MCP SSL/TLS errors."""
76
-
77
- def __init__(self, url: str, original_exception: Exception | None = None):
78
- super().__init__(f"SSL/TLS error connecting to {url}",
79
- url=url,
80
- category=MCPErrorCategory.SSL,
81
- suggestions=[
82
- "Check if the server requires HTTPS or has valid certificates",
83
- "Try using HTTP instead of HTTPS if appropriate"
84
- ],
85
- original_exception=original_exception)
86
-
87
-
88
- class MCPRequestError(MCPError):
89
- """Exception for MCP request errors."""
90
-
91
- def __init__(self, url: str, original_exception: Exception | None = None):
92
- message = f"Request failed to MCP server at {url}"
93
- if original_exception:
94
- message += f": {original_exception}"
95
-
96
- super().__init__(message,
97
- url=url,
98
- category=MCPErrorCategory.PROTOCOL,
99
- suggestions=["Check the server URL format and network settings"],
100
- original_exception=original_exception)
101
-
102
-
103
- class MCPToolNotFoundError(MCPError):
104
- """Exception for when a specific MCP tool is not found."""
105
-
106
- def __init__(self, tool_name: str, url: str, original_exception: Exception | None = None):
107
- super().__init__(f"Tool '{tool_name}' not available at {url}",
108
- url=url,
109
- category=MCPErrorCategory.TOOL_NOT_FOUND,
110
- suggestions=[
111
- "Use 'nat info mcp --detail' to see available tools",
112
- "Check that the tool name is spelled correctly"
113
- ],
114
- original_exception=original_exception)
115
-
116
-
117
- class MCPAuthenticationError(MCPError):
118
- """Exception for MCP authentication failures."""
119
-
120
- def __init__(self, url: str, original_exception: Exception | None = None):
121
- super().__init__(f"Authentication failed when connecting to MCP server at {url}",
122
- url=url,
123
- category=MCPErrorCategory.AUTHENTICATION,
124
- suggestions=[
125
- "Check if the server requires authentication credentials",
126
- "Verify that your credentials are correct and not expired"
127
- ],
128
- original_exception=original_exception)
129
-
130
-
131
- class MCPProtocolError(MCPError):
132
- """Exception for MCP protocol-related errors."""
133
-
134
- def __init__(self, url: str, message: str = "Protocol error", original_exception: Exception | None = None):
135
- super().__init__(f"{message} (MCP server at {url})",
136
- url=url,
137
- category=MCPErrorCategory.PROTOCOL,
138
- suggestions=[
139
- "Check that the MCP server is running and accessible at this URL",
140
- "Verify the server supports the expected MCP protocol version"
141
- ],
142
- original_exception=original_exception)
@@ -1,255 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from __future__ import annotations
17
-
18
- import logging
19
- from contextlib import asynccontextmanager
20
- from enum import Enum
21
- from typing import Any
22
-
23
- from mcp import ClientSession
24
- from mcp.client.sse import sse_client
25
- from mcp.types import TextContent
26
- from pydantic import BaseModel
27
- from pydantic import Field
28
- from pydantic import create_model
29
-
30
- from nat.tool.mcp.exceptions import MCPToolNotFoundError
31
- from nat.utils.exception_handlers.mcp import mcp_exception_handler
32
-
33
- logger = logging.getLogger(__name__)
34
-
35
-
36
- def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
37
- """
38
- Create a pydantic model from the input schema of the MCP tool
39
- """
40
- _type_map = {
41
- "string": str,
42
- "number": float,
43
- "integer": int,
44
- "boolean": bool,
45
- "array": list,
46
- "null": None,
47
- "object": dict,
48
- }
49
-
50
- properties = mcp_input_schema.get("properties", {})
51
- required_fields = set(mcp_input_schema.get("required", []))
52
- schema_dict = {}
53
-
54
- def _generate_valid_classname(class_name: str):
55
- return class_name.replace('_', ' ').replace('-', ' ').title().replace(' ', '')
56
-
57
- def _generate_field(field_name: str, field_properties: dict[str, Any]) -> tuple:
58
- json_type = field_properties.get("type", "string")
59
- enum_vals = field_properties.get("enum")
60
-
61
- if enum_vals:
62
- enum_name = f"{field_name.capitalize()}Enum"
63
- field_type = Enum(enum_name, {item: item for item in enum_vals})
64
-
65
- elif json_type == "object" and "properties" in field_properties:
66
- field_type = model_from_mcp_schema(name=field_name, mcp_input_schema=field_properties)
67
- elif json_type == "array" and "items" in field_properties:
68
- item_properties = field_properties.get("items", {})
69
- if item_properties.get("type") == "object":
70
- item_type = model_from_mcp_schema(name=field_name, mcp_input_schema=item_properties)
71
- else:
72
- item_type = _type_map.get(item_properties.get("type", "string"), Any)
73
- field_type = list[item_type]
74
- elif isinstance(json_type, list):
75
- field_type = None
76
- for t in json_type:
77
- mapped = _type_map.get(t, Any)
78
- field_type = mapped if field_type is None else field_type | mapped
79
-
80
- return field_type, Field(
81
- default=field_properties.get("default", None if "null" in json_type else ...),
82
- description=field_properties.get("description", "")
83
- )
84
- else:
85
- field_type = _type_map.get(json_type, Any)
86
-
87
- # Determine the default value based on whether the field is required
88
- if field_name in required_fields:
89
- # Field is required - use explicit default if provided, otherwise make it required
90
- default_value = field_properties.get("default", ...)
91
- else:
92
- # Field is optional - use explicit default if provided, otherwise None
93
- default_value = field_properties.get("default", None)
94
- # Make the type optional if no default was provided
95
- if "default" not in field_properties:
96
- field_type = field_type | None
97
-
98
- nullable = field_properties.get("nullable", False)
99
- description = field_properties.get("description", "")
100
-
101
- field_type = field_type | None if nullable else field_type
102
-
103
- return field_type, Field(default=default_value, description=description)
104
-
105
- for field_name, field_props in properties.items():
106
- schema_dict[field_name] = _generate_field(field_name=field_name, field_properties=field_props)
107
- return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
108
-
109
-
110
- class MCPSSEClient:
111
- """
112
- Client for creating a session and connecting to an MCP server using SSE
113
-
114
- Args:
115
- url (str): The url of the MCP server
116
- """
117
-
118
- def __init__(self, url: str):
119
- self.url = url
120
-
121
- @asynccontextmanager
122
- async def connect_to_sse_server(self):
123
- """
124
- Establish a session with an MCP SSE server within an aync context
125
- """
126
- async with sse_client(url=self.url) as (read, write):
127
- async with ClientSession(read, write) as session:
128
- await session.initialize()
129
- yield session
130
-
131
-
132
- class MCPBuilder(MCPSSEClient):
133
- """
134
- Builder class used to connect to an MCP Server and generate ToolClients
135
-
136
- Args:
137
- url (str): The url of the MCP server
138
- """
139
-
140
- def __init__(self, url):
141
- super().__init__(url)
142
- self._tools = None
143
-
144
- @mcp_exception_handler
145
- async def get_tools(self):
146
- """
147
- Retrieve a dictionary of all tools served by the MCP server.
148
-
149
- Returns:
150
- Dict of tool name to MCPToolClient
151
-
152
- Raises:
153
- MCPError: If connection or tool retrieval fails
154
- """
155
- async with self.connect_to_sse_server() as session:
156
- response = await session.list_tools()
157
-
158
- return {
159
- tool.name: MCPToolClient(self.url, tool.name, tool.description, tool_input_schema=tool.inputSchema)
160
- for tool in response.tools
161
- }
162
-
163
- @mcp_exception_handler
164
- async def get_tool(self, tool_name: str) -> MCPToolClient:
165
- """
166
- Get an MCP Tool by name.
167
-
168
- Args:
169
- tool_name (str): Name of the tool to load.
170
-
171
- Returns:
172
- MCPToolClient for the configured tool.
173
-
174
- Raises:
175
- MCPToolNotFoundError: If no tool is available with that name
176
- MCPError: If connection fails
177
- """
178
- if not self._tools:
179
- self._tools = await self.get_tools()
180
-
181
- tool = self._tools.get(tool_name)
182
- if not tool:
183
- raise MCPToolNotFoundError(tool_name, self.url)
184
- return tool
185
-
186
- @mcp_exception_handler
187
- async def call_tool(self, tool_name: str, tool_args: dict | None):
188
- async with self.connect_to_sse_server() as session:
189
- result = await session.call_tool(tool_name, tool_args)
190
- return result
191
-
192
-
193
- class MCPToolClient(MCPSSEClient):
194
- """
195
- Client wrapper used to call an MCP tool.
196
-
197
- Args:
198
- url (str): The url of the MCP server
199
- tool_name (str): The name of the tool to wrap
200
- tool_description (str): The description of the tool provided by the MCP server.
201
- tool_input_schema (dict): The input schema for the tool.
202
- """
203
-
204
- def __init__(self, url: str, tool_name: str, tool_description: str | None, tool_input_schema: dict | None = None):
205
- super().__init__(url)
206
- self._tool_name = tool_name
207
- self._tool_description = tool_description
208
- self._input_schema = model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None
209
-
210
- @property
211
- def name(self):
212
- """Returns the name of the tool."""
213
- return self._tool_name
214
-
215
- @property
216
- def description(self):
217
- """
218
- Returns the tool's description. If none was provided. Provides a simple description using the tool's name
219
- """
220
- if not self._tool_description:
221
- return f"MCP Tool {self._tool_name}"
222
- return self._tool_description
223
-
224
- @property
225
- def input_schema(self):
226
- """
227
- Returns the tool's input_schema.
228
- """
229
- return self._input_schema
230
-
231
- def set_description(self, description: str):
232
- """
233
- Manually define the tool's description using the provided string.
234
- """
235
- self._tool_description = description
236
-
237
- @mcp_exception_handler
238
- async def acall(self, tool_args: dict) -> str:
239
- """
240
- Call the MCP tool with the provided arguments.
241
-
242
- Args:
243
- tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
244
- """
245
- async with self.connect_to_sse_server() as session:
246
- result = await session.call_tool(self._tool_name, tool_args)
247
-
248
- output = []
249
- for res in result.content:
250
- if isinstance(res, TextContent):
251
- output.append(res.text)
252
- else:
253
- # Log non-text content for now
254
- logger.warning("Got not-text output from %s of type %s", self.name, type(res))
255
- return "\n".join(output)
nat/tool/mcp/mcp_tool.py DELETED
@@ -1,96 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import logging
17
-
18
- from pydantic import BaseModel
19
- from pydantic import Field
20
- from pydantic import HttpUrl
21
-
22
- from nat.builder.builder import Builder
23
- from nat.builder.function_info import FunctionInfo
24
- from nat.cli.register_workflow import register_function
25
- from nat.data_models.function import FunctionBaseConfig
26
-
27
- logger = logging.getLogger(__name__)
28
-
29
-
30
- class MCPToolConfig(FunctionBaseConfig, name="mcp_tool_wrapper"):
31
- """
32
- Function which connects to a Model Context Protocol (MCP) server and wraps the selected tool as a NeMo Agent toolkit
33
- function.
34
- """
35
- # Add your custom configuration parameters here
36
- url: HttpUrl = Field(description="The URL of the MCP server")
37
- mcp_tool_name: str = Field(description="The name of the tool served by the MCP Server that you want to use")
38
- description: str | None = Field(default=None,
39
- description="""
40
- Description for the tool that will override the description provided by the MCP server. Should only be used if
41
- the description provided by the server is poor or nonexistent
42
- """)
43
- return_exception: bool = Field(default=True,
44
- description="""
45
- If true, the tool will return the exception message if the tool call fails.
46
- If false, raise the exception.
47
- """)
48
-
49
-
50
- @register_function(config_type=MCPToolConfig)
51
- async def mcp_tool(config: MCPToolConfig, builder: Builder): # pylint: disable=unused-argument
52
- """
53
- Generate a NAT Function that wraps a tool provided by the MCP server.
54
- """
55
-
56
- from nat.tool.mcp.mcp_client import MCPBuilder
57
- from nat.tool.mcp.mcp_client import MCPToolClient
58
-
59
- client = MCPBuilder(url=str(config.url))
60
-
61
- tool: MCPToolClient = await client.get_tool(config.mcp_tool_name)
62
- if config.description:
63
- tool.set_description(description=config.description)
64
-
65
- logger.info("Configured to use tool: %s from MCP server at %s", tool.name, str(config.url))
66
-
67
- def _convert_from_str(input_str: str) -> tool.input_schema:
68
- return tool.input_schema.model_validate_json(input_str)
69
-
70
- async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
71
- # Run the tool, catching any errors and sending to agent for correction
72
- try:
73
- if tool_input:
74
- args = tool_input.model_dump()
75
- return await tool.acall(args)
76
-
77
- _ = tool.input_schema.model_validate(kwargs)
78
- filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
79
- return await tool.acall(filtered_kwargs)
80
- except Exception as e:
81
- if config.return_exception:
82
- if tool_input:
83
- logger.warning("Error calling tool %s with serialized input: %s",
84
- tool.name,
85
- tool_input.model_dump(),
86
- exc_info=True)
87
- else:
88
- logger.warning("Error calling tool %s with input: %s", tool.name, kwargs, exc_info=True)
89
- return str(e)
90
- # If the tool call fails, raise the exception.
91
- raise
92
-
93
- yield FunctionInfo.create(single_fn=_response_fn,
94
- description=tool.description,
95
- input_schema=tool.input_schema,
96
- converters=[_convert_from_str])