nvidia-nat 1.4.0a20251112__py3-none-any.whl → 1.4.0a20260113__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (492) hide show
  1. aiq/__init__.py +1 -1
  2. nat/{front_ends/mcp → agent/auto_memory_wrapper}/__init__.py +1 -1
  3. nat/agent/auto_memory_wrapper/agent.py +278 -0
  4. nat/agent/auto_memory_wrapper/register.py +227 -0
  5. nat/agent/auto_memory_wrapper/state.py +30 -0
  6. nat/agent/base.py +1 -1
  7. nat/agent/dual_node.py +1 -1
  8. nat/agent/prompt_optimizer/prompt.py +1 -1
  9. nat/agent/prompt_optimizer/register.py +1 -1
  10. nat/agent/react_agent/agent.py +16 -9
  11. nat/agent/react_agent/output_parser.py +2 -2
  12. nat/agent/react_agent/prompt.py +3 -2
  13. nat/agent/react_agent/register.py +2 -2
  14. nat/agent/react_agent/register_per_user_agent.py +104 -0
  15. nat/agent/reasoning_agent/reasoning_agent.py +1 -1
  16. nat/agent/register.py +3 -1
  17. nat/agent/responses_api_agent/__init__.py +1 -1
  18. nat/agent/responses_api_agent/register.py +1 -1
  19. nat/agent/rewoo_agent/agent.py +9 -4
  20. nat/agent/rewoo_agent/prompt.py +1 -1
  21. nat/agent/rewoo_agent/register.py +1 -1
  22. nat/agent/tool_calling_agent/agent.py +5 -4
  23. nat/agent/tool_calling_agent/register.py +1 -1
  24. nat/authentication/__init__.py +1 -1
  25. nat/authentication/api_key/__init__.py +1 -1
  26. nat/authentication/api_key/api_key_auth_provider.py +1 -1
  27. nat/authentication/api_key/api_key_auth_provider_config.py +22 -7
  28. nat/authentication/api_key/register.py +1 -1
  29. nat/authentication/credential_validator/__init__.py +1 -1
  30. nat/authentication/credential_validator/bearer_token_validator.py +1 -1
  31. nat/authentication/exceptions/__init__.py +1 -1
  32. nat/authentication/exceptions/api_key_exceptions.py +1 -1
  33. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  34. nat/authentication/http_basic_auth/register.py +1 -1
  35. nat/authentication/interfaces.py +1 -1
  36. nat/authentication/oauth2/__init__.py +1 -1
  37. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +1 -1
  38. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +1 -1
  39. nat/authentication/oauth2/oauth2_resource_server_config.py +1 -1
  40. nat/authentication/oauth2/register.py +1 -1
  41. nat/authentication/register.py +1 -1
  42. nat/builder/builder.py +563 -1
  43. nat/builder/child_builder.py +385 -0
  44. nat/builder/component_utils.py +34 -4
  45. nat/builder/context.py +34 -1
  46. nat/builder/embedder.py +1 -1
  47. nat/builder/eval_builder.py +19 -7
  48. nat/builder/evaluator.py +1 -1
  49. nat/builder/framework_enum.py +3 -1
  50. nat/builder/front_end.py +1 -1
  51. nat/builder/function.py +113 -5
  52. nat/builder/function_base.py +1 -1
  53. nat/builder/function_info.py +1 -1
  54. nat/builder/intermediate_step_manager.py +1 -1
  55. nat/builder/llm.py +1 -1
  56. nat/builder/per_user_workflow_builder.py +843 -0
  57. nat/builder/retriever.py +1 -1
  58. nat/builder/sync_builder.py +571 -0
  59. nat/builder/user_interaction_manager.py +1 -1
  60. nat/builder/workflow.py +5 -3
  61. nat/builder/workflow_builder.py +619 -378
  62. nat/cli/__init__.py +1 -1
  63. nat/cli/cli_utils/config_override.py +1 -1
  64. nat/cli/cli_utils/validation.py +32 -1
  65. nat/cli/commands/configure/channel/add.py +1 -1
  66. nat/cli/commands/configure/channel/channel.py +1 -1
  67. nat/cli/commands/configure/channel/remove.py +1 -1
  68. nat/cli/commands/configure/channel/update.py +1 -1
  69. nat/cli/commands/configure/configure.py +1 -1
  70. nat/cli/commands/evaluate.py +87 -13
  71. nat/cli/commands/finetune.py +132 -0
  72. nat/cli/commands/info/__init__.py +1 -1
  73. nat/cli/commands/info/info.py +1 -1
  74. nat/cli/commands/info/list_channels.py +1 -1
  75. nat/cli/commands/info/list_components.py +1 -1
  76. nat/cli/commands/object_store/__init__.py +1 -1
  77. nat/cli/commands/object_store/object_store.py +1 -1
  78. nat/cli/commands/optimize.py +1 -1
  79. nat/cli/commands/{mcp → red_teaming}/__init__.py +1 -1
  80. nat/cli/commands/red_teaming/red_teaming.py +138 -0
  81. nat/cli/commands/red_teaming/red_teaming_utils.py +73 -0
  82. nat/cli/commands/registry/__init__.py +1 -1
  83. nat/cli/commands/registry/publish.py +1 -1
  84. nat/cli/commands/registry/pull.py +1 -1
  85. nat/cli/commands/registry/registry.py +1 -1
  86. nat/cli/commands/registry/remove.py +1 -1
  87. nat/cli/commands/registry/search.py +1 -1
  88. nat/cli/commands/sizing/__init__.py +1 -1
  89. nat/cli/commands/sizing/calc.py +1 -1
  90. nat/cli/commands/sizing/sizing.py +1 -1
  91. nat/cli/commands/start.py +1 -1
  92. nat/cli/commands/uninstall.py +1 -1
  93. nat/cli/commands/validate.py +1 -1
  94. nat/cli/commands/workflow/__init__.py +1 -1
  95. nat/cli/commands/workflow/workflow.py +1 -1
  96. nat/cli/commands/workflow/workflow_commands.py +3 -2
  97. nat/cli/entrypoint.py +15 -37
  98. nat/cli/main.py +2 -2
  99. nat/cli/plugin_loader.py +69 -0
  100. nat/cli/register_workflow.py +233 -5
  101. nat/cli/type_registry.py +237 -3
  102. nat/control_flow/register.py +1 -1
  103. nat/control_flow/router_agent/agent.py +1 -1
  104. nat/control_flow/router_agent/prompt.py +1 -1
  105. nat/control_flow/router_agent/register.py +1 -1
  106. nat/control_flow/sequential_executor.py +28 -7
  107. nat/data_models/__init__.py +1 -1
  108. nat/data_models/agent.py +1 -1
  109. nat/data_models/api_server.py +38 -3
  110. nat/data_models/authentication.py +1 -1
  111. nat/data_models/common.py +1 -1
  112. nat/data_models/component.py +9 -1
  113. nat/data_models/component_ref.py +45 -1
  114. nat/data_models/config.py +78 -1
  115. nat/data_models/dataset_handler.py +15 -2
  116. nat/data_models/discovery_metadata.py +1 -1
  117. nat/data_models/embedder.py +1 -1
  118. nat/data_models/evaluate.py +6 -1
  119. nat/data_models/evaluator.py +1 -1
  120. nat/data_models/finetuning.py +260 -0
  121. nat/data_models/front_end.py +1 -1
  122. nat/data_models/function.py +15 -2
  123. nat/data_models/function_dependencies.py +1 -1
  124. nat/data_models/gated_field_mixin.py +1 -1
  125. nat/data_models/interactive.py +1 -1
  126. nat/data_models/intermediate_step.py +29 -2
  127. nat/data_models/invocation_node.py +1 -1
  128. nat/data_models/llm.py +1 -1
  129. nat/data_models/logging.py +1 -1
  130. nat/data_models/memory.py +1 -1
  131. nat/data_models/middleware.py +37 -0
  132. nat/data_models/object_store.py +1 -1
  133. nat/data_models/openai_mcp.py +1 -1
  134. nat/data_models/optimizable.py +1 -1
  135. nat/data_models/optimizer.py +1 -1
  136. nat/data_models/profiler.py +1 -1
  137. nat/data_models/registry_handler.py +1 -1
  138. nat/data_models/retriever.py +1 -1
  139. nat/data_models/retry_mixin.py +1 -1
  140. nat/data_models/runtime_enum.py +26 -0
  141. nat/data_models/span.py +1 -1
  142. nat/data_models/step_adaptor.py +1 -1
  143. nat/data_models/streaming.py +1 -1
  144. nat/data_models/swe_bench_model.py +1 -1
  145. nat/data_models/telemetry_exporter.py +1 -1
  146. nat/data_models/thinking_mixin.py +1 -1
  147. nat/data_models/ttc_strategy.py +1 -1
  148. nat/embedder/azure_openai_embedder.py +1 -1
  149. nat/embedder/nim_embedder.py +1 -1
  150. nat/embedder/openai_embedder.py +1 -1
  151. nat/embedder/register.py +1 -1
  152. nat/eval/__init__.py +1 -1
  153. nat/eval/config.py +8 -1
  154. nat/eval/dataset_handler/dataset_downloader.py +1 -1
  155. nat/eval/dataset_handler/dataset_filter.py +1 -1
  156. nat/eval/dataset_handler/dataset_handler.py +4 -2
  157. nat/eval/evaluate.py +226 -81
  158. nat/eval/evaluator/__init__.py +1 -1
  159. nat/eval/evaluator/base_evaluator.py +2 -2
  160. nat/eval/evaluator/evaluator_model.py +3 -2
  161. nat/eval/intermediate_step_adapter.py +1 -1
  162. nat/eval/llm_validator.py +336 -0
  163. nat/eval/rag_evaluator/evaluate.py +17 -10
  164. nat/eval/rag_evaluator/register.py +1 -1
  165. nat/eval/red_teaming_evaluator/__init__.py +14 -0
  166. nat/eval/red_teaming_evaluator/data_models.py +66 -0
  167. nat/eval/red_teaming_evaluator/evaluate.py +327 -0
  168. nat/eval/red_teaming_evaluator/filter_conditions.py +75 -0
  169. nat/eval/red_teaming_evaluator/register.py +55 -0
  170. nat/eval/register.py +2 -1
  171. nat/eval/remote_workflow.py +1 -1
  172. nat/eval/runners/__init__.py +1 -1
  173. nat/eval/runners/config.py +1 -1
  174. nat/eval/runners/multi_eval_runner.py +1 -1
  175. nat/eval/runners/red_teaming_runner/__init__.py +24 -0
  176. nat/eval/runners/red_teaming_runner/config.py +282 -0
  177. nat/eval/runners/red_teaming_runner/report_utils.py +707 -0
  178. nat/eval/runners/red_teaming_runner/runner.py +867 -0
  179. nat/eval/runtime_evaluator/__init__.py +1 -1
  180. nat/eval/runtime_evaluator/evaluate.py +1 -1
  181. nat/eval/runtime_evaluator/register.py +1 -1
  182. nat/eval/runtime_event_subscriber.py +1 -1
  183. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  184. nat/eval/swe_bench_evaluator/register.py +1 -1
  185. nat/eval/trajectory_evaluator/evaluate.py +2 -2
  186. nat/eval/trajectory_evaluator/register.py +1 -1
  187. nat/eval/tunable_rag_evaluator/evaluate.py +5 -5
  188. nat/eval/tunable_rag_evaluator/register.py +1 -1
  189. nat/eval/usage_stats.py +1 -1
  190. nat/eval/utils/eval_trace_ctx.py +1 -1
  191. nat/eval/utils/output_uploader.py +1 -1
  192. nat/eval/utils/tqdm_position_registry.py +1 -1
  193. nat/eval/utils/weave_eval.py +1 -1
  194. nat/experimental/decorators/experimental_warning_decorator.py +1 -1
  195. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +1 -1
  196. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +1 -1
  197. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +1 -1
  198. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  199. nat/experimental/test_time_compute/functions/multi_llm_judge_function.py +88 -0
  200. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +1 -1
  201. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  202. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  203. nat/experimental/test_time_compute/models/editor_config.py +1 -1
  204. nat/experimental/test_time_compute/models/scoring_config.py +1 -1
  205. nat/experimental/test_time_compute/models/search_config.py +20 -2
  206. nat/experimental/test_time_compute/models/selection_config.py +33 -2
  207. nat/experimental/test_time_compute/models/stage_enums.py +1 -1
  208. nat/experimental/test_time_compute/models/strategy_base.py +1 -1
  209. nat/experimental/test_time_compute/models/tool_use_config.py +1 -1
  210. nat/experimental/test_time_compute/models/ttc_item.py +1 -1
  211. nat/experimental/test_time_compute/register.py +4 -1
  212. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +1 -1
  213. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +1 -1
  214. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +1 -1
  215. nat/experimental/test_time_compute/search/multi_llm_generation.py +115 -0
  216. nat/experimental/test_time_compute/search/multi_llm_planner.py +1 -1
  217. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +1 -1
  218. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +1 -1
  219. nat/experimental/test_time_compute/selection/best_of_n_selector.py +1 -1
  220. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +1 -1
  221. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  222. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +1 -1
  223. nat/experimental/test_time_compute/selection/llm_judge_selection.py +127 -0
  224. nat/experimental/test_time_compute/selection/threshold_selector.py +1 -1
  225. nat/finetuning/__init__.py +24 -0
  226. nat/finetuning/finetuning_runtime.py +143 -0
  227. nat/finetuning/interfaces/__init__.py +24 -0
  228. nat/finetuning/interfaces/finetuning_runner.py +261 -0
  229. nat/finetuning/interfaces/trainer_adapter.py +103 -0
  230. nat/finetuning/interfaces/trajectory_builder.py +115 -0
  231. nat/finetuning/utils/__init__.py +15 -0
  232. nat/finetuning/utils/parsers/__init__.py +15 -0
  233. nat/finetuning/utils/parsers/adk_parser.py +141 -0
  234. nat/finetuning/utils/parsers/base_parser.py +238 -0
  235. nat/finetuning/utils/parsers/common.py +91 -0
  236. nat/finetuning/utils/parsers/langchain_parser.py +267 -0
  237. nat/finetuning/utils/parsers/llama_index_parser.py +218 -0
  238. nat/front_ends/__init__.py +1 -1
  239. nat/front_ends/console/__init__.py +1 -1
  240. nat/front_ends/console/authentication_flow_handler.py +1 -1
  241. nat/front_ends/console/console_front_end_config.py +4 -1
  242. nat/front_ends/console/console_front_end_plugin.py +5 -4
  243. nat/front_ends/console/register.py +1 -1
  244. nat/front_ends/cron/__init__.py +1 -1
  245. nat/front_ends/fastapi/__init__.py +1 -1
  246. nat/front_ends/fastapi/async_job.py +128 -0
  247. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  248. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +13 -9
  249. nat/front_ends/fastapi/dask_client_mixin.py +1 -1
  250. nat/front_ends/fastapi/fastapi_front_end_config.py +23 -1
  251. nat/front_ends/fastapi/fastapi_front_end_controller.py +1 -1
  252. nat/front_ends/fastapi/fastapi_front_end_plugin.py +25 -30
  253. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +318 -59
  254. nat/front_ends/fastapi/html_snippets/__init__.py +1 -1
  255. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +1 -1
  256. nat/front_ends/fastapi/intermediate_steps_subscriber.py +12 -1
  257. nat/front_ends/fastapi/job_store.py +23 -11
  258. nat/front_ends/fastapi/main.py +1 -1
  259. nat/front_ends/fastapi/message_handler.py +27 -4
  260. nat/front_ends/fastapi/message_validator.py +54 -2
  261. nat/front_ends/fastapi/register.py +1 -1
  262. nat/front_ends/fastapi/response_helpers.py +16 -15
  263. nat/front_ends/fastapi/step_adaptor.py +1 -1
  264. nat/front_ends/fastapi/utils.py +1 -1
  265. nat/front_ends/register.py +1 -2
  266. nat/front_ends/simple_base/__init__.py +1 -1
  267. nat/front_ends/simple_base/simple_front_end_plugin_base.py +6 -4
  268. nat/llm/aws_bedrock_llm.py +1 -1
  269. nat/llm/azure_openai_llm.py +10 -1
  270. nat/llm/dynamo_llm.py +363 -0
  271. nat/llm/huggingface_llm.py +177 -0
  272. nat/llm/litellm_llm.py +1 -1
  273. nat/llm/nim_llm.py +1 -1
  274. nat/llm/openai_llm.py +1 -1
  275. nat/llm/register.py +3 -1
  276. nat/llm/utils/__init__.py +1 -1
  277. nat/llm/utils/env_config_value.py +1 -1
  278. nat/llm/utils/error.py +1 -1
  279. nat/llm/utils/thinking.py +1 -1
  280. nat/memory/__init__.py +1 -1
  281. nat/memory/interfaces.py +1 -1
  282. nat/memory/models.py +1 -1
  283. nat/meta/pypi.md +1 -1
  284. nat/middleware/__init__.py +35 -0
  285. nat/middleware/cache/__init__.py +14 -0
  286. nat/middleware/cache/cache_middleware.py +253 -0
  287. nat/middleware/cache/cache_middleware_config.py +44 -0
  288. nat/middleware/cache/register.py +33 -0
  289. nat/middleware/defense/__init__.py +14 -0
  290. nat/middleware/defense/defense_middleware.py +362 -0
  291. nat/middleware/defense/defense_middleware_content_guard.py +455 -0
  292. nat/middleware/defense/defense_middleware_data_models.py +91 -0
  293. nat/middleware/defense/defense_middleware_output_verifier.py +440 -0
  294. nat/middleware/defense/defense_middleware_pii.py +356 -0
  295. nat/middleware/defense/register.py +82 -0
  296. nat/middleware/dynamic/__init__.py +14 -0
  297. nat/middleware/dynamic/dynamic_function_middleware.py +962 -0
  298. nat/middleware/dynamic/dynamic_middleware_config.py +132 -0
  299. nat/middleware/dynamic/register.py +34 -0
  300. nat/middleware/function_middleware.py +370 -0
  301. nat/middleware/logging/__init__.py +14 -0
  302. nat/middleware/logging/logging_middleware.py +67 -0
  303. nat/middleware/logging/logging_middleware_config.py +28 -0
  304. nat/middleware/logging/register.py +33 -0
  305. nat/middleware/middleware.py +298 -0
  306. nat/middleware/red_teaming/__init__.py +14 -0
  307. nat/middleware/red_teaming/red_teaming_middleware.py +344 -0
  308. nat/middleware/red_teaming/red_teaming_middleware_config.py +112 -0
  309. nat/middleware/red_teaming/register.py +47 -0
  310. nat/middleware/register.py +22 -0
  311. nat/middleware/utils/__init__.py +14 -0
  312. nat/middleware/utils/workflow_inventory.py +155 -0
  313. nat/object_store/__init__.py +1 -1
  314. nat/object_store/in_memory_object_store.py +1 -1
  315. nat/object_store/interfaces.py +1 -1
  316. nat/object_store/models.py +1 -1
  317. nat/object_store/register.py +1 -1
  318. nat/observability/__init__.py +1 -1
  319. nat/observability/exporter/__init__.py +1 -1
  320. nat/observability/exporter/base_exporter.py +1 -1
  321. nat/observability/exporter/exporter.py +1 -1
  322. nat/observability/exporter/file_exporter.py +1 -1
  323. nat/observability/exporter/processing_exporter.py +1 -1
  324. nat/observability/exporter/raw_exporter.py +1 -1
  325. nat/observability/exporter/span_exporter.py +7 -1
  326. nat/observability/exporter_manager.py +1 -1
  327. nat/observability/mixin/__init__.py +1 -1
  328. nat/observability/mixin/batch_config_mixin.py +1 -1
  329. nat/observability/mixin/collector_config_mixin.py +1 -1
  330. nat/observability/mixin/file_mixin.py +1 -1
  331. nat/observability/mixin/file_mode.py +1 -1
  332. nat/observability/mixin/redaction_config_mixin.py +1 -1
  333. nat/observability/mixin/resource_conflict_mixin.py +1 -1
  334. nat/observability/mixin/serialize_mixin.py +1 -1
  335. nat/observability/mixin/tagging_config_mixin.py +1 -1
  336. nat/observability/mixin/type_introspection_mixin.py +1 -1
  337. nat/observability/processor/__init__.py +1 -1
  338. nat/observability/processor/batching_processor.py +1 -1
  339. nat/observability/processor/callback_processor.py +1 -1
  340. nat/observability/processor/falsy_batch_filter_processor.py +1 -1
  341. nat/observability/processor/intermediate_step_serializer.py +1 -1
  342. nat/observability/processor/processor.py +1 -1
  343. nat/observability/processor/processor_factory.py +1 -1
  344. nat/observability/processor/redaction/__init__.py +1 -1
  345. nat/observability/processor/redaction/contextual_redaction_processor.py +1 -1
  346. nat/observability/processor/redaction/contextual_span_redaction_processor.py +1 -1
  347. nat/observability/processor/redaction/redaction_processor.py +1 -1
  348. nat/observability/processor/redaction/span_header_redaction_processor.py +1 -1
  349. nat/observability/processor/span_tagging_processor.py +1 -1
  350. nat/observability/register.py +1 -1
  351. nat/observability/utils/__init__.py +1 -1
  352. nat/observability/utils/dict_utils.py +1 -1
  353. nat/observability/utils/time_utils.py +1 -1
  354. nat/profiler/calc/__init__.py +1 -1
  355. nat/profiler/calc/calc_runner.py +3 -3
  356. nat/profiler/calc/calculations.py +1 -1
  357. nat/profiler/calc/data_models.py +1 -1
  358. nat/profiler/calc/plot.py +30 -3
  359. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  360. nat/profiler/callbacks/base_callback_class.py +1 -1
  361. nat/profiler/callbacks/langchain_callback_handler.py +33 -3
  362. nat/profiler/callbacks/llama_index_callback_handler.py +13 -10
  363. nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
  364. nat/profiler/callbacks/token_usage_base_model.py +1 -1
  365. nat/profiler/data_frame_row.py +1 -1
  366. nat/profiler/data_models.py +1 -1
  367. nat/profiler/decorators/framework_wrapper.py +32 -1
  368. nat/profiler/decorators/function_tracking.py +1 -1
  369. nat/profiler/forecasting/config.py +1 -1
  370. nat/profiler/forecasting/model_trainer.py +1 -1
  371. nat/profiler/forecasting/models/__init__.py +1 -1
  372. nat/profiler/forecasting/models/forecasting_base_model.py +1 -1
  373. nat/profiler/forecasting/models/linear_model.py +1 -1
  374. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  375. nat/profiler/inference_metrics_model.py +1 -1
  376. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  377. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  378. nat/profiler/inference_optimization/data_models.py +1 -1
  379. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +1 -1
  380. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  381. nat/profiler/inference_optimization/llm_metrics.py +1 -1
  382. nat/profiler/inference_optimization/prompt_caching.py +1 -1
  383. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  384. nat/profiler/inference_optimization/workflow_runtimes.py +1 -1
  385. nat/profiler/intermediate_property_adapter.py +1 -1
  386. nat/profiler/parameter_optimization/optimizable_utils.py +1 -1
  387. nat/profiler/parameter_optimization/optimizer_runtime.py +1 -1
  388. nat/profiler/parameter_optimization/parameter_optimizer.py +1 -1
  389. nat/profiler/parameter_optimization/parameter_selection.py +1 -1
  390. nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
  391. nat/profiler/parameter_optimization/prompt_optimizer.py +1 -1
  392. nat/profiler/parameter_optimization/update_helpers.py +1 -1
  393. nat/profiler/profile_runner.py +1 -1
  394. nat/profiler/utils.py +1 -1
  395. nat/registry_handlers/local/local_handler.py +1 -1
  396. nat/registry_handlers/local/register_local.py +1 -1
  397. nat/registry_handlers/metadata_factory.py +1 -1
  398. nat/registry_handlers/package_utils.py +1 -1
  399. nat/registry_handlers/pypi/pypi_handler.py +1 -1
  400. nat/registry_handlers/pypi/register_pypi.py +1 -1
  401. nat/registry_handlers/register.py +1 -1
  402. nat/registry_handlers/registry_handler_base.py +1 -1
  403. nat/registry_handlers/rest/register_rest.py +1 -1
  404. nat/registry_handlers/rest/rest_handler.py +1 -1
  405. nat/registry_handlers/schemas/headers.py +1 -1
  406. nat/registry_handlers/schemas/package.py +1 -1
  407. nat/registry_handlers/schemas/publish.py +1 -1
  408. nat/registry_handlers/schemas/pull.py +1 -1
  409. nat/registry_handlers/schemas/remove.py +1 -1
  410. nat/registry_handlers/schemas/search.py +1 -1
  411. nat/registry_handlers/schemas/status.py +1 -1
  412. nat/retriever/interface.py +1 -1
  413. nat/retriever/milvus/__init__.py +1 -1
  414. nat/retriever/milvus/register.py +12 -4
  415. nat/retriever/milvus/retriever.py +103 -41
  416. nat/retriever/models.py +1 -1
  417. nat/retriever/nemo_retriever/__init__.py +1 -1
  418. nat/retriever/nemo_retriever/register.py +1 -1
  419. nat/retriever/nemo_retriever/retriever.py +5 -5
  420. nat/retriever/register.py +1 -1
  421. nat/runtime/__init__.py +1 -1
  422. nat/runtime/loader.py +10 -3
  423. nat/runtime/metrics.py +180 -0
  424. nat/runtime/runner.py +13 -6
  425. nat/runtime/session.py +458 -32
  426. nat/runtime/user_metadata.py +1 -1
  427. nat/settings/global_settings.py +1 -1
  428. nat/tool/chat_completion.py +1 -1
  429. nat/tool/code_execution/README.md +1 -1
  430. nat/tool/code_execution/code_sandbox.py +2 -2
  431. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +1 -1
  432. nat/tool/code_execution/local_sandbox/__init__.py +1 -1
  433. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  434. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +1 -1
  435. nat/tool/code_execution/register.py +1 -1
  436. nat/tool/code_execution/utils.py +1 -1
  437. nat/tool/datetime_tools.py +1 -1
  438. nat/tool/document_search.py +1 -1
  439. nat/tool/github_tools.py +1 -1
  440. nat/tool/memory_tools/add_memory_tool.py +1 -1
  441. nat/tool/memory_tools/delete_memory_tool.py +1 -1
  442. nat/tool/memory_tools/get_memory_tool.py +1 -1
  443. nat/tool/nvidia_rag.py +2 -2
  444. nat/tool/register.py +1 -1
  445. nat/tool/retriever.py +1 -1
  446. nat/tool/server_tools.py +1 -1
  447. nat/utils/__init__.py +8 -5
  448. nat/utils/callable_utils.py +1 -1
  449. nat/utils/data_models/schema_validator.py +1 -1
  450. nat/utils/debugging_utils.py +1 -1
  451. nat/utils/decorators.py +1 -1
  452. nat/utils/dump_distro_mapping.py +1 -1
  453. nat/utils/exception_handlers/automatic_retries.py +3 -3
  454. nat/utils/exception_handlers/schemas.py +1 -1
  455. nat/utils/io/model_processing.py +1 -1
  456. nat/utils/io/supress_logs.py +33 -0
  457. nat/utils/io/yaml_tools.py +1 -1
  458. nat/utils/log_levels.py +1 -1
  459. nat/utils/log_utils.py +13 -1
  460. nat/utils/metadata_utils.py +1 -1
  461. nat/utils/optional_imports.py +1 -1
  462. nat/utils/producer_consumer_queue.py +1 -1
  463. nat/utils/reactive/base/observable_base.py +1 -1
  464. nat/utils/reactive/base/observer_base.py +1 -1
  465. nat/utils/reactive/base/subject_base.py +1 -1
  466. nat/utils/reactive/observable.py +1 -1
  467. nat/utils/reactive/observer.py +1 -1
  468. nat/utils/reactive/subject.py +1 -1
  469. nat/utils/reactive/subscription.py +1 -1
  470. nat/utils/responses_api.py +1 -1
  471. nat/utils/settings/global_settings.py +1 -1
  472. nat/utils/string_utils.py +1 -1
  473. nat/utils/type_converter.py +18 -5
  474. nat/utils/type_utils.py +1 -1
  475. nat/utils/url_utils.py +1 -1
  476. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/METADATA +46 -15
  477. nvidia_nat-1.4.0a20260113.dist-info/RECORD +547 -0
  478. nvidia_nat-1.4.0a20260113.dist-info/entry_points.txt +38 -0
  479. nat/cli/commands/mcp/mcp.py +0 -986
  480. nat/front_ends/mcp/introspection_token_verifier.py +0 -73
  481. nat/front_ends/mcp/mcp_front_end_config.py +0 -109
  482. nat/front_ends/mcp/mcp_front_end_plugin.py +0 -151
  483. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +0 -362
  484. nat/front_ends/mcp/memory_profiler.py +0 -320
  485. nat/front_ends/mcp/register.py +0 -27
  486. nat/front_ends/mcp/tool_converter.py +0 -321
  487. nvidia_nat-1.4.0a20251112.dist-info/RECORD +0 -481
  488. nvidia_nat-1.4.0a20251112.dist-info/entry_points.txt +0 -22
  489. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/WHEEL +0 -0
  490. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  491. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE.md +0 -0
  492. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -177,3 +177,47 @@ class TTCStrategyRef(ComponentRef):
