nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,90 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import logging
18
+ from pathlib import Path
19
+
20
+ import click
21
+
22
+ from nat.data_models.optimizer import OptimizerRunConfig
23
+ from nat.profiler.parameter_optimization.optimizer_runtime import optimize_config
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @click.group(name=__name__, invoke_without_command=True, help="Optimize a workflow with the specified dataset.")
29
+ @click.option(
30
+ "--config_file",
31
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
32
+ required=True,
33
+ help="A JSON/YAML file that sets the parameters for the workflow and evaluation.",
34
+ )
35
+ @click.option(
36
+ "--dataset",
37
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
38
+ required=False,
39
+ help="A json file with questions and ground truth answers. This will override the dataset path in the config file.",
40
+ )
41
+ @click.option(
42
+ "--result_json_path",
43
+ type=str,
44
+ default="$",
45
+ help=("A JSON path to extract the result from the workflow. Use this when the workflow returns "
46
+ "multiple objects or a dictionary. For example, '$.output' will extract the 'output' field "
47
+ "from the result."),
48
+ )
49
+ @click.option(
50
+ "--endpoint",
51
+ type=str,
52
+ default=None,
53
+ help="Use endpoint for running the workflow. Example: http://localhost:8000/generate",
54
+ )
55
+ @click.option(
56
+ "--endpoint_timeout",
57
+ type=int,
58
+ default=300,
59
+ help="HTTP response timeout in seconds. Only relevant if endpoint is specified.",
60
+ )
61
+ @click.pass_context
62
+ def optimizer_command(ctx, **kwargs) -> None:
63
+ """ Optimize workflow with the specified dataset"""
64
+ pass
65
+
66
+
67
+ async def run_optimizer(config: OptimizerRunConfig):
68
+ await optimize_config(config)
69
+
70
+
71
+ @optimizer_command.result_callback(replace=True)
72
+ def run_optimizer_callback(
73
+ processors, # pylint: disable=unused-argument
74
+ *,
75
+ config_file: Path,
76
+ dataset: Path,
77
+ result_json_path: str,
78
+ endpoint: str,
79
+ endpoint_timeout: int,
80
+ ):
81
+ """Run the optimizer with the provided config file and dataset."""
82
+ config = OptimizerRunConfig(
83
+ config_file=config_file,
84
+ dataset=dataset,
85
+ result_json_path=result_json_path,
86
+ endpoint=endpoint,
87
+ endpoint_timeout=endpoint_timeout,
88
+ )
89
+
90
+ asyncio.run(run_optimizer(config))
nat/cli/commands/start.py CHANGED
@@ -111,7 +111,7 @@ class StartCommandGroup(click.Group):
111
111
  elif (issubclass(decomposed_type.root, Path)):
112
112
  param_type = click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path)
113
113
 
114
- elif (issubclass(decomposed_type.root, (list, tuple, set))):
114
+ elif (issubclass(decomposed_type.root, list | tuple | set)):
115
115
  if (len(decomposed_type.args) == 1):
116
116
  inner = DecomposedType(decomposed_type.args[0])
117
117
  # Support containers of Literal values -> multiple Choice
@@ -1,16 +1,17 @@
1
- general:
2
- use_uvloop: true
3
- logging:
4
- console:
5
- _type: console
6
- level: WARN
1
+ functions:
2
+ current_datetime:
3
+ _type: current_datetime
4
+ {{python_safe_workflow_name}}:
5
+ _type: {{python_safe_workflow_name}}
6
+ prefix: "Hello:"
7
7
 
8
- front_end:
9
- _type: fastapi
10
-
11
- front_end:
12
- _type: console
8
+ llms:
9
+ nim_llm:
10
+ _type: nim
11
+ model_name: meta/llama-3.1-70b-instruct
12
+ temperature: 0.0
13
13
 
14
14
  workflow:
15
- _type: {{workflow_name}}
16
- parameter: default_value
15
+ _type: react_agent
16
+ llm_name: nim_llm
17
+ tool_names: [current_datetime, {{python_safe_workflow_name}}]
@@ -1,4 +1,4 @@
1
1
  # flake8: noqa
2
2
 
3
- # Import any tools which need to be automatically registered here
4
- from {{package_name}} import {{workflow_name}}_function
3
+ # Import the generated workflow function to trigger registration
4
+ from .{{package_name}} import {{ python_safe_workflow_name }}_function
@@ -3,6 +3,7 @@ import logging
3
3
  from pydantic import Field
4
4
 
