nvidia-nat 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/__init__.py +0 -0
  3. nat/agent/base.py +256 -0
  4. nat/agent/dual_node.py +67 -0
  5. nat/agent/react_agent/__init__.py +0 -0
  6. nat/agent/react_agent/agent.py +363 -0
  7. nat/agent/react_agent/output_parser.py +104 -0
  8. nat/agent/react_agent/prompt.py +44 -0
  9. nat/agent/react_agent/register.py +149 -0
  10. nat/agent/reasoning_agent/__init__.py +0 -0
  11. nat/agent/reasoning_agent/reasoning_agent.py +225 -0
  12. nat/agent/register.py +23 -0
  13. nat/agent/rewoo_agent/__init__.py +0 -0
  14. nat/agent/rewoo_agent/agent.py +415 -0
  15. nat/agent/rewoo_agent/prompt.py +110 -0
  16. nat/agent/rewoo_agent/register.py +157 -0
  17. nat/agent/tool_calling_agent/__init__.py +0 -0
  18. nat/agent/tool_calling_agent/agent.py +119 -0
  19. nat/agent/tool_calling_agent/register.py +106 -0
  20. nat/authentication/__init__.py +14 -0
  21. nat/authentication/api_key/__init__.py +14 -0
  22. nat/authentication/api_key/api_key_auth_provider.py +96 -0
  23. nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
  24. nat/authentication/api_key/register.py +26 -0
  25. nat/authentication/exceptions/__init__.py +14 -0
  26. nat/authentication/exceptions/api_key_exceptions.py +38 -0
  27. nat/authentication/http_basic_auth/__init__.py +0 -0
  28. nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  29. nat/authentication/http_basic_auth/register.py +30 -0
  30. nat/authentication/interfaces.py +93 -0
  31. nat/authentication/oauth2/__init__.py +14 -0
  32. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
  33. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  34. nat/authentication/oauth2/register.py +25 -0
  35. nat/authentication/register.py +21 -0
  36. nat/builder/__init__.py +0 -0
  37. nat/builder/builder.py +285 -0
  38. nat/builder/component_utils.py +316 -0
  39. nat/builder/context.py +270 -0
  40. nat/builder/embedder.py +24 -0
  41. nat/builder/eval_builder.py +161 -0
  42. nat/builder/evaluator.py +29 -0
  43. nat/builder/framework_enum.py +24 -0
  44. nat/builder/front_end.py +73 -0
  45. nat/builder/function.py +344 -0
  46. nat/builder/function_base.py +380 -0
  47. nat/builder/function_info.py +627 -0
  48. nat/builder/intermediate_step_manager.py +174 -0
  49. nat/builder/llm.py +25 -0
  50. nat/builder/retriever.py +25 -0
  51. nat/builder/user_interaction_manager.py +78 -0
  52. nat/builder/workflow.py +148 -0
  53. nat/builder/workflow_builder.py +1117 -0
  54. nat/cli/__init__.py +14 -0
  55. nat/cli/cli_utils/__init__.py +0 -0
  56. nat/cli/cli_utils/config_override.py +231 -0
  57. nat/cli/cli_utils/validation.py +37 -0
  58. nat/cli/commands/__init__.py +0 -0
  59. nat/cli/commands/configure/__init__.py +0 -0
  60. nat/cli/commands/configure/channel/__init__.py +0 -0
  61. nat/cli/commands/configure/channel/add.py +28 -0
  62. nat/cli/commands/configure/channel/channel.py +34 -0
  63. nat/cli/commands/configure/channel/remove.py +30 -0
  64. nat/cli/commands/configure/channel/update.py +30 -0
  65. nat/cli/commands/configure/configure.py +33 -0
  66. nat/cli/commands/evaluate.py +139 -0
  67. nat/cli/commands/info/__init__.py +14 -0
  68. nat/cli/commands/info/info.py +37 -0
  69. nat/cli/commands/info/list_channels.py +32 -0
  70. nat/cli/commands/info/list_components.py +129 -0
  71. nat/cli/commands/info/list_mcp.py +304 -0
  72. nat/cli/commands/registry/__init__.py +14 -0
  73. nat/cli/commands/registry/publish.py +88 -0
  74. nat/cli/commands/registry/pull.py +118 -0
  75. nat/cli/commands/registry/registry.py +36 -0
  76. nat/cli/commands/registry/remove.py +108 -0
  77. nat/cli/commands/registry/search.py +155 -0
  78. nat/cli/commands/sizing/__init__.py +14 -0
  79. nat/cli/commands/sizing/calc.py +297 -0
  80. nat/cli/commands/sizing/sizing.py +27 -0
  81. nat/cli/commands/start.py +246 -0
  82. nat/cli/commands/uninstall.py +81 -0
  83. nat/cli/commands/validate.py +47 -0
  84. nat/cli/commands/workflow/__init__.py +14 -0
  85. nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  86. nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
  87. nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
  88. nat/cli/commands/workflow/templates/register.py.j2 +5 -0
  89. nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
  90. nat/cli/commands/workflow/workflow.py +37 -0
  91. nat/cli/commands/workflow/workflow_commands.py +317 -0
  92. nat/cli/entrypoint.py +135 -0
  93. nat/cli/main.py +57 -0
  94. nat/cli/register_workflow.py +488 -0
  95. nat/cli/type_registry.py +1000 -0
  96. nat/data_models/__init__.py +14 -0
  97. nat/data_models/api_server.py +716 -0
  98. nat/data_models/authentication.py +231 -0
  99. nat/data_models/common.py +171 -0
  100. nat/data_models/component.py +58 -0
  101. nat/data_models/component_ref.py +168 -0
  102. nat/data_models/config.py +410 -0
  103. nat/data_models/dataset_handler.py +169 -0
  104. nat/data_models/discovery_metadata.py +305 -0
  105. nat/data_models/embedder.py +27 -0
  106. nat/data_models/evaluate.py +127 -0
  107. nat/data_models/evaluator.py +26 -0
  108. nat/data_models/front_end.py +26 -0
  109. nat/data_models/function.py +30 -0
  110. nat/data_models/function_dependencies.py +72 -0
  111. nat/data_models/interactive.py +246 -0
  112. nat/data_models/intermediate_step.py +302 -0
  113. nat/data_models/invocation_node.py +38 -0
  114. nat/data_models/llm.py +27 -0
  115. nat/data_models/logging.py +26 -0
  116. nat/data_models/memory.py +27 -0
  117. nat/data_models/object_store.py +44 -0
  118. nat/data_models/profiler.py +54 -0
  119. nat/data_models/registry_handler.py +26 -0
  120. nat/data_models/retriever.py +30 -0
  121. nat/data_models/retry_mixin.py +35 -0
  122. nat/data_models/span.py +190 -0
  123. nat/data_models/step_adaptor.py +64 -0
  124. nat/data_models/streaming.py +33 -0
  125. nat/data_models/swe_bench_model.py +54 -0
  126. nat/data_models/telemetry_exporter.py +26 -0
  127. nat/data_models/ttc_strategy.py +30 -0
  128. nat/embedder/__init__.py +0 -0
  129. nat/embedder/nim_embedder.py +59 -0
  130. nat/embedder/openai_embedder.py +43 -0
  131. nat/embedder/register.py +22 -0
  132. nat/eval/__init__.py +14 -0
  133. nat/eval/config.py +60 -0
  134. nat/eval/dataset_handler/__init__.py +0 -0
  135. nat/eval/dataset_handler/dataset_downloader.py +106 -0
  136. nat/eval/dataset_handler/dataset_filter.py +52 -0
  137. nat/eval/dataset_handler/dataset_handler.py +367 -0
  138. nat/eval/evaluate.py +510 -0
  139. nat/eval/evaluator/__init__.py +14 -0
  140. nat/eval/evaluator/base_evaluator.py +77 -0
  141. nat/eval/evaluator/evaluator_model.py +45 -0
  142. nat/eval/intermediate_step_adapter.py +99 -0
  143. nat/eval/rag_evaluator/__init__.py +0 -0
  144. nat/eval/rag_evaluator/evaluate.py +178 -0
  145. nat/eval/rag_evaluator/register.py +143 -0
  146. nat/eval/register.py +23 -0
  147. nat/eval/remote_workflow.py +133 -0
  148. nat/eval/runners/__init__.py +14 -0
  149. nat/eval/runners/config.py +39 -0
  150. nat/eval/runners/multi_eval_runner.py +54 -0
  151. nat/eval/runtime_event_subscriber.py +52 -0
  152. nat/eval/swe_bench_evaluator/__init__.py +0 -0
  153. nat/eval/swe_bench_evaluator/evaluate.py +215 -0
  154. nat/eval/swe_bench_evaluator/register.py +36 -0
  155. nat/eval/trajectory_evaluator/__init__.py +0 -0
  156. nat/eval/trajectory_evaluator/evaluate.py +75 -0
  157. nat/eval/trajectory_evaluator/register.py +40 -0
  158. nat/eval/tunable_rag_evaluator/__init__.py +0 -0
  159. nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
  160. nat/eval/tunable_rag_evaluator/register.py +52 -0
  161. nat/eval/usage_stats.py +41 -0
  162. nat/eval/utils/__init__.py +0 -0
  163. nat/eval/utils/output_uploader.py +140 -0
  164. nat/eval/utils/tqdm_position_registry.py +40 -0
  165. nat/eval/utils/weave_eval.py +184 -0
  166. nat/experimental/__init__.py +0 -0
  167. nat/experimental/decorators/__init__.py +0 -0
  168. nat/experimental/decorators/experimental_warning_decorator.py +134 -0
  169. nat/experimental/test_time_compute/__init__.py +0 -0
  170. nat/experimental/test_time_compute/editing/__init__.py +0 -0
  171. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  172. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  173. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  174. nat/experimental/test_time_compute/functions/__init__.py +0 -0
  175. nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  176. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
  177. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
  178. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
  179. nat/experimental/test_time_compute/models/__init__.py +0 -0
  180. nat/experimental/test_time_compute/models/editor_config.py +132 -0
  181. nat/experimental/test_time_compute/models/scoring_config.py +112 -0
  182. nat/experimental/test_time_compute/models/search_config.py +120 -0
  183. nat/experimental/test_time_compute/models/selection_config.py +154 -0
  184. nat/experimental/test_time_compute/models/stage_enums.py +43 -0
  185. nat/experimental/test_time_compute/models/strategy_base.py +66 -0
  186. nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
  187. nat/experimental/test_time_compute/models/ttc_item.py +48 -0
  188. nat/experimental/test_time_compute/register.py +36 -0
  189. nat/experimental/test_time_compute/scoring/__init__.py +0 -0
  190. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  191. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  192. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  193. nat/experimental/test_time_compute/search/__init__.py +0 -0
  194. nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  195. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  196. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  197. nat/experimental/test_time_compute/selection/__init__.py +0 -0
  198. nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  199. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  200. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
  201. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  202. nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  203. nat/front_ends/__init__.py +14 -0
  204. nat/front_ends/console/__init__.py +14 -0
  205. nat/front_ends/console/authentication_flow_handler.py +233 -0
  206. nat/front_ends/console/console_front_end_config.py +32 -0
  207. nat/front_ends/console/console_front_end_plugin.py +96 -0
  208. nat/front_ends/console/register.py +25 -0
  209. nat/front_ends/cron/__init__.py +14 -0
  210. nat/front_ends/fastapi/__init__.py +14 -0
  211. nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  212. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  213. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
  214. nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
  215. nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  216. nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
  217. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
  218. nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
  219. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  220. nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  221. nat/front_ends/fastapi/job_store.py +183 -0
  222. nat/front_ends/fastapi/main.py +72 -0
  223. nat/front_ends/fastapi/message_handler.py +320 -0
  224. nat/front_ends/fastapi/message_validator.py +352 -0
  225. nat/front_ends/fastapi/register.py +25 -0
  226. nat/front_ends/fastapi/response_helpers.py +195 -0
  227. nat/front_ends/fastapi/step_adaptor.py +319 -0
  228. nat/front_ends/mcp/__init__.py +14 -0
  229. nat/front_ends/mcp/mcp_front_end_config.py +36 -0
  230. nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
  231. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
  232. nat/front_ends/mcp/register.py +27 -0
  233. nat/front_ends/mcp/tool_converter.py +241 -0
  234. nat/front_ends/register.py +22 -0
  235. nat/front_ends/simple_base/__init__.py +14 -0
  236. nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
  237. nat/llm/__init__.py +0 -0
  238. nat/llm/aws_bedrock_llm.py +57 -0
  239. nat/llm/nim_llm.py +46 -0
  240. nat/llm/openai_llm.py +46 -0
  241. nat/llm/register.py +23 -0
  242. nat/llm/utils/__init__.py +14 -0
  243. nat/llm/utils/env_config_value.py +94 -0
  244. nat/llm/utils/error.py +17 -0
  245. nat/memory/__init__.py +20 -0
  246. nat/memory/interfaces.py +183 -0
  247. nat/memory/models.py +112 -0
  248. nat/meta/pypi.md +58 -0
  249. nat/object_store/__init__.py +20 -0
  250. nat/object_store/in_memory_object_store.py +76 -0
  251. nat/object_store/interfaces.py +84 -0
  252. nat/object_store/models.py +38 -0
  253. nat/object_store/register.py +20 -0
  254. nat/observability/__init__.py +14 -0
  255. nat/observability/exporter/__init__.py +14 -0
  256. nat/observability/exporter/base_exporter.py +449 -0
  257. nat/observability/exporter/exporter.py +78 -0
  258. nat/observability/exporter/file_exporter.py +33 -0
  259. nat/observability/exporter/processing_exporter.py +322 -0
  260. nat/observability/exporter/raw_exporter.py +52 -0
  261. nat/observability/exporter/span_exporter.py +288 -0
  262. nat/observability/exporter_manager.py +335 -0
  263. nat/observability/mixin/__init__.py +14 -0
  264. nat/observability/mixin/batch_config_mixin.py +26 -0
  265. nat/observability/mixin/collector_config_mixin.py +23 -0
  266. nat/observability/mixin/file_mixin.py +288 -0
  267. nat/observability/mixin/file_mode.py +23 -0
  268. nat/observability/mixin/resource_conflict_mixin.py +134 -0
  269. nat/observability/mixin/serialize_mixin.py +61 -0
  270. nat/observability/mixin/type_introspection_mixin.py +183 -0
  271. nat/observability/processor/__init__.py +14 -0
  272. nat/observability/processor/batching_processor.py +310 -0
  273. nat/observability/processor/callback_processor.py +42 -0
  274. nat/observability/processor/intermediate_step_serializer.py +28 -0
  275. nat/observability/processor/processor.py +71 -0
  276. nat/observability/register.py +96 -0
  277. nat/observability/utils/__init__.py +14 -0
  278. nat/observability/utils/dict_utils.py +236 -0
  279. nat/observability/utils/time_utils.py +31 -0
  280. nat/plugins/.namespace +1 -0
  281. nat/profiler/__init__.py +0 -0
  282. nat/profiler/calc/__init__.py +14 -0
  283. nat/profiler/calc/calc_runner.py +627 -0
  284. nat/profiler/calc/calculations.py +288 -0
  285. nat/profiler/calc/data_models.py +188 -0
  286. nat/profiler/calc/plot.py +345 -0
  287. nat/profiler/callbacks/__init__.py +0 -0
  288. nat/profiler/callbacks/agno_callback_handler.py +295 -0
  289. nat/profiler/callbacks/base_callback_class.py +20 -0
  290. nat/profiler/callbacks/langchain_callback_handler.py +290 -0
  291. nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
  292. nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  293. nat/profiler/callbacks/token_usage_base_model.py +27 -0
  294. nat/profiler/data_frame_row.py +51 -0
  295. nat/profiler/data_models.py +24 -0
  296. nat/profiler/decorators/__init__.py +0 -0
  297. nat/profiler/decorators/framework_wrapper.py +131 -0
  298. nat/profiler/decorators/function_tracking.py +254 -0
  299. nat/profiler/forecasting/__init__.py +0 -0
  300. nat/profiler/forecasting/config.py +18 -0
  301. nat/profiler/forecasting/model_trainer.py +75 -0
  302. nat/profiler/forecasting/models/__init__.py +22 -0
  303. nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
  304. nat/profiler/forecasting/models/linear_model.py +197 -0
  305. nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
  306. nat/profiler/inference_metrics_model.py +28 -0
  307. nat/profiler/inference_optimization/__init__.py +0 -0
  308. nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  309. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  310. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  311. nat/profiler/inference_optimization/data_models.py +386 -0
  312. nat/profiler/inference_optimization/experimental/__init__.py +0 -0
  313. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  314. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
  315. nat/profiler/inference_optimization/llm_metrics.py +212 -0
  316. nat/profiler/inference_optimization/prompt_caching.py +163 -0
  317. nat/profiler/inference_optimization/token_uniqueness.py +107 -0
  318. nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
  319. nat/profiler/intermediate_property_adapter.py +102 -0
  320. nat/profiler/profile_runner.py +473 -0
  321. nat/profiler/utils.py +184 -0
  322. nat/registry_handlers/__init__.py +0 -0
  323. nat/registry_handlers/local/__init__.py +0 -0
  324. nat/registry_handlers/local/local_handler.py +176 -0
  325. nat/registry_handlers/local/register_local.py +37 -0
  326. nat/registry_handlers/metadata_factory.py +60 -0
  327. nat/registry_handlers/package_utils.py +571 -0
  328. nat/registry_handlers/pypi/__init__.py +0 -0
  329. nat/registry_handlers/pypi/pypi_handler.py +251 -0
  330. nat/registry_handlers/pypi/register_pypi.py +40 -0
  331. nat/registry_handlers/register.py +21 -0
  332. nat/registry_handlers/registry_handler_base.py +157 -0
  333. nat/registry_handlers/rest/__init__.py +0 -0
  334. nat/registry_handlers/rest/register_rest.py +56 -0
  335. nat/registry_handlers/rest/rest_handler.py +237 -0
  336. nat/registry_handlers/schemas/__init__.py +0 -0
  337. nat/registry_handlers/schemas/headers.py +42 -0
  338. nat/registry_handlers/schemas/package.py +68 -0
  339. nat/registry_handlers/schemas/publish.py +68 -0
  340. nat/registry_handlers/schemas/pull.py +82 -0
  341. nat/registry_handlers/schemas/remove.py +36 -0
  342. nat/registry_handlers/schemas/search.py +91 -0
  343. nat/registry_handlers/schemas/status.py +47 -0
  344. nat/retriever/__init__.py +0 -0
  345. nat/retriever/interface.py +41 -0
  346. nat/retriever/milvus/__init__.py +14 -0
  347. nat/retriever/milvus/register.py +81 -0
  348. nat/retriever/milvus/retriever.py +228 -0
  349. nat/retriever/models.py +77 -0
  350. nat/retriever/nemo_retriever/__init__.py +14 -0
  351. nat/retriever/nemo_retriever/register.py +60 -0
  352. nat/retriever/nemo_retriever/retriever.py +190 -0
  353. nat/retriever/register.py +22 -0
  354. nat/runtime/__init__.py +14 -0
  355. nat/runtime/loader.py +220 -0
  356. nat/runtime/runner.py +195 -0
  357. nat/runtime/session.py +162 -0
  358. nat/runtime/user_metadata.py +130 -0
  359. nat/settings/__init__.py +0 -0
  360. nat/settings/global_settings.py +318 -0
  361. nat/test/.namespace +1 -0
  362. nat/tool/__init__.py +0 -0
  363. nat/tool/chat_completion.py +74 -0
  364. nat/tool/code_execution/README.md +151 -0
  365. nat/tool/code_execution/__init__.py +0 -0
  366. nat/tool/code_execution/code_sandbox.py +267 -0
  367. nat/tool/code_execution/local_sandbox/.gitignore +1 -0
  368. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  369. nat/tool/code_execution/local_sandbox/__init__.py +13 -0
  370. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  371. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  372. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  373. nat/tool/code_execution/register.py +74 -0
  374. nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
  375. nat/tool/code_execution/utils.py +100 -0
  376. nat/tool/datetime_tools.py +42 -0
  377. nat/tool/document_search.py +141 -0
  378. nat/tool/github_tools/__init__.py +0 -0
  379. nat/tool/github_tools/create_github_commit.py +133 -0
  380. nat/tool/github_tools/create_github_issue.py +87 -0
  381. nat/tool/github_tools/create_github_pr.py +106 -0
  382. nat/tool/github_tools/get_github_file.py +106 -0
  383. nat/tool/github_tools/get_github_issue.py +166 -0
  384. nat/tool/github_tools/get_github_pr.py +256 -0
  385. nat/tool/github_tools/update_github_issue.py +100 -0
  386. nat/tool/mcp/__init__.py +14 -0
  387. nat/tool/mcp/exceptions.py +142 -0
  388. nat/tool/mcp/mcp_client.py +255 -0
  389. nat/tool/mcp/mcp_tool.py +96 -0
  390. nat/tool/memory_tools/__init__.py +0 -0
  391. nat/tool/memory_tools/add_memory_tool.py +79 -0
  392. nat/tool/memory_tools/delete_memory_tool.py +67 -0
  393. nat/tool/memory_tools/get_memory_tool.py +72 -0
  394. nat/tool/nvidia_rag.py +95 -0
  395. nat/tool/register.py +38 -0
  396. nat/tool/retriever.py +94 -0
  397. nat/tool/server_tools.py +66 -0
  398. nat/utils/__init__.py +0 -0
  399. nat/utils/data_models/__init__.py +0 -0
  400. nat/utils/data_models/schema_validator.py +58 -0
  401. nat/utils/debugging_utils.py +43 -0
  402. nat/utils/dump_distro_mapping.py +32 -0
  403. nat/utils/exception_handlers/__init__.py +0 -0
  404. nat/utils/exception_handlers/automatic_retries.py +289 -0
  405. nat/utils/exception_handlers/mcp.py +211 -0
  406. nat/utils/exception_handlers/schemas.py +114 -0
  407. nat/utils/io/__init__.py +0 -0
  408. nat/utils/io/model_processing.py +28 -0
  409. nat/utils/io/yaml_tools.py +119 -0
  410. nat/utils/log_utils.py +37 -0
  411. nat/utils/metadata_utils.py +74 -0
  412. nat/utils/optional_imports.py +142 -0
  413. nat/utils/producer_consumer_queue.py +178 -0
  414. nat/utils/reactive/__init__.py +0 -0
  415. nat/utils/reactive/base/__init__.py +0 -0
  416. nat/utils/reactive/base/observable_base.py +65 -0
  417. nat/utils/reactive/base/observer_base.py +55 -0
  418. nat/utils/reactive/base/subject_base.py +79 -0
  419. nat/utils/reactive/observable.py +59 -0
  420. nat/utils/reactive/observer.py +76 -0
  421. nat/utils/reactive/subject.py +131 -0
  422. nat/utils/reactive/subscription.py +49 -0
  423. nat/utils/settings/__init__.py +0 -0
  424. nat/utils/settings/global_settings.py +197 -0
  425. nat/utils/string_utils.py +38 -0
  426. nat/utils/type_converter.py +290 -0
  427. nat/utils/type_utils.py +484 -0
  428. nat/utils/url_utils.py +27 -0
  429. nvidia_nat-1.2.0.dist-info/METADATA +365 -0
  430. nvidia_nat-1.2.0.dist-info/RECORD +435 -0
  431. nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
  432. nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
  433. nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  434. nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
  435. nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,410 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import sys