177
177
  @override
178
178
  def component_group(self):
179
179
  return ComponentGroup.TTC_STRATEGIES
180
+
181
+
182
+ class MiddlewareRef(ComponentRef):
183
+ """
184
+ A reference to middleware in a NAT configuration object.
185
+ """
186
+
187
+ @property
188
+ @override
189
+ def component_group(self):
190
+ return ComponentGroup.MIDDLEWARE
191
+
192
+
193
+ class TrainerRef(ComponentRef):
194
+ """
195
+ A reference to a trainer in a NAT configuration object.
196
+ """
197
+
198
+ @property
199
+ @override
200
+ def component_group(self):
201
+ return ComponentGroup.TRAINERS
202
+
203
+
204
+ class TrajectoryBuilderRef(ComponentRef):
205
+ """
206
+ A reference to a trajectory builder in a NAT configuration object.
207
+ """
208
+
209
+ @property
210
+ @override
211
+ def component_group(self):
212
+ return ComponentGroup.TRAJECTORY_BUILDERS
213
+
214
+
215
+ class TrainerAdapterRef(ComponentRef):
216
+ """
217
+ A reference to a trainer adapter in a NAT configuration object.
218
+ """
219
+
220
+ @property
221
+ @override
222
+ def component_group(self):
223
+ return ComponentGroup.TRAINER_ADAPTERS
nat/data_models/config.py CHANGED
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,6 +16,7 @@
16
16
  import logging
