nvidia-nat 1.2.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +239 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +355 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +41 -0
- aiq/agent/react_agent/register.py +149 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +225 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +411 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +119 -0
- aiq/agent/tool_calling_agent/register.py +106 -0
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +96 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +285 -0
- aiq/builder/component_utils.py +316 -0
- aiq/builder/context.py +264 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +161 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +344 -0
- aiq/builder/function_base.py +380 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +174 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +74 -0
- aiq/builder/workflow.py +148 -0
- aiq/builder/workflow_builder.py +1117 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +231 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +36 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +39 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/info/list_mcp.py +213 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +38 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +297 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +246 -0
- aiq/cli/commands/uninstall.py +81 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +313 -0
- aiq/cli/entrypoint.py +135 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +488 -0
- aiq/cli/type_registry.py +1000 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +694 -0
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +171 -0
- aiq/data_models/component.py +54 -0
- aiq/data_models/component_ref.py +168 -0
- aiq/data_models/config.py +406 -0
- aiq/data_models/dataset_handler.py +123 -0
- aiq/data_models/discovery_metadata.py +335 -0
- aiq/data_models/embedder.py +27 -0
- aiq/data_models/evaluate.py +127 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +72 -0
- aiq/data_models/interactive.py +246 -0
- aiq/data_models/intermediate_step.py +302 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +27 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +27 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/profiler.py +54 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/data_models/ttc_strategy.py +30 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +59 -0
- aiq/embedder/openai_embedder.py +43 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +60 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +254 -0
- aiq/eval/evaluate.py +506 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/base_evaluator.py +73 -0
- aiq/eval/evaluator/evaluator_model.py +45 -0
- aiq/eval/intermediate_step_adapter.py +99 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +178 -0
- aiq/eval/rag_evaluator/register.py +143 -0
- aiq/eval/register.py +23 -0
- aiq/eval/remote_workflow.py +133 -0
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +75 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/tunable_rag_evaluator/__init__.py +0 -0
- aiq/eval/tunable_rag_evaluator/evaluate.py +245 -0
- aiq/eval/tunable_rag_evaluator/register.py +52 -0
- aiq/eval/usage_stats.py +41 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +140 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/eval/utils/weave_eval.py +184 -0
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/test_time_compute/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/test_time_compute/functions/__init__.py +0 -0
- aiq/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- aiq/experimental/test_time_compute/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/test_time_compute/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/test_time_compute/models/__init__.py +0 -0
- aiq/experimental/test_time_compute/models/editor_config.py +132 -0
- aiq/experimental/test_time_compute/models/scoring_config.py +112 -0
- aiq/experimental/test_time_compute/models/search_config.py +120 -0
- aiq/experimental/test_time_compute/models/selection_config.py +154 -0
- aiq/experimental/test_time_compute/models/stage_enums.py +43 -0
- aiq/experimental/test_time_compute/models/strategy_base.py +66 -0
- aiq/experimental/test_time_compute/models/tool_use_config.py +41 -0
- aiq/experimental/test_time_compute/models/ttc_item.py +48 -0
- aiq/experimental/test_time_compute/register.py +36 -0
- aiq/experimental/test_time_compute/scoring/__init__.py +0 -0
- aiq/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/test_time_compute/search/__init__.py +0 -0
- aiq/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- aiq/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/test_time_compute/selection/__init__.py +0 -0
- aiq/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- aiq/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +96 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +234 -0
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1092 -0
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +183 -0
- aiq/front_ends/fastapi/main.py +72 -0
- aiq/front_ends/fastapi/message_handler.py +298 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +195 -0
- aiq/front_ends/fastapi/step_adaptor.py +321 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/aws_bedrock_llm.py +57 -0
- aiq/llm/nim_llm.py +46 -0
- aiq/llm/openai_llm.py +46 -0
- aiq/llm/register.py +23 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +112 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +58 -0
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +76 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +322 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +265 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +310 -0
- aiq/observability/processor/callback_processor.py +42 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +71 -0
- aiq/observability/register.py +96 -0
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +627 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +188 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +290 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/data_models.py +24 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +28 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +473 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +567 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +82 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +215 -0
- aiq/runtime/runner.py +190 -0
- aiq/runtime/session.py +158 -0
- aiq/runtime/user_metadata.py +130 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +151 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +267 -0
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- aiq/tool/code_execution/register.py +74 -0
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +255 -0
- aiq/tool/mcp/mcp_tool.py +96 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +79 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +38 -0
- aiq/tool/retriever.py +89 -0
- aiq/tool/server_tools.py +66 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/dump_distro_mapping.py +32 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/io/yaml_tools.py +119 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/optional_imports.py +142 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +290 -0
- aiq/utils/type_utils.py +484 -0
- aiq/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0rc5.dist-info/METADATA +363 -0
- nvidia_nat-1.2.0rc5.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0rc5.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0rc5.dist-info/entry_points.txt +20 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0rc5.dist-info/top_level.txt +1 -0
aiq/cli/type_registry.py
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import typing
|
|
18
|
+
from collections.abc import AsyncIterator
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
from contextlib import AbstractAsyncContextManager
|
|
21
|
+
from contextlib import contextmanager
|
|
22
|
+
from copy import deepcopy
|
|
23
|
+
from functools import cached_property
|
|
24
|
+
from logging import Handler
|
|
25
|
+
|
|
26
|
+
from pydantic import BaseModel
|
|
27
|
+
from pydantic import ConfigDict
|
|
28
|
+
from pydantic import Field
|
|
29
|
+
from pydantic import Tag
|
|
30
|
+
from pydantic import computed_field
|
|
31
|
+
from pydantic import field_validator
|
|
32
|
+
|
|
33
|
+
from aiq.authentication.interfaces import AuthProviderBase
|
|
34
|
+
from aiq.builder.builder import Builder
|
|
35
|
+
from aiq.builder.builder import EvalBuilder
|
|
36
|
+
from aiq.builder.embedder import EmbedderProviderInfo
|
|
37
|
+
from aiq.builder.evaluator import EvaluatorInfo
|
|
38
|
+
from aiq.builder.front_end import FrontEndBase
|
|
39
|
+
from aiq.builder.function import Function
|
|
40
|
+
from aiq.builder.function_base import FunctionBase
|
|
41
|
+
from aiq.builder.function_info import FunctionInfo
|
|
42
|
+
from aiq.builder.llm import LLMProviderInfo
|
|
43
|
+
from aiq.builder.retriever import RetrieverProviderInfo
|
|
44
|
+
from aiq.data_models.authentication import AuthProviderBaseConfig
|
|
45
|
+
from aiq.data_models.authentication import AuthProviderBaseConfigT
|
|
46
|
+
from aiq.data_models.common import TypedBaseModelT
|
|
47
|
+
from aiq.data_models.component import AIQComponentEnum
|
|
48
|
+
from aiq.data_models.config import AIQConfig
|
|
49
|
+
from aiq.data_models.discovery_metadata import DiscoveryMetadata
|
|
50
|
+
from aiq.data_models.embedder import EmbedderBaseConfig
|
|
51
|
+
from aiq.data_models.embedder import EmbedderBaseConfigT
|
|
52
|
+
from aiq.data_models.evaluator import EvaluatorBaseConfig
|
|
53
|
+
from aiq.data_models.evaluator import EvaluatorBaseConfigT
|
|
54
|
+
from aiq.data_models.front_end import FrontEndBaseConfig
|
|
55
|
+
from aiq.data_models.front_end import FrontEndConfigT
|
|
56
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
57
|
+
from aiq.data_models.function import FunctionConfigT
|
|
58
|
+
from aiq.data_models.llm import LLMBaseConfig
|
|
59
|
+
from aiq.data_models.llm import LLMBaseConfigT
|
|
60
|
+
from aiq.data_models.logging import LoggingBaseConfig
|
|
61
|
+
from aiq.data_models.logging import LoggingMethodConfigT
|
|
62
|
+
from aiq.data_models.memory import MemoryBaseConfig
|
|
63
|
+
from aiq.data_models.memory import MemoryBaseConfigT
|
|
64
|
+
from aiq.data_models.object_store import ObjectStoreBaseConfig
|
|
65
|
+
from aiq.data_models.object_store import ObjectStoreBaseConfigT
|
|
66
|
+
from aiq.data_models.registry_handler import RegistryHandlerBaseConfig
|
|
67
|
+
from aiq.data_models.registry_handler import RegistryHandlerBaseConfigT
|
|
68
|
+
from aiq.data_models.retriever import RetrieverBaseConfig
|
|
69
|
+
from aiq.data_models.retriever import RetrieverBaseConfigT
|
|
70
|
+
from aiq.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
|
71
|
+
from aiq.data_models.telemetry_exporter import TelemetryExporterConfigT
|
|
72
|
+
from aiq.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
73
|
+
from aiq.data_models.ttc_strategy import TTCStrategyBaseConfigT
|
|
74
|
+
from aiq.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
75
|
+
from aiq.memory.interfaces import MemoryEditor
|
|
76
|
+
from aiq.object_store.interfaces import ObjectStore
|
|
77
|
+
from aiq.observability.exporter.base_exporter import BaseExporter
|
|
78
|
+
from aiq.registry_handlers.registry_handler_base import AbstractRegistryHandler
|
|
79
|
+
|
|
80
|
+
logger = logging.getLogger(__name__)
|
|
81
|
+
|
|
82
|
+
AuthProviderBuildCallableT = Callable[[AuthProviderBaseConfigT, Builder], AsyncIterator[AuthProviderBase]]
|
|
83
|
+
EmbedderClientBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncIterator[typing.Any]]
|
|
84
|
+
EmbedderProviderBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncIterator[EmbedderProviderInfo]]
|
|
85
|
+
EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIterator[EvaluatorInfo]]
|
|
86
|
+
FrontEndBuildCallableT = Callable[[FrontEndConfigT, AIQConfig], AsyncIterator[FrontEndBase]]
|
|
87
|
+
FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
|
|
88
|
+
TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
|
|
89
|
+
LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
|
|
90
|
+
LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
|
|
91
|
+
LoggingMethodBuildCallableT = Callable[[LoggingMethodConfigT, Builder], AsyncIterator[Handler]]
|
|
92
|
+
MemoryBuildCallableT = Callable[[MemoryBaseConfigT, Builder], AsyncIterator[MemoryEditor]]
|
|
93
|
+
ObjectStoreBuildCallableT = Callable[[ObjectStoreBaseConfigT, Builder], AsyncIterator[ObjectStore]]
|
|
94
|
+
RegistryHandlerBuildCallableT = Callable[[RegistryHandlerBaseConfigT], AsyncIterator[AbstractRegistryHandler]]
|
|
95
|
+
RetrieverClientBuildCallableT = Callable[[RetrieverBaseConfigT, Builder], AsyncIterator[typing.Any]]
|
|
96
|
+
RetrieverProviderBuildCallableT = Callable[[RetrieverBaseConfigT, Builder], AsyncIterator[RetrieverProviderInfo]]
|
|
97
|
+
TelemetryExporterBuildCallableT = Callable[[TelemetryExporterConfigT, Builder], AsyncIterator[BaseExporter]]
|
|
98
|
+
ToolWrapperBuildCallableT = Callable[[str, Function, Builder], typing.Any]
|
|
99
|
+
|
|
100
|
+
AuthProviderRegisteredCallableT = Callable[[AuthProviderBaseConfigT, Builder],
|
|
101
|
+
AbstractAsyncContextManager[AuthProviderBase]]
|
|
102
|
+
EmbedderClientRegisteredCallableT = Callable[[EmbedderBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
103
|
+
EmbedderProviderRegisteredCallableT = Callable[[EmbedderBaseConfigT, Builder],
|
|
104
|
+
AbstractAsyncContextManager[EmbedderProviderInfo]]
|
|
105
|
+
EvaluatorRegisteredCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AbstractAsyncContextManager[EvaluatorInfo]]
|
|
106
|
+
FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, AIQConfig], AbstractAsyncContextManager[FrontEndBase]]
|
|
107
|
+
FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
|
|
108
|
+
AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
|
|
109
|
+
TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
|
|
110
|
+
LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
111
|
+
LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
|
|
112
|
+
LoggingMethodRegisteredCallableT = Callable[[LoggingMethodConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
113
|
+
MemoryRegisteredCallableT = Callable[[MemoryBaseConfigT, Builder], AbstractAsyncContextManager[MemoryEditor]]
|
|
114
|
+
ObjectStoreRegisteredCallableT = Callable[[ObjectStoreBaseConfigT, Builder], AbstractAsyncContextManager[ObjectStore]]
|
|
115
|
+
RegistryHandlerRegisteredCallableT = Callable[[RegistryHandlerBaseConfigT],
|
|
116
|
+
AbstractAsyncContextManager[AbstractRegistryHandler]]
|
|
117
|
+
RetrieverClientRegisteredCallableT = Callable[[RetrieverBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
118
|
+
RetrieverProviderRegisteredCallableT = Callable[[RetrieverBaseConfigT, Builder],
|
|
119
|
+
AbstractAsyncContextManager[RetrieverProviderInfo]]
|
|
120
|
+
TeleExporterRegisteredCallableT = Callable[[TelemetryExporterConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class RegisteredInfo(BaseModel, typing.Generic[TypedBaseModelT]):
|
|
124
|
+
|
|
125
|
+
model_config = ConfigDict(frozen=True)
|
|
126
|
+
|
|
127
|
+
full_type: str
|
|
128
|
+
config_type: type[TypedBaseModelT]
|
|
129
|
+
discovery_metadata: DiscoveryMetadata = DiscoveryMetadata()
|
|
130
|
+
|
|
131
|
+
@computed_field
|
|
132
|
+
@cached_property
|
|
133
|
+
def module_name(self) -> str:
|
|
134
|
+
return self.full_type.split("/")[0]
|
|
135
|
+
|
|
136
|
+
@computed_field
|
|
137
|
+
@cached_property
|
|
138
|
+
def local_name(self) -> str:
|
|
139
|
+
return self.full_type.split("/")[-1]
|
|
140
|
+
|
|
141
|
+
@field_validator("full_type", mode="after")
|
|
142
|
+
@classmethod
|
|
143
|
+
def validate_full_type(cls, full_type: str) -> str:
|
|
144
|
+
parts = full_type.split("/")
|
|
145
|
+
|
|
146
|
+
if (len(parts) != 2):
|
|
147
|
+
raise ValueError(f"Invalid full type: {full_type}. Expected format: `module_name/local_name`")
|
|
148
|
+
|
|
149
|
+
return full_type
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class RegisteredTelemetryExporter(RegisteredInfo[TelemetryExporterBaseConfig]):
|
|
153
|
+
|
|
154
|
+
build_fn: TeleExporterRegisteredCallableT = Field(repr=False)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class RegisteredLoggingMethod(RegisteredInfo[LoggingBaseConfig]):
|
|
158
|
+
|
|
159
|
+
build_fn: LoggingMethodRegisteredCallableT = Field(repr=False)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class RegisteredFrontEndInfo(RegisteredInfo[FrontEndBaseConfig]):
|
|
163
|
+
"""
|
|
164
|
+
Represents a registered front end. Front ends are the entry points to the workflow and are responsible for
|
|
165
|
+
orchestrating the workflow.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
build_fn: FrontEndRegisteredCallableT = Field(repr=False)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
|
|
172
|
+
"""
|
|
173
|
+
Represents a registered function. Functions are the building blocks of the workflow with predefined inputs, outputs,
|
|
174
|
+
and a description.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
build_fn: FunctionRegisteredCallableT = Field(repr=False)
|
|
178
|
+
framework_wrappers: list[str] = Field(default_factory=list)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
|
|
182
|
+
"""
|
|
183
|
+
Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
|
|
184
|
+
etc.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
build_fn: LLMProviderRegisteredCallableT = Field(repr=False)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class RegisteredAuthProviderInfo(RegisteredInfo[AuthProviderBaseConfig]):
|
|
191
|
+
"""
|
|
192
|
+
Represents a registered Authentication provider. Authentication providers facilitate the authentication process.
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
build_fn: AuthProviderRegisteredCallableT = Field(repr=False)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class RegisteredLLMClientInfo(RegisteredInfo[LLMBaseConfig]):
|
|
199
|
+
"""
|
|
200
|
+
Represents a registered LLM client. LLM Clients are the clients that interact with the LLM providers and are
|
|
201
|
+
specific to a particular LLM framework.
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
llm_framework: str
|
|
205
|
+
build_fn: LLMClientRegisteredCallableT = Field(repr=False)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class RegisteredEmbedderProviderInfo(RegisteredInfo[EmbedderBaseConfig]):
|
|
209
|
+
"""
|
|
210
|
+
Represents a registered Embedder provider. Embedder Providers are the operators of the Embedder models. i.e. NIMs,
|
|
211
|
+
OpenAI, Anthropic, etc.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
build_fn: EmbedderProviderRegisteredCallableT = Field(repr=False)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class RegisteredEmbedderClientInfo(RegisteredInfo[EmbedderBaseConfig]):
|
|
218
|
+
"""
|
|
219
|
+
Represents a registered Embedder client. Embedder Clients are the clients that interact with the Embedder providers
|
|
220
|
+
and are specific to a particular LLM framework.
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
llm_framework: str
|
|
224
|
+
build_fn: EmbedderClientRegisteredCallableT = Field(repr=False)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class RegisteredEvaluatorInfo(RegisteredInfo[EvaluatorBaseConfig]):
|
|
228
|
+
"""
|
|
229
|
+
Represents a registered Evaluator e.g. RagEvaluator, TrajectoryEvaluator, etc.
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
build_fn: EvaluatorRegisteredCallableT = Field(repr=False)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class RegisteredMemoryInfo(RegisteredInfo[MemoryBaseConfig]):
|
|
236
|
+
"""
|
|
237
|
+
Represents a registered Memory object which adheres to the memory interface.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
build_fn: MemoryRegisteredCallableT = Field(repr=False)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class RegisteredObjectStoreInfo(RegisteredInfo[ObjectStoreBaseConfig]):
|
|
244
|
+
"""
|
|
245
|
+
Represents a registered Object Store object which adheres to the object store interface.
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
build_fn: ObjectStoreRegisteredCallableT = Field(repr=False)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class RegisteredTTCStrategyInfo(RegisteredInfo[TTCStrategyBaseConfig]):
|
|
252
|
+
"""
|
|
253
|
+
Represents a registered TTC strategy.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
build_fn: TTCStrategyRegisterCallableT = Field(repr=False)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class RegisteredToolWrapper(BaseModel):
|
|
260
|
+
"""
|
|
261
|
+
Represents a registered tool wrapper. Tool wrappers are used to wrap the functions in a particular LLM framework.
|
|
262
|
+
They do not have their own configuration, but they are used to wrap the functions in a particular LLM framework.
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
llm_framework: str
|
|
266
|
+
build_fn: ToolWrapperBuildCallableT = Field(repr=False)
|
|
267
|
+
discovery_metadata: DiscoveryMetadata
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class RegisteredRetrieverProviderInfo(RegisteredInfo[RetrieverBaseConfig]):
|
|
271
|
+
"""
|
|
272
|
+
Represents a registered Retriever object which adheres to the retriever interface.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
build_fn: RetrieverProviderRegisteredCallableT = Field(repr=False)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class RegisteredRetrieverClientInfo(RegisteredInfo[RetrieverBaseConfig]):
|
|
279
|
+
"""
|
|
280
|
+
Represents a registered Retriever Client. Retriever Clients are the LLM Framework-specific clients that expose an
|
|
281
|
+
interface to the Retriever object.
|
|
282
|
+
"""
|
|
283
|
+
llm_framework: str | None
|
|
284
|
+
build_fn: RetrieverClientRegisteredCallableT = Field(repr=False)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class RegisteredRegistryHandlerInfo(RegisteredInfo[RegistryHandlerBaseConfig]):
|
|
288
|
+
"""
|
|
289
|
+
Represents a registered LLM client. LLM Clients are the clients that interact with the LLM providers and are
|
|
290
|
+
specific to a particular LLM framework.
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
build_fn: RegistryHandlerRegisteredCallableT = Field(repr=False)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class RegisteredPackage(BaseModel):
|
|
297
|
+
package_name: str
|
|
298
|
+
discovery_metadata: DiscoveryMetadata
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
302
|
+
|
|
303
|
+
def __init__(self) -> None:
|
|
304
|
+
# Telemetry Exporters
|
|
305
|
+
self._registered_telemetry_exporters: dict[type[TelemetryExporterBaseConfig], RegisteredTelemetryExporter] = {}
|
|
306
|
+
|
|
307
|
+
# Logging Methods
|
|
308
|
+
self._registered_logging_methods: dict[type[LoggingBaseConfig], RegisteredLoggingMethod] = {}
|
|
309
|
+
|
|
310
|
+
# Front Ends
|
|
311
|
+
self._registered_front_end_infos: dict[type[FrontEndBaseConfig], RegisteredFrontEndInfo] = {}
|
|
312
|
+
|
|
313
|
+
# Functions
|
|
314
|
+
self._registered_functions: dict[type[FunctionBaseConfig], RegisteredFunctionInfo] = {}
|
|
315
|
+
|
|
316
|
+
# LLMs
|
|
317
|
+
self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
|
|
318
|
+
self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
|
|
319
|
+
self._llm_client_framework_to_provider: dict[str, dict[type[LLMBaseConfig], RegisteredLLMClientInfo]] = {}
|
|
320
|
+
|
|
321
|
+
# Authentication
|
|
322
|
+
self._registered_auth_provider_infos: dict[type[AuthProviderBaseConfig], RegisteredAuthProviderInfo] = {}
|
|
323
|
+
|
|
324
|
+
# Embedders
|
|
325
|
+
self._registered_embedder_provider_infos: dict[type[EmbedderBaseConfig], RegisteredEmbedderProviderInfo] = {}
|
|
326
|
+
self._embedder_client_provider_to_framework: dict[type[EmbedderBaseConfig],
|
|
327
|
+
dict[str, RegisteredEmbedderClientInfo]] = {}
|
|
328
|
+
self._embedder_client_framework_to_provider: dict[str,
|
|
329
|
+
dict[type[EmbedderBaseConfig],
|
|
330
|
+
RegisteredEmbedderClientInfo]] = {}
|
|
331
|
+
|
|
332
|
+
# Evaluators
|
|
333
|
+
self._registered_evaluator_infos: dict[type[EvaluatorBaseConfig], RegisteredEvaluatorInfo] = {}
|
|
334
|
+
|
|
335
|
+
# Memory
|
|
336
|
+
self._registered_memory_infos: dict[type[MemoryBaseConfig], RegisteredMemoryInfo] = {}
|
|
337
|
+
|
|
338
|
+
# Object Stores
|
|
339
|
+
self._registered_object_store_infos: dict[type[ObjectStoreBaseConfig], RegisteredObjectStoreInfo] = {}
|
|
340
|
+
|
|
341
|
+
# Retrievers
|
|
342
|
+
self._registered_retriever_provider_infos: dict[type[RetrieverBaseConfig], RegisteredRetrieverProviderInfo] = {}
|
|
343
|
+
self._retriever_client_provider_to_framework: dict[type[RetrieverBaseConfig],
|
|
344
|
+
dict[str | None, RegisteredRetrieverClientInfo]] = {}
|
|
345
|
+
self._retriever_client_framework_to_provider: dict[str | None,
|
|
346
|
+
dict[type[RetrieverBaseConfig],
|
|
347
|
+
RegisteredRetrieverClientInfo]] = {}
|
|
348
|
+
|
|
349
|
+
# Registry Handlers
|
|
350
|
+
self._registered_registry_handler_infos: dict[type[RegistryHandlerBaseConfig],
|
|
351
|
+
RegisteredRegistryHandlerInfo] = {}
|
|
352
|
+
|
|
353
|
+
# Tool Wrappers
|
|
354
|
+
self._registered_tool_wrappers: dict[str, RegisteredToolWrapper] = {}
|
|
355
|
+
|
|
356
|
+
# TTC Strategies
|
|
357
|
+
self._registered_ttc_strategies: dict[type[TTCStrategyBaseConfig], RegisteredTTCStrategyInfo] = {}
|
|
358
|
+
|
|
359
|
+
# Packages
|
|
360
|
+
self._registered_packages: dict[str, RegisteredPackage] = {}
|
|
361
|
+
|
|
362
|
+
self._registration_changed_hooks: list[Callable[[], None]] = []
|
|
363
|
+
self._registration_changed_hooks_active: bool = True
|
|
364
|
+
|
|
365
|
+
self._registered_channel_map = {}
|
|
366
|
+
|
|
367
|
+
def _registration_changed(self):
|
|
368
|
+
|
|
369
|
+
if (not self._registration_changed_hooks_active):
|
|
370
|
+
return
|
|
371
|
+
|
|
372
|
+
logger.debug("Registration changed. Notifying hooks.")
|
|
373
|
+
|
|
374
|
+
for hook in self._registration_changed_hooks:
|
|
375
|
+
hook()
|
|
376
|
+
|
|
377
|
+
def add_registration_changed_hook(self, cb: Callable[[], typing.Any]) -> None:
|
|
378
|
+
|
|
379
|
+
self._registration_changed_hooks.append(cb)
|
|
380
|
+
|
|
381
|
+
@contextmanager
|
|
382
|
+
def pause_registration_changed_hooks(self):
|
|
383
|
+
|
|
384
|
+
self._registration_changed_hooks_active = False
|
|
385
|
+
|
|
386
|
+
try:
|
|
387
|
+
yield
|
|
388
|
+
finally:
|
|
389
|
+
self._registration_changed_hooks_active = True
|
|
390
|
+
|
|
391
|
+
# Ensure that the registration changed hooks are called
|
|
392
|
+
self._registration_changed()
|
|
393
|
+
|
|
394
|
+
def register_telemetry_exporter(self, registration: RegisteredTelemetryExporter):
|
|
395
|
+
|
|
396
|
+
if (registration.config_type in self._registered_telemetry_exporters):
|
|
397
|
+
raise ValueError(f"A telemetry exporter with the same config type `{registration.config_type}` has already "
|
|
398
|
+
"been registered.")
|
|
399
|
+
|
|
400
|
+
self._registered_telemetry_exporters[registration.config_type] = registration
|
|
401
|
+
|
|
402
|
+
self._registration_changed()
|
|
403
|
+
|
|
404
|
+
def get_telemetry_exporter(self, config_type: type[TelemetryExporterBaseConfig]) -> RegisteredTelemetryExporter:
|
|
405
|
+
|
|
406
|
+
try:
|
|
407
|
+
return self._registered_telemetry_exporters[config_type]
|
|
408
|
+
except KeyError as err:
|
|
409
|
+
raise KeyError(f"Could not find a registered telemetry exporter for config `{config_type}`. "
|
|
410
|
+
f"Registered configs: {set(self._registered_telemetry_exporters.keys())}") from err
|
|
411
|
+
|
|
412
|
+
def get_registered_telemetry_exporters(self) -> list[RegisteredInfo[TelemetryExporterBaseConfig]]:
|
|
413
|
+
|
|
414
|
+
return list(self._registered_telemetry_exporters.values())
|
|
415
|
+
|
|
416
|
+
def register_logging_method(self, registration: RegisteredLoggingMethod):
|
|
417
|
+
|
|
418
|
+
if (registration.config_type in self._registered_logging_methods):
|
|
419
|
+
raise ValueError(f"A logging method with the same config type `{registration.config_type}` has already "
|
|
420
|
+
"been registered.")
|
|
421
|
+
|
|
422
|
+
self._registered_logging_methods[registration.config_type] = registration
|
|
423
|
+
|
|
424
|
+
self._registration_changed()
|
|
425
|
+
|
|
426
|
+
def get_logging_method(self, config_type: type[LoggingBaseConfig]) -> RegisteredLoggingMethod:
|
|
427
|
+
try:
|
|
428
|
+
return self._registered_logging_methods[config_type]
|
|
429
|
+
except KeyError as err:
|
|
430
|
+
raise KeyError(f"No logging method found for config `{config_type}`. "
|
|
431
|
+
f"Known: {set(self._registered_logging_methods.keys())}") from err
|
|
432
|
+
|
|
433
|
+
def get_registered_logging_method(self) -> list[RegisteredInfo[LoggingBaseConfig]]:
|
|
434
|
+
|
|
435
|
+
return list(self._registered_logging_methods.values())
|
|
436
|
+
|
|
437
|
+
def register_front_end(self, registration: RegisteredFrontEndInfo):
|
|
438
|
+
|
|
439
|
+
if (registration.config_type in self._registered_front_end_infos):
|
|
440
|
+
raise ValueError(f"A front end with the same config type `{registration.config_type}` has already been "
|
|
441
|
+
"registered.")
|
|
442
|
+
|
|
443
|
+
self._registered_front_end_infos[registration.config_type] = registration
|
|
444
|
+
|
|
445
|
+
self._registration_changed()
|
|
446
|
+
|
|
447
|
+
def get_front_end(self, config_type: type[FrontEndBaseConfig]) -> RegisteredFrontEndInfo:
|
|
448
|
+
|
|
449
|
+
try:
|
|
450
|
+
return self._registered_front_end_infos[config_type]
|
|
451
|
+
except KeyError as err:
|
|
452
|
+
raise KeyError(f"Could not find a registered front end for config `{config_type}`. "
|
|
453
|
+
f"Registered configs: {set(self._registered_front_end_infos.keys())}") from err
|
|
454
|
+
|
|
455
|
+
def get_registered_front_ends(self) -> list[RegisteredInfo[FrontEndBaseConfig]]:
|
|
456
|
+
|
|
457
|
+
return list(self._registered_front_end_infos.values())
|
|
458
|
+
|
|
459
|
+
def register_function(self, registration: RegisteredFunctionInfo):
|
|
460
|
+
|
|
461
|
+
if (registration.config_type in self._registered_functions):
|
|
462
|
+
raise ValueError(f"A function with the same config type `{registration.config_type}` has already been "
|
|
463
|
+
"registered.")
|
|
464
|
+
|
|
465
|
+
self._registered_functions[registration.config_type] = registration
|
|
466
|
+
|
|
467
|
+
self._registration_changed()
|
|
468
|
+
|
|
469
|
+
def get_function(self, config_type: type[FunctionBaseConfig]) -> RegisteredFunctionInfo:
|
|
470
|
+
|
|
471
|
+
try:
|
|
472
|
+
return self._registered_functions[config_type]
|
|
473
|
+
except KeyError as err:
|
|
474
|
+
raise KeyError(f"Could not find a registered function for config `{config_type}`. "
|
|
475
|
+
f"Registered configs: {set(self._registered_functions.keys())}") from err
|
|
476
|
+
|
|
477
|
+
def get_registered_functions(self) -> list[RegisteredInfo[FunctionBaseConfig]]:
|
|
478
|
+
|
|
479
|
+
return list(self._registered_functions.values())
|
|
480
|
+
|
|
481
|
+
def register_llm_provider(self, info: RegisteredLLMProviderInfo):
|
|
482
|
+
|
|
483
|
+
if (info.config_type in self._registered_llm_provider_infos):
|
|
484
|
+
raise ValueError(
|
|
485
|
+
f"An LLM provider with the same config type `{info.config_type}` has already been registered.")
|
|
486
|
+
|
|
487
|
+
self._registered_llm_provider_infos[info.config_type] = info
|
|
488
|
+
|
|
489
|
+
self._registration_changed()
|
|
490
|
+
|
|
491
|
+
def get_llm_provider(self, config_type: type[LLMBaseConfig]) -> RegisteredLLMProviderInfo:
|
|
492
|
+
|
|
493
|
+
try:
|
|
494
|
+
return self._registered_llm_provider_infos[config_type]
|
|
495
|
+
except KeyError as err:
|
|
496
|
+
raise KeyError(f"Could not find a registered LLM provider for config `{config_type}`. "
|
|
497
|
+
f"Registered configs: {set(self._registered_llm_provider_infos.keys())}") from err
|
|
498
|
+
|
|
499
|
+
def get_registered_llm_providers(self) -> list[RegisteredInfo[LLMBaseConfig]]:
|
|
500
|
+
return list(self._registered_llm_provider_infos.values())
|
|
501
|
+
|
|
502
|
+
def register_auth_provider(self, info: RegisteredAuthProviderInfo):
|
|
503
|
+
|
|
504
|
+
if (info.config_type in self._registered_auth_provider_infos):
|
|
505
|
+
raise ValueError(
|
|
506
|
+
f"An Authentication Provider with the same config type `{info.config_type}` has already been "
|
|
507
|
+
"registered.")
|
|
508
|
+
|
|
509
|
+
self._registered_auth_provider_infos[info.config_type] = info
|
|
510
|
+
|
|
511
|
+
self._registration_changed()
|
|
512
|
+
|
|
513
|
+
def get_auth_provider(self, config_type: type[AuthProviderBaseConfig]) -> RegisteredAuthProviderInfo:
|
|
514
|
+
try:
|
|
515
|
+
return self._registered_auth_provider_infos[config_type]
|
|
516
|
+
except KeyError as err:
|
|
517
|
+
raise KeyError(f"Could not find a registered Authentication Provider for config `{config_type}`. "
|
|
518
|
+
f"Registered configs: {set(self._registered_auth_provider_infos.keys())}") from err
|
|
519
|
+
|
|
520
|
+
def get_registered_auth_providers(self) -> list[RegisteredInfo[AuthProviderBaseConfig]]:
|
|
521
|
+
return list(self._registered_auth_provider_infos.values())
|
|
522
|
+
|
|
523
|
+
def register_llm_client(self, info: RegisteredLLMClientInfo):
|
|
524
|
+
|
|
525
|
+
if (info.config_type in self._llm_client_provider_to_framework
|
|
526
|
+
and info.llm_framework in self._llm_client_provider_to_framework[info.config_type]):
|
|
527
|
+
raise ValueError(f"An LLM client with the same config type `{info.config_type}` "
|
|
528
|
+
f"and LLM framework `{info.llm_framework}` has already been registered.")
|
|
529
|
+
|
|
530
|
+
self._llm_client_provider_to_framework.setdefault(info.config_type, {})[info.llm_framework] = info
|
|
531
|
+
self._llm_client_framework_to_provider.setdefault(info.llm_framework, {})[info.config_type] = info
|
|
532
|
+
|
|
533
|
+
self._registration_changed()
|
|
534
|
+
|
|
535
|
+
def get_llm_client(self, config_type: type[LLMBaseConfig], wrapper_type: str) -> RegisteredLLMClientInfo:
|
|
536
|
+
|
|
537
|
+
try:
|
|
538
|
+
client_info = self._llm_client_provider_to_framework[config_type][wrapper_type]
|
|
539
|
+
except KeyError as err:
|
|
540
|
+
raise KeyError(f"An invalid LLM config and wrapper combination was supplied. Config: `{config_type}`, "
|
|
541
|
+
f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} LLM client but "
|
|
542
|
+
f"there is no registered conversion from that LLM provider to LLM framework: "
|
|
543
|
+
f"{wrapper_type}. "
|
|
544
|
+
f"Please provide an LLM configuration from one of the following providers: "
|
|
545
|
+
f"{set(self._llm_client_provider_to_framework.keys())}") from err
|
|
546
|
+
|
|
547
|
+
return client_info
|
|
548
|
+
|
|
549
|
+
def register_embedder_provider(self, info: RegisteredEmbedderProviderInfo):
|
|
550
|
+
|
|
551
|
+
if (info.config_type in self._registered_embedder_provider_infos):
|
|
552
|
+
raise ValueError(f"An Embedder provider with the same config type `{info.config_type}` has already been "
|
|
553
|
+
"registered.")
|
|
554
|
+
|
|
555
|
+
self._registered_embedder_provider_infos[info.config_type] = info
|
|
556
|
+
|
|
557
|
+
self._registration_changed()
|
|
558
|
+
|
|
559
|
+
def get_embedder_provider(self, config_type: type[EmbedderBaseConfig]) -> RegisteredEmbedderProviderInfo:
|
|
560
|
+
|
|
561
|
+
try:
|
|
562
|
+
return self._registered_embedder_provider_infos[config_type]
|
|
563
|
+
except KeyError as err:
|
|
564
|
+
raise KeyError(f"Could not find a registered Embedder provider for config `{config_type}`. "
|
|
565
|
+
f"Registered configs: {set(self._registered_embedder_provider_infos.keys())}") from err
|
|
566
|
+
|
|
567
|
+
def get_registered_embedder_providers(self) -> list[RegisteredInfo[EmbedderBaseConfig]]:
|
|
568
|
+
|
|
569
|
+
return list(self._registered_embedder_provider_infos.values())
|
|
570
|
+
|
|
571
|
+
def register_embedder_client(self, info: RegisteredEmbedderClientInfo):
|
|
572
|
+
|
|
573
|
+
if (info.config_type in self._embedder_client_provider_to_framework
|
|
574
|
+
and info.llm_framework in self._embedder_client_provider_to_framework[info.config_type]):
|
|
575
|
+
raise ValueError(f"An Embedder client with the same config type `{info.config_type}` has already been "
|
|
576
|
+
"registered.")
|
|
577
|
+
|
|
578
|
+
self._embedder_client_provider_to_framework.setdefault(info.config_type, {})[info.llm_framework] = info
|
|
579
|
+
self._embedder_client_framework_to_provider.setdefault(info.llm_framework, {})[info.config_type] = info
|
|
580
|
+
|
|
581
|
+
self._registration_changed()
|
|
582
|
+
|
|
583
|
+
def get_embedder_client(self, config_type: type[EmbedderBaseConfig],
|
|
584
|
+
wrapper_type: str) -> RegisteredEmbedderClientInfo:
|
|
585
|
+
|
|
586
|
+
try:
|
|
587
|
+
client_info = self._embedder_client_provider_to_framework[config_type][wrapper_type]
|
|
588
|
+
except KeyError as err:
|
|
589
|
+
raise KeyError(
|
|
590
|
+
f"An invalid Embedder config and wrapper combination was supplied. Config: `{config_type}`, "
|
|
591
|
+
"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
|
|
592
|
+
"there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
|
|
593
|
+
"Please provide an Embedder configuration from one of the following providers: "
|
|
594
|
+
f"{set(self._embedder_client_provider_to_framework.keys())}") from err
|
|
595
|
+
|
|
596
|
+
return client_info
|
|
597
|
+
|
|
598
|
+
def register_evaluator(self, info: RegisteredEvaluatorInfo):
|
|
599
|
+
|
|
600
|
+
if (info.config_type in self._registered_evaluator_infos):
|
|
601
|
+
raise ValueError(f"An Evaluator with the same config type `{info.config_type}` has already been "
|
|
602
|
+
"registered.")
|
|
603
|
+
|
|
604
|
+
self._registered_evaluator_infos[info.config_type] = info
|
|
605
|
+
|
|
606
|
+
self._registration_changed()
|
|
607
|
+
|
|
608
|
+
def get_evaluator(self, config_type: type[EvaluatorBaseConfig]) -> RegisteredEvaluatorInfo:
|
|
609
|
+
|
|
610
|
+
try:
|
|
611
|
+
return self._registered_evaluator_infos[config_type]
|
|
612
|
+
except KeyError as err:
|
|
613
|
+
raise KeyError(f"Could not find a registered Evaluator for config `{config_type}`. "
|
|
614
|
+
f"Registered configs: {set(self._registered_evaluator_infos.keys())}") from err
|
|
615
|
+
|
|
616
|
+
def get_registered_evaluators(self) -> list[RegisteredInfo[EvaluatorBaseConfig]]:
|
|
617
|
+
|
|
618
|
+
return list(self._registered_evaluator_infos.values())
|
|
619
|
+
|
|
620
|
+
def register_memory(self, info: RegisteredMemoryInfo):
|
|
621
|
+
|
|
622
|
+
if (info.config_type in self._registered_memory_infos):
|
|
623
|
+
raise ValueError(
|
|
624
|
+
f"A Memory client with the same config type `{info.config_type}` has already been registered.")
|
|
625
|
+
|
|
626
|
+
self._registered_memory_infos[info.config_type] = info
|
|
627
|
+
|
|
628
|
+
self._registration_changed()
|
|
629
|
+
|
|
630
|
+
def get_memory(self, config_type: type[MemoryBaseConfig]) -> RegisteredMemoryInfo:
|
|
631
|
+
|
|
632
|
+
try:
|
|
633
|
+
return self._registered_memory_infos[config_type]
|
|
634
|
+
except KeyError as err:
|
|
635
|
+
raise KeyError(f"Could not find a registered Memory client for config `{config_type}`. "
|
|
636
|
+
f"Registered configs: {set(self._registered_memory_infos.keys())}") from err
|
|
637
|
+
|
|
638
|
+
def get_registered_memorys(self) -> list[RegisteredInfo[MemoryBaseConfig]]:
|
|
639
|
+
|
|
640
|
+
return list(self._registered_memory_infos.values())
|
|
641
|
+
|
|
642
|
+
def register_object_store(self, info: RegisteredObjectStoreInfo):
|
|
643
|
+
|
|
644
|
+
if (info.config_type in self._registered_object_store_infos):
|
|
645
|
+
raise ValueError(f"An Object Store with the same config type `{info.config_type}` has already been "
|
|
646
|
+
"registered.")
|
|
647
|
+
|
|
648
|
+
self._registered_object_store_infos[info.config_type] = info
|
|
649
|
+
|
|
650
|
+
self._registration_changed()
|
|
651
|
+
|
|
652
|
+
def get_object_store(self, config_type: type[ObjectStoreBaseConfig]) -> RegisteredObjectStoreInfo:
|
|
653
|
+
|
|
654
|
+
try:
|
|
655
|
+
return self._registered_object_store_infos[config_type]
|
|
656
|
+
except KeyError as err:
|
|
657
|
+
raise KeyError(f"Could not find a registered Object Store for config `{config_type}`. "
|
|
658
|
+
f"Registered configs: {set(self._registered_object_store_infos.keys())}") from err
|
|
659
|
+
|
|
660
|
+
def get_registered_object_stores(self) -> list[RegisteredInfo[ObjectStoreBaseConfig]]:
|
|
661
|
+
|
|
662
|
+
return list(self._registered_object_store_infos.values())
|
|
663
|
+
|
|
664
|
+
def register_retriever_provider(self, info: RegisteredRetrieverProviderInfo):
|
|
665
|
+
|
|
666
|
+
if (info.config_type in self._registered_retriever_provider_infos):
|
|
667
|
+
raise ValueError(
|
|
668
|
+
f"A Retriever provider with the same config type `{info.config_type}` has already been registered")
|
|
669
|
+
|
|
670
|
+
self._registered_retriever_provider_infos[info.config_type] = info
|
|
671
|
+
|
|
672
|
+
self._registration_changed()
|
|
673
|
+
|
|
674
|
+
def get_retriever_provider(self, config_type: type[RetrieverBaseConfig]) -> RegisteredRetrieverProviderInfo:
|
|
675
|
+
|
|
676
|
+
try:
|
|
677
|
+
return self._registered_retriever_provider_infos[config_type]
|
|
678
|
+
except KeyError as err:
|
|
679
|
+
raise KeyError(f"Could not find a registered Retriever provider for config `{config_type}`. "
|
|
680
|
+
f"Registered configs: {set(self._registered_retriever_provider_infos.keys())}") from err
|
|
681
|
+
|
|
682
|
+
def get_registered_retriever_providers(self) -> list[RegisteredInfo[RetrieverBaseConfig]]:
|
|
683
|
+
|
|
684
|
+
return list(self._registered_retriever_provider_infos.values())
|
|
685
|
+
|
|
686
|
+
def register_retriever_client(self, info: RegisteredRetrieverClientInfo):
|
|
687
|
+
|
|
688
|
+
if (info.config_type in self._retriever_client_provider_to_framework
|
|
689
|
+
and info.llm_framework in self._retriever_client_provider_to_framework[info.config_type]):
|
|
690
|
+
raise ValueError(f"A Retriever client with the same config type `{info.config_type}` "
|
|
691
|
+
" and LLM framework `{info.llm_framework}` has already been registered.")
|
|
692
|
+
|
|
693
|
+
self._retriever_client_provider_to_framework.setdefault(info.config_type, {})[info.llm_framework] = info
|
|
694
|
+
self._retriever_client_framework_to_provider.setdefault(info.llm_framework, {})[info.config_type] = info
|
|
695
|
+
|
|
696
|
+
self._registration_changed()
|
|
697
|
+
|
|
698
|
+
def get_retriever_client(self, config_type: type[RetrieverBaseConfig],
|
|
699
|
+
wrapper_type: str | None) -> RegisteredRetrieverClientInfo:
|
|
700
|
+
|
|
701
|
+
try:
|
|
702
|
+
client_info = self._retriever_client_provider_to_framework[config_type][wrapper_type]
|
|
703
|
+
except KeyError as err:
|
|
704
|
+
raise KeyError(
|
|
705
|
+
f"An invalid Retriever config and wrapper combination was supplied. Config: `{config_type}`, "
|
|
706
|
+
"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
|
|
707
|
+
"there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
|
|
708
|
+
"Please provide a Retriever configuration from one of the following providers: "
|
|
709
|
+
f"{set(self._retriever_client_provider_to_framework.keys())}") from err
|
|
710
|
+
|
|
711
|
+
return client_info
|
|
712
|
+
|
|
713
|
+
def register_tool_wrapper(self, registration: RegisteredToolWrapper):
|
|
714
|
+
|
|
715
|
+
if (registration.llm_framework in self._registered_tool_wrappers):
|
|
716
|
+
raise ValueError(f"A tool wrapper for the LLM framework `{registration.llm_framework}` has already been "
|
|
717
|
+
"registered.")
|
|
718
|
+
|
|
719
|
+
self._registered_tool_wrappers[registration.llm_framework] = registration
|
|
720
|
+
|
|
721
|
+
self._registration_changed()
|
|
722
|
+
|
|
723
|
+
def get_tool_wrapper(self, llm_framework: str) -> RegisteredToolWrapper:
|
|
724
|
+
|
|
725
|
+
try:
|
|
726
|
+
return self._registered_tool_wrappers[llm_framework]
|
|
727
|
+
except KeyError as err:
|
|
728
|
+
raise KeyError(f"Could not find a registered tool wrapper for LLM framework `{llm_framework}`. "
|
|
729
|
+
f"Registered LLM frameworks: {set(self._registered_tool_wrappers.keys())}") from err
|
|
730
|
+
|
|
731
|
+
def register_ttc_strategy(self, info: RegisteredTTCStrategyInfo):
|
|
732
|
+
if (info.config_type in self._registered_ttc_strategies):
|
|
733
|
+
raise ValueError(
|
|
734
|
+
f"An TTC strategy with the same config type `{info.config_type}` has already been registered.")
|
|
735
|
+
|
|
736
|
+
self._registered_ttc_strategies[info.config_type] = info
|
|
737
|
+
|
|
738
|
+
self._registration_changed()
|
|
739
|
+
|
|
740
|
+
def get_ttc_strategy(self, config_type: type[TTCStrategyBaseConfig]) -> RegisteredTTCStrategyInfo:
|
|
741
|
+
try:
|
|
742
|
+
strategy = self._registered_ttc_strategies[config_type]
|
|
743
|
+
except Exception as e:
|
|
744
|
+
raise KeyError(f"Could not find a registered TTC strategy for config `{config_type}`. ") from e
|
|
745
|
+
return strategy
|
|
746
|
+
|
|
747
|
+
def get_registered_ttc_strategies(self) -> list[RegisteredInfo[TTCStrategyBaseConfig]]:
|
|
748
|
+
return list(self._registered_ttc_strategies.values())
|
|
749
|
+
|
|
750
|
+
def register_registry_handler(self, info: RegisteredRegistryHandlerInfo):
|
|
751
|
+
|
|
752
|
+
if (info.config_type in self._registered_memory_infos):
|
|
753
|
+
raise ValueError(
|
|
754
|
+
f"A Registry Handler with the same config type `{info.config_type}` has already been registered.")
|
|
755
|
+
|
|
756
|
+
self._registered_registry_handler_infos[info.config_type] = info
|
|
757
|
+
self._registered_channel_map[info.config_type.static_type()] = info
|
|
758
|
+
|
|
759
|
+
self._registration_changed()
|
|
760
|
+
|
|
761
|
+
def get_registry_handler(self, config_type: type[RegistryHandlerBaseConfig]) -> RegisteredRegistryHandlerInfo:
|
|
762
|
+
|
|
763
|
+
try:
|
|
764
|
+
return self._registered_registry_handler_infos[config_type]
|
|
765
|
+
except KeyError as err:
|
|
766
|
+
raise KeyError(f"Could not find a registered Registry Handler for config `{config_type}`. "
|
|
767
|
+
f"Registered configs: {set(self._registered_registry_handler_infos.keys())}") from err
|
|
768
|
+
|
|
769
|
+
def get_registered_registry_handlers(self) -> list[RegisteredInfo[RegistryHandlerBaseConfig]]:
|
|
770
|
+
|
|
771
|
+
return list(self._registered_registry_handler_infos.values())
|
|
772
|
+
|
|
773
|
+
def register_package(self, package_name: str, package_version: str | None = None):
|
|
774
|
+
|
|
775
|
+
discovery_metadata = DiscoveryMetadata.from_package_name(package_name=package_name,
|
|
776
|
+
package_version=package_version)
|
|
777
|
+
package = RegisteredPackage(discovery_metadata=discovery_metadata, package_name=package_name)
|
|
778
|
+
self._registered_packages[package.package_name] = package
|
|
779
|
+
|
|
780
|
+
self._registration_changed()
|
|
781
|
+
|
|
782
|
+
def get_infos_by_type(self, component_type: AIQComponentEnum) -> dict: # pylint: disable=R0911
|
|
783
|
+
|
|
784
|
+
if component_type == AIQComponentEnum.FRONT_END:
|
|
785
|
+
return self._registered_front_end_infos
|
|
786
|
+
|
|
787
|
+
if component_type == AIQComponentEnum.AUTHENTICATION_PROVIDER:
|
|
788
|
+
return self._registered_auth_provider_infos
|
|
789
|
+
|
|
790
|
+
if component_type == AIQComponentEnum.FUNCTION:
|
|
791
|
+
return self._registered_functions
|
|
792
|
+
|
|
793
|
+
if component_type == AIQComponentEnum.TOOL_WRAPPER:
|
|
794
|
+
return self._registered_tool_wrappers
|
|
795
|
+
|
|
796
|
+
if component_type == AIQComponentEnum.LLM_PROVIDER:
|
|
797
|
+
return self._registered_llm_provider_infos
|
|
798
|
+
|
|
799
|
+
if component_type == AIQComponentEnum.LLM_CLIENT:
|
|
800
|
+
leaf_llm_client_infos = {}
|
|
801
|
+
for framework in self._llm_client_provider_to_framework.values():
|
|
802
|
+
for info in framework.values():
|
|
803
|
+
leaf_llm_client_infos[info.discovery_metadata.component_name] = info
|
|
804
|
+
return leaf_llm_client_infos
|
|
805
|
+
|
|
806
|
+
if component_type == AIQComponentEnum.EMBEDDER_PROVIDER:
|
|
807
|
+
return self._registered_embedder_provider_infos
|
|
808
|
+
|
|
809
|
+
if component_type == AIQComponentEnum.EMBEDDER_CLIENT:
|
|
810
|
+
leaf_embedder_client_infos = {}
|
|
811
|
+
for framework in self._embedder_client_provider_to_framework.values():
|
|
812
|
+
for info in framework.values():
|
|
813
|
+
leaf_embedder_client_infos[info.discovery_metadata.component_name] = info
|
|
814
|
+
return leaf_embedder_client_infos
|
|
815
|
+
|
|
816
|
+
if component_type == AIQComponentEnum.RETRIEVER_PROVIDER:
|
|
817
|
+
return self._registered_retriever_provider_infos
|
|
818
|
+
|
|
819
|
+
if component_type == AIQComponentEnum.RETRIEVER_CLIENT:
|
|
820
|
+
leaf_retriever_client_infos = {}
|
|
821
|
+
for framework in self._retriever_client_provider_to_framework.values():
|
|
822
|
+
for info in framework.values():
|
|
823
|
+
leaf_retriever_client_infos[info.discovery_metadata.component_name] = info
|
|
824
|
+
return leaf_retriever_client_infos
|
|
825
|
+
|
|
826
|
+
if component_type == AIQComponentEnum.EVALUATOR:
|
|
827
|
+
return self._registered_evaluator_infos
|
|
828
|
+
|
|
829
|
+
if component_type == AIQComponentEnum.MEMORY:
|
|
830
|
+
return self._registered_memory_infos
|
|
831
|
+
|
|
832
|
+
if component_type == AIQComponentEnum.OBJECT_STORE:
|
|
833
|
+
return self._registered_object_store_infos
|
|
834
|
+
|
|
835
|
+
if component_type == AIQComponentEnum.REGISTRY_HANDLER:
|
|
836
|
+
return self._registered_registry_handler_infos
|
|
837
|
+
|
|
838
|
+
if component_type == AIQComponentEnum.LOGGING:
|
|
839
|
+
return self._registered_logging_methods
|
|
840
|
+
|
|
841
|
+
if component_type == AIQComponentEnum.TRACING:
|
|
842
|
+
return self._registered_telemetry_exporters
|
|
843
|
+
|
|
844
|
+
if component_type == AIQComponentEnum.PACKAGE:
|
|
845
|
+
return self._registered_packages
|
|
846
|
+
|
|
847
|
+
if component_type == AIQComponentEnum.TTC_STRATEGY:
|
|
848
|
+
return self._registered_ttc_strategies
|
|
849
|
+
|
|
850
|
+
raise ValueError(f"Supplied an unsupported component type {component_type}")
|
|
851
|
+
|
|
852
|
+
def get_registered_types_by_component_type( # pylint: disable=R0911
|
|
853
|
+
self, component_type: AIQComponentEnum) -> list[str]:
|
|
854
|
+
|
|
855
|
+
if component_type == AIQComponentEnum.FUNCTION:
|
|
856
|
+
return [i.static_type() for i in self._registered_functions]
|
|
857
|
+
|
|
858
|
+
if component_type == AIQComponentEnum.TOOL_WRAPPER:
|
|
859
|
+
return list(self._registered_tool_wrappers)
|
|
860
|
+
|
|
861
|
+
if component_type == AIQComponentEnum.LLM_PROVIDER:
|
|
862
|
+
return [i.static_type() for i in self._registered_llm_provider_infos]
|
|
863
|
+
|
|
864
|
+
if component_type == AIQComponentEnum.LLM_CLIENT:
|
|
865
|
+
leaf_client_provider_framework_types = []
|
|
866
|
+
for framework in self._llm_client_provider_to_framework.values():
|
|
867
|
+
for info in framework.values():
|
|
868
|
+
leaf_client_provider_framework_types.append([info.discovery_metadata.component_name])
|
|
869
|
+
return leaf_client_provider_framework_types
|
|
870
|
+
|
|
871
|
+
if component_type == AIQComponentEnum.EMBEDDER_PROVIDER:
|
|
872
|
+
return [i.static_type() for i in self._registered_embedder_provider_infos]
|
|
873
|
+
|
|
874
|
+
if component_type == AIQComponentEnum.EMBEDDER_CLIENT:
|
|
875
|
+
leaf_embedder_provider_framework_types = []
|
|
876
|
+
for framework in self._embedder_client_provider_to_framework.values():
|
|
877
|
+
for info in framework.values():
|
|
878
|
+
leaf_embedder_provider_framework_types.append([info.discovery_metadata.component_name])
|
|
879
|
+
return leaf_embedder_provider_framework_types
|
|
880
|
+
|
|
881
|
+
if component_type == AIQComponentEnum.EVALUATOR:
|
|
882
|
+
return [i.static_type() for i in self._registered_evaluator_infos]
|
|
883
|
+
|
|
884
|
+
if component_type == AIQComponentEnum.MEMORY:
|
|
885
|
+
return [i.static_type() for i in self._registered_memory_infos]
|
|
886
|
+
|
|
887
|
+
if component_type == AIQComponentEnum.REGISTRY_HANDLER:
|
|
888
|
+
return [i.static_type() for i in self._registered_registry_handler_infos]
|
|
889
|
+
|
|
890
|
+
if component_type == AIQComponentEnum.LOGGING:
|
|
891
|
+
return [i.static_type() for i in self._registered_logging_methods]
|
|
892
|
+
|
|
893
|
+
if component_type == AIQComponentEnum.TRACING:
|
|
894
|
+
return [i.static_type() for i in self._registered_telemetry_exporters]
|
|
895
|
+
|
|
896
|
+
if component_type == AIQComponentEnum.PACKAGE:
|
|
897
|
+
return list(self._registered_packages)
|
|
898
|
+
|
|
899
|
+
if component_type == AIQComponentEnum.TTC_STRATEGY:
|
|
900
|
+
return [i.static_type() for i in self._registered_ttc_strategies]
|
|
901
|
+
|
|
902
|
+
raise ValueError(f"Supplied an unsupported component type {component_type}")
|
|
903
|
+
|
|
904
|
+
def get_registered_channel_info_by_channel_type(self, channel_type: str) -> RegisteredRegistryHandlerInfo:
|
|
905
|
+
return self._registered_channel_map[channel_type]
|
|
906
|
+
|
|
907
|
+
def _do_compute_annotation(self, cls: type[TypedBaseModelT], registrations: list[RegisteredInfo[TypedBaseModelT]]):
|
|
908
|
+
|
|
909
|
+
while (len(registrations) < 2):
|
|
910
|
+
registrations.append(RegisteredInfo[TypedBaseModelT](full_type=f"_ignore/{len(registrations)}",
|
|
911
|
+
config_type=cls))
|
|
912
|
+
|
|
913
|
+
short_names: dict[str, int] = {}
|
|
914
|
+
type_list: list[tuple[str, type[TypedBaseModelT]]] = []
|
|
915
|
+
|
|
916
|
+
# For all keys in the list, split the key by / and increment the count of the last element
|
|
917
|
+
for key in registrations:
|
|
918
|
+
short_names[key.local_name] = short_names.get(key.local_name, 0) + 1
|
|
919
|
+
|
|
920
|
+
type_list.append((key.full_type, key.config_type))
|
|
921
|
+
|
|
922
|
+
# Now loop again and if the short name is unique, then create two entries, for the short and full name
|
|
923
|
+
for key in registrations:
|
|
924
|
+
|
|
925
|
+
if (short_names[key.local_name] == 1):
|
|
926
|
+
type_list.append((key.local_name, key.config_type))
|
|
927
|
+
|
|
928
|
+
# pylint: disable=consider-alternative-union-syntax
|
|
929
|
+
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
930
|
+
|
|
931
|
+
def compute_annotation(self, cls: type[TypedBaseModelT]):
|
|
932
|
+
|
|
933
|
+
if issubclass(cls, AuthProviderBaseConfig):
|
|
934
|
+
return self._do_compute_annotation(cls, self.get_registered_auth_providers())
|
|
935
|
+
|
|
936
|
+
if issubclass(cls, EmbedderBaseConfig):
|
|
937
|
+
return self._do_compute_annotation(cls, self.get_registered_embedder_providers())
|
|
938
|
+
|
|
939
|
+
if issubclass(cls, EvaluatorBaseConfig):
|
|
940
|
+
return self._do_compute_annotation(cls, self.get_registered_evaluators())
|
|
941
|
+
|
|
942
|
+
if issubclass(cls, FrontEndBaseConfig):
|
|
943
|
+
return self._do_compute_annotation(cls, self.get_registered_front_ends())
|
|
944
|
+
|
|
945
|
+
if issubclass(cls, FunctionBaseConfig):
|
|
946
|
+
return self._do_compute_annotation(cls, self.get_registered_functions())
|
|
947
|
+
|
|
948
|
+
if issubclass(cls, LLMBaseConfig):
|
|
949
|
+
return self._do_compute_annotation(cls, self.get_registered_llm_providers())
|
|
950
|
+
|
|
951
|
+
if issubclass(cls, MemoryBaseConfig):
|
|
952
|
+
return self._do_compute_annotation(cls, self.get_registered_memorys())
|
|
953
|
+
|
|
954
|
+
if issubclass(cls, ObjectStoreBaseConfig):
|
|
955
|
+
return self._do_compute_annotation(cls, self.get_registered_object_stores())
|
|
956
|
+
|
|
957
|
+
if issubclass(cls, RegistryHandlerBaseConfig):
|
|
958
|
+
return self._do_compute_annotation(cls, self.get_registered_registry_handlers())
|
|
959
|
+
|
|
960
|
+
if issubclass(cls, RetrieverBaseConfig):
|
|
961
|
+
return self._do_compute_annotation(cls, self.get_registered_retriever_providers())
|
|
962
|
+
|
|
963
|
+
if issubclass(cls, TelemetryExporterBaseConfig):
|
|
964
|
+
return self._do_compute_annotation(cls, self.get_registered_telemetry_exporters())
|
|
965
|
+
|
|
966
|
+
if issubclass(cls, LoggingBaseConfig):
|
|
967
|
+
return self._do_compute_annotation(cls, self.get_registered_logging_method())
|
|
968
|
+
|
|
969
|
+
if issubclass(cls, TTCStrategyBaseConfig):
|
|
970
|
+
return self._do_compute_annotation(cls, self.get_registered_ttc_strategies())
|
|
971
|
+
|
|
972
|
+
raise ValueError(f"Supplied an unsupported component type {cls}")
|
|
973
|
+
|
|
974
|
+
|
|
975
|
+
class GlobalTypeRegistry:
|
|
976
|
+
|
|
977
|
+
_global_registry: TypeRegistry = TypeRegistry()
|
|
978
|
+
|
|
979
|
+
@staticmethod
|
|
980
|
+
def get() -> TypeRegistry:
|
|
981
|
+
return GlobalTypeRegistry._global_registry
|
|
982
|
+
|
|
983
|
+
@staticmethod
|
|
984
|
+
@contextmanager
|
|
985
|
+
def push():
|
|
986
|
+
|
|
987
|
+
saved = GlobalTypeRegistry._global_registry
|
|
988
|
+
registry = deepcopy(saved)
|
|
989
|
+
|
|
990
|
+
try:
|
|
991
|
+
GlobalTypeRegistry._global_registry = registry
|
|
992
|
+
|
|
993
|
+
yield registry
|
|
994
|
+
finally:
|
|
995
|
+
GlobalTypeRegistry._global_registry = saved
|
|
996
|
+
GlobalTypeRegistry._global_registry._registration_changed()
|
|
997
|
+
|
|
998
|
+
|
|
999
|
+
# Finally, update the Config object each time the registry changes
|
|
1000
|
+
GlobalTypeRegistry.get().add_registration_changed_hook(lambda: AIQConfig.rebuild_annotations())
|