nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +13 -8
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +6 -5
- nat/agent/react_agent/register.py +49 -39
- nat/agent/reasoning_agent/reasoning_agent.py +17 -15
- nat/agent/register.py +2 -0
- nat/agent/responses_api_agent/__init__.py +14 -0
- nat/agent/responses_api_agent/register.py +126 -0
- nat/agent/rewoo_agent/agent.py +304 -117
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +51 -38
- nat/agent/tool_calling_agent/agent.py +75 -17
- nat/agent/tool_calling_agent/register.py +46 -23
- nat/authentication/api_key/api_key_auth_provider.py +6 -11
- nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
- nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
- nat/builder/builder.py +55 -23
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +54 -15
- nat/builder/eval_builder.py +14 -9
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +370 -0
- nat/builder/function_info.py +1 -1
- nat/builder/intermediate_step_manager.py +38 -2
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +306 -54
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/start.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/register.py.j2 +2 -2
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +60 -18
- nat/cli/entrypoint.py +15 -11
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +72 -1
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +199 -69
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +47 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +4 -3
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/intermediate_step.py +9 -1
- nat/data_models/llm.py +15 -1
- nat/data_models/openai_mcp.py +46 -0
- nat/data_models/optimizable.py +208 -0
- nat/data_models/optimizer.py +161 -0
- nat/data_models/span.py +41 -3
- nat/data_models/thinking_mixin.py +2 -2
- nat/embedder/azure_openai_embedder.py +2 -1
- nat/embedder/nim_embedder.py +3 -2
- nat/embedder/openai_embedder.py +3 -2
- nat/eval/config.py +1 -1
- nat/eval/dataset_handler/dataset_downloader.py +3 -2
- nat/eval/dataset_handler/dataset_filter.py +34 -2
- nat/eval/evaluate.py +10 -3
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +7 -4
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
- nat/eval/usage_stats.py +2 -0
- nat/eval/utils/output_uploader.py +3 -2
- nat/eval/utils/weave_eval.py +17 -3
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +2 -2
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +19 -7
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +69 -44
- nat/front_ends/fastapi/message_validator.py +8 -7
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +71 -3
- nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/tool_converter.py +78 -25
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +21 -8
- nat/llm/azure_openai_llm.py +14 -5
- nat/llm/litellm_llm.py +80 -0
- nat/llm/nim_llm.py +23 -9
- nat/llm/openai_llm.py +19 -7
- nat/llm/register.py +4 -0
- nat/llm/utils/thinking.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/exporter/span_exporter.py +43 -15
- nat/observability/exporter_manager.py +2 -2
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +1 -1
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/observability/register.py +16 -0
- nat/profiler/callbacks/langchain_callback_handler.py +32 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
- nat/profiler/callbacks/token_usage_base_model.py +2 -0
- nat/profiler/decorators/framework_wrapper.py +61 -9
- nat/profiler/decorators/function_tracking.py +35 -3
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/registry_handlers/pypi/register_pypi.py +5 -3
- nat/registry_handlers/rest/register_rest.py +5 -3
- nat/retriever/milvus/retriever.py +1 -1
- nat/retriever/nemo_retriever/register.py +2 -1
- nat/runtime/loader.py +1 -1
- nat/runtime/runner.py +111 -6
- nat/runtime/session.py +49 -3
- nat/settings/global_settings.py +2 -2
- nat/tool/chat_completion.py +4 -1
- nat/tool/code_execution/code_sandbox.py +3 -6
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
- nat/tool/datetime_tools.py +1 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/tool/register.py +2 -7
- nat/tool/server_tools.py +15 -2
- nat/utils/__init__.py +76 -0
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +278 -72
- nat/utils/io/yaml_tools.py +73 -3
- nat/utils/log_levels.py +25 -0
- nat/utils/responses_api.py +26 -0
- nat/utils/string_utils.py +16 -0
- nat/utils/type_converter.py +12 -3
- nat/utils/type_utils.py +6 -2
- nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -461
- nat/data_models/temperature_mixin.py +0 -43
- nat/data_models/top_p_mixin.py +0 -43
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
import click
|
|
21
|
+
|
|
22
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
23
|
+
from nat.profiler.parameter_optimization.optimizer_runtime import optimize_config
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@click.group(name=__name__, invoke_without_command=True, help="Optimize a workflow with the specified dataset.")
|
|
29
|
+
@click.option(
|
|
30
|
+
"--config_file",
|
|
31
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
|
|
32
|
+
required=True,
|
|
33
|
+
help="A JSON/YAML file that sets the parameters for the workflow and evaluation.",
|
|
34
|
+
)
|
|
35
|
+
@click.option(
|
|
36
|
+
"--dataset",
|
|
37
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
|
|
38
|
+
required=False,
|
|
39
|
+
help="A json file with questions and ground truth answers. This will override the dataset path in the config file.",
|
|
40
|
+
)
|
|
41
|
+
@click.option(
|
|
42
|
+
"--result_json_path",
|
|
43
|
+
type=str,
|
|
44
|
+
default="$",
|
|
45
|
+
help=("A JSON path to extract the result from the workflow. Use this when the workflow returns "
|
|
46
|
+
"multiple objects or a dictionary. For example, '$.output' will extract the 'output' field "
|
|
47
|
+
"from the result."),
|
|
48
|
+
)
|
|
49
|
+
@click.option(
|
|
50
|
+
"--endpoint",
|
|
51
|
+
type=str,
|
|
52
|
+
default=None,
|
|
53
|
+
help="Use endpoint for running the workflow. Example: http://localhost:8000/generate",
|
|
54
|
+
)
|
|
55
|
+
@click.option(
|
|
56
|
+
"--endpoint_timeout",
|
|
57
|
+
type=int,
|
|
58
|
+
default=300,
|
|
59
|
+
help="HTTP response timeout in seconds. Only relevant if endpoint is specified.",
|
|
60
|
+
)
|
|
61
|
+
@click.pass_context
|
|
62
|
+
def optimizer_command(ctx, **kwargs) -> None:
|
|
63
|
+
""" Optimize workflow with the specified dataset"""
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
async def run_optimizer(config: OptimizerRunConfig):
|
|
68
|
+
await optimize_config(config)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@optimizer_command.result_callback(replace=True)
|
|
72
|
+
def run_optimizer_callback(
|
|
73
|
+
processors, # pylint: disable=unused-argument
|
|
74
|
+
*,
|
|
75
|
+
config_file: Path,
|
|
76
|
+
dataset: Path,
|
|
77
|
+
result_json_path: str,
|
|
78
|
+
endpoint: str,
|
|
79
|
+
endpoint_timeout: int,
|
|
80
|
+
):
|
|
81
|
+
"""Run the optimizer with the provided config file and dataset."""
|
|
82
|
+
config = OptimizerRunConfig(
|
|
83
|
+
config_file=config_file,
|
|
84
|
+
dataset=dataset,
|
|
85
|
+
result_json_path=result_json_path,
|
|
86
|
+
endpoint=endpoint,
|
|
87
|
+
endpoint_timeout=endpoint_timeout,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
asyncio.run(run_optimizer(config))
|
nat/cli/commands/start.py
CHANGED
|
@@ -111,7 +111,7 @@ class StartCommandGroup(click.Group):
|
|
|
111
111
|
elif (issubclass(decomposed_type.root, Path)):
|
|
112
112
|
param_type = click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path)
|
|
113
113
|
|
|
114
|
-
elif (issubclass(decomposed_type.root,
|
|
114
|
+
elif (issubclass(decomposed_type.root, list | tuple | set)):
|
|
115
115
|
if (len(decomposed_type.args) == 1):
|
|
116
116
|
inner = DecomposedType(decomposed_type.args[0])
|
|
117
117
|
# Support containers of Literal values -> multiple Choice
|
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
1
|
+
functions:
|
|
2
|
+
current_datetime:
|
|
3
|
+
_type: current_datetime
|
|
4
|
+
{{python_safe_workflow_name}}:
|
|
5
|
+
_type: {{python_safe_workflow_name}}
|
|
6
|
+
prefix: "Hello:"
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
8
|
+
llms:
|
|
9
|
+
nim_llm:
|
|
10
|
+
_type: nim
|
|
11
|
+
model_name: meta/llama-3.1-70b-instruct
|
|
12
|
+
temperature: 0.0
|
|
13
13
|
|
|
14
14
|
workflow:
|
|
15
|
-
_type:
|
|
16
|
-
|
|
15
|
+
_type: react_agent
|
|
16
|
+
llm_name: nim_llm
|
|
17
|
+
tool_names: [current_datetime, {{python_safe_workflow_name}}]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
1
|
# flake8: noqa
|
|
2
2
|
|
|
3
|
-
# Import
|
|
4
|
-
from {{package_name}} import {{
|
|
3
|
+
# Import the generated workflow function to trigger registration
|
|
4
|
+
from .{{package_name}} import {{ python_safe_workflow_name }}_function
|
|
@@ -3,6 +3,7 @@ import logging
|
|
|
3
3
|
from pydantic import Field
|
|
4
4
|
|
|
5
5
|
from nat.builder.builder import Builder
|
|
6
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
6
7
|
from nat.builder.function_info import FunctionInfo
|
|
7
8
|
from nat.cli.register_workflow import register_function
|
|
8
9
|
from nat.data_models.function import FunctionBaseConfig
|
|
@@ -12,25 +13,38 @@ logger = logging.getLogger(__name__)
|
|
|
12
13
|
|
|
13
14
|
class {{ workflow_class_name }}(FunctionBaseConfig, name="{{ workflow_name }}"):
|
|
14
15
|
"""
|
|
15
|
-
{{workflow_description}}
|
|
16
|
+
{{ workflow_description }}
|
|
16
17
|
"""
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
18
|
+
prefix: str = Field(default="Echo:", description="Prefix to add before the echoed text.")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_function(config_type={{ workflow_class_name }}, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
22
|
+
async def {{ python_safe_workflow_name }}_function(config: {{ workflow_class_name }}, builder: Builder):
|
|
23
|
+
"""
|
|
24
|
+
Registers a function (addressable via `{{ workflow_name }}` in the configuration).
|
|
25
|
+
This registration ensures a static mapping of the function type, `{{ workflow_name }}`, to the `{{ workflow_class_name }}` configuration object.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
config ({{ workflow_class_name }}): The configuration for the function.
|
|
29
|
+
builder (Builder): The builder object.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
FunctionInfo: The function info object for the function.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
# Define the function that will be registered.
|
|
36
|
+
async def _echo(text: str) -> str:
|
|
37
|
+
"""
|
|
38
|
+
Takes a text input and echoes back with a pre-defined prefix.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
text (str): The text to echo back.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
str: The text with the prefix.
|
|
45
|
+
"""
|
|
46
|
+
return f"{config.prefix} {text}"
|
|
47
|
+
|
|
48
|
+
# The callable is wrapped in a FunctionInfo object.
|
|
49
|
+
# The description parameter is used to describe the function.
|
|
50
|
+
yield FunctionInfo.from_fn(_echo, description=_echo.__doc__)
|
|
@@ -27,6 +27,50 @@ from jinja2 import FileSystemLoader
|
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
def _get_nat_version() -> str | None:
|
|
31
|
+
"""
|
|
32
|
+
Get the current NAT version.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
str: The NAT version intended for use in a dependency string.
|
|
36
|
+
None: If the NAT version is not found.
|
|
37
|
+
"""
|
|
38
|
+
from nat.cli.entrypoint import get_version
|
|
39
|
+
|
|
40
|
+
current_version = get_version()
|
|
41
|
+
if current_version == "unknown":
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
version_parts = current_version.split(".")
|
|
45
|
+
if len(version_parts) < 3:
|
|
46
|
+
# If the version somehow doesn't have three parts, return the full version
|
|
47
|
+
return current_version
|
|
48
|
+
|
|
49
|
+
patch = version_parts[2]
|
|
50
|
+
try:
|
|
51
|
+
# If the patch is a number, keep only the major and minor parts
|
|
52
|
+
# Useful for stable releases and adheres to semantic versioning
|
|
53
|
+
_ = int(patch)
|
|
54
|
+
digits_to_keep = 2
|
|
55
|
+
except ValueError:
|
|
56
|
+
# If the patch is not a number, keep all three digits
|
|
57
|
+
# Useful for pre-release versions (and nightly builds)
|
|
58
|
+
digits_to_keep = 3
|
|
59
|
+
|
|
60
|
+
return ".".join(version_parts[:digits_to_keep])
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _is_nat_version_prerelease() -> bool:
|
|
64
|
+
"""
|
|
65
|
+
Check if the NAT version is a prerelease.
|
|
66
|
+
"""
|
|
67
|
+
version = _get_nat_version()
|
|
68
|
+
if version is None:
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
return len(version.split(".")) >= 3
|
|
72
|
+
|
|
73
|
+
|
|
30
74
|
def _get_nat_dependency(versioned: bool = True) -> str:
|
|
31
75
|
"""
|
|
32
76
|
Get the NAT dependency string with version.
|
|
@@ -44,16 +88,12 @@ def _get_nat_dependency(versioned: bool = True) -> str:
|
|
|
44
88
|
logger.debug("Using unversioned NAT dependency: %s", dependency)
|
|
45
89
|
return dependency
|
|
46
90
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
if current_version == "unknown":
|
|
51
|
-
logger.warning("Could not detect NAT version, using unversioned dependency")
|
|
91
|
+
version = _get_nat_version()
|
|
92
|
+
if version is None:
|
|
93
|
+
logger.debug("Could not detect NAT version, using unversioned dependency: %s", dependency)
|
|
52
94
|
return dependency
|
|
53
95
|
|
|
54
|
-
|
|
55
|
-
major_minor = ".".join(current_version.split(".")[:2])
|
|
56
|
-
dependency += f"~={major_minor}"
|
|
96
|
+
dependency += f"~={version}"
|
|
57
97
|
logger.debug("Using NAT dependency: %s", dependency)
|
|
58
98
|
return dependency
|
|
59
99
|
|
|
@@ -171,6 +211,9 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
171
211
|
workflow_dir (str): The directory to create the workflow package.
|
|
172
212
|
description (str): Description to pre-popluate the workflow docstring.
|
|
173
213
|
"""
|
|
214
|
+
# Fail fast with Click's standard exit code (2) for bad params.
|
|
215
|
+
if not workflow_name or not workflow_name.strip():
|
|
216
|
+
raise click.BadParameter("Workflow name cannot be empty.") # noqa: TRY003
|
|
174
217
|
try:
|
|
175
218
|
# Get the repository root
|
|
176
219
|
try:
|
|
@@ -216,23 +259,25 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
216
259
|
install_cmd = ['uv', 'pip', 'install', '-e', str(new_workflow_dir)]
|
|
217
260
|
else:
|
|
218
261
|
install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)]
|
|
262
|
+
if _is_nat_version_prerelease():
|
|
263
|
+
install_cmd.insert(2, "--pre")
|
|
219
264
|
|
|
220
|
-
|
|
265
|
+
python_safe_workflow_name = workflow_name.replace("-", "_")
|
|
221
266
|
|
|
222
267
|
# List of templates and their destinations
|
|
223
268
|
files_to_render = {
|
|
224
269
|
'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml',
|
|
225
270
|
'register.py.j2': base_dir / 'register.py',
|
|
226
|
-
'workflow.py.j2': base_dir / f'{
|
|
271
|
+
'workflow.py.j2': base_dir / f'{python_safe_workflow_name}.py',
|
|
227
272
|
'__init__.py.j2': base_dir / '__init__.py',
|
|
228
|
-
'config.yml.j2':
|
|
273
|
+
'config.yml.j2': configs_dir / 'config.yml',
|
|
229
274
|
}
|
|
230
275
|
|
|
231
276
|
# Render templates
|
|
232
277
|
context = {
|
|
233
278
|
'editable': editable,
|
|
234
279
|
'workflow_name': workflow_name,
|
|
235
|
-
'python_safe_workflow_name':
|
|
280
|
+
'python_safe_workflow_name': python_safe_workflow_name,
|
|
236
281
|
'package_name': package_name,
|
|
237
282
|
'rel_path_to_repo_root': rel_path_to_repo_root,
|
|
238
283
|
'workflow_class_name': f"{_generate_valid_classname(workflow_name)}FunctionConfig",
|
|
@@ -246,10 +291,6 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
246
291
|
with open(output_path, 'w', encoding="utf-8") as f:
|
|
247
292
|
f.write(content)
|
|
248
293
|
|
|
249
|
-
# Create symlink for config.yml
|
|
250
|
-
config_link = new_workflow_dir / 'configs' / 'config.yml'
|
|
251
|
-
os.symlink(config_source, config_link)
|
|
252
|
-
|
|
253
294
|
# Create symlinks for config and data directories
|
|
254
295
|
config_dir_source = configs_dir
|
|
255
296
|
config_dir_link = new_workflow_dir / 'configs'
|
|
@@ -313,7 +354,8 @@ def reinstall_command(workflow_name):
|
|
|
313
354
|
|
|
314
355
|
@click.command()
|
|
315
356
|
@click.argument('workflow_name')
|
|
316
|
-
|
|
357
|
+
@click.option('-y', '--yes', "yes_flag", is_flag=True, default=False, help='Do not prompt for confirmation.')
|
|
358
|
+
def delete_command(workflow_name: str, yes_flag: bool):
|
|
317
359
|
"""
|
|
318
360
|
Delete a NAT workflow and uninstall its package.
|
|
319
361
|
|
|
@@ -321,7 +363,7 @@ def delete_command(workflow_name: str):
|
|
|
321
363
|
workflow_name (str): The name of the workflow to delete.
|
|
322
364
|
"""
|
|
323
365
|
try:
|
|
324
|
-
if not click.confirm(f"Are you sure you want to delete the workflow '{workflow_name}'?"):
|
|
366
|
+
if not yes_flag and not click.confirm(f"Are you sure you want to delete the workflow '{workflow_name}'?"):
|
|
325
367
|
click.echo("Workflow deletion cancelled.")
|
|
326
368
|
return
|
|
327
369
|
editable = get_repo_root() is not None
|
nat/cli/entrypoint.py
CHANGED
|
@@ -29,11 +29,16 @@ import time
|
|
|
29
29
|
|
|
30
30
|
import click
|
|
31
31
|
import nest_asyncio
|
|
32
|
+
from dotenv import load_dotenv
|
|
33
|
+
|
|
34
|
+
from nat.utils.log_levels import LOG_LEVELS
|
|
32
35
|
|
|
33
36
|
from .commands.configure.configure import configure_command
|
|
34
37
|
from .commands.evaluate import eval_command
|
|
35
38
|
from .commands.info.info import info_command
|
|
39
|
+
from .commands.mcp.mcp import mcp_command
|
|
36
40
|
from .commands.object_store.object_store import object_store_command
|
|
41
|
+
from .commands.optimize import optimizer_command
|
|
37
42
|
from .commands.registry.registry import registry_command
|
|
38
43
|
from .commands.sizing.sizing import sizing
|
|
39
44
|
from .commands.start import start_command
|
|
@@ -41,23 +46,21 @@ from .commands.uninstall import uninstall_command
|
|
|
41
46
|
from .commands.validate import validate_command
|
|
42
47
|
from .commands.workflow.workflow import workflow_command
|
|
43
48
|
|
|
49
|
+
# Load environment variables from .env file, if it exists
|
|
50
|
+
load_dotenv()
|
|
51
|
+
|
|
44
52
|
# Apply at the beginning of the file to avoid issues with asyncio
|
|
45
53
|
nest_asyncio.apply()
|
|
46
54
|
|
|
47
|
-
# Define log level choices
|
|
48
|
-
LOG_LEVELS = {
|
|
49
|
-
'DEBUG': logging.DEBUG,
|
|
50
|
-
'INFO': logging.INFO,
|
|
51
|
-
'WARNING': logging.WARNING,
|
|
52
|
-
'ERROR': logging.ERROR,
|
|
53
|
-
'CRITICAL': logging.CRITICAL
|
|
54
|
-
}
|
|
55
|
-
|
|
56
55
|
|
|
57
56
|
def setup_logging(log_level: str):
|
|
58
57
|
"""Configure logging with the specified level"""
|
|
59
58
|
numeric_level = LOG_LEVELS.get(log_level.upper(), logging.INFO)
|
|
60
|
-
logging.basicConfig(
|
|
59
|
+
logging.basicConfig(
|
|
60
|
+
level=numeric_level,
|
|
61
|
+
format="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
62
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
63
|
+
)
|
|
61
64
|
return numeric_level
|
|
62
65
|
|
|
63
66
|
|
|
@@ -108,12 +111,13 @@ cli.add_command(uninstall_command, name="uninstall")
|
|
|
108
111
|
cli.add_command(validate_command, name="validate")
|
|
109
112
|
cli.add_command(workflow_command, name="workflow")
|
|
110
113
|
cli.add_command(sizing, name="sizing")
|
|
114
|
+
cli.add_command(optimizer_command, name="optimize")
|
|
111
115
|
cli.add_command(object_store_command, name="object-store")
|
|
116
|
+
cli.add_command(mcp_command, name="mcp")
|
|
112
117
|
|
|
113
118
|
# Aliases
|
|
114
119
|
cli.add_command(start_command.get_command(None, "console"), name="run") # type: ignore
|
|
115
120
|
cli.add_command(start_command.get_command(None, "fastapi"), name="serve") # type: ignore
|
|
116
|
-
cli.add_command(start_command.get_command(None, "mcp"), name="mcp") # type: ignore
|
|
117
121
|
|
|
118
122
|
|
|
119
123
|
@cli.result_callback()
|
nat/cli/main.py
CHANGED
nat/cli/register_workflow.py
CHANGED
|
@@ -27,6 +27,8 @@ from nat.cli.type_registry import EvaluatorRegisteredCallableT
|
|
|
27
27
|
from nat.cli.type_registry import FrontEndBuildCallableT
|
|
28
28
|
from nat.cli.type_registry import FrontEndRegisteredCallableT
|
|
29
29
|
from nat.cli.type_registry import FunctionBuildCallableT
|
|
30
|
+
from nat.cli.type_registry import FunctionGroupBuildCallableT
|
|
31
|
+
from nat.cli.type_registry import FunctionGroupRegisteredCallableT
|
|
30
32
|
from nat.cli.type_registry import FunctionRegisteredCallableT
|
|
31
33
|
from nat.cli.type_registry import LLMClientBuildCallableT
|
|
32
34
|
from nat.cli.type_registry import LLMClientRegisteredCallableT
|
|
@@ -60,6 +62,7 @@ from nat.data_models.embedder import EmbedderBaseConfigT
|
|
|
60
62
|
from nat.data_models.evaluator import EvaluatorBaseConfigT
|
|
61
63
|
from nat.data_models.front_end import FrontEndConfigT
|
|
62
64
|
from nat.data_models.function import FunctionConfigT
|
|
65
|
+
from nat.data_models.function import FunctionGroupConfigT
|
|
63
66
|
from nat.data_models.llm import LLMBaseConfigT
|
|
64
67
|
from nat.data_models.memory import MemoryBaseConfigT
|
|
65
68
|
from nat.data_models.object_store import ObjectStoreBaseConfigT
|
|
@@ -155,10 +158,7 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
155
158
|
|
|
156
159
|
context_manager_fn = asynccontextmanager(fn)
|
|
157
160
|
|
|
158
|
-
|
|
159
|
-
framework_wrappers_list: list[str] = []
|
|
160
|
-
else:
|
|
161
|
-
framework_wrappers_list = list(framework_wrappers)
|
|
161
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
162
162
|
|
|
163
163
|
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
164
164
|
component_type=ComponentEnum.FUNCTION)
|
|
@@ -177,6 +177,40 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
177
177
|
return register_function_inner
|
|
178
178
|
|
|
179
179
|
|
|
180
|
+
def register_function_group(config_type: type[FunctionGroupConfigT],
|
|
181
|
+
framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
|
|
182
|
+
"""
|
|
183
|
+
Register a function group with optional framework_wrappers for automatic profiler hooking.
|
|
184
|
+
Function groups share configuration/resources across multiple functions.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def register_function_group_inner(
|
|
188
|
+
fn: FunctionGroupBuildCallableT[FunctionGroupConfigT]
|
|
189
|
+
) -> FunctionGroupRegisteredCallableT[FunctionGroupConfigT]:
|
|
190
|
+
from .type_registry import GlobalTypeRegistry
|
|
191
|
+
from .type_registry import RegisteredFunctionGroupInfo
|
|
192
|
+
|
|
193
|
+
context_manager_fn = asynccontextmanager(fn)
|
|
194
|
+
|
|
195
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
196
|
+
|
|
197
|
+
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
198
|
+
component_type=ComponentEnum.FUNCTION_GROUP)
|
|
199
|
+
|
|
200
|
+
GlobalTypeRegistry.get().register_function_group(
|
|
201
|
+
RegisteredFunctionGroupInfo(
|
|
202
|
+
full_type=config_type.full_type,
|
|
203
|
+
config_type=config_type,
|
|
204
|
+
build_fn=context_manager_fn,
|
|
205
|
+
framework_wrappers=framework_wrappers_list,
|
|
206
|
+
discovery_metadata=discovery_metadata,
|
|
207
|
+
))
|
|
208
|
+
|
|
209
|
+
return context_manager_fn
|
|
210
|
+
|
|
211
|
+
return register_function_group_inner
|
|
212
|
+
|
|
213
|
+
|
|
180
214
|
def register_llm_provider(config_type: type[LLMBaseConfigT]):
|
|
181
215
|
|
|
182
216
|
def register_llm_provider_inner(
|
nat/cli/type_registry.py
CHANGED
|
@@ -37,6 +37,7 @@ from nat.builder.embedder import EmbedderProviderInfo
|
|
|
37
37
|
from nat.builder.evaluator import EvaluatorInfo
|
|
38
38
|
from nat.builder.front_end import FrontEndBase
|
|
39
39
|
from nat.builder.function import Function
|
|
40
|
+
from nat.builder.function import FunctionGroup
|
|
40
41
|
from nat.builder.function_base import FunctionBase
|
|
41
42
|
from nat.builder.function_info import FunctionInfo
|
|
42
43
|
from nat.builder.llm import LLMProviderInfo
|
|
@@ -55,6 +56,8 @@ from nat.data_models.front_end import FrontEndBaseConfig
|
|
|
55
56
|
from nat.data_models.front_end import FrontEndConfigT
|
|
56
57
|
from nat.data_models.function import FunctionBaseConfig
|
|
57
58
|
from nat.data_models.function import FunctionConfigT
|
|
59
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
60
|
+
from nat.data_models.function import FunctionGroupConfigT
|
|
58
61
|
from nat.data_models.llm import LLMBaseConfig
|
|
59
62
|
from nat.data_models.llm import LLMBaseConfigT
|
|
60
63
|
from nat.data_models.logging import LoggingBaseConfig
|
|
@@ -85,6 +88,7 @@ EmbedderProviderBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncI
|
|
|
85
88
|
EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIterator[EvaluatorInfo]]
|
|
86
89
|
FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
|
|
87
90
|
FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
|
|
91
|
+
FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
|
|
88
92
|
TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
|
|
89
93
|
LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
|
|
90
94
|
LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
|
|
@@ -106,6 +110,7 @@ EvaluatorRegisteredCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], Abs
|
|
|
106
110
|
FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncContextManager[FrontEndBase]]
|
|
107
111
|
FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
|
|
108
112
|
AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
|
|
113
|
+
FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
|
|
109
114
|
TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
|
|
110
115
|
LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
111
116
|
LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
|
|
@@ -178,6 +183,16 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
|
|
|
178
183
|
framework_wrappers: list[str] = Field(default_factory=list)
|
|
179
184
|
|
|
180
185
|
|
|
186
|
+
class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
|
|
187
|
+
"""
|
|
188
|
+
Represents a registered function group. Function groups are collections of functions that share configuration
|
|
189
|
+
and resources.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
build_fn: FunctionGroupRegisteredCallableT = Field(repr=False)
|
|
193
|
+
framework_wrappers: list[str] = Field(default_factory=list)
|
|
194
|
+
|
|
195
|
+
|
|
181
196
|
class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
|
|
182
197
|
"""
|
|
183
198
|
Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
|
|
@@ -313,6 +328,9 @@ class TypeRegistry:
|
|
|
313
328
|
# Functions
|
|
314
329
|
self._registered_functions: dict[type[FunctionBaseConfig], RegisteredFunctionInfo] = {}
|
|
315
330
|
|
|
331
|
+
# Function Groups
|
|
332
|
+
self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
|
|
333
|
+
|
|
316
334
|
# LLMs
|
|
317
335
|
self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
|
|
318
336
|
self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
|
|
@@ -478,6 +496,50 @@ class TypeRegistry:
|
|
|
478
496
|
|
|
479
497
|
return list(self._registered_functions.values())
|
|
480
498
|
|
|
499
|
+
def register_function_group(self, registration: RegisteredFunctionGroupInfo):
|
|
500
|
+
"""Register a function group with the type registry.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
registration: The function group registration information
|
|
504
|
+
|
|
505
|
+
Raises:
|
|
506
|
+
ValueError: If a function group with the same config type is already registered
|
|
507
|
+
"""
|
|
508
|
+
if (registration.config_type in self._registered_function_groups):
|
|
509
|
+
raise ValueError(
|
|
510
|
+
f"A function group with the same config type `{registration.config_type}` has already been "
|
|
511
|
+
"registered.")
|
|
512
|
+
|
|
513
|
+
self._registered_function_groups[registration.config_type] = registration
|
|
514
|
+
|
|
515
|
+
self._registration_changed()
|
|
516
|
+
|
|
517
|
+
def get_function_group(self, config_type: type[FunctionGroupBaseConfig]) -> RegisteredFunctionGroupInfo:
|
|
518
|
+
"""Get a registered function group by its config type.
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
config_type: The function group configuration type
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
RegisteredFunctionGroupInfo: The registered function group information
|
|
525
|
+
|
|
526
|
+
Raises:
|
|
527
|
+
KeyError: If no function group is registered for the given config type
|
|
528
|
+
"""
|
|
529
|
+
try:
|
|
530
|
+
return self._registered_function_groups[config_type]
|
|
531
|
+
except KeyError as err:
|
|
532
|
+
raise KeyError(f"Could not find a registered function group for config `{config_type}`. "
|
|
533
|
+
f"Registered configs: {set(self._registered_function_groups.keys())}") from err
|
|
534
|
+
|
|
535
|
+
def get_registered_function_groups(self) -> list[RegisteredInfo[FunctionGroupBaseConfig]]:
|
|
536
|
+
"""Get all registered function groups.
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
list[RegisteredInfo[FunctionGroupBaseConfig]]: List of all registered function groups
|
|
540
|
+
"""
|
|
541
|
+
return list(self._registered_function_groups.values())
|
|
542
|
+
|
|
481
543
|
def register_llm_provider(self, info: RegisteredLLMProviderInfo):
|
|
482
544
|
|
|
483
545
|
if (info.config_type in self._registered_llm_provider_infos):
|
|
@@ -790,6 +852,9 @@ class TypeRegistry:
|
|
|
790
852
|
if component_type == ComponentEnum.FUNCTION:
|
|
791
853
|
return self._registered_functions
|
|
792
854
|
|
|
855
|
+
if component_type == ComponentEnum.FUNCTION_GROUP:
|
|
856
|
+
return self._registered_function_groups
|
|
857
|
+
|
|
793
858
|
if component_type == ComponentEnum.TOOL_WRAPPER:
|
|
794
859
|
return self._registered_tool_wrappers
|
|
795
860
|
|
|
@@ -854,6 +919,9 @@ class TypeRegistry:
|
|
|
854
919
|
if component_type == ComponentEnum.FUNCTION:
|
|
855
920
|
return [i.static_type() for i in self._registered_functions]
|
|
856
921
|
|
|
922
|
+
if component_type == ComponentEnum.FUNCTION_GROUP:
|
|
923
|
+
return [i.static_type() for i in self._registered_function_groups]
|
|
924
|
+
|
|
857
925
|
if component_type == ComponentEnum.TOOL_WRAPPER:
|
|
858
926
|
return list(self._registered_tool_wrappers)
|
|
859
927
|
|
|
@@ -924,7 +992,7 @@ class TypeRegistry:
|
|
|
924
992
|
if (short_names[key.local_name] == 1):
|
|
925
993
|
type_list.append((key.local_name, key.config_type))
|
|
926
994
|
|
|
927
|
-
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
995
|
+
return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
928
996
|
|
|
929
997
|
def compute_annotation(self, cls: type[TypedBaseModelT]):
|
|
930
998
|
|
|
@@ -943,6 +1011,9 @@ class TypeRegistry:
|
|
|
943
1011
|
if issubclass(cls, FunctionBaseConfig):
|
|
944
1012
|
return self._do_compute_annotation(cls, self.get_registered_functions())
|
|
945
1013
|
|
|
1014
|
+
if issubclass(cls, FunctionGroupBaseConfig):
|
|
1015
|
+
return self._do_compute_annotation(cls, self.get_registered_function_groups())
|
|
1016
|
+
|
|
946
1017
|
if issubclass(cls, LLMBaseConfig):
|
|
947
1018
|
return self._do_compute_annotation(cls, self.get_registered_llm_providers())
|
|
948
1019
|
|
|
File without changes
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
|
|
18
|
+
# Import any control flows which need to be automatically registered here
|
|
19
|
+
from . import sequential_executor
|
|
20
|
+
from .router_agent import register
|
|
File without changes
|