17
17
  import sys
18
18
  import typing
19
+ from datetime import timedelta
19
20
 
20
21
  from pydantic import BaseModel
21
22
  from pydantic import ConfigDict
@@ -27,6 +28,10 @@ from pydantic import ValidatorFunctionWrapHandler
27
28
  from pydantic import field_validator
28
29
 
29
30
  from nat.data_models.evaluate import EvalConfig
31
+ from nat.data_models.finetuning import FinetuneConfig
32
+ from nat.data_models.finetuning import TrainerAdapterConfig
33
+ from nat.data_models.finetuning import TrainerConfig
34
+ from nat.data_models.finetuning import TrajectoryBuilderConfig
30
35
  from nat.data_models.front_end import FrontEndBaseConfig
31
36
  from nat.data_models.function import EmptyFunctionConfig
32
37
  from nat.data_models.function import FunctionBaseConfig
@@ -43,6 +48,7 @@ from .common import TypedBaseModel
43
48
  from .embedder import EmbedderBaseConfig
44
49
  from .llm import LLMBaseConfig
45
50
  from .memory import MemoryBaseConfig
51
+ from .middleware import FunctionMiddlewareBaseConfig
46
52
  from .object_store import ObjectStoreBaseConfig
47
53
  from .retriever import RetrieverBaseConfig
