nvidia-nat 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/__init__.py +0 -0
  3. nat/agent/base.py +256 -0
  4. nat/agent/dual_node.py +67 -0
  5. nat/agent/react_agent/__init__.py +0 -0
  6. nat/agent/react_agent/agent.py +363 -0
  7. nat/agent/react_agent/output_parser.py +104 -0
  8. nat/agent/react_agent/prompt.py +44 -0
  9. nat/agent/react_agent/register.py +149 -0
  10. nat/agent/reasoning_agent/__init__.py +0 -0
  11. nat/agent/reasoning_agent/reasoning_agent.py +225 -0
  12. nat/agent/register.py +23 -0
  13. nat/agent/rewoo_agent/__init__.py +0 -0
  14. nat/agent/rewoo_agent/agent.py +415 -0
  15. nat/agent/rewoo_agent/prompt.py +110 -0
  16. nat/agent/rewoo_agent/register.py +157 -0
  17. nat/agent/tool_calling_agent/__init__.py +0 -0
  18. nat/agent/tool_calling_agent/agent.py +119 -0
  19. nat/agent/tool_calling_agent/register.py +106 -0
  20. nat/authentication/__init__.py +14 -0
  21. nat/authentication/api_key/__init__.py +14 -0
  22. nat/authentication/api_key/api_key_auth_provider.py +96 -0
  23. nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
  24. nat/authentication/api_key/register.py +26 -0
  25. nat/authentication/exceptions/__init__.py +14 -0
  26. nat/authentication/exceptions/api_key_exceptions.py +38 -0
  27. nat/authentication/http_basic_auth/__init__.py +0 -0
  28. nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  29. nat/authentication/http_basic_auth/register.py +30 -0
  30. nat/authentication/interfaces.py +93 -0
  31. nat/authentication/oauth2/__init__.py +14 -0
  32. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
  33. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  34. nat/authentication/oauth2/register.py +25 -0
  35. nat/authentication/register.py +21 -0
  36. nat/builder/__init__.py +0 -0
  37. nat/builder/builder.py +285 -0
  38. nat/builder/component_utils.py +316 -0
  39. nat/builder/context.py +270 -0
  40. nat/builder/embedder.py +24 -0
  41. nat/builder/eval_builder.py +161 -0
  42. nat/builder/evaluator.py +29 -0
  43. nat/builder/framework_enum.py +24 -0
  44. nat/builder/front_end.py +73 -0
  45. nat/builder/function.py +344 -0
  46. nat/builder/function_base.py +380 -0
  47. nat/builder/function_info.py +627 -0
  48. nat/builder/intermediate_step_manager.py +174 -0
  49. nat/builder/llm.py +25 -0
  50. nat/builder/retriever.py +25 -0
  51. nat/builder/user_interaction_manager.py +78 -0
  52. nat/builder/workflow.py +148 -0
  53. nat/builder/workflow_builder.py +1117 -0
  54. nat/cli/__init__.py +14 -0
  55. nat/cli/cli_utils/__init__.py +0 -0
  56. nat/cli/cli_utils/config_override.py +231 -0
  57. nat/cli/cli_utils/validation.py +37 -0
  58. nat/cli/commands/__init__.py +0 -0
  59. nat/cli/commands/configure/__init__.py +0 -0
  60. nat/cli/commands/configure/channel/__init__.py +0 -0
  61. nat/cli/commands/configure/channel/add.py +28 -0
  62. nat/cli/commands/configure/channel/channel.py +34 -0
  63. nat/cli/commands/configure/channel/remove.py +30 -0
  64. nat/cli/commands/configure/channel/update.py +30 -0
  65. nat/cli/commands/configure/configure.py +33 -0
  66. nat/cli/commands/evaluate.py +139 -0
  67. nat/cli/commands/info/__init__.py +14 -0
  68. nat/cli/commands/info/info.py +37 -0
  69. nat/cli/commands/info/list_channels.py +32 -0
  70. nat/cli/commands/info/list_components.py +129 -0
  71. nat/cli/commands/info/list_mcp.py +304 -0
  72. nat/cli/commands/registry/__init__.py +14 -0
  73. nat/cli/commands/registry/publish.py +88 -0
  74. nat/cli/commands/registry/pull.py +118 -0
  75. nat/cli/commands/registry/registry.py +36 -0
  76. nat/cli/commands/registry/remove.py +108 -0
  77. nat/cli/commands/registry/search.py +155 -0
  78. nat/cli/commands/sizing/__init__.py +14 -0
  79. nat/cli/commands/sizing/calc.py +297 -0
  80. nat/cli/commands/sizing/sizing.py +27 -0
  81. nat/cli/commands/start.py +246 -0
  82. nat/cli/commands/uninstall.py +81 -0
  83. nat/cli/commands/validate.py +47 -0
  84. nat/cli/commands/workflow/__init__.py +14 -0
  85. nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  86. nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
  87. nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
  88. nat/cli/commands/workflow/templates/register.py.j2 +5 -0
  89. nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
  90. nat/cli/commands/workflow/workflow.py +37 -0
  91. nat/cli/commands/workflow/workflow_commands.py +317 -0
  92. nat/cli/entrypoint.py +135 -0
  93. nat/cli/main.py +57 -0
  94. nat/cli/register_workflow.py +488 -0
  95. nat/cli/type_registry.py +1000 -0
  96. nat/data_models/__init__.py +14 -0
  97. nat/data_models/api_server.py +716 -0
  98. nat/data_models/authentication.py +231 -0
  99. nat/data_models/common.py +171 -0
  100. nat/data_models/component.py +58 -0
  101. nat/data_models/component_ref.py +168 -0
  102. nat/data_models/config.py +410 -0
  103. nat/data_models/dataset_handler.py +169 -0
  104. nat/data_models/discovery_metadata.py +305 -0
  105. nat/data_models/embedder.py +27 -0
  106. nat/data_models/evaluate.py +127 -0
  107. nat/data_models/evaluator.py +26 -0
  108. nat/data_models/front_end.py +26 -0
  109. nat/data_models/function.py +30 -0
  110. nat/data_models/function_dependencies.py +72 -0
  111. nat/data_models/interactive.py +246 -0
  112. nat/data_models/intermediate_step.py +302 -0
  113. nat/data_models/invocation_node.py +38 -0
  114. nat/data_models/llm.py +27 -0
  115. nat/data_models/logging.py +26 -0
  116. nat/data_models/memory.py +27 -0
  117. nat/data_models/object_store.py +44 -0
  118. nat/data_models/profiler.py +54 -0
  119. nat/data_models/registry_handler.py +26 -0
  120. nat/data_models/retriever.py +30 -0
  121. nat/data_models/retry_mixin.py +35 -0
  122. nat/data_models/span.py +190 -0
  123. nat/data_models/step_adaptor.py +64 -0
  124. nat/data_models/streaming.py +33 -0
  125. nat/data_models/swe_bench_model.py +54 -0
  126. nat/data_models/telemetry_exporter.py +26 -0
  127. nat/data_models/ttc_strategy.py +30 -0
  128. nat/embedder/__init__.py +0 -0
  129. nat/embedder/nim_embedder.py +59 -0
  130. nat/embedder/openai_embedder.py +43 -0
  131. nat/embedder/register.py +22 -0
  132. nat/eval/__init__.py +14 -0
  133. nat/eval/config.py +60 -0
  134. nat/eval/dataset_handler/__init__.py +0 -0
  135. nat/eval/dataset_handler/dataset_downloader.py +106 -0
  136. nat/eval/dataset_handler/dataset_filter.py +52 -0
  137. nat/eval/dataset_handler/dataset_handler.py +367 -0
  138. nat/eval/evaluate.py +510 -0
  139. nat/eval/evaluator/__init__.py +14 -0
  140. nat/eval/evaluator/base_evaluator.py +77 -0
  141. nat/eval/evaluator/evaluator_model.py +45 -0
  142. nat/eval/intermediate_step_adapter.py +99 -0
  143. nat/eval/rag_evaluator/__init__.py +0 -0
  144. nat/eval/rag_evaluator/evaluate.py +178 -0
  145. nat/eval/rag_evaluator/register.py +143 -0
  146. nat/eval/register.py +23 -0
  147. nat/eval/remote_workflow.py +133 -0
  148. nat/eval/runners/__init__.py +14 -0
  149. nat/eval/runners/config.py +39 -0
  150. nat/eval/runners/multi_eval_runner.py +54 -0
  151. nat/eval/runtime_event_subscriber.py +52 -0
  152. nat/eval/swe_bench_evaluator/__init__.py +0 -0
  153. nat/eval/swe_bench_evaluator/evaluate.py +215 -0
  154. nat/eval/swe_bench_evaluator/register.py +36 -0
  155. nat/eval/trajectory_evaluator/__init__.py +0 -0
  156. nat/eval/trajectory_evaluator/evaluate.py +75 -0
  157. nat/eval/trajectory_evaluator/register.py +40 -0
  158. nat/eval/tunable_rag_evaluator/__init__.py +0 -0
  159. nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
  160. nat/eval/tunable_rag_evaluator/register.py +52 -0
  161. nat/eval/usage_stats.py +41 -0
  162. nat/eval/utils/__init__.py +0 -0
  163. nat/eval/utils/output_uploader.py +140 -0
  164. nat/eval/utils/tqdm_position_registry.py +40 -0
  165. nat/eval/utils/weave_eval.py +184 -0
  166. nat/experimental/__init__.py +0 -0
  167. nat/experimental/decorators/__init__.py +0 -0
  168. nat/experimental/decorators/experimental_warning_decorator.py +134 -0
  169. nat/experimental/test_time_compute/__init__.py +0 -0
  170. nat/experimental/test_time_compute/editing/__init__.py +0 -0
  171. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  172. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  173. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  174. nat/experimental/test_time_compute/functions/__init__.py +0 -0
  175. nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  176. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
  177. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
  178. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
  179. nat/experimental/test_time_compute/models/__init__.py +0 -0
  180. nat/experimental/test_time_compute/models/editor_config.py +132 -0
  181. nat/experimental/test_time_compute/models/scoring_config.py +112 -0
  182. nat/experimental/test_time_compute/models/search_config.py +120 -0
  183. nat/experimental/test_time_compute/models/selection_config.py +154 -0
  184. nat/experimental/test_time_compute/models/stage_enums.py +43 -0
  185. nat/experimental/test_time_compute/models/strategy_base.py +66 -0
  186. nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
  187. nat/experimental/test_time_compute/models/ttc_item.py +48 -0
  188. nat/experimental/test_time_compute/register.py +36 -0
  189. nat/experimental/test_time_compute/scoring/__init__.py +0 -0
  190. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  191. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  192. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  193. nat/experimental/test_time_compute/search/__init__.py +0 -0
  194. nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  195. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  196. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  197. nat/experimental/test_time_compute/selection/__init__.py +0 -0
  198. nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  199. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  200. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
  201. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  202. nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  203. nat/front_ends/__init__.py +14 -0
  204. nat/front_ends/console/__init__.py +14 -0
  205. nat/front_ends/console/authentication_flow_handler.py +233 -0
  206. nat/front_ends/console/console_front_end_config.py +32 -0
  207. nat/front_ends/console/console_front_end_plugin.py +96 -0
  208. nat/front_ends/console/register.py +25 -0
  209. nat/front_ends/cron/__init__.py +14 -0
  210. nat/front_ends/fastapi/__init__.py +14 -0
  211. nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  212. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  213. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
  214. nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
  215. nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  216. nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
  217. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
  218. nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
  219. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  220. nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  221. nat/front_ends/fastapi/job_store.py +183 -0
  222. nat/front_ends/fastapi/main.py +72 -0
  223. nat/front_ends/fastapi/message_handler.py +320 -0
  224. nat/front_ends/fastapi/message_validator.py +352 -0
  225. nat/front_ends/fastapi/register.py +25 -0
  226. nat/front_ends/fastapi/response_helpers.py +195 -0
  227. nat/front_ends/fastapi/step_adaptor.py +319 -0
  228. nat/front_ends/mcp/__init__.py +14 -0
  229. nat/front_ends/mcp/mcp_front_end_config.py +36 -0
  230. nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
  231. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
  232. nat/front_ends/mcp/register.py +27 -0
  233. nat/front_ends/mcp/tool_converter.py +241 -0
  234. nat/front_ends/register.py +22 -0
  235. nat/front_ends/simple_base/__init__.py +14 -0
  236. nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
  237. nat/llm/__init__.py +0 -0
  238. nat/llm/aws_bedrock_llm.py +57 -0
  239. nat/llm/nim_llm.py +46 -0
  240. nat/llm/openai_llm.py +46 -0
  241. nat/llm/register.py +23 -0
  242. nat/llm/utils/__init__.py +14 -0
  243. nat/llm/utils/env_config_value.py +94 -0
  244. nat/llm/utils/error.py +17 -0
  245. nat/memory/__init__.py +20 -0
  246. nat/memory/interfaces.py +183 -0
  247. nat/memory/models.py +112 -0
  248. nat/meta/pypi.md +58 -0
  249. nat/object_store/__init__.py +20 -0
  250. nat/object_store/in_memory_object_store.py +76 -0
  251. nat/object_store/interfaces.py +84 -0
  252. nat/object_store/models.py +38 -0
  253. nat/object_store/register.py +20 -0
  254. nat/observability/__init__.py +14 -0
  255. nat/observability/exporter/__init__.py +14 -0
  256. nat/observability/exporter/base_exporter.py +449 -0
  257. nat/observability/exporter/exporter.py +78 -0
  258. nat/observability/exporter/file_exporter.py +33 -0
  259. nat/observability/exporter/processing_exporter.py +322 -0
  260. nat/observability/exporter/raw_exporter.py +52 -0
  261. nat/observability/exporter/span_exporter.py +288 -0
  262. nat/observability/exporter_manager.py +335 -0
  263. nat/observability/mixin/__init__.py +14 -0
  264. nat/observability/mixin/batch_config_mixin.py +26 -0
  265. nat/observability/mixin/collector_config_mixin.py +23 -0
  266. nat/observability/mixin/file_mixin.py +288 -0
  267. nat/observability/mixin/file_mode.py +23 -0
  268. nat/observability/mixin/resource_conflict_mixin.py +134 -0
  269. nat/observability/mixin/serialize_mixin.py +61 -0
  270. nat/observability/mixin/type_introspection_mixin.py +183 -0
  271. nat/observability/processor/__init__.py +14 -0
  272. nat/observability/processor/batching_processor.py +310 -0
  273. nat/observability/processor/callback_processor.py +42 -0
  274. nat/observability/processor/intermediate_step_serializer.py +28 -0
  275. nat/observability/processor/processor.py +71 -0
  276. nat/observability/register.py +96 -0
  277. nat/observability/utils/__init__.py +14 -0
  278. nat/observability/utils/dict_utils.py +236 -0
  279. nat/observability/utils/time_utils.py +31 -0
  280. nat/plugins/.namespace +1 -0
  281. nat/profiler/__init__.py +0 -0
  282. nat/profiler/calc/__init__.py +14 -0
  283. nat/profiler/calc/calc_runner.py +627 -0
  284. nat/profiler/calc/calculations.py +288 -0
  285. nat/profiler/calc/data_models.py +188 -0
  286. nat/profiler/calc/plot.py +345 -0
  287. nat/profiler/callbacks/__init__.py +0 -0
  288. nat/profiler/callbacks/agno_callback_handler.py +295 -0
  289. nat/profiler/callbacks/base_callback_class.py +20 -0
  290. nat/profiler/callbacks/langchain_callback_handler.py +290 -0
  291. nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
  292. nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  293. nat/profiler/callbacks/token_usage_base_model.py +27 -0
  294. nat/profiler/data_frame_row.py +51 -0
  295. nat/profiler/data_models.py +24 -0
  296. nat/profiler/decorators/__init__.py +0 -0
  297. nat/profiler/decorators/framework_wrapper.py +131 -0
  298. nat/profiler/decorators/function_tracking.py +254 -0
  299. nat/profiler/forecasting/__init__.py +0 -0
  300. nat/profiler/forecasting/config.py +18 -0
  301. nat/profiler/forecasting/model_trainer.py +75 -0
  302. nat/profiler/forecasting/models/__init__.py +22 -0
  303. nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
  304. nat/profiler/forecasting/models/linear_model.py +197 -0
  305. nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
  306. nat/profiler/inference_metrics_model.py +28 -0
  307. nat/profiler/inference_optimization/__init__.py +0 -0
  308. nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  309. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  310. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  311. nat/profiler/inference_optimization/data_models.py +386 -0
  312. nat/profiler/inference_optimization/experimental/__init__.py +0 -0
  313. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  314. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
  315. nat/profiler/inference_optimization/llm_metrics.py +212 -0
  316. nat/profiler/inference_optimization/prompt_caching.py +163 -0
  317. nat/profiler/inference_optimization/token_uniqueness.py +107 -0
  318. nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
  319. nat/profiler/intermediate_property_adapter.py +102 -0
  320. nat/profiler/profile_runner.py +473 -0
  321. nat/profiler/utils.py +184 -0
  322. nat/registry_handlers/__init__.py +0 -0
  323. nat/registry_handlers/local/__init__.py +0 -0
  324. nat/registry_handlers/local/local_handler.py +176 -0
  325. nat/registry_handlers/local/register_local.py +37 -0
  326. nat/registry_handlers/metadata_factory.py +60 -0
  327. nat/registry_handlers/package_utils.py +571 -0
  328. nat/registry_handlers/pypi/__init__.py +0 -0
  329. nat/registry_handlers/pypi/pypi_handler.py +251 -0
  330. nat/registry_handlers/pypi/register_pypi.py +40 -0
  331. nat/registry_handlers/register.py +21 -0
  332. nat/registry_handlers/registry_handler_base.py +157 -0
  333. nat/registry_handlers/rest/__init__.py +0 -0
  334. nat/registry_handlers/rest/register_rest.py +56 -0
  335. nat/registry_handlers/rest/rest_handler.py +237 -0
  336. nat/registry_handlers/schemas/__init__.py +0 -0
  337. nat/registry_handlers/schemas/headers.py +42 -0
  338. nat/registry_handlers/schemas/package.py +68 -0
  339. nat/registry_handlers/schemas/publish.py +68 -0
  340. nat/registry_handlers/schemas/pull.py +82 -0
  341. nat/registry_handlers/schemas/remove.py +36 -0
  342. nat/registry_handlers/schemas/search.py +91 -0
  343. nat/registry_handlers/schemas/status.py +47 -0
  344. nat/retriever/__init__.py +0 -0
  345. nat/retriever/interface.py +41 -0
  346. nat/retriever/milvus/__init__.py +14 -0
  347. nat/retriever/milvus/register.py +81 -0
  348. nat/retriever/milvus/retriever.py +228 -0
  349. nat/retriever/models.py +77 -0
  350. nat/retriever/nemo_retriever/__init__.py +14 -0
  351. nat/retriever/nemo_retriever/register.py +60 -0
  352. nat/retriever/nemo_retriever/retriever.py +190 -0
  353. nat/retriever/register.py +22 -0
  354. nat/runtime/__init__.py +14 -0
  355. nat/runtime/loader.py +220 -0
  356. nat/runtime/runner.py +195 -0
  357. nat/runtime/session.py +162 -0
  358. nat/runtime/user_metadata.py +130 -0
  359. nat/settings/__init__.py +0 -0
  360. nat/settings/global_settings.py +318 -0
  361. nat/test/.namespace +1 -0
  362. nat/tool/__init__.py +0 -0
  363. nat/tool/chat_completion.py +74 -0
  364. nat/tool/code_execution/README.md +151 -0
  365. nat/tool/code_execution/__init__.py +0 -0
  366. nat/tool/code_execution/code_sandbox.py +267 -0
  367. nat/tool/code_execution/local_sandbox/.gitignore +1 -0
  368. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  369. nat/tool/code_execution/local_sandbox/__init__.py +13 -0
  370. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  371. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  372. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  373. nat/tool/code_execution/register.py +74 -0
  374. nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
  375. nat/tool/code_execution/utils.py +100 -0
  376. nat/tool/datetime_tools.py +42 -0
  377. nat/tool/document_search.py +141 -0
  378. nat/tool/github_tools/__init__.py +0 -0
  379. nat/tool/github_tools/create_github_commit.py +133 -0
  380. nat/tool/github_tools/create_github_issue.py +87 -0
  381. nat/tool/github_tools/create_github_pr.py +106 -0
  382. nat/tool/github_tools/get_github_file.py +106 -0
  383. nat/tool/github_tools/get_github_issue.py +166 -0
  384. nat/tool/github_tools/get_github_pr.py +256 -0
  385. nat/tool/github_tools/update_github_issue.py +100 -0
  386. nat/tool/mcp/__init__.py +14 -0
  387. nat/tool/mcp/exceptions.py +142 -0
  388. nat/tool/mcp/mcp_client.py +255 -0
  389. nat/tool/mcp/mcp_tool.py +96 -0
  390. nat/tool/memory_tools/__init__.py +0 -0
  391. nat/tool/memory_tools/add_memory_tool.py +79 -0
  392. nat/tool/memory_tools/delete_memory_tool.py +67 -0
  393. nat/tool/memory_tools/get_memory_tool.py +72 -0
  394. nat/tool/nvidia_rag.py +95 -0
  395. nat/tool/register.py +38 -0
  396. nat/tool/retriever.py +94 -0
  397. nat/tool/server_tools.py +66 -0
  398. nat/utils/__init__.py +0 -0
  399. nat/utils/data_models/__init__.py +0 -0
  400. nat/utils/data_models/schema_validator.py +58 -0
  401. nat/utils/debugging_utils.py +43 -0
  402. nat/utils/dump_distro_mapping.py +32 -0
  403. nat/utils/exception_handlers/__init__.py +0 -0
  404. nat/utils/exception_handlers/automatic_retries.py +289 -0
  405. nat/utils/exception_handlers/mcp.py +211 -0
  406. nat/utils/exception_handlers/schemas.py +114 -0
  407. nat/utils/io/__init__.py +0 -0
  408. nat/utils/io/model_processing.py +28 -0
  409. nat/utils/io/yaml_tools.py +119 -0
  410. nat/utils/log_utils.py +37 -0
  411. nat/utils/metadata_utils.py +74 -0
  412. nat/utils/optional_imports.py +142 -0
  413. nat/utils/producer_consumer_queue.py +178 -0
  414. nat/utils/reactive/__init__.py +0 -0
  415. nat/utils/reactive/base/__init__.py +0 -0
  416. nat/utils/reactive/base/observable_base.py +65 -0
  417. nat/utils/reactive/base/observer_base.py +55 -0
  418. nat/utils/reactive/base/subject_base.py +79 -0
  419. nat/utils/reactive/observable.py +59 -0
  420. nat/utils/reactive/observer.py +76 -0
  421. nat/utils/reactive/subject.py +131 -0
  422. nat/utils/reactive/subscription.py +49 -0
  423. nat/utils/settings/__init__.py +0 -0
  424. nat/utils/settings/global_settings.py +197 -0
  425. nat/utils/string_utils.py +38 -0
  426. nat/utils/type_converter.py +290 -0
  427. nat/utils/type_utils.py +484 -0
  428. nat/utils/url_utils.py +27 -0
  429. nvidia_nat-1.2.0.dist-info/METADATA +365 -0
  430. nvidia_nat-1.2.0.dist-info/RECORD +435 -0
  431. nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
  432. nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
  433. nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  434. nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
  435. nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,99 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from langchain_core.agents import AgentAction
