nvidia-nat 1.2.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. aiq/agent/__init__.py +0 -0
  2. aiq/agent/base.py +239 -0
  3. aiq/agent/dual_node.py +67 -0
  4. aiq/agent/react_agent/__init__.py +0 -0
  5. aiq/agent/react_agent/agent.py +355 -0
  6. aiq/agent/react_agent/output_parser.py +104 -0
  7. aiq/agent/react_agent/prompt.py +41 -0
  8. aiq/agent/react_agent/register.py +149 -0
  9. aiq/agent/reasoning_agent/__init__.py +0 -0
  10. aiq/agent/reasoning_agent/reasoning_agent.py +225 -0
  11. aiq/agent/register.py +23 -0
  12. aiq/agent/rewoo_agent/__init__.py +0 -0
  13. aiq/agent/rewoo_agent/agent.py +411 -0
  14. aiq/agent/rewoo_agent/prompt.py +108 -0
  15. aiq/agent/rewoo_agent/register.py +158 -0
  16. aiq/agent/tool_calling_agent/__init__.py +0 -0
  17. aiq/agent/tool_calling_agent/agent.py +119 -0
  18. aiq/agent/tool_calling_agent/register.py +106 -0
  19. aiq/authentication/__init__.py +14 -0
  20. aiq/authentication/api_key/__init__.py +14 -0
  21. aiq/authentication/api_key/api_key_auth_provider.py +96 -0
  22. aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
  23. aiq/authentication/api_key/register.py +26 -0
  24. aiq/authentication/exceptions/__init__.py +14 -0
  25. aiq/authentication/exceptions/api_key_exceptions.py +38 -0
  26. aiq/authentication/http_basic_auth/__init__.py +0 -0
  27. aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  28. aiq/authentication/http_basic_auth/register.py +30 -0
  29. aiq/authentication/interfaces.py +93 -0
  30. aiq/authentication/oauth2/__init__.py +14 -0
  31. aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
  32. aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  33. aiq/authentication/oauth2/register.py +25 -0
  34. aiq/authentication/register.py +21 -0
  35. aiq/builder/__init__.py +0 -0
  36. aiq/builder/builder.py +285 -0
  37. aiq/builder/component_utils.py +316 -0
  38. aiq/builder/context.py +264 -0
  39. aiq/builder/embedder.py +24 -0
  40. aiq/builder/eval_builder.py +161 -0
  41. aiq/builder/evaluator.py +29 -0
  42. aiq/builder/framework_enum.py +24 -0
  43. aiq/builder/front_end.py +73 -0
  44. aiq/builder/function.py +344 -0
  45. aiq/builder/function_base.py +380 -0
  46. aiq/builder/function_info.py +627 -0
  47. aiq/builder/intermediate_step_manager.py +174 -0
  48. aiq/builder/llm.py +25 -0
  49. aiq/builder/retriever.py +25 -0
  50. aiq/builder/user_interaction_manager.py +74 -0
  51. aiq/builder/workflow.py +148 -0
  52. aiq/builder/workflow_builder.py +1117 -0
  53. aiq/cli/__init__.py +14 -0
  54. aiq/cli/cli_utils/__init__.py +0 -0
  55. aiq/cli/cli_utils/config_override.py +231 -0
  56. aiq/cli/cli_utils/validation.py +37 -0
  57. aiq/cli/commands/__init__.py +0 -0
  58. aiq/cli/commands/configure/__init__.py +0 -0
  59. aiq/cli/commands/configure/channel/__init__.py +0 -0
  60. aiq/cli/commands/configure/channel/add.py +28 -0
  61. aiq/cli/commands/configure/channel/channel.py +36 -0
  62. aiq/cli/commands/configure/channel/remove.py +30 -0
  63. aiq/cli/commands/configure/channel/update.py +30 -0
  64. aiq/cli/commands/configure/configure.py +33 -0
  65. aiq/cli/commands/evaluate.py +139 -0
  66. aiq/cli/commands/info/__init__.py +14 -0
  67. aiq/cli/commands/info/info.py +39 -0
  68. aiq/cli/commands/info/list_channels.py +32 -0
  69. aiq/cli/commands/info/list_components.py +129 -0
  70. aiq/cli/commands/info/list_mcp.py +213 -0
  71. aiq/cli/commands/registry/__init__.py +14 -0
  72. aiq/cli/commands/registry/publish.py +88 -0
  73. aiq/cli/commands/registry/pull.py +118 -0
  74. aiq/cli/commands/registry/registry.py +38 -0
  75. aiq/cli/commands/registry/remove.py +108 -0
  76. aiq/cli/commands/registry/search.py +155 -0
  77. aiq/cli/commands/sizing/__init__.py +14 -0
  78. aiq/cli/commands/sizing/calc.py +297 -0
  79. aiq/cli/commands/sizing/sizing.py +27 -0
  80. aiq/cli/commands/start.py +246 -0
  81. aiq/cli/commands/uninstall.py +81 -0
  82. aiq/cli/commands/validate.py +47 -0
  83. aiq/cli/commands/workflow/__init__.py +14 -0
  84. aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  85. aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
  86. aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
  87. aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
  88. aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
  89. aiq/cli/commands/workflow/workflow.py +37 -0
  90. aiq/cli/commands/workflow/workflow_commands.py +313 -0
  91. aiq/cli/entrypoint.py +135 -0
  92. aiq/cli/main.py +44 -0
  93. aiq/cli/register_workflow.py +488 -0
  94. aiq/cli/type_registry.py +1000 -0
  95. aiq/data_models/__init__.py +14 -0
  96. aiq/data_models/api_server.py +694 -0
  97. aiq/data_models/authentication.py +231 -0
  98. aiq/data_models/common.py +171 -0
  99. aiq/data_models/component.py +54 -0
  100. aiq/data_models/component_ref.py +168 -0
  101. aiq/data_models/config.py +406 -0
  102. aiq/data_models/dataset_handler.py +123 -0
  103. aiq/data_models/discovery_metadata.py +335 -0
  104. aiq/data_models/embedder.py +27 -0
  105. aiq/data_models/evaluate.py +127 -0
  106. aiq/data_models/evaluator.py +26 -0
  107. aiq/data_models/front_end.py +26 -0
  108. aiq/data_models/function.py +30 -0
  109. aiq/data_models/function_dependencies.py +72 -0
  110. aiq/data_models/interactive.py +246 -0
  111. aiq/data_models/intermediate_step.py +302 -0
  112. aiq/data_models/invocation_node.py +38 -0
  113. aiq/data_models/llm.py +27 -0
  114. aiq/data_models/logging.py +26 -0
  115. aiq/data_models/memory.py +27 -0
  116. aiq/data_models/object_store.py +44 -0
  117. aiq/data_models/profiler.py +54 -0
  118. aiq/data_models/registry_handler.py +26 -0
  119. aiq/data_models/retriever.py +30 -0
  120. aiq/data_models/retry_mixin.py +35 -0
  121. aiq/data_models/span.py +187 -0
  122. aiq/data_models/step_adaptor.py +64 -0
  123. aiq/data_models/streaming.py +33 -0
  124. aiq/data_models/swe_bench_model.py +54 -0
  125. aiq/data_models/telemetry_exporter.py +26 -0
  126. aiq/data_models/ttc_strategy.py +30 -0
  127. aiq/embedder/__init__.py +0 -0
  128. aiq/embedder/langchain_client.py +41 -0
  129. aiq/embedder/nim_embedder.py +59 -0
  130. aiq/embedder/openai_embedder.py +43 -0
  131. aiq/embedder/register.py +24 -0
  132. aiq/eval/__init__.py +14 -0
  133. aiq/eval/config.py +60 -0
  134. aiq/eval/dataset_handler/__init__.py +0 -0
  135. aiq/eval/dataset_handler/dataset_downloader.py +106 -0
  136. aiq/eval/dataset_handler/dataset_filter.py +52 -0
  137. aiq/eval/dataset_handler/dataset_handler.py +254 -0
  138. aiq/eval/evaluate.py +506 -0
  139. aiq/eval/evaluator/__init__.py +14 -0
  140. aiq/eval/evaluator/base_evaluator.py +73 -0
  141. aiq/eval/evaluator/evaluator_model.py +45 -0
  142. aiq/eval/intermediate_step_adapter.py +99 -0
  143. aiq/eval/rag_evaluator/__init__.py +0 -0
  144. aiq/eval/rag_evaluator/evaluate.py +178 -0
  145. aiq/eval/rag_evaluator/register.py +143 -0
  146. aiq/eval/register.py +23 -0
  147. aiq/eval/remote_workflow.py +133 -0
  148. aiq/eval/runners/__init__.py +14 -0
  149. aiq/eval/runners/config.py +39 -0
  150. aiq/eval/runners/multi_eval_runner.py +54 -0
  151. aiq/eval/runtime_event_subscriber.py +52 -0
  152. aiq/eval/swe_bench_evaluator/__init__.py +0 -0
  153. aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
  154. aiq/eval/swe_bench_evaluator/register.py +36 -0
  155. aiq/eval/trajectory_evaluator/__init__.py +0 -0
  156. aiq/eval/trajectory_evaluator/evaluate.py +75 -0
  157. aiq/eval/trajectory_evaluator/register.py +40 -0
  158. aiq/eval/tunable_rag_evaluator/__init__.py +0 -0
  159. aiq/eval/tunable_rag_evaluator/evaluate.py +245 -0
  160. aiq/eval/tunable_rag_evaluator/register.py +52 -0
  161. aiq/eval/usage_stats.py +41 -0
  162. aiq/eval/utils/__init__.py +0 -0
  163. aiq/eval/utils/output_uploader.py +140 -0
  164. aiq/eval/utils/tqdm_position_registry.py +40 -0
  165. aiq/eval/utils/weave_eval.py +184 -0
  166. aiq/experimental/__init__.py +0 -0
  167. aiq/experimental/decorators/__init__.py +0 -0
  168. aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
  169. aiq/experimental/test_time_compute/__init__.py +0 -0
  170. aiq/experimental/test_time_compute/editing/__init__.py +0 -0
  171. aiq/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  172. aiq/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  173. aiq/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  174. aiq/experimental/test_time_compute/functions/__init__.py +0 -0
  175. aiq/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  176. aiq/experimental/test_time_compute/functions/its_tool_orchestration_function.py +205 -0
  177. aiq/experimental/test_time_compute/functions/its_tool_wrapper_function.py +146 -0
  178. aiq/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
  179. aiq/experimental/test_time_compute/models/__init__.py +0 -0
  180. aiq/experimental/test_time_compute/models/editor_config.py +132 -0
  181. aiq/experimental/test_time_compute/models/scoring_config.py +112 -0
  182. aiq/experimental/test_time_compute/models/search_config.py +120 -0
  183. aiq/experimental/test_time_compute/models/selection_config.py +154 -0
  184. aiq/experimental/test_time_compute/models/stage_enums.py +43 -0
  185. aiq/experimental/test_time_compute/models/strategy_base.py +66 -0
  186. aiq/experimental/test_time_compute/models/tool_use_config.py +41 -0
  187. aiq/experimental/test_time_compute/models/ttc_item.py +48 -0
  188. aiq/experimental/test_time_compute/register.py +36 -0
  189. aiq/experimental/test_time_compute/scoring/__init__.py +0 -0
  190. aiq/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  191. aiq/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  192. aiq/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  193. aiq/experimental/test_time_compute/search/__init__.py +0 -0
  194. aiq/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  195. aiq/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  196. aiq/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  197. aiq/experimental/test_time_compute/selection/__init__.py +0 -0
  198. aiq/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  199. aiq/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  200. aiq/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
  201. aiq/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  202. aiq/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  203. aiq/front_ends/__init__.py +14 -0
  204. aiq/front_ends/console/__init__.py +14 -0
  205. aiq/front_ends/console/authentication_flow_handler.py +233 -0
  206. aiq/front_ends/console/console_front_end_config.py +32 -0
  207. aiq/front_ends/console/console_front_end_plugin.py +96 -0
  208. aiq/front_ends/console/register.py +25 -0
  209. aiq/front_ends/cron/__init__.py +14 -0
  210. aiq/front_ends/fastapi/__init__.py +14 -0
  211. aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  212. aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  213. aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
  214. aiq/front_ends/fastapi/fastapi_front_end_config.py +234 -0
  215. aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  216. aiq/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
  217. aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1092 -0
  218. aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
  219. aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  220. aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  221. aiq/front_ends/fastapi/job_store.py +183 -0
  222. aiq/front_ends/fastapi/main.py +72 -0
  223. aiq/front_ends/fastapi/message_handler.py +298 -0
  224. aiq/front_ends/fastapi/message_validator.py +345 -0
  225. aiq/front_ends/fastapi/register.py +25 -0
  226. aiq/front_ends/fastapi/response_helpers.py +195 -0
  227. aiq/front_ends/fastapi/step_adaptor.py +321 -0
  228. aiq/front_ends/mcp/__init__.py +14 -0
  229. aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
  230. aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
  231. aiq/front_ends/mcp/register.py +27 -0
  232. aiq/front_ends/mcp/tool_converter.py +242 -0
  233. aiq/front_ends/register.py +22 -0
  234. aiq/front_ends/simple_base/__init__.py +14 -0
  235. aiq/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
  236. aiq/llm/__init__.py +0 -0
  237. aiq/llm/aws_bedrock_llm.py +57 -0
  238. aiq/llm/nim_llm.py +46 -0
  239. aiq/llm/openai_llm.py +46 -0
  240. aiq/llm/register.py +23 -0
  241. aiq/llm/utils/__init__.py +14 -0
  242. aiq/llm/utils/env_config_value.py +94 -0
  243. aiq/llm/utils/error.py +17 -0
  244. aiq/memory/__init__.py +20 -0
  245. aiq/memory/interfaces.py +183 -0
  246. aiq/memory/models.py +112 -0
  247. aiq/meta/module_to_distro.json +3 -0
  248. aiq/meta/pypi.md +58 -0
  249. aiq/object_store/__init__.py +20 -0
  250. aiq/object_store/in_memory_object_store.py +76 -0
  251. aiq/object_store/interfaces.py +84 -0
  252. aiq/object_store/models.py +36 -0
  253. aiq/object_store/register.py +20 -0
  254. aiq/observability/__init__.py +14 -0
  255. aiq/observability/exporter/__init__.py +14 -0
  256. aiq/observability/exporter/base_exporter.py +449 -0
  257. aiq/observability/exporter/exporter.py +78 -0
  258. aiq/observability/exporter/file_exporter.py +33 -0
  259. aiq/observability/exporter/processing_exporter.py +322 -0
  260. aiq/observability/exporter/raw_exporter.py +52 -0
  261. aiq/observability/exporter/span_exporter.py +265 -0
  262. aiq/observability/exporter_manager.py +335 -0
  263. aiq/observability/mixin/__init__.py +14 -0
  264. aiq/observability/mixin/batch_config_mixin.py +26 -0
  265. aiq/observability/mixin/collector_config_mixin.py +23 -0
  266. aiq/observability/mixin/file_mixin.py +288 -0
  267. aiq/observability/mixin/file_mode.py +23 -0
  268. aiq/observability/mixin/resource_conflict_mixin.py +134 -0
  269. aiq/observability/mixin/serialize_mixin.py +61 -0
  270. aiq/observability/mixin/type_introspection_mixin.py +183 -0
  271. aiq/observability/processor/__init__.py +14 -0
  272. aiq/observability/processor/batching_processor.py +310 -0
  273. aiq/observability/processor/callback_processor.py +42 -0
  274. aiq/observability/processor/intermediate_step_serializer.py +28 -0
  275. aiq/observability/processor/processor.py +71 -0
  276. aiq/observability/register.py +96 -0
  277. aiq/observability/utils/__init__.py +14 -0
  278. aiq/observability/utils/dict_utils.py +236 -0
  279. aiq/observability/utils/time_utils.py +31 -0
  280. aiq/plugins/.namespace +1 -0
  281. aiq/profiler/__init__.py +0 -0
  282. aiq/profiler/calc/__init__.py +14 -0
  283. aiq/profiler/calc/calc_runner.py +627 -0
  284. aiq/profiler/calc/calculations.py +288 -0
  285. aiq/profiler/calc/data_models.py +188 -0
  286. aiq/profiler/calc/plot.py +345 -0
  287. aiq/profiler/callbacks/__init__.py +0 -0
  288. aiq/profiler/callbacks/agno_callback_handler.py +295 -0
  289. aiq/profiler/callbacks/base_callback_class.py +20 -0
  290. aiq/profiler/callbacks/langchain_callback_handler.py +290 -0
  291. aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
  292. aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  293. aiq/profiler/callbacks/token_usage_base_model.py +27 -0
  294. aiq/profiler/data_frame_row.py +51 -0
  295. aiq/profiler/data_models.py +24 -0
  296. aiq/profiler/decorators/__init__.py +0 -0
  297. aiq/profiler/decorators/framework_wrapper.py +131 -0
  298. aiq/profiler/decorators/function_tracking.py +254 -0
  299. aiq/profiler/forecasting/__init__.py +0 -0
  300. aiq/profiler/forecasting/config.py +18 -0
  301. aiq/profiler/forecasting/model_trainer.py +75 -0
  302. aiq/profiler/forecasting/models/__init__.py +22 -0
  303. aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
  304. aiq/profiler/forecasting/models/linear_model.py +196 -0
  305. aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
  306. aiq/profiler/inference_metrics_model.py +28 -0
  307. aiq/profiler/inference_optimization/__init__.py +0 -0
  308. aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  309. aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  310. aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  311. aiq/profiler/inference_optimization/data_models.py +386 -0
  312. aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
  313. aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  314. aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
  315. aiq/profiler/inference_optimization/llm_metrics.py +212 -0
  316. aiq/profiler/inference_optimization/prompt_caching.py +163 -0
  317. aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
  318. aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
  319. aiq/profiler/intermediate_property_adapter.py +102 -0
  320. aiq/profiler/profile_runner.py +473 -0
  321. aiq/profiler/utils.py +184 -0
  322. aiq/registry_handlers/__init__.py +0 -0
  323. aiq/registry_handlers/local/__init__.py +0 -0
  324. aiq/registry_handlers/local/local_handler.py +176 -0
  325. aiq/registry_handlers/local/register_local.py +37 -0
  326. aiq/registry_handlers/metadata_factory.py +60 -0
  327. aiq/registry_handlers/package_utils.py +567 -0
  328. aiq/registry_handlers/pypi/__init__.py +0 -0
  329. aiq/registry_handlers/pypi/pypi_handler.py +251 -0
  330. aiq/registry_handlers/pypi/register_pypi.py +40 -0
  331. aiq/registry_handlers/register.py +21 -0
  332. aiq/registry_handlers/registry_handler_base.py +157 -0
  333. aiq/registry_handlers/rest/__init__.py +0 -0
  334. aiq/registry_handlers/rest/register_rest.py +56 -0
  335. aiq/registry_handlers/rest/rest_handler.py +237 -0
  336. aiq/registry_handlers/schemas/__init__.py +0 -0
  337. aiq/registry_handlers/schemas/headers.py +42 -0
  338. aiq/registry_handlers/schemas/package.py +68 -0
  339. aiq/registry_handlers/schemas/publish.py +63 -0
  340. aiq/registry_handlers/schemas/pull.py +82 -0
  341. aiq/registry_handlers/schemas/remove.py +36 -0
  342. aiq/registry_handlers/schemas/search.py +91 -0
  343. aiq/registry_handlers/schemas/status.py +47 -0
  344. aiq/retriever/__init__.py +0 -0
  345. aiq/retriever/interface.py +37 -0
  346. aiq/retriever/milvus/__init__.py +14 -0
  347. aiq/retriever/milvus/register.py +81 -0
  348. aiq/retriever/milvus/retriever.py +228 -0
  349. aiq/retriever/models.py +74 -0
  350. aiq/retriever/nemo_retriever/__init__.py +14 -0
  351. aiq/retriever/nemo_retriever/register.py +60 -0
  352. aiq/retriever/nemo_retriever/retriever.py +190 -0
  353. aiq/retriever/register.py +22 -0
  354. aiq/runtime/__init__.py +14 -0
  355. aiq/runtime/loader.py +215 -0
  356. aiq/runtime/runner.py +190 -0
  357. aiq/runtime/session.py +158 -0
  358. aiq/runtime/user_metadata.py +130 -0
  359. aiq/settings/__init__.py +0 -0
  360. aiq/settings/global_settings.py +318 -0
  361. aiq/test/.namespace +1 -0
  362. aiq/tool/__init__.py +0 -0
  363. aiq/tool/chat_completion.py +74 -0
  364. aiq/tool/code_execution/README.md +151 -0
  365. aiq/tool/code_execution/__init__.py +0 -0
  366. aiq/tool/code_execution/code_sandbox.py +267 -0
  367. aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
  368. aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  369. aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
  370. aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  371. aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  372. aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  373. aiq/tool/code_execution/register.py +74 -0
  374. aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
  375. aiq/tool/code_execution/utils.py +100 -0
  376. aiq/tool/datetime_tools.py +42 -0
  377. aiq/tool/document_search.py +141 -0
  378. aiq/tool/github_tools/__init__.py +0 -0
  379. aiq/tool/github_tools/create_github_commit.py +133 -0
  380. aiq/tool/github_tools/create_github_issue.py +87 -0
  381. aiq/tool/github_tools/create_github_pr.py +106 -0
  382. aiq/tool/github_tools/get_github_file.py +106 -0
  383. aiq/tool/github_tools/get_github_issue.py +166 -0
  384. aiq/tool/github_tools/get_github_pr.py +256 -0
  385. aiq/tool/github_tools/update_github_issue.py +100 -0
  386. aiq/tool/mcp/__init__.py +14 -0
  387. aiq/tool/mcp/exceptions.py +142 -0
  388. aiq/tool/mcp/mcp_client.py +255 -0
  389. aiq/tool/mcp/mcp_tool.py +96 -0
  390. aiq/tool/memory_tools/__init__.py +0 -0
  391. aiq/tool/memory_tools/add_memory_tool.py +79 -0
  392. aiq/tool/memory_tools/delete_memory_tool.py +67 -0
  393. aiq/tool/memory_tools/get_memory_tool.py +72 -0
  394. aiq/tool/nvidia_rag.py +95 -0
  395. aiq/tool/register.py +38 -0
  396. aiq/tool/retriever.py +89 -0
  397. aiq/tool/server_tools.py +66 -0
  398. aiq/utils/__init__.py +0 -0
  399. aiq/utils/data_models/__init__.py +0 -0
  400. aiq/utils/data_models/schema_validator.py +58 -0
  401. aiq/utils/debugging_utils.py +43 -0
  402. aiq/utils/dump_distro_mapping.py +32 -0
  403. aiq/utils/exception_handlers/__init__.py +0 -0
  404. aiq/utils/exception_handlers/automatic_retries.py +289 -0
  405. aiq/utils/exception_handlers/mcp.py +211 -0
  406. aiq/utils/exception_handlers/schemas.py +114 -0
  407. aiq/utils/io/__init__.py +0 -0
  408. aiq/utils/io/model_processing.py +28 -0
  409. aiq/utils/io/yaml_tools.py +119 -0
  410. aiq/utils/log_utils.py +37 -0
  411. aiq/utils/metadata_utils.py +74 -0
  412. aiq/utils/optional_imports.py +142 -0
  413. aiq/utils/producer_consumer_queue.py +178 -0
  414. aiq/utils/reactive/__init__.py +0 -0
  415. aiq/utils/reactive/base/__init__.py +0 -0
  416. aiq/utils/reactive/base/observable_base.py +65 -0
  417. aiq/utils/reactive/base/observer_base.py +55 -0
  418. aiq/utils/reactive/base/subject_base.py +79 -0
  419. aiq/utils/reactive/observable.py +59 -0
  420. aiq/utils/reactive/observer.py +76 -0
  421. aiq/utils/reactive/subject.py +131 -0
  422. aiq/utils/reactive/subscription.py +49 -0
  423. aiq/utils/settings/__init__.py +0 -0
  424. aiq/utils/settings/global_settings.py +197 -0
  425. aiq/utils/string_utils.py +38 -0
  426. aiq/utils/type_converter.py +290 -0
  427. aiq/utils/type_utils.py +484 -0
  428. aiq/utils/url_utils.py +27 -0
  429. nvidia_nat-1.2.0rc5.dist-info/METADATA +363 -0
  430. nvidia_nat-1.2.0rc5.dist-info/RECORD +435 -0
  431. nvidia_nat-1.2.0rc5.dist-info/WHEEL +5 -0
  432. nvidia_nat-1.2.0rc5.dist-info/entry_points.txt +20 -0
  433. nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
  434. nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE.md +201 -0
  435. nvidia_nat-1.2.0rc5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,254 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import inspect
