nvidia-nat 1.4.0a20251120__py3-none-any.whl → 1.4.0a20260113__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (492) hide show
  1. aiq/__init__.py +1 -1
  2. nat/{front_ends/mcp → agent/auto_memory_wrapper}/__init__.py +1 -1
  3. nat/agent/auto_memory_wrapper/agent.py +278 -0
  4. nat/agent/auto_memory_wrapper/register.py +227 -0
  5. nat/agent/auto_memory_wrapper/state.py +30 -0
  6. nat/agent/base.py +1 -1
  7. nat/agent/dual_node.py +1 -1
  8. nat/agent/prompt_optimizer/prompt.py +1 -1
  9. nat/agent/prompt_optimizer/register.py +1 -1
  10. nat/agent/react_agent/agent.py +16 -9
  11. nat/agent/react_agent/output_parser.py +2 -2
  12. nat/agent/react_agent/prompt.py +3 -2
  13. nat/agent/react_agent/register.py +2 -2
  14. nat/agent/react_agent/register_per_user_agent.py +104 -0
  15. nat/agent/reasoning_agent/reasoning_agent.py +1 -1
  16. nat/agent/register.py +3 -1
  17. nat/agent/responses_api_agent/__init__.py +1 -1
  18. nat/agent/responses_api_agent/register.py +1 -1
  19. nat/agent/rewoo_agent/agent.py +9 -4
  20. nat/agent/rewoo_agent/prompt.py +1 -1
  21. nat/agent/rewoo_agent/register.py +1 -1
  22. nat/agent/tool_calling_agent/agent.py +5 -4
  23. nat/agent/tool_calling_agent/register.py +1 -1
  24. nat/authentication/__init__.py +1 -1
  25. nat/authentication/api_key/__init__.py +1 -1
  26. nat/authentication/api_key/api_key_auth_provider.py +1 -1
  27. nat/authentication/api_key/api_key_auth_provider_config.py +22 -7
  28. nat/authentication/api_key/register.py +1 -1
  29. nat/authentication/credential_validator/__init__.py +1 -1
  30. nat/authentication/credential_validator/bearer_token_validator.py +1 -1
  31. nat/authentication/exceptions/__init__.py +1 -1
  32. nat/authentication/exceptions/api_key_exceptions.py +1 -1
  33. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  34. nat/authentication/http_basic_auth/register.py +1 -1
  35. nat/authentication/interfaces.py +1 -1
  36. nat/authentication/oauth2/__init__.py +1 -1
  37. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +1 -1
  38. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +1 -1
  39. nat/authentication/oauth2/oauth2_resource_server_config.py +1 -1
  40. nat/authentication/oauth2/register.py +1 -1
  41. nat/authentication/register.py +1 -1
  42. nat/builder/builder.py +511 -1
  43. nat/builder/child_builder.py +385 -0
  44. nat/builder/component_utils.py +28 -4
  45. nat/builder/context.py +17 -1
  46. nat/builder/embedder.py +1 -1
  47. nat/builder/eval_builder.py +19 -7
  48. nat/builder/evaluator.py +1 -1
  49. nat/builder/framework_enum.py +2 -1
  50. nat/builder/front_end.py +1 -1
  51. nat/builder/function.py +40 -3
  52. nat/builder/function_base.py +1 -1
  53. nat/builder/function_info.py +1 -1
  54. nat/builder/intermediate_step_manager.py +1 -1
  55. nat/builder/llm.py +1 -1
  56. nat/builder/per_user_workflow_builder.py +843 -0
  57. nat/builder/retriever.py +1 -1
  58. nat/builder/sync_builder.py +571 -0
  59. nat/builder/user_interaction_manager.py +1 -1
  60. nat/builder/workflow.py +1 -1
  61. nat/builder/workflow_builder.py +536 -424
  62. nat/cli/__init__.py +1 -1
  63. nat/cli/cli_utils/config_override.py +1 -1
  64. nat/cli/cli_utils/validation.py +32 -1
  65. nat/cli/commands/configure/channel/add.py +1 -1
  66. nat/cli/commands/configure/channel/channel.py +1 -1
  67. nat/cli/commands/configure/channel/remove.py +1 -1
  68. nat/cli/commands/configure/channel/update.py +1 -1
  69. nat/cli/commands/configure/configure.py +1 -1
  70. nat/cli/commands/evaluate.py +87 -13
  71. nat/cli/commands/finetune.py +132 -0
  72. nat/cli/commands/info/__init__.py +1 -1
  73. nat/cli/commands/info/info.py +1 -1
  74. nat/cli/commands/info/list_channels.py +1 -1
  75. nat/cli/commands/info/list_components.py +1 -1
  76. nat/cli/commands/object_store/__init__.py +1 -1
  77. nat/cli/commands/object_store/object_store.py +1 -1
  78. nat/cli/commands/optimize.py +1 -1
  79. nat/cli/commands/{mcp → red_teaming}/__init__.py +1 -1
  80. nat/cli/commands/red_teaming/red_teaming.py +138 -0
  81. nat/cli/commands/red_teaming/red_teaming_utils.py +73 -0
  82. nat/cli/commands/registry/__init__.py +1 -1
  83. nat/cli/commands/registry/publish.py +1 -1
  84. nat/cli/commands/registry/pull.py +1 -1
  85. nat/cli/commands/registry/registry.py +1 -1
  86. nat/cli/commands/registry/remove.py +1 -1
  87. nat/cli/commands/registry/search.py +1 -1
  88. nat/cli/commands/sizing/__init__.py +1 -1
  89. nat/cli/commands/sizing/calc.py +1 -1
  90. nat/cli/commands/sizing/sizing.py +1 -1
  91. nat/cli/commands/start.py +1 -1
  92. nat/cli/commands/uninstall.py +1 -1
  93. nat/cli/commands/validate.py +1 -1
  94. nat/cli/commands/workflow/__init__.py +1 -1
  95. nat/cli/commands/workflow/workflow.py +1 -1
  96. nat/cli/commands/workflow/workflow_commands.py +3 -2
  97. nat/cli/entrypoint.py +15 -37
  98. nat/cli/main.py +2 -2
  99. nat/cli/plugin_loader.py +69 -0
  100. nat/cli/register_workflow.py +183 -5
  101. nat/cli/type_registry.py +169 -3
  102. nat/control_flow/register.py +1 -1
  103. nat/control_flow/router_agent/agent.py +1 -1
  104. nat/control_flow/router_agent/prompt.py +1 -1
  105. nat/control_flow/router_agent/register.py +1 -1
  106. nat/control_flow/sequential_executor.py +28 -7
  107. nat/data_models/__init__.py +1 -1
  108. nat/data_models/agent.py +1 -1
  109. nat/data_models/api_server.py +38 -3
  110. nat/data_models/authentication.py +1 -1
  111. nat/data_models/common.py +1 -1
  112. nat/data_models/component.py +7 -1
  113. nat/data_models/component_ref.py +34 -1
  114. nat/data_models/config.py +62 -1
  115. nat/data_models/dataset_handler.py +15 -2
  116. nat/data_models/discovery_metadata.py +1 -1
  117. nat/data_models/embedder.py +1 -1
  118. nat/data_models/evaluate.py +6 -1
  119. nat/data_models/evaluator.py +1 -1
  120. nat/data_models/finetuning.py +260 -0
  121. nat/data_models/front_end.py +1 -1
  122. nat/data_models/function.py +1 -1
  123. nat/data_models/function_dependencies.py +1 -1
  124. nat/data_models/gated_field_mixin.py +1 -1
  125. nat/data_models/interactive.py +1 -1
  126. nat/data_models/intermediate_step.py +29 -2
  127. nat/data_models/invocation_node.py +1 -1
  128. nat/data_models/llm.py +1 -1
  129. nat/data_models/logging.py +1 -1
  130. nat/data_models/memory.py +1 -1
  131. nat/data_models/middleware.py +3 -1
  132. nat/data_models/object_store.py +1 -1
  133. nat/data_models/openai_mcp.py +1 -1
  134. nat/data_models/optimizable.py +1 -1
  135. nat/data_models/optimizer.py +1 -1
  136. nat/data_models/profiler.py +1 -1
  137. nat/data_models/registry_handler.py +1 -1
  138. nat/data_models/retriever.py +1 -1
  139. nat/data_models/retry_mixin.py +1 -1
  140. nat/data_models/runtime_enum.py +1 -1
  141. nat/data_models/span.py +1 -1
  142. nat/data_models/step_adaptor.py +1 -1
  143. nat/data_models/streaming.py +1 -1
  144. nat/data_models/swe_bench_model.py +1 -1
  145. nat/data_models/telemetry_exporter.py +1 -1
  146. nat/data_models/thinking_mixin.py +1 -1
  147. nat/data_models/ttc_strategy.py +1 -1
  148. nat/embedder/azure_openai_embedder.py +1 -1
  149. nat/embedder/nim_embedder.py +1 -1
  150. nat/embedder/openai_embedder.py +1 -1
  151. nat/embedder/register.py +1 -1
  152. nat/eval/__init__.py +1 -1
  153. nat/eval/config.py +8 -1
  154. nat/eval/dataset_handler/dataset_downloader.py +1 -1
  155. nat/eval/dataset_handler/dataset_filter.py +1 -1
  156. nat/eval/dataset_handler/dataset_handler.py +4 -2
  157. nat/eval/evaluate.py +217 -80
  158. nat/eval/evaluator/__init__.py +1 -1
  159. nat/eval/evaluator/base_evaluator.py +2 -2
  160. nat/eval/evaluator/evaluator_model.py +3 -2
  161. nat/eval/intermediate_step_adapter.py +1 -1
  162. nat/eval/llm_validator.py +336 -0
  163. nat/eval/rag_evaluator/evaluate.py +17 -10
  164. nat/eval/rag_evaluator/register.py +1 -1
  165. nat/eval/red_teaming_evaluator/__init__.py +14 -0
  166. nat/eval/red_teaming_evaluator/data_models.py +66 -0
  167. nat/eval/red_teaming_evaluator/evaluate.py +327 -0
  168. nat/eval/red_teaming_evaluator/filter_conditions.py +75 -0
  169. nat/eval/red_teaming_evaluator/register.py +55 -0
  170. nat/eval/register.py +2 -1
  171. nat/eval/remote_workflow.py +1 -1
  172. nat/eval/runners/__init__.py +1 -1
  173. nat/eval/runners/config.py +1 -1
  174. nat/eval/runners/multi_eval_runner.py +1 -1
  175. nat/eval/runners/red_teaming_runner/__init__.py +24 -0
  176. nat/eval/runners/red_teaming_runner/config.py +282 -0
  177. nat/eval/runners/red_teaming_runner/report_utils.py +707 -0
  178. nat/eval/runners/red_teaming_runner/runner.py +867 -0
  179. nat/eval/runtime_evaluator/__init__.py +1 -1
  180. nat/eval/runtime_evaluator/evaluate.py +1 -1
  181. nat/eval/runtime_evaluator/register.py +1 -1
  182. nat/eval/runtime_event_subscriber.py +1 -1
  183. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  184. nat/eval/swe_bench_evaluator/register.py +1 -1
  185. nat/eval/trajectory_evaluator/evaluate.py +2 -2
  186. nat/eval/trajectory_evaluator/register.py +1 -1
  187. nat/eval/tunable_rag_evaluator/evaluate.py +5 -5
  188. nat/eval/tunable_rag_evaluator/register.py +1 -1
  189. nat/eval/usage_stats.py +1 -1
  190. nat/eval/utils/eval_trace_ctx.py +1 -1
  191. nat/eval/utils/output_uploader.py +1 -1
  192. nat/eval/utils/tqdm_position_registry.py +1 -1
  193. nat/eval/utils/weave_eval.py +1 -1
  194. nat/experimental/decorators/experimental_warning_decorator.py +1 -1
  195. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +1 -1
  196. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +1 -1
  197. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +1 -1
  198. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  199. nat/experimental/test_time_compute/functions/multi_llm_judge_function.py +88 -0
  200. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +1 -1
  201. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  202. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  203. nat/experimental/test_time_compute/models/editor_config.py +1 -1
  204. nat/experimental/test_time_compute/models/scoring_config.py +1 -1
  205. nat/experimental/test_time_compute/models/search_config.py +20 -2
  206. nat/experimental/test_time_compute/models/selection_config.py +33 -2
  207. nat/experimental/test_time_compute/models/stage_enums.py +1 -1
  208. nat/experimental/test_time_compute/models/strategy_base.py +1 -1
  209. nat/experimental/test_time_compute/models/tool_use_config.py +1 -1
  210. nat/experimental/test_time_compute/models/ttc_item.py +1 -1
  211. nat/experimental/test_time_compute/register.py +4 -1
  212. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +1 -1
  213. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +1 -1
  214. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +1 -1
  215. nat/experimental/test_time_compute/search/multi_llm_generation.py +115 -0
  216. nat/experimental/test_time_compute/search/multi_llm_planner.py +1 -1
  217. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +1 -1
  218. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +1 -1
  219. nat/experimental/test_time_compute/selection/best_of_n_selector.py +1 -1
  220. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +1 -1
  221. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  222. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +1 -1
  223. nat/experimental/test_time_compute/selection/llm_judge_selection.py +127 -0
  224. nat/experimental/test_time_compute/selection/threshold_selector.py +1 -1
  225. nat/finetuning/__init__.py +24 -0
  226. nat/finetuning/finetuning_runtime.py +143 -0
  227. nat/finetuning/interfaces/__init__.py +24 -0
  228. nat/finetuning/interfaces/finetuning_runner.py +261 -0
  229. nat/finetuning/interfaces/trainer_adapter.py +103 -0
  230. nat/finetuning/interfaces/trajectory_builder.py +115 -0
  231. nat/finetuning/utils/__init__.py +15 -0
  232. nat/finetuning/utils/parsers/__init__.py +15 -0
  233. nat/finetuning/utils/parsers/adk_parser.py +141 -0
  234. nat/finetuning/utils/parsers/base_parser.py +238 -0
  235. nat/finetuning/utils/parsers/common.py +91 -0
  236. nat/finetuning/utils/parsers/langchain_parser.py +267 -0
  237. nat/finetuning/utils/parsers/llama_index_parser.py +218 -0
  238. nat/front_ends/__init__.py +1 -1
  239. nat/front_ends/console/__init__.py +1 -1
  240. nat/front_ends/console/authentication_flow_handler.py +1 -1
  241. nat/front_ends/console/console_front_end_config.py +4 -1
  242. nat/front_ends/console/console_front_end_plugin.py +5 -4
  243. nat/front_ends/console/register.py +1 -1
  244. nat/front_ends/cron/__init__.py +1 -1
  245. nat/front_ends/fastapi/__init__.py +1 -1
  246. nat/front_ends/fastapi/async_job.py +128 -0
  247. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  248. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +13 -9
  249. nat/front_ends/fastapi/dask_client_mixin.py +1 -1
  250. nat/front_ends/fastapi/fastapi_front_end_config.py +1 -1
  251. nat/front_ends/fastapi/fastapi_front_end_controller.py +1 -1
  252. nat/front_ends/fastapi/fastapi_front_end_plugin.py +25 -30
  253. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +195 -60
  254. nat/front_ends/fastapi/html_snippets/__init__.py +1 -1
  255. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +1 -1
  256. nat/front_ends/fastapi/intermediate_steps_subscriber.py +12 -1
  257. nat/front_ends/fastapi/job_store.py +23 -11
  258. nat/front_ends/fastapi/main.py +1 -1
  259. nat/front_ends/fastapi/message_handler.py +27 -4
  260. nat/front_ends/fastapi/message_validator.py +54 -2
  261. nat/front_ends/fastapi/register.py +1 -1
  262. nat/front_ends/fastapi/response_helpers.py +16 -15
  263. nat/front_ends/fastapi/step_adaptor.py +1 -1
  264. nat/front_ends/fastapi/utils.py +1 -1
  265. nat/front_ends/register.py +1 -2
  266. nat/front_ends/simple_base/__init__.py +1 -1
  267. nat/front_ends/simple_base/simple_front_end_plugin_base.py +6 -4
  268. nat/llm/aws_bedrock_llm.py +1 -1
  269. nat/llm/azure_openai_llm.py +10 -1
  270. nat/llm/dynamo_llm.py +363 -0
  271. nat/llm/huggingface_llm.py +177 -0
  272. nat/llm/litellm_llm.py +1 -1
  273. nat/llm/nim_llm.py +1 -1
  274. nat/llm/openai_llm.py +1 -1
  275. nat/llm/register.py +3 -1
  276. nat/llm/utils/__init__.py +1 -1
  277. nat/llm/utils/env_config_value.py +1 -1
  278. nat/llm/utils/error.py +1 -1
  279. nat/llm/utils/thinking.py +1 -1
  280. nat/memory/__init__.py +1 -1
  281. nat/memory/interfaces.py +1 -1
  282. nat/memory/models.py +1 -1
  283. nat/meta/pypi.md +1 -1
  284. nat/middleware/__init__.py +5 -5
  285. nat/middleware/cache/__init__.py +14 -0
  286. nat/middleware/{cache_middleware.py → cache/cache_middleware.py} +39 -42
  287. nat/middleware/cache/cache_middleware_config.py +44 -0
  288. nat/middleware/cache/register.py +33 -0
  289. nat/middleware/defense/__init__.py +14 -0
  290. nat/middleware/defense/defense_middleware.py +362 -0
  291. nat/middleware/defense/defense_middleware_content_guard.py +455 -0
  292. nat/middleware/defense/defense_middleware_data_models.py +91 -0
  293. nat/middleware/defense/defense_middleware_output_verifier.py +440 -0
  294. nat/middleware/defense/defense_middleware_pii.py +356 -0
  295. nat/middleware/defense/register.py +82 -0
  296. nat/middleware/dynamic/__init__.py +14 -0
  297. nat/middleware/dynamic/dynamic_function_middleware.py +962 -0
  298. nat/middleware/dynamic/dynamic_middleware_config.py +132 -0
  299. nat/middleware/dynamic/register.py +34 -0
  300. nat/middleware/function_middleware.py +236 -52
  301. nat/middleware/logging/__init__.py +14 -0
  302. nat/middleware/logging/logging_middleware.py +67 -0
  303. nat/middleware/logging/logging_middleware_config.py +28 -0
  304. nat/middleware/logging/register.py +33 -0
  305. nat/middleware/middleware.py +142 -28
  306. nat/middleware/red_teaming/__init__.py +14 -0
  307. nat/middleware/red_teaming/red_teaming_middleware.py +344 -0
  308. nat/middleware/red_teaming/red_teaming_middleware_config.py +112 -0
  309. nat/middleware/red_teaming/register.py +47 -0
  310. nat/middleware/register.py +7 -20
  311. nat/middleware/utils/__init__.py +14 -0
  312. nat/middleware/utils/workflow_inventory.py +155 -0
  313. nat/object_store/__init__.py +1 -1
  314. nat/object_store/in_memory_object_store.py +1 -1
  315. nat/object_store/interfaces.py +1 -1
  316. nat/object_store/models.py +1 -1
  317. nat/object_store/register.py +1 -1
  318. nat/observability/__init__.py +1 -1
  319. nat/observability/exporter/__init__.py +1 -1
  320. nat/observability/exporter/base_exporter.py +1 -1
  321. nat/observability/exporter/exporter.py +1 -1
  322. nat/observability/exporter/file_exporter.py +1 -1
  323. nat/observability/exporter/processing_exporter.py +1 -1
  324. nat/observability/exporter/raw_exporter.py +1 -1
  325. nat/observability/exporter/span_exporter.py +7 -1
  326. nat/observability/exporter_manager.py +1 -1
  327. nat/observability/mixin/__init__.py +1 -1
  328. nat/observability/mixin/batch_config_mixin.py +1 -1
  329. nat/observability/mixin/collector_config_mixin.py +1 -1
  330. nat/observability/mixin/file_mixin.py +1 -1
  331. nat/observability/mixin/file_mode.py +1 -1
  332. nat/observability/mixin/redaction_config_mixin.py +1 -1
  333. nat/observability/mixin/resource_conflict_mixin.py +1 -1
  334. nat/observability/mixin/serialize_mixin.py +1 -1
  335. nat/observability/mixin/tagging_config_mixin.py +1 -1
  336. nat/observability/mixin/type_introspection_mixin.py +1 -1
  337. nat/observability/processor/__init__.py +1 -1
  338. nat/observability/processor/batching_processor.py +1 -1
  339. nat/observability/processor/callback_processor.py +1 -1
  340. nat/observability/processor/falsy_batch_filter_processor.py +1 -1
  341. nat/observability/processor/intermediate_step_serializer.py +1 -1
  342. nat/observability/processor/processor.py +1 -1
  343. nat/observability/processor/processor_factory.py +1 -1
  344. nat/observability/processor/redaction/__init__.py +1 -1
  345. nat/observability/processor/redaction/contextual_redaction_processor.py +1 -1
  346. nat/observability/processor/redaction/contextual_span_redaction_processor.py +1 -1
  347. nat/observability/processor/redaction/redaction_processor.py +1 -1
  348. nat/observability/processor/redaction/span_header_redaction_processor.py +1 -1
  349. nat/observability/processor/span_tagging_processor.py +1 -1
  350. nat/observability/register.py +1 -1
  351. nat/observability/utils/__init__.py +1 -1
  352. nat/observability/utils/dict_utils.py +1 -1
  353. nat/observability/utils/time_utils.py +1 -1
  354. nat/profiler/calc/__init__.py +1 -1
  355. nat/profiler/calc/calc_runner.py +3 -3
  356. nat/profiler/calc/calculations.py +1 -1
  357. nat/profiler/calc/data_models.py +1 -1
  358. nat/profiler/calc/plot.py +30 -3
  359. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  360. nat/profiler/callbacks/base_callback_class.py +1 -1
  361. nat/profiler/callbacks/langchain_callback_handler.py +33 -3
  362. nat/profiler/callbacks/llama_index_callback_handler.py +13 -10
  363. nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
  364. nat/profiler/callbacks/token_usage_base_model.py +1 -1
  365. nat/profiler/data_frame_row.py +1 -1
  366. nat/profiler/data_models.py +1 -1
  367. nat/profiler/decorators/framework_wrapper.py +16 -1
  368. nat/profiler/decorators/function_tracking.py +1 -1
  369. nat/profiler/forecasting/config.py +1 -1
  370. nat/profiler/forecasting/model_trainer.py +1 -1
  371. nat/profiler/forecasting/models/__init__.py +1 -1
  372. nat/profiler/forecasting/models/forecasting_base_model.py +1 -1
  373. nat/profiler/forecasting/models/linear_model.py +1 -1
  374. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  375. nat/profiler/inference_metrics_model.py +1 -1
  376. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  377. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  378. nat/profiler/inference_optimization/data_models.py +1 -1
  379. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +1 -1
  380. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  381. nat/profiler/inference_optimization/llm_metrics.py +1 -1
  382. nat/profiler/inference_optimization/prompt_caching.py +1 -1
  383. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  384. nat/profiler/inference_optimization/workflow_runtimes.py +1 -1
  385. nat/profiler/intermediate_property_adapter.py +1 -1
  386. nat/profiler/parameter_optimization/optimizable_utils.py +1 -1
  387. nat/profiler/parameter_optimization/optimizer_runtime.py +1 -1
  388. nat/profiler/parameter_optimization/parameter_optimizer.py +1 -1
  389. nat/profiler/parameter_optimization/parameter_selection.py +1 -1
  390. nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
  391. nat/profiler/parameter_optimization/prompt_optimizer.py +1 -1
  392. nat/profiler/parameter_optimization/update_helpers.py +1 -1
  393. nat/profiler/profile_runner.py +1 -1
  394. nat/profiler/utils.py +1 -1
  395. nat/registry_handlers/local/local_handler.py +1 -1
  396. nat/registry_handlers/local/register_local.py +1 -1
  397. nat/registry_handlers/metadata_factory.py +1 -1
  398. nat/registry_handlers/package_utils.py +1 -1
  399. nat/registry_handlers/pypi/pypi_handler.py +1 -1
  400. nat/registry_handlers/pypi/register_pypi.py +1 -1
  401. nat/registry_handlers/register.py +1 -1
  402. nat/registry_handlers/registry_handler_base.py +1 -1
  403. nat/registry_handlers/rest/register_rest.py +1 -1
  404. nat/registry_handlers/rest/rest_handler.py +1 -1
  405. nat/registry_handlers/schemas/headers.py +1 -1
  406. nat/registry_handlers/schemas/package.py +1 -1
  407. nat/registry_handlers/schemas/publish.py +1 -1
  408. nat/registry_handlers/schemas/pull.py +1 -1
  409. nat/registry_handlers/schemas/remove.py +1 -1
  410. nat/registry_handlers/schemas/search.py +1 -1
  411. nat/registry_handlers/schemas/status.py +1 -1
  412. nat/retriever/interface.py +1 -1
  413. nat/retriever/milvus/__init__.py +1 -1
  414. nat/retriever/milvus/register.py +1 -1
  415. nat/retriever/milvus/retriever.py +1 -1
  416. nat/retriever/models.py +1 -1
  417. nat/retriever/nemo_retriever/__init__.py +1 -1
  418. nat/retriever/nemo_retriever/register.py +1 -1
  419. nat/retriever/nemo_retriever/retriever.py +5 -5
  420. nat/retriever/register.py +1 -1
  421. nat/runtime/__init__.py +1 -1
  422. nat/runtime/loader.py +10 -3
  423. nat/runtime/metrics.py +180 -0
  424. nat/runtime/runner.py +1 -5
  425. nat/runtime/session.py +451 -32
  426. nat/runtime/user_metadata.py +1 -1
  427. nat/settings/global_settings.py +1 -1
  428. nat/tool/chat_completion.py +1 -1
  429. nat/tool/code_execution/README.md +1 -1
  430. nat/tool/code_execution/code_sandbox.py +1 -1
  431. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +1 -1
  432. nat/tool/code_execution/local_sandbox/__init__.py +1 -1
  433. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  434. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +1 -1
  435. nat/tool/code_execution/register.py +1 -1
  436. nat/tool/code_execution/utils.py +1 -1
  437. nat/tool/datetime_tools.py +1 -1
  438. nat/tool/document_search.py +1 -1
  439. nat/tool/github_tools.py +1 -1
  440. nat/tool/memory_tools/add_memory_tool.py +1 -1
  441. nat/tool/memory_tools/delete_memory_tool.py +1 -1
  442. nat/tool/memory_tools/get_memory_tool.py +1 -1
  443. nat/tool/nvidia_rag.py +2 -2
  444. nat/tool/register.py +1 -1
  445. nat/tool/retriever.py +1 -1
  446. nat/tool/server_tools.py +1 -1
  447. nat/utils/__init__.py +8 -5
  448. nat/utils/callable_utils.py +1 -1
  449. nat/utils/data_models/schema_validator.py +1 -1
  450. nat/utils/debugging_utils.py +1 -1
  451. nat/utils/decorators.py +1 -1
  452. nat/utils/dump_distro_mapping.py +1 -1
  453. nat/utils/exception_handlers/automatic_retries.py +3 -3
  454. nat/utils/exception_handlers/schemas.py +1 -1
  455. nat/utils/io/model_processing.py +1 -1
  456. nat/utils/io/supress_logs.py +33 -0
  457. nat/utils/io/yaml_tools.py +1 -1
  458. nat/utils/log_levels.py +1 -1
  459. nat/utils/log_utils.py +13 -1
  460. nat/utils/metadata_utils.py +1 -1
  461. nat/utils/optional_imports.py +1 -1
  462. nat/utils/producer_consumer_queue.py +1 -1
  463. nat/utils/reactive/base/observable_base.py +1 -1
  464. nat/utils/reactive/base/observer_base.py +1 -1
  465. nat/utils/reactive/base/subject_base.py +1 -1
  466. nat/utils/reactive/observable.py +1 -1
  467. nat/utils/reactive/observer.py +1 -1
  468. nat/utils/reactive/subject.py +1 -1
  469. nat/utils/reactive/subscription.py +1 -1
  470. nat/utils/responses_api.py +1 -1
  471. nat/utils/settings/global_settings.py +1 -1
  472. nat/utils/string_utils.py +1 -1
  473. nat/utils/type_converter.py +18 -5
  474. nat/utils/type_utils.py +1 -1
  475. nat/utils/url_utils.py +1 -1
  476. {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/METADATA +39 -14
  477. nvidia_nat-1.4.0a20260113.dist-info/RECORD +547 -0
  478. nvidia_nat-1.4.0a20260113.dist-info/entry_points.txt +38 -0
  479. nat/cli/commands/mcp/mcp.py +0 -986
  480. nat/front_ends/mcp/introspection_token_verifier.py +0 -73
  481. nat/front_ends/mcp/mcp_front_end_config.py +0 -109
  482. nat/front_ends/mcp/mcp_front_end_plugin.py +0 -155
  483. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +0 -388
  484. nat/front_ends/mcp/memory_profiler.py +0 -320
  485. nat/front_ends/mcp/register.py +0 -27
  486. nat/front_ends/mcp/tool_converter.py +0 -321
  487. nvidia_nat-1.4.0a20251120.dist-info/RECORD +0 -488
  488. nvidia_nat-1.4.0a20251120.dist-info/entry_points.txt +0 -23
  489. {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/WHEEL +0 -0
  490. {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  491. {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE.md +0 -0
  492. {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -0,0 +1,127 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import re
18
+
19
+ from nat.builder.builder import Builder
20
+ from nat.builder.framework_enum import LLMFrameworkEnum
21
+ from nat.cli.register_workflow import register_ttc_strategy
22
+ from nat.experimental.test_time_compute.models.selection_config import LLMJudgeSelectionConfig
23
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
24
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
25
+ from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
26
+ from nat.experimental.test_time_compute.models.ttc_item import TTCItem
27
+ from nat.utils.io.model_processing import remove_r1_think_tags
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class LLMJudgeSelection(StrategyBase):
33
+ """
34
+ A selection strategy that uses a configured Judge LLM to select the best response.
35
+ """
36
+
37
+ def __init__(self, config: LLMJudgeSelectionConfig) -> None:
38
+ super().__init__(config)
39
+ self.config = config
40
+ self.judge_llm_bound = None
41
+
42
+ async def build_components(self, builder: Builder) -> None:
43
+ """
44
+ Builds the Judge LLM configured in the strategy.
45
+ """
46
+ logger.debug("Building components for LLMJudgeSelection")
47
+ self.judge_llm_bound = await builder.get_llm(self.config.judge_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
48
+
49
+ def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
50
+ return [PipelineTypeEnum.CUSTOM, PipelineTypeEnum.PLANNING, PipelineTypeEnum.AGENT_EXECUTION]
51
+
52
+ def stage_type(self) -> StageTypeEnum:
53
+ return StageTypeEnum.SELECTION
54
+
55
+ async def ainvoke(self,
56
+ items: list[TTCItem],
57
+ original_prompt: str | None = None,
58
+ agent_context: str | None = None,
59
+ **kwargs) -> list[TTCItem]:
60
+ """
61
+ Select the best item using the configured Judge LLM.
62
+ """
63
+ if not self.judge_llm_bound:
64
+ raise ValueError("Judge LLM not bound. Ensure `build_components` has been called.")
65
+
66
+ if not items:
67
+ logger.warning("No items provided for selection.")
68
+ return []
69
+
70
+ try:
71
+ from langchain_core.prompts import PromptTemplate
72
+ from pydantic import BaseModel
73
+ except ImportError as exc:
74
+ raise ImportError("langchain-core is not installed.") from exc
75
+
76
+ # Format the results for the prompt
77
+ results_str = ""
78
+ for idx, item in enumerate(items):
79
+ item_output = (str(item.output.model_dump()) if isinstance(item.output, BaseModel) else str(item.output))
80
+ results_str += f"{idx + 1}. {remove_r1_think_tags(item_output)}\n\n"
81
+
82
+ prompt_template = PromptTemplate(
83
+ template=self.config.selection_template,
84
+ input_variables=["original_prompt", "results"],
85
+ validate_template=True,
86
+ )
87
+
88
+ # Use input from first item if original_prompt is missing
89
+ query = original_prompt if original_prompt else (items[0].input or "Unknown Query")
90
+
91
+ prompt = (await prompt_template.ainvoke(input={"original_prompt": query, "results": results_str})).to_string()
92
+
93
+ logger.info("Asking Judge LLM to select the best response.")
94
+ judge_response = await self.judge_llm_bound.ainvoke(prompt)
95
+ judge_content = remove_r1_think_tags(
96
+ judge_response.content if hasattr(judge_response, 'content') else str(judge_response))
97
+
98
+ # Parse selection
99
+ # Expected format: 'SELECTED ITEM: <number>'
100
+ match = re.search(r'SELECTED ITEM:\s*(\d+)', judge_content, re.IGNORECASE)
101
+ if match:
102
+ try:
103
+ index = int(match.group(1)) - 1
104
+ if 0 <= index < len(items):
105
+ logger.info("Judge selected item %d", index + 1)
106
+ selected_item = items[index]
107
+ # Optionally attach judge's reasoning to metadata
108
+ if selected_item.metadata is None:
109
+ selected_item.metadata = {}
110
+ selected_item.metadata["judge_reasoning"] = judge_content
111
+ return [selected_item]
112
+ else:
113
+ logger.warning("Judge selected index %d which is out of range.", index + 1)
114
+ except ValueError:
115
+ logger.warning("Failed to parse integer from judge selection.")
116
+
117
+ logger.warning("Could not parse valid selection from judge response. "
118
+ "Returning first item as fallback.")
119
+ # Fallback to first item
120
+ return [items[0]]
121
+
122
+
123
+ @register_ttc_strategy(config_type=LLMJudgeSelectionConfig)
124
+ async def register_llm_judge_selection(config: LLMJudgeSelectionConfig, builder: Builder):
125
+ strategy = LLMJudgeSelection(config)
126
+ await strategy.build_components(builder)
127
+ yield strategy
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -0,0 +1,24 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from nat.finetuning.interfaces.finetuning_runner import Trainer
17
+ from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
18
+ from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
19
+
20
+ __all__ = [
21
+ "Trainer",
22
+ "TrajectoryBuilder",
23
+ "TrainerAdapter",
24
+ ]
@@ -0,0 +1,143 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Finetuning runtime for NAT that orchestrates the training process."""
17
+
18
+ import asyncio
19
+ import logging
20
+
21
+ from nat.data_models.finetuning import FinetuneRunConfig
22
+ from nat.data_models.finetuning import TrainingStatusEnum
23
+ from nat.finetuning.interfaces.finetuning_runner import Trainer
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ async def run_finetuning(runner: Trainer) -> None:
29
+ """
30
+ Run finetuning based on the provided configuration.
31
+
32
+ Args:
33
+ runner: An instance of the Trainer to run finetuning with
34
+ """
35
+ try:
36
+ # Initialize the runner
37
+ logger.info("Initializing finetuning runner...")
38
+
39
+ # Get number of epochs from config
40
+ num_epochs = runner.run_config.num_epochs
41
+
42
+ # Run training for specified epochs
43
+ logger.info("Starting training for %d epochs...", num_epochs)
44
+ job_statuses = await runner.run(num_epochs)
45
+
46
+ # Log final status
47
+ for status in job_statuses:
48
+ logger.info("Job %s completed with status: %s", status.run_id, status.status)
49
+ if status.message:
50
+ logger.info(" Message: %s", status.message)
51
+
52
+ # Get and log final metrics
53
+ if job_statuses:
54
+ final_run_id = job_statuses[-1].run_id
55
+ try:
56
+ metrics = await runner.get_metrics(final_run_id)
57
+ logger.info("Final metrics: %s", metrics)
58
+ except (ValueError, RuntimeError) as e:
59
+ logger.warning("Failed to retrieve metrics: %s", e)
60
+
61
+ # Log appropriate message based on job statuses
62
+ if not job_statuses:
63
+ logger.warning("Finetuning completed with no jobs executed.")
64
+ else:
65
+ failed_jobs = sum(1 for s in job_statuses if s.status == TrainingStatusEnum.FAILED)
66
+ canceled_jobs = sum(1 for s in job_statuses if s.status == TrainingStatusEnum.CANCELED)
67
+ completed_jobs = sum(1 for s in job_statuses if s.status == TrainingStatusEnum.COMPLETED)
68
+
69
+ if failed_jobs:
70
+ logger.error("Finetuning completed with %d failed job(s) out of %d total.",
71
+ failed_jobs,
72
+ len(job_statuses))
73
+ elif canceled_jobs:
74
+ logger.warning("Finetuning was canceled. %d job(s) were canceled out of %d total.",
75
+ canceled_jobs,
76
+ len(job_statuses))
77
+ elif completed_jobs == len(job_statuses):
78
+ logger.info("Finetuning completed successfully!")
79
+ else:
80
+ # Some jobs may still be pending or running (unexpected state)
81
+ logger.warning("Finetuning finished with %d completed, %d pending/running job(s).",
82
+ completed_jobs,
83
+ len(job_statuses) - completed_jobs)
84
+
85
+ except Exception as e:
86
+ logger.error("Finetuning failed: %s", e)
87
+ raise
88
+ finally:
89
+ # Always cleanup resources
90
+ logger.info("Cleaning up finetuning resources...")
91
+ await runner.cleanup()
92
+ logger.info("Cleanup completed")
93
+
94
+
95
+ async def finetuning_main(run_config: FinetuneRunConfig) -> None:
96
+ """
97
+ Main entry point for finetuning runtime.
98
+
99
+ Args:
100
+ run_config: FinetuneRunConfig object containing finetuning settings
101
+ """
102
+
103
+ from nat.builder.workflow_builder import WorkflowBuilder
104
+ from nat.runtime.loader import load_config
105
+
106
+ config = load_config(config_file=run_config.config_file)
107
+ finetuning_config = config.finetuning
108
+ finetuning_config.run_configuration = run_config
109
+
110
+ if not config.finetuning.enabled:
111
+ raise ValueError("Finetuning is not enabled in the provided configuration.")
112
+
113
+ async with WorkflowBuilder.from_config(config=config) as builder:
114
+ # Get trajectory builder and trainer adapter from builder
115
+ logger.info("Initializing finetuning components...")
116
+ trajectory_builder_name = finetuning_config.trajectory_builder
117
+ trainer_adapter_name = finetuning_config.trainer_adapter
118
+ trajectory_builder = await builder.get_trajectory_builder(trajectory_builder_name)
119
+ trainer_adapter = await builder.get_trainer_adapter(trainer_adapter_name)
120
+ logger.info("Finetuning components initialized.")
121
+
122
+ # Initialize trainer
123
+ trainer_name = finetuning_config.trainer
124
+ trainer = await builder.get_trainer(trainer_name,
125
+ trajectory_builder=trajectory_builder,
126
+ trainer_adapter=trainer_adapter)
127
+
128
+ await trainer.initialize(run_config=finetuning_config)
129
+
130
+ logger.info("Initialized trainer: %s", trainer_name)
131
+
132
+ # Run finetuning
133
+ await run_finetuning(trainer)
134
+
135
+
136
+ def run_finetuning_sync(run_config: FinetuneRunConfig) -> None:
137
+ """
138
+ Synchronous wrapper for running finetuning.
139
+
140
+ Args:
141
+ run_config: FinetuneRunConfig object containing finetuning settings
142
+ """
143
+ asyncio.run(finetuning_main(run_config))
@@ -0,0 +1,24 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from nat.finetuning.interfaces.finetuning_runner import Trainer
17
+ from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
18
+ from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
19
+
20
+ __all__ = [
21
+ "Trainer",
22
+ "TrajectoryBuilder",
23
+ "TrainerAdapter",
24
+ ]
@@ -0,0 +1,261 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from abc import ABC
18
+ from abc import abstractmethod
19
+ from typing import Any
20
+
21
+ from nat.data_models.finetuning import FinetuneConfig
22
+ from nat.data_models.finetuning import FinetuneRunConfig
23
+ from nat.data_models.finetuning import TrainerConfig
24
+ from nat.data_models.finetuning import TrainingJobRef
25
+ from nat.data_models.finetuning import TrainingJobStatus
26
+ from nat.data_models.finetuning import TrajectoryCollection
27
+ from nat.eval.config import EvaluationRunOutput
28
+ from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
29
+ from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class Trainer(ABC):
35
+ """
36
+ Abstract interface for running finetuning workflows.
37
+
38
+ The Trainer orchestrates the entire finetuning process by:
39
+ 1. Running evaluations to generate trajectories via TrajectoryBuilder
40
+ 2. Submitting trajectories for training via TrainerAdapter
41
+ 3. Managing multiple epochs of training
42
+ """
43
+
44
+ def __init__(self, trainer_config: TrainerConfig, **kwargs) -> None:
45
+ """
46
+ Initialize the Trainer.
47
+
48
+ Args:
49
+ trainer_config: Configuration for the trainer backend
50
+ run_config: Configuration for the training run
51
+ backend: Backend identifier
52
+ curriculum_config: Optional curriculum learning configuration
53
+ """
54
+ self.trainer_config = trainer_config
55
+ self.run_config: FinetuneConfig = None
56
+ self.curriculum_config = None
57
+ self.trajectory_builder: TrajectoryBuilder = None
58
+ self.trainer_adapter: TrainerAdapter = None
59
+
60
+ # Curriculum learning state
61
+ self._curriculum_state = None
62
+
63
+ async def bind_components(self, trajectory_builder: TrajectoryBuilder, trainer_adapter: TrainerAdapter) -> None:
64
+ """
65
+ Bind the TrajectoryBuilder and TrainerAdapter components.
66
+
67
+ Args:
68
+ trajectory_builder: Instance of TrajectoryBuilder
69
+ trainer_adapter: Instance of TrainerAdapter
70
+ """
71
+ self.trajectory_builder = trajectory_builder
72
+ self.trainer_adapter = trainer_adapter
73
+
74
+ async def initialize(self, run_config: FinetuneConfig) -> None:
75
+ """
76
+ Initialize the runner and its components.
77
+
78
+ This should:
79
+ - Initialize the TrajectoryBuilder
80
+ - Initialize the TrainerAdapter
81
+ - Verify connectivity to backend services
82
+ """
83
+
84
+ self.run_config = run_config
85
+ self.curriculum_config = self.run_config.curriculum_learning
86
+ self._curriculum_state = {
87
+ "current_percentile": self.curriculum_config.initial_percentile,
88
+ "last_expansion_epoch": -1,
89
+ "total_groups": 0,
90
+ "included_groups": set()
91
+ }
92
+ self.trainer_config.reward = self.run_config.reward_function
93
+
94
+ await self.trajectory_builder.initialize(run_config)
95
+ await self.trainer_adapter.initialize(run_config)
96
+
97
+ @abstractmethod
98
+ async def run_epoch(self, epoch: int, run_id: str) -> TrainingJobRef:
99
+ """
100
+ Run a single epoch of training.
101
+
102
+ Args:
103
+ epoch: The current epoch number (0-indexed)
104
+ run_id: Unique identifier for this training run
105
+
106
+ Returns:
107
+ TrainingJobRef: Reference to the submitted training job
108
+ """
109
+ raise NotImplementedError
110
+
111
+ @abstractmethod
112
+ async def run(self, num_epochs: int) -> list[TrainingJobStatus]:
113
+ """
114
+ Run the complete finetuning workflow for the specified number of epochs.
115
+
116
+ Args:
117
+ num_epochs: Number of epochs to train
118
+
119
+ Returns:
120
+ list[TrainingJobStatus]: Status of all training jobs
121
+ """
122
+ raise NotImplementedError
123
+
124
+ @abstractmethod
125
+ async def get_metrics(self, run_id: str) -> dict[str, Any]:
126
+ """
127
+ Get training metrics for a specific run.
128
+
129
+ Args:
130
+ run_id: The run identifier
131
+
132
+ Returns:
133
+ dict: Metrics from the training run
134
+ """
135
+ raise NotImplementedError
136
+
137
+ @abstractmethod
138
+ async def cleanup(self) -> None:
139
+ """
140
+ Clean up any resources used by the runner.
141
+ """
142
+ raise NotImplementedError
143
+
144
+ @abstractmethod
145
+ def log_progress(self, epoch: int, metrics: dict[str, Any], output_dir: str | None = None) -> None:
146
+ """
147
+ Log training progress for monitoring.
148
+
149
+ Args:
150
+ epoch: Current epoch number
151
+ metrics: Dictionary of metrics to log
152
+ output_dir: Optional output directory override
153
+ """
154
+ raise NotImplementedError
155
+
156
+ async def run_validation_evaluation(self, epoch: int, run_id: str) -> dict[str, Any]:
157
+ """
158
+ Run evaluation on validation dataset to collect rewards.
159
+
160
+ This method creates a temporary TrainerRunConfig with the validation
161
+ dataset and runs evaluation to collect rewards without training.
162
+
163
+ Args:
164
+ epoch: Current epoch number
165
+ run_id: Unique identifier for this training run
166
+ validation_dataset: Path to the validation dataset
167
+
168
+ Returns:
169
+ dict: Validation metrics including average reward
170
+ """
171
+ logger.info("Running validation evaluation for epoch %d", epoch + 1)
172
+
173
+ config = self.run_config.run_configuration.validation_config_file if (
174
+ self.run_config.run_configuration.validation_config_file) else self.run_config.run_configuration.config_file
175
+
176
+ # Create a temporary run config with validation dataset
177
+ validation_run_config = FinetuneRunConfig(config_file=config,
178
+ dataset=self.run_config.run_configuration.validation_dataset,
179
+ result_json_path=self.run_config.run_configuration.result_json_path,
180
+ endpoint=self.run_config.run_configuration.endpoint,
181
+ endpoint_timeout=self.run_config.run_configuration.endpoint_timeout,
182
+ override=self.run_config.run_configuration.override)
183
+
184
+ # Create a temporary trajectory builder for validation
185
+ validation_builder = self.trajectory_builder
186
+ original_run_config = validation_builder.run_config.run_configuration
187
+
188
+ try:
189
+
190
+ validation_builder.run_config.run_configuration = validation_run_config
191
+
192
+ # Run evaluation
193
+ eval_output = await validation_builder.run_eval()
194
+
195
+ # Calculate validation metrics from eval output
196
+ validation_metrics = self._calculate_validation_metrics(eval_output)
197
+ validation_metrics["epoch"] = epoch
198
+ validation_metrics["dataset_type"] = "validation"
199
+
200
+ logger.info("Validation metrics for epoch %d: %s", epoch, validation_metrics)
201
+ return validation_metrics
202
+
203
+ except Exception as e:
204
+ logger.error("Error during validation evaluation: %s", e)
205
+ return {"epoch": epoch, "dataset_type": "validation", "error": str(e), "avg_reward": 0.0, "num_examples": 0}
206
+ finally:
207
+ # Restore original run config
208
+ validation_builder.run_config.run_configuration = original_run_config
209
+
210
+ def _calculate_validation_metrics(self, eval_output: EvaluationRunOutput) -> dict[str, Any]:
211
+ """
212
+ Calculate validation metrics from evaluation output.
213
+
214
+ Args:
215
+ eval_output: Output from evaluation run
216
+
217
+ Returns:
218
+ dict: Calculated metrics
219
+ """
220
+ # Default implementation - subclasses can override for
221
+ # backend-specific metrics
222
+ metrics = {"avg_reward": 0.0, "min_reward": 0.0, "max_reward": 0.0, "num_examples": 0}
223
+
224
+ rewards = []
225
+ for metric_name, metric_value in eval_output.evaluation_results:
226
+ if metric_name == self.trainer_config.reward.name:
227
+ reward_results = metric_value.eval_output_items
228
+ for reward_item in reward_results:
229
+ rewards.append(reward_item.score)
230
+
231
+ if rewards:
232
+ metrics["avg_reward"] = sum(rewards) / len(rewards)
233
+ metrics["min_reward"] = min(rewards)
234
+ metrics["max_reward"] = max(rewards)
235
+ metrics["num_examples"] = len(rewards)
236
+
237
+ return metrics
238
+
239
+ def apply_curriculum_learning(self, trajectory_collection: TrajectoryCollection,
240
+ epoch: int) -> TrajectoryCollection:
241
+ """
242
+ Apply curriculum learning to filter trajectory groups based on difficulty.
243
+ """
244
+ raise NotImplementedError("Curriculum learning not implemented for this backend.")
245
+
246
+ def get_curriculum_state(self) -> dict[str, Any]:
247
+ """
248
+ Get the current state of curriculum learning.
249
+
250
+ Returns:
251
+ dict: Current curriculum state including percentile and group statistics
252
+ """
253
+ # Convert set to list for JSON serialization
254
+ state = {
255
+ "current_percentile": self._curriculum_state["current_percentile"],
256
+ "last_expansion_epoch": self._curriculum_state["last_expansion_epoch"],
257
+ "total_groups": self._curriculum_state["total_groups"],
258
+ "included_groups": list(self._curriculum_state["included_groups"]),
259
+ "config": self.curriculum_config.model_dump() if self.curriculum_config else None
260
+ }
261
+ return state