19
+
20
+ from nat.data_models.intermediate_step import IntermediateStep
21
+ from nat.data_models.intermediate_step import IntermediateStepType
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class IntermediateStepAdapter:
27
+ DEFAULT_EVENT_FILTER = [IntermediateStepType.LLM_END, IntermediateStepType.TOOL_END]
28
+
29
+ def filter_intermediate_steps(self,
30
+ intermediate_steps: list[IntermediateStep],
31
+ event_filter: list[IntermediateStepType]) -> list[IntermediateStep]:
32
+ """ Filters intermediate steps"""
33
+ if not event_filter:
34
+ return intermediate_steps
35
+ return [step for step in intermediate_steps if step.event_type in event_filter]
36
+
37
+ def validate_intermediate_steps(self, intermediate_steps: list[dict]) -> list[IntermediateStep]:
38
+ validated_steps = []
39
+ for step_data in intermediate_steps:
40
+ try:
41
+ validated_steps.append(IntermediateStep.model_validate(step_data))
42
+ except Exception as e:
43
+ logger.exception("Validation failed for step: %r, Error: %s", step_data, e, exc_info=True)
44
+ return validated_steps
45
+
46
+ def serialize_intermediate_steps(self, intermediate_steps: list[IntermediateStep]) -> list[dict]:
47
+ """Converts a list of IntermediateStep objects to a list of dictionaries."""
48
+ return [step.model_dump() for step in intermediate_steps]
49
+
50
+ @staticmethod
51
+ def agent_action_to_dict(action) -> dict:
52
+ """Convert AgentAction to a JSON-serializable dictionary."""
53
+ return {
54
+ "tool": action.tool,
55
+ "tool_input": action.tool_input,
56
+ "log": action.log,
57
+ "type": action.type,
58
+ }
59
+
60
+ def get_agent_action_single(self, step: IntermediateStep,
61
+ last_llm_end_step: IntermediateStep | None) -> tuple[AgentAction, str]:
62
+ """Converts a single intermediate step to Tuple[AgentAction, str]."""
63
+ # use the previous llm output as log
64
+ log = getattr(last_llm_end_step.data, "output", "") if last_llm_end_step else ""
65
+ tool_name = step.name or ""
66
+ tool_input = getattr(step.data, "input", "") if step.data else ""
67
+ tool_output = getattr(step.data, "output", "") if step.data else ""
68
+
69
+ action = AgentAction(tool=tool_name, tool_input=tool_input, log=log)
70
+
71
+ return action, tool_output
72
+
73
+ def get_agent_actions(self, intermediate_steps: list[IntermediateStep],
74
+ event_filter: list[IntermediateStepType]) -> list[tuple[AgentAction, str]]:
75
+ """Converts a list of intermediate steps to a list of (AgentAction, output)."""
76
+ steps = self.filter_intermediate_steps(intermediate_steps, event_filter)
77
+ last_llm_end_step = None
78
+ agent_actions = []
79
+ for step in steps:
80
+ if step.event_type == IntermediateStepType.LLM_END:
81
+ last_llm_end_step = step
82
+ action = self.get_agent_action_single(step, "")
83
+ agent_actions.append(action)
84
+ else:
85
+ action = self.get_agent_action_single(step, last_llm_end_step)
86
+ agent_actions.append(action)
87
+
88
+ return agent_actions
89
+
90
+ def get_context(self, intermediate_steps: list[IntermediateStep],
91
+ event_filter: list[IntermediateStepType]) -> list[str]:
92
+ """Grab the output of all the tools and return them as retrieved context."""
93
+ count = 0
94
+ agent_actions = []
95
+ for step in intermediate_steps:
96
+ if step.event_type in event_filter and step.data and step.data.output:
97
+ agent_actions.append(f"**Step {count}**\n{str(step.data.output)}")
98
+ count += 1
99
+ return agent_actions
File without changes
@@ -0,0 +1,178 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import math
18
+ from collections.abc import Sequence
19
+
20
+ from pydantic import BaseModel
21
+ from ragas import EvaluationDataset
22
+ from ragas import SingleTurnSample
23
+ from ragas.dataset_schema import EvaluationResult
24
+ from ragas.llms import LangchainLLMWrapper
25
+ from ragas.metrics import Metric
26
+ from tqdm import tqdm
27
+
28
+ from nat.data_models.intermediate_step import IntermediateStepType
29
+ from nat.eval.evaluator.evaluator_model import EvalInput
30
+ from nat.eval.evaluator.evaluator_model import EvalInputItem
31
+ from nat.eval.evaluator.evaluator_model import EvalOutput
32
+ from nat.eval.evaluator.evaluator_model import EvalOutputItem
33
+ from nat.eval.utils.tqdm_position_registry import TqdmPositionRegistry
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class RAGEvaluator:
39
+
40
+ def __init__(self,
41
+ evaluator_llm: LangchainLLMWrapper,
42
+ metrics: Sequence[Metric],
43
+ max_concurrency=8,
44
+ input_obj_field: str | None = None):
45
+ self.evaluator_llm = evaluator_llm
46
+ self.metrics = metrics
47
+ self.max_concurrency = max_concurrency
48
+ self.input_obj_field = input_obj_field
49
+
50
+ def extract_input_obj(self, item: EvalInputItem) -> str:
51
+ """Extracts the input object from EvalInputItem based on the configured input_obj_field."""
52
+ input_obj = item.input_obj
53
+ if isinstance(input_obj, BaseModel):
54
+ if self.input_obj_field and hasattr(input_obj, self.input_obj_field):
55
+ # If input_obj_field is specified, return the value of that field
56
+ return str(getattr(input_obj, self.input_obj_field, ""))
57
+
58
+ # If no input_obj_field is specified, return the string representation of the model
59
+ return input_obj.model_dump_json()
60
+
61
+ if isinstance(input_obj, dict):
62
+ # If input_obj is a dict, return the JSON string representation
63
+ if self.input_obj_field and self.input_obj_field in input_obj:
64
+ # If input_obj_field is specified, return the value of that field
65
+ return str(input_obj[self.input_obj_field])
66
+
67
+ return str(input_obj) # Fallback to string representation of the dict
68
+
69
+ def eval_input_to_ragas(self, eval_input: EvalInput) -> EvaluationDataset:
70
+ """Converts EvalInput into a Ragas-compatible EvaluationDataset."""
71
+ from nat.eval.intermediate_step_adapter import IntermediateStepAdapter
72
+ event_filter = [IntermediateStepType.TOOL_END, IntermediateStepType.LLM_END, IntermediateStepType.CUSTOM_END]
73
+ samples = []
74
+
75
+ intermediate_step_adapter = IntermediateStepAdapter()
76
+ for item in eval_input.eval_input_items:
77
+ # Extract required fields from EvalInputItem
78
+ user_input = self.extract_input_obj(item) # Extract input object as string
79
+ reference = item.expected_output_obj # Reference correct answer
80
+ response = item.output_obj # Model's generated response
81
+
82
+ # Handle context extraction from trajectory if available
83
+ reference_contexts = [""] # Default to empty context
84
+ # implement context extraction from expected_trajectory
85
+
86
+ retrieved_contexts = intermediate_step_adapter.get_context(item.trajectory, event_filter)
87
+ # implement context extraction from expected_trajectory
88
+
89
+ # Create a SingleTurnSample
90
+ sample = SingleTurnSample(
91
+ user_input=user_input,
92
+ reference=reference,
93
+ response=response,
94
+ reference_contexts=reference_contexts,
95
+ retrieved_contexts=retrieved_contexts,
96
+ )
97
+ samples.append(sample)
98
+
99
+ return EvaluationDataset(samples=samples)
100
+
101
+ def ragas_to_eval_output(self, eval_input: EvalInput, results_dataset: EvaluationResult | None) -> EvalOutput:
102
+ """Converts the ragas EvaluationResult to nat EvalOutput"""
103
+
104
+ if not results_dataset:
105
+ logger.error("Ragas evaluation failed with no results")
106
+ return EvalOutput(average_score=0.0, eval_output_items=[])
107
+
108
+ scores: list[dict[str, float]] = results_dataset.scores
109
+
110
+ # If Ragas returned no scores, return empty output to avoid downstream errors
111
+ if not scores:
112
+ logger.warning("Ragas returned empty score list")
113
+ return EvalOutput(average_score=0.0, eval_output_items=[])
114
+
115
+ def _nan_to_zero(v: float | None) -> float:
116
+ """Convert NaN or None to 0.0 for safe arithmetic/serialization."""
117
+ return 0.0 if v is None or (isinstance(v, float) and math.isnan(v)) else v
118
+
119
+ # Convert from list of dicts to dict of lists, coercing NaN/None to 0.0
120
+ scores_dict = {metric: [_nan_to_zero(score.get(metric)) for score in scores] for metric in scores[0]}
121
+ first_metric_name = list(scores_dict.keys())[0] if scores_dict else None
122
+
123
+ # Compute the average of each metric, guarding against empty lists
124
+ average_scores = {
125
+ metric: (sum(values) / len(values) if values else 0.0)
126
+ for metric, values in scores_dict.items()
127
+ }
128
+
129
+ first_avg_score = average_scores.get(list(scores_dict.keys())[0], 0.0)
130
+ if isinstance(first_avg_score, float) and math.isnan(first_avg_score):
131
+ first_avg_score = 0.0
132
+
133
+ df = results_dataset.to_pandas()
134
+ # Get id from eval_input if df size matches number of eval_input_items
135
+ if len(eval_input.eval_input_items) >= len(df):
136
+ ids = [item.id for item in eval_input.eval_input_items] # Extract IDs
137
+ else:
138
+ ids = df["user_input"].tolist() # Use "user_input" as ID fallback
139
+
140
+ # Construct EvalOutputItem list
141
+ eval_output_items = [
142
+ EvalOutputItem(
143
+ id=ids[i],
144
+ score=_nan_to_zero(getattr(row, first_metric_name, 0.0) if first_metric_name else 0.0),
145
+ reasoning={
146
+ key:
147
+ getattr(row, key, None) # Use getattr to safely access attributes
148
+ for key in ["user_input", "reference", "response", "retrieved_contexts"]
149
+ }) for i, row in enumerate(df.itertuples(index=False))
150
+ ]
151
+ # Return EvalOutput
152
+ return EvalOutput(average_score=first_avg_score, eval_output_items=eval_output_items)
153
+
154
+ async def evaluate(self, eval_input: EvalInput) -> EvalOutput:
155
+ """Run Ragas metrics evaluation on the provided EvalInput"""
156
+ from ragas import evaluate as ragas_evaluate
157
+ from ragas.run_config import RunConfig
158
+
159
+ ragas_dataset = self.eval_input_to_ragas(eval_input)
160
+ tqdm_position = TqdmPositionRegistry.claim()
161
+ first_metric_name = self.metrics[0].name
162
+ pbar = tqdm(total=len(ragas_dataset), desc=f"Evaluating Ragas {first_metric_name}", position=tqdm_position)
163
+ try:
164
+ results_dataset = ragas_evaluate(dataset=ragas_dataset,
165
+ metrics=self.metrics,
166
+ show_progress=True,
167
+ llm=self.evaluator_llm,
168
+ run_config=RunConfig(max_workers=self.max_concurrency),
169
+ _pbar=pbar)
170
+ except Exception as e:
171
+ # On exception we still continue with other evaluators. Log and return an avg_score of 0.0
172
+ logger.exception("Error evaluating ragas metric, Error: %s", e, exc_info=True)
173
+ results_dataset = None
174
+ finally:
175
+ pbar.close()
176
+ TqdmPositionRegistry.release(tqdm_position)
177
+
178
+ return self.ragas_to_eval_output(eval_input, results_dataset)
@@ -0,0 +1,143 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+ from pydantic import model_validator
21
+
22
+ from nat.builder.builder import EvalBuilder
23
+ from nat.builder.evaluator import EvaluatorInfo
24
+ from nat.builder.framework_enum import LLMFrameworkEnum
25
+ from nat.cli.register_workflow import register_evaluator
26
+ from nat.data_models.evaluator import EvaluatorBaseConfig
27
+ from nat.eval.evaluator.evaluator_model import EvalInput
28
+ from nat.eval.evaluator.evaluator_model import EvalOutput
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class RagasMetricConfig(BaseModel):
34
+ ''' RAGAS metrics configuration
35
+ skip: Allows the metric config to be present but not used
36
+ kwargs: Additional arguments to pass to the metric's callable
37
+ '''
38
+ skip: bool = False
39
+ # kwargs specific to the metric's callable
40
+ kwargs: dict | None = None
41
+
42
+
43
+ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
44
+ """Evaluation using RAGAS metrics."""
45
+
46
+ llm_name: str = Field(description="LLM as a judge.")
47
+ # Ragas metric
48
+ metric: str | dict[str, RagasMetricConfig] = Field(default="AnswerAccuracy",
49
+ description="RAGAS metric callable with optional 'kwargs:'")
50
+ input_obj_field: str | None = Field(
51
+ default=None, description="The field in the input object that contains the content to evaluate.")
52
+
53
+ @model_validator(mode="before")
54
+ @classmethod
55
+ def validate_metric(cls, values):
56
+ """Ensures metric is either a string or a single-item dictionary."""
57
+ metric = values.get("metric")
58
+
59
+ if isinstance(metric, dict):
60
+ if len(metric) != 1:
61
+ raise ValueError("Only one metric is allowed in the configuration.")
62
+ _, value = next(iter(metric.items()))
63
+ if not isinstance(value, dict):
64
+ raise ValueError("Metric value must be a RagasMetricConfig object.")
65
+ elif not isinstance(metric, str):
66
+ raise ValueError("Metric must be either a string or a single-item dictionary.")
67
+
68
+ return values
69
+
70
+ @property
71
+ def metric_name(self) -> str:
72
+ """Returns the single metric name."""
73
+ if isinstance(self.metric, str):
74
+ return self.metric
75
+ if isinstance(self.metric, dict) and self.metric:
76
+ return next(iter(self.metric.keys())) # pylint: disable=no-member
77
+ return ""
78
+
79
+ @property
80
+ def metric_config(self) -> RagasMetricConfig:
81
+ """Returns the metric configuration (or a default if only a string is provided)."""
82
+ if isinstance(self.metric, str):
83
+ return RagasMetricConfig() # Default config when only a metric name is given
84
+ if isinstance(self.metric, dict) and self.metric:
85
+ return next(iter(self.metric.values())) # pylint: disable=no-member
86
+ return RagasMetricConfig() # Default config when an invalid type is provided
87
+
88
+
89
+ @register_evaluator(config_type=RagasEvaluatorConfig)
90
+ async def register_ragas_evaluator(config: RagasEvaluatorConfig, builder: EvalBuilder):
91
+ from ragas.metrics import Metric
92
+
93
+ def get_ragas_metric(metric_name: str) -> Metric | None:
94
+ """
95
+ Fetch callable for RAGAS metrics
96
+ """
97
+ try:
98
+ import ragas.metrics as ragas_metrics
99
+
100
+ return getattr(ragas_metrics, metric_name)
101
+ except ImportError as e:
102
+ message = f"Ragas metrics not found {e}."
103
+ logger.error(message)
104
+ raise ValueError(message) from e
105
+ except AttributeError as e:
106
+ message = f"Ragas metric {metric_name} not found {e}."
107
+ logger.error(message)
108
+ return None
109
+
110
+ async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
111
+ '''Run the RAGAS evaluation and return the average scores and evaluation results dataframe'''
112
+ if not _evaluator:
113
+ logger.warning("No evaluator found for RAGAS metrics.")
114
+ # return empty results if no evaluator is found
115
+ return EvalOutput(average_score=0.0, eval_output_items=[])
116
+
117
+ return await _evaluator.evaluate(eval_input)
118
+
119
+ from .evaluate import RAGEvaluator
120
+
121
+ # Get LLM
122
+ llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
123
+
124
+ # Get RAGAS metric callable from the metric config and create a list of metric-callables
125
+ metrics = []
126
+ # currently only one metric is supported
127
+ metric_name = config.metric_name # Extracts the metric name
128
+ metric_config = config.metric_config # Extracts the config (handles str/dict cases)
129
+
130
+ # Skip if `skip` is True
131
+ if not metric_config.skip:
132
+ metric_callable = get_ragas_metric(metric_name)
133
+ if metric_callable:
134
+ kwargs = metric_config.kwargs or {}
135
+ metrics.append(metric_callable(**kwargs))
136
+
137
+ # Create the RAG evaluator
138
+ _evaluator = RAGEvaluator(evaluator_llm=llm,
139
+ metrics=metrics,
140
+ max_concurrency=builder.get_max_concurrency(),
141
+ input_obj_field=config.input_obj_field) if metrics else None
142
+
143
+ yield EvaluatorInfo(config=config, evaluate_fn=evaluate_fn, description="Evaluator for RAGAS metrics")
nat/eval/register.py ADDED
@@ -0,0 +1,23 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # flake8: noqa
17
+ # pylint: disable=unused-import
18
+
19
+ # Import evaluators which need to be automatically registered here
20
+ from .rag_evaluator.register import register_ragas_evaluator
21
+ from .swe_bench_evaluator.register import register_swe_bench_evaluator
22
+ from .trajectory_evaluator.register import register_trajectory_evaluator
23
+ from .tunable_rag_evaluator.register import register_tunable_rag_evaluator
@@ -0,0 +1,133 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import json
18
+ import logging
19
+
20
+ import aiohttp
21
+ from pydantic import ValidationError
22
+ from tqdm import tqdm
23
+
24
+ from nat.data_models.api_server import ResponseIntermediateStep
25
+ from nat.data_models.intermediate_step import IntermediateStep
26
+ from nat.data_models.intermediate_step import IntermediateStepPayload
27
+ from nat.data_models.invocation_node import InvocationNode
28
+ from nat.eval.config import EvaluationRunConfig
29
+ from nat.eval.evaluator.evaluator_model import EvalInput
30
+ from nat.eval.evaluator.evaluator_model import EvalInputItem
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Constants for streaming response prefixes
35
+ DATA_PREFIX = "data: "
36
+ INTERMEDIATE_DATA_PREFIX = "intermediate_data: "
37
+
38
+
39
+ class EvaluationRemoteWorkflowHandler:
40
+
41
+ def __init__(self, config: EvaluationRunConfig, max_concurrency: int):
42
+ self.config = config
43
+ # Run metadata
44
+ self.semaphore = asyncio.Semaphore(max_concurrency)
45
+
46
+ async def run_workflow_remote_single(self, session: aiohttp.ClientSession, item: EvalInputItem):
47
+ """
48
+ Sends a single input to the endpoint hosting the workflow and retrieves the response.
49
+ """
50
+ question = item.input_obj
51
+ # generate request format
52
+ payload = {"input_message": question}
53
+
54
+ try:
55
+ # Use the streaming endpoint
56
+ endpoint = f"{self.config.endpoint}/generate/full"
57
+ async with session.post(endpoint, json=payload) as response:
58
+ response.raise_for_status() # Raise an exception for HTTP errors
59
+
60
+ # Initialize variables to store the response
61
+ final_response = None
62
+ intermediate_steps = []
63
+
64
+ # Process the streaming response
65
+ async for line in response.content:
66
+ line = line.decode('utf-8').strip()
67
+ if not line:
68
+ continue
69
+
70
+ if line.startswith(DATA_PREFIX):
71
+ # This is a generate response chunk
72
+ try:
73
+ chunk_data = json.loads(line[len(DATA_PREFIX):])
74
+ if chunk_data.get("value"):
75
+ final_response = chunk_data.get("value")
76
+ except json.JSONDecodeError as e:
77
+ logger.error("Failed to parse generate response chunk: %s", e)
78
+ continue
79
+ elif line.startswith(INTERMEDIATE_DATA_PREFIX):
80
+ # This is an intermediate step
81
+ try:
82
+ step_data = json.loads(line[len(INTERMEDIATE_DATA_PREFIX):])
83
+ response_intermediate = ResponseIntermediateStep.model_validate(step_data)
84
+ # The payload is expected to be IntermediateStepPayload
85
+ payload = IntermediateStepPayload.model_validate_json(response_intermediate.payload)
86
+ intermediate_step = IntermediateStep(parent_id="remote",
87
+ function_ancestry=InvocationNode(
88
+ function_name=payload.name or "remote_function",
89
+ function_id=payload.UUID or "remote_function_id"),
90
+ payload=payload)
91
+ intermediate_steps.append(intermediate_step)
92
+ except (json.JSONDecodeError, ValidationError) as e:
93
+ logger.error("Failed to parse intermediate step: %s", e)
94
+ continue
95
+
96
+ except aiohttp.ClientError as e:
97
+ # Handle connection or HTTP-related errors
98
+ logger.error("Request failed for question %s: %s", question, e)
99
+ item.output_obj = None
100
+ item.trajectory = []
101
+ return
102
+
103
+ # Extract and fill the item with the response and intermediate steps
104
+ item.output_obj = final_response
105
+ item.trajectory = intermediate_steps
106
+ return
107
+
108
+ async def run_workflow_remote_with_limits(self, session: aiohttp.ClientSession, item: EvalInputItem, pbar: tqdm):
109
+ """
110
+ Sends limited number of concurrent requests to a remote workflow and retrieves responses.
111
+ """
112
+ async with self.semaphore:
113
+ await self.run_workflow_remote_single(session=session, item=item)
114
+ pbar.update(1)
115
+
116
+ async def run_workflow_remote(self, eval_input: EvalInput) -> EvalInput:
117
+ """
118
+ Sends inputs to a workflow hosted on a remote endpoint.
119
+ """
120
+ timeout = aiohttp.ClientTimeout(total=self.config.endpoint_timeout)
121
+ try:
122
+ pbar = tqdm(total=len(eval_input.eval_input_items), desc="Running workflow", unit="item")
123
+ async with aiohttp.ClientSession(timeout=timeout) as session:
124
+ # get the questions from the eval_input
125
+ tasks = [
126
+ self.run_workflow_remote_with_limits(session, item, pbar) for item in eval_input.eval_input_items
127
+ ]
128
+ await asyncio.gather(*tasks)
129
+
130
+ finally:
131
+ pbar.close()
132
+
133
+ return eval_input
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,39 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import typing
17
+
18
+ from pydantic import BaseModel
19
+
20
+ from nat.eval.config import EvaluationRunConfig
21
+ from nat.eval.config import EvaluationRunOutput
22
+
23
+
24
+ class MultiEvaluationRunConfig(BaseModel):
25
+ """
26
+ Parameters used for a multi-evaluation run.
27
+ This includes a dict of configs. The key is an id of any type.
28
+ Each pass loads the config, applies the overrides and runs to completion
29
+ before the next pass starts.
30
+ """
31
+ configs: dict[typing.Any, EvaluationRunConfig]
32
+
33
+
34
+ class MultiEvaluationRunOutput(BaseModel):
35
+ """
36
+ Output of a multi-evaluation run.
37
+ The results per-pass are accumulated in the evaluation_run_outputs dict.
38
+ """
39
+ evaluation_run_outputs: dict[typing.Any, EvaluationRunOutput]