nvidia-nat 1.4.0a20251112__py3-none-any.whl → 1.4.0a20260113__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (492) hide show
  1. aiq/__init__.py +1 -1
  2. nat/{front_ends/mcp → agent/auto_memory_wrapper}/__init__.py +1 -1
  3. nat/agent/auto_memory_wrapper/agent.py +278 -0
  4. nat/agent/auto_memory_wrapper/register.py +227 -0
  5. nat/agent/auto_memory_wrapper/state.py +30 -0
  6. nat/agent/base.py +1 -1
  7. nat/agent/dual_node.py +1 -1
  8. nat/agent/prompt_optimizer/prompt.py +1 -1
  9. nat/agent/prompt_optimizer/register.py +1 -1
  10. nat/agent/react_agent/agent.py +16 -9
  11. nat/agent/react_agent/output_parser.py +2 -2
  12. nat/agent/react_agent/prompt.py +3 -2
  13. nat/agent/react_agent/register.py +2 -2
  14. nat/agent/react_agent/register_per_user_agent.py +104 -0
  15. nat/agent/reasoning_agent/reasoning_agent.py +1 -1
  16. nat/agent/register.py +3 -1
  17. nat/agent/responses_api_agent/__init__.py +1 -1
  18. nat/agent/responses_api_agent/register.py +1 -1
  19. nat/agent/rewoo_agent/agent.py +9 -4
  20. nat/agent/rewoo_agent/prompt.py +1 -1
  21. nat/agent/rewoo_agent/register.py +1 -1
  22. nat/agent/tool_calling_agent/agent.py +5 -4
  23. nat/agent/tool_calling_agent/register.py +1 -1
  24. nat/authentication/__init__.py +1 -1
  25. nat/authentication/api_key/__init__.py +1 -1
  26. nat/authentication/api_key/api_key_auth_provider.py +1 -1
  27. nat/authentication/api_key/api_key_auth_provider_config.py +22 -7
  28. nat/authentication/api_key/register.py +1 -1
  29. nat/authentication/credential_validator/__init__.py +1 -1
  30. nat/authentication/credential_validator/bearer_token_validator.py +1 -1
  31. nat/authentication/exceptions/__init__.py +1 -1
  32. nat/authentication/exceptions/api_key_exceptions.py +1 -1
  33. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  34. nat/authentication/http_basic_auth/register.py +1 -1
  35. nat/authentication/interfaces.py +1 -1
  36. nat/authentication/oauth2/__init__.py +1 -1
  37. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +1 -1
  38. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +1 -1
  39. nat/authentication/oauth2/oauth2_resource_server_config.py +1 -1
  40. nat/authentication/oauth2/register.py +1 -1
  41. nat/authentication/register.py +1 -1
  42. nat/builder/builder.py +563 -1
  43. nat/builder/child_builder.py +385 -0
  44. nat/builder/component_utils.py +34 -4
  45. nat/builder/context.py +34 -1
  46. nat/builder/embedder.py +1 -1
  47. nat/builder/eval_builder.py +19 -7
  48. nat/builder/evaluator.py +1 -1
  49. nat/builder/framework_enum.py +3 -1
  50. nat/builder/front_end.py +1 -1
  51. nat/builder/function.py +113 -5
  52. nat/builder/function_base.py +1 -1
  53. nat/builder/function_info.py +1 -1
  54. nat/builder/intermediate_step_manager.py +1 -1
  55. nat/builder/llm.py +1 -1
  56. nat/builder/per_user_workflow_builder.py +843 -0
  57. nat/builder/retriever.py +1 -1
  58. nat/builder/sync_builder.py +571 -0
  59. nat/builder/user_interaction_manager.py +1 -1
  60. nat/builder/workflow.py +5 -3
  61. nat/builder/workflow_builder.py +619 -378
  62. nat/cli/__init__.py +1 -1
  63. nat/cli/cli_utils/config_override.py +1 -1
  64. nat/cli/cli_utils/validation.py +32 -1
  65. nat/cli/commands/configure/channel/add.py +1 -1
  66. nat/cli/commands/configure/channel/channel.py +1 -1
  67. nat/cli/commands/configure/channel/remove.py +1 -1
  68. nat/cli/commands/configure/channel/update.py +1 -1
  69. nat/cli/commands/configure/configure.py +1 -1
  70. nat/cli/commands/evaluate.py +87 -13
  71. nat/cli/commands/finetune.py +132 -0
  72. nat/cli/commands/info/__init__.py +1 -1
  73. nat/cli/commands/info/info.py +1 -1
  74. nat/cli/commands/info/list_channels.py +1 -1
  75. nat/cli/commands/info/list_components.py +1 -1
  76. nat/cli/commands/object_store/__init__.py +1 -1
  77. nat/cli/commands/object_store/object_store.py +1 -1
  78. nat/cli/commands/optimize.py +1 -1
  79. nat/cli/commands/{mcp → red_teaming}/__init__.py +1 -1
  80. nat/cli/commands/red_teaming/red_teaming.py +138 -0
  81. nat/cli/commands/red_teaming/red_teaming_utils.py +73 -0
  82. nat/cli/commands/registry/__init__.py +1 -1
  83. nat/cli/commands/registry/publish.py +1 -1
  84. nat/cli/commands/registry/pull.py +1 -1
  85. nat/cli/commands/registry/registry.py +1 -1
  86. nat/cli/commands/registry/remove.py +1 -1
  87. nat/cli/commands/registry/search.py +1 -1
  88. nat/cli/commands/sizing/__init__.py +1 -1
  89. nat/cli/commands/sizing/calc.py +1 -1
  90. nat/cli/commands/sizing/sizing.py +1 -1
  91. nat/cli/commands/start.py +1 -1
  92. nat/cli/commands/uninstall.py +1 -1
  93. nat/cli/commands/validate.py +1 -1
  94. nat/cli/commands/workflow/__init__.py +1 -1
  95. nat/cli/commands/workflow/workflow.py +1 -1
  96. nat/cli/commands/workflow/workflow_commands.py +3 -2
  97. nat/cli/entrypoint.py +15 -37
  98. nat/cli/main.py +2 -2
  99. nat/cli/plugin_loader.py +69 -0
  100. nat/cli/register_workflow.py +233 -5
  101. nat/cli/type_registry.py +237 -3
  102. nat/control_flow/register.py +1 -1
  103. nat/control_flow/router_agent/agent.py +1 -1
  104. nat/control_flow/router_agent/prompt.py +1 -1
  105. nat/control_flow/router_agent/register.py +1 -1
  106. nat/control_flow/sequential_executor.py +28 -7
  107. nat/data_models/__init__.py +1 -1
  108. nat/data_models/agent.py +1 -1
  109. nat/data_models/api_server.py +38 -3
  110. nat/data_models/authentication.py +1 -1
  111. nat/data_models/common.py +1 -1
  112. nat/data_models/component.py +9 -1
  113. nat/data_models/component_ref.py +45 -1
  114. nat/data_models/config.py +78 -1
  115. nat/data_models/dataset_handler.py +15 -2
  116. nat/data_models/discovery_metadata.py +1 -1
  117. nat/data_models/embedder.py +1 -1
  118. nat/data_models/evaluate.py +6 -1
  119. nat/data_models/evaluator.py +1 -1
  120. nat/data_models/finetuning.py +260 -0
  121. nat/data_models/front_end.py +1 -1
  122. nat/data_models/function.py +15 -2
  123. nat/data_models/function_dependencies.py +1 -1
  124. nat/data_models/gated_field_mixin.py +1 -1
  125. nat/data_models/interactive.py +1 -1
  126. nat/data_models/intermediate_step.py +29 -2
  127. nat/data_models/invocation_node.py +1 -1
  128. nat/data_models/llm.py +1 -1
  129. nat/data_models/logging.py +1 -1
  130. nat/data_models/memory.py +1 -1
  131. nat/data_models/middleware.py +37 -0
  132. nat/data_models/object_store.py +1 -1
  133. nat/data_models/openai_mcp.py +1 -1
  134. nat/data_models/optimizable.py +1 -1
  135. nat/data_models/optimizer.py +1 -1
  136. nat/data_models/profiler.py +1 -1
  137. nat/data_models/registry_handler.py +1 -1
  138. nat/data_models/retriever.py +1 -1
  139. nat/data_models/retry_mixin.py +1 -1
  140. nat/data_models/runtime_enum.py +26 -0
  141. nat/data_models/span.py +1 -1
  142. nat/data_models/step_adaptor.py +1 -1
  143. nat/data_models/streaming.py +1 -1
  144. nat/data_models/swe_bench_model.py +1 -1
  145. nat/data_models/telemetry_exporter.py +1 -1
  146. nat/data_models/thinking_mixin.py +1 -1
  147. nat/data_models/ttc_strategy.py +1 -1
  148. nat/embedder/azure_openai_embedder.py +1 -1
  149. nat/embedder/nim_embedder.py +1 -1
  150. nat/embedder/openai_embedder.py +1 -1
  151. nat/embedder/register.py +1 -1
  152. nat/eval/__init__.py +1 -1
  153. nat/eval/config.py +8 -1
  154. nat/eval/dataset_handler/dataset_downloader.py +1 -1
  155. nat/eval/dataset_handler/dataset_filter.py +1 -1
  156. nat/eval/dataset_handler/dataset_handler.py +4 -2
  157. nat/eval/evaluate.py +226 -81
  158. nat/eval/evaluator/__init__.py +1 -1
  159. nat/eval/evaluator/base_evaluator.py +2 -2
  160. nat/eval/evaluator/evaluator_model.py +3 -2
  161. nat/eval/intermediate_step_adapter.py +1 -1
  162. nat/eval/llm_validator.py +336 -0
  163. nat/eval/rag_evaluator/evaluate.py +17 -10
  164. nat/eval/rag_evaluator/register.py +1 -1
  165. nat/eval/red_teaming_evaluator/__init__.py +14 -0
  166. nat/eval/red_teaming_evaluator/data_models.py +66 -0
  167. nat/eval/red_teaming_evaluator/evaluate.py +327 -0
  168. nat/eval/red_teaming_evaluator/filter_conditions.py +75 -0
  169. nat/eval/red_teaming_evaluator/register.py +55 -0
  170. nat/eval/register.py +2 -1
  171. nat/eval/remote_workflow.py +1 -1
  172. nat/eval/runners/__init__.py +1 -1
  173. nat/eval/runners/config.py +1 -1
  174. nat/eval/runners/multi_eval_runner.py +1 -1
  175. nat/eval/runners/red_teaming_runner/__init__.py +24 -0
  176. nat/eval/runners/red_teaming_runner/config.py +282 -0
  177. nat/eval/runners/red_teaming_runner/report_utils.py +707 -0
  178. nat/eval/runners/red_teaming_runner/runner.py +867 -0
  179. nat/eval/runtime_evaluator/__init__.py +1 -1
  180. nat/eval/runtime_evaluator/evaluate.py +1 -1
  181. nat/eval/runtime_evaluator/register.py +1 -1
  182. nat/eval/runtime_event_subscriber.py +1 -1
  183. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  184. nat/eval/swe_bench_evaluator/register.py +1 -1
  185. nat/eval/trajectory_evaluator/evaluate.py +2 -2
  186. nat/eval/trajectory_evaluator/register.py +1 -1
  187. nat/eval/tunable_rag_evaluator/evaluate.py +5 -5
  188. nat/eval/tunable_rag_evaluator/register.py +1 -1
  189. nat/eval/usage_stats.py +1 -1
  190. nat/eval/utils/eval_trace_ctx.py +1 -1
  191. nat/eval/utils/output_uploader.py +1 -1
  192. nat/eval/utils/tqdm_position_registry.py +1 -1
  193. nat/eval/utils/weave_eval.py +1 -1
  194. nat/experimental/decorators/experimental_warning_decorator.py +1 -1
  195. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +1 -1
  196. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +1 -1
  197. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +1 -1
  198. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  199. nat/experimental/test_time_compute/functions/multi_llm_judge_function.py +88 -0
  200. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +1 -1
  201. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  202. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  203. nat/experimental/test_time_compute/models/editor_config.py +1 -1
  204. nat/experimental/test_time_compute/models/scoring_config.py +1 -1
  205. nat/experimental/test_time_compute/models/search_config.py +20 -2
  206. nat/experimental/test_time_compute/models/selection_config.py +33 -2
  207. nat/experimental/test_time_compute/models/stage_enums.py +1 -1
  208. nat/experimental/test_time_compute/models/strategy_base.py +1 -1
  209. nat/experimental/test_time_compute/models/tool_use_config.py +1 -1
  210. nat/experimental/test_time_compute/models/ttc_item.py +1 -1
  211. nat/experimental/test_time_compute/register.py +4 -1
  212. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +1 -1
  213. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +1 -1
  214. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +1 -1
  215. nat/experimental/test_time_compute/search/multi_llm_generation.py +115 -0
  216. nat/experimental/test_time_compute/search/multi_llm_planner.py +1 -1
  217. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +1 -1
  218. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +1 -1
  219. nat/experimental/test_time_compute/selection/best_of_n_selector.py +1 -1
  220. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +1 -1
  221. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  222. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +1 -1
  223. nat/experimental/test_time_compute/selection/llm_judge_selection.py +127 -0
  224. nat/experimental/test_time_compute/selection/threshold_selector.py +1 -1
  225. nat/finetuning/__init__.py +24 -0
  226. nat/finetuning/finetuning_runtime.py +143 -0
  227. nat/finetuning/interfaces/__init__.py +24 -0
  228. nat/finetuning/interfaces/finetuning_runner.py +261 -0
  229. nat/finetuning/interfaces/trainer_adapter.py +103 -0
  230. nat/finetuning/interfaces/trajectory_builder.py +115 -0
  231. nat/finetuning/utils/__init__.py +15 -0
  232. nat/finetuning/utils/parsers/__init__.py +15 -0
  233. nat/finetuning/utils/parsers/adk_parser.py +141 -0
  234. nat/finetuning/utils/parsers/base_parser.py +238 -0
  235. nat/finetuning/utils/parsers/common.py +91 -0
  236. nat/finetuning/utils/parsers/langchain_parser.py +267 -0
  237. nat/finetuning/utils/parsers/llama_index_parser.py +218 -0
  238. nat/front_ends/__init__.py +1 -1
  239. nat/front_ends/console/__init__.py +1 -1
  240. nat/front_ends/console/authentication_flow_handler.py +1 -1
  241. nat/front_ends/console/console_front_end_config.py +4 -1
  242. nat/front_ends/console/console_front_end_plugin.py +5 -4
  243. nat/front_ends/console/register.py +1 -1
  244. nat/front_ends/cron/__init__.py +1 -1
  245. nat/front_ends/fastapi/__init__.py +1 -1
  246. nat/front_ends/fastapi/async_job.py +128 -0
  247. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  248. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +13 -9
  249. nat/front_ends/fastapi/dask_client_mixin.py +1 -1
  250. nat/front_ends/fastapi/fastapi_front_end_config.py +23 -1
  251. nat/front_ends/fastapi/fastapi_front_end_controller.py +1 -1
  252. nat/front_ends/fastapi/fastapi_front_end_plugin.py +25 -30
  253. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +318 -59
  254. nat/front_ends/fastapi/html_snippets/__init__.py +1 -1
  255. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +1 -1
  256. nat/front_ends/fastapi/intermediate_steps_subscriber.py +12 -1
  257. nat/front_ends/fastapi/job_store.py +23 -11
  258. nat/front_ends/fastapi/main.py +1 -1
  259. nat/front_ends/fastapi/message_handler.py +27 -4
  260. nat/front_ends/fastapi/message_validator.py +54 -2
  261. nat/front_ends/fastapi/register.py +1 -1
  262. nat/front_ends/fastapi/response_helpers.py +16 -15
  263. nat/front_ends/fastapi/step_adaptor.py +1 -1
  264. nat/front_ends/fastapi/utils.py +1 -1
  265. nat/front_ends/register.py +1 -2
  266. nat/front_ends/simple_base/__init__.py +1 -1
  267. nat/front_ends/simple_base/simple_front_end_plugin_base.py +6 -4
  268. nat/llm/aws_bedrock_llm.py +1 -1
  269. nat/llm/azure_openai_llm.py +10 -1
  270. nat/llm/dynamo_llm.py +363 -0
  271. nat/llm/huggingface_llm.py +177 -0
  272. nat/llm/litellm_llm.py +1 -1
  273. nat/llm/nim_llm.py +1 -1
  274. nat/llm/openai_llm.py +1 -1
  275. nat/llm/register.py +3 -1
  276. nat/llm/utils/__init__.py +1 -1
  277. nat/llm/utils/env_config_value.py +1 -1
  278. nat/llm/utils/error.py +1 -1
  279. nat/llm/utils/thinking.py +1 -1
  280. nat/memory/__init__.py +1 -1
  281. nat/memory/interfaces.py +1 -1
  282. nat/memory/models.py +1 -1
  283. nat/meta/pypi.md +1 -1
  284. nat/middleware/__init__.py +35 -0
  285. nat/middleware/cache/__init__.py +14 -0
  286. nat/middleware/cache/cache_middleware.py +253 -0
  287. nat/middleware/cache/cache_middleware_config.py +44 -0
  288. nat/middleware/cache/register.py +33 -0
  289. nat/middleware/defense/__init__.py +14 -0
  290. nat/middleware/defense/defense_middleware.py +362 -0
  291. nat/middleware/defense/defense_middleware_content_guard.py +455 -0
  292. nat/middleware/defense/defense_middleware_data_models.py +91 -0
  293. nat/middleware/defense/defense_middleware_output_verifier.py +440 -0
  294. nat/middleware/defense/defense_middleware_pii.py +356 -0
  295. nat/middleware/defense/register.py +82 -0
  296. nat/middleware/dynamic/__init__.py +14 -0
  297. nat/middleware/dynamic/dynamic_function_middleware.py +962 -0
  298. nat/middleware/dynamic/dynamic_middleware_config.py +132 -0
  299. nat/middleware/dynamic/register.py +34 -0
  300. nat/middleware/function_middleware.py +370 -0
  301. nat/middleware/logging/__init__.py +14 -0
  302. nat/middleware/logging/logging_middleware.py +67 -0
  303. nat/middleware/logging/logging_middleware_config.py +28 -0
  304. nat/middleware/logging/register.py +33 -0
  305. nat/middleware/middleware.py +298 -0
  306. nat/middleware/red_teaming/__init__.py +14 -0
  307. nat/middleware/red_teaming/red_teaming_middleware.py +344 -0
  308. nat/middleware/red_teaming/red_teaming_middleware_config.py +112 -0
  309. nat/middleware/red_teaming/register.py +47 -0
  310. nat/middleware/register.py +22 -0
  311. nat/middleware/utils/__init__.py +14 -0
  312. nat/middleware/utils/workflow_inventory.py +155 -0
  313. nat/object_store/__init__.py +1 -1
  314. nat/object_store/in_memory_object_store.py +1 -1
  315. nat/object_store/interfaces.py +1 -1
  316. nat/object_store/models.py +1 -1
  317. nat/object_store/register.py +1 -1
  318. nat/observability/__init__.py +1 -1
  319. nat/observability/exporter/__init__.py +1 -1
  320. nat/observability/exporter/base_exporter.py +1 -1
  321. nat/observability/exporter/exporter.py +1 -1
  322. nat/observability/exporter/file_exporter.py +1 -1
  323. nat/observability/exporter/processing_exporter.py +1 -1
  324. nat/observability/exporter/raw_exporter.py +1 -1
  325. nat/observability/exporter/span_exporter.py +7 -1
  326. nat/observability/exporter_manager.py +1 -1
  327. nat/observability/mixin/__init__.py +1 -1
  328. nat/observability/mixin/batch_config_mixin.py +1 -1
  329. nat/observability/mixin/collector_config_mixin.py +1 -1
  330. nat/observability/mixin/file_mixin.py +1 -1
  331. nat/observability/mixin/file_mode.py +1 -1
  332. nat/observability/mixin/redaction_config_mixin.py +1 -1
  333. nat/observability/mixin/resource_conflict_mixin.py +1 -1
  334. nat/observability/mixin/serialize_mixin.py +1 -1
  335. nat/observability/mixin/tagging_config_mixin.py +1 -1
  336. nat/observability/mixin/type_introspection_mixin.py +1 -1
  337. nat/observability/processor/__init__.py +1 -1
  338. nat/observability/processor/batching_processor.py +1 -1
  339. nat/observability/processor/callback_processor.py +1 -1
  340. nat/observability/processor/falsy_batch_filter_processor.py +1 -1
  341. nat/observability/processor/intermediate_step_serializer.py +1 -1
  342. nat/observability/processor/processor.py +1 -1
  343. nat/observability/processor/processor_factory.py +1 -1
  344. nat/observability/processor/redaction/__init__.py +1 -1
  345. nat/observability/processor/redaction/contextual_redaction_processor.py +1 -1
  346. nat/observability/processor/redaction/contextual_span_redaction_processor.py +1 -1
  347. nat/observability/processor/redaction/redaction_processor.py +1 -1
  348. nat/observability/processor/redaction/span_header_redaction_processor.py +1 -1
  349. nat/observability/processor/span_tagging_processor.py +1 -1
  350. nat/observability/register.py +1 -1
  351. nat/observability/utils/__init__.py +1 -1
  352. nat/observability/utils/dict_utils.py +1 -1
  353. nat/observability/utils/time_utils.py +1 -1
  354. nat/profiler/calc/__init__.py +1 -1
  355. nat/profiler/calc/calc_runner.py +3 -3
  356. nat/profiler/calc/calculations.py +1 -1
  357. nat/profiler/calc/data_models.py +1 -1
  358. nat/profiler/calc/plot.py +30 -3
  359. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  360. nat/profiler/callbacks/base_callback_class.py +1 -1
  361. nat/profiler/callbacks/langchain_callback_handler.py +33 -3
  362. nat/profiler/callbacks/llama_index_callback_handler.py +13 -10
  363. nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
  364. nat/profiler/callbacks/token_usage_base_model.py +1 -1
  365. nat/profiler/data_frame_row.py +1 -1
  366. nat/profiler/data_models.py +1 -1
  367. nat/profiler/decorators/framework_wrapper.py +32 -1
  368. nat/profiler/decorators/function_tracking.py +1 -1
  369. nat/profiler/forecasting/config.py +1 -1
  370. nat/profiler/forecasting/model_trainer.py +1 -1
  371. nat/profiler/forecasting/models/__init__.py +1 -1
  372. nat/profiler/forecasting/models/forecasting_base_model.py +1 -1
  373. nat/profiler/forecasting/models/linear_model.py +1 -1
  374. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  375. nat/profiler/inference_metrics_model.py +1 -1
  376. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  377. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  378. nat/profiler/inference_optimization/data_models.py +1 -1
  379. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +1 -1
  380. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  381. nat/profiler/inference_optimization/llm_metrics.py +1 -1
  382. nat/profiler/inference_optimization/prompt_caching.py +1 -1
  383. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  384. nat/profiler/inference_optimization/workflow_runtimes.py +1 -1
  385. nat/profiler/intermediate_property_adapter.py +1 -1
  386. nat/profiler/parameter_optimization/optimizable_utils.py +1 -1
  387. nat/profiler/parameter_optimization/optimizer_runtime.py +1 -1
  388. nat/profiler/parameter_optimization/parameter_optimizer.py +1 -1
  389. nat/profiler/parameter_optimization/parameter_selection.py +1 -1
  390. nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
  391. nat/profiler/parameter_optimization/prompt_optimizer.py +1 -1
  392. nat/profiler/parameter_optimization/update_helpers.py +1 -1
  393. nat/profiler/profile_runner.py +1 -1
  394. nat/profiler/utils.py +1 -1
  395. nat/registry_handlers/local/local_handler.py +1 -1
  396. nat/registry_handlers/local/register_local.py +1 -1
  397. nat/registry_handlers/metadata_factory.py +1 -1
  398. nat/registry_handlers/package_utils.py +1 -1
  399. nat/registry_handlers/pypi/pypi_handler.py +1 -1
  400. nat/registry_handlers/pypi/register_pypi.py +1 -1
  401. nat/registry_handlers/register.py +1 -1
  402. nat/registry_handlers/registry_handler_base.py +1 -1
  403. nat/registry_handlers/rest/register_rest.py +1 -1
  404. nat/registry_handlers/rest/rest_handler.py +1 -1
  405. nat/registry_handlers/schemas/headers.py +1 -1
  406. nat/registry_handlers/schemas/package.py +1 -1
  407. nat/registry_handlers/schemas/publish.py +1 -1
  408. nat/registry_handlers/schemas/pull.py +1 -1
  409. nat/registry_handlers/schemas/remove.py +1 -1
  410. nat/registry_handlers/schemas/search.py +1 -1
  411. nat/registry_handlers/schemas/status.py +1 -1
  412. nat/retriever/interface.py +1 -1
  413. nat/retriever/milvus/__init__.py +1 -1
  414. nat/retriever/milvus/register.py +12 -4
  415. nat/retriever/milvus/retriever.py +103 -41
  416. nat/retriever/models.py +1 -1
  417. nat/retriever/nemo_retriever/__init__.py +1 -1
  418. nat/retriever/nemo_retriever/register.py +1 -1
  419. nat/retriever/nemo_retriever/retriever.py +5 -5
  420. nat/retriever/register.py +1 -1
  421. nat/runtime/__init__.py +1 -1
  422. nat/runtime/loader.py +10 -3
  423. nat/runtime/metrics.py +180 -0
  424. nat/runtime/runner.py +13 -6
  425. nat/runtime/session.py +458 -32
  426. nat/runtime/user_metadata.py +1 -1
  427. nat/settings/global_settings.py +1 -1
  428. nat/tool/chat_completion.py +1 -1
  429. nat/tool/code_execution/README.md +1 -1
  430. nat/tool/code_execution/code_sandbox.py +2 -2
  431. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +1 -1
  432. nat/tool/code_execution/local_sandbox/__init__.py +1 -1
  433. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  434. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +1 -1
  435. nat/tool/code_execution/register.py +1 -1
  436. nat/tool/code_execution/utils.py +1 -1
  437. nat/tool/datetime_tools.py +1 -1
  438. nat/tool/document_search.py +1 -1
  439. nat/tool/github_tools.py +1 -1
  440. nat/tool/memory_tools/add_memory_tool.py +1 -1
  441. nat/tool/memory_tools/delete_memory_tool.py +1 -1
  442. nat/tool/memory_tools/get_memory_tool.py +1 -1
  443. nat/tool/nvidia_rag.py +2 -2
  444. nat/tool/register.py +1 -1
  445. nat/tool/retriever.py +1 -1
  446. nat/tool/server_tools.py +1 -1
  447. nat/utils/__init__.py +8 -5
  448. nat/utils/callable_utils.py +1 -1
  449. nat/utils/data_models/schema_validator.py +1 -1
  450. nat/utils/debugging_utils.py +1 -1
  451. nat/utils/decorators.py +1 -1
  452. nat/utils/dump_distro_mapping.py +1 -1
  453. nat/utils/exception_handlers/automatic_retries.py +3 -3
  454. nat/utils/exception_handlers/schemas.py +1 -1
  455. nat/utils/io/model_processing.py +1 -1
  456. nat/utils/io/supress_logs.py +33 -0
  457. nat/utils/io/yaml_tools.py +1 -1
  458. nat/utils/log_levels.py +1 -1
  459. nat/utils/log_utils.py +13 -1
  460. nat/utils/metadata_utils.py +1 -1
  461. nat/utils/optional_imports.py +1 -1
  462. nat/utils/producer_consumer_queue.py +1 -1
  463. nat/utils/reactive/base/observable_base.py +1 -1
  464. nat/utils/reactive/base/observer_base.py +1 -1
  465. nat/utils/reactive/base/subject_base.py +1 -1
  466. nat/utils/reactive/observable.py +1 -1
  467. nat/utils/reactive/observer.py +1 -1
  468. nat/utils/reactive/subject.py +1 -1
  469. nat/utils/reactive/subscription.py +1 -1
  470. nat/utils/responses_api.py +1 -1
  471. nat/utils/settings/global_settings.py +1 -1
  472. nat/utils/string_utils.py +1 -1
  473. nat/utils/type_converter.py +18 -5
  474. nat/utils/type_utils.py +1 -1
  475. nat/utils/url_utils.py +1 -1
  476. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/METADATA +46 -15
  477. nvidia_nat-1.4.0a20260113.dist-info/RECORD +547 -0
  478. nvidia_nat-1.4.0a20260113.dist-info/entry_points.txt +38 -0
  479. nat/cli/commands/mcp/mcp.py +0 -986
  480. nat/front_ends/mcp/introspection_token_verifier.py +0 -73
  481. nat/front_ends/mcp/mcp_front_end_config.py +0 -109
  482. nat/front_ends/mcp/mcp_front_end_plugin.py +0 -151
  483. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +0 -362
  484. nat/front_ends/mcp/memory_profiler.py +0 -320
  485. nat/front_ends/mcp/register.py +0 -27
  486. nat/front_ends/mcp/tool_converter.py +0 -321
  487. nvidia_nat-1.4.0a20251112.dist-info/RECORD +0 -481
  488. nvidia_nat-1.4.0a20251112.dist-info/entry_points.txt +0 -22
  489. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/WHEEL +0 -0
  490. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  491. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE.md +0 -0
  492. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,103 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC
17
+ from abc import abstractmethod
18
+ from typing import Any
19
+
20
+ from nat.data_models.finetuning import FinetuneConfig
21
+ from nat.data_models.finetuning import TrainerAdapterConfig
22
+ from nat.data_models.finetuning import TrainingJobRef
23
+ from nat.data_models.finetuning import TrainingJobStatus
24
+ from nat.data_models.finetuning import TrajectoryCollection
25
+
26
+
27
+ class TrainerAdapter(ABC):
28
+ """
29
+ Adapter to send Trajectories to remote training cluster for weights updates.
30
+ """
31
+
32
+ def __init__(self, adapter_config: TrainerAdapterConfig):
33
+ self.adapter_config = adapter_config
34
+ self.run_config: FinetuneConfig = None
35
+
36
+ async def initialize(self, run_config: FinetuneConfig) -> None:
37
+ """
38
+ Asynchronously initialize any resources needed for the trainer adapter.
39
+ """
40
+ self.run_config = run_config
41
+ self.adapter_config.reward = self.run_config.reward_function
42
+
43
+ @abstractmethod
44
+ async def is_healthy(self) -> bool:
45
+ """
46
+ Check the health of the remote training backend.
47
+
48
+ Returns:
49
+ bool: True if the backend is healthy, False otherwise.
50
+ """
51
+ raise NotImplementedError
52
+
53
+ @abstractmethod
54
+ async def submit(self, trajectories: TrajectoryCollection) -> TrainingJobRef:
55
+ """
56
+ Submit trajectories to remote training backend.
57
+
58
+ Args:
59
+ trajectories (list[Trajectory]): The list of trajectories to submit.
60
+
61
+ Returns:
62
+ TrainingJobRef: Reference to the submitted training job.
63
+ """
64
+ raise NotImplementedError
65
+
66
+ @abstractmethod
67
+ async def status(self, ref: TrainingJobRef) -> TrainingJobStatus:
68
+ """
69
+ Get the status of a submitted training job.
70
+
71
+ Args:
72
+ ref (TrainingJobRef): Reference to the training job.
73
+
74
+ Returns:
75
+ TrainingJobStatus: The current status of the training job.
76
+ """
77
+ raise NotImplementedError
78
+
79
+ @abstractmethod
80
+ async def wait_until_complete(self, ref: TrainingJobRef, poll_interval: float = 10.0) -> TrainingJobStatus:
81
+ """
82
+ Wait until the training job is complete.
83
+
84
+ Args:
85
+ ref (TrainingJobRef): Reference to the training job.
86
+ poll_interval (float): Time in seconds between status checks.
87
+
88
+ Returns:
89
+ TrainingJobStatus: The final status of the training job.
90
+ """
91
+ raise NotImplementedError
92
+
93
+ @abstractmethod
94
+ def log_progress(self, ref: TrainingJobRef, metrics: dict[str, Any], output_dir: str | None = None) -> None:
95
+ """
96
+ Log training adapter progress.
97
+
98
+ Args:
99
+ ref: Training job reference
100
+ metrics: Dictionary of metrics to log
101
+ output_dir: Optional output directory override
102
+ """
103
+ raise NotImplementedError
@@ -0,0 +1,115 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC
17
+ from abc import abstractmethod
18
+ from typing import Any
19
+
20
+ from nat.data_models.finetuning import FinetuneConfig
21
+ from nat.data_models.finetuning import TrajectoryBuilderConfig
22
+ from nat.data_models.finetuning import TrajectoryCollection
23
+ from nat.eval.config import EvaluationRunOutput
24
+ from nat.eval.evaluator.evaluator_model import EvalOutputItem
25
+ from nat.utils.io.supress_logs import suppress_logs
26
+
27
+
28
+ class TrajectoryBuilder(ABC):
29
+ """
30
+ Abstract interface for building trajectories from episode items.
31
+ """
32
+
33
+ def __init__(self, trajectory_builder_config: TrajectoryBuilderConfig):
34
+ self.trajectory_builder_config = trajectory_builder_config
35
+ self.run_config: FinetuneConfig = None
36
+
37
+ async def initialize(self, run_config: FinetuneConfig) -> None:
38
+ """
39
+ Asynchronously initialize any resources needed for the trajectory builder.
40
+ """
41
+ self.run_config = run_config
42
+ self.trajectory_builder_config.reward = self.run_config.reward_function
43
+
44
+ async def run_eval(self) -> EvaluationRunOutput:
45
+ """
46
+ Run NAT Evaluation to generate episode items for trajectory building.
47
+
48
+ Returns:
49
+ EvaluationRunOutput: The output of the evaluation run.
50
+ """
51
+
52
+ from nat.eval.evaluate import EvaluationRun
53
+ from nat.eval.evaluate import EvaluationRunConfig
54
+
55
+ eval_cfg = EvaluationRunConfig(config_file=self.run_config.run_configuration.config_file,
56
+ dataset=self.run_config.run_configuration.dataset,
57
+ result_json_path=self.run_config.run_configuration.result_json_path,
58
+ endpoint=self.run_config.run_configuration.endpoint,
59
+ endpoint_timeout=self.run_config.run_configuration.endpoint_timeout,
60
+ override=self.run_config.run_configuration.override)
61
+
62
+ async with suppress_logs(prefix="nat.eval"):
63
+ evaluation_output = await EvaluationRun(config=eval_cfg).run_and_evaluate()
64
+
65
+ return evaluation_output
66
+
67
+ @abstractmethod
68
+ async def start_run(self, run_id: str, meta: dict | None = None) -> None:
69
+ """
70
+ Initialize any resources needed for the trajectory builder.
71
+
72
+ Args:
73
+ run_id (str): The unique identifier for the training run.
74
+ meta (dict): Metadata associated with the training run.
75
+ """
76
+ raise NotImplementedError
77
+
78
+ @abstractmethod
79
+ async def finalize(self, run_id: str, meta: dict | None = None) -> TrajectoryCollection:
80
+ """
81
+ Finalize the trajectory building process and return the constructed trajectories.
82
+
83
+ Args:
84
+ run_id (str): The unique identifier for the training run.
85
+ meta (dict): Metadata associated with the training run.
86
+
87
+ Returns:
88
+ list[Trajectory]: The list of constructed trajectories.
89
+ """
90
+ raise NotImplementedError
91
+
92
+ async def compute_reward(self, output_item: EvalOutputItem, meta: dict | None = None):
93
+ """
94
+ Compute reward for a given EvalOutputItem.
95
+
96
+ Args:
97
+ output_item (EvalOutputItem): The evaluation output item.
98
+ meta (dict): Metadata associated with the training run.
99
+
100
+ Returns:
101
+ float: The computed reward.
102
+ """
103
+ return float(output_item.score) if output_item.score is not None else 0.0
104
+
105
+ @abstractmethod
106
+ def log_progress(self, run_id: str, metrics: dict[str, Any], output_dir: str | None = None) -> None:
107
+ """
108
+ Log trajectory building progress.
109
+
110
+ Args:
111
+ run_id: The training run ID
112
+ metrics: Dictionary of metrics to log
113
+ output_dir: Optional output directory override
114
+ """
115
+ raise NotImplementedError
@@ -0,0 +1,15 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
@@ -0,0 +1,15 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
@@ -0,0 +1,141 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import logging
18
+
19
+ from nat.data_models.intermediate_step import IntermediateStep
20
+ from nat.data_models.intermediate_step import IntermediateStepType
21
+ from nat.finetuning.utils.parsers.common import extract_content
22
+ from nat.finetuning.utils.parsers.common import parse_generic_message
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Re-export for backwards compatibility and internal use
27
+ _extract_content = extract_content
28
+ _parse_generic_message = parse_generic_message
29
+
30
+
31
+ def parse_to_openai_message(message: IntermediateStep) -> dict: # noqa: ARG001
32
+ """
33
+ Convert IntermediateStep to OpenAI-compatible message dictionary.
34
+
35
+ Args:
36
+ message: An IntermediateStep object representing a single message.
37
+ previous_message: Previous message for context (reserved for future).
38
+
39
+ Returns:
40
+ A dictionary formatted for OpenAI API consumption.
41
+ """
42
+ result = {}
43
+
44
+ # Handle different event types to determine role and extract content
45
+ if message.event_type == IntermediateStepType.LLM_END:
46
+ # Assistant message from ChatResponse
47
+ result = _parse_assistant_message(message)
48
+ elif message.event_type == IntermediateStepType.TOOL_END:
49
+ # Tool/Function response message
50
+ result = _parse_tool_message(message)
51
+ elif message.event_type == IntermediateStepType.LLM_START:
52
+ # Extract user/system messages from the input
53
+ result = _parse_input_message(message)
54
+ else:
55
+ # For other types, try to infer from the data
56
+ result = _parse_generic_message(message)
57
+
58
+ return result
59
+
60
+
61
+ def _parse_input_message(message: IntermediateStep) -> dict | list[dict]:
62
+ """Parse user or system messages from LLM_START event."""
63
+
64
+ messages = message.data.payload
65
+
66
+ if len(messages) == 0:
67
+ return {"role": "user", "content": ""}
68
+ elif len(messages) == 1:
69
+ if not isinstance(messages[0], dict):
70
+ return {"role": "user", "content": str(messages[0])}
71
+
72
+ if not ("role" in messages[0] and "content" in messages[0]):
73
+ return {"role": "user", "content": json.dumps(messages[0])}
74
+
75
+ return messages[0]
76
+ else:
77
+ parsed_messages = []
78
+ for msg in messages:
79
+ if not isinstance(msg, dict):
80
+ parsed_messages.append({"role": "user", "content": str(msg)})
81
+ elif not ("role" in msg and "content" in msg):
82
+ parsed_messages.append({"role": "user", "content": json.dumps(msg)})
83
+ else:
84
+ parsed_messages.append(msg)
85
+ return parsed_messages
86
+
87
+
88
+ def _parse_assistant_message(message: IntermediateStep) -> dict:
89
+ """Parse an assistant message from LLM_END event."""
90
+ result = {"role": "assistant"}
91
+
92
+ # Get the ChatResponse from payload if available
93
+ try:
94
+ if message.data and message.data.payload:
95
+ pass
96
+ payload = message.data.payload
97
+ payload_message = getattr(payload, 'message', None)
98
+
99
+ if "logprobs" in payload:
100
+ result["logprobs"] = payload["logprobs"]
101
+ else:
102
+ logger.warning("No logprobs found in LLM_END message payload.")
103
+
104
+ if "content" in payload_message and payload_message["content"] is not None:
105
+ result["content"] = _extract_content(payload_message["content"])
106
+ else:
107
+ result["content"] = ""
108
+
109
+ if "tool_calls" in payload_message and payload_message["tool_calls"] is not None:
110
+ result["tool_calls"] = payload_message["tool_calls"]
111
+
112
+ else:
113
+ logger.warning("No payload found in LLM_END message data.")
114
+ return {"role": "assistant", "content": ""}
115
+ except Exception as _:
116
+ logger.exception("Error parsing assistant message from LLM_END event.")
117
+ return {"role": "assistant", "content": ""}
118
+
119
+ return result
120
+
121
+
122
+ def _parse_tool_message(message: IntermediateStep) -> dict:
123
+ """Parse a tool/function response message from TOOL_END event."""
124
+ result = {"role": "function"}
125
+
126
+ # Extract function output as content
127
+ if message.data:
128
+ if message.data.output:
129
+ result["content"] = _extract_content(message.data.output)
130
+ elif message.data.payload:
131
+ result["content"] = _extract_content(message.data.payload)
132
+ else:
133
+ result["content"] = ""
134
+ else:
135
+ result["content"] = ""
136
+
137
+ # Add function name if available
138
+ if message.name:
139
+ result["name"] = message.name
140
+
141
+ return result
@@ -0,0 +1,238 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+
19
+ from nat.builder.framework_enum import LLMFrameworkEnum
20
+ from nat.data_models.intermediate_step import IntermediateStep
21
+ from nat.data_models.intermediate_step import IntermediateStepState
22
+ from nat.data_models.intermediate_step import IntermediateStepType
23
+ from nat.finetuning.utils.parsers import adk_parser
24
+ from nat.finetuning.utils.parsers import langchain_parser
25
+ from nat.finetuning.utils.parsers import llama_index_parser
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def parse_to_openai_messages(steps: list[IntermediateStep]) -> list[dict]:
31
+ """
32
+ Convert IntermediateStep objects to OpenAI-compatible messages.
33
+
34
+ Args:
35
+ steps: List of IntermediateStep objects representing the conversation.
36
+
37
+ Returns:
38
+ List of dictionaries formatted for OpenAI API consumption.
39
+
40
+ Raises:
41
+ ValueError: If unsupported type or invalid sequence.
42
+ """
43
+
44
+ messages = []
45
+
46
+ # Track the last event type to handle special cases
47
+ last_event_type = None
48
+ message_content_hashes = set()
49
+ for message in steps:
50
+ # Skip LLM_START events that come after TOOL_END events
51
+ # These represent the assistant processing tool results internally
52
+ if message.event_type not in [
53
+ IntermediateStepType.LLM_END, IntermediateStepType.LLM_START, IntermediateStepType.TOOL_END
54
+ ]:
55
+ continue
56
+
57
+ if (message.event_type == IntermediateStepType.LLM_START and last_event_type == IntermediateStepType.TOOL_END):
58
+ continue
59
+
60
+ # Skip streaming chunks
61
+ if message.event_state not in [IntermediateStepState.START, IntermediateStepState.END]:
62
+ continue
63
+
64
+ # Parse the message based on framework
65
+ if message.framework == LLMFrameworkEnum.LANGCHAIN:
66
+ parsed_msg = langchain_parser.parse_to_openai_message(message=message)
67
+ elif message.framework == LLMFrameworkEnum.LLAMA_INDEX:
68
+ parsed_msg = llama_index_parser.parse_to_openai_message(message=message)
69
+ elif message.framework == LLMFrameworkEnum.ADK:
70
+ parsed_msg = adk_parser.parse_to_openai_message(message=message)
71
+ else:
72
+ if message.framework is not None:
73
+ logger.warning(f"Unsupported framework: {message.framework} for message {message}")
74
+ continue
75
+
76
+ # Add the parsed message
77
+ if message.event_type == IntermediateStepType.LLM_START:
78
+ # LLM_START messages may contain multiple messages (e.g., tools called by the LLM)
79
+ # We deduplicate previously seen messages if sharing message history to the model
80
+ if isinstance(parsed_msg, list):
81
+ for msg in parsed_msg:
82
+ content_hash = hash(msg["role"] + ": " + msg["content"])
83
+ if content_hash not in message_content_hashes:
84
+ messages.append(msg)
85
+ message_content_hashes.add(content_hash)
86
+ else:
87
+ content_hash = hash(parsed_msg["role"] + ": " + parsed_msg["content"])
88
+ messages.append(parsed_msg)
89
+ message_content_hashes.add(content_hash)
90
+ else:
91
+ assert not isinstance(parsed_msg, list), "TOOL_END or LLM_END should not produce multiple messages"
92
+ content_hash = hash(parsed_msg["role"] + ": " + parsed_msg["content"])
93
+ message_content_hashes.add(content_hash)
94
+ messages.append(parsed_msg)
95
+
96
+ last_event_type = message.event_type
97
+
98
+ # Validate and fix the message sequence
99
+ try:
100
+ messages = _validate_message_sequence(messages)
101
+ except Exception as _:
102
+ logger.exception("Error validating message sequence.")
103
+ raise
104
+
105
+ return messages
106
+
107
+
108
+ def _validate_message_sequence(messages: list[dict]) -> list[dict]:
109
+ """
110
+ Validate and fix the message sequence to follow OpenAI's expected format.
111
+
112
+ Rules:
113
+
114
+ - System messages can only appear at the beginning
115
+ - After system messages, must alternate between user/tool and assistant
116
+ - Cannot have consecutive user messages or consecutive assistant messages
117
+ - If first non-system messages are not user messages, they will be
118
+ concatenated into a single user message (with a warning)
119
+
120
+ Args:
121
+ messages: List of parsed OpenAI messages
122
+
123
+ Returns:
124
+ list[dict]: The validated (and potentially fixed) message list
125
+
126
+ Raises:
127
+ ValueError: If the message sequence is invalid.
128
+ """
129
+ if not messages:
130
+ return messages
131
+
132
+ # Check system messages are only at the beginning
133
+ found_non_system = False
134
+ for i, msg in enumerate(messages):
135
+ if msg.get("role") == "system":
136
+ if found_non_system:
137
+ raise ValueError(f"System message found at position {i} after "
138
+ "non-system messages. System messages must only "
139
+ "appear at the beginning.")
140
+ else:
141
+ found_non_system = True
142
+
143
+ # Find first non-system message
144
+ first_non_system_idx = 0
145
+ for i, msg in enumerate(messages):
146
+ if msg.get("role") != "system":
147
+ first_non_system_idx = i
148
+ break
149
+
150
+ # Fix non-user messages at the start of trajectory
151
+ # Collect all non-system messages before the first assistant message
152
+ if first_non_system_idx < len(messages):
153
+ # Find the first assistant message
154
+ first_assistant_idx = None
155
+ for i in range(first_non_system_idx, len(messages)):
156
+ if messages[i].get("role") == "assistant":
157
+ first_assistant_idx = i
158
+ break
159
+
160
+ # Check if we need to fix the start of the trajectory
161
+ if first_assistant_idx is not None:
162
+ messages_to_concatenate = []
163
+ for i in range(first_non_system_idx, first_assistant_idx):
164
+ msg = messages[i]
165
+ role = msg.get("role")
166
+ if role != "user":
167
+ # This message should be concatenated
168
+ messages_to_concatenate.append((i, msg))
169
+
170
+ if messages_to_concatenate:
171
+ # Collect all content from non-user messages at the start
172
+ content_parts = []
173
+ indices_to_remove = []
174
+
175
+ for i in range(first_non_system_idx, first_assistant_idx):
176
+ msg = messages[i]
177
+ role = msg.get("role")
178
+ content = msg.get("content", "")
179
+
180
+ if role not in ["user"]:
181
+ # Non-user message that needs to be consolidated
182
+ if content:
183
+ content_parts.append(f"[{role.upper()}]: {content}")
184
+ indices_to_remove.append(i)
185
+ else:
186
+ # User message - include its content
187
+ if content:
188
+ content_parts.append(content)
189
+ indices_to_remove.append(i)
190
+
191
+ # Create a single user message with concatenated content
192
+ if content_parts:
193
+ concatenated_content = "\n\n".join(content_parts)
194
+ new_user_message = {"role": "user", "content": concatenated_content}
195
+
196
+ # Log warning about the modification
197
+ logger.warning(
198
+ "Trajectory had %d non-user messages at the start "
199
+ "before the first assistant message. "
200
+ "Concatenated these into a single user message. "
201
+ "Original roles: %s",
202
+ len(messages_to_concatenate), [msg.get("role") for _, msg in messages_to_concatenate])
203
+
204
+ # Remove the old messages and insert the new one
205
+ # Remove in reverse order to maintain indices
206
+ for idx in reversed(indices_to_remove):
207
+ messages.pop(idx)
208
+
209
+ # Insert the new user message
210
+ messages.insert(first_non_system_idx, new_user_message)
211
+
212
+ # Recalculate first_non_system_idx after potential modifications
213
+ first_non_system_idx = 0
214
+ for i, msg in enumerate(messages):
215
+ if msg.get("role") != "system":
216
+ first_non_system_idx = i
217
+ break
218
+
219
+ # Validate alternating pattern after system messages
220
+ if first_non_system_idx < len(messages):
221
+ prev_role = None
222
+ for i in range(first_non_system_idx, len(messages)):
223
+ role = messages[i].get("role")
224
+
225
+ if prev_role:
226
+ # Check for invalid consecutive roles
227
+ if role == "user" and prev_role == "user":
228
+ raise ValueError(f"Consecutive user messages at positions {i-1} "
229
+ f"and {i}. User messages must be followed by "
230
+ "assistant messages.")
231
+ elif role == "assistant" and prev_role == "assistant":
232
+ raise ValueError(f"Consecutive assistant messages at positions "
233
+ f"{i-1} and {i}. Assistant messages must be "
234
+ "followed by user or tool messages.")
235
+
236
+ prev_role = role
237
+
238
+ return messages