48
54
 
@@ -86,6 +92,14 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
86
92
  registered_keys = GlobalTypeRegistry.get().get_registered_front_ends()
87
93
  elif (info.field_name == "ttc_strategies"):
88
94
  registered_keys = GlobalTypeRegistry.get().get_registered_ttc_strategies()
95
+ elif (info.field_name == "middleware"):
96
+ registered_keys = GlobalTypeRegistry.get().get_registered_middleware()
97
+ elif (info.field_name == "trainers"):
98
+ registered_keys = GlobalTypeRegistry.get().get_registered_trainers()
99
+ elif (info.field_name == "trainer_adapters"):
100
+ registered_keys = GlobalTypeRegistry.get().get_registered_trainer_adapters()
101
+ elif (info.field_name == "trajectory_builders"):
102
+ registered_keys = GlobalTypeRegistry.get().get_registered_trajectory_builders()
89
103
 
90
104
  else:
91
105
  assert False, f"Unknown field name {info.field_name} in validator"
@@ -201,6 +215,19 @@ class GeneralConfig(BaseModel):
201
215
 
202
216
  telemetry: TelemetryConfig = TelemetryConfig()
203
217
 
218
+ per_user_workflow_timeout: timedelta = Field(
219
+ default=timedelta(minutes=30),
220
+ description="Time after which inactive per-user workflows are cleaned up. "
221
+ "Only applies when workflow is per-user. Defaults to 30 minutes.")
222
+ per_user_workflow_cleanup_interval: timedelta = Field(
223
+ default=timedelta(minutes=5),
224
+ description="Interval for running cleanup of inactive per-user workflows. "
225
+ "Only applies when workflow is per-user. Defaults to 5 minutes.")
226
+ enable_per_user_monitoring: bool = Field(
227
+ default=False,
228
+ description="Enable the /monitor/users endpoint for per-user workflow resource monitoring. "
229
+ "When enabled, exposes metrics like request counts, latency, LLM usage, and memory for each user.")
230
+
204
231
  # FrontEnd Configuration
