nvidia-nat 1.1.0a20251020__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (480) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/__init__.py +0 -0
  3. nat/agent/base.py +265 -0
  4. nat/agent/dual_node.py +72 -0
  5. nat/agent/prompt_optimizer/__init__.py +0 -0
  6. nat/agent/prompt_optimizer/prompt.py +68 -0
  7. nat/agent/prompt_optimizer/register.py +149 -0
  8. nat/agent/react_agent/__init__.py +0 -0
  9. nat/agent/react_agent/agent.py +394 -0
  10. nat/agent/react_agent/output_parser.py +104 -0
  11. nat/agent/react_agent/prompt.py +44 -0
  12. nat/agent/react_agent/register.py +168 -0
  13. nat/agent/reasoning_agent/__init__.py +0 -0
  14. nat/agent/reasoning_agent/reasoning_agent.py +227 -0
  15. nat/agent/register.py +23 -0
  16. nat/agent/rewoo_agent/__init__.py +0 -0
  17. nat/agent/rewoo_agent/agent.py +593 -0
  18. nat/agent/rewoo_agent/prompt.py +107 -0
  19. nat/agent/rewoo_agent/register.py +175 -0
  20. nat/agent/tool_calling_agent/__init__.py +0 -0
  21. nat/agent/tool_calling_agent/agent.py +246 -0
  22. nat/agent/tool_calling_agent/register.py +129 -0
  23. nat/authentication/__init__.py +14 -0
  24. nat/authentication/api_key/__init__.py +14 -0
  25. nat/authentication/api_key/api_key_auth_provider.py +96 -0
  26. nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
  27. nat/authentication/api_key/register.py +26 -0
  28. nat/authentication/credential_validator/__init__.py +14 -0
  29. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  30. nat/authentication/exceptions/__init__.py +14 -0
  31. nat/authentication/exceptions/api_key_exceptions.py +38 -0
  32. nat/authentication/http_basic_auth/__init__.py +0 -0
  33. nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  34. nat/authentication/http_basic_auth/register.py +30 -0
  35. nat/authentication/interfaces.py +96 -0
  36. nat/authentication/oauth2/__init__.py +14 -0
  37. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +140 -0
  38. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  39. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  40. nat/authentication/oauth2/register.py +25 -0
  41. nat/authentication/register.py +20 -0
  42. nat/builder/__init__.py +0 -0
  43. nat/builder/builder.py +317 -0
  44. nat/builder/component_utils.py +320 -0
  45. nat/builder/context.py +321 -0
  46. nat/builder/embedder.py +24 -0
  47. nat/builder/eval_builder.py +166 -0
  48. nat/builder/evaluator.py +29 -0
  49. nat/builder/framework_enum.py +25 -0
  50. nat/builder/front_end.py +73 -0
  51. nat/builder/function.py +714 -0
  52. nat/builder/function_base.py +380 -0
  53. nat/builder/function_info.py +625 -0
  54. nat/builder/intermediate_step_manager.py +206 -0
  55. nat/builder/llm.py +25 -0
  56. nat/builder/retriever.py +25 -0
  57. nat/builder/user_interaction_manager.py +78 -0
  58. nat/builder/workflow.py +160 -0
  59. nat/builder/workflow_builder.py +1365 -0
  60. nat/cli/__init__.py +14 -0
  61. nat/cli/cli_utils/__init__.py +0 -0
  62. nat/cli/cli_utils/config_override.py +231 -0
  63. nat/cli/cli_utils/validation.py +37 -0
  64. nat/cli/commands/__init__.py +0 -0
  65. nat/cli/commands/configure/__init__.py +0 -0
  66. nat/cli/commands/configure/channel/__init__.py +0 -0
  67. nat/cli/commands/configure/channel/add.py +28 -0
  68. nat/cli/commands/configure/channel/channel.py +34 -0
  69. nat/cli/commands/configure/channel/remove.py +30 -0
  70. nat/cli/commands/configure/channel/update.py +30 -0
  71. nat/cli/commands/configure/configure.py +33 -0
  72. nat/cli/commands/evaluate.py +139 -0
  73. nat/cli/commands/info/__init__.py +14 -0
  74. nat/cli/commands/info/info.py +47 -0
  75. nat/cli/commands/info/list_channels.py +32 -0
  76. nat/cli/commands/info/list_components.py +128 -0
  77. nat/cli/commands/mcp/__init__.py +14 -0
  78. nat/cli/commands/mcp/mcp.py +986 -0
  79. nat/cli/commands/object_store/__init__.py +14 -0
  80. nat/cli/commands/object_store/object_store.py +227 -0
  81. nat/cli/commands/optimize.py +90 -0
  82. nat/cli/commands/registry/__init__.py +14 -0
  83. nat/cli/commands/registry/publish.py +88 -0
  84. nat/cli/commands/registry/pull.py +118 -0
  85. nat/cli/commands/registry/registry.py +36 -0
  86. nat/cli/commands/registry/remove.py +108 -0
  87. nat/cli/commands/registry/search.py +153 -0
  88. nat/cli/commands/sizing/__init__.py +14 -0
  89. nat/cli/commands/sizing/calc.py +297 -0
  90. nat/cli/commands/sizing/sizing.py +27 -0
  91. nat/cli/commands/start.py +257 -0
  92. nat/cli/commands/uninstall.py +81 -0
  93. nat/cli/commands/validate.py +47 -0
  94. nat/cli/commands/workflow/__init__.py +14 -0
  95. nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  96. nat/cli/commands/workflow/templates/config.yml.j2 +17 -0
  97. nat/cli/commands/workflow/templates/pyproject.toml.j2 +25 -0
  98. nat/cli/commands/workflow/templates/register.py.j2 +4 -0
  99. nat/cli/commands/workflow/templates/workflow.py.j2 +50 -0
  100. nat/cli/commands/workflow/workflow.py +37 -0
  101. nat/cli/commands/workflow/workflow_commands.py +403 -0
  102. nat/cli/entrypoint.py +141 -0
  103. nat/cli/main.py +60 -0
  104. nat/cli/register_workflow.py +522 -0
  105. nat/cli/type_registry.py +1069 -0
  106. nat/control_flow/__init__.py +0 -0
  107. nat/control_flow/register.py +20 -0
  108. nat/control_flow/router_agent/__init__.py +0 -0
  109. nat/control_flow/router_agent/agent.py +329 -0
  110. nat/control_flow/router_agent/prompt.py +48 -0
  111. nat/control_flow/router_agent/register.py +91 -0
  112. nat/control_flow/sequential_executor.py +166 -0
  113. nat/data_models/__init__.py +14 -0
  114. nat/data_models/agent.py +34 -0
  115. nat/data_models/api_server.py +843 -0
  116. nat/data_models/authentication.py +245 -0
  117. nat/data_models/common.py +171 -0
  118. nat/data_models/component.py +60 -0
  119. nat/data_models/component_ref.py +179 -0
  120. nat/data_models/config.py +434 -0
  121. nat/data_models/dataset_handler.py +169 -0
  122. nat/data_models/discovery_metadata.py +305 -0
  123. nat/data_models/embedder.py +27 -0
  124. nat/data_models/evaluate.py +130 -0
  125. nat/data_models/evaluator.py +26 -0
  126. nat/data_models/front_end.py +26 -0
  127. nat/data_models/function.py +64 -0
  128. nat/data_models/function_dependencies.py +80 -0
  129. nat/data_models/gated_field_mixin.py +242 -0
  130. nat/data_models/interactive.py +246 -0
  131. nat/data_models/intermediate_step.py +302 -0
  132. nat/data_models/invocation_node.py +38 -0
  133. nat/data_models/llm.py +27 -0
  134. nat/data_models/logging.py +26 -0
  135. nat/data_models/memory.py +27 -0
  136. nat/data_models/object_store.py +44 -0
  137. nat/data_models/optimizable.py +119 -0
  138. nat/data_models/optimizer.py +149 -0
  139. nat/data_models/profiler.py +54 -0
  140. nat/data_models/registry_handler.py +26 -0
  141. nat/data_models/retriever.py +30 -0
  142. nat/data_models/retry_mixin.py +35 -0
  143. nat/data_models/span.py +228 -0
  144. nat/data_models/step_adaptor.py +64 -0
  145. nat/data_models/streaming.py +33 -0
  146. nat/data_models/swe_bench_model.py +54 -0
  147. nat/data_models/telemetry_exporter.py +26 -0
  148. nat/data_models/temperature_mixin.py +44 -0
  149. nat/data_models/thinking_mixin.py +86 -0
  150. nat/data_models/top_p_mixin.py +44 -0
  151. nat/data_models/ttc_strategy.py +30 -0
  152. nat/embedder/__init__.py +0 -0
  153. nat/embedder/azure_openai_embedder.py +46 -0
  154. nat/embedder/nim_embedder.py +59 -0
  155. nat/embedder/openai_embedder.py +42 -0
  156. nat/embedder/register.py +22 -0
  157. nat/eval/__init__.py +14 -0
  158. nat/eval/config.py +62 -0
  159. nat/eval/dataset_handler/__init__.py +0 -0
  160. nat/eval/dataset_handler/dataset_downloader.py +106 -0
  161. nat/eval/dataset_handler/dataset_filter.py +52 -0
  162. nat/eval/dataset_handler/dataset_handler.py +431 -0
  163. nat/eval/evaluate.py +565 -0
  164. nat/eval/evaluator/__init__.py +14 -0
  165. nat/eval/evaluator/base_evaluator.py +77 -0
  166. nat/eval/evaluator/evaluator_model.py +58 -0
  167. nat/eval/intermediate_step_adapter.py +99 -0
  168. nat/eval/rag_evaluator/__init__.py +0 -0
  169. nat/eval/rag_evaluator/evaluate.py +178 -0
  170. nat/eval/rag_evaluator/register.py +143 -0
  171. nat/eval/register.py +26 -0
  172. nat/eval/remote_workflow.py +133 -0
  173. nat/eval/runners/__init__.py +14 -0
  174. nat/eval/runners/config.py +39 -0
  175. nat/eval/runners/multi_eval_runner.py +54 -0
  176. nat/eval/runtime_evaluator/__init__.py +14 -0
  177. nat/eval/runtime_evaluator/evaluate.py +123 -0
  178. nat/eval/runtime_evaluator/register.py +100 -0
  179. nat/eval/runtime_event_subscriber.py +52 -0
  180. nat/eval/swe_bench_evaluator/__init__.py +0 -0
  181. nat/eval/swe_bench_evaluator/evaluate.py +215 -0
  182. nat/eval/swe_bench_evaluator/register.py +36 -0
  183. nat/eval/trajectory_evaluator/__init__.py +0 -0
  184. nat/eval/trajectory_evaluator/evaluate.py +75 -0
  185. nat/eval/trajectory_evaluator/register.py +40 -0
  186. nat/eval/tunable_rag_evaluator/__init__.py +0 -0
  187. nat/eval/tunable_rag_evaluator/evaluate.py +242 -0
  188. nat/eval/tunable_rag_evaluator/register.py +52 -0
  189. nat/eval/usage_stats.py +41 -0
  190. nat/eval/utils/__init__.py +0 -0
  191. nat/eval/utils/eval_trace_ctx.py +89 -0
  192. nat/eval/utils/output_uploader.py +140 -0
  193. nat/eval/utils/tqdm_position_registry.py +40 -0
  194. nat/eval/utils/weave_eval.py +193 -0
  195. nat/experimental/__init__.py +0 -0
  196. nat/experimental/decorators/__init__.py +0 -0
  197. nat/experimental/decorators/experimental_warning_decorator.py +154 -0
  198. nat/experimental/test_time_compute/__init__.py +0 -0
  199. nat/experimental/test_time_compute/editing/__init__.py +0 -0
  200. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  201. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  202. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  203. nat/experimental/test_time_compute/functions/__init__.py +0 -0
  204. nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  205. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +228 -0
  206. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
  207. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
  208. nat/experimental/test_time_compute/models/__init__.py +0 -0
  209. nat/experimental/test_time_compute/models/editor_config.py +132 -0
  210. nat/experimental/test_time_compute/models/scoring_config.py +112 -0
  211. nat/experimental/test_time_compute/models/search_config.py +120 -0
  212. nat/experimental/test_time_compute/models/selection_config.py +154 -0
  213. nat/experimental/test_time_compute/models/stage_enums.py +43 -0
  214. nat/experimental/test_time_compute/models/strategy_base.py +67 -0
  215. nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
  216. nat/experimental/test_time_compute/models/ttc_item.py +48 -0
  217. nat/experimental/test_time_compute/register.py +35 -0
  218. nat/experimental/test_time_compute/scoring/__init__.py +0 -0
  219. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  220. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  221. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  222. nat/experimental/test_time_compute/search/__init__.py +0 -0
  223. nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  224. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  225. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  226. nat/experimental/test_time_compute/selection/__init__.py +0 -0
  227. nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  228. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  229. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +157 -0
  230. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  231. nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  232. nat/front_ends/__init__.py +14 -0
  233. nat/front_ends/console/__init__.py +14 -0
  234. nat/front_ends/console/authentication_flow_handler.py +285 -0
  235. nat/front_ends/console/console_front_end_config.py +32 -0
  236. nat/front_ends/console/console_front_end_plugin.py +108 -0
  237. nat/front_ends/console/register.py +25 -0
  238. nat/front_ends/cron/__init__.py +14 -0
  239. nat/front_ends/fastapi/__init__.py +14 -0
  240. nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  241. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  242. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +142 -0
  243. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  244. nat/front_ends/fastapi/fastapi_front_end_config.py +272 -0
  245. nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  246. nat/front_ends/fastapi/fastapi_front_end_plugin.py +247 -0
  247. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1257 -0
  248. nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
  249. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  250. nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  251. nat/front_ends/fastapi/job_store.py +602 -0
  252. nat/front_ends/fastapi/main.py +64 -0
  253. nat/front_ends/fastapi/message_handler.py +344 -0
  254. nat/front_ends/fastapi/message_validator.py +351 -0
  255. nat/front_ends/fastapi/register.py +25 -0
  256. nat/front_ends/fastapi/response_helpers.py +195 -0
  257. nat/front_ends/fastapi/step_adaptor.py +319 -0
  258. nat/front_ends/fastapi/utils.py +57 -0
  259. nat/front_ends/mcp/__init__.py +14 -0
  260. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  261. nat/front_ends/mcp/mcp_front_end_config.py +90 -0
  262. nat/front_ends/mcp/mcp_front_end_plugin.py +113 -0
  263. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +268 -0
  264. nat/front_ends/mcp/memory_profiler.py +320 -0
  265. nat/front_ends/mcp/register.py +27 -0
  266. nat/front_ends/mcp/tool_converter.py +290 -0
  267. nat/front_ends/register.py +21 -0
  268. nat/front_ends/simple_base/__init__.py +14 -0
  269. nat/front_ends/simple_base/simple_front_end_plugin_base.py +56 -0
  270. nat/llm/__init__.py +0 -0
  271. nat/llm/aws_bedrock_llm.py +69 -0
  272. nat/llm/azure_openai_llm.py +57 -0
  273. nat/llm/litellm_llm.py +69 -0
  274. nat/llm/nim_llm.py +58 -0
  275. nat/llm/openai_llm.py +54 -0
  276. nat/llm/register.py +27 -0
  277. nat/llm/utils/__init__.py +14 -0
  278. nat/llm/utils/env_config_value.py +93 -0
  279. nat/llm/utils/error.py +17 -0
  280. nat/llm/utils/thinking.py +215 -0
  281. nat/memory/__init__.py +20 -0
  282. nat/memory/interfaces.py +183 -0
  283. nat/memory/models.py +112 -0
  284. nat/meta/pypi.md +58 -0
  285. nat/object_store/__init__.py +20 -0
  286. nat/object_store/in_memory_object_store.py +76 -0
  287. nat/object_store/interfaces.py +84 -0
  288. nat/object_store/models.py +38 -0
  289. nat/object_store/register.py +19 -0
  290. nat/observability/__init__.py +14 -0
  291. nat/observability/exporter/__init__.py +14 -0
  292. nat/observability/exporter/base_exporter.py +449 -0
  293. nat/observability/exporter/exporter.py +78 -0
  294. nat/observability/exporter/file_exporter.py +33 -0
  295. nat/observability/exporter/processing_exporter.py +550 -0
  296. nat/observability/exporter/raw_exporter.py +52 -0
  297. nat/observability/exporter/span_exporter.py +308 -0
  298. nat/observability/exporter_manager.py +335 -0
  299. nat/observability/mixin/__init__.py +14 -0
  300. nat/observability/mixin/batch_config_mixin.py +26 -0
  301. nat/observability/mixin/collector_config_mixin.py +23 -0
  302. nat/observability/mixin/file_mixin.py +288 -0
  303. nat/observability/mixin/file_mode.py +23 -0
  304. nat/observability/mixin/redaction_config_mixin.py +42 -0
  305. nat/observability/mixin/resource_conflict_mixin.py +134 -0
  306. nat/observability/mixin/serialize_mixin.py +61 -0
  307. nat/observability/mixin/tagging_config_mixin.py +62 -0
  308. nat/observability/mixin/type_introspection_mixin.py +496 -0
  309. nat/observability/processor/__init__.py +14 -0
  310. nat/observability/processor/batching_processor.py +308 -0
  311. nat/observability/processor/callback_processor.py +42 -0
  312. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  313. nat/observability/processor/intermediate_step_serializer.py +28 -0
  314. nat/observability/processor/processor.py +74 -0
  315. nat/observability/processor/processor_factory.py +70 -0
  316. nat/observability/processor/redaction/__init__.py +24 -0
  317. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  318. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  319. nat/observability/processor/redaction/redaction_processor.py +177 -0
  320. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  321. nat/observability/processor/span_tagging_processor.py +68 -0
  322. nat/observability/register.py +114 -0
  323. nat/observability/utils/__init__.py +14 -0
  324. nat/observability/utils/dict_utils.py +236 -0
  325. nat/observability/utils/time_utils.py +31 -0
  326. nat/plugins/.namespace +1 -0
  327. nat/profiler/__init__.py +0 -0
  328. nat/profiler/calc/__init__.py +14 -0
  329. nat/profiler/calc/calc_runner.py +626 -0
  330. nat/profiler/calc/calculations.py +288 -0
  331. nat/profiler/calc/data_models.py +188 -0
  332. nat/profiler/calc/plot.py +345 -0
  333. nat/profiler/callbacks/__init__.py +0 -0
  334. nat/profiler/callbacks/agno_callback_handler.py +295 -0
  335. nat/profiler/callbacks/base_callback_class.py +20 -0
  336. nat/profiler/callbacks/langchain_callback_handler.py +297 -0
  337. nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
  338. nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  339. nat/profiler/callbacks/token_usage_base_model.py +27 -0
  340. nat/profiler/data_frame_row.py +51 -0
  341. nat/profiler/data_models.py +24 -0
  342. nat/profiler/decorators/__init__.py +0 -0
  343. nat/profiler/decorators/framework_wrapper.py +180 -0
  344. nat/profiler/decorators/function_tracking.py +411 -0
  345. nat/profiler/forecasting/__init__.py +0 -0
  346. nat/profiler/forecasting/config.py +18 -0
  347. nat/profiler/forecasting/model_trainer.py +75 -0
  348. nat/profiler/forecasting/models/__init__.py +22 -0
  349. nat/profiler/forecasting/models/forecasting_base_model.py +42 -0
  350. nat/profiler/forecasting/models/linear_model.py +197 -0
  351. nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
  352. nat/profiler/inference_metrics_model.py +28 -0
  353. nat/profiler/inference_optimization/__init__.py +0 -0
  354. nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  355. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  356. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  357. nat/profiler/inference_optimization/data_models.py +386 -0
  358. nat/profiler/inference_optimization/experimental/__init__.py +0 -0
  359. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  360. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +404 -0
  361. nat/profiler/inference_optimization/llm_metrics.py +212 -0
  362. nat/profiler/inference_optimization/prompt_caching.py +163 -0
  363. nat/profiler/inference_optimization/token_uniqueness.py +107 -0
  364. nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
  365. nat/profiler/intermediate_property_adapter.py +102 -0
  366. nat/profiler/parameter_optimization/__init__.py +0 -0
  367. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  368. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  369. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  370. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  371. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  372. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  373. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  374. nat/profiler/profile_runner.py +478 -0
  375. nat/profiler/utils.py +186 -0
  376. nat/registry_handlers/__init__.py +0 -0
  377. nat/registry_handlers/local/__init__.py +0 -0
  378. nat/registry_handlers/local/local_handler.py +176 -0
  379. nat/registry_handlers/local/register_local.py +37 -0
  380. nat/registry_handlers/metadata_factory.py +60 -0
  381. nat/registry_handlers/package_utils.py +570 -0
  382. nat/registry_handlers/pypi/__init__.py +0 -0
  383. nat/registry_handlers/pypi/pypi_handler.py +248 -0
  384. nat/registry_handlers/pypi/register_pypi.py +40 -0
  385. nat/registry_handlers/register.py +20 -0
  386. nat/registry_handlers/registry_handler_base.py +157 -0
  387. nat/registry_handlers/rest/__init__.py +0 -0
  388. nat/registry_handlers/rest/register_rest.py +56 -0
  389. nat/registry_handlers/rest/rest_handler.py +236 -0
  390. nat/registry_handlers/schemas/__init__.py +0 -0
  391. nat/registry_handlers/schemas/headers.py +42 -0
  392. nat/registry_handlers/schemas/package.py +68 -0
  393. nat/registry_handlers/schemas/publish.py +68 -0
  394. nat/registry_handlers/schemas/pull.py +82 -0
  395. nat/registry_handlers/schemas/remove.py +36 -0
  396. nat/registry_handlers/schemas/search.py +91 -0
  397. nat/registry_handlers/schemas/status.py +47 -0
  398. nat/retriever/__init__.py +0 -0
  399. nat/retriever/interface.py +41 -0
  400. nat/retriever/milvus/__init__.py +14 -0
  401. nat/retriever/milvus/register.py +81 -0
  402. nat/retriever/milvus/retriever.py +228 -0
  403. nat/retriever/models.py +77 -0
  404. nat/retriever/nemo_retriever/__init__.py +14 -0
  405. nat/retriever/nemo_retriever/register.py +60 -0
  406. nat/retriever/nemo_retriever/retriever.py +190 -0
  407. nat/retriever/register.py +21 -0
  408. nat/runtime/__init__.py +14 -0
  409. nat/runtime/loader.py +220 -0
  410. nat/runtime/runner.py +292 -0
  411. nat/runtime/session.py +223 -0
  412. nat/runtime/user_metadata.py +130 -0
  413. nat/settings/__init__.py +0 -0
  414. nat/settings/global_settings.py +329 -0
  415. nat/test/.namespace +1 -0
  416. nat/tool/__init__.py +0 -0
  417. nat/tool/chat_completion.py +77 -0
  418. nat/tool/code_execution/README.md +151 -0
  419. nat/tool/code_execution/__init__.py +0 -0
  420. nat/tool/code_execution/code_sandbox.py +267 -0
  421. nat/tool/code_execution/local_sandbox/.gitignore +1 -0
  422. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  423. nat/tool/code_execution/local_sandbox/__init__.py +13 -0
  424. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  425. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  426. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  427. nat/tool/code_execution/register.py +74 -0
  428. nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
  429. nat/tool/code_execution/utils.py +100 -0
  430. nat/tool/datetime_tools.py +82 -0
  431. nat/tool/document_search.py +141 -0
  432. nat/tool/github_tools.py +450 -0
  433. nat/tool/memory_tools/__init__.py +0 -0
  434. nat/tool/memory_tools/add_memory_tool.py +79 -0
  435. nat/tool/memory_tools/delete_memory_tool.py +66 -0
  436. nat/tool/memory_tools/get_memory_tool.py +72 -0
  437. nat/tool/nvidia_rag.py +95 -0
  438. nat/tool/register.py +31 -0
  439. nat/tool/retriever.py +95 -0
  440. nat/tool/server_tools.py +66 -0
  441. nat/utils/__init__.py +0 -0
  442. nat/utils/callable_utils.py +70 -0
  443. nat/utils/data_models/__init__.py +0 -0
  444. nat/utils/data_models/schema_validator.py +58 -0
  445. nat/utils/debugging_utils.py +43 -0
  446. nat/utils/decorators.py +210 -0
  447. nat/utils/dump_distro_mapping.py +32 -0
  448. nat/utils/exception_handlers/__init__.py +0 -0
  449. nat/utils/exception_handlers/automatic_retries.py +342 -0
  450. nat/utils/exception_handlers/schemas.py +114 -0
  451. nat/utils/io/__init__.py +0 -0
  452. nat/utils/io/model_processing.py +28 -0
  453. nat/utils/io/yaml_tools.py +119 -0
  454. nat/utils/log_levels.py +25 -0
  455. nat/utils/log_utils.py +37 -0
  456. nat/utils/metadata_utils.py +74 -0
  457. nat/utils/optional_imports.py +142 -0
  458. nat/utils/producer_consumer_queue.py +178 -0
  459. nat/utils/reactive/__init__.py +0 -0
  460. nat/utils/reactive/base/__init__.py +0 -0
  461. nat/utils/reactive/base/observable_base.py +65 -0
  462. nat/utils/reactive/base/observer_base.py +55 -0
  463. nat/utils/reactive/base/subject_base.py +79 -0
  464. nat/utils/reactive/observable.py +59 -0
  465. nat/utils/reactive/observer.py +76 -0
  466. nat/utils/reactive/subject.py +131 -0
  467. nat/utils/reactive/subscription.py +49 -0
  468. nat/utils/settings/__init__.py +0 -0
  469. nat/utils/settings/global_settings.py +195 -0
  470. nat/utils/string_utils.py +38 -0
  471. nat/utils/type_converter.py +299 -0
  472. nat/utils/type_utils.py +488 -0
  473. nat/utils/url_utils.py +27 -0
  474. nvidia_nat-1.1.0a20251020.dist-info/METADATA +195 -0
  475. nvidia_nat-1.1.0a20251020.dist-info/RECORD +480 -0
  476. nvidia_nat-1.1.0a20251020.dist-info/WHEEL +5 -0
  477. nvidia_nat-1.1.0a20251020.dist-info/entry_points.txt +22 -0
  478. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  479. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
  480. nvidia_nat-1.1.0a20251020.dist-info/top_level.txt +2 -0
