nvidia-nat 1.2.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. aiq/agent/__init__.py +0 -0
  2. aiq/agent/base.py +239 -0
  3. aiq/agent/dual_node.py +67 -0
  4. aiq/agent/react_agent/__init__.py +0 -0
  5. aiq/agent/react_agent/agent.py +355 -0
  6. aiq/agent/react_agent/output_parser.py +104 -0
  7. aiq/agent/react_agent/prompt.py +41 -0
  8. aiq/agent/react_agent/register.py +149 -0
  9. aiq/agent/reasoning_agent/__init__.py +0 -0
  10. aiq/agent/reasoning_agent/reasoning_agent.py +225 -0
  11. aiq/agent/register.py +23 -0
  12. aiq/agent/rewoo_agent/__init__.py +0 -0
  13. aiq/agent/rewoo_agent/agent.py +411 -0
  14. aiq/agent/rewoo_agent/prompt.py +108 -0
  15. aiq/agent/rewoo_agent/register.py +158 -0
  16. aiq/agent/tool_calling_agent/__init__.py +0 -0
  17. aiq/agent/tool_calling_agent/agent.py +119 -0
  18. aiq/agent/tool_calling_agent/register.py +106 -0
  19. aiq/authentication/__init__.py +14 -0
  20. aiq/authentication/api_key/__init__.py +14 -0
  21. aiq/authentication/api_key/api_key_auth_provider.py +96 -0
  22. aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
  23. aiq/authentication/api_key/register.py +26 -0
  24. aiq/authentication/exceptions/__init__.py +14 -0
  25. aiq/authentication/exceptions/api_key_exceptions.py +38 -0
  26. aiq/authentication/http_basic_auth/__init__.py +0 -0
  27. aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  28. aiq/authentication/http_basic_auth/register.py +30 -0
  29. aiq/authentication/interfaces.py +93 -0
  30. aiq/authentication/oauth2/__init__.py +14 -0
  31. aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
  32. aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  33. aiq/authentication/oauth2/register.py +25 -0
  34. aiq/authentication/register.py +21 -0
  35. aiq/builder/__init__.py +0 -0
  36. aiq/builder/builder.py +285 -0
  37. aiq/builder/component_utils.py +316 -0
  38. aiq/builder/context.py +264 -0
  39. aiq/builder/embedder.py +24 -0
  40. aiq/builder/eval_builder.py +161 -0
  41. aiq/builder/evaluator.py +29 -0
  42. aiq/builder/framework_enum.py +24 -0
  43. aiq/builder/front_end.py +73 -0
  44. aiq/builder/function.py +344 -0
  45. aiq/builder/function_base.py +380 -0
  46. aiq/builder/function_info.py +627 -0
  47. aiq/builder/intermediate_step_manager.py +174 -0
  48. aiq/builder/llm.py +25 -0
  49. aiq/builder/retriever.py +25 -0
  50. aiq/builder/user_interaction_manager.py +74 -0
  51. aiq/builder/workflow.py +148 -0
  52. aiq/builder/workflow_builder.py +1117 -0
  53. aiq/cli/__init__.py +14 -0
  54. aiq/cli/cli_utils/__init__.py +0 -0
  55. aiq/cli/cli_utils/config_override.py +231 -0
  56. aiq/cli/cli_utils/validation.py +37 -0
  57. aiq/cli/commands/__init__.py +0 -0
  58. aiq/cli/commands/configure/__init__.py +0 -0
  59. aiq/cli/commands/configure/channel/__init__.py +0 -0
  60. aiq/cli/commands/configure/channel/add.py +28 -0
  61. aiq/cli/commands/configure/channel/channel.py +36 -0
  62. aiq/cli/commands/configure/channel/remove.py +30 -0
  63. aiq/cli/commands/configure/channel/update.py +30 -0
  64. aiq/cli/commands/configure/configure.py +33 -0
  65. aiq/cli/commands/evaluate.py +139 -0
  66. aiq/cli/commands/info/__init__.py +14 -0
  67. aiq/cli/commands/info/info.py +39 -0
  68. aiq/cli/commands/info/list_channels.py +32 -0
  69. aiq/cli/commands/info/list_components.py +129 -0
  70. aiq/cli/commands/info/list_mcp.py +213 -0
  71. aiq/cli/commands/registry/__init__.py +14 -0
  72. aiq/cli/commands/registry/publish.py +88 -0
  73. aiq/cli/commands/registry/pull.py +118 -0
  74. aiq/cli/commands/registry/registry.py +38 -0
  75. aiq/cli/commands/registry/remove.py +108 -0
  76. aiq/cli/commands/registry/search.py +155 -0
  77. aiq/cli/commands/sizing/__init__.py +14 -0
  78. aiq/cli/commands/sizing/calc.py +297 -0
  79. aiq/cli/commands/sizing/sizing.py +27 -0
  80. aiq/cli/commands/start.py +246 -0
  81. aiq/cli/commands/uninstall.py +81 -0
  82. aiq/cli/commands/validate.py +47 -0
  83. aiq/cli/commands/workflow/__init__.py +14 -0
  84. aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  85. aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
  86. aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
  87. aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
  88. aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
  89. aiq/cli/commands/workflow/workflow.py +37 -0
  90. aiq/cli/commands/workflow/workflow_commands.py +313 -0
  91. aiq/cli/entrypoint.py +135 -0
  92. aiq/cli/main.py +44 -0
  93. aiq/cli/register_workflow.py +488 -0
  94. aiq/cli/type_registry.py +1000 -0
  95. aiq/data_models/__init__.py +14 -0
  96. aiq/data_models/api_server.py +694 -0
  97. aiq/data_models/authentication.py +231 -0
  98. aiq/data_models/common.py +171 -0
  99. aiq/data_models/component.py +54 -0
  100. aiq/data_models/component_ref.py +168 -0
  101. aiq/data_models/config.py +406 -0
  102. aiq/data_models/dataset_handler.py +123 -0
  103. aiq/data_models/discovery_metadata.py +335 -0
  104. aiq/data_models/embedder.py +27 -0
  105. aiq/data_models/evaluate.py +127 -0
  106. aiq/data_models/evaluator.py +26 -0
  107. aiq/data_models/front_end.py +26 -0
  108. aiq/data_models/function.py +30 -0
  109. aiq/data_models/function_dependencies.py +72 -0
  110. aiq/data_models/interactive.py +246 -0
  111. aiq/data_models/intermediate_step.py +302 -0
  112. aiq/data_models/invocation_node.py +38 -0
  113. aiq/data_models/llm.py +27 -0
  114. aiq/data_models/logging.py +26 -0
  115. aiq/data_models/memory.py +27 -0
  116. aiq/data_models/object_store.py +44 -0
  117. aiq/data_models/profiler.py +54 -0
  118. aiq/data_models/registry_handler.py +26 -0
  119. aiq/data_models/retriever.py +30 -0
  120. aiq/data_models/retry_mixin.py +35 -0
  121. aiq/data_models/span.py +187 -0
  122. aiq/data_models/step_adaptor.py +64 -0
  123. aiq/data_models/streaming.py +33 -0
  124. aiq/data_models/swe_bench_model.py +54 -0
  125. aiq/data_models/telemetry_exporter.py +26 -0
  126. aiq/data_models/ttc_strategy.py +30 -0
  127. aiq/embedder/__init__.py +0 -0
  128. aiq/embedder/langchain_client.py +41 -0
  129. aiq/embedder/nim_embedder.py +59 -0
  130. aiq/embedder/openai_embedder.py +43 -0
  131. aiq/embedder/register.py +24 -0
  132. aiq/eval/__init__.py +14 -0
  133. aiq/eval/config.py +60 -0
  134. aiq/eval/dataset_handler/__init__.py +0 -0
  135. aiq/eval/dataset_handler/dataset_downloader.py +106 -0
  136. aiq/eval/dataset_handler/dataset_filter.py +52 -0
  137. aiq/eval/dataset_handler/dataset_handler.py +254 -0
  138. aiq/eval/evaluate.py +506 -0
  139. aiq/eval/evaluator/__init__.py +14 -0
  140. aiq/eval/evaluator/base_evaluator.py +73 -0
  141. aiq/eval/evaluator/evaluator_model.py +45 -0
  142. aiq/eval/intermediate_step_adapter.py +99 -0
  143. aiq/eval/rag_evaluator/__init__.py +0 -0
  144. aiq/eval/rag_evaluator/evaluate.py +178 -0
  145. aiq/eval/rag_evaluator/register.py +143 -0
  146. aiq/eval/register.py +23 -0
  147. aiq/eval/remote_workflow.py +133 -0
  148. aiq/eval/runners/__init__.py +14 -0
  149. aiq/eval/runners/config.py +39 -0
  150. aiq/eval/runners/multi_eval_runner.py +54 -0
  151. aiq/eval/runtime_event_subscriber.py +52 -0
  152. aiq/eval/swe_bench_evaluator/__init__.py +0 -0
  153. aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
  154. aiq/eval/swe_bench_evaluator/register.py +36 -0
  155. aiq/eval/trajectory_evaluator/__init__.py +0 -0
  156. aiq/eval/trajectory_evaluator/evaluate.py +75 -0
  157. aiq/eval/trajectory_evaluator/register.py +40 -0
  158. aiq/eval/tunable_rag_evaluator/__init__.py +0 -0
  159. aiq/eval/tunable_rag_evaluator/evaluate.py +245 -0
  160. aiq/eval/tunable_rag_evaluator/register.py +52 -0
  161. aiq/eval/usage_stats.py +41 -0
  162. aiq/eval/utils/__init__.py +0 -0
  163. aiq/eval/utils/output_uploader.py +140 -0
  164. aiq/eval/utils/tqdm_position_registry.py +40 -0
  165. aiq/eval/utils/weave_eval.py +184 -0
  166. aiq/experimental/__init__.py +0 -0
  167. aiq/experimental/decorators/__init__.py +0 -0
  168. aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
  169. aiq/experimental/test_time_compute/__init__.py +0 -0
  170. aiq/experimental/test_time_compute/editing/__init__.py +0 -0
  171. aiq/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  172. aiq/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  173. aiq/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  174. aiq/experimental/test_time_compute/functions/__init__.py +0 -0
  175. aiq/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  176. aiq/experimental/test_time_compute/functions/its_tool_orchestration_function.py +205 -0
  177. aiq/experimental/test_time_compute/functions/its_tool_wrapper_function.py +146 -0
  178. aiq/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
  179. aiq/experimental/test_time_compute/models/__init__.py +0 -0
  180. aiq/experimental/test_time_compute/models/editor_config.py +132 -0
  181. aiq/experimental/test_time_compute/models/scoring_config.py +112 -0
  182. aiq/experimental/test_time_compute/models/search_config.py +120 -0
  183. aiq/experimental/test_time_compute/models/selection_config.py +154 -0
  184. aiq/experimental/test_time_compute/models/stage_enums.py +43 -0
  185. aiq/experimental/test_time_compute/models/strategy_base.py +66 -0
  186. aiq/experimental/test_time_compute/models/tool_use_config.py +41 -0
  187. aiq/experimental/test_time_compute/models/ttc_item.py +48 -0
  188. aiq/experimental/test_time_compute/register.py +36 -0
  189. aiq/experimental/test_time_compute/scoring/__init__.py +0 -0
  190. aiq/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  191. aiq/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  192. aiq/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  193. aiq/experimental/test_time_compute/search/__init__.py +0 -0
  194. aiq/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  195. aiq/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  196. aiq/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  197. aiq/experimental/test_time_compute/selection/__init__.py +0 -0
  198. aiq/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  199. aiq/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  200. aiq/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
  201. aiq/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  202. aiq/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  203. aiq/front_ends/__init__.py +14 -0
  204. aiq/front_ends/console/__init__.py +14 -0
  205. aiq/front_ends/console/authentication_flow_handler.py +233 -0
  206. aiq/front_ends/console/console_front_end_config.py +32 -0
  207. aiq/front_ends/console/console_front_end_plugin.py +96 -0
  208. aiq/front_ends/console/register.py +25 -0
  209. aiq/front_ends/cron/__init__.py +14 -0
  210. aiq/front_ends/fastapi/__init__.py +14 -0
  211. aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  212. aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  213. aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
  214. aiq/front_ends/fastapi/fastapi_front_end_config.py +234 -0
  215. aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  216. aiq/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
  217. aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1092 -0
  218. aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
  219. aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  220. aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  221. aiq/front_ends/fastapi/job_store.py +183 -0
  222. aiq/front_ends/fastapi/main.py +72 -0
  223. aiq/front_ends/fastapi/message_handler.py +298 -0
  224. aiq/front_ends/fastapi/message_validator.py +345 -0
  225. aiq/front_ends/fastapi/register.py +25 -0
  226. aiq/front_ends/fastapi/response_helpers.py +195 -0
  227. aiq/front_ends/fastapi/step_adaptor.py +321 -0
  228. aiq/front_ends/mcp/__init__.py +14 -0
  229. aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
  230. aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
  231. aiq/front_ends/mcp/register.py +27 -0
  232. aiq/front_ends/mcp/tool_converter.py +242 -0
  233. aiq/front_ends/register.py +22 -0
  234. aiq/front_ends/simple_base/__init__.py +14 -0
  235. aiq/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
  236. aiq/llm/__init__.py +0 -0
  237. aiq/llm/aws_bedrock_llm.py +57 -0
  238. aiq/llm/nim_llm.py +46 -0
  239. aiq/llm/openai_llm.py +46 -0
  240. aiq/llm/register.py +23 -0
  241. aiq/llm/utils/__init__.py +14 -0
  242. aiq/llm/utils/env_config_value.py +94 -0
  243. aiq/llm/utils/error.py +17 -0
  244. aiq/memory/__init__.py +20 -0
  245. aiq/memory/interfaces.py +183 -0
  246. aiq/memory/models.py +112 -0
  247. aiq/meta/module_to_distro.json +3 -0
  248. aiq/meta/pypi.md +58 -0
  249. aiq/object_store/__init__.py +20 -0
  250. aiq/object_store/in_memory_object_store.py +76 -0
  251. aiq/object_store/interfaces.py +84 -0
  252. aiq/object_store/models.py +36 -0
  253. aiq/object_store/register.py +20 -0
  254. aiq/observability/__init__.py +14 -0
  255. aiq/observability/exporter/__init__.py +14 -0
  256. aiq/observability/exporter/base_exporter.py +449 -0
  257. aiq/observability/exporter/exporter.py +78 -0
  258. aiq/observability/exporter/file_exporter.py +33 -0
  259. aiq/observability/exporter/processing_exporter.py +322 -0
  260. aiq/observability/exporter/raw_exporter.py +52 -0
  261. aiq/observability/exporter/span_exporter.py +265 -0
  262. aiq/observability/exporter_manager.py +335 -0
  263. aiq/observability/mixin/__init__.py +14 -0
  264. aiq/observability/mixin/batch_config_mixin.py +26 -0
  265. aiq/observability/mixin/collector_config_mixin.py +23 -0
  266. aiq/observability/mixin/file_mixin.py +288 -0
  267. aiq/observability/mixin/file_mode.py +23 -0
  268. aiq/observability/mixin/resource_conflict_mixin.py +134 -0
  269. aiq/observability/mixin/serialize_mixin.py +61 -0
  270. aiq/observability/mixin/type_introspection_mixin.py +183 -0
  271. aiq/observability/processor/__init__.py +14 -0
  272. aiq/observability/processor/batching_processor.py +310 -0
  273. aiq/observability/processor/callback_processor.py +42 -0
  274. aiq/observability/processor/intermediate_step_serializer.py +28 -0
  275. aiq/observability/processor/processor.py +71 -0
  276. aiq/observability/register.py +96 -0
  277. aiq/observability/utils/__init__.py +14 -0
  278. aiq/observability/utils/dict_utils.py +236 -0
  279. aiq/observability/utils/time_utils.py +31 -0
  280. aiq/plugins/.namespace +1 -0
  281. aiq/profiler/__init__.py +0 -0
  282. aiq/profiler/calc/__init__.py +14 -0
  283. aiq/profiler/calc/calc_runner.py +627 -0
  284. aiq/profiler/calc/calculations.py +288 -0
  285. aiq/profiler/calc/data_models.py +188 -0
  286. aiq/profiler/calc/plot.py +345 -0
  287. aiq/profiler/callbacks/__init__.py +0 -0
  288. aiq/profiler/callbacks/agno_callback_handler.py +295 -0
  289. aiq/profiler/callbacks/base_callback_class.py +20 -0
  290. aiq/profiler/callbacks/langchain_callback_handler.py +290 -0
  291. aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
  292. aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  293. aiq/profiler/callbacks/token_usage_base_model.py +27 -0
  294. aiq/profiler/data_frame_row.py +51 -0
  295. aiq/profiler/data_models.py +24 -0
  296. aiq/profiler/decorators/__init__.py +0 -0
  297. aiq/profiler/decorators/framework_wrapper.py +131 -0
  298. aiq/profiler/decorators/function_tracking.py +254 -0
  299. aiq/profiler/forecasting/__init__.py +0 -0
  300. aiq/profiler/forecasting/config.py +18 -0
  301. aiq/profiler/forecasting/model_trainer.py +75 -0
  302. aiq/profiler/forecasting/models/__init__.py +22 -0
  303. aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
  304. aiq/profiler/forecasting/models/linear_model.py +196 -0
  305. aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
  306. aiq/profiler/inference_metrics_model.py +28 -0
  307. aiq/profiler/inference_optimization/__init__.py +0 -0
  308. aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  309. aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  310. aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  311. aiq/profiler/inference_optimization/data_models.py +386 -0
  312. aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
  313. aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  314. aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
  315. aiq/profiler/inference_optimization/llm_metrics.py +212 -0
  316. aiq/profiler/inference_optimization/prompt_caching.py +163 -0
  317. aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
  318. aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
  319. aiq/profiler/intermediate_property_adapter.py +102 -0
  320. aiq/profiler/profile_runner.py +473 -0
  321. aiq/profiler/utils.py +184 -0
  322. aiq/registry_handlers/__init__.py +0 -0
  323. aiq/registry_handlers/local/__init__.py +0 -0
  324. aiq/registry_handlers/local/local_handler.py +176 -0
  325. aiq/registry_handlers/local/register_local.py +37 -0
  326. aiq/registry_handlers/metadata_factory.py +60 -0
  327. aiq/registry_handlers/package_utils.py +567 -0
  328. aiq/registry_handlers/pypi/__init__.py +0 -0
  329. aiq/registry_handlers/pypi/pypi_handler.py +251 -0
  330. aiq/registry_handlers/pypi/register_pypi.py +40 -0
  331. aiq/registry_handlers/register.py +21 -0
  332. aiq/registry_handlers/registry_handler_base.py +157 -0
  333. aiq/registry_handlers/rest/__init__.py +0 -0
  334. aiq/registry_handlers/rest/register_rest.py +56 -0
  335. aiq/registry_handlers/rest/rest_handler.py +237 -0
  336. aiq/registry_handlers/schemas/__init__.py +0 -0
  337. aiq/registry_handlers/schemas/headers.py +42 -0
  338. aiq/registry_handlers/schemas/package.py +68 -0
  339. aiq/registry_handlers/schemas/publish.py +63 -0
  340. aiq/registry_handlers/schemas/pull.py +82 -0
  341. aiq/registry_handlers/schemas/remove.py +36 -0
  342. aiq/registry_handlers/schemas/search.py +91 -0
  343. aiq/registry_handlers/schemas/status.py +47 -0
  344. aiq/retriever/__init__.py +0 -0
  345. aiq/retriever/interface.py +37 -0
  346. aiq/retriever/milvus/__init__.py +14 -0
  347. aiq/retriever/milvus/register.py +81 -0
  348. aiq/retriever/milvus/retriever.py +228 -0
  349. aiq/retriever/models.py +74 -0
  350. aiq/retriever/nemo_retriever/__init__.py +14 -0
  351. aiq/retriever/nemo_retriever/register.py +60 -0
  352. aiq/retriever/nemo_retriever/retriever.py +190 -0
  353. aiq/retriever/register.py +22 -0
  354. aiq/runtime/__init__.py +14 -0
  355. aiq/runtime/loader.py +215 -0
  356. aiq/runtime/runner.py +190 -0
  357. aiq/runtime/session.py +158 -0
  358. aiq/runtime/user_metadata.py +130 -0
  359. aiq/settings/__init__.py +0 -0
  360. aiq/settings/global_settings.py +318 -0
  361. aiq/test/.namespace +1 -0
  362. aiq/tool/__init__.py +0 -0
  363. aiq/tool/chat_completion.py +74 -0
  364. aiq/tool/code_execution/README.md +151 -0
  365. aiq/tool/code_execution/__init__.py +0 -0
  366. aiq/tool/code_execution/code_sandbox.py +267 -0
  367. aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
  368. aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  369. aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
  370. aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  371. aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  372. aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  373. aiq/tool/code_execution/register.py +74 -0
  374. aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
  375. aiq/tool/code_execution/utils.py +100 -0
  376. aiq/tool/datetime_tools.py +42 -0
  377. aiq/tool/document_search.py +141 -0
  378. aiq/tool/github_tools/__init__.py +0 -0
  379. aiq/tool/github_tools/create_github_commit.py +133 -0
  380. aiq/tool/github_tools/create_github_issue.py +87 -0
  381. aiq/tool/github_tools/create_github_pr.py +106 -0
  382. aiq/tool/github_tools/get_github_file.py +106 -0
  383. aiq/tool/github_tools/get_github_issue.py +166 -0
  384. aiq/tool/github_tools/get_github_pr.py +256 -0
  385. aiq/tool/github_tools/update_github_issue.py +100 -0
  386. aiq/tool/mcp/__init__.py +14 -0
  387. aiq/tool/mcp/exceptions.py +142 -0
  388. aiq/tool/mcp/mcp_client.py +255 -0
  389. aiq/tool/mcp/mcp_tool.py +96 -0
  390. aiq/tool/memory_tools/__init__.py +0 -0
  391. aiq/tool/memory_tools/add_memory_tool.py +79 -0
  392. aiq/tool/memory_tools/delete_memory_tool.py +67 -0
  393. aiq/tool/memory_tools/get_memory_tool.py +72 -0
  394. aiq/tool/nvidia_rag.py +95 -0
  395. aiq/tool/register.py +38 -0
  396. aiq/tool/retriever.py +89 -0
  397. aiq/tool/server_tools.py +66 -0
  398. aiq/utils/__init__.py +0 -0
  399. aiq/utils/data_models/__init__.py +0 -0
  400. aiq/utils/data_models/schema_validator.py +58 -0
  401. aiq/utils/debugging_utils.py +43 -0
  402. aiq/utils/dump_distro_mapping.py +32 -0
  403. aiq/utils/exception_handlers/__init__.py +0 -0
  404. aiq/utils/exception_handlers/automatic_retries.py +289 -0
  405. aiq/utils/exception_handlers/mcp.py +211 -0
  406. aiq/utils/exception_handlers/schemas.py +114 -0
  407. aiq/utils/io/__init__.py +0 -0
  408. aiq/utils/io/model_processing.py +28 -0
  409. aiq/utils/io/yaml_tools.py +119 -0
  410. aiq/utils/log_utils.py +37 -0
  411. aiq/utils/metadata_utils.py +74 -0
  412. aiq/utils/optional_imports.py +142 -0
  413. aiq/utils/producer_consumer_queue.py +178 -0
  414. aiq/utils/reactive/__init__.py +0 -0
  415. aiq/utils/reactive/base/__init__.py +0 -0
  416. aiq/utils/reactive/base/observable_base.py +65 -0
  417. aiq/utils/reactive/base/observer_base.py +55 -0
  418. aiq/utils/reactive/base/subject_base.py +79 -0
  419. aiq/utils/reactive/observable.py +59 -0
  420. aiq/utils/reactive/observer.py +76 -0
  421. aiq/utils/reactive/subject.py +131 -0
  422. aiq/utils/reactive/subscription.py +49 -0
  423. aiq/utils/settings/__init__.py +0 -0
  424. aiq/utils/settings/global_settings.py +197 -0
  425. aiq/utils/string_utils.py +38 -0
  426. aiq/utils/type_converter.py +290 -0
  427. aiq/utils/type_utils.py +484 -0
  428. aiq/utils/url_utils.py +27 -0
  429. nvidia_nat-1.2.0rc5.dist-info/METADATA +363 -0
  430. nvidia_nat-1.2.0rc5.dist-info/RECORD +435 -0
  431. nvidia_nat-1.2.0rc5.dist-info/WHEEL +5 -0
  432. nvidia_nat-1.2.0rc5.dist-info/entry_points.txt +20 -0
  433. nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
  434. nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE.md +201 -0
  435. nvidia_nat-1.2.0rc5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,406 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import sys
