nvidia-nat 1.1.0a20251020__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +265 -0
- nat/agent/dual_node.py +72 -0
- nat/agent/prompt_optimizer/__init__.py +0 -0
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +394 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +168 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +227 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +593 -0
- nat/agent/rewoo_agent/prompt.py +107 -0
- nat/agent/rewoo_agent/register.py +175 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +246 -0
- nat/agent/tool_calling_agent/register.py +129 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +96 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +140 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +20 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +317 -0
- nat/builder/component_utils.py +320 -0
- nat/builder/context.py +321 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +166 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +25 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +714 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +625 -0
- nat/builder/intermediate_step_manager.py +206 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +160 -0
- nat/builder/workflow_builder.py +1365 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +47 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +128 -0
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +153 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +257 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +17 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +25 -0
- nat/cli/commands/workflow/templates/register.py.j2 +4 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +50 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +403 -0
- nat/cli/entrypoint.py +141 -0
- nat/cli/main.py +60 -0
- nat/cli/register_workflow.py +522 -0
- nat/cli/type_registry.py +1069 -0
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +843 -0
- nat/data_models/authentication.py +245 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +60 -0
- nat/data_models/component_ref.py +179 -0
- nat/data_models/config.py +434 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +130 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +64 -0
- nat/data_models/function_dependencies.py +80 -0
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +228 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +42 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +62 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +431 -0
- nat/eval/evaluate.py +565 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +58 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +26 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +242 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +193 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +154 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +228 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +67 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +35 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +157 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +285 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +108 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +142 -0
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +272 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +247 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1257 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +602 -0
- nat/front_ends/fastapi/main.py +64 -0
- nat/front_ends/fastapi/message_handler.py +344 -0
- nat/front_ends/fastapi/message_validator.py +351 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +90 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +113 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +268 -0
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +290 -0
- nat/front_ends/register.py +21 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +56 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +69 -0
- nat/llm/azure_openai_llm.py +57 -0
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +58 -0
- nat/llm/openai_llm.py +54 -0
- nat/llm/register.py +27 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +93 -0
- nat/llm/utils/error.py +17 -0
- nat/llm/utils/thinking.py +215 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +19 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +550 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +308 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +496 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +308 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +74 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +114 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +626 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +297 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +180 -0
- nat/profiler/decorators/function_tracking.py +411 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +42 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +404 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +478 -0
- nat/profiler/utils.py +186 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +570 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +248 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +20 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +236 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +21 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +292 -0
- nat/runtime/session.py +223 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +329 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +77 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +82 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +66 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +31 -0
- nat/tool/retriever.py +95 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/decorators.py +210 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +342 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_levels.py +25 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +195 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +299 -0
- nat/utils/type_utils.py +488 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.1.0a20251020.dist-info/METADATA +195 -0
- nvidia_nat-1.1.0a20251020.dist-info/RECORD +480 -0
- nvidia_nat-1.1.0a20251020.dist-info/WHEEL +5 -0
- nvidia_nat-1.1.0a20251020.dist-info/entry_points.txt +22 -0
- nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.1.0a20251020.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import re
|
|
18
|
+
|
|
19
|
+
from nat.builder.builder import Builder
|
|
20
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
21
|
+
from nat.cli.register_workflow import register_ttc_strategy
|
|
22
|
+
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
23
|
+
from nat.experimental.test_time_compute.models.selection_config import LLMBasedPlanSelectionConfig
|
|
24
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
25
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
26
|
+
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
27
|
+
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
28
|
+
from nat.utils.io.model_processing import remove_r1_think_tags
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LLMBasedPlanSelector(StrategyBase):
|
|
34
|
+
|
|
35
|
+
def __init__(self, config: TTCStrategyBaseConfig) -> None:
|
|
36
|
+
super().__init__(config)
|
|
37
|
+
self.llm_bound = None
|
|
38
|
+
|
|
39
|
+
async def build_components(self, builder: Builder) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Build the components required for the selector.
|
|
42
|
+
"""
|
|
43
|
+
self.llm_bound = await builder.get_llm(self.config.selection_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
44
|
+
|
|
45
|
+
def supported_pipeline_types(self) -> [PipelineTypeEnum]:
|
|
46
|
+
return [PipelineTypeEnum.PLANNING]
|
|
47
|
+
|
|
48
|
+
def stage_type(self) -> StageTypeEnum:
|
|
49
|
+
return StageTypeEnum.SELECTION
|
|
50
|
+
|
|
51
|
+
async def ainvoke(self,
|
|
52
|
+
items: list[TTCItem],
|
|
53
|
+
original_prompt: str | None = None,
|
|
54
|
+
agent_context: str | None = None,
|
|
55
|
+
**kwargs) -> [TTCItem]:
|
|
56
|
+
"""
|
|
57
|
+
Select the planning items based on the selection strategy.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
original_prompt (str): The prompt the user provided the agent.
|
|
61
|
+
agent_context (str): The context of the agent, if applicable.
|
|
62
|
+
items (list[TTCItem]): The list of planning items to select from.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
TTCItem: The selected planning item.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
from langchain_core.language_models import BaseChatModel
|
|
70
|
+
from langchain_core.prompts import PromptTemplate
|
|
71
|
+
except ImportError:
|
|
72
|
+
raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
|
|
73
|
+
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
74
|
+
|
|
75
|
+
if not isinstance(self.llm_bound, BaseChatModel):
|
|
76
|
+
raise ValueError("The `selection_llm` must be an instance of `BaseChatModel`.")
|
|
77
|
+
|
|
78
|
+
model: BaseChatModel = self.llm_bound
|
|
79
|
+
|
|
80
|
+
plans = ""
|
|
81
|
+
for idx, item in enumerate(items):
|
|
82
|
+
plans += f"{idx + 1}. {remove_r1_think_tags(item.plan)}\n"
|
|
83
|
+
|
|
84
|
+
prompt_template = PromptTemplate(
|
|
85
|
+
template=self.config.selection_template,
|
|
86
|
+
input_variables=["original_prompt", "context", "plans"],
|
|
87
|
+
validate_template=True,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
prompt = (await prompt_template.ainvoke(input={
|
|
91
|
+
"original_prompt": original_prompt, "context": agent_context, "plans": plans
|
|
92
|
+
})).to_string()
|
|
93
|
+
|
|
94
|
+
selected_plan_index = remove_r1_think_tags((await model.ainvoke(prompt)).content)
|
|
95
|
+
|
|
96
|
+
# Model Response will be 'Plan {plan number}'
|
|
97
|
+
# Use RegEx to extrac Plan {idx} from response strong
|
|
98
|
+
if not isinstance(selected_plan_index, str):
|
|
99
|
+
logger.warning(f"Invalid response from LLM for selected plan index: {selected_plan_index}.")
|
|
100
|
+
raise ValueError("Unable to parse the selected plan index.")
|
|
101
|
+
selected_plan_index = selected_plan_index.strip()
|
|
102
|
+
match = re.match(r'^\s*SELECTED PLAN:\s+(\d+)', selected_plan_index)
|
|
103
|
+
if not match:
|
|
104
|
+
logger.warning(f"Could not parse the selected plan index from the response: {selected_plan_index}.")
|
|
105
|
+
raise ValueError("The response format for selecting the plan is incorrect.")
|
|
106
|
+
index = match.group(1)
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
selected_index = int(index) - 1
|
|
110
|
+
if selected_index < 0 or selected_index >= len(items):
|
|
111
|
+
raise ValueError("Selected index is out of range.")
|
|
112
|
+
|
|
113
|
+
# Return the selected planning item
|
|
114
|
+
return [items[selected_index]]
|
|
115
|
+
except ValueError as e:
|
|
116
|
+
logger.warning(f"Error parsing the selected plan index: {index}. Exception: {str(e)}")
|
|
117
|
+
raise ValueError(f"Failed to parse the selected plan index from the LLM response: {selected_plan_index}. "
|
|
118
|
+
"Ensure the response follows the expected format.") from e
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@register_ttc_strategy(config_type=LLMBasedPlanSelectionConfig)
|
|
122
|
+
async def register_llm_based_plan_selection(config: LLMBasedPlanSelectionConfig, builder: Builder):
|
|
123
|
+
"""
|
|
124
|
+
Register the LLMBasedPlanSelector with the provided configuration.
|
|
125
|
+
"""
|
|
126
|
+
selector = LLMBasedPlanSelector(config)
|
|
127
|
+
await selector.build_components(Builder())
|
|
128
|
+
yield selector
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from nat.builder.builder import Builder
|
|
19
|
+
from nat.cli.register_workflow import register_ttc_strategy
|
|
20
|
+
from nat.experimental.test_time_compute.models.selection_config import ThresholdSelectionConfig
|
|
21
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
22
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
23
|
+
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
24
|
+
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ThresholdSelector(StrategyBase):
|
|
30
|
+
"""
|
|
31
|
+
Downselects only those TTCItems whose 'score' >= config.threshold.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
async def build_components(self, builder: Builder) -> None:
|
|
35
|
+
# No special components needed
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
|
|
39
|
+
return [PipelineTypeEnum.TOOL_USE]
|
|
40
|
+
|
|
41
|
+
def stage_type(self) -> StageTypeEnum:
|
|
42
|
+
return StageTypeEnum.SELECTION
|
|
43
|
+
|
|
44
|
+
async def ainvoke(self,
|
|
45
|
+
items: list[TTCItem],
|
|
46
|
+
original_prompt: str | None = None,
|
|
47
|
+
agent_context: str | None = None,
|
|
48
|
+
**kwargs) -> list[TTCItem]:
|
|
49
|
+
threshold = self.config.threshold
|
|
50
|
+
selected = [itm for itm in items if (itm.score is not None and itm.score >= threshold)]
|
|
51
|
+
logger.info("ThresholdSelector: %d items => %d items (threshold=%.1f)", len(items), len(selected), threshold)
|
|
52
|
+
return selected
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@register_ttc_strategy(config_type=ThresholdSelectionConfig)
|
|
56
|
+
async def register_threshold_selector(config: ThresholdSelectionConfig, builder: Builder):
|
|
57
|
+
selector = ThresholdSelector(config)
|
|
58
|
+
yield selector
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
import secrets
|
|
19
|
+
import webbrowser
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
from dataclasses import field
|
|
22
|
+
|
|
23
|
+
import click
|
|
24
|
+
import httpx
|
|
25
|
+
import pkce
|
|
26
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
27
|
+
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
28
|
+
from fastapi import FastAPI
|
|
29
|
+
from fastapi import Request
|
|
30
|
+
|
|
31
|
+
from nat.authentication.interfaces import FlowHandlerBase
|
|
32
|
+
from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
|
|
33
|
+
from nat.data_models.authentication import AuthenticatedContext
|
|
34
|
+
from nat.data_models.authentication import AuthFlowType
|
|
35
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
36
|
+
from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# --------------------------------------------------------------------------- #
|
|
42
|
+
# Helpers #
|
|
43
|
+
# --------------------------------------------------------------------------- #
|
|
44
|
+
@dataclass
|
|
45
|
+
class _FlowState:
|
|
46
|
+
future: asyncio.Future = field(default_factory=asyncio.Future, init=False)
|
|
47
|
+
challenge: str | None = None
|
|
48
|
+
verifier: str | None = None
|
|
49
|
+
token_url: str | None = None
|
|
50
|
+
use_pkce: bool | None = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# --------------------------------------------------------------------------- #
|
|
54
|
+
# Main handler #
|
|
55
|
+
# --------------------------------------------------------------------------- #
|
|
56
|
+
class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
57
|
+
"""
|
|
58
|
+
Authentication helper for CLI / console environments. Supports:
|
|
59
|
+
|
|
60
|
+
• HTTP Basic (username/password)
|
|
61
|
+
• OAuth 2 Authorization‑Code with optional PKCE
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
# ----------------------------- lifecycle ----------------------------- #
|
|
65
|
+
def __init__(self) -> None:
|
|
66
|
+
super().__init__()
|
|
67
|
+
self._server_controller: _FastApiFrontEndController | None = None
|
|
68
|
+
self._redirect_app: FastAPI | None = None # ★ NEW
|
|
69
|
+
self._flows: dict[str, _FlowState] = {}
|
|
70
|
+
self._active_flows = 0
|
|
71
|
+
self._server_lock = asyncio.Lock()
|
|
72
|
+
self._oauth_client: AsyncOAuth2Client | None = None
|
|
73
|
+
|
|
74
|
+
# ----------------------------- public API ---------------------------- #
|
|
75
|
+
async def authenticate(
|
|
76
|
+
self,
|
|
77
|
+
config: AuthProviderBaseConfig,
|
|
78
|
+
method: AuthFlowType,
|
|
79
|
+
) -> AuthenticatedContext:
|
|
80
|
+
if method == AuthFlowType.HTTP_BASIC:
|
|
81
|
+
return self._handle_http_basic()
|
|
82
|
+
if method == AuthFlowType.OAUTH2_AUTHORIZATION_CODE:
|
|
83
|
+
if (not isinstance(config, OAuth2AuthCodeFlowProviderConfig)):
|
|
84
|
+
raise ValueError("Requested OAuth2 Authorization Code Flow but passed invalid config")
|
|
85
|
+
|
|
86
|
+
return await self._handle_oauth2_auth_code_flow(config)
|
|
87
|
+
|
|
88
|
+
raise NotImplementedError(f"Auth method “{method}” not supported.")
|
|
89
|
+
|
|
90
|
+
# --------------------- OAuth2 helper factories ----------------------- #
|
|
91
|
+
def construct_oauth_client(self, cfg: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client:
|
|
92
|
+
"""
|
|
93
|
+
Separated for easy overriding in tests (to inject ASGITransport).
|
|
94
|
+
"""
|
|
95
|
+
try:
|
|
96
|
+
client = AsyncOAuth2Client(
|
|
97
|
+
client_id=cfg.client_id,
|
|
98
|
+
client_secret=cfg.client_secret,
|
|
99
|
+
redirect_uri=cfg.redirect_uri,
|
|
100
|
+
scope=" ".join(cfg.scopes) if cfg.scopes else None,
|
|
101
|
+
token_endpoint=cfg.token_url,
|
|
102
|
+
token_endpoint_auth_method=cfg.token_endpoint_auth_method,
|
|
103
|
+
code_challenge_method="S256" if cfg.use_pkce else None,
|
|
104
|
+
)
|
|
105
|
+
self._oauth_client = client
|
|
106
|
+
return client
|
|
107
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
108
|
+
raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
|
|
109
|
+
except Exception as e:
|
|
110
|
+
raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
|
|
111
|
+
|
|
112
|
+
def _create_authorization_url(self,
|
|
113
|
+
client: AsyncOAuth2Client,
|
|
114
|
+
config: OAuth2AuthCodeFlowProviderConfig,
|
|
115
|
+
state: str,
|
|
116
|
+
verifier: str | None = None,
|
|
117
|
+
challenge: str | None = None) -> str:
|
|
118
|
+
"""
|
|
119
|
+
Create OAuth authorization URL with proper error handling.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
client: The OAuth2 client instance
|
|
123
|
+
config: OAuth2 configuration
|
|
124
|
+
state: OAuth state parameter
|
|
125
|
+
verifier: PKCE verifier (if using PKCE)
|
|
126
|
+
challenge: PKCE challenge (if using PKCE)
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
The authorization URL
|
|
130
|
+
"""
|
|
131
|
+
try:
|
|
132
|
+
auth_url, _ = client.create_authorization_url(
|
|
133
|
+
config.authorization_url,
|
|
134
|
+
state=state,
|
|
135
|
+
code_verifier=verifier if config.use_pkce else None,
|
|
136
|
+
code_challenge=challenge if config.use_pkce else None,
|
|
137
|
+
**(config.authorization_kwargs or {})
|
|
138
|
+
)
|
|
139
|
+
return auth_url
|
|
140
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
141
|
+
raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
|
|
142
|
+
|
|
143
|
+
# --------------------------- HTTP Basic ------------------------------ #
|
|
144
|
+
@staticmethod
|
|
145
|
+
def _handle_http_basic() -> AuthenticatedContext:
|
|
146
|
+
username = click.prompt("Username", type=str)
|
|
147
|
+
password = click.prompt("Password", type=str, hide_input=True)
|
|
148
|
+
|
|
149
|
+
import base64
|
|
150
|
+
credentials = f"{username}:{password}"
|
|
151
|
+
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("ascii")
|
|
152
|
+
|
|
153
|
+
return AuthenticatedContext(
|
|
154
|
+
headers={"Authorization": f"Bearer {encoded_credentials}"},
|
|
155
|
+
metadata={
|
|
156
|
+
"username": username, "password": password
|
|
157
|
+
},
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# --------------------- OAuth2 Authorization‑Code --------------------- #
|
|
161
|
+
async def _handle_oauth2_auth_code_flow(self, cfg: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
|
|
162
|
+
state = secrets.token_urlsafe(16)
|
|
163
|
+
flow_state = _FlowState()
|
|
164
|
+
client = self.construct_oauth_client(cfg)
|
|
165
|
+
|
|
166
|
+
flow_state.token_url = cfg.token_url
|
|
167
|
+
flow_state.use_pkce = cfg.use_pkce
|
|
168
|
+
|
|
169
|
+
# PKCE bits
|
|
170
|
+
if cfg.use_pkce:
|
|
171
|
+
verifier, challenge = pkce.generate_pkce_pair()
|
|
172
|
+
flow_state.verifier = verifier
|
|
173
|
+
flow_state.challenge = challenge
|
|
174
|
+
|
|
175
|
+
# Create authorization URL using helper function
|
|
176
|
+
auth_url = self._create_authorization_url(client=client,
|
|
177
|
+
config=cfg,
|
|
178
|
+
state=state,
|
|
179
|
+
verifier=flow_state.verifier,
|
|
180
|
+
challenge=flow_state.challenge)
|
|
181
|
+
|
|
182
|
+
# Register flow + maybe spin up redirect handler
|
|
183
|
+
async with self._server_lock:
|
|
184
|
+
if (not self._redirect_app):
|
|
185
|
+
self._redirect_app = await self._build_redirect_app()
|
|
186
|
+
|
|
187
|
+
await self._start_redirect_server()
|
|
188
|
+
|
|
189
|
+
self._flows[state] = flow_state
|
|
190
|
+
self._active_flows += 1
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
webbrowser.open(auth_url)
|
|
194
|
+
click.echo("Your browser has been opened for authentication.")
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.error("Browser open failed: %s", e)
|
|
197
|
+
raise RuntimeError(f"Browser open failed: {e}") from e
|
|
198
|
+
|
|
199
|
+
# Wait for the redirect to land
|
|
200
|
+
try:
|
|
201
|
+
token = await asyncio.wait_for(flow_state.future, timeout=300)
|
|
202
|
+
except TimeoutError as exc:
|
|
203
|
+
raise RuntimeError("Authentication timed out (5 min).") from exc
|
|
204
|
+
finally:
|
|
205
|
+
async with self._server_lock:
|
|
206
|
+
self._flows.pop(state, None)
|
|
207
|
+
self._active_flows -= 1
|
|
208
|
+
|
|
209
|
+
if self._active_flows == 0:
|
|
210
|
+
await self._stop_redirect_server()
|
|
211
|
+
|
|
212
|
+
return AuthenticatedContext(
|
|
213
|
+
headers={"Authorization": f"Bearer {token['access_token']}"},
|
|
214
|
+
metadata={
|
|
215
|
+
"expires_at": token.get("expires_at"), "raw_token": token
|
|
216
|
+
},
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# --------------- redirect server / in‑process app -------------------- #
|
|
220
|
+
async def _build_redirect_app(self) -> FastAPI:
|
|
221
|
+
"""
|
|
222
|
+
* If cfg.run_redirect_local_server == True → start a local server.
|
|
223
|
+
* Else → only build the redirect app and save it to `self._redirect_app`
|
|
224
|
+
for in‑process testing.
|
|
225
|
+
"""
|
|
226
|
+
app = FastAPI()
|
|
227
|
+
|
|
228
|
+
@app.get("/auth/redirect")
|
|
229
|
+
async def handle_redirect(request: Request):
|
|
230
|
+
state = request.query_params.get("state")
|
|
231
|
+
if not state or state not in self._flows:
|
|
232
|
+
return "Invalid state; restart authentication."
|
|
233
|
+
flow_state = self._flows[state]
|
|
234
|
+
try:
|
|
235
|
+
token = await self._oauth_client.fetch_token( # type: ignore[arg-type]
|
|
236
|
+
url=flow_state.token_url,
|
|
237
|
+
authorization_response=str(request.url),
|
|
238
|
+
code_verifier=flow_state.verifier if flow_state.use_pkce else None,
|
|
239
|
+
state=state,
|
|
240
|
+
)
|
|
241
|
+
flow_state.future.set_result(token)
|
|
242
|
+
except OAuthError as e:
|
|
243
|
+
flow_state.future.set_exception(
|
|
244
|
+
RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
|
|
245
|
+
return "Authentication failed: Authorization server rejected the request. You may close this tab."
|
|
246
|
+
except httpx.HTTPError as e:
|
|
247
|
+
flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
|
|
248
|
+
return "Authentication failed: Network error occurred. You may close this tab."
|
|
249
|
+
except Exception as e:
|
|
250
|
+
flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
|
|
251
|
+
return "Authentication failed: An unexpected error occurred. You may close this tab."
|
|
252
|
+
return "Authentication successful – you may close this tab."
|
|
253
|
+
|
|
254
|
+
return app
|
|
255
|
+
|
|
256
|
+
async def _start_redirect_server(self) -> None:
|
|
257
|
+
# If the server is already running, do nothing
|
|
258
|
+
if self._server_controller:
|
|
259
|
+
return
|
|
260
|
+
try:
|
|
261
|
+
if not self._redirect_app:
|
|
262
|
+
raise RuntimeError("Redirect app not built.")
|
|
263
|
+
|
|
264
|
+
self._server_controller = _FastApiFrontEndController(self._redirect_app)
|
|
265
|
+
|
|
266
|
+
asyncio.create_task(self._server_controller.start_server(host="localhost", port=8000))
|
|
267
|
+
|
|
268
|
+
# Give the server a moment to bind sockets before we return
|
|
269
|
+
await asyncio.sleep(0.3)
|
|
270
|
+
except Exception as exc: # noqa: BLE001
|
|
271
|
+
raise RuntimeError(f"Failed to start redirect server: {exc}") from exc
|
|
272
|
+
|
|
273
|
+
async def _stop_redirect_server(self) -> None:
|
|
274
|
+
if self._server_controller:
|
|
275
|
+
await self._server_controller.stop_server()
|
|
276
|
+
self._server_controller = None
|
|
277
|
+
|
|
278
|
+
# ------------------------- test helpers ------------------------------ #
|
|
279
|
+
@property
|
|
280
|
+
def redirect_app(self) -> FastAPI | None:
|
|
281
|
+
"""
|
|
282
|
+
In test mode (run_redirect_local_server=False) the in‑memory redirect
|
|
283
|
+
app is exposed for testing purposes.
|
|
284
|
+
"""
|
|
285
|
+
return self._redirect_app
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from nat.data_models.front_end import FrontEndBaseConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ConsoleFrontEndConfig(FrontEndBaseConfig, name="console"):
|
|
24
|
+
"""
|
|
25
|
+
A front end that allows a NAT workflow to be run from the console.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
input_query: list[str] | None = Field(default=None,
|
|
29
|
+
alias="input",
|
|
30
|
+
description="A single input to submit the the workflow.")
|
|
31
|
+
input_file: Path | None = Field(default=None,
|
|
32
|
+
description="Path to a json file of inputs to submit to the workflow.")
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
import click
|
|
20
|
+
from colorama import Fore
|
|
21
|
+
|
|
22
|
+
from nat.data_models.interactive import HumanPromptModelType
|
|
23
|
+
from nat.data_models.interactive import HumanResponse
|
|
24
|
+
from nat.data_models.interactive import HumanResponseText
|
|
25
|
+
from nat.data_models.interactive import InteractionPrompt
|
|
26
|
+
from nat.front_ends.console.authentication_flow_handler import ConsoleAuthenticationFlowHandler
|
|
27
|
+
from nat.front_ends.console.console_front_end_config import ConsoleFrontEndConfig
|
|
28
|
+
from nat.front_ends.simple_base.simple_front_end_plugin_base import SimpleFrontEndPluginBase
|
|
29
|
+
from nat.runtime.session import SessionManager
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
async def prompt_for_input_cli(question: InteractionPrompt) -> HumanResponse:
|
|
35
|
+
"""
|
|
36
|
+
A simple CLI-based callback.
|
|
37
|
+
Takes question as str, returns the typed line as str.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
if question.content.input_type == HumanPromptModelType.TEXT:
|
|
41
|
+
user_response = click.prompt(text=question.content.text)
|
|
42
|
+
|
|
43
|
+
return HumanResponseText(text=user_response)
|
|
44
|
+
|
|
45
|
+
raise ValueError("Unsupported human prompt input type. The run command only supports the 'HumanPromptText' "
|
|
46
|
+
"input type. Please use the 'serve' command to ensure full support for all input types.")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
50
|
+
|
|
51
|
+
def __init__(self, full_config):
|
|
52
|
+
super().__init__(full_config=full_config)
|
|
53
|
+
|
|
54
|
+
# Set the authentication flow handler
|
|
55
|
+
self.auth_flow_handler = ConsoleAuthenticationFlowHandler()
|
|
56
|
+
|
|
57
|
+
async def pre_run(self):
|
|
58
|
+
if (self.front_end_config.input_query is not None and self.front_end_config.input_file is not None):
|
|
59
|
+
raise click.UsageError("Must specify either --input or --input_file, not both")
|
|
60
|
+
if (self.front_end_config.input_query is None and self.front_end_config.input_file is None):
|
|
61
|
+
raise click.UsageError("Must specify either --input or --input_file")
|
|
62
|
+
|
|
63
|
+
async def run_workflow(self, session_manager: SessionManager):
|
|
64
|
+
|
|
65
|
+
assert session_manager is not None, "Session manager must be provided"
|
|
66
|
+
runner_outputs = None
|
|
67
|
+
|
|
68
|
+
if (self.front_end_config.input_query):
|
|
69
|
+
|
|
70
|
+
async def run_single_query(query):
|
|
71
|
+
|
|
72
|
+
async with session_manager.session(
|
|
73
|
+
user_input_callback=prompt_for_input_cli,
|
|
74
|
+
user_authentication_callback=self.auth_flow_handler.authenticate) as session:
|
|
75
|
+
async with session.run(query) as runner:
|
|
76
|
+
base_output = await runner.result(to_type=str)
|
|
77
|
+
|
|
78
|
+
return base_output
|
|
79
|
+
|
|
80
|
+
# Convert to a list
|
|
81
|
+
input_list = list(self.front_end_config.input_query)
|
|
82
|
+
logger.debug("Processing input: %s", self.front_end_config.input_query)
|
|
83
|
+
|
|
84
|
+
# Make `return_exceptions=False` explicit; all exceptions are raised instead of being silenced
|
|
85
|
+
runner_outputs = await asyncio.gather(*[run_single_query(query) for query in input_list],
|
|
86
|
+
return_exceptions=False)
|
|
87
|
+
|
|
88
|
+
elif (self.front_end_config.input_file):
|
|
89
|
+
|
|
90
|
+
# Run the workflow
|
|
91
|
+
with open(self.front_end_config.input_file, encoding="utf-8") as f:
|
|
92
|
+
|
|
93
|
+
async with session_manager.workflow.run(f) as runner:
|
|
94
|
+
runner_outputs = await runner.result(to_type=str)
|
|
95
|
+
else:
|
|
96
|
+
assert False, "Should not reach here. Should have been caught by pre_run"
|
|
97
|
+
|
|
98
|
+
line = f"{'-' * 50}"
|
|
99
|
+
prefix = f"{line}\n{Fore.GREEN}Workflow Result:\n"
|
|
100
|
+
suffix = f"{Fore.RESET}\n{line}"
|
|
101
|
+
|
|
102
|
+
logger.info(f"{prefix}%s{suffix}", runner_outputs)
|
|
103
|
+
|
|
104
|
+
# (handler is a stream handler) => (level > INFO)
|
|
105
|
+
effective_level_too_high = all(
|
|
106
|
+
type(h) is not logging.StreamHandler or h.level > logging.INFO for h in logging.getLogger().handlers)
|
|
107
|
+
if effective_level_too_high:
|
|
108
|
+
print(f"{prefix}{runner_outputs}{suffix}")
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from nat.cli.register_workflow import register_front_end
|
|
17
|
+
from nat.data_models.config import Config
|
|
18
|
+
from nat.front_ends.console.console_front_end_config import ConsoleFrontEndConfig
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_front_end(config_type=ConsoleFrontEndConfig)
|
|
22
|
+
async def register_fastapi_front_end(config: ConsoleFrontEndConfig, full_config: Config):
|
|
23
|
+
from nat.front_ends.console.console_front_end_plugin import ConsoleFrontEndPlugin
|
|
24
|
+
|
|
25
|
+
yield ConsoleFrontEndPlugin(full_config=full_config)
|