nvidia-nat 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/__init__.py +0 -0
  3. nat/agent/base.py +256 -0
  4. nat/agent/dual_node.py +67 -0
  5. nat/agent/react_agent/__init__.py +0 -0
  6. nat/agent/react_agent/agent.py +363 -0
  7. nat/agent/react_agent/output_parser.py +104 -0
  8. nat/agent/react_agent/prompt.py +44 -0
  9. nat/agent/react_agent/register.py +149 -0
  10. nat/agent/reasoning_agent/__init__.py +0 -0
  11. nat/agent/reasoning_agent/reasoning_agent.py +225 -0
  12. nat/agent/register.py +23 -0
  13. nat/agent/rewoo_agent/__init__.py +0 -0
  14. nat/agent/rewoo_agent/agent.py +415 -0
  15. nat/agent/rewoo_agent/prompt.py +110 -0
  16. nat/agent/rewoo_agent/register.py +157 -0
  17. nat/agent/tool_calling_agent/__init__.py +0 -0
  18. nat/agent/tool_calling_agent/agent.py +119 -0
  19. nat/agent/tool_calling_agent/register.py +106 -0
  20. nat/authentication/__init__.py +14 -0
  21. nat/authentication/api_key/__init__.py +14 -0
  22. nat/authentication/api_key/api_key_auth_provider.py +96 -0
  23. nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
  24. nat/authentication/api_key/register.py +26 -0
  25. nat/authentication/exceptions/__init__.py +14 -0
  26. nat/authentication/exceptions/api_key_exceptions.py +38 -0
  27. nat/authentication/http_basic_auth/__init__.py +0 -0
  28. nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  29. nat/authentication/http_basic_auth/register.py +30 -0
  30. nat/authentication/interfaces.py +93 -0
  31. nat/authentication/oauth2/__init__.py +14 -0
  32. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
  33. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  34. nat/authentication/oauth2/register.py +25 -0
  35. nat/authentication/register.py +21 -0
  36. nat/builder/__init__.py +0 -0
  37. nat/builder/builder.py +285 -0
  38. nat/builder/component_utils.py +316 -0
  39. nat/builder/context.py +270 -0
  40. nat/builder/embedder.py +24 -0
  41. nat/builder/eval_builder.py +161 -0
  42. nat/builder/evaluator.py +29 -0
  43. nat/builder/framework_enum.py +24 -0
  44. nat/builder/front_end.py +73 -0
  45. nat/builder/function.py +344 -0
  46. nat/builder/function_base.py +380 -0
  47. nat/builder/function_info.py +627 -0
  48. nat/builder/intermediate_step_manager.py +174 -0
  49. nat/builder/llm.py +25 -0
  50. nat/builder/retriever.py +25 -0
  51. nat/builder/user_interaction_manager.py +78 -0
  52. nat/builder/workflow.py +148 -0
  53. nat/builder/workflow_builder.py +1117 -0
  54. nat/cli/__init__.py +14 -0
  55. nat/cli/cli_utils/__init__.py +0 -0
  56. nat/cli/cli_utils/config_override.py +231 -0
  57. nat/cli/cli_utils/validation.py +37 -0
  58. nat/cli/commands/__init__.py +0 -0
  59. nat/cli/commands/configure/__init__.py +0 -0
  60. nat/cli/commands/configure/channel/__init__.py +0 -0
  61. nat/cli/commands/configure/channel/add.py +28 -0
  62. nat/cli/commands/configure/channel/channel.py +34 -0
  63. nat/cli/commands/configure/channel/remove.py +30 -0
  64. nat/cli/commands/configure/channel/update.py +30 -0
  65. nat/cli/commands/configure/configure.py +33 -0
  66. nat/cli/commands/evaluate.py +139 -0
  67. nat/cli/commands/info/__init__.py +14 -0
  68. nat/cli/commands/info/info.py +37 -0
  69. nat/cli/commands/info/list_channels.py +32 -0
  70. nat/cli/commands/info/list_components.py +129 -0
  71. nat/cli/commands/info/list_mcp.py +304 -0
  72. nat/cli/commands/registry/__init__.py +14 -0
  73. nat/cli/commands/registry/publish.py +88 -0
  74. nat/cli/commands/registry/pull.py +118 -0
  75. nat/cli/commands/registry/registry.py +36 -0
  76. nat/cli/commands/registry/remove.py +108 -0
  77. nat/cli/commands/registry/search.py +155 -0
  78. nat/cli/commands/sizing/__init__.py +14 -0
  79. nat/cli/commands/sizing/calc.py +297 -0
  80. nat/cli/commands/sizing/sizing.py +27 -0
  81. nat/cli/commands/start.py +246 -0
  82. nat/cli/commands/uninstall.py +81 -0
  83. nat/cli/commands/validate.py +47 -0
  84. nat/cli/commands/workflow/__init__.py +14 -0
  85. nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  86. nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
  87. nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
  88. nat/cli/commands/workflow/templates/register.py.j2 +5 -0
  89. nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
  90. nat/cli/commands/workflow/workflow.py +37 -0
  91. nat/cli/commands/workflow/workflow_commands.py +317 -0
  92. nat/cli/entrypoint.py +135 -0
  93. nat/cli/main.py +57 -0
  94. nat/cli/register_workflow.py +488 -0
  95. nat/cli/type_registry.py +1000 -0
  96. nat/data_models/__init__.py +14 -0
  97. nat/data_models/api_server.py +716 -0
  98. nat/data_models/authentication.py +231 -0
  99. nat/data_models/common.py +171 -0
  100. nat/data_models/component.py +58 -0
  101. nat/data_models/component_ref.py +168 -0
  102. nat/data_models/config.py +410 -0
  103. nat/data_models/dataset_handler.py +169 -0
  104. nat/data_models/discovery_metadata.py +305 -0
  105. nat/data_models/embedder.py +27 -0
  106. nat/data_models/evaluate.py +127 -0
  107. nat/data_models/evaluator.py +26 -0
  108. nat/data_models/front_end.py +26 -0
  109. nat/data_models/function.py +30 -0
  110. nat/data_models/function_dependencies.py +72 -0
  111. nat/data_models/interactive.py +246 -0
  112. nat/data_models/intermediate_step.py +302 -0
  113. nat/data_models/invocation_node.py +38 -0
  114. nat/data_models/llm.py +27 -0
  115. nat/data_models/logging.py +26 -0
  116. nat/data_models/memory.py +27 -0
  117. nat/data_models/object_store.py +44 -0
  118. nat/data_models/profiler.py +54 -0
  119. nat/data_models/registry_handler.py +26 -0
  120. nat/data_models/retriever.py +30 -0
  121. nat/data_models/retry_mixin.py +35 -0
  122. nat/data_models/span.py +190 -0
  123. nat/data_models/step_adaptor.py +64 -0
  124. nat/data_models/streaming.py +33 -0
  125. nat/data_models/swe_bench_model.py +54 -0
  126. nat/data_models/telemetry_exporter.py +26 -0
  127. nat/data_models/ttc_strategy.py +30 -0
  128. nat/embedder/__init__.py +0 -0
  129. nat/embedder/nim_embedder.py +59 -0
  130. nat/embedder/openai_embedder.py +43 -0
  131. nat/embedder/register.py +22 -0
  132. nat/eval/__init__.py +14 -0
  133. nat/eval/config.py +60 -0
  134. nat/eval/dataset_handler/__init__.py +0 -0
  135. nat/eval/dataset_handler/dataset_downloader.py +106 -0
  136. nat/eval/dataset_handler/dataset_filter.py +52 -0
  137. nat/eval/dataset_handler/dataset_handler.py +367 -0
  138. nat/eval/evaluate.py +510 -0
  139. nat/eval/evaluator/__init__.py +14 -0
  140. nat/eval/evaluator/base_evaluator.py +77 -0
  141. nat/eval/evaluator/evaluator_model.py +45 -0
  142. nat/eval/intermediate_step_adapter.py +99 -0
  143. nat/eval/rag_evaluator/__init__.py +0 -0
  144. nat/eval/rag_evaluator/evaluate.py +178 -0
  145. nat/eval/rag_evaluator/register.py +143 -0
  146. nat/eval/register.py +23 -0
  147. nat/eval/remote_workflow.py +133 -0
  148. nat/eval/runners/__init__.py +14 -0
  149. nat/eval/runners/config.py +39 -0
  150. nat/eval/runners/multi_eval_runner.py +54 -0
  151. nat/eval/runtime_event_subscriber.py +52 -0
  152. nat/eval/swe_bench_evaluator/__init__.py +0 -0
  153. nat/eval/swe_bench_evaluator/evaluate.py +215 -0
  154. nat/eval/swe_bench_evaluator/register.py +36 -0
  155. nat/eval/trajectory_evaluator/__init__.py +0 -0
  156. nat/eval/trajectory_evaluator/evaluate.py +75 -0
  157. nat/eval/trajectory_evaluator/register.py +40 -0
  158. nat/eval/tunable_rag_evaluator/__init__.py +0 -0
  159. nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
  160. nat/eval/tunable_rag_evaluator/register.py +52 -0
  161. nat/eval/usage_stats.py +41 -0
  162. nat/eval/utils/__init__.py +0 -0
  163. nat/eval/utils/output_uploader.py +140 -0
  164. nat/eval/utils/tqdm_position_registry.py +40 -0
  165. nat/eval/utils/weave_eval.py +184 -0
  166. nat/experimental/__init__.py +0 -0
  167. nat/experimental/decorators/__init__.py +0 -0
  168. nat/experimental/decorators/experimental_warning_decorator.py +134 -0
  169. nat/experimental/test_time_compute/__init__.py +0 -0
  170. nat/experimental/test_time_compute/editing/__init__.py +0 -0
  171. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  172. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  173. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  174. nat/experimental/test_time_compute/functions/__init__.py +0 -0
  175. nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  176. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
  177. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
  178. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
  179. nat/experimental/test_time_compute/models/__init__.py +0 -0
  180. nat/experimental/test_time_compute/models/editor_config.py +132 -0
  181. nat/experimental/test_time_compute/models/scoring_config.py +112 -0
  182. nat/experimental/test_time_compute/models/search_config.py +120 -0
  183. nat/experimental/test_time_compute/models/selection_config.py +154 -0
  184. nat/experimental/test_time_compute/models/stage_enums.py +43 -0
  185. nat/experimental/test_time_compute/models/strategy_base.py +66 -0
  186. nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
  187. nat/experimental/test_time_compute/models/ttc_item.py +48 -0
  188. nat/experimental/test_time_compute/register.py +36 -0
  189. nat/experimental/test_time_compute/scoring/__init__.py +0 -0
  190. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  191. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  192. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  193. nat/experimental/test_time_compute/search/__init__.py +0 -0
  194. nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  195. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  196. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  197. nat/experimental/test_time_compute/selection/__init__.py +0 -0
  198. nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  199. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  200. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
  201. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  202. nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  203. nat/front_ends/__init__.py +14 -0
  204. nat/front_ends/console/__init__.py +14 -0
  205. nat/front_ends/console/authentication_flow_handler.py +233 -0
  206. nat/front_ends/console/console_front_end_config.py +32 -0
  207. nat/front_ends/console/console_front_end_plugin.py +96 -0
  208. nat/front_ends/console/register.py +25 -0
  209. nat/front_ends/cron/__init__.py +14 -0
  210. nat/front_ends/fastapi/__init__.py +14 -0
  211. nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  212. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  213. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
  214. nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
  215. nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  216. nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
  217. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
  218. nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
  219. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  220. nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  221. nat/front_ends/fastapi/job_store.py +183 -0
  222. nat/front_ends/fastapi/main.py +72 -0
  223. nat/front_ends/fastapi/message_handler.py +320 -0
  224. nat/front_ends/fastapi/message_validator.py +352 -0
  225. nat/front_ends/fastapi/register.py +25 -0
  226. nat/front_ends/fastapi/response_helpers.py +195 -0
  227. nat/front_ends/fastapi/step_adaptor.py +319 -0
  228. nat/front_ends/mcp/__init__.py +14 -0
  229. nat/front_ends/mcp/mcp_front_end_config.py +36 -0
  230. nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
  231. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
  232. nat/front_ends/mcp/register.py +27 -0
  233. nat/front_ends/mcp/tool_converter.py +241 -0
  234. nat/front_ends/register.py +22 -0
  235. nat/front_ends/simple_base/__init__.py +14 -0
  236. nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
  237. nat/llm/__init__.py +0 -0
  238. nat/llm/aws_bedrock_llm.py +57 -0
  239. nat/llm/nim_llm.py +46 -0
  240. nat/llm/openai_llm.py +46 -0
  241. nat/llm/register.py +23 -0
  242. nat/llm/utils/__init__.py +14 -0
  243. nat/llm/utils/env_config_value.py +94 -0
  244. nat/llm/utils/error.py +17 -0
  245. nat/memory/__init__.py +20 -0
  246. nat/memory/interfaces.py +183 -0
  247. nat/memory/models.py +112 -0
  248. nat/meta/pypi.md +58 -0
  249. nat/object_store/__init__.py +20 -0
  250. nat/object_store/in_memory_object_store.py +76 -0
  251. nat/object_store/interfaces.py +84 -0
  252. nat/object_store/models.py +38 -0
  253. nat/object_store/register.py +20 -0
  254. nat/observability/__init__.py +14 -0
  255. nat/observability/exporter/__init__.py +14 -0
  256. nat/observability/exporter/base_exporter.py +449 -0
  257. nat/observability/exporter/exporter.py +78 -0
  258. nat/observability/exporter/file_exporter.py +33 -0
  259. nat/observability/exporter/processing_exporter.py +322 -0
  260. nat/observability/exporter/raw_exporter.py +52 -0
  261. nat/observability/exporter/span_exporter.py +288 -0
  262. nat/observability/exporter_manager.py +335 -0
  263. nat/observability/mixin/__init__.py +14 -0
  264. nat/observability/mixin/batch_config_mixin.py +26 -0
  265. nat/observability/mixin/collector_config_mixin.py +23 -0
  266. nat/observability/mixin/file_mixin.py +288 -0
  267. nat/observability/mixin/file_mode.py +23 -0
  268. nat/observability/mixin/resource_conflict_mixin.py +134 -0
  269. nat/observability/mixin/serialize_mixin.py +61 -0
  270. nat/observability/mixin/type_introspection_mixin.py +183 -0
  271. nat/observability/processor/__init__.py +14 -0
  272. nat/observability/processor/batching_processor.py +310 -0
  273. nat/observability/processor/callback_processor.py +42 -0
  274. nat/observability/processor/intermediate_step_serializer.py +28 -0
  275. nat/observability/processor/processor.py +71 -0
  276. nat/observability/register.py +96 -0
  277. nat/observability/utils/__init__.py +14 -0
  278. nat/observability/utils/dict_utils.py +236 -0
  279. nat/observability/utils/time_utils.py +31 -0
  280. nat/plugins/.namespace +1 -0
  281. nat/profiler/__init__.py +0 -0
  282. nat/profiler/calc/__init__.py +14 -0
  283. nat/profiler/calc/calc_runner.py +627 -0
  284. nat/profiler/calc/calculations.py +288 -0
  285. nat/profiler/calc/data_models.py +188 -0
  286. nat/profiler/calc/plot.py +345 -0
  287. nat/profiler/callbacks/__init__.py +0 -0
  288. nat/profiler/callbacks/agno_callback_handler.py +295 -0
  289. nat/profiler/callbacks/base_callback_class.py +20 -0
  290. nat/profiler/callbacks/langchain_callback_handler.py +290 -0
  291. nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
  292. nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  293. nat/profiler/callbacks/token_usage_base_model.py +27 -0
  294. nat/profiler/data_frame_row.py +51 -0
  295. nat/profiler/data_models.py +24 -0
  296. nat/profiler/decorators/__init__.py +0 -0
  297. nat/profiler/decorators/framework_wrapper.py +131 -0
  298. nat/profiler/decorators/function_tracking.py +254 -0
  299. nat/profiler/forecasting/__init__.py +0 -0
  300. nat/profiler/forecasting/config.py +18 -0
  301. nat/profiler/forecasting/model_trainer.py +75 -0
  302. nat/profiler/forecasting/models/__init__.py +22 -0
  303. nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
  304. nat/profiler/forecasting/models/linear_model.py +197 -0
  305. nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
  306. nat/profiler/inference_metrics_model.py +28 -0
  307. nat/profiler/inference_optimization/__init__.py +0 -0
  308. nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  309. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  310. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  311. nat/profiler/inference_optimization/data_models.py +386 -0
  312. nat/profiler/inference_optimization/experimental/__init__.py +0 -0
  313. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  314. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
  315. nat/profiler/inference_optimization/llm_metrics.py +212 -0
  316. nat/profiler/inference_optimization/prompt_caching.py +163 -0
  317. nat/profiler/inference_optimization/token_uniqueness.py +107 -0
  318. nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
  319. nat/profiler/intermediate_property_adapter.py +102 -0
  320. nat/profiler/profile_runner.py +473 -0
  321. nat/profiler/utils.py +184 -0
  322. nat/registry_handlers/__init__.py +0 -0
  323. nat/registry_handlers/local/__init__.py +0 -0
  324. nat/registry_handlers/local/local_handler.py +176 -0
  325. nat/registry_handlers/local/register_local.py +37 -0
  326. nat/registry_handlers/metadata_factory.py +60 -0
  327. nat/registry_handlers/package_utils.py +571 -0
  328. nat/registry_handlers/pypi/__init__.py +0 -0
  329. nat/registry_handlers/pypi/pypi_handler.py +251 -0
  330. nat/registry_handlers/pypi/register_pypi.py +40 -0
  331. nat/registry_handlers/register.py +21 -0
  332. nat/registry_handlers/registry_handler_base.py +157 -0
  333. nat/registry_handlers/rest/__init__.py +0 -0
  334. nat/registry_handlers/rest/register_rest.py +56 -0
  335. nat/registry_handlers/rest/rest_handler.py +237 -0
  336. nat/registry_handlers/schemas/__init__.py +0 -0
  337. nat/registry_handlers/schemas/headers.py +42 -0
  338. nat/registry_handlers/schemas/package.py +68 -0
  339. nat/registry_handlers/schemas/publish.py +68 -0
  340. nat/registry_handlers/schemas/pull.py +82 -0
  341. nat/registry_handlers/schemas/remove.py +36 -0
  342. nat/registry_handlers/schemas/search.py +91 -0
  343. nat/registry_handlers/schemas/status.py +47 -0
  344. nat/retriever/__init__.py +0 -0
  345. nat/retriever/interface.py +41 -0
  346. nat/retriever/milvus/__init__.py +14 -0
  347. nat/retriever/milvus/register.py +81 -0
  348. nat/retriever/milvus/retriever.py +228 -0
  349. nat/retriever/models.py +77 -0
  350. nat/retriever/nemo_retriever/__init__.py +14 -0
  351. nat/retriever/nemo_retriever/register.py +60 -0
  352. nat/retriever/nemo_retriever/retriever.py +190 -0
  353. nat/retriever/register.py +22 -0
  354. nat/runtime/__init__.py +14 -0
  355. nat/runtime/loader.py +220 -0
  356. nat/runtime/runner.py +195 -0
  357. nat/runtime/session.py +162 -0
  358. nat/runtime/user_metadata.py +130 -0
  359. nat/settings/__init__.py +0 -0
  360. nat/settings/global_settings.py +318 -0
  361. nat/test/.namespace +1 -0
  362. nat/tool/__init__.py +0 -0
  363. nat/tool/chat_completion.py +74 -0
  364. nat/tool/code_execution/README.md +151 -0
  365. nat/tool/code_execution/__init__.py +0 -0
  366. nat/tool/code_execution/code_sandbox.py +267 -0
  367. nat/tool/code_execution/local_sandbox/.gitignore +1 -0
  368. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  369. nat/tool/code_execution/local_sandbox/__init__.py +13 -0
  370. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  371. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  372. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  373. nat/tool/code_execution/register.py +74 -0
  374. nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
  375. nat/tool/code_execution/utils.py +100 -0
  376. nat/tool/datetime_tools.py +42 -0
  377. nat/tool/document_search.py +141 -0
  378. nat/tool/github_tools/__init__.py +0 -0
  379. nat/tool/github_tools/create_github_commit.py +133 -0
  380. nat/tool/github_tools/create_github_issue.py +87 -0
  381. nat/tool/github_tools/create_github_pr.py +106 -0
  382. nat/tool/github_tools/get_github_file.py +106 -0
  383. nat/tool/github_tools/get_github_issue.py +166 -0
  384. nat/tool/github_tools/get_github_pr.py +256 -0
  385. nat/tool/github_tools/update_github_issue.py +100 -0
  386. nat/tool/mcp/__init__.py +14 -0
  387. nat/tool/mcp/exceptions.py +142 -0
  388. nat/tool/mcp/mcp_client.py +255 -0
  389. nat/tool/mcp/mcp_tool.py +96 -0
  390. nat/tool/memory_tools/__init__.py +0 -0
  391. nat/tool/memory_tools/add_memory_tool.py +79 -0
  392. nat/tool/memory_tools/delete_memory_tool.py +67 -0
  393. nat/tool/memory_tools/get_memory_tool.py +72 -0
  394. nat/tool/nvidia_rag.py +95 -0
  395. nat/tool/register.py +38 -0
  396. nat/tool/retriever.py +94 -0
  397. nat/tool/server_tools.py +66 -0
  398. nat/utils/__init__.py +0 -0
  399. nat/utils/data_models/__init__.py +0 -0
  400. nat/utils/data_models/schema_validator.py +58 -0
  401. nat/utils/debugging_utils.py +43 -0
  402. nat/utils/dump_distro_mapping.py +32 -0
  403. nat/utils/exception_handlers/__init__.py +0 -0
  404. nat/utils/exception_handlers/automatic_retries.py +289 -0
  405. nat/utils/exception_handlers/mcp.py +211 -0
  406. nat/utils/exception_handlers/schemas.py +114 -0
  407. nat/utils/io/__init__.py +0 -0
  408. nat/utils/io/model_processing.py +28 -0
  409. nat/utils/io/yaml_tools.py +119 -0
  410. nat/utils/log_utils.py +37 -0
  411. nat/utils/metadata_utils.py +74 -0
  412. nat/utils/optional_imports.py +142 -0
  413. nat/utils/producer_consumer_queue.py +178 -0
  414. nat/utils/reactive/__init__.py +0 -0
  415. nat/utils/reactive/base/__init__.py +0 -0
  416. nat/utils/reactive/base/observable_base.py +65 -0
  417. nat/utils/reactive/base/observer_base.py +55 -0
  418. nat/utils/reactive/base/subject_base.py +79 -0
  419. nat/utils/reactive/observable.py +59 -0
  420. nat/utils/reactive/observer.py +76 -0
  421. nat/utils/reactive/subject.py +131 -0
  422. nat/utils/reactive/subscription.py +49 -0
  423. nat/utils/settings/__init__.py +0 -0
  424. nat/utils/settings/global_settings.py +197 -0
  425. nat/utils/string_utils.py +38 -0
  426. nat/utils/type_converter.py +290 -0
  427. nat/utils/type_utils.py +484 -0
  428. nat/utils/url_utils.py +27 -0
  429. nvidia_nat-1.2.0.dist-info/METADATA +365 -0
  430. nvidia_nat-1.2.0.dist-info/RECORD +435 -0
  431. nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
  432. nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
  433. nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  434. nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
  435. nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,254 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import inspect
