nvidia-nat 1.1.0a20251020__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (480) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/__init__.py +0 -0
  3. nat/agent/base.py +265 -0
  4. nat/agent/dual_node.py +72 -0
  5. nat/agent/prompt_optimizer/__init__.py +0 -0
  6. nat/agent/prompt_optimizer/prompt.py +68 -0
  7. nat/agent/prompt_optimizer/register.py +149 -0
  8. nat/agent/react_agent/__init__.py +0 -0
  9. nat/agent/react_agent/agent.py +394 -0
  10. nat/agent/react_agent/output_parser.py +104 -0
  11. nat/agent/react_agent/prompt.py +44 -0
  12. nat/agent/react_agent/register.py +168 -0
  13. nat/agent/reasoning_agent/__init__.py +0 -0
  14. nat/agent/reasoning_agent/reasoning_agent.py +227 -0
  15. nat/agent/register.py +23 -0
  16. nat/agent/rewoo_agent/__init__.py +0 -0
  17. nat/agent/rewoo_agent/agent.py +593 -0
  18. nat/agent/rewoo_agent/prompt.py +107 -0
  19. nat/agent/rewoo_agent/register.py +175 -0
  20. nat/agent/tool_calling_agent/__init__.py +0 -0
  21. nat/agent/tool_calling_agent/agent.py +246 -0
  22. nat/agent/tool_calling_agent/register.py +129 -0
  23. nat/authentication/__init__.py +14 -0
  24. nat/authentication/api_key/__init__.py +14 -0
  25. nat/authentication/api_key/api_key_auth_provider.py +96 -0
  26. nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
  27. nat/authentication/api_key/register.py +26 -0
  28. nat/authentication/credential_validator/__init__.py +14 -0
  29. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  30. nat/authentication/exceptions/__init__.py +14 -0
  31. nat/authentication/exceptions/api_key_exceptions.py +38 -0
  32. nat/authentication/http_basic_auth/__init__.py +0 -0
  33. nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  34. nat/authentication/http_basic_auth/register.py +30 -0
  35. nat/authentication/interfaces.py +96 -0
  36. nat/authentication/oauth2/__init__.py +14 -0
  37. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +140 -0
  38. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  39. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  40. nat/authentication/oauth2/register.py +25 -0
  41. nat/authentication/register.py +20 -0
  42. nat/builder/__init__.py +0 -0
  43. nat/builder/builder.py +317 -0
  44. nat/builder/component_utils.py +320 -0
  45. nat/builder/context.py +321 -0
  46. nat/builder/embedder.py +24 -0
  47. nat/builder/eval_builder.py +166 -0
  48. nat/builder/evaluator.py +29 -0
  49. nat/builder/framework_enum.py +25 -0
  50. nat/builder/front_end.py +73 -0
  51. nat/builder/function.py +714 -0
  52. nat/builder/function_base.py +380 -0
  53. nat/builder/function_info.py +625 -0
  54. nat/builder/intermediate_step_manager.py +206 -0
  55. nat/builder/llm.py +25 -0
  56. nat/builder/retriever.py +25 -0
  57. nat/builder/user_interaction_manager.py +78 -0
  58. nat/builder/workflow.py +160 -0
  59. nat/builder/workflow_builder.py +1365 -0
  60. nat/cli/__init__.py +14 -0
  61. nat/cli/cli_utils/__init__.py +0 -0
  62. nat/cli/cli_utils/config_override.py +231 -0
  63. nat/cli/cli_utils/validation.py +37 -0
  64. nat/cli/commands/__init__.py +0 -0
  65. nat/cli/commands/configure/__init__.py +0 -0
  66. nat/cli/commands/configure/channel/__init__.py +0 -0
  67. nat/cli/commands/configure/channel/add.py +28 -0
  68. nat/cli/commands/configure/channel/channel.py +34 -0
  69. nat/cli/commands/configure/channel/remove.py +30 -0
  70. nat/cli/commands/configure/channel/update.py +30 -0
  71. nat/cli/commands/configure/configure.py +33 -0
  72. nat/cli/commands/evaluate.py +139 -0
  73. nat/cli/commands/info/__init__.py +14 -0
  74. nat/cli/commands/info/info.py +47 -0
  75. nat/cli/commands/info/list_channels.py +32 -0
  76. nat/cli/commands/info/list_components.py +128 -0
  77. nat/cli/commands/mcp/__init__.py +14 -0
  78. nat/cli/commands/mcp/mcp.py +986 -0
  79. nat/cli/commands/object_store/__init__.py +14 -0
  80. nat/cli/commands/object_store/object_store.py +227 -0
  81. nat/cli/commands/optimize.py +90 -0
  82. nat/cli/commands/registry/__init__.py +14 -0
  83. nat/cli/commands/registry/publish.py +88 -0
  84. nat/cli/commands/registry/pull.py +118 -0
  85. nat/cli/commands/registry/registry.py +36 -0
  86. nat/cli/commands/registry/remove.py +108 -0
  87. nat/cli/commands/registry/search.py +153 -0
  88. nat/cli/commands/sizing/__init__.py +14 -0
  89. nat/cli/commands/sizing/calc.py +297 -0
  90. nat/cli/commands/sizing/sizing.py +27 -0
  91. nat/cli/commands/start.py +257 -0
  92. nat/cli/commands/uninstall.py +81 -0
  93. nat/cli/commands/validate.py +47 -0
  94. nat/cli/commands/workflow/__init__.py +14 -0
  95. nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  96. nat/cli/commands/workflow/templates/config.yml.j2 +17 -0
  97. nat/cli/commands/workflow/templates/pyproject.toml.j2 +25 -0
  98. nat/cli/commands/workflow/templates/register.py.j2 +4 -0
  99. nat/cli/commands/workflow/templates/workflow.py.j2 +50 -0
  100. nat/cli/commands/workflow/workflow.py +37 -0
  101. nat/cli/commands/workflow/workflow_commands.py +403 -0
  102. nat/cli/entrypoint.py +141 -0
  103. nat/cli/main.py +60 -0
  104. nat/cli/register_workflow.py +522 -0
  105. nat/cli/type_registry.py +1069 -0
  106. nat/control_flow/__init__.py +0 -0
  107. nat/control_flow/register.py +20 -0
  108. nat/control_flow/router_agent/__init__.py +0 -0
  109. nat/control_flow/router_agent/agent.py +329 -0
  110. nat/control_flow/router_agent/prompt.py +48 -0
  111. nat/control_flow/router_agent/register.py +91 -0
  112. nat/control_flow/sequential_executor.py +166 -0
  113. nat/data_models/__init__.py +14 -0
  114. nat/data_models/agent.py +34 -0
  115. nat/data_models/api_server.py +843 -0
  116. nat/data_models/authentication.py +245 -0
  117. nat/data_models/common.py +171 -0
  118. nat/data_models/component.py +60 -0
  119. nat/data_models/component_ref.py +179 -0
  120. nat/data_models/config.py +434 -0
  121. nat/data_models/dataset_handler.py +169 -0
  122. nat/data_models/discovery_metadata.py +305 -0
  123. nat/data_models/embedder.py +27 -0
  124. nat/data_models/evaluate.py +130 -0
  125. nat/data_models/evaluator.py +26 -0
  126. nat/data_models/front_end.py +26 -0
  127. nat/data_models/function.py +64 -0
  128. nat/data_models/function_dependencies.py +80 -0
  129. nat/data_models/gated_field_mixin.py +242 -0
  130. nat/data_models/interactive.py +246 -0
  131. nat/data_models/intermediate_step.py +302 -0
  132. nat/data_models/invocation_node.py +38 -0
  133. nat/data_models/llm.py +27 -0
  134. nat/data_models/logging.py +26 -0
  135. nat/data_models/memory.py +27 -0
  136. nat/data_models/object_store.py +44 -0
  137. nat/data_models/optimizable.py +119 -0
  138. nat/data_models/optimizer.py +149 -0
  139. nat/data_models/profiler.py +54 -0
  140. nat/data_models/registry_handler.py +26 -0
  141. nat/data_models/retriever.py +30 -0
  142. nat/data_models/retry_mixin.py +35 -0
  143. nat/data_models/span.py +228 -0
  144. nat/data_models/step_adaptor.py +64 -0
  145. nat/data_models/streaming.py +33 -0
  146. nat/data_models/swe_bench_model.py +54 -0
  147. nat/data_models/telemetry_exporter.py +26 -0
  148. nat/data_models/temperature_mixin.py +44 -0
  149. nat/data_models/thinking_mixin.py +86 -0
  150. nat/data_models/top_p_mixin.py +44 -0
  151. nat/data_models/ttc_strategy.py +30 -0
  152. nat/embedder/__init__.py +0 -0
  153. nat/embedder/azure_openai_embedder.py +46 -0
  154. nat/embedder/nim_embedder.py +59 -0
  155. nat/embedder/openai_embedder.py +42 -0
  156. nat/embedder/register.py +22 -0
  157. nat/eval/__init__.py +14 -0
  158. nat/eval/config.py +62 -0
  159. nat/eval/dataset_handler/__init__.py +0 -0
  160. nat/eval/dataset_handler/dataset_downloader.py +106 -0
  161. nat/eval/dataset_handler/dataset_filter.py +52 -0
  162. nat/eval/dataset_handler/dataset_handler.py +431 -0
  163. nat/eval/evaluate.py +565 -0
  164. nat/eval/evaluator/__init__.py +14 -0
  165. nat/eval/evaluator/base_evaluator.py +77 -0
  166. nat/eval/evaluator/evaluator_model.py +58 -0
  167. nat/eval/intermediate_step_adapter.py +99 -0
  168. nat/eval/rag_evaluator/__init__.py +0 -0
  169. nat/eval/rag_evaluator/evaluate.py +178 -0
  170. nat/eval/rag_evaluator/register.py +143 -0
  171. nat/eval/register.py +26 -0
  172. nat/eval/remote_workflow.py +133 -0
  173. nat/eval/runners/__init__.py +14 -0
  174. nat/eval/runners/config.py +39 -0
  175. nat/eval/runners/multi_eval_runner.py +54 -0
  176. nat/eval/runtime_evaluator/__init__.py +14 -0
  177. nat/eval/runtime_evaluator/evaluate.py +123 -0
  178. nat/eval/runtime_evaluator/register.py +100 -0
  179. nat/eval/runtime_event_subscriber.py +52 -0
  180. nat/eval/swe_bench_evaluator/__init__.py +0 -0
  181. nat/eval/swe_bench_evaluator/evaluate.py +215 -0
  182. nat/eval/swe_bench_evaluator/register.py +36 -0
  183. nat/eval/trajectory_evaluator/__init__.py +0 -0
  184. nat/eval/trajectory_evaluator/evaluate.py +75 -0
  185. nat/eval/trajectory_evaluator/register.py +40 -0
  186. nat/eval/tunable_rag_evaluator/__init__.py +0 -0
  187. nat/eval/tunable_rag_evaluator/evaluate.py +242 -0
  188. nat/eval/tunable_rag_evaluator/register.py +52 -0
  189. nat/eval/usage_stats.py +41 -0
  190. nat/eval/utils/__init__.py +0 -0
  191. nat/eval/utils/eval_trace_ctx.py +89 -0
  192. nat/eval/utils/output_uploader.py +140 -0
  193. nat/eval/utils/tqdm_position_registry.py +40 -0
  194. nat/eval/utils/weave_eval.py +193 -0
  195. nat/experimental/__init__.py +0 -0
  196. nat/experimental/decorators/__init__.py +0 -0
  197. nat/experimental/decorators/experimental_warning_decorator.py +154 -0
  198. nat/experimental/test_time_compute/__init__.py +0 -0
  199. nat/experimental/test_time_compute/editing/__init__.py +0 -0
  200. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  201. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  202. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  203. nat/experimental/test_time_compute/functions/__init__.py +0 -0
  204. nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  205. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +228 -0
  206. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
  207. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
  208. nat/experimental/test_time_compute/models/__init__.py +0 -0
  209. nat/experimental/test_time_compute/models/editor_config.py +132 -0
  210. nat/experimental/test_time_compute/models/scoring_config.py +112 -0
  211. nat/experimental/test_time_compute/models/search_config.py +120 -0
  212. nat/experimental/test_time_compute/models/selection_config.py +154 -0
  213. nat/experimental/test_time_compute/models/stage_enums.py +43 -0
  214. nat/experimental/test_time_compute/models/strategy_base.py +67 -0
  215. nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
  216. nat/experimental/test_time_compute/models/ttc_item.py +48 -0
  217. nat/experimental/test_time_compute/register.py +35 -0
  218. nat/experimental/test_time_compute/scoring/__init__.py +0 -0
  219. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  220. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  221. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  222. nat/experimental/test_time_compute/search/__init__.py +0 -0
  223. nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  224. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  225. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  226. nat/experimental/test_time_compute/selection/__init__.py +0 -0
  227. nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  228. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  229. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +157 -0
  230. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  231. nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  232. nat/front_ends/__init__.py +14 -0
  233. nat/front_ends/console/__init__.py +14 -0
  234. nat/front_ends/console/authentication_flow_handler.py +285 -0
  235. nat/front_ends/console/console_front_end_config.py +32 -0
  236. nat/front_ends/console/console_front_end_plugin.py +108 -0
  237. nat/front_ends/console/register.py +25 -0
  238. nat/front_ends/cron/__init__.py +14 -0
  239. nat/front_ends/fastapi/__init__.py +14 -0
  240. nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  241. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  242. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +142 -0
  243. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  244. nat/front_ends/fastapi/fastapi_front_end_config.py +272 -0
  245. nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  246. nat/front_ends/fastapi/fastapi_front_end_plugin.py +247 -0
  247. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1257 -0
  248. nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
  249. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  250. nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  251. nat/front_ends/fastapi/job_store.py +602 -0
  252. nat/front_ends/fastapi/main.py +64 -0
  253. nat/front_ends/fastapi/message_handler.py +344 -0
  254. nat/front_ends/fastapi/message_validator.py +351 -0
  255. nat/front_ends/fastapi/register.py +25 -0
  256. nat/front_ends/fastapi/response_helpers.py +195 -0
  257. nat/front_ends/fastapi/step_adaptor.py +319 -0
  258. nat/front_ends/fastapi/utils.py +57 -0
  259. nat/front_ends/mcp/__init__.py +14 -0
  260. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  261. nat/front_ends/mcp/mcp_front_end_config.py +90 -0
  262. nat/front_ends/mcp/mcp_front_end_plugin.py +113 -0
  263. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +268 -0
  264. nat/front_ends/mcp/memory_profiler.py +320 -0
  265. nat/front_ends/mcp/register.py +27 -0
  266. nat/front_ends/mcp/tool_converter.py +290 -0
  267. nat/front_ends/register.py +21 -0
  268. nat/front_ends/simple_base/__init__.py +14 -0
  269. nat/front_ends/simple_base/simple_front_end_plugin_base.py +56 -0
  270. nat/llm/__init__.py +0 -0
  271. nat/llm/aws_bedrock_llm.py +69 -0
  272. nat/llm/azure_openai_llm.py +57 -0
  273. nat/llm/litellm_llm.py +69 -0
  274. nat/llm/nim_llm.py +58 -0
  275. nat/llm/openai_llm.py +54 -0
  276. nat/llm/register.py +27 -0
  277. nat/llm/utils/__init__.py +14 -0
  278. nat/llm/utils/env_config_value.py +93 -0
  279. nat/llm/utils/error.py +17 -0
  280. nat/llm/utils/thinking.py +215 -0
  281. nat/memory/__init__.py +20 -0
  282. nat/memory/interfaces.py +183 -0
  283. nat/memory/models.py +112 -0
  284. nat/meta/pypi.md +58 -0
  285. nat/object_store/__init__.py +20 -0
  286. nat/object_store/in_memory_object_store.py +76 -0
  287. nat/object_store/interfaces.py +84 -0
  288. nat/object_store/models.py +38 -0
  289. nat/object_store/register.py +19 -0
  290. nat/observability/__init__.py +14 -0
  291. nat/observability/exporter/__init__.py +14 -0
  292. nat/observability/exporter/base_exporter.py +449 -0
  293. nat/observability/exporter/exporter.py +78 -0
  294. nat/observability/exporter/file_exporter.py +33 -0
  295. nat/observability/exporter/processing_exporter.py +550 -0
  296. nat/observability/exporter/raw_exporter.py +52 -0
  297. nat/observability/exporter/span_exporter.py +308 -0
  298. nat/observability/exporter_manager.py +335 -0
  299. nat/observability/mixin/__init__.py +14 -0
  300. nat/observability/mixin/batch_config_mixin.py +26 -0
  301. nat/observability/mixin/collector_config_mixin.py +23 -0
  302. nat/observability/mixin/file_mixin.py +288 -0
  303. nat/observability/mixin/file_mode.py +23 -0
  304. nat/observability/mixin/redaction_config_mixin.py +42 -0
  305. nat/observability/mixin/resource_conflict_mixin.py +134 -0
  306. nat/observability/mixin/serialize_mixin.py +61 -0
  307. nat/observability/mixin/tagging_config_mixin.py +62 -0
  308. nat/observability/mixin/type_introspection_mixin.py +496 -0
  309. nat/observability/processor/__init__.py +14 -0
  310. nat/observability/processor/batching_processor.py +308 -0
  311. nat/observability/processor/callback_processor.py +42 -0
  312. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  313. nat/observability/processor/intermediate_step_serializer.py +28 -0
  314. nat/observability/processor/processor.py +74 -0
  315. nat/observability/processor/processor_factory.py +70 -0
  316. nat/observability/processor/redaction/__init__.py +24 -0
  317. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  318. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  319. nat/observability/processor/redaction/redaction_processor.py +177 -0
  320. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  321. nat/observability/processor/span_tagging_processor.py +68 -0
  322. nat/observability/register.py +114 -0
  323. nat/observability/utils/__init__.py +14 -0
  324. nat/observability/utils/dict_utils.py +236 -0
  325. nat/observability/utils/time_utils.py +31 -0
  326. nat/plugins/.namespace +1 -0
  327. nat/profiler/__init__.py +0 -0
  328. nat/profiler/calc/__init__.py +14 -0
  329. nat/profiler/calc/calc_runner.py +626 -0
  330. nat/profiler/calc/calculations.py +288 -0
  331. nat/profiler/calc/data_models.py +188 -0
  332. nat/profiler/calc/plot.py +345 -0
  333. nat/profiler/callbacks/__init__.py +0 -0
  334. nat/profiler/callbacks/agno_callback_handler.py +295 -0
  335. nat/profiler/callbacks/base_callback_class.py +20 -0
  336. nat/profiler/callbacks/langchain_callback_handler.py +297 -0
  337. nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
  338. nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  339. nat/profiler/callbacks/token_usage_base_model.py +27 -0
  340. nat/profiler/data_frame_row.py +51 -0
  341. nat/profiler/data_models.py +24 -0
  342. nat/profiler/decorators/__init__.py +0 -0
  343. nat/profiler/decorators/framework_wrapper.py +180 -0
  344. nat/profiler/decorators/function_tracking.py +411 -0
  345. nat/profiler/forecasting/__init__.py +0 -0
  346. nat/profiler/forecasting/config.py +18 -0
  347. nat/profiler/forecasting/model_trainer.py +75 -0
  348. nat/profiler/forecasting/models/__init__.py +22 -0
  349. nat/profiler/forecasting/models/forecasting_base_model.py +42 -0
  350. nat/profiler/forecasting/models/linear_model.py +197 -0
  351. nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
  352. nat/profiler/inference_metrics_model.py +28 -0
  353. nat/profiler/inference_optimization/__init__.py +0 -0
  354. nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  355. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  356. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  357. nat/profiler/inference_optimization/data_models.py +386 -0
  358. nat/profiler/inference_optimization/experimental/__init__.py +0 -0
  359. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  360. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +404 -0
  361. nat/profiler/inference_optimization/llm_metrics.py +212 -0
  362. nat/profiler/inference_optimization/prompt_caching.py +163 -0
  363. nat/profiler/inference_optimization/token_uniqueness.py +107 -0
  364. nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
  365. nat/profiler/intermediate_property_adapter.py +102 -0
  366. nat/profiler/parameter_optimization/__init__.py +0 -0
  367. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  368. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  369. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  370. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  371. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  372. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  373. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  374. nat/profiler/profile_runner.py +478 -0
  375. nat/profiler/utils.py +186 -0
  376. nat/registry_handlers/__init__.py +0 -0
  377. nat/registry_handlers/local/__init__.py +0 -0
  378. nat/registry_handlers/local/local_handler.py +176 -0
  379. nat/registry_handlers/local/register_local.py +37 -0
  380. nat/registry_handlers/metadata_factory.py +60 -0
  381. nat/registry_handlers/package_utils.py +570 -0
  382. nat/registry_handlers/pypi/__init__.py +0 -0
  383. nat/registry_handlers/pypi/pypi_handler.py +248 -0
  384. nat/registry_handlers/pypi/register_pypi.py +40 -0
  385. nat/registry_handlers/register.py +20 -0
  386. nat/registry_handlers/registry_handler_base.py +157 -0
  387. nat/registry_handlers/rest/__init__.py +0 -0
  388. nat/registry_handlers/rest/register_rest.py +56 -0
  389. nat/registry_handlers/rest/rest_handler.py +236 -0
  390. nat/registry_handlers/schemas/__init__.py +0 -0
  391. nat/registry_handlers/schemas/headers.py +42 -0
  392. nat/registry_handlers/schemas/package.py +68 -0
  393. nat/registry_handlers/schemas/publish.py +68 -0
  394. nat/registry_handlers/schemas/pull.py +82 -0
  395. nat/registry_handlers/schemas/remove.py +36 -0
  396. nat/registry_handlers/schemas/search.py +91 -0
  397. nat/registry_handlers/schemas/status.py +47 -0
  398. nat/retriever/__init__.py +0 -0
  399. nat/retriever/interface.py +41 -0
  400. nat/retriever/milvus/__init__.py +14 -0
  401. nat/retriever/milvus/register.py +81 -0
  402. nat/retriever/milvus/retriever.py +228 -0
  403. nat/retriever/models.py +77 -0
  404. nat/retriever/nemo_retriever/__init__.py +14 -0
  405. nat/retriever/nemo_retriever/register.py +60 -0
  406. nat/retriever/nemo_retriever/retriever.py +190 -0
  407. nat/retriever/register.py +21 -0
  408. nat/runtime/__init__.py +14 -0
  409. nat/runtime/loader.py +220 -0
  410. nat/runtime/runner.py +292 -0
  411. nat/runtime/session.py +223 -0
  412. nat/runtime/user_metadata.py +130 -0
  413. nat/settings/__init__.py +0 -0
  414. nat/settings/global_settings.py +329 -0
  415. nat/test/.namespace +1 -0
  416. nat/tool/__init__.py +0 -0
  417. nat/tool/chat_completion.py +77 -0
  418. nat/tool/code_execution/README.md +151 -0
  419. nat/tool/code_execution/__init__.py +0 -0
  420. nat/tool/code_execution/code_sandbox.py +267 -0
  421. nat/tool/code_execution/local_sandbox/.gitignore +1 -0
  422. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  423. nat/tool/code_execution/local_sandbox/__init__.py +13 -0
  424. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  425. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  426. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  427. nat/tool/code_execution/register.py +74 -0
  428. nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
  429. nat/tool/code_execution/utils.py +100 -0
  430. nat/tool/datetime_tools.py +82 -0
  431. nat/tool/document_search.py +141 -0
  432. nat/tool/github_tools.py +450 -0
  433. nat/tool/memory_tools/__init__.py +0 -0
  434. nat/tool/memory_tools/add_memory_tool.py +79 -0
  435. nat/tool/memory_tools/delete_memory_tool.py +66 -0
  436. nat/tool/memory_tools/get_memory_tool.py +72 -0
  437. nat/tool/nvidia_rag.py +95 -0
  438. nat/tool/register.py +31 -0
  439. nat/tool/retriever.py +95 -0
  440. nat/tool/server_tools.py +66 -0
  441. nat/utils/__init__.py +0 -0
  442. nat/utils/callable_utils.py +70 -0
  443. nat/utils/data_models/__init__.py +0 -0
  444. nat/utils/data_models/schema_validator.py +58 -0
  445. nat/utils/debugging_utils.py +43 -0
  446. nat/utils/decorators.py +210 -0
  447. nat/utils/dump_distro_mapping.py +32 -0
  448. nat/utils/exception_handlers/__init__.py +0 -0
  449. nat/utils/exception_handlers/automatic_retries.py +342 -0
  450. nat/utils/exception_handlers/schemas.py +114 -0
  451. nat/utils/io/__init__.py +0 -0
  452. nat/utils/io/model_processing.py +28 -0
  453. nat/utils/io/yaml_tools.py +119 -0
  454. nat/utils/log_levels.py +25 -0
  455. nat/utils/log_utils.py +37 -0
  456. nat/utils/metadata_utils.py +74 -0
  457. nat/utils/optional_imports.py +142 -0
  458. nat/utils/producer_consumer_queue.py +178 -0
  459. nat/utils/reactive/__init__.py +0 -0
  460. nat/utils/reactive/base/__init__.py +0 -0
  461. nat/utils/reactive/base/observable_base.py +65 -0
  462. nat/utils/reactive/base/observer_base.py +55 -0
  463. nat/utils/reactive/base/subject_base.py +79 -0
  464. nat/utils/reactive/observable.py +59 -0
  465. nat/utils/reactive/observer.py +76 -0
  466. nat/utils/reactive/subject.py +131 -0
  467. nat/utils/reactive/subscription.py +49 -0
  468. nat/utils/settings/__init__.py +0 -0
  469. nat/utils/settings/global_settings.py +195 -0
  470. nat/utils/string_utils.py +38 -0
  471. nat/utils/type_converter.py +299 -0
  472. nat/utils/type_utils.py +488 -0
  473. nat/utils/url_utils.py +27 -0
  474. nvidia_nat-1.1.0a20251020.dist-info/METADATA +195 -0
  475. nvidia_nat-1.1.0a20251020.dist-info/RECORD +480 -0
  476. nvidia_nat-1.1.0a20251020.dist-info/WHEEL +5 -0
  477. nvidia_nat-1.1.0a20251020.dist-info/entry_points.txt +22 -0
  478. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  479. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
  480. nvidia_nat-1.1.0a20251020.dist-info/top_level.txt +2 -0
