nvidia-nat 1.4.0a20251112__py3-none-any.whl → 1.4.0a20260113__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +1 -1
- nat/{front_ends/mcp → agent/auto_memory_wrapper}/__init__.py +1 -1
- nat/agent/auto_memory_wrapper/agent.py +278 -0
- nat/agent/auto_memory_wrapper/register.py +227 -0
- nat/agent/auto_memory_wrapper/state.py +30 -0
- nat/agent/base.py +1 -1
- nat/agent/dual_node.py +1 -1
- nat/agent/prompt_optimizer/prompt.py +1 -1
- nat/agent/prompt_optimizer/register.py +1 -1
- nat/agent/react_agent/agent.py +16 -9
- nat/agent/react_agent/output_parser.py +2 -2
- nat/agent/react_agent/prompt.py +3 -2
- nat/agent/react_agent/register.py +2 -2
- nat/agent/react_agent/register_per_user_agent.py +104 -0
- nat/agent/reasoning_agent/reasoning_agent.py +1 -1
- nat/agent/register.py +3 -1
- nat/agent/responses_api_agent/__init__.py +1 -1
- nat/agent/responses_api_agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +9 -4
- nat/agent/rewoo_agent/prompt.py +1 -1
- nat/agent/rewoo_agent/register.py +1 -1
- nat/agent/tool_calling_agent/agent.py +5 -4
- nat/agent/tool_calling_agent/register.py +1 -1
- nat/authentication/__init__.py +1 -1
- nat/authentication/api_key/__init__.py +1 -1
- nat/authentication/api_key/api_key_auth_provider.py +1 -1
- nat/authentication/api_key/api_key_auth_provider_config.py +22 -7
- nat/authentication/api_key/register.py +1 -1
- nat/authentication/credential_validator/__init__.py +1 -1
- nat/authentication/credential_validator/bearer_token_validator.py +1 -1
- nat/authentication/exceptions/__init__.py +1 -1
- nat/authentication/exceptions/api_key_exceptions.py +1 -1
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/http_basic_auth/register.py +1 -1
- nat/authentication/interfaces.py +1 -1
- nat/authentication/oauth2/__init__.py +1 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +1 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +1 -1
- nat/authentication/oauth2/oauth2_resource_server_config.py +1 -1
- nat/authentication/oauth2/register.py +1 -1
- nat/authentication/register.py +1 -1
- nat/builder/builder.py +563 -1
- nat/builder/child_builder.py +385 -0
- nat/builder/component_utils.py +34 -4
- nat/builder/context.py +34 -1
- nat/builder/embedder.py +1 -1
- nat/builder/eval_builder.py +19 -7
- nat/builder/evaluator.py +1 -1
- nat/builder/framework_enum.py +3 -1
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +113 -5
- nat/builder/function_base.py +1 -1
- nat/builder/function_info.py +1 -1
- nat/builder/intermediate_step_manager.py +1 -1
- nat/builder/llm.py +1 -1
- nat/builder/per_user_workflow_builder.py +843 -0
- nat/builder/retriever.py +1 -1
- nat/builder/sync_builder.py +571 -0
- nat/builder/user_interaction_manager.py +1 -1
- nat/builder/workflow.py +5 -3
- nat/builder/workflow_builder.py +619 -378
- nat/cli/__init__.py +1 -1
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/cli_utils/validation.py +32 -1
- nat/cli/commands/configure/channel/add.py +1 -1
- nat/cli/commands/configure/channel/channel.py +1 -1
- nat/cli/commands/configure/channel/remove.py +1 -1
- nat/cli/commands/configure/channel/update.py +1 -1
- nat/cli/commands/configure/configure.py +1 -1
- nat/cli/commands/evaluate.py +87 -13
- nat/cli/commands/finetune.py +132 -0
- nat/cli/commands/info/__init__.py +1 -1
- nat/cli/commands/info/info.py +1 -1
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +1 -1
- nat/cli/commands/object_store/__init__.py +1 -1
- nat/cli/commands/object_store/object_store.py +1 -1
- nat/cli/commands/optimize.py +1 -1
- nat/cli/commands/{mcp → red_teaming}/__init__.py +1 -1
- nat/cli/commands/red_teaming/red_teaming.py +138 -0
- nat/cli/commands/red_teaming/red_teaming_utils.py +73 -0
- nat/cli/commands/registry/__init__.py +1 -1
- nat/cli/commands/registry/publish.py +1 -1
- nat/cli/commands/registry/pull.py +1 -1
- nat/cli/commands/registry/registry.py +1 -1
- nat/cli/commands/registry/remove.py +1 -1
- nat/cli/commands/registry/search.py +1 -1
- nat/cli/commands/sizing/__init__.py +1 -1
- nat/cli/commands/sizing/calc.py +1 -1
- nat/cli/commands/sizing/sizing.py +1 -1
- nat/cli/commands/start.py +1 -1
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/validate.py +1 -1
- nat/cli/commands/workflow/__init__.py +1 -1
- nat/cli/commands/workflow/workflow.py +1 -1
- nat/cli/commands/workflow/workflow_commands.py +3 -2
- nat/cli/entrypoint.py +15 -37
- nat/cli/main.py +2 -2
- nat/cli/plugin_loader.py +69 -0
- nat/cli/register_workflow.py +233 -5
- nat/cli/type_registry.py +237 -3
- nat/control_flow/register.py +1 -1
- nat/control_flow/router_agent/agent.py +1 -1
- nat/control_flow/router_agent/prompt.py +1 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/control_flow/sequential_executor.py +28 -7
- nat/data_models/__init__.py +1 -1
- nat/data_models/agent.py +1 -1
- nat/data_models/api_server.py +38 -3
- nat/data_models/authentication.py +1 -1
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +9 -1
- nat/data_models/component_ref.py +45 -1
- nat/data_models/config.py +78 -1
- nat/data_models/dataset_handler.py +15 -2
- nat/data_models/discovery_metadata.py +1 -1
- nat/data_models/embedder.py +1 -1
- nat/data_models/evaluate.py +6 -1
- nat/data_models/evaluator.py +1 -1
- nat/data_models/finetuning.py +260 -0
- nat/data_models/front_end.py +1 -1
- nat/data_models/function.py +15 -2
- nat/data_models/function_dependencies.py +1 -1
- nat/data_models/gated_field_mixin.py +1 -1
- nat/data_models/interactive.py +1 -1
- nat/data_models/intermediate_step.py +29 -2
- nat/data_models/invocation_node.py +1 -1
- nat/data_models/llm.py +1 -1
- nat/data_models/logging.py +1 -1
- nat/data_models/memory.py +1 -1
- nat/data_models/middleware.py +37 -0
- nat/data_models/object_store.py +1 -1
- nat/data_models/openai_mcp.py +1 -1
- nat/data_models/optimizable.py +1 -1
- nat/data_models/optimizer.py +1 -1
- nat/data_models/profiler.py +1 -1
- nat/data_models/registry_handler.py +1 -1
- nat/data_models/retriever.py +1 -1
- nat/data_models/retry_mixin.py +1 -1
- nat/data_models/runtime_enum.py +26 -0
- nat/data_models/span.py +1 -1
- nat/data_models/step_adaptor.py +1 -1
- nat/data_models/streaming.py +1 -1
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/telemetry_exporter.py +1 -1
- nat/data_models/thinking_mixin.py +1 -1
- nat/data_models/ttc_strategy.py +1 -1
- nat/embedder/azure_openai_embedder.py +1 -1
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +1 -1
- nat/eval/__init__.py +1 -1
- nat/eval/config.py +8 -1
- nat/eval/dataset_handler/dataset_downloader.py +1 -1
- nat/eval/dataset_handler/dataset_filter.py +1 -1
- nat/eval/dataset_handler/dataset_handler.py +4 -2
- nat/eval/evaluate.py +226 -81
- nat/eval/evaluator/__init__.py +1 -1
- nat/eval/evaluator/base_evaluator.py +2 -2
- nat/eval/evaluator/evaluator_model.py +3 -2
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/llm_validator.py +336 -0
- nat/eval/rag_evaluator/evaluate.py +17 -10
- nat/eval/rag_evaluator/register.py +1 -1
- nat/eval/red_teaming_evaluator/__init__.py +14 -0
- nat/eval/red_teaming_evaluator/data_models.py +66 -0
- nat/eval/red_teaming_evaluator/evaluate.py +327 -0
- nat/eval/red_teaming_evaluator/filter_conditions.py +75 -0
- nat/eval/red_teaming_evaluator/register.py +55 -0
- nat/eval/register.py +2 -1
- nat/eval/remote_workflow.py +1 -1
- nat/eval/runners/__init__.py +1 -1
- nat/eval/runners/config.py +1 -1
- nat/eval/runners/multi_eval_runner.py +1 -1
- nat/eval/runners/red_teaming_runner/__init__.py +24 -0
- nat/eval/runners/red_teaming_runner/config.py +282 -0
- nat/eval/runners/red_teaming_runner/report_utils.py +707 -0
- nat/eval/runners/red_teaming_runner/runner.py +867 -0
- nat/eval/runtime_evaluator/__init__.py +1 -1
- nat/eval/runtime_evaluator/evaluate.py +1 -1
- nat/eval/runtime_evaluator/register.py +1 -1
- nat/eval/runtime_event_subscriber.py +1 -1
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/swe_bench_evaluator/register.py +1 -1
- nat/eval/trajectory_evaluator/evaluate.py +2 -2
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +5 -5
- nat/eval/tunable_rag_evaluator/register.py +1 -1
- nat/eval/usage_stats.py +1 -1
- nat/eval/utils/eval_trace_ctx.py +1 -1
- nat/eval/utils/output_uploader.py +1 -1
- nat/eval/utils/tqdm_position_registry.py +1 -1
- nat/eval/utils/weave_eval.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +1 -1
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +1 -1
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +1 -1
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +1 -1
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/multi_llm_judge_function.py +88 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/editor_config.py +1 -1
- nat/experimental/test_time_compute/models/scoring_config.py +1 -1
- nat/experimental/test_time_compute/models/search_config.py +20 -2
- nat/experimental/test_time_compute/models/selection_config.py +33 -2
- nat/experimental/test_time_compute/models/stage_enums.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +1 -1
- nat/experimental/test_time_compute/models/tool_use_config.py +1 -1
- nat/experimental/test_time_compute/models/ttc_item.py +1 -1
- nat/experimental/test_time_compute/register.py +4 -1
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +1 -1
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +1 -1
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +1 -1
- nat/experimental/test_time_compute/search/multi_llm_generation.py +115 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +1 -1
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +1 -1
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +1 -1
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_judge_selection.py +127 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +1 -1
- nat/finetuning/__init__.py +24 -0
- nat/finetuning/finetuning_runtime.py +143 -0
- nat/finetuning/interfaces/__init__.py +24 -0
- nat/finetuning/interfaces/finetuning_runner.py +261 -0
- nat/finetuning/interfaces/trainer_adapter.py +103 -0
- nat/finetuning/interfaces/trajectory_builder.py +115 -0
- nat/finetuning/utils/__init__.py +15 -0
- nat/finetuning/utils/parsers/__init__.py +15 -0
- nat/finetuning/utils/parsers/adk_parser.py +141 -0
- nat/finetuning/utils/parsers/base_parser.py +238 -0
- nat/finetuning/utils/parsers/common.py +91 -0
- nat/finetuning/utils/parsers/langchain_parser.py +267 -0
- nat/finetuning/utils/parsers/llama_index_parser.py +218 -0
- nat/front_ends/__init__.py +1 -1
- nat/front_ends/console/__init__.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +1 -1
- nat/front_ends/console/console_front_end_config.py +4 -1
- nat/front_ends/console/console_front_end_plugin.py +5 -4
- nat/front_ends/console/register.py +1 -1
- nat/front_ends/cron/__init__.py +1 -1
- nat/front_ends/fastapi/__init__.py +1 -1
- nat/front_ends/fastapi/async_job.py +128 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +13 -9
- nat/front_ends/fastapi/dask_client_mixin.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_config.py +23 -1
- nat/front_ends/fastapi/fastapi_front_end_controller.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +25 -30
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +318 -59
- nat/front_ends/fastapi/html_snippets/__init__.py +1 -1
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +1 -1
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +12 -1
- nat/front_ends/fastapi/job_store.py +23 -11
- nat/front_ends/fastapi/main.py +1 -1
- nat/front_ends/fastapi/message_handler.py +27 -4
- nat/front_ends/fastapi/message_validator.py +54 -2
- nat/front_ends/fastapi/register.py +1 -1
- nat/front_ends/fastapi/response_helpers.py +16 -15
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/fastapi/utils.py +1 -1
- nat/front_ends/register.py +1 -2
- nat/front_ends/simple_base/__init__.py +1 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +6 -4
- nat/llm/aws_bedrock_llm.py +1 -1
- nat/llm/azure_openai_llm.py +10 -1
- nat/llm/dynamo_llm.py +363 -0
- nat/llm/huggingface_llm.py +177 -0
- nat/llm/litellm_llm.py +1 -1
- nat/llm/nim_llm.py +1 -1
- nat/llm/openai_llm.py +1 -1
- nat/llm/register.py +3 -1
- nat/llm/utils/__init__.py +1 -1
- nat/llm/utils/env_config_value.py +1 -1
- nat/llm/utils/error.py +1 -1
- nat/llm/utils/thinking.py +1 -1
- nat/memory/__init__.py +1 -1
- nat/memory/interfaces.py +1 -1
- nat/memory/models.py +1 -1
- nat/meta/pypi.md +1 -1
- nat/middleware/__init__.py +35 -0
- nat/middleware/cache/__init__.py +14 -0
- nat/middleware/cache/cache_middleware.py +253 -0
- nat/middleware/cache/cache_middleware_config.py +44 -0
- nat/middleware/cache/register.py +33 -0
- nat/middleware/defense/__init__.py +14 -0
- nat/middleware/defense/defense_middleware.py +362 -0
- nat/middleware/defense/defense_middleware_content_guard.py +455 -0
- nat/middleware/defense/defense_middleware_data_models.py +91 -0
- nat/middleware/defense/defense_middleware_output_verifier.py +440 -0
- nat/middleware/defense/defense_middleware_pii.py +356 -0
- nat/middleware/defense/register.py +82 -0
- nat/middleware/dynamic/__init__.py +14 -0
- nat/middleware/dynamic/dynamic_function_middleware.py +962 -0
- nat/middleware/dynamic/dynamic_middleware_config.py +132 -0
- nat/middleware/dynamic/register.py +34 -0
- nat/middleware/function_middleware.py +370 -0
- nat/middleware/logging/__init__.py +14 -0
- nat/middleware/logging/logging_middleware.py +67 -0
- nat/middleware/logging/logging_middleware_config.py +28 -0
- nat/middleware/logging/register.py +33 -0
- nat/middleware/middleware.py +298 -0
- nat/middleware/red_teaming/__init__.py +14 -0
- nat/middleware/red_teaming/red_teaming_middleware.py +344 -0
- nat/middleware/red_teaming/red_teaming_middleware_config.py +112 -0
- nat/middleware/red_teaming/register.py +47 -0
- nat/middleware/register.py +22 -0
- nat/middleware/utils/__init__.py +14 -0
- nat/middleware/utils/workflow_inventory.py +155 -0
- nat/object_store/__init__.py +1 -1
- nat/object_store/in_memory_object_store.py +1 -1
- nat/object_store/interfaces.py +1 -1
- nat/object_store/models.py +1 -1
- nat/object_store/register.py +1 -1
- nat/observability/__init__.py +1 -1
- nat/observability/exporter/__init__.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/exporter.py +1 -1
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +1 -1
- nat/observability/exporter/raw_exporter.py +1 -1
- nat/observability/exporter/span_exporter.py +7 -1
- nat/observability/exporter_manager.py +1 -1
- nat/observability/mixin/__init__.py +1 -1
- nat/observability/mixin/batch_config_mixin.py +1 -1
- nat/observability/mixin/collector_config_mixin.py +1 -1
- nat/observability/mixin/file_mixin.py +1 -1
- nat/observability/mixin/file_mode.py +1 -1
- nat/observability/mixin/redaction_config_mixin.py +1 -1
- nat/observability/mixin/resource_conflict_mixin.py +1 -1
- nat/observability/mixin/serialize_mixin.py +1 -1
- nat/observability/mixin/tagging_config_mixin.py +1 -1
- nat/observability/mixin/type_introspection_mixin.py +1 -1
- nat/observability/processor/__init__.py +1 -1
- nat/observability/processor/batching_processor.py +1 -1
- nat/observability/processor/callback_processor.py +1 -1
- nat/observability/processor/falsy_batch_filter_processor.py +1 -1
- nat/observability/processor/intermediate_step_serializer.py +1 -1
- nat/observability/processor/processor.py +1 -1
- nat/observability/processor/processor_factory.py +1 -1
- nat/observability/processor/redaction/__init__.py +1 -1
- nat/observability/processor/redaction/contextual_redaction_processor.py +1 -1
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +1 -1
- nat/observability/processor/redaction/redaction_processor.py +1 -1
- nat/observability/processor/redaction/span_header_redaction_processor.py +1 -1
- nat/observability/processor/span_tagging_processor.py +1 -1
- nat/observability/register.py +1 -1
- nat/observability/utils/__init__.py +1 -1
- nat/observability/utils/dict_utils.py +1 -1
- nat/observability/utils/time_utils.py +1 -1
- nat/profiler/calc/__init__.py +1 -1
- nat/profiler/calc/calc_runner.py +3 -3
- nat/profiler/calc/calculations.py +1 -1
- nat/profiler/calc/data_models.py +1 -1
- nat/profiler/calc/plot.py +30 -3
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/base_callback_class.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +33 -3
- nat/profiler/callbacks/llama_index_callback_handler.py +13 -10
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- nat/profiler/callbacks/token_usage_base_model.py +1 -1
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/data_models.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +32 -1
- nat/profiler/decorators/function_tracking.py +1 -1
- nat/profiler/forecasting/config.py +1 -1
- nat/profiler/forecasting/model_trainer.py +1 -1
- nat/profiler/forecasting/models/__init__.py +1 -1
- nat/profiler/forecasting/models/forecasting_base_model.py +1 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_metrics_model.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +1 -1
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/profiler/inference_optimization/llm_metrics.py +1 -1
- nat/profiler/inference_optimization/prompt_caching.py +1 -1
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/inference_optimization/workflow_runtimes.py +1 -1
- nat/profiler/intermediate_property_adapter.py +1 -1
- nat/profiler/parameter_optimization/optimizable_utils.py +1 -1
- nat/profiler/parameter_optimization/optimizer_runtime.py +1 -1
- nat/profiler/parameter_optimization/parameter_optimizer.py +1 -1
- nat/profiler/parameter_optimization/parameter_selection.py +1 -1
- nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
- nat/profiler/parameter_optimization/prompt_optimizer.py +1 -1
- nat/profiler/parameter_optimization/update_helpers.py +1 -1
- nat/profiler/profile_runner.py +1 -1
- nat/profiler/utils.py +1 -1
- nat/registry_handlers/local/local_handler.py +1 -1
- nat/registry_handlers/local/register_local.py +1 -1
- nat/registry_handlers/metadata_factory.py +1 -1
- nat/registry_handlers/package_utils.py +1 -1
- nat/registry_handlers/pypi/pypi_handler.py +1 -1
- nat/registry_handlers/pypi/register_pypi.py +1 -1
- nat/registry_handlers/register.py +1 -1
- nat/registry_handlers/registry_handler_base.py +1 -1
- nat/registry_handlers/rest/register_rest.py +1 -1
- nat/registry_handlers/rest/rest_handler.py +1 -1
- nat/registry_handlers/schemas/headers.py +1 -1
- nat/registry_handlers/schemas/package.py +1 -1
- nat/registry_handlers/schemas/publish.py +1 -1
- nat/registry_handlers/schemas/pull.py +1 -1
- nat/registry_handlers/schemas/remove.py +1 -1
- nat/registry_handlers/schemas/search.py +1 -1
- nat/registry_handlers/schemas/status.py +1 -1
- nat/retriever/interface.py +1 -1
- nat/retriever/milvus/__init__.py +1 -1
- nat/retriever/milvus/register.py +12 -4
- nat/retriever/milvus/retriever.py +103 -41
- nat/retriever/models.py +1 -1
- nat/retriever/nemo_retriever/__init__.py +1 -1
- nat/retriever/nemo_retriever/register.py +1 -1
- nat/retriever/nemo_retriever/retriever.py +5 -5
- nat/retriever/register.py +1 -1
- nat/runtime/__init__.py +1 -1
- nat/runtime/loader.py +10 -3
- nat/runtime/metrics.py +180 -0
- nat/runtime/runner.py +13 -6
- nat/runtime/session.py +458 -32
- nat/runtime/user_metadata.py +1 -1
- nat/settings/global_settings.py +1 -1
- nat/tool/chat_completion.py +1 -1
- nat/tool/code_execution/README.md +1 -1
- nat/tool/code_execution/code_sandbox.py +2 -2
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +1 -1
- nat/tool/code_execution/local_sandbox/__init__.py +1 -1
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +1 -1
- nat/tool/code_execution/register.py +1 -1
- nat/tool/code_execution/utils.py +1 -1
- nat/tool/datetime_tools.py +1 -1
- nat/tool/document_search.py +1 -1
- nat/tool/github_tools.py +1 -1
- nat/tool/memory_tools/add_memory_tool.py +1 -1
- nat/tool/memory_tools/delete_memory_tool.py +1 -1
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +2 -2
- nat/tool/register.py +1 -1
- nat/tool/retriever.py +1 -1
- nat/tool/server_tools.py +1 -1
- nat/utils/__init__.py +8 -5
- nat/utils/callable_utils.py +1 -1
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/debugging_utils.py +1 -1
- nat/utils/decorators.py +1 -1
- nat/utils/dump_distro_mapping.py +1 -1
- nat/utils/exception_handlers/automatic_retries.py +3 -3
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/model_processing.py +1 -1
- nat/utils/io/supress_logs.py +33 -0
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/log_levels.py +1 -1
- nat/utils/log_utils.py +13 -1
- nat/utils/metadata_utils.py +1 -1
- nat/utils/optional_imports.py +1 -1
- nat/utils/producer_consumer_queue.py +1 -1
- nat/utils/reactive/base/observable_base.py +1 -1
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/base/subject_base.py +1 -1
- nat/utils/reactive/observable.py +1 -1
- nat/utils/reactive/observer.py +1 -1
- nat/utils/reactive/subject.py +1 -1
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/responses_api.py +1 -1
- nat/utils/settings/global_settings.py +1 -1
- nat/utils/string_utils.py +1 -1
- nat/utils/type_converter.py +18 -5
- nat/utils/type_utils.py +1 -1
- nat/utils/url_utils.py +1 -1
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/METADATA +46 -15
- nvidia_nat-1.4.0a20260113.dist-info/RECORD +547 -0
- nvidia_nat-1.4.0a20260113.dist-info/entry_points.txt +38 -0
- nat/cli/commands/mcp/mcp.py +0 -986
- nat/front_ends/mcp/introspection_token_verifier.py +0 -73
- nat/front_ends/mcp/mcp_front_end_config.py +0 -109
- nat/front_ends/mcp/mcp_front_end_plugin.py +0 -151
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +0 -362
- nat/front_ends/mcp/memory_profiler.py +0 -320
- nat/front_ends/mcp/register.py +0 -27
- nat/front_ends/mcp/tool_converter.py +0 -321
- nvidia_nat-1.4.0a20251112.dist-info/RECORD +0 -481
- nvidia_nat-1.4.0a20251112.dist-info/entry_points.txt +0 -22
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/top_level.txt +0 -0
nat/builder/workflow_builder.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024-
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -28,7 +28,8 @@ from typing import cast
|
|
|
28
28
|
from nat.authentication.interfaces import AuthProviderBase
|
|
29
29
|
from nat.builder.builder import Builder
|
|
30
30
|
from nat.builder.builder import UserManagerHolder
|
|
31
|
-
from nat.builder.
|
|
31
|
+
from nat.builder.child_builder import ChildBuilder
|
|
32
|
+
from nat.builder.component_utils import WORKFLOW_COMPONENT_NAME
|
|
32
33
|
from nat.builder.component_utils import build_dependency_sequence
|
|
33
34
|
from nat.builder.context import Context
|
|
34
35
|
from nat.builder.context import ContextState
|
|
@@ -40,6 +41,7 @@ from nat.builder.function import LambdaFunction
|
|
|
40
41
|
from nat.builder.function_info import FunctionInfo
|
|
41
42
|
from nat.builder.llm import LLMProviderInfo
|
|
42
43
|
from nat.builder.retriever import RetrieverProviderInfo
|
|
44
|
+
from nat.builder.sync_builder import SyncBuilder
|
|
43
45
|
from nat.builder.workflow import Workflow
|
|
44
46
|
from nat.cli.type_registry import GlobalTypeRegistry
|
|
45
47
|
from nat.cli.type_registry import TypeRegistry
|
|
@@ -51,17 +53,25 @@ from nat.data_models.component_ref import FunctionGroupRef
|
|
|
51
53
|
from nat.data_models.component_ref import FunctionRef
|
|
52
54
|
from nat.data_models.component_ref import LLMRef
|
|
53
55
|
from nat.data_models.component_ref import MemoryRef
|
|
56
|
+
from nat.data_models.component_ref import MiddlewareRef
|
|
54
57
|
from nat.data_models.component_ref import ObjectStoreRef
|
|
55
58
|
from nat.data_models.component_ref import RetrieverRef
|
|
59
|
+
from nat.data_models.component_ref import TrainerAdapterRef
|
|
60
|
+
from nat.data_models.component_ref import TrainerRef
|
|
61
|
+
from nat.data_models.component_ref import TrajectoryBuilderRef
|
|
56
62
|
from nat.data_models.component_ref import TTCStrategyRef
|
|
57
63
|
from nat.data_models.config import Config
|
|
58
64
|
from nat.data_models.config import GeneralConfig
|
|
59
65
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
66
|
+
from nat.data_models.finetuning import TrainerAdapterConfig
|
|
67
|
+
from nat.data_models.finetuning import TrainerConfig
|
|
68
|
+
from nat.data_models.finetuning import TrajectoryBuilderConfig
|
|
60
69
|
from nat.data_models.function import FunctionBaseConfig
|
|
61
70
|
from nat.data_models.function import FunctionGroupBaseConfig
|
|
62
71
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
63
72
|
from nat.data_models.llm import LLMBaseConfig
|
|
64
73
|
from nat.data_models.memory import MemoryBaseConfig
|
|
74
|
+
from nat.data_models.middleware import MiddlewareBaseConfig
|
|
65
75
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
66
76
|
from nat.data_models.retriever import RetrieverBaseConfig
|
|
67
77
|
from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
|
@@ -70,12 +80,16 @@ from nat.experimental.decorators.experimental_warning_decorator import experimen
|
|
|
70
80
|
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
71
81
|
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
72
82
|
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
83
|
+
from nat.finetuning.interfaces.finetuning_runner import Trainer
|
|
84
|
+
from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
|
|
85
|
+
from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
|
|
73
86
|
from nat.memory.interfaces import MemoryEditor
|
|
87
|
+
from nat.middleware.function_middleware import FunctionMiddleware
|
|
88
|
+
from nat.middleware.middleware import Middleware
|
|
74
89
|
from nat.object_store.interfaces import ObjectStore
|
|
75
90
|
from nat.observability.exporter.base_exporter import BaseExporter
|
|
76
91
|
from nat.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
|
|
77
92
|
from nat.profiler.utils import detect_llm_frameworks_in_build_fn
|
|
78
|
-
from nat.retriever.interface import Retriever
|
|
79
93
|
from nat.utils.type_utils import override
|
|
80
94
|
|
|
81
95
|
logger = logging.getLogger(__name__)
|
|
@@ -141,6 +155,157 @@ class ConfiguredTTCStrategy:
|
|
|
141
155
|
instance: StrategyBase
|
|
142
156
|
|
|
143
157
|
|
|
158
|
+
@dataclasses.dataclass
|
|
159
|
+
class ConfiguredMiddleware:
|
|
160
|
+
config: MiddlewareBaseConfig
|
|
161
|
+
instance: Middleware
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@dataclasses.dataclass
|
|
165
|
+
class ConfiguredTrainer:
|
|
166
|
+
config: TrainerConfig
|
|
167
|
+
instance: Trainer
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@dataclasses.dataclass
|
|
171
|
+
class ConfiguredTrainerAdapter:
|
|
172
|
+
config: TrainerAdapterConfig
|
|
173
|
+
instance: TrainerAdapter
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@dataclasses.dataclass
|
|
177
|
+
class ConfiguredTrajectoryBuilder:
|
|
178
|
+
config: TrajectoryBuilderConfig
|
|
179
|
+
instance: TrajectoryBuilder
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _log_build_failure(component_name: str,
|
|
183
|
+
component_type: str,
|
|
184
|
+
completed_components: list[tuple[str, str]],
|
|
185
|
+
remaining_components: list[tuple[str, str]],
|
|
186
|
+
original_error: Exception) -> None:
|
|
187
|
+
"""
|
|
188
|
+
Common method to log comprehensive build failure information.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
component_name (str): The name of the component that failed to build
|
|
192
|
+
component_type (str): The type of the component that failed to build
|
|
193
|
+
completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
|
|
194
|
+
remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
|
|
195
|
+
original_error (Exception): The original exception that caused the failure
|
|
196
|
+
"""
|
|
197
|
+
logger.error("Failed to initialize component %s (%s)", component_name, component_type)
|
|
198
|
+
|
|
199
|
+
if completed_components:
|
|
200
|
+
logger.error("Successfully built components:")
|
|
201
|
+
for name, comp_type in completed_components:
|
|
202
|
+
logger.error("- %s (%s)", name, comp_type)
|
|
203
|
+
else:
|
|
204
|
+
logger.error("No components were successfully built before this failure")
|
|
205
|
+
|
|
206
|
+
if remaining_components:
|
|
207
|
+
logger.error("Remaining components to build:")
|
|
208
|
+
for name, comp_type in remaining_components:
|
|
209
|
+
logger.error("- %s (%s)", name, comp_type)
|
|
210
|
+
else:
|
|
211
|
+
logger.error("No remaining components to build")
|
|
212
|
+
|
|
213
|
+
logger.error("Original error: %s", original_error, exc_info=True)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
async def _build_function_impl(
|
|
217
|
+
*,
|
|
218
|
+
name: str,
|
|
219
|
+
config: FunctionBaseConfig,
|
|
220
|
+
registry: TypeRegistry,
|
|
221
|
+
exit_stack: AsyncExitStack,
|
|
222
|
+
inner_builder: 'ChildBuilder',
|
|
223
|
+
llms: dict[str, LLMProviderInfo],
|
|
224
|
+
dependencies: dict[str, FunctionDependencies],
|
|
225
|
+
middleware_instances: list[FunctionMiddleware],
|
|
226
|
+
) -> ConfiguredFunction:
|
|
227
|
+
"""
|
|
228
|
+
Helper for core function building logic.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
name: The function name
|
|
232
|
+
config: The function configuration
|
|
233
|
+
registry: Type registry to look up the function registration
|
|
234
|
+
exit_stack: Async exit stack for context management
|
|
235
|
+
inner_builder: ChildBuilder instance for dependency tracking
|
|
236
|
+
llms: Dictionary of LLM instances
|
|
237
|
+
dependencies: Dictionary to store function dependencies
|
|
238
|
+
middleware_instances: Pre-resolved middleware instances
|
|
239
|
+
"""
|
|
240
|
+
registration = registry.get_function(type(config))
|
|
241
|
+
|
|
242
|
+
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
243
|
+
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
244
|
+
|
|
245
|
+
build_result = await exit_stack.enter_async_context(build_fn(config, inner_builder))
|
|
246
|
+
|
|
247
|
+
dependencies[name] = inner_builder.dependencies
|
|
248
|
+
|
|
249
|
+
# If the build result is a function, wrap it in a FunctionInfo
|
|
250
|
+
if inspect.isfunction(build_result):
|
|
251
|
+
build_result = FunctionInfo.from_fn(build_result)
|
|
252
|
+
|
|
253
|
+
if isinstance(build_result, FunctionInfo):
|
|
254
|
+
build_result = LambdaFunction.from_info(config=config, info=build_result, instance_name=name)
|
|
255
|
+
|
|
256
|
+
if not isinstance(build_result, Function):
|
|
257
|
+
raise ValueError("Expected a function, FunctionInfo object, or FunctionBase object to be "
|
|
258
|
+
f"returned from the function builder. Got {type(build_result)}")
|
|
259
|
+
|
|
260
|
+
build_result.configure_middleware(middleware_instances)
|
|
261
|
+
|
|
262
|
+
return ConfiguredFunction(config=config, instance=build_result)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
async def _build_function_group_impl(
|
|
266
|
+
*,
|
|
267
|
+
name: str,
|
|
268
|
+
config: FunctionGroupBaseConfig,
|
|
269
|
+
registry: TypeRegistry,
|
|
270
|
+
exit_stack: AsyncExitStack,
|
|
271
|
+
inner_builder: 'ChildBuilder',
|
|
272
|
+
llms: dict[str, LLMProviderInfo],
|
|
273
|
+
dependencies: dict[str, FunctionDependencies],
|
|
274
|
+
middleware_instances: list[FunctionMiddleware],
|
|
275
|
+
) -> ConfiguredFunctionGroup:
|
|
276
|
+
"""
|
|
277
|
+
Core function group building logic shared between WorkflowBuilder and PerUserWorkflowBuilder.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
name: The function group name
|
|
281
|
+
config: The function group configuration
|
|
282
|
+
registry: Type registry to look up the function group registration
|
|
283
|
+
exit_stack: Async exit stack for context management
|
|
284
|
+
inner_builder: ChildBuilder instance for dependency tracking
|
|
285
|
+
llms: Dictionary of LLM instances
|
|
286
|
+
dependencies: Dictionary to store function group dependencies
|
|
287
|
+
middleware_instances: Pre-resolved middleware instances
|
|
288
|
+
"""
|
|
289
|
+
registration = registry.get_function_group(type(config))
|
|
290
|
+
|
|
291
|
+
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
292
|
+
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
293
|
+
|
|
294
|
+
build_result = await exit_stack.enter_async_context(build_fn(config, inner_builder))
|
|
295
|
+
|
|
296
|
+
dependencies[name] = inner_builder.dependencies
|
|
297
|
+
|
|
298
|
+
if not isinstance(build_result, FunctionGroup):
|
|
299
|
+
raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
|
|
300
|
+
f"Got {type(build_result)}")
|
|
301
|
+
|
|
302
|
+
# Set the instance name BEFORE configuring middleware
|
|
303
|
+
build_result.set_instance_name(name)
|
|
304
|
+
build_result.configure_middleware(middleware_instances)
|
|
305
|
+
|
|
306
|
+
return ConfiguredFunctionGroup(config=config, instance=build_result)
|
|
307
|
+
|
|
308
|
+
|
|
144
309
|
class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
145
310
|
|
|
146
311
|
def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
|
|
@@ -170,6 +335,10 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
170
335
|
self._object_stores: dict[str, ConfiguredObjectStore] = {}
|
|
171
336
|
self._retrievers: dict[str, ConfiguredRetriever] = {}
|
|
172
337
|
self._ttc_strategies: dict[str, ConfiguredTTCStrategy] = {}
|
|
338
|
+
self._middleware: dict[str, ConfiguredMiddleware] = {}
|
|
339
|
+
self._trainers: dict[str, ConfiguredTrainer] = {}
|
|
340
|
+
self._trainer_adapters: dict[str, ConfiguredTrainerAdapter] = {}
|
|
341
|
+
self._trajectory_builders: dict[str, ConfiguredTrajectoryBuilder] = {}
|
|
173
342
|
|
|
174
343
|
self._context_state = ContextState.get()
|
|
175
344
|
|
|
@@ -178,8 +347,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
178
347
|
# Create a mapping to track function name -> other function names it depends on
|
|
179
348
|
self.function_dependencies: dict[str, FunctionDependencies] = {}
|
|
180
349
|
self.function_group_dependencies: dict[str, FunctionDependencies] = {}
|
|
181
|
-
|
|
182
|
-
|
|
350
|
+
|
|
351
|
+
# List of completed built components
|
|
352
|
+
self.completed_components: list[tuple[str, str]] = []
|
|
353
|
+
# List of remaining components to be built
|
|
354
|
+
self.remaining_components: list[tuple[str, str]] = []
|
|
183
355
|
|
|
184
356
|
async def __aenter__(self):
|
|
185
357
|
|
|
@@ -260,6 +432,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
260
432
|
|
|
261
433
|
await self._exit_stack.__aexit__(*exc_details)
|
|
262
434
|
|
|
435
|
+
@override
|
|
436
|
+
@property
|
|
437
|
+
def sync_builder(self) -> SyncBuilder:
|
|
438
|
+
return SyncBuilder(self)
|
|
439
|
+
|
|
263
440
|
async def build(self, entry_function: str | None = None) -> Workflow:
|
|
264
441
|
"""
|
|
265
442
|
Creates an instance of a workflow object using the added components and the desired entry function.
|
|
@@ -334,6 +511,18 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
334
511
|
ttc_strategies={
|
|
335
512
|
k: v.config
|
|
336
513
|
for k, v in self._ttc_strategies.items()
|
|
514
|
+
},
|
|
515
|
+
trainers={
|
|
516
|
+
k: v.config
|
|
517
|
+
for k, v in self._trainers.items()
|
|
518
|
+
},
|
|
519
|
+
trainer_adapters={
|
|
520
|
+
k: v.config
|
|
521
|
+
for k, v in self._trainer_adapters.items()
|
|
522
|
+
},
|
|
523
|
+
trajectory_builders={
|
|
524
|
+
k: v.config
|
|
525
|
+
for k, v in self._trajectory_builders.items()
|
|
337
526
|
})
|
|
338
527
|
|
|
339
528
|
if (entry_function is None):
|
|
@@ -385,45 +574,49 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
385
574
|
|
|
386
575
|
return self._exit_stack
|
|
387
576
|
|
|
388
|
-
async def
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
# We need to do this for every function because we don't know
|
|
394
|
-
# Where LLama Index Agents are Instantiated and Settings need to
|
|
395
|
-
# be set before the function is built
|
|
396
|
-
# It's only slower the first time because of the import
|
|
397
|
-
# So we can afford to do this for every function
|
|
398
|
-
|
|
399
|
-
llms = {k: v.instance for k, v in self._llms.items()}
|
|
400
|
-
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
401
|
-
|
|
402
|
-
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
403
|
-
|
|
404
|
-
# Set the currently building function so the ChildBuilder can track dependencies
|
|
405
|
-
self.current_function_building = config.type
|
|
406
|
-
# Empty set of dependencies for the current function
|
|
407
|
-
self.function_dependencies[config.type] = FunctionDependencies()
|
|
408
|
-
|
|
409
|
-
build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
|
|
410
|
-
|
|
411
|
-
self.function_dependencies[name] = inner_builder.dependencies
|
|
412
|
-
|
|
413
|
-
# If the build result is a function, wrap it in a FunctionInfo
|
|
414
|
-
if inspect.isfunction(build_result):
|
|
415
|
-
|
|
416
|
-
build_result = FunctionInfo.from_fn(build_result)
|
|
417
|
-
|
|
418
|
-
if (isinstance(build_result, FunctionInfo)):
|
|
419
|
-
# Create the function object
|
|
420
|
-
build_result = LambdaFunction.from_info(config=config, info=build_result, instance_name=name)
|
|
577
|
+
async def _resolve_middleware_instances(self, middleware_names: list[str], component_name: str,
|
|
578
|
+
component_type: str) -> list[FunctionMiddleware]:
|
|
579
|
+
"""
|
|
580
|
+
Resolve middleware names to FunctionMiddleware instances.
|
|
581
|
+
"""
|
|
421
582
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
583
|
+
middleware_instances: list[FunctionMiddleware] = []
|
|
584
|
+
for middleware_name in middleware_names:
|
|
585
|
+
if middleware_name not in self._middleware:
|
|
586
|
+
raise ValueError(f"Middleware `{middleware_name}` not found for {component_type} `{component_name}`. "
|
|
587
|
+
f"It must be configured in the `middleware` section of the YAML configuration.")
|
|
588
|
+
middleware_obj = self._middleware[middleware_name].instance
|
|
589
|
+
if not isinstance(middleware_obj, FunctionMiddleware):
|
|
590
|
+
raise TypeError(f"Middleware `{middleware_name}` is not a FunctionMiddleware and cannot be used"
|
|
591
|
+
f"with {component_type}s. "
|
|
592
|
+
f"Only FunctionMiddleware types support function-specific wrapping.")
|
|
593
|
+
middleware_instances.append(middleware_obj)
|
|
594
|
+
return middleware_instances
|
|
425
595
|
|
|
426
|
-
|
|
596
|
+
async def _build_function(self, name: str, config: FunctionBaseConfig) -> ConfiguredFunction:
|
|
597
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
598
|
+
|
|
599
|
+
# We need to do this for every function because we don't know
|
|
600
|
+
# Where LLama Index Agents are Instantiated and Settings need to
|
|
601
|
+
# be set before the function is built
|
|
602
|
+
# It's only slower the first time because of the import
|
|
603
|
+
# So we can afford to do this for every function
|
|
604
|
+
llms = {k: v.instance for k, v in self._llms.items()}
|
|
605
|
+
|
|
606
|
+
# Resolve middleware names from config to middleware instances
|
|
607
|
+
# Only FunctionMiddleware types can be used with functions
|
|
608
|
+
middleware_instances = await self._resolve_middleware_instances(config.middleware, name, "function")
|
|
609
|
+
|
|
610
|
+
return await _build_function_impl(
|
|
611
|
+
name=name,
|
|
612
|
+
config=config,
|
|
613
|
+
registry=self._registry,
|
|
614
|
+
exit_stack=self._get_exit_stack(),
|
|
615
|
+
inner_builder=inner_builder,
|
|
616
|
+
llms=llms,
|
|
617
|
+
dependencies=self.function_dependencies,
|
|
618
|
+
middleware_instances=middleware_instances,
|
|
619
|
+
)
|
|
427
620
|
|
|
428
621
|
async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
|
|
429
622
|
"""Build a function group from the provided configuration.
|
|
@@ -438,32 +631,23 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
438
631
|
Raises:
|
|
439
632
|
ValueError: If the function group builder returns invalid results
|
|
440
633
|
"""
|
|
441
|
-
registration = self._registry.get_function_group(type(config))
|
|
442
|
-
|
|
443
|
-
inner_builder = ChildBuilder(self)
|
|
444
634
|
|
|
445
|
-
|
|
446
|
-
llms = {k: v.instance for k, v in self._llms.items()}
|
|
447
|
-
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
635
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
448
636
|
|
|
449
|
-
|
|
637
|
+
# Build the function group - use the same wrapping pattern as _build_function
|
|
638
|
+
llms = {k: v.instance for k, v in self._llms.items()}
|
|
639
|
+
# Resolve middleware names from config to middleware instances
|
|
640
|
+
# Only FunctionMiddleware types can be used with function groups
|
|
641
|
+
middleware_instances = await self._resolve_middleware_instances(config.middleware, name, "function group")
|
|
450
642
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
if not isinstance(build_result, FunctionGroup):
|
|
461
|
-
raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
|
|
462
|
-
f"Got {type(build_result)}")
|
|
463
|
-
|
|
464
|
-
# set the instance name for the function group based on the workflow-provided name
|
|
465
|
-
build_result.set_instance_name(name)
|
|
466
|
-
return ConfiguredFunctionGroup(config=config, instance=build_result)
|
|
643
|
+
return await _build_function_group_impl(name=name,
|
|
644
|
+
config=config,
|
|
645
|
+
registry=self._registry,
|
|
646
|
+
exit_stack=self._get_exit_stack(),
|
|
647
|
+
inner_builder=inner_builder,
|
|
648
|
+
llms=llms,
|
|
649
|
+
dependencies=self.function_group_dependencies,
|
|
650
|
+
middleware_instances=middleware_instances)
|
|
467
651
|
|
|
468
652
|
@override
|
|
469
653
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
@@ -472,6 +656,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
472
656
|
|
|
473
657
|
if (name in self._functions or name in self._function_groups):
|
|
474
658
|
raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
|
|
659
|
+
if any(name.startswith(k + FunctionGroup.SEPARATOR) for k in self._function_groups.keys()):
|
|
660
|
+
raise ValueError(f"A Function name starts with a Function Group name: `{name}`")
|
|
475
661
|
|
|
476
662
|
build_result = await self._build_function(name=name, config=config)
|
|
477
663
|
|
|
@@ -486,6 +672,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
486
672
|
|
|
487
673
|
if (name in self._function_groups or name in self._functions):
|
|
488
674
|
raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
|
|
675
|
+
if any(k.startswith(name + FunctionGroup.SEPARATOR) for k in self._functions.keys()):
|
|
676
|
+
raise ValueError(f"A Function name starts with a Function Group name: `{name}`")
|
|
489
677
|
|
|
490
678
|
# Build the function group
|
|
491
679
|
build_result = await self._build_function_group(name=name, config=config)
|
|
@@ -505,10 +693,23 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
505
693
|
|
|
506
694
|
return build_result.instance
|
|
507
695
|
|
|
696
|
+
def _check_backwards_compatibility_function_name(self, name: str) -> str:
|
|
697
|
+
if name in self._functions:
|
|
698
|
+
return name
|
|
699
|
+
new_name = name.replace(FunctionGroup.LEGACY_SEPARATOR, FunctionGroup.SEPARATOR)
|
|
700
|
+
if new_name in self._functions:
|
|
701
|
+
logger.warning(
|
|
702
|
+
f"Function `{name}` is deprecated and will be removed in a future release. Use `{new_name}` instead.")
|
|
703
|
+
return new_name
|
|
704
|
+
return name
|
|
705
|
+
|
|
508
706
|
@override
|
|
509
707
|
async def get_function(self, name: str | FunctionRef) -> Function:
|
|
510
708
|
if isinstance(name, FunctionRef):
|
|
511
709
|
name = str(name)
|
|
710
|
+
|
|
711
|
+
name = self._check_backwards_compatibility_function_name(name)
|
|
712
|
+
|
|
512
713
|
if name not in self._functions:
|
|
513
714
|
raise ValueError(f"Function `{name}` not found")
|
|
514
715
|
|
|
@@ -527,6 +728,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
527
728
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
528
729
|
if isinstance(name, FunctionRef):
|
|
529
730
|
name = str(name)
|
|
731
|
+
name = self._check_backwards_compatibility_function_name(name)
|
|
530
732
|
if name not in self._functions:
|
|
531
733
|
raise ValueError(f"Function `{name}` not found")
|
|
532
734
|
|
|
@@ -547,7 +749,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
547
749
|
if self._workflow is not None:
|
|
548
750
|
warnings.warn("Overwriting existing workflow")
|
|
549
751
|
|
|
550
|
-
build_result = await self._build_function(name=
|
|
752
|
+
build_result = await self._build_function(name=WORKFLOW_COMPONENT_NAME, config=config)
|
|
551
753
|
|
|
552
754
|
self._workflow = build_result
|
|
553
755
|
|
|
@@ -618,6 +820,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
618
820
|
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
619
821
|
if isinstance(fn_name, FunctionRef):
|
|
620
822
|
fn_name = str(fn_name)
|
|
823
|
+
|
|
824
|
+
fn_name = self._check_backwards_compatibility_function_name(fn_name)
|
|
825
|
+
|
|
621
826
|
if fn_name not in self._functions:
|
|
622
827
|
raise ValueError(f"Function `{fn_name}` not found in list of functions")
|
|
623
828
|
fn = self._functions[fn_name]
|
|
@@ -640,9 +845,10 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
640
845
|
try:
|
|
641
846
|
llm_info = self._registry.get_llm_provider(type(config))
|
|
642
847
|
|
|
643
|
-
|
|
848
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
849
|
+
info_obj = await self._get_exit_stack().enter_async_context(llm_info.build_fn(config, inner_builder))
|
|
644
850
|
|
|
645
|
-
|
|
851
|
+
self._llms[name] = ConfiguredLLM(config=config, instance=info_obj)
|
|
646
852
|
except Exception as e:
|
|
647
853
|
logger.error("Error adding llm `%s` with config `%s`: %s", name, config, e)
|
|
648
854
|
raise
|
|
@@ -660,7 +866,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
660
866
|
# Generate wrapped client from registered client info
|
|
661
867
|
client_info = self._registry.get_llm_client(config_type=type(llm_info.config), wrapper_type=wrapper_type)
|
|
662
868
|
|
|
663
|
-
|
|
869
|
+
with ChildBuilder.use(llm_info.config, self) as inner_builder:
|
|
870
|
+
client = await self._get_exit_stack().enter_async_context(
|
|
871
|
+
client_info.build_fn(llm_info.config, inner_builder))
|
|
664
872
|
|
|
665
873
|
# Return a frameworks specific client
|
|
666
874
|
return client
|
|
@@ -710,7 +918,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
710
918
|
try:
|
|
711
919
|
authentication_info = self._registry.get_auth_provider(type(config))
|
|
712
920
|
|
|
713
|
-
|
|
921
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
922
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
923
|
+
authentication_info.build_fn(config, inner_builder))
|
|
714
924
|
|
|
715
925
|
self._auth_providers[name] = ConfiguredAuthProvider(config=config, instance=info_obj)
|
|
716
926
|
|
|
@@ -756,7 +966,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
756
966
|
try:
|
|
757
967
|
embedder_info = self._registry.get_embedder_provider(type(config))
|
|
758
968
|
|
|
759
|
-
|
|
969
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
970
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
971
|
+
embedder_info.build_fn(config, inner_builder))
|
|
760
972
|
|
|
761
973
|
self._embedders[name] = ConfiguredEmbedder(config=config, instance=info_obj)
|
|
762
974
|
except Exception as e:
|
|
@@ -776,7 +988,10 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
776
988
|
# Generate wrapped client from registered client info
|
|
777
989
|
client_info = self._registry.get_embedder_client(config_type=type(embedder_info.config),
|
|
778
990
|
wrapper_type=wrapper_type)
|
|
779
|
-
|
|
991
|
+
|
|
992
|
+
with ChildBuilder.use(embedder_info.config, self) as inner_builder:
|
|
993
|
+
client = await self._get_exit_stack().enter_async_context(
|
|
994
|
+
client_info.build_fn(embedder_info.config, inner_builder))
|
|
780
995
|
|
|
781
996
|
# Return a frameworks specific client
|
|
782
997
|
return client
|
|
@@ -801,7 +1016,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
801
1016
|
|
|
802
1017
|
memory_info = self._registry.get_memory(type(config))
|
|
803
1018
|
|
|
804
|
-
|
|
1019
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1020
|
+
info_obj = await self._get_exit_stack().enter_async_context(memory_info.build_fn(config, inner_builder))
|
|
805
1021
|
|
|
806
1022
|
self._memory_clients[name] = ConfiguredMemory(config=config, instance=info_obj)
|
|
807
1023
|
|
|
@@ -833,7 +1049,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
833
1049
|
|
|
834
1050
|
object_store_info = self._registry.get_object_store(type(config))
|
|
835
1051
|
|
|
836
|
-
|
|
1052
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1053
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
1054
|
+
object_store_info.build_fn(config, inner_builder))
|
|
837
1055
|
|
|
838
1056
|
self._object_stores[name] = ConfiguredObjectStore(config=config, instance=info_obj)
|
|
839
1057
|
|
|
@@ -862,7 +1080,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
862
1080
|
try:
|
|
863
1081
|
retriever_info = self._registry.get_retriever_provider(type(config))
|
|
864
1082
|
|
|
865
|
-
|
|
1083
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1084
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
1085
|
+
retriever_info.build_fn(config, inner_builder))
|
|
866
1086
|
|
|
867
1087
|
self._retrievers[name] = ConfiguredRetriever(config=config, instance=info_obj)
|
|
868
1088
|
|
|
@@ -886,7 +1106,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
886
1106
|
client_info = self._registry.get_retriever_client(config_type=type(retriever_info.config),
|
|
887
1107
|
wrapper_type=wrapper_type)
|
|
888
1108
|
|
|
889
|
-
|
|
1109
|
+
with ChildBuilder.use(retriever_info.config, self) as inner_builder:
|
|
1110
|
+
client = await self._get_exit_stack().enter_async_context(
|
|
1111
|
+
client_info.build_fn(retriever_info.config, inner_builder))
|
|
890
1112
|
|
|
891
1113
|
# Return a frameworks specific client
|
|
892
1114
|
return client
|
|
@@ -902,6 +1124,120 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
902
1124
|
|
|
903
1125
|
return self._retrievers[retriever_name].config
|
|
904
1126
|
|
|
1127
|
+
@override
|
|
1128
|
+
@experimental(feature_name="Finetuning")
|
|
1129
|
+
async def add_trainer(self, name: str | TrainerRef, config: TrainerConfig) -> Trainer:
|
|
1130
|
+
if (name in self._trainers):
|
|
1131
|
+
raise ValueError(f"Trainer '{name}' already exists in the list of trainers")
|
|
1132
|
+
|
|
1133
|
+
try:
|
|
1134
|
+
trainer_info = self._registry.get_trainer(type(config))
|
|
1135
|
+
|
|
1136
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1137
|
+
info_obj = await self._get_exit_stack().enter_async_context(trainer_info.build_fn(
|
|
1138
|
+
config, inner_builder))
|
|
1139
|
+
|
|
1140
|
+
self._trainers[name] = ConfiguredTrainer(config=config, instance=info_obj)
|
|
1141
|
+
|
|
1142
|
+
return info_obj
|
|
1143
|
+
|
|
1144
|
+
except Exception as e:
|
|
1145
|
+
logger.error("Error adding trainer `%s` with config `%s`: %s", name, config, e)
|
|
1146
|
+
raise
|
|
1147
|
+
|
|
1148
|
+
@override
|
|
1149
|
+
@experimental(feature_name="Finetuning")
|
|
1150
|
+
async def add_trainer_adapter(self, name: str | TrainerAdapterRef, config: TrainerAdapterConfig) -> TrainerAdapter:
|
|
1151
|
+
if (name in self._trainer_adapters):
|
|
1152
|
+
raise ValueError(f"Trainer adapter '{name}' already exists in the list of trainer adapters")
|
|
1153
|
+
|
|
1154
|
+
try:
|
|
1155
|
+
trainer_adapter_info = self._registry.get_trainer_adapter(type(config))
|
|
1156
|
+
|
|
1157
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1158
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
1159
|
+
trainer_adapter_info.build_fn(config, inner_builder))
|
|
1160
|
+
|
|
1161
|
+
self._trainer_adapters[name] = ConfiguredTrainerAdapter(config=config, instance=info_obj)
|
|
1162
|
+
|
|
1163
|
+
return info_obj
|
|
1164
|
+
|
|
1165
|
+
except Exception as e:
|
|
1166
|
+
logger.error("Error adding trainer adapter `%s` with config `%s`: %s", name, config, e)
|
|
1167
|
+
raise
|
|
1168
|
+
|
|
1169
|
+
@override
|
|
1170
|
+
@experimental(feature_name="Finetuning")
|
|
1171
|
+
async def add_trajectory_builder(self, name: str | TrajectoryBuilderRef,
|
|
1172
|
+
config: TrajectoryBuilderConfig) -> TrajectoryBuilder:
|
|
1173
|
+
if (name in self._trajectory_builders):
|
|
1174
|
+
raise ValueError(f"Trajectory builder '{name}' already exists in the list of trajectory builders")
|
|
1175
|
+
|
|
1176
|
+
try:
|
|
1177
|
+
trajectory_builder_info = self._registry.get_trajectory_builder(type(config))
|
|
1178
|
+
|
|
1179
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1180
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
1181
|
+
trajectory_builder_info.build_fn(config, inner_builder))
|
|
1182
|
+
|
|
1183
|
+
self._trajectory_builders[name] = ConfiguredTrajectoryBuilder(config=config, instance=info_obj)
|
|
1184
|
+
|
|
1185
|
+
return info_obj
|
|
1186
|
+
|
|
1187
|
+
except Exception as e:
|
|
1188
|
+
logger.error("Error adding trajectory builder `%s` with config `%s`: %s", name, config, e)
|
|
1189
|
+
raise
|
|
1190
|
+
|
|
1191
|
+
@override
|
|
1192
|
+
async def get_trainer(self,
|
|
1193
|
+
trainer_name: str | TrainerRef,
|
|
1194
|
+
trajectory_builder: TrajectoryBuilder,
|
|
1195
|
+
trainer_adapter: TrainerAdapter) -> Trainer:
|
|
1196
|
+
|
|
1197
|
+
if trainer_name not in self._trainers:
|
|
1198
|
+
raise ValueError(f"Trainer '{trainer_name}' not found")
|
|
1199
|
+
|
|
1200
|
+
trainer_instance = self._trainers[trainer_name].instance
|
|
1201
|
+
await trainer_instance.bind_components(trainer_adapter=trainer_adapter, trajectory_builder=trajectory_builder)
|
|
1202
|
+
|
|
1203
|
+
return trainer_instance
|
|
1204
|
+
|
|
1205
|
+
@override
|
|
1206
|
+
async def get_trainer_config(self, trainer_name: str | TrainerRef) -> TrainerConfig:
|
|
1207
|
+
if trainer_name not in self._trainers:
|
|
1208
|
+
raise ValueError(f"Trainer '{trainer_name}' not found")
|
|
1209
|
+
|
|
1210
|
+
return self._trainers[trainer_name].config
|
|
1211
|
+
|
|
1212
|
+
@override
|
|
1213
|
+
async def get_trainer_adapter_config(self, trainer_adapter_name: str | TrainerAdapterRef) -> TrainerAdapterConfig:
|
|
1214
|
+
if trainer_adapter_name not in self._trainer_adapters:
|
|
1215
|
+
raise ValueError(f"Trainer adapter '{trainer_adapter_name}' not found")
|
|
1216
|
+
|
|
1217
|
+
return self._trainer_adapters[trainer_adapter_name].config
|
|
1218
|
+
|
|
1219
|
+
@override
|
|
1220
|
+
async def get_trajectory_builder_config(
|
|
1221
|
+
self, trajectory_builder_name: str | TrajectoryBuilderRef) -> (TrajectoryBuilderConfig):
|
|
1222
|
+
if trajectory_builder_name not in self._trajectory_builders:
|
|
1223
|
+
raise ValueError(f"Trajectory builder '{trajectory_builder_name}' not found")
|
|
1224
|
+
|
|
1225
|
+
return self._trajectory_builders[trajectory_builder_name].config
|
|
1226
|
+
|
|
1227
|
+
@override
|
|
1228
|
+
async def get_trainer_adapter(self, trainer_adapter_name: str | TrainerAdapterRef) -> TrainerAdapter:
|
|
1229
|
+
if trainer_adapter_name not in self._trainer_adapters:
|
|
1230
|
+
raise ValueError(f"Trainer adapter '{trainer_adapter_name}' not found")
|
|
1231
|
+
|
|
1232
|
+
return self._trainer_adapters[trainer_adapter_name].instance
|
|
1233
|
+
|
|
1234
|
+
@override
|
|
1235
|
+
async def get_trajectory_builder(self, trajectory_builder_name: str | TrajectoryBuilderRef) -> TrajectoryBuilder:
|
|
1236
|
+
if trajectory_builder_name not in self._trajectory_builders:
|
|
1237
|
+
raise ValueError(f"Trajectory builder '{trajectory_builder_name}' not found")
|
|
1238
|
+
|
|
1239
|
+
return self._trajectory_builders[trajectory_builder_name].instance
|
|
1240
|
+
|
|
905
1241
|
@override
|
|
906
1242
|
@experimental(feature_name="TTC")
|
|
907
1243
|
async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig) -> None:
|
|
@@ -911,7 +1247,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
911
1247
|
try:
|
|
912
1248
|
ttc_strategy_info = self._registry.get_ttc_strategy(type(config))
|
|
913
1249
|
|
|
914
|
-
|
|
1250
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1251
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
1252
|
+
ttc_strategy_info.build_fn(config, inner_builder))
|
|
915
1253
|
|
|
916
1254
|
self._ttc_strategies[name] = ConfiguredTTCStrategy(config=config, instance=info_obj)
|
|
917
1255
|
|
|
@@ -969,98 +1307,96 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
969
1307
|
return config
|
|
970
1308
|
|
|
971
1309
|
@override
|
|
972
|
-
def
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
async def add_telemetry_exporter(self, name: str, config: TelemetryExporterBaseConfig) -> None:
|
|
976
|
-
"""Add an configured telemetry exporter to the builder.
|
|
1310
|
+
async def add_middleware(self, name: str | MiddlewareRef, config: MiddlewareBaseConfig) -> Middleware:
|
|
1311
|
+
"""Add middleware to the builder.
|
|
977
1312
|
|
|
978
1313
|
Args:
|
|
979
|
-
name
|
|
980
|
-
config
|
|
1314
|
+
name: The name or reference for the middleware
|
|
1315
|
+
config: The configuration for the middleware
|
|
1316
|
+
|
|
1317
|
+
Returns:
|
|
1318
|
+
The built middleware instance
|
|
1319
|
+
|
|
1320
|
+
Raises:
|
|
1321
|
+
ValueError: If the middleware already exists
|
|
981
1322
|
"""
|
|
982
|
-
if
|
|
983
|
-
raise ValueError(f"
|
|
1323
|
+
if name in self._middleware:
|
|
1324
|
+
raise ValueError(f"Middleware `{name}` already exists in the list of middleware")
|
|
984
1325
|
|
|
985
|
-
|
|
1326
|
+
try:
|
|
1327
|
+
middleware_info = self._registry.get_middleware(type(config))
|
|
986
1328
|
|
|
987
|
-
|
|
988
|
-
|
|
1329
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1330
|
+
middleware_instance = await self._get_exit_stack().enter_async_context(
|
|
1331
|
+
middleware_info.build_fn(config, inner_builder))
|
|
989
1332
|
|
|
990
|
-
|
|
991
|
-
exporter = await self._get_exit_stack().enter_async_context(exporter_context_manager)
|
|
992
|
-
self._telemetry_exporters[name] = ConfiguredTelemetryExporter(config=config, instance=exporter)
|
|
1333
|
+
self._middleware[name] = ConfiguredMiddleware(config=config, instance=middleware_instance)
|
|
993
1334
|
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1335
|
+
return middleware_instance
|
|
1336
|
+
except Exception as e:
|
|
1337
|
+
logger.error("Error adding function middleware `%s` with config `%s`: %s", name, config, e)
|
|
1338
|
+
raise
|
|
1339
|
+
|
|
1340
|
+
@override
|
|
1341
|
+
async def get_middleware(self, middleware_name: str | MiddlewareRef) -> Middleware:
|
|
1342
|
+
"""Get built middleware by name.
|
|
1002
1343
|
|
|
1003
1344
|
Args:
|
|
1004
|
-
|
|
1005
|
-
component_type (str): The type of the component that failed to build
|
|
1006
|
-
completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
|
|
1007
|
-
remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
|
|
1008
|
-
original_error (Exception): The original exception that caused the failure
|
|
1009
|
-
"""
|
|
1010
|
-
logger.error("Failed to initialize component %s (%s)", component_name, component_type)
|
|
1345
|
+
middleware_name: The name or reference of the middleware
|
|
1011
1346
|
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
for name, comp_type in completed_components:
|
|
1015
|
-
logger.error("- %s (%s)", name, comp_type)
|
|
1016
|
-
else:
|
|
1017
|
-
logger.error("No components were successfully built before this failure")
|
|
1347
|
+
Returns:
|
|
1348
|
+
The built middleware instance
|
|
1018
1349
|
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
logger.error("No remaining components to build")
|
|
1350
|
+
Raises:
|
|
1351
|
+
ValueError: If the middleware is not found
|
|
1352
|
+
"""
|
|
1353
|
+
if middleware_name not in self._middleware:
|
|
1354
|
+
raise ValueError(f"Middleware `{middleware_name}` not found")
|
|
1025
1355
|
|
|
1026
|
-
|
|
1356
|
+
return self._middleware[middleware_name].instance
|
|
1027
1357
|
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
remaining_components: list[tuple[str, str]],
|
|
1032
|
-
original_error: Exception) -> None:
|
|
1033
|
-
"""
|
|
1034
|
-
Log comprehensive component build failure information.
|
|
1358
|
+
@override
|
|
1359
|
+
def get_middleware_config(self, middleware_name: str | MiddlewareRef) -> MiddlewareBaseConfig:
|
|
1360
|
+
"""Get the configuration for middleware.
|
|
1035
1361
|
|
|
1036
1362
|
Args:
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
self._log_build_failure(component_name,
|
|
1046
|
-
component_type,
|
|
1047
|
-
completed_components,
|
|
1048
|
-
remaining_components,
|
|
1049
|
-
original_error)
|
|
1050
|
-
|
|
1051
|
-
def _log_build_failure_workflow(self,
|
|
1052
|
-
completed_components: list[tuple[str, str]],
|
|
1053
|
-
remaining_components: list[tuple[str, str]],
|
|
1054
|
-
original_error: Exception) -> None:
|
|
1363
|
+
middleware_name: The name or reference of the middleware
|
|
1364
|
+
|
|
1365
|
+
Returns:
|
|
1366
|
+
The configuration for the middleware
|
|
1367
|
+
|
|
1368
|
+
Raises:
|
|
1369
|
+
ValueError: If the middleware is not found
|
|
1055
1370
|
"""
|
|
1056
|
-
|
|
1371
|
+
if middleware_name not in self._middleware:
|
|
1372
|
+
raise ValueError(f"Middleware `{middleware_name}` not found")
|
|
1373
|
+
|
|
1374
|
+
return self._middleware[middleware_name].config
|
|
1375
|
+
|
|
1376
|
+
@override
|
|
1377
|
+
def get_user_manager(self):
|
|
1378
|
+
return UserManagerHolder(context=Context(self._context_state))
|
|
1379
|
+
|
|
1380
|
+
async def add_telemetry_exporter(self, name: str, config: TelemetryExporterBaseConfig) -> None:
|
|
1381
|
+
"""Add an configured telemetry exporter to the builder.
|
|
1057
1382
|
|
|
1058
1383
|
Args:
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
original_error (Exception): The original exception that caused the failure
|
|
1384
|
+
name (str): The name of the telemetry exporter
|
|
1385
|
+
config (TelemetryExporterBaseConfig): The configuration for the exporter
|
|
1062
1386
|
"""
|
|
1063
|
-
|
|
1387
|
+
if (name in self._telemetry_exporters):
|
|
1388
|
+
raise ValueError(f"Telemetry exporter '{name}' already exists in the list of telemetry exporters")
|
|
1389
|
+
|
|
1390
|
+
exporter_info = self._registry.get_telemetry_exporter(type(config))
|
|
1391
|
+
|
|
1392
|
+
# Build the exporter outside the lock (parallel)
|
|
1393
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1394
|
+
exporter_context_manager = exporter_info.build_fn(config, inner_builder)
|
|
1395
|
+
|
|
1396
|
+
# Only protect the shared state modifications (serialized)
|
|
1397
|
+
exporter = await self._get_exit_stack().enter_async_context(exporter_context_manager)
|
|
1398
|
+
|
|
1399
|
+
self._telemetry_exporters[name] = ConfiguredTelemetryExporter(config=config, instance=exporter)
|
|
1064
1400
|
|
|
1065
1401
|
async def populate_builder(self, config: Config, skip_workflow: bool = False):
|
|
1066
1402
|
"""
|
|
@@ -1074,21 +1410,14 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
1074
1410
|
# Generate the build sequence
|
|
1075
1411
|
build_sequence = build_dependency_sequence(config)
|
|
1076
1412
|
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
remaining_components = [(str(comp.name), comp.component_group.value) for comp in build_sequence
|
|
1080
|
-
if not comp.is_root]
|
|
1413
|
+
self.remaining_components = [(str(comp.name), comp.component_group.value) for comp in build_sequence
|
|
1414
|
+
if not comp.is_root]
|
|
1081
1415
|
if not skip_workflow:
|
|
1082
|
-
remaining_components.append((
|
|
1416
|
+
self.remaining_components.append((WORKFLOW_COMPONENT_NAME, "workflow"))
|
|
1083
1417
|
|
|
1084
|
-
# Loop over all
|
|
1418
|
+
# Loop over all components and add to the workflow builder
|
|
1085
1419
|
for component_instance in build_sequence:
|
|
1086
1420
|
try:
|
|
1087
|
-
# Remove from remaining as we start building (if not root)
|
|
1088
|
-
if not component_instance.is_root:
|
|
1089
|
-
remaining_components.remove(
|
|
1090
|
-
(str(component_instance.name), component_instance.component_group.value))
|
|
1091
|
-
|
|
1092
1421
|
# Instantiate a the llm
|
|
1093
1422
|
if component_instance.component_group == ComponentGroup.LLMS:
|
|
1094
1423
|
await self.add_llm(component_instance.name, cast(LLMBaseConfig, component_instance.config))
|
|
@@ -1108,16 +1437,29 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
1108
1437
|
elif component_instance.component_group == ComponentGroup.RETRIEVERS:
|
|
1109
1438
|
await self.add_retriever(component_instance.name,
|
|
1110
1439
|
cast(RetrieverBaseConfig, component_instance.config))
|
|
1440
|
+
# Instantiate middleware
|
|
1441
|
+
elif component_instance.component_group == ComponentGroup.MIDDLEWARE:
|
|
1442
|
+
await self.add_middleware(component_instance.name,
|
|
1443
|
+
cast(MiddlewareBaseConfig, component_instance.config))
|
|
1111
1444
|
# Instantiate a function group
|
|
1112
1445
|
elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
|
|
1446
|
+
config_obj = cast(FunctionGroupBaseConfig, component_instance.config)
|
|
1447
|
+
registration = self._registry.get_function_group(type(config_obj))
|
|
1448
|
+
if registration.is_per_user:
|
|
1449
|
+
# Skip per-user function groups as they will be built lazily by PerUserWorkflowBuilder
|
|
1450
|
+
continue
|
|
1113
1451
|
await self.add_function_group(component_instance.name,
|
|
1114
1452
|
cast(FunctionGroupBaseConfig, component_instance.config))
|
|
1115
1453
|
# Instantiate a function
|
|
1116
1454
|
elif component_instance.component_group == ComponentGroup.FUNCTIONS:
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1455
|
+
config_obj = cast(FunctionBaseConfig, component_instance.config)
|
|
1456
|
+
registration = self._registry.get_function(type(config_obj))
|
|
1457
|
+
if registration.is_per_user:
|
|
1458
|
+
# Skip per-user functions as they will be built lazily by PerUserWorkflowBuilder
|
|
1459
|
+
continue
|
|
1460
|
+
elif not component_instance.is_root:
|
|
1461
|
+
# If the function is not the root, add it to the workflow builder
|
|
1462
|
+
await self.add_function(component_instance.name, config_obj)
|
|
1121
1463
|
elif component_instance.component_group == ComponentGroup.TTC_STRATEGIES:
|
|
1122
1464
|
await self.add_ttc_strategy(component_instance.name,
|
|
1123
1465
|
cast(TTCStrategyBaseConfig, component_instance.config))
|
|
@@ -1125,241 +1467,140 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
1125
1467
|
elif component_instance.component_group == ComponentGroup.AUTHENTICATION:
|
|
1126
1468
|
await self.add_auth_provider(component_instance.name,
|
|
1127
1469
|
cast(AuthProviderBaseConfig, component_instance.config))
|
|
1470
|
+
|
|
1471
|
+
elif component_instance.component_group == ComponentGroup.TRAINERS:
|
|
1472
|
+
await self.add_trainer(component_instance.name, cast(TrainerConfig, component_instance.config))
|
|
1473
|
+
|
|
1474
|
+
elif component_instance.component_group == ComponentGroup.TRAINER_ADAPTERS:
|
|
1475
|
+
await self.add_trainer_adapter(component_instance.name,
|
|
1476
|
+
cast(TrainerAdapterConfig, component_instance.config))
|
|
1477
|
+
|
|
1478
|
+
elif component_instance.component_group == ComponentGroup.TRAJECTORY_BUILDERS:
|
|
1479
|
+
await self.add_trajectory_builder(component_instance.name,
|
|
1480
|
+
cast(TrajectoryBuilderConfig, component_instance.config))
|
|
1128
1481
|
else:
|
|
1129
1482
|
raise ValueError(f"Unknown component group {component_instance.component_group}")
|
|
1130
1483
|
|
|
1131
|
-
#
|
|
1484
|
+
# Remove from remaining and add to completed after successful build (if not root)
|
|
1132
1485
|
if not component_instance.is_root:
|
|
1133
|
-
|
|
1486
|
+
self.remaining_components.remove(
|
|
1487
|
+
(str(component_instance.name), component_instance.component_group.value))
|
|
1488
|
+
self.completed_components.append(
|
|
1134
1489
|
(str(component_instance.name), component_instance.component_group.value))
|
|
1135
1490
|
|
|
1136
1491
|
except Exception as e:
|
|
1137
|
-
|
|
1492
|
+
_log_build_failure(str(component_instance.name),
|
|
1493
|
+
component_instance.component_group.value,
|
|
1494
|
+
self.completed_components,
|
|
1495
|
+
self.remaining_components,
|
|
1496
|
+
e)
|
|
1138
1497
|
raise
|
|
1139
1498
|
|
|
1140
1499
|
# Instantiate the workflow
|
|
1141
1500
|
if not skip_workflow:
|
|
1142
1501
|
try:
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1502
|
+
workflow_registration = self._registry.get_function(type(config.workflow))
|
|
1503
|
+
# If the workflow is shared (not per-user), build it
|
|
1504
|
+
# Otherwise, build it lazily by PerUserWorkflowBuilder
|
|
1505
|
+
if not workflow_registration.is_per_user:
|
|
1506
|
+
# Remove workflow from remaining as we start building
|
|
1507
|
+
self.remaining_components.remove((WORKFLOW_COMPONENT_NAME, "workflow"))
|
|
1508
|
+
await self.set_workflow(config.workflow)
|
|
1509
|
+
self.completed_components.append((WORKFLOW_COMPONENT_NAME, "workflow"))
|
|
1147
1510
|
except Exception as e:
|
|
1148
|
-
|
|
1511
|
+
_log_build_failure(WORKFLOW_COMPONENT_NAME,
|
|
1512
|
+
"workflow",
|
|
1513
|
+
self.completed_components,
|
|
1514
|
+
self.remaining_components,
|
|
1515
|
+
e)
|
|
1149
1516
|
raise
|
|
1150
1517
|
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
async def from_config(cls, config: Config):
|
|
1154
|
-
|
|
1155
|
-
async with cls(general_config=config.general) as builder:
|
|
1156
|
-
await builder.populate_builder(config)
|
|
1157
|
-
yield builder
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
class ChildBuilder(Builder):
|
|
1161
|
-
|
|
1162
|
-
def __init__(self, workflow_builder: WorkflowBuilder) -> None:
|
|
1163
|
-
|
|
1164
|
-
self._workflow_builder = workflow_builder
|
|
1165
|
-
|
|
1166
|
-
self._dependencies = FunctionDependencies()
|
|
1167
|
-
|
|
1168
|
-
@property
|
|
1169
|
-
def dependencies(self) -> FunctionDependencies:
|
|
1170
|
-
return self._dependencies
|
|
1171
|
-
|
|
1172
|
-
@override
|
|
1173
|
-
async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
|
|
1174
|
-
return await self._workflow_builder.add_function(name, config)
|
|
1175
|
-
|
|
1176
|
-
@override
|
|
1177
|
-
async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
1178
|
-
return await self._workflow_builder.add_function_group(name, config)
|
|
1179
|
-
|
|
1180
|
-
@override
|
|
1181
|
-
async def get_function(self, name: str) -> Function:
|
|
1182
|
-
# If a function tries to get another function, we assume it uses it
|
|
1183
|
-
fn = await self._workflow_builder.get_function(name)
|
|
1184
|
-
|
|
1185
|
-
self._dependencies.add_function(name)
|
|
1186
|
-
|
|
1187
|
-
return fn
|
|
1188
|
-
|
|
1189
|
-
@override
|
|
1190
|
-
async def get_function_group(self, name: str) -> FunctionGroup:
|
|
1191
|
-
# If a function tries to get a function group, we assume it uses it
|
|
1192
|
-
function_group = await self._workflow_builder.get_function_group(name)
|
|
1193
|
-
|
|
1194
|
-
self._dependencies.add_function_group(name)
|
|
1195
|
-
|
|
1196
|
-
return function_group
|
|
1197
|
-
|
|
1198
|
-
@override
|
|
1199
|
-
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
|
1200
|
-
return self._workflow_builder.get_function_config(name)
|
|
1201
|
-
|
|
1202
|
-
@override
|
|
1203
|
-
def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
|
|
1204
|
-
return self._workflow_builder.get_function_group_config(name)
|
|
1205
|
-
|
|
1206
|
-
@override
|
|
1207
|
-
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
1208
|
-
return await self._workflow_builder.set_workflow(config)
|
|
1209
|
-
|
|
1210
|
-
@override
|
|
1211
|
-
def get_workflow(self) -> Function:
|
|
1212
|
-
return self._workflow_builder.get_workflow()
|
|
1213
|
-
|
|
1214
|
-
@override
|
|
1215
|
-
def get_workflow_config(self) -> FunctionBaseConfig:
|
|
1216
|
-
return self._workflow_builder.get_workflow_config()
|
|
1217
|
-
|
|
1218
|
-
@override
|
|
1219
|
-
async def get_tools(self,
|
|
1220
|
-
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
1221
|
-
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
1222
|
-
tools = await self._workflow_builder.get_tools(tool_names, wrapper_type)
|
|
1223
|
-
for tool_name in tool_names:
|
|
1224
|
-
if tool_name in self._workflow_builder._function_groups:
|
|
1225
|
-
self._dependencies.add_function_group(tool_name)
|
|
1226
|
-
else:
|
|
1227
|
-
self._dependencies.add_function(tool_name)
|
|
1228
|
-
return tools
|
|
1229
|
-
|
|
1230
|
-
@override
|
|
1231
|
-
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
|
|
1232
|
-
# If a function tries to get another function as a tool, we assume it uses it
|
|
1233
|
-
fn = await self._workflow_builder.get_tool(fn_name, wrapper_type)
|
|
1234
|
-
|
|
1235
|
-
self._dependencies.add_function(fn_name)
|
|
1236
|
-
|
|
1237
|
-
return fn
|
|
1238
|
-
|
|
1239
|
-
@override
|
|
1240
|
-
async def add_llm(self, name: str, config: LLMBaseConfig) -> None:
|
|
1241
|
-
return await self._workflow_builder.add_llm(name, config)
|
|
1242
|
-
|
|
1243
|
-
@experimental(feature_name="Authentication")
|
|
1244
|
-
@override
|
|
1245
|
-
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> AuthProviderBase:
|
|
1246
|
-
return await self._workflow_builder.add_auth_provider(name, config)
|
|
1247
|
-
|
|
1248
|
-
@override
|
|
1249
|
-
async def get_auth_provider(self, auth_provider_name: str):
|
|
1250
|
-
return await self._workflow_builder.get_auth_provider(auth_provider_name)
|
|
1251
|
-
|
|
1252
|
-
@override
|
|
1253
|
-
async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
1254
|
-
llm = await self._workflow_builder.get_llm(llm_name, wrapper_type)
|
|
1255
|
-
|
|
1256
|
-
self._dependencies.add_llm(llm_name)
|
|
1257
|
-
|
|
1258
|
-
return llm
|
|
1518
|
+
# Check if any shared components have dependencies on per-user components
|
|
1519
|
+
self._validate_dependencies(config)
|
|
1259
1520
|
|
|
1260
|
-
|
|
1261
|
-
def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
|
|
1262
|
-
return self._workflow_builder.get_llm_config(llm_name)
|
|
1263
|
-
|
|
1264
|
-
@override
|
|
1265
|
-
async def add_embedder(self, name: str, config: EmbedderBaseConfig) -> None:
|
|
1266
|
-
await self._workflow_builder.add_embedder(name, config)
|
|
1267
|
-
|
|
1268
|
-
@override
|
|
1269
|
-
async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
1270
|
-
embedder = await self._workflow_builder.get_embedder(embedder_name, wrapper_type)
|
|
1271
|
-
|
|
1272
|
-
self._dependencies.add_embedder(embedder_name)
|
|
1273
|
-
|
|
1274
|
-
return embedder
|
|
1275
|
-
|
|
1276
|
-
@override
|
|
1277
|
-
def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
|
|
1278
|
-
return self._workflow_builder.get_embedder_config(embedder_name)
|
|
1279
|
-
|
|
1280
|
-
@override
|
|
1281
|
-
async def add_memory_client(self, name: str, config: MemoryBaseConfig) -> MemoryEditor:
|
|
1282
|
-
return await self._workflow_builder.add_memory_client(name, config)
|
|
1283
|
-
|
|
1284
|
-
@override
|
|
1285
|
-
async def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
|
1286
|
-
"""
|
|
1287
|
-
Return the instantiated memory client for the given name.
|
|
1521
|
+
def _validate_dependencies(self, config: Config):
|
|
1288
1522
|
"""
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
self._dependencies.add_memory_client(memory_name)
|
|
1292
|
-
|
|
1293
|
-
return memory_client
|
|
1294
|
-
|
|
1295
|
-
@override
|
|
1296
|
-
def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
|
|
1297
|
-
return self._workflow_builder.get_memory_client_config(memory_name=memory_name)
|
|
1523
|
+
Validate no shared component has dependencies on any per-user components.
|
|
1298
1524
|
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
return await self._workflow_builder.add_object_store(name, config)
|
|
1302
|
-
|
|
1303
|
-
@override
|
|
1304
|
-
async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
|
|
1525
|
+
This prevents invalid configurations where shared components try to use per-user functions that do not exist
|
|
1526
|
+
at shared builder initialization time.
|
|
1305
1527
|
"""
|
|
1306
|
-
Return the instantiated object store client for the given name.
|
|
1307
|
-
"""
|
|
1308
|
-
object_store_client = await self._workflow_builder.get_object_store_client(object_store_name)
|
|
1309
|
-
|
|
1310
|
-
self._dependencies.add_object_store(object_store_name)
|
|
1311
1528
|
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1529
|
+
# Check shared functions do not depend on per-user functions or function_groups
|
|
1530
|
+
for fn_name, fn_deps in self.function_dependencies.items():
|
|
1531
|
+
if fn_name == WORKFLOW_COMPONENT_NAME:
|
|
1532
|
+
continue
|
|
1533
|
+
|
|
1534
|
+
fn_config = self.get_function_config(fn_name)
|
|
1535
|
+
fn_registration = self._registry.get_function(type(fn_config))
|
|
1536
|
+
|
|
1537
|
+
if not fn_registration.is_per_user:
|
|
1538
|
+
for dep_fn_name in fn_deps.functions:
|
|
1539
|
+
dep_config = config.functions.get(dep_fn_name)
|
|
1540
|
+
if dep_config is not None:
|
|
1541
|
+
dep_registration = self._registry.get_function(type(dep_config))
|
|
1542
|
+
if dep_registration.is_per_user:
|
|
1543
|
+
raise ValueError(f"Function `{fn_name}` depends on per-user function `{dep_fn_name}`")
|
|
1544
|
+
|
|
1545
|
+
for dep_fg_name in fn_deps.function_groups:
|
|
1546
|
+
dep_config = config.function_groups.get(dep_fg_name)
|
|
1547
|
+
if dep_config is not None:
|
|
1548
|
+
dep_registration = self._registry.get_function_group(type(dep_config))
|
|
1549
|
+
if dep_registration.is_per_user:
|
|
1550
|
+
raise ValueError(f"Function `{fn_name}` depends on per-user function_group `{dep_fg_name}`")
|
|
1551
|
+
|
|
1552
|
+
# Check shared function_groups do not depend on per-user functions or function_groups
|
|
1553
|
+
for fg_name, fg_deps in self.function_group_dependencies.items():
|
|
1554
|
+
fg_config = self.get_function_group_config(fg_name)
|
|
1555
|
+
fg_registration = self._registry.get_function_group(type(fg_config))
|
|
1556
|
+
|
|
1557
|
+
if not fg_registration.is_per_user:
|
|
1558
|
+
for dep_fn_name in fg_deps.functions:
|
|
1559
|
+
dep_config = config.functions.get(dep_fn_name)
|
|
1560
|
+
if dep_config is not None:
|
|
1561
|
+
dep_registration = self._registry.get_function(type(dep_config))
|
|
1562
|
+
if dep_registration.is_per_user:
|
|
1563
|
+
raise ValueError(f"FunctionGroup `{fg_name}` depends on per-user function `{dep_fn_name}`")
|
|
1564
|
+
|
|
1565
|
+
for dep_fg_name in fg_deps.function_groups:
|
|
1566
|
+
dep_config = config.function_groups.get(dep_fg_name)
|
|
1567
|
+
if dep_config is not None:
|
|
1568
|
+
dep_registration = self._registry.get_function_group(type(dep_config))
|
|
1569
|
+
if dep_registration.is_per_user:
|
|
1570
|
+
raise ValueError(
|
|
1571
|
+
f"FunctionGroup `{fg_name}` depends on per-user function_group `{dep_fg_name}`")
|
|
1322
1572
|
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
pipeline_type: PipelineTypeEnum,
|
|
1327
|
-
stage_type: StageTypeEnum) -> StrategyBase:
|
|
1328
|
-
return await self._workflow_builder.get_ttc_strategy(strategy_name=strategy_name,
|
|
1329
|
-
pipeline_type=pipeline_type,
|
|
1330
|
-
stage_type=stage_type)
|
|
1331
|
-
|
|
1332
|
-
@override
|
|
1333
|
-
async def get_ttc_strategy_config(self,
|
|
1334
|
-
strategy_name: str | TTCStrategyRef,
|
|
1335
|
-
pipeline_type: PipelineTypeEnum,
|
|
1336
|
-
stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
|
|
1337
|
-
return await self._workflow_builder.get_ttc_strategy_config(strategy_name=strategy_name,
|
|
1338
|
-
pipeline_type=pipeline_type,
|
|
1339
|
-
stage_type=stage_type)
|
|
1340
|
-
|
|
1341
|
-
@override
|
|
1342
|
-
async def add_retriever(self, name: str, config: RetrieverBaseConfig) -> None:
|
|
1343
|
-
await self._workflow_builder.add_retriever(name, config)
|
|
1344
|
-
|
|
1345
|
-
@override
|
|
1346
|
-
async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None) -> Retriever:
|
|
1347
|
-
if not wrapper_type:
|
|
1348
|
-
return await self._workflow_builder.get_retriever(retriever_name=retriever_name)
|
|
1349
|
-
return await self._workflow_builder.get_retriever(retriever_name=retriever_name, wrapper_type=wrapper_type)
|
|
1573
|
+
if self._workflow is not None:
|
|
1574
|
+
workflow_config = self.get_workflow_config()
|
|
1575
|
+
workflow_registration = self._registry.get_function(type(workflow_config))
|
|
1350
1576
|
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1577
|
+
# Per-user workflow must be owned by PerUserWorkflowBuilder
|
|
1578
|
+
if workflow_registration.is_per_user:
|
|
1579
|
+
raise ValueError("Workflow is a per-user function, but it is owned by a shared WorkflowBuilder")
|
|
1354
1580
|
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1581
|
+
else:
|
|
1582
|
+
workflow_deps = self.function_dependencies.get(WORKFLOW_COMPONENT_NAME, FunctionDependencies())
|
|
1583
|
+
|
|
1584
|
+
for dep_fn_name in workflow_deps.functions:
|
|
1585
|
+
if dep_fn_name in config.functions:
|
|
1586
|
+
dep_config = config.functions[dep_fn_name]
|
|
1587
|
+
if dep_config is not None:
|
|
1588
|
+
dep_registration = self._registry.get_function(type(dep_config))
|
|
1589
|
+
if dep_registration.is_per_user:
|
|
1590
|
+
raise ValueError(f"Shared Workflow depends on per-user function `{dep_fn_name}`")
|
|
1591
|
+
|
|
1592
|
+
for dep_fg_name in workflow_deps.function_groups:
|
|
1593
|
+
if dep_fg_name in config.function_groups:
|
|
1594
|
+
dep_config = config.function_groups[dep_fg_name]
|
|
1595
|
+
if dep_config is not None:
|
|
1596
|
+
dep_registration = self._registry.get_function_group(type(dep_config))
|
|
1597
|
+
if dep_registration.is_per_user:
|
|
1598
|
+
raise ValueError(f"Shared Workflow depends on per-user function_group `{dep_fg_name}`")
|
|
1358
1599
|
|
|
1359
|
-
@
|
|
1360
|
-
|
|
1361
|
-
|
|
1600
|
+
@classmethod
|
|
1601
|
+
@asynccontextmanager
|
|
1602
|
+
async def from_config(cls, config: Config):
|
|
1362
1603
|
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1604
|
+
async with cls(general_config=config.general) as builder:
|
|
1605
|
+
await builder.populate_builder(config)
|
|
1606
|
+
yield builder
|