18
+ import typing
19
+
20
+ from pydantic import BaseModel
21
+ from pydantic import ConfigDict
22
+ from pydantic import Discriminator
23
+ from pydantic import ValidationError
24
+ from pydantic import ValidationInfo
25
+ from pydantic import ValidatorFunctionWrapHandler
26
+ from pydantic import field_validator
27
+
28
+ from aiq.data_models.evaluate import EvalConfig
29
+ from aiq.data_models.front_end import FrontEndBaseConfig
30
+ from aiq.data_models.function import EmptyFunctionConfig
31
+ from aiq.data_models.function import FunctionBaseConfig
32
+ from aiq.data_models.logging import LoggingBaseConfig
33
+ from aiq.data_models.telemetry_exporter import TelemetryExporterBaseConfig
34
+ from aiq.data_models.ttc_strategy import TTCStrategyBaseConfig
35
+ from aiq.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
36
+
37
+ from .authentication import AuthProviderBaseConfig
38
+ from .common import HashableBaseModel
39
+ from .common import TypedBaseModel
40
+ from .embedder import EmbedderBaseConfig
41
+ from .llm import LLMBaseConfig
42
+ from .memory import MemoryBaseConfig
43
+ from .object_store import ObjectStoreBaseConfig
44
+ from .retriever import RetrieverBaseConfig
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
50
+ from aiq.cli.type_registry import GlobalTypeRegistry # pylint: disable=cyclic-import
51
+
52
+ new_errors = []
53
+ logged_once = False
54
+ needs_reraise = False
55
+ for e in err.errors():
56
+
57
+ error_type = e['type']
58
+ if error_type == 'union_tag_invalid' and "ctx" in e and not logged_once:
59
+ requested_type = e["ctx"]["tag"]
60
+
61
+ if (info.field_name in ('workflow', 'functions')):
62
+ registered_keys = GlobalTypeRegistry.get().get_registered_functions()
63
+ elif (info.field_name == "authentication"):
64
+ registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
65
+ elif (info.field_name == "llms"):
66
+ registered_keys = GlobalTypeRegistry.get().get_registered_llm_providers()
67
+ elif (info.field_name == "embedders"):
68
+ registered_keys = GlobalTypeRegistry.get().get_registered_embedder_providers()
69
+ elif (info.field_name == "memory"):
70
+ registered_keys = GlobalTypeRegistry.get().get_registered_memorys()
71
+ elif (info.field_name == "object_stores"):
72
+ registered_keys = GlobalTypeRegistry.get().get_registered_object_stores()
73
+ elif (info.field_name == "retrievers"):
74
+ registered_keys = GlobalTypeRegistry.get().get_registered_retriever_providers()
75
+ elif (info.field_name == "tracing"):
76
+ registered_keys = GlobalTypeRegistry.get().get_registered_telemetry_exporters()
77
+ elif (info.field_name == "logging"):
78
+ registered_keys = GlobalTypeRegistry.get().get_registered_logging_method()
79
+ elif (info.field_name == "evaluators"):
80
+ registered_keys = GlobalTypeRegistry.get().get_registered_evaluators()
81
+ elif (info.field_name == "front_ends"):
82
+ registered_keys = GlobalTypeRegistry.get().get_registered_front_ends()
83
+ elif (info.field_name == "ttc_strategies"):
84
+ registered_keys = GlobalTypeRegistry.get().get_registered_ttc_strategies()
85
+
86
+ else:
87
+ assert False, f"Unknown field name {info.field_name} in validator"
88
+
89
+ # Check and see if the there are multiple full types which match this short type
90
+ matching_keys = [k for k in registered_keys if k.local_name == requested_type]
91
+
92
+ assert len(matching_keys) != 1, "Exact match should have been found. Contact developers"
93
+
94
+ matching_key_names = [x.full_type for x in matching_keys]
95
+ registered_key_names = [x.full_type for x in registered_keys]
96
+
97
+ if (len(matching_keys) == 0):
98
+ # This is a case where the requested type is not found. Show a helpful message about what is
99
+ # available
100
+ logger.error(("Requested %s type `%s` not found. "
101
+ "Have you ensured the necessary package has been installed with `uv pip install`?"
102
+ "\nAvailable %s names:\n - %s\n"),
103
+ info.field_name,
104
+ requested_type,
105
+ info.field_name,
106
+ '\n - '.join(registered_key_names))
107
+ else:
108
+ # This is a case where the requested type is ambiguous.
109
+ logger.error(("Requested %s type `%s` is ambiguous. "
110
+ "Matched multiple %s by their local name: %s. "
111
+ "Please use the fully qualified %s name."
112
+ "\nAvailable %s names:\n - %s\n"),
113
+ info.field_name,
114
+ requested_type,
115
+ info.field_name,
116
+ matching_key_names,
117
+ info.field_name,
118
+ info.field_name,
119
+ '\n - '.join(registered_key_names))
120
+
121
+ # Only show one error
122
+ logged_once = True
123
+
124
+ elif error_type == 'missing':
125
+ location = e["loc"]
126
+ if len(location) > 1: # remove the _type field from the location
127
+ e['loc'] = (location[0], ) + location[2:]
128
+ needs_reraise = True
129
+
130
+ new_errors.append(e)
131
+
132
+ if needs_reraise:
133
+ raise ValidationError.from_exception_data(title=err.title, line_errors=new_errors)
134
+
135
+
136
+ class TelemetryConfig(BaseModel):
137
+
138
+ logging: dict[str, LoggingBaseConfig] = {}
139
+ tracing: dict[str, TelemetryExporterBaseConfig] = {}
140
+
141
+ @field_validator("logging", "tracing", mode="wrap")
142
+ @classmethod
143
+ def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
144
+
145
+ try:
146
+ return handler(value)
147
+ except ValidationError as err:
148
+ _process_validation_error(err, handler, info)
149
+ raise
150
+
151
+ @classmethod
152
+ def rebuild_annotations(cls):
153
+
154
+ from aiq.cli.type_registry import GlobalTypeRegistry
155
+
156
+ type_registry = GlobalTypeRegistry.get()
157
+
158
+ TracingAnnotation = dict[str,
159
+ typing.Annotated[type_registry.compute_annotation(TelemetryExporterBaseConfig),
160
+ Discriminator(TypedBaseModel.discriminator)]]
161
+
162
+ LoggingAnnotation = dict[str,
163
+ typing.Annotated[type_registry.compute_annotation(LoggingBaseConfig),
164
+ Discriminator(TypedBaseModel.discriminator)]]
165
+
166
+ should_rebuild = False
167
+
168
+ tracing_field = cls.model_fields.get("tracing")
169
+ if tracing_field is not None and tracing_field.annotation != TracingAnnotation:
170
+ tracing_field.annotation = TracingAnnotation
171
+ should_rebuild = True
172
+
173
+ logging_field = cls.model_fields.get("logging")
174
+ if logging_field is not None and logging_field.annotation != LoggingAnnotation:
175
+ logging_field.annotation = LoggingAnnotation
176
+ should_rebuild = True
177
+
178
+ if (should_rebuild):
179
+ return cls.model_rebuild(force=True)
180
+
181
+ return False
182
+
183
+
184
+ class GeneralConfig(BaseModel):
185
+
186
+ model_config = ConfigDict(protected_namespaces=())
187
+
188
+ use_uvloop: bool = True
189
+ """
190
+ Whether to use uvloop for the event loop. This can provide a significant speedup in some cases. Disable to provide
191
+ better error messages when debugging.
192
+ """
193
+
194
+ telemetry: TelemetryConfig = TelemetryConfig()
195
+
196
+ # FrontEnd Configuration
197
+ front_end: FrontEndBaseConfig = FastApiFrontEndConfig()
198
+
199
+ @field_validator("front_end", mode="wrap")
200
+ @classmethod
201
+ def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
202
+
203
+ try:
204
+ return handler(value)
205
+ except ValidationError as err:
206
+ _process_validation_error(err, handler, info)
207
+ raise
208
+
209
+ @classmethod
210
+ def rebuild_annotations(cls):
211
+
212
+ from aiq.cli.type_registry import GlobalTypeRegistry
213
+
214
+ type_registry = GlobalTypeRegistry.get()
215
+
216
+ FrontEndAnnotation = typing.Annotated[type_registry.compute_annotation(FrontEndBaseConfig),
217
+ Discriminator(TypedBaseModel.discriminator)]
218
+
219
+ should_rebuild = False
220
+
221
+ front_end_field = cls.model_fields.get("front_end")
222
+ if front_end_field is not None and front_end_field.annotation != FrontEndAnnotation:
223
+ front_end_field.annotation = FrontEndAnnotation
224
+ should_rebuild = True
225
+
226
+ if (TelemetryConfig.rebuild_annotations()):
227
+ should_rebuild = True
228
+
229
+ if (should_rebuild):
230
+ return cls.model_rebuild(force=True)
231
+
232
+ return False
233
+
234
+
235
+ class AIQConfig(HashableBaseModel):
236
+
237
+ model_config = ConfigDict(extra="forbid")
238
+
239
+ # Global Options
240
+ general: GeneralConfig = GeneralConfig()
241
+
242
+ # Functions Configuration
243
+ functions: dict[str, FunctionBaseConfig] = {}
244
+
245
+ # LLMs Configuration
246
+ llms: dict[str, LLMBaseConfig] = {}
247
+
248
+ # Embedders Configuration
249
+ embedders: dict[str, EmbedderBaseConfig] = {}
250
+
251
+ # Memory Configuration
252
+ memory: dict[str, MemoryBaseConfig] = {}
253
+
254
+ # Object Stores Configuration
255
+ object_stores: dict[str, ObjectStoreBaseConfig] = {}
256
+
257
+ # Retriever Configuration
258
+ retrievers: dict[str, RetrieverBaseConfig] = {}
259
+
260
+ # TTC Strategies
261
+ ttc_strategies: dict[str, TTCStrategyBaseConfig] = {}
262
+
263
+ # Workflow Configuration
264
+ workflow: FunctionBaseConfig = EmptyFunctionConfig()
265
+
266
+ # Authentication Configuration
267
+ authentication: dict[str, AuthProviderBaseConfig] = {}
268
+
269
+ # Evaluation Options
270
+ eval: EvalConfig = EvalConfig()
271
+
272
+ def print_summary(self, stream: typing.TextIO = sys.stdout):
273
+ """Print a summary of the configuration"""
274
+
275
+ stream.write("\nConfiguration Summary:\n")
276
+ stream.write("-" * 20 + "\n")
277
+ if self.workflow:
278
+ stream.write(f"Workflow Type: {self.workflow.type}\n")
279
+
280
+ stream.write(f"Number of Functions: {len(self.functions)}\n")
281
+ stream.write(f"Number of LLMs: {len(self.llms)}\n")
282
+ stream.write(f"Number of Embedders: {len(self.embedders)}\n")
283
+ stream.write(f"Number of Memory: {len(self.memory)}\n")
284
+ stream.write(f"Number of Object Stores: {len(self.object_stores)}\n")
285
+ stream.write(f"Number of Retrievers: {len(self.retrievers)}\n")
286
+ stream.write(f"Number of TTC Strategies: {len(self.ttc_strategies)}\n")
287
+ stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
288
+
289
+ @field_validator("functions",
290
+ "llms",
291
+ "embedders",
292
+ "memory",
293
+ "retrievers",
294
+ "workflow",
295
+ "ttc_strategies",
296
+ "authentication",
297
+ mode="wrap")
298
+ @classmethod
299
+ def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
300
+
301
+ try:
302
+ return handler(value)
303
+ except ValidationError as err:
304
+ _process_validation_error(err, handler, info)
305
+ raise
306
+
307
+ @classmethod
308
+ def rebuild_annotations(cls):
309
+
310
+ from aiq.cli.type_registry import GlobalTypeRegistry
311
+
312
+ type_registry = GlobalTypeRegistry.get()
313
+
314
+ LLMsAnnotation = dict[str,
315
+ typing.Annotated[type_registry.compute_annotation(LLMBaseConfig),
316
+ Discriminator(TypedBaseModel.discriminator)]]
317
+
318
+ AuthenticationProviderAnnotation = dict[str,
319
+ typing.Annotated[
320
+ type_registry.compute_annotation(AuthProviderBaseConfig),
321
+ Discriminator(TypedBaseModel.discriminator)]]
322
+
323
+ EmbeddersAnnotation = dict[str,
324
+ typing.Annotated[type_registry.compute_annotation(EmbedderBaseConfig),
325
+ Discriminator(TypedBaseModel.discriminator)]]
326
+
327
+ FunctionsAnnotation = dict[str,
328
+ typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
329
+ Discriminator(TypedBaseModel.discriminator)]]
330
+
331
+ MemoryAnnotation = dict[str,
332
+ typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
333
+ Discriminator(TypedBaseModel.discriminator)]]
334
+
335
+ ObjectStoreAnnotation = dict[str,
336
+ typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
337
+ Discriminator(TypedBaseModel.discriminator)]]
338
+
339
+ RetrieverAnnotation = dict[str,
340
+ typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
341
+ Discriminator(TypedBaseModel.discriminator)]]
342
+
343
+ TTCStrategyAnnotation = dict[str,
344
+ typing.Annotated[type_registry.compute_annotation(TTCStrategyBaseConfig),
345
+ Discriminator(TypedBaseModel.discriminator)]]
346
+
347
+ WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
348
+ Discriminator(TypedBaseModel.discriminator)]
349
+
350
+ should_rebuild = False
351
+
352
+ auth_providers_field = cls.model_fields.get("authentication")
353
+ if auth_providers_field is not None and auth_providers_field.annotation != AuthenticationProviderAnnotation:
354
+ auth_providers_field.annotation = AuthenticationProviderAnnotation
355
+ should_rebuild = True
356
+
357
+ llms_field = cls.model_fields.get("llms")
358
+ if llms_field is not None and llms_field.annotation != LLMsAnnotation:
359
+ llms_field.annotation = LLMsAnnotation
360
+ should_rebuild = True
361
+
362
+ embedders_field = cls.model_fields.get("embedders")
363
+ if embedders_field is not None and embedders_field.annotation != EmbeddersAnnotation:
364
+ embedders_field.annotation = EmbeddersAnnotation
365
+ should_rebuild = True
366
+
367
+ functions_field = cls.model_fields.get("functions")
368
+ if functions_field is not None and functions_field.annotation != FunctionsAnnotation:
369
+ functions_field.annotation = FunctionsAnnotation
370
+ should_rebuild = True
371
+
372
+ memory_field = cls.model_fields.get("memory")
373
+ if memory_field is not None and memory_field.annotation != MemoryAnnotation:
374
+ memory_field.annotation = MemoryAnnotation
375
+ should_rebuild = True
376
+
377
+ object_stores_field = cls.model_fields.get("object_stores")
378
+ if object_stores_field is not None and object_stores_field.annotation != ObjectStoreAnnotation:
379
+ object_stores_field.annotation = ObjectStoreAnnotation
380
+ should_rebuild = True
381
+
382
+ retrievers_field = cls.model_fields.get("retrievers")
383
+ if retrievers_field is not None and retrievers_field.annotation != RetrieverAnnotation:
384
+ retrievers_field.annotation = RetrieverAnnotation
385
+ should_rebuild = True
386
+
387
+ ttc_strategies_field = cls.model_fields.get("ttc_strategies")
388
+ if ttc_strategies_field is not None and ttc_strategies_field.annotation != TTCStrategyAnnotation:
389
+ ttc_strategies_field.annotation = TTCStrategyAnnotation
390
+ should_rebuild = True
391
+
392
+ workflow_field = cls.model_fields.get("workflow")
393
+ if workflow_field is not None and workflow_field.annotation != WorkflowAnnotation:
394
+ workflow_field.annotation = WorkflowAnnotation
395
+ should_rebuild = True
396
+
397
+ if (GeneralConfig.rebuild_annotations()):
398
+ should_rebuild = True
399
+
400
+ if (EvalConfig.rebuild_annotations()):
401
+ should_rebuild = True
402
+
403
+ if (should_rebuild):
404
+ return cls.model_rebuild(force=True)
405
+
406
+ return False
@@ -0,0 +1,123 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import typing
18
+ from collections.abc import Callable
19
+ from pathlib import Path
20
+
21
+ import pandas as pd
22
+ from pydantic import BaseModel
23
+ from pydantic import Discriminator
24
+ from pydantic import FilePath
25
+ from pydantic import Tag
26
+
27
+ from aiq.data_models.common import BaseModelRegistryTag
28
+ from aiq.data_models.common import TypedBaseModel
29
+
30
+
31
+ class EvalS3Config(BaseModel):
32
+
33
+ endpoint_url: str | None = None
34
+ region_name: str | None = None
35
+ bucket: str
36
+ access_key: str
37
+ secret_key: str
38
+
39
+
40
+ class EvalFilterEntryConfig(BaseModel):
41
+ # values are lists of allowed/blocked values
42
+ field: dict[str, list[str | int | float]] = {}
43
+
44
+
45
+ class EvalFilterConfig(BaseModel):
46
+ allowlist: EvalFilterEntryConfig | None = None
47
+ denylist: EvalFilterEntryConfig | None = None
48
+
49
+
50
+ class EvalDatasetStructureConfig(BaseModel):
51
+ disable: bool = False
52
+ question_key: str = "question"
53
+ answer_key: str = "answer"
54
+ generated_answer_key: str = "generated_answer"
55
+ trajectory_key: str = "intermediate_steps"
56
+ expected_trajectory_key: str = "expected_intermediate_steps"
57
+
58
+
59
+ # Base model
60
+ class EvalDatasetBaseConfig(TypedBaseModel, BaseModelRegistryTag):
61
+
62
+ id_key: str = "id"
63
+ structure: EvalDatasetStructureConfig = EvalDatasetStructureConfig()
64
+
65
+ # Filters
66
+ filter: EvalFilterConfig | None = EvalFilterConfig()
67
+
68
+ s3: EvalS3Config | None = None
69
+
70
+ remote_file_path: str | None = None # only for s3
71
+ file_path: Path | str = Path(".tmp/aiq/examples/default/default.json")
72
+
73
+
74
+ class EvalDatasetJsonConfig(EvalDatasetBaseConfig, name="json"):
75
+
76
+ @staticmethod
77
+ def parser() -> tuple[Callable, dict]:
78
+ return pd.read_json, {}
79
+
80
+
81
+ def read_jsonl(file_path: FilePath, **kwargs):
82
+ with open(file_path, 'r', encoding='utf-8') as f:
83
+ data = [json.loads(line) for line in f]
84
+ return pd.DataFrame(data)
85
+
86
+
87
+ class EvalDatasetJsonlConfig(EvalDatasetBaseConfig, name="jsonl"):
88
+
89
+ @staticmethod
90
+ def parser() -> tuple[Callable, dict]:
91
+ return read_jsonl, {}
92
+
93
+
94
+ class EvalDatasetCsvConfig(EvalDatasetBaseConfig, name="csv"):
95
+
96
+ @staticmethod
97
+ def parser() -> tuple[Callable, dict]:
98
+ return pd.read_csv, {}
99
+
100
+
101
+ class EvalDatasetParquetConfig(EvalDatasetBaseConfig, name="parquet"):
102
+
103
+ @staticmethod
104
+ def parser() -> tuple[Callable, dict]:
105
+ return pd.read_parquet, {}
106
+
107
+
108
+ class EvalDatasetXlsConfig(EvalDatasetBaseConfig, name="xls"):
109
+
110
+ @staticmethod
111
+ def parser() -> tuple[Callable, dict]:
112
+ return pd.read_excel, {"engine": "openpyxl"}
113
+
114
+
115
+ # Union model with discriminator
116
+ EvalDatasetConfig = typing.Annotated[typing.Annotated[EvalDatasetJsonConfig, Tag(EvalDatasetJsonConfig.static_type())]
117
+ | typing.Annotated[EvalDatasetCsvConfig, Tag(EvalDatasetCsvConfig.static_type())]
118
+ | typing.Annotated[EvalDatasetXlsConfig, Tag(EvalDatasetXlsConfig.static_type())]
119
+ | typing.Annotated[EvalDatasetParquetConfig,
120
+ Tag(EvalDatasetParquetConfig.static_type())]
121
+ | typing.Annotated[EvalDatasetJsonlConfig,
122
+ Tag(EvalDatasetJsonlConfig.static_type())],
123
+ Discriminator(TypedBaseModel.discriminator)]