205
232
  front_end: FrontEndBaseConfig = FastApiFrontEndConfig()
206
233
 
@@ -253,6 +280,9 @@ class Config(HashableBaseModel):
253
280
  # Function Groups Configuration
254
281
  function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
255
282
 
283
+ # Middleware Configuration
284
+ middleware: dict[str, FunctionMiddlewareBaseConfig] = Field(default_factory=dict)
285
+
256
286
  # LLMs Configuration
257
287
  llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
258
288
 
@@ -283,6 +313,12 @@ class Config(HashableBaseModel):
283
313
  # Evaluation Options
284
314
  eval: EvalConfig = EvalConfig()
285
315
 
316
+ # Finetuning Options
317
+ trainers: dict[str, TrainerConfig] = Field(default_factory=dict)
318
+ trainer_adapters: dict[str, TrainerAdapterConfig] = Field(default_factory=dict)
319
+ trajectory_builders: dict[str, TrajectoryBuilderConfig] = Field(default_factory=dict)
320
+ finetuning: FinetuneConfig = FinetuneConfig()
321
+
286
322
  def print_summary(self, stream: typing.TextIO = sys.stdout):
287
323
  """Print a summary of the configuration"""
288
324
 
@@ -303,6 +339,7 @@ class Config(HashableBaseModel):
303
339
 
