nvidia-nat 1.2.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +239 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +355 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +41 -0
- aiq/agent/react_agent/register.py +149 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +225 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +411 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +119 -0
- aiq/agent/tool_calling_agent/register.py +106 -0
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +96 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +285 -0
- aiq/builder/component_utils.py +316 -0
- aiq/builder/context.py +264 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +161 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +344 -0
- aiq/builder/function_base.py +380 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +174 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +74 -0
- aiq/builder/workflow.py +148 -0
- aiq/builder/workflow_builder.py +1117 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +231 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +36 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +39 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/info/list_mcp.py +213 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +38 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +297 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +246 -0
- aiq/cli/commands/uninstall.py +81 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +313 -0
- aiq/cli/entrypoint.py +135 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +488 -0
- aiq/cli/type_registry.py +1000 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +694 -0
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +171 -0
- aiq/data_models/component.py +54 -0
- aiq/data_models/component_ref.py +168 -0
- aiq/data_models/config.py +406 -0
- aiq/data_models/dataset_handler.py +123 -0
- aiq/data_models/discovery_metadata.py +335 -0
- aiq/data_models/embedder.py +27 -0
- aiq/data_models/evaluate.py +127 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +72 -0
- aiq/data_models/interactive.py +246 -0
- aiq/data_models/intermediate_step.py +302 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +27 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +27 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/profiler.py +54 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/data_models/ttc_strategy.py +30 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +59 -0
- aiq/embedder/openai_embedder.py +43 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +60 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +254 -0
- aiq/eval/evaluate.py +506 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/base_evaluator.py +73 -0
- aiq/eval/evaluator/evaluator_model.py +45 -0
- aiq/eval/intermediate_step_adapter.py +99 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +178 -0
- aiq/eval/rag_evaluator/register.py +143 -0
- aiq/eval/register.py +23 -0
- aiq/eval/remote_workflow.py +133 -0
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +75 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/tunable_rag_evaluator/__init__.py +0 -0
- aiq/eval/tunable_rag_evaluator/evaluate.py +245 -0
- aiq/eval/tunable_rag_evaluator/register.py +52 -0
- aiq/eval/usage_stats.py +41 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +140 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/eval/utils/weave_eval.py +184 -0
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/test_time_compute/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/test_time_compute/functions/__init__.py +0 -0
- aiq/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- aiq/experimental/test_time_compute/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/test_time_compute/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/test_time_compute/models/__init__.py +0 -0
- aiq/experimental/test_time_compute/models/editor_config.py +132 -0
- aiq/experimental/test_time_compute/models/scoring_config.py +112 -0
- aiq/experimental/test_time_compute/models/search_config.py +120 -0
- aiq/experimental/test_time_compute/models/selection_config.py +154 -0
- aiq/experimental/test_time_compute/models/stage_enums.py +43 -0
- aiq/experimental/test_time_compute/models/strategy_base.py +66 -0
- aiq/experimental/test_time_compute/models/tool_use_config.py +41 -0
- aiq/experimental/test_time_compute/models/ttc_item.py +48 -0
- aiq/experimental/test_time_compute/register.py +36 -0
- aiq/experimental/test_time_compute/scoring/__init__.py +0 -0
- aiq/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/test_time_compute/search/__init__.py +0 -0
- aiq/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- aiq/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/test_time_compute/selection/__init__.py +0 -0
- aiq/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- aiq/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +96 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +234 -0
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1092 -0
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +183 -0
- aiq/front_ends/fastapi/main.py +72 -0
- aiq/front_ends/fastapi/message_handler.py +298 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +195 -0
- aiq/front_ends/fastapi/step_adaptor.py +321 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/aws_bedrock_llm.py +57 -0
- aiq/llm/nim_llm.py +46 -0
- aiq/llm/openai_llm.py +46 -0
- aiq/llm/register.py +23 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +112 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +58 -0
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +76 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +322 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +265 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +310 -0
- aiq/observability/processor/callback_processor.py +42 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +71 -0
- aiq/observability/register.py +96 -0
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +627 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +188 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +290 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/data_models.py +24 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +28 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +473 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +567 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +82 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +215 -0
- aiq/runtime/runner.py +190 -0
- aiq/runtime/session.py +158 -0
- aiq/runtime/user_metadata.py +130 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +151 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +267 -0
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- aiq/tool/code_execution/register.py +74 -0
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +255 -0
- aiq/tool/mcp/mcp_tool.py +96 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +79 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +38 -0
- aiq/tool/retriever.py +89 -0
- aiq/tool/server_tools.py +66 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/dump_distro_mapping.py +32 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/io/yaml_tools.py +119 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/optional_imports.py +142 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +290 -0
- aiq/utils/type_utils.py +484 -0
- aiq/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0rc5.dist-info/METADATA +363 -0
- nvidia_nat-1.2.0rc5.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0rc5.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0rc5.dist-info/entry_points.txt +20 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0rc5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1117 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import dataclasses
|
|
17
|
+
import inspect
|
|
18
|
+
import logging
|
|
19
|
+
import warnings
|
|
20
|
+
from contextlib import AbstractAsyncContextManager
|
|
21
|
+
from contextlib import AsyncExitStack
|
|
22
|
+
from contextlib import asynccontextmanager
|
|
23
|
+
|
|
24
|
+
from aiq.authentication.interfaces import AuthProviderBase
|
|
25
|
+
from aiq.builder.builder import Builder
|
|
26
|
+
from aiq.builder.builder import UserManagerHolder
|
|
27
|
+
from aiq.builder.component_utils import ComponentInstanceData
|
|
28
|
+
from aiq.builder.component_utils import build_dependency_sequence
|
|
29
|
+
from aiq.builder.context import AIQContext
|
|
30
|
+
from aiq.builder.context import AIQContextState
|
|
31
|
+
from aiq.builder.embedder import EmbedderProviderInfo
|
|
32
|
+
from aiq.builder.framework_enum import LLMFrameworkEnum
|
|
33
|
+
from aiq.builder.function import Function
|
|
34
|
+
from aiq.builder.function import LambdaFunction
|
|
35
|
+
from aiq.builder.function_info import FunctionInfo
|
|
36
|
+
from aiq.builder.llm import LLMProviderInfo
|
|
37
|
+
from aiq.builder.retriever import RetrieverProviderInfo
|
|
38
|
+
from aiq.builder.workflow import Workflow
|
|
39
|
+
from aiq.cli.type_registry import GlobalTypeRegistry
|
|
40
|
+
from aiq.cli.type_registry import TypeRegistry
|
|
41
|
+
from aiq.data_models.authentication import AuthProviderBaseConfig
|
|
42
|
+
from aiq.data_models.component import ComponentGroup
|
|
43
|
+
from aiq.data_models.component_ref import AuthenticationRef
|
|
44
|
+
from aiq.data_models.component_ref import EmbedderRef
|
|
45
|
+
from aiq.data_models.component_ref import FunctionRef
|
|
46
|
+
from aiq.data_models.component_ref import LLMRef
|
|
47
|
+
from aiq.data_models.component_ref import MemoryRef
|
|
48
|
+
from aiq.data_models.component_ref import ObjectStoreRef
|
|
49
|
+
from aiq.data_models.component_ref import RetrieverRef
|
|
50
|
+
from aiq.data_models.component_ref import TTCStrategyRef
|
|
51
|
+
from aiq.data_models.config import AIQConfig
|
|
52
|
+
from aiq.data_models.config import GeneralConfig
|
|
53
|
+
from aiq.data_models.embedder import EmbedderBaseConfig
|
|
54
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
55
|
+
from aiq.data_models.function_dependencies import FunctionDependencies
|
|
56
|
+
from aiq.data_models.llm import LLMBaseConfig
|
|
57
|
+
from aiq.data_models.memory import MemoryBaseConfig
|
|
58
|
+
from aiq.data_models.object_store import ObjectStoreBaseConfig
|
|
59
|
+
from aiq.data_models.retriever import RetrieverBaseConfig
|
|
60
|
+
from aiq.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
|
61
|
+
from aiq.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
62
|
+
from aiq.experimental.decorators.experimental_warning_decorator import aiq_experimental
|
|
63
|
+
from aiq.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
64
|
+
from aiq.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
65
|
+
from aiq.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
66
|
+
from aiq.memory.interfaces import MemoryEditor
|
|
67
|
+
from aiq.object_store.interfaces import ObjectStore
|
|
68
|
+
from aiq.observability.exporter.base_exporter import BaseExporter
|
|
69
|
+
from aiq.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
|
|
70
|
+
from aiq.profiler.utils import detect_llm_frameworks_in_build_fn
|
|
71
|
+
from aiq.utils.type_utils import override
|
|
72
|
+
|
|
73
|
+
logger = logging.getLogger(__name__)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclasses.dataclass
|
|
77
|
+
class ConfiguredTelemetryExporter:
|
|
78
|
+
config: TelemetryExporterBaseConfig
|
|
79
|
+
instance: BaseExporter
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclasses.dataclass
|
|
83
|
+
class ConfiguredFunction:
|
|
84
|
+
config: FunctionBaseConfig
|
|
85
|
+
instance: Function
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclasses.dataclass
|
|
89
|
+
class ConfiguredLLM:
|
|
90
|
+
config: LLMBaseConfig
|
|
91
|
+
instance: LLMProviderInfo
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclasses.dataclass
|
|
95
|
+
class ConfiguredEmbedder:
|
|
96
|
+
config: EmbedderBaseConfig
|
|
97
|
+
instance: EmbedderProviderInfo
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclasses.dataclass
|
|
101
|
+
class ConfiguredMemory:
|
|
102
|
+
config: MemoryBaseConfig
|
|
103
|
+
instance: MemoryEditor
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclasses.dataclass
|
|
107
|
+
class ConfiguredObjectStore:
|
|
108
|
+
config: ObjectStoreBaseConfig
|
|
109
|
+
instance: ObjectStore
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclasses.dataclass
|
|
113
|
+
class ConfiguredRetriever:
|
|
114
|
+
config: RetrieverBaseConfig
|
|
115
|
+
instance: RetrieverProviderInfo
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@dataclasses.dataclass
|
|
119
|
+
class ConfiguredAuthProvider:
|
|
120
|
+
config: AuthProviderBaseConfig
|
|
121
|
+
instance: AuthProviderBase
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclasses.dataclass
|
|
125
|
+
class ConfiguredTTCStrategy:
|
|
126
|
+
config: TTCStrategyBaseConfig
|
|
127
|
+
instance: StrategyBase
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# pylint: disable=too-many-public-methods
|
|
131
|
+
class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
132
|
+
|
|
133
|
+
def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
|
|
134
|
+
|
|
135
|
+
if general_config is None:
|
|
136
|
+
general_config = GeneralConfig()
|
|
137
|
+
|
|
138
|
+
if registry is None:
|
|
139
|
+
registry = GlobalTypeRegistry.get()
|
|
140
|
+
|
|
141
|
+
self.general_config = general_config
|
|
142
|
+
|
|
143
|
+
self._registry = registry
|
|
144
|
+
|
|
145
|
+
self._logging_handlers: dict[str, logging.Handler] = {}
|
|
146
|
+
self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
|
|
147
|
+
|
|
148
|
+
self._functions: dict[str, ConfiguredFunction] = {}
|
|
149
|
+
self._workflow: ConfiguredFunction | None = None
|
|
150
|
+
|
|
151
|
+
self._llms: dict[str, ConfiguredLLM] = {}
|
|
152
|
+
self._auth_providers: dict[str, ConfiguredAuthProvider] = {}
|
|
153
|
+
self._embedders: dict[str, ConfiguredEmbedder] = {}
|
|
154
|
+
self._memory_clients: dict[str, ConfiguredMemory] = {}
|
|
155
|
+
self._object_stores: dict[str, ConfiguredObjectStore] = {}
|
|
156
|
+
self._retrievers: dict[str, ConfiguredRetriever] = {}
|
|
157
|
+
self._ttc_strategies: dict[str, ConfiguredTTCStrategy] = {}
|
|
158
|
+
|
|
159
|
+
self._context_state = AIQContextState.get()
|
|
160
|
+
|
|
161
|
+
self._exit_stack: AsyncExitStack | None = None
|
|
162
|
+
|
|
163
|
+
# Create a mapping to track function name -> other function names it depends on
|
|
164
|
+
self.function_dependencies: dict[str, FunctionDependencies] = {}
|
|
165
|
+
self.current_function_building: str | None = None
|
|
166
|
+
|
|
167
|
+
async def __aenter__(self):
|
|
168
|
+
|
|
169
|
+
self._exit_stack = AsyncExitStack()
|
|
170
|
+
|
|
171
|
+
# Get the telemetry info from the config
|
|
172
|
+
telemetry_config = self.general_config.telemetry
|
|
173
|
+
|
|
174
|
+
for key, logging_config in telemetry_config.logging.items():
|
|
175
|
+
# Use the same pattern as tracing, but for logging
|
|
176
|
+
logging_info = self._registry.get_logging_method(type(logging_config))
|
|
177
|
+
handler = await self._exit_stack.enter_async_context(logging_info.build_fn(logging_config, self))
|
|
178
|
+
|
|
179
|
+
# Type check
|
|
180
|
+
if not isinstance(handler, logging.Handler):
|
|
181
|
+
raise TypeError(f"Expected a logging.Handler from {key}, got {type(handler)}")
|
|
182
|
+
|
|
183
|
+
# Store them in a dict so we can un-register them if needed
|
|
184
|
+
self._logging_handlers[key] = handler
|
|
185
|
+
|
|
186
|
+
# Now attach to AIQ Toolkit's root logger
|
|
187
|
+
logging.getLogger().addHandler(handler)
|
|
188
|
+
|
|
189
|
+
# Add the telemetry exporters
|
|
190
|
+
for key, telemetry_exporter_config in telemetry_config.tracing.items():
|
|
191
|
+
await self.add_telemetry_exporter(key, telemetry_exporter_config)
|
|
192
|
+
|
|
193
|
+
return self
|
|
194
|
+
|
|
195
|
+
async def __aexit__(self, *exc_details):
|
|
196
|
+
|
|
197
|
+
assert self._exit_stack is not None, "Exit stack not initialized"
|
|
198
|
+
|
|
199
|
+
for _, handler in self._logging_handlers.items():
|
|
200
|
+
logging.getLogger().removeHandler(handler)
|
|
201
|
+
|
|
202
|
+
await self._exit_stack.__aexit__(*exc_details)
|
|
203
|
+
|
|
204
|
+
def build(self, entry_function: str | None = None) -> Workflow:
|
|
205
|
+
"""
|
|
206
|
+
Creates an instance of a workflow object using the added components and the desired entry function.
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
entry_function : str | None, optional
|
|
211
|
+
The function name to use as the entry point for the created workflow. If None, the entry point will be the
|
|
212
|
+
specified workflow function. By default None
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
-------
|
|
216
|
+
Workflow
|
|
217
|
+
A created workflow.
|
|
218
|
+
|
|
219
|
+
Raises
|
|
220
|
+
------
|
|
221
|
+
ValueError
|
|
222
|
+
If the workflow has not been set before building.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
if (self._workflow is None):
|
|
226
|
+
raise ValueError("Must set a workflow before building")
|
|
227
|
+
|
|
228
|
+
# Build the config from the added objects
|
|
229
|
+
config = AIQConfig(general=self.general_config,
|
|
230
|
+
functions={
|
|
231
|
+
k: v.config
|
|
232
|
+
for k, v in self._functions.items()
|
|
233
|
+
},
|
|
234
|
+
workflow=self._workflow.config,
|
|
235
|
+
llms={
|
|
236
|
+
k: v.config
|
|
237
|
+
for k, v in self._llms.items()
|
|
238
|
+
},
|
|
239
|
+
embedders={
|
|
240
|
+
k: v.config
|
|
241
|
+
for k, v in self._embedders.items()
|
|
242
|
+
},
|
|
243
|
+
memory={
|
|
244
|
+
k: v.config
|
|
245
|
+
for k, v in self._memory_clients.items()
|
|
246
|
+
},
|
|
247
|
+
object_stores={
|
|
248
|
+
k: v.config
|
|
249
|
+
for k, v in self._object_stores.items()
|
|
250
|
+
},
|
|
251
|
+
retrievers={
|
|
252
|
+
k: v.config
|
|
253
|
+
for k, v in self._retrievers.items()
|
|
254
|
+
},
|
|
255
|
+
ttc_strategies={
|
|
256
|
+
k: v.config
|
|
257
|
+
for k, v in self._ttc_strategies.items()
|
|
258
|
+
})
|
|
259
|
+
|
|
260
|
+
if (entry_function is None):
|
|
261
|
+
entry_fn_obj = self.get_workflow()
|
|
262
|
+
else:
|
|
263
|
+
entry_fn_obj = self.get_function(entry_function)
|
|
264
|
+
|
|
265
|
+
workflow = Workflow.from_entry_fn(config=config,
|
|
266
|
+
entry_fn=entry_fn_obj,
|
|
267
|
+
functions={
|
|
268
|
+
k: v.instance
|
|
269
|
+
for k, v in self._functions.items()
|
|
270
|
+
},
|
|
271
|
+
llms={
|
|
272
|
+
k: v.instance
|
|
273
|
+
for k, v in self._llms.items()
|
|
274
|
+
},
|
|
275
|
+
embeddings={
|
|
276
|
+
k: v.instance
|
|
277
|
+
for k, v in self._embedders.items()
|
|
278
|
+
},
|
|
279
|
+
memory={
|
|
280
|
+
k: v.instance
|
|
281
|
+
for k, v in self._memory_clients.items()
|
|
282
|
+
},
|
|
283
|
+
object_stores={
|
|
284
|
+
k: v.instance
|
|
285
|
+
for k, v in self._object_stores.items()
|
|
286
|
+
},
|
|
287
|
+
telemetry_exporters={
|
|
288
|
+
k: v.instance
|
|
289
|
+
for k, v in self._telemetry_exporters.items()
|
|
290
|
+
},
|
|
291
|
+
retrievers={
|
|
292
|
+
k: v.instance
|
|
293
|
+
for k, v in self._retrievers.items()
|
|
294
|
+
},
|
|
295
|
+
ttc_strategies={
|
|
296
|
+
k: v.instance
|
|
297
|
+
for k, v in self._ttc_strategies.items()
|
|
298
|
+
},
|
|
299
|
+
context_state=self._context_state)
|
|
300
|
+
|
|
301
|
+
return workflow
|
|
302
|
+
|
|
303
|
+
def _get_exit_stack(self) -> AsyncExitStack:
|
|
304
|
+
|
|
305
|
+
if self._exit_stack is None:
|
|
306
|
+
raise ValueError(
|
|
307
|
+
"Exit stack not initialized. Did you forget to call `async with WorkflowBuilder() as builder`?")
|
|
308
|
+
|
|
309
|
+
return self._exit_stack
|
|
310
|
+
|
|
311
|
+
async def _build_function(self, name: str, config: FunctionBaseConfig) -> ConfiguredFunction:
|
|
312
|
+
registration = self._registry.get_function(type(config))
|
|
313
|
+
|
|
314
|
+
inner_builder = ChildBuilder(self)
|
|
315
|
+
|
|
316
|
+
# We need to do this for every function because we don't know
|
|
317
|
+
# Where LLama Index Agents are Instantiated and Settings need to
|
|
318
|
+
# be set before the function is built
|
|
319
|
+
# It's only slower the first time because of the import
|
|
320
|
+
# So we can afford to do this for every function
|
|
321
|
+
|
|
322
|
+
llms = {k: v.instance for k, v in self._llms.items()}
|
|
323
|
+
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
324
|
+
|
|
325
|
+
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
326
|
+
|
|
327
|
+
# Set the currently building function so the ChildBuilder can track dependencies
|
|
328
|
+
self.current_function_building = config.type
|
|
329
|
+
# Empty set of dependencies for the current function
|
|
330
|
+
self.function_dependencies[config.type] = FunctionDependencies()
|
|
331
|
+
|
|
332
|
+
build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
|
|
333
|
+
|
|
334
|
+
self.function_dependencies[name] = inner_builder.dependencies
|
|
335
|
+
|
|
336
|
+
# If the build result is a function, wrap it in a FunctionInfo
|
|
337
|
+
if inspect.isfunction(build_result):
|
|
338
|
+
|
|
339
|
+
build_result = FunctionInfo.from_fn(build_result)
|
|
340
|
+
|
|
341
|
+
if (isinstance(build_result, FunctionInfo)):
|
|
342
|
+
# Create the function object
|
|
343
|
+
build_result = LambdaFunction.from_info(config=config, info=build_result, instance_name=name)
|
|
344
|
+
|
|
345
|
+
if (not isinstance(build_result, Function)):
|
|
346
|
+
raise ValueError("Expected a function, FunctionInfo object, or FunctionBase object to be "
|
|
347
|
+
f"returned from the function builder. Got {type(build_result)}")
|
|
348
|
+
|
|
349
|
+
return ConfiguredFunction(config=config, instance=build_result)
|
|
350
|
+
|
|
351
|
+
@override
|
|
352
|
+
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
353
|
+
|
|
354
|
+
if (name in self._functions):
|
|
355
|
+
raise ValueError(f"Function `{name}` already exists in the list of functions")
|
|
356
|
+
|
|
357
|
+
build_result = await self._build_function(name=name, config=config)
|
|
358
|
+
|
|
359
|
+
self._functions[name] = build_result
|
|
360
|
+
|
|
361
|
+
return build_result.instance
|
|
362
|
+
|
|
363
|
+
@override
|
|
364
|
+
def get_function(self, name: str | FunctionRef) -> Function:
|
|
365
|
+
|
|
366
|
+
if name not in self._functions:
|
|
367
|
+
raise ValueError(f"Function `{name}` not found")
|
|
368
|
+
|
|
369
|
+
return self._functions[name].instance
|
|
370
|
+
|
|
371
|
+
@override
|
|
372
|
+
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
373
|
+
if name not in self._functions:
|
|
374
|
+
raise ValueError(f"Function `{name}` not found")
|
|
375
|
+
|
|
376
|
+
return self._functions[name].config
|
|
377
|
+
|
|
378
|
+
@override
|
|
379
|
+
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
380
|
+
|
|
381
|
+
if self._workflow is not None:
|
|
382
|
+
warnings.warn("Overwriting existing workflow")
|
|
383
|
+
|
|
384
|
+
build_result = await self._build_function(name="<workflow>", config=config)
|
|
385
|
+
|
|
386
|
+
self._workflow = build_result
|
|
387
|
+
|
|
388
|
+
return build_result.instance
|
|
389
|
+
|
|
390
|
+
@override
|
|
391
|
+
def get_workflow(self) -> Function:
|
|
392
|
+
|
|
393
|
+
if self._workflow is None:
|
|
394
|
+
raise ValueError("No workflow set")
|
|
395
|
+
|
|
396
|
+
return self._workflow.instance
|
|
397
|
+
|
|
398
|
+
@override
|
|
399
|
+
def get_workflow_config(self) -> FunctionBaseConfig:
|
|
400
|
+
if self._workflow is None:
|
|
401
|
+
raise ValueError("No workflow set")
|
|
402
|
+
|
|
403
|
+
return self._workflow.config
|
|
404
|
+
|
|
405
|
+
@override
|
|
406
|
+
def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
|
|
407
|
+
return self.function_dependencies[fn_name]
|
|
408
|
+
|
|
409
|
+
@override
|
|
410
|
+
def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
|
|
411
|
+
|
|
412
|
+
if fn_name not in self._functions:
|
|
413
|
+
raise ValueError(f"Function `{fn_name}` not found in list of functions")
|
|
414
|
+
|
|
415
|
+
fn = self._functions[fn_name]
|
|
416
|
+
|
|
417
|
+
try:
|
|
418
|
+
# Using the registry, get the tool wrapper for the requested framework
|
|
419
|
+
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
420
|
+
|
|
421
|
+
# Wrap in the correct wrapper
|
|
422
|
+
return tool_wrapper_reg.build_fn(fn_name, fn.instance, self)
|
|
423
|
+
except Exception as e:
|
|
424
|
+
logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
|
|
425
|
+
raise e
|
|
426
|
+
|
|
427
|
+
@override
|
|
428
|
+
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
|
|
429
|
+
|
|
430
|
+
if (name in self._llms):
|
|
431
|
+
raise ValueError(f"LLM `{name}` already exists in the list of LLMs")
|
|
432
|
+
|
|
433
|
+
try:
|
|
434
|
+
llm_info = self._registry.get_llm_provider(type(config))
|
|
435
|
+
|
|
436
|
+
info_obj = await self._get_exit_stack().enter_async_context(llm_info.build_fn(config, self))
|
|
437
|
+
|
|
438
|
+
self._llms[name] = ConfiguredLLM(config=config, instance=info_obj)
|
|
439
|
+
except Exception as e:
|
|
440
|
+
logger.error("Error adding llm `%s` with config `%s`", name, config, exc_info=True)
|
|
441
|
+
raise e
|
|
442
|
+
|
|
443
|
+
@override
|
|
444
|
+
async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str):
|
|
445
|
+
|
|
446
|
+
if (llm_name not in self._llms):
|
|
447
|
+
raise ValueError(f"LLM `{llm_name}` not found")
|
|
448
|
+
|
|
449
|
+
try:
|
|
450
|
+
# Get llm info
|
|
451
|
+
llm_info = self._llms[llm_name]
|
|
452
|
+
|
|
453
|
+
# Generate wrapped client from registered client info
|
|
454
|
+
client_info = self._registry.get_llm_client(config_type=type(llm_info.config), wrapper_type=wrapper_type)
|
|
455
|
+
|
|
456
|
+
client = await self._get_exit_stack().enter_async_context(client_info.build_fn(llm_info.config, self))
|
|
457
|
+
|
|
458
|
+
# Return a frameworks specific client
|
|
459
|
+
return client
|
|
460
|
+
except Exception as e:
|
|
461
|
+
logger.error("Error getting llm `%s` with wrapper `%s`", llm_name, wrapper_type, exc_info=True)
|
|
462
|
+
raise e
|
|
463
|
+
|
|
464
|
+
@override
|
|
465
|
+
def get_llm_config(self, llm_name: str | LLMRef) -> LLMBaseConfig:
|
|
466
|
+
|
|
467
|
+
if llm_name not in self._llms:
|
|
468
|
+
raise ValueError(f"LLM `{llm_name}` not found")
|
|
469
|
+
|
|
470
|
+
# Return the tool configuration object
|
|
471
|
+
return self._llms[llm_name].config
|
|
472
|
+
|
|
473
|
+
@aiq_experimental(feature_name="Authentication")
|
|
474
|
+
@override
|
|
475
|
+
async def add_auth_provider(self, name: str | AuthenticationRef,
|
|
476
|
+
config: AuthProviderBaseConfig) -> AuthProviderBase:
|
|
477
|
+
"""
|
|
478
|
+
Add an authentication provider to the workflow by constructing it from a configuration object.
|
|
479
|
+
|
|
480
|
+
Note: The Authentication Provider API is experimental and the API may change in future releases.
|
|
481
|
+
|
|
482
|
+
Parameters
|
|
483
|
+
----------
|
|
484
|
+
name : str | AuthenticationRef
|
|
485
|
+
The name of the authentication provider to add.
|
|
486
|
+
config : AuthProviderBaseConfig
|
|
487
|
+
The configuration for the authentication provider.
|
|
488
|
+
|
|
489
|
+
Returns
|
|
490
|
+
-------
|
|
491
|
+
AuthProviderBase
|
|
492
|
+
The authentication provider instance.
|
|
493
|
+
|
|
494
|
+
Raises
|
|
495
|
+
------
|
|
496
|
+
ValueError
|
|
497
|
+
If the authentication provider is already in the list of authentication providers.
|
|
498
|
+
"""
|
|
499
|
+
|
|
500
|
+
if (name in self._auth_providers):
|
|
501
|
+
raise ValueError(f"Authentication `{name}` already exists in the list of Authentication Providers")
|
|
502
|
+
|
|
503
|
+
try:
|
|
504
|
+
authentication_info = self._registry.get_auth_provider(type(config))
|
|
505
|
+
|
|
506
|
+
info_obj = await self._get_exit_stack().enter_async_context(authentication_info.build_fn(config, self))
|
|
507
|
+
|
|
508
|
+
self._auth_providers[name] = ConfiguredAuthProvider(config=config, instance=info_obj)
|
|
509
|
+
|
|
510
|
+
return info_obj
|
|
511
|
+
except Exception as e:
|
|
512
|
+
logger.error("Error adding authentication `%s` with config `%s`", name, config, exc_info=True)
|
|
513
|
+
raise e
|
|
514
|
+
|
|
515
|
+
@override
|
|
516
|
+
async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
|
|
517
|
+
"""
|
|
518
|
+
Get the authentication provider instance for the given name.
|
|
519
|
+
|
|
520
|
+
Note: The Authentication Provider API is experimental and the API may change in future releases.
|
|
521
|
+
|
|
522
|
+
Parameters
|
|
523
|
+
----------
|
|
524
|
+
auth_provider_name : str
|
|
525
|
+
The name of the authentication provider to get.
|
|
526
|
+
|
|
527
|
+
Returns
|
|
528
|
+
-------
|
|
529
|
+
AuthProviderBase
|
|
530
|
+
The authentication provider instance.
|
|
531
|
+
|
|
532
|
+
Raises
|
|
533
|
+
------
|
|
534
|
+
ValueError
|
|
535
|
+
If the authentication provider is not found.
|
|
536
|
+
"""
|
|
537
|
+
|
|
538
|
+
if auth_provider_name not in self._auth_providers:
|
|
539
|
+
raise ValueError(f"Authentication `{auth_provider_name}` not found")
|
|
540
|
+
|
|
541
|
+
return self._auth_providers[auth_provider_name].instance
|
|
542
|
+
|
|
543
|
+
@override
|
|
544
|
+
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
|
|
545
|
+
|
|
546
|
+
if (name in self._embedders):
|
|
547
|
+
raise ValueError(f"Embedder `{name}` already exists in the list of embedders")
|
|
548
|
+
|
|
549
|
+
try:
|
|
550
|
+
embedder_info = self._registry.get_embedder_provider(type(config))
|
|
551
|
+
|
|
552
|
+
info_obj = await self._get_exit_stack().enter_async_context(embedder_info.build_fn(config, self))
|
|
553
|
+
|
|
554
|
+
self._embedders[name] = ConfiguredEmbedder(config=config, instance=info_obj)
|
|
555
|
+
except Exception as e:
|
|
556
|
+
logger.error("Error adding embedder `%s` with config `%s`", name, config, exc_info=True)
|
|
557
|
+
|
|
558
|
+
raise e
|
|
559
|
+
|
|
560
|
+
@override
|
|
561
|
+
async def get_embedder(self, embedder_name: str | EmbedderRef, wrapper_type: LLMFrameworkEnum | str):
|
|
562
|
+
|
|
563
|
+
if (embedder_name not in self._embedders):
|
|
564
|
+
raise ValueError(f"Embedder `{embedder_name}` not found")
|
|
565
|
+
|
|
566
|
+
try:
|
|
567
|
+
# Get embedder info
|
|
568
|
+
embedder_info = self._embedders[embedder_name]
|
|
569
|
+
|
|
570
|
+
# Generate wrapped client from registered client info
|
|
571
|
+
client_info = self._registry.get_embedder_client(config_type=type(embedder_info.config),
|
|
572
|
+
wrapper_type=wrapper_type)
|
|
573
|
+
client = await self._get_exit_stack().enter_async_context(client_info.build_fn(embedder_info.config, self))
|
|
574
|
+
|
|
575
|
+
# Return a frameworks specific client
|
|
576
|
+
return client
|
|
577
|
+
except Exception as e:
|
|
578
|
+
logger.error("Error getting embedder `%s` with wrapper `%s`", embedder_name, wrapper_type, exc_info=True)
|
|
579
|
+
raise e
|
|
580
|
+
|
|
581
|
+
@override
|
|
582
|
+
def get_embedder_config(self, embedder_name: str | EmbedderRef) -> EmbedderBaseConfig:
|
|
583
|
+
|
|
584
|
+
if embedder_name not in self._embedders:
|
|
585
|
+
raise ValueError(f"Tool `{embedder_name}` not found")
|
|
586
|
+
|
|
587
|
+
# Return the tool configuration object
|
|
588
|
+
return self._embedders[embedder_name].config
|
|
589
|
+
|
|
590
|
+
@override
|
|
591
|
+
async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig) -> MemoryEditor:
|
|
592
|
+
|
|
593
|
+
if (name in self._memory_clients):
|
|
594
|
+
raise ValueError(f"Memory `{name}` already exists in the list of memories")
|
|
595
|
+
|
|
596
|
+
memory_info = self._registry.get_memory(type(config))
|
|
597
|
+
|
|
598
|
+
info_obj = await self._get_exit_stack().enter_async_context(memory_info.build_fn(config, self))
|
|
599
|
+
|
|
600
|
+
self._memory_clients[name] = ConfiguredMemory(config=config, instance=info_obj)
|
|
601
|
+
|
|
602
|
+
return info_obj
|
|
603
|
+
|
|
604
|
+
@override
|
|
605
|
+
def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
606
|
+
"""
|
|
607
|
+
Return the instantiated memory client for the given name.
|
|
608
|
+
"""
|
|
609
|
+
if memory_name not in self._memory_clients:
|
|
610
|
+
raise ValueError(f"Memory `{memory_name}` not found")
|
|
611
|
+
|
|
612
|
+
return self._memory_clients[memory_name].instance
|
|
613
|
+
|
|
614
|
+
@override
|
|
615
|
+
def get_memory_client_config(self, memory_name: str | MemoryRef) -> MemoryBaseConfig:
|
|
616
|
+
|
|
617
|
+
if memory_name not in self._memory_clients:
|
|
618
|
+
raise ValueError(f"Memory `{memory_name}` not found")
|
|
619
|
+
|
|
620
|
+
# Return the tool configuration object
|
|
621
|
+
return self._memory_clients[memory_name].config
|
|
622
|
+
|
|
623
|
+
@override
|
|
624
|
+
async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig) -> ObjectStore:
|
|
625
|
+
if name in self._object_stores:
|
|
626
|
+
raise ValueError(f"Object store `{name}` already exists in the list of object stores")
|
|
627
|
+
|
|
628
|
+
object_store_info = self._registry.get_object_store(type(config))
|
|
629
|
+
|
|
630
|
+
info_obj = await self._get_exit_stack().enter_async_context(object_store_info.build_fn(config, self))
|
|
631
|
+
|
|
632
|
+
self._object_stores[name] = ConfiguredObjectStore(config=config, instance=info_obj)
|
|
633
|
+
|
|
634
|
+
return info_obj
|
|
635
|
+
|
|
636
|
+
@override
|
|
637
|
+
async def get_object_store_client(self, object_store_name: str | ObjectStoreRef) -> ObjectStore:
|
|
638
|
+
if object_store_name not in self._object_stores:
|
|
639
|
+
raise ValueError(f"Object store `{object_store_name}` not found")
|
|
640
|
+
|
|
641
|
+
return self._object_stores[object_store_name].instance
|
|
642
|
+
|
|
643
|
+
@override
|
|
644
|
+
def get_object_store_config(self, object_store_name: str | ObjectStoreRef) -> ObjectStoreBaseConfig:
|
|
645
|
+
if object_store_name not in self._object_stores:
|
|
646
|
+
raise ValueError(f"Object store `{object_store_name}` not found")
|
|
647
|
+
|
|
648
|
+
return self._object_stores[object_store_name].config
|
|
649
|
+
|
|
650
|
+
@override
|
|
651
|
+
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
|
|
652
|
+
|
|
653
|
+
if (name in self._retrievers):
|
|
654
|
+
raise ValueError(f"Retriever '{name}' already exists in the list of retrievers")
|
|
655
|
+
|
|
656
|
+
try:
|
|
657
|
+
retriever_info = self._registry.get_retriever_provider(type(config))
|
|
658
|
+
|
|
659
|
+
info_obj = await self._get_exit_stack().enter_async_context(retriever_info.build_fn(config, self))
|
|
660
|
+
|
|
661
|
+
self._retrievers[name] = ConfiguredRetriever(config=config, instance=info_obj)
|
|
662
|
+
|
|
663
|
+
except Exception as e:
|
|
664
|
+
logger.error("Error adding retriever `%s` with config `%s`", name, config, exc_info=True)
|
|
665
|
+
|
|
666
|
+
raise e
|
|
667
|
+
|
|
668
|
+
# return info_obj
|
|
669
|
+
|
|
670
|
+
@override
|
|
671
|
+
async def get_retriever(self,
|
|
672
|
+
retriever_name: str | RetrieverRef,
|
|
673
|
+
wrapper_type: LLMFrameworkEnum | str | None = None):
|
|
674
|
+
|
|
675
|
+
if retriever_name not in self._retrievers:
|
|
676
|
+
raise ValueError(f"Retriever '{retriever_name}' not found")
|
|
677
|
+
|
|
678
|
+
try:
|
|
679
|
+
# Get retriever info
|
|
680
|
+
retriever_info = self._retrievers[retriever_name]
|
|
681
|
+
|
|
682
|
+
# Generate wrapped client from registered client info
|
|
683
|
+
client_info = self._registry.get_retriever_client(config_type=type(retriever_info.config),
|
|
684
|
+
wrapper_type=wrapper_type)
|
|
685
|
+
|
|
686
|
+
client = await self._get_exit_stack().enter_async_context(client_info.build_fn(retriever_info.config, self))
|
|
687
|
+
|
|
688
|
+
# Return a frameworks specific client
|
|
689
|
+
return client
|
|
690
|
+
except Exception as e:
|
|
691
|
+
logger.error("Error getting retriever `%s` with wrapper `%s`", retriever_name, wrapper_type, exc_info=True)
|
|
692
|
+
raise e
|
|
693
|
+
|
|
694
|
+
@override
|
|
695
|
+
async def get_retriever_config(self, retriever_name: str | RetrieverRef) -> RetrieverBaseConfig:
|
|
696
|
+
|
|
697
|
+
if retriever_name not in self._retrievers:
|
|
698
|
+
raise ValueError(f"Retriever `{retriever_name}` not found")
|
|
699
|
+
|
|
700
|
+
return self._retrievers[retriever_name].config
|
|
701
|
+
|
|
702
|
+
@aiq_experimental(feature_name="TTC")
|
|
703
|
+
@override
|
|
704
|
+
async def add_ttc_strategy(self, name: str | str, config: TTCStrategyBaseConfig):
|
|
705
|
+
if (name in self._ttc_strategies):
|
|
706
|
+
raise ValueError(f"TTC strategy '{name}' already exists in the list of TTC strategies")
|
|
707
|
+
|
|
708
|
+
try:
|
|
709
|
+
ttc_strategy_info = self._registry.get_ttc_strategy(type(config))
|
|
710
|
+
|
|
711
|
+
info_obj = await self._get_exit_stack().enter_async_context(ttc_strategy_info.build_fn(config, self))
|
|
712
|
+
|
|
713
|
+
self._ttc_strategies[name] = ConfiguredTTCStrategy(config=config, instance=info_obj)
|
|
714
|
+
|
|
715
|
+
except Exception as e:
|
|
716
|
+
logger.error("Error adding TTC strategy `%s` with config `%s`", name, config, exc_info=True)
|
|
717
|
+
|
|
718
|
+
raise e
|
|
719
|
+
|
|
720
|
+
@override
|
|
721
|
+
async def get_ttc_strategy(self,
|
|
722
|
+
strategy_name: str | TTCStrategyRef,
|
|
723
|
+
pipeline_type: PipelineTypeEnum,
|
|
724
|
+
stage_type: StageTypeEnum) -> StrategyBase:
|
|
725
|
+
|
|
726
|
+
if strategy_name not in self._ttc_strategies:
|
|
727
|
+
raise ValueError(f"TTC strategy '{strategy_name}' not found")
|
|
728
|
+
|
|
729
|
+
try:
|
|
730
|
+
# Get strategy info
|
|
731
|
+
ttc_strategy_info = self._ttc_strategies[strategy_name]
|
|
732
|
+
|
|
733
|
+
instance = ttc_strategy_info.instance
|
|
734
|
+
|
|
735
|
+
if not stage_type == instance.stage_type():
|
|
736
|
+
raise ValueError(f"TTC strategy '{strategy_name}' is not compatible with stage type '{stage_type}'")
|
|
737
|
+
|
|
738
|
+
if pipeline_type not in instance.supported_pipeline_types():
|
|
739
|
+
raise ValueError(
|
|
740
|
+
f"TTC strategy '{strategy_name}' is not compatible with pipeline type '{pipeline_type}'")
|
|
741
|
+
|
|
742
|
+
instance.set_pipeline_type(pipeline_type)
|
|
743
|
+
|
|
744
|
+
return instance
|
|
745
|
+
except Exception as e:
|
|
746
|
+
logger.error("Error getting TTC strategy `%s`", strategy_name, exc_info=True)
|
|
747
|
+
raise e
|
|
748
|
+
|
|
749
|
+
@override
|
|
750
|
+
async def get_ttc_strategy_config(self,
|
|
751
|
+
strategy_name: str | TTCStrategyRef,
|
|
752
|
+
pipeline_type: PipelineTypeEnum,
|
|
753
|
+
stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
|
|
754
|
+
if strategy_name not in self._ttc_strategies:
|
|
755
|
+
raise ValueError(f"TTC strategy '{strategy_name}' not found")
|
|
756
|
+
|
|
757
|
+
strategy_info = self._ttc_strategies[strategy_name]
|
|
758
|
+
instance = strategy_info.instance
|
|
759
|
+
config = strategy_info.config
|
|
760
|
+
|
|
761
|
+
if not stage_type == instance.stage_type():
|
|
762
|
+
raise ValueError(f"TTC strategy '{strategy_name}' is not compatible with stage type '{stage_type}'")
|
|
763
|
+
|
|
764
|
+
if pipeline_type not in instance.supported_pipeline_types():
|
|
765
|
+
raise ValueError(f"TTC strategy '{strategy_name}' is not compatible with pipeline type '{pipeline_type}'")
|
|
766
|
+
|
|
767
|
+
return config
|
|
768
|
+
|
|
769
|
+
@override
|
|
770
|
+
def get_user_manager(self):
|
|
771
|
+
return UserManagerHolder(context=AIQContext(self._context_state))
|
|
772
|
+
|
|
773
|
+
async def add_telemetry_exporter(self, name: str, config: TelemetryExporterBaseConfig) -> None:
|
|
774
|
+
"""Add an configured telemetry exporter to the builder.
|
|
775
|
+
|
|
776
|
+
Args:
|
|
777
|
+
name (str): The name of the telemetry exporter
|
|
778
|
+
config (TelemetryExporterBaseConfig): The configuration for the exporter
|
|
779
|
+
"""
|
|
780
|
+
if (name in self._telemetry_exporters):
|
|
781
|
+
raise ValueError(f"Telemetry exporter '{name}' already exists in the list of telemetry exporters")
|
|
782
|
+
|
|
783
|
+
exporter_info = self._registry.get_telemetry_exporter(type(config))
|
|
784
|
+
|
|
785
|
+
# Build the exporter outside the lock (parallel)
|
|
786
|
+
exporter_context_manager = exporter_info.build_fn(config, self)
|
|
787
|
+
|
|
788
|
+
# Only protect the shared state modifications (serialized)
|
|
789
|
+
exporter = await self._get_exit_stack().enter_async_context(exporter_context_manager)
|
|
790
|
+
self._telemetry_exporters[name] = ConfiguredTelemetryExporter(config=config, instance=exporter)
|
|
791
|
+
|
|
792
|
+
def _log_build_failure(self,
|
|
793
|
+
component_name: str,
|
|
794
|
+
component_type: str,
|
|
795
|
+
completed_components: list[tuple[str, str]],
|
|
796
|
+
remaining_components: list[tuple[str, str]],
|
|
797
|
+
original_error: Exception) -> None:
|
|
798
|
+
"""
|
|
799
|
+
Common method to log comprehensive build failure information.
|
|
800
|
+
|
|
801
|
+
Args:
|
|
802
|
+
component_name (str): The name of the component that failed to build
|
|
803
|
+
component_type (str): The type of the component that failed to build
|
|
804
|
+
completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
|
|
805
|
+
remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
|
|
806
|
+
original_error (Exception): The original exception that caused the failure
|
|
807
|
+
"""
|
|
808
|
+
logger.error("Failed to initialize component %s (%s)", component_name, component_type)
|
|
809
|
+
|
|
810
|
+
if completed_components:
|
|
811
|
+
logger.error("Successfully built components:")
|
|
812
|
+
for name, comp_type in completed_components:
|
|
813
|
+
logger.error("- %s (%s)", name, comp_type)
|
|
814
|
+
else:
|
|
815
|
+
logger.error("No components were successfully built before this failure")
|
|
816
|
+
|
|
817
|
+
if remaining_components:
|
|
818
|
+
logger.error("Remaining components to build:")
|
|
819
|
+
for name, comp_type in remaining_components:
|
|
820
|
+
logger.error("- %s (%s)", name, comp_type)
|
|
821
|
+
else:
|
|
822
|
+
logger.error("No remaining components to build")
|
|
823
|
+
|
|
824
|
+
logger.error("Original error:", exc_info=original_error)
|
|
825
|
+
|
|
826
|
+
def _log_build_failure_component(self,
|
|
827
|
+
failing_component: ComponentInstanceData,
|
|
828
|
+
completed_components: list[tuple[str, str]],
|
|
829
|
+
remaining_components: list[tuple[str, str]],
|
|
830
|
+
original_error: Exception) -> None:
|
|
831
|
+
"""
|
|
832
|
+
Log comprehensive component build failure information.
|
|
833
|
+
|
|
834
|
+
Args:
|
|
835
|
+
failing_component (ComponentInstanceData): The ComponentInstanceData that failed to build
|
|
836
|
+
completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
|
|
837
|
+
remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
|
|
838
|
+
original_error (Exception): The original exception that caused the failure
|
|
839
|
+
"""
|
|
840
|
+
component_name = failing_component.name
|
|
841
|
+
component_type = failing_component.component_group.value
|
|
842
|
+
|
|
843
|
+
self._log_build_failure(component_name,
|
|
844
|
+
component_type,
|
|
845
|
+
completed_components,
|
|
846
|
+
remaining_components,
|
|
847
|
+
original_error)
|
|
848
|
+
|
|
849
|
+
def _log_build_failure_workflow(self,
|
|
850
|
+
completed_components: list[tuple[str, str]],
|
|
851
|
+
remaining_components: list[tuple[str, str]],
|
|
852
|
+
original_error: Exception) -> None:
|
|
853
|
+
"""
|
|
854
|
+
Log comprehensive workflow build failure information.
|
|
855
|
+
|
|
856
|
+
Args:
|
|
857
|
+
completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
|
|
858
|
+
remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
|
|
859
|
+
original_error (Exception): The original exception that caused the failure
|
|
860
|
+
"""
|
|
861
|
+
self._log_build_failure("<workflow>", "workflow", completed_components, remaining_components, original_error)
|
|
862
|
+
|
|
863
|
+
async def populate_builder(self, config: AIQConfig, skip_workflow: bool = False):
|
|
864
|
+
"""
|
|
865
|
+
Populate the builder with components and optionally set up the workflow.
|
|
866
|
+
|
|
867
|
+
Args:
|
|
868
|
+
config (AIQConfig): The configuration object containing component definitions.
|
|
869
|
+
skip_workflow (bool): If True, skips the workflow instantiation step. Defaults to False.
|
|
870
|
+
|
|
871
|
+
"""
|
|
872
|
+
# Generate the build sequence
|
|
873
|
+
build_sequence = build_dependency_sequence(config)
|
|
874
|
+
|
|
875
|
+
# Initialize progress tracking
|
|
876
|
+
completed_components = []
|
|
877
|
+
remaining_components = [(str(comp.name), comp.component_group.value) for comp in build_sequence
|
|
878
|
+
if not comp.is_root]
|
|
879
|
+
if not skip_workflow:
|
|
880
|
+
remaining_components.append(("<workflow>", "workflow"))
|
|
881
|
+
|
|
882
|
+
# Loop over all objects and add to the workflow builder
|
|
883
|
+
for component_instance in build_sequence:
|
|
884
|
+
try:
|
|
885
|
+
# Remove from remaining as we start building (if not root)
|
|
886
|
+
if not component_instance.is_root:
|
|
887
|
+
remaining_components.remove(
|
|
888
|
+
(str(component_instance.name), component_instance.component_group.value))
|
|
889
|
+
|
|
890
|
+
# Instantiate a the llm
|
|
891
|
+
if component_instance.component_group == ComponentGroup.LLMS:
|
|
892
|
+
await self.add_llm(component_instance.name, component_instance.config)
|
|
893
|
+
# Instantiate a the embedder
|
|
894
|
+
elif component_instance.component_group == ComponentGroup.EMBEDDERS:
|
|
895
|
+
await self.add_embedder(component_instance.name, component_instance.config)
|
|
896
|
+
# Instantiate a memory client
|
|
897
|
+
elif component_instance.component_group == ComponentGroup.MEMORY:
|
|
898
|
+
await self.add_memory_client(component_instance.name, component_instance.config)
|
|
899
|
+
# Instantiate a object store client
|
|
900
|
+
elif component_instance.component_group == ComponentGroup.OBJECT_STORES:
|
|
901
|
+
await self.add_object_store(component_instance.name, component_instance.config)
|
|
902
|
+
# Instantiate a retriever client
|
|
903
|
+
elif component_instance.component_group == ComponentGroup.RETRIEVERS:
|
|
904
|
+
await self.add_retriever(component_instance.name, component_instance.config)
|
|
905
|
+
# Instantiate a function
|
|
906
|
+
elif component_instance.component_group == ComponentGroup.FUNCTIONS:
|
|
907
|
+
# If the function is the root, set it as the workflow later
|
|
908
|
+
if (not component_instance.is_root):
|
|
909
|
+
await self.add_function(component_instance.name, component_instance.config)
|
|
910
|
+
elif component_instance.component_group == ComponentGroup.TTC_STRATEGIES:
|
|
911
|
+
await self.add_ttc_strategy(component_instance.name, component_instance.config)
|
|
912
|
+
|
|
913
|
+
elif component_instance.component_group == ComponentGroup.AUTHENTICATION:
|
|
914
|
+
await self.add_auth_provider(component_instance.name, component_instance.config)
|
|
915
|
+
else:
|
|
916
|
+
raise ValueError(f"Unknown component group {component_instance.component_group}")
|
|
917
|
+
|
|
918
|
+
# Add to completed after successful build (if not root)
|
|
919
|
+
if not component_instance.is_root:
|
|
920
|
+
completed_components.append(
|
|
921
|
+
(str(component_instance.name), component_instance.component_group.value))
|
|
922
|
+
|
|
923
|
+
except Exception as e:
|
|
924
|
+
self._log_build_failure_component(component_instance, completed_components, remaining_components, e)
|
|
925
|
+
raise
|
|
926
|
+
|
|
927
|
+
# Instantiate the workflow
|
|
928
|
+
if not skip_workflow:
|
|
929
|
+
try:
|
|
930
|
+
# Remove workflow from remaining as we start building
|
|
931
|
+
remaining_components.remove(("<workflow>", "workflow"))
|
|
932
|
+
await self.set_workflow(config.workflow)
|
|
933
|
+
completed_components.append(("<workflow>", "workflow"))
|
|
934
|
+
except Exception as e:
|
|
935
|
+
self._log_build_failure_workflow(completed_components, remaining_components, e)
|
|
936
|
+
raise
|
|
937
|
+
|
|
938
|
+
@classmethod
|
|
939
|
+
@asynccontextmanager
|
|
940
|
+
async def from_config(cls, config: AIQConfig):
|
|
941
|
+
|
|
942
|
+
async with cls(general_config=config.general) as builder:
|
|
943
|
+
await builder.populate_builder(config)
|
|
944
|
+
yield builder
|
|
945
|
+
|
|
946
|
+
|
|
947
|
+
class ChildBuilder(Builder):
|
|
948
|
+
|
|
949
|
+
def __init__(self, workflow_builder: WorkflowBuilder) -> None:
|
|
950
|
+
|
|
951
|
+
self._workflow_builder = workflow_builder
|
|
952
|
+
|
|
953
|
+
self._dependencies = FunctionDependencies()
|
|
954
|
+
|
|
955
|
+
@property
|
|
956
|
+
def dependencies(self) -> FunctionDependencies:
|
|
957
|
+
return self._dependencies
|
|
958
|
+
|
|
959
|
+
@override
|
|
960
|
+
async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
|
|
961
|
+
return await self._workflow_builder.add_function(name, config)
|
|
962
|
+
|
|
963
|
+
@override
|
|
964
|
+
def get_function(self, name: str) -> Function:
|
|
965
|
+
# If a function tries to get another function, we assume it uses it
|
|
966
|
+
fn = self._workflow_builder.get_function(name)
|
|
967
|
+
|
|
968
|
+
self._dependencies.add_function(name)
|
|
969
|
+
|
|
970
|
+
return fn
|
|
971
|
+
|
|
972
|
+
@override
|
|
973
|
+
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
|
974
|
+
return self._workflow_builder.get_function_config(name)
|
|
975
|
+
|
|
976
|
+
@override
|
|
977
|
+
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
978
|
+
return await self._workflow_builder.set_workflow(config)
|
|
979
|
+
|
|
980
|
+
@override
|
|
981
|
+
def get_workflow(self) -> Function:
|
|
982
|
+
return self._workflow_builder.get_workflow()
|
|
983
|
+
|
|
984
|
+
@override
|
|
985
|
+
def get_workflow_config(self) -> FunctionBaseConfig:
|
|
986
|
+
return self._workflow_builder.get_workflow_config()
|
|
987
|
+
|
|
988
|
+
@override
|
|
989
|
+
def get_tool(self, fn_name: str, wrapper_type: LLMFrameworkEnum | str):
|
|
990
|
+
# If a function tries to get another function as a tool, we assume it uses it
|
|
991
|
+
fn = self._workflow_builder.get_tool(fn_name, wrapper_type)
|
|
992
|
+
|
|
993
|
+
self._dependencies.add_function(fn_name)
|
|
994
|
+
|
|
995
|
+
return fn
|
|
996
|
+
|
|
997
|
+
@override
|
|
998
|
+
async def add_llm(self, name: str, config: LLMBaseConfig):
|
|
999
|
+
return await self._workflow_builder.add_llm(name, config)
|
|
1000
|
+
|
|
1001
|
+
@override
|
|
1002
|
+
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig):
|
|
1003
|
+
return await self._workflow_builder.add_auth_provider(name, config)
|
|
1004
|
+
|
|
1005
|
+
@override
|
|
1006
|
+
async def get_auth_provider(self, auth_provider_name: str):
|
|
1007
|
+
return await self._workflow_builder.get_auth_provider(auth_provider_name)
|
|
1008
|
+
|
|
1009
|
+
@override
|
|
1010
|
+
async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str):
|
|
1011
|
+
llm = await self._workflow_builder.get_llm(llm_name, wrapper_type)
|
|
1012
|
+
|
|
1013
|
+
self._dependencies.add_llm(llm_name)
|
|
1014
|
+
|
|
1015
|
+
return llm
|
|
1016
|
+
|
|
1017
|
+
@override
|
|
1018
|
+
def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
|
|
1019
|
+
return self._workflow_builder.get_llm_config(llm_name)
|
|
1020
|
+
|
|
1021
|
+
@override
|
|
1022
|
+
async def add_embedder(self, name: str, config: EmbedderBaseConfig):
|
|
1023
|
+
return await self._workflow_builder.add_embedder(name, config)
|
|
1024
|
+
|
|
1025
|
+
@override
|
|
1026
|
+
async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str):
|
|
1027
|
+
embedder = await self._workflow_builder.get_embedder(embedder_name, wrapper_type)
|
|
1028
|
+
|
|
1029
|
+
self._dependencies.add_embedder(embedder_name)
|
|
1030
|
+
|
|
1031
|
+
return embedder
|
|
1032
|
+
|
|
1033
|
+
@override
|
|
1034
|
+
def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
|
|
1035
|
+
return self._workflow_builder.get_embedder_config(embedder_name)
|
|
1036
|
+
|
|
1037
|
+
@override
|
|
1038
|
+
async def add_memory_client(self, name: str, config: MemoryBaseConfig) -> MemoryEditor:
|
|
1039
|
+
return await self._workflow_builder.add_memory_client(name, config)
|
|
1040
|
+
|
|
1041
|
+
@override
|
|
1042
|
+
def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
|
1043
|
+
"""
|
|
1044
|
+
Return the instantiated memory client for the given name.
|
|
1045
|
+
"""
|
|
1046
|
+
memory_client = self._workflow_builder.get_memory_client(memory_name)
|
|
1047
|
+
|
|
1048
|
+
self._dependencies.add_memory_client(memory_name)
|
|
1049
|
+
|
|
1050
|
+
return memory_client
|
|
1051
|
+
|
|
1052
|
+
@override
|
|
1053
|
+
def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
|
|
1054
|
+
return self._workflow_builder.get_memory_client_config(memory_name=memory_name)
|
|
1055
|
+
|
|
1056
|
+
@override
|
|
1057
|
+
async def add_object_store(self, name: str, config: ObjectStoreBaseConfig):
|
|
1058
|
+
return await self._workflow_builder.add_object_store(name, config)
|
|
1059
|
+
|
|
1060
|
+
@override
|
|
1061
|
+
async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
|
|
1062
|
+
"""
|
|
1063
|
+
Return the instantiated object store client for the given name.
|
|
1064
|
+
"""
|
|
1065
|
+
object_store_client = await self._workflow_builder.get_object_store_client(object_store_name)
|
|
1066
|
+
|
|
1067
|
+
self._dependencies.add_object_store(object_store_name)
|
|
1068
|
+
|
|
1069
|
+
return object_store_client
|
|
1070
|
+
|
|
1071
|
+
@override
|
|
1072
|
+
def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
|
|
1073
|
+
return self._workflow_builder.get_object_store_config(object_store_name)
|
|
1074
|
+
|
|
1075
|
+
@override
|
|
1076
|
+
async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig):
|
|
1077
|
+
return await self._workflow_builder.add_ttc_strategy(name, config)
|
|
1078
|
+
|
|
1079
|
+
@override
|
|
1080
|
+
async def get_ttc_strategy(self,
|
|
1081
|
+
strategy_name: str | TTCStrategyRef,
|
|
1082
|
+
pipeline_type: PipelineTypeEnum,
|
|
1083
|
+
stage_type: StageTypeEnum) -> StrategyBase:
|
|
1084
|
+
return await self._workflow_builder.get_ttc_strategy(strategy_name=strategy_name,
|
|
1085
|
+
pipeline_type=pipeline_type,
|
|
1086
|
+
stage_type=stage_type)
|
|
1087
|
+
|
|
1088
|
+
@override
|
|
1089
|
+
async def get_ttc_strategy_config(self,
|
|
1090
|
+
strategy_name: str | TTCStrategyRef,
|
|
1091
|
+
pipeline_type: PipelineTypeEnum,
|
|
1092
|
+
stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
|
|
1093
|
+
return await self._workflow_builder.get_ttc_strategy_config(strategy_name=strategy_name,
|
|
1094
|
+
pipeline_type=pipeline_type,
|
|
1095
|
+
stage_type=stage_type)
|
|
1096
|
+
|
|
1097
|
+
@override
|
|
1098
|
+
async def add_retriever(self, name: str, config: RetrieverBaseConfig):
|
|
1099
|
+
return await self._workflow_builder.add_retriever(name, config)
|
|
1100
|
+
|
|
1101
|
+
@override
|
|
1102
|
+
async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None):
|
|
1103
|
+
if not wrapper_type:
|
|
1104
|
+
return await self._workflow_builder.get_retriever(retriever_name=retriever_name)
|
|
1105
|
+
return await self._workflow_builder.get_retriever(retriever_name=retriever_name, wrapper_type=wrapper_type)
|
|
1106
|
+
|
|
1107
|
+
@override
|
|
1108
|
+
async def get_retriever_config(self, retriever_name: str) -> RetrieverBaseConfig:
|
|
1109
|
+
return await self._workflow_builder.get_retriever_config(retriever_name=retriever_name)
|
|
1110
|
+
|
|
1111
|
+
@override
|
|
1112
|
+
def get_user_manager(self) -> UserManagerHolder:
|
|
1113
|
+
return self._workflow_builder.get_user_manager()
|
|
1114
|
+
|
|
1115
|
+
@override
|
|
1116
|
+
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1117
|
+
return self._workflow_builder.get_function_dependencies(fn_name)
|