@@ -0,0 +1,593 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import json
18
+ import logging
19
+ import re
20
+ from json import JSONDecodeError
21
+ from typing import Any
22
+
23
+ from langchain_core.callbacks.base import AsyncCallbackHandler
24
+ from langchain_core.language_models import BaseChatModel
25
+ from langchain_core.messages.ai import AIMessage
26
+ from langchain_core.messages.base import BaseMessage
27
+ from langchain_core.messages.human import HumanMessage
28
+ from langchain_core.messages.tool import ToolMessage
29
+ from langchain_core.prompts.chat import ChatPromptTemplate
30
+ from langchain_core.runnables.config import RunnableConfig
31
+ from langchain_core.tools import BaseTool
32
+ from langgraph.graph import StateGraph
33
+ from langgraph.graph.state import CompiledStateGraph
34
+ from pydantic import BaseModel
35
+ from pydantic import Field
36
+
37
+ from nat.agent.base import AGENT_CALL_LOG_MESSAGE
38
+ from nat.agent.base import AGENT_LOG_PREFIX
39
+ from nat.agent.base import INPUT_SCHEMA_MESSAGE
40
+ from nat.agent.base import NO_INPUT_ERROR_MESSAGE
41
+ from nat.agent.base import TOOL_NOT_FOUND_ERROR_MESSAGE
42
+ from nat.agent.base import AgentDecision
43
+ from nat.agent.base import BaseAgent
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ class ReWOOEvidence(BaseModel):
49
+ placeholder: str
50
+ tool: str
51
+ tool_input: Any
52
+
53
+
54
+ class ReWOOPlanStep(BaseModel):
55
+ plan: str
56
+ evidence: ReWOOEvidence
57
+
58
+
59
+ class ReWOOGraphState(BaseModel):
60
+ """State schema for the ReWOO Agent Graph"""
61
+ messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReWOO Agent
62
+ task: HumanMessage = Field(default_factory=lambda: HumanMessage(content="")) # the task provided by user
63
+ plan: AIMessage = Field(
64
+ default_factory=lambda: AIMessage(content="")) # the plan generated by the planner to solve the task
65
+ steps: AIMessage = Field(
66
+ default_factory=lambda: AIMessage(content="")) # the steps to solve the task, parsed from the plan
67
+ # New fields for parallel execution support
68
+ evidence_map: dict[str, ReWOOPlanStep] = Field(default_factory=dict) # mapping from placeholders to step info
69
+ execution_levels: list[list[str]] = Field(default_factory=list) # levels for parallel execution
70
+ current_level: int = Field(default=0) # current execution level
71
+ intermediate_results: dict[str, ToolMessage] = Field(default_factory=dict) # the intermediate results of each step
72
+ result: AIMessage = Field(
73
+ default_factory=lambda: AIMessage(content="")) # the final result of the task, generated by the solver
74
+
75
+
76
+ class ReWOOAgentGraph(BaseAgent):
77
+ """Configurable ReWOO Agent.
78
+
79
+ Args:
80
+ detailed_logs: Toggles logging of inputs, outputs, and intermediate steps.
81
+ """
82
+
83
+ def __init__(self,
84
+ llm: BaseChatModel,
85
+ planner_prompt: ChatPromptTemplate,
86
+ solver_prompt: ChatPromptTemplate,
87
+ tools: list[BaseTool],
88
+ use_tool_schema: bool = True,
89
+ callbacks: list[AsyncCallbackHandler] | None = None,
90
+ detailed_logs: bool = False,
91
+ log_response_max_chars: int = 1000,
92
+ tool_call_max_retries: int = 3,
93
+ raise_tool_call_error: bool = True):
94
+ super().__init__(llm=llm,
95
+ tools=tools,
96
+ callbacks=callbacks,
97
+ detailed_logs=detailed_logs,
98
+ log_response_max_chars=log_response_max_chars)
99
+
100
+ logger.debug(
101
+ "%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
102
+ AGENT_LOG_PREFIX)
103
+
104
+ def describe_tool(tool: BaseTool) -> str:
105
+ description = f"{tool.name}: {tool.description}"
106
+ if use_tool_schema:
107
+ description += f". {INPUT_SCHEMA_MESSAGE.format(schema=tool.input_schema.model_fields)}"
108
+ return description
109
+
110
+ tool_names = ",".join(tool.name for tool in tools)
111
+ tool_names_and_descriptions = "\n".join(describe_tool(tool) for tool in tools)
112
+
113
+ self.planner_prompt = planner_prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
114
+ self.solver_prompt = solver_prompt
115
+ self.tools_dict = {tool.name: tool for tool in tools}
116
+ self.tool_call_max_retries = tool_call_max_retries
117
+ self.raise_tool_call_error = raise_tool_call_error
118
+
119
+ logger.debug("%s Initialized ReWOO Agent Graph", AGENT_LOG_PREFIX)
120
+
121
+ def _get_tool(self, tool_name: str):
122
+ try:
123
+ return self.tools_dict.get(tool_name)
124
+ except Exception as ex:
125
+ logger.error("%s Unable to find tool with the name %s\n%s", AGENT_LOG_PREFIX, tool_name, ex)
126
+ raise
127
+
128
+ @staticmethod
129
+ def _get_current_level_status(state: ReWOOGraphState) -> tuple[int, bool]:
130
+ """
131
+ Get the current execution level and whether it's complete.
132
+
133
+ Args:
134
+ state: The ReWOO graph state.
135
+
136
+ Returns:
137
+ tuple of (current_level, is_complete). Level -1 means all execution is complete.
138
+ """
139
+ if not state.execution_levels:
140
+ return -1, True
141
+
142
+ current_level = state.current_level
143
+
144
+ # Check if we've completed all levels
145
+ if current_level >= len(state.execution_levels):
146
+ return -1, True
147
+
148
+ # Check if current level is complete
149
+ current_level_placeholders = state.execution_levels[current_level]
150
+ level_complete = all(placeholder in state.intermediate_results for placeholder in current_level_placeholders)
151
+
152
+ return current_level, level_complete
153
+
154
+ @staticmethod
155
+ def _parse_planner_output(planner_output: str) -> list[ReWOOPlanStep]:
156
+
157
+ try:
158
+ return [ReWOOPlanStep(**step) for step in json.loads(planner_output)]
159
+ except Exception as ex:
160
+ raise ValueError(f"The output of planner is invalid JSON format: {planner_output}") from ex
161
+
162
+ @staticmethod
163
+ def _parse_planner_dependencies(steps: list[ReWOOPlanStep]) -> tuple[dict[str, ReWOOPlanStep], list[list[str]]]:
164
+ """
165
+ Parse planner steps to identify dependencies and create execution levels for parallel processing.
166
+ This creates a dependency map and identifies which evidence placeholders can be executed in parallel.
167
+
168
+ Args:
169
+ steps: list of plan steps from the planner.
170
+
171
+ Returns:
172
+ A mapping from evidence placeholders to step info and execution levels for parallel processing.
173
+ """
174
+ # First pass: collect all evidence placeholders and their info
175
+ evidences: dict[str, ReWOOPlanStep] = {
176
+ step.evidence.placeholder: step
177
+ for step in steps if step.evidence and step.evidence.placeholder
178
+ }
179
+
180
+ # Second pass: find dependencies now that we have all placeholders
181
+ dependencies = {
182
+ step.evidence.placeholder: [
183
+ var for var in re.findall(r"#E\d+", str(step.evidence.tool_input))
184
+ if var in evidences and var != step.evidence.placeholder
185
+ ]
186
+ for step in steps if step.evidence and step.evidence.placeholder
187
+ }
188
+
189
+ # Create execution levels using topological sort
190
+ levels: list[list[str]] = []
191
+ remaining = dict(dependencies)
192
+
193
+ while remaining:
194
+ # Find items with no dependencies (can be executed in parallel)
195
+ ready = [placeholder for placeholder, deps in remaining.items() if not deps]
196
+
197
+ if not ready:
198
+ raise ValueError("Circular dependency detected in planner output")
199
+
200
+ levels.append(ready)
201
+
202
+ # Remove completed items from remaining
203
+ for placeholder in ready:
204
+ remaining.pop(placeholder)
205
+
206
+ # Remove completed items from other dependencies
207
+ for ph, deps in list(remaining.items()):
208
+ remaining[ph] = list(set(deps) - set(ready))
209
+ return evidences, levels
210
+
211
+ @staticmethod
212
+ def _replace_placeholder(placeholder: str, tool_input: str | dict, tool_output: str | dict) -> str | dict:
213
+
214
+ # Replace the placeholders in the tool input with the previous tool output
215
+ if isinstance(tool_input, dict):
216
+ for key, value in tool_input.items():
217
+ if value is not None:
218
+ if value == placeholder:
219
+ tool_input[key] = tool_output
220
+ elif placeholder in value:
221
+ # If the placeholder is part of the value, replace it with the stringified output
222
+ tool_input[key] = value.replace(placeholder, str(tool_output))
223
+
224
+ elif isinstance(tool_input, str):
225
+ tool_input = tool_input.replace(placeholder, str(tool_output))
226
+
227
+ else:
228
+ assert False, f"Unexpected type for tool_input: {type(tool_input)}"
229
+
230
+ return tool_input
231
+
232
+ @staticmethod
233
+ def _parse_tool_input(tool_input: str | dict):
234
+
235
+ # If the input is already a dictionary, return it as is
236
+ if isinstance(tool_input, dict):
237
+ logger.debug("%s Tool input is already a dictionary. Use the tool input as is.", AGENT_LOG_PREFIX)
238
+ return tool_input
239
+
240
+ # If the input is a string, attempt to parse it as JSON
241
+ try:
242
+ tool_input = tool_input.strip()
243
+ # If the input is already a valid JSON string, load it
244
+ tool_input_parsed = json.loads(tool_input)
245
+ logger.debug("%s Successfully parsed structured tool input", AGENT_LOG_PREFIX)
246
+
247
+ except JSONDecodeError:
248
+ try:
249
+ # Replace single quotes with double quotes and attempt parsing again
250
+ tool_input_fixed = tool_input.replace("'", '"')
251
+ tool_input_parsed = json.loads(tool_input_fixed)
252
+ logger.debug(
253
+ "%s Successfully parsed structured tool input after replacing single quotes with double quotes",
254
+ AGENT_LOG_PREFIX)
255
+
256
+ except JSONDecodeError:
257
+ # If it still fails, fall back to using the input as a raw string
258
+ tool_input_parsed = tool_input
259
+ logger.debug("%s Unable to parse structured tool input. Using raw tool input as is.", AGENT_LOG_PREFIX)
260
+
261
+ return tool_input_parsed
262
+
263
+ async def planner_node(self, state: ReWOOGraphState):
264
+ try:
265
+ logger.debug("%s Starting the ReWOO Planner Node", AGENT_LOG_PREFIX)
266
+
267
+ planner = self.planner_prompt | self.llm
268
+ task = str(state.task.content)
269
+ if not task:
270
+ logger.error("%s No task provided to the ReWOO Agent. Please provide a valid task.", AGENT_LOG_PREFIX)
271
+ return {"result": NO_INPUT_ERROR_MESSAGE}
272
+ chat_history = self._get_chat_history(state.messages)
273
+ plan = await self._stream_llm(
274
+ planner,
275
+ {
276
+ "task": task, "chat_history": chat_history
277
+ },
278
+ RunnableConfig(callbacks=self.callbacks) # type: ignore
279
+ )
280
+
281
+ steps = self._parse_planner_output(str(plan.content))
282
+
283
+ # Parse dependencies and create execution levels for parallel processing
284
+ evidence_map, execution_levels = self._parse_planner_dependencies(steps)
285
+
286
+ if self.detailed_logs:
287
+ agent_response_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(plan.content))
288
+ logger.info("ReWOO agent planner output: %s", agent_response_log_message)
289
+ logger.info("ReWOO agent execution levels: %s", execution_levels)
290
+
291
+ return {
292
+ "plan": plan,
293
+ "evidence_map": evidence_map,
294
+ "execution_levels": execution_levels,
295
+ "current_level": 0,
296
+ }
297
+
298
+ except Exception as ex:
299
+ logger.error("%s Failed to call planner_node: %s", AGENT_LOG_PREFIX, ex)
300
+ raise
301
+
302
+ async def executor_node(self, state: ReWOOGraphState):
303
+ """
304
+ Execute tools in parallel for the current dependency level.
305
+
306
+ This replaces the sequential execution with parallel execution of tools
307
+ that have no dependencies between them.
308
+ """
309
+ try:
310
+ logger.debug("%s Starting the ReWOO Executor Node", AGENT_LOG_PREFIX)
311
+
312
+ current_level, level_complete = self._get_current_level_status(state)
313
+
314
+ # Should not be invoked if all levels are complete
315
+ if current_level < 0:
316
+ logger.error("%s ReWOO Executor invoked after all levels complete", AGENT_LOG_PREFIX)
317
+ raise RuntimeError("ReWOO Executor invoked after all levels complete")
318
+
319
+ # If current level is already complete, move to next level
320
+ if level_complete:
321
+ new_level = current_level + 1
322
+ logger.debug("%s Level %s complete, moving to level %s", AGENT_LOG_PREFIX, current_level, new_level)
323
+ return {"current_level": new_level}
324
+
325
+ # Get placeholders for current level
326
+ current_level_placeholders = state.execution_levels[current_level]
327
+
328
+ # Filter to only placeholders not yet completed
329
+ pending_placeholders = list(set(current_level_placeholders) - set(state.intermediate_results.keys()))
330
+
331
+ if not pending_placeholders:
332
+ # All placeholders in this level are done, move to next level
333
+ new_level = current_level + 1
334
+ return {"current_level": new_level}
335
+
336
+ logger.debug("%s Executing level %s with %s tools in parallel: %s",
337
+ AGENT_LOG_PREFIX,
338
+ current_level,
339
+ len(pending_placeholders),
340
+ pending_placeholders)
341
+
342
+ # Execute all tools in current level in parallel
343
+ tasks = []
344
+ for placeholder in pending_placeholders:
345
+ step_info = state.evidence_map[placeholder]
346
+ task = self._execute_single_tool(placeholder, step_info, state.intermediate_results)
347
+ tasks.append(task)
348
+
349
+ # Wait for all tasks in current level to complete
350
+ results = await asyncio.gather(*tasks, return_exceptions=True)
351
+
352
+ # Process results and update intermediate_results
353
+ updated_intermediate_results = dict(state.intermediate_results)
354
+
355
+ for placeholder, result in zip(pending_placeholders, results):
356
+ if isinstance(result, BaseException):
357
+ logger.error("%s Tool execution failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result)
358
+ # Create error tool message
359
+ error_message = f"Tool execution failed: {str(result)}"
360
+ updated_intermediate_results[placeholder] = ToolMessage(content=error_message,
361
+ tool_call_id=placeholder)
362
+ if self.raise_tool_call_error:
363
+ raise result
364
+ else:
365
+ updated_intermediate_results[placeholder] = result
366
+ # Check if the ToolMessage has error status and raise_tool_call_error is True
367
+ if (isinstance(result, ToolMessage) and hasattr(result, 'status') and result.status == "error"
368
+ and self.raise_tool_call_error):
369
+ logger.error("%s Tool call failed for %s: %s", AGENT_LOG_PREFIX, placeholder, result.content)
370
+ raise RuntimeError(f"Tool call failed: {result.content}")
371
+
372
+ if self.detailed_logs:
373
+ logger.info("%s Completed level %s with %s tools",
374
+ AGENT_LOG_PREFIX,
375
+ current_level,
376
+ len(pending_placeholders))
377
+
378
+ return {"intermediate_results": updated_intermediate_results}
379
+
380
+ except Exception as ex:
381
+ logger.error("%s Failed to call executor_node: %s", AGENT_LOG_PREFIX, ex)
382
+ raise
383
+
384
+ async def _execute_single_tool(self,
385
+ placeholder: str,
386
+ step_info: ReWOOPlanStep,
387
+ intermediate_results: dict[str, ToolMessage]) -> ToolMessage:
388
+ """
389
+ Execute a single tool with proper placeholder replacement.
390
+
391
+ Args:
392
+ placeholder: The evidence placeholder (e.g., "#E1").
393
+ step_info: Step information containing tool and tool_input.
394
+ intermediate_results: Current intermediate results for placeholder replacement.
395
+
396
+ Returns:
397
+ ToolMessage with the tool execution result.
398
+ """
399
+ evidence_info = step_info.evidence
400
+ tool_name = evidence_info.tool
401
+ tool_input = evidence_info.tool_input
402
+
403
+ # Replace placeholders in tool input with previous results
404
+ for ph_key, tool_output in intermediate_results.items():
405
+ tool_output_content = tool_output.content
406
+ # If the content is a list, get the first element which should be a dict
407
+ if isinstance(tool_output_content, list):
408
+ tool_output_content = tool_output_content[0]
409
+ assert isinstance(tool_output_content, dict)
410
+
411
+ tool_input = self._replace_placeholder(ph_key, tool_input, tool_output_content)
412
+
413
+ # Get the requested tool
414
+ requested_tool = self._get_tool(tool_name)
415
+ if not requested_tool:
416
+ configured_tool_names = list(self.tools_dict.keys())
417
+ logger.warning(
418
+ "%s ReWOO Agent wants to call tool %s. In the ReWOO Agent's configuration within the config file,"
419
+ "there is no tool with that name: %s",
420
+ AGENT_LOG_PREFIX,
421
+ tool_name,
422
+ configured_tool_names)
423
+
424
+ return ToolMessage(content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(tool_name=tool_name,
425
+ tools=configured_tool_names),
426
+ tool_call_id=placeholder)
427
+
428
+ if self.detailed_logs:
429
+ logger.debug("%s Calling tool %s with input: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_input)
430
+
431
+ # Parse and execute the tool
432
+ tool_input_parsed = self._parse_tool_input(tool_input)
433
+ tool_response = await self._call_tool(
434
+ requested_tool,
435
+ tool_input_parsed,
436
+ RunnableConfig(callbacks=self.callbacks), # type: ignore
437
+ max_retries=self.tool_call_max_retries)
438
+
439
+ if self.detailed_logs:
440
+ self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))
441
+
442
+ return tool_response
443
+
444
+ async def solver_node(self, state: ReWOOGraphState):
445
+ try:
446
+ logger.debug("%s Starting the ReWOO Solver Node", AGENT_LOG_PREFIX)
447
+
448
+ plan = ""
449
+ # Add the tool outputs of each step to the plan using evidence_map
450
+ for placeholder, step_info in state.evidence_map.items():
451
+ evidence_info = step_info.evidence
452
+ original_tool_input = evidence_info.tool_input
453
+ tool_name = evidence_info.tool
454
+
455
+ # Replace placeholders in tool input with actual results
456
+ final_tool_input = original_tool_input
457
+ for ph_key, tool_output in state.intermediate_results.items():
458
+ tool_output_content = tool_output.content
459
+ # If the content is a list, get the first element which should be a dict
460
+ if isinstance(tool_output_content, list):
461
+ tool_output_content = tool_output_content[0]
462
+ assert isinstance(tool_output_content, dict)
463
+
464
+ final_tool_input = self._replace_placeholder(ph_key, final_tool_input, tool_output_content)
465
+
466
+ # Get the final result for this placeholder
467
+ final_result = ""
468
+ if placeholder in state.intermediate_results:
469
+ result_content = state.intermediate_results[placeholder].content
470
+ if isinstance(result_content, list):
471
+ result_content = result_content[0]
472
+ if isinstance(result_content, dict):
473
+ final_result = str(result_content)
474
+ else:
475
+ final_result = str(result_content)
476
+
477
+ step_plan = step_info.plan
478
+ plan += '\n'.join([
479
+ f"Plan: {step_plan}",
480
+ f"{placeholder} = {tool_name}[{final_tool_input}",
481
+ f"Result: {final_result}\n\n"
482
+ ])
483
+
484
+ task = str(state.task.content)
485
+ solver_prompt = self.solver_prompt.partial(plan=plan)
486
+ solver = solver_prompt | self.llm
487
+
488
+ output_message = await self._stream_llm(solver, {"task": task},
489
+ RunnableConfig(callbacks=self.callbacks)) # type: ignore
490
+
491
+ if self.detailed_logs:
492
+ solver_output_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(output_message.content))
493
+ logger.info("ReWOO agent solver output: %s", solver_output_log_message)
494
+
495
+ return {"result": output_message}
496
+
497
+ except Exception as ex:
498
+ logger.error("%s Failed to call solver_node: %s", AGENT_LOG_PREFIX, ex)
499
+ raise
500
+
501
+ async def conditional_edge(self, state: ReWOOGraphState):
502
+ try:
503
+ logger.debug("%s Starting the ReWOO Conditional Edge", AGENT_LOG_PREFIX)
504
+
505
+ current_level, level_complete = self._get_current_level_status(state)
506
+
507
+ # If all levels are complete, move to solver
508
+ if current_level == -1:
509
+ logger.debug("%s All execution levels complete, moving to solver", AGENT_LOG_PREFIX)
510
+ return AgentDecision.END
511
+
512
+ # If current level is complete, check if there are more levels
513
+ if level_complete:
514
+ next_level = current_level + 1
515
+ if next_level >= len(state.execution_levels):
516
+ logger.debug("%s All execution levels complete, moving to solver", AGENT_LOG_PREFIX)
517
+ return AgentDecision.END
518
+
519
+ logger.debug("%s Continuing with executor (level %s, complete: %s)",
520
+ AGENT_LOG_PREFIX,
521
+ current_level,
522
+ level_complete)
523
+ return AgentDecision.TOOL
524
+
525
+ except Exception as ex:
526
+ logger.exception("%s Failed to determine whether agent is calling a tool: %s", AGENT_LOG_PREFIX, ex)
527
+ logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
528
+ return AgentDecision.END
529
+
530
+ async def _build_graph(self, state_schema: type) -> CompiledStateGraph:
531
+ try:
532
+ logger.debug("%s Building and compiling the ReWOO Graph", AGENT_LOG_PREFIX)
533
+
534
+ graph = StateGraph(state_schema)
535
+ graph.add_node("planner", self.planner_node)
536
+ graph.add_node("executor", self.executor_node)
537
+ graph.add_node("solver", self.solver_node)
538
+
539
+ graph.add_edge("planner", "executor")
540
+ graph.add_conditional_edges("executor",
541
+ self.conditional_edge, {
542
+ AgentDecision.TOOL: "executor", AgentDecision.END: "solver"
543
+ })
544
+
545
+ graph.set_entry_point("planner")
546
+ graph.set_finish_point("solver")
547
+
548
+ self.graph = graph.compile()
549
+ logger.debug("%s ReWOO Graph built and compiled successfully", AGENT_LOG_PREFIX)
550
+
551
+ return self.graph
552
+
553
+ except Exception as ex:
554
+ logger.error("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex)
555
+ raise
556
+
557
+ async def build_graph(self):
558
+ try:
559
+ await self._build_graph(state_schema=ReWOOGraphState)
560
+ logger.debug("%s ReWOO Graph built and compiled successfully", AGENT_LOG_PREFIX)
561
+ return self.graph
562
+ except Exception as ex:
563
+ logger.error("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex)
564
+ raise
565
+
566
+ @staticmethod
567
+ def validate_planner_prompt(planner_prompt: str) -> bool:
568
+ errors = []
569
+ if not planner_prompt:
570
+ errors.append("The planner prompt cannot be empty.")
571
+ required_prompt_variables = {
572
+ "{tools}": "The planner prompt must contain {tools} so the planner agent knows about configured tools.",
573
+ "{tool_names}": "The planner prompt must contain {tool_names} so the planner agent knows tool names."
574
+ }
575
+ for variable_name, error_message in required_prompt_variables.items():
576
+ if variable_name not in planner_prompt:
577
+ errors.append(error_message)
578
+ if errors:
579
+ error_text = "\n".join(errors)
580
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
581
+ raise ValueError(error_text)
582
+ return True
583
+
584
+ @staticmethod
585
+ def validate_solver_prompt(solver_prompt: str) -> bool:
586
+ errors = []
587
+ if not solver_prompt:
588
+ errors.append("The solver prompt cannot be empty.")
589
+ if errors:
590
+ error_text = "\n".join(errors)
591
+ logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
592
+ raise ValueError(error_text)
593
+ return True