nvidia-nat 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/__init__.py +0 -0
  3. nat/agent/base.py +256 -0
  4. nat/agent/dual_node.py +67 -0
  5. nat/agent/react_agent/__init__.py +0 -0
  6. nat/agent/react_agent/agent.py +363 -0
  7. nat/agent/react_agent/output_parser.py +104 -0
  8. nat/agent/react_agent/prompt.py +44 -0
  9. nat/agent/react_agent/register.py +149 -0
  10. nat/agent/reasoning_agent/__init__.py +0 -0
  11. nat/agent/reasoning_agent/reasoning_agent.py +225 -0
  12. nat/agent/register.py +23 -0
  13. nat/agent/rewoo_agent/__init__.py +0 -0
  14. nat/agent/rewoo_agent/agent.py +415 -0
  15. nat/agent/rewoo_agent/prompt.py +110 -0
  16. nat/agent/rewoo_agent/register.py +157 -0
  17. nat/agent/tool_calling_agent/__init__.py +0 -0
  18. nat/agent/tool_calling_agent/agent.py +119 -0
  19. nat/agent/tool_calling_agent/register.py +106 -0
  20. nat/authentication/__init__.py +14 -0
  21. nat/authentication/api_key/__init__.py +14 -0
  22. nat/authentication/api_key/api_key_auth_provider.py +96 -0
  23. nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
  24. nat/authentication/api_key/register.py +26 -0
  25. nat/authentication/exceptions/__init__.py +14 -0
  26. nat/authentication/exceptions/api_key_exceptions.py +38 -0
  27. nat/authentication/http_basic_auth/__init__.py +0 -0
  28. nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  29. nat/authentication/http_basic_auth/register.py +30 -0
  30. nat/authentication/interfaces.py +93 -0
  31. nat/authentication/oauth2/__init__.py +14 -0
  32. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
  33. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  34. nat/authentication/oauth2/register.py +25 -0
  35. nat/authentication/register.py +21 -0
  36. nat/builder/__init__.py +0 -0
  37. nat/builder/builder.py +285 -0
  38. nat/builder/component_utils.py +316 -0
  39. nat/builder/context.py +270 -0
  40. nat/builder/embedder.py +24 -0
  41. nat/builder/eval_builder.py +161 -0
  42. nat/builder/evaluator.py +29 -0
  43. nat/builder/framework_enum.py +24 -0
  44. nat/builder/front_end.py +73 -0
  45. nat/builder/function.py +344 -0
  46. nat/builder/function_base.py +380 -0
  47. nat/builder/function_info.py +627 -0
  48. nat/builder/intermediate_step_manager.py +174 -0
  49. nat/builder/llm.py +25 -0
  50. nat/builder/retriever.py +25 -0
  51. nat/builder/user_interaction_manager.py +78 -0
  52. nat/builder/workflow.py +148 -0
  53. nat/builder/workflow_builder.py +1117 -0
  54. nat/cli/__init__.py +14 -0
  55. nat/cli/cli_utils/__init__.py +0 -0
  56. nat/cli/cli_utils/config_override.py +231 -0
  57. nat/cli/cli_utils/validation.py +37 -0
  58. nat/cli/commands/__init__.py +0 -0
  59. nat/cli/commands/configure/__init__.py +0 -0
  60. nat/cli/commands/configure/channel/__init__.py +0 -0
  61. nat/cli/commands/configure/channel/add.py +28 -0
  62. nat/cli/commands/configure/channel/channel.py +34 -0
  63. nat/cli/commands/configure/channel/remove.py +30 -0
  64. nat/cli/commands/configure/channel/update.py +30 -0
  65. nat/cli/commands/configure/configure.py +33 -0
  66. nat/cli/commands/evaluate.py +139 -0
  67. nat/cli/commands/info/__init__.py +14 -0
  68. nat/cli/commands/info/info.py +37 -0
  69. nat/cli/commands/info/list_channels.py +32 -0
  70. nat/cli/commands/info/list_components.py +129 -0
  71. nat/cli/commands/info/list_mcp.py +304 -0
  72. nat/cli/commands/registry/__init__.py +14 -0
  73. nat/cli/commands/registry/publish.py +88 -0
  74. nat/cli/commands/registry/pull.py +118 -0
  75. nat/cli/commands/registry/registry.py +36 -0
  76. nat/cli/commands/registry/remove.py +108 -0
  77. nat/cli/commands/registry/search.py +155 -0
  78. nat/cli/commands/sizing/__init__.py +14 -0
  79. nat/cli/commands/sizing/calc.py +297 -0
  80. nat/cli/commands/sizing/sizing.py +27 -0
  81. nat/cli/commands/start.py +246 -0
  82. nat/cli/commands/uninstall.py +81 -0
  83. nat/cli/commands/validate.py +47 -0
  84. nat/cli/commands/workflow/__init__.py +14 -0
  85. nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  86. nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
  87. nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
  88. nat/cli/commands/workflow/templates/register.py.j2 +5 -0
  89. nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
  90. nat/cli/commands/workflow/workflow.py +37 -0
  91. nat/cli/commands/workflow/workflow_commands.py +317 -0
  92. nat/cli/entrypoint.py +135 -0
  93. nat/cli/main.py +57 -0
  94. nat/cli/register_workflow.py +488 -0
  95. nat/cli/type_registry.py +1000 -0
  96. nat/data_models/__init__.py +14 -0
  97. nat/data_models/api_server.py +716 -0
  98. nat/data_models/authentication.py +231 -0
  99. nat/data_models/common.py +171 -0
  100. nat/data_models/component.py +58 -0
  101. nat/data_models/component_ref.py +168 -0
  102. nat/data_models/config.py +410 -0
  103. nat/data_models/dataset_handler.py +169 -0
  104. nat/data_models/discovery_metadata.py +305 -0
  105. nat/data_models/embedder.py +27 -0
  106. nat/data_models/evaluate.py +127 -0
  107. nat/data_models/evaluator.py +26 -0
  108. nat/data_models/front_end.py +26 -0
  109. nat/data_models/function.py +30 -0
  110. nat/data_models/function_dependencies.py +72 -0
  111. nat/data_models/interactive.py +246 -0
  112. nat/data_models/intermediate_step.py +302 -0
  113. nat/data_models/invocation_node.py +38 -0
  114. nat/data_models/llm.py +27 -0
  115. nat/data_models/logging.py +26 -0
  116. nat/data_models/memory.py +27 -0
  117. nat/data_models/object_store.py +44 -0
  118. nat/data_models/profiler.py +54 -0
  119. nat/data_models/registry_handler.py +26 -0
  120. nat/data_models/retriever.py +30 -0
  121. nat/data_models/retry_mixin.py +35 -0
  122. nat/data_models/span.py +190 -0
  123. nat/data_models/step_adaptor.py +64 -0
  124. nat/data_models/streaming.py +33 -0
  125. nat/data_models/swe_bench_model.py +54 -0
  126. nat/data_models/telemetry_exporter.py +26 -0
  127. nat/data_models/ttc_strategy.py +30 -0
  128. nat/embedder/__init__.py +0 -0
  129. nat/embedder/nim_embedder.py +59 -0
  130. nat/embedder/openai_embedder.py +43 -0
  131. nat/embedder/register.py +22 -0
  132. nat/eval/__init__.py +14 -0
  133. nat/eval/config.py +60 -0
  134. nat/eval/dataset_handler/__init__.py +0 -0
  135. nat/eval/dataset_handler/dataset_downloader.py +106 -0
  136. nat/eval/dataset_handler/dataset_filter.py +52 -0
  137. nat/eval/dataset_handler/dataset_handler.py +367 -0
  138. nat/eval/evaluate.py +510 -0
  139. nat/eval/evaluator/__init__.py +14 -0
  140. nat/eval/evaluator/base_evaluator.py +77 -0
  141. nat/eval/evaluator/evaluator_model.py +45 -0
  142. nat/eval/intermediate_step_adapter.py +99 -0
  143. nat/eval/rag_evaluator/__init__.py +0 -0
  144. nat/eval/rag_evaluator/evaluate.py +178 -0
  145. nat/eval/rag_evaluator/register.py +143 -0
  146. nat/eval/register.py +23 -0
  147. nat/eval/remote_workflow.py +133 -0
  148. nat/eval/runners/__init__.py +14 -0
  149. nat/eval/runners/config.py +39 -0
  150. nat/eval/runners/multi_eval_runner.py +54 -0
  151. nat/eval/runtime_event_subscriber.py +52 -0
  152. nat/eval/swe_bench_evaluator/__init__.py +0 -0
  153. nat/eval/swe_bench_evaluator/evaluate.py +215 -0
  154. nat/eval/swe_bench_evaluator/register.py +36 -0
  155. nat/eval/trajectory_evaluator/__init__.py +0 -0
  156. nat/eval/trajectory_evaluator/evaluate.py +75 -0
  157. nat/eval/trajectory_evaluator/register.py +40 -0
  158. nat/eval/tunable_rag_evaluator/__init__.py +0 -0
  159. nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
  160. nat/eval/tunable_rag_evaluator/register.py +52 -0
  161. nat/eval/usage_stats.py +41 -0
  162. nat/eval/utils/__init__.py +0 -0
  163. nat/eval/utils/output_uploader.py +140 -0
  164. nat/eval/utils/tqdm_position_registry.py +40 -0
  165. nat/eval/utils/weave_eval.py +184 -0
  166. nat/experimental/__init__.py +0 -0
  167. nat/experimental/decorators/__init__.py +0 -0
  168. nat/experimental/decorators/experimental_warning_decorator.py +134 -0
  169. nat/experimental/test_time_compute/__init__.py +0 -0
  170. nat/experimental/test_time_compute/editing/__init__.py +0 -0
  171. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  172. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  173. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  174. nat/experimental/test_time_compute/functions/__init__.py +0 -0
  175. nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  176. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
  177. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
  178. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
  179. nat/experimental/test_time_compute/models/__init__.py +0 -0
  180. nat/experimental/test_time_compute/models/editor_config.py +132 -0
  181. nat/experimental/test_time_compute/models/scoring_config.py +112 -0
  182. nat/experimental/test_time_compute/models/search_config.py +120 -0
  183. nat/experimental/test_time_compute/models/selection_config.py +154 -0
  184. nat/experimental/test_time_compute/models/stage_enums.py +43 -0
  185. nat/experimental/test_time_compute/models/strategy_base.py +66 -0
  186. nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
  187. nat/experimental/test_time_compute/models/ttc_item.py +48 -0
  188. nat/experimental/test_time_compute/register.py +36 -0
  189. nat/experimental/test_time_compute/scoring/__init__.py +0 -0
  190. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  191. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  192. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  193. nat/experimental/test_time_compute/search/__init__.py +0 -0
  194. nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  195. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  196. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  197. nat/experimental/test_time_compute/selection/__init__.py +0 -0
  198. nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  199. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  200. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
  201. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  202. nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  203. nat/front_ends/__init__.py +14 -0
  204. nat/front_ends/console/__init__.py +14 -0
  205. nat/front_ends/console/authentication_flow_handler.py +233 -0
  206. nat/front_ends/console/console_front_end_config.py +32 -0
  207. nat/front_ends/console/console_front_end_plugin.py +96 -0
  208. nat/front_ends/console/register.py +25 -0
  209. nat/front_ends/cron/__init__.py +14 -0
  210. nat/front_ends/fastapi/__init__.py +14 -0
  211. nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  212. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  213. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
  214. nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
  215. nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  216. nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
  217. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
  218. nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
  219. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  220. nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  221. nat/front_ends/fastapi/job_store.py +183 -0
  222. nat/front_ends/fastapi/main.py +72 -0
  223. nat/front_ends/fastapi/message_handler.py +320 -0
  224. nat/front_ends/fastapi/message_validator.py +352 -0
  225. nat/front_ends/fastapi/register.py +25 -0
  226. nat/front_ends/fastapi/response_helpers.py +195 -0
  227. nat/front_ends/fastapi/step_adaptor.py +319 -0
  228. nat/front_ends/mcp/__init__.py +14 -0
  229. nat/front_ends/mcp/mcp_front_end_config.py +36 -0
  230. nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
  231. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
  232. nat/front_ends/mcp/register.py +27 -0
  233. nat/front_ends/mcp/tool_converter.py +241 -0
  234. nat/front_ends/register.py +22 -0
  235. nat/front_ends/simple_base/__init__.py +14 -0
  236. nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
  237. nat/llm/__init__.py +0 -0
  238. nat/llm/aws_bedrock_llm.py +57 -0
  239. nat/llm/nim_llm.py +46 -0
  240. nat/llm/openai_llm.py +46 -0
  241. nat/llm/register.py +23 -0
  242. nat/llm/utils/__init__.py +14 -0
  243. nat/llm/utils/env_config_value.py +94 -0
  244. nat/llm/utils/error.py +17 -0
  245. nat/memory/__init__.py +20 -0
  246. nat/memory/interfaces.py +183 -0
  247. nat/memory/models.py +112 -0
  248. nat/meta/pypi.md +58 -0
  249. nat/object_store/__init__.py +20 -0
  250. nat/object_store/in_memory_object_store.py +76 -0
  251. nat/object_store/interfaces.py +84 -0
  252. nat/object_store/models.py +38 -0
  253. nat/object_store/register.py +20 -0
  254. nat/observability/__init__.py +14 -0
  255. nat/observability/exporter/__init__.py +14 -0
  256. nat/observability/exporter/base_exporter.py +449 -0
  257. nat/observability/exporter/exporter.py +78 -0
  258. nat/observability/exporter/file_exporter.py +33 -0
  259. nat/observability/exporter/processing_exporter.py +322 -0
  260. nat/observability/exporter/raw_exporter.py +52 -0
  261. nat/observability/exporter/span_exporter.py +288 -0
  262. nat/observability/exporter_manager.py +335 -0
  263. nat/observability/mixin/__init__.py +14 -0
  264. nat/observability/mixin/batch_config_mixin.py +26 -0
  265. nat/observability/mixin/collector_config_mixin.py +23 -0
  266. nat/observability/mixin/file_mixin.py +288 -0
  267. nat/observability/mixin/file_mode.py +23 -0
  268. nat/observability/mixin/resource_conflict_mixin.py +134 -0
  269. nat/observability/mixin/serialize_mixin.py +61 -0
  270. nat/observability/mixin/type_introspection_mixin.py +183 -0
  271. nat/observability/processor/__init__.py +14 -0
  272. nat/observability/processor/batching_processor.py +310 -0
  273. nat/observability/processor/callback_processor.py +42 -0
  274. nat/observability/processor/intermediate_step_serializer.py +28 -0
  275. nat/observability/processor/processor.py +71 -0
  276. nat/observability/register.py +96 -0
  277. nat/observability/utils/__init__.py +14 -0
  278. nat/observability/utils/dict_utils.py +236 -0
  279. nat/observability/utils/time_utils.py +31 -0
  280. nat/plugins/.namespace +1 -0
  281. nat/profiler/__init__.py +0 -0
  282. nat/profiler/calc/__init__.py +14 -0
  283. nat/profiler/calc/calc_runner.py +627 -0
  284. nat/profiler/calc/calculations.py +288 -0
  285. nat/profiler/calc/data_models.py +188 -0
  286. nat/profiler/calc/plot.py +345 -0
  287. nat/profiler/callbacks/__init__.py +0 -0
  288. nat/profiler/callbacks/agno_callback_handler.py +295 -0
  289. nat/profiler/callbacks/base_callback_class.py +20 -0
  290. nat/profiler/callbacks/langchain_callback_handler.py +290 -0
  291. nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
  292. nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  293. nat/profiler/callbacks/token_usage_base_model.py +27 -0
  294. nat/profiler/data_frame_row.py +51 -0
  295. nat/profiler/data_models.py +24 -0
  296. nat/profiler/decorators/__init__.py +0 -0
  297. nat/profiler/decorators/framework_wrapper.py +131 -0
  298. nat/profiler/decorators/function_tracking.py +254 -0
  299. nat/profiler/forecasting/__init__.py +0 -0
  300. nat/profiler/forecasting/config.py +18 -0
  301. nat/profiler/forecasting/model_trainer.py +75 -0
  302. nat/profiler/forecasting/models/__init__.py +22 -0
  303. nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
  304. nat/profiler/forecasting/models/linear_model.py +197 -0
  305. nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
  306. nat/profiler/inference_metrics_model.py +28 -0
  307. nat/profiler/inference_optimization/__init__.py +0 -0
  308. nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  309. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  310. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  311. nat/profiler/inference_optimization/data_models.py +386 -0
  312. nat/profiler/inference_optimization/experimental/__init__.py +0 -0
  313. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  314. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
  315. nat/profiler/inference_optimization/llm_metrics.py +212 -0
  316. nat/profiler/inference_optimization/prompt_caching.py +163 -0
  317. nat/profiler/inference_optimization/token_uniqueness.py +107 -0
  318. nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
  319. nat/profiler/intermediate_property_adapter.py +102 -0
  320. nat/profiler/profile_runner.py +473 -0
  321. nat/profiler/utils.py +184 -0
  322. nat/registry_handlers/__init__.py +0 -0
  323. nat/registry_handlers/local/__init__.py +0 -0
  324. nat/registry_handlers/local/local_handler.py +176 -0
  325. nat/registry_handlers/local/register_local.py +37 -0
  326. nat/registry_handlers/metadata_factory.py +60 -0
  327. nat/registry_handlers/package_utils.py +571 -0
  328. nat/registry_handlers/pypi/__init__.py +0 -0
  329. nat/registry_handlers/pypi/pypi_handler.py +251 -0
  330. nat/registry_handlers/pypi/register_pypi.py +40 -0
  331. nat/registry_handlers/register.py +21 -0
  332. nat/registry_handlers/registry_handler_base.py +157 -0
  333. nat/registry_handlers/rest/__init__.py +0 -0
  334. nat/registry_handlers/rest/register_rest.py +56 -0
  335. nat/registry_handlers/rest/rest_handler.py +237 -0
  336. nat/registry_handlers/schemas/__init__.py +0 -0
  337. nat/registry_handlers/schemas/headers.py +42 -0
  338. nat/registry_handlers/schemas/package.py +68 -0
  339. nat/registry_handlers/schemas/publish.py +68 -0
  340. nat/registry_handlers/schemas/pull.py +82 -0
  341. nat/registry_handlers/schemas/remove.py +36 -0
  342. nat/registry_handlers/schemas/search.py +91 -0
  343. nat/registry_handlers/schemas/status.py +47 -0
  344. nat/retriever/__init__.py +0 -0
  345. nat/retriever/interface.py +41 -0
  346. nat/retriever/milvus/__init__.py +14 -0
  347. nat/retriever/milvus/register.py +81 -0
  348. nat/retriever/milvus/retriever.py +228 -0
  349. nat/retriever/models.py +77 -0
  350. nat/retriever/nemo_retriever/__init__.py +14 -0
  351. nat/retriever/nemo_retriever/register.py +60 -0
  352. nat/retriever/nemo_retriever/retriever.py +190 -0
  353. nat/retriever/register.py +22 -0
  354. nat/runtime/__init__.py +14 -0
  355. nat/runtime/loader.py +220 -0
  356. nat/runtime/runner.py +195 -0
  357. nat/runtime/session.py +162 -0
  358. nat/runtime/user_metadata.py +130 -0
  359. nat/settings/__init__.py +0 -0
  360. nat/settings/global_settings.py +318 -0
  361. nat/test/.namespace +1 -0
  362. nat/tool/__init__.py +0 -0
  363. nat/tool/chat_completion.py +74 -0
  364. nat/tool/code_execution/README.md +151 -0
  365. nat/tool/code_execution/__init__.py +0 -0
  366. nat/tool/code_execution/code_sandbox.py +267 -0
  367. nat/tool/code_execution/local_sandbox/.gitignore +1 -0
  368. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  369. nat/tool/code_execution/local_sandbox/__init__.py +13 -0
  370. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  371. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  372. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  373. nat/tool/code_execution/register.py +74 -0
  374. nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
  375. nat/tool/code_execution/utils.py +100 -0
  376. nat/tool/datetime_tools.py +42 -0
  377. nat/tool/document_search.py +141 -0
  378. nat/tool/github_tools/__init__.py +0 -0
  379. nat/tool/github_tools/create_github_commit.py +133 -0
  380. nat/tool/github_tools/create_github_issue.py +87 -0
  381. nat/tool/github_tools/create_github_pr.py +106 -0
  382. nat/tool/github_tools/get_github_file.py +106 -0
  383. nat/tool/github_tools/get_github_issue.py +166 -0
  384. nat/tool/github_tools/get_github_pr.py +256 -0
  385. nat/tool/github_tools/update_github_issue.py +100 -0
  386. nat/tool/mcp/__init__.py +14 -0
  387. nat/tool/mcp/exceptions.py +142 -0
  388. nat/tool/mcp/mcp_client.py +255 -0
  389. nat/tool/mcp/mcp_tool.py +96 -0
  390. nat/tool/memory_tools/__init__.py +0 -0
  391. nat/tool/memory_tools/add_memory_tool.py +79 -0
  392. nat/tool/memory_tools/delete_memory_tool.py +67 -0
  393. nat/tool/memory_tools/get_memory_tool.py +72 -0
  394. nat/tool/nvidia_rag.py +95 -0
  395. nat/tool/register.py +38 -0
  396. nat/tool/retriever.py +94 -0
  397. nat/tool/server_tools.py +66 -0
  398. nat/utils/__init__.py +0 -0
  399. nat/utils/data_models/__init__.py +0 -0
  400. nat/utils/data_models/schema_validator.py +58 -0
  401. nat/utils/debugging_utils.py +43 -0
  402. nat/utils/dump_distro_mapping.py +32 -0
  403. nat/utils/exception_handlers/__init__.py +0 -0
  404. nat/utils/exception_handlers/automatic_retries.py +289 -0
  405. nat/utils/exception_handlers/mcp.py +211 -0
  406. nat/utils/exception_handlers/schemas.py +114 -0
  407. nat/utils/io/__init__.py +0 -0
  408. nat/utils/io/model_processing.py +28 -0
  409. nat/utils/io/yaml_tools.py +119 -0
  410. nat/utils/log_utils.py +37 -0
  411. nat/utils/metadata_utils.py +74 -0
  412. nat/utils/optional_imports.py +142 -0
  413. nat/utils/producer_consumer_queue.py +178 -0
  414. nat/utils/reactive/__init__.py +0 -0
  415. nat/utils/reactive/base/__init__.py +0 -0
  416. nat/utils/reactive/base/observable_base.py +65 -0
  417. nat/utils/reactive/base/observer_base.py +55 -0
  418. nat/utils/reactive/base/subject_base.py +79 -0
  419. nat/utils/reactive/observable.py +59 -0
  420. nat/utils/reactive/observer.py +76 -0
  421. nat/utils/reactive/subject.py +131 -0
  422. nat/utils/reactive/subscription.py +49 -0
  423. nat/utils/settings/__init__.py +0 -0
  424. nat/utils/settings/global_settings.py +197 -0
  425. nat/utils/string_utils.py +38 -0
  426. nat/utils/type_converter.py +290 -0
  427. nat/utils/type_utils.py +484 -0
  428. nat/utils/url_utils.py +27 -0
  429. nvidia_nat-1.2.0.dist-info/METADATA +365 -0
  430. nvidia_nat-1.2.0.dist-info/RECORD +435 -0
  431. nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
  432. nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
  433. nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  434. nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
  435. nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,190 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import logging
