nvidia-nat 1.1.0a20251020__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (480) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/__init__.py +0 -0
  3. nat/agent/base.py +265 -0
  4. nat/agent/dual_node.py +72 -0
  5. nat/agent/prompt_optimizer/__init__.py +0 -0
  6. nat/agent/prompt_optimizer/prompt.py +68 -0
  7. nat/agent/prompt_optimizer/register.py +149 -0
  8. nat/agent/react_agent/__init__.py +0 -0
  9. nat/agent/react_agent/agent.py +394 -0
  10. nat/agent/react_agent/output_parser.py +104 -0
  11. nat/agent/react_agent/prompt.py +44 -0
  12. nat/agent/react_agent/register.py +168 -0
  13. nat/agent/reasoning_agent/__init__.py +0 -0
  14. nat/agent/reasoning_agent/reasoning_agent.py +227 -0
  15. nat/agent/register.py +23 -0
  16. nat/agent/rewoo_agent/__init__.py +0 -0
  17. nat/agent/rewoo_agent/agent.py +593 -0
  18. nat/agent/rewoo_agent/prompt.py +107 -0
  19. nat/agent/rewoo_agent/register.py +175 -0
  20. nat/agent/tool_calling_agent/__init__.py +0 -0
  21. nat/agent/tool_calling_agent/agent.py +246 -0
  22. nat/agent/tool_calling_agent/register.py +129 -0
  23. nat/authentication/__init__.py +14 -0
  24. nat/authentication/api_key/__init__.py +14 -0
  25. nat/authentication/api_key/api_key_auth_provider.py +96 -0
  26. nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
  27. nat/authentication/api_key/register.py +26 -0
  28. nat/authentication/credential_validator/__init__.py +14 -0
  29. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  30. nat/authentication/exceptions/__init__.py +14 -0
  31. nat/authentication/exceptions/api_key_exceptions.py +38 -0
  32. nat/authentication/http_basic_auth/__init__.py +0 -0
  33. nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  34. nat/authentication/http_basic_auth/register.py +30 -0
  35. nat/authentication/interfaces.py +96 -0
  36. nat/authentication/oauth2/__init__.py +14 -0
  37. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +140 -0
  38. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  39. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  40. nat/authentication/oauth2/register.py +25 -0
  41. nat/authentication/register.py +20 -0
  42. nat/builder/__init__.py +0 -0
  43. nat/builder/builder.py +317 -0
  44. nat/builder/component_utils.py +320 -0
  45. nat/builder/context.py +321 -0
  46. nat/builder/embedder.py +24 -0
  47. nat/builder/eval_builder.py +166 -0
  48. nat/builder/evaluator.py +29 -0
  49. nat/builder/framework_enum.py +25 -0
  50. nat/builder/front_end.py +73 -0
  51. nat/builder/function.py +714 -0
  52. nat/builder/function_base.py +380 -0
  53. nat/builder/function_info.py +625 -0
  54. nat/builder/intermediate_step_manager.py +206 -0
  55. nat/builder/llm.py +25 -0
  56. nat/builder/retriever.py +25 -0
  57. nat/builder/user_interaction_manager.py +78 -0
  58. nat/builder/workflow.py +160 -0
  59. nat/builder/workflow_builder.py +1365 -0
  60. nat/cli/__init__.py +14 -0
  61. nat/cli/cli_utils/__init__.py +0 -0
  62. nat/cli/cli_utils/config_override.py +231 -0
  63. nat/cli/cli_utils/validation.py +37 -0
  64. nat/cli/commands/__init__.py +0 -0
  65. nat/cli/commands/configure/__init__.py +0 -0
  66. nat/cli/commands/configure/channel/__init__.py +0 -0
  67. nat/cli/commands/configure/channel/add.py +28 -0
  68. nat/cli/commands/configure/channel/channel.py +34 -0
  69. nat/cli/commands/configure/channel/remove.py +30 -0
  70. nat/cli/commands/configure/channel/update.py +30 -0
  71. nat/cli/commands/configure/configure.py +33 -0
  72. nat/cli/commands/evaluate.py +139 -0
  73. nat/cli/commands/info/__init__.py +14 -0
  74. nat/cli/commands/info/info.py +47 -0
  75. nat/cli/commands/info/list_channels.py +32 -0
  76. nat/cli/commands/info/list_components.py +128 -0
  77. nat/cli/commands/mcp/__init__.py +14 -0
  78. nat/cli/commands/mcp/mcp.py +986 -0
  79. nat/cli/commands/object_store/__init__.py +14 -0
  80. nat/cli/commands/object_store/object_store.py +227 -0
  81. nat/cli/commands/optimize.py +90 -0
  82. nat/cli/commands/registry/__init__.py +14 -0
  83. nat/cli/commands/registry/publish.py +88 -0
  84. nat/cli/commands/registry/pull.py +118 -0
  85. nat/cli/commands/registry/registry.py +36 -0
  86. nat/cli/commands/registry/remove.py +108 -0
  87. nat/cli/commands/registry/search.py +153 -0
  88. nat/cli/commands/sizing/__init__.py +14 -0
  89. nat/cli/commands/sizing/calc.py +297 -0
  90. nat/cli/commands/sizing/sizing.py +27 -0
  91. nat/cli/commands/start.py +257 -0
  92. nat/cli/commands/uninstall.py +81 -0
  93. nat/cli/commands/validate.py +47 -0
  94. nat/cli/commands/workflow/__init__.py +14 -0
  95. nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  96. nat/cli/commands/workflow/templates/config.yml.j2 +17 -0
  97. nat/cli/commands/workflow/templates/pyproject.toml.j2 +25 -0
  98. nat/cli/commands/workflow/templates/register.py.j2 +4 -0
  99. nat/cli/commands/workflow/templates/workflow.py.j2 +50 -0
  100. nat/cli/commands/workflow/workflow.py +37 -0
  101. nat/cli/commands/workflow/workflow_commands.py +403 -0
  102. nat/cli/entrypoint.py +141 -0
  103. nat/cli/main.py +60 -0
  104. nat/cli/register_workflow.py +522 -0
  105. nat/cli/type_registry.py +1069 -0
  106. nat/control_flow/__init__.py +0 -0
  107. nat/control_flow/register.py +20 -0
  108. nat/control_flow/router_agent/__init__.py +0 -0
  109. nat/control_flow/router_agent/agent.py +329 -0
  110. nat/control_flow/router_agent/prompt.py +48 -0
  111. nat/control_flow/router_agent/register.py +91 -0
  112. nat/control_flow/sequential_executor.py +166 -0
  113. nat/data_models/__init__.py +14 -0
  114. nat/data_models/agent.py +34 -0
  115. nat/data_models/api_server.py +843 -0
  116. nat/data_models/authentication.py +245 -0
  117. nat/data_models/common.py +171 -0
  118. nat/data_models/component.py +60 -0
  119. nat/data_models/component_ref.py +179 -0
  120. nat/data_models/config.py +434 -0
  121. nat/data_models/dataset_handler.py +169 -0
  122. nat/data_models/discovery_metadata.py +305 -0
  123. nat/data_models/embedder.py +27 -0
  124. nat/data_models/evaluate.py +130 -0
  125. nat/data_models/evaluator.py +26 -0
  126. nat/data_models/front_end.py +26 -0
  127. nat/data_models/function.py +64 -0
  128. nat/data_models/function_dependencies.py +80 -0
  129. nat/data_models/gated_field_mixin.py +242 -0
  130. nat/data_models/interactive.py +246 -0
  131. nat/data_models/intermediate_step.py +302 -0
  132. nat/data_models/invocation_node.py +38 -0
  133. nat/data_models/llm.py +27 -0
  134. nat/data_models/logging.py +26 -0
  135. nat/data_models/memory.py +27 -0
  136. nat/data_models/object_store.py +44 -0
  137. nat/data_models/optimizable.py +119 -0
  138. nat/data_models/optimizer.py +149 -0
  139. nat/data_models/profiler.py +54 -0
  140. nat/data_models/registry_handler.py +26 -0
  141. nat/data_models/retriever.py +30 -0
  142. nat/data_models/retry_mixin.py +35 -0
  143. nat/data_models/span.py +228 -0
  144. nat/data_models/step_adaptor.py +64 -0
  145. nat/data_models/streaming.py +33 -0
  146. nat/data_models/swe_bench_model.py +54 -0
  147. nat/data_models/telemetry_exporter.py +26 -0
  148. nat/data_models/temperature_mixin.py +44 -0
  149. nat/data_models/thinking_mixin.py +86 -0
  150. nat/data_models/top_p_mixin.py +44 -0
  151. nat/data_models/ttc_strategy.py +30 -0
  152. nat/embedder/__init__.py +0 -0
  153. nat/embedder/azure_openai_embedder.py +46 -0
  154. nat/embedder/nim_embedder.py +59 -0
  155. nat/embedder/openai_embedder.py +42 -0
  156. nat/embedder/register.py +22 -0
  157. nat/eval/__init__.py +14 -0
  158. nat/eval/config.py +62 -0
  159. nat/eval/dataset_handler/__init__.py +0 -0
  160. nat/eval/dataset_handler/dataset_downloader.py +106 -0
  161. nat/eval/dataset_handler/dataset_filter.py +52 -0
  162. nat/eval/dataset_handler/dataset_handler.py +431 -0
  163. nat/eval/evaluate.py +565 -0
  164. nat/eval/evaluator/__init__.py +14 -0
  165. nat/eval/evaluator/base_evaluator.py +77 -0
  166. nat/eval/evaluator/evaluator_model.py +58 -0
  167. nat/eval/intermediate_step_adapter.py +99 -0
  168. nat/eval/rag_evaluator/__init__.py +0 -0
  169. nat/eval/rag_evaluator/evaluate.py +178 -0
  170. nat/eval/rag_evaluator/register.py +143 -0
  171. nat/eval/register.py +26 -0
  172. nat/eval/remote_workflow.py +133 -0
  173. nat/eval/runners/__init__.py +14 -0
  174. nat/eval/runners/config.py +39 -0
  175. nat/eval/runners/multi_eval_runner.py +54 -0
  176. nat/eval/runtime_evaluator/__init__.py +14 -0
  177. nat/eval/runtime_evaluator/evaluate.py +123 -0
  178. nat/eval/runtime_evaluator/register.py +100 -0
  179. nat/eval/runtime_event_subscriber.py +52 -0
  180. nat/eval/swe_bench_evaluator/__init__.py +0 -0
  181. nat/eval/swe_bench_evaluator/evaluate.py +215 -0
  182. nat/eval/swe_bench_evaluator/register.py +36 -0
  183. nat/eval/trajectory_evaluator/__init__.py +0 -0
  184. nat/eval/trajectory_evaluator/evaluate.py +75 -0
  185. nat/eval/trajectory_evaluator/register.py +40 -0
  186. nat/eval/tunable_rag_evaluator/__init__.py +0 -0
  187. nat/eval/tunable_rag_evaluator/evaluate.py +242 -0
  188. nat/eval/tunable_rag_evaluator/register.py +52 -0
  189. nat/eval/usage_stats.py +41 -0
  190. nat/eval/utils/__init__.py +0 -0
  191. nat/eval/utils/eval_trace_ctx.py +89 -0
  192. nat/eval/utils/output_uploader.py +140 -0
  193. nat/eval/utils/tqdm_position_registry.py +40 -0
  194. nat/eval/utils/weave_eval.py +193 -0
  195. nat/experimental/__init__.py +0 -0
  196. nat/experimental/decorators/__init__.py +0 -0
  197. nat/experimental/decorators/experimental_warning_decorator.py +154 -0
  198. nat/experimental/test_time_compute/__init__.py +0 -0
  199. nat/experimental/test_time_compute/editing/__init__.py +0 -0
  200. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  201. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  202. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  203. nat/experimental/test_time_compute/functions/__init__.py +0 -0
  204. nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  205. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +228 -0
  206. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
  207. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
  208. nat/experimental/test_time_compute/models/__init__.py +0 -0
  209. nat/experimental/test_time_compute/models/editor_config.py +132 -0
  210. nat/experimental/test_time_compute/models/scoring_config.py +112 -0
  211. nat/experimental/test_time_compute/models/search_config.py +120 -0
  212. nat/experimental/test_time_compute/models/selection_config.py +154 -0
  213. nat/experimental/test_time_compute/models/stage_enums.py +43 -0
  214. nat/experimental/test_time_compute/models/strategy_base.py +67 -0
  215. nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
  216. nat/experimental/test_time_compute/models/ttc_item.py +48 -0
  217. nat/experimental/test_time_compute/register.py +35 -0
  218. nat/experimental/test_time_compute/scoring/__init__.py +0 -0
  219. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  220. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  221. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  222. nat/experimental/test_time_compute/search/__init__.py +0 -0
  223. nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  224. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  225. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  226. nat/experimental/test_time_compute/selection/__init__.py +0 -0
  227. nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  228. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  229. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +157 -0
  230. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  231. nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  232. nat/front_ends/__init__.py +14 -0
  233. nat/front_ends/console/__init__.py +14 -0
  234. nat/front_ends/console/authentication_flow_handler.py +285 -0
  235. nat/front_ends/console/console_front_end_config.py +32 -0
  236. nat/front_ends/console/console_front_end_plugin.py +108 -0
  237. nat/front_ends/console/register.py +25 -0
  238. nat/front_ends/cron/__init__.py +14 -0
  239. nat/front_ends/fastapi/__init__.py +14 -0
  240. nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  241. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  242. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +142 -0
  243. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  244. nat/front_ends/fastapi/fastapi_front_end_config.py +272 -0
  245. nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  246. nat/front_ends/fastapi/fastapi_front_end_plugin.py +247 -0
  247. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1257 -0
  248. nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
  249. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  250. nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  251. nat/front_ends/fastapi/job_store.py +602 -0
  252. nat/front_ends/fastapi/main.py +64 -0
  253. nat/front_ends/fastapi/message_handler.py +344 -0
  254. nat/front_ends/fastapi/message_validator.py +351 -0
  255. nat/front_ends/fastapi/register.py +25 -0
  256. nat/front_ends/fastapi/response_helpers.py +195 -0
  257. nat/front_ends/fastapi/step_adaptor.py +319 -0
  258. nat/front_ends/fastapi/utils.py +57 -0
  259. nat/front_ends/mcp/__init__.py +14 -0
  260. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  261. nat/front_ends/mcp/mcp_front_end_config.py +90 -0
  262. nat/front_ends/mcp/mcp_front_end_plugin.py +113 -0
  263. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +268 -0
  264. nat/front_ends/mcp/memory_profiler.py +320 -0
  265. nat/front_ends/mcp/register.py +27 -0
  266. nat/front_ends/mcp/tool_converter.py +290 -0
  267. nat/front_ends/register.py +21 -0
  268. nat/front_ends/simple_base/__init__.py +14 -0
  269. nat/front_ends/simple_base/simple_front_end_plugin_base.py +56 -0
  270. nat/llm/__init__.py +0 -0
  271. nat/llm/aws_bedrock_llm.py +69 -0
  272. nat/llm/azure_openai_llm.py +57 -0
  273. nat/llm/litellm_llm.py +69 -0
  274. nat/llm/nim_llm.py +58 -0
  275. nat/llm/openai_llm.py +54 -0
  276. nat/llm/register.py +27 -0
  277. nat/llm/utils/__init__.py +14 -0
  278. nat/llm/utils/env_config_value.py +93 -0
  279. nat/llm/utils/error.py +17 -0
  280. nat/llm/utils/thinking.py +215 -0
  281. nat/memory/__init__.py +20 -0
  282. nat/memory/interfaces.py +183 -0
  283. nat/memory/models.py +112 -0
  284. nat/meta/pypi.md +58 -0
  285. nat/object_store/__init__.py +20 -0
  286. nat/object_store/in_memory_object_store.py +76 -0
  287. nat/object_store/interfaces.py +84 -0
  288. nat/object_store/models.py +38 -0
  289. nat/object_store/register.py +19 -0
  290. nat/observability/__init__.py +14 -0
  291. nat/observability/exporter/__init__.py +14 -0
  292. nat/observability/exporter/base_exporter.py +449 -0
  293. nat/observability/exporter/exporter.py +78 -0
  294. nat/observability/exporter/file_exporter.py +33 -0
  295. nat/observability/exporter/processing_exporter.py +550 -0
  296. nat/observability/exporter/raw_exporter.py +52 -0
  297. nat/observability/exporter/span_exporter.py +308 -0
  298. nat/observability/exporter_manager.py +335 -0
  299. nat/observability/mixin/__init__.py +14 -0
  300. nat/observability/mixin/batch_config_mixin.py +26 -0
  301. nat/observability/mixin/collector_config_mixin.py +23 -0
  302. nat/observability/mixin/file_mixin.py +288 -0
  303. nat/observability/mixin/file_mode.py +23 -0
  304. nat/observability/mixin/redaction_config_mixin.py +42 -0
  305. nat/observability/mixin/resource_conflict_mixin.py +134 -0
  306. nat/observability/mixin/serialize_mixin.py +61 -0
  307. nat/observability/mixin/tagging_config_mixin.py +62 -0
  308. nat/observability/mixin/type_introspection_mixin.py +496 -0
  309. nat/observability/processor/__init__.py +14 -0
  310. nat/observability/processor/batching_processor.py +308 -0
  311. nat/observability/processor/callback_processor.py +42 -0
  312. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  313. nat/observability/processor/intermediate_step_serializer.py +28 -0
  314. nat/observability/processor/processor.py +74 -0
  315. nat/observability/processor/processor_factory.py +70 -0
  316. nat/observability/processor/redaction/__init__.py +24 -0
  317. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  318. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  319. nat/observability/processor/redaction/redaction_processor.py +177 -0
  320. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  321. nat/observability/processor/span_tagging_processor.py +68 -0
  322. nat/observability/register.py +114 -0
  323. nat/observability/utils/__init__.py +14 -0
  324. nat/observability/utils/dict_utils.py +236 -0
  325. nat/observability/utils/time_utils.py +31 -0
  326. nat/plugins/.namespace +1 -0
  327. nat/profiler/__init__.py +0 -0
  328. nat/profiler/calc/__init__.py +14 -0
  329. nat/profiler/calc/calc_runner.py +626 -0
  330. nat/profiler/calc/calculations.py +288 -0
  331. nat/profiler/calc/data_models.py +188 -0
  332. nat/profiler/calc/plot.py +345 -0
  333. nat/profiler/callbacks/__init__.py +0 -0
  334. nat/profiler/callbacks/agno_callback_handler.py +295 -0
  335. nat/profiler/callbacks/base_callback_class.py +20 -0
  336. nat/profiler/callbacks/langchain_callback_handler.py +297 -0
  337. nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
  338. nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  339. nat/profiler/callbacks/token_usage_base_model.py +27 -0
  340. nat/profiler/data_frame_row.py +51 -0
  341. nat/profiler/data_models.py +24 -0
  342. nat/profiler/decorators/__init__.py +0 -0
  343. nat/profiler/decorators/framework_wrapper.py +180 -0
  344. nat/profiler/decorators/function_tracking.py +411 -0
  345. nat/profiler/forecasting/__init__.py +0 -0
  346. nat/profiler/forecasting/config.py +18 -0
  347. nat/profiler/forecasting/model_trainer.py +75 -0
  348. nat/profiler/forecasting/models/__init__.py +22 -0
  349. nat/profiler/forecasting/models/forecasting_base_model.py +42 -0
  350. nat/profiler/forecasting/models/linear_model.py +197 -0
  351. nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
  352. nat/profiler/inference_metrics_model.py +28 -0
  353. nat/profiler/inference_optimization/__init__.py +0 -0
  354. nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  355. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  356. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  357. nat/profiler/inference_optimization/data_models.py +386 -0
  358. nat/profiler/inference_optimization/experimental/__init__.py +0 -0
  359. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  360. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +404 -0
  361. nat/profiler/inference_optimization/llm_metrics.py +212 -0
  362. nat/profiler/inference_optimization/prompt_caching.py +163 -0
  363. nat/profiler/inference_optimization/token_uniqueness.py +107 -0
  364. nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
  365. nat/profiler/intermediate_property_adapter.py +102 -0
  366. nat/profiler/parameter_optimization/__init__.py +0 -0
  367. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  368. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  369. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  370. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  371. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  372. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  373. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  374. nat/profiler/profile_runner.py +478 -0
  375. nat/profiler/utils.py +186 -0
  376. nat/registry_handlers/__init__.py +0 -0
  377. nat/registry_handlers/local/__init__.py +0 -0
  378. nat/registry_handlers/local/local_handler.py +176 -0
  379. nat/registry_handlers/local/register_local.py +37 -0
  380. nat/registry_handlers/metadata_factory.py +60 -0
  381. nat/registry_handlers/package_utils.py +570 -0
  382. nat/registry_handlers/pypi/__init__.py +0 -0
  383. nat/registry_handlers/pypi/pypi_handler.py +248 -0
  384. nat/registry_handlers/pypi/register_pypi.py +40 -0
  385. nat/registry_handlers/register.py +20 -0
  386. nat/registry_handlers/registry_handler_base.py +157 -0
  387. nat/registry_handlers/rest/__init__.py +0 -0
  388. nat/registry_handlers/rest/register_rest.py +56 -0
  389. nat/registry_handlers/rest/rest_handler.py +236 -0
  390. nat/registry_handlers/schemas/__init__.py +0 -0
  391. nat/registry_handlers/schemas/headers.py +42 -0
  392. nat/registry_handlers/schemas/package.py +68 -0
  393. nat/registry_handlers/schemas/publish.py +68 -0
  394. nat/registry_handlers/schemas/pull.py +82 -0
  395. nat/registry_handlers/schemas/remove.py +36 -0
  396. nat/registry_handlers/schemas/search.py +91 -0
  397. nat/registry_handlers/schemas/status.py +47 -0
  398. nat/retriever/__init__.py +0 -0
  399. nat/retriever/interface.py +41 -0
  400. nat/retriever/milvus/__init__.py +14 -0
  401. nat/retriever/milvus/register.py +81 -0
  402. nat/retriever/milvus/retriever.py +228 -0
  403. nat/retriever/models.py +77 -0
  404. nat/retriever/nemo_retriever/__init__.py +14 -0
  405. nat/retriever/nemo_retriever/register.py +60 -0
  406. nat/retriever/nemo_retriever/retriever.py +190 -0
  407. nat/retriever/register.py +21 -0
  408. nat/runtime/__init__.py +14 -0
  409. nat/runtime/loader.py +220 -0
  410. nat/runtime/runner.py +292 -0
  411. nat/runtime/session.py +223 -0
  412. nat/runtime/user_metadata.py +130 -0
  413. nat/settings/__init__.py +0 -0
  414. nat/settings/global_settings.py +329 -0
  415. nat/test/.namespace +1 -0
  416. nat/tool/__init__.py +0 -0
  417. nat/tool/chat_completion.py +77 -0
  418. nat/tool/code_execution/README.md +151 -0
  419. nat/tool/code_execution/__init__.py +0 -0
  420. nat/tool/code_execution/code_sandbox.py +267 -0
  421. nat/tool/code_execution/local_sandbox/.gitignore +1 -0
  422. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  423. nat/tool/code_execution/local_sandbox/__init__.py +13 -0
  424. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  425. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  426. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  427. nat/tool/code_execution/register.py +74 -0
  428. nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
  429. nat/tool/code_execution/utils.py +100 -0
  430. nat/tool/datetime_tools.py +82 -0
  431. nat/tool/document_search.py +141 -0
  432. nat/tool/github_tools.py +450 -0
  433. nat/tool/memory_tools/__init__.py +0 -0
  434. nat/tool/memory_tools/add_memory_tool.py +79 -0
  435. nat/tool/memory_tools/delete_memory_tool.py +66 -0
  436. nat/tool/memory_tools/get_memory_tool.py +72 -0
  437. nat/tool/nvidia_rag.py +95 -0
  438. nat/tool/register.py +31 -0
  439. nat/tool/retriever.py +95 -0
  440. nat/tool/server_tools.py +66 -0
  441. nat/utils/__init__.py +0 -0
  442. nat/utils/callable_utils.py +70 -0
  443. nat/utils/data_models/__init__.py +0 -0
  444. nat/utils/data_models/schema_validator.py +58 -0
  445. nat/utils/debugging_utils.py +43 -0
  446. nat/utils/decorators.py +210 -0
  447. nat/utils/dump_distro_mapping.py +32 -0
  448. nat/utils/exception_handlers/__init__.py +0 -0
  449. nat/utils/exception_handlers/automatic_retries.py +342 -0
  450. nat/utils/exception_handlers/schemas.py +114 -0
  451. nat/utils/io/__init__.py +0 -0
  452. nat/utils/io/model_processing.py +28 -0
  453. nat/utils/io/yaml_tools.py +119 -0
  454. nat/utils/log_levels.py +25 -0
  455. nat/utils/log_utils.py +37 -0
  456. nat/utils/metadata_utils.py +74 -0
  457. nat/utils/optional_imports.py +142 -0
  458. nat/utils/producer_consumer_queue.py +178 -0
  459. nat/utils/reactive/__init__.py +0 -0
  460. nat/utils/reactive/base/__init__.py +0 -0
  461. nat/utils/reactive/base/observable_base.py +65 -0
  462. nat/utils/reactive/base/observer_base.py +55 -0
  463. nat/utils/reactive/base/subject_base.py +79 -0
  464. nat/utils/reactive/observable.py +59 -0
  465. nat/utils/reactive/observer.py +76 -0
  466. nat/utils/reactive/subject.py +131 -0
  467. nat/utils/reactive/subscription.py +49 -0
  468. nat/utils/settings/__init__.py +0 -0
  469. nat/utils/settings/global_settings.py +195 -0
  470. nat/utils/string_utils.py +38 -0
  471. nat/utils/type_converter.py +299 -0
  472. nat/utils/type_utils.py +488 -0
  473. nat/utils/url_utils.py +27 -0
  474. nvidia_nat-1.1.0a20251020.dist-info/METADATA +195 -0
  475. nvidia_nat-1.1.0a20251020.dist-info/RECORD +480 -0
  476. nvidia_nat-1.1.0a20251020.dist-info/WHEEL +5 -0
  477. nvidia_nat-1.1.0a20251020.dist-info/entry_points.txt +22 -0
  478. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  479. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
  480. nvidia_nat-1.1.0a20251020.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1365 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import dataclasses
