nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
# pylint: disable=R0917
|
|
18
|
+
import logging
|
|
19
|
+
from json import JSONDecodeError
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
21
|
+
|
|
22
|
+
from langchain_core.agents import AgentAction
|
|
23
|
+
from langchain_core.agents import AgentFinish
|
|
24
|
+
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
25
|
+
from langchain_core.language_models import BaseChatModel
|
|
26
|
+
from langchain_core.messages.ai import AIMessage
|
|
27
|
+
from langchain_core.messages.base import BaseMessage
|
|
28
|
+
from langchain_core.messages.human import HumanMessage
|
|
29
|
+
from langchain_core.messages.tool import ToolMessage
|
|
30
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
31
|
+
from langchain_core.prompts import MessagesPlaceholder
|
|
32
|
+
from langchain_core.runnables.config import RunnableConfig
|
|
33
|
+
from langchain_core.tools import BaseTool
|
|
34
|
+
from pydantic import BaseModel
|
|
35
|
+
from pydantic import Field
|
|
36
|
+
|
|
37
|
+
from nat.agent.base import AGENT_CALL_LOG_MESSAGE
|
|
38
|
+
from nat.agent.base import AGENT_LOG_PREFIX
|
|
39
|
+
from nat.agent.base import INPUT_SCHEMA_MESSAGE
|
|
40
|
+
from nat.agent.base import NO_INPUT_ERROR_MESSAGE
|
|
41
|
+
from nat.agent.base import TOOL_NOT_FOUND_ERROR_MESSAGE
|
|
42
|
+
from nat.agent.base import AgentDecision
|
|
43
|
+
from nat.agent.dual_node import DualNodeAgent
|
|
44
|
+
from nat.agent.react_agent.output_parser import ReActOutputParser
|
|
45
|
+
from nat.agent.react_agent.output_parser import ReActOutputParserException
|
|
46
|
+
from nat.agent.react_agent.prompt import SYSTEM_PROMPT
|
|
47
|
+
from nat.agent.react_agent.prompt import USER_PROMPT
|
|
48
|
+
|
|
49
|
+
# To avoid circular imports
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
from nat.agent.react_agent.register import ReActAgentWorkflowConfig
|
|
52
|
+
|
|
53
|
+
logger = logging.getLogger(__name__)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ReActGraphState(BaseModel):
|
|
57
|
+
"""State schema for the ReAct Agent Graph"""
|
|
58
|
+
messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReAct Agent
|
|
59
|
+
agent_scratchpad: list[AgentAction] = Field(default_factory=list) # agent thoughts / intermediate steps
|
|
60
|
+
tool_responses: list[BaseMessage] = Field(default_factory=list) # the responses from any tool calls
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ReActAgentGraph(DualNodeAgent):
|
|
64
|
+
"""Configurable LangGraph ReAct Agent. A ReAct Agent performs reasoning inbetween tool calls, and utilizes the tool
|
|
65
|
+
names and descriptions to select the optimal tool. Supports retrying on output parsing errors. Argument
|
|
66
|
+
"detailed_logs" toggles logging of inputs, outputs, and intermediate steps."""
|
|
67
|
+
|
|
68
|
+
def __init__(self,
|
|
69
|
+
llm: BaseChatModel,
|
|
70
|
+
prompt: ChatPromptTemplate,
|
|
71
|
+
tools: list[BaseTool],
|
|
72
|
+
use_tool_schema: bool = True,
|
|
73
|
+
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
74
|
+
detailed_logs: bool = False,
|
|
75
|
+
retry_agent_response_parsing_errors: bool = True,
|
|
76
|
+
parse_agent_response_max_retries: int = 1,
|
|
77
|
+
tool_call_max_retries: int = 1,
|
|
78
|
+
pass_tool_call_errors_to_agent: bool = True):
|
|
79
|
+
super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
|
|
80
|
+
self.parse_agent_response_max_retries = (parse_agent_response_max_retries
|
|
81
|
+
if retry_agent_response_parsing_errors else 1)
|
|
82
|
+
self.tool_call_max_retries = tool_call_max_retries
|
|
83
|
+
self.pass_tool_call_errors_to_agent = pass_tool_call_errors_to_agent
|
|
84
|
+
logger.debug(
|
|
85
|
+
"%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
|
|
86
|
+
AGENT_LOG_PREFIX)
|
|
87
|
+
tool_names = ",".join([tool.name for tool in tools[:-1]]) + ',' + tools[-1].name # prevent trailing ","
|
|
88
|
+
if not use_tool_schema:
|
|
89
|
+
tool_names_and_descriptions = "\n".join(
|
|
90
|
+
[f"{tool.name}: {tool.description}"
|
|
91
|
+
for tool in tools[:-1]]) + "\n" + f"{tools[-1].name}: {tools[-1].description}" # prevent trailing "\n"
|
|
92
|
+
else:
|
|
93
|
+
logger.debug("%s Adding the tools' input schema to the tools' description", AGENT_LOG_PREFIX)
|
|
94
|
+
tool_names_and_descriptions = "\n".join([
|
|
95
|
+
f"{tool.name}: {tool.description}. {INPUT_SCHEMA_MESSAGE.format(schema=tool.input_schema.model_fields)}"
|
|
96
|
+
for tool in tools[:-1]
|
|
97
|
+
]) + "\n" + (f"{tools[-1].name}: {tools[-1].description}. "
|
|
98
|
+
f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
|
|
99
|
+
prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
|
|
100
|
+
# construct the ReAct Agent
|
|
101
|
+
bound_llm = llm.bind(stop=["Observation:"]) # type: ignore
|
|
102
|
+
self.agent = prompt | bound_llm
|
|
103
|
+
self.tools_dict = {tool.name: tool for tool in tools}
|
|
104
|
+
logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
|
|
105
|
+
|
|
106
|
+
def _get_tool(self, tool_name: str):
|
|
107
|
+
try:
|
|
108
|
+
return self.tools_dict.get(tool_name)
|
|
109
|
+
except Exception as ex:
|
|
110
|
+
logger.exception("%s Unable to find tool with the name %s\n%s",
|
|
111
|
+
AGENT_LOG_PREFIX,
|
|
112
|
+
tool_name,
|
|
113
|
+
ex,
|
|
114
|
+
exc_info=True)
|
|
115
|
+
raise ex
|
|
116
|
+
|
|
117
|
+
async def agent_node(self, state: ReActGraphState):
|
|
118
|
+
try:
|
|
119
|
+
logger.debug("%s Starting the ReAct Agent Node", AGENT_LOG_PREFIX)
|
|
120
|
+
# keeping a working state allows us to resolve parsing errors without polluting the agent scratchpad
|
|
121
|
+
# the agent "forgets" about the parsing error after solving it - prevents hallucinations in next cycles
|
|
122
|
+
working_state = []
|
|
123
|
+
# Starting from attempt 1 instead of 0 for logging
|
|
124
|
+
for attempt in range(1, self.parse_agent_response_max_retries + 1):
|
|
125
|
+
# the first time we are invoking the ReAct Agent, it won't have any intermediate steps / agent thoughts
|
|
126
|
+
if len(state.agent_scratchpad) == 0 and len(working_state) == 0:
|
|
127
|
+
# the user input comes from the "messages" state channel
|
|
128
|
+
if len(state.messages) == 0:
|
|
129
|
+
raise RuntimeError('No input received in state: "messages"')
|
|
130
|
+
# to check is any human input passed or not, if no input passed Agent will return the state
|
|
131
|
+
content = str(state.messages[-1].content)
|
|
132
|
+
if content.strip() == "":
|
|
133
|
+
logger.error("%s No human input passed to the agent.", AGENT_LOG_PREFIX)
|
|
134
|
+
state.messages += [AIMessage(content=NO_INPUT_ERROR_MESSAGE)]
|
|
135
|
+
return state
|
|
136
|
+
question = content
|
|
137
|
+
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
138
|
+
chat_history = self._get_chat_history(state.messages)
|
|
139
|
+
output_message = await self._stream_llm(
|
|
140
|
+
self.agent,
|
|
141
|
+
{
|
|
142
|
+
"question": question, "chat_history": chat_history
|
|
143
|
+
},
|
|
144
|
+
RunnableConfig(callbacks=self.callbacks) # type: ignore
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if self.detailed_logs:
|
|
148
|
+
logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
|
|
149
|
+
else:
|
|
150
|
+
# ReAct Agents require agentic cycles
|
|
151
|
+
# in an agentic cycle, preserve the agent's thoughts from the previous cycles,
|
|
152
|
+
# and give the agent the response from the tool it called
|
|
153
|
+
agent_scratchpad = []
|
|
154
|
+
for index, intermediate_step in enumerate(state.agent_scratchpad):
|
|
155
|
+
agent_thoughts = AIMessage(content=intermediate_step.log)
|
|
156
|
+
agent_scratchpad.append(agent_thoughts)
|
|
157
|
+
tool_response_content = str(state.tool_responses[index].content)
|
|
158
|
+
tool_response = HumanMessage(content=tool_response_content)
|
|
159
|
+
agent_scratchpad.append(tool_response)
|
|
160
|
+
agent_scratchpad += working_state
|
|
161
|
+
chat_history = self._get_chat_history(state.messages)
|
|
162
|
+
question = str(state.messages[-1].content)
|
|
163
|
+
logger.debug("%s Querying agent, attempt: %s", AGENT_LOG_PREFIX, attempt)
|
|
164
|
+
|
|
165
|
+
output_message = await self._stream_llm(
|
|
166
|
+
self.agent, {
|
|
167
|
+
"question": question, "agent_scratchpad": agent_scratchpad, "chat_history": chat_history
|
|
168
|
+
},
|
|
169
|
+
RunnableConfig(callbacks=self.callbacks))
|
|
170
|
+
|
|
171
|
+
if self.detailed_logs:
|
|
172
|
+
logger.info(AGENT_CALL_LOG_MESSAGE, question, output_message.content)
|
|
173
|
+
logger.debug("%s The agent's scratchpad (with tool result) was:\n%s",
|
|
174
|
+
AGENT_LOG_PREFIX,
|
|
175
|
+
agent_scratchpad)
|
|
176
|
+
try:
|
|
177
|
+
# check if the agent has the final answer yet
|
|
178
|
+
logger.debug("%s Successfully obtained agent response. Parsing agent's response", AGENT_LOG_PREFIX)
|
|
179
|
+
agent_output = await ReActOutputParser().aparse(output_message.content)
|
|
180
|
+
logger.debug("%s Successfully parsed agent response after %s attempts", AGENT_LOG_PREFIX, attempt)
|
|
181
|
+
if isinstance(agent_output, AgentFinish):
|
|
182
|
+
final_answer = agent_output.return_values.get('output', output_message.content)
|
|
183
|
+
logger.debug("%s The agent has finished, and has the final answer", AGENT_LOG_PREFIX)
|
|
184
|
+
# this is where we handle the final output of the Agent, we can clean-up/format/postprocess here
|
|
185
|
+
# the final answer goes in the "messages" state channel
|
|
186
|
+
state.messages += [AIMessage(content=final_answer)]
|
|
187
|
+
else:
|
|
188
|
+
# the agent wants to call a tool, ensure the thoughts are preserved for the next agentic cycle
|
|
189
|
+
agent_output.log = output_message.content
|
|
190
|
+
logger.debug("%s The agent wants to call a tool: %s", AGENT_LOG_PREFIX, agent_output.tool)
|
|
191
|
+
state.agent_scratchpad += [agent_output]
|
|
192
|
+
|
|
193
|
+
return state
|
|
194
|
+
except ReActOutputParserException as ex:
|
|
195
|
+
# the agent output did not meet the expected ReAct output format. This can happen for a few reasons:
|
|
196
|
+
# the agent mentioned a tool, but already has the final answer, this can happen with Llama models
|
|
197
|
+
# - the ReAct Agent already has the answer, and is reflecting on how it obtained the answer
|
|
198
|
+
# the agent might have also missed Action or Action Input in its output
|
|
199
|
+
logger.debug("%s Error parsing agent output\nObservation:%s\nAgent Output:\n%s",
|
|
200
|
+
AGENT_LOG_PREFIX,
|
|
201
|
+
ex.observation,
|
|
202
|
+
output_message.content)
|
|
203
|
+
if attempt == self.parse_agent_response_max_retries:
|
|
204
|
+
logger.warning(
|
|
205
|
+
"%s Failed to parse agent output after %d attempts, consider enabling or "
|
|
206
|
+
"increasing parse_agent_response_max_retries",
|
|
207
|
+
AGENT_LOG_PREFIX,
|
|
208
|
+
attempt)
|
|
209
|
+
# the final answer goes in the "messages" state channel
|
|
210
|
+
combined_content = str(ex.observation) + '\n' + str(output_message.content)
|
|
211
|
+
output_message.content = combined_content
|
|
212
|
+
state.messages += [output_message]
|
|
213
|
+
return state
|
|
214
|
+
# retry parsing errors, if configured
|
|
215
|
+
logger.info("%s Retrying ReAct Agent, including output parsing Observation", AGENT_LOG_PREFIX)
|
|
216
|
+
working_state.append(output_message)
|
|
217
|
+
working_state.append(HumanMessage(content=str(ex.observation)))
|
|
218
|
+
except Exception as ex:
|
|
219
|
+
logger.exception("%s Failed to call agent_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
|
|
220
|
+
raise ex
|
|
221
|
+
|
|
222
|
+
async def conditional_edge(self, state: ReActGraphState):
|
|
223
|
+
try:
|
|
224
|
+
logger.debug("%s Starting the ReAct Conditional Edge", AGENT_LOG_PREFIX)
|
|
225
|
+
if len(state.messages) > 1:
|
|
226
|
+
# the ReAct Agent has finished executing, the last agent output was AgentFinish
|
|
227
|
+
last_message_content = str(state.messages[-1].content)
|
|
228
|
+
logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, last_message_content)
|
|
229
|
+
return AgentDecision.END
|
|
230
|
+
# else the agent wants to call a tool
|
|
231
|
+
agent_output = state.agent_scratchpad[-1]
|
|
232
|
+
logger.debug("%s The agent wants to call: %s with input: %s",
|
|
233
|
+
AGENT_LOG_PREFIX,
|
|
234
|
+
agent_output.tool,
|
|
235
|
+
agent_output.tool_input)
|
|
236
|
+
return AgentDecision.TOOL
|
|
237
|
+
except Exception as ex:
|
|
238
|
+
logger.exception("Failed to determine whether agent is calling a tool: %s", ex, exc_info=True)
|
|
239
|
+
logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
|
|
240
|
+
return AgentDecision.END
|
|
241
|
+
|
|
242
|
+
async def tool_node(self, state: ReActGraphState):
|
|
243
|
+
|
|
244
|
+
logger.debug("%s Starting the Tool Call Node", AGENT_LOG_PREFIX)
|
|
245
|
+
if len(state.agent_scratchpad) == 0:
|
|
246
|
+
raise RuntimeError('No tool input received in state: "agent_scratchpad"')
|
|
247
|
+
agent_thoughts = state.agent_scratchpad[-1]
|
|
248
|
+
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
249
|
+
requested_tool = self._get_tool(agent_thoughts.tool)
|
|
250
|
+
if not requested_tool:
|
|
251
|
+
configured_tool_names = list(self.tools_dict.keys())
|
|
252
|
+
logger.warning(
|
|
253
|
+
"%s ReAct Agent wants to call tool %s. In the ReAct Agent's configuration within the config file,"
|
|
254
|
+
"there is no tool with that name: %s",
|
|
255
|
+
AGENT_LOG_PREFIX,
|
|
256
|
+
agent_thoughts.tool,
|
|
257
|
+
configured_tool_names)
|
|
258
|
+
tool_response = ToolMessage(name='agent_error',
|
|
259
|
+
tool_call_id='agent_error',
|
|
260
|
+
content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(tool_name=agent_thoughts.tool,
|
|
261
|
+
tools=configured_tool_names))
|
|
262
|
+
state.tool_responses += [tool_response]
|
|
263
|
+
return state
|
|
264
|
+
|
|
265
|
+
logger.debug("%s Calling tool %s with input: %s",
|
|
266
|
+
AGENT_LOG_PREFIX,
|
|
267
|
+
requested_tool.name,
|
|
268
|
+
agent_thoughts.tool_input)
|
|
269
|
+
|
|
270
|
+
# Run the tool. Try to use structured input, if possible.
|
|
271
|
+
try:
|
|
272
|
+
tool_input_str = str(agent_thoughts.tool_input).strip().replace("'", '"')
|
|
273
|
+
tool_input_dict = json.loads(tool_input_str) if tool_input_str != 'None' else tool_input_str
|
|
274
|
+
logger.debug("%s Successfully parsed structured tool input from Action Input", AGENT_LOG_PREFIX)
|
|
275
|
+
|
|
276
|
+
tool_response = await self._call_tool(requested_tool,
|
|
277
|
+
tool_input_dict,
|
|
278
|
+
RunnableConfig(callbacks=self.callbacks),
|
|
279
|
+
max_retries=self.tool_call_max_retries)
|
|
280
|
+
|
|
281
|
+
if self.detailed_logs:
|
|
282
|
+
self._log_tool_response(requested_tool.name, tool_input_dict, str(tool_response.content))
|
|
283
|
+
|
|
284
|
+
except JSONDecodeError as ex:
|
|
285
|
+
logger.debug(
|
|
286
|
+
"%s Unable to parse structured tool input from Action Input. Using Action Input as is."
|
|
287
|
+
"\nParsing error: %s",
|
|
288
|
+
AGENT_LOG_PREFIX,
|
|
289
|
+
ex,
|
|
290
|
+
exc_info=True)
|
|
291
|
+
tool_input_str = str(agent_thoughts.tool_input)
|
|
292
|
+
|
|
293
|
+
tool_response = await self._call_tool(requested_tool,
|
|
294
|
+
tool_input_str,
|
|
295
|
+
RunnableConfig(callbacks=self.callbacks),
|
|
296
|
+
max_retries=self.tool_call_max_retries)
|
|
297
|
+
|
|
298
|
+
if self.detailed_logs:
|
|
299
|
+
self._log_tool_response(requested_tool.name, tool_input_str, str(tool_response.content))
|
|
300
|
+
|
|
301
|
+
if not self.pass_tool_call_errors_to_agent:
|
|
302
|
+
if tool_response.status == "error":
|
|
303
|
+
logger.error("%s Tool %s failed: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_response.content)
|
|
304
|
+
raise RuntimeError("Tool call failed: " + str(tool_response.content))
|
|
305
|
+
|
|
306
|
+
state.tool_responses += [tool_response]
|
|
307
|
+
return state
|
|
308
|
+
|
|
309
|
+
async def build_graph(self):
|
|
310
|
+
try:
|
|
311
|
+
await super()._build_graph(state_schema=ReActGraphState)
|
|
312
|
+
logger.debug("%s ReAct Graph built and compiled successfully", AGENT_LOG_PREFIX)
|
|
313
|
+
return self.graph
|
|
314
|
+
except Exception as ex:
|
|
315
|
+
logger.exception("%s Failed to build ReAct Graph: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
|
|
316
|
+
raise ex
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def validate_system_prompt(system_prompt: str) -> bool:
|
|
320
|
+
errors = []
|
|
321
|
+
if not system_prompt:
|
|
322
|
+
errors.append("The system prompt cannot be empty.")
|
|
323
|
+
required_prompt_variables = {
|
|
324
|
+
"{tools}": "The system prompt must contain {tools} so the agent knows about configured tools.",
|
|
325
|
+
"{tool_names}": "The system prompt must contain {tool_names} so the agent knows tool names."
|
|
326
|
+
}
|
|
327
|
+
for variable_name, error_message in required_prompt_variables.items():
|
|
328
|
+
if variable_name not in system_prompt:
|
|
329
|
+
errors.append(error_message)
|
|
330
|
+
if errors:
|
|
331
|
+
error_text = "\n".join(errors)
|
|
332
|
+
logger.exception("%s %s", AGENT_LOG_PREFIX, error_text)
|
|
333
|
+
raise ValueError(error_text)
|
|
334
|
+
return True
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def create_react_agent_prompt(config: "ReActAgentWorkflowConfig") -> ChatPromptTemplate:
|
|
338
|
+
"""
|
|
339
|
+
Create a ReAct Agent prompt from the config.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
config (ReActAgentWorkflowConfig): The config to use for the prompt.
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
ChatPromptTemplate: The ReAct Agent prompt.
|
|
346
|
+
"""
|
|
347
|
+
# the ReAct Agent prompt can be customized via config option system_prompt and additional_instructions.
|
|
348
|
+
|
|
349
|
+
if config.system_prompt:
|
|
350
|
+
prompt_str = config.system_prompt
|
|
351
|
+
else:
|
|
352
|
+
prompt_str = SYSTEM_PROMPT
|
|
353
|
+
|
|
354
|
+
if config.additional_instructions:
|
|
355
|
+
prompt_str += f" {config.additional_instructions}"
|
|
356
|
+
|
|
357
|
+
valid_prompt = ReActAgentGraph.validate_system_prompt(prompt_str)
|
|
358
|
+
if not valid_prompt:
|
|
359
|
+
logger.exception("%s Invalid system_prompt", AGENT_LOG_PREFIX)
|
|
360
|
+
raise ValueError("Invalid system_prompt")
|
|
361
|
+
prompt = ChatPromptTemplate([("system", prompt_str), ("user", USER_PROMPT),
|
|
362
|
+
MessagesPlaceholder(variable_name='agent_scratchpad', optional=True)])
|
|
363
|
+
return prompt
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
from langchain.agents.agent import AgentOutputParser
|
|
19
|
+
from langchain_core.agents import AgentAction
|
|
20
|
+
from langchain_core.agents import AgentFinish
|
|
21
|
+
from langchain_core.exceptions import LangChainException
|
|
22
|
+
|
|
23
|
+
from .prompt import SYSTEM_PROMPT
|
|
24
|
+
|
|
25
|
+
FINAL_ANSWER_ACTION = "Final Answer:"
|
|
26
|
+
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = "Invalid Format: Missing 'Action:' after 'Thought:'"
|
|
27
|
+
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE = "Invalid Format: Missing 'Action Input:' after 'Action:'"
|
|
28
|
+
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = ("Parsing LLM output produced both a final answer and a parse-able "
|
|
29
|
+
"action:")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ReActOutputParserException(ValueError, LangChainException):
|
|
33
|
+
|
|
34
|
+
def __init__(self,
|
|
35
|
+
observation=None,
|
|
36
|
+
missing_action=False,
|
|
37
|
+
missing_action_input=False,
|
|
38
|
+
final_answer_and_action=False):
|
|
39
|
+
self.observation = observation
|
|
40
|
+
self.missing_action = missing_action
|
|
41
|
+
self.missing_action_input = missing_action_input
|
|
42
|
+
self.final_answer_and_action = final_answer_and_action
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ReActOutputParser(AgentOutputParser):
|
|
46
|
+
"""Parses ReAct-style LLM calls that have a single tool input.
|
|
47
|
+
|
|
48
|
+
Expects output to be in one of two formats.
|
|
49
|
+
|
|
50
|
+
If the output signals that an action should be taken,
|
|
51
|
+
should be in the below format. This will result in an AgentAction
|
|
52
|
+
being returned.
|
|
53
|
+
|
|
54
|
+
```
|
|
55
|
+
Thought: agent thought here
|
|
56
|
+
Action: search
|
|
57
|
+
Action Input: what is the temperature in SF?
|
|
58
|
+
Observation: Waiting for the tool response...
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
If the output signals that a final answer should be given,
|
|
62
|
+
should be in the below format. This will result in an AgentFinish
|
|
63
|
+
being returned.
|
|
64
|
+
|
|
65
|
+
```
|
|
66
|
+
Thought: agent thought here
|
|
67
|
+
Final Answer: The temperature is 100 degrees
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def get_format_instructions(self) -> str:
|
|
73
|
+
return SYSTEM_PROMPT
|
|
74
|
+
|
|
75
|
+
def parse(self, text: str) -> AgentAction | AgentFinish:
|
|
76
|
+
includes_answer = FINAL_ANSWER_ACTION in text
|
|
77
|
+
regex = r"Action\s*\d*\s*:[\s]*(.*?)\s*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*?)(?=\s*[\n|\s]\s*Observation\b|$)"
|
|
78
|
+
action_match = re.search(regex, text, re.DOTALL)
|
|
79
|
+
if action_match:
|
|
80
|
+
if includes_answer:
|
|
81
|
+
raise ReActOutputParserException(
|
|
82
|
+
final_answer_and_action=True,
|
|
83
|
+
observation=f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}")
|
|
84
|
+
action = action_match.group(1).strip()
|
|
85
|
+
action_input = action_match.group(2)
|
|
86
|
+
tool_input = action_input.strip(" ")
|
|
87
|
+
tool_input = tool_input.strip('"')
|
|
88
|
+
|
|
89
|
+
return AgentAction(action, tool_input, text)
|
|
90
|
+
|
|
91
|
+
if includes_answer:
|
|
92
|
+
return AgentFinish({"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text)
|
|
93
|
+
|
|
94
|
+
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
|
|
95
|
+
raise ReActOutputParserException(observation=MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
|
|
96
|
+
missing_action=True)
|
|
97
|
+
if not re.search(r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL):
|
|
98
|
+
raise ReActOutputParserException(observation=MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
|
99
|
+
missing_action_input=True)
|
|
100
|
+
raise ReActOutputParserException(f"Could not parse LLM output: `{text}`")
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def _type(self) -> str:
|
|
104
|
+
return "react-input"
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
|
|
18
|
+
SYSTEM_PROMPT = """
|
|
19
|
+
Answer the following questions as best you can. You may ask the human to use the following tools:
|
|
20
|
+
|
|
21
|
+
{tools}
|
|
22
|
+
|
|
23
|
+
You may respond in one of two formats.
|
|
24
|
+
Use the following format exactly to ask the human to use a tool:
|
|
25
|
+
|
|
26
|
+
Question: the input question you must answer
|
|
27
|
+
Thought: you should always think about what to do
|
|
28
|
+
Action: the action to take, should be one of [{tool_names}]
|
|
29
|
+
Action Input: the input to the action (if there is no required input, include "Action Input: None")
|
|
30
|
+
Observation: wait for the human to respond with the result from the tool, do not assume the response
|
|
31
|
+
|
|
32
|
+
... (this Thought/Action/Action Input/Observation can repeat N times. If you do not need to use a tool, or after asking the human to use any tools and waiting for the human to respond, you might know the final answer.)
|
|
33
|
+
Use the following format once you have the final answer:
|
|
34
|
+
|
|
35
|
+
Thought: I now know the final answer
|
|
36
|
+
Final Answer: the final answer to the original input question
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
USER_PROMPT = """
|
|
40
|
+
Previous conversation history:
|
|
41
|
+
{chat_history}
|
|
42
|
+
|
|
43
|
+
Question: {question}
|
|
44
|
+
"""
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import AliasChoices
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.builder.builder import Builder
|
|
22
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
23
|
+
from nat.builder.function_info import FunctionInfo
|
|
24
|
+
from nat.cli.register_workflow import register_function
|
|
25
|
+
from nat.data_models.api_server import ChatRequest
|
|
26
|
+
from nat.data_models.api_server import ChatResponse
|
|
27
|
+
from nat.data_models.component_ref import FunctionRef
|
|
28
|
+
from nat.data_models.component_ref import LLMRef
|
|
29
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
30
|
+
from nat.utils.type_converter import GlobalTypeConverter
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
|
|
36
|
+
"""
|
|
37
|
+
Defines a NAT function that uses a ReAct Agent performs reasoning inbetween tool calls, and utilizes the
|
|
38
|
+
tool names and descriptions to select the optimal tool.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
tool_names: list[FunctionRef] = Field(default_factory=list,
|
|
42
|
+
description="The list of tools to provide to the react agent.")
|
|
43
|
+
llm_name: LLMRef = Field(description="The LLM model to use with the react agent.")
|
|
44
|
+
verbose: bool = Field(default=False, description="Set the verbosity of the react agent's logging.")
|
|
45
|
+
retry_agent_response_parsing_errors: bool = Field(
|
|
46
|
+
default=True,
|
|
47
|
+
validation_alias=AliasChoices("retry_agent_response_parsing_errors", "retry_parsing_errors"),
|
|
48
|
+
description="Whether to retry when encountering parsing errors in the agent's response.")
|
|
49
|
+
parse_agent_response_max_retries: int = Field(
|
|
50
|
+
default=1,
|
|
51
|
+
validation_alias=AliasChoices("parse_agent_response_max_retries", "max_retries"),
|
|
52
|
+
description="Maximum number of times the Agent may retry parsing errors. "
|
|
53
|
+
"Prevents the Agent from getting into infinite hallucination loops.")
|
|
54
|
+
tool_call_max_retries: int = Field(default=1, description="The number of retries before raising a tool call error.")
|
|
55
|
+
max_tool_calls: int = Field(default=15,
|
|
56
|
+
validation_alias=AliasChoices("max_tool_calls", "max_iterations"),
|
|
57
|
+
description="Maximum number of tool calls before stopping the agent.")
|
|
58
|
+
pass_tool_call_errors_to_agent: bool = Field(
|
|
59
|
+
default=True,
|
|
60
|
+
description="Whether to pass tool call errors to agent. If False, failed tool calls will raise an exception.")
|
|
61
|
+
include_tool_input_schema_in_tool_description: bool = Field(
|
|
62
|
+
default=True, description="Specify inclusion of tool input schemas in the prompt.")
|
|
63
|
+
description: str = Field(default="ReAct Agent Workflow", description="The description of this functions use.")
|
|
64
|
+
system_prompt: str | None = Field(
|
|
65
|
+
default=None,
|
|
66
|
+
description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
|
|
67
|
+
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
68
|
+
use_openai_api: bool = Field(default=False,
|
|
69
|
+
description=("Use OpenAI API for the input/output types to the function. "
|
|
70
|
+
"If False, strings will be used."))
|
|
71
|
+
additional_instructions: str | None = Field(
|
|
72
|
+
default=None, description="Additional instructions to provide to the agent in addition to the base prompt.")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@register_function(config_type=ReActAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
76
|
+
async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builder):
|
|
77
|
+
from langchain.schema import BaseMessage
|
|
78
|
+
from langchain_core.messages import trim_messages
|
|
79
|
+
from langgraph.graph.graph import CompiledGraph
|
|
80
|
+
|
|
81
|
+
from nat.agent.base import AGENT_LOG_PREFIX
|
|
82
|
+
from nat.agent.react_agent.agent import ReActAgentGraph
|
|
83
|
+
from nat.agent.react_agent.agent import ReActGraphState
|
|
84
|
+
from nat.agent.react_agent.agent import create_react_agent_prompt
|
|
85
|
+
|
|
86
|
+
prompt = create_react_agent_prompt(config)
|
|
87
|
+
|
|
88
|
+
# we can choose an LLM for the ReAct agent in the config file
|
|
89
|
+
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
90
|
+
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
91
|
+
# the sample tool provided can easily be copied or changed
|
|
92
|
+
tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
93
|
+
if not tools:
|
|
94
|
+
raise ValueError(f"No tools specified for ReAct Agent '{config.llm_name}'")
|
|
95
|
+
# configure callbacks, for sending intermediate steps
|
|
96
|
+
# construct the ReAct Agent Graph from the configured llm, prompt, and tools
|
|
97
|
+
graph: CompiledGraph = await ReActAgentGraph(
|
|
98
|
+
llm=llm,
|
|
99
|
+
prompt=prompt,
|
|
100
|
+
tools=tools,
|
|
101
|
+
use_tool_schema=config.include_tool_input_schema_in_tool_description,
|
|
102
|
+
detailed_logs=config.verbose,
|
|
103
|
+
retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
|
|
104
|
+
parse_agent_response_max_retries=config.parse_agent_response_max_retries,
|
|
105
|
+
tool_call_max_retries=config.tool_call_max_retries,
|
|
106
|
+
pass_tool_call_errors_to_agent=config.pass_tool_call_errors_to_agent).build_graph()
|
|
107
|
+
|
|
108
|
+
async def _response_fn(input_message: ChatRequest) -> ChatResponse:
|
|
109
|
+
try:
|
|
110
|
+
# initialize the starting state with the user query
|
|
111
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
|
|
112
|
+
max_tokens=config.max_history,
|
|
113
|
+
strategy="last",
|
|
114
|
+
token_counter=len,
|
|
115
|
+
start_on="human",
|
|
116
|
+
include_system=True)
|
|
117
|
+
|
|
118
|
+
state = ReActGraphState(messages=messages)
|
|
119
|
+
|
|
120
|
+
# run the ReAct Agent Graph
|
|
121
|
+
state = await graph.ainvoke(state, config={'recursion_limit': (config.max_tool_calls + 1) * 2})
|
|
122
|
+
# setting recursion_limit: 4 allows 1 tool call
|
|
123
|
+
# - allows the ReAct Agent to perform 1 cycle / call 1 single tool,
|
|
124
|
+
# - but stops the agent when it tries to call a tool a second time
|
|
125
|
+
|
|
126
|
+
# get and return the output from the state
|
|
127
|
+
state = ReActGraphState(**state)
|
|
128
|
+
output_message = state.messages[-1] # pylint: disable=E1136
|
|
129
|
+
return ChatResponse.from_string(str(output_message.content))
|
|
130
|
+
|
|
131
|
+
except Exception as ex:
|
|
132
|
+
logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
|
|
133
|
+
# here, we can implement custom error messages
|
|
134
|
+
if config.verbose:
|
|
135
|
+
return ChatResponse.from_string(str(ex))
|
|
136
|
+
return ChatResponse.from_string("I seem to be having a problem.")
|
|
137
|
+
|
|
138
|
+
if (config.use_openai_api):
|
|
139
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
140
|
+
else:
|
|
141
|
+
|
|
142
|
+
async def _str_api_fn(input_message: str) -> str:
|
|
143
|
+
oai_input = GlobalTypeConverter.get().try_convert(input_message, to_type=ChatRequest)
|
|
144
|
+
|
|
145
|
+
oai_output = await _response_fn(oai_input)
|
|
146
|
+
|
|
147
|
+
return GlobalTypeConverter.get().try_convert(oai_output, to_type=str)
|
|
148
|
+
|
|
149
|
+
yield FunctionInfo.from_fn(_str_api_fn, description=config.description)
|
|
File without changes
|