nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from nat.builder.builder import Builder
|
|
19
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
20
|
+
from nat.cli.register_workflow import register_ttc_strategy
|
|
21
|
+
from nat.experimental.test_time_compute.models.editor_config import MotivationAwareSummarizationConfig
|
|
22
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
23
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
24
|
+
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
25
|
+
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
26
|
+
from nat.utils.io.model_processing import remove_r1_think_tags
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MotivationAwareSummarization(StrategyBase):
|
|
32
|
+
"""
|
|
33
|
+
A strategy that, for each incoming TTCItem, summarizes the output based on input
|
|
34
|
+
and motivation.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, config: MotivationAwareSummarizationConfig) -> None:
|
|
38
|
+
super().__init__(config)
|
|
39
|
+
self.config = config
|
|
40
|
+
self.llm_bound = None
|
|
41
|
+
|
|
42
|
+
async def build_components(self, builder: Builder) -> None:
|
|
43
|
+
"""
|
|
44
|
+
Binds each LLMRef in self.config.llms to an actual LLM client.
|
|
45
|
+
"""
|
|
46
|
+
bound_llm = await builder.get_llm(self.config.editor_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
47
|
+
self.llm_bound = bound_llm
|
|
48
|
+
|
|
49
|
+
def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
|
|
50
|
+
return [PipelineTypeEnum.TOOL_USE]
|
|
51
|
+
|
|
52
|
+
def stage_type(self) -> StageTypeEnum:
|
|
53
|
+
return StageTypeEnum.EDITING
|
|
54
|
+
|
|
55
|
+
async def ainvoke(self,
|
|
56
|
+
items: list[TTCItem],
|
|
57
|
+
original_prompt: str | None = None,
|
|
58
|
+
agent_context: str | None = None,
|
|
59
|
+
**kwargs) -> list[TTCItem]:
|
|
60
|
+
"""
|
|
61
|
+
For each TTCItem, rewrite the 'input' using each LLM to create a new perspective.
|
|
62
|
+
The new TTCItems' 'output' field will store the newly generated query.
|
|
63
|
+
"""
|
|
64
|
+
try:
|
|
65
|
+
from langchain_core.prompts import PromptTemplate
|
|
66
|
+
except ImportError:
|
|
67
|
+
raise ImportError("langchain-core is required for MultiQueryRetrievalSearch. "
|
|
68
|
+
"Install nvidia-nat-langchain or similar.")
|
|
69
|
+
|
|
70
|
+
new_ttc_items: list[TTCItem] = []
|
|
71
|
+
|
|
72
|
+
# Create a single PromptTemplate object for rewriting the query
|
|
73
|
+
template_vars = ["task", "motivation", "output"]
|
|
74
|
+
query_template = PromptTemplate(template=self.config.editor_template,
|
|
75
|
+
input_variables=template_vars,
|
|
76
|
+
validate_template=True)
|
|
77
|
+
|
|
78
|
+
for item in items:
|
|
79
|
+
original_task = str(item.input) or ""
|
|
80
|
+
motivation = str(item.metadata) if item.metadata else ""
|
|
81
|
+
output = str(item.output) if item.output else ""
|
|
82
|
+
|
|
83
|
+
prompt = await (query_template.ainvoke(input={
|
|
84
|
+
"task": original_task, "motivation": motivation, "output": output
|
|
85
|
+
}))
|
|
86
|
+
|
|
87
|
+
llm_response = await self.llm_bound.ainvoke(prompt.to_string())
|
|
88
|
+
llm_response = remove_r1_think_tags(llm_response.content)
|
|
89
|
+
|
|
90
|
+
logger.info("LLM response from summarization: %s", llm_response)
|
|
91
|
+
|
|
92
|
+
new_ttc_items.append(
|
|
93
|
+
TTCItem(
|
|
94
|
+
input=item.input,
|
|
95
|
+
output=remove_r1_think_tags(llm_response),
|
|
96
|
+
metadata=item.metadata,
|
|
97
|
+
name=item.name, # keep the original tool name
|
|
98
|
+
))
|
|
99
|
+
|
|
100
|
+
return new_ttc_items
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@register_ttc_strategy(config_type=MotivationAwareSummarizationConfig)
|
|
104
|
+
async def register_multi_query_retrieval_search(config: MotivationAwareSummarizationConfig, builder: Builder):
|
|
105
|
+
strategy = MotivationAwareSummarization(config)
|
|
106
|
+
await strategy.build_components(builder)
|
|
107
|
+
yield strategy
|
|
File without changes
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from nat.builder.builder import Builder
|
|
21
|
+
from nat.builder.function import Function
|
|
22
|
+
from nat.builder.function_info import FunctionInfo
|
|
23
|
+
from nat.cli.register_workflow import register_function
|
|
24
|
+
from nat.data_models.component_ref import FunctionRef
|
|
25
|
+
from nat.data_models.component_ref import TTCStrategyRef
|
|
26
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
27
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
28
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
29
|
+
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ExecuteScoreSelectFunctionConfig(FunctionBaseConfig, name="execute_score_select_function"):
|
|
35
|
+
scorer: TTCStrategyRef | None = Field(description="Strategy to score the output of the function", default=None)
|
|
36
|
+
selector: TTCStrategyRef = Field(description="Strategy to select the best output of the function")
|
|
37
|
+
augmented_fn: FunctionRef = Field(description="Function that will be executed")
|
|
38
|
+
|
|
39
|
+
num_executions: int = Field(3, description="Number of times to execute the function")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_function(config_type=ExecuteScoreSelectFunctionConfig)
|
|
43
|
+
async def execute_score_select_function(config: ExecuteScoreSelectFunctionConfig, builder: Builder):
|
|
44
|
+
import asyncio
|
|
45
|
+
import warnings
|
|
46
|
+
|
|
47
|
+
from pydantic import BaseModel
|
|
48
|
+
|
|
49
|
+
executable_fn: Function = builder.get_function(name=config.augmented_fn)
|
|
50
|
+
|
|
51
|
+
if config.scorer:
|
|
52
|
+
scorer = await builder.get_ttc_strategy(strategy_name=config.scorer,
|
|
53
|
+
pipeline_type=PipelineTypeEnum.AGENT_EXECUTION,
|
|
54
|
+
stage_type=StageTypeEnum.SCORING)
|
|
55
|
+
else:
|
|
56
|
+
scorer = None
|
|
57
|
+
|
|
58
|
+
selector = await builder.get_ttc_strategy(strategy_name=config.selector,
|
|
59
|
+
pipeline_type=PipelineTypeEnum.AGENT_EXECUTION,
|
|
60
|
+
stage_type=StageTypeEnum.SELECTION)
|
|
61
|
+
|
|
62
|
+
if executable_fn.has_streaming_output:
|
|
63
|
+
warnings.warn("Streaming output is not supported for this function. "
|
|
64
|
+
"The function will be executed in non-streaming mode.")
|
|
65
|
+
|
|
66
|
+
def convert_to_str(arg):
|
|
67
|
+
if isinstance(arg, BaseModel):
|
|
68
|
+
return str(arg.model_dump())
|
|
69
|
+
return str(arg)
|
|
70
|
+
|
|
71
|
+
async def execute_fn(input_msg: executable_fn.input_type) -> executable_fn.single_output_type:
|
|
72
|
+
|
|
73
|
+
logger.info("Executing function %d times", config.num_executions)
|
|
74
|
+
tasks = [executable_fn.ainvoke(input_msg) for _ in range(config.num_executions)]
|
|
75
|
+
results = await asyncio.gather(*tasks)
|
|
76
|
+
|
|
77
|
+
input_str = convert_to_str(input_msg)
|
|
78
|
+
function_outputs = [convert_to_str(out) for out in results]
|
|
79
|
+
its_items = [TTCItem(
|
|
80
|
+
input=input_str,
|
|
81
|
+
output=out,
|
|
82
|
+
) for out in function_outputs]
|
|
83
|
+
|
|
84
|
+
if scorer:
|
|
85
|
+
logger.info("Beginning scoring")
|
|
86
|
+
its_items = await scorer.ainvoke(items=its_items)
|
|
87
|
+
|
|
88
|
+
logger.info("Beginning selection")
|
|
89
|
+
selected_item = (await selector.ainvoke(items=its_items, original_prompt=its_items[0].input))[0]
|
|
90
|
+
|
|
91
|
+
# Find the index of selected item in its_items by matching the output
|
|
92
|
+
selected_output = selected_item.output
|
|
93
|
+
selected_index = -1
|
|
94
|
+
for i, item in enumerate(its_items):
|
|
95
|
+
if item.output == selected_output:
|
|
96
|
+
selected_index = i
|
|
97
|
+
break
|
|
98
|
+
|
|
99
|
+
return results[selected_index] if selected_index != -1 else selected_output
|
|
100
|
+
|
|
101
|
+
yield FunctionInfo.from_fn(
|
|
102
|
+
fn=execute_fn,
|
|
103
|
+
description=("This function executes a given function multiple times, scores the outputs, "
|
|
104
|
+
"and selects the best output based on the specified scoring and selection strategies."),
|
|
105
|
+
)
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from collections.abc import AsyncGenerator
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.builder.builder import Builder
|
|
22
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
23
|
+
from nat.builder.function_info import FunctionInfo
|
|
24
|
+
from nat.cli.register_workflow import register_function
|
|
25
|
+
from nat.data_models.api_server import ChatRequest
|
|
26
|
+
from nat.data_models.component_ref import FunctionRef
|
|
27
|
+
from nat.data_models.component_ref import TTCStrategyRef
|
|
28
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
29
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
30
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
31
|
+
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PlanSelectExecuteFunctionConfig(FunctionBaseConfig, name="plan_select_execute_function"):
|
|
37
|
+
"""
|
|
38
|
+
Defines a NAT function that performs reasoning on the input data.
|
|
39
|
+
Output is passed to the next function in the workflow.
|
|
40
|
+
|
|
41
|
+
Designed to be used with an InterceptingFunction.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
augmented_fn: FunctionRef = Field(description="The name of the function to reason on.")
|
|
45
|
+
|
|
46
|
+
planner: TTCStrategyRef = Field(description="The configuration for the planner.")
|
|
47
|
+
editor: TTCStrategyRef | None = Field(description="The configuration for the editor.", default=None)
|
|
48
|
+
scorer: TTCStrategyRef | None = Field(description="The configuration for the scorer.", default=None)
|
|
49
|
+
selector: TTCStrategyRef = Field(description="The configuration for the selector.")
|
|
50
|
+
|
|
51
|
+
verbose: bool = Field(default=False, description="Whether to log detailed information.")
|
|
52
|
+
agent_context_prompt_template: str = Field(
|
|
53
|
+
description="The template for the agent context prompt. This prompt is used to provide context about the agent",
|
|
54
|
+
default=("\nThe agent system has the following description:\n"
|
|
55
|
+
"{description}\n"
|
|
56
|
+
"And has access to the following tools with functionality:\n"
|
|
57
|
+
"{tools}\n\n"))
|
|
58
|
+
|
|
59
|
+
downstream_template: str = Field(
|
|
60
|
+
description=("The template for the downstream prompt. This prompt is used to provide the reasoning output to"
|
|
61
|
+
" the executing agent"),
|
|
62
|
+
default=("Answer the following question based on message history: {input_text}"
|
|
63
|
+
"\n\nHere is a plan for execution that you could use to guide you if you wanted to:"
|
|
64
|
+
"\n\n{reasoning_output}"
|
|
65
|
+
"\n\nNOTE: Remember to follow your guidance on how to format output, etc."
|
|
66
|
+
"\n\n You must respond with the answer to the original question directly to the user."))
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@register_function(config_type=PlanSelectExecuteFunctionConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
70
|
+
async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig, builder: Builder):
|
|
71
|
+
"""
|
|
72
|
+
Build a ExecutionPlanningFunction from the provided config.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
config (ExecutionPlanningFunctionConfig): The config for the ExecutionPlanningFunction.
|
|
76
|
+
builder (Builder): The Builder instance to use for building the function.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
ExecutionPlanningFunction: The built ExecutionPlanningFunction.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
from langchain_core.prompts import PromptTemplate
|
|
84
|
+
except ImportError:
|
|
85
|
+
raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
|
|
86
|
+
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
87
|
+
|
|
88
|
+
# Get the augmented function's description
|
|
89
|
+
augmented_function = builder.get_function(config.augmented_fn)
|
|
90
|
+
|
|
91
|
+
# For now, we rely on runtime checking for type conversion
|
|
92
|
+
|
|
93
|
+
if augmented_function.description and augmented_function.description != "":
|
|
94
|
+
augmented_function_desc = augmented_function.description
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(f"Function {config.augmented_fn} does not have a description. Cannot augment "
|
|
97
|
+
f"function without a description.")
|
|
98
|
+
|
|
99
|
+
# Get the function dependencies of the augmented function
|
|
100
|
+
function_used_tools = builder.get_function_dependencies(config.augmented_fn).functions
|
|
101
|
+
tool_list = "Tool: Description\n"
|
|
102
|
+
|
|
103
|
+
for tool in function_used_tools:
|
|
104
|
+
tool_impl = builder.get_function(tool)
|
|
105
|
+
tool_list += f"- {tool}: {tool_impl.description if hasattr(tool_impl, 'description') else ''}\n"
|
|
106
|
+
|
|
107
|
+
# Draft the reasoning prompt for the augmented function
|
|
108
|
+
template = PromptTemplate(template=config.agent_context_prompt_template,
|
|
109
|
+
input_variables=["description", "tools"],
|
|
110
|
+
validate_template=True)
|
|
111
|
+
|
|
112
|
+
downstream_template = PromptTemplate(template=config.downstream_template,
|
|
113
|
+
input_variables=["input_text", "reasoning_output"],
|
|
114
|
+
validate_template=True)
|
|
115
|
+
|
|
116
|
+
planner = await builder.get_ttc_strategy(strategy_name=config.planner,
|
|
117
|
+
pipeline_type=PipelineTypeEnum.PLANNING,
|
|
118
|
+
stage_type=StageTypeEnum.SEARCH)
|
|
119
|
+
|
|
120
|
+
selector = await builder.get_ttc_strategy(strategy_name=config.selector,
|
|
121
|
+
pipeline_type=PipelineTypeEnum.PLANNING,
|
|
122
|
+
stage_type=StageTypeEnum.SELECTION)
|
|
123
|
+
|
|
124
|
+
if config.editor:
|
|
125
|
+
editor = await builder.get_ttc_strategy(strategy_name=config.editor,
|
|
126
|
+
pipeline_type=PipelineTypeEnum.PLANNING,
|
|
127
|
+
stage_type=StageTypeEnum.EDITING)
|
|
128
|
+
else:
|
|
129
|
+
editor = None
|
|
130
|
+
|
|
131
|
+
if config.scorer:
|
|
132
|
+
scorer = await builder.get_ttc_strategy(strategy_name=config.scorer,
|
|
133
|
+
pipeline_type=PipelineTypeEnum.PLANNING,
|
|
134
|
+
stage_type=StageTypeEnum.SCORING)
|
|
135
|
+
else:
|
|
136
|
+
scorer = None
|
|
137
|
+
|
|
138
|
+
async def planning_pipeline(prompt, context):
|
|
139
|
+
|
|
140
|
+
plans = await planner.ainvoke([TTCItem()], prompt, context)
|
|
141
|
+
|
|
142
|
+
if editor:
|
|
143
|
+
plans = await editor.ainvoke(plans, prompt, context)
|
|
144
|
+
if scorer:
|
|
145
|
+
plans = await scorer.ainvoke(plans, prompt, context)
|
|
146
|
+
|
|
147
|
+
selected_plan = (await selector.ainvoke(plans, prompt, context))[0]
|
|
148
|
+
|
|
149
|
+
return selected_plan
|
|
150
|
+
|
|
151
|
+
streaming_inner_fn = None
|
|
152
|
+
single_inner_fn = None
|
|
153
|
+
|
|
154
|
+
if augmented_function.has_streaming_output:
|
|
155
|
+
|
|
156
|
+
async def streaming_inner(
|
|
157
|
+
input_message: ChatRequest) -> AsyncGenerator[augmented_function.streaming_output_type]:
|
|
158
|
+
"""
|
|
159
|
+
Perform reasoning on the input text.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
input_message (ChatRequest): The input text to reason on.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
input_text = "".join([str(message.model_dump()) + "\n" for message in input_message.messages])
|
|
166
|
+
|
|
167
|
+
context_prompt = await template.ainvoke(input={"description": augmented_function_desc, "tools": tool_list})
|
|
168
|
+
|
|
169
|
+
context_prompt = context_prompt.to_string()
|
|
170
|
+
|
|
171
|
+
# Run the TTC pipeline
|
|
172
|
+
planning_item: TTCItem = await planning_pipeline(prompt=input_text, context=context_prompt)
|
|
173
|
+
|
|
174
|
+
output = await downstream_template.ainvoke(input={
|
|
175
|
+
"input_text": input_text, "reasoning_output": planning_item.plan
|
|
176
|
+
})
|
|
177
|
+
|
|
178
|
+
output = output.to_string()
|
|
179
|
+
|
|
180
|
+
if config.verbose:
|
|
181
|
+
logger.info("Reasoning plan and input to agent: \n\n%s", output)
|
|
182
|
+
|
|
183
|
+
async for chunk in augmented_function.acall_stream(output):
|
|
184
|
+
yield chunk
|
|
185
|
+
|
|
186
|
+
streaming_inner_fn = streaming_inner
|
|
187
|
+
|
|
188
|
+
if augmented_function.has_single_output:
|
|
189
|
+
|
|
190
|
+
async def single_inner(input_message: ChatRequest) -> augmented_function.single_output_type:
|
|
191
|
+
"""
|
|
192
|
+
Perform reasoning on the input text.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
input_message (ChatRequest): The input text to reason on.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
input_text = "".join([str(message.model_dump()) + "\n" for message in input_message.messages])
|
|
199
|
+
|
|
200
|
+
context_prompt = await template.ainvoke(input={"description": augmented_function_desc, "tools": tool_list})
|
|
201
|
+
|
|
202
|
+
context_prompt = context_prompt.to_string()
|
|
203
|
+
|
|
204
|
+
# Run the TTC pipeline
|
|
205
|
+
planning_item: TTCItem = await planning_pipeline(prompt=input_text, context=context_prompt)
|
|
206
|
+
|
|
207
|
+
output = await downstream_template.ainvoke(input={
|
|
208
|
+
"input_text": input_text, "reasoning_output": planning_item.plan
|
|
209
|
+
})
|
|
210
|
+
|
|
211
|
+
output = output.to_string()
|
|
212
|
+
|
|
213
|
+
if config.verbose:
|
|
214
|
+
logger.info("Reasoning plan and input to agent: \n\n%s", output)
|
|
215
|
+
|
|
216
|
+
return await augmented_function.acall_invoke(output)
|
|
217
|
+
|
|
218
|
+
single_inner_fn = single_inner
|
|
219
|
+
|
|
220
|
+
yield FunctionInfo.create(
|
|
221
|
+
single_fn=single_inner_fn,
|
|
222
|
+
stream_fn=streaming_inner_fn,
|
|
223
|
+
description=("Function that runs an TTC execution planner on input and sends plan downstream"),
|
|
224
|
+
converters=augmented_function.converter_list)
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.builder.builder import Builder
|
|
22
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
23
|
+
from nat.builder.function_info import FunctionInfo
|
|
24
|
+
from nat.cli.register_workflow import register_function
|
|
25
|
+
from nat.data_models.component_ref import FunctionRef
|
|
26
|
+
from nat.data_models.component_ref import TTCStrategyRef
|
|
27
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
28
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
29
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
30
|
+
from nat.experimental.test_time_compute.models.tool_use_config import ToolUseInputSchema
|
|
31
|
+
from nat.experimental.test_time_compute.models.tool_use_config import ToolUselist
|
|
32
|
+
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TTCToolOrchestrationFunctionConfig(FunctionBaseConfig, name="ttc_tool_orchestration"):
|
|
38
|
+
"""
|
|
39
|
+
Configuration for the TTCToolOrchestrationFunction, which is used to orchestrate multiple functions.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
augmented_fns: list[FunctionRef] = Field(
|
|
43
|
+
description="list of FunctionRefs for the functions to be orchestrated. Must be wrapped in `ttc_tool_wrapper`.")
|
|
44
|
+
|
|
45
|
+
search_strategy: TTCStrategyRef | None = Field(
|
|
46
|
+
description="The TTC search strategy to use for orchestrating invocation of the functions."
|
|
47
|
+
" If None, no search will be performed.",
|
|
48
|
+
default=None,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
editing_strategy: TTCStrategyRef | None = Field(
|
|
52
|
+
default=None,
|
|
53
|
+
description="The TTC editing strategy to use for orchestrating invocation of the functions. "
|
|
54
|
+
"If None, no editing will be performed.",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
scoring_strategy: TTCStrategyRef | None = Field(
|
|
58
|
+
default=None,
|
|
59
|
+
description="The TTC scoring strategy to use for orchestrating invocation of the functions. "
|
|
60
|
+
"If None, no scoring will be performed.",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
selection_strategy: TTCStrategyRef = Field(
|
|
64
|
+
description="The TTC selection strategy to use for orchestrating invocation of the functions.")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@register_function(config_type=TTCToolOrchestrationFunctionConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
68
|
+
async def register_ttc_tool_orchestration_function(
|
|
69
|
+
config: TTCToolOrchestrationFunctionConfig,
|
|
70
|
+
builder: Builder,
|
|
71
|
+
):
|
|
72
|
+
"""
|
|
73
|
+
Registers an TTC-based orchestration function that:
|
|
74
|
+
1. Instantiates all relevant strategies (search, editing, scoring, selection).
|
|
75
|
+
2. Accepts a ToolUselist, converts each item to an TTCItem, optionally runs search/editing.
|
|
76
|
+
3. Calls the correct augmented_fn per item using name=tool name.
|
|
77
|
+
4. If configured, runs scoring and selection on the result.
|
|
78
|
+
5. Returns a new ToolUselist with each output set.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# 1) Gather references to all augmented (wrapped) functions
|
|
82
|
+
function_map = {}
|
|
83
|
+
for fn_ref in config.augmented_fns:
|
|
84
|
+
# Retrieve the actual function from the builder
|
|
85
|
+
fn_obj = builder.get_function(fn_ref)
|
|
86
|
+
function_map[fn_ref] = fn_obj
|
|
87
|
+
|
|
88
|
+
# 2) Instantiate search, editing, scoring, selection strategies (if any)
|
|
89
|
+
search = None
|
|
90
|
+
if config.search_strategy is not None:
|
|
91
|
+
search = await builder.get_ttc_strategy(
|
|
92
|
+
strategy_name=config.search_strategy,
|
|
93
|
+
pipeline_type=PipelineTypeEnum.TOOL_USE,
|
|
94
|
+
stage_type=StageTypeEnum.SEARCH,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
editing = None
|
|
98
|
+
if config.editing_strategy is not None:
|
|
99
|
+
editing = await builder.get_ttc_strategy(
|
|
100
|
+
strategy_name=config.editing_strategy,
|
|
101
|
+
pipeline_type=PipelineTypeEnum.TOOL_USE,
|
|
102
|
+
stage_type=StageTypeEnum.EDITING,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
scoring = None
|
|
106
|
+
if config.scoring_strategy is not None:
|
|
107
|
+
scoring = await builder.get_ttc_strategy(
|
|
108
|
+
strategy_name=config.scoring_strategy,
|
|
109
|
+
pipeline_type=PipelineTypeEnum.TOOL_USE,
|
|
110
|
+
stage_type=StageTypeEnum.SCORING,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
selection = await builder.get_ttc_strategy(
|
|
114
|
+
strategy_name=config.selection_strategy,
|
|
115
|
+
pipeline_type=PipelineTypeEnum.TOOL_USE,
|
|
116
|
+
stage_type=StageTypeEnum.SELECTION,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
fn_description = ("\n".join(f"- **{fn_ref}**: {function_map[fn_ref].description or 'No description provided.'}"
|
|
120
|
+
for fn_ref in config.augmented_fns))
|
|
121
|
+
|
|
122
|
+
# 3) Create the inner function to handle single (non-streaming) calls.
|
|
123
|
+
async def single_inner(tool_list: ToolUselist) -> ToolUselist:
|
|
124
|
+
"""
|
|
125
|
+
Orchestrates multiple tool usages, optionally using search/editing/scoring/selection steps.
|
|
126
|
+
"""
|
|
127
|
+
# Convert each ToolUseInputSchema to TTCItem
|
|
128
|
+
ttc_items = []
|
|
129
|
+
for t in tool_list.tools:
|
|
130
|
+
item = TTCItem(
|
|
131
|
+
input=t.task_description, # The user "task"
|
|
132
|
+
output=None,
|
|
133
|
+
name=t.tool_name, # The "tool name"
|
|
134
|
+
metadata=t.motivation, # The "justification"
|
|
135
|
+
)
|
|
136
|
+
ttc_items.append(item)
|
|
137
|
+
|
|
138
|
+
# Run search strategy if present
|
|
139
|
+
if search is not None:
|
|
140
|
+
ttc_items = await search.ainvoke(ttc_items)
|
|
141
|
+
|
|
142
|
+
logger.info("TTC orchestration function: %d items after search", len(ttc_items))
|
|
143
|
+
|
|
144
|
+
# Invoke the correct augmented function for each item concurrently
|
|
145
|
+
# Helper coroutine to invoke a tool function and capture result or error
|
|
146
|
+
async def _invoke_tool(item: TTCItem, fn):
|
|
147
|
+
try:
|
|
148
|
+
result = await fn.acall_invoke(item.output)
|
|
149
|
+
return item, result, None
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.error(f"Error invoking function '{item.name}': {e}")
|
|
152
|
+
return item, None, str(e)
|
|
153
|
+
|
|
154
|
+
tasks = []
|
|
155
|
+
for item in ttc_items:
|
|
156
|
+
if item.name not in function_map:
|
|
157
|
+
logger.error(f"Function '{item.name}' not found in function map.")
|
|
158
|
+
item.output = f"Error: Function '{item.name}' not found in function map. Check your input"
|
|
159
|
+
else:
|
|
160
|
+
fn = function_map[item.name]
|
|
161
|
+
tasks.append(_invoke_tool(item, fn))
|
|
162
|
+
|
|
163
|
+
# Await all tasks and assign outputs
|
|
164
|
+
if tasks:
|
|
165
|
+
results = await asyncio.gather(*tasks)
|
|
166
|
+
for item, result, error in results:
|
|
167
|
+
if error:
|
|
168
|
+
item.output = f"Error invoking function '{item.name}': {error}"
|
|
169
|
+
else:
|
|
170
|
+
item.output = result
|
|
171
|
+
|
|
172
|
+
if editing:
|
|
173
|
+
ttc_items = await editing.ainvoke(ttc_items)
|
|
174
|
+
|
|
175
|
+
# Run scoring strategy if present
|
|
176
|
+
if scoring is not None:
|
|
177
|
+
ttc_items = await scoring.ainvoke(ttc_items)
|
|
178
|
+
|
|
179
|
+
# Run selection strategy
|
|
180
|
+
if selection is not None:
|
|
181
|
+
ttc_items = await selection.ainvoke(ttc_items)
|
|
182
|
+
|
|
183
|
+
logger.info("TTC orchestration function: %d items after selection", len(ttc_items))
|
|
184
|
+
|
|
185
|
+
# Convert final results from TTCItems back to a ToolUselist
|
|
186
|
+
final_list = ToolUselist(tools=[])
|
|
187
|
+
for item in ttc_items:
|
|
188
|
+
# Compose a new ToolUseInputSchema with final output
|
|
189
|
+
new_tool = ToolUseInputSchema(
|
|
190
|
+
tool_name=item.name,
|
|
191
|
+
task_description=str(item.input),
|
|
192
|
+
motivation=item.metadata if item.metadata else None,
|
|
193
|
+
output=str(item.output) if item.output is not None else None,
|
|
194
|
+
)
|
|
195
|
+
final_list.tools.append(new_tool)
|
|
196
|
+
|
|
197
|
+
return final_list
|
|
198
|
+
|
|
199
|
+
# 4) Return the function info (only a single_fn is needed; no streaming)
|
|
200
|
+
yield FunctionInfo.create(
|
|
201
|
+
single_fn=single_inner,
|
|
202
|
+
stream_fn=None, # No streaming required
|
|
203
|
+
input_schema=ToolUselist,
|
|
204
|
+
single_output_schema=ToolUselist,
|
|
205
|
+
description=fn_description)
|