304
340
  @field_validator("functions",
305
341
  "function_groups",
342
+ "middleware",
306
343
  "llms",
307
344
  "embedders",
308
345
  "memory",
@@ -310,6 +347,9 @@ class Config(HashableBaseModel):
310
347
  "workflow",
311
348
  "ttc_strategies",
312
349
  "authentication",
350
+ "trainers",
351
+ "trainer_adapters",
352
+ "trajectory_builders",
313
353
  mode="wrap")
314
354
  @classmethod
315
355
  def validate_components(cls, value: typing.Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
@@ -348,6 +388,10 @@ class Config(HashableBaseModel):
348
388
  typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
349
389
  Discriminator(TypedBaseModel.discriminator)]]
350
390
 
391
+ MiddlewareAnnotation = dict[str,
392
+ typing.Annotated[type_registry.compute_annotation(FunctionMiddlewareBaseConfig),
393
+ Discriminator(TypedBaseModel.discriminator)]]
394
+
351
395
  MemoryAnnotation = dict[str,
352
396
  typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
353
397
  Discriminator(TypedBaseModel.discriminator)]]
@@ -366,6 +410,18 @@ class Config(HashableBaseModel):
366
410
  WorkflowAnnotation = typing.Annotated[(type_registry.compute_annotation(FunctionBaseConfig)),
367
411
  Discriminator(TypedBaseModel.discriminator)]
368
412
 
413
+ TrainersAnnotation = dict[str,
414
+ typing.Annotated[type_registry.compute_annotation(TrainerConfig),
415
+ Discriminator(TypedBaseModel.discriminator)]]
416
+
417
+ TrainerAdaptersAnnotation = dict[str,
418
+ typing.Annotated[type_registry.compute_annotation(TrainerAdapterConfig),
419
+ Discriminator(TypedBaseModel.discriminator)]]
420
+
421
+ TrajectoryBuildersAnnotation = dict[str,
422
+ typing.Annotated[type_registry.compute_annotation(TrajectoryBuilderConfig),
423
+ Discriminator(TypedBaseModel.discriminator)]]
424
+
369
425
  should_rebuild = False
370
426
 
371
427
  auth_providers_field = cls.model_fields.get("authentication")
@@ -393,6 +449,11 @@ class Config(HashableBaseModel):
393
449
  function_groups_field.annotation = FunctionGroupsAnnotation
394
450
  should_rebuild = True
395
451
 
452
+ middleware_field = cls.model_fields.get("middleware")
453
+ if (middleware_field is not None and middleware_field.annotation != MiddlewareAnnotation):
454
+ middleware_field.annotation = MiddlewareAnnotation
455
+ should_rebuild = True
456
+
396
457
  memory_field = cls.model_fields.get("memory")
397
458
  if memory_field is not None and memory_field.annotation != MemoryAnnotation:
398
459
  memory_field.annotation = MemoryAnnotation
@@ -418,6 +479,22 @@ class Config(HashableBaseModel):
418
479
  workflow_field.annotation = WorkflowAnnotation
419
480
  should_rebuild = True
420
481
 
482
+ trainers_field = cls.model_fields.get("trainers")
483
+ if trainers_field is not None and trainers_field.annotation != TrainersAnnotation:
484
+ trainers_field.annotation = TrainersAnnotation
485
+ should_rebuild = True
486
+
487
+ trainer_adapters_field = cls.model_fields.get("trainer_adapters")
488
+ if trainer_adapters_field is not None and trainer_adapters_field.annotation != TrainerAdaptersAnnotation:
489
+ trainer_adapters_field.annotation = TrainerAdaptersAnnotation
490
+ should_rebuild = True
491
+
492
+ trajectory_builders_field = cls.model_fields.get("trajectory_builders")
493
+ if (trajectory_builders_field is not None
494
+ and trajectory_builders_field.annotation != TrajectoryBuildersAnnotation):
495
+ trajectory_builders_field.annotation = TrajectoryBuildersAnnotation
496
+ should_rebuild = True
497
+
421
498
  if (GeneralConfig.rebuild_annotations()):
422
499
  should_rebuild = True
423
500
 
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,7 +12,6 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
16
15
  import importlib
17
16
  import json
18
17
  import typing
@@ -21,6 +20,7 @@ from pathlib import Path
21
20
 
22
21
  import pandas as pd
23
22
  from pydantic import BaseModel
23
+ from pydantic import ConfigDict
24
24
  from pydantic import Discriminator
25
25
  from pydantic import FilePath
26
26
  from pydantic import Tag
@@ -32,6 +32,8 @@ from nat.data_models.common import TypedBaseModel
32
32
 
33
33
  class EvalS3Config(BaseModel):
34
34
 
35
+ model_config = ConfigDict(extra="forbid")
36
+
35
37
  endpoint_url: str | None = None
36
38
  region_name: str | None = None
37
39
  bucket: str
@@ -40,16 +42,25 @@ class EvalS3Config(BaseModel):
40
42
 
41
43
 
42
44
  class EvalFilterEntryConfig(BaseModel):
45
+
46
+ model_config = ConfigDict(extra="forbid")
47
+
43
48
  # values are lists of allowed/blocked values
