nvidia-nat 1.1.0a20251020__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (480) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/__init__.py +0 -0
  3. nat/agent/base.py +265 -0
  4. nat/agent/dual_node.py +72 -0
  5. nat/agent/prompt_optimizer/__init__.py +0 -0
  6. nat/agent/prompt_optimizer/prompt.py +68 -0
  7. nat/agent/prompt_optimizer/register.py +149 -0
  8. nat/agent/react_agent/__init__.py +0 -0
  9. nat/agent/react_agent/agent.py +394 -0
  10. nat/agent/react_agent/output_parser.py +104 -0
  11. nat/agent/react_agent/prompt.py +44 -0
  12. nat/agent/react_agent/register.py +168 -0
  13. nat/agent/reasoning_agent/__init__.py +0 -0
  14. nat/agent/reasoning_agent/reasoning_agent.py +227 -0
  15. nat/agent/register.py +23 -0
  16. nat/agent/rewoo_agent/__init__.py +0 -0
  17. nat/agent/rewoo_agent/agent.py +593 -0
  18. nat/agent/rewoo_agent/prompt.py +107 -0
  19. nat/agent/rewoo_agent/register.py +175 -0
  20. nat/agent/tool_calling_agent/__init__.py +0 -0
  21. nat/agent/tool_calling_agent/agent.py +246 -0
  22. nat/agent/tool_calling_agent/register.py +129 -0
  23. nat/authentication/__init__.py +14 -0
  24. nat/authentication/api_key/__init__.py +14 -0
  25. nat/authentication/api_key/api_key_auth_provider.py +96 -0
  26. nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
  27. nat/authentication/api_key/register.py +26 -0
  28. nat/authentication/credential_validator/__init__.py +14 -0
  29. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  30. nat/authentication/exceptions/__init__.py +14 -0
  31. nat/authentication/exceptions/api_key_exceptions.py +38 -0
  32. nat/authentication/http_basic_auth/__init__.py +0 -0
  33. nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  34. nat/authentication/http_basic_auth/register.py +30 -0
  35. nat/authentication/interfaces.py +96 -0
  36. nat/authentication/oauth2/__init__.py +14 -0
  37. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +140 -0
  38. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  39. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  40. nat/authentication/oauth2/register.py +25 -0
  41. nat/authentication/register.py +20 -0
  42. nat/builder/__init__.py +0 -0
  43. nat/builder/builder.py +317 -0
  44. nat/builder/component_utils.py +320 -0
  45. nat/builder/context.py +321 -0
  46. nat/builder/embedder.py +24 -0
  47. nat/builder/eval_builder.py +166 -0
  48. nat/builder/evaluator.py +29 -0
  49. nat/builder/framework_enum.py +25 -0
  50. nat/builder/front_end.py +73 -0
  51. nat/builder/function.py +714 -0
  52. nat/builder/function_base.py +380 -0
  53. nat/builder/function_info.py +625 -0
  54. nat/builder/intermediate_step_manager.py +206 -0
  55. nat/builder/llm.py +25 -0
  56. nat/builder/retriever.py +25 -0
  57. nat/builder/user_interaction_manager.py +78 -0
  58. nat/builder/workflow.py +160 -0
  59. nat/builder/workflow_builder.py +1365 -0
  60. nat/cli/__init__.py +14 -0
  61. nat/cli/cli_utils/__init__.py +0 -0
  62. nat/cli/cli_utils/config_override.py +231 -0
  63. nat/cli/cli_utils/validation.py +37 -0
  64. nat/cli/commands/__init__.py +0 -0
  65. nat/cli/commands/configure/__init__.py +0 -0
  66. nat/cli/commands/configure/channel/__init__.py +0 -0
  67. nat/cli/commands/configure/channel/add.py +28 -0
  68. nat/cli/commands/configure/channel/channel.py +34 -0
  69. nat/cli/commands/configure/channel/remove.py +30 -0
  70. nat/cli/commands/configure/channel/update.py +30 -0
  71. nat/cli/commands/configure/configure.py +33 -0
  72. nat/cli/commands/evaluate.py +139 -0
  73. nat/cli/commands/info/__init__.py +14 -0
  74. nat/cli/commands/info/info.py +47 -0
  75. nat/cli/commands/info/list_channels.py +32 -0
  76. nat/cli/commands/info/list_components.py +128 -0
  77. nat/cli/commands/mcp/__init__.py +14 -0
  78. nat/cli/commands/mcp/mcp.py +986 -0
  79. nat/cli/commands/object_store/__init__.py +14 -0
  80. nat/cli/commands/object_store/object_store.py +227 -0
  81. nat/cli/commands/optimize.py +90 -0
  82. nat/cli/commands/registry/__init__.py +14 -0
  83. nat/cli/commands/registry/publish.py +88 -0
  84. nat/cli/commands/registry/pull.py +118 -0
  85. nat/cli/commands/registry/registry.py +36 -0
  86. nat/cli/commands/registry/remove.py +108 -0
  87. nat/cli/commands/registry/search.py +153 -0
  88. nat/cli/commands/sizing/__init__.py +14 -0
  89. nat/cli/commands/sizing/calc.py +297 -0
  90. nat/cli/commands/sizing/sizing.py +27 -0
  91. nat/cli/commands/start.py +257 -0
  92. nat/cli/commands/uninstall.py +81 -0
  93. nat/cli/commands/validate.py +47 -0
  94. nat/cli/commands/workflow/__init__.py +14 -0
  95. nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  96. nat/cli/commands/workflow/templates/config.yml.j2 +17 -0
  97. nat/cli/commands/workflow/templates/pyproject.toml.j2 +25 -0
  98. nat/cli/commands/workflow/templates/register.py.j2 +4 -0
  99. nat/cli/commands/workflow/templates/workflow.py.j2 +50 -0
  100. nat/cli/commands/workflow/workflow.py +37 -0
  101. nat/cli/commands/workflow/workflow_commands.py +403 -0
  102. nat/cli/entrypoint.py +141 -0
  103. nat/cli/main.py +60 -0
  104. nat/cli/register_workflow.py +522 -0
  105. nat/cli/type_registry.py +1069 -0
  106. nat/control_flow/__init__.py +0 -0
  107. nat/control_flow/register.py +20 -0
  108. nat/control_flow/router_agent/__init__.py +0 -0
  109. nat/control_flow/router_agent/agent.py +329 -0
  110. nat/control_flow/router_agent/prompt.py +48 -0
  111. nat/control_flow/router_agent/register.py +91 -0
  112. nat/control_flow/sequential_executor.py +166 -0
  113. nat/data_models/__init__.py +14 -0
  114. nat/data_models/agent.py +34 -0
  115. nat/data_models/api_server.py +843 -0
  116. nat/data_models/authentication.py +245 -0
  117. nat/data_models/common.py +171 -0
  118. nat/data_models/component.py +60 -0
  119. nat/data_models/component_ref.py +179 -0
  120. nat/data_models/config.py +434 -0
  121. nat/data_models/dataset_handler.py +169 -0
  122. nat/data_models/discovery_metadata.py +305 -0
  123. nat/data_models/embedder.py +27 -0
  124. nat/data_models/evaluate.py +130 -0
  125. nat/data_models/evaluator.py +26 -0
  126. nat/data_models/front_end.py +26 -0
  127. nat/data_models/function.py +64 -0
  128. nat/data_models/function_dependencies.py +80 -0
  129. nat/data_models/gated_field_mixin.py +242 -0
  130. nat/data_models/interactive.py +246 -0
  131. nat/data_models/intermediate_step.py +302 -0
  132. nat/data_models/invocation_node.py +38 -0
  133. nat/data_models/llm.py +27 -0
  134. nat/data_models/logging.py +26 -0
  135. nat/data_models/memory.py +27 -0
  136. nat/data_models/object_store.py +44 -0
  137. nat/data_models/optimizable.py +119 -0
  138. nat/data_models/optimizer.py +149 -0
  139. nat/data_models/profiler.py +54 -0
  140. nat/data_models/registry_handler.py +26 -0
  141. nat/data_models/retriever.py +30 -0
  142. nat/data_models/retry_mixin.py +35 -0
  143. nat/data_models/span.py +228 -0
  144. nat/data_models/step_adaptor.py +64 -0
  145. nat/data_models/streaming.py +33 -0
  146. nat/data_models/swe_bench_model.py +54 -0
  147. nat/data_models/telemetry_exporter.py +26 -0
  148. nat/data_models/temperature_mixin.py +44 -0
  149. nat/data_models/thinking_mixin.py +86 -0
  150. nat/data_models/top_p_mixin.py +44 -0
  151. nat/data_models/ttc_strategy.py +30 -0
  152. nat/embedder/__init__.py +0 -0
  153. nat/embedder/azure_openai_embedder.py +46 -0
  154. nat/embedder/nim_embedder.py +59 -0
  155. nat/embedder/openai_embedder.py +42 -0
  156. nat/embedder/register.py +22 -0
  157. nat/eval/__init__.py +14 -0
  158. nat/eval/config.py +62 -0
  159. nat/eval/dataset_handler/__init__.py +0 -0
  160. nat/eval/dataset_handler/dataset_downloader.py +106 -0
  161. nat/eval/dataset_handler/dataset_filter.py +52 -0
  162. nat/eval/dataset_handler/dataset_handler.py +431 -0
  163. nat/eval/evaluate.py +565 -0
  164. nat/eval/evaluator/__init__.py +14 -0
  165. nat/eval/evaluator/base_evaluator.py +77 -0
  166. nat/eval/evaluator/evaluator_model.py +58 -0
  167. nat/eval/intermediate_step_adapter.py +99 -0
  168. nat/eval/rag_evaluator/__init__.py +0 -0
  169. nat/eval/rag_evaluator/evaluate.py +178 -0
  170. nat/eval/rag_evaluator/register.py +143 -0
  171. nat/eval/register.py +26 -0
  172. nat/eval/remote_workflow.py +133 -0
  173. nat/eval/runners/__init__.py +14 -0
  174. nat/eval/runners/config.py +39 -0
  175. nat/eval/runners/multi_eval_runner.py +54 -0
  176. nat/eval/runtime_evaluator/__init__.py +14 -0
  177. nat/eval/runtime_evaluator/evaluate.py +123 -0
  178. nat/eval/runtime_evaluator/register.py +100 -0
  179. nat/eval/runtime_event_subscriber.py +52 -0
  180. nat/eval/swe_bench_evaluator/__init__.py +0 -0
  181. nat/eval/swe_bench_evaluator/evaluate.py +215 -0
  182. nat/eval/swe_bench_evaluator/register.py +36 -0
  183. nat/eval/trajectory_evaluator/__init__.py +0 -0
  184. nat/eval/trajectory_evaluator/evaluate.py +75 -0
  185. nat/eval/trajectory_evaluator/register.py +40 -0
  186. nat/eval/tunable_rag_evaluator/__init__.py +0 -0
  187. nat/eval/tunable_rag_evaluator/evaluate.py +242 -0
  188. nat/eval/tunable_rag_evaluator/register.py +52 -0
  189. nat/eval/usage_stats.py +41 -0
  190. nat/eval/utils/__init__.py +0 -0
  191. nat/eval/utils/eval_trace_ctx.py +89 -0
  192. nat/eval/utils/output_uploader.py +140 -0
  193. nat/eval/utils/tqdm_position_registry.py +40 -0
  194. nat/eval/utils/weave_eval.py +193 -0
  195. nat/experimental/__init__.py +0 -0
  196. nat/experimental/decorators/__init__.py +0 -0
  197. nat/experimental/decorators/experimental_warning_decorator.py +154 -0
  198. nat/experimental/test_time_compute/__init__.py +0 -0
  199. nat/experimental/test_time_compute/editing/__init__.py +0 -0
  200. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  201. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  202. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  203. nat/experimental/test_time_compute/functions/__init__.py +0 -0
  204. nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  205. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +228 -0
  206. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
  207. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
  208. nat/experimental/test_time_compute/models/__init__.py +0 -0
  209. nat/experimental/test_time_compute/models/editor_config.py +132 -0
  210. nat/experimental/test_time_compute/models/scoring_config.py +112 -0
  211. nat/experimental/test_time_compute/models/search_config.py +120 -0
  212. nat/experimental/test_time_compute/models/selection_config.py +154 -0
  213. nat/experimental/test_time_compute/models/stage_enums.py +43 -0
  214. nat/experimental/test_time_compute/models/strategy_base.py +67 -0
  215. nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
  216. nat/experimental/test_time_compute/models/ttc_item.py +48 -0
  217. nat/experimental/test_time_compute/register.py +35 -0
  218. nat/experimental/test_time_compute/scoring/__init__.py +0 -0
  219. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  220. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  221. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  222. nat/experimental/test_time_compute/search/__init__.py +0 -0
  223. nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  224. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  225. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  226. nat/experimental/test_time_compute/selection/__init__.py +0 -0
  227. nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  228. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  229. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +157 -0
  230. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  231. nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  232. nat/front_ends/__init__.py +14 -0
  233. nat/front_ends/console/__init__.py +14 -0
  234. nat/front_ends/console/authentication_flow_handler.py +285 -0
  235. nat/front_ends/console/console_front_end_config.py +32 -0
  236. nat/front_ends/console/console_front_end_plugin.py +108 -0
  237. nat/front_ends/console/register.py +25 -0
  238. nat/front_ends/cron/__init__.py +14 -0
  239. nat/front_ends/fastapi/__init__.py +14 -0
  240. nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  241. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  242. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +142 -0
  243. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  244. nat/front_ends/fastapi/fastapi_front_end_config.py +272 -0
  245. nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  246. nat/front_ends/fastapi/fastapi_front_end_plugin.py +247 -0
  247. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1257 -0
  248. nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
  249. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  250. nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  251. nat/front_ends/fastapi/job_store.py +602 -0
  252. nat/front_ends/fastapi/main.py +64 -0
  253. nat/front_ends/fastapi/message_handler.py +344 -0
  254. nat/front_ends/fastapi/message_validator.py +351 -0
  255. nat/front_ends/fastapi/register.py +25 -0
  256. nat/front_ends/fastapi/response_helpers.py +195 -0
  257. nat/front_ends/fastapi/step_adaptor.py +319 -0
  258. nat/front_ends/fastapi/utils.py +57 -0
  259. nat/front_ends/mcp/__init__.py +14 -0
  260. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  261. nat/front_ends/mcp/mcp_front_end_config.py +90 -0
  262. nat/front_ends/mcp/mcp_front_end_plugin.py +113 -0
  263. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +268 -0
  264. nat/front_ends/mcp/memory_profiler.py +320 -0
  265. nat/front_ends/mcp/register.py +27 -0
  266. nat/front_ends/mcp/tool_converter.py +290 -0
  267. nat/front_ends/register.py +21 -0
  268. nat/front_ends/simple_base/__init__.py +14 -0
  269. nat/front_ends/simple_base/simple_front_end_plugin_base.py +56 -0
  270. nat/llm/__init__.py +0 -0
  271. nat/llm/aws_bedrock_llm.py +69 -0
  272. nat/llm/azure_openai_llm.py +57 -0
  273. nat/llm/litellm_llm.py +69 -0
  274. nat/llm/nim_llm.py +58 -0
  275. nat/llm/openai_llm.py +54 -0
  276. nat/llm/register.py +27 -0
  277. nat/llm/utils/__init__.py +14 -0
  278. nat/llm/utils/env_config_value.py +93 -0
  279. nat/llm/utils/error.py +17 -0
  280. nat/llm/utils/thinking.py +215 -0
  281. nat/memory/__init__.py +20 -0
  282. nat/memory/interfaces.py +183 -0
  283. nat/memory/models.py +112 -0
  284. nat/meta/pypi.md +58 -0
  285. nat/object_store/__init__.py +20 -0
  286. nat/object_store/in_memory_object_store.py +76 -0
  287. nat/object_store/interfaces.py +84 -0
  288. nat/object_store/models.py +38 -0
  289. nat/object_store/register.py +19 -0
  290. nat/observability/__init__.py +14 -0
  291. nat/observability/exporter/__init__.py +14 -0
  292. nat/observability/exporter/base_exporter.py +449 -0
  293. nat/observability/exporter/exporter.py +78 -0
  294. nat/observability/exporter/file_exporter.py +33 -0
  295. nat/observability/exporter/processing_exporter.py +550 -0
  296. nat/observability/exporter/raw_exporter.py +52 -0
  297. nat/observability/exporter/span_exporter.py +308 -0
  298. nat/observability/exporter_manager.py +335 -0
  299. nat/observability/mixin/__init__.py +14 -0
  300. nat/observability/mixin/batch_config_mixin.py +26 -0
  301. nat/observability/mixin/collector_config_mixin.py +23 -0
  302. nat/observability/mixin/file_mixin.py +288 -0
  303. nat/observability/mixin/file_mode.py +23 -0
  304. nat/observability/mixin/redaction_config_mixin.py +42 -0
  305. nat/observability/mixin/resource_conflict_mixin.py +134 -0
  306. nat/observability/mixin/serialize_mixin.py +61 -0
  307. nat/observability/mixin/tagging_config_mixin.py +62 -0
  308. nat/observability/mixin/type_introspection_mixin.py +496 -0
  309. nat/observability/processor/__init__.py +14 -0
  310. nat/observability/processor/batching_processor.py +308 -0
  311. nat/observability/processor/callback_processor.py +42 -0
  312. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  313. nat/observability/processor/intermediate_step_serializer.py +28 -0
  314. nat/observability/processor/processor.py +74 -0
  315. nat/observability/processor/processor_factory.py +70 -0
  316. nat/observability/processor/redaction/__init__.py +24 -0
  317. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  318. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  319. nat/observability/processor/redaction/redaction_processor.py +177 -0
  320. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  321. nat/observability/processor/span_tagging_processor.py +68 -0
  322. nat/observability/register.py +114 -0
  323. nat/observability/utils/__init__.py +14 -0
  324. nat/observability/utils/dict_utils.py +236 -0
  325. nat/observability/utils/time_utils.py +31 -0
  326. nat/plugins/.namespace +1 -0
  327. nat/profiler/__init__.py +0 -0
  328. nat/profiler/calc/__init__.py +14 -0
  329. nat/profiler/calc/calc_runner.py +626 -0
  330. nat/profiler/calc/calculations.py +288 -0
  331. nat/profiler/calc/data_models.py +188 -0
  332. nat/profiler/calc/plot.py +345 -0
  333. nat/profiler/callbacks/__init__.py +0 -0
  334. nat/profiler/callbacks/agno_callback_handler.py +295 -0
  335. nat/profiler/callbacks/base_callback_class.py +20 -0
  336. nat/profiler/callbacks/langchain_callback_handler.py +297 -0
  337. nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
  338. nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  339. nat/profiler/callbacks/token_usage_base_model.py +27 -0
  340. nat/profiler/data_frame_row.py +51 -0
  341. nat/profiler/data_models.py +24 -0
  342. nat/profiler/decorators/__init__.py +0 -0
  343. nat/profiler/decorators/framework_wrapper.py +180 -0
  344. nat/profiler/decorators/function_tracking.py +411 -0
  345. nat/profiler/forecasting/__init__.py +0 -0
  346. nat/profiler/forecasting/config.py +18 -0
  347. nat/profiler/forecasting/model_trainer.py +75 -0
  348. nat/profiler/forecasting/models/__init__.py +22 -0
  349. nat/profiler/forecasting/models/forecasting_base_model.py +42 -0
  350. nat/profiler/forecasting/models/linear_model.py +197 -0
  351. nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
  352. nat/profiler/inference_metrics_model.py +28 -0
  353. nat/profiler/inference_optimization/__init__.py +0 -0
  354. nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  355. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  356. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  357. nat/profiler/inference_optimization/data_models.py +386 -0
  358. nat/profiler/inference_optimization/experimental/__init__.py +0 -0
  359. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  360. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +404 -0
  361. nat/profiler/inference_optimization/llm_metrics.py +212 -0
  362. nat/profiler/inference_optimization/prompt_caching.py +163 -0
  363. nat/profiler/inference_optimization/token_uniqueness.py +107 -0
  364. nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
  365. nat/profiler/intermediate_property_adapter.py +102 -0
  366. nat/profiler/parameter_optimization/__init__.py +0 -0
  367. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  368. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  369. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  370. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  371. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  372. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  373. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  374. nat/profiler/profile_runner.py +478 -0
  375. nat/profiler/utils.py +186 -0
  376. nat/registry_handlers/__init__.py +0 -0
  377. nat/registry_handlers/local/__init__.py +0 -0
  378. nat/registry_handlers/local/local_handler.py +176 -0
  379. nat/registry_handlers/local/register_local.py +37 -0
  380. nat/registry_handlers/metadata_factory.py +60 -0
  381. nat/registry_handlers/package_utils.py +570 -0
  382. nat/registry_handlers/pypi/__init__.py +0 -0
  383. nat/registry_handlers/pypi/pypi_handler.py +248 -0
  384. nat/registry_handlers/pypi/register_pypi.py +40 -0
  385. nat/registry_handlers/register.py +20 -0
  386. nat/registry_handlers/registry_handler_base.py +157 -0
  387. nat/registry_handlers/rest/__init__.py +0 -0
  388. nat/registry_handlers/rest/register_rest.py +56 -0
  389. nat/registry_handlers/rest/rest_handler.py +236 -0
  390. nat/registry_handlers/schemas/__init__.py +0 -0
  391. nat/registry_handlers/schemas/headers.py +42 -0
  392. nat/registry_handlers/schemas/package.py +68 -0
  393. nat/registry_handlers/schemas/publish.py +68 -0
  394. nat/registry_handlers/schemas/pull.py +82 -0
  395. nat/registry_handlers/schemas/remove.py +36 -0
  396. nat/registry_handlers/schemas/search.py +91 -0
  397. nat/registry_handlers/schemas/status.py +47 -0
  398. nat/retriever/__init__.py +0 -0
  399. nat/retriever/interface.py +41 -0
  400. nat/retriever/milvus/__init__.py +14 -0
  401. nat/retriever/milvus/register.py +81 -0
  402. nat/retriever/milvus/retriever.py +228 -0
  403. nat/retriever/models.py +77 -0
  404. nat/retriever/nemo_retriever/__init__.py +14 -0
  405. nat/retriever/nemo_retriever/register.py +60 -0
  406. nat/retriever/nemo_retriever/retriever.py +190 -0
  407. nat/retriever/register.py +21 -0
  408. nat/runtime/__init__.py +14 -0
  409. nat/runtime/loader.py +220 -0
  410. nat/runtime/runner.py +292 -0
  411. nat/runtime/session.py +223 -0
  412. nat/runtime/user_metadata.py +130 -0
  413. nat/settings/__init__.py +0 -0
  414. nat/settings/global_settings.py +329 -0
  415. nat/test/.namespace +1 -0
  416. nat/tool/__init__.py +0 -0
  417. nat/tool/chat_completion.py +77 -0
  418. nat/tool/code_execution/README.md +151 -0
  419. nat/tool/code_execution/__init__.py +0 -0
  420. nat/tool/code_execution/code_sandbox.py +267 -0
  421. nat/tool/code_execution/local_sandbox/.gitignore +1 -0
  422. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  423. nat/tool/code_execution/local_sandbox/__init__.py +13 -0
  424. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  425. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  426. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  427. nat/tool/code_execution/register.py +74 -0
  428. nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
  429. nat/tool/code_execution/utils.py +100 -0
  430. nat/tool/datetime_tools.py +82 -0
  431. nat/tool/document_search.py +141 -0
  432. nat/tool/github_tools.py +450 -0
  433. nat/tool/memory_tools/__init__.py +0 -0
  434. nat/tool/memory_tools/add_memory_tool.py +79 -0
  435. nat/tool/memory_tools/delete_memory_tool.py +66 -0
  436. nat/tool/memory_tools/get_memory_tool.py +72 -0
  437. nat/tool/nvidia_rag.py +95 -0
  438. nat/tool/register.py +31 -0
  439. nat/tool/retriever.py +95 -0
  440. nat/tool/server_tools.py +66 -0
  441. nat/utils/__init__.py +0 -0
  442. nat/utils/callable_utils.py +70 -0
  443. nat/utils/data_models/__init__.py +0 -0
  444. nat/utils/data_models/schema_validator.py +58 -0
  445. nat/utils/debugging_utils.py +43 -0
  446. nat/utils/decorators.py +210 -0
  447. nat/utils/dump_distro_mapping.py +32 -0
  448. nat/utils/exception_handlers/__init__.py +0 -0
  449. nat/utils/exception_handlers/automatic_retries.py +342 -0
  450. nat/utils/exception_handlers/schemas.py +114 -0
  451. nat/utils/io/__init__.py +0 -0
  452. nat/utils/io/model_processing.py +28 -0
  453. nat/utils/io/yaml_tools.py +119 -0
  454. nat/utils/log_levels.py +25 -0
  455. nat/utils/log_utils.py +37 -0
  456. nat/utils/metadata_utils.py +74 -0
  457. nat/utils/optional_imports.py +142 -0
  458. nat/utils/producer_consumer_queue.py +178 -0
  459. nat/utils/reactive/__init__.py +0 -0
  460. nat/utils/reactive/base/__init__.py +0 -0
  461. nat/utils/reactive/base/observable_base.py +65 -0
  462. nat/utils/reactive/base/observer_base.py +55 -0
  463. nat/utils/reactive/base/subject_base.py +79 -0
  464. nat/utils/reactive/observable.py +59 -0
  465. nat/utils/reactive/observer.py +76 -0
  466. nat/utils/reactive/subject.py +131 -0
  467. nat/utils/reactive/subscription.py +49 -0
  468. nat/utils/settings/__init__.py +0 -0
  469. nat/utils/settings/global_settings.py +195 -0
  470. nat/utils/string_utils.py +38 -0
  471. nat/utils/type_converter.py +299 -0
  472. nat/utils/type_utils.py +488 -0
  473. nat/utils/url_utils.py +27 -0
  474. nvidia_nat-1.1.0a20251020.dist-info/METADATA +195 -0
  475. nvidia_nat-1.1.0a20251020.dist-info/RECORD +480 -0
  476. nvidia_nat-1.1.0a20251020.dist-info/WHEEL +5 -0
  477. nvidia_nat-1.1.0a20251020.dist-info/entry_points.txt +22 -0
  478. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  479. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
  480. nvidia_nat-1.1.0a20251020.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1069 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import typing