@@ -0,0 +1,204 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import logging
18
+ import re
19
+
20
+ from nat.builder.builder import Builder
21
+ from nat.builder.framework_enum import LLMFrameworkEnum
22
+ from nat.cli.register_workflow import register_ttc_strategy
23
+ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
24
+ from nat.experimental.test_time_compute.models.editor_config import LLMAsAJudgeEditorConfig
25
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
26
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
27
+ from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
28
+ from nat.experimental.test_time_compute.models.ttc_item import TTCItem
29
+ from nat.utils.io.model_processing import remove_r1_think_tags
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class LLMAsAJudgeEditor(StrategyBase):
35
+ """
36
+ Given a list of PlanningItems, uses a feedback LLM to generate feedback on each plan
37
+ Then edits the plan based on feedback.
38
+ """
39
+
40
+ def __init__(self, config: TTCStrategyBaseConfig) -> None:
41
+ super().__init__(config)
42
+ self.feedback_llm = None
43
+ self.editing_llm = None
44
+
45
+ async def build_components(self, builder: Builder) -> None:
46
+ """
47
+ Build the components required for the editor.
48
+ """
49
+ # Get the feedback LLM
50
+ self.feedback_llm = await builder.get_llm(self.config.feedback_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
51
+
52
+ self.editing_llm = await builder.get_llm(self.config.editing_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
53
+
54
+ def supported_pipeline_types(self) -> [PipelineTypeEnum]:
55
+ return [PipelineTypeEnum.PLANNING]
56
+
57
+ def stage_type(self) -> StageTypeEnum:
58
+ return StageTypeEnum.EDITING
59
+
60
+ async def generate_feedback(self, llm, template, context: str, prompt: str, item: TTCItem) -> TTCItem:
61
+ """
62
+ Helper function to generate feedback for a given planning item using the provided prompt.
63
+ """
64
+
65
+ prompt = await template.ainvoke(
66
+ input={
67
+ "context": context,
68
+ "original_prompt": prompt, # Original prompt used to generate the plans
69
+ "plan": item.plan,
70
+ "num_feedback": self.config.num_feedback
71
+ })
72
+
73
+ feedback_result = await llm.ainvoke(prompt.to_string())
74
+ if not feedback_result:
75
+ logger.warning(f"No feedback generated for plan: {item.plan}.")
76
+ return item
77
+
78
+ # Update the planning item with the generated feedback
79
+ cleaned = remove_r1_think_tags(
80
+ feedback_result.content if hasattr(feedback_result, 'content') else str(feedback_result))
81
+
82
+ # Feedback is the string following 'FEEDBACK:'. Use Regex to extract
83
+ cleaned = re.sub(r'(?i)^\s*FEEDBACK:\s*', '', cleaned).strip()
84
+ if not cleaned:
85
+ logger.warning(f"Feedback was empty for plan: {item.plan}.")
86
+ return item
87
+
88
+ item.feedback = cleaned # Set the feedback in the TTCItem
89
+
90
+ return item
91
+
92
+ async def edit_plan(self, llm, template, context: str, prompt: str, item: TTCItem) -> TTCItem:
93
+ """
94
+ Helper function to edit a plan based on feedback using the provided prompt.
95
+ """
96
+
97
+ if not item.feedback:
98
+ logger.warning(f"No feedback available for plan: {item.plan}. Cannot edit.")
99
+ return item
100
+
101
+ prompt = await template.ainvoke(
102
+ input={
103
+ "context": context,
104
+ "original_prompt": prompt, # Original prompt used to generate the plans
105
+ "plan": item.plan,
106
+ "feedback": item.feedback
107
+ })
108
+
109
+ editing_result = await llm.ainvoke(prompt.to_string())
110
+ if not editing_result:
111
+ logger.warning(f"No editing result generated for plan: {item.plan}.")
112
+ return item
113
+
114
+ # Update the planning item with the edited plan
115
+ cleaned = remove_r1_think_tags(
116
+ editing_result.content if hasattr(editing_result, 'content') else str(editing_result))
117
+
118
+ # Plan is the string following 'EDITED PLAN:'. Use Regex to extract
119
+ cleaned = re.sub(r'(?i)^\s*EDITED PLAN:\s*', '', cleaned).strip()
120
+ if not cleaned:
121
+ logger.warning(f"Edited plan was empty for plan: {item.plan}. Returning original.")
122
+ return item
123
+
124
+ # Update the plan in the PlanningItem
125
+ item.plan = cleaned
126
+
127
+ return item
128
+
129
+ async def ainvoke(self,
130
+ items: list[TTCItem],
131
+ original_prompt: str | None = None,
132
+ agent_context: str | None = None,
133
+ **kwargs) -> list[TTCItem]:
134
+ """
135
+ Edit the provided planning items using a feedback LLM.
136
+ """
137
+ from langchain_core.language_models import BaseChatModel
138
+ from langchain_core.prompts import PromptTemplate
139
+
140
+ # assert self.config.feedback_llm is a BaseChatModel
141
+ if not isinstance(self.feedback_llm, BaseChatModel):
142
+ raise ValueError("The `feedback_llm` must be an instance of `BaseChatModel`.")
143
+
144
+ # assert self.config.editing_llm is a BaseChatModel
145
+ if not isinstance(self.editing_llm, BaseChatModel):
146
+ raise ValueError("The `editing_llm` must be an instance of `BaseChatModel`.")
147
+
148
+ feedback_model: BaseChatModel = self.feedback_llm
149
+ editing_model: BaseChatModel = self.editing_llm
150
+
151
+ feedback_template = PromptTemplate(template=self.config.feedback_template,
152
+ input_variables=["context", "original_prompt", "plan", "num_feedback"],
153
+ validate_template=True)
154
+
155
+ editing_template = PromptTemplate(template=self.config.editor_template,
156
+ input_variables=["context", "original_prompt", "plan", "feedback"],
157
+ validate_template=True)
158
+
159
+ # Generate feedback for each planning item concurrently
160
+ feedback_tasks = [
161
+ self.generate_feedback(
162
+ llm=feedback_model,
163
+ template=feedback_template,
164
+ context=agent_context,
165
+ prompt=original_prompt, # Original prompt used to generate the plans
166
+ item=item) for item in items
167
+ ]
168
+ # Run the feedback tasks concurrently and gather results
169
+ planning_items_with_feedback = await asyncio.gather(*feedback_tasks)
170
+
171
+ if not planning_items_with_feedback:
172
+ raise ValueError("No feedback was generated for the planning items. Please check the LLM response.")
173
+
174
+ logger.info("Generated feedback for %d plans.", len(planning_items_with_feedback))
175
+
176
+ # Now edit each planning item based on the feedback concurrently
177
+ editing_tasks = [
178
+ self.edit_plan(
179
+ llm=editing_model,
180
+ template=editing_template,
181
+ context=agent_context,
182
+ prompt=original_prompt, # Original prompt used to generate the plans
183
+ item=item) for item in planning_items_with_feedback
184
+ ]
185
+ # Run the editing tasks concurrently and gather results
186
+ edited_planning_items = await asyncio.gather(*editing_tasks)
187
+
188
+ if not edited_planning_items:
189
+ raise ValueError("No plans were edited. Please check the LLM response.")
190
+
191
+ logger.info("Edited %d plans based on feedback.", len(edited_planning_items))
192
+ return edited_planning_items
193
+
194
+
195
+ @register_ttc_strategy(config_type=LLMAsAJudgeEditorConfig)
196
+ async def register_llm_as_a_judge_editor(config: TTCStrategyBaseConfig, builder: Builder):
197
+ """
198
+ Register the LLMAsAJudgeEditor strategy with the provided configuration and builder.
199
+ """
200
+
201
+ editor = LLMAsAJudgeEditor(config)
202
+ await editor.build_components(builder)
203
+
204
+ yield editor
@@ -0,0 +1,107 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from nat.builder.builder import Builder
19
+ from nat.builder.framework_enum import LLMFrameworkEnum
20
+ from nat.cli.register_workflow import register_ttc_strategy
21
+ from nat.experimental.test_time_compute.models.editor_config import MotivationAwareSummarizationConfig
22
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
23
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
24
+ from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
25
+ from nat.experimental.test_time_compute.models.ttc_item import TTCItem
26
+ from nat.utils.io.model_processing import remove_r1_think_tags
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class MotivationAwareSummarization(StrategyBase):
32
+ """
33
+ A strategy that, for each incoming TTCItem, summarizes the output based on input
34
+ and motivation.
35
+ """
36
+
37
+ def __init__(self, config: MotivationAwareSummarizationConfig) -> None:
38
+ super().__init__(config)
39
+ self.config = config
40
+ self.llm_bound = None
41
+
42
+ async def build_components(self, builder: Builder) -> None:
43
+ """
44
+ Binds each LLMRef in self.config.llms to an actual LLM client.
45
+ """
46
+ bound_llm = await builder.get_llm(self.config.editor_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
47
+ self.llm_bound = bound_llm
48
+
49
+ def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
50
+ return [PipelineTypeEnum.TOOL_USE]
51
+
52
+ def stage_type(self) -> StageTypeEnum:
53
+ return StageTypeEnum.EDITING
54
+
55
+ async def ainvoke(self,
56
+ items: list[TTCItem],
57
+ original_prompt: str | None = None,
58
+ agent_context: str | None = None,
59
+ **kwargs) -> list[TTCItem]:
60
+ """
61
+ For each TTCItem, rewrite the 'input' using each LLM to create a new perspective.
62
+ The new TTCItems' 'output' field will store the newly generated query.
63
+ """
64
+ try:
65
+ from langchain_core.prompts import PromptTemplate
66
+ except ImportError:
67
+ raise ImportError("langchain-core is required for MultiQueryRetrievalSearch. "
68
+ "Install nvidia-nat-langchain or similar.")
69
+
70
+ new_ttc_items: list[TTCItem] = []
71
+
72
+ # Create a single PromptTemplate object for rewriting the query
73
+ template_vars = ["task", "motivation", "output"]
74
+ query_template = PromptTemplate(template=self.config.editor_template,
75
+ input_variables=template_vars,
76
+ validate_template=True)
77
+
78
+ for item in items:
79
+ original_task = str(item.input) or ""
80
+ motivation = str(item.metadata) if item.metadata else ""
81
+ output = str(item.output) if item.output else ""
82
+
83
+ prompt = await (query_template.ainvoke(input={
84
+ "task": original_task, "motivation": motivation, "output": output
85
+ }))
86
+
87
+ llm_response = await self.llm_bound.ainvoke(prompt.to_string())
88
+ llm_response = remove_r1_think_tags(llm_response.content)
89
+
90
+ logger.info("LLM response from summarization: %s", llm_response)
91
+
92
+ new_ttc_items.append(
93
+ TTCItem(
94
+ input=item.input,
95
+ output=remove_r1_think_tags(llm_response),
96
+ metadata=item.metadata,
97
+ name=item.name, # keep the original tool name
98
+ ))
99
+
100
+ return new_ttc_items
101
+
102
+
103
+ @register_ttc_strategy(config_type=MotivationAwareSummarizationConfig)
104
+ async def register_multi_query_retrieval_search(config: MotivationAwareSummarizationConfig, builder: Builder):
105
+ strategy = MotivationAwareSummarization(config)
106
+ await strategy.build_components(builder)
107
+ yield strategy
@@ -0,0 +1,105 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from pydantic import Field
19
+
20
+ from nat.builder.builder import Builder
21
+ from nat.builder.function import Function
22
+ from nat.builder.function_info import FunctionInfo
23
+ from nat.cli.register_workflow import register_function
24
+ from nat.data_models.component_ref import FunctionRef
25
+ from nat.data_models.component_ref import TTCStrategyRef
26
+ from nat.data_models.function import FunctionBaseConfig
27
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
28
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
29
+ from nat.experimental.test_time_compute.models.ttc_item import TTCItem
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class ExecuteScoreSelectFunctionConfig(FunctionBaseConfig, name="execute_score_select_function"):
35
+ scorer: TTCStrategyRef | None = Field(description="Strategy to score the output of the function", default=None)
36
+ selector: TTCStrategyRef = Field(description="Strategy to select the best output of the function")
37
+ augmented_fn: FunctionRef = Field(description="Function that will be executed")
38
+
39
+ num_executions: int = Field(3, description="Number of times to execute the function")
40
+
41
+
42
+ @register_function(config_type=ExecuteScoreSelectFunctionConfig)
43
+ async def execute_score_select_function(config: ExecuteScoreSelectFunctionConfig, builder: Builder):
44
+ import asyncio
45
+ import warnings
46
+
47
+ from pydantic import BaseModel
48
+
49
+ executable_fn: Function = await builder.get_function(name=config.augmented_fn)
50
+
51
+ if config.scorer:
52
+ scorer = await builder.get_ttc_strategy(strategy_name=config.scorer,
53
+ pipeline_type=PipelineTypeEnum.AGENT_EXECUTION,
54
+ stage_type=StageTypeEnum.SCORING)
55
+ else:
56
+ scorer = None
57
+
58
+ selector = await builder.get_ttc_strategy(strategy_name=config.selector,
59
+ pipeline_type=PipelineTypeEnum.AGENT_EXECUTION,
60
+ stage_type=StageTypeEnum.SELECTION)
61
+
62
+ if executable_fn.has_streaming_output:
63
+ warnings.warn("Streaming output is not supported for this function. "
64
+ "The function will be executed in non-streaming mode.")
65
+
66
+ def convert_to_str(arg):
67
+ if isinstance(arg, BaseModel):
68
+ return str(arg.model_dump())
69
+ return str(arg)
70
+
71
+ async def execute_fn(input_msg: executable_fn.input_type) -> executable_fn.single_output_type:
72
+
73
+ logger.info("Executing function %d times", config.num_executions)
74
+ tasks = [executable_fn.ainvoke(input_msg) for _ in range(config.num_executions)]
75
+ results = await asyncio.gather(*tasks)
76
+
77
+ input_str = convert_to_str(input_msg)
78
+ function_outputs = [convert_to_str(out) for out in results]
79
+ its_items = [TTCItem(
80
+ input=input_str,
81
+ output=out,
82
+ ) for out in function_outputs]
83
+
84
+ if scorer:
85
+ logger.info("Beginning scoring")
86
+ its_items = await scorer.ainvoke(items=its_items)
87
+
88
+ logger.info("Beginning selection")
89
+ selected_item = (await selector.ainvoke(items=its_items, original_prompt=its_items[0].input))[0]
90
+
91
+ # Find the index of selected item in its_items by matching the output
92
+ selected_output = selected_item.output
93
+ selected_index = -1
94
+ for i, item in enumerate(its_items):
95
+ if item.output == selected_output:
96
+ selected_index = i
97
+ break
98
+
99
+ return results[selected_index] if selected_index != -1 else selected_output
100
+
101
+ yield FunctionInfo.from_fn(
102
+ fn=execute_fn,
103
+ description=("This function executes a given function multiple times, scores the outputs, "
104
+ "and selects the best output based on the specified scoring and selection strategies."),
105
+ )
@@ -0,0 +1,228 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from collections.abc import AsyncGenerator
18
+
19
+ from pydantic import Field
20
+
21
+ from nat.builder.builder import Builder
22
+ from nat.builder.framework_enum import LLMFrameworkEnum
23
+ from nat.builder.function_info import FunctionInfo
24
+ from nat.cli.register_workflow import register_function
25
+ from nat.data_models.api_server import ChatRequest
26
+ from nat.data_models.component_ref import FunctionRef
27
+ from nat.data_models.component_ref import TTCStrategyRef
28
+ from nat.data_models.function import FunctionBaseConfig
29
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
30
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
31
+ from nat.experimental.test_time_compute.models.ttc_item import TTCItem
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class PlanSelectExecuteFunctionConfig(FunctionBaseConfig, name="plan_select_execute_function"):
37
+ """
38
+ Defines a NAT function that performs reasoning on the input data.
39
+ Output is passed to the next function in the workflow.
40
+
41
+ Designed to be used with an InterceptingFunction.
42
+ """
43
+
44
+ augmented_fn: FunctionRef = Field(description="The name of the function to reason on.")
45
+
46
+ planner: TTCStrategyRef = Field(description="The configuration for the planner.")
47
+ editor: TTCStrategyRef | None = Field(description="The configuration for the editor.", default=None)
48
+ scorer: TTCStrategyRef | None = Field(description="The configuration for the scorer.", default=None)
49
+ selector: TTCStrategyRef = Field(description="The configuration for the selector.")
50
+
51
+ verbose: bool = Field(default=False, description="Whether to log detailed information.")
52
+ agent_context_prompt_template: str = Field(
53
+ description="The template for the agent context prompt. This prompt is used to provide context about the agent",
54
+ default=("\nThe agent system has the following description:\n"
55
+ "{description}\n"
56
+ "And has access to the following tools with functionality:\n"
57
+ "{tools}\n\n"))
58
+
59
+ downstream_template: str = Field(
60
+ description=("The template for the downstream prompt. This prompt is used to provide the reasoning output to"
61
+ " the executing agent"),
62
+ default=("Answer the following question based on message history: {input_text}"
63
+ "\n\nHere is a plan for execution that you could use to guide you if you wanted to:"
64
+ "\n\n{reasoning_output}"
65
+ "\n\nNOTE: Remember to follow your guidance on how to format output, etc."
66
+ "\n\n You must respond with the answer to the original question directly to the user."))
67
+
68
+
69
+ @register_function(config_type=PlanSelectExecuteFunctionConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
70
+ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig, builder: Builder):
71
+ """
72
+ Build a ExecutionPlanningFunction from the provided config.
73
+
74
+ Args:
75
+ config (ExecutionPlanningFunctionConfig): The config for the ExecutionPlanningFunction.
76
+ builder (Builder): The Builder instance to use for building the function.
77
+
78
+ Returns:
79
+ ExecutionPlanningFunction: The built ExecutionPlanningFunction.
80
+ """
81
+
82
+ try:
83
+ from langchain_core.prompts import PromptTemplate
84
+ except ImportError:
85
+ raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
86
+ "This error can be resolved by installing nvidia-nat-langchain.")
87
+
88
+ # Get the augmented function's description
89
+ augmented_function = await builder.get_function(config.augmented_fn)
90
+
91
+ # For now, we rely on runtime checking for type conversion
92
+
93
+ if augmented_function.description and augmented_function.description != "":
94
+ augmented_function_desc = augmented_function.description
95
+ else:
96
+ raise ValueError(f"Function {config.augmented_fn} does not have a description. Cannot augment "
97
+ f"function without a description.")
98
+
99
+ # Get the function dependencies of the augmented function
100
+ function_dependencies = builder.get_function_dependencies(config.augmented_fn)
101
+ function_used_tools = set(function_dependencies.functions)
102
+ for function_group in function_dependencies.function_groups:
103
+ function_used_tools.update(builder.get_function_group_dependencies(function_group).functions)
104
+
105
+ tool_list = "Tool: Description\n"
106
+
107
+ for tool in function_used_tools:
108
+ tool_impl = await builder.get_function(tool)
109
+ tool_list += f"- {tool}: {tool_impl.description if hasattr(tool_impl, 'description') else ''}\n"
110
+
111
+ # Draft the reasoning prompt for the augmented function
112
+ template = PromptTemplate(template=config.agent_context_prompt_template,
113
+ input_variables=["description", "tools"],
114
+ validate_template=True)
115
+
116
+ downstream_template = PromptTemplate(template=config.downstream_template,
117
+ input_variables=["input_text", "reasoning_output"],
118
+ validate_template=True)
119
+
120
+ planner = await builder.get_ttc_strategy(strategy_name=config.planner,
121
+ pipeline_type=PipelineTypeEnum.PLANNING,
122
+ stage_type=StageTypeEnum.SEARCH)
123
+
124
+ selector = await builder.get_ttc_strategy(strategy_name=config.selector,
125
+ pipeline_type=PipelineTypeEnum.PLANNING,
126
+ stage_type=StageTypeEnum.SELECTION)
127
+
128
+ if config.editor:
129
+ editor = await builder.get_ttc_strategy(strategy_name=config.editor,
130
+ pipeline_type=PipelineTypeEnum.PLANNING,
131
+ stage_type=StageTypeEnum.EDITING)
132
+ else:
133
+ editor = None
134
+
135
+ if config.scorer:
136
+ scorer = await builder.get_ttc_strategy(strategy_name=config.scorer,
137
+ pipeline_type=PipelineTypeEnum.PLANNING,
138
+ stage_type=StageTypeEnum.SCORING)
139
+ else:
140
+ scorer = None
141
+
142
+ async def planning_pipeline(prompt, context):
143
+
144
+ plans = await planner.ainvoke([TTCItem()], prompt, context)
145
+
146
+ if editor:
147
+ plans = await editor.ainvoke(plans, prompt, context)
148
+ if scorer:
149
+ plans = await scorer.ainvoke(plans, prompt, context)
150
+
151
+ selected_plan = (await selector.ainvoke(plans, prompt, context))[0]
152
+
153
+ return selected_plan
154
+
155
+ streaming_inner_fn = None
156
+ single_inner_fn = None
157
+
158
+ if augmented_function.has_streaming_output:
159
+
160
+ async def streaming_inner(
161
+ input_message: ChatRequest) -> AsyncGenerator[augmented_function.streaming_output_type]:
162
+ """
163
+ Perform reasoning on the input text.
164
+
165
+ Args:
166
+ input_message (ChatRequest): The input text to reason on.
167
+ """
168
+
169
+ input_text = "".join([str(message.model_dump()) + "\n" for message in input_message.messages])
170
+
171
+ context_prompt = await template.ainvoke(input={"description": augmented_function_desc, "tools": tool_list})
172
+
173
+ context_prompt = context_prompt.to_string()
174
+
175
+ # Run the TTC pipeline
176
+ planning_item: TTCItem = await planning_pipeline(prompt=input_text, context=context_prompt)
177
+
178
+ output = await downstream_template.ainvoke(input={
179
+ "input_text": input_text, "reasoning_output": planning_item.plan
180
+ })
181
+
182
+ output = output.to_string()
183
+
184
+ if config.verbose:
185
+ logger.info("Reasoning plan and input to agent: \n\n%s", output)
186
+
187
+ async for chunk in augmented_function.acall_stream(output):
188
+ yield chunk
189
+
190
+ streaming_inner_fn = streaming_inner
191
+
192
+ if augmented_function.has_single_output:
193
+
194
+ async def single_inner(input_message: ChatRequest) -> augmented_function.single_output_type:
195
+ """
196
+ Perform reasoning on the input text.
197
+
198
+ Args:
199
+ input_message (ChatRequest): The input text to reason on.
200
+ """
201
+
202
+ input_text = "".join([str(message.model_dump()) + "\n" for message in input_message.messages])
203
+
204
+ context_prompt = await template.ainvoke(input={"description": augmented_function_desc, "tools": tool_list})
205
+
206
+ context_prompt = context_prompt.to_string()
207
+
208
+ # Run the TTC pipeline
209
+ planning_item: TTCItem = await planning_pipeline(prompt=input_text, context=context_prompt)
210
+
211
+ output = await downstream_template.ainvoke(input={
212
+ "input_text": input_text, "reasoning_output": planning_item.plan
213
+ })
214
+
215
+ output = output.to_string()
216
+
217
+ if config.verbose:
218
+ logger.info("Reasoning plan and input to agent: \n\n%s", output)
219
+
220
+ return await augmented_function.acall_invoke(output)
221
+
222
+ single_inner_fn = single_inner
223
+
224
+ yield FunctionInfo.create(
225
+ single_fn=single_inner_fn,
226
+ stream_fn=streaming_inner_fn,
227
+ description=("Function that runs an TTC execution planner on input and sends plan downstream"),
228
+ converters=augmented_function.converter_list)