18
+ import uuid
19
+ from typing import Any
20
+
21
+ from pydantic import BaseModel
22
+
23
+ from aiq.builder.context import AIQContext
24
+ from aiq.builder.intermediate_step_manager import IntermediateStepManager
25
+ from aiq.data_models.intermediate_step import IntermediateStepPayload
26
+ from aiq.data_models.intermediate_step import IntermediateStepType
27
+ from aiq.data_models.intermediate_step import TraceMetadata
28
+
29
+
30
+ # --- Helper function to recursively serialize any object into JSON-friendly data ---
31
+ def _serialize_data(obj: Any) -> Any:
32
+ """Convert `obj` into a structure that can be passed to `json.dumps(...)`."""
33
+ if isinstance(obj, BaseModel):
34
+ # Convert Pydantic model to dict
35
+ return obj.model_dump()
36
+
37
+ if isinstance(obj, dict):
38
+ return {str(k): _serialize_data(v) for k, v in obj.items()}
39
+ if isinstance(obj, (list, tuple, set)):
40
+ return [_serialize_data(item) for item in obj]
41
+
42
+ if isinstance(obj, (str, int, float, bool, type(None))):
43
+ return obj
44
+
45
+ # Fallback
46
+ return str(obj)
47
+
48
+
49
+ def _prepare_serialized_args_kwargs(*args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
50
+ """Serialize args and kwargs before calling the wrapped function."""
51
+ serialized_args = [_serialize_data(a) for a in args]
52
+ serialized_kwargs = {k: _serialize_data(v) for k, v in kwargs.items()}
53
+ return serialized_args, serialized_kwargs
54
+
55
+
56
+ def push_intermediate_step(step_manager: IntermediateStepManager,
57
+ identifier: str,
58
+ function_name: str,
59
+ event_type: IntermediateStepType,
60
+ args: Any = None,
61
+ kwargs: Any = None,
62
+ output: Any = None,
63
+ metadata: dict[str, Any] | None = None) -> None:
64
+ """Push an intermediate step to the AIQ Toolkit Event Stream."""
65
+
66
+ payload = IntermediateStepPayload(UUID=identifier,
67
+ event_type=event_type,
68
+ name=function_name,
69
+ metadata=TraceMetadata(
70
+ span_inputs=[args, kwargs],
71
+ span_outputs=output,
72
+ provided_metadata=metadata,
73
+ ))
74
+
75
+ step_manager.push_intermediate_step(payload)
76
+
77
+
78
+ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
79
+ """
80
+ Decorator that can wrap any type of function (sync, async, generator,
81
+ async generator) and executes "tracking logic" around it.
82
+
83
+ - If the function is async, it will be wrapped in an async function.
84
+ - If the function is a generator, it will be wrapped in a generator function.
85
+ - If the function is an async generator, it will be wrapped in an async generator function.
86
+ - If the function is sync, it will be wrapped in a sync function.
87
+ """
88
+ function_name: str = func.__name__ if func else "<unknown_function>"
89
+
90
+ # If called as @track_function(...) but not immediately passed a function
91
+ if func is None:
92
+
93
+ def decorator_wrapper(actual_func):
94
+ return track_function(actual_func, metadata=metadata)
95
+
96
+ return decorator_wrapper
97
+
98
+ # --- Validate metadata ---
99
+ if metadata is not None:
100
+ if not isinstance(metadata, dict):
101
+ raise TypeError("metadata must be a dict[str, Any].")
102
+ if any(not isinstance(k, str) for k in metadata.keys()):
103
+ raise TypeError("All metadata keys must be strings.")
104
+
105
+ # --- Now detect the function type and wrap accordingly ---
106
+ if inspect.isasyncgenfunction(func):
107
+ # ---------------------
108
+ # ASYNC GENERATOR
109
+ # ---------------------
110
+
111
+ @functools.wraps(func)
112
+ async def async_gen_wrapper(*args, **kwargs):
113
+ step_manager: IntermediateStepManager = AIQContext.get().intermediate_step_manager
114
+ # 1) Serialize input
115
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
116
+
117
+ invocation_id = str(uuid.uuid4())
118
+ push_intermediate_step(step_manager,
119
+ invocation_id,
120
+ function_name,
121
+ IntermediateStepType.SPAN_START,
122
+ args=serialized_args,
123
+ kwargs=serialized_kwargs,
124
+ metadata=metadata)
125
+
126
+ # 2) Call the original async generator
127
+ async for item in func(*args, **kwargs):
128
+ # 3) Serialize the yielded item before yielding it
129
+ serialized_item = _serialize_data(item)
130
+ push_intermediate_step(step_manager,
131
+ invocation_id,
132
+ function_name,
133
+ IntermediateStepType.SPAN_CHUNK,
134
+ args=serialized_args,
135
+ kwargs=serialized_kwargs,
136
+ output=serialized_item,
137
+ metadata=metadata)
138
+ yield item # yield the original item
139
+
140
+ push_intermediate_step(step_manager,
141
+ invocation_id,
142
+ function_name,
143
+ IntermediateStepType.SPAN_END,
144
+ args=serialized_args,
145
+ kwargs=serialized_kwargs,
146
+ output=None,
147
+ metadata=metadata)
148
+
149
+ # 4) Post-yield logic if any
150
+
151
+ return async_gen_wrapper
152
+
153
+ if inspect.iscoroutinefunction(func):
154
+ # ---------------------
155
+ # ASYNC FUNCTION
156
+ # ---------------------
157
+ @functools.wraps(func)
158
+ async def async_wrapper(*args, **kwargs):
159
+ step_manager: IntermediateStepManager = AIQContext.get().intermediate_step_manager
160
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
161
+ invocation_id = str(uuid.uuid4())
162
+ push_intermediate_step(step_manager,
163
+ invocation_id,
164
+ function_name,
165
+ IntermediateStepType.SPAN_START,
166
+ args=serialized_args,
167
+ kwargs=serialized_kwargs,
168
+ metadata=metadata)
169
+
170
+ result = await func(*args, **kwargs)
171
+
172
+ serialized_result = _serialize_data(result)
173
+ push_intermediate_step(step_manager,
174
+ invocation_id,
175
+ function_name,
176
+ IntermediateStepType.SPAN_END,
177
+ args=serialized_args,
178
+ kwargs=serialized_kwargs,
179
+ output=serialized_result,
180
+ metadata=metadata)
181
+
182
+ return result
183
+
184
+ return async_wrapper
185
+
186
+ if inspect.isgeneratorfunction(func):
187
+ # ---------------------
188
+ # SYNC GENERATOR
189
+ # ---------------------
190
+ @functools.wraps(func)
191
+ def sync_gen_wrapper(*args, **kwargs):
192
+ step_manager: IntermediateStepManager = AIQContext.get().intermediate_step_manager
193
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
194
+ invocation_id = str(uuid.uuid4())
195
+ push_intermediate_step(step_manager,
196
+ invocation_id,
197
+ function_name,
198
+ IntermediateStepType.SPAN_START,
199
+ args=serialized_args,
200
+ kwargs=serialized_kwargs,
201
+ metadata=metadata)
202
+
203
+ for item in func(*args, **kwargs):
204
+ serialized_item = _serialize_data(item)
205
+ push_intermediate_step(step_manager,
206
+ invocation_id,
207
+ function_name,
208
+ IntermediateStepType.SPAN_CHUNK,
209
+ args=serialized_args,
210
+ kwargs=serialized_kwargs,
211
+ output=serialized_item,
212
+ metadata=metadata)
213
+
214
+ yield item # yield the original item
215
+
216
+ push_intermediate_step(step_manager,
217
+ invocation_id,
218
+ function_name,
219
+ IntermediateStepType.SPAN_END,
220
+ args=serialized_args,
221
+ kwargs=serialized_kwargs,
222
+ output=None,
223
+ metadata=metadata)
224
+
225
+ return sync_gen_wrapper
226
+
227
+ @functools.wraps(func)
228
+ def sync_wrapper(*args, **kwargs):
229
+ step_manager: IntermediateStepManager = AIQContext.get().intermediate_step_manager
230
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
231
+ invocation_id = str(uuid.uuid4())
232
+ push_intermediate_step(step_manager,
233
+ invocation_id,
234
+ function_name,
235
+ IntermediateStepType.SPAN_START,
236
+ args=serialized_args,
237
+ kwargs=serialized_kwargs,
238
+ metadata=metadata)
239
+
240
+ result = func(*args, **kwargs)
241
+
242
+ serialized_result = _serialize_data(result)
243
+ push_intermediate_step(step_manager,
244
+ invocation_id,
245
+ function_name,
246
+ IntermediateStepType.SPAN_END,
247
+ args=serialized_args,
248
+ kwargs=serialized_kwargs,
249
+ output=serialized_result,
250
+ metadata=metadata)
251
+
252
+ return result
253
+
254
+ return sync_wrapper
File without changes
@@ -0,0 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # If you have any global constants or defaults
17
+ DEFAULT_MODEL_TYPE = "randomforest"
18
+ DEFAULT_MATRIX_LENGTH = 10
@@ -0,0 +1,75 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # forecasting/model_trainer.py
17
+
18
+ import logging
19
+
20
+ from aiq.profiler.forecasting.config import DEFAULT_MODEL_TYPE
21
+ from aiq.profiler.forecasting.models import ForecastingBaseModel
22
+ from aiq.profiler.forecasting.models import LinearModel
23
+ from aiq.profiler.forecasting.models import RandomForestModel
24
+ from aiq.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def create_model(model_type: str) -> ForecastingBaseModel:
30
+ """
31
+ A simple factory method that returns a model instance
32
+ based on the input string. Extend this with more model
33
+ classes (e.g., PolynomialModel, RandomForestModel, etc.).
34
+ """
35
+ if model_type == "linear":
36
+ return LinearModel()
37
+ if model_type == "randomforest":
38
+ return RandomForestModel()
39
+
40
+ raise ValueError(f"Unsupported model_type: {model_type}")
41
+
42
+
43
+ class ModelTrainer:
44
+ """
45
+ Orchestrates data preprocessing, training, and returning
46
+ a fitted model.
47
+
48
+ Parameters
49
+ ----------
50
+ model_type: str, default = "randomforest"
51
+ The type of model to train. Options include "linear" and "randomforest".
52
+ """
53
+
54
+ def __init__(self, model_type: str = DEFAULT_MODEL_TYPE):
55
+ self.model_type = model_type
56
+ self._model = create_model(self.model_type)
57
+
58
+ def train(self, raw_stats: list[list[IntermediatePropertyAdaptor]]) -> ForecastingBaseModel:
59
+ """
60
+ Train the model using the `raw_stats` training data.
61
+
62
+ Parameters
63
+ ----------
64
+ raw_stats: list[list[IntermediatePropertyAdaptor]]
65
+ Stats collected by the profiler.
66
+
67
+ Returns
68
+ -------
69
+ ForecastingBaseModel
70
+ A fitted model.
71
+ """
72
+
73
+ self._model.fit(raw_stats)
74
+
75
+ return self._model
@@ -0,0 +1,22 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # forecasting/models/__init__.py
17
+
18
+ from .forecasting_base_model import ForecastingBaseModel
19
+ from .linear_model import LinearModel
20
+ from .random_forest_regressor import RandomForestModel
21
+
22
+ __all__ = ["ForecastingBaseModel", "LinearModel", "RandomForestModel"]
@@ -0,0 +1,40 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # forecasting/models/base_model.py
17
+
18
+ from abc import ABC, abstractmethod
19
+ import numpy as np
20
+
21
+
22
+ class ForecastingBaseModel(ABC):
23
+ """
24
+ Abstract base class for all models in this package.
25
+ """
26
+
27
+ @abstractmethod
28
+ def fit(self, raw_stats):
29
+ """
30
+ Train/fine-tune the model on the provided dataset.
31
+ """
32
+ pass
33
+
34
+ @abstractmethod
35
+ def predict(self, raw_stats) -> np.ndarray:
36
+ """
37
+ Predict using the trained model.
38
+ Returns a np.ndarray, shape = (N, 4).
39
+ """
40
+ pass
@@ -0,0 +1,196 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ import numpy as np
19
+
20
+ from aiq.profiler.forecasting.models.forecasting_base_model import ForecastingBaseModel
21
+ from aiq.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class LinearModel(ForecastingBaseModel):
27
+ """
28
+ A linear regression model that conforms to the BaseModel interface.
29
+ """
30
+
31
+ def __init__(self):
32
+ super().__init__()
33
+
34
+ try:
35
+ from sklearn.linear_model import LinearRegression
36
+ except ImportError:
37
+ logger.error("scikit-learn is not installed. Please install scikit-learn to use the LinearModel "
38
+ "profiling model or install `aiq[profiler]` to install all necessary profiling packages.")
39
+
40
+ raise
41
+
42
+ self.model = LinearRegression()
43
+ self.matrix_length = None
44
+
45
+ def fit(self, raw_stats: list[list[IntermediatePropertyAdaptor]]):
46
+ """
47
+ X: shape (N, M) # M = matrix_length * 4
48
+ y: shape (N, 4)
49
+ """
50
+ x_flat, y_flat = self._prep_for_model_training(raw_stats)
51
+
52
+ logger.info("Training dataset size: X=%s, y=%s", x_flat.shape, y_flat.shape)
53
+
54
+ # 3) Fit
55
+ self.model.fit(x_flat, y_flat)
56
+
57
+ def predict(self, raw_stats: list[list[IntermediatePropertyAdaptor]]) -> np.ndarray:
58
+ """
59
+ Predict using the fitted linear model.
60
+ Returns shape (N, 4)
61
+ """
62
+ X = self._prep_single(raw_stats)
63
+ return self.model.predict(X)
64
+
65
+ def _prep_single(self, raw_stats: list[list[IntermediatePropertyAdaptor]]) -> np.ndarray:
66
+ arr, _ = self._extract_token_usage_meta(raw_stats)
67
+ arr = arr[0]
68
+ n_rows = arr.shape[0]
69
+
70
+ matrix_length = self.matrix_length
71
+
72
+ assert matrix_length is not None, "matrix_length must be set before calling _prep_single"
73
+
74
+ if n_rows >= matrix_length:
75
+ # Keep the latest matrix_length rows
76
+ x_mat = arr[-matrix_length:, :]
77
+ else:
78
+ # Pad with zeros at the top
79
+ pad_size = matrix_length - n_rows
80
+ pad_block = np.zeros((pad_size, arr.shape[1]), dtype=arr.dtype)
81
+ x_mat = np.vstack([pad_block, arr])
82
+
83
+ return x_mat
84
+
85
+ def _prep_for_model_training(self, raw_stats: list[list[IntermediatePropertyAdaptor]]):
86
+ raw_matrices, matrix_length = self._extract_token_usage_meta(raw_stats)
87
+
88
+ self.matrix_length = matrix_length
89
+
90
+ x_list = []
91
+ y_list = []
92
+ for arr in raw_matrices:
93
+ samples = self._preprocess_for_forecasting(arr, matrix_length)
94
+ for (x_mat, y_mat) in samples:
95
+ x_list.append(x_mat)
96
+ y_list.append(y_mat)
97
+
98
+ # 2) Flatten features
99
+ x_flat, y_flat = self._flatten_features(x_list, y_list)
100
+
101
+ return x_flat, y_flat
102
+
103
+ def _extract_token_usage_meta(self, all_requests_data: list[list[IntermediatePropertyAdaptor]]):
104
+ import math
105
+
106
+ all_run_data = []
107
+ call_stack_sizes = []
108
+
109
+ for prompt in all_requests_data:
110
+ run_data = []
111
+ seconds_between_call_map = {}
112
+
113
+ for stat in prompt:
114
+ if stat.event_type.value == "LLM_START":
115
+ seconds_between_call_map[stat.UUID] = stat.seconds_between_calls
116
+
117
+ if stat.event_type.value == "LLM_END":
118
+ step_data = [
119
+ seconds_between_call_map[stat.UUID],
120
+ stat.token_usage.prompt_tokens,
121
+ stat.token_usage.completion_tokens
122
+ ]
123
+
124
+ run_data.append(step_data)
125
+
126
+ all_run_data.append(run_data)
127
+ call_stack_sizes.append(len(run_data))
128
+
129
+ all_run_data = [np.array(run) for run in all_run_data]
130
+ recommended_matrix_length = math.ceil(sum(call_stack_sizes) / len(call_stack_sizes))
131
+
132
+ return all_run_data, recommended_matrix_length
133
+
134
+ def _preprocess_for_forecasting(self, arr: np.ndarray, matrix_length: int):
135
+ """
136
+ Given a 2D NumPy array `arr` of shape (n_rows, 4), generate a list of
137
+ (input_array, output_array) pairs for forecasting, each of shape:
138
+
139
+ - input_array: (matrix_length, 4) after padding/trimming
140
+ - output_array: (1, 4)
141
+ """
142
+ n_rows = arr.shape[0]
143
+
144
+ # partial_sums[i] = sum of arr[i:] per column
145
+ partial_sums = np.flip(np.cumsum(np.flip(arr, axis=0), axis=0), axis=0)
146
+
147
+ samples = []
148
+ for i in range(n_rows):
149
+ x_untrimmed = arr[:i + 1, :]
150
+ # Trim or pad
151
+ current_len = x_untrimmed.shape[0]
152
+ if current_len > matrix_length:
153
+ x_mat = x_untrimmed[-matrix_length:, :]
154
+ elif current_len < matrix_length:
155
+ pad_size = matrix_length - current_len
156
+ pad_block = np.zeros((pad_size, x_untrimmed.shape[1]), dtype=arr.dtype)
157
+ x_mat = np.vstack([pad_block, x_untrimmed])
158
+ else:
159
+ x_mat = x_untrimmed
160
+
161
+ # Compute output
162
+ if i == n_rows - 1:
163
+ y_vec = np.array([0, 0, 0, 0], dtype=arr.dtype)
164
+ else:
165
+ n_below = n_rows - (i + 1)
166
+ sum_below = partial_sums[i + 1]
167
+ avg_col0 = sum_below[0] / n_below
168
+ sum_rest = sum_below[1:]
169
+ y_vec = np.concatenate(([avg_col0], sum_rest))
170
+
171
+ samples.append((x_mat, y_vec.reshape(1, 4)))
172
+
173
+ return samples
174
+
175
+ def _flatten_features(self, x_list, y_list):
176
+ """
177
+ x_list: list of arrays, each of shape (matrix_length, 4)
178
+ y_list: list of arrays, each of shape (1, 4)
179
+
180
+ Returns:
181
+ x_flat: np.array of shape (N, matrix_length*4)
182
+ y_flat: np.array of shape (N, 4)
183
+ """
184
+ flattened_x = []
185
+ flattened_y = []
186
+
187
+ for x_mat, y_mat in zip(x_list, y_list):
188
+ x_1d = x_mat.flatten() # shape -> (matrix_length*4,)
189
+ y_1d = y_mat.flatten() # shape -> (4,)
190
+ flattened_x.append(x_1d)
191
+ flattened_y.append(y_1d)
192
+
193
+ x_flat = np.array(flattened_x)
194
+ y_flat = np.array(flattened_y)
195
+ logger.debug("Flattened features to shapes: %s (X), %s (y).", x_flat.shape, y_flat.shape)
196
+ return x_flat, y_flat