nvidia-nat 1.2.1rc1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +27 -18
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +81 -50
- nat/agent/react_agent/register.py +59 -40
- nat/agent/reasoning_agent/reasoning_agent.py +17 -15
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +327 -149
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +64 -46
- nat/agent/tool_calling_agent/agent.py +152 -29
- nat/agent/tool_calling_agent/register.py +61 -38
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +10 -6
- nat/builder/context.py +70 -18
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/intermediate_step_manager.py +6 -2
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +327 -79
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +5 -2
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +105 -19
- nat/cli/entrypoint.py +17 -11
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +79 -10
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +196 -67
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +42 -18
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +2 -3
- nat/embedder/register.py +1 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +9 -6
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +19 -7
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +455 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +74 -50
- nat/front_ends/fastapi/message_validator.py +20 -21
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +47 -3
- nat/front_ends/mcp/mcp_front_end_plugin.py +48 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +120 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +57 -0
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +5 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +22 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +14 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +164 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +395 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +105 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +12 -3
- nat/utils/type_utils.py +9 -5
- nvidia_nat-1.3.0.dist-info/METADATA +195 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/RECORD +244 -200
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- nvidia_nat-1.2.1rc1.dist-info/METADATA +0 -365
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,77 @@ from jinja2 import FileSystemLoader
|
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
def _get_nat_version() -> str | None:
|
|
31
|
+
"""
|
|
32
|
+
Get the current NAT version.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
str: The NAT version intended for use in a dependency string.
|
|
36
|
+
None: If the NAT version is not found.
|
|
37
|
+
"""
|
|
38
|
+
from nat.cli.entrypoint import get_version
|
|
39
|
+
|
|
40
|
+
current_version = get_version()
|
|
41
|
+
if current_version == "unknown":
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
version_parts = current_version.split(".")
|
|
45
|
+
if len(version_parts) < 3:
|
|
46
|
+
# If the version somehow doesn't have three parts, return the full version
|
|
47
|
+
return current_version
|
|
48
|
+
|
|
49
|
+
patch = version_parts[2]
|
|
50
|
+
try:
|
|
51
|
+
# If the patch is a number, keep only the major and minor parts
|
|
52
|
+
# Useful for stable releases and adheres to semantic versioning
|
|
53
|
+
_ = int(patch)
|
|
54
|
+
digits_to_keep = 2
|
|
55
|
+
except ValueError:
|
|
56
|
+
# If the patch is not a number, keep all three digits
|
|
57
|
+
# Useful for pre-release versions (and nightly builds)
|
|
58
|
+
digits_to_keep = 3
|
|
59
|
+
|
|
60
|
+
return ".".join(version_parts[:digits_to_keep])
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _is_nat_version_prerelease() -> bool:
|
|
64
|
+
"""
|
|
65
|
+
Check if the NAT version is a prerelease.
|
|
66
|
+
"""
|
|
67
|
+
version = _get_nat_version()
|
|
68
|
+
if version is None:
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
return len(version.split(".")) >= 3
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _get_nat_dependency(versioned: bool = True) -> str:
|
|
75
|
+
"""
|
|
76
|
+
Get the NAT dependency string with version.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
versioned: Whether to include the version in the dependency string
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
str: The dependency string to use in pyproject.toml
|
|
83
|
+
"""
|
|
84
|
+
# Assume the default dependency is LangChain/LangGraph
|
|
85
|
+
dependency = "nvidia-nat[langchain]"
|
|
86
|
+
|
|
87
|
+
if not versioned:
|
|
88
|
+
logger.debug("Using unversioned NAT dependency: %s", dependency)
|
|
89
|
+
return dependency
|
|
90
|
+
|
|
91
|
+
version = _get_nat_version()
|
|
92
|
+
if version is None:
|
|
93
|
+
logger.debug("Could not detect NAT version, using unversioned dependency: %s", dependency)
|
|
94
|
+
return dependency
|
|
95
|
+
|
|
96
|
+
dependency += f"~={version}"
|
|
97
|
+
logger.debug("Using NAT dependency: %s", dependency)
|
|
98
|
+
return dependency
|
|
99
|
+
|
|
100
|
+
|
|
30
101
|
class PackageError(Exception):
|
|
31
102
|
pass
|
|
32
103
|
|
|
@@ -66,7 +137,7 @@ def find_package_root(package_name: str) -> Path | None:
|
|
|
66
137
|
try:
|
|
67
138
|
info = json.loads(direct_url)
|
|
68
139
|
except json.JSONDecodeError:
|
|
69
|
-
logger.
|
|
140
|
+
logger.exception("Malformed direct_url.json for package: %s", package_name)
|
|
70
141
|
return None
|
|
71
142
|
|
|
72
143
|
if not info.get("dir_info", {}).get("editable"):
|
|
@@ -130,7 +201,6 @@ def get_workflow_path_from_name(workflow_name: str):
|
|
|
130
201
|
default="NAT function template. Please update the description.",
|
|
131
202
|
help="""A description of the component being created. Will be used to populate the docstring and will describe the
|
|
132
203
|
component when inspecting installed components using 'nat info component'""")
|
|
133
|
-
# pylint: disable=missing-param-doc
|
|
134
204
|
def create_command(workflow_name: str, install: bool, workflow_dir: str, description: str):
|
|
135
205
|
"""
|
|
136
206
|
Create a new NAT workflow using templates.
|
|
@@ -141,6 +211,9 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
141
211
|
workflow_dir (str): The directory to create the workflow package.
|
|
142
212
|
description (str): Description to pre-popluate the workflow docstring.
|
|
143
213
|
"""
|
|
214
|
+
# Fail fast with Click's standard exit code (2) for bad params.
|
|
215
|
+
if not workflow_name or not workflow_name.strip():
|
|
216
|
+
raise click.BadParameter("Workflow name cannot be empty.") # noqa: TRY003
|
|
144
217
|
try:
|
|
145
218
|
# Get the repository root
|
|
146
219
|
try:
|
|
@@ -166,12 +239,17 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
166
239
|
click.echo(f"Workflow '{workflow_name}' already exists.")
|
|
167
240
|
return
|
|
168
241
|
|
|
242
|
+
base_dir = new_workflow_dir / 'src' / package_name
|
|
243
|
+
|
|
244
|
+
configs_dir = base_dir / 'configs'
|
|
245
|
+
data_dir = base_dir / 'data'
|
|
246
|
+
|
|
169
247
|
# Create directory structure
|
|
170
|
-
|
|
248
|
+
base_dir.mkdir(parents=True)
|
|
171
249
|
# Create config directory
|
|
172
|
-
|
|
173
|
-
# Create
|
|
174
|
-
|
|
250
|
+
configs_dir.mkdir(parents=True)
|
|
251
|
+
# Create data directory
|
|
252
|
+
data_dir.mkdir(parents=True)
|
|
175
253
|
|
|
176
254
|
# Initialize Jinja2 environment
|
|
177
255
|
env = Environment(loader=FileSystemLoader(str(template_dir)))
|
|
@@ -181,25 +259,30 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
181
259
|
install_cmd = ['uv', 'pip', 'install', '-e', str(new_workflow_dir)]
|
|
182
260
|
else:
|
|
183
261
|
install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)]
|
|
262
|
+
if _is_nat_version_prerelease():
|
|
263
|
+
install_cmd.insert(2, "--pre")
|
|
264
|
+
|
|
265
|
+
python_safe_workflow_name = workflow_name.replace("-", "_")
|
|
184
266
|
|
|
185
267
|
# List of templates and their destinations
|
|
186
268
|
files_to_render = {
|
|
187
269
|
'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml',
|
|
188
|
-
'register.py.j2':
|
|
189
|
-
'workflow.py.j2':
|
|
190
|
-
'__init__.py.j2':
|
|
191
|
-
'config.yml.j2':
|
|
270
|
+
'register.py.j2': base_dir / 'register.py',
|
|
271
|
+
'workflow.py.j2': base_dir / f'{python_safe_workflow_name}.py',
|
|
272
|
+
'__init__.py.j2': base_dir / '__init__.py',
|
|
273
|
+
'config.yml.j2': configs_dir / 'config.yml',
|
|
192
274
|
}
|
|
193
275
|
|
|
194
276
|
# Render templates
|
|
195
277
|
context = {
|
|
196
278
|
'editable': editable,
|
|
197
279
|
'workflow_name': workflow_name,
|
|
198
|
-
'python_safe_workflow_name':
|
|
280
|
+
'python_safe_workflow_name': python_safe_workflow_name,
|
|
199
281
|
'package_name': package_name,
|
|
200
282
|
'rel_path_to_repo_root': rel_path_to_repo_root,
|
|
201
283
|
'workflow_class_name': f"{_generate_valid_classname(workflow_name)}FunctionConfig",
|
|
202
|
-
'workflow_description': description
|
|
284
|
+
'workflow_description': description,
|
|
285
|
+
'nat_dependency': _get_nat_dependency()
|
|
203
286
|
}
|
|
204
287
|
|
|
205
288
|
for template_name, output_path in files_to_render.items():
|
|
@@ -208,10 +291,13 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
208
291
|
with open(output_path, 'w', encoding="utf-8") as f:
|
|
209
292
|
f.write(content)
|
|
210
293
|
|
|
211
|
-
# Create
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
294
|
+
# Create symlinks for config and data directories
|
|
295
|
+
config_dir_source = configs_dir
|
|
296
|
+
config_dir_link = new_workflow_dir / 'configs'
|
|
297
|
+
data_dir_source = data_dir
|
|
298
|
+
data_dir_link = new_workflow_dir / 'data'
|
|
299
|
+
os.symlink(config_dir_source, config_dir_link)
|
|
300
|
+
os.symlink(data_dir_source, data_dir_link)
|
|
215
301
|
|
|
216
302
|
if install:
|
|
217
303
|
# Install the new package without changing directories
|
|
@@ -226,7 +312,7 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
226
312
|
|
|
227
313
|
click.echo(f"Workflow '{workflow_name}' created successfully in '{new_workflow_dir}'.")
|
|
228
314
|
except Exception as e:
|
|
229
|
-
logger.exception("An error occurred while creating the workflow: %s", e
|
|
315
|
+
logger.exception("An error occurred while creating the workflow: %s", e)
|
|
230
316
|
click.echo(f"An error occurred while creating the workflow: {e}")
|
|
231
317
|
|
|
232
318
|
|
|
@@ -262,7 +348,7 @@ def reinstall_command(workflow_name):
|
|
|
262
348
|
|
|
263
349
|
click.echo(f"Workflow '{workflow_name}' reinstalled successfully.")
|
|
264
350
|
except Exception as e:
|
|
265
|
-
logger.exception("An error occurred while reinstalling the workflow: %s", e
|
|
351
|
+
logger.exception("An error occurred while reinstalling the workflow: %s", e)
|
|
266
352
|
click.echo(f"An error occurred while reinstalling the workflow: {e}")
|
|
267
353
|
|
|
268
354
|
|
|
@@ -309,7 +395,7 @@ def delete_command(workflow_name: str):
|
|
|
309
395
|
|
|
310
396
|
click.echo(f"Workflow '{workflow_name}' deleted successfully.")
|
|
311
397
|
except Exception as e:
|
|
312
|
-
logger.exception("An error occurred while deleting the workflow: %s", e
|
|
398
|
+
logger.exception("An error occurred while deleting the workflow: %s", e)
|
|
313
399
|
click.echo(f"An error occurred while deleting the workflow: {e}")
|
|
314
400
|
|
|
315
401
|
|
nat/cli/entrypoint.py
CHANGED
|
@@ -29,10 +29,16 @@ import time
|
|
|
29
29
|
|
|
30
30
|
import click
|
|
31
31
|
import nest_asyncio
|
|
32
|
+
from dotenv import load_dotenv
|
|
33
|
+
|
|
34
|
+
from nat.utils.log_levels import LOG_LEVELS
|
|
32
35
|
|
|
33
36
|
from .commands.configure.configure import configure_command
|
|
34
37
|
from .commands.evaluate import eval_command
|
|
35
38
|
from .commands.info.info import info_command
|
|
39
|
+
from .commands.mcp.mcp import mcp_command
|
|
40
|
+
from .commands.object_store.object_store import object_store_command
|
|
41
|
+
from .commands.optimize import optimizer_command
|
|
36
42
|
from .commands.registry.registry import registry_command
|
|
37
43
|
from .commands.sizing.sizing import sizing
|
|
38
44
|
from .commands.start import start_command
|
|
@@ -40,23 +46,21 @@ from .commands.uninstall import uninstall_command
|
|
|
40
46
|
from .commands.validate import validate_command
|
|
41
47
|
from .commands.workflow.workflow import workflow_command
|
|
42
48
|
|
|
49
|
+
# Load environment variables from .env file, if it exists
|
|
50
|
+
load_dotenv()
|
|
51
|
+
|
|
43
52
|
# Apply at the beginning of the file to avoid issues with asyncio
|
|
44
53
|
nest_asyncio.apply()
|
|
45
54
|
|
|
46
|
-
# Define log level choices
|
|
47
|
-
LOG_LEVELS = {
|
|
48
|
-
'DEBUG': logging.DEBUG,
|
|
49
|
-
'INFO': logging.INFO,
|
|
50
|
-
'WARNING': logging.WARNING,
|
|
51
|
-
'ERROR': logging.ERROR,
|
|
52
|
-
'CRITICAL': logging.CRITICAL
|
|
53
|
-
}
|
|
54
|
-
|
|
55
55
|
|
|
56
56
|
def setup_logging(log_level: str):
|
|
57
57
|
"""Configure logging with the specified level"""
|
|
58
58
|
numeric_level = LOG_LEVELS.get(log_level.upper(), logging.INFO)
|
|
59
|
-
logging.basicConfig(
|
|
59
|
+
logging.basicConfig(
|
|
60
|
+
level=numeric_level,
|
|
61
|
+
format="%(asctime)s - %(levelname)-8s - %(name)s:%(lineno)d - %(message)s",
|
|
62
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
63
|
+
)
|
|
60
64
|
return numeric_level
|
|
61
65
|
|
|
62
66
|
|
|
@@ -107,11 +111,13 @@ cli.add_command(uninstall_command, name="uninstall")
|
|
|
107
111
|
cli.add_command(validate_command, name="validate")
|
|
108
112
|
cli.add_command(workflow_command, name="workflow")
|
|
109
113
|
cli.add_command(sizing, name="sizing")
|
|
114
|
+
cli.add_command(optimizer_command, name="optimize")
|
|
115
|
+
cli.add_command(object_store_command, name="object-store")
|
|
116
|
+
cli.add_command(mcp_command, name="mcp")
|
|
110
117
|
|
|
111
118
|
# Aliases
|
|
112
119
|
cli.add_command(start_command.get_command(None, "console"), name="run") # type: ignore
|
|
113
120
|
cli.add_command(start_command.get_command(None, "fastapi"), name="serve") # type: ignore
|
|
114
|
-
cli.add_command(start_command.get_command(None, "mcp"), name="mcp")
|
|
115
121
|
|
|
116
122
|
|
|
117
123
|
@cli.result_callback()
|
nat/cli/main.py
CHANGED
nat/cli/register_workflow.py
CHANGED
|
@@ -27,6 +27,8 @@ from nat.cli.type_registry import EvaluatorRegisteredCallableT
|
|
|
27
27
|
from nat.cli.type_registry import FrontEndBuildCallableT
|
|
28
28
|
from nat.cli.type_registry import FrontEndRegisteredCallableT
|
|
29
29
|
from nat.cli.type_registry import FunctionBuildCallableT
|
|
30
|
+
from nat.cli.type_registry import FunctionGroupBuildCallableT
|
|
31
|
+
from nat.cli.type_registry import FunctionGroupRegisteredCallableT
|
|
30
32
|
from nat.cli.type_registry import FunctionRegisteredCallableT
|
|
31
33
|
from nat.cli.type_registry import LLMClientBuildCallableT
|
|
32
34
|
from nat.cli.type_registry import LLMClientRegisteredCallableT
|
|
@@ -60,6 +62,7 @@ from nat.data_models.embedder import EmbedderBaseConfigT
|
|
|
60
62
|
from nat.data_models.evaluator import EvaluatorBaseConfigT
|
|
61
63
|
from nat.data_models.front_end import FrontEndConfigT
|
|
62
64
|
from nat.data_models.function import FunctionConfigT
|
|
65
|
+
from nat.data_models.function import FunctionGroupConfigT
|
|
63
66
|
from nat.data_models.llm import LLMBaseConfigT
|
|
64
67
|
from nat.data_models.memory import MemoryBaseConfigT
|
|
65
68
|
from nat.data_models.object_store import ObjectStoreBaseConfigT
|
|
@@ -155,10 +158,7 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
155
158
|
|
|
156
159
|
context_manager_fn = asynccontextmanager(fn)
|
|
157
160
|
|
|
158
|
-
|
|
159
|
-
framework_wrappers_list: list[str] = []
|
|
160
|
-
else:
|
|
161
|
-
framework_wrappers_list = list(framework_wrappers)
|
|
161
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
162
162
|
|
|
163
163
|
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
164
164
|
component_type=ComponentEnum.FUNCTION)
|
|
@@ -177,6 +177,40 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
177
177
|
return register_function_inner
|
|
178
178
|
|
|
179
179
|
|
|
180
|
+
def register_function_group(config_type: type[FunctionGroupConfigT],
|
|
181
|
+
framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
|
|
182
|
+
"""
|
|
183
|
+
Register a function group with optional framework_wrappers for automatic profiler hooking.
|
|
184
|
+
Function groups share configuration/resources across multiple functions.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def register_function_group_inner(
|
|
188
|
+
fn: FunctionGroupBuildCallableT[FunctionGroupConfigT]
|
|
189
|
+
) -> FunctionGroupRegisteredCallableT[FunctionGroupConfigT]:
|
|
190
|
+
from .type_registry import GlobalTypeRegistry
|
|
191
|
+
from .type_registry import RegisteredFunctionGroupInfo
|
|
192
|
+
|
|
193
|
+
context_manager_fn = asynccontextmanager(fn)
|
|
194
|
+
|
|
195
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
196
|
+
|
|
197
|
+
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
198
|
+
component_type=ComponentEnum.FUNCTION_GROUP)
|
|
199
|
+
|
|
200
|
+
GlobalTypeRegistry.get().register_function_group(
|
|
201
|
+
RegisteredFunctionGroupInfo(
|
|
202
|
+
full_type=config_type.full_type,
|
|
203
|
+
config_type=config_type,
|
|
204
|
+
build_fn=context_manager_fn,
|
|
205
|
+
framework_wrappers=framework_wrappers_list,
|
|
206
|
+
discovery_metadata=discovery_metadata,
|
|
207
|
+
))
|
|
208
|
+
|
|
209
|
+
return context_manager_fn
|
|
210
|
+
|
|
211
|
+
return register_function_group_inner
|
|
212
|
+
|
|
213
|
+
|
|
180
214
|
def register_llm_provider(config_type: type[LLMBaseConfigT]):
|
|
181
215
|
|
|
182
216
|
def register_llm_provider_inner(
|
nat/cli/type_registry.py
CHANGED
|
@@ -37,6 +37,7 @@ from nat.builder.embedder import EmbedderProviderInfo
|
|
|
37
37
|
from nat.builder.evaluator import EvaluatorInfo
|
|
38
38
|
from nat.builder.front_end import FrontEndBase
|
|
39
39
|
from nat.builder.function import Function
|
|
40
|
+
from nat.builder.function import FunctionGroup
|
|
40
41
|
from nat.builder.function_base import FunctionBase
|
|
41
42
|
from nat.builder.function_info import FunctionInfo
|
|
42
43
|
from nat.builder.llm import LLMProviderInfo
|
|
@@ -55,6 +56,8 @@ from nat.data_models.front_end import FrontEndBaseConfig
|
|
|
55
56
|
from nat.data_models.front_end import FrontEndConfigT
|
|
56
57
|
from nat.data_models.function import FunctionBaseConfig
|
|
57
58
|
from nat.data_models.function import FunctionConfigT
|
|
59
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
60
|
+
from nat.data_models.function import FunctionGroupConfigT
|
|
58
61
|
from nat.data_models.llm import LLMBaseConfig
|
|
59
62
|
from nat.data_models.llm import LLMBaseConfigT
|
|
60
63
|
from nat.data_models.logging import LoggingBaseConfig
|
|
@@ -85,6 +88,7 @@ EmbedderProviderBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncI
|
|
|
85
88
|
EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIterator[EvaluatorInfo]]
|
|
86
89
|
FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
|
|
87
90
|
FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
|
|
91
|
+
FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
|
|
88
92
|
TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
|
|
89
93
|
LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
|
|
90
94
|
LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
|
|
@@ -106,6 +110,7 @@ EvaluatorRegisteredCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], Abs
|
|
|
106
110
|
FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncContextManager[FrontEndBase]]
|
|
107
111
|
FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
|
|
108
112
|
AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
|
|
113
|
+
FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
|
|
109
114
|
TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
|
|
110
115
|
LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
111
116
|
LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
|
|
@@ -178,6 +183,16 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
|
|
|
178
183
|
framework_wrappers: list[str] = Field(default_factory=list)
|
|
179
184
|
|
|
180
185
|
|
|
186
|
+
class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
|
|
187
|
+
"""
|
|
188
|
+
Represents a registered function group. Function groups are collections of functions that share configuration
|
|
189
|
+
and resources.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
build_fn: FunctionGroupRegisteredCallableT = Field(repr=False)
|
|
193
|
+
framework_wrappers: list[str] = Field(default_factory=list)
|
|
194
|
+
|
|
195
|
+
|
|
181
196
|
class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
|
|
182
197
|
"""
|
|
183
198
|
Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
|
|
@@ -298,7 +313,7 @@ class RegisteredPackage(BaseModel):
|
|
|
298
313
|
discovery_metadata: DiscoveryMetadata
|
|
299
314
|
|
|
300
315
|
|
|
301
|
-
class TypeRegistry:
|
|
316
|
+
class TypeRegistry:
|
|
302
317
|
|
|
303
318
|
def __init__(self) -> None:
|
|
304
319
|
# Telemetry Exporters
|
|
@@ -313,6 +328,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
313
328
|
# Functions
|
|
314
329
|
self._registered_functions: dict[type[FunctionBaseConfig], RegisteredFunctionInfo] = {}
|
|
315
330
|
|
|
331
|
+
# Function Groups
|
|
332
|
+
self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
|
|
333
|
+
|
|
316
334
|
# LLMs
|
|
317
335
|
self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
|
|
318
336
|
self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
|
|
@@ -478,6 +496,50 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
478
496
|
|
|
479
497
|
return list(self._registered_functions.values())
|
|
480
498
|
|
|
499
|
+
def register_function_group(self, registration: RegisteredFunctionGroupInfo):
|
|
500
|
+
"""Register a function group with the type registry.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
registration: The function group registration information
|
|
504
|
+
|
|
505
|
+
Raises:
|
|
506
|
+
ValueError: If a function group with the same config type is already registered
|
|
507
|
+
"""
|
|
508
|
+
if (registration.config_type in self._registered_function_groups):
|
|
509
|
+
raise ValueError(
|
|
510
|
+
f"A function group with the same config type `{registration.config_type}` has already been "
|
|
511
|
+
"registered.")
|
|
512
|
+
|
|
513
|
+
self._registered_function_groups[registration.config_type] = registration
|
|
514
|
+
|
|
515
|
+
self._registration_changed()
|
|
516
|
+
|
|
517
|
+
def get_function_group(self, config_type: type[FunctionGroupBaseConfig]) -> RegisteredFunctionGroupInfo:
|
|
518
|
+
"""Get a registered function group by its config type.
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
config_type: The function group configuration type
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
RegisteredFunctionGroupInfo: The registered function group information
|
|
525
|
+
|
|
526
|
+
Raises:
|
|
527
|
+
KeyError: If no function group is registered for the given config type
|
|
528
|
+
"""
|
|
529
|
+
try:
|
|
530
|
+
return self._registered_function_groups[config_type]
|
|
531
|
+
except KeyError as err:
|
|
532
|
+
raise KeyError(f"Could not find a registered function group for config `{config_type}`. "
|
|
533
|
+
f"Registered configs: {set(self._registered_function_groups.keys())}") from err
|
|
534
|
+
|
|
535
|
+
def get_registered_function_groups(self) -> list[RegisteredInfo[FunctionGroupBaseConfig]]:
|
|
536
|
+
"""Get all registered function groups.
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
list[RegisteredInfo[FunctionGroupBaseConfig]]: List of all registered function groups
|
|
540
|
+
"""
|
|
541
|
+
return list(self._registered_function_groups.values())
|
|
542
|
+
|
|
481
543
|
def register_llm_provider(self, info: RegisteredLLMProviderInfo):
|
|
482
544
|
|
|
483
545
|
if (info.config_type in self._registered_llm_provider_infos):
|
|
@@ -588,8 +650,8 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
588
650
|
except KeyError as err:
|
|
589
651
|
raise KeyError(
|
|
590
652
|
f"An invalid Embedder config and wrapper combination was supplied. Config: `{config_type}`, "
|
|
591
|
-
"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
|
|
592
|
-
"there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
|
|
653
|
+
f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
|
|
654
|
+
f"there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
|
|
593
655
|
"Please provide an Embedder configuration from one of the following providers: "
|
|
594
656
|
f"{set(self._embedder_client_provider_to_framework.keys())}") from err
|
|
595
657
|
|
|
@@ -703,8 +765,8 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
703
765
|
except KeyError as err:
|
|
704
766
|
raise KeyError(
|
|
705
767
|
f"An invalid Retriever config and wrapper combination was supplied. Config: `{config_type}`, "
|
|
706
|
-
"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
|
|
707
|
-
"there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
|
|
768
|
+
f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
|
|
769
|
+
f"there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
|
|
708
770
|
"Please provide a Retriever configuration from one of the following providers: "
|
|
709
771
|
f"{set(self._retriever_client_provider_to_framework.keys())}") from err
|
|
710
772
|
|
|
@@ -779,7 +841,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
779
841
|
|
|
780
842
|
self._registration_changed()
|
|
781
843
|
|
|
782
|
-
def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
|
|
844
|
+
def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
|
|
783
845
|
|
|
784
846
|
if component_type == ComponentEnum.FRONT_END:
|
|
785
847
|
return self._registered_front_end_infos
|
|
@@ -790,6 +852,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
790
852
|
if component_type == ComponentEnum.FUNCTION:
|
|
791
853
|
return self._registered_functions
|
|
792
854
|
|
|
855
|
+
if component_type == ComponentEnum.FUNCTION_GROUP:
|
|
856
|
+
return self._registered_function_groups
|
|
857
|
+
|
|
793
858
|
if component_type == ComponentEnum.TOOL_WRAPPER:
|
|
794
859
|
return self._registered_tool_wrappers
|
|
795
860
|
|
|
@@ -849,12 +914,14 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
849
914
|
|
|
850
915
|
raise ValueError(f"Supplied an unsupported component type {component_type}")
|
|
851
916
|
|
|
852
|
-
def get_registered_types_by_component_type(
|
|
853
|
-
self, component_type: ComponentEnum) -> list[str]:
|
|
917
|
+
def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
|
|
854
918
|
|
|
855
919
|
if component_type == ComponentEnum.FUNCTION:
|
|
856
920
|
return [i.static_type() for i in self._registered_functions]
|
|
857
921
|
|
|
922
|
+
if component_type == ComponentEnum.FUNCTION_GROUP:
|
|
923
|
+
return [i.static_type() for i in self._registered_function_groups]
|
|
924
|
+
|
|
858
925
|
if component_type == ComponentEnum.TOOL_WRAPPER:
|
|
859
926
|
return list(self._registered_tool_wrappers)
|
|
860
927
|
|
|
@@ -925,8 +992,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
925
992
|
if (short_names[key.local_name] == 1):
|
|
926
993
|
type_list.append((key.local_name, key.config_type))
|
|
927
994
|
|
|
928
|
-
|
|
929
|
-
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
995
|
+
return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
930
996
|
|
|
931
997
|
def compute_annotation(self, cls: type[TypedBaseModelT]):
|
|
932
998
|
|
|
@@ -945,6 +1011,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
945
1011
|
if issubclass(cls, FunctionBaseConfig):
|
|
946
1012
|
return self._do_compute_annotation(cls, self.get_registered_functions())
|
|
947
1013
|
|
|
1014
|
+
if issubclass(cls, FunctionGroupBaseConfig):
|
|
1015
|
+
return self._do_compute_annotation(cls, self.get_registered_function_groups())
|
|
1016
|
+
|
|
948
1017
|
if issubclass(cls, LLMBaseConfig):
|
|
949
1018
|
return self._do_compute_annotation(cls, self.get_registered_llm_providers())
|
|
950
1019
|
|
|
File without changes
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
|
|
18
|
+
# Import any control flows which need to be automatically registered here
|
|
19
|
+
from . import sequential_executor
|
|
20
|
+
from .router_agent import register
|
|
File without changes
|