5
5
  from nat.builder.builder import Builder
6
+ from nat.builder.framework_enum import LLMFrameworkEnum
6
7
  from nat.builder.function_info import FunctionInfo
7
8
  from nat.cli.register_workflow import register_function
8
9
  from nat.data_models.function import FunctionBaseConfig
@@ -12,25 +13,38 @@ logger = logging.getLogger(__name__)
12
13
 
13
14
  class {{ workflow_class_name }}(FunctionBaseConfig, name="{{ workflow_name }}"):
14
15
  """
15
- {{workflow_description}}
16
+ {{ workflow_description }}
16
17
  """
17
- # Add your custom configuration parameters here
18
- parameter: str = Field(default="default_value", description="Notional description for this parameter")
19
-
20
-
21
- @register_function(config_type={{ workflow_class_name }})
22
- async def {{ python_safe_workflow_name }}_function(
23
- config: {{ workflow_class_name }}, builder: Builder
24
- ):
25
- # Implement your function logic here
26
- async def _response_fn(input_message: str) -> str:
27
- # Process the input_message and generate output
28
- output_message = f"Hello from {{ workflow_name }} workflow! You said: {input_message}"
29
- return output_message
30
-
31
- try:
32
- yield FunctionInfo.create(single_fn=_response_fn)
33
- except GeneratorExit:
34
- logger.warning("Function exited early!")
35
- finally:
36
- logger.info("Cleaning up {{ workflow_name }} workflow.")
18
+ prefix: str = Field(default="Echo:", description="Prefix to add before the echoed text.")
19
+
20
+
21
+ @register_function(config_type={{ workflow_class_name }}, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
22
+ async def {{ python_safe_workflow_name }}_function(config: {{ workflow_class_name }}, builder: Builder):
23
+ """
24
+ Registers a function (addressable via `{{ workflow_name }}` in the configuration).
25
+ This registration ensures a static mapping of the function type, `{{ workflow_name }}`, to the `{{ workflow_class_name }}` configuration object.
26
+
27
+ Args:
28
+ config ({{ workflow_class_name }}): The configuration for the function.
29
+ builder (Builder): The builder object.
30
+
31
+ Returns:
32
+ FunctionInfo: The function info object for the function.
33
+ """
34
+
35
+ # Define the function that will be registered.
36
+ async def _echo(text: str) -> str:
37
+ """
38
+ Takes a text input and echoes back with a pre-defined prefix.
39
+
40
+ Args:
41
+ text (str): The text to echo back.
42
+
43
+ Returns:
44
+ str: The text with the prefix.
45
+ """
46
+ return f"{config.prefix} {text}"
47
+
48
+ # The callable is wrapped in a FunctionInfo object.
49
+ # The description parameter is used to describe the function.
50
+ yield FunctionInfo.from_fn(_echo, description=_echo.__doc__)
@@ -27,6 +27,50 @@ from jinja2 import FileSystemLoader
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
29
 
30
+ def _get_nat_version() -> str | None:
31
+ """
32
+ Get the current NAT version.
33
+
34
+ Returns:
35
+ str: The NAT version intended for use in a dependency string.
36
+ None: If the NAT version is not found.
37
+ """
38
+ from nat.cli.entrypoint import get_version
39
+
40
+ current_version = get_version()
41
+ if current_version == "unknown":
42
+ return None
43
+
44
+ version_parts = current_version.split(".")
45
+ if len(version_parts) < 3:
46
+ # If the version somehow doesn't have three parts, return the full version
47
+ return current_version
48
+
49
+ patch = version_parts[2]
50
+ try:
51
+ # If the patch is a number, keep only the major and minor parts
52
+ # Useful for stable releases and adheres to semantic versioning
53
+ _ = int(patch)
54
+ digits_to_keep = 2
55
+ except ValueError:
56
+ # If the patch is not a number, keep all three digits
57
+ # Useful for pre-release versions (and nightly builds)
58
+ digits_to_keep = 3
59
+
60
+ return ".".join(version_parts[:digits_to_keep])
61
+
62
+
63
+ def _is_nat_version_prerelease() -> bool:
64
+ """
65
+ Check if the NAT version is a prerelease.
66
+ """
67
+ version = _get_nat_version()
68
+ if version is None:
69
+ return False
70
+
71
+ return len(version.split(".")) >= 3
72
+
73
+
30
74
  def _get_nat_dependency(versioned: bool = True) -> str:
31
75
  """
32
76
  Get the NAT dependency string with version.
@@ -44,16 +88,12 @@ def _get_nat_dependency(versioned: bool = True) -> str:
44
88
  logger.debug("Using unversioned NAT dependency: %s", dependency)
45
89
  return dependency
46
90
 
47
- # Get the current NAT version
48
- from nat.cli.entrypoint import get_version
49
- current_version = get_version()
50
- if current_version == "unknown":
51
- logger.warning("Could not detect NAT version, using unversioned dependency")
91
+ version = _get_nat_version()
92
+ if version is None:
93
+ logger.debug("Could not detect NAT version, using unversioned dependency: %s", dependency)
52
94
  return dependency
53
95
 
54
- # Extract major.minor (e.g., "1.2.3" -> "1.2")
55
- major_minor = ".".join(current_version.split(".")[:2])
56
- dependency += f"~={major_minor}"
96
+ dependency += f"~={version}"
57
97
  logger.debug("Using NAT dependency: %s", dependency)
58
98
  return dependency
59
99
 
@@ -171,6 +211,9 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
171
211
  workflow_dir (str): The directory to create the workflow package.
172
212
  description (str): Description to pre-popluate the workflow docstring.
173
213
  """
214
+ # Fail fast with Click's standard exit code (2) for bad params.
215
+ if not workflow_name or not workflow_name.strip():
216
+ raise click.BadParameter("Workflow name cannot be empty.") # noqa: TRY003
174
217
  try:
175
218
  # Get the repository root
176
219
  try:
@@ -216,23 +259,25 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
216
259
  install_cmd = ['uv', 'pip', 'install', '-e', str(new_workflow_dir)]
217
260
  else:
218
261
  install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)]