44
49
  field: dict[str, list[str | int | float]] = {}
45
50
 
46
51
 
47
52
  class EvalFilterConfig(BaseModel):
53
+
54
+ model_config = ConfigDict(extra="forbid")
55
+
48
56
  allowlist: EvalFilterEntryConfig | None = None
49
57
  denylist: EvalFilterEntryConfig | None = None
50
58
 
51
59
 
52
60
  class EvalDatasetStructureConfig(BaseModel):
61
+
62
+ model_config = ConfigDict(extra="forbid")
63
+
53
64
  disable: bool = False
54
65
  question_key: str = "question"
55
66
  answer_key: str = "answer"
@@ -61,6 +72,8 @@ class EvalDatasetStructureConfig(BaseModel):
61
72
  # Base model
62
73
  class EvalDatasetBaseConfig(TypedBaseModel, BaseModelRegistryTag):
63
74
 
75
+ model_config = ConfigDict(extra="forbid")
76
+
64
77
  id_key: str = "id"
65
78
  structure: EvalDatasetStructureConfig = EvalDatasetStructureConfig()
66
79
 
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -91,6 +91,11 @@ class EvalGeneralConfig(BaseModel):
91
91
  # Inference profiler
92
92
  profiler: ProfilerConfig | None = None
93
93
 
94
+ # When enabled, validates that all LLM endpoints are accessible before starting evaluation.
95
+ # This catches deployment issues early (e.g., 404 errors from canceled training jobs).
96
+ # Recommended for production workflows. Opt-in for now, may become default in future.
97
+ validate_llm_endpoints: bool = False
98
+
94
99
  # overwrite the output_dir with the output config if present
95
100
  @model_validator(mode="before")