18
+ import uuid
19
+ from typing import Any
20
+
21
+ from pydantic import BaseModel
22
+
23
+ from nat.builder.context import Context
24
+ from nat.builder.intermediate_step_manager import IntermediateStepManager
25
+ from nat.data_models.intermediate_step import IntermediateStepPayload
26
+ from nat.data_models.intermediate_step import IntermediateStepType
27
+ from nat.data_models.intermediate_step import TraceMetadata
28
+
29
+
30
+ # --- Helper function to recursively serialize any object into JSON-friendly data ---
31
+ def _serialize_data(obj: Any) -> Any:
32
+ """Convert `obj` into a structure that can be passed to `json.dumps(...)`."""
33
+ if isinstance(obj, BaseModel):
34
+ # Convert Pydantic model to dict
35
+ return obj.model_dump()
36
+
37
+ if isinstance(obj, dict):
38
+ return {str(k): _serialize_data(v) for k, v in obj.items()}
39
+ if isinstance(obj, (list, tuple, set)):
40
+ return [_serialize_data(item) for item in obj]
41
+
42
+ if isinstance(obj, (str, int, float, bool, type(None))):
43
+ return obj
44
+
45
+ # Fallback
46
+ return str(obj)
47
+
48
+
49
+ def _prepare_serialized_args_kwargs(*args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
50
+ """Serialize args and kwargs before calling the wrapped function."""
51
+ serialized_args = [_serialize_data(a) for a in args]
52
+ serialized_kwargs = {k: _serialize_data(v) for k, v in kwargs.items()}
53
+ return serialized_args, serialized_kwargs
54
+
55
+
56
+ def push_intermediate_step(step_manager: IntermediateStepManager,
57
+ identifier: str,
58
+ function_name: str,
59
+ event_type: IntermediateStepType,
60
+ args: Any = None,
61
+ kwargs: Any = None,
62
+ output: Any = None,
63
+ metadata: dict[str, Any] | None = None) -> None:
64
+ """Push an intermediate step to the NAT Event Stream."""
65
+
66
+ payload = IntermediateStepPayload(UUID=identifier,
67
+ event_type=event_type,
68
+ name=function_name,
69
+ metadata=TraceMetadata(
70
+ span_inputs=[args, kwargs],
71
+ span_outputs=output,
72
+ provided_metadata=metadata,
73
+ ))
74
+
75
+ step_manager.push_intermediate_step(payload)
76
+
77
+
78
+ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
79
+ """
80
+ Decorator that can wrap any type of function (sync, async, generator,
81
+ async generator) and executes "tracking logic" around it.
82
+
83
+ - If the function is async, it will be wrapped in an async function.
84
+ - If the function is a generator, it will be wrapped in a generator function.
85
+ - If the function is an async generator, it will be wrapped in an async generator function.
86
+ - If the function is sync, it will be wrapped in a sync function.
87
+ """
88
+ function_name: str = func.__name__ if func else "<unknown_function>"
89
+
90
+ # If called as @track_function(...) but not immediately passed a function
91
+ if func is None:
92
+
93
+ def decorator_wrapper(actual_func):
94
+ return track_function(actual_func, metadata=metadata)
95
+
96
+ return decorator_wrapper
97
+
98
+ # --- Validate metadata ---
99
+ if metadata is not None:
100
+ if not isinstance(metadata, dict):
101
+ raise TypeError("metadata must be a dict[str, Any].")
102
+ if any(not isinstance(k, str) for k in metadata.keys()):
103
+ raise TypeError("All metadata keys must be strings.")
104
+
105
+ # --- Now detect the function type and wrap accordingly ---
106
+ if inspect.isasyncgenfunction(func):
107
+ # ---------------------
108
+ # ASYNC GENERATOR
109
+ # ---------------------
110
+
111
+ @functools.wraps(func)
112
+ async def async_gen_wrapper(*args, **kwargs):
113
+ step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
114
+ # 1) Serialize input
115
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
116
+
117
+ invocation_id = str(uuid.uuid4())
118
+ push_intermediate_step(step_manager,
119
+ invocation_id,
120
+ function_name,
121
+ IntermediateStepType.SPAN_START,
122
+ args=serialized_args,
123
+ kwargs=serialized_kwargs,
124
+ metadata=metadata)
125
+
126
+ # 2) Call the original async generator
127
+ async for item in func(*args, **kwargs):
128
+ # 3) Serialize the yielded item before yielding it
129
+ serialized_item = _serialize_data(item)
130
+ push_intermediate_step(step_manager,
131
+ invocation_id,
132
+ function_name,
133
+ IntermediateStepType.SPAN_CHUNK,
134
+ args=serialized_args,
135
+ kwargs=serialized_kwargs,
136
+ output=serialized_item,
137
+ metadata=metadata)
138
+ yield item # yield the original item
139
+
140
+ push_intermediate_step(step_manager,
141
+ invocation_id,
142
+ function_name,
143
+ IntermediateStepType.SPAN_END,
144
+ args=serialized_args,
145
+ kwargs=serialized_kwargs,
146
+ output=None,
147
+ metadata=metadata)
148
+
149
+ # 4) Post-yield logic if any
150
+
151
+ return async_gen_wrapper
152
+
153
+ if inspect.iscoroutinefunction(func):
154
+ # ---------------------
155
+ # ASYNC FUNCTION
156
+ # ---------------------
157
+ @functools.wraps(func)
158
+ async def async_wrapper(*args, **kwargs):
159
+ step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
160
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
161
+ invocation_id = str(uuid.uuid4())
162
+ push_intermediate_step(step_manager,
163
+ invocation_id,
164
+ function_name,
165
+ IntermediateStepType.SPAN_START,
166
+ args=serialized_args,
167
+ kwargs=serialized_kwargs,
168
+ metadata=metadata)
169
+
170
+ result = await func(*args, **kwargs)
171
+
172
+ serialized_result = _serialize_data(result)
173
+ push_intermediate_step(step_manager,
174
+ invocation_id,
175
+ function_name,
176
+ IntermediateStepType.SPAN_END,
177
+ args=serialized_args,
178
+ kwargs=serialized_kwargs,
179
+ output=serialized_result,
180
+ metadata=metadata)
181
+
182
+ return result
183
+
184
+ return async_wrapper
185
+
186
+ if inspect.isgeneratorfunction(func):
187
+ # ---------------------
188
+ # SYNC GENERATOR
189
+ # ---------------------
190
+ @functools.wraps(func)
191
+ def sync_gen_wrapper(*args, **kwargs):
192
+ step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
193
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
194
+ invocation_id = str(uuid.uuid4())
195
+ push_intermediate_step(step_manager,
196
+ invocation_id,
197
+ function_name,
198
+ IntermediateStepType.SPAN_START,
199
+ args=serialized_args,
200
+ kwargs=serialized_kwargs,
201
+ metadata=metadata)
202
+
203
+ for item in func(*args, **kwargs):
204
+ serialized_item = _serialize_data(item)
205
+ push_intermediate_step(step_manager,
206
+ invocation_id,
207
+ function_name,
208
+ IntermediateStepType.SPAN_CHUNK,
209
+ args=serialized_args,
210
+ kwargs=serialized_kwargs,
211
+ output=serialized_item,
212
+ metadata=metadata)
213
+
214
+ yield item # yield the original item
215
+
216
+ push_intermediate_step(step_manager,
217
+ invocation_id,
218
+ function_name,
219
+ IntermediateStepType.SPAN_END,
220
+ args=serialized_args,
221
+ kwargs=serialized_kwargs,
222
+ output=None,
223
+ metadata=metadata)
224
+
225
+ return sync_gen_wrapper
226
+
227
+ @functools.wraps(func)
228
+ def sync_wrapper(*args, **kwargs):
229
+ step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
230
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
231
+ invocation_id = str(uuid.uuid4())
232
+ push_intermediate_step(step_manager,
233
+ invocation_id,
234
+ function_name,
235
+ IntermediateStepType.SPAN_START,
236
+ args=serialized_args,
237
+ kwargs=serialized_kwargs,
238
+ metadata=metadata)
239
+
240
+ result = func(*args, **kwargs)
241
+
242
+ serialized_result = _serialize_data(result)
243
+ push_intermediate_step(step_manager,
244
+ invocation_id,
245
+ function_name,
246
+ IntermediateStepType.SPAN_END,
247
+ args=serialized_args,
248
+ kwargs=serialized_kwargs,
249
+ output=serialized_result,
250
+ metadata=metadata)
251
+
252
+ return result
253
+
254
+ return sync_wrapper
File without changes
@@ -0,0 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # If you have any global constants or defaults
17
+ DEFAULT_MODEL_TYPE = "randomforest"
18
+ DEFAULT_MATRIX_LENGTH = 10
@@ -0,0 +1,75 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # forecasting/model_trainer.py
17
+
18
+ import logging
19
+
20
+ from nat.profiler.forecasting.config import DEFAULT_MODEL_TYPE
21
+ from nat.profiler.forecasting.models import ForecastingBaseModel
22
+ from nat.profiler.forecasting.models import LinearModel
23
+ from nat.profiler.forecasting.models import RandomForestModel
24
+ from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def create_model(model_type: str) -> ForecastingBaseModel:
30
+ """
31
+ A simple factory method that returns a model instance
32
+ based on the input string. Extend this with more model
33
+ classes (e.g., PolynomialModel, RandomForestModel, etc.).
34
+ """
35
+ if model_type == "linear":
36
+ return LinearModel()
37
+ if model_type == "randomforest":
38
+ return RandomForestModel()
39
+
40
+ raise ValueError(f"Unsupported model_type: {model_type}")
41
+
42
+
43
+ class ModelTrainer:
44
+ """
45
+ Orchestrates data preprocessing, training, and returning
46
+ a fitted model.
47
+
48
+ Parameters
49
+ ----------
50
+ model_type: str, default = "randomforest"
51
+ The type of model to train. Options include "linear" and "randomforest".
52
+ """
53
+
54
+ def __init__(self, model_type: str = DEFAULT_MODEL_TYPE):
55
+ self.model_type = model_type
56
+ self._model = create_model(self.model_type)
57
+
58
+ def train(self, raw_stats: list[list[IntermediatePropertyAdaptor]]) -> ForecastingBaseModel:
59
+ """
60
+ Train the model using the `raw_stats` training data.
61
+
62
+ Parameters
63
+ ----------
64
+ raw_stats: list[list[IntermediatePropertyAdaptor]]
65
+ Stats collected by the profiler.
66
+
67
+ Returns
68
+ -------
69
+ ForecastingBaseModel
70
+ A fitted model.
71
+ """
72
+
73
+ self._model.fit(raw_stats)
74
+
75
+ return self._model
@@ -0,0 +1,22 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # forecasting/models/__init__.py
17
+
18
+ from .forecasting_base_model import ForecastingBaseModel
19
+ from .linear_model import LinearModel
20
+ from .random_forest_regressor import RandomForestModel
21
+
22
+ __all__ = ["ForecastingBaseModel", "LinearModel", "RandomForestModel"]
@@ -0,0 +1,40 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # forecasting/models/base_model.py
17
+
18
+ from abc import ABC, abstractmethod
19
+ import numpy as np
20
+
21
+
22
+ class ForecastingBaseModel(ABC):
23
+ """
24
+ Abstract base class for all models in this package.
25
+ """
26
+
27
+ @abstractmethod
28
+ def fit(self, raw_stats):
29
+ """
30
+ Train/fine-tune the model on the provided dataset.
31
+ """
32
+ pass
33
+
34
+ @abstractmethod
35
+ def predict(self, raw_stats) -> np.ndarray:
36
+ """
37
+ Predict using the trained model.
38
+ Returns a np.ndarray, shape = (N, 4).
39
+ """
40
+ pass
@@ -0,0 +1,197 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ import numpy as np
19
+
20
+ from nat.profiler.forecasting.models.forecasting_base_model import ForecastingBaseModel
21
+ from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class LinearModel(ForecastingBaseModel):
27
+ """
28
+ A linear regression model that conforms to the BaseModel interface.
29
+ """
30
+
31
+ def __init__(self):
32
+ super().__init__()
33
+
34
+ try:
35
+ from sklearn.linear_model import LinearRegression
36
+ except ImportError:
37
+ logger.error(
38
+ "scikit-learn is not installed. Please install scikit-learn to use the LinearModel "
39
+ "profiling model or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
40
+
41
+ raise
42
+
43
+ self.model = LinearRegression()
44
+ self.matrix_length = None
45
+
46
+ def fit(self, raw_stats: list[list[IntermediatePropertyAdaptor]]):
47
+ """
48
+ X: shape (N, M) # M = matrix_length * 4
49
+ y: shape (N, 4)
50
+ """
51
+ x_flat, y_flat = self._prep_for_model_training(raw_stats)
52
+
53
+ logger.info("Training dataset size: X=%s, y=%s", x_flat.shape, y_flat.shape)
54
+
55
+ # 3) Fit
56
+ self.model.fit(x_flat, y_flat)
57
+
58
+ def predict(self, raw_stats: list[list[IntermediatePropertyAdaptor]]) -> np.ndarray:
59
+ """
60
+ Predict using the fitted linear model.
61
+ Returns shape (N, 4)
62
+ """
63
+ X = self._prep_single(raw_stats)
64
+ return self.model.predict(X)
65
+
66
+ def _prep_single(self, raw_stats: list[list[IntermediatePropertyAdaptor]]) -> np.ndarray:
67
+ arr, _ = self._extract_token_usage_meta(raw_stats)
68
+ arr = arr[0]
69
+ n_rows = arr.shape[0]
70
+
71
+ matrix_length = self.matrix_length
72
+
73
+ assert matrix_length is not None, "matrix_length must be set before calling _prep_single"
74
+
75
+ if n_rows >= matrix_length:
76
+ # Keep the latest matrix_length rows
77
+ x_mat = arr[-matrix_length:, :]
78
+ else:
79
+ # Pad with zeros at the top
80
+ pad_size = matrix_length - n_rows
81
+ pad_block = np.zeros((pad_size, arr.shape[1]), dtype=arr.dtype)
82
+ x_mat = np.vstack([pad_block, arr])
83
+
84
+ return x_mat
85
+
86
+ def _prep_for_model_training(self, raw_stats: list[list[IntermediatePropertyAdaptor]]):
87
+ raw_matrices, matrix_length = self._extract_token_usage_meta(raw_stats)
88
+
89
+ self.matrix_length = matrix_length
90
+
91
+ x_list = []
92
+ y_list = []
93
+ for arr in raw_matrices:
94
+ samples = self._preprocess_for_forecasting(arr, matrix_length)
95
+ for (x_mat, y_mat) in samples:
96
+ x_list.append(x_mat)
97
+ y_list.append(y_mat)
98
+
99
+ # 2) Flatten features
100
+ x_flat, y_flat = self._flatten_features(x_list, y_list)
101
+
102
+ return x_flat, y_flat
103
+
104
+ def _extract_token_usage_meta(self, all_requests_data: list[list[IntermediatePropertyAdaptor]]):
105
+ import math
106
+
107
+ all_run_data = []
108
+ call_stack_sizes = []
109
+
110
+ for prompt in all_requests_data:
111
+ run_data = []
112
+ seconds_between_call_map = {}
113
+
114
+ for stat in prompt:
115
+ if stat.event_type.value == "LLM_START":
116
+ seconds_between_call_map[stat.UUID] = stat.seconds_between_calls
117
+
118
+ if stat.event_type.value == "LLM_END":
119
+ step_data = [
120
+ seconds_between_call_map[stat.UUID],
121
+ stat.token_usage.prompt_tokens,
122
+ stat.token_usage.completion_tokens
123
+ ]
124
+
125
+ run_data.append(step_data)
126
+
127
+ all_run_data.append(run_data)
128
+ call_stack_sizes.append(len(run_data))
129
+
130
+ all_run_data = [np.array(run) for run in all_run_data]
131
+ recommended_matrix_length = math.ceil(sum(call_stack_sizes) / len(call_stack_sizes))
132
+
133
+ return all_run_data, recommended_matrix_length
134
+
135
+ def _preprocess_for_forecasting(self, arr: np.ndarray, matrix_length: int):
136
+ """
137
+ Given a 2D NumPy array `arr` of shape (n_rows, 4), generate a list of
138
+ (input_array, output_array) pairs for forecasting, each of shape:
139
+
140
+ - input_array: (matrix_length, 4) after padding/trimming
141
+ - output_array: (1, 4)
142
+ """
143
+ n_rows = arr.shape[0]
144
+
145
+ # partial_sums[i] = sum of arr[i:] per column
146
+ partial_sums = np.flip(np.cumsum(np.flip(arr, axis=0), axis=0), axis=0)
147
+
148
+ samples = []
149
+ for i in range(n_rows):
150
+ x_untrimmed = arr[:i + 1, :]
151
+ # Trim or pad
152
+ current_len = x_untrimmed.shape[0]
153
+ if current_len > matrix_length:
154
+ x_mat = x_untrimmed[-matrix_length:, :]
155
+ elif current_len < matrix_length:
156
+ pad_size = matrix_length - current_len
157
+ pad_block = np.zeros((pad_size, x_untrimmed.shape[1]), dtype=arr.dtype)
158
+ x_mat = np.vstack([pad_block, x_untrimmed])
159
+ else:
160
+ x_mat = x_untrimmed
161
+
162
+ # Compute output
163
+ if i == n_rows - 1:
164
+ y_vec = np.array([0, 0, 0, 0], dtype=arr.dtype)
165
+ else:
166
+ n_below = n_rows - (i + 1)
167
+ sum_below = partial_sums[i + 1]
168
+ avg_col0 = sum_below[0] / n_below
169
+ sum_rest = sum_below[1:]
170
+ y_vec = np.concatenate(([avg_col0], sum_rest))
171
+
172
+ samples.append((x_mat, y_vec.reshape(1, 4)))
173
+
174
+ return samples
175
+
176
+ def _flatten_features(self, x_list, y_list):
177
+ """
178
+ x_list: list of arrays, each of shape (matrix_length, 4)
179
+ y_list: list of arrays, each of shape (1, 4)
180
+
181
+ Returns:
182
+ x_flat: np.array of shape (N, matrix_length*4)
183
+ y_flat: np.array of shape (N, 4)
184
+ """
185
+ flattened_x = []
186
+ flattened_y = []
187
+
188
+ for x_mat, y_mat in zip(x_list, y_list):
189
+ x_1d = x_mat.flatten() # shape -> (matrix_length*4,)
190
+ y_1d = y_mat.flatten() # shape -> (4,)
191
+ flattened_x.append(x_1d)
192
+ flattened_y.append(y_1d)
193
+
194
+ x_flat = np.array(flattened_x)
195
+ y_flat = np.array(flattened_y)
196
+ logger.debug("Flattened features to shapes: %s (X), %s (y).", x_flat.shape, y_flat.shape)
197
+ return x_flat, y_flat