nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import typing
|
|
20
|
+
from functools import partial
|
|
21
|
+
from urllib.parse import urljoin
|
|
22
|
+
|
|
23
|
+
import httpx
|
|
24
|
+
from langchain_core.retrievers import BaseRetriever
|
|
25
|
+
from pydantic import BaseModel
|
|
26
|
+
from pydantic import Field
|
|
27
|
+
from pydantic import HttpUrl
|
|
28
|
+
|
|
29
|
+
from nat.retriever.interface import Retriever
|
|
30
|
+
from nat.retriever.models import Document
|
|
31
|
+
from nat.retriever.models import RetrieverError
|
|
32
|
+
from nat.retriever.models import RetrieverOutput
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Collection(BaseModel):
|
|
38
|
+
id: str
|
|
39
|
+
name: str
|
|
40
|
+
meta: typing.Any
|
|
41
|
+
pipeline: str
|
|
42
|
+
created_at: str
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RetrieverPayload(BaseModel):
|
|
46
|
+
query: str
|
|
47
|
+
top_k: int = Field(le=50, gt=0)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class CollectionUnavailableError(RetrieverError):
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class NemoRetriever(Retriever):
|
|
55
|
+
"""
|
|
56
|
+
Client for retrieving document chunks from a Nemo Retriever service.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, uri: str | HttpUrl, timeout: int = 60, nvidia_api_key: str = None, **kwargs):
|
|
60
|
+
|
|
61
|
+
self.base_url = str(uri)
|
|
62
|
+
self.timeout = timeout
|
|
63
|
+
self._search_func = self._search
|
|
64
|
+
self.api_key = nvidia_api_key if nvidia_api_key else os.getenv('NVIDIA_API_KEY')
|
|
65
|
+
self._bound_params = []
|
|
66
|
+
if not self.api_key:
|
|
67
|
+
logger.warning("No API key was specified as part of configuration or as an environment variable.")
|
|
68
|
+
|
|
69
|
+
def bind(self, **kwargs) -> None:
|
|
70
|
+
"""
|
|
71
|
+
Bind default values to the search method. Cannot bind the 'query' parameter.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
kwargs (dict): Key value pairs corresponding to the default values of search parameters.
|
|
75
|
+
"""
|
|
76
|
+
if "query" in kwargs:
|
|
77
|
+
kwargs = {k: v for k, v in kwargs.items() if k != "query"}
|
|
78
|
+
self._search_func = partial(self._search_func, **kwargs)
|
|
79
|
+
self._bound_params = list(kwargs.keys())
|
|
80
|
+
logger.debug("Binding paramaters for search function: %s", kwargs)
|
|
81
|
+
|
|
82
|
+
def get_unbound_params(self) -> list[str]:
|
|
83
|
+
"""
|
|
84
|
+
Returns a list of unbound parameters which will need to be passed to the search function.
|
|
85
|
+
"""
|
|
86
|
+
return [param for param in ["query", "collection_name", "top_k"] if param not in self._bound_params]
|
|
87
|
+
|
|
88
|
+
async def get_collections(self, client) -> list[Collection]:
|
|
89
|
+
"""
|
|
90
|
+
Get a list of all available collections as pydantic `Collection` objects
|
|
91
|
+
"""
|
|
92
|
+
collection_response = await client.get(urljoin(self.base_url, "/v1/collections"))
|
|
93
|
+
collection_response.raise_for_status()
|
|
94
|
+
if not collection_response or len(collection_response.json().get('collections', [])) == 0:
|
|
95
|
+
raise CollectionUnavailableError(f"No collections available at {self.base_url}")
|
|
96
|
+
|
|
97
|
+
collections = [
|
|
98
|
+
Collection.model_validate(collection) for collection in collection_response.json()["collections"]
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
return collections
|
|
102
|
+
|
|
103
|
+
async def get_collection_by_name(self, collection_name, client) -> Collection:
|
|
104
|
+
"""
|
|
105
|
+
Retrieve a collection using it's name. Will return the first collection found if the name is ambiguous.
|
|
106
|
+
"""
|
|
107
|
+
collections = await self.get_collections(client)
|
|
108
|
+
if (collection := next((c for c in collections if c.name == collection_name), None)) is None:
|
|
109
|
+
raise CollectionUnavailableError(f"Collection {collection_name} not found")
|
|
110
|
+
return collection
|
|
111
|
+
|
|
112
|
+
async def search(self, query: str, **kwargs):
|
|
113
|
+
return await self._search_func(query=query, **kwargs)
|
|
114
|
+
|
|
115
|
+
async def _search(
|
|
116
|
+
self,
|
|
117
|
+
query: str,
|
|
118
|
+
collection_name: str,
|
|
119
|
+
top_k: str,
|
|
120
|
+
output_fields: list[str] = None,
|
|
121
|
+
):
|
|
122
|
+
"""
|
|
123
|
+
Retrieve document chunks from the configured Nemo Retriever Service.
|
|
124
|
+
"""
|
|
125
|
+
output = []
|
|
126
|
+
try:
|
|
127
|
+
async with httpx.AsyncClient(headers={"Authorization": f"Bearer {self.api_key}"},
|
|
128
|
+
timeout=self.timeout) as client:
|
|
129
|
+
collection = await self.get_collection_by_name(collection_name, client)
|
|
130
|
+
url = urljoin(self.base_url, f"/v1/collections/{collection.id}/search")
|
|
131
|
+
|
|
132
|
+
payload = RetrieverPayload(query=query, top_k=top_k)
|
|
133
|
+
response = await client.post(url, content=json.dumps(payload.model_dump(mode="python")))
|
|
134
|
+
|
|
135
|
+
logger.debug("response.status_code=%s", response.status_code)
|
|
136
|
+
|
|
137
|
+
response.raise_for_status()
|
|
138
|
+
output = response.json().get("chunks")
|
|
139
|
+
|
|
140
|
+
# Handle output fields
|
|
141
|
+
output = [_flatten(chunk, output_fields) for chunk in output]
|
|
142
|
+
|
|
143
|
+
return _wrap_nemo_results(output=output, content_field="content")
|
|
144
|
+
|
|
145
|
+
except Exception as e:
|
|
146
|
+
logger.exception("Encountered an error when retrieving results from Nemo Retriever: %s", e)
|
|
147
|
+
raise CollectionUnavailableError(
|
|
148
|
+
f"Error when retrieving documents from {collection_name} for query '{query}'") from e
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _wrap_nemo_results(output: list[dict], content_field: str):
|
|
152
|
+
return RetrieverOutput(results=[_wrap_nemo_single_results(o, content_field=content_field) for o in output])
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _wrap_nemo_single_results(output: dict, content_field: str):
|
|
156
|
+
return Document(page_content=output[content_field],
|
|
157
|
+
metadata={
|
|
158
|
+
k: v
|
|
159
|
+
for k, v in output.items() if k != content_field
|
|
160
|
+
})
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _flatten(obj: dict, output_fields: list[str]) -> list[str]:
|
|
164
|
+
base_fields = [
|
|
165
|
+
"format",
|
|
166
|
+
"id",
|
|
167
|
+
]
|
|
168
|
+
if not output_fields:
|
|
169
|
+
output_fields = [
|
|
170
|
+
"format",
|
|
171
|
+
"id",
|
|
172
|
+
]
|
|
173
|
+
output_fields.extend(list(obj["metadata"].keys()))
|
|
174
|
+
data = {"content": obj.get("content")}
|
|
175
|
+
for field in base_fields:
|
|
176
|
+
if field in output_fields:
|
|
177
|
+
data.update({field: obj[field]})
|
|
178
|
+
|
|
179
|
+
data.update({k: v for k, v in obj['metadata'].items() if k in output_fields})
|
|
180
|
+
return data
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class NemoLangchainRetriever(BaseRetriever, BaseModel):
|
|
184
|
+
client: NemoRetriever
|
|
185
|
+
|
|
186
|
+
def _get_relevant_documents(self, query, *, run_manager, **kwargs):
|
|
187
|
+
raise NotImplementedError
|
|
188
|
+
|
|
189
|
+
async def _aget_relevant_documents(self, query, *, run_manager, **kwargs):
|
|
190
|
+
return await self.client.search(query, **kwargs)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# pylint: disable=unused-import
|
|
17
|
+
# flake8: noqa
|
|
18
|
+
# isort:skip_file
|
|
19
|
+
|
|
20
|
+
# Import any providers which need to be automatically registered here
|
|
21
|
+
import nat.retriever.milvus.register
|
|
22
|
+
import nat.retriever.nemo_retriever.register
|
nat/runtime/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
nat/runtime/loader.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import importlib.metadata
|
|
19
|
+
import logging
|
|
20
|
+
import time
|
|
21
|
+
from contextlib import asynccontextmanager
|
|
22
|
+
from enum import IntFlag
|
|
23
|
+
from enum import auto
|
|
24
|
+
from functools import lru_cache
|
|
25
|
+
from functools import reduce
|
|
26
|
+
|
|
27
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
28
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
29
|
+
from nat.data_models.config import Config
|
|
30
|
+
from nat.runtime.session import SessionManager
|
|
31
|
+
from nat.utils.data_models.schema_validator import validate_schema
|
|
32
|
+
from nat.utils.debugging_utils import is_debugger_attached
|
|
33
|
+
from nat.utils.io.yaml_tools import yaml_load
|
|
34
|
+
from nat.utils.type_utils import StrPath
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PluginTypes(IntFlag):
|
|
40
|
+
COMPONENT = auto()
|
|
41
|
+
"""
|
|
42
|
+
A plugin that is a component of the workflow. This includes tools, LLMs, retrievers, etc.
|
|
43
|
+
"""
|
|
44
|
+
FRONT_END = auto()
|
|
45
|
+
"""
|
|
46
|
+
A plugin that is a front end for the workflow. This includes FastAPI, Gradio, etc.
|
|
47
|
+
"""
|
|
48
|
+
EVALUATOR = auto()
|
|
49
|
+
"""
|
|
50
|
+
A plugin that is an evaluator for the workflow. This includes evaluators like RAGAS, SWE-bench, etc.
|
|
51
|
+
"""
|
|
52
|
+
AUTHENTICATION = auto()
|
|
53
|
+
"""
|
|
54
|
+
A plugin that is an API authentication provider for the workflow. This includes Oauth2, API Key, etc.
|
|
55
|
+
"""
|
|
56
|
+
REGISTRY_HANDLER = auto()
|
|
57
|
+
|
|
58
|
+
# Convenience flag for groups of plugin types
|
|
59
|
+
CONFIG_OBJECT = COMPONENT | FRONT_END | EVALUATOR | AUTHENTICATION
|
|
60
|
+
"""
|
|
61
|
+
Any plugin that can be specified in the NAT configuration file.
|
|
62
|
+
"""
|
|
63
|
+
ALL = COMPONENT | FRONT_END | EVALUATOR | REGISTRY_HANDLER | AUTHENTICATION
|
|
64
|
+
"""
|
|
65
|
+
All plugin types
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def load_config(config_file: StrPath) -> Config:
|
|
70
|
+
"""
|
|
71
|
+
This is the primary entry point for loading a NAT configuration file. It ensures that all plugins are
|
|
72
|
+
loaded and then validates the configuration file against the Config schema.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
config_file : StrPath
|
|
77
|
+
The path to the configuration file
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
Config
|
|
82
|
+
The validated Config object
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
# Ensure all of the plugins are loaded
|
|
86
|
+
discover_and_register_plugins(PluginTypes.CONFIG_OBJECT)
|
|
87
|
+
|
|
88
|
+
config_yaml = yaml_load(config_file)
|
|
89
|
+
|
|
90
|
+
# Validate configuration adheres to NAT schemas
|
|
91
|
+
validated_nat_config = validate_schema(config_yaml, Config)
|
|
92
|
+
|
|
93
|
+
return validated_nat_config
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@asynccontextmanager
|
|
97
|
+
async def load_workflow(config_file: StrPath, max_concurrency: int = -1):
|
|
98
|
+
"""
|
|
99
|
+
Load the NAT configuration file and create an Runner object. This is the primary entry point for running
|
|
100
|
+
NAT workflows.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
config_file : StrPath
|
|
105
|
+
The path to the configuration file
|
|
106
|
+
max_concurrency : int, optional
|
|
107
|
+
The maximum number of parallel workflow invocations to support. Specifying 0 or -1 will allow an unlimited
|
|
108
|
+
count, by default -1
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
# Load the config object
|
|
112
|
+
config = load_config(config_file)
|
|
113
|
+
|
|
114
|
+
# Must yield the workflow function otherwise it cleans up
|
|
115
|
+
async with WorkflowBuilder.from_config(config=config) as workflow:
|
|
116
|
+
|
|
117
|
+
yield SessionManager(workflow.build(), max_concurrency=max_concurrency)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@lru_cache
|
|
121
|
+
def discover_entrypoints(plugin_type: PluginTypes):
|
|
122
|
+
"""
|
|
123
|
+
Discover all the requested plugin types which were registered via an entry point group and return them.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
entry_points = importlib.metadata.entry_points()
|
|
127
|
+
|
|
128
|
+
plugin_groups = []
|
|
129
|
+
|
|
130
|
+
# Add the specified plugin type to the list of groups to load
|
|
131
|
+
# The aiq entrypoints are intentionally left in the list to maintain backwards compatibility.
|
|
132
|
+
if (plugin_type & PluginTypes.COMPONENT):
|
|
133
|
+
plugin_groups.extend(["aiq.plugins", "aiq.components", "nat.plugins", "nat.components"])
|
|
134
|
+
if (plugin_type & PluginTypes.FRONT_END):
|
|
135
|
+
plugin_groups.extend(["aiq.front_ends", "nat.front_ends"])
|
|
136
|
+
if (plugin_type & PluginTypes.REGISTRY_HANDLER):
|
|
137
|
+
plugin_groups.extend(["aiq.registry_handlers", "nat.registry_handlers"])
|
|
138
|
+
if (plugin_type & PluginTypes.EVALUATOR):
|
|
139
|
+
plugin_groups.extend(["aiq.evaluators", "nat.evaluators"])
|
|
140
|
+
if (plugin_type & PluginTypes.AUTHENTICATION):
|
|
141
|
+
plugin_groups.extend(["aiq.authentication_providers", "nat.authentication_providers"])
|
|
142
|
+
|
|
143
|
+
# Get the entry points for the specified groups
|
|
144
|
+
nat_plugins = reduce(lambda x, y: list(x) + list(y), [entry_points.select(group=y) for y in plugin_groups])
|
|
145
|
+
|
|
146
|
+
return nat_plugins
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@lru_cache
|
|
150
|
+
def get_all_entrypoints_distro_mapping() -> dict[str, str]:
|
|
151
|
+
"""
|
|
152
|
+
Get the mapping of all NAT entry points to their distribution names.
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
mapping = {}
|
|
156
|
+
nat_entrypoints = discover_entrypoints(PluginTypes.ALL)
|
|
157
|
+
for ep in nat_entrypoints:
|
|
158
|
+
ep_module_parts = ep.module.split(".")
|
|
159
|
+
current_parts = []
|
|
160
|
+
for part in ep_module_parts:
|
|
161
|
+
current_parts.append(part)
|
|
162
|
+
module_prefix = ".".join(current_parts)
|
|
163
|
+
mapping[module_prefix] = ep.dist.name
|
|
164
|
+
|
|
165
|
+
return mapping
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def discover_and_register_plugins(plugin_type: PluginTypes):
|
|
169
|
+
"""
|
|
170
|
+
Discover all the requested plugin types which were registered via an entry point group and register them into the
|
|
171
|
+
GlobalTypeRegistry.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
# Get the entry points for the specified groups
|
|
175
|
+
nat_plugins = discover_entrypoints(plugin_type)
|
|
176
|
+
|
|
177
|
+
count = 0
|
|
178
|
+
|
|
179
|
+
# Pause registration hooks for performance. This is useful when loading a large number of plugins.
|
|
180
|
+
with GlobalTypeRegistry.get().pause_registration_changed_hooks():
|
|
181
|
+
|
|
182
|
+
for entry_point in nat_plugins:
|
|
183
|
+
try:
|
|
184
|
+
logger.debug("Loading module '%s' from entry point '%s'...", entry_point.module, entry_point.name)
|
|
185
|
+
|
|
186
|
+
start_time = time.time()
|
|
187
|
+
|
|
188
|
+
entry_point.load()
|
|
189
|
+
|
|
190
|
+
elapsed_time = (time.time() - start_time) * 1000
|
|
191
|
+
|
|
192
|
+
logger.debug("Loading module '%s' from entry point '%s'...Complete (%f ms)",
|
|
193
|
+
entry_point.module,
|
|
194
|
+
entry_point.name,
|
|
195
|
+
elapsed_time)
|
|
196
|
+
|
|
197
|
+
# Log a warning if the plugin took a long time to load. This can be useful for debugging slow imports.
|
|
198
|
+
# The threshold is 300 ms if no plugins have been loaded yet, and 100 ms otherwise. Triple the threshold
|
|
199
|
+
# if a debugger is attached.
|
|
200
|
+
if (elapsed_time > (300.0 if count == 0 else 150.0) * (3 if is_debugger_attached() else 1)):
|
|
201
|
+
logger.debug(
|
|
202
|
+
"Loading module '%s' from entry point '%s' took a long time (%f ms). "
|
|
203
|
+
"Ensure all imports are inside your registered functions.",
|
|
204
|
+
entry_point.module,
|
|
205
|
+
entry_point.name,
|
|
206
|
+
elapsed_time)
|
|
207
|
+
|
|
208
|
+
except ImportError:
|
|
209
|
+
logger.warning("Failed to import plugin '%s'", entry_point.name, exc_info=True)
|
|
210
|
+
# Optionally, you can mark the plugin as unavailable or take other actions
|
|
211
|
+
|
|
212
|
+
except Exception:
|
|
213
|
+
logger.exception("An error occurred while loading plugin '%s': {e}", entry_point.name, exc_info=True)
|
|
214
|
+
|
|
215
|
+
finally:
|
|
216
|
+
count += 1
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# Compatibility alias
|
|
220
|
+
get_all_aiq_entrypoints_distro_mapping = get_all_entrypoints_distro_mapping
|
nat/runtime/runner.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import typing
|
|
18
|
+
from enum import Enum
|
|
19
|
+
|
|
20
|
+
from nat.builder.context import Context
|
|
21
|
+
from nat.builder.context import ContextState
|
|
22
|
+
from nat.builder.function import Function
|
|
23
|
+
from nat.data_models.invocation_node import InvocationNode
|
|
24
|
+
from nat.observability.exporter_manager import ExporterManager
|
|
25
|
+
from nat.utils.reactive.subject import Subject
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class UserManagerBase:
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RunnerState(Enum):
|
|
35
|
+
UNINITIALIZED = 0
|
|
36
|
+
INITIALIZED = 1
|
|
37
|
+
RUNNING = 2
|
|
38
|
+
COMPLETED = 3
|
|
39
|
+
FAILED = 4
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
_T = typing.TypeVar("_T")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Runner:
|
|
46
|
+
|
|
47
|
+
def __init__(self,
|
|
48
|
+
input_message: typing.Any,
|
|
49
|
+
entry_fn: Function,
|
|
50
|
+
context_state: ContextState,
|
|
51
|
+
exporter_manager: ExporterManager):
|
|
52
|
+
"""
|
|
53
|
+
The Runner class is used to run a workflow. It handles converting input and output data types and running the
|
|
54
|
+
workflow with the specified concurrency.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
input_message : typing.Any
|
|
59
|
+
The input message to the workflow
|
|
60
|
+
entry_fn : Function
|
|
61
|
+
The entry function to the workflow
|
|
62
|
+
context_state : ContextState
|
|
63
|
+
The context state to use
|
|
64
|
+
exporter_manager : ExporterManager
|
|
65
|
+
The exporter manager to use
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
if (entry_fn is None):
|
|
69
|
+
raise ValueError("entry_fn cannot be None")
|
|
70
|
+
|
|
71
|
+
self._entry_fn = entry_fn
|
|
72
|
+
self._context_state = context_state
|
|
73
|
+
self._context = Context(self._context_state)
|
|
74
|
+
|
|
75
|
+
self._state = RunnerState.UNINITIALIZED
|
|
76
|
+
|
|
77
|
+
self._input_message_token = None
|
|
78
|
+
|
|
79
|
+
# Before we start, we need to convert the input message to the workflow input type
|
|
80
|
+
self._input_message = input_message
|
|
81
|
+
|
|
82
|
+
self._exporter_manager = exporter_manager
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def context(self) -> Context:
|
|
86
|
+
return self._context
|
|
87
|
+
|
|
88
|
+
def convert(self, value: typing.Any, to_type: type[_T]) -> _T:
|
|
89
|
+
return self._entry_fn.convert(value, to_type)
|
|
90
|
+
|
|
91
|
+
async def __aenter__(self):
|
|
92
|
+
|
|
93
|
+
# Set the input message on the context
|
|
94
|
+
self._input_message_token = self._context_state.input_message.set(self._input_message)
|
|
95
|
+
|
|
96
|
+
# Create reactive event stream
|
|
97
|
+
self._context_state.event_stream.set(Subject())
|
|
98
|
+
self._context_state.active_function.set(InvocationNode(
|
|
99
|
+
function_name="root",
|
|
100
|
+
function_id="root",
|
|
101
|
+
))
|
|
102
|
+
|
|
103
|
+
if (self._state == RunnerState.UNINITIALIZED):
|
|
104
|
+
self._state = RunnerState.INITIALIZED
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError("Cannot enter the context more than once")
|
|
107
|
+
|
|
108
|
+
return self
|
|
109
|
+
|
|
110
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
111
|
+
|
|
112
|
+
if (self._input_message_token is None):
|
|
113
|
+
raise ValueError("Cannot exit the context without entering it")
|
|
114
|
+
|
|
115
|
+
self._context_state.input_message.reset(self._input_message_token)
|
|
116
|
+
|
|
117
|
+
if (self._state not in (RunnerState.COMPLETED, RunnerState.FAILED)):
|
|
118
|
+
raise ValueError("Cannot exit the context without completing the workflow")
|
|
119
|
+
|
|
120
|
+
@typing.overload
|
|
121
|
+
async def result(self) -> typing.Any:
|
|
122
|
+
...
|
|
123
|
+
|
|
124
|
+
@typing.overload
|
|
125
|
+
async def result(self, to_type: type[_T]) -> _T:
|
|
126
|
+
...
|
|
127
|
+
|
|
128
|
+
async def result(self, to_type: type | None = None):
|
|
129
|
+
|
|
130
|
+
if (self._state != RunnerState.INITIALIZED):
|
|
131
|
+
raise ValueError("Cannot run the workflow without entering the context")
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
self._state = RunnerState.RUNNING
|
|
135
|
+
|
|
136
|
+
if (not self._entry_fn.has_single_output):
|
|
137
|
+
raise ValueError("Workflow does not support single output")
|
|
138
|
+
|
|
139
|
+
async with self._exporter_manager.start(context_state=self._context_state):
|
|
140
|
+
# Run the workflow
|
|
141
|
+
result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type)
|
|
142
|
+
|
|
143
|
+
# Close the intermediate stream
|
|
144
|
+
event_stream = self._context_state.event_stream.get()
|
|
145
|
+
if event_stream:
|
|
146
|
+
event_stream.on_complete()
|
|
147
|
+
|
|
148
|
+
self._state = RunnerState.COMPLETED
|
|
149
|
+
|
|
150
|
+
return result
|
|
151
|
+
except Exception as e:
|
|
152
|
+
logger.exception("Error running workflow: %s", e)
|
|
153
|
+
event_stream = self._context_state.event_stream.get()
|
|
154
|
+
if event_stream:
|
|
155
|
+
event_stream.on_complete()
|
|
156
|
+
self._state = RunnerState.FAILED
|
|
157
|
+
|
|
158
|
+
raise
|
|
159
|
+
|
|
160
|
+
async def result_stream(self, to_type: type | None = None):
|
|
161
|
+
|
|
162
|
+
if (self._state != RunnerState.INITIALIZED):
|
|
163
|
+
raise ValueError("Cannot run the workflow without entering the context")
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
self._state = RunnerState.RUNNING
|
|
167
|
+
|
|
168
|
+
if (not self._entry_fn.has_streaming_output):
|
|
169
|
+
raise ValueError("Workflow does not support streaming output")
|
|
170
|
+
|
|
171
|
+
# Run the workflow
|
|
172
|
+
async with self._exporter_manager.start(context_state=self._context_state):
|
|
173
|
+
async for m in self._entry_fn.astream(self._input_message, to_type=to_type):
|
|
174
|
+
yield m
|
|
175
|
+
|
|
176
|
+
self._state = RunnerState.COMPLETED
|
|
177
|
+
|
|
178
|
+
# Close the intermediate stream
|
|
179
|
+
event_stream = self._context_state.event_stream.get()
|
|
180
|
+
if event_stream:
|
|
181
|
+
event_stream.on_complete()
|
|
182
|
+
|
|
183
|
+
except Exception as e:
|
|
184
|
+
logger.exception("Error running workflow: %s", e)
|
|
185
|
+
event_stream = self._context_state.event_stream.get()
|
|
186
|
+
if event_stream:
|
|
187
|
+
event_stream.on_complete()
|
|
188
|
+
self._state = RunnerState.FAILED
|
|
189
|
+
|
|
190
|
+
raise
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# Compatibility aliases with previous releases
|
|
194
|
+
AIQRunnerState = RunnerState
|
|
195
|
+
AIQRunner = Runner
|