18
+ import typing
19
+
20
+ from pydantic import BaseModel
21
+ from pydantic import ConfigDict
22
+ from pydantic import Discriminator
23
+ from pydantic import ValidationError
24
+ from pydantic import ValidationInfo
25
+ from pydantic import ValidatorFunctionWrapHandler
26
+ from pydantic import field_validator
27
+
28
+ from nat.data_models.evaluate import EvalConfig
29
+ from nat.data_models.front_end import FrontEndBaseConfig
30
+ from nat.data_models.function import EmptyFunctionConfig
31
+ from nat.data_models.function import FunctionBaseConfig
32
+ from nat.data_models.logging import LoggingBaseConfig
33
+ from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
34
+ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
35
+ from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
36
+
37
+ from .authentication import AuthProviderBaseConfig
38
+ from .common import HashableBaseModel
39
+ from .common import TypedBaseModel
40
+ from .embedder import EmbedderBaseConfig
41
+ from .llm import LLMBaseConfig
42
+ from .memory import MemoryBaseConfig
43
+ from .object_store import ObjectStoreBaseConfig
44
+ from .retriever import RetrieverBaseConfig
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
50
+ from nat.cli.type_registry import GlobalTypeRegistry # pylint: disable=cyclic-import
51
+
52
+ new_errors = []
53
+ logged_once = False
54
+ needs_reraise = False
55
+ for e in err.errors():
56
+
57
+ error_type = e['type']
58
+ if error_type == 'union_tag_invalid' and "ctx" in e and not logged_once:
59
+ requested_type = e["ctx"]["tag"]
60
+
61
+ if (info.field_name in ('workflow', 'functions')):
62
+ registered_keys = GlobalTypeRegistry.get().get_registered_functions()
63
+ elif (info.field_name == "authentication"):
64
+ registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
65
+ elif (info.field_name == "llms"):
66
+ registered_keys = GlobalTypeRegistry.get().get_registered_llm_providers()
67
+ elif (info.field_name == "embedders"):
68
+ registered_keys = GlobalTypeRegistry.get().get_registered_embedder_providers()
69
+ elif (info.field_name == "memory"):
70
+ registered_keys = GlobalTypeRegistry.get().get_registered_memorys()
71
+ elif (info.field_name == "object_stores"):
72
+ registered_keys = GlobalTypeRegistry.get().get_registered_object_stores()
73
+ elif (info.field_name == "retrievers"):
74
+ registered_keys = GlobalTypeRegistry.get().get_registered_retriever_providers()
75
+ elif (info.field_name == "tracing"):
76
+ registered_keys = GlobalTypeRegistry.get().get_registered_telemetry_exporters()
77
+ elif (info.field_name == "logging"):
78
+ registered_keys = GlobalTypeRegistry.get().get_registered_logging_method()
79
+ elif (info.field_name == "evaluators"):
80
+ registered_keys = GlobalTypeRegistry.get().get_registered_evaluators()
81
+ elif (info.field_name == "front_ends"):
82
+ registered_keys = GlobalTypeRegistry.get().get_registered_front_ends()
83
+ elif (info.field_name == "ttc_strategies"):
84
+ registered_keys = GlobalTypeRegistry.get().get_registered_ttc_strategies()
85
+
86
+ else:
87
+ assert False, f"Unknown field name {info.field_name} in validator"
88
+
89
+ # Check and see if the there are multiple full types which match this short type
90
+ matching_keys = [k for k in registered_keys if k.local_name == requested_type]
91
+
92
+ assert len(matching_keys) != 1, "Exact match should have been found. Contact developers"
93
+
94
+ matching_key_names = [x.full_type for x in matching_keys]
95
+ registered_key_names = [x.full_type for x in registered_keys]
96
+
97
+ if (len(matching_keys) == 0):
98
+ # This is a case where the requested type is not found. Show a helpful message about what is
99
+ # available
100
+ logger.error(("Requested %s type `%s` not found. "
101
+ "Have you ensured the necessary package has been installed with `uv pip install`?"
102
+ "\nAvailable %s names:\n - %s\n"),
103
+ info.field_name,
104
+ requested_type,
105
+ info.field_name,
106
+ '\n - '.join(registered_key_names))
107
+ else:
108
+ # This is a case where the requested type is ambiguous.
109
+ logger.error(("Requested %s type `%s` is ambiguous. "
110
+ "Matched multiple %s by their local name: %s. "
111
+ "Please use the fully qualified %s name."
112
+ "\nAvailable %s names:\n - %s\n"),
113
+ info.field_name,
114
+ requested_type,
115
+ info.field_name,
116
+ matching_key_names,
117
+ info.field_name,
118
+ info.field_name,
119
+ '\n - '.join(registered_key_names))
120
+
121
+ # Only show one error
122
+ logged_once = True
123
+
124
+ elif error_type == 'missing':
125
+ location = e["loc"]
126
+ if len(location) > 1: # remove the _type field from the location
127
+ e['loc'] = (location[0], ) + location[2:]
128
+ needs_reraise = True
129
+
130
+ new_errors.append(e)
131
+
132
+ if needs_reraise:
133
+ raise ValidationError.from_exception_data(title=err.title, line_errors=new_errors)
134
+
135
+
136
+ class TelemetryConfig(BaseModel):
137
+
138
+ logging: dict[str, LoggingBaseConfig] = {}
139
+ tracing: dict[str, TelemetryExporterBaseConfig] = {}
140
+
141
+ @field_validator("logging", "tracing", mode="wrap")
142
+ @classmethod
143
+ def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
144
+
145
+ try:
146
+ return handler(value)
147
+ except ValidationError as err:
148
+ _process_validation_error(err, handler, info)
149
+ raise
150
+
151
+ @classmethod
152
+ def rebuild_annotations(cls):
153
+
154
+ from nat.cli.type_registry import GlobalTypeRegistry
155
+
156
+ type_registry = GlobalTypeRegistry.get()
157
+
158
+ TracingAnnotation = dict[str,
159
+ typing.Annotated[type_registry.compute_annotation(TelemetryExporterBaseConfig),
160
+ Discriminator(TypedBaseModel.discriminator)]]
161
+
162
+ LoggingAnnotation = dict[str,
163
+ typing.Annotated[type_registry.compute_annotation(LoggingBaseConfig),
164
+ Discriminator(TypedBaseModel.discriminator)]]
165
+
166
+ should_rebuild = False
167
+
168
+ tracing_field = cls.model_fields.get("tracing")
169
+ if tracing_field is not None and tracing_field.annotation != TracingAnnotation:
170
+ tracing_field.annotation = TracingAnnotation
171
+ should_rebuild = True
172
+
173
+ logging_field = cls.model_fields.get("logging")
174
+ if logging_field is not None and logging_field.annotation != LoggingAnnotation:
175
+ logging_field.annotation = LoggingAnnotation
176
+ should_rebuild = True
177
+
178
+ if (should_rebuild):
179
+ return cls.model_rebuild(force=True)
180
+
181
+ return False
182
+
183
+
184
+ class GeneralConfig(BaseModel):
185
+
186
+ model_config = ConfigDict(protected_namespaces=())
187
+
188
+ use_uvloop: bool = True
189
+ """
190
+ Whether to use uvloop for the event loop. This can provide a significant speedup in some cases. Disable to provide
191
+ better error messages when debugging.
192
+ """
193
+
194
+ telemetry: TelemetryConfig = TelemetryConfig()
195
+
196
+ # FrontEnd Configuration
197
+ front_end: FrontEndBaseConfig = FastApiFrontEndConfig()
198
+
199
+ @field_validator("front_end", mode="wrap")
200
+ @classmethod
201
+ def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
202
+
203
+ try:
204
+ return handler(value)
205
+ except ValidationError as err:
206
+ _process_validation_error(err, handler, info)
207
+ raise
208
+
209
+ @classmethod
210
+ def rebuild_annotations(cls):
211
+
212
+ from nat.cli.type_registry import GlobalTypeRegistry
213
+
214
+ type_registry = GlobalTypeRegistry.get()
215
+
216
+ FrontEndAnnotation = typing.Annotated[type_registry.compute_annotation(FrontEndBaseConfig),
217
+ Discriminator(TypedBaseModel.discriminator)]
218
+
219
+ should_rebuild = False
220
+
221
+ front_end_field = cls.model_fields.get("front_end")
222
+ if front_end_field is not None and front_end_field.annotation != FrontEndAnnotation:
223
+ front_end_field.annotation = FrontEndAnnotation
224
+ should_rebuild = True
225
+
226
+ if (TelemetryConfig.rebuild_annotations()):
227
+ should_rebuild = True
228
+
229
+ if (should_rebuild):
230
+ return cls.model_rebuild(force=True)
231
+
232
+ return False
233
+
234
+
235
+ class Config(HashableBaseModel):
236
+
237
+ model_config = ConfigDict(extra="forbid")
238
+
239
+ # Global Options
240
+ general: GeneralConfig = GeneralConfig()
241
+
242
+ # Functions Configuration
243
+ functions: dict[str, FunctionBaseConfig] = {}
244
+
245
+ # LLMs Configuration
246
+ llms: dict[str, LLMBaseConfig] = {}
247
+
248
+ # Embedders Configuration
249
+ embedders: dict[str, EmbedderBaseConfig] = {}
250
+
251
+ # Memory Configuration
252
+ memory: dict[str, MemoryBaseConfig] = {}
253
+
254
+ # Object Stores Configuration
255
+ object_stores: dict[str, ObjectStoreBaseConfig] = {}
256
+
257
+ # Retriever Configuration
258
+ retrievers: dict[str, RetrieverBaseConfig] = {}
259
+
260
+ # TTC Strategies
261
+ ttc_strategies: dict[str, TTCStrategyBaseConfig] = {}
262
+
263
+ # Workflow Configuration
264
+ workflow: FunctionBaseConfig = EmptyFunctionConfig()
265
+
266
+ # Authentication Configuration
267
+ authentication: dict[str, AuthProviderBaseConfig] = {}
268
+
269
+ # Evaluation Options
270
+ eval: EvalConfig = EvalConfig()
271
+
272
+ def print_summary(self, stream: typing.TextIO = sys.stdout):
273
+ """Print a summary of the configuration"""
274
+
275
+ stream.write("\nConfiguration Summary:\n")
276
+ stream.write("-" * 20 + "\n")
277
+ if self.workflow:
278
+ stream.write(f"Workflow Type: {self.workflow.type}\n")
279
+
280
+ stream.write(f"Number of Functions: {len(self.functions)}\n")
281
+ stream.write(f"Number of LLMs: {len(self.llms)}\n")
282
+ stream.write(f"Number of Embedders: {len(self.embedders)}\n")
283
+ stream.write(f"Number of Memory: {len(self.memory)}\n")
284
+ stream.write(f"Number of Object Stores: {len(self.object_stores)}\n")
285
+ stream.write(f"Number of Retrievers: {len(self.retrievers)}\n")
286
+ stream.write(f"Number of TTC Strategies: {len(self.ttc_strategies)}\n")
287
+ stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
288
+
289
+ @field_validator("functions",
290
+ "llms",
291
+ "embedders",
292
+ "memory",
293
+ "retrievers",
294
+ "workflow",
295
+ "ttc_strategies",
296
+ "authentication",
297
+ mode="wrap")
298
+ @classmethod
299
+ def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
300
+
301
+ try:
302
+ return handler(value)
303
+ except ValidationError as err:
304
+ _process_validation_error(err, handler, info)
305
+ raise
306
+
307
+ @classmethod
308
+ def rebuild_annotations(cls):
309
+
310
+ from nat.cli.type_registry import GlobalTypeRegistry
311
+
312
+ type_registry = GlobalTypeRegistry.get()
313
+
314
+ LLMsAnnotation = dict[str,
315
+ typing.Annotated[type_registry.compute_annotation(LLMBaseConfig),
316
+ Discriminator(TypedBaseModel.discriminator)]]
317
+
318
+ AuthenticationProviderAnnotation = dict[str,
319
+ typing.Annotated[
320
+ type_registry.compute_annotation(AuthProviderBaseConfig),
321
+ Discriminator(TypedBaseModel.discriminator)]]
322
+
323
+ EmbeddersAnnotation = dict[str,
324
+ typing.Annotated[type_registry.compute_annotation(EmbedderBaseConfig),
325
+ Discriminator(TypedBaseModel.discriminator)]]
326
+
327
+ FunctionsAnnotation = dict[str,
328
+ typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
329
+ Discriminator(TypedBaseModel.discriminator)]]
330
+
331
+ MemoryAnnotation = dict[str,
332
+ typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
333
+ Discriminator(TypedBaseModel.discriminator)]]
334
+
335
+ ObjectStoreAnnotation = dict[str,
336
+ typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
337
+ Discriminator(TypedBaseModel.discriminator)]]
338
+
339
+ RetrieverAnnotation = dict[str,
340
+ typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
341
+ Discriminator(TypedBaseModel.discriminator)]]
342
+
343
+ TTCStrategyAnnotation = dict[str,
344
+ typing.Annotated[type_registry.compute_annotation(TTCStrategyBaseConfig),
345
+ Discriminator(TypedBaseModel.discriminator)]]
346
+
347
+ WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
348
+ Discriminator(TypedBaseModel.discriminator)]
349
+
350
+ should_rebuild = False
351
+
352
+ auth_providers_field = cls.model_fields.get("authentication")
353
+ if auth_providers_field is not None and auth_providers_field.annotation != AuthenticationProviderAnnotation:
354
+ auth_providers_field.annotation = AuthenticationProviderAnnotation
355
+ should_rebuild = True
356
+
357
+ llms_field = cls.model_fields.get("llms")
358
+ if llms_field is not None and llms_field.annotation != LLMsAnnotation:
359
+ llms_field.annotation = LLMsAnnotation
360
+ should_rebuild = True
361
+
362
+ embedders_field = cls.model_fields.get("embedders")
363
+ if embedders_field is not None and embedders_field.annotation != EmbeddersAnnotation:
364
+ embedders_field.annotation = EmbeddersAnnotation
365
+ should_rebuild = True
366
+
367
+ functions_field = cls.model_fields.get("functions")
368
+ if functions_field is not None and functions_field.annotation != FunctionsAnnotation:
369
+ functions_field.annotation = FunctionsAnnotation
370
+ should_rebuild = True
371
+
372
+ memory_field = cls.model_fields.get("memory")
373
+ if memory_field is not None and memory_field.annotation != MemoryAnnotation:
374
+ memory_field.annotation = MemoryAnnotation
375
+ should_rebuild = True
376
+
377
+ object_stores_field = cls.model_fields.get("object_stores")
378
+ if object_stores_field is not None and object_stores_field.annotation != ObjectStoreAnnotation:
379
+ object_stores_field.annotation = ObjectStoreAnnotation
380
+ should_rebuild = True
381
+
382
+ retrievers_field = cls.model_fields.get("retrievers")
383
+ if retrievers_field is not None and retrievers_field.annotation != RetrieverAnnotation:
384
+ retrievers_field.annotation = RetrieverAnnotation
385
+ should_rebuild = True
386
+
387
+ ttc_strategies_field = cls.model_fields.get("ttc_strategies")
388
+ if ttc_strategies_field is not None and ttc_strategies_field.annotation != TTCStrategyAnnotation:
389
+ ttc_strategies_field.annotation = TTCStrategyAnnotation
390
+ should_rebuild = True
391
+
392
+ workflow_field = cls.model_fields.get("workflow")
393
+ if workflow_field is not None and workflow_field.annotation != WorkflowAnnotation:
394
+ workflow_field.annotation = WorkflowAnnotation
395
+ should_rebuild = True
396
+
397
+ if (GeneralConfig.rebuild_annotations()):
398
+ should_rebuild = True
399
+
400
+ if (EvalConfig.rebuild_annotations()):
401
+ should_rebuild = True
402
+
403
+ if (should_rebuild):
404
+ return cls.model_rebuild(force=True)
405
+
406
+ return False
407
+
408
+
409
+ # Compatibility aliases with previous releases
410
+ AIQConfig = Config
@@ -0,0 +1,169 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ import json
18
+ import typing
19
+ from collections.abc import Callable
20
+ from pathlib import Path
21
+
22
+ import pandas as pd
23
+ from pydantic import BaseModel
24
+ from pydantic import Discriminator
25
+ from pydantic import FilePath
26
+ from pydantic import Tag
27
+
28
+ from nat.data_models.common import BaseModelRegistryTag
29
+ from nat.data_models.common import TypedBaseModel
30
+
31
+
32
+ class EvalS3Config(BaseModel):
33
+
34
+ endpoint_url: str | None = None
35
+ region_name: str | None = None
36
+ bucket: str
37
+ access_key: str
38
+ secret_key: str
39
+
40
+
41
+ class EvalFilterEntryConfig(BaseModel):
42
+ # values are lists of allowed/blocked values
43
+ field: dict[str, list[str | int | float]] = {}
44
+
45
+
46
+ class EvalFilterConfig(BaseModel):
47
+ allowlist: EvalFilterEntryConfig | None = None
48
+ denylist: EvalFilterEntryConfig | None = None
49
+
50
+
51
+ class EvalDatasetStructureConfig(BaseModel):
52
+ disable: bool = False
53
+ question_key: str = "question"
54
+ answer_key: str = "answer"
55
+ generated_answer_key: str = "generated_answer"
56
+ trajectory_key: str = "intermediate_steps"
57
+ expected_trajectory_key: str = "expected_intermediate_steps"
58
+
59
+
60
+ # Base model
61
+ class EvalDatasetBaseConfig(TypedBaseModel, BaseModelRegistryTag):
62
+
63
+ id_key: str = "id"
64
+ structure: EvalDatasetStructureConfig = EvalDatasetStructureConfig()
65
+
66
+ # Filters
67
+ filter: EvalFilterConfig | None = EvalFilterConfig()
68
+
69
+ s3: EvalS3Config | None = None
70
+
71
+ remote_file_path: str | None = None # only for s3
72
+ file_path: Path | str = Path(".tmp/nat/examples/default/default.json")
73
+
74
+
75
+ class EvalDatasetJsonConfig(EvalDatasetBaseConfig, name="json"):
76
+
77
+ @staticmethod
78
+ def parser() -> tuple[Callable, dict]:
79
+ return pd.read_json, {}
80
+
81
+
82
+ def read_jsonl(file_path: FilePath):
83
+ with open(file_path, 'r', encoding='utf-8') as f:
84
+ data = [json.loads(line) for line in f]
85
+ return pd.DataFrame(data)
86
+
87
+
88
+ class EvalDatasetJsonlConfig(EvalDatasetBaseConfig, name="jsonl"):
89
+
90
+ @staticmethod
91
+ def parser() -> tuple[Callable, dict]:
92
+ return read_jsonl, {}
93
+
94
+
95
+ class EvalDatasetCsvConfig(EvalDatasetBaseConfig, name="csv"):
96
+
97
+ @staticmethod
98
+ def parser() -> tuple[Callable, dict]:
99
+ return pd.read_csv, {}
100
+
101
+
102
+ class EvalDatasetParquetConfig(EvalDatasetBaseConfig, name="parquet"):
103
+
104
+ @staticmethod
105
+ def parser() -> tuple[Callable, dict]:
106
+ return pd.read_parquet, {}
107
+
108
+
109
+ class EvalDatasetXlsConfig(EvalDatasetBaseConfig, name="xls"):
110
+
111
+ @staticmethod
112
+ def parser() -> tuple[Callable, dict]:
113
+ return pd.read_excel, {"engine": "openpyxl"}
114
+
115
+
116
+ class EvalDatasetCustomConfig(EvalDatasetBaseConfig, name="custom"):
117
+ """
118
+ Configuration for custom dataset type that allows users to specify
119
+ a custom Python function to transform their dataset into EvalInput format.
120
+ """
121
+
122
+ function: str # Direct import path to function, format: "module.path.function_name"
123
+ kwargs: dict[str, typing.Any] = {} # Additional arguments to pass to the custom function
124
+
125
+ def parser(self) -> tuple[Callable, dict]:
126
+ """
127
+ Load and return the custom function for dataset transformation.
128
+
129
+ Returns:
130
+ Tuple of (custom_function, kwargs) where custom_function transforms
131
+ a dataset file into an EvalInput object.
132
+ """
133
+ custom_function = self._load_custom_function()
134
+ return custom_function, self.kwargs
135
+
136
+ def _load_custom_function(self) -> Callable:
137
+ """
138
+ Import and return the custom function using standard Python import path.
139
+ """
140
+ if not self.function:
141
+ raise ValueError("Function path cannot be empty")
142
+
143
+ # Split the function path to get module and function name
144
+ module_path, function_name = self.function.rsplit(".", 1)
145
+
146
+ # Import the module
147
+ module = importlib.import_module(module_path)
148
+
149
+ # Get the function from the module
150
+ if not hasattr(module, function_name):
151
+ raise AttributeError(f"Function '{function_name}' not found in module '{module_path}'")
152
+
153
+ custom_function = getattr(module, function_name)
154
+
155
+ if not callable(custom_function):
156
+ raise ValueError(f"'{self.function}' is not callable")
157
+
158
+ return custom_function
159
+
160
+
161
+ # Union model with discriminator
162
+ EvalDatasetConfig = typing.Annotated[
163
+ typing.Annotated[EvalDatasetJsonConfig, Tag(EvalDatasetJsonConfig.static_type())]
164
+ | typing.Annotated[EvalDatasetCsvConfig, Tag(EvalDatasetCsvConfig.static_type())]
165
+ | typing.Annotated[EvalDatasetXlsConfig, Tag(EvalDatasetXlsConfig.static_type())]
166
+ | typing.Annotated[EvalDatasetParquetConfig, Tag(EvalDatasetParquetConfig.static_type())]
167
+ | typing.Annotated[EvalDatasetJsonlConfig, Tag(EvalDatasetJsonlConfig.static_type())]
168
+ | typing.Annotated[EvalDatasetCustomConfig, Tag(EvalDatasetCustomConfig.static_type())],
169
+ Discriminator(TypedBaseModel.discriminator)]