nvidia-nat 1.2.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. aiq/agent/__init__.py +0 -0
  2. aiq/agent/base.py +239 -0
  3. aiq/agent/dual_node.py +67 -0
  4. aiq/agent/react_agent/__init__.py +0 -0
  5. aiq/agent/react_agent/agent.py +355 -0
  6. aiq/agent/react_agent/output_parser.py +104 -0
  7. aiq/agent/react_agent/prompt.py +41 -0
  8. aiq/agent/react_agent/register.py +149 -0
  9. aiq/agent/reasoning_agent/__init__.py +0 -0
  10. aiq/agent/reasoning_agent/reasoning_agent.py +225 -0
  11. aiq/agent/register.py +23 -0
  12. aiq/agent/rewoo_agent/__init__.py +0 -0
  13. aiq/agent/rewoo_agent/agent.py +411 -0
  14. aiq/agent/rewoo_agent/prompt.py +108 -0
  15. aiq/agent/rewoo_agent/register.py +158 -0
  16. aiq/agent/tool_calling_agent/__init__.py +0 -0
  17. aiq/agent/tool_calling_agent/agent.py +119 -0
  18. aiq/agent/tool_calling_agent/register.py +106 -0
  19. aiq/authentication/__init__.py +14 -0
  20. aiq/authentication/api_key/__init__.py +14 -0
  21. aiq/authentication/api_key/api_key_auth_provider.py +96 -0
  22. aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
  23. aiq/authentication/api_key/register.py +26 -0
  24. aiq/authentication/exceptions/__init__.py +14 -0
  25. aiq/authentication/exceptions/api_key_exceptions.py +38 -0
  26. aiq/authentication/http_basic_auth/__init__.py +0 -0
  27. aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  28. aiq/authentication/http_basic_auth/register.py +30 -0
  29. aiq/authentication/interfaces.py +93 -0
  30. aiq/authentication/oauth2/__init__.py +14 -0
  31. aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
  32. aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  33. aiq/authentication/oauth2/register.py +25 -0
  34. aiq/authentication/register.py +21 -0
  35. aiq/builder/__init__.py +0 -0
  36. aiq/builder/builder.py +285 -0
  37. aiq/builder/component_utils.py +316 -0
  38. aiq/builder/context.py +264 -0
  39. aiq/builder/embedder.py +24 -0
  40. aiq/builder/eval_builder.py +161 -0
  41. aiq/builder/evaluator.py +29 -0
  42. aiq/builder/framework_enum.py +24 -0
  43. aiq/builder/front_end.py +73 -0
  44. aiq/builder/function.py +344 -0
  45. aiq/builder/function_base.py +380 -0
  46. aiq/builder/function_info.py +627 -0
  47. aiq/builder/intermediate_step_manager.py +174 -0
  48. aiq/builder/llm.py +25 -0
  49. aiq/builder/retriever.py +25 -0
  50. aiq/builder/user_interaction_manager.py +74 -0
  51. aiq/builder/workflow.py +148 -0
  52. aiq/builder/workflow_builder.py +1117 -0
  53. aiq/cli/__init__.py +14 -0
  54. aiq/cli/cli_utils/__init__.py +0 -0
  55. aiq/cli/cli_utils/config_override.py +231 -0
  56. aiq/cli/cli_utils/validation.py +37 -0
  57. aiq/cli/commands/__init__.py +0 -0
  58. aiq/cli/commands/configure/__init__.py +0 -0
  59. aiq/cli/commands/configure/channel/__init__.py +0 -0
  60. aiq/cli/commands/configure/channel/add.py +28 -0
  61. aiq/cli/commands/configure/channel/channel.py +36 -0
  62. aiq/cli/commands/configure/channel/remove.py +30 -0
  63. aiq/cli/commands/configure/channel/update.py +30 -0
  64. aiq/cli/commands/configure/configure.py +33 -0
  65. aiq/cli/commands/evaluate.py +139 -0
  66. aiq/cli/commands/info/__init__.py +14 -0
  67. aiq/cli/commands/info/info.py +39 -0
  68. aiq/cli/commands/info/list_channels.py +32 -0
  69. aiq/cli/commands/info/list_components.py +129 -0
  70. aiq/cli/commands/info/list_mcp.py +213 -0
  71. aiq/cli/commands/registry/__init__.py +14 -0
  72. aiq/cli/commands/registry/publish.py +88 -0
  73. aiq/cli/commands/registry/pull.py +118 -0
  74. aiq/cli/commands/registry/registry.py +38 -0
  75. aiq/cli/commands/registry/remove.py +108 -0
  76. aiq/cli/commands/registry/search.py +155 -0
  77. aiq/cli/commands/sizing/__init__.py +14 -0
  78. aiq/cli/commands/sizing/calc.py +297 -0
  79. aiq/cli/commands/sizing/sizing.py +27 -0
  80. aiq/cli/commands/start.py +246 -0
  81. aiq/cli/commands/uninstall.py +81 -0
  82. aiq/cli/commands/validate.py +47 -0
  83. aiq/cli/commands/workflow/__init__.py +14 -0
  84. aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  85. aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
  86. aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
  87. aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
  88. aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
  89. aiq/cli/commands/workflow/workflow.py +37 -0
  90. aiq/cli/commands/workflow/workflow_commands.py +313 -0
  91. aiq/cli/entrypoint.py +135 -0
  92. aiq/cli/main.py +44 -0
  93. aiq/cli/register_workflow.py +488 -0
  94. aiq/cli/type_registry.py +1000 -0
  95. aiq/data_models/__init__.py +14 -0
  96. aiq/data_models/api_server.py +694 -0
  97. aiq/data_models/authentication.py +231 -0
  98. aiq/data_models/common.py +171 -0
  99. aiq/data_models/component.py +54 -0
  100. aiq/data_models/component_ref.py +168 -0
  101. aiq/data_models/config.py +406 -0
  102. aiq/data_models/dataset_handler.py +123 -0
  103. aiq/data_models/discovery_metadata.py +335 -0
  104. aiq/data_models/embedder.py +27 -0
  105. aiq/data_models/evaluate.py +127 -0
  106. aiq/data_models/evaluator.py +26 -0
  107. aiq/data_models/front_end.py +26 -0
  108. aiq/data_models/function.py +30 -0
  109. aiq/data_models/function_dependencies.py +72 -0
  110. aiq/data_models/interactive.py +246 -0
  111. aiq/data_models/intermediate_step.py +302 -0
  112. aiq/data_models/invocation_node.py +38 -0
  113. aiq/data_models/llm.py +27 -0
  114. aiq/data_models/logging.py +26 -0
  115. aiq/data_models/memory.py +27 -0
  116. aiq/data_models/object_store.py +44 -0
  117. aiq/data_models/profiler.py +54 -0
  118. aiq/data_models/registry_handler.py +26 -0
  119. aiq/data_models/retriever.py +30 -0
  120. aiq/data_models/retry_mixin.py +35 -0
  121. aiq/data_models/span.py +187 -0
  122. aiq/data_models/step_adaptor.py +64 -0
  123. aiq/data_models/streaming.py +33 -0
  124. aiq/data_models/swe_bench_model.py +54 -0
  125. aiq/data_models/telemetry_exporter.py +26 -0
  126. aiq/data_models/ttc_strategy.py +30 -0
  127. aiq/embedder/__init__.py +0 -0
  128. aiq/embedder/langchain_client.py +41 -0
  129. aiq/embedder/nim_embedder.py +59 -0
  130. aiq/embedder/openai_embedder.py +43 -0
  131. aiq/embedder/register.py +24 -0
  132. aiq/eval/__init__.py +14 -0
  133. aiq/eval/config.py +60 -0
  134. aiq/eval/dataset_handler/__init__.py +0 -0
  135. aiq/eval/dataset_handler/dataset_downloader.py +106 -0
  136. aiq/eval/dataset_handler/dataset_filter.py +52 -0
  137. aiq/eval/dataset_handler/dataset_handler.py +254 -0
  138. aiq/eval/evaluate.py +506 -0
  139. aiq/eval/evaluator/__init__.py +14 -0
  140. aiq/eval/evaluator/base_evaluator.py +73 -0
  141. aiq/eval/evaluator/evaluator_model.py +45 -0
  142. aiq/eval/intermediate_step_adapter.py +99 -0
  143. aiq/eval/rag_evaluator/__init__.py +0 -0
  144. aiq/eval/rag_evaluator/evaluate.py +178 -0
  145. aiq/eval/rag_evaluator/register.py +143 -0
  146. aiq/eval/register.py +23 -0
  147. aiq/eval/remote_workflow.py +133 -0
  148. aiq/eval/runners/__init__.py +14 -0
  149. aiq/eval/runners/config.py +39 -0
  150. aiq/eval/runners/multi_eval_runner.py +54 -0
  151. aiq/eval/runtime_event_subscriber.py +52 -0
  152. aiq/eval/swe_bench_evaluator/__init__.py +0 -0
  153. aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
  154. aiq/eval/swe_bench_evaluator/register.py +36 -0
  155. aiq/eval/trajectory_evaluator/__init__.py +0 -0
  156. aiq/eval/trajectory_evaluator/evaluate.py +75 -0
  157. aiq/eval/trajectory_evaluator/register.py +40 -0
  158. aiq/eval/tunable_rag_evaluator/__init__.py +0 -0
  159. aiq/eval/tunable_rag_evaluator/evaluate.py +245 -0
  160. aiq/eval/tunable_rag_evaluator/register.py +52 -0
  161. aiq/eval/usage_stats.py +41 -0
  162. aiq/eval/utils/__init__.py +0 -0
  163. aiq/eval/utils/output_uploader.py +140 -0
  164. aiq/eval/utils/tqdm_position_registry.py +40 -0
  165. aiq/eval/utils/weave_eval.py +184 -0
  166. aiq/experimental/__init__.py +0 -0
  167. aiq/experimental/decorators/__init__.py +0 -0
  168. aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
  169. aiq/experimental/test_time_compute/__init__.py +0 -0
  170. aiq/experimental/test_time_compute/editing/__init__.py +0 -0
  171. aiq/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  172. aiq/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  173. aiq/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  174. aiq/experimental/test_time_compute/functions/__init__.py +0 -0
  175. aiq/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  176. aiq/experimental/test_time_compute/functions/its_tool_orchestration_function.py +205 -0
  177. aiq/experimental/test_time_compute/functions/its_tool_wrapper_function.py +146 -0
  178. aiq/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
  179. aiq/experimental/test_time_compute/models/__init__.py +0 -0
  180. aiq/experimental/test_time_compute/models/editor_config.py +132 -0
  181. aiq/experimental/test_time_compute/models/scoring_config.py +112 -0
  182. aiq/experimental/test_time_compute/models/search_config.py +120 -0
  183. aiq/experimental/test_time_compute/models/selection_config.py +154 -0
  184. aiq/experimental/test_time_compute/models/stage_enums.py +43 -0
  185. aiq/experimental/test_time_compute/models/strategy_base.py +66 -0
  186. aiq/experimental/test_time_compute/models/tool_use_config.py +41 -0
  187. aiq/experimental/test_time_compute/models/ttc_item.py +48 -0
  188. aiq/experimental/test_time_compute/register.py +36 -0
  189. aiq/experimental/test_time_compute/scoring/__init__.py +0 -0
  190. aiq/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  191. aiq/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  192. aiq/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  193. aiq/experimental/test_time_compute/search/__init__.py +0 -0
  194. aiq/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  195. aiq/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  196. aiq/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  197. aiq/experimental/test_time_compute/selection/__init__.py +0 -0
  198. aiq/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  199. aiq/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  200. aiq/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
  201. aiq/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  202. aiq/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  203. aiq/front_ends/__init__.py +14 -0
  204. aiq/front_ends/console/__init__.py +14 -0
  205. aiq/front_ends/console/authentication_flow_handler.py +233 -0
  206. aiq/front_ends/console/console_front_end_config.py +32 -0
  207. aiq/front_ends/console/console_front_end_plugin.py +96 -0
  208. aiq/front_ends/console/register.py +25 -0
  209. aiq/front_ends/cron/__init__.py +14 -0
  210. aiq/front_ends/fastapi/__init__.py +14 -0
  211. aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  212. aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  213. aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
  214. aiq/front_ends/fastapi/fastapi_front_end_config.py +234 -0
  215. aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  216. aiq/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
  217. aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1092 -0
  218. aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
  219. aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  220. aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  221. aiq/front_ends/fastapi/job_store.py +183 -0
  222. aiq/front_ends/fastapi/main.py +72 -0
  223. aiq/front_ends/fastapi/message_handler.py +298 -0
  224. aiq/front_ends/fastapi/message_validator.py +345 -0
  225. aiq/front_ends/fastapi/register.py +25 -0
  226. aiq/front_ends/fastapi/response_helpers.py +195 -0
  227. aiq/front_ends/fastapi/step_adaptor.py +321 -0
  228. aiq/front_ends/mcp/__init__.py +14 -0
  229. aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
  230. aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
  231. aiq/front_ends/mcp/register.py +27 -0
  232. aiq/front_ends/mcp/tool_converter.py +242 -0
  233. aiq/front_ends/register.py +22 -0
  234. aiq/front_ends/simple_base/__init__.py +14 -0
  235. aiq/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
  236. aiq/llm/__init__.py +0 -0
  237. aiq/llm/aws_bedrock_llm.py +57 -0
  238. aiq/llm/nim_llm.py +46 -0
  239. aiq/llm/openai_llm.py +46 -0
  240. aiq/llm/register.py +23 -0
  241. aiq/llm/utils/__init__.py +14 -0
  242. aiq/llm/utils/env_config_value.py +94 -0
  243. aiq/llm/utils/error.py +17 -0
  244. aiq/memory/__init__.py +20 -0
  245. aiq/memory/interfaces.py +183 -0
  246. aiq/memory/models.py +112 -0
  247. aiq/meta/module_to_distro.json +3 -0
  248. aiq/meta/pypi.md +58 -0
  249. aiq/object_store/__init__.py +20 -0
  250. aiq/object_store/in_memory_object_store.py +76 -0
  251. aiq/object_store/interfaces.py +84 -0
  252. aiq/object_store/models.py +36 -0
  253. aiq/object_store/register.py +20 -0
  254. aiq/observability/__init__.py +14 -0
  255. aiq/observability/exporter/__init__.py +14 -0
  256. aiq/observability/exporter/base_exporter.py +449 -0
  257. aiq/observability/exporter/exporter.py +78 -0
  258. aiq/observability/exporter/file_exporter.py +33 -0
  259. aiq/observability/exporter/processing_exporter.py +322 -0
  260. aiq/observability/exporter/raw_exporter.py +52 -0
  261. aiq/observability/exporter/span_exporter.py +265 -0
  262. aiq/observability/exporter_manager.py +335 -0
  263. aiq/observability/mixin/__init__.py +14 -0
  264. aiq/observability/mixin/batch_config_mixin.py +26 -0
  265. aiq/observability/mixin/collector_config_mixin.py +23 -0
  266. aiq/observability/mixin/file_mixin.py +288 -0
  267. aiq/observability/mixin/file_mode.py +23 -0
  268. aiq/observability/mixin/resource_conflict_mixin.py +134 -0
  269. aiq/observability/mixin/serialize_mixin.py +61 -0
  270. aiq/observability/mixin/type_introspection_mixin.py +183 -0
  271. aiq/observability/processor/__init__.py +14 -0
  272. aiq/observability/processor/batching_processor.py +310 -0
  273. aiq/observability/processor/callback_processor.py +42 -0
  274. aiq/observability/processor/intermediate_step_serializer.py +28 -0
  275. aiq/observability/processor/processor.py +71 -0
  276. aiq/observability/register.py +96 -0
  277. aiq/observability/utils/__init__.py +14 -0
  278. aiq/observability/utils/dict_utils.py +236 -0
  279. aiq/observability/utils/time_utils.py +31 -0
  280. aiq/plugins/.namespace +1 -0
  281. aiq/profiler/__init__.py +0 -0
  282. aiq/profiler/calc/__init__.py +14 -0
  283. aiq/profiler/calc/calc_runner.py +627 -0
  284. aiq/profiler/calc/calculations.py +288 -0
  285. aiq/profiler/calc/data_models.py +188 -0
  286. aiq/profiler/calc/plot.py +345 -0
  287. aiq/profiler/callbacks/__init__.py +0 -0
  288. aiq/profiler/callbacks/agno_callback_handler.py +295 -0
  289. aiq/profiler/callbacks/base_callback_class.py +20 -0
  290. aiq/profiler/callbacks/langchain_callback_handler.py +290 -0
  291. aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
  292. aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  293. aiq/profiler/callbacks/token_usage_base_model.py +27 -0
  294. aiq/profiler/data_frame_row.py +51 -0
  295. aiq/profiler/data_models.py +24 -0
  296. aiq/profiler/decorators/__init__.py +0 -0
  297. aiq/profiler/decorators/framework_wrapper.py +131 -0
  298. aiq/profiler/decorators/function_tracking.py +254 -0
  299. aiq/profiler/forecasting/__init__.py +0 -0
  300. aiq/profiler/forecasting/config.py +18 -0
  301. aiq/profiler/forecasting/model_trainer.py +75 -0
  302. aiq/profiler/forecasting/models/__init__.py +22 -0
  303. aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
  304. aiq/profiler/forecasting/models/linear_model.py +196 -0
  305. aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
  306. aiq/profiler/inference_metrics_model.py +28 -0
  307. aiq/profiler/inference_optimization/__init__.py +0 -0
  308. aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  309. aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  310. aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  311. aiq/profiler/inference_optimization/data_models.py +386 -0
  312. aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
  313. aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  314. aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
  315. aiq/profiler/inference_optimization/llm_metrics.py +212 -0
  316. aiq/profiler/inference_optimization/prompt_caching.py +163 -0
  317. aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
  318. aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
  319. aiq/profiler/intermediate_property_adapter.py +102 -0
  320. aiq/profiler/profile_runner.py +473 -0
  321. aiq/profiler/utils.py +184 -0
  322. aiq/registry_handlers/__init__.py +0 -0
  323. aiq/registry_handlers/local/__init__.py +0 -0
  324. aiq/registry_handlers/local/local_handler.py +176 -0
  325. aiq/registry_handlers/local/register_local.py +37 -0
  326. aiq/registry_handlers/metadata_factory.py +60 -0
  327. aiq/registry_handlers/package_utils.py +567 -0
  328. aiq/registry_handlers/pypi/__init__.py +0 -0
  329. aiq/registry_handlers/pypi/pypi_handler.py +251 -0
  330. aiq/registry_handlers/pypi/register_pypi.py +40 -0
  331. aiq/registry_handlers/register.py +21 -0
  332. aiq/registry_handlers/registry_handler_base.py +157 -0
  333. aiq/registry_handlers/rest/__init__.py +0 -0
  334. aiq/registry_handlers/rest/register_rest.py +56 -0
  335. aiq/registry_handlers/rest/rest_handler.py +237 -0
  336. aiq/registry_handlers/schemas/__init__.py +0 -0
  337. aiq/registry_handlers/schemas/headers.py +42 -0
  338. aiq/registry_handlers/schemas/package.py +68 -0
  339. aiq/registry_handlers/schemas/publish.py +63 -0
  340. aiq/registry_handlers/schemas/pull.py +82 -0
  341. aiq/registry_handlers/schemas/remove.py +36 -0
  342. aiq/registry_handlers/schemas/search.py +91 -0
  343. aiq/registry_handlers/schemas/status.py +47 -0
  344. aiq/retriever/__init__.py +0 -0
  345. aiq/retriever/interface.py +37 -0
  346. aiq/retriever/milvus/__init__.py +14 -0
  347. aiq/retriever/milvus/register.py +81 -0
  348. aiq/retriever/milvus/retriever.py +228 -0
  349. aiq/retriever/models.py +74 -0
  350. aiq/retriever/nemo_retriever/__init__.py +14 -0
  351. aiq/retriever/nemo_retriever/register.py +60 -0
  352. aiq/retriever/nemo_retriever/retriever.py +190 -0
  353. aiq/retriever/register.py +22 -0
  354. aiq/runtime/__init__.py +14 -0
  355. aiq/runtime/loader.py +215 -0
  356. aiq/runtime/runner.py +190 -0
  357. aiq/runtime/session.py +158 -0
  358. aiq/runtime/user_metadata.py +130 -0
  359. aiq/settings/__init__.py +0 -0
  360. aiq/settings/global_settings.py +318 -0
  361. aiq/test/.namespace +1 -0
  362. aiq/tool/__init__.py +0 -0
  363. aiq/tool/chat_completion.py +74 -0
  364. aiq/tool/code_execution/README.md +151 -0
  365. aiq/tool/code_execution/__init__.py +0 -0
  366. aiq/tool/code_execution/code_sandbox.py +267 -0
  367. aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
  368. aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  369. aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
  370. aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  371. aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  372. aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  373. aiq/tool/code_execution/register.py +74 -0
  374. aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
  375. aiq/tool/code_execution/utils.py +100 -0
  376. aiq/tool/datetime_tools.py +42 -0
  377. aiq/tool/document_search.py +141 -0
  378. aiq/tool/github_tools/__init__.py +0 -0
  379. aiq/tool/github_tools/create_github_commit.py +133 -0
  380. aiq/tool/github_tools/create_github_issue.py +87 -0
  381. aiq/tool/github_tools/create_github_pr.py +106 -0
  382. aiq/tool/github_tools/get_github_file.py +106 -0
  383. aiq/tool/github_tools/get_github_issue.py +166 -0
  384. aiq/tool/github_tools/get_github_pr.py +256 -0
  385. aiq/tool/github_tools/update_github_issue.py +100 -0
  386. aiq/tool/mcp/__init__.py +14 -0
  387. aiq/tool/mcp/exceptions.py +142 -0
  388. aiq/tool/mcp/mcp_client.py +255 -0
  389. aiq/tool/mcp/mcp_tool.py +96 -0
  390. aiq/tool/memory_tools/__init__.py +0 -0
  391. aiq/tool/memory_tools/add_memory_tool.py +79 -0
  392. aiq/tool/memory_tools/delete_memory_tool.py +67 -0
  393. aiq/tool/memory_tools/get_memory_tool.py +72 -0
  394. aiq/tool/nvidia_rag.py +95 -0
  395. aiq/tool/register.py +38 -0
  396. aiq/tool/retriever.py +89 -0
  397. aiq/tool/server_tools.py +66 -0
  398. aiq/utils/__init__.py +0 -0
  399. aiq/utils/data_models/__init__.py +0 -0
  400. aiq/utils/data_models/schema_validator.py +58 -0
  401. aiq/utils/debugging_utils.py +43 -0
  402. aiq/utils/dump_distro_mapping.py +32 -0
  403. aiq/utils/exception_handlers/__init__.py +0 -0
  404. aiq/utils/exception_handlers/automatic_retries.py +289 -0
  405. aiq/utils/exception_handlers/mcp.py +211 -0
  406. aiq/utils/exception_handlers/schemas.py +114 -0
  407. aiq/utils/io/__init__.py +0 -0
  408. aiq/utils/io/model_processing.py +28 -0
  409. aiq/utils/io/yaml_tools.py +119 -0
  410. aiq/utils/log_utils.py +37 -0
  411. aiq/utils/metadata_utils.py +74 -0
  412. aiq/utils/optional_imports.py +142 -0
  413. aiq/utils/producer_consumer_queue.py +178 -0
  414. aiq/utils/reactive/__init__.py +0 -0
  415. aiq/utils/reactive/base/__init__.py +0 -0
  416. aiq/utils/reactive/base/observable_base.py +65 -0
  417. aiq/utils/reactive/base/observer_base.py +55 -0
  418. aiq/utils/reactive/base/subject_base.py +79 -0
  419. aiq/utils/reactive/observable.py +59 -0
  420. aiq/utils/reactive/observer.py +76 -0
  421. aiq/utils/reactive/subject.py +131 -0
  422. aiq/utils/reactive/subscription.py +49 -0
  423. aiq/utils/settings/__init__.py +0 -0
  424. aiq/utils/settings/global_settings.py +197 -0
  425. aiq/utils/string_utils.py +38 -0
  426. aiq/utils/type_converter.py +290 -0
  427. aiq/utils/type_utils.py +484 -0
  428. aiq/utils/url_utils.py +27 -0
  429. nvidia_nat-1.2.0rc5.dist-info/METADATA +363 -0
  430. nvidia_nat-1.2.0rc5.dist-info/RECORD +435 -0
  431. nvidia_nat-1.2.0rc5.dist-info/WHEEL +5 -0
  432. nvidia_nat-1.2.0rc5.dist-info/entry_points.txt +20 -0
  433. nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
  434. nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE.md +201 -0
  435. nvidia_nat-1.2.0rc5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,37 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC
17
+ from abc import abstractmethod
18
+
19
+ from aiq.retriever.models import RetrieverOutput
20
+
21
+
22
+ class AIQRetriever(ABC):
23
+ """
24
+ Abstract interface for interacting with data stores.
25
+
26
+ A Retriever is resposible for retrieving data from a configured data store.
27
+
28
+ Implemntations may integrate with vector stores or other indexing backends that allow for text-based search.
29
+ """
30
+
31
+ @abstractmethod
32
+ async def search(self, query: str, **kwargs) -> RetrieverOutput:
33
+ """
34
+ Retireve max(top_k) items from the data store based on vector similarity search (implementation dependent).
35
+
36
+ """
37
+ raise NotImplementedError
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,81 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import Field
17
+ from pydantic import HttpUrl
18
+
19
+ from aiq.builder.builder import Builder
20
+ from aiq.builder.builder import LLMFrameworkEnum
21
+ from aiq.builder.retriever import RetrieverProviderInfo
22
+ from aiq.cli.register_workflow import register_retriever_client
23
+ from aiq.cli.register_workflow import register_retriever_provider
24
+ from aiq.data_models.retriever import RetrieverBaseConfig
25
+
26
+
27
+ class MilvusRetrieverConfig(RetrieverBaseConfig, name="milvus_retriever"):
28
+ """
29
+ Configuration for a Retriever which pulls data from a Milvus service.
30
+ """
31
+ uri: HttpUrl = Field(description="The uri of Milvus service")
32
+ connection_args: dict = Field(
33
+ description="Dictionary of arguments used to connect to and authenticate with the Milvus service",
34
+ default={},
35
+ )
36
+ embedding_model: str = Field(description="The name of the embedding model to use for vectorizing the query")
37
+ collection_name: str | None = Field(description="The name of the milvus collection to search", default=None)
38
+ content_field: str = Field(description="Name of the primary field to store/retrieve",
39
+ default="text",
40
+ alias="primary_field")
41
+ top_k: int | None = Field(gt=0, description="The number of results to return", default=None)
42
+ output_fields: list[str] | None = Field(
43
+ default=None,
44
+ description="A list of fields to return from the datastore. If 'None', all fields but the vector are returned.")
45
+ search_params: dict = Field(default={"metric_type": "L2"},
46
+ description="Search parameters to use when performing vector search")
47
+ vector_field: str = Field(default="vector", description="Name of the field to compare with the vectorized query")
48
+ description: str | None = Field(default=None,
49
+ description="If present it will be used as the tool description",
50
+ alias="collection_description")
51
+
52
+
53
+ @register_retriever_provider(config_type=MilvusRetrieverConfig)
54
+ async def milvus_retriever(retriever_config: MilvusRetrieverConfig, builder: Builder):
55
+ yield RetrieverProviderInfo(config=retriever_config,
56
+ description="An adapter for a Miluvs data store to use with a Retriever Client")
57
+
58
+
59
+ @register_retriever_client(config_type=MilvusRetrieverConfig, wrapper_type=None)
60
+ async def milvus_retriever_client(config: MilvusRetrieverConfig, builder: Builder):
61
+ from pymilvus import MilvusClient
62
+
63
+ from aiq.retriever.milvus.retriever import MilvusRetriever
64
+
65
+ embedder = await builder.get_embedder(embedder_name=config.embedding_model, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
66
+
67
+ milvus_client = MilvusClient(uri=str(config.uri), **config.connection_args)
68
+ retriever = MilvusRetriever(
69
+ client=milvus_client,
70
+ embedder=embedder,
71
+ content_field=config.content_field,
72
+ )
73
+
74
+ # Using parameters in the config to set default values which can be overridden during the function call.
75
+ optional_fields = ["collection_name", "top_k", "output_fields", "search_params", "vector_field"]
76
+ model_dict = config.model_dump()
77
+ optional_args = {field: model_dict[field] for field in optional_fields if model_dict[field] is not None}
78
+
79
+ retriever.bind(**optional_args)
80
+
81
+ yield retriever
@@ -0,0 +1,228 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from functools import partial
18
+
19
+ from langchain_core.embeddings import Embeddings
20
+ from pymilvus import MilvusClient
21
+ from pymilvus.client.abstract import Hit
22
+
23
+ from aiq.retriever.interface import AIQRetriever
24
+ from aiq.retriever.models import AIQDocument
25
+ from aiq.retriever.models import RetrieverError
26
+ from aiq.retriever.models import RetrieverOutput
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class CollectionNotFoundError(RetrieverError):
32
+ pass
33
+
34
+
35
+ class MilvusRetriever(AIQRetriever):
36
+ """
37
+ Client for retrieving document chunks from a Milvus vectorstore
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ client: MilvusClient,
43
+ embedder: Embeddings,
44
+ content_field: str = "text",
45
+ use_iterator: bool = False,
46
+ ) -> None:
47
+ """
48
+ Initialize the Milvus Retriever using a preconfigured MilvusClient
49
+
50
+ Args:
51
+ client (MilvusClient): Preinstantiate pymilvus.MilvusClient object.
52
+ """
53
+ self._client = client
54
+ self._embedder = embedder
55
+
56
+ if use_iterator and "search_iterator" not in dir(self._client):
57
+ raise ValueError("This version of the pymilvus.MilvusClient does not support the search iterator.")
58
+
59
+ self._search_func = self._search if not use_iterator else self._search_with_iterator
60
+ self._default_params = None
61
+ self._bound_params = []
62
+ self.content_field = content_field
63
+ logger.info("Mivlus Retriever using %s for search.", self._search_func.__name__)
64
+
65
+ def bind(self, **kwargs) -> None:
66
+ """
67
+ Bind default values to the search method. Cannot bind the 'query' parameter.
68
+
69
+ Args:
70
+ kwargs (dict): Key value pairs corresponding to the default values of search parameters.
71
+ """
72
+ if "query" in kwargs:
73
+ kwargs = {k: v for k, v in kwargs.items() if k != "query"}
74
+ self._search_func = partial(self._search_func, **kwargs)
75
+ self._bound_params = list(kwargs.keys())
76
+ logger.debug("Binding paramaters for search function: %s", kwargs)
77
+
78
+ def get_unbound_params(self) -> list[str]:
79
+ """
80
+ Returns a list of unbound parameters which will need to be passed to the search function.
81
+ """
82
+ return [param for param in ["query", "collection_name", "top_k", "filters"] if param not in self._bound_params]
83
+
84
+ def _validate_collection(self, collection_name: str) -> bool:
85
+ return collection_name in self._client.list_collections()
86
+
87
+ async def search(self, query: str, **kwargs):
88
+ return await self._search_func(query=query, **kwargs)
89
+
90
+ async def _search_with_iterator(self,
91
+ query: str,
92
+ *,
93
+ collection_name: str,
94
+ top_k: int,
95
+ filters: str | None = None,
96
+ output_fields: list[str] | None = None,
97
+ search_params: dict | None = None,
98
+ timeout: float | None = None,
99
+ vector_field_name: str | None = "vector",
100
+ distance_cutoff: float | None = None,
101
+ **kwargs):
102
+ """
103
+ Retrieve document chunks from a Milvus vectorstore using a search iterator, allowing for the retrieval of more
104
+ results.
105
+ """
106
+ logger.debug("MilvusRetriever searching query: %s, for collection: %s. Returning max %s results",
107
+ query,
108
+ collection_name,
109
+ top_k)
110
+
111
+ if not self._validate_collection(collection_name):
112
+ raise CollectionNotFoundError(f"Collection: {collection_name} does not exist")
113
+
114
+ # If no output fields are specified, return all of them
115
+ if not output_fields:
116
+ collection_schema = self._client.describe_collection(collection_name)
117
+ output_fields = [
118
+ field["name"] for field in collection_schema.get("fields") if field["name"] != vector_field_name
119
+ ]
120
+
121
+ search_vector = self._embedder.embed_query(query)
122
+
123
+ search_iterator = self._client.search_iterator(
124
+ collection_name=collection_name,
125
+ data=[search_vector],
126
+ batch_size=kwargs.get("batch_size", 1000),
127
+ filter=filters,
128
+ limit=top_k,
129
+ output_fields=output_fields,
130
+ search_params=search_params if search_params else {"metric_type": "L2"},
131
+ timeout=timeout,
132
+ anns_field=vector_field_name,
133
+ round_decimal=kwargs.get("round_decimal", -1),
134
+ partition_names=kwargs.get("partition_names", None),
135
+ )
136
+
137
+ results = []
138
+ try:
139
+ while True:
140
+ _res = search_iterator.next()
141
+ res = _res.get_res()
142
+ if len(_res) == 0:
143
+ search_iterator.close()
144
+ break
145
+
146
+ if distance_cutoff and res[0][-1].distance > distance_cutoff:
147
+ for i in range(len(res[0])):
148
+ if res[0][i].distance > distance_cutoff:
149
+ break
150
+ results.append(res[0][i])
151
+ break
152
+ results.extend(res[0])
153
+
154
+ return _wrap_milvus_results(results, content_field=self.content_field)
155
+
156
+ except Exception as e:
157
+ logger.exception("Exception when retrieving results from milvus for query %s: %s", query, e)
158
+ raise RetrieverError(f"Error when retrieving documents from {collection_name} for query '{query}'") from e
159
+
160
+ async def _search(self,
161
+ query: str,
162
+ *,
163
+ collection_name: str,
164
+ top_k: int,
165
+ filters: str | None = None,
166
+ output_fields: list[str] | None = None,
167
+ search_params: dict | None = None,
168
+ timeout: float | None = None,
169
+ vector_field_name: str | None = "vector",
170
+ **kwargs):
171
+ """
172
+ Retrieve document chunks from a Milvus vectorstore
173
+ """
174
+ logger.debug("MilvusRetriever searching query: %s, for collection: %s. Returning max %s results",
175
+ query,
176
+ collection_name,
177
+ top_k)
178
+
179
+ if not self._validate_collection(collection_name):
180
+ raise CollectionNotFoundError(f"Collection: {collection_name} does not exist")
181
+
182
+ available_fields = [v.get("name") for v in self._client.describe_collection(collection_name).get("fields", {})]
183
+
184
+ if self.content_field not in available_fields:
185
+ raise ValueError(f"The specified content field: {self.content_field} is not part of the schema.")
186
+
187
+ if vector_field_name not in available_fields:
188
+ raise ValueError(f"The specified vector field name: {vector_field_name} is not part of the schema.")
189
+
190
+ # If no output fields are specified, return all of them
191
+ if not output_fields:
192
+ output_fields = [field for field in available_fields if field != vector_field_name]
193
+
194
+ if self.content_field not in output_fields:
195
+ output_fields.append(self.content_field)
196
+
197
+ search_vector = self._embedder.embed_query(query)
198
+ res = self._client.search(
199
+ collection_name=collection_name,
200
+ data=[search_vector],
201
+ filter=filters,
202
+ output_fields=output_fields,
203
+ search_params=search_params if search_params else {"metric_type": "L2"},
204
+ timeout=timeout,
205
+ anns_field=vector_field_name,
206
+ limit=top_k,
207
+ )
208
+
209
+ return _wrap_milvus_results(res[0], content_field=self.content_field)
210
+
211
+
212
+ def _wrap_milvus_results(res: list[Hit], content_field: str):
213
+ return RetrieverOutput(results=[_wrap_milvus_single_results(r, content_field=content_field) for r in res])
214
+
215
+
216
+ def _wrap_milvus_single_results(res: Hit | dict, content_field: str) -> AIQDocument:
217
+ if not isinstance(res, (Hit, dict)):
218
+ raise ValueError(f"Milvus search returned object of type {type(res)}. Expected 'Hit' or 'dict'.")
219
+
220
+ if isinstance(res, Hit):
221
+ metadata = {k: v for k, v in res.fields.items() if k != content_field}
222
+ metadata.update({"distance": res.distance})
223
+ return AIQDocument(page_content=res.fields[content_field], metadata=metadata, document_id=res.id)
224
+
225
+ fields = res["entity"]
226
+ metadata = {k: v for k, v in fields.items() if k != content_field}
227
+ metadata.update({"distance": res.get("distance")})
228
+ return AIQDocument(page_content=fields.get(content_field), metadata=metadata, document_id=res["id"])
@@ -0,0 +1,74 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ from typing import Any
20
+
21
+ from pydantic import BaseModel
22
+ from pydantic import Field
23
+
24
+ from aiq.utils.type_converter import GlobalTypeConverter
25
+
26
+
27
+ class AIQDocument(BaseModel):
28
+ """
29
+ Object representing a retrieved document/chunk from a standard AIQ Toolkit Retriever.
30
+ """
31
+ page_content: str = Field(description="Primary content of the document to insert or retrieve")
32
+ metadata: dict[str, Any] = Field(description="Metadata dictionary attached to the AIQDocument")
33
+ document_id: str | None = Field(description="Unique ID for the document, if supported by the configured datastore",
34
+ default=None)
35
+
36
+ @classmethod
37
+ def from_dict(cls, data: dict[str, Any]) -> AIQDocument:
38
+ """
39
+ Deserialize an AIQDocument from a dictionary representation.
40
+
41
+ Args:
42
+ data (dict): A dictionary containing keys
43
+ 'page_content', 'metadata', and optionally 'document_id'.
44
+
45
+ Returns:
46
+ MemoryItem: A reconstructed MemoryItem instance.
47
+ """
48
+ return cls(**data)
49
+
50
+
51
+ class RetrieverOutput(BaseModel):
52
+ results: list[AIQDocument] = Field(description="A list of retrieved AIQDocuments")
53
+
54
+ def __len__(self):
55
+ return len(self.results)
56
+
57
+ def __str__(self):
58
+ return json.dumps(self.model_dump())
59
+
60
+
61
+ class RetrieverError(Exception):
62
+ pass
63
+
64
+
65
+ def retriever_output_to_dict(obj: RetrieverOutput) -> dict:
66
+ return obj.model_dump()
67
+
68
+
69
+ def retriever_output_to_str(obj: RetrieverOutput) -> str:
70
+ return str(obj)
71
+
72
+
73
+ GlobalTypeConverter.register_converter(retriever_output_to_dict)
74
+ GlobalTypeConverter.register_converter(retriever_output_to_str)
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,60 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import Field
17
+ from pydantic import HttpUrl
18
+
19
+ from aiq.builder.builder import Builder
20
+ from aiq.builder.retriever import RetrieverProviderInfo
21
+ from aiq.cli.register_workflow import register_retriever_client
22
+ from aiq.cli.register_workflow import register_retriever_provider
23
+ from aiq.data_models.retriever import RetrieverBaseConfig
24
+
25
+
26
+ class NemoRetrieverConfig(RetrieverBaseConfig, name="nemo_retriever"):
27
+ """
28
+ Configuration for a Retriever which pulls data from a Nemo Retriever service.
29
+ """
30
+ uri: HttpUrl = Field(description="The uri of the Nemo Retriever service.")
31
+ collection_name: str | None = Field(description="The name of the collection to search", default=None)
32
+ top_k: int | None = Field(description="The number of results to return", gt=0, le=50, default=None)
33
+ output_fields: list[str] | None = Field(
34
+ default=None,
35
+ description="A list of fields to return from the datastore. If 'None', all fields but the vector are returned.")
36
+ timeout: int = Field(default=60, description="Maximum time to wait for results to be returned from the service.")
37
+ nvidia_api_key: str | None = Field(
38
+ description="API key used to authenticate with the service. If 'None', will use ENV Variable 'NVIDIA_API_KEY'",
39
+ default=None,
40
+ )
41
+
42
+
43
+ @register_retriever_provider(config_type=NemoRetrieverConfig)
44
+ async def nemo_retriever(retriever_config: NemoRetrieverConfig, builder: Builder):
45
+ yield RetrieverProviderInfo(config=retriever_config,
46
+ description="An adapter for a Nemo data store for use with a Retriever Client")
47
+
48
+
49
+ @register_retriever_client(config_type=NemoRetrieverConfig, wrapper_type=None)
50
+ async def nemo_retriever_client(config: NemoRetrieverConfig, builder: Builder):
51
+ from aiq.retriever.nemo_retriever.retriever import NemoRetriever
52
+
53
+ retriever = NemoRetriever(**config.model_dump(exclude={"type", "top_k", "collection_name"}))
54
+ optional_fields = ["collection_name", "top_k", "output_fields"]
55
+ model_dict = config.model_dump()
56
+ optional_args = {field: model_dict[field] for field in optional_fields if model_dict[field] is not None}
57
+
58
+ retriever.bind(**optional_args)
59
+
60
+ yield retriever