18
+ from collections.abc import AsyncIterator
19
+ from collections.abc import Callable
20
+ from contextlib import AbstractAsyncContextManager
21
+ from contextlib import contextmanager
22
+ from copy import deepcopy
23
+ from functools import cached_property
24
+ from logging import Handler
25
+
26
+ from pydantic import BaseModel
27
+ from pydantic import ConfigDict
28
+ from pydantic import Field
29
+ from pydantic import Tag
30
+ from pydantic import computed_field
31
+ from pydantic import field_validator
32
+
33
+ from nat.authentication.interfaces import AuthProviderBase
34
+ from nat.builder.builder import Builder
35
+ from nat.builder.builder import EvalBuilder
36
+ from nat.builder.embedder import EmbedderProviderInfo
37
+ from nat.builder.evaluator import EvaluatorInfo
38
+ from nat.builder.front_end import FrontEndBase
39
+ from nat.builder.function import Function
40
+ from nat.builder.function import FunctionGroup
41
+ from nat.builder.function_base import FunctionBase
42
+ from nat.builder.function_info import FunctionInfo
43
+ from nat.builder.llm import LLMProviderInfo
44
+ from nat.builder.retriever import RetrieverProviderInfo
45
+ from nat.data_models.authentication import AuthProviderBaseConfig
46
+ from nat.data_models.authentication import AuthProviderBaseConfigT
47
+ from nat.data_models.common import TypedBaseModelT
48
+ from nat.data_models.component import ComponentEnum
49
+ from nat.data_models.config import Config
50
+ from nat.data_models.discovery_metadata import DiscoveryMetadata
51
+ from nat.data_models.embedder import EmbedderBaseConfig
52
+ from nat.data_models.embedder import EmbedderBaseConfigT
53
+ from nat.data_models.evaluator import EvaluatorBaseConfig
54
+ from nat.data_models.evaluator import EvaluatorBaseConfigT
55
+ from nat.data_models.front_end import FrontEndBaseConfig
56
+ from nat.data_models.front_end import FrontEndConfigT
57
+ from nat.data_models.function import FunctionBaseConfig
58
+ from nat.data_models.function import FunctionConfigT
59
+ from nat.data_models.function import FunctionGroupBaseConfig
60
+ from nat.data_models.function import FunctionGroupConfigT
61
+ from nat.data_models.llm import LLMBaseConfig
62
+ from nat.data_models.llm import LLMBaseConfigT
63
+ from nat.data_models.logging import LoggingBaseConfig
64
+ from nat.data_models.logging import LoggingMethodConfigT
65
+ from nat.data_models.memory import MemoryBaseConfig
66
+ from nat.data_models.memory import MemoryBaseConfigT
67
+ from nat.data_models.object_store import ObjectStoreBaseConfig
68
+ from nat.data_models.object_store import ObjectStoreBaseConfigT
69
+ from nat.data_models.registry_handler import RegistryHandlerBaseConfig
70
+ from nat.data_models.registry_handler import RegistryHandlerBaseConfigT
71
+ from nat.data_models.retriever import RetrieverBaseConfig
72
+ from nat.data_models.retriever import RetrieverBaseConfigT
73
+ from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
74
+ from nat.data_models.telemetry_exporter import TelemetryExporterConfigT
75
+ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
76
+ from nat.data_models.ttc_strategy import TTCStrategyBaseConfigT
77
+ from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
78
+ from nat.memory.interfaces import MemoryEditor
79
+ from nat.object_store.interfaces import ObjectStore
80
+ from nat.observability.exporter.base_exporter import BaseExporter
81
+ from nat.registry_handlers.registry_handler_base import AbstractRegistryHandler
82
+
83
+ logger = logging.getLogger(__name__)
84
+
85
+ AuthProviderBuildCallableT = Callable[[AuthProviderBaseConfigT, Builder], AsyncIterator[AuthProviderBase]]
86
+ EmbedderClientBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncIterator[typing.Any]]
87
+ EmbedderProviderBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncIterator[EmbedderProviderInfo]]
88
+ EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIterator[EvaluatorInfo]]
89
+ FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
90
+ FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
91
+ FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
92
+ TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
93
+ LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
94
+ LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
95
+ LoggingMethodBuildCallableT = Callable[[LoggingMethodConfigT, Builder], AsyncIterator[Handler]]
96
+ MemoryBuildCallableT = Callable[[MemoryBaseConfigT, Builder], AsyncIterator[MemoryEditor]]
97
+ ObjectStoreBuildCallableT = Callable[[ObjectStoreBaseConfigT, Builder], AsyncIterator[ObjectStore]]
98
+ RegistryHandlerBuildCallableT = Callable[[RegistryHandlerBaseConfigT], AsyncIterator[AbstractRegistryHandler]]
99
+ RetrieverClientBuildCallableT = Callable[[RetrieverBaseConfigT, Builder], AsyncIterator[typing.Any]]
100
+ RetrieverProviderBuildCallableT = Callable[[RetrieverBaseConfigT, Builder], AsyncIterator[RetrieverProviderInfo]]
101
+ TelemetryExporterBuildCallableT = Callable[[TelemetryExporterConfigT, Builder], AsyncIterator[BaseExporter]]
102
+ ToolWrapperBuildCallableT = Callable[[str, Function, Builder], typing.Any]
103
+
104
+ AuthProviderRegisteredCallableT = Callable[[AuthProviderBaseConfigT, Builder],
105
+ AbstractAsyncContextManager[AuthProviderBase]]
106
+ EmbedderClientRegisteredCallableT = Callable[[EmbedderBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
107
+ EmbedderProviderRegisteredCallableT = Callable[[EmbedderBaseConfigT, Builder],
108
+ AbstractAsyncContextManager[EmbedderProviderInfo]]
109
+ EvaluatorRegisteredCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AbstractAsyncContextManager[EvaluatorInfo]]
110
+ FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncContextManager[FrontEndBase]]
111
+ FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
112
+ AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
113
+ FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
114
+ TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
115
+ LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
116
+ LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
117
+ LoggingMethodRegisteredCallableT = Callable[[LoggingMethodConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
118
+ MemoryRegisteredCallableT = Callable[[MemoryBaseConfigT, Builder], AbstractAsyncContextManager[MemoryEditor]]
119
+ ObjectStoreRegisteredCallableT = Callable[[ObjectStoreBaseConfigT, Builder], AbstractAsyncContextManager[ObjectStore]]
120
+ RegistryHandlerRegisteredCallableT = Callable[[RegistryHandlerBaseConfigT],
121
+ AbstractAsyncContextManager[AbstractRegistryHandler]]
122
+ RetrieverClientRegisteredCallableT = Callable[[RetrieverBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
123
+ RetrieverProviderRegisteredCallableT = Callable[[RetrieverBaseConfigT, Builder],
124
+ AbstractAsyncContextManager[RetrieverProviderInfo]]
125
+ TeleExporterRegisteredCallableT = Callable[[TelemetryExporterConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
126
+
127
+
128
+ class RegisteredInfo(BaseModel, typing.Generic[TypedBaseModelT]):
129
+
130
+ model_config = ConfigDict(frozen=True)
131
+
132
+ full_type: str
133
+ config_type: type[TypedBaseModelT]
134
+ discovery_metadata: DiscoveryMetadata = DiscoveryMetadata()
135
+
136
+ @computed_field
137
+ @cached_property
138
+ def module_name(self) -> str:
139
+ return self.full_type.split("/")[0]
140
+
141
+ @computed_field
142
+ @cached_property
143
+ def local_name(self) -> str:
144
+ return self.full_type.split("/")[-1]
145
+
146
+ @field_validator("full_type", mode="after")
147
+ @classmethod
148
+ def validate_full_type(cls, full_type: str) -> str:
149
+ parts = full_type.split("/")
150
+
151
+ if (len(parts) != 2):
152
+ raise ValueError(f"Invalid full type: {full_type}. Expected format: `module_name/local_name`")
153
+
154
+ return full_type
155
+
156
+
157
+ class RegisteredTelemetryExporter(RegisteredInfo[TelemetryExporterBaseConfig]):
158
+
159
+ build_fn: TeleExporterRegisteredCallableT = Field(repr=False)
160
+
161
+
162
+ class RegisteredLoggingMethod(RegisteredInfo[LoggingBaseConfig]):
163
+
164
+ build_fn: LoggingMethodRegisteredCallableT = Field(repr=False)
165
+
166
+
167
+ class RegisteredFrontEndInfo(RegisteredInfo[FrontEndBaseConfig]):
168
+ """
169
+ Represents a registered front end. Front ends are the entry points to the workflow and are responsible for
170
+ orchestrating the workflow.
171
+ """
172
+
173
+ build_fn: FrontEndRegisteredCallableT = Field(repr=False)
174
+
175
+
176
+ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
177
+ """
178
+ Represents a registered function. Functions are the building blocks of the workflow with predefined inputs, outputs,
179
+ and a description.
180
+ """
181
+
182
+ build_fn: FunctionRegisteredCallableT = Field(repr=False)
183
+ framework_wrappers: list[str] = Field(default_factory=list)
184
+
185
+
186
+ class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
187
+ """
188
+ Represents a registered function group. Function groups are collections of functions that share configuration
189
+ and resources.
190
+ """
191
+
192
+ build_fn: FunctionGroupRegisteredCallableT = Field(repr=False)
193
+ framework_wrappers: list[str] = Field(default_factory=list)
194
+
195
+
196
+ class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
197
+ """
198
+ Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
199
+ etc.
200
+ """
201
+
202
+ build_fn: LLMProviderRegisteredCallableT = Field(repr=False)
203
+
204
+
205
+ class RegisteredAuthProviderInfo(RegisteredInfo[AuthProviderBaseConfig]):
206
+ """
207
+ Represents a registered Authentication provider. Authentication providers facilitate the authentication process.
208
+ """
209
+
210
+ build_fn: AuthProviderRegisteredCallableT = Field(repr=False)
211
+
212
+
213
+ class RegisteredLLMClientInfo(RegisteredInfo[LLMBaseConfig]):
214
+ """
215
+ Represents a registered LLM client. LLM Clients are the clients that interact with the LLM providers and are
216
+ specific to a particular LLM framework.
217
+ """
218
+
219
+ llm_framework: str
220
+ build_fn: LLMClientRegisteredCallableT = Field(repr=False)
221
+
222
+
223
+ class RegisteredEmbedderProviderInfo(RegisteredInfo[EmbedderBaseConfig]):
224
+ """
225
+ Represents a registered Embedder provider. Embedder Providers are the operators of the Embedder models. i.e. NIMs,
226
+ OpenAI, Anthropic, etc.
227
+ """
228
+
229
+ build_fn: EmbedderProviderRegisteredCallableT = Field(repr=False)
230
+
231
+
232
+ class RegisteredEmbedderClientInfo(RegisteredInfo[EmbedderBaseConfig]):
233
+ """
234
+ Represents a registered Embedder client. Embedder Clients are the clients that interact with the Embedder providers
235
+ and are specific to a particular LLM framework.
236
+ """
237
+
238
+ llm_framework: str
239
+ build_fn: EmbedderClientRegisteredCallableT = Field(repr=False)
240
+
241
+
242
+ class RegisteredEvaluatorInfo(RegisteredInfo[EvaluatorBaseConfig]):
243
+ """
244
+ Represents a registered Evaluator e.g. RagEvaluator, TrajectoryEvaluator, etc.
245
+ """
246
+
247
+ build_fn: EvaluatorRegisteredCallableT = Field(repr=False)
248
+
249
+
250
+ class RegisteredMemoryInfo(RegisteredInfo[MemoryBaseConfig]):
251
+ """
252
+ Represents a registered Memory object which adheres to the memory interface.
253
+ """
254
+
255
+ build_fn: MemoryRegisteredCallableT = Field(repr=False)
256
+
257
+
258
+ class RegisteredObjectStoreInfo(RegisteredInfo[ObjectStoreBaseConfig]):
259
+ """
260
+ Represents a registered Object Store object which adheres to the object store interface.
261
+ """
262
+
263
+ build_fn: ObjectStoreRegisteredCallableT = Field(repr=False)
264
+
265
+
266
+ class RegisteredTTCStrategyInfo(RegisteredInfo[TTCStrategyBaseConfig]):
267
+ """
268
+ Represents a registered TTC strategy.
269
+ """
270
+
271
+ build_fn: TTCStrategyRegisterCallableT = Field(repr=False)
272
+
273
+
274
+ class RegisteredToolWrapper(BaseModel):
275
+ """
276
+ Represents a registered tool wrapper. Tool wrappers are used to wrap the functions in a particular LLM framework.
277
+ They do not have their own configuration, but they are used to wrap the functions in a particular LLM framework.
278
+ """
279
+
280
+ llm_framework: str
281
+ build_fn: ToolWrapperBuildCallableT = Field(repr=False)
282
+ discovery_metadata: DiscoveryMetadata
283
+
284
+
285
+ class RegisteredRetrieverProviderInfo(RegisteredInfo[RetrieverBaseConfig]):
286
+ """
287
+ Represents a registered Retriever object which adheres to the retriever interface.
288
+ """
289
+
290
+ build_fn: RetrieverProviderRegisteredCallableT = Field(repr=False)
291
+
292
+
293
+ class RegisteredRetrieverClientInfo(RegisteredInfo[RetrieverBaseConfig]):
294
+ """
295
+ Represents a registered Retriever Client. Retriever Clients are the LLM Framework-specific clients that expose an
296
+ interface to the Retriever object.
297
+ """
298
+ llm_framework: str | None
299
+ build_fn: RetrieverClientRegisteredCallableT = Field(repr=False)
300
+
301
+
302
+ class RegisteredRegistryHandlerInfo(RegisteredInfo[RegistryHandlerBaseConfig]):
303
+ """
304
+ Represents a registered LLM client. LLM Clients are the clients that interact with the LLM providers and are
305
+ specific to a particular LLM framework.
306
+ """
307
+
308
+ build_fn: RegistryHandlerRegisteredCallableT = Field(repr=False)
309
+
310
+
311
+ class RegisteredPackage(BaseModel):
312
+ package_name: str
313
+ discovery_metadata: DiscoveryMetadata
314
+
315
+
316
+ class TypeRegistry:
317
+
318
+ def __init__(self) -> None:
319
+ # Telemetry Exporters
320
+ self._registered_telemetry_exporters: dict[type[TelemetryExporterBaseConfig], RegisteredTelemetryExporter] = {}
321
+
322
+ # Logging Methods
323
+ self._registered_logging_methods: dict[type[LoggingBaseConfig], RegisteredLoggingMethod] = {}
324
+
325
+ # Front Ends
326
+ self._registered_front_end_infos: dict[type[FrontEndBaseConfig], RegisteredFrontEndInfo] = {}
327
+
328
+ # Functions
329
+ self._registered_functions: dict[type[FunctionBaseConfig], RegisteredFunctionInfo] = {}
330
+
331
+ # Function Groups
332
+ self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
333
+
334
+ # LLMs
335
+ self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
336
+ self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
337
+ self._llm_client_framework_to_provider: dict[str, dict[type[LLMBaseConfig], RegisteredLLMClientInfo]] = {}
338
+
339
+ # Authentication
340
+ self._registered_auth_provider_infos: dict[type[AuthProviderBaseConfig], RegisteredAuthProviderInfo] = {}
341
+
342
+ # Embedders
343
+ self._registered_embedder_provider_infos: dict[type[EmbedderBaseConfig], RegisteredEmbedderProviderInfo] = {}
344
+ self._embedder_client_provider_to_framework: dict[type[EmbedderBaseConfig],
345
+ dict[str, RegisteredEmbedderClientInfo]] = {}
346
+ self._embedder_client_framework_to_provider: dict[str,
347
+ dict[type[EmbedderBaseConfig],
348
+ RegisteredEmbedderClientInfo]] = {}
349
+
350
+ # Evaluators
351
+ self._registered_evaluator_infos: dict[type[EvaluatorBaseConfig], RegisteredEvaluatorInfo] = {}
352
+
353
+ # Memory
354
+ self._registered_memory_infos: dict[type[MemoryBaseConfig], RegisteredMemoryInfo] = {}
355
+
356
+ # Object Stores
357
+ self._registered_object_store_infos: dict[type[ObjectStoreBaseConfig], RegisteredObjectStoreInfo] = {}
358
+
359
+ # Retrievers
360
+ self._registered_retriever_provider_infos: dict[type[RetrieverBaseConfig], RegisteredRetrieverProviderInfo] = {}
361
+ self._retriever_client_provider_to_framework: dict[type[RetrieverBaseConfig],
362
+ dict[str | None, RegisteredRetrieverClientInfo]] = {}
363
+ self._retriever_client_framework_to_provider: dict[str | None,
364
+ dict[type[RetrieverBaseConfig],
365
+ RegisteredRetrieverClientInfo]] = {}
366
+
367
+ # Registry Handlers
368
+ self._registered_registry_handler_infos: dict[type[RegistryHandlerBaseConfig],
369
+ RegisteredRegistryHandlerInfo] = {}
370
+
371
+ # Tool Wrappers
372
+ self._registered_tool_wrappers: dict[str, RegisteredToolWrapper] = {}
373
+
374
+ # TTC Strategies
375
+ self._registered_ttc_strategies: dict[type[TTCStrategyBaseConfig], RegisteredTTCStrategyInfo] = {}
376
+
377
+ # Packages
378
+ self._registered_packages: dict[str, RegisteredPackage] = {}
379
+
380
+ self._registration_changed_hooks: list[Callable[[], None]] = []
381
+ self._registration_changed_hooks_active: bool = True
382
+
383
+ self._registered_channel_map = {}
384
+
385
+ def _registration_changed(self):
386
+
387
+ if (not self._registration_changed_hooks_active):
388
+ return
389
+
390
+ logger.debug("Registration changed. Notifying hooks.")
391
+
392
+ for hook in self._registration_changed_hooks:
393
+ hook()
394
+
395
+ def add_registration_changed_hook(self, cb: Callable[[], typing.Any]) -> None:
396
+
397
+ self._registration_changed_hooks.append(cb)
398
+
399
+ @contextmanager
400
+ def pause_registration_changed_hooks(self):
401
+
402
+ self._registration_changed_hooks_active = False
403
+
404
+ try:
405
+ yield
406
+ finally:
407
+ self._registration_changed_hooks_active = True
408
+
409
+ # Ensure that the registration changed hooks are called
410
+ self._registration_changed()
411
+
412
+ def register_telemetry_exporter(self, registration: RegisteredTelemetryExporter):
413
+
414
+ if (registration.config_type in self._registered_telemetry_exporters):
415
+ raise ValueError(f"A telemetry exporter with the same config type `{registration.config_type}` has already "
416
+ "been registered.")
417
+
418
+ self._registered_telemetry_exporters[registration.config_type] = registration
419
+
420
+ self._registration_changed()
421
+
422
+ def get_telemetry_exporter(self, config_type: type[TelemetryExporterBaseConfig]) -> RegisteredTelemetryExporter:
423
+
424
+ try:
425
+ return self._registered_telemetry_exporters[config_type]
426
+ except KeyError as err:
427
+ raise KeyError(f"Could not find a registered telemetry exporter for config `{config_type}`. "
428
+ f"Registered configs: {set(self._registered_telemetry_exporters.keys())}") from err
429
+
430
+ def get_registered_telemetry_exporters(self) -> list[RegisteredInfo[TelemetryExporterBaseConfig]]:
431
+
432
+ return list(self._registered_telemetry_exporters.values())
433
+
434
+ def register_logging_method(self, registration: RegisteredLoggingMethod):
435
+
436
+ if (registration.config_type in self._registered_logging_methods):
437
+ raise ValueError(f"A logging method with the same config type `{registration.config_type}` has already "
438
+ "been registered.")
439
+
440
+ self._registered_logging_methods[registration.config_type] = registration
441
+
442
+ self._registration_changed()
443
+
444
+ def get_logging_method(self, config_type: type[LoggingBaseConfig]) -> RegisteredLoggingMethod:
445
+ try:
446
+ return self._registered_logging_methods[config_type]
447
+ except KeyError as err:
448
+ raise KeyError(f"No logging method found for config `{config_type}`. "
449
+ f"Known: {set(self._registered_logging_methods.keys())}") from err
450
+
451
+ def get_registered_logging_method(self) -> list[RegisteredInfo[LoggingBaseConfig]]:
452
+
453
+ return list(self._registered_logging_methods.values())
454
+
455
+ def register_front_end(self, registration: RegisteredFrontEndInfo):
456
+
457
+ if (registration.config_type in self._registered_front_end_infos):
458
+ raise ValueError(f"A front end with the same config type `{registration.config_type}` has already been "
459
+ "registered.")
460
+
461
+ self._registered_front_end_infos[registration.config_type] = registration
462
+
463
+ self._registration_changed()
464
+
465
+ def get_front_end(self, config_type: type[FrontEndBaseConfig]) -> RegisteredFrontEndInfo:
466
+
467
+ try:
468
+ return self._registered_front_end_infos[config_type]
469
+ except KeyError as err:
470
+ raise KeyError(f"Could not find a registered front end for config `{config_type}`. "
471
+ f"Registered configs: {set(self._registered_front_end_infos.keys())}") from err
472
+
473
+ def get_registered_front_ends(self) -> list[RegisteredInfo[FrontEndBaseConfig]]:
474
+
475
+ return list(self._registered_front_end_infos.values())
476
+
477
+ def register_function(self, registration: RegisteredFunctionInfo):
478
+
479
+ if (registration.config_type in self._registered_functions):
480
+ raise ValueError(f"A function with the same config type `{registration.config_type}` has already been "
481
+ "registered.")
482
+
483
+ self._registered_functions[registration.config_type] = registration
484
+
485
+ self._registration_changed()
486
+
487
+ def get_function(self, config_type: type[FunctionBaseConfig]) -> RegisteredFunctionInfo:
488
+
489
+ try:
490
+ return self._registered_functions[config_type]
491
+ except KeyError as err:
492
+ raise KeyError(f"Could not find a registered function for config `{config_type}`. "
493
+ f"Registered configs: {set(self._registered_functions.keys())}") from err
494
+
495
+ def get_registered_functions(self) -> list[RegisteredInfo[FunctionBaseConfig]]:
496
+
497
+ return list(self._registered_functions.values())
498
+
499
+ def register_function_group(self, registration: RegisteredFunctionGroupInfo):
500
+ """Register a function group with the type registry.
501
+
502
+ Args:
503
+ registration: The function group registration information
504
+
505
+ Raises:
506
+ ValueError: If a function group with the same config type is already registered
507
+ """
508
+ if (registration.config_type in self._registered_function_groups):
509
+ raise ValueError(
510
+ f"A function group with the same config type `{registration.config_type}` has already been "
511
+ "registered.")
512
+
513
+ self._registered_function_groups[registration.config_type] = registration
514
+
515
+ self._registration_changed()
516
+
517
+ def get_function_group(self, config_type: type[FunctionGroupBaseConfig]) -> RegisteredFunctionGroupInfo:
518
+ """Get a registered function group by its config type.
519
+
520
+ Args:
521
+ config_type: The function group configuration type
522
+
523
+ Returns:
524
+ RegisteredFunctionGroupInfo: The registered function group information
525
+
526
+ Raises:
527
+ KeyError: If no function group is registered for the given config type
528
+ """
529
+ try:
530
+ return self._registered_function_groups[config_type]
531
+ except KeyError as err:
532
+ raise KeyError(f"Could not find a registered function group for config `{config_type}`. "
533
+ f"Registered configs: {set(self._registered_function_groups.keys())}") from err
534
+
535
+ def get_registered_function_groups(self) -> list[RegisteredInfo[FunctionGroupBaseConfig]]:
536
+ """Get all registered function groups.
537
+
538
+ Returns:
539
+ list[RegisteredInfo[FunctionGroupBaseConfig]]: List of all registered function groups
540
+ """
541
+ return list(self._registered_function_groups.values())
542
+
543
+ def register_llm_provider(self, info: RegisteredLLMProviderInfo):
544
+
545
+ if (info.config_type in self._registered_llm_provider_infos):
546
+ raise ValueError(
547
+ f"An LLM provider with the same config type `{info.config_type}` has already been registered.")
548
+
549
+ self._registered_llm_provider_infos[info.config_type] = info
550
+
551
+ self._registration_changed()
552
+
553
+ def get_llm_provider(self, config_type: type[LLMBaseConfig]) -> RegisteredLLMProviderInfo:
554
+
555
+ try:
556
+ return self._registered_llm_provider_infos[config_type]
557
+ except KeyError as err:
558
+ raise KeyError(f"Could not find a registered LLM provider for config `{config_type}`. "
559
+ f"Registered configs: {set(self._registered_llm_provider_infos.keys())}") from err
560
+
561
+ def get_registered_llm_providers(self) -> list[RegisteredInfo[LLMBaseConfig]]:
562
+ return list(self._registered_llm_provider_infos.values())
563
+
564
+ def register_auth_provider(self, info: RegisteredAuthProviderInfo):
565
+
566
+ if (info.config_type in self._registered_auth_provider_infos):
567
+ raise ValueError(
568
+ f"An Authentication Provider with the same config type `{info.config_type}` has already been "
569
+ "registered.")
570
+
571
+ self._registered_auth_provider_infos[info.config_type] = info
572
+
573
+ self._registration_changed()
574
+
575
+ def get_auth_provider(self, config_type: type[AuthProviderBaseConfig]) -> RegisteredAuthProviderInfo:
576
+ try:
577
+ return self._registered_auth_provider_infos[config_type]
578
+ except KeyError as err:
579
+ raise KeyError(f"Could not find a registered Authentication Provider for config `{config_type}`. "
580
+ f"Registered configs: {set(self._registered_auth_provider_infos.keys())}") from err
581
+
582
+ def get_registered_auth_providers(self) -> list[RegisteredInfo[AuthProviderBaseConfig]]:
583
+ return list(self._registered_auth_provider_infos.values())
584
+
585
+ def register_llm_client(self, info: RegisteredLLMClientInfo):
586
+
587
+ if (info.config_type in self._llm_client_provider_to_framework
588
+ and info.llm_framework in self._llm_client_provider_to_framework[info.config_type]):
589
+ raise ValueError(f"An LLM client with the same config type `{info.config_type}` "
590
+ f"and LLM framework `{info.llm_framework}` has already been registered.")
591
+
592
+ self._llm_client_provider_to_framework.setdefault(info.config_type, {})[info.llm_framework] = info
593
+ self._llm_client_framework_to_provider.setdefault(info.llm_framework, {})[info.config_type] = info
594
+
595
+ self._registration_changed()
596
+
597
+ def get_llm_client(self, config_type: type[LLMBaseConfig], wrapper_type: str) -> RegisteredLLMClientInfo:
598
+
599
+ try:
600
+ client_info = self._llm_client_provider_to_framework[config_type][wrapper_type]
601
+ except KeyError as err:
602
+ raise KeyError(f"An invalid LLM config and wrapper combination was supplied. Config: `{config_type}`, "
603
+ f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} LLM client but "
604
+ f"there is no registered conversion from that LLM provider to LLM framework: "
605
+ f"{wrapper_type}. "
606
+ f"Please provide an LLM configuration from one of the following providers: "
607
+ f"{set(self._llm_client_provider_to_framework.keys())}") from err
608
+
609
+ return client_info
610
+
611
+ def register_embedder_provider(self, info: RegisteredEmbedderProviderInfo):
612
+
613
+ if (info.config_type in self._registered_embedder_provider_infos):
614
+ raise ValueError(f"An Embedder provider with the same config type `{info.config_type}` has already been "
615
+ "registered.")
616
+
617
+ self._registered_embedder_provider_infos[info.config_type] = info
618
+
619
+ self._registration_changed()
620
+
621
+ def get_embedder_provider(self, config_type: type[EmbedderBaseConfig]) -> RegisteredEmbedderProviderInfo:
622
+
623
+ try:
624
+ return self._registered_embedder_provider_infos[config_type]
625
+ except KeyError as err:
626
+ raise KeyError(f"Could not find a registered Embedder provider for config `{config_type}`. "
627
+ f"Registered configs: {set(self._registered_embedder_provider_infos.keys())}") from err
628
+
629
+ def get_registered_embedder_providers(self) -> list[RegisteredInfo[EmbedderBaseConfig]]:
630
+
631
+ return list(self._registered_embedder_provider_infos.values())
632
+
633
+ def register_embedder_client(self, info: RegisteredEmbedderClientInfo):
634
+
635
+ if (info.config_type in self._embedder_client_provider_to_framework
636
+ and info.llm_framework in self._embedder_client_provider_to_framework[info.config_type]):
637
+ raise ValueError(f"An Embedder client with the same config type `{info.config_type}` has already been "
638
+ "registered.")
639
+
640
+ self._embedder_client_provider_to_framework.setdefault(info.config_type, {})[info.llm_framework] = info
641
+ self._embedder_client_framework_to_provider.setdefault(info.llm_framework, {})[info.config_type] = info
642
+
643
+ self._registration_changed()
644
+
645
+ def get_embedder_client(self, config_type: type[EmbedderBaseConfig],
646
+ wrapper_type: str) -> RegisteredEmbedderClientInfo:
647
+
648
+ try:
649
+ client_info = self._embedder_client_provider_to_framework[config_type][wrapper_type]
650
+ except KeyError as err:
651
+ raise KeyError(
652
+ f"An invalid Embedder config and wrapper combination was supplied. Config: `{config_type}`, "
653
+ f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
654
+ f"there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
655
+ "Please provide an Embedder configuration from one of the following providers: "
656
+ f"{set(self._embedder_client_provider_to_framework.keys())}") from err
657
+
658
+ return client_info
659
+
660
+ def register_evaluator(self, info: RegisteredEvaluatorInfo):
661
+
662
+ if (info.config_type in self._registered_evaluator_infos):
663
+ raise ValueError(f"An Evaluator with the same config type `{info.config_type}` has already been "
664
+ "registered.")
665
+
666
+ self._registered_evaluator_infos[info.config_type] = info
667
+
668
+ self._registration_changed()
669
+
670
+ def get_evaluator(self, config_type: type[EvaluatorBaseConfig]) -> RegisteredEvaluatorInfo:
671
+
672
+ try:
673
+ return self._registered_evaluator_infos[config_type]
674
+ except KeyError as err:
675
+ raise KeyError(f"Could not find a registered Evaluator for config `{config_type}`. "
676
+ f"Registered configs: {set(self._registered_evaluator_infos.keys())}") from err
677
+
678
+ def get_registered_evaluators(self) -> list[RegisteredInfo[EvaluatorBaseConfig]]:
679
+
680
+ return list(self._registered_evaluator_infos.values())
681
+
682
+ def register_memory(self, info: RegisteredMemoryInfo):
683
+
684
+ if (info.config_type in self._registered_memory_infos):
685
+ raise ValueError(
686
+ f"A Memory client with the same config type `{info.config_type}` has already been registered.")
687
+
688
+ self._registered_memory_infos[info.config_type] = info
689
+
690
+ self._registration_changed()
691
+
692
+ def get_memory(self, config_type: type[MemoryBaseConfig]) -> RegisteredMemoryInfo:
693
+
694
+ try:
695
+ return self._registered_memory_infos[config_type]
696
+ except KeyError as err:
697
+ raise KeyError(f"Could not find a registered Memory client for config `{config_type}`. "
698
+ f"Registered configs: {set(self._registered_memory_infos.keys())}") from err
699
+
700
+ def get_registered_memorys(self) -> list[RegisteredInfo[MemoryBaseConfig]]:
701
+
702
+ return list(self._registered_memory_infos.values())
703
+
704
+ def register_object_store(self, info: RegisteredObjectStoreInfo):
705
+
706
+ if (info.config_type in self._registered_object_store_infos):
707
+ raise ValueError(f"An Object Store with the same config type `{info.config_type}` has already been "
708
+ "registered.")
709
+
710
+ self._registered_object_store_infos[info.config_type] = info
711
+
712
+ self._registration_changed()
713
+
714
+ def get_object_store(self, config_type: type[ObjectStoreBaseConfig]) -> RegisteredObjectStoreInfo:
715
+
716
+ try:
717
+ return self._registered_object_store_infos[config_type]
718
+ except KeyError as err:
719
+ raise KeyError(f"Could not find a registered Object Store for config `{config_type}`. "
720
+ f"Registered configs: {set(self._registered_object_store_infos.keys())}") from err
721
+
722
+ def get_registered_object_stores(self) -> list[RegisteredInfo[ObjectStoreBaseConfig]]:
723
+
724
+ return list(self._registered_object_store_infos.values())
725
+
726
+ def register_retriever_provider(self, info: RegisteredRetrieverProviderInfo):
727
+
728
+ if (info.config_type in self._registered_retriever_provider_infos):
729
+ raise ValueError(
730
+ f"A Retriever provider with the same config type `{info.config_type}` has already been registered")
731
+
732
+ self._registered_retriever_provider_infos[info.config_type] = info
733
+
734
+ self._registration_changed()
735
+
736
+ def get_retriever_provider(self, config_type: type[RetrieverBaseConfig]) -> RegisteredRetrieverProviderInfo:
737
+
738
+ try:
739
+ return self._registered_retriever_provider_infos[config_type]
740
+ except KeyError as err:
741
+ raise KeyError(f"Could not find a registered Retriever provider for config `{config_type}`. "
742
+ f"Registered configs: {set(self._registered_retriever_provider_infos.keys())}") from err
743
+
744
+ def get_registered_retriever_providers(self) -> list[RegisteredInfo[RetrieverBaseConfig]]:
745
+
746
+ return list(self._registered_retriever_provider_infos.values())
747
+
748
+ def register_retriever_client(self, info: RegisteredRetrieverClientInfo):
749
+
750
+ if (info.config_type in self._retriever_client_provider_to_framework
751
+ and info.llm_framework in self._retriever_client_provider_to_framework[info.config_type]):
752
+ raise ValueError(f"A Retriever client with the same config type `{info.config_type}` "
753
+ " and LLM framework `{info.llm_framework}` has already been registered.")
754
+
755
+ self._retriever_client_provider_to_framework.setdefault(info.config_type, {})[info.llm_framework] = info
756
+ self._retriever_client_framework_to_provider.setdefault(info.llm_framework, {})[info.config_type] = info
757
+
758
+ self._registration_changed()
759
+
760
+ def get_retriever_client(self, config_type: type[RetrieverBaseConfig],
761
+ wrapper_type: str | None) -> RegisteredRetrieverClientInfo:
762
+
763
+ try:
764
+ client_info = self._retriever_client_provider_to_framework[config_type][wrapper_type]
765
+ except KeyError as err:
766
+ raise KeyError(
767
+ f"An invalid Retriever config and wrapper combination was supplied. Config: `{config_type}`, "
768
+ f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
769
+ f"there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
770
+ "Please provide a Retriever configuration from one of the following providers: "
771
+ f"{set(self._retriever_client_provider_to_framework.keys())}") from err
772
+
773
+ return client_info
774
+
775
+ def register_tool_wrapper(self, registration: RegisteredToolWrapper):
776
+
777
+ if (registration.llm_framework in self._registered_tool_wrappers):
778
+ raise ValueError(f"A tool wrapper for the LLM framework `{registration.llm_framework}` has already been "
779
+ "registered.")
780
+
781
+ self._registered_tool_wrappers[registration.llm_framework] = registration
782
+
783
+ self._registration_changed()
784
+
785
+ def get_tool_wrapper(self, llm_framework: str) -> RegisteredToolWrapper:
786
+
787
+ try:
788
+ return self._registered_tool_wrappers[llm_framework]
789
+ except KeyError as err:
790
+ raise KeyError(f"Could not find a registered tool wrapper for LLM framework `{llm_framework}`. "
791
+ f"Registered LLM frameworks: {set(self._registered_tool_wrappers.keys())}") from err
792
+
793
+ def register_ttc_strategy(self, info: RegisteredTTCStrategyInfo):
794
+ if (info.config_type in self._registered_ttc_strategies):
795
+ raise ValueError(
796
+ f"An TTC strategy with the same config type `{info.config_type}` has already been registered.")
797
+
798
+ self._registered_ttc_strategies[info.config_type] = info
799
+
800
+ self._registration_changed()
801
+
802
+ def get_ttc_strategy(self, config_type: type[TTCStrategyBaseConfig]) -> RegisteredTTCStrategyInfo:
803
+ try:
804
+ strategy = self._registered_ttc_strategies[config_type]
805
+ except Exception as e:
806
+ raise KeyError(f"Could not find a registered TTC strategy for config `{config_type}`. ") from e
807
+ return strategy
808
+
809
+ def get_registered_ttc_strategies(self) -> list[RegisteredInfo[TTCStrategyBaseConfig]]:
810
+ return list(self._registered_ttc_strategies.values())
811
+
812
+ def register_registry_handler(self, info: RegisteredRegistryHandlerInfo):
813
+
814
+ if (info.config_type in self._registered_memory_infos):
815
+ raise ValueError(
816
+ f"A Registry Handler with the same config type `{info.config_type}` has already been registered.")
817
+
818
+ self._registered_registry_handler_infos[info.config_type] = info
819
+ self._registered_channel_map[info.config_type.static_type()] = info
820
+
821
+ self._registration_changed()
822
+
823
+ def get_registry_handler(self, config_type: type[RegistryHandlerBaseConfig]) -> RegisteredRegistryHandlerInfo:
824
+
825
+ try:
826
+ return self._registered_registry_handler_infos[config_type]
827
+ except KeyError as err:
828
+ raise KeyError(f"Could not find a registered Registry Handler for config `{config_type}`. "
829
+ f"Registered configs: {set(self._registered_registry_handler_infos.keys())}") from err
830
+
831
+ def get_registered_registry_handlers(self) -> list[RegisteredInfo[RegistryHandlerBaseConfig]]:
832
+
833
+ return list(self._registered_registry_handler_infos.values())
834
+
835
+ def register_package(self, package_name: str, package_version: str | None = None):
836
+
837
+ discovery_metadata = DiscoveryMetadata.from_package_name(package_name=package_name,
838
+ package_version=package_version)
839
+ package = RegisteredPackage(discovery_metadata=discovery_metadata, package_name=package_name)
840
+ self._registered_packages[package.package_name] = package
841
+
842
+ self._registration_changed()
843
+
844
+ def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
845
+
846
+ if component_type == ComponentEnum.FRONT_END:
847
+ return self._registered_front_end_infos
848
+
849
+ if component_type == ComponentEnum.AUTHENTICATION_PROVIDER:
850
+ return self._registered_auth_provider_infos
851
+
852
+ if component_type == ComponentEnum.FUNCTION:
853
+ return self._registered_functions
854
+
855
+ if component_type == ComponentEnum.FUNCTION_GROUP:
856
+ return self._registered_function_groups
857
+
858
+ if component_type == ComponentEnum.TOOL_WRAPPER:
859
+ return self._registered_tool_wrappers
860
+
861
+ if component_type == ComponentEnum.LLM_PROVIDER:
862
+ return self._registered_llm_provider_infos
863
+
864
+ if component_type == ComponentEnum.LLM_CLIENT:
865
+ leaf_llm_client_infos = {}
866
+ for framework in self._llm_client_provider_to_framework.values():
867
+ for info in framework.values():
868
+ leaf_llm_client_infos[info.discovery_metadata.component_name] = info
869
+ return leaf_llm_client_infos
870
+
871
+ if component_type == ComponentEnum.EMBEDDER_PROVIDER:
872
+ return self._registered_embedder_provider_infos
873
+
874
+ if component_type == ComponentEnum.EMBEDDER_CLIENT:
875
+ leaf_embedder_client_infos = {}
876
+ for framework in self._embedder_client_provider_to_framework.values():
877
+ for info in framework.values():
878
+ leaf_embedder_client_infos[info.discovery_metadata.component_name] = info
879
+ return leaf_embedder_client_infos
880
+
881
+ if component_type == ComponentEnum.RETRIEVER_PROVIDER:
882
+ return self._registered_retriever_provider_infos
883
+
884
+ if component_type == ComponentEnum.RETRIEVER_CLIENT:
885
+ leaf_retriever_client_infos = {}
886
+ for framework in self._retriever_client_provider_to_framework.values():
887
+ for info in framework.values():
888
+ leaf_retriever_client_infos[info.discovery_metadata.component_name] = info
889
+ return leaf_retriever_client_infos
890
+
891
+ if component_type == ComponentEnum.EVALUATOR:
892
+ return self._registered_evaluator_infos
893
+
894
+ if component_type == ComponentEnum.MEMORY:
895
+ return self._registered_memory_infos
896
+
897
+ if component_type == ComponentEnum.OBJECT_STORE:
898
+ return self._registered_object_store_infos
899
+
900
+ if component_type == ComponentEnum.REGISTRY_HANDLER:
901
+ return self._registered_registry_handler_infos
902
+
903
+ if component_type == ComponentEnum.LOGGING:
904
+ return self._registered_logging_methods
905
+
906
+ if component_type == ComponentEnum.TRACING:
907
+ return self._registered_telemetry_exporters
908
+
909
+ if component_type == ComponentEnum.PACKAGE:
910
+ return self._registered_packages
911
+
912
+ if component_type == ComponentEnum.TTC_STRATEGY:
913
+ return self._registered_ttc_strategies
914
+
915
+ raise ValueError(f"Supplied an unsupported component type {component_type}")
916
+
917
+ def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
918
+
919
+ if component_type == ComponentEnum.FUNCTION:
920
+ return [i.static_type() for i in self._registered_functions]
921
+
922
+ if component_type == ComponentEnum.FUNCTION_GROUP:
923
+ return [i.static_type() for i in self._registered_function_groups]
924
+
925
+ if component_type == ComponentEnum.TOOL_WRAPPER:
926
+ return list(self._registered_tool_wrappers)
927
+
928
+ if component_type == ComponentEnum.LLM_PROVIDER:
929
+ return [i.static_type() for i in self._registered_llm_provider_infos]
930
+
931
+ if component_type == ComponentEnum.LLM_CLIENT:
932
+ leaf_client_provider_framework_types = []
933
+ for framework in self._llm_client_provider_to_framework.values():
934
+ for info in framework.values():
935
+ leaf_client_provider_framework_types.append([info.discovery_metadata.component_name])
936
+ return leaf_client_provider_framework_types
937
+
938
+ if component_type == ComponentEnum.EMBEDDER_PROVIDER:
939
+ return [i.static_type() for i in self._registered_embedder_provider_infos]
940
+
941
+ if component_type == ComponentEnum.EMBEDDER_CLIENT:
942
+ leaf_embedder_provider_framework_types = []
943
+ for framework in self._embedder_client_provider_to_framework.values():
944
+ for info in framework.values():
945
+ leaf_embedder_provider_framework_types.append([info.discovery_metadata.component_name])
946
+ return leaf_embedder_provider_framework_types
947
+
948
+ if component_type == ComponentEnum.EVALUATOR:
949
+ return [i.static_type() for i in self._registered_evaluator_infos]
950
+
951
+ if component_type == ComponentEnum.MEMORY:
952
+ return [i.static_type() for i in self._registered_memory_infos]
953
+
954
+ if component_type == ComponentEnum.REGISTRY_HANDLER:
955
+ return [i.static_type() for i in self._registered_registry_handler_infos]
956
+
957
+ if component_type == ComponentEnum.LOGGING:
958
+ return [i.static_type() for i in self._registered_logging_methods]
959
+
960
+ if component_type == ComponentEnum.TRACING:
961
+ return [i.static_type() for i in self._registered_telemetry_exporters]
962
+
963
+ if component_type == ComponentEnum.PACKAGE:
964
+ return list(self._registered_packages)
965
+
966
+ if component_type == ComponentEnum.TTC_STRATEGY:
967
+ return [i.static_type() for i in self._registered_ttc_strategies]
968
+
969
+ raise ValueError(f"Supplied an unsupported component type {component_type}")
970
+
971
+ def get_registered_channel_info_by_channel_type(self, channel_type: str) -> RegisteredRegistryHandlerInfo:
972
+ return self._registered_channel_map[channel_type]
973
+
974
+ def _do_compute_annotation(self, cls: type[TypedBaseModelT], registrations: list[RegisteredInfo[TypedBaseModelT]]):
975
+
976
+ while (len(registrations) < 2):
977
+ registrations.append(RegisteredInfo[TypedBaseModelT](full_type=f"_ignore/{len(registrations)}",
978
+ config_type=cls))
979
+
980
+ short_names: dict[str, int] = {}
981
+ type_list: list[tuple[str, type[TypedBaseModelT]]] = []
982
+
983
+ # For all keys in the list, split the key by / and increment the count of the last element
984
+ for key in registrations:
985
+ short_names[key.local_name] = short_names.get(key.local_name, 0) + 1
986
+
987
+ type_list.append((key.full_type, key.config_type))
988
+
989
+ # Now loop again and if the short name is unique, then create two entries, for the short and full name
990
+ for key in registrations:
991
+
992
+ if (short_names[key.local_name] == 1):
993
+ type_list.append((key.local_name, key.config_type))
994
+
995
+ return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
996
+
997
+ def compute_annotation(self, cls: type[TypedBaseModelT]):
998
+
999
+ if issubclass(cls, AuthProviderBaseConfig):
1000
+ return self._do_compute_annotation(cls, self.get_registered_auth_providers())
1001
+
1002
+ if issubclass(cls, EmbedderBaseConfig):
1003
+ return self._do_compute_annotation(cls, self.get_registered_embedder_providers())
1004
+
1005
+ if issubclass(cls, EvaluatorBaseConfig):
1006
+ return self._do_compute_annotation(cls, self.get_registered_evaluators())
1007
+
1008
+ if issubclass(cls, FrontEndBaseConfig):
1009
+ return self._do_compute_annotation(cls, self.get_registered_front_ends())
1010
+
1011
+ if issubclass(cls, FunctionBaseConfig):
1012
+ return self._do_compute_annotation(cls, self.get_registered_functions())
1013
+
1014
+ if issubclass(cls, FunctionGroupBaseConfig):
1015
+ return self._do_compute_annotation(cls, self.get_registered_function_groups())
1016
+
1017
+ if issubclass(cls, LLMBaseConfig):
1018
+ return self._do_compute_annotation(cls, self.get_registered_llm_providers())
1019
+
1020
+ if issubclass(cls, MemoryBaseConfig):
1021
+ return self._do_compute_annotation(cls, self.get_registered_memorys())
1022
+
1023
+ if issubclass(cls, ObjectStoreBaseConfig):
1024
+ return self._do_compute_annotation(cls, self.get_registered_object_stores())
1025
+
1026
+ if issubclass(cls, RegistryHandlerBaseConfig):
1027
+ return self._do_compute_annotation(cls, self.get_registered_registry_handlers())
1028
+
1029
+ if issubclass(cls, RetrieverBaseConfig):
1030
+ return self._do_compute_annotation(cls, self.get_registered_retriever_providers())
1031
+
1032
+ if issubclass(cls, TelemetryExporterBaseConfig):
1033
+ return self._do_compute_annotation(cls, self.get_registered_telemetry_exporters())
1034
+
1035
+ if issubclass(cls, LoggingBaseConfig):
1036
+ return self._do_compute_annotation(cls, self.get_registered_logging_method())
1037
+
1038
+ if issubclass(cls, TTCStrategyBaseConfig):
1039
+ return self._do_compute_annotation(cls, self.get_registered_ttc_strategies())
1040
+
1041
+ raise ValueError(f"Supplied an unsupported component type {cls}")
1042
+
1043
+
1044
+ class GlobalTypeRegistry:
1045
+
1046
+ _global_registry: TypeRegistry = TypeRegistry()
1047
+
1048
+ @staticmethod
1049
+ def get() -> TypeRegistry:
1050
+ return GlobalTypeRegistry._global_registry
1051
+
1052
+ @staticmethod
1053
+ @contextmanager
1054
+ def push():
1055
+
1056
+ saved = GlobalTypeRegistry._global_registry
1057
+ registry = deepcopy(saved)
1058
+
1059
+ try:
1060
+ GlobalTypeRegistry._global_registry = registry
1061
+
1062
+ yield registry
1063
+ finally:
1064
+ GlobalTypeRegistry._global_registry = saved
1065
+ GlobalTypeRegistry._global_registry._registration_changed()
1066
+
1067
+
1068
+ # Finally, update the Config object each time the registry changes
1069
+ GlobalTypeRegistry.get().add_registration_changed_hook(lambda: Config.rebuild_annotations())