nvidia-nat 1.4.0a20251112__py3-none-any.whl → 1.4.0a20260113__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +1 -1
- nat/{front_ends/mcp → agent/auto_memory_wrapper}/__init__.py +1 -1
- nat/agent/auto_memory_wrapper/agent.py +278 -0
- nat/agent/auto_memory_wrapper/register.py +227 -0
- nat/agent/auto_memory_wrapper/state.py +30 -0
- nat/agent/base.py +1 -1
- nat/agent/dual_node.py +1 -1
- nat/agent/prompt_optimizer/prompt.py +1 -1
- nat/agent/prompt_optimizer/register.py +1 -1
- nat/agent/react_agent/agent.py +16 -9
- nat/agent/react_agent/output_parser.py +2 -2
- nat/agent/react_agent/prompt.py +3 -2
- nat/agent/react_agent/register.py +2 -2
- nat/agent/react_agent/register_per_user_agent.py +104 -0
- nat/agent/reasoning_agent/reasoning_agent.py +1 -1
- nat/agent/register.py +3 -1
- nat/agent/responses_api_agent/__init__.py +1 -1
- nat/agent/responses_api_agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +9 -4
- nat/agent/rewoo_agent/prompt.py +1 -1
- nat/agent/rewoo_agent/register.py +1 -1
- nat/agent/tool_calling_agent/agent.py +5 -4
- nat/agent/tool_calling_agent/register.py +1 -1
- nat/authentication/__init__.py +1 -1
- nat/authentication/api_key/__init__.py +1 -1
- nat/authentication/api_key/api_key_auth_provider.py +1 -1
- nat/authentication/api_key/api_key_auth_provider_config.py +22 -7
- nat/authentication/api_key/register.py +1 -1
- nat/authentication/credential_validator/__init__.py +1 -1
- nat/authentication/credential_validator/bearer_token_validator.py +1 -1
- nat/authentication/exceptions/__init__.py +1 -1
- nat/authentication/exceptions/api_key_exceptions.py +1 -1
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/http_basic_auth/register.py +1 -1
- nat/authentication/interfaces.py +1 -1
- nat/authentication/oauth2/__init__.py +1 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +1 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +1 -1
- nat/authentication/oauth2/oauth2_resource_server_config.py +1 -1
- nat/authentication/oauth2/register.py +1 -1
- nat/authentication/register.py +1 -1
- nat/builder/builder.py +563 -1
- nat/builder/child_builder.py +385 -0
- nat/builder/component_utils.py +34 -4
- nat/builder/context.py +34 -1
- nat/builder/embedder.py +1 -1
- nat/builder/eval_builder.py +19 -7
- nat/builder/evaluator.py +1 -1
- nat/builder/framework_enum.py +3 -1
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +113 -5
- nat/builder/function_base.py +1 -1
- nat/builder/function_info.py +1 -1
- nat/builder/intermediate_step_manager.py +1 -1
- nat/builder/llm.py +1 -1
- nat/builder/per_user_workflow_builder.py +843 -0
- nat/builder/retriever.py +1 -1
- nat/builder/sync_builder.py +571 -0
- nat/builder/user_interaction_manager.py +1 -1
- nat/builder/workflow.py +5 -3
- nat/builder/workflow_builder.py +619 -378
- nat/cli/__init__.py +1 -1
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/cli_utils/validation.py +32 -1
- nat/cli/commands/configure/channel/add.py +1 -1
- nat/cli/commands/configure/channel/channel.py +1 -1
- nat/cli/commands/configure/channel/remove.py +1 -1
- nat/cli/commands/configure/channel/update.py +1 -1
- nat/cli/commands/configure/configure.py +1 -1
- nat/cli/commands/evaluate.py +87 -13
- nat/cli/commands/finetune.py +132 -0
- nat/cli/commands/info/__init__.py +1 -1
- nat/cli/commands/info/info.py +1 -1
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +1 -1
- nat/cli/commands/object_store/__init__.py +1 -1
- nat/cli/commands/object_store/object_store.py +1 -1
- nat/cli/commands/optimize.py +1 -1
- nat/cli/commands/{mcp → red_teaming}/__init__.py +1 -1
- nat/cli/commands/red_teaming/red_teaming.py +138 -0
- nat/cli/commands/red_teaming/red_teaming_utils.py +73 -0
- nat/cli/commands/registry/__init__.py +1 -1
- nat/cli/commands/registry/publish.py +1 -1
- nat/cli/commands/registry/pull.py +1 -1
- nat/cli/commands/registry/registry.py +1 -1
- nat/cli/commands/registry/remove.py +1 -1
- nat/cli/commands/registry/search.py +1 -1
- nat/cli/commands/sizing/__init__.py +1 -1
- nat/cli/commands/sizing/calc.py +1 -1
- nat/cli/commands/sizing/sizing.py +1 -1
- nat/cli/commands/start.py +1 -1
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/validate.py +1 -1
- nat/cli/commands/workflow/__init__.py +1 -1
- nat/cli/commands/workflow/workflow.py +1 -1
- nat/cli/commands/workflow/workflow_commands.py +3 -2
- nat/cli/entrypoint.py +15 -37
- nat/cli/main.py +2 -2
- nat/cli/plugin_loader.py +69 -0
- nat/cli/register_workflow.py +233 -5
- nat/cli/type_registry.py +237 -3
- nat/control_flow/register.py +1 -1
- nat/control_flow/router_agent/agent.py +1 -1
- nat/control_flow/router_agent/prompt.py +1 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/control_flow/sequential_executor.py +28 -7
- nat/data_models/__init__.py +1 -1
- nat/data_models/agent.py +1 -1
- nat/data_models/api_server.py +38 -3
- nat/data_models/authentication.py +1 -1
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +9 -1
- nat/data_models/component_ref.py +45 -1
- nat/data_models/config.py +78 -1
- nat/data_models/dataset_handler.py +15 -2
- nat/data_models/discovery_metadata.py +1 -1
- nat/data_models/embedder.py +1 -1
- nat/data_models/evaluate.py +6 -1
- nat/data_models/evaluator.py +1 -1
- nat/data_models/finetuning.py +260 -0
- nat/data_models/front_end.py +1 -1
- nat/data_models/function.py +15 -2
- nat/data_models/function_dependencies.py +1 -1
- nat/data_models/gated_field_mixin.py +1 -1
- nat/data_models/interactive.py +1 -1
- nat/data_models/intermediate_step.py +29 -2
- nat/data_models/invocation_node.py +1 -1
- nat/data_models/llm.py +1 -1
- nat/data_models/logging.py +1 -1
- nat/data_models/memory.py +1 -1
- nat/data_models/middleware.py +37 -0
- nat/data_models/object_store.py +1 -1
- nat/data_models/openai_mcp.py +1 -1
- nat/data_models/optimizable.py +1 -1
- nat/data_models/optimizer.py +1 -1
- nat/data_models/profiler.py +1 -1
- nat/data_models/registry_handler.py +1 -1
- nat/data_models/retriever.py +1 -1
- nat/data_models/retry_mixin.py +1 -1
- nat/data_models/runtime_enum.py +26 -0
- nat/data_models/span.py +1 -1
- nat/data_models/step_adaptor.py +1 -1
- nat/data_models/streaming.py +1 -1
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/telemetry_exporter.py +1 -1
- nat/data_models/thinking_mixin.py +1 -1
- nat/data_models/ttc_strategy.py +1 -1
- nat/embedder/azure_openai_embedder.py +1 -1
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +1 -1
- nat/eval/__init__.py +1 -1
- nat/eval/config.py +8 -1
- nat/eval/dataset_handler/dataset_downloader.py +1 -1
- nat/eval/dataset_handler/dataset_filter.py +1 -1
- nat/eval/dataset_handler/dataset_handler.py +4 -2
- nat/eval/evaluate.py +226 -81
- nat/eval/evaluator/__init__.py +1 -1
- nat/eval/evaluator/base_evaluator.py +2 -2
- nat/eval/evaluator/evaluator_model.py +3 -2
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/llm_validator.py +336 -0
- nat/eval/rag_evaluator/evaluate.py +17 -10
- nat/eval/rag_evaluator/register.py +1 -1
- nat/eval/red_teaming_evaluator/__init__.py +14 -0
- nat/eval/red_teaming_evaluator/data_models.py +66 -0
- nat/eval/red_teaming_evaluator/evaluate.py +327 -0
- nat/eval/red_teaming_evaluator/filter_conditions.py +75 -0
- nat/eval/red_teaming_evaluator/register.py +55 -0
- nat/eval/register.py +2 -1
- nat/eval/remote_workflow.py +1 -1
- nat/eval/runners/__init__.py +1 -1
- nat/eval/runners/config.py +1 -1
- nat/eval/runners/multi_eval_runner.py +1 -1
- nat/eval/runners/red_teaming_runner/__init__.py +24 -0
- nat/eval/runners/red_teaming_runner/config.py +282 -0
- nat/eval/runners/red_teaming_runner/report_utils.py +707 -0
- nat/eval/runners/red_teaming_runner/runner.py +867 -0
- nat/eval/runtime_evaluator/__init__.py +1 -1
- nat/eval/runtime_evaluator/evaluate.py +1 -1
- nat/eval/runtime_evaluator/register.py +1 -1
- nat/eval/runtime_event_subscriber.py +1 -1
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/swe_bench_evaluator/register.py +1 -1
- nat/eval/trajectory_evaluator/evaluate.py +2 -2
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +5 -5
- nat/eval/tunable_rag_evaluator/register.py +1 -1
- nat/eval/usage_stats.py +1 -1
- nat/eval/utils/eval_trace_ctx.py +1 -1
- nat/eval/utils/output_uploader.py +1 -1
- nat/eval/utils/tqdm_position_registry.py +1 -1
- nat/eval/utils/weave_eval.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +1 -1
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +1 -1
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +1 -1
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +1 -1
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/multi_llm_judge_function.py +88 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/editor_config.py +1 -1
- nat/experimental/test_time_compute/models/scoring_config.py +1 -1
- nat/experimental/test_time_compute/models/search_config.py +20 -2
- nat/experimental/test_time_compute/models/selection_config.py +33 -2
- nat/experimental/test_time_compute/models/stage_enums.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +1 -1
- nat/experimental/test_time_compute/models/tool_use_config.py +1 -1
- nat/experimental/test_time_compute/models/ttc_item.py +1 -1
- nat/experimental/test_time_compute/register.py +4 -1
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +1 -1
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +1 -1
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +1 -1
- nat/experimental/test_time_compute/search/multi_llm_generation.py +115 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +1 -1
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +1 -1
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +1 -1
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_judge_selection.py +127 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +1 -1
- nat/finetuning/__init__.py +24 -0
- nat/finetuning/finetuning_runtime.py +143 -0
- nat/finetuning/interfaces/__init__.py +24 -0
- nat/finetuning/interfaces/finetuning_runner.py +261 -0
- nat/finetuning/interfaces/trainer_adapter.py +103 -0
- nat/finetuning/interfaces/trajectory_builder.py +115 -0
- nat/finetuning/utils/__init__.py +15 -0
- nat/finetuning/utils/parsers/__init__.py +15 -0
- nat/finetuning/utils/parsers/adk_parser.py +141 -0
- nat/finetuning/utils/parsers/base_parser.py +238 -0
- nat/finetuning/utils/parsers/common.py +91 -0
- nat/finetuning/utils/parsers/langchain_parser.py +267 -0
- nat/finetuning/utils/parsers/llama_index_parser.py +218 -0
- nat/front_ends/__init__.py +1 -1
- nat/front_ends/console/__init__.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +1 -1
- nat/front_ends/console/console_front_end_config.py +4 -1
- nat/front_ends/console/console_front_end_plugin.py +5 -4
- nat/front_ends/console/register.py +1 -1
- nat/front_ends/cron/__init__.py +1 -1
- nat/front_ends/fastapi/__init__.py +1 -1
- nat/front_ends/fastapi/async_job.py +128 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +13 -9
- nat/front_ends/fastapi/dask_client_mixin.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_config.py +23 -1
- nat/front_ends/fastapi/fastapi_front_end_controller.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +25 -30
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +318 -59
- nat/front_ends/fastapi/html_snippets/__init__.py +1 -1
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +1 -1
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +12 -1
- nat/front_ends/fastapi/job_store.py +23 -11
- nat/front_ends/fastapi/main.py +1 -1
- nat/front_ends/fastapi/message_handler.py +27 -4
- nat/front_ends/fastapi/message_validator.py +54 -2
- nat/front_ends/fastapi/register.py +1 -1
- nat/front_ends/fastapi/response_helpers.py +16 -15
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/fastapi/utils.py +1 -1
- nat/front_ends/register.py +1 -2
- nat/front_ends/simple_base/__init__.py +1 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +6 -4
- nat/llm/aws_bedrock_llm.py +1 -1
- nat/llm/azure_openai_llm.py +10 -1
- nat/llm/dynamo_llm.py +363 -0
- nat/llm/huggingface_llm.py +177 -0
- nat/llm/litellm_llm.py +1 -1
- nat/llm/nim_llm.py +1 -1
- nat/llm/openai_llm.py +1 -1
- nat/llm/register.py +3 -1
- nat/llm/utils/__init__.py +1 -1
- nat/llm/utils/env_config_value.py +1 -1
- nat/llm/utils/error.py +1 -1
- nat/llm/utils/thinking.py +1 -1
- nat/memory/__init__.py +1 -1
- nat/memory/interfaces.py +1 -1
- nat/memory/models.py +1 -1
- nat/meta/pypi.md +1 -1
- nat/middleware/__init__.py +35 -0
- nat/middleware/cache/__init__.py +14 -0
- nat/middleware/cache/cache_middleware.py +253 -0
- nat/middleware/cache/cache_middleware_config.py +44 -0
- nat/middleware/cache/register.py +33 -0
- nat/middleware/defense/__init__.py +14 -0
- nat/middleware/defense/defense_middleware.py +362 -0
- nat/middleware/defense/defense_middleware_content_guard.py +455 -0
- nat/middleware/defense/defense_middleware_data_models.py +91 -0
- nat/middleware/defense/defense_middleware_output_verifier.py +440 -0
- nat/middleware/defense/defense_middleware_pii.py +356 -0
- nat/middleware/defense/register.py +82 -0
- nat/middleware/dynamic/__init__.py +14 -0
- nat/middleware/dynamic/dynamic_function_middleware.py +962 -0
- nat/middleware/dynamic/dynamic_middleware_config.py +132 -0
- nat/middleware/dynamic/register.py +34 -0
- nat/middleware/function_middleware.py +370 -0
- nat/middleware/logging/__init__.py +14 -0
- nat/middleware/logging/logging_middleware.py +67 -0
- nat/middleware/logging/logging_middleware_config.py +28 -0
- nat/middleware/logging/register.py +33 -0
- nat/middleware/middleware.py +298 -0
- nat/middleware/red_teaming/__init__.py +14 -0
- nat/middleware/red_teaming/red_teaming_middleware.py +344 -0
- nat/middleware/red_teaming/red_teaming_middleware_config.py +112 -0
- nat/middleware/red_teaming/register.py +47 -0
- nat/middleware/register.py +22 -0
- nat/middleware/utils/__init__.py +14 -0
- nat/middleware/utils/workflow_inventory.py +155 -0
- nat/object_store/__init__.py +1 -1
- nat/object_store/in_memory_object_store.py +1 -1
- nat/object_store/interfaces.py +1 -1
- nat/object_store/models.py +1 -1
- nat/object_store/register.py +1 -1
- nat/observability/__init__.py +1 -1
- nat/observability/exporter/__init__.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/exporter.py +1 -1
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +1 -1
- nat/observability/exporter/raw_exporter.py +1 -1
- nat/observability/exporter/span_exporter.py +7 -1
- nat/observability/exporter_manager.py +1 -1
- nat/observability/mixin/__init__.py +1 -1
- nat/observability/mixin/batch_config_mixin.py +1 -1
- nat/observability/mixin/collector_config_mixin.py +1 -1
- nat/observability/mixin/file_mixin.py +1 -1
- nat/observability/mixin/file_mode.py +1 -1
- nat/observability/mixin/redaction_config_mixin.py +1 -1
- nat/observability/mixin/resource_conflict_mixin.py +1 -1
- nat/observability/mixin/serialize_mixin.py +1 -1
- nat/observability/mixin/tagging_config_mixin.py +1 -1
- nat/observability/mixin/type_introspection_mixin.py +1 -1
- nat/observability/processor/__init__.py +1 -1
- nat/observability/processor/batching_processor.py +1 -1
- nat/observability/processor/callback_processor.py +1 -1
- nat/observability/processor/falsy_batch_filter_processor.py +1 -1
- nat/observability/processor/intermediate_step_serializer.py +1 -1
- nat/observability/processor/processor.py +1 -1
- nat/observability/processor/processor_factory.py +1 -1
- nat/observability/processor/redaction/__init__.py +1 -1
- nat/observability/processor/redaction/contextual_redaction_processor.py +1 -1
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +1 -1
- nat/observability/processor/redaction/redaction_processor.py +1 -1
- nat/observability/processor/redaction/span_header_redaction_processor.py +1 -1
- nat/observability/processor/span_tagging_processor.py +1 -1
- nat/observability/register.py +1 -1
- nat/observability/utils/__init__.py +1 -1
- nat/observability/utils/dict_utils.py +1 -1
- nat/observability/utils/time_utils.py +1 -1
- nat/profiler/calc/__init__.py +1 -1
- nat/profiler/calc/calc_runner.py +3 -3
- nat/profiler/calc/calculations.py +1 -1
- nat/profiler/calc/data_models.py +1 -1
- nat/profiler/calc/plot.py +30 -3
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/base_callback_class.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +33 -3
- nat/profiler/callbacks/llama_index_callback_handler.py +13 -10
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- nat/profiler/callbacks/token_usage_base_model.py +1 -1
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/data_models.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +32 -1
- nat/profiler/decorators/function_tracking.py +1 -1
- nat/profiler/forecasting/config.py +1 -1
- nat/profiler/forecasting/model_trainer.py +1 -1
- nat/profiler/forecasting/models/__init__.py +1 -1
- nat/profiler/forecasting/models/forecasting_base_model.py +1 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_metrics_model.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +1 -1
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/profiler/inference_optimization/llm_metrics.py +1 -1
- nat/profiler/inference_optimization/prompt_caching.py +1 -1
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/inference_optimization/workflow_runtimes.py +1 -1
- nat/profiler/intermediate_property_adapter.py +1 -1
- nat/profiler/parameter_optimization/optimizable_utils.py +1 -1
- nat/profiler/parameter_optimization/optimizer_runtime.py +1 -1
- nat/profiler/parameter_optimization/parameter_optimizer.py +1 -1
- nat/profiler/parameter_optimization/parameter_selection.py +1 -1
- nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
- nat/profiler/parameter_optimization/prompt_optimizer.py +1 -1
- nat/profiler/parameter_optimization/update_helpers.py +1 -1
- nat/profiler/profile_runner.py +1 -1
- nat/profiler/utils.py +1 -1
- nat/registry_handlers/local/local_handler.py +1 -1
- nat/registry_handlers/local/register_local.py +1 -1
- nat/registry_handlers/metadata_factory.py +1 -1
- nat/registry_handlers/package_utils.py +1 -1
- nat/registry_handlers/pypi/pypi_handler.py +1 -1
- nat/registry_handlers/pypi/register_pypi.py +1 -1
- nat/registry_handlers/register.py +1 -1
- nat/registry_handlers/registry_handler_base.py +1 -1
- nat/registry_handlers/rest/register_rest.py +1 -1
- nat/registry_handlers/rest/rest_handler.py +1 -1
- nat/registry_handlers/schemas/headers.py +1 -1
- nat/registry_handlers/schemas/package.py +1 -1
- nat/registry_handlers/schemas/publish.py +1 -1
- nat/registry_handlers/schemas/pull.py +1 -1
- nat/registry_handlers/schemas/remove.py +1 -1
- nat/registry_handlers/schemas/search.py +1 -1
- nat/registry_handlers/schemas/status.py +1 -1
- nat/retriever/interface.py +1 -1
- nat/retriever/milvus/__init__.py +1 -1
- nat/retriever/milvus/register.py +12 -4
- nat/retriever/milvus/retriever.py +103 -41
- nat/retriever/models.py +1 -1
- nat/retriever/nemo_retriever/__init__.py +1 -1
- nat/retriever/nemo_retriever/register.py +1 -1
- nat/retriever/nemo_retriever/retriever.py +5 -5
- nat/retriever/register.py +1 -1
- nat/runtime/__init__.py +1 -1
- nat/runtime/loader.py +10 -3
- nat/runtime/metrics.py +180 -0
- nat/runtime/runner.py +13 -6
- nat/runtime/session.py +458 -32
- nat/runtime/user_metadata.py +1 -1
- nat/settings/global_settings.py +1 -1
- nat/tool/chat_completion.py +1 -1
- nat/tool/code_execution/README.md +1 -1
- nat/tool/code_execution/code_sandbox.py +2 -2
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +1 -1
- nat/tool/code_execution/local_sandbox/__init__.py +1 -1
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +1 -1
- nat/tool/code_execution/register.py +1 -1
- nat/tool/code_execution/utils.py +1 -1
- nat/tool/datetime_tools.py +1 -1
- nat/tool/document_search.py +1 -1
- nat/tool/github_tools.py +1 -1
- nat/tool/memory_tools/add_memory_tool.py +1 -1
- nat/tool/memory_tools/delete_memory_tool.py +1 -1
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +2 -2
- nat/tool/register.py +1 -1
- nat/tool/retriever.py +1 -1
- nat/tool/server_tools.py +1 -1
- nat/utils/__init__.py +8 -5
- nat/utils/callable_utils.py +1 -1
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/debugging_utils.py +1 -1
- nat/utils/decorators.py +1 -1
- nat/utils/dump_distro_mapping.py +1 -1
- nat/utils/exception_handlers/automatic_retries.py +3 -3
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/model_processing.py +1 -1
- nat/utils/io/supress_logs.py +33 -0
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/log_levels.py +1 -1
- nat/utils/log_utils.py +13 -1
- nat/utils/metadata_utils.py +1 -1
- nat/utils/optional_imports.py +1 -1
- nat/utils/producer_consumer_queue.py +1 -1
- nat/utils/reactive/base/observable_base.py +1 -1
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/base/subject_base.py +1 -1
- nat/utils/reactive/observable.py +1 -1
- nat/utils/reactive/observer.py +1 -1
- nat/utils/reactive/subject.py +1 -1
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/responses_api.py +1 -1
- nat/utils/settings/global_settings.py +1 -1
- nat/utils/string_utils.py +1 -1
- nat/utils/type_converter.py +18 -5
- nat/utils/type_utils.py +1 -1
- nat/utils/url_utils.py +1 -1
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/METADATA +46 -15
- nvidia_nat-1.4.0a20260113.dist-info/RECORD +547 -0
- nvidia_nat-1.4.0a20260113.dist-info/entry_points.txt +38 -0
- nat/cli/commands/mcp/mcp.py +0 -986
- nat/front_ends/mcp/introspection_token_verifier.py +0 -73
- nat/front_ends/mcp/mcp_front_end_config.py +0 -109
- nat/front_ends/mcp/mcp_front_end_plugin.py +0 -151
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +0 -362
- nat/front_ends/mcp/memory_profiler.py +0 -320
- nat/front_ends/mcp/register.py +0 -27
- nat/front_ends/mcp/tool_converter.py +0 -321
- nvidia_nat-1.4.0a20251112.dist-info/RECORD +0 -481
- nvidia_nat-1.4.0a20251112.dist-info/entry_points.txt +0 -22
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import re
|
|
18
|
+
|
|
19
|
+
from nat.builder.builder import Builder
|
|
20
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
21
|
+
from nat.cli.register_workflow import register_ttc_strategy
|
|
22
|
+
from nat.experimental.test_time_compute.models.selection_config import LLMJudgeSelectionConfig
|
|
23
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
24
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
25
|
+
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
26
|
+
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
27
|
+
from nat.utils.io.model_processing import remove_r1_think_tags
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class LLMJudgeSelection(StrategyBase):
|
|
33
|
+
"""
|
|
34
|
+
A selection strategy that uses a configured Judge LLM to select the best response.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, config: LLMJudgeSelectionConfig) -> None:
|
|
38
|
+
super().__init__(config)
|
|
39
|
+
self.config = config
|
|
40
|
+
self.judge_llm_bound = None
|
|
41
|
+
|
|
42
|
+
async def build_components(self, builder: Builder) -> None:
|
|
43
|
+
"""
|
|
44
|
+
Builds the Judge LLM configured in the strategy.
|
|
45
|
+
"""
|
|
46
|
+
logger.debug("Building components for LLMJudgeSelection")
|
|
47
|
+
self.judge_llm_bound = await builder.get_llm(self.config.judge_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
48
|
+
|
|
49
|
+
def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
|
|
50
|
+
return [PipelineTypeEnum.CUSTOM, PipelineTypeEnum.PLANNING, PipelineTypeEnum.AGENT_EXECUTION]
|
|
51
|
+
|
|
52
|
+
def stage_type(self) -> StageTypeEnum:
|
|
53
|
+
return StageTypeEnum.SELECTION
|
|
54
|
+
|
|
55
|
+
async def ainvoke(self,
|
|
56
|
+
items: list[TTCItem],
|
|
57
|
+
original_prompt: str | None = None,
|
|
58
|
+
agent_context: str | None = None,
|
|
59
|
+
**kwargs) -> list[TTCItem]:
|
|
60
|
+
"""
|
|
61
|
+
Select the best item using the configured Judge LLM.
|
|
62
|
+
"""
|
|
63
|
+
if not self.judge_llm_bound:
|
|
64
|
+
raise ValueError("Judge LLM not bound. Ensure `build_components` has been called.")
|
|
65
|
+
|
|
66
|
+
if not items:
|
|
67
|
+
logger.warning("No items provided for selection.")
|
|
68
|
+
return []
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
from langchain_core.prompts import PromptTemplate
|
|
72
|
+
from pydantic import BaseModel
|
|
73
|
+
except ImportError as exc:
|
|
74
|
+
raise ImportError("langchain-core is not installed.") from exc
|
|
75
|
+
|
|
76
|
+
# Format the results for the prompt
|
|
77
|
+
results_str = ""
|
|
78
|
+
for idx, item in enumerate(items):
|
|
79
|
+
item_output = (str(item.output.model_dump()) if isinstance(item.output, BaseModel) else str(item.output))
|
|
80
|
+
results_str += f"{idx + 1}. {remove_r1_think_tags(item_output)}\n\n"
|
|
81
|
+
|
|
82
|
+
prompt_template = PromptTemplate(
|
|
83
|
+
template=self.config.selection_template,
|
|
84
|
+
input_variables=["original_prompt", "results"],
|
|
85
|
+
validate_template=True,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Use input from first item if original_prompt is missing
|
|
89
|
+
query = original_prompt if original_prompt else (items[0].input or "Unknown Query")
|
|
90
|
+
|
|
91
|
+
prompt = (await prompt_template.ainvoke(input={"original_prompt": query, "results": results_str})).to_string()
|
|
92
|
+
|
|
93
|
+
logger.info("Asking Judge LLM to select the best response.")
|
|
94
|
+
judge_response = await self.judge_llm_bound.ainvoke(prompt)
|
|
95
|
+
judge_content = remove_r1_think_tags(
|
|
96
|
+
judge_response.content if hasattr(judge_response, 'content') else str(judge_response))
|
|
97
|
+
|
|
98
|
+
# Parse selection
|
|
99
|
+
# Expected format: 'SELECTED ITEM: <number>'
|
|
100
|
+
match = re.search(r'SELECTED ITEM:\s*(\d+)', judge_content, re.IGNORECASE)
|
|
101
|
+
if match:
|
|
102
|
+
try:
|
|
103
|
+
index = int(match.group(1)) - 1
|
|
104
|
+
if 0 <= index < len(items):
|
|
105
|
+
logger.info("Judge selected item %d", index + 1)
|
|
106
|
+
selected_item = items[index]
|
|
107
|
+
# Optionally attach judge's reasoning to metadata
|
|
108
|
+
if selected_item.metadata is None:
|
|
109
|
+
selected_item.metadata = {}
|
|
110
|
+
selected_item.metadata["judge_reasoning"] = judge_content
|
|
111
|
+
return [selected_item]
|
|
112
|
+
else:
|
|
113
|
+
logger.warning("Judge selected index %d which is out of range.", index + 1)
|
|
114
|
+
except ValueError:
|
|
115
|
+
logger.warning("Failed to parse integer from judge selection.")
|
|
116
|
+
|
|
117
|
+
logger.warning("Could not parse valid selection from judge response. "
|
|
118
|
+
"Returning first item as fallback.")
|
|
119
|
+
# Fallback to first item
|
|
120
|
+
return [items[0]]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@register_ttc_strategy(config_type=LLMJudgeSelectionConfig)
|
|
124
|
+
async def register_llm_judge_selection(config: LLMJudgeSelectionConfig, builder: Builder):
|
|
125
|
+
strategy = LLMJudgeSelection(config)
|
|
126
|
+
await strategy.build_components(builder)
|
|
127
|
+
yield strategy
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from nat.finetuning.interfaces.finetuning_runner import Trainer
|
|
17
|
+
from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
|
|
18
|
+
from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"Trainer",
|
|
22
|
+
"TrajectoryBuilder",
|
|
23
|
+
"TrainerAdapter",
|
|
24
|
+
]
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""Finetuning runtime for NAT that orchestrates the training process."""
|
|
17
|
+
|
|
18
|
+
import asyncio
|
|
19
|
+
import logging
|
|
20
|
+
|
|
21
|
+
from nat.data_models.finetuning import FinetuneRunConfig
|
|
22
|
+
from nat.data_models.finetuning import TrainingStatusEnum
|
|
23
|
+
from nat.finetuning.interfaces.finetuning_runner import Trainer
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def run_finetuning(runner: Trainer) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Run finetuning based on the provided configuration.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
runner: An instance of the Trainer to run finetuning with
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
# Initialize the runner
|
|
37
|
+
logger.info("Initializing finetuning runner...")
|
|
38
|
+
|
|
39
|
+
# Get number of epochs from config
|
|
40
|
+
num_epochs = runner.run_config.num_epochs
|
|
41
|
+
|
|
42
|
+
# Run training for specified epochs
|
|
43
|
+
logger.info("Starting training for %d epochs...", num_epochs)
|
|
44
|
+
job_statuses = await runner.run(num_epochs)
|
|
45
|
+
|
|
46
|
+
# Log final status
|
|
47
|
+
for status in job_statuses:
|
|
48
|
+
logger.info("Job %s completed with status: %s", status.run_id, status.status)
|
|
49
|
+
if status.message:
|
|
50
|
+
logger.info(" Message: %s", status.message)
|
|
51
|
+
|
|
52
|
+
# Get and log final metrics
|
|
53
|
+
if job_statuses:
|
|
54
|
+
final_run_id = job_statuses[-1].run_id
|
|
55
|
+
try:
|
|
56
|
+
metrics = await runner.get_metrics(final_run_id)
|
|
57
|
+
logger.info("Final metrics: %s", metrics)
|
|
58
|
+
except (ValueError, RuntimeError) as e:
|
|
59
|
+
logger.warning("Failed to retrieve metrics: %s", e)
|
|
60
|
+
|
|
61
|
+
# Log appropriate message based on job statuses
|
|
62
|
+
if not job_statuses:
|
|
63
|
+
logger.warning("Finetuning completed with no jobs executed.")
|
|
64
|
+
else:
|
|
65
|
+
failed_jobs = sum(1 for s in job_statuses if s.status == TrainingStatusEnum.FAILED)
|
|
66
|
+
canceled_jobs = sum(1 for s in job_statuses if s.status == TrainingStatusEnum.CANCELED)
|
|
67
|
+
completed_jobs = sum(1 for s in job_statuses if s.status == TrainingStatusEnum.COMPLETED)
|
|
68
|
+
|
|
69
|
+
if failed_jobs:
|
|
70
|
+
logger.error("Finetuning completed with %d failed job(s) out of %d total.",
|
|
71
|
+
failed_jobs,
|
|
72
|
+
len(job_statuses))
|
|
73
|
+
elif canceled_jobs:
|
|
74
|
+
logger.warning("Finetuning was canceled. %d job(s) were canceled out of %d total.",
|
|
75
|
+
canceled_jobs,
|
|
76
|
+
len(job_statuses))
|
|
77
|
+
elif completed_jobs == len(job_statuses):
|
|
78
|
+
logger.info("Finetuning completed successfully!")
|
|
79
|
+
else:
|
|
80
|
+
# Some jobs may still be pending or running (unexpected state)
|
|
81
|
+
logger.warning("Finetuning finished with %d completed, %d pending/running job(s).",
|
|
82
|
+
completed_jobs,
|
|
83
|
+
len(job_statuses) - completed_jobs)
|
|
84
|
+
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.error("Finetuning failed: %s", e)
|
|
87
|
+
raise
|
|
88
|
+
finally:
|
|
89
|
+
# Always cleanup resources
|
|
90
|
+
logger.info("Cleaning up finetuning resources...")
|
|
91
|
+
await runner.cleanup()
|
|
92
|
+
logger.info("Cleanup completed")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
async def finetuning_main(run_config: FinetuneRunConfig) -> None:
|
|
96
|
+
"""
|
|
97
|
+
Main entry point for finetuning runtime.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
run_config: FinetuneRunConfig object containing finetuning settings
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
104
|
+
from nat.runtime.loader import load_config
|
|
105
|
+
|
|
106
|
+
config = load_config(config_file=run_config.config_file)
|
|
107
|
+
finetuning_config = config.finetuning
|
|
108
|
+
finetuning_config.run_configuration = run_config
|
|
109
|
+
|
|
110
|
+
if not config.finetuning.enabled:
|
|
111
|
+
raise ValueError("Finetuning is not enabled in the provided configuration.")
|
|
112
|
+
|
|
113
|
+
async with WorkflowBuilder.from_config(config=config) as builder:
|
|
114
|
+
# Get trajectory builder and trainer adapter from builder
|
|
115
|
+
logger.info("Initializing finetuning components...")
|
|
116
|
+
trajectory_builder_name = finetuning_config.trajectory_builder
|
|
117
|
+
trainer_adapter_name = finetuning_config.trainer_adapter
|
|
118
|
+
trajectory_builder = await builder.get_trajectory_builder(trajectory_builder_name)
|
|
119
|
+
trainer_adapter = await builder.get_trainer_adapter(trainer_adapter_name)
|
|
120
|
+
logger.info("Finetuning components initialized.")
|
|
121
|
+
|
|
122
|
+
# Initialize trainer
|
|
123
|
+
trainer_name = finetuning_config.trainer
|
|
124
|
+
trainer = await builder.get_trainer(trainer_name,
|
|
125
|
+
trajectory_builder=trajectory_builder,
|
|
126
|
+
trainer_adapter=trainer_adapter)
|
|
127
|
+
|
|
128
|
+
await trainer.initialize(run_config=finetuning_config)
|
|
129
|
+
|
|
130
|
+
logger.info("Initialized trainer: %s", trainer_name)
|
|
131
|
+
|
|
132
|
+
# Run finetuning
|
|
133
|
+
await run_finetuning(trainer)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def run_finetuning_sync(run_config: FinetuneRunConfig) -> None:
|
|
137
|
+
"""
|
|
138
|
+
Synchronous wrapper for running finetuning.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
run_config: FinetuneRunConfig object containing finetuning settings
|
|
142
|
+
"""
|
|
143
|
+
asyncio.run(finetuning_main(run_config))
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from nat.finetuning.interfaces.finetuning_runner import Trainer
|
|
17
|
+
from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
|
|
18
|
+
from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"Trainer",
|
|
22
|
+
"TrajectoryBuilder",
|
|
23
|
+
"TrainerAdapter",
|
|
24
|
+
]
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from abc import ABC
|
|
18
|
+
from abc import abstractmethod
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from nat.data_models.finetuning import FinetuneConfig
|
|
22
|
+
from nat.data_models.finetuning import FinetuneRunConfig
|
|
23
|
+
from nat.data_models.finetuning import TrainerConfig
|
|
24
|
+
from nat.data_models.finetuning import TrainingJobRef
|
|
25
|
+
from nat.data_models.finetuning import TrainingJobStatus
|
|
26
|
+
from nat.data_models.finetuning import TrajectoryCollection
|
|
27
|
+
from nat.eval.config import EvaluationRunOutput
|
|
28
|
+
from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
|
|
29
|
+
from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Trainer(ABC):
|
|
35
|
+
"""
|
|
36
|
+
Abstract interface for running finetuning workflows.
|
|
37
|
+
|
|
38
|
+
The Trainer orchestrates the entire finetuning process by:
|
|
39
|
+
1. Running evaluations to generate trajectories via TrajectoryBuilder
|
|
40
|
+
2. Submitting trajectories for training via TrainerAdapter
|
|
41
|
+
3. Managing multiple epochs of training
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, trainer_config: TrainerConfig, **kwargs) -> None:
|
|
45
|
+
"""
|
|
46
|
+
Initialize the Trainer.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
trainer_config: Configuration for the trainer backend
|
|
50
|
+
run_config: Configuration for the training run
|
|
51
|
+
backend: Backend identifier
|
|
52
|
+
curriculum_config: Optional curriculum learning configuration
|
|
53
|
+
"""
|
|
54
|
+
self.trainer_config = trainer_config
|
|
55
|
+
self.run_config: FinetuneConfig = None
|
|
56
|
+
self.curriculum_config = None
|
|
57
|
+
self.trajectory_builder: TrajectoryBuilder = None
|
|
58
|
+
self.trainer_adapter: TrainerAdapter = None
|
|
59
|
+
|
|
60
|
+
# Curriculum learning state
|
|
61
|
+
self._curriculum_state = None
|
|
62
|
+
|
|
63
|
+
async def bind_components(self, trajectory_builder: TrajectoryBuilder, trainer_adapter: TrainerAdapter) -> None:
|
|
64
|
+
"""
|
|
65
|
+
Bind the TrajectoryBuilder and TrainerAdapter components.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
trajectory_builder: Instance of TrajectoryBuilder
|
|
69
|
+
trainer_adapter: Instance of TrainerAdapter
|
|
70
|
+
"""
|
|
71
|
+
self.trajectory_builder = trajectory_builder
|
|
72
|
+
self.trainer_adapter = trainer_adapter
|
|
73
|
+
|
|
74
|
+
async def initialize(self, run_config: FinetuneConfig) -> None:
|
|
75
|
+
"""
|
|
76
|
+
Initialize the runner and its components.
|
|
77
|
+
|
|
78
|
+
This should:
|
|
79
|
+
- Initialize the TrajectoryBuilder
|
|
80
|
+
- Initialize the TrainerAdapter
|
|
81
|
+
- Verify connectivity to backend services
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
self.run_config = run_config
|
|
85
|
+
self.curriculum_config = self.run_config.curriculum_learning
|
|
86
|
+
self._curriculum_state = {
|
|
87
|
+
"current_percentile": self.curriculum_config.initial_percentile,
|
|
88
|
+
"last_expansion_epoch": -1,
|
|
89
|
+
"total_groups": 0,
|
|
90
|
+
"included_groups": set()
|
|
91
|
+
}
|
|
92
|
+
self.trainer_config.reward = self.run_config.reward_function
|
|
93
|
+
|
|
94
|
+
await self.trajectory_builder.initialize(run_config)
|
|
95
|
+
await self.trainer_adapter.initialize(run_config)
|
|
96
|
+
|
|
97
|
+
@abstractmethod
|
|
98
|
+
async def run_epoch(self, epoch: int, run_id: str) -> TrainingJobRef:
|
|
99
|
+
"""
|
|
100
|
+
Run a single epoch of training.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
epoch: The current epoch number (0-indexed)
|
|
104
|
+
run_id: Unique identifier for this training run
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
TrainingJobRef: Reference to the submitted training job
|
|
108
|
+
"""
|
|
109
|
+
raise NotImplementedError
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
async def run(self, num_epochs: int) -> list[TrainingJobStatus]:
|
|
113
|
+
"""
|
|
114
|
+
Run the complete finetuning workflow for the specified number of epochs.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
num_epochs: Number of epochs to train
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
list[TrainingJobStatus]: Status of all training jobs
|
|
121
|
+
"""
|
|
122
|
+
raise NotImplementedError
|
|
123
|
+
|
|
124
|
+
@abstractmethod
|
|
125
|
+
async def get_metrics(self, run_id: str) -> dict[str, Any]:
|
|
126
|
+
"""
|
|
127
|
+
Get training metrics for a specific run.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
run_id: The run identifier
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
dict: Metrics from the training run
|
|
134
|
+
"""
|
|
135
|
+
raise NotImplementedError
|
|
136
|
+
|
|
137
|
+
@abstractmethod
|
|
138
|
+
async def cleanup(self) -> None:
|
|
139
|
+
"""
|
|
140
|
+
Clean up any resources used by the runner.
|
|
141
|
+
"""
|
|
142
|
+
raise NotImplementedError
|
|
143
|
+
|
|
144
|
+
@abstractmethod
|
|
145
|
+
def log_progress(self, epoch: int, metrics: dict[str, Any], output_dir: str | None = None) -> None:
|
|
146
|
+
"""
|
|
147
|
+
Log training progress for monitoring.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
epoch: Current epoch number
|
|
151
|
+
metrics: Dictionary of metrics to log
|
|
152
|
+
output_dir: Optional output directory override
|
|
153
|
+
"""
|
|
154
|
+
raise NotImplementedError
|
|
155
|
+
|
|
156
|
+
async def run_validation_evaluation(self, epoch: int, run_id: str) -> dict[str, Any]:
|
|
157
|
+
"""
|
|
158
|
+
Run evaluation on validation dataset to collect rewards.
|
|
159
|
+
|
|
160
|
+
This method creates a temporary TrainerRunConfig with the validation
|
|
161
|
+
dataset and runs evaluation to collect rewards without training.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
epoch: Current epoch number
|
|
165
|
+
run_id: Unique identifier for this training run
|
|
166
|
+
validation_dataset: Path to the validation dataset
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
dict: Validation metrics including average reward
|
|
170
|
+
"""
|
|
171
|
+
logger.info("Running validation evaluation for epoch %d", epoch + 1)
|
|
172
|
+
|
|
173
|
+
config = self.run_config.run_configuration.validation_config_file if (
|
|
174
|
+
self.run_config.run_configuration.validation_config_file) else self.run_config.run_configuration.config_file
|
|
175
|
+
|
|
176
|
+
# Create a temporary run config with validation dataset
|
|
177
|
+
validation_run_config = FinetuneRunConfig(config_file=config,
|
|
178
|
+
dataset=self.run_config.run_configuration.validation_dataset,
|
|
179
|
+
result_json_path=self.run_config.run_configuration.result_json_path,
|
|
180
|
+
endpoint=self.run_config.run_configuration.endpoint,
|
|
181
|
+
endpoint_timeout=self.run_config.run_configuration.endpoint_timeout,
|
|
182
|
+
override=self.run_config.run_configuration.override)
|
|
183
|
+
|
|
184
|
+
# Create a temporary trajectory builder for validation
|
|
185
|
+
validation_builder = self.trajectory_builder
|
|
186
|
+
original_run_config = validation_builder.run_config.run_configuration
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
|
|
190
|
+
validation_builder.run_config.run_configuration = validation_run_config
|
|
191
|
+
|
|
192
|
+
# Run evaluation
|
|
193
|
+
eval_output = await validation_builder.run_eval()
|
|
194
|
+
|
|
195
|
+
# Calculate validation metrics from eval output
|
|
196
|
+
validation_metrics = self._calculate_validation_metrics(eval_output)
|
|
197
|
+
validation_metrics["epoch"] = epoch
|
|
198
|
+
validation_metrics["dataset_type"] = "validation"
|
|
199
|
+
|
|
200
|
+
logger.info("Validation metrics for epoch %d: %s", epoch, validation_metrics)
|
|
201
|
+
return validation_metrics
|
|
202
|
+
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logger.error("Error during validation evaluation: %s", e)
|
|
205
|
+
return {"epoch": epoch, "dataset_type": "validation", "error": str(e), "avg_reward": 0.0, "num_examples": 0}
|
|
206
|
+
finally:
|
|
207
|
+
# Restore original run config
|
|
208
|
+
validation_builder.run_config.run_configuration = original_run_config
|
|
209
|
+
|
|
210
|
+
def _calculate_validation_metrics(self, eval_output: EvaluationRunOutput) -> dict[str, Any]:
|
|
211
|
+
"""
|
|
212
|
+
Calculate validation metrics from evaluation output.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
eval_output: Output from evaluation run
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
dict: Calculated metrics
|
|
219
|
+
"""
|
|
220
|
+
# Default implementation - subclasses can override for
|
|
221
|
+
# backend-specific metrics
|
|
222
|
+
metrics = {"avg_reward": 0.0, "min_reward": 0.0, "max_reward": 0.0, "num_examples": 0}
|
|
223
|
+
|
|
224
|
+
rewards = []
|
|
225
|
+
for metric_name, metric_value in eval_output.evaluation_results:
|
|
226
|
+
if metric_name == self.trainer_config.reward.name:
|
|
227
|
+
reward_results = metric_value.eval_output_items
|
|
228
|
+
for reward_item in reward_results:
|
|
229
|
+
rewards.append(reward_item.score)
|
|
230
|
+
|
|
231
|
+
if rewards:
|
|
232
|
+
metrics["avg_reward"] = sum(rewards) / len(rewards)
|
|
233
|
+
metrics["min_reward"] = min(rewards)
|
|
234
|
+
metrics["max_reward"] = max(rewards)
|
|
235
|
+
metrics["num_examples"] = len(rewards)
|
|
236
|
+
|
|
237
|
+
return metrics
|
|
238
|
+
|
|
239
|
+
def apply_curriculum_learning(self, trajectory_collection: TrajectoryCollection,
|
|
240
|
+
epoch: int) -> TrajectoryCollection:
|
|
241
|
+
"""
|
|
242
|
+
Apply curriculum learning to filter trajectory groups based on difficulty.
|
|
243
|
+
"""
|
|
244
|
+
raise NotImplementedError("Curriculum learning not implemented for this backend.")
|
|
245
|
+
|
|
246
|
+
def get_curriculum_state(self) -> dict[str, Any]:
|
|
247
|
+
"""
|
|
248
|
+
Get the current state of curriculum learning.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
dict: Current curriculum state including percentile and group statistics
|
|
252
|
+
"""
|
|
253
|
+
# Convert set to list for JSON serialization
|
|
254
|
+
state = {
|
|
255
|
+
"current_percentile": self._curriculum_state["current_percentile"],
|
|
256
|
+
"last_expansion_epoch": self._curriculum_state["last_expansion_epoch"],
|
|
257
|
+
"total_groups": self._curriculum_state["total_groups"],
|
|
258
|
+
"included_groups": list(self._curriculum_state["included_groups"]),
|
|
259
|
+
"config": self.curriculum_config.model_dump() if self.curriculum_config else None
|
|
260
|
+
}
|
|
261
|
+
return state
|