nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +50 -22
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +54 -27
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +68 -17
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +62 -22
- nat/cli/entrypoint.py +8 -10
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +74 -66
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +19 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +106 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,50 @@ from jinja2 import FileSystemLoader
|
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
def _get_nat_version() -> str | None:
|
|
31
|
+
"""
|
|
32
|
+
Get the current NAT version.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
str: The NAT version intended for use in a dependency string.
|
|
36
|
+
None: If the NAT version is not found.
|
|
37
|
+
"""
|
|
38
|
+
from nat.cli.entrypoint import get_version
|
|
39
|
+
|
|
40
|
+
current_version = get_version()
|
|
41
|
+
if current_version == "unknown":
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
version_parts = current_version.split(".")
|
|
45
|
+
if len(version_parts) < 3:
|
|
46
|
+
# If the version somehow doesn't have three parts, return the full version
|
|
47
|
+
return current_version
|
|
48
|
+
|
|
49
|
+
patch = version_parts[2]
|
|
50
|
+
try:
|
|
51
|
+
# If the patch is a number, keep only the major and minor parts
|
|
52
|
+
# Useful for stable releases and adheres to semantic versioning
|
|
53
|
+
_ = int(patch)
|
|
54
|
+
digits_to_keep = 2
|
|
55
|
+
except ValueError:
|
|
56
|
+
# If the patch is not a number, keep all three digits
|
|
57
|
+
# Useful for pre-release versions (and nightly builds)
|
|
58
|
+
digits_to_keep = 3
|
|
59
|
+
|
|
60
|
+
return ".".join(version_parts[:digits_to_keep])
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _is_nat_version_prerelease() -> bool:
|
|
64
|
+
"""
|
|
65
|
+
Check if the NAT version is a prerelease.
|
|
66
|
+
"""
|
|
67
|
+
version = _get_nat_version()
|
|
68
|
+
if version is None:
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
return len(version.split(".")) >= 3
|
|
72
|
+
|
|
73
|
+
|
|
30
74
|
def _get_nat_dependency(versioned: bool = True) -> str:
|
|
31
75
|
"""
|
|
32
76
|
Get the NAT dependency string with version.
|
|
@@ -37,23 +81,19 @@ def _get_nat_dependency(versioned: bool = True) -> str:
|
|
|
37
81
|
Returns:
|
|
38
82
|
str: The dependency string to use in pyproject.toml
|
|
39
83
|
"""
|
|
40
|
-
# Assume the default dependency is
|
|
84
|
+
# Assume the default dependency is LangChain/LangGraph
|
|
41
85
|
dependency = "nvidia-nat[langchain]"
|
|
42
86
|
|
|
43
87
|
if not versioned:
|
|
44
88
|
logger.debug("Using unversioned NAT dependency: %s", dependency)
|
|
45
89
|
return dependency
|
|
46
90
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
if current_version == "unknown":
|
|
51
|
-
logger.warning("Could not detect NAT version, using unversioned dependency")
|
|
91
|
+
version = _get_nat_version()
|
|
92
|
+
if version is None:
|
|
93
|
+
logger.debug("Could not detect NAT version, using unversioned dependency: %s", dependency)
|
|
52
94
|
return dependency
|
|
53
95
|
|
|
54
|
-
|
|
55
|
-
major_minor = ".".join(current_version.split(".")[:2])
|
|
56
|
-
dependency += f"~={major_minor}"
|
|
96
|
+
dependency += f"~={version}"
|
|
57
97
|
logger.debug("Using NAT dependency: %s", dependency)
|
|
58
98
|
return dependency
|
|
59
99
|
|
|
@@ -97,7 +137,7 @@ def find_package_root(package_name: str) -> Path | None:
|
|
|
97
137
|
try:
|
|
98
138
|
info = json.loads(direct_url)
|
|
99
139
|
except json.JSONDecodeError:
|
|
100
|
-
logger.
|
|
140
|
+
logger.exception("Malformed direct_url.json for package: %s", package_name)
|
|
101
141
|
return None
|
|
102
142
|
|
|
103
143
|
if not info.get("dir_info", {}).get("editable"):
|
|
@@ -161,7 +201,6 @@ def get_workflow_path_from_name(workflow_name: str):
|
|
|
161
201
|
default="NAT function template. Please update the description.",
|
|
162
202
|
help="""A description of the component being created. Will be used to populate the docstring and will describe the
|
|
163
203
|
component when inspecting installed components using 'nat info component'""")
|
|
164
|
-
# pylint: disable=missing-param-doc
|
|
165
204
|
def create_command(workflow_name: str, install: bool, workflow_dir: str, description: str):
|
|
166
205
|
"""
|
|
167
206
|
Create a new NAT workflow using templates.
|
|
@@ -172,6 +211,9 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
172
211
|
workflow_dir (str): The directory to create the workflow package.
|
|
173
212
|
description (str): Description to pre-popluate the workflow docstring.
|
|
174
213
|
"""
|
|
214
|
+
# Fail fast with Click's standard exit code (2) for bad params.
|
|
215
|
+
if not workflow_name or not workflow_name.strip():
|
|
216
|
+
raise click.BadParameter("Workflow name cannot be empty.") # noqa: TRY003
|
|
175
217
|
try:
|
|
176
218
|
# Get the repository root
|
|
177
219
|
try:
|
|
@@ -217,23 +259,25 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
217
259
|
install_cmd = ['uv', 'pip', 'install', '-e', str(new_workflow_dir)]
|
|
218
260
|
else:
|
|
219
261
|
install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)]
|
|
262
|
+
if _is_nat_version_prerelease():
|
|
263
|
+
install_cmd.insert(2, "--pre")
|
|
220
264
|
|
|
221
|
-
|
|
265
|
+
python_safe_workflow_name = workflow_name.replace("-", "_")
|
|
222
266
|
|
|
223
267
|
# List of templates and their destinations
|
|
224
268
|
files_to_render = {
|
|
225
269
|
'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml',
|
|
226
270
|
'register.py.j2': base_dir / 'register.py',
|
|
227
|
-
'workflow.py.j2': base_dir / f'{
|
|
271
|
+
'workflow.py.j2': base_dir / f'{python_safe_workflow_name}.py',
|
|
228
272
|
'__init__.py.j2': base_dir / '__init__.py',
|
|
229
|
-
'config.yml.j2':
|
|
273
|
+
'config.yml.j2': configs_dir / 'config.yml',
|
|
230
274
|
}
|
|
231
275
|
|
|
232
276
|
# Render templates
|
|
233
277
|
context = {
|
|
234
278
|
'editable': editable,
|
|
235
279
|
'workflow_name': workflow_name,
|
|
236
|
-
'python_safe_workflow_name':
|
|
280
|
+
'python_safe_workflow_name': python_safe_workflow_name,
|
|
237
281
|
'package_name': package_name,
|
|
238
282
|
'rel_path_to_repo_root': rel_path_to_repo_root,
|
|
239
283
|
'workflow_class_name': f"{_generate_valid_classname(workflow_name)}FunctionConfig",
|
|
@@ -247,10 +291,6 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
247
291
|
with open(output_path, 'w', encoding="utf-8") as f:
|
|
248
292
|
f.write(content)
|
|
249
293
|
|
|
250
|
-
# Create symlink for config.yml
|
|
251
|
-
config_link = new_workflow_dir / 'configs' / 'config.yml'
|
|
252
|
-
os.symlink(config_source, config_link)
|
|
253
|
-
|
|
254
294
|
# Create symlinks for config and data directories
|
|
255
295
|
config_dir_source = configs_dir
|
|
256
296
|
config_dir_link = new_workflow_dir / 'configs'
|
|
@@ -272,7 +312,7 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
272
312
|
|
|
273
313
|
click.echo(f"Workflow '{workflow_name}' created successfully in '{new_workflow_dir}'.")
|
|
274
314
|
except Exception as e:
|
|
275
|
-
logger.exception("An error occurred while creating the workflow: %s", e
|
|
315
|
+
logger.exception("An error occurred while creating the workflow: %s", e)
|
|
276
316
|
click.echo(f"An error occurred while creating the workflow: {e}")
|
|
277
317
|
|
|
278
318
|
|
|
@@ -308,7 +348,7 @@ def reinstall_command(workflow_name):
|
|
|
308
348
|
|
|
309
349
|
click.echo(f"Workflow '{workflow_name}' reinstalled successfully.")
|
|
310
350
|
except Exception as e:
|
|
311
|
-
logger.exception("An error occurred while reinstalling the workflow: %s", e
|
|
351
|
+
logger.exception("An error occurred while reinstalling the workflow: %s", e)
|
|
312
352
|
click.echo(f"An error occurred while reinstalling the workflow: {e}")
|
|
313
353
|
|
|
314
354
|
|
|
@@ -355,7 +395,7 @@ def delete_command(workflow_name: str):
|
|
|
355
395
|
|
|
356
396
|
click.echo(f"Workflow '{workflow_name}' deleted successfully.")
|
|
357
397
|
except Exception as e:
|
|
358
|
-
logger.exception("An error occurred while deleting the workflow: %s", e
|
|
398
|
+
logger.exception("An error occurred while deleting the workflow: %s", e)
|
|
359
399
|
click.echo(f"An error occurred while deleting the workflow: {e}")
|
|
360
400
|
|
|
361
401
|
|
nat/cli/entrypoint.py
CHANGED
|
@@ -30,9 +30,14 @@ import time
|
|
|
30
30
|
import click
|
|
31
31
|
import nest_asyncio
|
|
32
32
|
|
|
33
|
+
from nat.utils.log_levels import LOG_LEVELS
|
|
34
|
+
|
|
33
35
|
from .commands.configure.configure import configure_command
|
|
34
36
|
from .commands.evaluate import eval_command
|
|
35
37
|
from .commands.info.info import info_command
|
|
38
|
+
from .commands.mcp.mcp import mcp_command
|
|
39
|
+
from .commands.object_store.object_store import object_store_command
|
|
40
|
+
from .commands.optimize import optimizer_command
|
|
36
41
|
from .commands.registry.registry import registry_command
|
|
37
42
|
from .commands.sizing.sizing import sizing
|
|
38
43
|
from .commands.start import start_command
|
|
@@ -43,15 +48,6 @@ from .commands.workflow.workflow import workflow_command
|
|
|
43
48
|
# Apply at the beginning of the file to avoid issues with asyncio
|
|
44
49
|
nest_asyncio.apply()
|
|
45
50
|
|
|
46
|
-
# Define log level choices
|
|
47
|
-
LOG_LEVELS = {
|
|
48
|
-
'DEBUG': logging.DEBUG,
|
|
49
|
-
'INFO': logging.INFO,
|
|
50
|
-
'WARNING': logging.WARNING,
|
|
51
|
-
'ERROR': logging.ERROR,
|
|
52
|
-
'CRITICAL': logging.CRITICAL
|
|
53
|
-
}
|
|
54
|
-
|
|
55
51
|
|
|
56
52
|
def setup_logging(log_level: str):
|
|
57
53
|
"""Configure logging with the specified level"""
|
|
@@ -107,11 +103,13 @@ cli.add_command(uninstall_command, name="uninstall")
|
|
|
107
103
|
cli.add_command(validate_command, name="validate")
|
|
108
104
|
cli.add_command(workflow_command, name="workflow")
|
|
109
105
|
cli.add_command(sizing, name="sizing")
|
|
106
|
+
cli.add_command(optimizer_command, name="optimize")
|
|
107
|
+
cli.add_command(object_store_command, name="object-store")
|
|
108
|
+
cli.add_command(mcp_command, name="mcp")
|
|
110
109
|
|
|
111
110
|
# Aliases
|
|
112
111
|
cli.add_command(start_command.get_command(None, "console"), name="run") # type: ignore
|
|
113
112
|
cli.add_command(start_command.get_command(None, "fastapi"), name="serve") # type: ignore
|
|
114
|
-
cli.add_command(start_command.get_command(None, "mcp"), name="mcp")
|
|
115
113
|
|
|
116
114
|
|
|
117
115
|
@cli.result_callback()
|
nat/cli/main.py
CHANGED
nat/cli/register_workflow.py
CHANGED
|
@@ -27,6 +27,8 @@ from nat.cli.type_registry import EvaluatorRegisteredCallableT
|
|
|
27
27
|
from nat.cli.type_registry import FrontEndBuildCallableT
|
|
28
28
|
from nat.cli.type_registry import FrontEndRegisteredCallableT
|
|
29
29
|
from nat.cli.type_registry import FunctionBuildCallableT
|
|
30
|
+
from nat.cli.type_registry import FunctionGroupBuildCallableT
|
|
31
|
+
from nat.cli.type_registry import FunctionGroupRegisteredCallableT
|
|
30
32
|
from nat.cli.type_registry import FunctionRegisteredCallableT
|
|
31
33
|
from nat.cli.type_registry import LLMClientBuildCallableT
|
|
32
34
|
from nat.cli.type_registry import LLMClientRegisteredCallableT
|
|
@@ -60,6 +62,7 @@ from nat.data_models.embedder import EmbedderBaseConfigT
|
|
|
60
62
|
from nat.data_models.evaluator import EvaluatorBaseConfigT
|
|
61
63
|
from nat.data_models.front_end import FrontEndConfigT
|
|
62
64
|
from nat.data_models.function import FunctionConfigT
|
|
65
|
+
from nat.data_models.function import FunctionGroupConfigT
|
|
63
66
|
from nat.data_models.llm import LLMBaseConfigT
|
|
64
67
|
from nat.data_models.memory import MemoryBaseConfigT
|
|
65
68
|
from nat.data_models.object_store import ObjectStoreBaseConfigT
|
|
@@ -155,10 +158,7 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
155
158
|
|
|
156
159
|
context_manager_fn = asynccontextmanager(fn)
|
|
157
160
|
|
|
158
|
-
|
|
159
|
-
framework_wrappers_list: list[str] = []
|
|
160
|
-
else:
|
|
161
|
-
framework_wrappers_list = list(framework_wrappers)
|
|
161
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
162
162
|
|
|
163
163
|
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
164
164
|
component_type=ComponentEnum.FUNCTION)
|
|
@@ -177,6 +177,40 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
177
177
|
return register_function_inner
|
|
178
178
|
|
|
179
179
|
|
|
180
|
+
def register_function_group(config_type: type[FunctionGroupConfigT],
|
|
181
|
+
framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
|
|
182
|
+
"""
|
|
183
|
+
Register a function group with optional framework_wrappers for automatic profiler hooking.
|
|
184
|
+
Function groups share configuration/resources across multiple functions.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def register_function_group_inner(
|
|
188
|
+
fn: FunctionGroupBuildCallableT[FunctionGroupConfigT]
|
|
189
|
+
) -> FunctionGroupRegisteredCallableT[FunctionGroupConfigT]:
|
|
190
|
+
from .type_registry import GlobalTypeRegistry
|
|
191
|
+
from .type_registry import RegisteredFunctionGroupInfo
|
|
192
|
+
|
|
193
|
+
context_manager_fn = asynccontextmanager(fn)
|
|
194
|
+
|
|
195
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
196
|
+
|
|
197
|
+
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
198
|
+
component_type=ComponentEnum.FUNCTION_GROUP)
|
|
199
|
+
|
|
200
|
+
GlobalTypeRegistry.get().register_function_group(
|
|
201
|
+
RegisteredFunctionGroupInfo(
|
|
202
|
+
full_type=config_type.full_type,
|
|
203
|
+
config_type=config_type,
|
|
204
|
+
build_fn=context_manager_fn,
|
|
205
|
+
framework_wrappers=framework_wrappers_list,
|
|
206
|
+
discovery_metadata=discovery_metadata,
|
|
207
|
+
))
|
|
208
|
+
|
|
209
|
+
return context_manager_fn
|
|
210
|
+
|
|
211
|
+
return register_function_group_inner
|
|
212
|
+
|
|
213
|
+
|
|
180
214
|
def register_llm_provider(config_type: type[LLMBaseConfigT]):
|
|
181
215
|
|
|
182
216
|
def register_llm_provider_inner(
|
nat/cli/type_registry.py
CHANGED
|
@@ -37,6 +37,7 @@ from nat.builder.embedder import EmbedderProviderInfo
|
|
|
37
37
|
from nat.builder.evaluator import EvaluatorInfo
|
|
38
38
|
from nat.builder.front_end import FrontEndBase
|
|
39
39
|
from nat.builder.function import Function
|
|
40
|
+
from nat.builder.function import FunctionGroup
|
|
40
41
|
from nat.builder.function_base import FunctionBase
|
|
41
42
|
from nat.builder.function_info import FunctionInfo
|
|
42
43
|
from nat.builder.llm import LLMProviderInfo
|
|
@@ -55,6 +56,8 @@ from nat.data_models.front_end import FrontEndBaseConfig
|
|
|
55
56
|
from nat.data_models.front_end import FrontEndConfigT
|
|
56
57
|
from nat.data_models.function import FunctionBaseConfig
|
|
57
58
|
from nat.data_models.function import FunctionConfigT
|
|
59
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
60
|
+
from nat.data_models.function import FunctionGroupConfigT
|
|
58
61
|
from nat.data_models.llm import LLMBaseConfig
|
|
59
62
|
from nat.data_models.llm import LLMBaseConfigT
|
|
60
63
|
from nat.data_models.logging import LoggingBaseConfig
|
|
@@ -85,6 +88,7 @@ EmbedderProviderBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncI
|
|
|
85
88
|
EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIterator[EvaluatorInfo]]
|
|
86
89
|
FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
|
|
87
90
|
FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
|
|
91
|
+
FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
|
|
88
92
|
TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
|
|
89
93
|
LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
|
|
90
94
|
LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
|
|
@@ -106,6 +110,7 @@ EvaluatorRegisteredCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], Abs
|
|
|
106
110
|
FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncContextManager[FrontEndBase]]
|
|
107
111
|
FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
|
|
108
112
|
AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
|
|
113
|
+
FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
|
|
109
114
|
TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
|
|
110
115
|
LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
111
116
|
LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
|
|
@@ -178,6 +183,16 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
|
|
|
178
183
|
framework_wrappers: list[str] = Field(default_factory=list)
|
|
179
184
|
|
|
180
185
|
|
|
186
|
+
class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
|
|
187
|
+
"""
|
|
188
|
+
Represents a registered function group. Function groups are collections of functions that share configuration
|
|
189
|
+
and resources.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
build_fn: FunctionGroupRegisteredCallableT = Field(repr=False)
|
|
193
|
+
framework_wrappers: list[str] = Field(default_factory=list)
|
|
194
|
+
|
|
195
|
+
|
|
181
196
|
class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
|
|
182
197
|
"""
|
|
183
198
|
Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
|
|
@@ -298,7 +313,7 @@ class RegisteredPackage(BaseModel):
|
|
|
298
313
|
discovery_metadata: DiscoveryMetadata
|
|
299
314
|
|
|
300
315
|
|
|
301
|
-
class TypeRegistry:
|
|
316
|
+
class TypeRegistry:
|
|
302
317
|
|
|
303
318
|
def __init__(self) -> None:
|
|
304
319
|
# Telemetry Exporters
|
|
@@ -313,6 +328,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
313
328
|
# Functions
|
|
314
329
|
self._registered_functions: dict[type[FunctionBaseConfig], RegisteredFunctionInfo] = {}
|
|
315
330
|
|
|
331
|
+
# Function Groups
|
|
332
|
+
self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
|
|
333
|
+
|
|
316
334
|
# LLMs
|
|
317
335
|
self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
|
|
318
336
|
self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
|
|
@@ -478,6 +496,50 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
478
496
|
|
|
479
497
|
return list(self._registered_functions.values())
|
|
480
498
|
|
|
499
|
+
def register_function_group(self, registration: RegisteredFunctionGroupInfo):
|
|
500
|
+
"""Register a function group with the type registry.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
registration: The function group registration information
|
|
504
|
+
|
|
505
|
+
Raises:
|
|
506
|
+
ValueError: If a function group with the same config type is already registered
|
|
507
|
+
"""
|
|
508
|
+
if (registration.config_type in self._registered_function_groups):
|
|
509
|
+
raise ValueError(
|
|
510
|
+
f"A function group with the same config type `{registration.config_type}` has already been "
|
|
511
|
+
"registered.")
|
|
512
|
+
|
|
513
|
+
self._registered_function_groups[registration.config_type] = registration
|
|
514
|
+
|
|
515
|
+
self._registration_changed()
|
|
516
|
+
|
|
517
|
+
def get_function_group(self, config_type: type[FunctionGroupBaseConfig]) -> RegisteredFunctionGroupInfo:
|
|
518
|
+
"""Get a registered function group by its config type.
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
config_type: The function group configuration type
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
RegisteredFunctionGroupInfo: The registered function group information
|
|
525
|
+
|
|
526
|
+
Raises:
|
|
527
|
+
KeyError: If no function group is registered for the given config type
|
|
528
|
+
"""
|
|
529
|
+
try:
|
|
530
|
+
return self._registered_function_groups[config_type]
|
|
531
|
+
except KeyError as err:
|
|
532
|
+
raise KeyError(f"Could not find a registered function group for config `{config_type}`. "
|
|
533
|
+
f"Registered configs: {set(self._registered_function_groups.keys())}") from err
|
|
534
|
+
|
|
535
|
+
def get_registered_function_groups(self) -> list[RegisteredInfo[FunctionGroupBaseConfig]]:
|
|
536
|
+
"""Get all registered function groups.
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
list[RegisteredInfo[FunctionGroupBaseConfig]]: List of all registered function groups
|
|
540
|
+
"""
|
|
541
|
+
return list(self._registered_function_groups.values())
|
|
542
|
+
|
|
481
543
|
def register_llm_provider(self, info: RegisteredLLMProviderInfo):
|
|
482
544
|
|
|
483
545
|
if (info.config_type in self._registered_llm_provider_infos):
|
|
@@ -779,7 +841,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
779
841
|
|
|
780
842
|
self._registration_changed()
|
|
781
843
|
|
|
782
|
-
def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
|
|
844
|
+
def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
|
|
783
845
|
|
|
784
846
|
if component_type == ComponentEnum.FRONT_END:
|
|
785
847
|
return self._registered_front_end_infos
|
|
@@ -790,6 +852,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
790
852
|
if component_type == ComponentEnum.FUNCTION:
|
|
791
853
|
return self._registered_functions
|
|
792
854
|
|
|
855
|
+
if component_type == ComponentEnum.FUNCTION_GROUP:
|
|
856
|
+
return self._registered_function_groups
|
|
857
|
+
|
|
793
858
|
if component_type == ComponentEnum.TOOL_WRAPPER:
|
|
794
859
|
return self._registered_tool_wrappers
|
|
795
860
|
|
|
@@ -849,12 +914,14 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
849
914
|
|
|
850
915
|
raise ValueError(f"Supplied an unsupported component type {component_type}")
|
|
851
916
|
|
|
852
|
-
def get_registered_types_by_component_type(
|
|
853
|
-
self, component_type: ComponentEnum) -> list[str]:
|
|
917
|
+
def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
|
|
854
918
|
|
|
855
919
|
if component_type == ComponentEnum.FUNCTION:
|
|
856
920
|
return [i.static_type() for i in self._registered_functions]
|
|
857
921
|
|
|
922
|
+
if component_type == ComponentEnum.FUNCTION_GROUP:
|
|
923
|
+
return [i.static_type() for i in self._registered_function_groups]
|
|
924
|
+
|
|
858
925
|
if component_type == ComponentEnum.TOOL_WRAPPER:
|
|
859
926
|
return list(self._registered_tool_wrappers)
|
|
860
927
|
|
|
@@ -925,8 +992,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
925
992
|
if (short_names[key.local_name] == 1):
|
|
926
993
|
type_list.append((key.local_name, key.config_type))
|
|
927
994
|
|
|
928
|
-
|
|
929
|
-
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
995
|
+
return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
930
996
|
|
|
931
997
|
def compute_annotation(self, cls: type[TypedBaseModelT]):
|
|
932
998
|
|
|
@@ -945,6 +1011,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
945
1011
|
if issubclass(cls, FunctionBaseConfig):
|
|
946
1012
|
return self._do_compute_annotation(cls, self.get_registered_functions())
|
|
947
1013
|
|
|
1014
|
+
if issubclass(cls, FunctionGroupBaseConfig):
|
|
1015
|
+
return self._do_compute_annotation(cls, self.get_registered_function_groups())
|
|
1016
|
+
|
|
948
1017
|
if issubclass(cls, LLMBaseConfig):
|
|
949
1018
|
return self._do_compute_annotation(cls, self.get_registered_llm_providers())
|
|
950
1019
|
|
|
File without changes
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
|
|
18
|
+
# Import any control flows which need to be automatically registered here
|
|
19
|
+
from . import sequential_executor
|
|
20
|
+
from .router_agent import register
|
|
File without changes
|