nvidia-nat 1.1.0a20251020__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +265 -0
- nat/agent/dual_node.py +72 -0
- nat/agent/prompt_optimizer/__init__.py +0 -0
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +394 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +168 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +227 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +593 -0
- nat/agent/rewoo_agent/prompt.py +107 -0
- nat/agent/rewoo_agent/register.py +175 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +246 -0
- nat/agent/tool_calling_agent/register.py +129 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +96 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +140 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +20 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +317 -0
- nat/builder/component_utils.py +320 -0
- nat/builder/context.py +321 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +166 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +25 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +714 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +625 -0
- nat/builder/intermediate_step_manager.py +206 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +160 -0
- nat/builder/workflow_builder.py +1365 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +47 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +128 -0
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +153 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +257 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +17 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +25 -0
- nat/cli/commands/workflow/templates/register.py.j2 +4 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +50 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +403 -0
- nat/cli/entrypoint.py +141 -0
- nat/cli/main.py +60 -0
- nat/cli/register_workflow.py +522 -0
- nat/cli/type_registry.py +1069 -0
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +843 -0
- nat/data_models/authentication.py +245 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +60 -0
- nat/data_models/component_ref.py +179 -0
- nat/data_models/config.py +434 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +130 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +64 -0
- nat/data_models/function_dependencies.py +80 -0
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +228 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +42 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +62 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +431 -0
- nat/eval/evaluate.py +565 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +58 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +26 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +242 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +193 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +154 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +228 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +67 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +35 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +157 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +285 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +108 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +142 -0
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +272 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +247 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1257 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +602 -0
- nat/front_ends/fastapi/main.py +64 -0
- nat/front_ends/fastapi/message_handler.py +344 -0
- nat/front_ends/fastapi/message_validator.py +351 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +90 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +113 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +268 -0
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +290 -0
- nat/front_ends/register.py +21 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +56 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +69 -0
- nat/llm/azure_openai_llm.py +57 -0
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +58 -0
- nat/llm/openai_llm.py +54 -0
- nat/llm/register.py +27 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +93 -0
- nat/llm/utils/error.py +17 -0
- nat/llm/utils/thinking.py +215 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +19 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +550 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +308 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +496 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +308 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +74 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +114 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +626 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +297 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +180 -0
- nat/profiler/decorators/function_tracking.py +411 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +42 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +404 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +478 -0
- nat/profiler/utils.py +186 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +570 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +248 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +20 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +236 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +21 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +292 -0
- nat/runtime/session.py +223 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +329 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +77 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +82 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +66 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +31 -0
- nat/tool/retriever.py +95 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/decorators.py +210 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +342 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_levels.py +25 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +195 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +299 -0
- nat/utils/type_utils.py +488 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.1.0a20251020.dist-info/METADATA +195 -0
- nvidia_nat-1.1.0a20251020.dist-info/RECORD +480 -0
- nvidia_nat-1.1.0a20251020.dist-info/WHEEL +5 -0
- nvidia_nat-1.1.0a20251020.dist-info/entry_points.txt +22 -0
- nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.1.0a20251020.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
import inspect
|
|
18
|
+
import uuid
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
from typing import Any
|
|
21
|
+
from typing import TypeVar
|
|
22
|
+
from typing import cast
|
|
23
|
+
from typing import overload
|
|
24
|
+
|
|
25
|
+
from pydantic import BaseModel
|
|
26
|
+
|
|
27
|
+
from nat.builder.context import Context
|
|
28
|
+
from nat.builder.intermediate_step_manager import IntermediateStepManager
|
|
29
|
+
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
30
|
+
from nat.data_models.intermediate_step import IntermediateStepType
|
|
31
|
+
from nat.data_models.intermediate_step import TraceMetadata
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# --- Helper function to recursively serialize any object into JSON-friendly data ---
|
|
35
|
+
def _serialize_data(obj: Any) -> Any:
|
|
36
|
+
"""Convert `obj` into a structure that can be passed to `json.dumps(...)`."""
|
|
37
|
+
if isinstance(obj, BaseModel):
|
|
38
|
+
# Convert Pydantic model to dict
|
|
39
|
+
return obj.model_dump()
|
|
40
|
+
|
|
41
|
+
if isinstance(obj, dict):
|
|
42
|
+
return {str(k): _serialize_data(v) for k, v in obj.items()}
|
|
43
|
+
if isinstance(obj, list | tuple | set):
|
|
44
|
+
return [_serialize_data(item) for item in obj]
|
|
45
|
+
|
|
46
|
+
if isinstance(obj, str | int | float | bool | type(None)):
|
|
47
|
+
return obj
|
|
48
|
+
|
|
49
|
+
# Fallback
|
|
50
|
+
return str(obj)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _prepare_serialized_args_kwargs(*args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
|
|
54
|
+
"""Serialize args and kwargs before calling the wrapped function."""
|
|
55
|
+
serialized_args = [_serialize_data(a) for a in args]
|
|
56
|
+
serialized_kwargs = {k: _serialize_data(v) for k, v in kwargs.items()}
|
|
57
|
+
return serialized_args, serialized_kwargs
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def push_intermediate_step(step_manager: IntermediateStepManager,
|
|
61
|
+
identifier: str,
|
|
62
|
+
function_name: str,
|
|
63
|
+
event_type: IntermediateStepType,
|
|
64
|
+
args: Any = None,
|
|
65
|
+
kwargs: Any = None,
|
|
66
|
+
output: Any = None,
|
|
67
|
+
metadata: dict[str, Any] | None = None) -> None:
|
|
68
|
+
"""Push an intermediate step to the NAT Event Stream."""
|
|
69
|
+
|
|
70
|
+
payload = IntermediateStepPayload(UUID=identifier,
|
|
71
|
+
event_type=event_type,
|
|
72
|
+
name=function_name,
|
|
73
|
+
metadata=TraceMetadata(
|
|
74
|
+
span_inputs=[args, kwargs],
|
|
75
|
+
span_outputs=output,
|
|
76
|
+
provided_metadata=metadata,
|
|
77
|
+
))
|
|
78
|
+
|
|
79
|
+
step_manager.push_intermediate_step(payload)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# Type variable for overloads
|
|
83
|
+
F = TypeVar('F', bound=Callable[..., Any])
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# Overloads for different function types
|
|
87
|
+
@overload
|
|
88
|
+
def track_function(func: F, *, metadata: dict[str, Any] | None = None) -> F:
|
|
89
|
+
"""Overload for when a function is passed directly."""
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@overload
|
|
94
|
+
def track_function(*, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
|
|
95
|
+
"""Overload for decorator factory usage (when called with parentheses)."""
|
|
96
|
+
...
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None) -> Any:
|
|
100
|
+
"""
|
|
101
|
+
Decorator that can wrap any type of function (sync, async, generator,
|
|
102
|
+
async generator) and executes "tracking logic" around it.
|
|
103
|
+
|
|
104
|
+
- If the function is async, it will be wrapped in an async function.
|
|
105
|
+
- If the function is a generator, it will be wrapped in a generator function.
|
|
106
|
+
- If the function is an async generator, it will be wrapped in an async generator function.
|
|
107
|
+
- If the function is sync, it will be wrapped in a sync function.
|
|
108
|
+
"""
|
|
109
|
+
function_name: str = func.__name__ if func else "<unknown_function>"
|
|
110
|
+
|
|
111
|
+
# If called as @track_function(...) but not immediately passed a function
|
|
112
|
+
if func is None:
|
|
113
|
+
|
|
114
|
+
def decorator_wrapper(actual_func):
|
|
115
|
+
return track_function(actual_func, metadata=metadata)
|
|
116
|
+
|
|
117
|
+
return decorator_wrapper
|
|
118
|
+
|
|
119
|
+
# --- Validate metadata ---
|
|
120
|
+
if metadata is not None:
|
|
121
|
+
if not isinstance(metadata, dict):
|
|
122
|
+
raise TypeError("metadata must be a dict[str, Any].")
|
|
123
|
+
if any(not isinstance(k, str) for k in metadata.keys()):
|
|
124
|
+
raise TypeError("All metadata keys must be strings.")
|
|
125
|
+
|
|
126
|
+
# --- Now detect the function type and wrap accordingly ---
|
|
127
|
+
if inspect.isasyncgenfunction(func):
|
|
128
|
+
# ---------------------
|
|
129
|
+
# ASYNC GENERATOR
|
|
130
|
+
# ---------------------
|
|
131
|
+
|
|
132
|
+
@functools.wraps(func)
|
|
133
|
+
async def async_gen_wrapper(*args, **kwargs):
|
|
134
|
+
step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
|
|
135
|
+
# 1) Serialize input
|
|
136
|
+
serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
|
|
137
|
+
|
|
138
|
+
invocation_id = str(uuid.uuid4())
|
|
139
|
+
push_intermediate_step(step_manager,
|
|
140
|
+
invocation_id,
|
|
141
|
+
function_name,
|
|
142
|
+
IntermediateStepType.SPAN_START,
|
|
143
|
+
args=serialized_args,
|
|
144
|
+
kwargs=serialized_kwargs,
|
|
145
|
+
metadata=metadata)
|
|
146
|
+
|
|
147
|
+
# 2) Call the original async generator
|
|
148
|
+
async for item in func(*args, **kwargs):
|
|
149
|
+
# 3) Serialize the yielded item before yielding it
|
|
150
|
+
serialized_item = _serialize_data(item)
|
|
151
|
+
push_intermediate_step(step_manager,
|
|
152
|
+
invocation_id,
|
|
153
|
+
function_name,
|
|
154
|
+
IntermediateStepType.SPAN_CHUNK,
|
|
155
|
+
args=serialized_args,
|
|
156
|
+
kwargs=serialized_kwargs,
|
|
157
|
+
output=serialized_item,
|
|
158
|
+
metadata=metadata)
|
|
159
|
+
yield item # yield the original item
|
|
160
|
+
|
|
161
|
+
push_intermediate_step(step_manager,
|
|
162
|
+
invocation_id,
|
|
163
|
+
function_name,
|
|
164
|
+
IntermediateStepType.SPAN_END,
|
|
165
|
+
args=serialized_args,
|
|
166
|
+
kwargs=serialized_kwargs,
|
|
167
|
+
output=None,
|
|
168
|
+
metadata=metadata)
|
|
169
|
+
|
|
170
|
+
# 4) Post-yield logic if any
|
|
171
|
+
|
|
172
|
+
return async_gen_wrapper
|
|
173
|
+
|
|
174
|
+
if inspect.iscoroutinefunction(func):
|
|
175
|
+
# ---------------------
|
|
176
|
+
# ASYNC FUNCTION
|
|
177
|
+
# ---------------------
|
|
178
|
+
@functools.wraps(func)
|
|
179
|
+
async def async_wrapper(*args, **kwargs):
|
|
180
|
+
step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
|
|
181
|
+
serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
|
|
182
|
+
invocation_id = str(uuid.uuid4())
|
|
183
|
+
push_intermediate_step(step_manager,
|
|
184
|
+
invocation_id,
|
|
185
|
+
function_name,
|
|
186
|
+
IntermediateStepType.SPAN_START,
|
|
187
|
+
args=serialized_args,
|
|
188
|
+
kwargs=serialized_kwargs,
|
|
189
|
+
metadata=metadata)
|
|
190
|
+
|
|
191
|
+
result = await func(*args, **kwargs)
|
|
192
|
+
|
|
193
|
+
serialized_result = _serialize_data(result)
|
|
194
|
+
push_intermediate_step(step_manager,
|
|
195
|
+
invocation_id,
|
|
196
|
+
function_name,
|
|
197
|
+
IntermediateStepType.SPAN_END,
|
|
198
|
+
args=serialized_args,
|
|
199
|
+
kwargs=serialized_kwargs,
|
|
200
|
+
output=serialized_result,
|
|
201
|
+
metadata=metadata)
|
|
202
|
+
|
|
203
|
+
return result
|
|
204
|
+
|
|
205
|
+
return async_wrapper
|
|
206
|
+
|
|
207
|
+
if inspect.isgeneratorfunction(func):
|
|
208
|
+
# ---------------------
|
|
209
|
+
# SYNC GENERATOR
|
|
210
|
+
# ---------------------
|
|
211
|
+
@functools.wraps(func)
|
|
212
|
+
def sync_gen_wrapper(*args, **kwargs):
|
|
213
|
+
step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
|
|
214
|
+
serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
|
|
215
|
+
invocation_id = str(uuid.uuid4())
|
|
216
|
+
push_intermediate_step(step_manager,
|
|
217
|
+
invocation_id,
|
|
218
|
+
function_name,
|
|
219
|
+
IntermediateStepType.SPAN_START,
|
|
220
|
+
args=serialized_args,
|
|
221
|
+
kwargs=serialized_kwargs,
|
|
222
|
+
metadata=metadata)
|
|
223
|
+
|
|
224
|
+
for item in func(*args, **kwargs):
|
|
225
|
+
serialized_item = _serialize_data(item)
|
|
226
|
+
push_intermediate_step(step_manager,
|
|
227
|
+
invocation_id,
|
|
228
|
+
function_name,
|
|
229
|
+
IntermediateStepType.SPAN_CHUNK,
|
|
230
|
+
args=serialized_args,
|
|
231
|
+
kwargs=serialized_kwargs,
|
|
232
|
+
output=serialized_item,
|
|
233
|
+
metadata=metadata)
|
|
234
|
+
|
|
235
|
+
yield item # yield the original item
|
|
236
|
+
|
|
237
|
+
push_intermediate_step(step_manager,
|
|
238
|
+
invocation_id,
|
|
239
|
+
function_name,
|
|
240
|
+
IntermediateStepType.SPAN_END,
|
|
241
|
+
args=serialized_args,
|
|
242
|
+
kwargs=serialized_kwargs,
|
|
243
|
+
output=None,
|
|
244
|
+
metadata=metadata)
|
|
245
|
+
|
|
246
|
+
return sync_gen_wrapper
|
|
247
|
+
|
|
248
|
+
@functools.wraps(func)
|
|
249
|
+
def sync_wrapper(*args, **kwargs):
|
|
250
|
+
step_manager: IntermediateStepManager = Context.get().intermediate_step_manager
|
|
251
|
+
serialized_args, serialized_kwargs = _prepare_serialized_args_kwargs(*args, **kwargs)
|
|
252
|
+
invocation_id = str(uuid.uuid4())
|
|
253
|
+
push_intermediate_step(step_manager,
|
|
254
|
+
invocation_id,
|
|
255
|
+
function_name,
|
|
256
|
+
IntermediateStepType.SPAN_START,
|
|
257
|
+
args=serialized_args,
|
|
258
|
+
kwargs=serialized_kwargs,
|
|
259
|
+
metadata=metadata)
|
|
260
|
+
|
|
261
|
+
result = func(*args, **kwargs)
|
|
262
|
+
|
|
263
|
+
serialized_result = _serialize_data(result)
|
|
264
|
+
push_intermediate_step(step_manager,
|
|
265
|
+
invocation_id,
|
|
266
|
+
function_name,
|
|
267
|
+
IntermediateStepType.SPAN_END,
|
|
268
|
+
args=serialized_args,
|
|
269
|
+
kwargs=serialized_kwargs,
|
|
270
|
+
output=serialized_result,
|
|
271
|
+
metadata=metadata)
|
|
272
|
+
|
|
273
|
+
return result
|
|
274
|
+
|
|
275
|
+
return sync_wrapper
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
# Overloads for track_unregistered_function
|
|
279
|
+
@overload
|
|
280
|
+
def track_unregistered_function(func: F, *, name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
|
|
281
|
+
"""Overload for when a function is passed directly."""
|
|
282
|
+
...
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@overload
|
|
286
|
+
def track_unregistered_function(*, name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
|
|
287
|
+
"""Overload for decorator factory usage (when called with parentheses)."""
|
|
288
|
+
...
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def track_unregistered_function(func: Callable[..., Any] | None = None,
|
|
292
|
+
*,
|
|
293
|
+
name: str | None = None,
|
|
294
|
+
metadata: dict[str, Any] | None = None) -> Callable[..., Any]:
|
|
295
|
+
"""
|
|
296
|
+
Decorator that wraps any function with scope management and automatic tracking.
|
|
297
|
+
|
|
298
|
+
- Sets active function context using the function name
|
|
299
|
+
- Leverages Context.push_active_function for built-in tracking
|
|
300
|
+
- Avoids duplicate tracking entries by relying on the library's built-in systems
|
|
301
|
+
- Supports sync/async functions and generators
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
func: The function to wrap (auto-detected when used without parentheses)
|
|
305
|
+
name: Custom name to use for tracking instead of func.__name__
|
|
306
|
+
metadata: Additional metadata to include in tracking
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
# If called with parameters: @track_unregistered_function(name="...", metadata={...})
|
|
310
|
+
if func is None:
|
|
311
|
+
|
|
312
|
+
def decorator_wrapper(actual_func: Callable[..., Any]) -> Callable[..., Any]:
|
|
313
|
+
# Cast to ensure type checker understands this returns a callable
|
|
314
|
+
return cast(Callable[..., Any], track_unregistered_function(actual_func, name=name, metadata=metadata))
|
|
315
|
+
|
|
316
|
+
return decorator_wrapper
|
|
317
|
+
|
|
318
|
+
# Direct decoration: @track_unregistered_function or recursive call with actual function
|
|
319
|
+
function_name: str = name if name else func.__name__
|
|
320
|
+
|
|
321
|
+
# --- Validate metadata ---
|
|
322
|
+
if metadata is not None:
|
|
323
|
+
if not isinstance(metadata, dict):
|
|
324
|
+
raise TypeError("metadata must be a dict[str, Any].")
|
|
325
|
+
if any(not isinstance(k, str) for k in metadata.keys()):
|
|
326
|
+
raise TypeError("All metadata keys must be strings.")
|
|
327
|
+
|
|
328
|
+
trace_metadata = TraceMetadata(provided_metadata=metadata)
|
|
329
|
+
|
|
330
|
+
# --- Now detect the function type and wrap accordingly ---
|
|
331
|
+
if inspect.isasyncgenfunction(func):
|
|
332
|
+
# ---------------------
|
|
333
|
+
# ASYNC GENERATOR
|
|
334
|
+
# ---------------------
|
|
335
|
+
|
|
336
|
+
@functools.wraps(func)
|
|
337
|
+
async def async_gen_wrapper(*args, **kwargs):
|
|
338
|
+
context = Context.get()
|
|
339
|
+
input_data = (
|
|
340
|
+
*args,
|
|
341
|
+
kwargs,
|
|
342
|
+
)
|
|
343
|
+
# Only do context management - let push_active_function handle tracking
|
|
344
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
345
|
+
final_outputs = []
|
|
346
|
+
async for item in func(*args, **kwargs):
|
|
347
|
+
final_outputs.append(item)
|
|
348
|
+
yield item
|
|
349
|
+
|
|
350
|
+
manager.set_output(final_outputs)
|
|
351
|
+
|
|
352
|
+
return async_gen_wrapper
|
|
353
|
+
|
|
354
|
+
if inspect.iscoroutinefunction(func):
|
|
355
|
+
# ---------------------
|
|
356
|
+
# ASYNC FUNCTION
|
|
357
|
+
# ---------------------
|
|
358
|
+
@functools.wraps(func)
|
|
359
|
+
async def async_wrapper(*args, **kwargs):
|
|
360
|
+
context = Context.get()
|
|
361
|
+
input_data = (
|
|
362
|
+
*args,
|
|
363
|
+
kwargs,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Only do context management - let push_active_function handle tracking
|
|
367
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
368
|
+
result = await func(*args, **kwargs)
|
|
369
|
+
manager.set_output(result)
|
|
370
|
+
return result
|
|
371
|
+
|
|
372
|
+
return async_wrapper
|
|
373
|
+
|
|
374
|
+
if inspect.isgeneratorfunction(func):
|
|
375
|
+
# ---------------------
|
|
376
|
+
# SYNC GENERATOR
|
|
377
|
+
# ---------------------
|
|
378
|
+
@functools.wraps(func)
|
|
379
|
+
def sync_gen_wrapper(*args, **kwargs):
|
|
380
|
+
context = Context.get()
|
|
381
|
+
input_data = (
|
|
382
|
+
*args,
|
|
383
|
+
kwargs,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Only do context management - let push_active_function handle tracking
|
|
387
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
388
|
+
final_outputs = []
|
|
389
|
+
for item in func(*args, **kwargs):
|
|
390
|
+
final_outputs.append(item)
|
|
391
|
+
yield item
|
|
392
|
+
|
|
393
|
+
manager.set_output(final_outputs)
|
|
394
|
+
|
|
395
|
+
return sync_gen_wrapper
|
|
396
|
+
|
|
397
|
+
@functools.wraps(func)
|
|
398
|
+
def sync_wrapper(*args, **kwargs):
|
|
399
|
+
context = Context.get()
|
|
400
|
+
input_data = (
|
|
401
|
+
*args,
|
|
402
|
+
kwargs,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# Only do context management - let push_active_function handle tracking
|
|
406
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
407
|
+
result = func(*args, **kwargs)
|
|
408
|
+
manager.set_output(result)
|
|
409
|
+
return result
|
|
410
|
+
|
|
411
|
+
return sync_wrapper
|
|
File without changes
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# If you have any global constants or defaults
|
|
17
|
+
DEFAULT_MODEL_TYPE = "randomforest"
|
|
18
|
+
DEFAULT_MATRIX_LENGTH = 10
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# forecasting/model_trainer.py
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
|
|
20
|
+
from nat.profiler.forecasting.config import DEFAULT_MODEL_TYPE
|
|
21
|
+
from nat.profiler.forecasting.models import ForecastingBaseModel
|
|
22
|
+
from nat.profiler.forecasting.models import LinearModel
|
|
23
|
+
from nat.profiler.forecasting.models import RandomForestModel
|
|
24
|
+
from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def create_model(model_type: str) -> ForecastingBaseModel:
|
|
30
|
+
"""
|
|
31
|
+
A simple factory method that returns a model instance
|
|
32
|
+
based on the input string. Extend this with more model
|
|
33
|
+
classes (e.g., PolynomialModel, RandomForestModel, etc.).
|
|
34
|
+
"""
|
|
35
|
+
if model_type == "linear":
|
|
36
|
+
return LinearModel()
|
|
37
|
+
if model_type == "randomforest":
|
|
38
|
+
return RandomForestModel()
|
|
39
|
+
|
|
40
|
+
raise ValueError(f"Unsupported model_type: {model_type}")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ModelTrainer:
|
|
44
|
+
"""
|
|
45
|
+
Orchestrates data preprocessing, training, and returning
|
|
46
|
+
a fitted model.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
model_type: str, default = "randomforest"
|
|
51
|
+
The type of model to train. Options include "linear" and "randomforest".
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, model_type: str = DEFAULT_MODEL_TYPE):
|
|
55
|
+
self.model_type = model_type
|
|
56
|
+
self._model = create_model(self.model_type)
|
|
57
|
+
|
|
58
|
+
def train(self, raw_stats: list[list[IntermediatePropertyAdaptor]]) -> ForecastingBaseModel:
|
|
59
|
+
"""
|
|
60
|
+
Train the model using the `raw_stats` training data.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
raw_stats: list[list[IntermediatePropertyAdaptor]]
|
|
65
|
+
Stats collected by the profiler.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
ForecastingBaseModel
|
|
70
|
+
A fitted model.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
self._model.fit(raw_stats)
|
|
74
|
+
|
|
75
|
+
return self._model
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# forecasting/models/__init__.py
|
|
17
|
+
|
|
18
|
+
from .forecasting_base_model import ForecastingBaseModel
|
|
19
|
+
from .linear_model import LinearModel
|
|
20
|
+
from .random_forest_regressor import RandomForestModel
|
|
21
|
+
|
|
22
|
+
__all__ = ["ForecastingBaseModel", "LinearModel", "RandomForestModel"]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# forecasting/models/base_model.py
|
|
17
|
+
|
|
18
|
+
from abc import ABC
|
|
19
|
+
from abc import abstractmethod
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ForecastingBaseModel(ABC):
|
|
25
|
+
"""
|
|
26
|
+
Abstract base class for all models in this package.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def fit(self, raw_stats):
|
|
31
|
+
"""
|
|
32
|
+
Train/fine-tune the model on the provided dataset.
|
|
33
|
+
"""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def predict(self, raw_stats) -> np.ndarray:
|
|
38
|
+
"""
|
|
39
|
+
Predict using the trained model.
|
|
40
|
+
Returns a np.ndarray, shape = (N, 4).
|
|
41
|
+
"""
|
|
42
|
+
pass
|