96
101
  @classmethod
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -0,0 +1,260 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import typing
18
+ from enum import Enum
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ from pydantic import BaseModel
23
+ from pydantic import Field
24
+ from pydantic import model_validator
25
+
26
+ from .common import BaseModelRegistryTag
27
+ from .common import TypedBaseModel
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class RewardFunctionConfig(BaseModel):
33
+ """
34
+ Configuration for the reward function
35
+ """
36
+ name: str = Field(description="Name of the reward function.")
37
+
38
+
39
+ class TrainerConfig(TypedBaseModel, BaseModelRegistryTag):
40
+ """
41
+ Base configuration for the Trainer
42
+ """
43
+ reward: RewardFunctionConfig | None = Field(
44
+ description="Configuration for the reward function used during training.", default=None)
45
+
46
+
47
+ class TrajectoryBuilderConfig(TypedBaseModel, BaseModelRegistryTag):
48
+ """
49
+ Configuration for the trajectory collector
50
+ """
51
+ reward: RewardFunctionConfig | None = Field(
52
+ description="Configuration for the reward function used during trajectory building.", default=None)
53
+
54
+
55
+ class TrainerAdapterConfig(TypedBaseModel, BaseModelRegistryTag):
56
+ """
57
+ Configuration for the trainer adapter
58
+ """
59
+ reward: RewardFunctionConfig | None = Field(
60
+ description="Configuration for the reward function used during training.", default=None)
61
+
62
+
63
+ TrainerConfigT = typing.TypeVar("TrainerConfigT", bound=TrainerConfig)
64
+ TrajectoryBuilderConfigT = typing.TypeVar("TrajectoryBuilderConfigT", bound=TrajectoryBuilderConfig)
65
+ TrainerAdapterConfigT = typing.TypeVar("TrainerAdapterConfigT", bound=TrainerAdapterConfig)
66
+
67
+
68
+ class TrainingJobRef(BaseModel):
69
+ """
70
+ A reference to a training job.
71
+ """
72
+ run_id: str = Field(description="The ID of the run this job belongs to.")
73
+ backend: str = Field(description="The backend used for the training job.")
74
+ metadata: dict | None = Field(description="Any additional metadata for the training job.", default=None)
75
+
76
+
77
+ class TrainingStatusEnum(str, Enum):
78
+ PENDING = "pending"
79
+ RUNNING = "running"
80
+ COMPLETED = "completed"
81
+ FAILED = "failed"
82
+ CANCELED = "canceled"
83
+
84
+
85
+ class TrainingJobStatus(BaseModel):
86
+ """
87
+ The status of a training job.
88
+ """
89
+ run_id: str = Field(description="The ID of the run this job belongs to.")
90
+ backend: str = Field(description="The backend used for the training job.")
91
+ status: TrainingStatusEnum = Field(description="The current status of the training job.")
92
+ progress: float | None = Field(description="The progress of the training job as a percentage (0.0 to 100.0).",
93
+ default=None)
94
+ message: str | None = Field(description="Any additional message or information about the training job.",
95
+ default=None)
96
+ metadata: dict | None = Field(description="Any additional metadata for the training job.", default=None)
97
+
98
+
99
+ class EpisodeItemRole(str, Enum):
100
+ USER = "user"
101
+ ASSISTANT = "assistant"
102
+ SYSTEM = "system"
103
+ FUNCTION = "function"
104
+ TOOL = "tool"
105
+ ENVIRONMENT = "environment"
106
+ OTHER = "other"
107
+
108
+
109
+ class EpisodeItem(BaseModel):
110
+ """
111
+ A single step in an episode.
112
+ """
113
+ role: EpisodeItemRole = Field(description="The role of the agent (e.g., 'user', 'assistant').")
114
+ content: str = Field(description="The content of the message.")
115
+ logprobs: Any | None = Field(description="The log probabilities of the tokens in the message.", default=None)
116
+ metadata: dict | None = Field(description="Any additional metadata for the step.", default=None)
117
+
118
+ # Add model validator after construction that checks that logprobs can't be none of role is assistant
119
+ @model_validator(mode="after")
120
+ def check_logprobs(self) -> "EpisodeItem":
121
+ if self.role == EpisodeItemRole.ASSISTANT and self.logprobs is None:
122
+ raise ValueError("logprobs must be provided for assistant role.")
123
+ return self
124
+
125
+
126
+ class OpenAIMessage(BaseModel):
127
+ """
128
+ A message in the OpenAI chat format.
129
+ """
130
+ role: str = Field(description="The role of the message (e.g., 'user', 'assistant').")
131
+ content: str = Field(description="The content of the message.")
132
+
133
+
134
+ class DPOItem(BaseModel):
135
+ """
136
+ A single step in an episode for DPO training.
137
+ """
138
+ prompt: list[OpenAIMessage] | str = Field(description="The prompt messages leading to the response.")
139
+ chosen_response: str = Field(description="The response chosen as better by the reward model.")
140
+ rejected_response: str = Field(description="The response rejected as worse by the reward model.")
141
+
142
+
143
+ class Trajectory(BaseModel):
144
+ """
145
+ A trajectory is a sequence of states, actions, and rewards.
146
+ """
147
+ episode: list[EpisodeItem] | list[DPOItem] = Field(description="A list of steps in the episode.")
148
+ reward: float = Field(description="The total reward for the episode.")
149
+ shaped_rewards: list[float] | None = Field(description="The shaped rewards for each step in the episode.",
150
+ default=None)
151
+ metadata: dict | None = Field(description="Any additional metadata for the trajectory.", default=None)
152
+
153
+
154
+ class TrajectoryCollection(BaseModel):
155
+ """
156
+ A collection of trajectories.
157
+ """
158
+ trajectories: list[list[Trajectory]] = Field(
159
+ description="A list of trajectory lists, each inner list contains trajectories for one example.")
160
+ run_id: str = Field(description="The ID of the run this collection belongs to.")
161
+
162
+
163
+ class CurriculumLearningConfig(BaseModel):
164
+ """
165
+ Configuration for curriculum learning in fine-tuning.
166
+
167
+ Curriculum learning progressively introduces harder training examples
168
+ to improve model learning and convergence.
169
+ """
170
+ enabled: bool = Field(default=False, description="Whether to enable curriculum learning")
171
+ initial_percentile: float = Field(default=0.3,
172
+ description="Initial percentile of trajectory groups to include (0.0-1.0). "
173
+ "E.g., 0.3 means start with top 30% easiest groups")
174
+ increment_percentile: float = Field(default=0.2,
175
+ description="Percentile increment when expanding curriculum. "
176
+ "E.g., 0.2 means add 20% more groups each expansion")
177
+ expansion_interval: int = Field(default=5, description="Number of epochs between curriculum expansions", ge=1)
178
+ min_reward_diff: float = Field(default=0.1,
179
+ description="Minimum reward difference within a group to be included. "
180
+ "Groups with all same rewards provide no learning signal")
181
+ sort_ascending: bool = Field(default=False,
182
+ description="If True, sort groups from low to high reward (hard to easy). "
183
+ "If False, sort from high to low reward (easy to hard)")
184
+
185
+ random_subsample: float | None = Field(
186
+ default=None, description="If set, randomly subsample this fraction of trajectories from each group.")
187
+
188
+ @model_validator(mode="after")
189
+ def validate_percentiles(self) -> "CurriculumLearningConfig":
190
+ """Validate that percentile values are in valid range."""
191
+ if not 0.0 < self.initial_percentile <= 1.0:
192
+ raise ValueError("initial_percentile must be between 0 and 1")
193
+ if not 0.0 < self.increment_percentile <= 1.0:
194
+ raise ValueError("increment_percentile must be between 0 and 1")
195
+ return self
196
+
197
+
198
+ class FinetuneRunConfig(BaseModel):
199
+ """
200
+ CLI Args for running finetuning and configuring
201
+ """
202
+ config_file: Path | BaseModel = Field(description="Config file for NAT", default=None)
203
+ dataset: str | Path | None = None # dataset file path can be specified in the config file
204
+ result_json_path: str = "$"
205
+ endpoint: str | None = None # only used when running the workflow remotely
206
+ endpoint_timeout: int = 300
207
+ override: tuple[tuple[str, str], ...] = ()
208
+ validation_dataset: str | Path | None = Field(default=None,
209
+ description="Validation dataset file path for periodic validation")
210
+
211
+ validation_interval: int = Field(default=5, description="Run validation every N epochs", ge=1)
212
+
213
+ validation_config_file: str | Path | None = Field(default=None,
214
+ description="Optional separate config file for validation runs")
215
+
216
+
217
+ class FinetuneConfig(BaseModel):
218
+ """
219
+ Parameters used for a Trainer run
220
+ """
221
+
222
+ enabled: bool = Field(description="Whether fine-tuning is enabled.", default=False)
223
+ trainer: str | None = Field(description="The trainer to use for fine-tuning.", default=None)
224
+ trajectory_builder: str | None = Field(description="The trajectory builder to use for fine-tuning.", default=None)
225
+
226
+ trainer_adapter: str | None = Field(description="The trainer adapter to use for fine-tuning.", default=None)
227
+ reward_function: RewardFunctionConfig | None = Field(description="Configuration for the reward function.",
228
+ default=None)
229
+ target_functions: list[str] = ["<workflow>"]
230
+ target_model: str | None = Field(
231
+ description="Target model name to fine-tune. If None, all intermediate steps will be used without "
232
+ "filtering. This can lead to issues if multiple models are used in the workflow.",
233
+ default=None)
234
+ curriculum_learning: CurriculumLearningConfig = Field(
235
+ default=CurriculumLearningConfig(), description="Configuration for curriculum learning during fine-tuning")
236
+
237
+ num_epochs: int = Field(default=1, description="Number of epochs to run", ge=1)
238
+ output_dir: Path = Field(default=Path("./.tmp/nat/finetuning/"),
239
+ description="Directory for outputs and checkpoints")
240
+
241
+ # Overridden by command line args
242
+ run_configuration: FinetuneRunConfig | None = Field(
243
+ description="Run-time configuration for fine-tuning (overrides CLI arguments).", default=None)
244
+
245
+ # Before validator: if enabled, config file, trainer, trajectory builder, trainer adapter and reward
246
+ # function must be set
247
+ @model_validator(mode="before")
248
+ def validate_finetuning_enabled(cls, values: dict[str, Any]) -> dict[str, Any]:
249
+ if values.get("enabled", False):
250
+ required_fields = ["trainer", "trajectory_builder", "trainer_adapter"]
251
+ missing_fields = [field for field in required_fields if values.get(field) is None]
252
+ if missing_fields:
253
+ raise ValueError(f"When fine-tuning is enabled, the following fields must be set: "
254
+ f"{', '.join(missing_fields)}")
255
+
256
+ # Warn user their config will be overridden by CLI args
257
+ if "run_configuration" in values and values["run_configuration"] is not None:
258
+ logger.warning("run_configuration will be overridden by CLI arguments during finetuning run.")
259
+
260
+ return values
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,7 +24,16 @@ from .common import TypedBaseModel
24
24
 
25
25
 
26
26
  class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
27
- pass
27
+ """Base configuration for functions.
28
+
29
+ Attributes:
30
+ middleware: List of function middleware names to apply to this function.
31
+ These must match names defined in the `middleware` section of the YAML configuration.
32
+ """
33
+ middleware: list[str] = Field(
34
+ default_factory=list,
35
+ description="List of function middleware names to apply to this function in order",
36
+ )
28
37
 
29
38
 
30
39
  class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
@@ -40,6 +49,10 @@ class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
40
49
  default_factory=list,
41
50
  description="The list of function names which should be excluded from default access to the group",
42
51
  )
52
+ middleware: list[str] = Field(
53
+ default_factory=list,
54
+ description="List of function middleware names to apply to all functions in this group",
55
+ )
43
56
 
44
57
  @field_validator("include", "exclude")
45
58
  @classmethod
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");