nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import re
|
|
18
|
+
from collections.abc import AsyncGenerator
|
|
19
|
+
|
|
20
|
+
from pydantic import Field
|
|
21
|
+
|
|
22
|
+
from nat.builder.builder import Builder
|
|
23
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
24
|
+
from nat.builder.function_info import FunctionInfo
|
|
25
|
+
from nat.cli.register_workflow import register_function
|
|
26
|
+
from nat.data_models.api_server import ChatRequest
|
|
27
|
+
from nat.data_models.component_ref import FunctionRef
|
|
28
|
+
from nat.data_models.component_ref import LLMRef
|
|
29
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ReasoningFunctionConfig(FunctionBaseConfig, name="reasoning_agent"):
|
|
35
|
+
"""
|
|
36
|
+
Defines a NAT function that performs reasoning on the input data.
|
|
37
|
+
Output is passed to the next function in the workflow.
|
|
38
|
+
|
|
39
|
+
Designed to be used with an InterceptingFunction.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
llm_name: LLMRef = Field(description="The name of the LLM to use for reasoning.")
|
|
43
|
+
augmented_fn: FunctionRef = Field(description="The name of the function to reason on.")
|
|
44
|
+
verbose: bool = Field(default=False, description="Whether to log detailed information.")
|
|
45
|
+
reasoning_prompt_template: str = Field(
|
|
46
|
+
default=("You are an expert reasoning model task with creating a detailed execution plan"
|
|
47
|
+
" for a system that has the following description:\n\n"
|
|
48
|
+
"**Description:** \n{augmented_function_desc}\n\n"
|
|
49
|
+
"Given the following input and a list of available tools, please provide a detailed step-by-step plan"
|
|
50
|
+
" that an instruction following system can use to address the input. Ensure the plan includes:\n\n"
|
|
51
|
+
"1. Identifying the key components of the input.\n"
|
|
52
|
+
"2. Determining the most suitable tools for each task.\n"
|
|
53
|
+
"3. Outlining the sequence of actions to be taken.\n\n"
|
|
54
|
+
"**Input:** \n{input_text}\n\n"
|
|
55
|
+
"**Tools and description of the tool:** \n{tools}\n\n"
|
|
56
|
+
"An example plan could look like this:\n\n"
|
|
57
|
+
"1. Call tool A with input X\n"
|
|
58
|
+
"2. Call tool B with input Y\n"
|
|
59
|
+
"3. Interpret the output of tool A and B\n"
|
|
60
|
+
"4. Return the final result"
|
|
61
|
+
"\n\n **PLAN:**\n"),
|
|
62
|
+
description="The reasoning model prompt template.")
|
|
63
|
+
|
|
64
|
+
instruction_prompt_template: str = Field(
|
|
65
|
+
default=("Answer the following question based on message history: {input_text}"
|
|
66
|
+
"\n\nHere is a plan for execution that you could use to guide you if you wanted to:"
|
|
67
|
+
"\n\n{reasoning_output}"
|
|
68
|
+
"\n\nNOTE: Remember to follow your guidance on how to format output, etc."
|
|
69
|
+
"\n\n You must respond with the answer to the original question directly to the user."),
|
|
70
|
+
description="The instruction prompt template.")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@register_function(config_type=ReasoningFunctionConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
74
|
+
async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Builder):
|
|
75
|
+
"""
|
|
76
|
+
Build a ReasoningFunction from the provided config.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
config (ReasoningFunctionConfig): The config for the ReasoningFunction.
|
|
80
|
+
builder (Builder): The Builder instance to use for building the function.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
ReasoningFunction: The built ReasoningFunction.
|
|
84
|
+
"""
|
|
85
|
+
from langchain_core.language_models import BaseChatModel
|
|
86
|
+
from langchain_core.prompts import PromptTemplate
|
|
87
|
+
|
|
88
|
+
from nat.agent.base import AGENT_LOG_PREFIX
|
|
89
|
+
|
|
90
|
+
def remove_r1_think_tags(text: str):
|
|
91
|
+
pattern = r'(<think>)?.*?</think>\s*(.*)'
|
|
92
|
+
|
|
93
|
+
# Add re.DOTALL flag to make . match newlines
|
|
94
|
+
match = re.match(pattern, text, re.DOTALL)
|
|
95
|
+
|
|
96
|
+
if match:
|
|
97
|
+
return match.group(2)
|
|
98
|
+
|
|
99
|
+
return text
|
|
100
|
+
|
|
101
|
+
# Get the LLM to use for reasoning
|
|
102
|
+
llm: BaseChatModel = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
103
|
+
|
|
104
|
+
# Get the augmented function's description
|
|
105
|
+
augmented_function = builder.get_function(config.augmented_fn)
|
|
106
|
+
|
|
107
|
+
# For now, we rely on runtime checking for type conversion
|
|
108
|
+
|
|
109
|
+
if augmented_function.description and augmented_function.description != "":
|
|
110
|
+
augmented_function_desc = augmented_function.description
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError(f"Function {config.augmented_fn} does not have a description. Cannot augment "
|
|
113
|
+
f"function without a description.")
|
|
114
|
+
|
|
115
|
+
# Get the function dependencies of the augmented function
|
|
116
|
+
function_used_tools = builder.get_function_dependencies(config.augmented_fn).functions
|
|
117
|
+
tool_names_with_desc: list[tuple[str, str]] = []
|
|
118
|
+
|
|
119
|
+
for tool in function_used_tools:
|
|
120
|
+
tool_impl = builder.get_function(tool)
|
|
121
|
+
tool_names_with_desc.append((tool, tool_impl.description if hasattr(tool_impl, "description") else ""))
|
|
122
|
+
|
|
123
|
+
# Draft the reasoning prompt for the augmented function
|
|
124
|
+
template = PromptTemplate(template=config.reasoning_prompt_template,
|
|
125
|
+
input_variables=["augmented_function_desc", "input_text", "tools"],
|
|
126
|
+
validate_template=True)
|
|
127
|
+
|
|
128
|
+
downstream_template = PromptTemplate(template=config.instruction_prompt_template,
|
|
129
|
+
input_variables=["input_text", "reasoning_output"],
|
|
130
|
+
validate_template=True)
|
|
131
|
+
|
|
132
|
+
streaming_inner_fn = None
|
|
133
|
+
single_inner_fn = None
|
|
134
|
+
|
|
135
|
+
if augmented_function.has_streaming_output:
|
|
136
|
+
|
|
137
|
+
async def streaming_inner(
|
|
138
|
+
input_message: ChatRequest) -> AsyncGenerator[augmented_function.streaming_output_type]:
|
|
139
|
+
"""
|
|
140
|
+
Perform reasoning on the input text.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
input_message (ChatRequest): The input text to reason on.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
input_text = "".join([str(message.model_dump()) + "\n" for message in input_message.messages])
|
|
147
|
+
|
|
148
|
+
prompt = await template.ainvoke(
|
|
149
|
+
input={
|
|
150
|
+
"augmented_function_desc": augmented_function_desc,
|
|
151
|
+
"input_text": input_text,
|
|
152
|
+
"tools": "\n".join([f"- {tool[0]}: {tool[1]}" for tool in tool_names_with_desc])
|
|
153
|
+
})
|
|
154
|
+
|
|
155
|
+
prompt = prompt.to_string()
|
|
156
|
+
|
|
157
|
+
# Get the reasoning output from the LLM
|
|
158
|
+
reasoning_output = ""
|
|
159
|
+
|
|
160
|
+
async for chunk in llm.astream(prompt):
|
|
161
|
+
reasoning_output += chunk.content
|
|
162
|
+
|
|
163
|
+
reasoning_output = remove_r1_think_tags(reasoning_output)
|
|
164
|
+
|
|
165
|
+
output = await downstream_template.ainvoke(input={
|
|
166
|
+
"input_text": input_text, "reasoning_output": reasoning_output
|
|
167
|
+
})
|
|
168
|
+
|
|
169
|
+
output = output.to_string()
|
|
170
|
+
|
|
171
|
+
if config.verbose:
|
|
172
|
+
logger.info("%s Reasoning plan and input to agent: \n\n%s", AGENT_LOG_PREFIX, output)
|
|
173
|
+
|
|
174
|
+
async for chunk in augmented_function.acall_stream(output):
|
|
175
|
+
yield chunk
|
|
176
|
+
|
|
177
|
+
streaming_inner_fn = streaming_inner
|
|
178
|
+
|
|
179
|
+
if augmented_function.has_single_output:
|
|
180
|
+
|
|
181
|
+
async def single_inner(input_message: ChatRequest) -> augmented_function.single_output_type:
|
|
182
|
+
"""
|
|
183
|
+
Perform reasoning on the input text.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
input_message (ChatRequest): The input text to reason on.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
input_text = "".join([str(message.model_dump()) + "\n" for message in input_message.messages])
|
|
190
|
+
|
|
191
|
+
prompt = await template.ainvoke(
|
|
192
|
+
input={
|
|
193
|
+
"augmented_function_desc": augmented_function_desc,
|
|
194
|
+
"input_text": input_text,
|
|
195
|
+
"tools": "\n".join([f"- {tool[0]}: {tool[1]}" for tool in tool_names_with_desc])
|
|
196
|
+
})
|
|
197
|
+
|
|
198
|
+
prompt = prompt.to_string()
|
|
199
|
+
|
|
200
|
+
# Get the reasoning output from the LLM
|
|
201
|
+
reasoning_output = ""
|
|
202
|
+
|
|
203
|
+
async for chunk in llm.astream(prompt):
|
|
204
|
+
reasoning_output += chunk.content
|
|
205
|
+
|
|
206
|
+
reasoning_output = remove_r1_think_tags(reasoning_output)
|
|
207
|
+
|
|
208
|
+
output = await downstream_template.ainvoke(input={
|
|
209
|
+
"input_text": input_text, "reasoning_output": reasoning_output
|
|
210
|
+
})
|
|
211
|
+
|
|
212
|
+
output = output.to_string()
|
|
213
|
+
|
|
214
|
+
if config.verbose:
|
|
215
|
+
logger.info("%s Reasoning plan and input to agent: \n\n%s", AGENT_LOG_PREFIX, output)
|
|
216
|
+
|
|
217
|
+
return await augmented_function.acall_invoke(output)
|
|
218
|
+
|
|
219
|
+
single_inner_fn = single_inner
|
|
220
|
+
|
|
221
|
+
yield FunctionInfo.create(
|
|
222
|
+
single_fn=single_inner_fn,
|
|
223
|
+
stream_fn=streaming_inner_fn,
|
|
224
|
+
description=("Reasoning function that generates a detailed execution plan for a system based on the input."),
|
|
225
|
+
converters=augmented_function.converter_list)
|
nat/agent/register.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# pylint: disable=unused-import
|
|
17
|
+
# flake8: noqa
|
|
18
|
+
|
|
19
|
+
# Import any workflows which need to be automatically registered here
|
|
20
|
+
from .react_agent import register as react_agent
|
|
21
|
+
from .reasoning_agent import reasoning_agent
|
|
22
|
+
from .rewoo_agent import register as rewoo_agent
|
|
23
|
+
from .tool_calling_agent import register as tool_calling_agent
|
|
File without changes
|
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
# pylint: disable=R0917
|
|
18
|
+
import logging
|
|
19
|
+
from json import JSONDecodeError
|
|
20
|
+
|
|
21
|
+
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
22
|
+
from langchain_core.language_models import BaseChatModel
|
|
23
|
+
from langchain_core.messages.ai import AIMessage
|
|
24
|
+
from langchain_core.messages.base import BaseMessage
|
|
25
|
+
from langchain_core.messages.human import HumanMessage
|
|
26
|
+
from langchain_core.messages.tool import ToolMessage
|
|
27
|
+
from langchain_core.prompts.chat import ChatPromptTemplate
|
|
28
|
+
from langchain_core.runnables.config import RunnableConfig
|
|
29
|
+
from langchain_core.tools import BaseTool
|
|
30
|
+
from langgraph.graph import StateGraph
|
|
31
|
+
from pydantic import BaseModel
|
|
32
|
+
from pydantic import Field
|
|
33
|
+
|
|
34
|
+
from nat.agent.base import AGENT_CALL_LOG_MESSAGE
|
|
35
|
+
from nat.agent.base import AGENT_LOG_PREFIX
|
|
36
|
+
from nat.agent.base import INPUT_SCHEMA_MESSAGE
|
|
37
|
+
from nat.agent.base import NO_INPUT_ERROR_MESSAGE
|
|
38
|
+
from nat.agent.base import TOOL_NOT_FOUND_ERROR_MESSAGE
|
|
39
|
+
from nat.agent.base import AgentDecision
|
|
40
|
+
from nat.agent.base import BaseAgent
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ReWOOGraphState(BaseModel):
|
|
46
|
+
"""State schema for the ReWOO Agent Graph"""
|
|
47
|
+
messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReWOO Agent
|
|
48
|
+
task: HumanMessage = Field(default_factory=lambda: HumanMessage(content="")) # the task provided by user
|
|
49
|
+
plan: AIMessage = Field(
|
|
50
|
+
default_factory=lambda: AIMessage(content="")) # the plan generated by the planner to solve the task
|
|
51
|
+
steps: AIMessage = Field(
|
|
52
|
+
default_factory=lambda: AIMessage(content="")) # the steps to solve the task, parsed from the plan
|
|
53
|
+
intermediate_results: dict[str, ToolMessage] = Field(default_factory=dict) # the intermediate results of each step
|
|
54
|
+
result: AIMessage = Field(
|
|
55
|
+
default_factory=lambda: AIMessage(content="")) # the final result of the task, generated by the solver
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ReWOOAgentGraph(BaseAgent):
|
|
59
|
+
"""Configurable LangGraph ReWOO Agent. A ReWOO Agent performs reasoning by interacting with other objects or tools
|
|
60
|
+
and utilizes their outputs to make decisions. Supports retrying on output parsing errors. Argument
|
|
61
|
+
"detailed_logs" toggles logging of inputs, outputs, and intermediate steps."""
|
|
62
|
+
|
|
63
|
+
def __init__(self,
|
|
64
|
+
llm: BaseChatModel,
|
|
65
|
+
planner_prompt: ChatPromptTemplate,
|
|
66
|
+
solver_prompt: ChatPromptTemplate,
|
|
67
|
+
tools: list[BaseTool],
|
|
68
|
+
use_tool_schema: bool = True,
|
|
69
|
+
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
70
|
+
detailed_logs: bool = False):
|
|
71
|
+
super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
|
|
72
|
+
|
|
73
|
+
logger.debug(
|
|
74
|
+
"%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
|
|
75
|
+
AGENT_LOG_PREFIX)
|
|
76
|
+
tool_names = ",".join([tool.name for tool in tools[:-1]]) + ',' + tools[-1].name # prevent trailing ","
|
|
77
|
+
if not use_tool_schema:
|
|
78
|
+
tool_names_and_descriptions = "\n".join(
|
|
79
|
+
[f"{tool.name}: {tool.description}"
|
|
80
|
+
for tool in tools[:-1]]) + "\n" + f"{tools[-1].name}: {tools[-1].description}" # prevent trailing "\n"
|
|
81
|
+
else:
|
|
82
|
+
logger.debug("%s Adding the tools' input schema to the tools' description", AGENT_LOG_PREFIX)
|
|
83
|
+
tool_names_and_descriptions = "\n".join([
|
|
84
|
+
f"{tool.name}: {tool.description}. {INPUT_SCHEMA_MESSAGE.format(schema=tool.input_schema.model_fields)}"
|
|
85
|
+
for tool in tools[:-1]
|
|
86
|
+
]) + "\n" + (f"{tools[-1].name}: {tools[-1].description}. "
|
|
87
|
+
f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
|
|
88
|
+
|
|
89
|
+
self.planner_prompt = planner_prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
|
|
90
|
+
self.solver_prompt = solver_prompt
|
|
91
|
+
self.tools_dict = {tool.name: tool for tool in tools}
|
|
92
|
+
|
|
93
|
+
logger.debug("%s Initialized ReWOO Agent Graph", AGENT_LOG_PREFIX)
|
|
94
|
+
|
|
95
|
+
def _get_tool(self, tool_name: str):
|
|
96
|
+
try:
|
|
97
|
+
return self.tools_dict.get(tool_name)
|
|
98
|
+
except Exception as ex:
|
|
99
|
+
logger.exception("%s Unable to find tool with the name %s\n%s",
|
|
100
|
+
AGENT_LOG_PREFIX,
|
|
101
|
+
tool_name,
|
|
102
|
+
ex,
|
|
103
|
+
exc_info=True)
|
|
104
|
+
raise ex
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def _get_current_step(state: ReWOOGraphState) -> int:
|
|
108
|
+
steps = state.steps.content
|
|
109
|
+
if len(steps) == 0:
|
|
110
|
+
raise RuntimeError('No steps received in ReWOOGraphState')
|
|
111
|
+
|
|
112
|
+
if len(state.intermediate_results) == len(steps):
|
|
113
|
+
# all steps are done
|
|
114
|
+
return -1
|
|
115
|
+
|
|
116
|
+
return len(state.intermediate_results)
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def _parse_planner_output(planner_output: str) -> AIMessage:
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
steps = json.loads(planner_output)
|
|
123
|
+
except json.JSONDecodeError as ex:
|
|
124
|
+
raise ValueError(f"The output of planner is invalid JSON format: {planner_output}") from ex
|
|
125
|
+
|
|
126
|
+
return AIMessage(content=steps)
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def _replace_placeholder(placeholder: str, tool_input: str | dict, tool_output: str | dict) -> str | dict:
|
|
130
|
+
|
|
131
|
+
# Replace the placeholders in the tool input with the previous tool output
|
|
132
|
+
if isinstance(tool_input, dict):
|
|
133
|
+
for key, value in tool_input.items():
|
|
134
|
+
if value is not None:
|
|
135
|
+
if value == placeholder:
|
|
136
|
+
tool_input[key] = tool_output
|
|
137
|
+
elif placeholder in value:
|
|
138
|
+
# If the placeholder is part of the value, replace it with the stringified output
|
|
139
|
+
tool_input[key] = value.replace(placeholder, str(tool_output))
|
|
140
|
+
|
|
141
|
+
elif isinstance(tool_input, str):
|
|
142
|
+
tool_input = tool_input.replace(placeholder, str(tool_output))
|
|
143
|
+
|
|
144
|
+
else:
|
|
145
|
+
assert False, f"Unexpected type for tool_input: {type(tool_input)}"
|
|
146
|
+
return tool_input
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def _parse_tool_input(tool_input: str | dict):
|
|
150
|
+
|
|
151
|
+
# If the input is already a dictionary, return it as is
|
|
152
|
+
if isinstance(tool_input, dict):
|
|
153
|
+
logger.debug("%s Tool input is already a dictionary. Use the tool input as is.", AGENT_LOG_PREFIX)
|
|
154
|
+
return tool_input
|
|
155
|
+
|
|
156
|
+
# If the input is a string, attempt to parse it as JSON
|
|
157
|
+
try:
|
|
158
|
+
tool_input = tool_input.strip()
|
|
159
|
+
# If the input is already a valid JSON string, load it
|
|
160
|
+
tool_input_parsed = json.loads(tool_input)
|
|
161
|
+
logger.debug("%s Successfully parsed structured tool input", AGENT_LOG_PREFIX)
|
|
162
|
+
|
|
163
|
+
except JSONDecodeError:
|
|
164
|
+
try:
|
|
165
|
+
# Replace single quotes with double quotes and attempt parsing again
|
|
166
|
+
tool_input_fixed = tool_input.replace("'", '"')
|
|
167
|
+
tool_input_parsed = json.loads(tool_input_fixed)
|
|
168
|
+
logger.debug(
|
|
169
|
+
"%s Successfully parsed structured tool input after replacing single quotes with double quotes",
|
|
170
|
+
AGENT_LOG_PREFIX)
|
|
171
|
+
|
|
172
|
+
except JSONDecodeError:
|
|
173
|
+
# If it still fails, fall back to using the input as a raw string
|
|
174
|
+
tool_input_parsed = tool_input
|
|
175
|
+
logger.debug("%s Unable to parse structured tool input. Using raw tool input as is.", AGENT_LOG_PREFIX)
|
|
176
|
+
|
|
177
|
+
return tool_input_parsed
|
|
178
|
+
|
|
179
|
+
async def planner_node(self, state: ReWOOGraphState):
|
|
180
|
+
try:
|
|
181
|
+
logger.debug("%s Starting the ReWOO Planner Node", AGENT_LOG_PREFIX)
|
|
182
|
+
|
|
183
|
+
planner = self.planner_prompt | self.llm
|
|
184
|
+
task = str(state.task.content)
|
|
185
|
+
if not task:
|
|
186
|
+
logger.error("%s No task provided to the ReWOO Agent. Please provide a valid task.", AGENT_LOG_PREFIX)
|
|
187
|
+
return {"result": NO_INPUT_ERROR_MESSAGE}
|
|
188
|
+
chat_history = self._get_chat_history(state.messages)
|
|
189
|
+
plan = await self._stream_llm(
|
|
190
|
+
planner,
|
|
191
|
+
{
|
|
192
|
+
"task": task, "chat_history": chat_history
|
|
193
|
+
},
|
|
194
|
+
RunnableConfig(callbacks=self.callbacks) # type: ignore
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
steps = self._parse_planner_output(str(plan.content))
|
|
198
|
+
|
|
199
|
+
if self.detailed_logs:
|
|
200
|
+
agent_response_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(plan.content))
|
|
201
|
+
logger.info("ReWOO agent planner output: %s", agent_response_log_message)
|
|
202
|
+
|
|
203
|
+
return {"plan": plan, "steps": steps}
|
|
204
|
+
|
|
205
|
+
except Exception as ex:
|
|
206
|
+
logger.exception("%s Failed to call planner_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
|
|
207
|
+
raise ex
|
|
208
|
+
|
|
209
|
+
async def executor_node(self, state: ReWOOGraphState):
|
|
210
|
+
try:
|
|
211
|
+
logger.debug("%s Starting the ReWOO Executor Node", AGENT_LOG_PREFIX)
|
|
212
|
+
|
|
213
|
+
current_step = self._get_current_step(state)
|
|
214
|
+
# The executor node should not be invoked after all steps are finished
|
|
215
|
+
if current_step < 0:
|
|
216
|
+
logger.error("%s ReWOO Executor is invoked with an invalid step number: %s",
|
|
217
|
+
AGENT_LOG_PREFIX,
|
|
218
|
+
current_step)
|
|
219
|
+
raise RuntimeError(f"ReWOO Executor is invoked with an invalid step number: {current_step}")
|
|
220
|
+
|
|
221
|
+
steps_content = state.steps.content
|
|
222
|
+
if isinstance(steps_content, list) and current_step < len(steps_content):
|
|
223
|
+
step = steps_content[current_step]
|
|
224
|
+
if isinstance(step, dict) and "evidence" in step:
|
|
225
|
+
step_info = step["evidence"]
|
|
226
|
+
placeholder = step_info.get("placeholder", "")
|
|
227
|
+
tool = step_info.get("tool", "")
|
|
228
|
+
tool_input = step_info.get("tool_input", "")
|
|
229
|
+
else:
|
|
230
|
+
logger.error("%s Invalid step format at index %s", AGENT_LOG_PREFIX, current_step)
|
|
231
|
+
return {"intermediate_results": state.intermediate_results}
|
|
232
|
+
else:
|
|
233
|
+
logger.error("%s Invalid steps content or index %s", AGENT_LOG_PREFIX, current_step)
|
|
234
|
+
return {"intermediate_results": state.intermediate_results}
|
|
235
|
+
|
|
236
|
+
intermediate_results = state.intermediate_results
|
|
237
|
+
|
|
238
|
+
# Replace the placeholder in the tool input with the previous tool output
|
|
239
|
+
for _placeholder, _tool_output in intermediate_results.items():
|
|
240
|
+
_tool_output = _tool_output.content
|
|
241
|
+
# If the content is a list, get the first element which should be a dict
|
|
242
|
+
if isinstance(_tool_output, list):
|
|
243
|
+
_tool_output = _tool_output[0]
|
|
244
|
+
assert isinstance(_tool_output, dict)
|
|
245
|
+
|
|
246
|
+
tool_input = self._replace_placeholder(_placeholder, tool_input, _tool_output)
|
|
247
|
+
|
|
248
|
+
requested_tool = self._get_tool(tool)
|
|
249
|
+
if not requested_tool:
|
|
250
|
+
configured_tool_names = list(self.tools_dict.keys())
|
|
251
|
+
logger.warning(
|
|
252
|
+
"%s ReWOO Agent wants to call tool %s. In the ReWOO Agent's configuration within the config file,"
|
|
253
|
+
"there is no tool with that name: %s",
|
|
254
|
+
AGENT_LOG_PREFIX,
|
|
255
|
+
tool,
|
|
256
|
+
configured_tool_names)
|
|
257
|
+
|
|
258
|
+
intermediate_results[placeholder] = ToolMessage(content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(
|
|
259
|
+
tool_name=tool, tools=configured_tool_names),
|
|
260
|
+
tool_call_id=tool)
|
|
261
|
+
return {"intermediate_results": intermediate_results}
|
|
262
|
+
|
|
263
|
+
if self.detailed_logs:
|
|
264
|
+
logger.debug("%s Calling tool %s with input: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_input)
|
|
265
|
+
|
|
266
|
+
# Run the tool. Try to use structured input, if possible
|
|
267
|
+
tool_input_parsed = self._parse_tool_input(tool_input)
|
|
268
|
+
tool_response = await self._call_tool(requested_tool,
|
|
269
|
+
tool_input_parsed,
|
|
270
|
+
RunnableConfig(callbacks=self.callbacks),
|
|
271
|
+
max_retries=3)
|
|
272
|
+
|
|
273
|
+
# ToolMessage only accepts str or list[str | dict] as content.
|
|
274
|
+
# Convert into list if the response is a dict.
|
|
275
|
+
if isinstance(tool_response, dict):
|
|
276
|
+
tool_response = [tool_response]
|
|
277
|
+
|
|
278
|
+
tool_response_message = ToolMessage(name=tool, tool_call_id=tool, content=tool_response)
|
|
279
|
+
|
|
280
|
+
if self.detailed_logs:
|
|
281
|
+
self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))
|
|
282
|
+
|
|
283
|
+
intermediate_results[placeholder] = tool_response_message
|
|
284
|
+
return {"intermediate_results": intermediate_results}
|
|
285
|
+
|
|
286
|
+
except Exception as ex:
|
|
287
|
+
logger.exception("%s Failed to call executor_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
|
|
288
|
+
raise ex
|
|
289
|
+
|
|
290
|
+
async def solver_node(self, state: ReWOOGraphState):
|
|
291
|
+
try:
|
|
292
|
+
logger.debug("%s Starting the ReWOO Solver Node", AGENT_LOG_PREFIX)
|
|
293
|
+
|
|
294
|
+
plan = ""
|
|
295
|
+
# Add the tool outputs of each step to the plan
|
|
296
|
+
for step in state.steps.content:
|
|
297
|
+
step_info = step["evidence"]
|
|
298
|
+
placeholder = step_info.get("placeholder", "")
|
|
299
|
+
tool_input = step_info.get("tool_input", "")
|
|
300
|
+
|
|
301
|
+
intermediate_results = state.intermediate_results
|
|
302
|
+
for _placeholder, _tool_output in intermediate_results.items():
|
|
303
|
+
_tool_output = _tool_output.content
|
|
304
|
+
# If the content is a list, get the first element which should be a dict
|
|
305
|
+
if isinstance(_tool_output, list):
|
|
306
|
+
_tool_output = _tool_output[0]
|
|
307
|
+
assert isinstance(_tool_output, dict)
|
|
308
|
+
|
|
309
|
+
tool_input = self._replace_placeholder(_placeholder, tool_input, _tool_output)
|
|
310
|
+
|
|
311
|
+
placeholder = placeholder.replace(_placeholder, str(_tool_output))
|
|
312
|
+
|
|
313
|
+
_plan = step.get("plan")
|
|
314
|
+
tool = step_info.get("tool")
|
|
315
|
+
plan += f"Plan: {_plan}\n{placeholder} = {tool}[{tool_input}]"
|
|
316
|
+
|
|
317
|
+
task = str(state.task.content)
|
|
318
|
+
solver_prompt = self.solver_prompt.partial(plan=plan)
|
|
319
|
+
solver = solver_prompt | self.llm
|
|
320
|
+
|
|
321
|
+
output_message = await self._stream_llm(solver, {"task": task},
|
|
322
|
+
RunnableConfig(callbacks=self.callbacks)) # type: ignore
|
|
323
|
+
|
|
324
|
+
if self.detailed_logs:
|
|
325
|
+
solver_output_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(output_message.content))
|
|
326
|
+
logger.info("ReWOO agent solver output: %s", solver_output_log_message)
|
|
327
|
+
|
|
328
|
+
return {"result": output_message}
|
|
329
|
+
|
|
330
|
+
except Exception as ex:
|
|
331
|
+
logger.exception("%s Failed to call solver_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
|
|
332
|
+
raise ex
|
|
333
|
+
|
|
334
|
+
async def conditional_edge(self, state: ReWOOGraphState):
|
|
335
|
+
try:
|
|
336
|
+
logger.debug("%s Starting the ReWOO Conditional Edge", AGENT_LOG_PREFIX)
|
|
337
|
+
|
|
338
|
+
current_step = self._get_current_step(state)
|
|
339
|
+
if current_step == -1:
|
|
340
|
+
logger.debug("%s The ReWOO Executor has finished its task", AGENT_LOG_PREFIX)
|
|
341
|
+
return AgentDecision.END
|
|
342
|
+
|
|
343
|
+
logger.debug("%s The ReWOO Executor is still working on the task", AGENT_LOG_PREFIX)
|
|
344
|
+
return AgentDecision.TOOL
|
|
345
|
+
|
|
346
|
+
except Exception as ex:
|
|
347
|
+
logger.exception("%s Failed to determine whether agent is calling a tool: %s",
|
|
348
|
+
AGENT_LOG_PREFIX,
|
|
349
|
+
ex,
|
|
350
|
+
exc_info=True)
|
|
351
|
+
logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
|
|
352
|
+
return AgentDecision.END
|
|
353
|
+
|
|
354
|
+
async def _build_graph(self, state_schema):
|
|
355
|
+
try:
|
|
356
|
+
logger.debug("%s Building and compiling the ReWOO Graph", AGENT_LOG_PREFIX)
|
|
357
|
+
|
|
358
|
+
graph = StateGraph(state_schema)
|
|
359
|
+
graph.add_node("planner", self.planner_node)
|
|
360
|
+
graph.add_node("executor", self.executor_node)
|
|
361
|
+
graph.add_node("solver", self.solver_node)
|
|
362
|
+
|
|
363
|
+
graph.add_edge("planner", "executor")
|
|
364
|
+
conditional_edge_possible_outputs = {AgentDecision.TOOL: "executor", AgentDecision.END: "solver"}
|
|
365
|
+
graph.add_conditional_edges("executor", self.conditional_edge, conditional_edge_possible_outputs)
|
|
366
|
+
|
|
367
|
+
graph.set_entry_point("planner")
|
|
368
|
+
graph.set_finish_point("solver")
|
|
369
|
+
|
|
370
|
+
self.graph = graph.compile()
|
|
371
|
+
logger.debug("%s ReWOO Graph built and compiled successfully", AGENT_LOG_PREFIX)
|
|
372
|
+
|
|
373
|
+
return self.graph
|
|
374
|
+
|
|
375
|
+
except Exception as ex:
|
|
376
|
+
logger.exception("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
|
|
377
|
+
raise ex
|
|
378
|
+
|
|
379
|
+
async def build_graph(self):
|
|
380
|
+
try:
|
|
381
|
+
await self._build_graph(state_schema=ReWOOGraphState)
|
|
382
|
+
logger.debug("%s ReWOO Graph built and compiled successfully", AGENT_LOG_PREFIX)
|
|
383
|
+
return self.graph
|
|
384
|
+
except Exception as ex:
|
|
385
|
+
logger.exception("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
|
|
386
|
+
raise ex
|
|
387
|
+
|
|
388
|
+
@staticmethod
|
|
389
|
+
def validate_planner_prompt(planner_prompt: str) -> bool:
|
|
390
|
+
errors = []
|
|
391
|
+
if not planner_prompt:
|
|
392
|
+
errors.append("The planner prompt cannot be empty.")
|
|
393
|
+
required_prompt_variables = {
|
|
394
|
+
"{tools}": "The planner prompt must contain {tools} so the planner agent knows about configured tools.",
|
|
395
|
+
"{tool_names}": "The planner prompt must contain {tool_names} so the planner agent knows tool names."
|
|
396
|
+
}
|
|
397
|
+
for variable_name, error_message in required_prompt_variables.items():
|
|
398
|
+
if variable_name not in planner_prompt:
|
|
399
|
+
errors.append(error_message)
|
|
400
|
+
if errors:
|
|
401
|
+
error_text = "\n".join(errors)
|
|
402
|
+
logger.exception("%s %s", AGENT_LOG_PREFIX, error_text)
|
|
403
|
+
raise ValueError(error_text)
|
|
404
|
+
return True
|
|
405
|
+
|
|
406
|
+
@staticmethod
|
|
407
|
+
def validate_solver_prompt(solver_prompt: str) -> bool:
|
|
408
|
+
errors = []
|
|
409
|
+
if not solver_prompt:
|
|
410
|
+
errors.append("The solver prompt cannot be empty.")
|
|
411
|
+
if errors:
|
|
412
|
+
error_text = "\n".join(errors)
|
|
413
|
+
logger.exception("%s %s", AGENT_LOG_PREFIX, error_text)
|
|
414
|
+
raise ValueError(error_text)
|
|
415
|
+
return True
|