18
+ import os
19
+ import typing
20
+ from functools import partial
21
+ from urllib.parse import urljoin
22
+
23
+ import httpx
24
+ from langchain_core.retrievers import BaseRetriever
25
+ from pydantic import BaseModel
26
+ from pydantic import Field
27
+ from pydantic import HttpUrl
28
+
29
+ from nat.retriever.interface import Retriever
30
+ from nat.retriever.models import Document
31
+ from nat.retriever.models import RetrieverError
32
+ from nat.retriever.models import RetrieverOutput
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class Collection(BaseModel):
38
+ id: str
39
+ name: str
40
+ meta: typing.Any
41
+ pipeline: str
42
+ created_at: str
43
+
44
+
45
+ class RetrieverPayload(BaseModel):
46
+ query: str
47
+ top_k: int = Field(le=50, gt=0)
48
+
49
+
50
+ class CollectionUnavailableError(RetrieverError):
51
+ pass
52
+
53
+
54
+ class NemoRetriever(Retriever):
55
+ """
56
+ Client for retrieving document chunks from a Nemo Retriever service.
57
+ """
58
+
59
+ def __init__(self, uri: str | HttpUrl, timeout: int = 60, nvidia_api_key: str = None, **kwargs):
60
+
61
+ self.base_url = str(uri)
62
+ self.timeout = timeout
63
+ self._search_func = self._search
64
+ self.api_key = nvidia_api_key if nvidia_api_key else os.getenv('NVIDIA_API_KEY')
65
+ self._bound_params = []
66
+ if not self.api_key:
67
+ logger.warning("No API key was specified as part of configuration or as an environment variable.")
68
+
69
+ def bind(self, **kwargs) -> None:
70
+ """
71
+ Bind default values to the search method. Cannot bind the 'query' parameter.
72
+
73
+ Args:
74
+ kwargs (dict): Key value pairs corresponding to the default values of search parameters.
75
+ """
76
+ if "query" in kwargs:
77
+ kwargs = {k: v for k, v in kwargs.items() if k != "query"}
78
+ self._search_func = partial(self._search_func, **kwargs)
79
+ self._bound_params = list(kwargs.keys())
80
+ logger.debug("Binding paramaters for search function: %s", kwargs)
81
+
82
+ def get_unbound_params(self) -> list[str]:
83
+ """
84
+ Returns a list of unbound parameters which will need to be passed to the search function.
85
+ """
86
+ return [param for param in ["query", "collection_name", "top_k"] if param not in self._bound_params]
87
+
88
+ async def get_collections(self, client) -> list[Collection]:
89
+ """
90
+ Get a list of all available collections as pydantic `Collection` objects
91
+ """
92
+ collection_response = await client.get(urljoin(self.base_url, "/v1/collections"))
93
+ collection_response.raise_for_status()
94
+ if not collection_response or len(collection_response.json().get('collections', [])) == 0:
95
+ raise CollectionUnavailableError(f"No collections available at {self.base_url}")
96
+
97
+ collections = [
98
+ Collection.model_validate(collection) for collection in collection_response.json()["collections"]
99
+ ]
100
+
101
+ return collections
102
+
103
+ async def get_collection_by_name(self, collection_name, client) -> Collection:
104
+ """
105
+ Retrieve a collection using it's name. Will return the first collection found if the name is ambiguous.
106
+ """
107
+ collections = await self.get_collections(client)
108
+ if (collection := next((c for c in collections if c.name == collection_name), None)) is None:
109
+ raise CollectionUnavailableError(f"Collection {collection_name} not found")
110
+ return collection
111
+
112
+ async def search(self, query: str, **kwargs):
113
+ return await self._search_func(query=query, **kwargs)
114
+
115
+ async def _search(
116
+ self,
117
+ query: str,
118
+ collection_name: str,
119
+ top_k: str,
120
+ output_fields: list[str] = None,
121
+ ):
122
+ """
123
+ Retrieve document chunks from the configured Nemo Retriever Service.
124
+ """
125
+ output = []
126
+ try:
127
+ async with httpx.AsyncClient(headers={"Authorization": f"Bearer {self.api_key}"},
128
+ timeout=self.timeout) as client:
129
+ collection = await self.get_collection_by_name(collection_name, client)
130
+ url = urljoin(self.base_url, f"/v1/collections/{collection.id}/search")
131
+
132
+ payload = RetrieverPayload(query=query, top_k=top_k)
133
+ response = await client.post(url, content=json.dumps(payload.model_dump(mode="python")))
134
+
135
+ logger.debug("response.status_code=%s", response.status_code)
136
+
137
+ response.raise_for_status()
138
+ output = response.json().get("chunks")
139
+
140
+ # Handle output fields
141
+ output = [_flatten(chunk, output_fields) for chunk in output]
142
+
143
+ return _wrap_nemo_results(output=output, content_field="content")
144
+
145
+ except Exception as e:
146
+ logger.exception("Encountered an error when retrieving results from Nemo Retriever: %s", e)
147
+ raise CollectionUnavailableError(
148
+ f"Error when retrieving documents from {collection_name} for query '{query}'") from e
149
+
150
+
151
+ def _wrap_nemo_results(output: list[dict], content_field: str):
152
+ return RetrieverOutput(results=[_wrap_nemo_single_results(o, content_field=content_field) for o in output])
153
+
154
+
155
+ def _wrap_nemo_single_results(output: dict, content_field: str):
156
+ return Document(page_content=output[content_field],
157
+ metadata={
158
+ k: v
159
+ for k, v in output.items() if k != content_field
160
+ })
161
+
162
+
163
+ def _flatten(obj: dict, output_fields: list[str]) -> list[str]:
164
+ base_fields = [
165
+ "format",
166
+ "id",
167
+ ]
168
+ if not output_fields:
169
+ output_fields = [
170
+ "format",
171
+ "id",
172
+ ]
173
+ output_fields.extend(list(obj["metadata"].keys()))
174
+ data = {"content": obj.get("content")}
175
+ for field in base_fields:
176
+ if field in output_fields:
177
+ data.update({field: obj[field]})
178
+
179
+ data.update({k: v for k, v in obj['metadata'].items() if k in output_fields})
180
+ return data
181
+
182
+
183
+ class NemoLangchainRetriever(BaseRetriever, BaseModel):
184
+ client: NemoRetriever
185
+
186
+ def _get_relevant_documents(self, query, *, run_manager, **kwargs):
187
+ raise NotImplementedError
188
+
189
+ async def _aget_relevant_documents(self, query, *, run_manager, **kwargs):
190
+ return await self.client.search(query, **kwargs)
@@ -0,0 +1,22 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: disable=unused-import
17
+ # flake8: noqa
18
+ # isort:skip_file
19
+
20
+ # Import any providers which need to be automatically registered here
21
+ import nat.retriever.milvus.register
22
+ import nat.retriever.nemo_retriever.register
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
nat/runtime/loader.py ADDED
@@ -0,0 +1,220 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import importlib.metadata
19
+ import logging
20
+ import time
21
+ from contextlib import asynccontextmanager
22
+ from enum import IntFlag
23
+ from enum import auto
24
+ from functools import lru_cache
25
+ from functools import reduce
26
+
27
+ from nat.builder.workflow_builder import WorkflowBuilder
28
+ from nat.cli.type_registry import GlobalTypeRegistry
29
+ from nat.data_models.config import Config
30
+ from nat.runtime.session import SessionManager
31
+ from nat.utils.data_models.schema_validator import validate_schema
32
+ from nat.utils.debugging_utils import is_debugger_attached
33
+ from nat.utils.io.yaml_tools import yaml_load
34
+ from nat.utils.type_utils import StrPath
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class PluginTypes(IntFlag):
40
+ COMPONENT = auto()
41
+ """
42
+ A plugin that is a component of the workflow. This includes tools, LLMs, retrievers, etc.
43
+ """
44
+ FRONT_END = auto()
45
+ """
46
+ A plugin that is a front end for the workflow. This includes FastAPI, Gradio, etc.
47
+ """
48
+ EVALUATOR = auto()
49
+ """
50
+ A plugin that is an evaluator for the workflow. This includes evaluators like RAGAS, SWE-bench, etc.
51
+ """
52
+ AUTHENTICATION = auto()
53
+ """
54
+ A plugin that is an API authentication provider for the workflow. This includes Oauth2, API Key, etc.
55
+ """
56
+ REGISTRY_HANDLER = auto()
57
+
58
+ # Convenience flag for groups of plugin types
59
+ CONFIG_OBJECT = COMPONENT | FRONT_END | EVALUATOR | AUTHENTICATION
60
+ """
61
+ Any plugin that can be specified in the NAT configuration file.
62
+ """
63
+ ALL = COMPONENT | FRONT_END | EVALUATOR | REGISTRY_HANDLER | AUTHENTICATION
64
+ """
65
+ All plugin types
66
+ """
67
+
68
+
69
+ def load_config(config_file: StrPath) -> Config:
70
+ """
71
+ This is the primary entry point for loading a NAT configuration file. It ensures that all plugins are
72
+ loaded and then validates the configuration file against the Config schema.
73
+
74
+ Parameters
75
+ ----------
76
+ config_file : StrPath
77
+ The path to the configuration file
78
+
79
+ Returns
80
+ -------
81
+ Config
82
+ The validated Config object
83
+ """
84
+
85
+ # Ensure all of the plugins are loaded
86
+ discover_and_register_plugins(PluginTypes.CONFIG_OBJECT)
87
+
88
+ config_yaml = yaml_load(config_file)
89
+
90
+ # Validate configuration adheres to NAT schemas
91
+ validated_nat_config = validate_schema(config_yaml, Config)
92
+
93
+ return validated_nat_config
94
+
95
+
96
+ @asynccontextmanager
97
+ async def load_workflow(config_file: StrPath, max_concurrency: int = -1):
98
+ """
99
+ Load the NAT configuration file and create an Runner object. This is the primary entry point for running
100
+ NAT workflows.
101
+
102
+ Parameters
103
+ ----------
104
+ config_file : StrPath
105
+ The path to the configuration file
106
+ max_concurrency : int, optional
107
+ The maximum number of parallel workflow invocations to support. Specifying 0 or -1 will allow an unlimited
108
+ count, by default -1
109
+ """
110
+
111
+ # Load the config object
112
+ config = load_config(config_file)
113
+
114
+ # Must yield the workflow function otherwise it cleans up
115
+ async with WorkflowBuilder.from_config(config=config) as workflow:
116
+
117
+ yield SessionManager(workflow.build(), max_concurrency=max_concurrency)
118
+
119
+
120
+ @lru_cache
121
+ def discover_entrypoints(plugin_type: PluginTypes):
122
+ """
123
+ Discover all the requested plugin types which were registered via an entry point group and return them.
124
+ """
125
+
126
+ entry_points = importlib.metadata.entry_points()
127
+
128
+ plugin_groups = []
129
+
130
+ # Add the specified plugin type to the list of groups to load
131
+ # The aiq entrypoints are intentionally left in the list to maintain backwards compatibility.
132
+ if (plugin_type & PluginTypes.COMPONENT):
133
+ plugin_groups.extend(["aiq.plugins", "aiq.components", "nat.plugins", "nat.components"])
134
+ if (plugin_type & PluginTypes.FRONT_END):
135
+ plugin_groups.extend(["aiq.front_ends", "nat.front_ends"])
136
+ if (plugin_type & PluginTypes.REGISTRY_HANDLER):
137
+ plugin_groups.extend(["aiq.registry_handlers", "nat.registry_handlers"])
138
+ if (plugin_type & PluginTypes.EVALUATOR):
139
+ plugin_groups.extend(["aiq.evaluators", "nat.evaluators"])
140
+ if (plugin_type & PluginTypes.AUTHENTICATION):
141
+ plugin_groups.extend(["aiq.authentication_providers", "nat.authentication_providers"])
142
+
143
+ # Get the entry points for the specified groups
144
+ nat_plugins = reduce(lambda x, y: list(x) + list(y), [entry_points.select(group=y) for y in plugin_groups])
145
+
146
+ return nat_plugins
147
+
148
+
149
+ @lru_cache
150
+ def get_all_entrypoints_distro_mapping() -> dict[str, str]:
151
+ """
152
+ Get the mapping of all NAT entry points to their distribution names.
153
+ """
154
+
155
+ mapping = {}
156
+ nat_entrypoints = discover_entrypoints(PluginTypes.ALL)
157
+ for ep in nat_entrypoints:
158
+ ep_module_parts = ep.module.split(".")
159
+ current_parts = []
160
+ for part in ep_module_parts:
161
+ current_parts.append(part)
162
+ module_prefix = ".".join(current_parts)
163
+ mapping[module_prefix] = ep.dist.name
164
+
165
+ return mapping
166
+
167
+
168
+ def discover_and_register_plugins(plugin_type: PluginTypes):
169
+ """
170
+ Discover all the requested plugin types which were registered via an entry point group and register them into the
171
+ GlobalTypeRegistry.
172
+ """
173
+
174
+ # Get the entry points for the specified groups
175
+ nat_plugins = discover_entrypoints(plugin_type)
176
+
177
+ count = 0
178
+
179
+ # Pause registration hooks for performance. This is useful when loading a large number of plugins.
180
+ with GlobalTypeRegistry.get().pause_registration_changed_hooks():
181
+
182
+ for entry_point in nat_plugins:
183
+ try:
184
+ logger.debug("Loading module '%s' from entry point '%s'...", entry_point.module, entry_point.name)
185
+
186
+ start_time = time.time()
187
+
188
+ entry_point.load()
189
+
190
+ elapsed_time = (time.time() - start_time) * 1000
191
+
192
+ logger.debug("Loading module '%s' from entry point '%s'...Complete (%f ms)",
193
+ entry_point.module,
194
+ entry_point.name,
195
+ elapsed_time)
196
+
197
+ # Log a warning if the plugin took a long time to load. This can be useful for debugging slow imports.
198
+ # The threshold is 300 ms if no plugins have been loaded yet, and 100 ms otherwise. Triple the threshold
199
+ # if a debugger is attached.
200
+ if (elapsed_time > (300.0 if count == 0 else 150.0) * (3 if is_debugger_attached() else 1)):
201
+ logger.debug(
202
+ "Loading module '%s' from entry point '%s' took a long time (%f ms). "
203
+ "Ensure all imports are inside your registered functions.",
204
+ entry_point.module,
205
+ entry_point.name,
206
+ elapsed_time)
207
+
208
+ except ImportError:
209
+ logger.warning("Failed to import plugin '%s'", entry_point.name, exc_info=True)
210
+ # Optionally, you can mark the plugin as unavailable or take other actions
211
+
212
+ except Exception:
213
+ logger.exception("An error occurred while loading plugin '%s': {e}", entry_point.name, exc_info=True)
214
+
215
+ finally:
216
+ count += 1
217
+
218
+
219
+ # Compatibility alias
220
+ get_all_aiq_entrypoints_distro_mapping = get_all_entrypoints_distro_mapping
nat/runtime/runner.py ADDED
@@ -0,0 +1,195 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import typing
18
+ from enum import Enum
19
+
20
+ from nat.builder.context import Context
21
+ from nat.builder.context import ContextState
22
+ from nat.builder.function import Function
23
+ from nat.data_models.invocation_node import InvocationNode
24
+ from nat.observability.exporter_manager import ExporterManager
25
+ from nat.utils.reactive.subject import Subject
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class UserManagerBase:
31
+ pass
32
+
33
+
34
+ class RunnerState(Enum):
35
+ UNINITIALIZED = 0
36
+ INITIALIZED = 1
37
+ RUNNING = 2
38
+ COMPLETED = 3
39
+ FAILED = 4
40
+
41
+
42
+ _T = typing.TypeVar("_T")
43
+
44
+
45
+ class Runner:
46
+
47
+ def __init__(self,
48
+ input_message: typing.Any,
49
+ entry_fn: Function,
50
+ context_state: ContextState,
51
+ exporter_manager: ExporterManager):
52
+ """
53
+ The Runner class is used to run a workflow. It handles converting input and output data types and running the
54
+ workflow with the specified concurrency.
55
+
56
+ Parameters
57
+ ----------
58
+ input_message : typing.Any
59
+ The input message to the workflow
60
+ entry_fn : Function
61
+ The entry function to the workflow
62
+ context_state : ContextState
63
+ The context state to use
64
+ exporter_manager : ExporterManager
65
+ The exporter manager to use
66
+ """
67
+
68
+ if (entry_fn is None):
69
+ raise ValueError("entry_fn cannot be None")
70
+
71
+ self._entry_fn = entry_fn
72
+ self._context_state = context_state
73
+ self._context = Context(self._context_state)
74
+
75
+ self._state = RunnerState.UNINITIALIZED
76
+
77
+ self._input_message_token = None
78
+
79
+ # Before we start, we need to convert the input message to the workflow input type
80
+ self._input_message = input_message
81
+
82
+ self._exporter_manager = exporter_manager
83
+
84
+ @property
85
+ def context(self) -> Context:
86
+ return self._context
87
+
88
+ def convert(self, value: typing.Any, to_type: type[_T]) -> _T:
89
+ return self._entry_fn.convert(value, to_type)
90
+
91
+ async def __aenter__(self):
92
+
93
+ # Set the input message on the context
94
+ self._input_message_token = self._context_state.input_message.set(self._input_message)
95
+
96
+ # Create reactive event stream
97
+ self._context_state.event_stream.set(Subject())
98
+ self._context_state.active_function.set(InvocationNode(
99
+ function_name="root",
100
+ function_id="root",
101
+ ))
102
+
103
+ if (self._state == RunnerState.UNINITIALIZED):
104
+ self._state = RunnerState.INITIALIZED
105
+ else:
106
+ raise ValueError("Cannot enter the context more than once")
107
+
108
+ return self
109
+
110
+ async def __aexit__(self, exc_type, exc_value, traceback):
111
+
112
+ if (self._input_message_token is None):
113
+ raise ValueError("Cannot exit the context without entering it")
114
+
115
+ self._context_state.input_message.reset(self._input_message_token)
116
+
117
+ if (self._state not in (RunnerState.COMPLETED, RunnerState.FAILED)):
118
+ raise ValueError("Cannot exit the context without completing the workflow")
119
+
120
+ @typing.overload
121
+ async def result(self) -> typing.Any:
122
+ ...
123
+
124
+ @typing.overload
125
+ async def result(self, to_type: type[_T]) -> _T:
126
+ ...
127
+
128
+ async def result(self, to_type: type | None = None):
129
+
130
+ if (self._state != RunnerState.INITIALIZED):
131
+ raise ValueError("Cannot run the workflow without entering the context")
132
+
133
+ try:
134
+ self._state = RunnerState.RUNNING
135
+
136
+ if (not self._entry_fn.has_single_output):
137
+ raise ValueError("Workflow does not support single output")
138
+
139
+ async with self._exporter_manager.start(context_state=self._context_state):
140
+ # Run the workflow
141
+ result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type)
142
+
143
+ # Close the intermediate stream
144
+ event_stream = self._context_state.event_stream.get()
145
+ if event_stream:
146
+ event_stream.on_complete()
147
+
148
+ self._state = RunnerState.COMPLETED
149
+
150
+ return result
151
+ except Exception as e:
152
+ logger.exception("Error running workflow: %s", e)
153
+ event_stream = self._context_state.event_stream.get()
154
+ if event_stream:
155
+ event_stream.on_complete()
156
+ self._state = RunnerState.FAILED
157
+
158
+ raise
159
+
160
+ async def result_stream(self, to_type: type | None = None):
161
+
162
+ if (self._state != RunnerState.INITIALIZED):
163
+ raise ValueError("Cannot run the workflow without entering the context")
164
+
165
+ try:
166
+ self._state = RunnerState.RUNNING
167
+
168
+ if (not self._entry_fn.has_streaming_output):
169
+ raise ValueError("Workflow does not support streaming output")
170
+
171
+ # Run the workflow
172
+ async with self._exporter_manager.start(context_state=self._context_state):
173
+ async for m in self._entry_fn.astream(self._input_message, to_type=to_type):
174
+ yield m
175
+
176
+ self._state = RunnerState.COMPLETED
177
+
178
+ # Close the intermediate stream
179
+ event_stream = self._context_state.event_stream.get()
180
+ if event_stream:
181
+ event_stream.on_complete()
182
+
183
+ except Exception as e:
184
+ logger.exception("Error running workflow: %s", e)
185
+ event_stream = self._context_state.event_stream.get()
186
+ if event_stream:
187
+ event_stream.on_complete()
188
+ self._state = RunnerState.FAILED
189
+
190
+ raise
191
+
192
+
193
+ # Compatibility aliases with previous releases
194
+ AIQRunnerState = RunnerState
195
+ AIQRunner = Runner