18
+ import inspect
19
+ import logging
20
+ import typing
21
+ import warnings
22
+ from collections.abc import Sequence
23
+ from contextlib import AbstractAsyncContextManager
24
+ from contextlib import AsyncExitStack
25
+ from contextlib import asynccontextmanager
26
+ from typing import cast
27
+
28
+ from nat.authentication.interfaces import AuthProviderBase
29
+ from nat.builder.builder import Builder
30
+ from nat.builder.builder import UserManagerHolder
31
+ from nat.builder.component_utils import ComponentInstanceData
32
+ from nat.builder.component_utils import build_dependency_sequence
33
+ from nat.builder.context import Context
34
+ from nat.builder.context import ContextState
35
+ from nat.builder.embedder import EmbedderProviderInfo
36
+ from nat.builder.framework_enum import LLMFrameworkEnum
37
+ from nat.builder.function import Function
38
+ from nat.builder.function import FunctionGroup
39
+ from nat.builder.function import LambdaFunction
40
+ from nat.builder.function_info import FunctionInfo
41
+ from nat.builder.llm import LLMProviderInfo
42
+ from nat.builder.retriever import RetrieverProviderInfo
43
+ from nat.builder.workflow import Workflow
44
+ from nat.cli.type_registry import GlobalTypeRegistry
45
+ from nat.cli.type_registry import TypeRegistry
46
+ from nat.data_models.authentication import AuthProviderBaseConfig
47
+ from nat.data_models.component import ComponentGroup
48
+ from nat.data_models.component_ref import AuthenticationRef
49
+ from nat.data_models.component_ref import EmbedderRef
50
+ from nat.data_models.component_ref import FunctionGroupRef
51
+ from nat.data_models.component_ref import FunctionRef
52
+ from nat.data_models.component_ref import LLMRef
53
+ from nat.data_models.component_ref import MemoryRef
54
+ from nat.data_models.component_ref import ObjectStoreRef
55
+ from nat.data_models.component_ref import RetrieverRef
56
+ from nat.data_models.component_ref import TTCStrategyRef
57
+ from nat.data_models.config import Config
58
+ from nat.data_models.config import GeneralConfig
59
+ from nat.data_models.embedder import EmbedderBaseConfig
60
+ from nat.data_models.function import FunctionBaseConfig
61
+ from nat.data_models.function import FunctionGroupBaseConfig
62
+ from nat.data_models.function_dependencies import FunctionDependencies
63
+ from nat.data_models.llm import LLMBaseConfig
64
+ from nat.data_models.memory import MemoryBaseConfig
65
+ from nat.data_models.object_store import ObjectStoreBaseConfig
66
+ from nat.data_models.retriever import RetrieverBaseConfig
67
+ from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
68
+ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
69
+ from nat.experimental.decorators.experimental_warning_decorator import experimental
70
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
71
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
72
+ from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
73
+ from nat.memory.interfaces import MemoryEditor
74
+ from nat.object_store.interfaces import ObjectStore
75
+ from nat.observability.exporter.base_exporter import BaseExporter
76
+ from nat.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
77
+ from nat.profiler.utils import detect_llm_frameworks_in_build_fn
78
+ from nat.retriever.interface import Retriever
79
+ from nat.utils.type_utils import override
80
+
81
+ logger = logging.getLogger(__name__)
82
+
83
+
84
+ @dataclasses.dataclass
85
+ class ConfiguredTelemetryExporter:
86
+ config: TelemetryExporterBaseConfig
87
+ instance: BaseExporter
88
+
89
+
90
+ @dataclasses.dataclass
91
+ class ConfiguredFunction:
92
+ config: FunctionBaseConfig
93
+ instance: Function
94
+
95
+
96
+ @dataclasses.dataclass
97
+ class ConfiguredFunctionGroup:
98
+ config: FunctionGroupBaseConfig
99
+ instance: FunctionGroup
100
+
101
+
102
+ @dataclasses.dataclass
103
+ class ConfiguredLLM:
104
+ config: LLMBaseConfig
105
+ instance: LLMProviderInfo
106
+
107
+
108
+ @dataclasses.dataclass
109
+ class ConfiguredEmbedder:
110
+ config: EmbedderBaseConfig
111
+ instance: EmbedderProviderInfo
112
+
113
+
114
+ @dataclasses.dataclass
115
+ class ConfiguredMemory:
116
+ config: MemoryBaseConfig
117
+ instance: MemoryEditor
118
+
119
+
120
+ @dataclasses.dataclass
121
+ class ConfiguredObjectStore:
122
+ config: ObjectStoreBaseConfig
123
+ instance: ObjectStore
124
+
125
+
126
+ @dataclasses.dataclass
127
+ class ConfiguredRetriever:
128
+ config: RetrieverBaseConfig
129
+ instance: RetrieverProviderInfo
130
+
131
+
132
+ @dataclasses.dataclass
133
+ class ConfiguredAuthProvider:
134
+ config: AuthProviderBaseConfig
135
+ instance: AuthProviderBase
136
+
137
+
138
+ @dataclasses.dataclass
139
+ class ConfiguredTTCStrategy:
140
+ config: TTCStrategyBaseConfig
141
+ instance: StrategyBase
142
+
143
+
144
+ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
145
+
146
+ def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
147
+
148
+ if general_config is None:
149
+ general_config = GeneralConfig()
150
+
151
+ if registry is None:
152
+ registry = GlobalTypeRegistry.get()
153
+
154
+ self.general_config = general_config
155
+
156
+ self._registry = registry
157
+
158
+ self._logging_handlers: dict[str, logging.Handler] = {}
159
+ self._removed_root_handlers: list[tuple[logging.Handler, int]] = []
160
+ self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
161
+
162
+ self._functions: dict[str, ConfiguredFunction] = {}
163
+ self._function_groups: dict[str, ConfiguredFunctionGroup] = {}
164
+ self._workflow: ConfiguredFunction | None = None
165
+
166
+ self._llms: dict[str, ConfiguredLLM] = {}
167
+ self._auth_providers: dict[str, ConfiguredAuthProvider] = {}
168
+ self._embedders: dict[str, ConfiguredEmbedder] = {}
169
+ self._memory_clients: dict[str, ConfiguredMemory] = {}
170
+ self._object_stores: dict[str, ConfiguredObjectStore] = {}
171
+ self._retrievers: dict[str, ConfiguredRetriever] = {}
172
+ self._ttc_strategies: dict[str, ConfiguredTTCStrategy] = {}
173
+
174
+ self._context_state = ContextState.get()
175
+
176
+ self._exit_stack: AsyncExitStack | None = None
177
+
178
+ # Create a mapping to track function name -> other function names it depends on
179
+ self.function_dependencies: dict[str, FunctionDependencies] = {}
180
+ self.function_group_dependencies: dict[str, FunctionDependencies] = {}
181
+ self.current_function_building: str | None = None
182
+ self.current_function_group_building: str | None = None
183
+
184
+ async def __aenter__(self):
185
+
186
+ self._exit_stack = AsyncExitStack()
187
+
188
+ # Get the telemetry info from the config
189
+ telemetry_config = self.general_config.telemetry
190
+
191
+ # If we have logging configuration, we need to manage the root logger properly
192
+ root_logger = logging.getLogger()
193
+
194
+ # Collect configured handler types to determine if we need to adjust existing handlers
195
+ # This is somewhat of a hack by inspecting the class name of the config object
196
+ has_console_handler = any(
197
+ hasattr(config, "__class__") and "console" in config.__class__.__name__.lower()
198
+ for config in telemetry_config.logging.values())
199
+
200
+ for key, logging_config in telemetry_config.logging.items():
201
+ # Use the same pattern as tracing, but for logging
202
+ logging_info = self._registry.get_logging_method(type(logging_config))
203
+ handler = await self._exit_stack.enter_async_context(logging_info.build_fn(logging_config, self))
204
+
205
+ # Type check
206
+ if not isinstance(handler, logging.Handler):
207
+ raise TypeError(f"Expected a logging.Handler from {key}, got {type(handler)}")
208
+
209
+ # Store them in a dict so we can un-register them if needed
210
+ self._logging_handlers[key] = handler
211
+
212
+ # Now attach to NAT's root logger
213
+ root_logger.addHandler(handler)
214
+
215
+ # If we added logging handlers, manage existing handlers appropriately
216
+ if self._logging_handlers:
217
+ min_handler_level = min((handler.level for handler in root_logger.handlers), default=logging.CRITICAL)
218
+
219
+ # Ensure the root logger level allows messages through
220
+ root_logger.level = max(root_logger.level, min_handler_level)
221
+
222
+ # If a console handler is configured, adjust or remove default CLI handlers
223
+ # to avoid duplicate output while preserving workflow visibility
224
+ if has_console_handler:
225
+ # Remove existing StreamHandlers that are not the newly configured ones
226
+ for handler in root_logger.handlers[:]:
227
+ if type(handler) is logging.StreamHandler and handler not in self._logging_handlers.values():
228
+ self._removed_root_handlers.append((handler, handler.level))
229
+ root_logger.removeHandler(handler)
230
+ else:
231
+ # No console handler configured, but adjust existing handler levels
232
+ # to respect the minimum configured level for file/other handlers
233
+ for handler in root_logger.handlers[:]:
234
+ if type(handler) is logging.StreamHandler:
235
+ old_level = handler.level
236
+ handler.setLevel(min_handler_level)
237
+ self._removed_root_handlers.append((handler, old_level))
238
+
239
+ # Add the telemetry exporters
240
+ for key, telemetry_exporter_config in telemetry_config.tracing.items():
241
+ await self.add_telemetry_exporter(key, telemetry_exporter_config)
242
+
243
+ return self
244
+
245
+ async def __aexit__(self, *exc_details):
246
+
247
+ assert self._exit_stack is not None, "Exit stack not initialized"
248
+
249
+ root_logger = logging.getLogger()
250
+
251
+ # Remove custom logging handlers
252
+ for handler in self._logging_handlers.values():
253
+ root_logger.removeHandler(handler)
254
+
255
+ # Restore original handlers and their levels
256
+ for handler, old_level in self._removed_root_handlers:
257
+ if handler not in root_logger.handlers:
258
+ root_logger.addHandler(handler)
259
+ handler.setLevel(old_level)
260
+
261
+ await self._exit_stack.__aexit__(*exc_details)
262
+
263
+ async def build(self, entry_function: str | None = None) -> Workflow:
264
+ """
265
+ Creates an instance of a workflow object using the added components and the desired entry function.
266
+
267
+ Parameters
268
+ ----------
269
+ entry_function : str | None, optional
270
+ The function name to use as the entry point for the created workflow. If None, the entry point will be the
271
+ specified workflow function. By default None
272
+
273
+ Returns
274
+ -------
275
+ Workflow
276
+ A created workflow.
277
+
278
+ Raises
279
+ ------
280
+ ValueError
281
+ If the workflow has not been set before building.
282
+ """
283
+
284
+ if (self._workflow is None):
285
+ raise ValueError("Must set a workflow before building")
286
+
287
+ # Set of all functions which are "included" by function groups
288
+ included_functions = set()
289
+ # Dictionary of function configs
290
+ function_configs = dict()
291
+ # Dictionary of function group configs
292
+ function_group_configs = dict()
293
+ # Dictionary of function instances
294
+ function_instances = dict()
295
+ # Dictionary of function group instances
296
+ function_group_instances = dict()
297
+
298
+ for k, v in self._function_groups.items():
299
+ included_functions.update((await v.instance.get_included_functions()).keys())
300
+ function_group_configs[k] = v.config
301
+ function_group_instances[k] = v.instance
302
+
303
+ # Function configs need to be restricted to only the functions that are not in a function group
304
+ for k, v in self._functions.items():
305
+ if k not in included_functions:
306
+ function_configs[k] = v.config
307
+ function_instances[k] = v.instance
308
+
309
+ # Build the config from the added objects
310
+ config = Config(general=self.general_config,
311
+ functions=function_configs,
312
+ function_groups=function_group_configs,
313
+ workflow=self._workflow.config,
314
+ llms={
315
+ k: v.config
316
+ for k, v in self._llms.items()
317
+ },
318
+ embedders={
319
+ k: v.config
320
+ for k, v in self._embedders.items()
321
+ },
322
+ memory={
323
+ k: v.config
324
+ for k, v in self._memory_clients.items()
325
+ },
326
+ object_stores={
327
+ k: v.config
328
+ for k, v in self._object_stores.items()
329
+ },
330
+ retrievers={
331
+ k: v.config
332
+ for k, v in self._retrievers.items()
333
+ },
334
+ ttc_strategies={
335
+ k: v.config
336
+ for k, v in self._ttc_strategies.items()
337
+ })
338
+
339
+ if (entry_function is None):
340
+ entry_fn_obj = self.get_workflow()
341
+ else:
342
+ entry_fn_obj = await self.get_function(entry_function)
343
+
344
+ workflow = Workflow.from_entry_fn(config=config,
345
+ entry_fn=entry_fn_obj,
346
+ functions=function_instances,
347
+ function_groups=function_group_instances,
348
+ llms={
349
+ k: v.instance
350
+ for k, v in self._llms.items()
351
+ },
352
+ embeddings={
353
+ k: v.instance
354
+ for k, v in self._embedders.items()
355
+ },
356
+ memory={
357
+ k: v.instance
358
+ for k, v in self._memory_clients.items()
359
+ },
360
+ object_stores={
361
+ k: v.instance
362
+ for k, v in self._object_stores.items()
363
+ },
364
+ telemetry_exporters={
365
+ k: v.instance
366
+ for k, v in self._telemetry_exporters.items()
367
+ },
368
+ retrievers={
369
+ k: v.instance
370
+ for k, v in self._retrievers.items()
371
+ },
372
+ ttc_strategies={
373
+ k: v.instance
374
+ for k, v in self._ttc_strategies.items()
375
+ },
376
+ context_state=self._context_state)
377
+
378
+ return workflow
379
+
380
+ def _get_exit_stack(self) -> AsyncExitStack:
381
+
382
+ if self._exit_stack is None:
383
+ raise ValueError(
384
+ "Exit stack not initialized. Did you forget to call `async with WorkflowBuilder() as builder`?")
385
+
386
+ return self._exit_stack
387
+
388
+ async def _build_function(self, name: str, config: FunctionBaseConfig) -> ConfiguredFunction:
389
+ registration = self._registry.get_function(type(config))
390
+
391
+ inner_builder = ChildBuilder(self)
392
+
393
+ # We need to do this for every function because we don't know
394
+ # Where LLama Index Agents are Instantiated and Settings need to
395
+ # be set before the function is built
396
+ # It's only slower the first time because of the import
397
+ # So we can afford to do this for every function
398
+
399
+ llms = {k: v.instance for k, v in self._llms.items()}
400
+ function_frameworks = detect_llm_frameworks_in_build_fn(registration)
401
+
402
+ build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
403
+
404
+ # Set the currently building function so the ChildBuilder can track dependencies
405
+ self.current_function_building = config.type
406
+ # Empty set of dependencies for the current function
407
+ self.function_dependencies[config.type] = FunctionDependencies()
408
+
409
+ build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
410
+
411
+ self.function_dependencies[name] = inner_builder.dependencies
412
+
413
+ # If the build result is a function, wrap it in a FunctionInfo
414
+ if inspect.isfunction(build_result):
415
+
416
+ build_result = FunctionInfo.from_fn(build_result)
417
+
418
+ if (isinstance(build_result, FunctionInfo)):
419
+ # Create the function object
420
+ build_result = LambdaFunction.from_info(config=config, info=build_result, instance_name=name)
421
+
422
+ if (not isinstance(build_result, Function)):
423
+ raise ValueError("Expected a function, FunctionInfo object, or FunctionBase object to be "
424
+ f"returned from the function builder. Got {type(build_result)}")
425
+
426
+ return ConfiguredFunction(config=config, instance=build_result)
427
+
428
+ async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
429
+ """Build a function group from the provided configuration.
430
+
431
+ Args:
432
+ name: The name of the function group
433
+ config: The function group configuration
434
+
435
+ Returns:
436
+ ConfiguredFunctionGroup: The built function group
437
+
438
+ Raises:
439
+ ValueError: If the function group builder returns invalid results
440
+ """
441
+ registration = self._registry.get_function_group(type(config))
442
+
443
+ inner_builder = ChildBuilder(self)
444
+
445
+ # Build the function group - use the same wrapping pattern as _build_function
446
+ llms = {k: v.instance for k, v in self._llms.items()}
447
+ function_frameworks = detect_llm_frameworks_in_build_fn(registration)
448
+
449
+ build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
450
+
451
+ # Set the currently building function group so the ChildBuilder can track dependencies
452
+ self.current_function_group_building = config.type
453
+ # Empty set of dependencies for the current function group
454
+ self.function_group_dependencies[config.type] = FunctionDependencies()
455
+
456
+ build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
457
+
458
+ self.function_group_dependencies[name] = inner_builder.dependencies
459
+
460
+ if not isinstance(build_result, FunctionGroup):
461
+ raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
462
+ f"Got {type(build_result)}")
463
+
464
+ # set the instance name for the function group based on the workflow-provided name
465
+ build_result.set_instance_name(name)
466
+ return ConfiguredFunctionGroup(config=config, instance=build_result)
467
+
468
+ @override
469
+ async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
470
+ if isinstance(name, FunctionRef):
471
+ name = str(name)
472
+
473
+ if (name in self._functions or name in self._function_groups):
474
+ raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
475
+
476
+ build_result = await self._build_function(name=name, config=config)
477
+
478
+ self._functions[name] = build_result
479
+
480
+ return build_result.instance
481
+
482
+ @override
483
+ async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
484
+ if isinstance(name, FunctionGroupRef):
485
+ name = str(name)
486
+
487
+ if (name in self._function_groups or name in self._functions):
488
+ raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
489
+
490
+ # Build the function group
491
+ build_result = await self._build_function_group(name=name, config=config)
492
+
493
+ self._function_groups[name] = build_result
494
+
495
+ # If the function group exposes functions, add them to the global function registry
496
+ # If the function group exposes functions, record and add them to the registry
497
+ included_functions = await build_result.instance.get_included_functions()
498
+ for k in included_functions:
499
+ if k in self._functions:
500
+ raise ValueError(f"Exposed function `{k}` from group `{name}` conflicts with an existing function")
501
+ self._functions.update({
502
+ k: ConfiguredFunction(config=v.config, instance=v)
503
+ for k, v in included_functions.items()
504
+ })
505
+
506
+ return build_result.instance
507
+
508
+ @override
509
+ async def get_function(self, name: str | FunctionRef) -> Function:
510
+ if isinstance(name, FunctionRef):
511
+ name = str(name)
512
+ if name not in self._functions:
513
+ raise ValueError(f"Function `{name}` not found")
514
+
515
+ return self._functions[name].instance
516
+
517
+ @override
518
+ async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
519
+ if isinstance(name, FunctionGroupRef):
520
+ name = str(name)
521
+ if name not in self._function_groups:
522
+ raise ValueError(f"Function group `{name}` not found")
523
+
524
+ return self._function_groups[name].instance
525
+
526
+ @override
527
+ def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
528
+ if isinstance(name, FunctionRef):
529
+ name = str(name)
530
+ if name not in self._functions:
531
+ raise ValueError(f"Function `{name}` not found")
532
+
533
+ return self._functions[name].config
534
+
535
+ @override
536
+ def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
537
+ if isinstance(name, FunctionGroupRef):
538
+ name = str(name)
539
+ if name not in self._function_groups:
540
+ raise ValueError(f"Function group `{name}` not found")
541
+
542
+ return self._function_groups[name].config
543
+
544
+ @override
545
+ async def set_workflow(self, config: FunctionBaseConfig) -> Function:
546
+
547
+ if self._workflow is not None:
548
+ warnings.warn("Overwriting existing workflow")
549
+
550
+ build_result = await self._build_function(name="<workflow>", config=config)
551
+
552
+ self._workflow = build_result
553
+
554
+ return build_result.instance
555
+
556
+ @override
557
+ def get_workflow(self) -> Function:
558
+
559
+ if self._workflow is None:
560
+ raise ValueError("No workflow set")
561
+
562
+ return self._workflow.instance
563
+
564
+ @override
565
+ def get_workflow_config(self) -> FunctionBaseConfig:
566
+ if self._workflow is None:
567
+ raise ValueError("No workflow set")
568
+
569
+ return self._workflow.config
570
+
571
+ @override
572
+ def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
573
+ if isinstance(fn_name, FunctionRef):
574
+ fn_name = str(fn_name)
575
+ return self.function_dependencies[fn_name]
576
+
577
+ @override
578
+ def get_function_group_dependencies(self, fn_name: str | FunctionGroupRef) -> FunctionDependencies:
579
+ if isinstance(fn_name, FunctionGroupRef):
580
+ fn_name = str(fn_name)
581
+ return self.function_group_dependencies[fn_name]
582
+
583
+ @override
584
+ async def get_tools(self,
585
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
586
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
587
+
588
+ unique = set(tool_names)
589
+ if len(unique) != len(tool_names):
590
+ raise ValueError("Tool names must be unique")
591
+
592
+ async def _get_tools(n: str | FunctionRef | FunctionGroupRef):
593
+ tools = []
594
+ is_function_group_ref = isinstance(n, FunctionGroupRef)
595
+ if isinstance(n, FunctionRef) or is_function_group_ref:
596
+ n = str(n)
597
+ if n not in self._function_groups:
598
+ # the passed tool name is probably a function, but first check if it's a function group
599
+ if is_function_group_ref:
600
+ raise ValueError(f"Function group `{n}` not found in the list of function groups")
601
+ tools.append(await self.get_tool(n, wrapper_type))
602
+ else:
603
+ tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
604
+ current_function_group = self._function_groups[n]
605
+ for fn_name, fn_instance in (await current_function_group.instance.get_accessible_functions()).items():
606
+ try:
607
+ tools.append(tool_wrapper_reg.build_fn(fn_name, fn_instance, self))
608
+ except Exception:
609
+ logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
610
+ raise
611
+ return tools
612
+
613
+ tool_lists = await asyncio.gather(*[_get_tools(n) for n in tool_names])
614
+ # Flatten the list of lists into a single list
615
+ return [tool for tools in tool_lists for tool in tools]
616
+
617
+ @override
618
+ async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
619
+ if isinstance(fn_name, FunctionRef):
620
+ fn_name = str(fn_name)
621
+ if fn_name not in self._functions:
622
+ raise ValueError(f"Function `{fn_name}` not found in list of functions")
623
+ fn = self._functions[fn_name]
624
+ try:
625
+ # Using the registry, get the tool wrapper for the requested framework
626
+ tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
627
+
628
+ # Wrap in the correct wrapper
629
+ return tool_wrapper_reg.build_fn(fn_name, fn.instance, self)
630
+ except Exception as e:
631
+ logger.error("Error fetching tool `%s`: %s", fn_name, e)
632
+ raise
633
+
634
+ @override
635
+ async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> None:
636
+
637
+ if (name in self._llms):
638
+ raise ValueError(f"LLM `{name}` already exists in the list of LLMs")
639
+
640
+ try:
641
+ llm_info = self._registry.get_llm_provider(type(config))
642
+
643
+ info_obj = await self._get_exit_stack().enter_async_context(llm_info.build_fn(config, self))
644
+
645
+ self._llms[name] = ConfiguredLLM(config=config, instance=info_obj)
646
+ except Exception as e:
647
+ logger.error("Error adding llm `%s` with config `%s`: %s", name, config, e)
648
+ raise
649
+
650
+ @override
651
+ async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
652
+
653
+ if (llm_name not in self._llms):
654
+ raise ValueError(f"LLM `{llm_name}` not found")
655
+
656
+ try:
657
+ # Get llm info
658
+ llm_info = self._llms[llm_name]
659
+
660
+ # Generate wrapped client from registered client info
661
+ client_info = self._registry.get_llm_client(config_type=type(llm_info.config), wrapper_type=wrapper_type)
662
+
663
+ client = await self._get_exit_stack().enter_async_context(client_info.build_fn(llm_info.config, self))
664
+
665
+ # Return a frameworks specific client
666
+ return client
667
+ except Exception as e:
668
+ logger.error("Error getting llm `%s` with wrapper `%s`: %s", llm_name, wrapper_type, e)
669
+ raise
670
+
671
+ @override
672
+ def get_llm_config(self, llm_name: str | LLMRef) -> LLMBaseConfig:
673
+
674
+ if llm_name not in self._llms:
675
+ raise ValueError(f"LLM `{llm_name}` not found")
676
+
677
+ # Return the tool configuration object
678
+ return self._llms[llm_name].config
679
+
680
+ @experimental(feature_name="Authentication")
681
+ @override
682
+ async def add_auth_provider(self, name: str | AuthenticationRef,
683
+ config: AuthProviderBaseConfig) -> AuthProviderBase:
684
+ """
685
+ Add an authentication provider to the workflow by constructing it from a configuration object.
686
+
687
+ Note: The Authentication Provider API is experimental and the API may change in future releases.
688
+
689
+ Parameters
690
+ ----------
691
+ name : str | AuthenticationRef
692
+ The name of the authentication provider to add.
693
+ config : AuthProviderBaseConfig
694
+ The configuration for the authentication provider.
695
+
696
+ Returns
697
+ -------
698
+ AuthProviderBase
699
+ The authentication provider instance.
700
+
701
+ Raises
702
+ ------
703
+ ValueError
704
+ If the authentication provider is already in the list of authentication providers.
705
+ """
706
+
707
+ if (name in self._auth_providers):
708
+ raise ValueError(f"Authentication `{name}` already exists in the list of Authentication Providers")
709
+
710
+ try:
711
+ authentication_info = self._registry.get_auth_provider(type(config))
712
+
713
+ info_obj = await self._get_exit_stack().enter_async_context(authentication_info.build_fn(config, self))
714
+
715
+ self._auth_providers[name] = ConfiguredAuthProvider(config=config, instance=info_obj)
716
+
717
+ return info_obj
718
+ except Exception as e:
719
+ logger.error("Error adding authentication `%s` with config `%s`: %s", name, config, e)
720
+ raise
721
+
722
+ @override
723
+ async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
724
+ """
725
+ Get the authentication provider instance for the given name.
726
+
727
+ Note: The Authentication Provider API is experimental and the API may change in future releases.
728
+
729
+ Parameters
730
+ ----------
731
+ auth_provider_name : str
732
+ The name of the authentication provider to get.
733
+
734
+ Returns
735
+ -------
736
+ AuthProviderBase
737
+ The authentication provider instance.
738
+
739
+ Raises
740
+ ------
741
+ ValueError
742
+ If the authentication provider is not found.
743
+ """
744
+
745
+ if auth_provider_name not in self._auth_providers:
746
+ raise ValueError(f"Authentication `{auth_provider_name}` not found")
747
+
748
+ return self._auth_providers[auth_provider_name].instance
749
+
750
+ @override
751
+ async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
752
+
753
+ if (name in self._embedders):
754
+ raise ValueError(f"Embedder `{name}` already exists in the list of embedders")
755
+
756
+ try:
757
+ embedder_info = self._registry.get_embedder_provider(type(config))
758
+
759
+ info_obj = await self._get_exit_stack().enter_async_context(embedder_info.build_fn(config, self))
760
+
761
+ self._embedders[name] = ConfiguredEmbedder(config=config, instance=info_obj)
762
+ except Exception as e:
763
+ logger.error("Error adding embedder `%s` with config `%s`: %s", name, config, e)
764
+ raise
765
+
766
+ @override
767
+ async def get_embedder(self, embedder_name: str | EmbedderRef, wrapper_type: LLMFrameworkEnum | str):
768
+
769
+ if (embedder_name not in self._embedders):
770
+ raise ValueError(f"Embedder `{embedder_name}` not found")
771
+
772
+ try:
773
+ # Get embedder info
774
+ embedder_info = self._embedders[embedder_name]
775
+
776
+ # Generate wrapped client from registered client info
777
+ client_info = self._registry.get_embedder_client(config_type=type(embedder_info.config),
778
+ wrapper_type=wrapper_type)
779
+ client = await self._get_exit_stack().enter_async_context(client_info.build_fn(embedder_info.config, self))
780
+
781
+ # Return a frameworks specific client
782
+ return client
783
+ except Exception as e:
784
+ logger.error("Error getting embedder `%s` with wrapper `%s`: %s", embedder_name, wrapper_type, e)
785
+ raise
786
+
787
+ @override
788
+ def get_embedder_config(self, embedder_name: str | EmbedderRef) -> EmbedderBaseConfig:
789
+
790
+ if embedder_name not in self._embedders:
791
+ raise ValueError(f"Tool `{embedder_name}` not found")
792
+
793
+ # Return the tool configuration object
794
+ return self._embedders[embedder_name].config
795
+
796
+ @override
797
+ async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig) -> MemoryEditor:
798
+
799
+ if (name in self._memory_clients):
800
+ raise ValueError(f"Memory `{name}` already exists in the list of memories")
801
+
802
+ memory_info = self._registry.get_memory(type(config))
803
+
804
+ info_obj = await self._get_exit_stack().enter_async_context(memory_info.build_fn(config, self))
805
+
806
+ self._memory_clients[name] = ConfiguredMemory(config=config, instance=info_obj)
807
+
808
+ return info_obj
809
+
810
+ @override
811
+ async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
812
+ """
813
+ Return the instantiated memory client for the given name.
814
+ """
815
+ if memory_name not in self._memory_clients:
816
+ raise ValueError(f"Memory `{memory_name}` not found")
817
+
818
+ return self._memory_clients[memory_name].instance
819
+
820
+ @override
821
+ def get_memory_client_config(self, memory_name: str | MemoryRef) -> MemoryBaseConfig:
822
+
823
+ if memory_name not in self._memory_clients:
824
+ raise ValueError(f"Memory `{memory_name}` not found")
825
+
826
+ # Return the tool configuration object
827
+ return self._memory_clients[memory_name].config
828
+
829
+ @override
830
+ async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig) -> ObjectStore:
831
+ if name in self._object_stores:
832
+ raise ValueError(f"Object store `{name}` already exists in the list of object stores")
833
+
834
+ object_store_info = self._registry.get_object_store(type(config))
835
+
836
+ info_obj = await self._get_exit_stack().enter_async_context(object_store_info.build_fn(config, self))
837
+
838
+ self._object_stores[name] = ConfiguredObjectStore(config=config, instance=info_obj)
839
+
840
+ return info_obj
841
+
842
+ @override
843
+ async def get_object_store_client(self, object_store_name: str | ObjectStoreRef) -> ObjectStore:
844
+ if object_store_name not in self._object_stores:
845
+ raise ValueError(f"Object store `{object_store_name}` not found")
846
+
847
+ return self._object_stores[object_store_name].instance
848
+
849
+ @override
850
+ def get_object_store_config(self, object_store_name: str | ObjectStoreRef) -> ObjectStoreBaseConfig:
851
+ if object_store_name not in self._object_stores:
852
+ raise ValueError(f"Object store `{object_store_name}` not found")
853
+
854
+ return self._object_stores[object_store_name].config
855
+
856
+ @override
857
+ async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
858
+
859
+ if (name in self._retrievers):
860
+ raise ValueError(f"Retriever '{name}' already exists in the list of retrievers")
861
+
862
+ try:
863
+ retriever_info = self._registry.get_retriever_provider(type(config))
864
+
865
+ info_obj = await self._get_exit_stack().enter_async_context(retriever_info.build_fn(config, self))
866
+
867
+ self._retrievers[name] = ConfiguredRetriever(config=config, instance=info_obj)
868
+
869
+ except Exception as e:
870
+ logger.error("Error adding retriever `%s` with config `%s`: %s", name, config, e)
871
+ raise
872
+
873
+ @override
874
+ async def get_retriever(self,
875
+ retriever_name: str | RetrieverRef,
876
+ wrapper_type: LLMFrameworkEnum | str | None = None):
877
+
878
+ if retriever_name not in self._retrievers:
879
+ raise ValueError(f"Retriever '{retriever_name}' not found")
880
+
881
+ try:
882
+ # Get retriever info
883
+ retriever_info = self._retrievers[retriever_name]
884
+
885
+ # Generate wrapped client from registered client info
886
+ client_info = self._registry.get_retriever_client(config_type=type(retriever_info.config),
887
+ wrapper_type=wrapper_type)
888
+
889
+ client = await self._get_exit_stack().enter_async_context(client_info.build_fn(retriever_info.config, self))
890
+
891
+ # Return a frameworks specific client
892
+ return client
893
+ except Exception as e:
894
+ logger.error("Error getting retriever `%s` with wrapper `%s`: %s", retriever_name, wrapper_type, e)
895
+ raise
896
+
897
+ @override
898
+ async def get_retriever_config(self, retriever_name: str | RetrieverRef) -> RetrieverBaseConfig:
899
+
900
+ if retriever_name not in self._retrievers:
901
+ raise ValueError(f"Retriever `{retriever_name}` not found")
902
+
903
+ return self._retrievers[retriever_name].config
904
+
905
+ @override
906
+ @experimental(feature_name="TTC")
907
+ async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig) -> None:
908
+ if (name in self._ttc_strategies):
909
+ raise ValueError(f"TTC strategy '{name}' already exists in the list of TTC strategies")
910
+
911
+ try:
912
+ ttc_strategy_info = self._registry.get_ttc_strategy(type(config))
913
+
914
+ info_obj = await self._get_exit_stack().enter_async_context(ttc_strategy_info.build_fn(config, self))
915
+
916
+ self._ttc_strategies[name] = ConfiguredTTCStrategy(config=config, instance=info_obj)
917
+
918
+ except Exception as e:
919
+ logger.error("Error adding TTC strategy `%s` with config `%s`: %s", name, config, e)
920
+ raise
921
+
922
+ @override
923
+ async def get_ttc_strategy(self,
924
+ strategy_name: str | TTCStrategyRef,
925
+ pipeline_type: PipelineTypeEnum,
926
+ stage_type: StageTypeEnum) -> StrategyBase:
927
+
928
+ if strategy_name not in self._ttc_strategies:
929
+ raise ValueError(f"TTC strategy '{strategy_name}' not found")
930
+
931
+ try:
932
+ # Get strategy info
933
+ ttc_strategy_info = self._ttc_strategies[strategy_name]
934
+
935
+ instance = ttc_strategy_info.instance
936
+
937
+ if not stage_type == instance.stage_type():
938
+ raise ValueError(f"TTC strategy '{strategy_name}' is not compatible with stage type '{stage_type}'")
939
+
940
+ if pipeline_type not in instance.supported_pipeline_types():
941
+ raise ValueError(
942
+ f"TTC strategy '{strategy_name}' is not compatible with pipeline type '{pipeline_type}'")
943
+
944
+ instance.set_pipeline_type(pipeline_type)
945
+
946
+ return instance
947
+ except Exception as e:
948
+ logger.error("Error getting TTC strategy `%s`: %s", strategy_name, e)
949
+ raise
950
+
951
+ @override
952
+ async def get_ttc_strategy_config(self,
953
+ strategy_name: str | TTCStrategyRef,
954
+ pipeline_type: PipelineTypeEnum,
955
+ stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
956
+ if strategy_name not in self._ttc_strategies:
957
+ raise ValueError(f"TTC strategy '{strategy_name}' not found")
958
+
959
+ strategy_info = self._ttc_strategies[strategy_name]
960
+ instance = strategy_info.instance
961
+ config = strategy_info.config
962
+
963
+ if not stage_type == instance.stage_type():
964
+ raise ValueError(f"TTC strategy '{strategy_name}' is not compatible with stage type '{stage_type}'")
965
+
966
+ if pipeline_type not in instance.supported_pipeline_types():
967
+ raise ValueError(f"TTC strategy '{strategy_name}' is not compatible with pipeline type '{pipeline_type}'")
968
+
969
+ return config
970
+
971
+ @override
972
+ def get_user_manager(self):
973
+ return UserManagerHolder(context=Context(self._context_state))
974
+
975
+ async def add_telemetry_exporter(self, name: str, config: TelemetryExporterBaseConfig) -> None:
976
+ """Add an configured telemetry exporter to the builder.
977
+
978
+ Args:
979
+ name (str): The name of the telemetry exporter
980
+ config (TelemetryExporterBaseConfig): The configuration for the exporter
981
+ """
982
+ if (name in self._telemetry_exporters):
983
+ raise ValueError(f"Telemetry exporter '{name}' already exists in the list of telemetry exporters")
984
+
985
+ exporter_info = self._registry.get_telemetry_exporter(type(config))
986
+
987
+ # Build the exporter outside the lock (parallel)
988
+ exporter_context_manager = exporter_info.build_fn(config, self)
989
+
990
+ # Only protect the shared state modifications (serialized)
991
+ exporter = await self._get_exit_stack().enter_async_context(exporter_context_manager)
992
+ self._telemetry_exporters[name] = ConfiguredTelemetryExporter(config=config, instance=exporter)
993
+
994
+ def _log_build_failure(self,
995
+ component_name: str,
996
+ component_type: str,
997
+ completed_components: list[tuple[str, str]],
998
+ remaining_components: list[tuple[str, str]],
999
+ original_error: Exception) -> None:
1000
+ """
1001
+ Common method to log comprehensive build failure information.
1002
+
1003
+ Args:
1004
+ component_name (str): The name of the component that failed to build
1005
+ component_type (str): The type of the component that failed to build
1006
+ completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
1007
+ remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
1008
+ original_error (Exception): The original exception that caused the failure
1009
+ """
1010
+ logger.error("Failed to initialize component %s (%s)", component_name, component_type)
1011
+
1012
+ if completed_components:
1013
+ logger.error("Successfully built components:")
1014
+ for name, comp_type in completed_components:
1015
+ logger.error("- %s (%s)", name, comp_type)
1016
+ else:
1017
+ logger.error("No components were successfully built before this failure")
1018
+
1019
+ if remaining_components:
1020
+ logger.error("Remaining components to build:")
1021
+ for name, comp_type in remaining_components:
1022
+ logger.error("- %s (%s)", name, comp_type)
1023
+ else:
1024
+ logger.error("No remaining components to build")
1025
+
1026
+ logger.error("Original error: %s", original_error, exc_info=True)
1027
+
1028
+ def _log_build_failure_component(self,
1029
+ failing_component: ComponentInstanceData,
1030
+ completed_components: list[tuple[str, str]],
1031
+ remaining_components: list[tuple[str, str]],
1032
+ original_error: Exception) -> None:
1033
+ """
1034
+ Log comprehensive component build failure information.
1035
+
1036
+ Args:
1037
+ failing_component (ComponentInstanceData): The ComponentInstanceData that failed to build
1038
+ completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
1039
+ remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
1040
+ original_error (Exception): The original exception that caused the failure
1041
+ """
1042
+ component_name = failing_component.name
1043
+ component_type = failing_component.component_group.value
1044
+
1045
+ self._log_build_failure(component_name,
1046
+ component_type,
1047
+ completed_components,
1048
+ remaining_components,
1049
+ original_error)
1050
+
1051
+ def _log_build_failure_workflow(self,
1052
+ completed_components: list[tuple[str, str]],
1053
+ remaining_components: list[tuple[str, str]],
1054
+ original_error: Exception) -> None:
1055
+ """
1056
+ Log comprehensive workflow build failure information.
1057
+
1058
+ Args:
1059
+ completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
1060
+ remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
1061
+ original_error (Exception): The original exception that caused the failure
1062
+ """
1063
+ self._log_build_failure("<workflow>", "workflow", completed_components, remaining_components, original_error)
1064
+
1065
+ async def populate_builder(self, config: Config, skip_workflow: bool = False):
1066
+ """
1067
+ Populate the builder with components and optionally set up the workflow.
1068
+
1069
+ Args:
1070
+ config (Config): The configuration object containing component definitions.
1071
+ skip_workflow (bool): If True, skips the workflow instantiation step. Defaults to False.
1072
+
1073
+ """
1074
+ # Generate the build sequence
1075
+ build_sequence = build_dependency_sequence(config)
1076
+
1077
+ # Initialize progress tracking
1078
+ completed_components = []
1079
+ remaining_components = [(str(comp.name), comp.component_group.value) for comp in build_sequence
1080
+ if not comp.is_root]
1081
+ if not skip_workflow:
1082
+ remaining_components.append(("<workflow>", "workflow"))
1083
+
1084
+ # Loop over all objects and add to the workflow builder
1085
+ for component_instance in build_sequence:
1086
+ try:
1087
+ # Remove from remaining as we start building (if not root)
1088
+ if not component_instance.is_root:
1089
+ remaining_components.remove(
1090
+ (str(component_instance.name), component_instance.component_group.value))
1091
+
1092
+ # Instantiate a the llm
1093
+ if component_instance.component_group == ComponentGroup.LLMS:
1094
+ await self.add_llm(component_instance.name, cast(LLMBaseConfig, component_instance.config))
1095
+ # Instantiate a the embedder
1096
+ elif component_instance.component_group == ComponentGroup.EMBEDDERS:
1097
+ await self.add_embedder(component_instance.name,
1098
+ cast(EmbedderBaseConfig, component_instance.config))
1099
+ # Instantiate a memory client
1100
+ elif component_instance.component_group == ComponentGroup.MEMORY:
1101
+ await self.add_memory_client(component_instance.name,
1102
+ cast(MemoryBaseConfig, component_instance.config))
1103
+ # Instantiate a object store client
1104
+ elif component_instance.component_group == ComponentGroup.OBJECT_STORES:
1105
+ await self.add_object_store(component_instance.name,
1106
+ cast(ObjectStoreBaseConfig, component_instance.config))
1107
+ # Instantiate a retriever client
1108
+ elif component_instance.component_group == ComponentGroup.RETRIEVERS:
1109
+ await self.add_retriever(component_instance.name,
1110
+ cast(RetrieverBaseConfig, component_instance.config))
1111
+ # Instantiate a function group
1112
+ elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
1113
+ await self.add_function_group(component_instance.name,
1114
+ cast(FunctionGroupBaseConfig, component_instance.config))
1115
+ # Instantiate a function
1116
+ elif component_instance.component_group == ComponentGroup.FUNCTIONS:
1117
+ # If the function is the root, set it as the workflow later
1118
+ if (not component_instance.is_root):
1119
+ await self.add_function(component_instance.name,
1120
+ cast(FunctionBaseConfig, component_instance.config))
1121
+ elif component_instance.component_group == ComponentGroup.TTC_STRATEGIES:
1122
+ await self.add_ttc_strategy(component_instance.name,
1123
+ cast(TTCStrategyBaseConfig, component_instance.config))
1124
+
1125
+ elif component_instance.component_group == ComponentGroup.AUTHENTICATION:
1126
+ await self.add_auth_provider(component_instance.name,
1127
+ cast(AuthProviderBaseConfig, component_instance.config))
1128
+ else:
1129
+ raise ValueError(f"Unknown component group {component_instance.component_group}")
1130
+
1131
+ # Add to completed after successful build (if not root)
1132
+ if not component_instance.is_root:
1133
+ completed_components.append(
1134
+ (str(component_instance.name), component_instance.component_group.value))
1135
+
1136
+ except Exception as e:
1137
+ self._log_build_failure_component(component_instance, completed_components, remaining_components, e)
1138
+ raise
1139
+
1140
+ # Instantiate the workflow
1141
+ if not skip_workflow:
1142
+ try:
1143
+ # Remove workflow from remaining as we start building
1144
+ remaining_components.remove(("<workflow>", "workflow"))
1145
+ await self.set_workflow(config.workflow)
1146
+ completed_components.append(("<workflow>", "workflow"))
1147
+ except Exception as e:
1148
+ self._log_build_failure_workflow(completed_components, remaining_components, e)
1149
+ raise
1150
+
1151
+ @classmethod
1152
+ @asynccontextmanager
1153
+ async def from_config(cls, config: Config):
1154
+
1155
+ async with cls(general_config=config.general) as builder:
1156
+ await builder.populate_builder(config)
1157
+ yield builder
1158
+
1159
+
1160
+ class ChildBuilder(Builder):
1161
+
1162
+ def __init__(self, workflow_builder: WorkflowBuilder) -> None:
1163
+
1164
+ self._workflow_builder = workflow_builder
1165
+
1166
+ self._dependencies = FunctionDependencies()
1167
+
1168
+ @property
1169
+ def dependencies(self) -> FunctionDependencies:
1170
+ return self._dependencies
1171
+
1172
+ @override
1173
+ async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
1174
+ return await self._workflow_builder.add_function(name, config)
1175
+
1176
+ @override
1177
+ async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
1178
+ return await self._workflow_builder.add_function_group(name, config)
1179
+
1180
+ @override
1181
+ async def get_function(self, name: str) -> Function:
1182
+ # If a function tries to get another function, we assume it uses it
1183
+ fn = await self._workflow_builder.get_function(name)
1184
+
1185
+ self._dependencies.add_function(name)
1186
+
1187
+ return fn
1188
+
1189
+ @override
1190
+ async def get_function_group(self, name: str) -> FunctionGroup:
1191
+ # If a function tries to get a function group, we assume it uses it
1192
+ function_group = await self._workflow_builder.get_function_group(name)
1193
+
1194
+ self._dependencies.add_function_group(name)
1195
+
1196
+ return function_group
1197
+
1198
+ @override
1199
+ def get_function_config(self, name: str) -> FunctionBaseConfig:
1200
+ return self._workflow_builder.get_function_config(name)
1201
+
1202
+ @override
1203
+ def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
1204
+ return self._workflow_builder.get_function_group_config(name)
1205
+
1206
+ @override
1207
+ async def set_workflow(self, config: FunctionBaseConfig) -> Function:
1208
+ return await self._workflow_builder.set_workflow(config)
1209
+
1210
+ @override
1211
+ def get_workflow(self) -> Function:
1212
+ return self._workflow_builder.get_workflow()
1213
+
1214
+ @override
1215
+ def get_workflow_config(self) -> FunctionBaseConfig:
1216
+ return self._workflow_builder.get_workflow_config()
1217
+
1218
+ @override
1219
+ async def get_tools(self,
1220
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
1221
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
1222
+ tools = await self._workflow_builder.get_tools(tool_names, wrapper_type)
1223
+ for tool_name in tool_names:
1224
+ if tool_name in self._workflow_builder._function_groups:
1225
+ self._dependencies.add_function_group(tool_name)
1226
+ else:
1227
+ self._dependencies.add_function(tool_name)
1228
+ return tools
1229
+
1230
+ @override
1231
+ async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
1232
+ # If a function tries to get another function as a tool, we assume it uses it
1233
+ fn = await self._workflow_builder.get_tool(fn_name, wrapper_type)
1234
+
1235
+ self._dependencies.add_function(fn_name)
1236
+
1237
+ return fn
1238
+
1239
+ @override
1240
+ async def add_llm(self, name: str, config: LLMBaseConfig) -> None:
1241
+ return await self._workflow_builder.add_llm(name, config)
1242
+
1243
+ @experimental(feature_name="Authentication")
1244
+ @override
1245
+ async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> AuthProviderBase:
1246
+ return await self._workflow_builder.add_auth_provider(name, config)
1247
+
1248
+ @override
1249
+ async def get_auth_provider(self, auth_provider_name: str):
1250
+ return await self._workflow_builder.get_auth_provider(auth_provider_name)
1251
+
1252
+ @override
1253
+ async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
1254
+ llm = await self._workflow_builder.get_llm(llm_name, wrapper_type)
1255
+
1256
+ self._dependencies.add_llm(llm_name)
1257
+
1258
+ return llm
1259
+
1260
+ @override
1261
+ def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
1262
+ return self._workflow_builder.get_llm_config(llm_name)
1263
+
1264
+ @override
1265
+ async def add_embedder(self, name: str, config: EmbedderBaseConfig) -> None:
1266
+ await self._workflow_builder.add_embedder(name, config)
1267
+
1268
+ @override
1269
+ async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
1270
+ embedder = await self._workflow_builder.get_embedder(embedder_name, wrapper_type)
1271
+
1272
+ self._dependencies.add_embedder(embedder_name)
1273
+
1274
+ return embedder
1275
+
1276
+ @override
1277
+ def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
1278
+ return self._workflow_builder.get_embedder_config(embedder_name)
1279
+
1280
+ @override
1281
+ async def add_memory_client(self, name: str, config: MemoryBaseConfig) -> MemoryEditor:
1282
+ return await self._workflow_builder.add_memory_client(name, config)
1283
+
1284
+ @override
1285
+ async def get_memory_client(self, memory_name: str) -> MemoryEditor:
1286
+ """
1287
+ Return the instantiated memory client for the given name.
1288
+ """
1289
+ memory_client = await self._workflow_builder.get_memory_client(memory_name)
1290
+
1291
+ self._dependencies.add_memory_client(memory_name)
1292
+
1293
+ return memory_client
1294
+
1295
+ @override
1296
+ def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
1297
+ return self._workflow_builder.get_memory_client_config(memory_name=memory_name)
1298
+
1299
+ @override
1300
+ async def add_object_store(self, name: str, config: ObjectStoreBaseConfig):
1301
+ return await self._workflow_builder.add_object_store(name, config)
1302
+
1303
+ @override
1304
+ async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
1305
+ """
1306
+ Return the instantiated object store client for the given name.
1307
+ """
1308
+ object_store_client = await self._workflow_builder.get_object_store_client(object_store_name)
1309
+
1310
+ self._dependencies.add_object_store(object_store_name)
1311
+
1312
+ return object_store_client
1313
+
1314
+ @override
1315
+ def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
1316
+ return self._workflow_builder.get_object_store_config(object_store_name)
1317
+
1318
+ @override
1319
+ @experimental(feature_name="TTC")
1320
+ async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
1321
+ await self._workflow_builder.add_ttc_strategy(name, config)
1322
+
1323
+ @override
1324
+ async def get_ttc_strategy(self,
1325
+ strategy_name: str | TTCStrategyRef,
1326
+ pipeline_type: PipelineTypeEnum,
1327
+ stage_type: StageTypeEnum) -> StrategyBase:
1328
+ return await self._workflow_builder.get_ttc_strategy(strategy_name=strategy_name,
1329
+ pipeline_type=pipeline_type,
1330
+ stage_type=stage_type)
1331
+
1332
+ @override
1333
+ async def get_ttc_strategy_config(self,
1334
+ strategy_name: str | TTCStrategyRef,
1335
+ pipeline_type: PipelineTypeEnum,
1336
+ stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
1337
+ return await self._workflow_builder.get_ttc_strategy_config(strategy_name=strategy_name,
1338
+ pipeline_type=pipeline_type,
1339
+ stage_type=stage_type)
1340
+
1341
+ @override
1342
+ async def add_retriever(self, name: str, config: RetrieverBaseConfig) -> None:
1343
+ await self._workflow_builder.add_retriever(name, config)
1344
+
1345
+ @override
1346
+ async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None) -> Retriever:
1347
+ if not wrapper_type:
1348
+ return await self._workflow_builder.get_retriever(retriever_name=retriever_name)
1349
+ return await self._workflow_builder.get_retriever(retriever_name=retriever_name, wrapper_type=wrapper_type)
1350
+
1351
+ @override
1352
+ async def get_retriever_config(self, retriever_name: str) -> RetrieverBaseConfig:
1353
+ return await self._workflow_builder.get_retriever_config(retriever_name=retriever_name)
1354
+
1355
+ @override
1356
+ def get_user_manager(self) -> UserManagerHolder:
1357
+ return self._workflow_builder.get_user_manager()
1358
+
1359
+ @override
1360
+ def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
1361
+ return self._workflow_builder.get_function_dependencies(fn_name)
1362
+
1363
+ @override
1364
+ def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
1365
+ return self._workflow_builder.get_function_group_dependencies(fn_name)