nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
import inspect
|
|
18
|
+
import uuid
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
23
|
+
from nat.builder.context import Context
|
|
24
|
+
from nat.builder.intermediate_step_manager import IntermediateStepManager
|
|
25
|
+
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
26
|
+
from nat.data_models.intermediate_step import IntermediateStepType
|
|
27
|
+
from nat.data_models.intermediate_step import TraceMetadata
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# --- Helper function to recursively serialize any object into JSON-friendly data ---
|
|
31
|
+
def _serialize_data(obj: Any) -> Any:
|
|
32
|
+
"""Convert `obj` into a structure that can be passed to `json.dumps(...)`."""
|
|
33
|
+
if isinstance(obj, BaseModel):
|
|
34
|
+
# Convert Pydantic model to dict
|
|
35
|
+
return obj.model_dump()
|
|
36
|
+
|
|
37
|
+
if isinstance(obj, dict):
|
|
38
|
+
return {str(k): _serialize_data(v) for k, v in obj.items()}
|
|
39
|
+
if isinstance(obj, (list, tuple, set)):
|
|
40
|
+
return [_serialize_data(item) for item in obj]
|
|
41
|
+
|
|
42
|
+
if isinstance(obj, (str, int, float, bool, type(None))):
|
|
43
|
+
return obj
|
|
44
|
+
|
|
45
|
+
# Fallback
|
|
46
|
+
return str(obj)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _prepare_serialized_args_kwargs(*args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
|
|
50
|
+
"""Serialize args and kwargs before calling the wrapped function."""
|
|
51
|
+
serialized_args = [_serialize_data(a) for a in args]
|
|
52
|
+
serialized_kwargs = {k: _serialize_data(v) for k, v in kwargs.items()}
|
|
53
|
+
return serialized_args, serialized_kwargs
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def push_intermediate_step(step_manager: IntermediateStepManager,
|
|
57
|
+
identifier: str,
|
|
58
|
+
function_name: str,
|
|
59
|
+
event_type: IntermediateStepType,
|
|
60
|
+
args: Any = None,
|
|
61
|
+
kwargs: Any = None,
|
|
62
|
+
output: Any = None,
|
|
63
|
+
metadata: dict[str, Any] | None = None) -> None:
|
|
64
|
+
"""Push an intermediate step to the NAT Event Stream."""
|
|
65
|
+
|
|
66
|
+
payload = IntermediateStepPayload(UUID=identifier,
|
|
67
|
+
event_type=event_type,
|
|
68
|
+
name=function_name,
|
|
69
|
+
metadata=TraceMetadata(
|
|
70
|
+
span_inputs=[args, kwargs],
|
|
71
|
+
span_outputs=output,
|
|
72
|
+
provided_metadata=metadata,
|
|
73
|
+
))
|
|
74
|
+
|
|
75
|
+
step_manager.push_intermediate_step(payload)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
|
|
79
|
+
"""
|
|
80
|
+
Decorator that can wrap any type of function (sync, async, generator,
|
|
81
|
+
async generator) and executes "tracking logic" around it.
|
|
82
|
+
|
|
83
|
+
- If the function is async, it will be wrapped in an async function.
|
|
84
|
+
- If the function is a generator, it will be wrapped in a generator function.
|
|
85
|
+
- If the function is an async generator, it will be wrapped in an async generator function.
|
|
86
|
+
- If the function is sync, it will be wrapped in a sync function.
|
|
87
|
+
"""
|
|
88
|
+
function_name: str = func.__name__ if func else "<unknown_function>"
|
|
89
|
+
|
|
90
|
+
# If called as @track_function(...) but not immediately passed a function
|
|
91
|
+
if func is None:
|
|
92
|
+
|
|
93
|
+
def decorator_wrapper(actual_func):
|
|
94
|
+
return track_function(actual_func, metadata=metadata)
|
|
95
|
+
|
|
96
|
+
return decorator_wrapper
|
|
97
|
+
|
|
98
|
+
# --- Validate metadata ---
|
|
99
|
+
if metadata is not None:
|
|
100
|
+
if not isinstance(metadata, dict):
|
|
101
|
+
raise TypeError("metadata must be a dict[str, Any].")
|
|
102
|
+
if any(not isinstance(k, str) for k in metadata.keys()):
|
|
103
|
+
raise TypeError("All metadata keys must be strings.")
|
|
104
|
+
|
|
105
|
+
# --- Now detect the function type and wrap accordingly ---
|
|
106
|
+
if inspect.isasyncgenfunction(func):
|
|
107
|
+
# ---------------------
|
|
108
|
+
# ASYNC GENERATOR
|
|
109
|
+
# ---------------------
|
|
110
|
+
|
|
111
|
+
@functools.wraps(func)
|
|
112
|
+
async def async_gen_wrapper(*args, **kwargs):
|
|
113
|
+
step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
|
|
114
|
+
# 1) Serialize input
|
|
115
|
+
serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
|
|
116
|
+
|
|
117
|
+
invocation_id = str(uuid.uuid4())
|
|
118
|
+
push_intermediate_step(step_manager,
|
|
119
|
+
invocation_id,
|
|
120
|
+
function_name,
|
|
121
|
+
IntermediateStepType.SPAN_START,
|
|
122
|
+
args=serialized_args,
|
|
123
|
+
kwargs=serialized_kwargs,
|
|
124
|
+
metadata=metadata)
|
|
125
|
+
|
|
126
|
+
# 2) Call the original async generator
|
|
127
|
+
async for item in func(*args, **kwargs):
|
|
128
|
+
# 3) Serialize the yielded item before yielding it
|
|
129
|
+
serialized_item = _serialize_data(item)
|
|
130
|
+
push_intermediate_step(step_manager,
|
|
131
|
+
invocation_id,
|
|
132
|
+
function_name,
|
|
133
|
+
IntermediateStepType.SPAN_CHUNK,
|
|
134
|
+
args=serialized_args,
|
|
135
|
+
kwargs=serialized_kwargs,
|
|
136
|
+
output=serialized_item,
|
|
137
|
+
metadata=metadata)
|
|
138
|
+
yield item # yield the original item
|
|
139
|
+
|
|
140
|
+
push_intermediate_step(step_manager,
|
|
141
|
+
invocation_id,
|
|
142
|
+
function_name,
|
|
143
|
+
IntermediateStepType.SPAN_END,
|
|
144
|
+
args=serialized_args,
|
|
145
|
+
kwargs=serialized_kwargs,
|
|
146
|
+
output=None,
|
|
147
|
+
metadata=metadata)
|
|
148
|
+
|
|
149
|
+
# 4) Post-yield logic if any
|
|
150
|
+
|
|
151
|
+
return async_gen_wrapper
|
|
152
|
+
|
|
153
|
+
if inspect.iscoroutinefunction(func):
|
|
154
|
+
# ---------------------
|
|
155
|
+
# ASYNC FUNCTION
|
|
156
|
+
# ---------------------
|
|
157
|
+
@functools.wraps(func)
|
|
158
|
+
async def async_wrapper(*args, **kwargs):
|
|
159
|
+
step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
|
|
160
|
+
serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
|
|
161
|
+
invocation_id = str(uuid.uuid4())
|
|
162
|
+
push_intermediate_step(step_manager,
|
|
163
|
+
invocation_id,
|
|
164
|
+
function_name,
|
|
165
|
+
IntermediateStepType.SPAN_START,
|
|
166
|
+
args=serialized_args,
|
|
167
|
+
kwargs=serialized_kwargs,
|
|
168
|
+
metadata=metadata)
|
|
169
|
+
|
|
170
|
+
result = await func(*args, **kwargs)
|
|
171
|
+
|
|
172
|
+
serialized_result = _serialize_data(result)
|
|
173
|
+
push_intermediate_step(step_manager,
|
|
174
|
+
invocation_id,
|
|
175
|
+
function_name,
|
|
176
|
+
IntermediateStepType.SPAN_END,
|
|
177
|
+
args=serialized_args,
|
|
178
|
+
kwargs=serialized_kwargs,
|
|
179
|
+
output=serialized_result,
|
|
180
|
+
metadata=metadata)
|
|
181
|
+
|
|
182
|
+
return result
|
|
183
|
+
|
|
184
|
+
return async_wrapper
|
|
185
|
+
|
|
186
|
+
if inspect.isgeneratorfunction(func):
|
|
187
|
+
# ---------------------
|
|
188
|
+
# SYNC GENERATOR
|
|
189
|
+
# ---------------------
|
|
190
|
+
@functools.wraps(func)
|
|
191
|
+
def sync_gen_wrapper(*args, **kwargs):
|
|
192
|
+
step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
|
|
193
|
+
serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
|
|
194
|
+
invocation_id = str(uuid.uuid4())
|
|
195
|
+
push_intermediate_step(step_manager,
|
|
196
|
+
invocation_id,
|
|
197
|
+
function_name,
|
|
198
|
+
IntermediateStepType.SPAN_START,
|
|
199
|
+
args=serialized_args,
|
|
200
|
+
kwargs=serialized_kwargs,
|
|
201
|
+
metadata=metadata)
|
|
202
|
+
|
|
203
|
+
for item in func(*args, **kwargs):
|
|
204
|
+
serialized_item = _serialize_data(item)
|
|
205
|
+
push_intermediate_step(step_manager,
|
|
206
|
+
invocation_id,
|
|
207
|
+
function_name,
|
|
208
|
+
IntermediateStepType.SPAN_CHUNK,
|
|
209
|
+
args=serialized_args,
|
|
210
|
+
kwargs=serialized_kwargs,
|
|
211
|
+
output=serialized_item,
|
|
212
|
+
metadata=metadata)
|
|
213
|
+
|
|
214
|
+
yield item # yield the original item
|
|
215
|
+
|
|
216
|
+
push_intermediate_step(step_manager,
|
|
217
|
+
invocation_id,
|
|
218
|
+
function_name,
|
|
219
|
+
IntermediateStepType.SPAN_END,
|
|
220
|
+
args=serialized_args,
|
|
221
|
+
kwargs=serialized_kwargs,
|
|
222
|
+
output=None,
|
|
223
|
+
metadata=metadata)
|
|
224
|
+
|
|
225
|
+
return sync_gen_wrapper
|
|
226
|
+
|
|
227
|
+
@functools.wraps(func)
|
|
228
|
+
def sync_wrapper(*args, **kwargs):
|
|
229
|
+
step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
|
|
230
|
+
serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
|
|
231
|
+
invocation_id = str(uuid.uuid4())
|
|
232
|
+
push_intermediate_step(step_manager,
|
|
233
|
+
invocation_id,
|
|
234
|
+
function_name,
|
|
235
|
+
IntermediateStepType.SPAN_START,
|
|
236
|
+
args=serialized_args,
|
|
237
|
+
kwargs=serialized_kwargs,
|
|
238
|
+
metadata=metadata)
|
|
239
|
+
|
|
240
|
+
result = func(*args, **kwargs)
|
|
241
|
+
|
|
242
|
+
serialized_result = _serialize_data(result)
|
|
243
|
+
push_intermediate_step(step_manager,
|
|
244
|
+
invocation_id,
|
|
245
|
+
function_name,
|
|
246
|
+
IntermediateStepType.SPAN_END,
|
|
247
|
+
args=serialized_args,
|
|
248
|
+
kwargs=serialized_kwargs,
|
|
249
|
+
output=serialized_result,
|
|
250
|
+
metadata=metadata)
|
|
251
|
+
|
|
252
|
+
return result
|
|
253
|
+
|
|
254
|
+
return sync_wrapper
|
|
File without changes
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# If you have any global constants or defaults
|
|
17
|
+
DEFAULT_MODEL_TYPE = "randomforest"
|
|
18
|
+
DEFAULT_MATRIX_LENGTH = 10
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# forecasting/model_trainer.py
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
|
|
20
|
+
from nat.profiler.forecasting.config import DEFAULT_MODEL_TYPE
|
|
21
|
+
from nat.profiler.forecasting.models import ForecastingBaseModel
|
|
22
|
+
from nat.profiler.forecasting.models import LinearModel
|
|
23
|
+
from nat.profiler.forecasting.models import RandomForestModel
|
|
24
|
+
from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def create_model(model_type: str) -> ForecastingBaseModel:
|
|
30
|
+
"""
|
|
31
|
+
A simple factory method that returns a model instance
|
|
32
|
+
based on the input string. Extend this with more model
|
|
33
|
+
classes (e.g., PolynomialModel, RandomForestModel, etc.).
|
|
34
|
+
"""
|
|
35
|
+
if model_type == "linear":
|
|
36
|
+
return LinearModel()
|
|
37
|
+
if model_type == "randomforest":
|
|
38
|
+
return RandomForestModel()
|
|
39
|
+
|
|
40
|
+
raise ValueError(f"Unsupported model_type: {model_type}")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ModelTrainer:
|
|
44
|
+
"""
|
|
45
|
+
Orchestrates data preprocessing, training, and returning
|
|
46
|
+
a fitted model.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
model_type: str, default = "randomforest"
|
|
51
|
+
The type of model to train. Options include "linear" and "randomforest".
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, model_type: str = DEFAULT_MODEL_TYPE):
|
|
55
|
+
self.model_type = model_type
|
|
56
|
+
self._model = create_model(self.model_type)
|
|
57
|
+
|
|
58
|
+
def train(self, raw_stats: list[list[IntermediatePropertyAdaptor]]) -> ForecastingBaseModel:
|
|
59
|
+
"""
|
|
60
|
+
Train the model using the `raw_stats` training data.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
raw_stats: list[list[IntermediatePropertyAdaptor]]
|
|
65
|
+
Stats collected by the profiler.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
ForecastingBaseModel
|
|
70
|
+
A fitted model.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
self._model.fit(raw_stats)
|
|
74
|
+
|
|
75
|
+
return self._model
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# forecasting/models/__init__.py
|
|
17
|
+
|
|
18
|
+
from .forecasting_base_model import ForecastingBaseModel
|
|
19
|
+
from .linear_model import LinearModel
|
|
20
|
+
from .random_forest_regressor import RandomForestModel
|
|
21
|
+
|
|
22
|
+
__all__ = ["ForecastingBaseModel", "LinearModel", "RandomForestModel"]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# forecasting/models/base_model.py
|
|
17
|
+
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ForecastingBaseModel(ABC):
|
|
23
|
+
"""
|
|
24
|
+
Abstract base class for all models in this package.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def fit(self, raw_stats):
|
|
29
|
+
"""
|
|
30
|
+
Train/fine-tune the model on the provided dataset.
|
|
31
|
+
"""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def predict(self, raw_stats) -> np.ndarray:
|
|
36
|
+
"""
|
|
37
|
+
Predict using the trained model.
|
|
38
|
+
Returns a np.ndarray, shape = (N, 4).
|
|
39
|
+
"""
|
|
40
|
+
pass
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from nat.profiler.forecasting.models.forecasting_base_model import ForecastingBaseModel
|
|
21
|
+
from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LinearModel(ForecastingBaseModel):
|
|
27
|
+
"""
|
|
28
|
+
A linear regression model that conforms to the BaseModel interface.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self):
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from sklearn.linear_model import LinearRegression
|
|
36
|
+
except ImportError:
|
|
37
|
+
logger.error(
|
|
38
|
+
"scikit-learn is not installed. Please install scikit-learn to use the LinearModel "
|
|
39
|
+
"profiling model or install `nvidia-nat[profiler]` to install all necessary profiling packages.")
|
|
40
|
+
|
|
41
|
+
raise
|
|
42
|
+
|
|
43
|
+
self.model = LinearRegression()
|
|
44
|
+
self.matrix_length = None
|
|
45
|
+
|
|
46
|
+
def fit(self, raw_stats: list[list[IntermediatePropertyAdaptor]]):
|
|
47
|
+
"""
|
|
48
|
+
X: shape (N, M) # M = matrix_length * 4
|
|
49
|
+
y: shape (N, 4)
|
|
50
|
+
"""
|
|
51
|
+
x_flat, y_flat = self._prep_for_model_training(raw_stats)
|
|
52
|
+
|
|
53
|
+
logger.info("Training dataset size: X=%s, y=%s", x_flat.shape, y_flat.shape)
|
|
54
|
+
|
|
55
|
+
# 3) Fit
|
|
56
|
+
self.model.fit(x_flat, y_flat)
|
|
57
|
+
|
|
58
|
+
def predict(self, raw_stats: list[list[IntermediatePropertyAdaptor]]) -> np.ndarray:
|
|
59
|
+
"""
|
|
60
|
+
Predict using the fitted linear model.
|
|
61
|
+
Returns shape (N, 4)
|
|
62
|
+
"""
|
|
63
|
+
X = self._prep_single(raw_stats)
|
|
64
|
+
return self.model.predict(X)
|
|
65
|
+
|
|
66
|
+
def _prep_single(self, raw_stats: list[list[IntermediatePropertyAdaptor]]) -> np.ndarray:
|
|
67
|
+
arr, _ = self._extract_token_usage_meta(raw_stats)
|
|
68
|
+
arr = arr[0]
|
|
69
|
+
n_rows = arr.shape[0]
|
|
70
|
+
|
|
71
|
+
matrix_length = self.matrix_length
|
|
72
|
+
|
|
73
|
+
assert matrix_length is not None, "matrix_length must be set before calling _prep_single"
|
|
74
|
+
|
|
75
|
+
if n_rows >= matrix_length:
|
|
76
|
+
# Keep the latest matrix_length rows
|
|
77
|
+
x_mat = arr[-matrix_length:, :]
|
|
78
|
+
else:
|
|
79
|
+
# Pad with zeros at the top
|
|
80
|
+
pad_size = matrix_length - n_rows
|
|
81
|
+
pad_block = np.zeros((pad_size, arr.shape[1]), dtype=arr.dtype)
|
|
82
|
+
x_mat = np.vstack([pad_block, arr])
|
|
83
|
+
|
|
84
|
+
return x_mat
|
|
85
|
+
|
|
86
|
+
def _prep_for_model_training(self, raw_stats: list[list[IntermediatePropertyAdaptor]]):
|
|
87
|
+
raw_matrices, matrix_length = self._extract_token_usage_meta(raw_stats)
|
|
88
|
+
|
|
89
|
+
self.matrix_length = matrix_length
|
|
90
|
+
|
|
91
|
+
x_list = []
|
|
92
|
+
y_list = []
|
|
93
|
+
for arr in raw_matrices:
|
|
94
|
+
samples = self._preprocess_for_forecasting(arr, matrix_length)
|
|
95
|
+
for (x_mat, y_mat) in samples:
|
|
96
|
+
x_list.append(x_mat)
|
|
97
|
+
y_list.append(y_mat)
|
|
98
|
+
|
|
99
|
+
# 2) Flatten features
|
|
100
|
+
x_flat, y_flat = self._flatten_features(x_list, y_list)
|
|
101
|
+
|
|
102
|
+
return x_flat, y_flat
|
|
103
|
+
|
|
104
|
+
def _extract_token_usage_meta(self, all_requests_data: list[list[IntermediatePropertyAdaptor]]):
|
|
105
|
+
import math
|
|
106
|
+
|
|
107
|
+
all_run_data = []
|
|
108
|
+
call_stack_sizes = []
|
|
109
|
+
|
|
110
|
+
for prompt in all_requests_data:
|
|
111
|
+
run_data = []
|
|
112
|
+
seconds_between_call_map = {}
|
|
113
|
+
|
|
114
|
+
for stat in prompt:
|
|
115
|
+
if stat.event_type.value == "LLM_START":
|
|
116
|
+
seconds_between_call_map[stat.UUID] = stat.seconds_between_calls
|
|
117
|
+
|
|
118
|
+
if stat.event_type.value == "LLM_END":
|
|
119
|
+
step_data = [
|
|
120
|
+
seconds_between_call_map[stat.UUID],
|
|
121
|
+
stat.token_usage.prompt_tokens,
|
|
122
|
+
stat.token_usage.completion_tokens
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
run_data.append(step_data)
|
|
126
|
+
|
|
127
|
+
all_run_data.append(run_data)
|
|
128
|
+
call_stack_sizes.append(len(run_data))
|
|
129
|
+
|
|
130
|
+
all_run_data = [np.array(run) for run in all_run_data]
|
|
131
|
+
recommended_matrix_length = math.ceil(sum(call_stack_sizes) / len(call_stack_sizes))
|
|
132
|
+
|
|
133
|
+
return all_run_data, recommended_matrix_length
|
|
134
|
+
|
|
135
|
+
def _preprocess_for_forecasting(self, arr: np.ndarray, matrix_length: int):
|
|
136
|
+
"""
|
|
137
|
+
Given a 2D NumPy array `arr` of shape (n_rows, 4), generate a list of
|
|
138
|
+
(input_array, output_array) pairs for forecasting, each of shape:
|
|
139
|
+
|
|
140
|
+
- input_array: (matrix_length, 4) after padding/trimming
|
|
141
|
+
- output_array: (1, 4)
|
|
142
|
+
"""
|
|
143
|
+
n_rows = arr.shape[0]
|
|
144
|
+
|
|
145
|
+
# partial_sums[i] = sum of arr[i:] per column
|
|
146
|
+
partial_sums = np.flip(np.cumsum(np.flip(arr, axis=0), axis=0), axis=0)
|
|
147
|
+
|
|
148
|
+
samples = []
|
|
149
|
+
for i in range(n_rows):
|
|
150
|
+
x_untrimmed = arr[:i + 1, :]
|
|
151
|
+
# Trim or pad
|
|
152
|
+
current_len = x_untrimmed.shape[0]
|
|
153
|
+
if current_len > matrix_length:
|
|
154
|
+
x_mat = x_untrimmed[-matrix_length:, :]
|
|
155
|
+
elif current_len < matrix_length:
|
|
156
|
+
pad_size = matrix_length - current_len
|
|
157
|
+
pad_block = np.zeros((pad_size, x_untrimmed.shape[1]), dtype=arr.dtype)
|
|
158
|
+
x_mat = np.vstack([pad_block, x_untrimmed])
|
|
159
|
+
else:
|
|
160
|
+
x_mat = x_untrimmed
|
|
161
|
+
|
|
162
|
+
# Compute output
|
|
163
|
+
if i == n_rows - 1:
|
|
164
|
+
y_vec = np.array([0, 0, 0, 0], dtype=arr.dtype)
|
|
165
|
+
else:
|
|
166
|
+
n_below = n_rows - (i + 1)
|
|
167
|
+
sum_below = partial_sums[i + 1]
|
|
168
|
+
avg_col0 = sum_below[0] / n_below
|
|
169
|
+
sum_rest = sum_below[1:]
|
|
170
|
+
y_vec = np.concatenate(([avg_col0], sum_rest))
|
|
171
|
+
|
|
172
|
+
samples.append((x_mat, y_vec.reshape(1, 4)))
|
|
173
|
+
|
|
174
|
+
return samples
|
|
175
|
+
|
|
176
|
+
def _flatten_features(self, x_list, y_list):
|
|
177
|
+
"""
|
|
178
|
+
x_list: list of arrays, each of shape (matrix_length, 4)
|
|
179
|
+
y_list: list of arrays, each of shape (1, 4)
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
x_flat: np.array of shape (N, matrix_length*4)
|
|
183
|
+
y_flat: np.array of shape (N, 4)
|
|
184
|
+
"""
|
|
185
|
+
flattened_x = []
|
|
186
|
+
flattened_y = []
|
|
187
|
+
|
|
188
|
+
for x_mat, y_mat in zip(x_list, y_list):
|
|
189
|
+
x_1d = x_mat.flatten() # shape -> (matrix_length*4,)
|
|
190
|
+
y_1d = y_mat.flatten() # shape -> (4,)
|
|
191
|
+
flattened_x.append(x_1d)
|
|
192
|
+
flattened_y.append(y_1d)
|
|
193
|
+
|
|
194
|
+
x_flat = np.array(flattened_x)
|
|
195
|
+
y_flat = np.array(flattened_y)
|
|
196
|
+
logger.debug("Flattened features to shapes: %s (X), %s (y).", x_flat.shape, y_flat.shape)
|
|
197
|
+
return x_flat, y_flat
|