262
+ if _is_nat_version_prerelease():
263
+ install_cmd.insert(2, "--pre")
219
264
 
220
- config_source = configs_dir / 'config.yml'
265
+ python_safe_workflow_name = workflow_name.replace("-", "_")
221
266
 
222
267
  # List of templates and their destinations
223
268
  files_to_render = {
224
269
  'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml',
225
270
  'register.py.j2': base_dir / 'register.py',
226
- 'workflow.py.j2': base_dir / f'{workflow_name}_function.py',
271
+ 'workflow.py.j2': base_dir / f'{python_safe_workflow_name}.py',
227
272
  '__init__.py.j2': base_dir / '__init__.py',
228
- 'config.yml.j2': config_source,
273
+ 'config.yml.j2': configs_dir / 'config.yml',
229
274
  }
230
275
 
231
276
  # Render templates
232
277
  context = {
233
278
  'editable': editable,
234
279
  'workflow_name': workflow_name,
235
- 'python_safe_workflow_name': workflow_name.replace("-", "_"),
280
+ 'python_safe_workflow_name': python_safe_workflow_name,
236
281
  'package_name': package_name,
237
282
  'rel_path_to_repo_root': rel_path_to_repo_root,
238
283
  'workflow_class_name': f"{_generate_valid_classname(workflow_name)}FunctionConfig",
@@ -246,10 +291,6 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
246
291
  with open(output_path, 'w', encoding="utf-8") as f:
247
292
  f.write(content)
248
293
 
249
- # Create symlink for config.yml
250
- config_link = new_workflow_dir / 'configs' / 'config.yml'
251
- os.symlink(config_source, config_link)
252
-
253
294
  # Create symlinks for config and data directories
254
295
  config_dir_source = configs_dir
255
296
  config_dir_link = new_workflow_dir / 'configs'
@@ -313,7 +354,8 @@ def reinstall_command(workflow_name):
313
354
 
314
355
  @click.command()
315
356
  @click.argument('workflow_name')
316
- def delete_command(workflow_name: str):
357
+ @click.option('-y', '--yes', "yes_flag", is_flag=True, default=False, help='Do not prompt for confirmation.')
358
+ def delete_command(workflow_name: str, yes_flag: bool):
317
359
  """
318
360
  Delete a NAT workflow and uninstall its package.
319
361
 
@@ -321,7 +363,7 @@ def delete_command(workflow_name: str):
321
363
  workflow_name (str): The name of the workflow to delete.
322
364
  """
323
365
  try:
324
- if not click.confirm(f"Are you sure you want to delete the workflow '{workflow_name}'?"):
366
+ if not yes_flag and not click.confirm(f"Are you sure you want to delete the workflow '{workflow_name}'?"):
325
367
  click.echo("Workflow deletion cancelled.")
326
368
  return
327
369
  editable = get_repo_root() is not None
nat/cli/entrypoint.py CHANGED
@@ -29,11 +29,16 @@ import time
29
29
 
30
30
  import click
31
31
  import nest_asyncio
32
+ from dotenv import load_dotenv
33
+
34
+ from nat.utils.log_levels import LOG_LEVELS
32
35
 
33
36
  from .commands.configure.configure import configure_command
34
37
  from .commands.evaluate import eval_command
35
38
  from .commands.info.info import info_command
39
+ from .commands.mcp.mcp import mcp_command
36
40
  from .commands.object_store.object_store import object_store_command
41
+ from .commands.optimize import optimizer_command
37
42
  from .commands.registry.registry import registry_command
38
43
  from .commands.sizing.sizing import sizing
39
44
  from .commands.start import start_command
@@ -41,23 +46,21 @@ from .commands.uninstall import uninstall_command
41
46
  from .commands.validate import validate_command
42
47
  from .commands.workflow.workflow import workflow_command
43
48
 
49
+ # Load environment variables from .env file, if it exists
50
+ load_dotenv()
51
+
44
52
  # Apply at the beginning of the file to avoid issues with asyncio
45
53
  nest_asyncio.apply()
46
54
 
47
- # Define log level choices
48
- LOG_LEVELS = {
49
- 'DEBUG': logging.DEBUG,
50
- 'INFO': logging.INFO,
51
- 'WARNING': logging.WARNING,
52
- 'ERROR': logging.ERROR,
53
- 'CRITICAL': logging.CRITICAL
54
- }
55
-
56
55
 
57
56
  def setup_logging(log_level: str):
58
57
  """Configure logging with the specified level"""
59
58
  numeric_level = LOG_LEVELS.get(log_level.upper(), logging.INFO)
60
- logging.basicConfig(level=numeric_level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
59
+ logging.basicConfig(
60
+ level=numeric_level,
61
+ format="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
62
+ datefmt="%Y-%m-%d %H:%M:%S",
63
+ )
61
64
  return numeric_level
62
65
 
63
66
 
@@ -108,12 +111,13 @@ cli.add_command(uninstall_command, name="uninstall")
108
111
  cli.add_command(validate_command, name="validate")
109
112
  cli.add_command(workflow_command, name="workflow")
110
113
  cli.add_command(sizing, name="sizing")
114
+ cli.add_command(optimizer_command, name="optimize")
111
115
  cli.add_command(object_store_command, name="object-store")
116
+ cli.add_command(mcp_command, name="mcp")
112
117
 
113
118
  # Aliases
114
119
  cli.add_command(start_command.get_command(None, "console"), name="run") # type: ignore
115
120
  cli.add_command(start_command.get_command(None, "fastapi"), name="serve") # type: ignore
116
- cli.add_command(start_command.get_command(None, "mcp"), name="mcp") # type: ignore
117
121
 
118
122
 
119
123
  @cli.result_callback()
nat/cli/main.py CHANGED
@@ -30,6 +30,9 @@ def run_cli():
30
30
  import os
31
31
  import sys
32
32
 
33
+ # Suppress warnings from transformers
34
+ os.environ["TRANSFORMERS_VERBOSITY"] = "error"
35
+
33
36
  parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
34
37
 
35
38
  if (parent_dir not in sys.path):
@@ -27,6 +27,8 @@ from nat.cli.type_registry import EvaluatorRegisteredCallableT
27
27
  from nat.cli.type_registry import FrontEndBuildCallableT
28
28
  from nat.cli.type_registry import FrontEndRegisteredCallableT
29
29
  from nat.cli.type_registry import FunctionBuildCallableT
30
+ from nat.cli.type_registry import FunctionGroupBuildCallableT
31
+ from nat.cli.type_registry import FunctionGroupRegisteredCallableT
30
32
  from nat.cli.type_registry import FunctionRegisteredCallableT
31
33
  from nat.cli.type_registry import LLMClientBuildCallableT
32
34
  from nat.cli.type_registry import LLMClientRegisteredCallableT
@@ -60,6 +62,7 @@ from nat.data_models.embedder import EmbedderBaseConfigT
60
62
  from nat.data_models.evaluator import EvaluatorBaseConfigT
61
63
  from nat.data_models.front_end import FrontEndConfigT
62
64
  from nat.data_models.function import FunctionConfigT
65
+ from nat.data_models.function import FunctionGroupConfigT
63
66
  from nat.data_models.llm import LLMBaseConfigT
64
67
  from nat.data_models.memory import MemoryBaseConfigT
65
68
  from nat.data_models.object_store import ObjectStoreBaseConfigT
@@ -155,10 +158,7 @@ def register_function(config_type: type[FunctionConfigT],
155
158
 
156
159
  context_manager_fn = asynccontextmanager(fn)
157
160
 
158
- if framework_wrappers is None:
159
- framework_wrappers_list: list[str] = []
160
- else:
161
- framework_wrappers_list = list(framework_wrappers)
161
+ framework_wrappers_list = list(framework_wrappers or [])
162
162
 
163
163
  discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
164
164
  component_type=ComponentEnum.FUNCTION)
@@ -177,6 +177,40 @@ def register_function(config_type: type[FunctionConfigT],
177
177
  return register_function_inner
178
178
 
179
179
 
180
+ def register_function_group(config_type: type[FunctionGroupConfigT],
181
+ framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
182
+ """
183
+ Register a function group with optional framework_wrappers for automatic profiler hooking.
184
+ Function groups share configuration/resources across multiple functions.
185
+ """
186
+
187
+ def register_function_group_inner(
188
+ fn: FunctionGroupBuildCallableT[FunctionGroupConfigT]
189
+ ) -> FunctionGroupRegisteredCallableT[FunctionGroupConfigT]:
190
+ from .type_registry import GlobalTypeRegistry
191
+ from .type_registry import RegisteredFunctionGroupInfo
192
+
193
+ context_manager_fn = asynccontextmanager(fn)
194
+
195
+ framework_wrappers_list = list(framework_wrappers or [])
196
+
197
+ discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
198
+ component_type=ComponentEnum.FUNCTION_GROUP)
199
+
200
+ GlobalTypeRegistry.get().register_function_group(
201
+ RegisteredFunctionGroupInfo(
202
+ full_type=config_type.full_type,
203
+ config_type=config_type,
204
+ build_fn=context_manager_fn,
205
+ framework_wrappers=framework_wrappers_list,
206
+ discovery_metadata=discovery_metadata,
207
+ ))
208
+
209
+ return context_manager_fn
210
+
211
+ return register_function_group_inner
212
+
213
+
180
214
  def register_llm_provider(config_type: type[LLMBaseConfigT]):
181
215
 
182
216
  def register_llm_provider_inner(
nat/cli/type_registry.py CHANGED
@@ -37,6 +37,7 @@ from nat.builder.embedder import EmbedderProviderInfo
37
37
  from nat.builder.evaluator import EvaluatorInfo
38
38
  from nat.builder.front_end import FrontEndBase
39
39
  from nat.builder.function import Function
40
+ from nat.builder.function import FunctionGroup
40
41
  from nat.builder.function_base import FunctionBase
41
42
  from nat.builder.function_info import FunctionInfo
42
43
  from nat.builder.llm import LLMProviderInfo
@@ -55,6 +56,8 @@ from nat.data_models.front_end import FrontEndBaseConfig
55
56
  from nat.data_models.front_end import FrontEndConfigT
56
57
  from nat.data_models.function import FunctionBaseConfig
57
58
  from nat.data_models.function import FunctionConfigT
59
+ from nat.data_models.function import FunctionGroupBaseConfig
60
+ from nat.data_models.function import FunctionGroupConfigT
58
61
  from nat.data_models.llm import LLMBaseConfig
59
62
  from nat.data_models.llm import LLMBaseConfigT
60
63
  from nat.data_models.logging import LoggingBaseConfig
@@ -85,6 +88,7 @@ EmbedderProviderBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncI
85
88
  EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIterator[EvaluatorInfo]]
86
89
  FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
87
90
  FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
91
+ FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
88
92
  TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
89
93
  LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
90
94
  LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
@@ -106,6 +110,7 @@ EvaluatorRegisteredCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], Abs
106
110
  FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncContextManager[FrontEndBase]]
107
111
  FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
108
112
  AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
113
+ FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
109
114
  TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
110
115
  LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
111
116
  LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
@@ -178,6 +183,16 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
178
183
  framework_wrappers: list[str] = Field(default_factory=list)
179
184
 
180
185
 
186
+ class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
187
+ """
188
+ Represents a registered function group. Function groups are collections of functions that share configuration
189
+ and resources.
190
+ """
191
+
192
+ build_fn: FunctionGroupRegisteredCallableT = Field(repr=False)
193
+ framework_wrappers: list[str] = Field(default_factory=list)
194
+
195
+
181
196
  class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
182
197
  """
183
198
  Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
@@ -313,6 +328,9 @@ class TypeRegistry:
313
328
  # Functions
314
329
  self._registered_functions: dict[type[FunctionBaseConfig], RegisteredFunctionInfo] = {}
315
330
 
331
+ # Function Groups
332
+ self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
333
+
316
334
  # LLMs
317
335
  self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
318
336
  self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
@@ -478,6 +496,50 @@ class TypeRegistry:
478
496
 
479
497
  return list(self._registered_functions.values())
480
498
 
499
+ def register_function_group(self, registration: RegisteredFunctionGroupInfo):
500
+ """Register a function group with the type registry.
501
+
502
+ Args:
503
+ registration: The function group registration information
504
+
505
+ Raises:
506
+ ValueError: If a function group with the same config type is already registered
507
+ """
508
+ if (registration.config_type in self._registered_function_groups):
509
+ raise ValueError(
510
+ f"A function group with the same config type `{registration.config_type}` has already been "
511
+ "registered.")
512
+
513
+ self._registered_function_groups[registration.config_type] = registration
514
+
515
+ self._registration_changed()
516
+
517
+ def get_function_group(self, config_type: type[FunctionGroupBaseConfig]) -> RegisteredFunctionGroupInfo:
518
+ """Get a registered function group by its config type.
519
+
520
+ Args:
521
+ config_type: The function group configuration type
522
+
523
+ Returns:
524
+ RegisteredFunctionGroupInfo: The registered function group information
525
+
526
+ Raises:
527
+ KeyError: If no function group is registered for the given config type
528
+ """
529
+ try:
530
+ return self._registered_function_groups[config_type]
531
+ except KeyError as err:
532
+ raise KeyError(f"Could not find a registered function group for config `{config_type}`. "
533
+ f"Registered configs: {set(self._registered_function_groups.keys())}") from err
534
+
535
+ def get_registered_function_groups(self) -> list[RegisteredInfo[FunctionGroupBaseConfig]]:
536
+ """Get all registered function groups.
537
+
538
+ Returns:
539
+ list[RegisteredInfo[FunctionGroupBaseConfig]]: List of all registered function groups
540
+ """
541
+ return list(self._registered_function_groups.values())
542
+
481
543
  def register_llm_provider(self, info: RegisteredLLMProviderInfo):
482
544
 
483
545
  if (info.config_type in self._registered_llm_provider_infos):
@@ -790,6 +852,9 @@ class TypeRegistry:
790
852
  if component_type == ComponentEnum.FUNCTION:
791
853
  return self._registered_functions
792
854
 
855
+ if component_type == ComponentEnum.FUNCTION_GROUP:
856
+ return self._registered_function_groups
857
+
793
858
  if component_type == ComponentEnum.TOOL_WRAPPER:
794
859
  return self._registered_tool_wrappers
795
860
 
@@ -854,6 +919,9 @@ class TypeRegistry:
854
919
  if component_type == ComponentEnum.FUNCTION:
855
920
  return [i.static_type() for i in self._registered_functions]
856
921
 
922
+ if component_type == ComponentEnum.FUNCTION_GROUP:
923
+ return [i.static_type() for i in self._registered_function_groups]
924
+
857
925
  if component_type == ComponentEnum.TOOL_WRAPPER:
858
926
  return list(self._registered_tool_wrappers)
859
927
 
@@ -924,7 +992,7 @@ class TypeRegistry:
924
992
  if (short_names[key.local_name] == 1):
925
993
  type_list.append((key.local_name, key.config_type))
926
994
 
927
- return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
995
+ return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
928
996
 
929
997
  def compute_annotation(self, cls: type[TypedBaseModelT]):
930
998
 
@@ -943,6 +1011,9 @@ class TypeRegistry:
943
1011
  if issubclass(cls, FunctionBaseConfig):
944
1012
  return self._do_compute_annotation(cls, self.get_registered_functions())
945
1013
 
1014
+ if issubclass(cls, FunctionGroupBaseConfig):
1015
+ return self._do_compute_annotation(cls, self.get_registered_function_groups())
1016
+
946
1017
  if issubclass(cls, LLMBaseConfig):
947
1018
  return self._do_compute_annotation(cls, self.get_registered_llm_providers())
948
1019
 
File without changes
@@ -0,0 +1,20 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # flake8: noqa
17
+
18
+ # Import any control flows which need to be automatically registered here
19
+ from . import sequential_executor
20
+ from .router_agent import register
File without changes