nvidia-nat 1.1.0a20251020__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (480) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/__init__.py +0 -0
  3. nat/agent/base.py +265 -0
  4. nat/agent/dual_node.py +72 -0
  5. nat/agent/prompt_optimizer/__init__.py +0 -0
  6. nat/agent/prompt_optimizer/prompt.py +68 -0
  7. nat/agent/prompt_optimizer/register.py +149 -0
  8. nat/agent/react_agent/__init__.py +0 -0
  9. nat/agent/react_agent/agent.py +394 -0
  10. nat/agent/react_agent/output_parser.py +104 -0
  11. nat/agent/react_agent/prompt.py +44 -0
  12. nat/agent/react_agent/register.py +168 -0
  13. nat/agent/reasoning_agent/__init__.py +0 -0
  14. nat/agent/reasoning_agent/reasoning_agent.py +227 -0
  15. nat/agent/register.py +23 -0
  16. nat/agent/rewoo_agent/__init__.py +0 -0
  17. nat/agent/rewoo_agent/agent.py +593 -0
  18. nat/agent/rewoo_agent/prompt.py +107 -0
  19. nat/agent/rewoo_agent/register.py +175 -0
  20. nat/agent/tool_calling_agent/__init__.py +0 -0
  21. nat/agent/tool_calling_agent/agent.py +246 -0
  22. nat/agent/tool_calling_agent/register.py +129 -0
  23. nat/authentication/__init__.py +14 -0
  24. nat/authentication/api_key/__init__.py +14 -0
  25. nat/authentication/api_key/api_key_auth_provider.py +96 -0
  26. nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
  27. nat/authentication/api_key/register.py +26 -0
  28. nat/authentication/credential_validator/__init__.py +14 -0
  29. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  30. nat/authentication/exceptions/__init__.py +14 -0
  31. nat/authentication/exceptions/api_key_exceptions.py +38 -0
  32. nat/authentication/http_basic_auth/__init__.py +0 -0
  33. nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
  34. nat/authentication/http_basic_auth/register.py +30 -0
  35. nat/authentication/interfaces.py +96 -0
  36. nat/authentication/oauth2/__init__.py +14 -0
  37. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +140 -0
  38. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
  39. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  40. nat/authentication/oauth2/register.py +25 -0
  41. nat/authentication/register.py +20 -0
  42. nat/builder/__init__.py +0 -0
  43. nat/builder/builder.py +317 -0
  44. nat/builder/component_utils.py +320 -0
  45. nat/builder/context.py +321 -0
  46. nat/builder/embedder.py +24 -0
  47. nat/builder/eval_builder.py +166 -0
  48. nat/builder/evaluator.py +29 -0
  49. nat/builder/framework_enum.py +25 -0
  50. nat/builder/front_end.py +73 -0
  51. nat/builder/function.py +714 -0
  52. nat/builder/function_base.py +380 -0
  53. nat/builder/function_info.py +625 -0
  54. nat/builder/intermediate_step_manager.py +206 -0
  55. nat/builder/llm.py +25 -0
  56. nat/builder/retriever.py +25 -0
  57. nat/builder/user_interaction_manager.py +78 -0
  58. nat/builder/workflow.py +160 -0
  59. nat/builder/workflow_builder.py +1365 -0
  60. nat/cli/__init__.py +14 -0
  61. nat/cli/cli_utils/__init__.py +0 -0
  62. nat/cli/cli_utils/config_override.py +231 -0
  63. nat/cli/cli_utils/validation.py +37 -0
  64. nat/cli/commands/__init__.py +0 -0
  65. nat/cli/commands/configure/__init__.py +0 -0
  66. nat/cli/commands/configure/channel/__init__.py +0 -0
  67. nat/cli/commands/configure/channel/add.py +28 -0
  68. nat/cli/commands/configure/channel/channel.py +34 -0
  69. nat/cli/commands/configure/channel/remove.py +30 -0
  70. nat/cli/commands/configure/channel/update.py +30 -0
  71. nat/cli/commands/configure/configure.py +33 -0
  72. nat/cli/commands/evaluate.py +139 -0
  73. nat/cli/commands/info/__init__.py +14 -0
  74. nat/cli/commands/info/info.py +47 -0
  75. nat/cli/commands/info/list_channels.py +32 -0
  76. nat/cli/commands/info/list_components.py +128 -0
  77. nat/cli/commands/mcp/__init__.py +14 -0
  78. nat/cli/commands/mcp/mcp.py +986 -0
  79. nat/cli/commands/object_store/__init__.py +14 -0
  80. nat/cli/commands/object_store/object_store.py +227 -0
  81. nat/cli/commands/optimize.py +90 -0
  82. nat/cli/commands/registry/__init__.py +14 -0
  83. nat/cli/commands/registry/publish.py +88 -0
  84. nat/cli/commands/registry/pull.py +118 -0
  85. nat/cli/commands/registry/registry.py +36 -0
  86. nat/cli/commands/registry/remove.py +108 -0
  87. nat/cli/commands/registry/search.py +153 -0
  88. nat/cli/commands/sizing/__init__.py +14 -0
  89. nat/cli/commands/sizing/calc.py +297 -0
  90. nat/cli/commands/sizing/sizing.py +27 -0
  91. nat/cli/commands/start.py +257 -0
  92. nat/cli/commands/uninstall.py +81 -0
  93. nat/cli/commands/validate.py +47 -0
  94. nat/cli/commands/workflow/__init__.py +14 -0
  95. nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  96. nat/cli/commands/workflow/templates/config.yml.j2 +17 -0
  97. nat/cli/commands/workflow/templates/pyproject.toml.j2 +25 -0
  98. nat/cli/commands/workflow/templates/register.py.j2 +4 -0
  99. nat/cli/commands/workflow/templates/workflow.py.j2 +50 -0
  100. nat/cli/commands/workflow/workflow.py +37 -0
  101. nat/cli/commands/workflow/workflow_commands.py +403 -0
  102. nat/cli/entrypoint.py +141 -0
  103. nat/cli/main.py +60 -0
  104. nat/cli/register_workflow.py +522 -0
  105. nat/cli/type_registry.py +1069 -0
  106. nat/control_flow/__init__.py +0 -0
  107. nat/control_flow/register.py +20 -0
  108. nat/control_flow/router_agent/__init__.py +0 -0
  109. nat/control_flow/router_agent/agent.py +329 -0
  110. nat/control_flow/router_agent/prompt.py +48 -0
  111. nat/control_flow/router_agent/register.py +91 -0
  112. nat/control_flow/sequential_executor.py +166 -0
  113. nat/data_models/__init__.py +14 -0
  114. nat/data_models/agent.py +34 -0
  115. nat/data_models/api_server.py +843 -0
  116. nat/data_models/authentication.py +245 -0
  117. nat/data_models/common.py +171 -0
  118. nat/data_models/component.py +60 -0
  119. nat/data_models/component_ref.py +179 -0
  120. nat/data_models/config.py +434 -0
  121. nat/data_models/dataset_handler.py +169 -0
  122. nat/data_models/discovery_metadata.py +305 -0
  123. nat/data_models/embedder.py +27 -0
  124. nat/data_models/evaluate.py +130 -0
  125. nat/data_models/evaluator.py +26 -0
  126. nat/data_models/front_end.py +26 -0
  127. nat/data_models/function.py +64 -0
  128. nat/data_models/function_dependencies.py +80 -0
  129. nat/data_models/gated_field_mixin.py +242 -0
  130. nat/data_models/interactive.py +246 -0
  131. nat/data_models/intermediate_step.py +302 -0
  132. nat/data_models/invocation_node.py +38 -0
  133. nat/data_models/llm.py +27 -0
  134. nat/data_models/logging.py +26 -0
  135. nat/data_models/memory.py +27 -0
  136. nat/data_models/object_store.py +44 -0
  137. nat/data_models/optimizable.py +119 -0
  138. nat/data_models/optimizer.py +149 -0
  139. nat/data_models/profiler.py +54 -0
  140. nat/data_models/registry_handler.py +26 -0
  141. nat/data_models/retriever.py +30 -0
  142. nat/data_models/retry_mixin.py +35 -0
  143. nat/data_models/span.py +228 -0
  144. nat/data_models/step_adaptor.py +64 -0
  145. nat/data_models/streaming.py +33 -0
  146. nat/data_models/swe_bench_model.py +54 -0
  147. nat/data_models/telemetry_exporter.py +26 -0
  148. nat/data_models/temperature_mixin.py +44 -0
  149. nat/data_models/thinking_mixin.py +86 -0
  150. nat/data_models/top_p_mixin.py +44 -0
  151. nat/data_models/ttc_strategy.py +30 -0
  152. nat/embedder/__init__.py +0 -0
  153. nat/embedder/azure_openai_embedder.py +46 -0
  154. nat/embedder/nim_embedder.py +59 -0
  155. nat/embedder/openai_embedder.py +42 -0
  156. nat/embedder/register.py +22 -0
  157. nat/eval/__init__.py +14 -0
  158. nat/eval/config.py +62 -0
  159. nat/eval/dataset_handler/__init__.py +0 -0
  160. nat/eval/dataset_handler/dataset_downloader.py +106 -0
  161. nat/eval/dataset_handler/dataset_filter.py +52 -0
  162. nat/eval/dataset_handler/dataset_handler.py +431 -0
  163. nat/eval/evaluate.py +565 -0
  164. nat/eval/evaluator/__init__.py +14 -0
  165. nat/eval/evaluator/base_evaluator.py +77 -0
  166. nat/eval/evaluator/evaluator_model.py +58 -0
  167. nat/eval/intermediate_step_adapter.py +99 -0
  168. nat/eval/rag_evaluator/__init__.py +0 -0
  169. nat/eval/rag_evaluator/evaluate.py +178 -0
  170. nat/eval/rag_evaluator/register.py +143 -0
  171. nat/eval/register.py +26 -0
  172. nat/eval/remote_workflow.py +133 -0
  173. nat/eval/runners/__init__.py +14 -0
  174. nat/eval/runners/config.py +39 -0
  175. nat/eval/runners/multi_eval_runner.py +54 -0
  176. nat/eval/runtime_evaluator/__init__.py +14 -0
  177. nat/eval/runtime_evaluator/evaluate.py +123 -0
  178. nat/eval/runtime_evaluator/register.py +100 -0
  179. nat/eval/runtime_event_subscriber.py +52 -0
  180. nat/eval/swe_bench_evaluator/__init__.py +0 -0
  181. nat/eval/swe_bench_evaluator/evaluate.py +215 -0
  182. nat/eval/swe_bench_evaluator/register.py +36 -0
  183. nat/eval/trajectory_evaluator/__init__.py +0 -0
  184. nat/eval/trajectory_evaluator/evaluate.py +75 -0
  185. nat/eval/trajectory_evaluator/register.py +40 -0
  186. nat/eval/tunable_rag_evaluator/__init__.py +0 -0
  187. nat/eval/tunable_rag_evaluator/evaluate.py +242 -0
  188. nat/eval/tunable_rag_evaluator/register.py +52 -0
  189. nat/eval/usage_stats.py +41 -0
  190. nat/eval/utils/__init__.py +0 -0
  191. nat/eval/utils/eval_trace_ctx.py +89 -0
  192. nat/eval/utils/output_uploader.py +140 -0
  193. nat/eval/utils/tqdm_position_registry.py +40 -0
  194. nat/eval/utils/weave_eval.py +193 -0
  195. nat/experimental/__init__.py +0 -0
  196. nat/experimental/decorators/__init__.py +0 -0
  197. nat/experimental/decorators/experimental_warning_decorator.py +154 -0
  198. nat/experimental/test_time_compute/__init__.py +0 -0
  199. nat/experimental/test_time_compute/editing/__init__.py +0 -0
  200. nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
  201. nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
  202. nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
  203. nat/experimental/test_time_compute/functions/__init__.py +0 -0
  204. nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
  205. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +228 -0
  206. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
  207. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
  208. nat/experimental/test_time_compute/models/__init__.py +0 -0
  209. nat/experimental/test_time_compute/models/editor_config.py +132 -0
  210. nat/experimental/test_time_compute/models/scoring_config.py +112 -0
  211. nat/experimental/test_time_compute/models/search_config.py +120 -0
  212. nat/experimental/test_time_compute/models/selection_config.py +154 -0
  213. nat/experimental/test_time_compute/models/stage_enums.py +43 -0
  214. nat/experimental/test_time_compute/models/strategy_base.py +67 -0
  215. nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
  216. nat/experimental/test_time_compute/models/ttc_item.py +48 -0
  217. nat/experimental/test_time_compute/register.py +35 -0
  218. nat/experimental/test_time_compute/scoring/__init__.py +0 -0
  219. nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
  220. nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
  221. nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
  222. nat/experimental/test_time_compute/search/__init__.py +0 -0
  223. nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
  224. nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
  225. nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
  226. nat/experimental/test_time_compute/selection/__init__.py +0 -0
  227. nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
  228. nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
  229. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +157 -0
  230. nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
  231. nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
  232. nat/front_ends/__init__.py +14 -0
  233. nat/front_ends/console/__init__.py +14 -0
  234. nat/front_ends/console/authentication_flow_handler.py +285 -0
  235. nat/front_ends/console/console_front_end_config.py +32 -0
  236. nat/front_ends/console/console_front_end_plugin.py +108 -0
  237. nat/front_ends/console/register.py +25 -0
  238. nat/front_ends/cron/__init__.py +14 -0
  239. nat/front_ends/fastapi/__init__.py +14 -0
  240. nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
  241. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
  242. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +142 -0
  243. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  244. nat/front_ends/fastapi/fastapi_front_end_config.py +272 -0
  245. nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
  246. nat/front_ends/fastapi/fastapi_front_end_plugin.py +247 -0
  247. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1257 -0
  248. nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
  249. nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
  250. nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  251. nat/front_ends/fastapi/job_store.py +602 -0
  252. nat/front_ends/fastapi/main.py +64 -0
  253. nat/front_ends/fastapi/message_handler.py +344 -0
  254. nat/front_ends/fastapi/message_validator.py +351 -0
  255. nat/front_ends/fastapi/register.py +25 -0
  256. nat/front_ends/fastapi/response_helpers.py +195 -0
  257. nat/front_ends/fastapi/step_adaptor.py +319 -0
  258. nat/front_ends/fastapi/utils.py +57 -0
  259. nat/front_ends/mcp/__init__.py +14 -0
  260. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  261. nat/front_ends/mcp/mcp_front_end_config.py +90 -0
  262. nat/front_ends/mcp/mcp_front_end_plugin.py +113 -0
  263. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +268 -0
  264. nat/front_ends/mcp/memory_profiler.py +320 -0
  265. nat/front_ends/mcp/register.py +27 -0
  266. nat/front_ends/mcp/tool_converter.py +290 -0
  267. nat/front_ends/register.py +21 -0
  268. nat/front_ends/simple_base/__init__.py +14 -0
  269. nat/front_ends/simple_base/simple_front_end_plugin_base.py +56 -0
  270. nat/llm/__init__.py +0 -0
  271. nat/llm/aws_bedrock_llm.py +69 -0
  272. nat/llm/azure_openai_llm.py +57 -0
  273. nat/llm/litellm_llm.py +69 -0
  274. nat/llm/nim_llm.py +58 -0
  275. nat/llm/openai_llm.py +54 -0
  276. nat/llm/register.py +27 -0
  277. nat/llm/utils/__init__.py +14 -0
  278. nat/llm/utils/env_config_value.py +93 -0
  279. nat/llm/utils/error.py +17 -0
  280. nat/llm/utils/thinking.py +215 -0
  281. nat/memory/__init__.py +20 -0
  282. nat/memory/interfaces.py +183 -0
  283. nat/memory/models.py +112 -0
  284. nat/meta/pypi.md +58 -0
  285. nat/object_store/__init__.py +20 -0
  286. nat/object_store/in_memory_object_store.py +76 -0
  287. nat/object_store/interfaces.py +84 -0
  288. nat/object_store/models.py +38 -0
  289. nat/object_store/register.py +19 -0
  290. nat/observability/__init__.py +14 -0
  291. nat/observability/exporter/__init__.py +14 -0
  292. nat/observability/exporter/base_exporter.py +449 -0
  293. nat/observability/exporter/exporter.py +78 -0
  294. nat/observability/exporter/file_exporter.py +33 -0
  295. nat/observability/exporter/processing_exporter.py +550 -0
  296. nat/observability/exporter/raw_exporter.py +52 -0
  297. nat/observability/exporter/span_exporter.py +308 -0
  298. nat/observability/exporter_manager.py +335 -0
  299. nat/observability/mixin/__init__.py +14 -0
  300. nat/observability/mixin/batch_config_mixin.py +26 -0
  301. nat/observability/mixin/collector_config_mixin.py +23 -0
  302. nat/observability/mixin/file_mixin.py +288 -0
  303. nat/observability/mixin/file_mode.py +23 -0
  304. nat/observability/mixin/redaction_config_mixin.py +42 -0
  305. nat/observability/mixin/resource_conflict_mixin.py +134 -0
  306. nat/observability/mixin/serialize_mixin.py +61 -0
  307. nat/observability/mixin/tagging_config_mixin.py +62 -0
  308. nat/observability/mixin/type_introspection_mixin.py +496 -0
  309. nat/observability/processor/__init__.py +14 -0
  310. nat/observability/processor/batching_processor.py +308 -0
  311. nat/observability/processor/callback_processor.py +42 -0
  312. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  313. nat/observability/processor/intermediate_step_serializer.py +28 -0
  314. nat/observability/processor/processor.py +74 -0
  315. nat/observability/processor/processor_factory.py +70 -0
  316. nat/observability/processor/redaction/__init__.py +24 -0
  317. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  318. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  319. nat/observability/processor/redaction/redaction_processor.py +177 -0
  320. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  321. nat/observability/processor/span_tagging_processor.py +68 -0
  322. nat/observability/register.py +114 -0
  323. nat/observability/utils/__init__.py +14 -0
  324. nat/observability/utils/dict_utils.py +236 -0
  325. nat/observability/utils/time_utils.py +31 -0
  326. nat/plugins/.namespace +1 -0
  327. nat/profiler/__init__.py +0 -0
  328. nat/profiler/calc/__init__.py +14 -0
  329. nat/profiler/calc/calc_runner.py +626 -0
  330. nat/profiler/calc/calculations.py +288 -0
  331. nat/profiler/calc/data_models.py +188 -0
  332. nat/profiler/calc/plot.py +345 -0
  333. nat/profiler/callbacks/__init__.py +0 -0
  334. nat/profiler/callbacks/agno_callback_handler.py +295 -0
  335. nat/profiler/callbacks/base_callback_class.py +20 -0
  336. nat/profiler/callbacks/langchain_callback_handler.py +297 -0
  337. nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
  338. nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  339. nat/profiler/callbacks/token_usage_base_model.py +27 -0
  340. nat/profiler/data_frame_row.py +51 -0
  341. nat/profiler/data_models.py +24 -0
  342. nat/profiler/decorators/__init__.py +0 -0
  343. nat/profiler/decorators/framework_wrapper.py +180 -0
  344. nat/profiler/decorators/function_tracking.py +411 -0
  345. nat/profiler/forecasting/__init__.py +0 -0
  346. nat/profiler/forecasting/config.py +18 -0
  347. nat/profiler/forecasting/model_trainer.py +75 -0
  348. nat/profiler/forecasting/models/__init__.py +22 -0
  349. nat/profiler/forecasting/models/forecasting_base_model.py +42 -0
  350. nat/profiler/forecasting/models/linear_model.py +197 -0
  351. nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
  352. nat/profiler/inference_metrics_model.py +28 -0
  353. nat/profiler/inference_optimization/__init__.py +0 -0
  354. nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  355. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
  356. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  357. nat/profiler/inference_optimization/data_models.py +386 -0
  358. nat/profiler/inference_optimization/experimental/__init__.py +0 -0
  359. nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  360. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +404 -0
  361. nat/profiler/inference_optimization/llm_metrics.py +212 -0
  362. nat/profiler/inference_optimization/prompt_caching.py +163 -0
  363. nat/profiler/inference_optimization/token_uniqueness.py +107 -0
  364. nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
  365. nat/profiler/intermediate_property_adapter.py +102 -0
  366. nat/profiler/parameter_optimization/__init__.py +0 -0
  367. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  368. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  369. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  370. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  371. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  372. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  373. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  374. nat/profiler/profile_runner.py +478 -0
  375. nat/profiler/utils.py +186 -0
  376. nat/registry_handlers/__init__.py +0 -0
  377. nat/registry_handlers/local/__init__.py +0 -0
  378. nat/registry_handlers/local/local_handler.py +176 -0
  379. nat/registry_handlers/local/register_local.py +37 -0
  380. nat/registry_handlers/metadata_factory.py +60 -0
  381. nat/registry_handlers/package_utils.py +570 -0
  382. nat/registry_handlers/pypi/__init__.py +0 -0
  383. nat/registry_handlers/pypi/pypi_handler.py +248 -0
  384. nat/registry_handlers/pypi/register_pypi.py +40 -0
  385. nat/registry_handlers/register.py +20 -0
  386. nat/registry_handlers/registry_handler_base.py +157 -0
  387. nat/registry_handlers/rest/__init__.py +0 -0
  388. nat/registry_handlers/rest/register_rest.py +56 -0
  389. nat/registry_handlers/rest/rest_handler.py +236 -0
  390. nat/registry_handlers/schemas/__init__.py +0 -0
  391. nat/registry_handlers/schemas/headers.py +42 -0
  392. nat/registry_handlers/schemas/package.py +68 -0
  393. nat/registry_handlers/schemas/publish.py +68 -0
  394. nat/registry_handlers/schemas/pull.py +82 -0
  395. nat/registry_handlers/schemas/remove.py +36 -0
  396. nat/registry_handlers/schemas/search.py +91 -0
  397. nat/registry_handlers/schemas/status.py +47 -0
  398. nat/retriever/__init__.py +0 -0
  399. nat/retriever/interface.py +41 -0
  400. nat/retriever/milvus/__init__.py +14 -0
  401. nat/retriever/milvus/register.py +81 -0
  402. nat/retriever/milvus/retriever.py +228 -0
  403. nat/retriever/models.py +77 -0
  404. nat/retriever/nemo_retriever/__init__.py +14 -0
  405. nat/retriever/nemo_retriever/register.py +60 -0
  406. nat/retriever/nemo_retriever/retriever.py +190 -0
  407. nat/retriever/register.py +21 -0
  408. nat/runtime/__init__.py +14 -0
  409. nat/runtime/loader.py +220 -0
  410. nat/runtime/runner.py +292 -0
  411. nat/runtime/session.py +223 -0
  412. nat/runtime/user_metadata.py +130 -0
  413. nat/settings/__init__.py +0 -0
  414. nat/settings/global_settings.py +329 -0
  415. nat/test/.namespace +1 -0
  416. nat/tool/__init__.py +0 -0
  417. nat/tool/chat_completion.py +77 -0
  418. nat/tool/code_execution/README.md +151 -0
  419. nat/tool/code_execution/__init__.py +0 -0
  420. nat/tool/code_execution/code_sandbox.py +267 -0
  421. nat/tool/code_execution/local_sandbox/.gitignore +1 -0
  422. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  423. nat/tool/code_execution/local_sandbox/__init__.py +13 -0
  424. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
  425. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
  426. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
  427. nat/tool/code_execution/register.py +74 -0
  428. nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
  429. nat/tool/code_execution/utils.py +100 -0
  430. nat/tool/datetime_tools.py +82 -0
  431. nat/tool/document_search.py +141 -0
  432. nat/tool/github_tools.py +450 -0
  433. nat/tool/memory_tools/__init__.py +0 -0
  434. nat/tool/memory_tools/add_memory_tool.py +79 -0
  435. nat/tool/memory_tools/delete_memory_tool.py +66 -0
  436. nat/tool/memory_tools/get_memory_tool.py +72 -0
  437. nat/tool/nvidia_rag.py +95 -0
  438. nat/tool/register.py +31 -0
  439. nat/tool/retriever.py +95 -0
  440. nat/tool/server_tools.py +66 -0
  441. nat/utils/__init__.py +0 -0
  442. nat/utils/callable_utils.py +70 -0
  443. nat/utils/data_models/__init__.py +0 -0
  444. nat/utils/data_models/schema_validator.py +58 -0
  445. nat/utils/debugging_utils.py +43 -0
  446. nat/utils/decorators.py +210 -0
  447. nat/utils/dump_distro_mapping.py +32 -0
  448. nat/utils/exception_handlers/__init__.py +0 -0
  449. nat/utils/exception_handlers/automatic_retries.py +342 -0
  450. nat/utils/exception_handlers/schemas.py +114 -0
  451. nat/utils/io/__init__.py +0 -0
  452. nat/utils/io/model_processing.py +28 -0
  453. nat/utils/io/yaml_tools.py +119 -0
  454. nat/utils/log_levels.py +25 -0
  455. nat/utils/log_utils.py +37 -0
  456. nat/utils/metadata_utils.py +74 -0
  457. nat/utils/optional_imports.py +142 -0
  458. nat/utils/producer_consumer_queue.py +178 -0
  459. nat/utils/reactive/__init__.py +0 -0
  460. nat/utils/reactive/base/__init__.py +0 -0
  461. nat/utils/reactive/base/observable_base.py +65 -0
  462. nat/utils/reactive/base/observer_base.py +55 -0
  463. nat/utils/reactive/base/subject_base.py +79 -0
  464. nat/utils/reactive/observable.py +59 -0
  465. nat/utils/reactive/observer.py +76 -0
  466. nat/utils/reactive/subject.py +131 -0
  467. nat/utils/reactive/subscription.py +49 -0
  468. nat/utils/settings/__init__.py +0 -0
  469. nat/utils/settings/global_settings.py +195 -0
  470. nat/utils/string_utils.py +38 -0
  471. nat/utils/type_converter.py +299 -0
  472. nat/utils/type_utils.py +488 -0
  473. nat/utils/url_utils.py +27 -0
  474. nvidia_nat-1.1.0a20251020.dist-info/METADATA +195 -0
  475. nvidia_nat-1.1.0a20251020.dist-info/RECORD +480 -0
  476. nvidia_nat-1.1.0a20251020.dist-info/WHEEL +5 -0
  477. nvidia_nat-1.1.0a20251020.dist-info/entry_points.txt +22 -0
  478. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  479. nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
  480. nvidia_nat-1.1.0a20251020.dist-info/top_level.txt +2 -0
