nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
nat/builder/builder.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import typing
|
|
18
|
+
from abc import ABC
|
|
19
|
+
from abc import abstractmethod
|
|
20
|
+
from collections.abc import Sequence
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
24
|
+
from nat.builder.context import Context
|
|
25
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
26
|
+
from nat.builder.function import Function
|
|
27
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
28
|
+
from nat.data_models.component_ref import AuthenticationRef
|
|
29
|
+
from nat.data_models.component_ref import EmbedderRef
|
|
30
|
+
from nat.data_models.component_ref import FunctionRef
|
|
31
|
+
from nat.data_models.component_ref import LLMRef
|
|
32
|
+
from nat.data_models.component_ref import MemoryRef
|
|
33
|
+
from nat.data_models.component_ref import ObjectStoreRef
|
|
34
|
+
from nat.data_models.component_ref import RetrieverRef
|
|
35
|
+
from nat.data_models.component_ref import TTCStrategyRef
|
|
36
|
+
from nat.data_models.embedder import EmbedderBaseConfig
|
|
37
|
+
from nat.data_models.evaluator import EvaluatorBaseConfig
|
|
38
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
39
|
+
from nat.data_models.function_dependencies import FunctionDependencies
|
|
40
|
+
from nat.data_models.llm import LLMBaseConfig
|
|
41
|
+
from nat.data_models.memory import MemoryBaseConfig
|
|
42
|
+
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
43
|
+
from nat.data_models.retriever import RetrieverBaseConfig
|
|
44
|
+
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
45
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
46
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
47
|
+
from nat.memory.interfaces import MemoryEditor
|
|
48
|
+
from nat.object_store.interfaces import ObjectStore
|
|
49
|
+
from nat.retriever.interface import Retriever
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class UserManagerHolder():
|
|
53
|
+
|
|
54
|
+
def __init__(self, context: Context) -> None:
|
|
55
|
+
self._context = context
|
|
56
|
+
|
|
57
|
+
def get_id(self):
|
|
58
|
+
return self._context.user_manager.get_id()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Builder(ABC): # pylint: disable=too-many-public-methods
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def get_function(self, name: str | FunctionRef) -> Function:
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
|
|
72
|
+
|
|
73
|
+
return [self.get_function(name) for name in function_names]
|
|
74
|
+
|
|
75
|
+
@abstractmethod
|
|
76
|
+
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
@abstractmethod
|
|
80
|
+
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def get_workflow(self) -> Function:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
def get_workflow_config(self) -> FunctionBaseConfig:
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
def get_tools(self, tool_names: Sequence[str | FunctionRef],
|
|
92
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
93
|
+
|
|
94
|
+
return [self.get_tool(fn_name=n, wrapper_type=wrapper_type) for n in tool_names]
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
@abstractmethod
|
|
101
|
+
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
@abstractmethod
|
|
105
|
+
async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
async def get_llms(self, llm_names: Sequence[str | LLMRef],
|
|
109
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
110
|
+
|
|
111
|
+
coros = [self.get_llm(llm_name=n, wrapper_type=wrapper_type) for n in llm_names]
|
|
112
|
+
|
|
113
|
+
llms = await asyncio.gather(*coros, return_exceptions=False)
|
|
114
|
+
|
|
115
|
+
return list(llms)
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def get_llm_config(self, llm_name: str | LLMRef) -> LLMBaseConfig:
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
@abstractmethod
|
|
122
|
+
async def add_auth_provider(self, name: str | AuthenticationRef, config: AuthProviderBaseConfig):
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
@abstractmethod
|
|
126
|
+
async def get_auth_provider(self, auth_provider_name: str | AuthenticationRef) -> AuthProviderBase:
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
async def get_auth_providers(self, auth_provider_names: list[str | AuthenticationRef]):
|
|
130
|
+
|
|
131
|
+
coros = [self.get_auth_provider(auth_provider_name=n) for n in auth_provider_names]
|
|
132
|
+
|
|
133
|
+
auth_providers = await asyncio.gather(*coros, return_exceptions=False)
|
|
134
|
+
|
|
135
|
+
return list(auth_providers)
|
|
136
|
+
|
|
137
|
+
@abstractmethod
|
|
138
|
+
async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig):
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
async def get_object_store_clients(self, object_store_names: Sequence[str | ObjectStoreRef]) -> list[ObjectStore]:
|
|
142
|
+
"""
|
|
143
|
+
Return a list of all object store clients.
|
|
144
|
+
"""
|
|
145
|
+
return list(await asyncio.gather(*[self.get_object_store_client(name) for name in object_store_names]))
|
|
146
|
+
|
|
147
|
+
@abstractmethod
|
|
148
|
+
async def get_object_store_client(self, object_store_name: str | ObjectStoreRef) -> ObjectStore:
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
@abstractmethod
|
|
152
|
+
def get_object_store_config(self, object_store_name: str | ObjectStoreRef) -> ObjectStoreBaseConfig:
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
@abstractmethod
|
|
156
|
+
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
|
|
157
|
+
pass
|
|
158
|
+
|
|
159
|
+
async def get_embedders(self, embedder_names: Sequence[str | EmbedderRef],
|
|
160
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
161
|
+
|
|
162
|
+
coros = [self.get_embedder(embedder_name=n, wrapper_type=wrapper_type) for n in embedder_names]
|
|
163
|
+
|
|
164
|
+
embedders = await asyncio.gather(*coros, return_exceptions=False)
|
|
165
|
+
|
|
166
|
+
return list(embedders)
|
|
167
|
+
|
|
168
|
+
@abstractmethod
|
|
169
|
+
async def get_embedder(self, embedder_name: str | EmbedderRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
170
|
+
pass
|
|
171
|
+
|
|
172
|
+
@abstractmethod
|
|
173
|
+
def get_embedder_config(self, embedder_name: str | EmbedderRef) -> EmbedderBaseConfig:
|
|
174
|
+
pass
|
|
175
|
+
|
|
176
|
+
@abstractmethod
|
|
177
|
+
async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig):
|
|
178
|
+
pass
|
|
179
|
+
|
|
180
|
+
def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
|
|
181
|
+
"""
|
|
182
|
+
Return a list of memory clients for the specified names.
|
|
183
|
+
"""
|
|
184
|
+
return [self.get_memory_client(n) for n in memory_names]
|
|
185
|
+
|
|
186
|
+
@abstractmethod
|
|
187
|
+
def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
188
|
+
"""
|
|
189
|
+
Return the instantiated memory client for the given name.
|
|
190
|
+
"""
|
|
191
|
+
pass
|
|
192
|
+
|
|
193
|
+
@abstractmethod
|
|
194
|
+
def get_memory_client_config(self, memory_name: str | MemoryRef) -> MemoryBaseConfig:
|
|
195
|
+
pass
|
|
196
|
+
|
|
197
|
+
@abstractmethod
|
|
198
|
+
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
async def get_retrievers(self,
|
|
202
|
+
retriever_names: Sequence[str | RetrieverRef],
|
|
203
|
+
wrapper_type: LLMFrameworkEnum | str | None = None):
|
|
204
|
+
|
|
205
|
+
tasks = [self.get_retriever(n, wrapper_type=wrapper_type) for n in retriever_names]
|
|
206
|
+
|
|
207
|
+
retrievers = await asyncio.gather(*tasks, return_exceptions=False)
|
|
208
|
+
|
|
209
|
+
return list(retrievers)
|
|
210
|
+
|
|
211
|
+
@typing.overload
|
|
212
|
+
async def get_retriever(self, retriever_name: str | RetrieverRef,
|
|
213
|
+
wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
214
|
+
...
|
|
215
|
+
|
|
216
|
+
@typing.overload
|
|
217
|
+
async def get_retriever(self, retriever_name: str | RetrieverRef, wrapper_type: None) -> Retriever:
|
|
218
|
+
...
|
|
219
|
+
|
|
220
|
+
@typing.overload
|
|
221
|
+
async def get_retriever(self, retriever_name: str | RetrieverRef) -> Retriever:
|
|
222
|
+
...
|
|
223
|
+
|
|
224
|
+
@abstractmethod
|
|
225
|
+
async def get_retriever(self,
|
|
226
|
+
retriever_name: str | RetrieverRef,
|
|
227
|
+
wrapper_type: LLMFrameworkEnum | str | None = None) -> typing.Any:
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
@abstractmethod
|
|
231
|
+
async def get_retriever_config(self, retriever_name: str | RetrieverRef) -> RetrieverBaseConfig:
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
@abstractmethod
|
|
235
|
+
async def add_ttc_strategy(self, name: str | str, config: TTCStrategyBaseConfig):
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
@abstractmethod
|
|
239
|
+
async def get_ttc_strategy(self,
|
|
240
|
+
strategy_name: str | TTCStrategyRef,
|
|
241
|
+
pipeline_type: PipelineTypeEnum,
|
|
242
|
+
stage_type: StageTypeEnum):
|
|
243
|
+
pass
|
|
244
|
+
|
|
245
|
+
@abstractmethod
|
|
246
|
+
async def get_ttc_strategy_config(self,
|
|
247
|
+
strategy_name: str | TTCStrategyRef,
|
|
248
|
+
pipeline_type: PipelineTypeEnum,
|
|
249
|
+
stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
|
|
250
|
+
pass
|
|
251
|
+
|
|
252
|
+
@abstractmethod
|
|
253
|
+
def get_user_manager(self) -> UserManagerHolder:
|
|
254
|
+
pass
|
|
255
|
+
|
|
256
|
+
@abstractmethod
|
|
257
|
+
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
258
|
+
pass
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class EvalBuilder(Builder):
|
|
262
|
+
|
|
263
|
+
@abstractmethod
|
|
264
|
+
async def add_evaluator(self, name: str, config: EvaluatorBaseConfig):
|
|
265
|
+
pass
|
|
266
|
+
|
|
267
|
+
@abstractmethod
|
|
268
|
+
def get_evaluator(self, evaluator_name: str) -> typing.Any:
|
|
269
|
+
pass
|
|
270
|
+
|
|
271
|
+
@abstractmethod
|
|
272
|
+
def get_evaluator_config(self, evaluator_name: str) -> EvaluatorBaseConfig:
|
|
273
|
+
pass
|
|
274
|
+
|
|
275
|
+
@abstractmethod
|
|
276
|
+
def get_max_concurrency(self) -> int:
|
|
277
|
+
pass
|
|
278
|
+
|
|
279
|
+
@abstractmethod
|
|
280
|
+
def get_output_dir(self) -> Path:
|
|
281
|
+
pass
|
|
282
|
+
|
|
283
|
+
@abstractmethod
|
|
284
|
+
def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
285
|
+
pass
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import typing
|
|
18
|
+
from collections.abc import Generator
|
|
19
|
+
from collections.abc import Iterable
|
|
20
|
+
|
|
21
|
+
import networkx as nx
|
|
22
|
+
from pydantic import BaseModel
|
|
23
|
+
|
|
24
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
25
|
+
from nat.data_models.common import TypedBaseModel
|
|
26
|
+
from nat.data_models.component import ComponentGroup
|
|
27
|
+
from nat.data_models.component_ref import ComponentRef
|
|
28
|
+
from nat.data_models.component_ref import ComponentRefNode
|
|
29
|
+
from nat.data_models.component_ref import generate_instance_id
|
|
30
|
+
from nat.data_models.config import Config
|
|
31
|
+
from nat.data_models.embedder import EmbedderBaseConfig
|
|
32
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
33
|
+
from nat.data_models.llm import LLMBaseConfig
|
|
34
|
+
from nat.data_models.memory import MemoryBaseConfig
|
|
35
|
+
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
36
|
+
from nat.data_models.retriever import RetrieverBaseConfig
|
|
37
|
+
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
38
|
+
from nat.utils.type_utils import DecomposedType
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
# Order in which we want to process the component groups
|
|
43
|
+
_component_group_order = [
|
|
44
|
+
ComponentGroup.AUTHENTICATION,
|
|
45
|
+
ComponentGroup.EMBEDDERS,
|
|
46
|
+
ComponentGroup.LLMS,
|
|
47
|
+
ComponentGroup.MEMORY,
|
|
48
|
+
ComponentGroup.OBJECT_STORES,
|
|
49
|
+
ComponentGroup.RETRIEVERS,
|
|
50
|
+
ComponentGroup.TTC_STRATEGIES,
|
|
51
|
+
ComponentGroup.FUNCTIONS,
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ComponentInstanceData(BaseModel):
|
|
56
|
+
"""A data model to hold component runtime instance metadata to support generating build sequences.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
component_group (ComponentGroup): The component group in a NAT configuration object.
|
|
60
|
+
name (ComponentRef): The name of the component runtime instance.
|
|
61
|
+
config (TypedBaseModel): The runtime instance's configuration object.
|
|
62
|
+
instance_id (str): Unique identifier for each runtime instance.
|
|
63
|
+
is_root (bool): A flag to indicate if the runtime instance is the root of the workflow.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
component_group: ComponentGroup
|
|
67
|
+
name: ComponentRef
|
|
68
|
+
config: TypedBaseModel
|
|
69
|
+
instance_id: str
|
|
70
|
+
is_root: bool = False
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def iterate_leaf_to_root(graph: nx.DiGraph) -> Generator[ComponentRefNode]:
|
|
74
|
+
"""A recursive generator that yields leaf nodes from the bottom to the root of a directed graph.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
graph (nx.DiGraph): A networkx directed graph object.
|
|
78
|
+
|
|
79
|
+
Yields:
|
|
80
|
+
ComponentRefNode: An object contain a ComponentRef and its component group.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
leaf_nodes = [node for node, degree in graph.out_degree() if degree == 0]
|
|
84
|
+
|
|
85
|
+
if len(leaf_nodes) > 0:
|
|
86
|
+
for leaf_node in leaf_nodes:
|
|
87
|
+
yield leaf_node
|
|
88
|
+
graph.remove_node(leaf_node)
|
|
89
|
+
|
|
90
|
+
yield from iterate_leaf_to_root(graph)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def group_from_component(component: TypedBaseModel) -> ComponentGroup | None:
|
|
94
|
+
"""Determines the component group from a runtime instance configuration object.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
component (TypedBaseModel): A runtime instance configuration object.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
ComponentGroup | None: The component group of the runtime instance configuration object. If the
|
|
101
|
+
component is not a valid runtime instance, None is returned.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
if (isinstance(component, AuthProviderBaseConfig)):
|
|
105
|
+
return ComponentGroup.AUTHENTICATION
|
|
106
|
+
if (isinstance(component, EmbedderBaseConfig)):
|
|
107
|
+
return ComponentGroup.EMBEDDERS
|
|
108
|
+
if (isinstance(component, FunctionBaseConfig)):
|
|
109
|
+
return ComponentGroup.FUNCTIONS
|
|
110
|
+
if (isinstance(component, LLMBaseConfig)):
|
|
111
|
+
return ComponentGroup.LLMS
|
|
112
|
+
if (isinstance(component, MemoryBaseConfig)):
|
|
113
|
+
return ComponentGroup.MEMORY
|
|
114
|
+
if (isinstance(component, ObjectStoreBaseConfig)):
|
|
115
|
+
return ComponentGroup.OBJECT_STORES
|
|
116
|
+
if (isinstance(component, RetrieverBaseConfig)):
|
|
117
|
+
return ComponentGroup.RETRIEVERS
|
|
118
|
+
if (isinstance(component, TTCStrategyBaseConfig)):
|
|
119
|
+
return ComponentGroup.TTC_STRATEGIES
|
|
120
|
+
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def recursive_componentref_discovery(cls: TypedBaseModel, value: typing.Any,
|
|
125
|
+
type_hint: type[typing.Any]) -> Generator[tuple[str, ComponentRefNode]]:
|
|
126
|
+
"""Discovers instances of ComponentRefs in a configuration object and updates the dependency graph.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
cls (TypedBaseModel): A configuration object for a runtime instance.
|
|
130
|
+
value (typing.Any): The current traversed value from the configuration object.
|
|
131
|
+
type_hint (type[typing.Any]): The type of the current traversed value from the configuration object.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
decomposed_type = DecomposedType(type_hint)
|
|
135
|
+
|
|
136
|
+
if (value is None):
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
if ((decomposed_type.origin is None) and (not issubclass(type(value), BaseModel))):
|
|
140
|
+
if issubclass(type(value), ComponentRef):
|
|
141
|
+
instance_id = generate_instance_id(cls)
|
|
142
|
+
value_node = ComponentRefNode(ref_name=value, component_group=value.component_group)
|
|
143
|
+
yield instance_id, value_node
|
|
144
|
+
|
|
145
|
+
elif ((decomposed_type.origin in (tuple, list, set)) and (isinstance(value, Iterable))):
|
|
146
|
+
for v in value:
|
|
147
|
+
yield from recursive_componentref_discovery(cls, v, decomposed_type.args[0])
|
|
148
|
+
elif ((decomposed_type.origin in (dict, type(typing.TypedDict))) and (isinstance(value, dict))):
|
|
149
|
+
for v in value.values():
|
|
150
|
+
yield from recursive_componentref_discovery(cls, v, decomposed_type.args[1])
|
|
151
|
+
elif (issubclass(type(value), BaseModel)):
|
|
152
|
+
for field, field_info in value.model_fields.items():
|
|
153
|
+
field_data = getattr(value, field)
|
|
154
|
+
yield from recursive_componentref_discovery(cls, field_data, field_info.annotation)
|
|
155
|
+
if (decomposed_type.is_union):
|
|
156
|
+
for arg in decomposed_type.args:
|
|
157
|
+
if arg is typing.Any or (isinstance(value, DecomposedType(arg).root)):
|
|
158
|
+
yield from recursive_componentref_discovery(cls, value, arg)
|
|
159
|
+
else:
|
|
160
|
+
for arg in decomposed_type.args:
|
|
161
|
+
yield from recursive_componentref_discovery(cls, value, arg)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def update_dependency_graph(config: "Config", instance_config: TypedBaseModel,
|
|
165
|
+
dependency_graph: nx.DiGraph) -> nx.DiGraph:
|
|
166
|
+
"""Updates the hierarchical component instance dependency graph from a configuration runtime instance.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
config (Config): A NAT configuration object with runtime instance details.
|
|
170
|
+
instance_config (TypedBaseModel): A component's runtime instance configuration object.
|
|
171
|
+
dependency_graph (nx.DiGraph): A graph tracking runtime instance component dependencies.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
nx.DiGraph: An dependency graph that has been updated with the provided runtime instance.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
for field_name, field_info in instance_config.model_fields.items():
|
|
178
|
+
|
|
179
|
+
for instance_id, value_node in recursive_componentref_discovery(
|
|
180
|
+
instance_config,
|
|
181
|
+
getattr(instance_config, field_name),
|
|
182
|
+
field_info.annotation): # type: ignore
|
|
183
|
+
|
|
184
|
+
# add immediate edge
|
|
185
|
+
dependency_graph.add_edge(instance_id, value_node)
|
|
186
|
+
# add dependency edge to ensure connections to leaf nodes exist
|
|
187
|
+
dependency_component_dict = getattr(config, value_node.component_group)
|
|
188
|
+
dependency_component_instance_config = dependency_component_dict.get(value_node.ref_name)
|
|
189
|
+
dependency_component_instance_id = generate_instance_id(dependency_component_instance_config)
|
|
190
|
+
dependency_graph.add_edge(value_node, dependency_component_instance_id)
|
|
191
|
+
|
|
192
|
+
return dependency_graph
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def config_to_dependency_objects(config: "Config") -> tuple[dict[str, ComponentInstanceData], nx.DiGraph]:
|
|
196
|
+
"""Generates a map of component runtime instance IDs to use when generating a build sequence.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
config (Config): The NAT workflow configuration object.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
tuple[dict[str, ComponentInstanceData], nx.DiGraph]: A tuple containing a map of component runtime instance
|
|
203
|
+
IDs to a component object containing its metadata and a dependency graph of nested components.
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
# Build map of every runtime instances
|
|
207
|
+
dependency_map: dict[str, ComponentInstanceData] = {}
|
|
208
|
+
dependency_graph: nx.DiGraph = nx.DiGraph()
|
|
209
|
+
|
|
210
|
+
# Create the dependency map preserving as much order as we can
|
|
211
|
+
for group in _component_group_order:
|
|
212
|
+
|
|
213
|
+
component_dict = getattr(config, group.value)
|
|
214
|
+
|
|
215
|
+
assert isinstance(component_dict, dict), "Config components must be a dictionary"
|
|
216
|
+
|
|
217
|
+
for component_instance_name, component_instance_config in component_dict.items():
|
|
218
|
+
|
|
219
|
+
instance_id = generate_instance_id(component_instance_config)
|
|
220
|
+
dependency_map[instance_id] = ComponentInstanceData(component_group=group,
|
|
221
|
+
instance_id=instance_id,
|
|
222
|
+
name=component_instance_name,
|
|
223
|
+
config=component_instance_config)
|
|
224
|
+
|
|
225
|
+
dependency_graph = update_dependency_graph(config=config,
|
|
226
|
+
instance_config=component_instance_config,
|
|
227
|
+
dependency_graph=dependency_graph)
|
|
228
|
+
|
|
229
|
+
# Set the workflow flag on the workflow instance (must be last)
|
|
230
|
+
workflow_instance_id = generate_instance_id(config.workflow)
|
|
231
|
+
|
|
232
|
+
dependency_map[workflow_instance_id] = ComponentInstanceData(
|
|
233
|
+
component_group=ComponentGroup.FUNCTIONS,
|
|
234
|
+
instance_id=workflow_instance_id,
|
|
235
|
+
name="<workflow>", # type: ignore
|
|
236
|
+
config=config.workflow,
|
|
237
|
+
is_root=True)
|
|
238
|
+
|
|
239
|
+
dependency_graph = update_dependency_graph(config=config,
|
|
240
|
+
instance_config=config.workflow,
|
|
241
|
+
dependency_graph=dependency_graph)
|
|
242
|
+
|
|
243
|
+
return dependency_map, dependency_graph
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def build_dependency_sequence(config: "Config") -> list[ComponentInstanceData]:
|
|
247
|
+
"""Generates the depencency sequence from a NAT configuration object
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
config (Config): A NAT configuration object.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
list[ComponentInstanceData]: A list representing the instatiation sequence to ensure all valid
|
|
254
|
+
runtime instance references.
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
total_node_count = len(config.embedders) + len(config.functions) + len(config.llms) + len(config.memory) + len(
|
|
258
|
+
config.object_stores) + len(config.retrievers) + len(config.ttc_strategies) + len(
|
|
259
|
+
config.authentication) + 1 # +1 for the workflow
|
|
260
|
+
|
|
261
|
+
dependency_map: dict
|
|
262
|
+
dependency_graph: nx.DiGraph
|
|
263
|
+
dependency_map, dependency_graph = config_to_dependency_objects(config=config)
|
|
264
|
+
|
|
265
|
+
dependency_sequence: list[ComponentInstanceData] = []
|
|
266
|
+
instance_ids = set()
|
|
267
|
+
for node in iterate_leaf_to_root(dependency_graph.copy()): # type: ignore
|
|
268
|
+
|
|
269
|
+
if (node not in dependency_sequence):
|
|
270
|
+
|
|
271
|
+
# Convert node to id
|
|
272
|
+
if (isinstance(node, ComponentRefNode) and issubclass(type(node.ref_name), ComponentRef)):
|
|
273
|
+
|
|
274
|
+
component_group_configs = getattr(config, node.component_group.value)
|
|
275
|
+
node_config = component_group_configs.get(node.ref_name, None)
|
|
276
|
+
|
|
277
|
+
# Only add nodes that are valid in the current instance configuration
|
|
278
|
+
if (node_config is None):
|
|
279
|
+
continue
|
|
280
|
+
|
|
281
|
+
component_instance = ComponentInstanceData(
|
|
282
|
+
name=node.ref_name,
|
|
283
|
+
component_group=node.component_group.value, # type: ignore
|
|
284
|
+
config=node_config,
|
|
285
|
+
instance_id=generate_instance_id(node_config))
|
|
286
|
+
|
|
287
|
+
else:
|
|
288
|
+
|
|
289
|
+
component_instance = dependency_map.get(node, None)
|
|
290
|
+
|
|
291
|
+
# Only add nodes that are valid in the current instance configuration
|
|
292
|
+
if (component_instance is None):
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
if (component_instance.instance_id not in instance_ids):
|
|
296
|
+
|
|
297
|
+
dependency_sequence.append(component_instance)
|
|
298
|
+
instance_ids.add(component_instance.instance_id)
|
|
299
|
+
|
|
300
|
+
remaining_dependency_sequence: list[ComponentInstanceData] = []
|
|
301
|
+
|
|
302
|
+
# Find the remaining nodes that are not in the sequence preserving order
|
|
303
|
+
for instance_id, instance in dependency_map.items():
|
|
304
|
+
if (instance_id not in instance_ids):
|
|
305
|
+
remaining_dependency_sequence.append(instance)
|
|
306
|
+
|
|
307
|
+
# Add the remaining at the front of the sequence
|
|
308
|
+
dependency_sequence = remaining_dependency_sequence + dependency_sequence
|
|
309
|
+
|
|
310
|
+
# Find the root node and make sure it is the last node in the sequence
|
|
311
|
+
dependency_sequence = [x for x in dependency_sequence if not x.is_root
|
|
312
|
+
] + [x for x in dependency_sequence if x.is_root]
|
|
313
|
+
|
|
314
|
+
assert len(dependency_sequence) == total_node_count, "Dependency sequence generation failed. Report as bug."
|
|
315
|
+
|
|
316
|
+
return dependency_sequence
|