@@ -0,0 +1,411 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import inspect
18
+ import uuid
19
+ from collections.abc import Callable
20
+ from typing import Any
21
+ from typing import TypeVar
22
+ from typing import cast
23
+ from typing import overload
24
+
25
+ from pydantic import BaseModel
26
+
27
+ from nat.builder.context import Context
28
+ from nat.builder.intermediate_step_manager import IntermediateStepManager
29
+ from nat.data_models.intermediate_step import IntermediateStepPayload
30
+ from nat.data_models.intermediate_step import IntermediateStepType
31
+ from nat.data_models.intermediate_step import TraceMetadata
32
+
33
+
34
+ # --- Helper function to recursively serialize any object into JSON-friendly data ---
35
+ def _serialize_data(obj: Any) -> Any:
36
+ """Convert `obj` into a structure that can be passed to `json.dumps(...)`."""
37
+ if isinstance(obj, BaseModel):
38
+ # Convert Pydantic model to dict
39
+ return obj.model_dump()
40
+
41
+ if isinstance(obj, dict):
42
+ return {str(k): _serialize_data(v) for k, v in obj.items()}
43
+ if isinstance(obj, list | tuple | set):
44
+ return [_serialize_data(item) for item in obj]
45
+
46
+ if isinstance(obj, str | int | float | bool | type(None)):
47
+ return obj
48
+
49
+ # Fallback
50
+ return str(obj)
51
+
52
+
53
+ def _prepare_serialized_args_kwargs(*args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
54
+ """Serialize args and kwargs before calling the wrapped function."""
55
+ serialized_args = [_serialize_data(a) for a in args]
56
+ serialized_kwargs = {k: _serialize_data(v) for k, v in kwargs.items()}
57
+ return serialized_args, serialized_kwargs
58
+
59
+
60
+ def push_intermediate_step(step_manager: IntermediateStepManager,
61
+ identifier: str,
62
+ function_name: str,
63
+ event_type: IntermediateStepType,
64
+ args: Any = None,
65
+ kwargs: Any = None,
66
+ output: Any = None,
67
+ metadata: dict[str, Any] | None = None) -> None:
68
+ """Push an intermediate step to the NAT Event Stream."""
69
+
70
+ payload = IntermediateStepPayload(UUID=identifier,
71
+ event_type=event_type,
72
+ name=function_name,
73
+ metadata=TraceMetadata(
74
+ span_inputs=[args, kwargs],
75
+ span_outputs=output,
76
+ provided_metadata=metadata,
77
+ ))
78
+
79
+ step_manager.push_intermediate_step(payload)
80
+
81
+
82
+ # Type variable for overloads
83
+ F = TypeVar('F', bound=Callable[..., Any])
84
+
85
+
86
+ # Overloads for different function types
87
+ @overload
88
+ def track_function(func: F, *, metadata: dict[str, Any] | None = None) -> F:
89
+ """Overload for when a function is passed directly."""
90
+ ...
91
+
92
+
93
+ @overload
94
+ def track_function(*, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
95
+ """Overload for decorator factory usage (when called with parentheses)."""
96
+ ...
97
+
98
+
99
+ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None) -> Any:
100
+ """
101
+ Decorator that can wrap any type of function (sync, async, generator,
102
+ async generator) and executes "tracking logic" around it.
103
+
104
+ - If the function is async, it will be wrapped in an async function.
105
+ - If the function is a generator, it will be wrapped in a generator function.
106
+ - If the function is an async generator, it will be wrapped in an async generator function.
107
+ - If the function is sync, it will be wrapped in a sync function.
108
+ """
109
+ function_name: str = func.__name__ if func else "<unknown_function>"
110
+
111
+ # If called as @track_function(...) but not immediately passed a function
112
+ if func is None:
113
+
114
+ def decorator_wrapper(actual_func):
115
+ return track_function(actual_func, metadata=metadata)
116
+
117
+ return decorator_wrapper
118
+
119
+ # --- Validate metadata ---
120
+ if metadata is not None:
121
+ if not isinstance(metadata, dict):
122
+ raise TypeError("metadata must be a dict[str, Any].")
123
+ if any(not isinstance(k, str) for k in metadata.keys()):
124
+ raise TypeError("All metadata keys must be strings.")
125
+
126
+ # --- Now detect the function type and wrap accordingly ---
127
+ if inspect.isasyncgenfunction(func):
128
+ # ---------------------
129
+ # ASYNC GENERATOR
130
+ # ---------------------
131
+
132
+ @functools.wraps(func)
133
+ async def async_gen_wrapper(*args, **kwargs):
134
+ step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
135
+ # 1) Serialize input
136
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
137
+
138
+ invocation_id = str(uuid.uuid4())
139
+ push_intermediate_step(step_manager,
140
+ invocation_id,
141
+ function_name,
142
+ IntermediateStepType.SPAN_START,
143
+ args=serialized_args,
144
+ kwargs=serialized_kwargs,
145
+ metadata=metadata)
146
+
147
+ # 2) Call the original async generator
148
+ async for item in func(*args, **kwargs):
149
+ # 3) Serialize the yielded item before yielding it
150
+ serialized_item = _serialize_data(item)
151
+ push_intermediate_step(step_manager,
152
+ invocation_id,
153
+ function_name,
154
+ IntermediateStepType.SPAN_CHUNK,
155
+ args=serialized_args,
156
+ kwargs=serialized_kwargs,
157
+ output=serialized_item,
158
+ metadata=metadata)
159
+ yield item # yield the original item
160
+
161
+ push_intermediate_step(step_manager,
162
+ invocation_id,
163
+ function_name,
164
+ IntermediateStepType.SPAN_END,
165
+ args=serialized_args,
166
+ kwargs=serialized_kwargs,
167
+ output=None,
168
+ metadata=metadata)
169
+
170
+ # 4) Post-yield logic if any
171
+
172
+ return async_gen_wrapper
173
+
174
+ if inspect.iscoroutinefunction(func):
175
+ # ---------------------
176
+ # ASYNC FUNCTION
177
+ # ---------------------
178
+ @functools.wraps(func)
179
+ async def async_wrapper(*args, **kwargs):
180
+ step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
181
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
182
+ invocation_id = str(uuid.uuid4())
183
+ push_intermediate_step(step_manager,
184
+ invocation_id,
185
+ function_name,
186
+ IntermediateStepType.SPAN_START,
187
+ args=serialized_args,
188
+ kwargs=serialized_kwargs,
189
+ metadata=metadata)
190
+
191
+ result = await func(*args, **kwargs)
192
+
193
+ serialized_result = _serialize_data(result)
194
+ push_intermediate_step(step_manager,
195
+ invocation_id,
196
+ function_name,
197
+ IntermediateStepType.SPAN_END,
198
+ args=serialized_args,
199
+ kwargs=serialized_kwargs,
200
+ output=serialized_result,
201
+ metadata=metadata)
202
+
203
+ return result
204
+
205
+ return async_wrapper
206
+
207
+ if inspect.isgeneratorfunction(func):
208
+ # ---------------------
209
+ # SYNC GENERATOR
210
+ # ---------------------
211
+ @functools.wraps(func)
212
+ def sync_gen_wrapper(*args, **kwargs):
213
+ step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
214
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
215
+ invocation_id = str(uuid.uuid4())
216
+ push_intermediate_step(step_manager,
217
+ invocation_id,
218
+ function_name,
219
+ IntermediateStepType.SPAN_START,
220
+ args=serialized_args,
221
+ kwargs=serialized_kwargs,
222
+ metadata=metadata)
223
+
224
+ for item in func(*args, **kwargs):
225
+ serialized_item = _serialize_data(item)
226
+ push_intermediate_step(step_manager,
227
+ invocation_id,
228
+ function_name,
229
+ IntermediateStepType.SPAN_CHUNK,
230
+ args=serialized_args,
231
+ kwargs=serialized_kwargs,
232
+ output=serialized_item,
233
+ metadata=metadata)
234
+
235
+ yield item # yield the original item
236
+
237
+ push_intermediate_step(step_manager,
238
+ invocation_id,
239
+ function_name,
240
+ IntermediateStepType.SPAN_END,
241
+ args=serialized_args,
242
+ kwargs=serialized_kwargs,
243
+ output=None,
244
+ metadata=metadata)
245
+
246
+ return sync_gen_wrapper
247
+
248
+ @functools.wraps(func)
249
+ def sync_wrapper(*args, **kwargs):
250
+ step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
251
+ serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
252
+ invocation_id = str(uuid.uuid4())
253
+ push_intermediate_step(step_manager,
254
+ invocation_id,
255
+ function_name,
256
+ IntermediateStepType.SPAN_START,
257
+ args=serialized_args,
258
+ kwargs=serialized_kwargs,
259
+ metadata=metadata)
260
+
261
+ result = func(*args, **kwargs)
262
+
263
+ serialized_result = _serialize_data(result)
264
+ push_intermediate_step(step_manager,
265
+ invocation_id,
266
+ function_name,
267
+ IntermediateStepType.SPAN_END,
268
+ args=serialized_args,
269
+ kwargs=serialized_kwargs,
270
+ output=serialized_result,
271
+ metadata=metadata)
272
+
273
+ return result
274
+
275
+ return sync_wrapper
276
+
277
+
278
+ # Overloads for track_unregistered_function
279
+ @overload
280
+ def track_unregistered_function(func: F, *, name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
281
+ """Overload for when a function is passed directly."""
282
+ ...
283
+
284
+
285
+ @overload
286
+ def track_unregistered_function(*, name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
287
+ """Overload for decorator factory usage (when called with parentheses)."""
288
+ ...
289
+
290
+
291
+ def track_unregistered_function(func: Callable[..., Any] | None = None,
292
+ *,
293
+ name: str | None = None,
294
+ metadata: dict[str, Any] | None = None) -> Callable[..., Any]:
295
+ """
296
+ Decorator that wraps any function with scope management and automatic tracking.
297
+
298
+ - Sets active function context using the function name
299
+ - Leverages Context.push_active_function for built-in tracking
300
+ - Avoids duplicate tracking entries by relying on the library's built-in systems
301
+ - Supports sync/async functions and generators
302
+
303
+ Args:
304
+ func: The function to wrap (auto-detected when used without parentheses)
305
+ name: Custom name to use for tracking instead of func.__name__
306
+ metadata: Additional metadata to include in tracking
307
+ """
308
+
309
+ # If called with parameters: @track_unregistered_function(name="...", metadata={...})
310
+ if func is None:
311
+
312
+ def decorator_wrapper(actual_func: Callable[..., Any]) -> Callable[..., Any]:
313
+ # Cast to ensure type checker understands this returns a callable
314
+ return cast(Callable[..., Any], track_unregistered_function(actual_func, name=name, metadata=metadata))
315
+
316
+ return decorator_wrapper
317
+
318
+ # Direct decoration: @track_unregistered_function or recursive call with actual function
319
+ function_name: str = name if name else func.__name__
320
+
321
+ # --- Validate metadata ---
322
+ if metadata is not None:
323
+ if not isinstance(metadata, dict):
324
+ raise TypeError("metadata must be a dict[str, Any].")
325
+ if any(not isinstance(k, str) for k in metadata.keys()):
326
+ raise TypeError("All metadata keys must be strings.")
327
+
328
+ trace_metadata = TraceMetadata(provided_metadata=metadata)
329
+
330
+ # --- Now detect the function type and wrap accordingly ---
331
+ if inspect.isasyncgenfunction(func):
332
+ # ---------------------
333
+ # ASYNC GENERATOR
334
+ # ---------------------
335
+
336
+ @functools.wraps(func)
337
+ async def async_gen_wrapper(*args, **kwargs):
338
+ context = Context.get()
339
+ input_data = (
340
+ *args,
341
+ kwargs,
342
+ )
343
+ # Only do context management - let push_active_function handle tracking
344
+ with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
345
+ final_outputs = []
346
+ async for item in func(*args, **kwargs):
347
+ final_outputs.append(item)
348
+ yield item
349
+
350
+ manager.set_output(final_outputs)
351
+
352
+ return async_gen_wrapper
353
+
354
+ if inspect.iscoroutinefunction(func):
355
+ # ---------------------
356
+ # ASYNC FUNCTION
357
+ # ---------------------
358
+ @functools.wraps(func)
359
+ async def async_wrapper(*args, **kwargs):
360
+ context = Context.get()
361
+ input_data = (
362
+ *args,
363
+ kwargs,
364
+ )
365
+
366
+ # Only do context management - let push_active_function handle tracking
367
+ with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
368
+ result = await func(*args, **kwargs)
369
+ manager.set_output(result)
370
+ return result
371
+
372
+ return async_wrapper
373
+
374
+ if inspect.isgeneratorfunction(func):
375
+ # ---------------------
376
+ # SYNC GENERATOR
377
+ # ---------------------
378
+ @functools.wraps(func)
379
+ def sync_gen_wrapper(*args, **kwargs):
380
+ context = Context.get()
381
+ input_data = (
382
+ *args,
383
+ kwargs,
384
+ )
385
+
386
+ # Only do context management - let push_active_function handle tracking
387
+ with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
388
+ final_outputs = []
389
+ for item in func(*args, **kwargs):
390
+ final_outputs.append(item)
391
+ yield item
392
+
393
+ manager.set_output(final_outputs)
394
+
395
+ return sync_gen_wrapper
396
+
397
+ @functools.wraps(func)
398
+ def sync_wrapper(*args, **kwargs):
399
+ context = Context.get()
400
+ input_data = (
401
+ *args,
402
+ kwargs,
403
+ )
404
+
405
+ # Only do context management - let push_active_function handle tracking
406
+ with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
407
+ result = func(*args, **kwargs)
408
+ manager.set_output(result)
409
+ return result
410
+
411
+ return sync_wrapper
File without changes
@@ -0,0 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # If you have any global constants or defaults
17
+ DEFAULT_MODEL_TYPE = "randomforest"
18
+ DEFAULT_MATRIX_LENGTH = 10
@@ -0,0 +1,75 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # forecasting/model_trainer.py
17
+
18
+ import logging
19
+
20
+ from nat.profiler.forecasting.config import DEFAULT_MODEL_TYPE
21
+ from nat.profiler.forecasting.models import ForecastingBaseModel
22
+ from nat.profiler.forecasting.models import LinearModel
23
+ from nat.profiler.forecasting.models import RandomForestModel
24
+ from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def create_model(model_type: str) -> ForecastingBaseModel:
30
+ """
31
+ A simple factory method that returns a model instance
32
+ based on the input string. Extend this with more model
33
+ classes (e.g., PolynomialModel, RandomForestModel, etc.).
34
+ """
35
+ if model_type == "linear":
36
+ return LinearModel()
37
+ if model_type == "randomforest":
38
+ return RandomForestModel()
39
+
40
+ raise ValueError(f"Unsupported model_type: {model_type}")
41
+
42
+
43
+ class ModelTrainer:
44
+ """
45
+ Orchestrates data preprocessing, training, and returning
46
+ a fitted model.
47
+
48
+ Parameters
49
+ ----------
50
+ model_type: str, default = "randomforest"
51
+ The type of model to train. Options include "linear" and "randomforest".
52
+ """
53
+
54
+ def __init__(self, model_type: str = DEFAULT_MODEL_TYPE):
55
+ self.model_type = model_type
56
+ self._model = create_model(self.model_type)
57
+
58
+ def train(self, raw_stats: list[list[IntermediatePropertyAdaptor]]) -> ForecastingBaseModel:
59
+ """
60
+ Train the model using the `raw_stats` training data.
61
+
62
+ Parameters
63
+ ----------
64
+ raw_stats: list[list[IntermediatePropertyAdaptor]]
65
+ Stats collected by the profiler.
66
+
67
+ Returns
68
+ -------
69
+ ForecastingBaseModel
70
+ A fitted model.
71
+ """
72
+
73
+ self._model.fit(raw_stats)
74
+
75
+ return self._model
@@ -0,0 +1,22 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # forecasting/models/__init__.py
17
+
18
+ from .forecasting_base_model import ForecastingBaseModel
19
+ from .linear_model import LinearModel
20
+ from .random_forest_regressor import RandomForestModel
21
+
22
+ __all__ = ["ForecastingBaseModel", "LinearModel", "RandomForestModel"]
@@ -0,0 +1,42 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # forecasting/models/base_model.py
17
+
18
+ from abc import ABC
19
+ from abc import abstractmethod
20
+
21
+ import numpy as np
22
+
23
+
24
+ class ForecastingBaseModel(ABC):
25
+ """
26
+ Abstract base class for all models in this package.
27
+ """
28
+
29
+ @abstractmethod
30
+ def fit(self, raw_stats):
31
+ """
32
+ Train/fine-tune the model on the provided dataset.
33
+ """
34
+ pass
35
+
36
+ @abstractmethod
37
+ def predict(self, raw_stats) -> np.ndarray:
38
+ """
39
+ Predict using the trained model.
40
+ Returns a np.ndarray, shape = (N, 4).
41
+ """
42
+ pass