nvidia-nat 1.4.0a20251120__py3-none-any.whl → 1.4.0a20260113__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +1 -1
- nat/{front_ends/mcp → agent/auto_memory_wrapper}/__init__.py +1 -1
- nat/agent/auto_memory_wrapper/agent.py +278 -0
- nat/agent/auto_memory_wrapper/register.py +227 -0
- nat/agent/auto_memory_wrapper/state.py +30 -0
- nat/agent/base.py +1 -1
- nat/agent/dual_node.py +1 -1
- nat/agent/prompt_optimizer/prompt.py +1 -1
- nat/agent/prompt_optimizer/register.py +1 -1
- nat/agent/react_agent/agent.py +16 -9
- nat/agent/react_agent/output_parser.py +2 -2
- nat/agent/react_agent/prompt.py +3 -2
- nat/agent/react_agent/register.py +2 -2
- nat/agent/react_agent/register_per_user_agent.py +104 -0
- nat/agent/reasoning_agent/reasoning_agent.py +1 -1
- nat/agent/register.py +3 -1
- nat/agent/responses_api_agent/__init__.py +1 -1
- nat/agent/responses_api_agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +9 -4
- nat/agent/rewoo_agent/prompt.py +1 -1
- nat/agent/rewoo_agent/register.py +1 -1
- nat/agent/tool_calling_agent/agent.py +5 -4
- nat/agent/tool_calling_agent/register.py +1 -1
- nat/authentication/__init__.py +1 -1
- nat/authentication/api_key/__init__.py +1 -1
- nat/authentication/api_key/api_key_auth_provider.py +1 -1
- nat/authentication/api_key/api_key_auth_provider_config.py +22 -7
- nat/authentication/api_key/register.py +1 -1
- nat/authentication/credential_validator/__init__.py +1 -1
- nat/authentication/credential_validator/bearer_token_validator.py +1 -1
- nat/authentication/exceptions/__init__.py +1 -1
- nat/authentication/exceptions/api_key_exceptions.py +1 -1
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/http_basic_auth/register.py +1 -1
- nat/authentication/interfaces.py +1 -1
- nat/authentication/oauth2/__init__.py +1 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +1 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +1 -1
- nat/authentication/oauth2/oauth2_resource_server_config.py +1 -1
- nat/authentication/oauth2/register.py +1 -1
- nat/authentication/register.py +1 -1
- nat/builder/builder.py +511 -1
- nat/builder/child_builder.py +385 -0
- nat/builder/component_utils.py +28 -4
- nat/builder/context.py +17 -1
- nat/builder/embedder.py +1 -1
- nat/builder/eval_builder.py +19 -7
- nat/builder/evaluator.py +1 -1
- nat/builder/framework_enum.py +2 -1
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +40 -3
- nat/builder/function_base.py +1 -1
- nat/builder/function_info.py +1 -1
- nat/builder/intermediate_step_manager.py +1 -1
- nat/builder/llm.py +1 -1
- nat/builder/per_user_workflow_builder.py +843 -0
- nat/builder/retriever.py +1 -1
- nat/builder/sync_builder.py +571 -0
- nat/builder/user_interaction_manager.py +1 -1
- nat/builder/workflow.py +1 -1
- nat/builder/workflow_builder.py +536 -424
- nat/cli/__init__.py +1 -1
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/cli_utils/validation.py +32 -1
- nat/cli/commands/configure/channel/add.py +1 -1
- nat/cli/commands/configure/channel/channel.py +1 -1
- nat/cli/commands/configure/channel/remove.py +1 -1
- nat/cli/commands/configure/channel/update.py +1 -1
- nat/cli/commands/configure/configure.py +1 -1
- nat/cli/commands/evaluate.py +87 -13
- nat/cli/commands/finetune.py +132 -0
- nat/cli/commands/info/__init__.py +1 -1
- nat/cli/commands/info/info.py +1 -1
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +1 -1
- nat/cli/commands/object_store/__init__.py +1 -1
- nat/cli/commands/object_store/object_store.py +1 -1
- nat/cli/commands/optimize.py +1 -1
- nat/cli/commands/{mcp → red_teaming}/__init__.py +1 -1
- nat/cli/commands/red_teaming/red_teaming.py +138 -0
- nat/cli/commands/red_teaming/red_teaming_utils.py +73 -0
- nat/cli/commands/registry/__init__.py +1 -1
- nat/cli/commands/registry/publish.py +1 -1
- nat/cli/commands/registry/pull.py +1 -1
- nat/cli/commands/registry/registry.py +1 -1
- nat/cli/commands/registry/remove.py +1 -1
- nat/cli/commands/registry/search.py +1 -1
- nat/cli/commands/sizing/__init__.py +1 -1
- nat/cli/commands/sizing/calc.py +1 -1
- nat/cli/commands/sizing/sizing.py +1 -1
- nat/cli/commands/start.py +1 -1
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/validate.py +1 -1
- nat/cli/commands/workflow/__init__.py +1 -1
- nat/cli/commands/workflow/workflow.py +1 -1
- nat/cli/commands/workflow/workflow_commands.py +3 -2
- nat/cli/entrypoint.py +15 -37
- nat/cli/main.py +2 -2
- nat/cli/plugin_loader.py +69 -0
- nat/cli/register_workflow.py +183 -5
- nat/cli/type_registry.py +169 -3
- nat/control_flow/register.py +1 -1
- nat/control_flow/router_agent/agent.py +1 -1
- nat/control_flow/router_agent/prompt.py +1 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/control_flow/sequential_executor.py +28 -7
- nat/data_models/__init__.py +1 -1
- nat/data_models/agent.py +1 -1
- nat/data_models/api_server.py +38 -3
- nat/data_models/authentication.py +1 -1
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +7 -1
- nat/data_models/component_ref.py +34 -1
- nat/data_models/config.py +62 -1
- nat/data_models/dataset_handler.py +15 -2
- nat/data_models/discovery_metadata.py +1 -1
- nat/data_models/embedder.py +1 -1
- nat/data_models/evaluate.py +6 -1
- nat/data_models/evaluator.py +1 -1
- nat/data_models/finetuning.py +260 -0
- nat/data_models/front_end.py +1 -1
- nat/data_models/function.py +1 -1
- nat/data_models/function_dependencies.py +1 -1
- nat/data_models/gated_field_mixin.py +1 -1
- nat/data_models/interactive.py +1 -1
- nat/data_models/intermediate_step.py +29 -2
- nat/data_models/invocation_node.py +1 -1
- nat/data_models/llm.py +1 -1
- nat/data_models/logging.py +1 -1
- nat/data_models/memory.py +1 -1
- nat/data_models/middleware.py +3 -1
- nat/data_models/object_store.py +1 -1
- nat/data_models/openai_mcp.py +1 -1
- nat/data_models/optimizable.py +1 -1
- nat/data_models/optimizer.py +1 -1
- nat/data_models/profiler.py +1 -1
- nat/data_models/registry_handler.py +1 -1
- nat/data_models/retriever.py +1 -1
- nat/data_models/retry_mixin.py +1 -1
- nat/data_models/runtime_enum.py +1 -1
- nat/data_models/span.py +1 -1
- nat/data_models/step_adaptor.py +1 -1
- nat/data_models/streaming.py +1 -1
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/telemetry_exporter.py +1 -1
- nat/data_models/thinking_mixin.py +1 -1
- nat/data_models/ttc_strategy.py +1 -1
- nat/embedder/azure_openai_embedder.py +1 -1
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +1 -1
- nat/eval/__init__.py +1 -1
- nat/eval/config.py +8 -1
- nat/eval/dataset_handler/dataset_downloader.py +1 -1
- nat/eval/dataset_handler/dataset_filter.py +1 -1
- nat/eval/dataset_handler/dataset_handler.py +4 -2
- nat/eval/evaluate.py +217 -80
- nat/eval/evaluator/__init__.py +1 -1
- nat/eval/evaluator/base_evaluator.py +2 -2
- nat/eval/evaluator/evaluator_model.py +3 -2
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/llm_validator.py +336 -0
- nat/eval/rag_evaluator/evaluate.py +17 -10
- nat/eval/rag_evaluator/register.py +1 -1
- nat/eval/red_teaming_evaluator/__init__.py +14 -0
- nat/eval/red_teaming_evaluator/data_models.py +66 -0
- nat/eval/red_teaming_evaluator/evaluate.py +327 -0
- nat/eval/red_teaming_evaluator/filter_conditions.py +75 -0
- nat/eval/red_teaming_evaluator/register.py +55 -0
- nat/eval/register.py +2 -1
- nat/eval/remote_workflow.py +1 -1
- nat/eval/runners/__init__.py +1 -1
- nat/eval/runners/config.py +1 -1
- nat/eval/runners/multi_eval_runner.py +1 -1
- nat/eval/runners/red_teaming_runner/__init__.py +24 -0
- nat/eval/runners/red_teaming_runner/config.py +282 -0
- nat/eval/runners/red_teaming_runner/report_utils.py +707 -0
- nat/eval/runners/red_teaming_runner/runner.py +867 -0
- nat/eval/runtime_evaluator/__init__.py +1 -1
- nat/eval/runtime_evaluator/evaluate.py +1 -1
- nat/eval/runtime_evaluator/register.py +1 -1
- nat/eval/runtime_event_subscriber.py +1 -1
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/swe_bench_evaluator/register.py +1 -1
- nat/eval/trajectory_evaluator/evaluate.py +2 -2
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +5 -5
- nat/eval/tunable_rag_evaluator/register.py +1 -1
- nat/eval/usage_stats.py +1 -1
- nat/eval/utils/eval_trace_ctx.py +1 -1
- nat/eval/utils/output_uploader.py +1 -1
- nat/eval/utils/tqdm_position_registry.py +1 -1
- nat/eval/utils/weave_eval.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +1 -1
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +1 -1
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +1 -1
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +1 -1
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/multi_llm_judge_function.py +88 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/editor_config.py +1 -1
- nat/experimental/test_time_compute/models/scoring_config.py +1 -1
- nat/experimental/test_time_compute/models/search_config.py +20 -2
- nat/experimental/test_time_compute/models/selection_config.py +33 -2
- nat/experimental/test_time_compute/models/stage_enums.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +1 -1
- nat/experimental/test_time_compute/models/tool_use_config.py +1 -1
- nat/experimental/test_time_compute/models/ttc_item.py +1 -1
- nat/experimental/test_time_compute/register.py +4 -1
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +1 -1
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +1 -1
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +1 -1
- nat/experimental/test_time_compute/search/multi_llm_generation.py +115 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +1 -1
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +1 -1
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +1 -1
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +1 -1
- nat/experimental/test_time_compute/selection/llm_judge_selection.py +127 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +1 -1
- nat/finetuning/__init__.py +24 -0
- nat/finetuning/finetuning_runtime.py +143 -0
- nat/finetuning/interfaces/__init__.py +24 -0
- nat/finetuning/interfaces/finetuning_runner.py +261 -0
- nat/finetuning/interfaces/trainer_adapter.py +103 -0
- nat/finetuning/interfaces/trajectory_builder.py +115 -0
- nat/finetuning/utils/__init__.py +15 -0
- nat/finetuning/utils/parsers/__init__.py +15 -0
- nat/finetuning/utils/parsers/adk_parser.py +141 -0
- nat/finetuning/utils/parsers/base_parser.py +238 -0
- nat/finetuning/utils/parsers/common.py +91 -0
- nat/finetuning/utils/parsers/langchain_parser.py +267 -0
- nat/finetuning/utils/parsers/llama_index_parser.py +218 -0
- nat/front_ends/__init__.py +1 -1
- nat/front_ends/console/__init__.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +1 -1
- nat/front_ends/console/console_front_end_config.py +4 -1
- nat/front_ends/console/console_front_end_plugin.py +5 -4
- nat/front_ends/console/register.py +1 -1
- nat/front_ends/cron/__init__.py +1 -1
- nat/front_ends/fastapi/__init__.py +1 -1
- nat/front_ends/fastapi/async_job.py +128 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +13 -9
- nat/front_ends/fastapi/dask_client_mixin.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_config.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_controller.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +25 -30
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +195 -60
- nat/front_ends/fastapi/html_snippets/__init__.py +1 -1
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +1 -1
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +12 -1
- nat/front_ends/fastapi/job_store.py +23 -11
- nat/front_ends/fastapi/main.py +1 -1
- nat/front_ends/fastapi/message_handler.py +27 -4
- nat/front_ends/fastapi/message_validator.py +54 -2
- nat/front_ends/fastapi/register.py +1 -1
- nat/front_ends/fastapi/response_helpers.py +16 -15
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/fastapi/utils.py +1 -1
- nat/front_ends/register.py +1 -2
- nat/front_ends/simple_base/__init__.py +1 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +6 -4
- nat/llm/aws_bedrock_llm.py +1 -1
- nat/llm/azure_openai_llm.py +10 -1
- nat/llm/dynamo_llm.py +363 -0
- nat/llm/huggingface_llm.py +177 -0
- nat/llm/litellm_llm.py +1 -1
- nat/llm/nim_llm.py +1 -1
- nat/llm/openai_llm.py +1 -1
- nat/llm/register.py +3 -1
- nat/llm/utils/__init__.py +1 -1
- nat/llm/utils/env_config_value.py +1 -1
- nat/llm/utils/error.py +1 -1
- nat/llm/utils/thinking.py +1 -1
- nat/memory/__init__.py +1 -1
- nat/memory/interfaces.py +1 -1
- nat/memory/models.py +1 -1
- nat/meta/pypi.md +1 -1
- nat/middleware/__init__.py +5 -5
- nat/middleware/cache/__init__.py +14 -0
- nat/middleware/{cache_middleware.py → cache/cache_middleware.py} +39 -42
- nat/middleware/cache/cache_middleware_config.py +44 -0
- nat/middleware/cache/register.py +33 -0
- nat/middleware/defense/__init__.py +14 -0
- nat/middleware/defense/defense_middleware.py +362 -0
- nat/middleware/defense/defense_middleware_content_guard.py +455 -0
- nat/middleware/defense/defense_middleware_data_models.py +91 -0
- nat/middleware/defense/defense_middleware_output_verifier.py +440 -0
- nat/middleware/defense/defense_middleware_pii.py +356 -0
- nat/middleware/defense/register.py +82 -0
- nat/middleware/dynamic/__init__.py +14 -0
- nat/middleware/dynamic/dynamic_function_middleware.py +962 -0
- nat/middleware/dynamic/dynamic_middleware_config.py +132 -0
- nat/middleware/dynamic/register.py +34 -0
- nat/middleware/function_middleware.py +236 -52
- nat/middleware/logging/__init__.py +14 -0
- nat/middleware/logging/logging_middleware.py +67 -0
- nat/middleware/logging/logging_middleware_config.py +28 -0
- nat/middleware/logging/register.py +33 -0
- nat/middleware/middleware.py +142 -28
- nat/middleware/red_teaming/__init__.py +14 -0
- nat/middleware/red_teaming/red_teaming_middleware.py +344 -0
- nat/middleware/red_teaming/red_teaming_middleware_config.py +112 -0
- nat/middleware/red_teaming/register.py +47 -0
- nat/middleware/register.py +7 -20
- nat/middleware/utils/__init__.py +14 -0
- nat/middleware/utils/workflow_inventory.py +155 -0
- nat/object_store/__init__.py +1 -1
- nat/object_store/in_memory_object_store.py +1 -1
- nat/object_store/interfaces.py +1 -1
- nat/object_store/models.py +1 -1
- nat/object_store/register.py +1 -1
- nat/observability/__init__.py +1 -1
- nat/observability/exporter/__init__.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/exporter.py +1 -1
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +1 -1
- nat/observability/exporter/raw_exporter.py +1 -1
- nat/observability/exporter/span_exporter.py +7 -1
- nat/observability/exporter_manager.py +1 -1
- nat/observability/mixin/__init__.py +1 -1
- nat/observability/mixin/batch_config_mixin.py +1 -1
- nat/observability/mixin/collector_config_mixin.py +1 -1
- nat/observability/mixin/file_mixin.py +1 -1
- nat/observability/mixin/file_mode.py +1 -1
- nat/observability/mixin/redaction_config_mixin.py +1 -1
- nat/observability/mixin/resource_conflict_mixin.py +1 -1
- nat/observability/mixin/serialize_mixin.py +1 -1
- nat/observability/mixin/tagging_config_mixin.py +1 -1
- nat/observability/mixin/type_introspection_mixin.py +1 -1
- nat/observability/processor/__init__.py +1 -1
- nat/observability/processor/batching_processor.py +1 -1
- nat/observability/processor/callback_processor.py +1 -1
- nat/observability/processor/falsy_batch_filter_processor.py +1 -1
- nat/observability/processor/intermediate_step_serializer.py +1 -1
- nat/observability/processor/processor.py +1 -1
- nat/observability/processor/processor_factory.py +1 -1
- nat/observability/processor/redaction/__init__.py +1 -1
- nat/observability/processor/redaction/contextual_redaction_processor.py +1 -1
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +1 -1
- nat/observability/processor/redaction/redaction_processor.py +1 -1
- nat/observability/processor/redaction/span_header_redaction_processor.py +1 -1
- nat/observability/processor/span_tagging_processor.py +1 -1
- nat/observability/register.py +1 -1
- nat/observability/utils/__init__.py +1 -1
- nat/observability/utils/dict_utils.py +1 -1
- nat/observability/utils/time_utils.py +1 -1
- nat/profiler/calc/__init__.py +1 -1
- nat/profiler/calc/calc_runner.py +3 -3
- nat/profiler/calc/calculations.py +1 -1
- nat/profiler/calc/data_models.py +1 -1
- nat/profiler/calc/plot.py +30 -3
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/base_callback_class.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +33 -3
- nat/profiler/callbacks/llama_index_callback_handler.py +13 -10
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- nat/profiler/callbacks/token_usage_base_model.py +1 -1
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/data_models.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +16 -1
- nat/profiler/decorators/function_tracking.py +1 -1
- nat/profiler/forecasting/config.py +1 -1
- nat/profiler/forecasting/model_trainer.py +1 -1
- nat/profiler/forecasting/models/__init__.py +1 -1
- nat/profiler/forecasting/models/forecasting_base_model.py +1 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_metrics_model.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +1 -1
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/profiler/inference_optimization/llm_metrics.py +1 -1
- nat/profiler/inference_optimization/prompt_caching.py +1 -1
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/inference_optimization/workflow_runtimes.py +1 -1
- nat/profiler/intermediate_property_adapter.py +1 -1
- nat/profiler/parameter_optimization/optimizable_utils.py +1 -1
- nat/profiler/parameter_optimization/optimizer_runtime.py +1 -1
- nat/profiler/parameter_optimization/parameter_optimizer.py +1 -1
- nat/profiler/parameter_optimization/parameter_selection.py +1 -1
- nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
- nat/profiler/parameter_optimization/prompt_optimizer.py +1 -1
- nat/profiler/parameter_optimization/update_helpers.py +1 -1
- nat/profiler/profile_runner.py +1 -1
- nat/profiler/utils.py +1 -1
- nat/registry_handlers/local/local_handler.py +1 -1
- nat/registry_handlers/local/register_local.py +1 -1
- nat/registry_handlers/metadata_factory.py +1 -1
- nat/registry_handlers/package_utils.py +1 -1
- nat/registry_handlers/pypi/pypi_handler.py +1 -1
- nat/registry_handlers/pypi/register_pypi.py +1 -1
- nat/registry_handlers/register.py +1 -1
- nat/registry_handlers/registry_handler_base.py +1 -1
- nat/registry_handlers/rest/register_rest.py +1 -1
- nat/registry_handlers/rest/rest_handler.py +1 -1
- nat/registry_handlers/schemas/headers.py +1 -1
- nat/registry_handlers/schemas/package.py +1 -1
- nat/registry_handlers/schemas/publish.py +1 -1
- nat/registry_handlers/schemas/pull.py +1 -1
- nat/registry_handlers/schemas/remove.py +1 -1
- nat/registry_handlers/schemas/search.py +1 -1
- nat/registry_handlers/schemas/status.py +1 -1
- nat/retriever/interface.py +1 -1
- nat/retriever/milvus/__init__.py +1 -1
- nat/retriever/milvus/register.py +1 -1
- nat/retriever/milvus/retriever.py +1 -1
- nat/retriever/models.py +1 -1
- nat/retriever/nemo_retriever/__init__.py +1 -1
- nat/retriever/nemo_retriever/register.py +1 -1
- nat/retriever/nemo_retriever/retriever.py +5 -5
- nat/retriever/register.py +1 -1
- nat/runtime/__init__.py +1 -1
- nat/runtime/loader.py +10 -3
- nat/runtime/metrics.py +180 -0
- nat/runtime/runner.py +1 -5
- nat/runtime/session.py +451 -32
- nat/runtime/user_metadata.py +1 -1
- nat/settings/global_settings.py +1 -1
- nat/tool/chat_completion.py +1 -1
- nat/tool/code_execution/README.md +1 -1
- nat/tool/code_execution/code_sandbox.py +1 -1
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +1 -1
- nat/tool/code_execution/local_sandbox/__init__.py +1 -1
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +1 -1
- nat/tool/code_execution/register.py +1 -1
- nat/tool/code_execution/utils.py +1 -1
- nat/tool/datetime_tools.py +1 -1
- nat/tool/document_search.py +1 -1
- nat/tool/github_tools.py +1 -1
- nat/tool/memory_tools/add_memory_tool.py +1 -1
- nat/tool/memory_tools/delete_memory_tool.py +1 -1
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +2 -2
- nat/tool/register.py +1 -1
- nat/tool/retriever.py +1 -1
- nat/tool/server_tools.py +1 -1
- nat/utils/__init__.py +8 -5
- nat/utils/callable_utils.py +1 -1
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/debugging_utils.py +1 -1
- nat/utils/decorators.py +1 -1
- nat/utils/dump_distro_mapping.py +1 -1
- nat/utils/exception_handlers/automatic_retries.py +3 -3
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/model_processing.py +1 -1
- nat/utils/io/supress_logs.py +33 -0
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/log_levels.py +1 -1
- nat/utils/log_utils.py +13 -1
- nat/utils/metadata_utils.py +1 -1
- nat/utils/optional_imports.py +1 -1
- nat/utils/producer_consumer_queue.py +1 -1
- nat/utils/reactive/base/observable_base.py +1 -1
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/base/subject_base.py +1 -1
- nat/utils/reactive/observable.py +1 -1
- nat/utils/reactive/observer.py +1 -1
- nat/utils/reactive/subject.py +1 -1
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/responses_api.py +1 -1
- nat/utils/settings/global_settings.py +1 -1
- nat/utils/string_utils.py +1 -1
- nat/utils/type_converter.py +18 -5
- nat/utils/type_utils.py +1 -1
- nat/utils/url_utils.py +1 -1
- {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/METADATA +39 -14
- nvidia_nat-1.4.0a20260113.dist-info/RECORD +547 -0
- nvidia_nat-1.4.0a20260113.dist-info/entry_points.txt +38 -0
- nat/cli/commands/mcp/mcp.py +0 -986
- nat/front_ends/mcp/introspection_token_verifier.py +0 -73
- nat/front_ends/mcp/mcp_front_end_config.py +0 -109
- nat/front_ends/mcp/mcp_front_end_plugin.py +0 -155
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +0 -388
- nat/front_ends/mcp/memory_profiler.py +0 -320
- nat/front_ends/mcp/register.py +0 -27
- nat/front_ends/mcp/tool_converter.py +0 -321
- nvidia_nat-1.4.0a20251120.dist-info/RECORD +0 -488
- nvidia_nat-1.4.0a20251120.dist-info/entry_points.txt +0 -23
- {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251120.dist-info → nvidia_nat-1.4.0a20260113.dist-info}/top_level.txt +0 -0
nat/builder/workflow_builder.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024-
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -28,7 +28,8 @@ from typing import cast
|
|
|
28
28
|
from nat.authentication.interfaces import AuthProviderBase
|
|
29
29
|
from nat.builder.builder import Builder
|
|
30
30
|
from nat.builder.builder import UserManagerHolder
|
|
31
|
-
from nat.builder.
|
|
31
|
+
from nat.builder.child_builder import ChildBuilder
|
|
32
|
+
from nat.builder.component_utils import WORKFLOW_COMPONENT_NAME
|
|
32
33
|
from nat.builder.component_utils import build_dependency_sequence
|
|
33
34
|
from nat.builder.context import Context
|
|
34
35
|
from nat.builder.context import ContextState
|
|
@@ -40,6 +41,7 @@ from nat.builder.function import LambdaFunction
|
|
|
40
41
|
from nat.builder.function_info import FunctionInfo
|
|
41
42
|
from nat.builder.llm import LLMProviderInfo
|
|
42
43
|
from nat.builder.retriever import RetrieverProviderInfo
|
|
44
|
+
from nat.builder.sync_builder import SyncBuilder
|
|
43
45
|
from nat.builder.workflow import Workflow
|
|
44
46
|
from nat.cli.type_registry import GlobalTypeRegistry
|
|
45
47
|
from nat.cli.type_registry import TypeRegistry
|
|
@@ -54,10 +56,16 @@ from nat.data_models.component_ref import MemoryRef
|
|
|
54
56
|
from nat.data_models.component_ref import MiddlewareRef
|
|
55
57
|
from nat.data_models.component_ref import ObjectStoreRef
|
|
56
58
|
from nat.data_models.component_ref import RetrieverRef
|
|
59
|
+
from nat.data_models.component_ref import TrainerAdapterRef
|
|
60
|
+
from nat.data_models.component_ref import TrainerRef
|
|
61
|
+
from nat.data_models.component_ref import TrajectoryBuilderRef
|
|
57
62
|
from nat.data_models.component_ref import TTCStrategyRef
|
|
58
63
|
from nat.data_models.config import Config
|
|
59
64
|
from nat.data_models.config import GeneralConfig
|
|
60
65
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
66
|
+
from nat.data_models.finetuning import TrainerAdapterConfig
|
|
67
|
+
from nat.data_models.finetuning import TrainerConfig
|
|
68
|
+
from nat.data_models.finetuning import TrajectoryBuilderConfig
|
|
61
69
|
from nat.data_models.function import FunctionBaseConfig
|
|
62
70
|
from nat.data_models.function import FunctionGroupBaseConfig
|
|
63
71
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
@@ -72,6 +80,9 @@ from nat.experimental.decorators.experimental_warning_decorator import experimen
|
|
|
72
80
|
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
73
81
|
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
74
82
|
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
83
|
+
from nat.finetuning.interfaces.finetuning_runner import Trainer
|
|
84
|
+
from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
|
|
85
|
+
from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
|
|
75
86
|
from nat.memory.interfaces import MemoryEditor
|
|
76
87
|
from nat.middleware.function_middleware import FunctionMiddleware
|
|
77
88
|
from nat.middleware.middleware import Middleware
|
|
@@ -79,7 +90,6 @@ from nat.object_store.interfaces import ObjectStore
|
|
|
79
90
|
from nat.observability.exporter.base_exporter import BaseExporter
|
|
80
91
|
from nat.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
|
|
81
92
|
from nat.profiler.utils import detect_llm_frameworks_in_build_fn
|
|
82
|
-
from nat.retriever.interface import Retriever
|
|
83
93
|
from nat.utils.type_utils import override
|
|
84
94
|
|
|
85
95
|
logger = logging.getLogger(__name__)
|
|
@@ -151,6 +161,151 @@ class ConfiguredMiddleware:
|
|
|
151
161
|
instance: Middleware
|
|
152
162
|
|
|
153
163
|
|
|
164
|
+
@dataclasses.dataclass
|
|
165
|
+
class ConfiguredTrainer:
|
|
166
|
+
config: TrainerConfig
|
|
167
|
+
instance: Trainer
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@dataclasses.dataclass
|
|
171
|
+
class ConfiguredTrainerAdapter:
|
|
172
|
+
config: TrainerAdapterConfig
|
|
173
|
+
instance: TrainerAdapter
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@dataclasses.dataclass
|
|
177
|
+
class ConfiguredTrajectoryBuilder:
|
|
178
|
+
config: TrajectoryBuilderConfig
|
|
179
|
+
instance: TrajectoryBuilder
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _log_build_failure(component_name: str,
|
|
183
|
+
component_type: str,
|
|
184
|
+
completed_components: list[tuple[str, str]],
|
|
185
|
+
remaining_components: list[tuple[str, str]],
|
|
186
|
+
original_error: Exception) -> None:
|
|
187
|
+
"""
|
|
188
|
+
Common method to log comprehensive build failure information.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
component_name (str): The name of the component that failed to build
|
|
192
|
+
component_type (str): The type of the component that failed to build
|
|
193
|
+
completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
|
|
194
|
+
remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
|
|
195
|
+
original_error (Exception): The original exception that caused the failure
|
|
196
|
+
"""
|
|
197
|
+
logger.error("Failed to initialize component %s (%s)", component_name, component_type)
|
|
198
|
+
|
|
199
|
+
if completed_components:
|
|
200
|
+
logger.error("Successfully built components:")
|
|
201
|
+
for name, comp_type in completed_components:
|
|
202
|
+
logger.error("- %s (%s)", name, comp_type)
|
|
203
|
+
else:
|
|
204
|
+
logger.error("No components were successfully built before this failure")
|
|
205
|
+
|
|
206
|
+
if remaining_components:
|
|
207
|
+
logger.error("Remaining components to build:")
|
|
208
|
+
for name, comp_type in remaining_components:
|
|
209
|
+
logger.error("- %s (%s)", name, comp_type)
|
|
210
|
+
else:
|
|
211
|
+
logger.error("No remaining components to build")
|
|
212
|
+
|
|
213
|
+
logger.error("Original error: %s", original_error, exc_info=True)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
async def _build_function_impl(
|
|
217
|
+
*,
|
|
218
|
+
name: str,
|
|
219
|
+
config: FunctionBaseConfig,
|
|
220
|
+
registry: TypeRegistry,
|
|
221
|
+
exit_stack: AsyncExitStack,
|
|
222
|
+
inner_builder: 'ChildBuilder',
|
|
223
|
+
llms: dict[str, LLMProviderInfo],
|
|
224
|
+
dependencies: dict[str, FunctionDependencies],
|
|
225
|
+
middleware_instances: list[FunctionMiddleware],
|
|
226
|
+
) -> ConfiguredFunction:
|
|
227
|
+
"""
|
|
228
|
+
Helper for core function building logic.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
name: The function name
|
|
232
|
+
config: The function configuration
|
|
233
|
+
registry: Type registry to look up the function registration
|
|
234
|
+
exit_stack: Async exit stack for context management
|
|
235
|
+
inner_builder: ChildBuilder instance for dependency tracking
|
|
236
|
+
llms: Dictionary of LLM instances
|
|
237
|
+
dependencies: Dictionary to store function dependencies
|
|
238
|
+
middleware_instances: Pre-resolved middleware instances
|
|
239
|
+
"""
|
|
240
|
+
registration = registry.get_function(type(config))
|
|
241
|
+
|
|
242
|
+
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
243
|
+
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
244
|
+
|
|
245
|
+
build_result = await exit_stack.enter_async_context(build_fn(config, inner_builder))
|
|
246
|
+
|
|
247
|
+
dependencies[name] = inner_builder.dependencies
|
|
248
|
+
|
|
249
|
+
# If the build result is a function, wrap it in a FunctionInfo
|
|
250
|
+
if inspect.isfunction(build_result):
|
|
251
|
+
build_result = FunctionInfo.from_fn(build_result)
|
|
252
|
+
|
|
253
|
+
if isinstance(build_result, FunctionInfo):
|
|
254
|
+
build_result = LambdaFunction.from_info(config=config, info=build_result, instance_name=name)
|
|
255
|
+
|
|
256
|
+
if not isinstance(build_result, Function):
|
|
257
|
+
raise ValueError("Expected a function, FunctionInfo object, or FunctionBase object to be "
|
|
258
|
+
f"returned from the function builder. Got {type(build_result)}")
|
|
259
|
+
|
|
260
|
+
build_result.configure_middleware(middleware_instances)
|
|
261
|
+
|
|
262
|
+
return ConfiguredFunction(config=config, instance=build_result)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
async def _build_function_group_impl(
|
|
266
|
+
*,
|
|
267
|
+
name: str,
|
|
268
|
+
config: FunctionGroupBaseConfig,
|
|
269
|
+
registry: TypeRegistry,
|
|
270
|
+
exit_stack: AsyncExitStack,
|
|
271
|
+
inner_builder: 'ChildBuilder',
|
|
272
|
+
llms: dict[str, LLMProviderInfo],
|
|
273
|
+
dependencies: dict[str, FunctionDependencies],
|
|
274
|
+
middleware_instances: list[FunctionMiddleware],
|
|
275
|
+
) -> ConfiguredFunctionGroup:
|
|
276
|
+
"""
|
|
277
|
+
Core function group building logic shared between WorkflowBuilder and PerUserWorkflowBuilder.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
name: The function group name
|
|
281
|
+
config: The function group configuration
|
|
282
|
+
registry: Type registry to look up the function group registration
|
|
283
|
+
exit_stack: Async exit stack for context management
|
|
284
|
+
inner_builder: ChildBuilder instance for dependency tracking
|
|
285
|
+
llms: Dictionary of LLM instances
|
|
286
|
+
dependencies: Dictionary to store function group dependencies
|
|
287
|
+
middleware_instances: Pre-resolved middleware instances
|
|
288
|
+
"""
|
|
289
|
+
registration = registry.get_function_group(type(config))
|
|
290
|
+
|
|
291
|
+
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
292
|
+
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
293
|
+
|
|
294
|
+
build_result = await exit_stack.enter_async_context(build_fn(config, inner_builder))
|
|
295
|
+
|
|
296
|
+
dependencies[name] = inner_builder.dependencies
|
|
297
|
+
|
|
298
|
+
if not isinstance(build_result, FunctionGroup):
|
|
299
|
+
raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
|
|
300
|
+
f"Got {type(build_result)}")
|
|
301
|
+
|
|
302
|
+
# Set the instance name BEFORE configuring middleware
|
|
303
|
+
build_result.set_instance_name(name)
|
|
304
|
+
build_result.configure_middleware(middleware_instances)
|
|
305
|
+
|
|
306
|
+
return ConfiguredFunctionGroup(config=config, instance=build_result)
|
|
307
|
+
|
|
308
|
+
|
|
154
309
|
class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
155
310
|
|
|
156
311
|
def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
|
|
@@ -181,6 +336,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
181
336
|
self._retrievers: dict[str, ConfiguredRetriever] = {}
|
|
182
337
|
self._ttc_strategies: dict[str, ConfiguredTTCStrategy] = {}
|
|
183
338
|
self._middleware: dict[str, ConfiguredMiddleware] = {}
|
|
339
|
+
self._trainers: dict[str, ConfiguredTrainer] = {}
|
|
340
|
+
self._trainer_adapters: dict[str, ConfiguredTrainerAdapter] = {}
|
|
341
|
+
self._trajectory_builders: dict[str, ConfiguredTrajectoryBuilder] = {}
|
|
184
342
|
|
|
185
343
|
self._context_state = ContextState.get()
|
|
186
344
|
|
|
@@ -189,8 +347,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
189
347
|
# Create a mapping to track function name -> other function names it depends on
|
|
190
348
|
self.function_dependencies: dict[str, FunctionDependencies] = {}
|
|
191
349
|
self.function_group_dependencies: dict[str, FunctionDependencies] = {}
|
|
192
|
-
|
|
193
|
-
|
|
350
|
+
|
|
351
|
+
# List of completed built components
|
|
352
|
+
self.completed_components: list[tuple[str, str]] = []
|
|
353
|
+
# List of remaining components to be built
|
|
354
|
+
self.remaining_components: list[tuple[str, str]] = []
|
|
194
355
|
|
|
195
356
|
async def __aenter__(self):
|
|
196
357
|
|
|
@@ -271,6 +432,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
271
432
|
|
|
272
433
|
await self._exit_stack.__aexit__(*exc_details)
|
|
273
434
|
|
|
435
|
+
@override
|
|
436
|
+
@property
|
|
437
|
+
def sync_builder(self) -> SyncBuilder:
|
|
438
|
+
return SyncBuilder(self)
|
|
439
|
+
|
|
274
440
|
async def build(self, entry_function: str | None = None) -> Workflow:
|
|
275
441
|
"""
|
|
276
442
|
Creates an instance of a workflow object using the added components and the desired entry function.
|
|
@@ -345,6 +511,18 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
345
511
|
ttc_strategies={
|
|
346
512
|
k: v.config
|
|
347
513
|
for k, v in self._ttc_strategies.items()
|
|
514
|
+
},
|
|
515
|
+
trainers={
|
|
516
|
+
k: v.config
|
|
517
|
+
for k, v in self._trainers.items()
|
|
518
|
+
},
|
|
519
|
+
trainer_adapters={
|
|
520
|
+
k: v.config
|
|
521
|
+
for k, v in self._trainer_adapters.items()
|
|
522
|
+
},
|
|
523
|
+
trajectory_builders={
|
|
524
|
+
k: v.config
|
|
525
|
+
for k, v in self._trajectory_builders.items()
|
|
348
526
|
})
|
|
349
527
|
|
|
350
528
|
if (entry_function is None):
|
|
@@ -396,61 +574,49 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
396
574
|
|
|
397
575
|
return self._exit_stack
|
|
398
576
|
|
|
399
|
-
async def
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
# We need to do this for every function because we don't know
|
|
405
|
-
# Where LLama Index Agents are Instantiated and Settings need to
|
|
406
|
-
# be set before the function is built
|
|
407
|
-
# It's only slower the first time because of the import
|
|
408
|
-
# So we can afford to do this for every function
|
|
409
|
-
|
|
410
|
-
llms = {k: v.instance for k, v in self._llms.items()}
|
|
411
|
-
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
412
|
-
|
|
413
|
-
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
414
|
-
|
|
415
|
-
# Set the currently building function so the ChildBuilder can track dependencies
|
|
416
|
-
self.current_function_building = config.type
|
|
417
|
-
# Empty set of dependencies for the current function
|
|
418
|
-
self.function_dependencies[config.type] = FunctionDependencies()
|
|
419
|
-
|
|
420
|
-
build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
|
|
421
|
-
|
|
422
|
-
self.function_dependencies[name] = inner_builder.dependencies
|
|
423
|
-
|
|
424
|
-
# If the build result is a function, wrap it in a FunctionInfo
|
|
425
|
-
if inspect.isfunction(build_result):
|
|
426
|
-
|
|
427
|
-
build_result = FunctionInfo.from_fn(build_result)
|
|
428
|
-
|
|
429
|
-
if (isinstance(build_result, FunctionInfo)):
|
|
430
|
-
# Create the function object
|
|
431
|
-
build_result = LambdaFunction.from_info(config=config, info=build_result, instance_name=name)
|
|
432
|
-
|
|
433
|
-
if (not isinstance(build_result, Function)):
|
|
434
|
-
raise ValueError("Expected a function, FunctionInfo object, or FunctionBase object to be "
|
|
435
|
-
f"returned from the function builder. Got {type(build_result)}")
|
|
577
|
+
async def _resolve_middleware_instances(self, middleware_names: list[str], component_name: str,
|
|
578
|
+
component_type: str) -> list[FunctionMiddleware]:
|
|
579
|
+
"""
|
|
580
|
+
Resolve middleware names to FunctionMiddleware instances.
|
|
581
|
+
"""
|
|
436
582
|
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
middleware_instances = []
|
|
440
|
-
for middleware_name in config.middleware:
|
|
583
|
+
middleware_instances: list[FunctionMiddleware] = []
|
|
584
|
+
for middleware_name in middleware_names:
|
|
441
585
|
if middleware_name not in self._middleware:
|
|
442
|
-
raise ValueError(f"Middleware `{middleware_name}` not found for
|
|
586
|
+
raise ValueError(f"Middleware `{middleware_name}` not found for {component_type} `{component_name}`. "
|
|
443
587
|
f"It must be configured in the `middleware` section of the YAML configuration.")
|
|
444
588
|
middleware_obj = self._middleware[middleware_name].instance
|
|
445
589
|
if not isinstance(middleware_obj, FunctionMiddleware):
|
|
446
|
-
raise TypeError(
|
|
447
|
-
|
|
448
|
-
|
|
590
|
+
raise TypeError(f"Middleware `{middleware_name}` is not a FunctionMiddleware and cannot be used"
|
|
591
|
+
f"with {component_type}s. "
|
|
592
|
+
f"Only FunctionMiddleware types support function-specific wrapping.")
|
|
449
593
|
middleware_instances.append(middleware_obj)
|
|
594
|
+
return middleware_instances
|
|
450
595
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
596
|
+
async def _build_function(self, name: str, config: FunctionBaseConfig) -> ConfiguredFunction:
|
|
597
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
598
|
+
|
|
599
|
+
# We need to do this for every function because we don't know
|
|
600
|
+
# Where LLama Index Agents are Instantiated and Settings need to
|
|
601
|
+
# be set before the function is built
|
|
602
|
+
# It's only slower the first time because of the import
|
|
603
|
+
# So we can afford to do this for every function
|
|
604
|
+
llms = {k: v.instance for k, v in self._llms.items()}
|
|
605
|
+
|
|
606
|
+
# Resolve middleware names from config to middleware instances
|
|
607
|
+
# Only FunctionMiddleware types can be used with functions
|
|
608
|
+
middleware_instances = await self._resolve_middleware_instances(config.middleware, name, "function")
|
|
609
|
+
|
|
610
|
+
return await _build_function_impl(
|
|
611
|
+
name=name,
|
|
612
|
+
config=config,
|
|
613
|
+
registry=self._registry,
|
|
614
|
+
exit_stack=self._get_exit_stack(),
|
|
615
|
+
inner_builder=inner_builder,
|
|
616
|
+
llms=llms,
|
|
617
|
+
dependencies=self.function_dependencies,
|
|
618
|
+
middleware_instances=middleware_instances,
|
|
619
|
+
)
|
|
454
620
|
|
|
455
621
|
async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
|
|
456
622
|
"""Build a function group from the provided configuration.
|
|
@@ -465,49 +631,23 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
465
631
|
Raises:
|
|
466
632
|
ValueError: If the function group builder returns invalid results
|
|
467
633
|
"""
|
|
468
|
-
registration = self._registry.get_function_group(type(config))
|
|
469
|
-
|
|
470
|
-
inner_builder = ChildBuilder(self)
|
|
471
|
-
|
|
472
|
-
# Build the function group - use the same wrapping pattern as _build_function
|
|
473
|
-
llms = {k: v.instance for k, v in self._llms.items()}
|
|
474
|
-
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
475
|
-
|
|
476
|
-
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
477
634
|
|
|
478
|
-
|
|
479
|
-
self.current_function_group_building = config.type
|
|
480
|
-
# Empty set of dependencies for the current function group
|
|
481
|
-
self.function_group_dependencies[config.type] = FunctionDependencies()
|
|
635
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
482
636
|
|
|
483
|
-
|
|
637
|
+
# Build the function group - use the same wrapping pattern as _build_function
|
|
638
|
+
llms = {k: v.instance for k, v in self._llms.items()}
|
|
639
|
+
# Resolve middleware names from config to middleware instances
|
|
640
|
+
# Only FunctionMiddleware types can be used with function groups
|
|
641
|
+
middleware_instances = await self._resolve_middleware_instances(config.middleware, name, "function group")
|
|
484
642
|
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
middleware_instances = []
|
|
494
|
-
for middleware_name in config.middleware:
|
|
495
|
-
if middleware_name not in self._middleware:
|
|
496
|
-
raise ValueError(f"Middleware `{middleware_name}` not found for function group `{name}`. "
|
|
497
|
-
f"It must be configured in the `middleware` section of the YAML configuration.")
|
|
498
|
-
middleware_obj = self._middleware[middleware_name].instance
|
|
499
|
-
if not isinstance(middleware_obj, FunctionMiddleware):
|
|
500
|
-
raise TypeError(f"Middleware `{middleware_name}` is not a FunctionMiddleware and "
|
|
501
|
-
f"cannot be used with function groups. "
|
|
502
|
-
f"Only FunctionMiddleware types support function-specific wrapping.")
|
|
503
|
-
middleware_instances.append(middleware_obj)
|
|
504
|
-
|
|
505
|
-
# Configure middleware for the function group
|
|
506
|
-
build_result.configure_middleware(middleware_instances)
|
|
507
|
-
|
|
508
|
-
# set the instance name for the function group based on the workflow-provided name
|
|
509
|
-
build_result.set_instance_name(name)
|
|
510
|
-
return ConfiguredFunctionGroup(config=config, instance=build_result)
|
|
643
|
+
return await _build_function_group_impl(name=name,
|
|
644
|
+
config=config,
|
|
645
|
+
registry=self._registry,
|
|
646
|
+
exit_stack=self._get_exit_stack(),
|
|
647
|
+
inner_builder=inner_builder,
|
|
648
|
+
llms=llms,
|
|
649
|
+
dependencies=self.function_group_dependencies,
|
|
650
|
+
middleware_instances=middleware_instances)
|
|
511
651
|
|
|
512
652
|
@override
|
|
513
653
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
@@ -516,6 +656,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
516
656
|
|
|
517
657
|
if (name in self._functions or name in self._function_groups):
|
|
518
658
|
raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
|
|
659
|
+
if any(name.startswith(k + FunctionGroup.SEPARATOR) for k in self._function_groups.keys()):
|
|
660
|
+
raise ValueError(f"A Function name starts with a Function Group name: `{name}`")
|
|
519
661
|
|
|
520
662
|
build_result = await self._build_function(name=name, config=config)
|
|
521
663
|
|
|
@@ -530,6 +672,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
530
672
|
|
|
531
673
|
if (name in self._function_groups or name in self._functions):
|
|
532
674
|
raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
|
|
675
|
+
if any(k.startswith(name + FunctionGroup.SEPARATOR) for k in self._functions.keys()):
|
|
676
|
+
raise ValueError(f"A Function name starts with a Function Group name: `{name}`")
|
|
533
677
|
|
|
534
678
|
# Build the function group
|
|
535
679
|
build_result = await self._build_function_group(name=name, config=config)
|
|
@@ -549,10 +693,23 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
549
693
|
|
|
550
694
|
return build_result.instance
|
|
551
695
|
|
|
696
|
+
def _check_backwards_compatibility_function_name(self, name: str) -> str:
|
|
697
|
+
if name in self._functions:
|
|
698
|
+
return name
|
|
699
|
+
new_name = name.replace(FunctionGroup.LEGACY_SEPARATOR, FunctionGroup.SEPARATOR)
|
|
700
|
+
if new_name in self._functions:
|
|
701
|
+
logger.warning(
|
|
702
|
+
f"Function `{name}` is deprecated and will be removed in a future release. Use `{new_name}` instead.")
|
|
703
|
+
return new_name
|
|
704
|
+
return name
|
|
705
|
+
|
|
552
706
|
@override
|
|
553
707
|
async def get_function(self, name: str | FunctionRef) -> Function:
|
|
554
708
|
if isinstance(name, FunctionRef):
|
|
555
709
|
name = str(name)
|
|
710
|
+
|
|
711
|
+
name = self._check_backwards_compatibility_function_name(name)
|
|
712
|
+
|
|
556
713
|
if name not in self._functions:
|
|
557
714
|
raise ValueError(f"Function `{name}` not found")
|
|
558
715
|
|
|
@@ -571,6 +728,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
571
728
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
572
729
|
if isinstance(name, FunctionRef):
|
|
573
730
|
name = str(name)
|
|
731
|
+
name = self._check_backwards_compatibility_function_name(name)
|
|
574
732
|
if name not in self._functions:
|
|
575
733
|
raise ValueError(f"Function `{name}` not found")
|
|
576
734
|
|
|
@@ -591,7 +749,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
591
749
|
if self._workflow is not None:
|
|
592
750
|
warnings.warn("Overwriting existing workflow")
|
|
593
751
|
|
|
594
|
-
build_result = await self._build_function(name=
|
|
752
|
+
build_result = await self._build_function(name=WORKFLOW_COMPONENT_NAME, config=config)
|
|
595
753
|
|
|
596
754
|
self._workflow = build_result
|
|
597
755
|
|
|
@@ -662,6 +820,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
662
820
|
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
663
821
|
if isinstance(fn_name, FunctionRef):
|
|
664
822
|
fn_name = str(fn_name)
|
|
823
|
+
|
|
824
|
+
fn_name = self._check_backwards_compatibility_function_name(fn_name)
|
|
825
|
+
|
|
665
826
|
if fn_name not in self._functions:
|
|
666
827
|
raise ValueError(f"Function `{fn_name}` not found in list of functions")
|
|
667
828
|
fn = self._functions[fn_name]
|
|
@@ -684,9 +845,10 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
684
845
|
try:
|
|
685
846
|
llm_info = self._registry.get_llm_provider(type(config))
|
|
686
847
|
|
|
687
|
-
|
|
848
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
849
|
+
info_obj = await self._get_exit_stack().enter_async_context(llm_info.build_fn(config, inner_builder))
|
|
688
850
|
|
|
689
|
-
|
|
851
|
+
self._llms[name] = ConfiguredLLM(config=config, instance=info_obj)
|
|
690
852
|
except Exception as e:
|
|
691
853
|
logger.error("Error adding llm `%s` with config `%s`: %s", name, config, e)
|
|
692
854
|
raise
|
|
@@ -704,7 +866,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
704
866
|
# Generate wrapped client from registered client info
|
|
705
867
|
client_info = self._registry.get_llm_client(config_type=type(llm_info.config), wrapper_type=wrapper_type)
|
|
706
868
|
|
|
707
|
-
|
|
869
|
+
with ChildBuilder.use(llm_info.config, self) as inner_builder:
|
|
870
|
+
client = await self._get_exit_stack().enter_async_context(
|
|
871
|
+
client_info.build_fn(llm_info.config, inner_builder))
|
|
708
872
|
|
|
709
873
|
# Return a frameworks specific client
|
|
710
874
|
return client
|
|
@@ -754,7 +918,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
754
918
|
try:
|
|
755
919
|
authentication_info = self._registry.get_auth_provider(type(config))
|
|
756
920
|
|
|
757
|
-
|
|
921
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
922
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
923
|
+
authentication_info.build_fn(config, inner_builder))
|
|
758
924
|
|
|
759
925
|
self._auth_providers[name] = ConfiguredAuthProvider(config=config, instance=info_obj)
|
|
760
926
|
|
|
@@ -800,7 +966,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
800
966
|
try:
|
|
801
967
|
embedder_info = self._registry.get_embedder_provider(type(config))
|
|
802
968
|
|
|
803
|
-
|
|
969
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
970
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
971
|
+
embedder_info.build_fn(config, inner_builder))
|
|
804
972
|
|
|
805
973
|
self._embedders[name] = ConfiguredEmbedder(config=config, instance=info_obj)
|
|
806
974
|
except Exception as e:
|
|
@@ -820,7 +988,10 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
820
988
|
# Generate wrapped client from registered client info
|
|
821
989
|
client_info = self._registry.get_embedder_client(config_type=type(embedder_info.config),
|
|
822
990
|
wrapper_type=wrapper_type)
|
|
823
|
-
|
|
991
|
+
|
|
992
|
+
with ChildBuilder.use(embedder_info.config, self) as inner_builder:
|
|
993
|
+
client = await self._get_exit_stack().enter_async_context(
|
|
994
|
+
client_info.build_fn(embedder_info.config, inner_builder))
|
|
824
995
|
|
|
825
996
|
# Return a frameworks specific client
|
|
826
997
|
return client
|
|
@@ -845,7 +1016,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
845
1016
|
|
|
846
1017
|
memory_info = self._registry.get_memory(type(config))
|
|
847
1018
|
|
|
848
|
-
|
|
1019
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1020
|
+
info_obj = await self._get_exit_stack().enter_async_context(memory_info.build_fn(config, inner_builder))
|
|
849
1021
|
|
|
850
1022
|
self._memory_clients[name] = ConfiguredMemory(config=config, instance=info_obj)
|
|
851
1023
|
|
|
@@ -877,7 +1049,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
877
1049
|
|
|
878
1050
|
object_store_info = self._registry.get_object_store(type(config))
|
|
879
1051
|
|
|
880
|
-
|
|
1052
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1053
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
1054
|
+
object_store_info.build_fn(config, inner_builder))
|
|
881
1055
|
|
|
882
1056
|
self._object_stores[name] = ConfiguredObjectStore(config=config, instance=info_obj)
|
|
883
1057
|
|
|
@@ -906,7 +1080,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
906
1080
|
try:
|
|
907
1081
|
retriever_info = self._registry.get_retriever_provider(type(config))
|
|
908
1082
|
|
|
909
|
-
|
|
1083
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1084
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
1085
|
+
retriever_info.build_fn(config, inner_builder))
|
|
910
1086
|
|
|
911
1087
|
self._retrievers[name] = ConfiguredRetriever(config=config, instance=info_obj)
|
|
912
1088
|
|
|
@@ -930,7 +1106,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
930
1106
|
client_info = self._registry.get_retriever_client(config_type=type(retriever_info.config),
|
|
931
1107
|
wrapper_type=wrapper_type)
|
|
932
1108
|
|
|
933
|
-
|
|
1109
|
+
with ChildBuilder.use(retriever_info.config, self) as inner_builder:
|
|
1110
|
+
client = await self._get_exit_stack().enter_async_context(
|
|
1111
|
+
client_info.build_fn(retriever_info.config, inner_builder))
|
|
934
1112
|
|
|
935
1113
|
# Return a frameworks specific client
|
|
936
1114
|
return client
|
|
@@ -946,6 +1124,120 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
946
1124
|
|
|
947
1125
|
return self._retrievers[retriever_name].config
|
|
948
1126
|
|
|
1127
|
+
@override
|
|
1128
|
+
@experimental(feature_name="Finetuning")
|
|
1129
|
+
async def add_trainer(self, name: str | TrainerRef, config: TrainerConfig) -> Trainer:
|
|
1130
|
+
if (name in self._trainers):
|
|
1131
|
+
raise ValueError(f"Trainer '{name}' already exists in the list of trainers")
|
|
1132
|
+
|
|
1133
|
+
try:
|
|
1134
|
+
trainer_info = self._registry.get_trainer(type(config))
|
|
1135
|
+
|
|
1136
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1137
|
+
info_obj = await self._get_exit_stack().enter_async_context(trainer_info.build_fn(
|
|
1138
|
+
config, inner_builder))
|
|
1139
|
+
|
|
1140
|
+
self._trainers[name] = ConfiguredTrainer(config=config, instance=info_obj)
|
|
1141
|
+
|
|
1142
|
+
return info_obj
|
|
1143
|
+
|
|
1144
|
+
except Exception as e:
|
|
1145
|
+
logger.error("Error adding trainer `%s` with config `%s`: %s", name, config, e)
|
|
1146
|
+
raise
|
|
1147
|
+
|
|
1148
|
+
@override
|
|
1149
|
+
@experimental(feature_name="Finetuning")
|
|
1150
|
+
async def add_trainer_adapter(self, name: str | TrainerAdapterRef, config: TrainerAdapterConfig) -> TrainerAdapter:
|
|
1151
|
+
if (name in self._trainer_adapters):
|
|
1152
|
+
raise ValueError(f"Trainer adapter '{name}' already exists in the list of trainer adapters")
|
|
1153
|
+
|
|
1154
|
+
try:
|
|
1155
|
+
trainer_adapter_info = self._registry.get_trainer_adapter(type(config))
|
|
1156
|
+
|
|
1157
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1158
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
1159
|
+
trainer_adapter_info.build_fn(config, inner_builder))
|
|
1160
|
+
|
|
1161
|
+
self._trainer_adapters[name] = ConfiguredTrainerAdapter(config=config, instance=info_obj)
|
|
1162
|
+
|
|
1163
|
+
return info_obj
|
|
1164
|
+
|
|
1165
|
+
except Exception as e:
|
|
1166
|
+
logger.error("Error adding trainer adapter `%s` with config `%s`: %s", name, config, e)
|
|
1167
|
+
raise
|
|
1168
|
+
|
|
1169
|
+
@override
|
|
1170
|
+
@experimental(feature_name="Finetuning")
|
|
1171
|
+
async def add_trajectory_builder(self, name: str | TrajectoryBuilderRef,
|
|
1172
|
+
config: TrajectoryBuilderConfig) -> TrajectoryBuilder:
|
|
1173
|
+
if (name in self._trajectory_builders):
|
|
1174
|
+
raise ValueError(f"Trajectory builder '{name}' already exists in the list of trajectory builders")
|
|
1175
|
+
|
|
1176
|
+
try:
|
|
1177
|
+
trajectory_builder_info = self._registry.get_trajectory_builder(type(config))
|
|
1178
|
+
|
|
1179
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1180
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
1181
|
+
trajectory_builder_info.build_fn(config, inner_builder))
|
|
1182
|
+
|
|
1183
|
+
self._trajectory_builders[name] = ConfiguredTrajectoryBuilder(config=config, instance=info_obj)
|
|
1184
|
+
|
|
1185
|
+
return info_obj
|
|
1186
|
+
|
|
1187
|
+
except Exception as e:
|
|
1188
|
+
logger.error("Error adding trajectory builder `%s` with config `%s`: %s", name, config, e)
|
|
1189
|
+
raise
|
|
1190
|
+
|
|
1191
|
+
@override
|
|
1192
|
+
async def get_trainer(self,
|
|
1193
|
+
trainer_name: str | TrainerRef,
|
|
1194
|
+
trajectory_builder: TrajectoryBuilder,
|
|
1195
|
+
trainer_adapter: TrainerAdapter) -> Trainer:
|
|
1196
|
+
|
|
1197
|
+
if trainer_name not in self._trainers:
|
|
1198
|
+
raise ValueError(f"Trainer '{trainer_name}' not found")
|
|
1199
|
+
|
|
1200
|
+
trainer_instance = self._trainers[trainer_name].instance
|
|
1201
|
+
await trainer_instance.bind_components(trainer_adapter=trainer_adapter, trajectory_builder=trajectory_builder)
|
|
1202
|
+
|
|
1203
|
+
return trainer_instance
|
|
1204
|
+
|
|
1205
|
+
@override
|
|
1206
|
+
async def get_trainer_config(self, trainer_name: str | TrainerRef) -> TrainerConfig:
|
|
1207
|
+
if trainer_name not in self._trainers:
|
|
1208
|
+
raise ValueError(f"Trainer '{trainer_name}' not found")
|
|
1209
|
+
|
|
1210
|
+
return self._trainers[trainer_name].config
|
|
1211
|
+
|
|
1212
|
+
@override
|
|
1213
|
+
async def get_trainer_adapter_config(self, trainer_adapter_name: str | TrainerAdapterRef) -> TrainerAdapterConfig:
|
|
1214
|
+
if trainer_adapter_name not in self._trainer_adapters:
|
|
1215
|
+
raise ValueError(f"Trainer adapter '{trainer_adapter_name}' not found")
|
|
1216
|
+
|
|
1217
|
+
return self._trainer_adapters[trainer_adapter_name].config
|
|
1218
|
+
|
|
1219
|
+
@override
|
|
1220
|
+
async def get_trajectory_builder_config(
|
|
1221
|
+
self, trajectory_builder_name: str | TrajectoryBuilderRef) -> (TrajectoryBuilderConfig):
|
|
1222
|
+
if trajectory_builder_name not in self._trajectory_builders:
|
|
1223
|
+
raise ValueError(f"Trajectory builder '{trajectory_builder_name}' not found")
|
|
1224
|
+
|
|
1225
|
+
return self._trajectory_builders[trajectory_builder_name].config
|
|
1226
|
+
|
|
1227
|
+
@override
|
|
1228
|
+
async def get_trainer_adapter(self, trainer_adapter_name: str | TrainerAdapterRef) -> TrainerAdapter:
|
|
1229
|
+
if trainer_adapter_name not in self._trainer_adapters:
|
|
1230
|
+
raise ValueError(f"Trainer adapter '{trainer_adapter_name}' not found")
|
|
1231
|
+
|
|
1232
|
+
return self._trainer_adapters[trainer_adapter_name].instance
|
|
1233
|
+
|
|
1234
|
+
@override
|
|
1235
|
+
async def get_trajectory_builder(self, trajectory_builder_name: str | TrajectoryBuilderRef) -> TrajectoryBuilder:
|
|
1236
|
+
if trajectory_builder_name not in self._trajectory_builders:
|
|
1237
|
+
raise ValueError(f"Trajectory builder '{trajectory_builder_name}' not found")
|
|
1238
|
+
|
|
1239
|
+
return self._trajectory_builders[trajectory_builder_name].instance
|
|
1240
|
+
|
|
949
1241
|
@override
|
|
950
1242
|
@experimental(feature_name="TTC")
|
|
951
1243
|
async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig) -> None:
|
|
@@ -955,7 +1247,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
955
1247
|
try:
|
|
956
1248
|
ttc_strategy_info = self._registry.get_ttc_strategy(type(config))
|
|
957
1249
|
|
|
958
|
-
|
|
1250
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1251
|
+
info_obj = await self._get_exit_stack().enter_async_context(
|
|
1252
|
+
ttc_strategy_info.build_fn(config, inner_builder))
|
|
959
1253
|
|
|
960
1254
|
self._ttc_strategies[name] = ConfiguredTTCStrategy(config=config, instance=info_obj)
|
|
961
1255
|
|
|
@@ -1032,8 +1326,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
1032
1326
|
try:
|
|
1033
1327
|
middleware_info = self._registry.get_middleware(type(config))
|
|
1034
1328
|
|
|
1035
|
-
|
|
1036
|
-
|
|
1329
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1330
|
+
middleware_instance = await self._get_exit_stack().enter_async_context(
|
|
1331
|
+
middleware_info.build_fn(config, inner_builder))
|
|
1037
1332
|
|
|
1038
1333
|
self._middleware[name] = ConfiguredMiddleware(config=config, instance=middleware_instance)
|
|
1039
1334
|
|
|
@@ -1095,82 +1390,13 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
1095
1390
|
exporter_info = self._registry.get_telemetry_exporter(type(config))
|
|
1096
1391
|
|
|
1097
1392
|
# Build the exporter outside the lock (parallel)
|
|
1098
|
-
|
|
1393
|
+
with ChildBuilder.use(config, self) as inner_builder:
|
|
1394
|
+
exporter_context_manager = exporter_info.build_fn(config, inner_builder)
|
|
1099
1395
|
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
self._telemetry_exporters[name] = ConfiguredTelemetryExporter(config=config, instance=exporter)
|
|
1396
|
+
# Only protect the shared state modifications (serialized)
|
|
1397
|
+
exporter = await self._get_exit_stack().enter_async_context(exporter_context_manager)
|
|
1103
1398
|
|
|
1104
|
-
|
|
1105
|
-
component_name: str,
|
|
1106
|
-
component_type: str,
|
|
1107
|
-
completed_components: list[tuple[str, str]],
|
|
1108
|
-
remaining_components: list[tuple[str, str]],
|
|
1109
|
-
original_error: Exception) -> None:
|
|
1110
|
-
"""
|
|
1111
|
-
Common method to log comprehensive build failure information.
|
|
1112
|
-
|
|
1113
|
-
Args:
|
|
1114
|
-
component_name (str): The name of the component that failed to build
|
|
1115
|
-
component_type (str): The type of the component that failed to build
|
|
1116
|
-
completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
|
|
1117
|
-
remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
|
|
1118
|
-
original_error (Exception): The original exception that caused the failure
|
|
1119
|
-
"""
|
|
1120
|
-
logger.error("Failed to initialize component %s (%s)", component_name, component_type)
|
|
1121
|
-
|
|
1122
|
-
if completed_components:
|
|
1123
|
-
logger.error("Successfully built components:")
|
|
1124
|
-
for name, comp_type in completed_components:
|
|
1125
|
-
logger.error("- %s (%s)", name, comp_type)
|
|
1126
|
-
else:
|
|
1127
|
-
logger.error("No components were successfully built before this failure")
|
|
1128
|
-
|
|
1129
|
-
if remaining_components:
|
|
1130
|
-
logger.error("Remaining components to build:")
|
|
1131
|
-
for name, comp_type in remaining_components:
|
|
1132
|
-
logger.error("- %s (%s)", name, comp_type)
|
|
1133
|
-
else:
|
|
1134
|
-
logger.error("No remaining components to build")
|
|
1135
|
-
|
|
1136
|
-
logger.error("Original error: %s", original_error, exc_info=True)
|
|
1137
|
-
|
|
1138
|
-
def _log_build_failure_component(self,
|
|
1139
|
-
failing_component: ComponentInstanceData,
|
|
1140
|
-
completed_components: list[tuple[str, str]],
|
|
1141
|
-
remaining_components: list[tuple[str, str]],
|
|
1142
|
-
original_error: Exception) -> None:
|
|
1143
|
-
"""
|
|
1144
|
-
Log comprehensive component build failure information.
|
|
1145
|
-
|
|
1146
|
-
Args:
|
|
1147
|
-
failing_component (ComponentInstanceData): The ComponentInstanceData that failed to build
|
|
1148
|
-
completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
|
|
1149
|
-
remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
|
|
1150
|
-
original_error (Exception): The original exception that caused the failure
|
|
1151
|
-
"""
|
|
1152
|
-
component_name = failing_component.name
|
|
1153
|
-
component_type = failing_component.component_group.value
|
|
1154
|
-
|
|
1155
|
-
self._log_build_failure(component_name,
|
|
1156
|
-
component_type,
|
|
1157
|
-
completed_components,
|
|
1158
|
-
remaining_components,
|
|
1159
|
-
original_error)
|
|
1160
|
-
|
|
1161
|
-
def _log_build_failure_workflow(self,
|
|
1162
|
-
completed_components: list[tuple[str, str]],
|
|
1163
|
-
remaining_components: list[tuple[str, str]],
|
|
1164
|
-
original_error: Exception) -> None:
|
|
1165
|
-
"""
|
|
1166
|
-
Log comprehensive workflow build failure information.
|
|
1167
|
-
|
|
1168
|
-
Args:
|
|
1169
|
-
completed_components (list[tuple[str, str]]): List of (name, type) tuples for successfully built components
|
|
1170
|
-
remaining_components (list[tuple[str, str]]): List of (name, type) tuples for components still to be built
|
|
1171
|
-
original_error (Exception): The original exception that caused the failure
|
|
1172
|
-
"""
|
|
1173
|
-
self._log_build_failure("<workflow>", "workflow", completed_components, remaining_components, original_error)
|
|
1399
|
+
self._telemetry_exporters[name] = ConfiguredTelemetryExporter(config=config, instance=exporter)
|
|
1174
1400
|
|
|
1175
1401
|
async def populate_builder(self, config: Config, skip_workflow: bool = False):
|
|
1176
1402
|
"""
|
|
@@ -1184,21 +1410,14 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
1184
1410
|
# Generate the build sequence
|
|
1185
1411
|
build_sequence = build_dependency_sequence(config)
|
|
1186
1412
|
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
remaining_components = [(str(comp.name), comp.component_group.value) for comp in build_sequence
|
|
1190
|
-
if not comp.is_root]
|
|
1413
|
+
self.remaining_components = [(str(comp.name), comp.component_group.value) for comp in build_sequence
|
|
1414
|
+
if not comp.is_root]
|
|
1191
1415
|
if not skip_workflow:
|
|
1192
|
-
remaining_components.append((
|
|
1416
|
+
self.remaining_components.append((WORKFLOW_COMPONENT_NAME, "workflow"))
|
|
1193
1417
|
|
|
1194
|
-
# Loop over all
|
|
1418
|
+
# Loop over all components and add to the workflow builder
|
|
1195
1419
|
for component_instance in build_sequence:
|
|
1196
1420
|
try:
|
|
1197
|
-
# Remove from remaining as we start building (if not root)
|
|
1198
|
-
if not component_instance.is_root:
|
|
1199
|
-
remaining_components.remove(
|
|
1200
|
-
(str(component_instance.name), component_instance.component_group.value))
|
|
1201
|
-
|
|
1202
1421
|
# Instantiate a the llm
|
|
1203
1422
|
if component_instance.component_group == ComponentGroup.LLMS:
|
|
1204
1423
|
await self.add_llm(component_instance.name, cast(LLMBaseConfig, component_instance.config))
|
|
@@ -1224,14 +1443,23 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
1224
1443
|
cast(MiddlewareBaseConfig, component_instance.config))
|
|
1225
1444
|
# Instantiate a function group
|
|
1226
1445
|
elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
|
|
1446
|
+
config_obj = cast(FunctionGroupBaseConfig, component_instance.config)
|
|
1447
|
+
registration = self._registry.get_function_group(type(config_obj))
|
|
1448
|
+
if registration.is_per_user:
|
|
1449
|
+
# Skip per-user function groups as they will be built lazily by PerUserWorkflowBuilder
|
|
1450
|
+
continue
|
|
1227
1451
|
await self.add_function_group(component_instance.name,
|
|
1228
1452
|
cast(FunctionGroupBaseConfig, component_instance.config))
|
|
1229
1453
|
# Instantiate a function
|
|
1230
1454
|
elif component_instance.component_group == ComponentGroup.FUNCTIONS:
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1455
|
+
config_obj = cast(FunctionBaseConfig, component_instance.config)
|
|
1456
|
+
registration = self._registry.get_function(type(config_obj))
|
|
1457
|
+
if registration.is_per_user:
|
|
1458
|
+
# Skip per-user functions as they will be built lazily by PerUserWorkflowBuilder
|
|
1459
|
+
continue
|
|
1460
|
+
elif not component_instance.is_root:
|
|
1461
|
+
# If the function is not the root, add it to the workflow builder
|
|
1462
|
+
await self.add_function(component_instance.name, config_obj)
|
|
1235
1463
|
elif component_instance.component_group == ComponentGroup.TTC_STRATEGIES:
|
|
1236
1464
|
await self.add_ttc_strategy(component_instance.name,
|
|
1237
1465
|
cast(TTCStrategyBaseConfig, component_instance.config))
|
|
@@ -1239,256 +1467,140 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
1239
1467
|
elif component_instance.component_group == ComponentGroup.AUTHENTICATION:
|
|
1240
1468
|
await self.add_auth_provider(component_instance.name,
|
|
1241
1469
|
cast(AuthProviderBaseConfig, component_instance.config))
|
|
1470
|
+
|
|
1471
|
+
elif component_instance.component_group == ComponentGroup.TRAINERS:
|
|
1472
|
+
await self.add_trainer(component_instance.name, cast(TrainerConfig, component_instance.config))
|
|
1473
|
+
|
|
1474
|
+
elif component_instance.component_group == ComponentGroup.TRAINER_ADAPTERS:
|
|
1475
|
+
await self.add_trainer_adapter(component_instance.name,
|
|
1476
|
+
cast(TrainerAdapterConfig, component_instance.config))
|
|
1477
|
+
|
|
1478
|
+
elif component_instance.component_group == ComponentGroup.TRAJECTORY_BUILDERS:
|
|
1479
|
+
await self.add_trajectory_builder(component_instance.name,
|
|
1480
|
+
cast(TrajectoryBuilderConfig, component_instance.config))
|
|
1242
1481
|
else:
|
|
1243
1482
|
raise ValueError(f"Unknown component group {component_instance.component_group}")
|
|
1244
1483
|
|
|
1245
|
-
#
|
|
1484
|
+
# Remove from remaining and add to completed after successful build (if not root)
|
|
1246
1485
|
if not component_instance.is_root:
|
|
1247
|
-
|
|
1486
|
+
self.remaining_components.remove(
|
|
1487
|
+
(str(component_instance.name), component_instance.component_group.value))
|
|
1488
|
+
self.completed_components.append(
|
|
1248
1489
|
(str(component_instance.name), component_instance.component_group.value))
|
|
1249
1490
|
|
|
1250
1491
|
except Exception as e:
|
|
1251
|
-
|
|
1492
|
+
_log_build_failure(str(component_instance.name),
|
|
1493
|
+
component_instance.component_group.value,
|
|
1494
|
+
self.completed_components,
|
|
1495
|
+
self.remaining_components,
|
|
1496
|
+
e)
|
|
1252
1497
|
raise
|
|
1253
1498
|
|
|
1254
1499
|
# Instantiate the workflow
|
|
1255
1500
|
if not skip_workflow:
|
|
1256
1501
|
try:
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1502
|
+
workflow_registration = self._registry.get_function(type(config.workflow))
|
|
1503
|
+
# If the workflow is shared (not per-user), build it
|
|
1504
|
+
# Otherwise, build it lazily by PerUserWorkflowBuilder
|
|
1505
|
+
if not workflow_registration.is_per_user:
|
|
1506
|
+
# Remove workflow from remaining as we start building
|
|
1507
|
+
self.remaining_components.remove((WORKFLOW_COMPONENT_NAME, "workflow"))
|
|
1508
|
+
await self.set_workflow(config.workflow)
|
|
1509
|
+
self.completed_components.append((WORKFLOW_COMPONENT_NAME, "workflow"))
|
|
1261
1510
|
except Exception as e:
|
|
1262
|
-
|
|
1511
|
+
_log_build_failure(WORKFLOW_COMPONENT_NAME,
|
|
1512
|
+
"workflow",
|
|
1513
|
+
self.completed_components,
|
|
1514
|
+
self.remaining_components,
|
|
1515
|
+
e)
|
|
1263
1516
|
raise
|
|
1264
1517
|
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
async def from_config(cls, config: Config):
|
|
1268
|
-
|
|
1269
|
-
async with cls(general_config=config.general) as builder:
|
|
1270
|
-
await builder.populate_builder(config)
|
|
1271
|
-
yield builder
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
class ChildBuilder(Builder):
|
|
1275
|
-
|
|
1276
|
-
def __init__(self, workflow_builder: WorkflowBuilder) -> None:
|
|
1277
|
-
|
|
1278
|
-
self._workflow_builder = workflow_builder
|
|
1279
|
-
|
|
1280
|
-
self._dependencies = FunctionDependencies()
|
|
1281
|
-
|
|
1282
|
-
@property
|
|
1283
|
-
def dependencies(self) -> FunctionDependencies:
|
|
1284
|
-
return self._dependencies
|
|
1285
|
-
|
|
1286
|
-
@override
|
|
1287
|
-
async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
|
|
1288
|
-
return await self._workflow_builder.add_function(name, config)
|
|
1289
|
-
|
|
1290
|
-
@override
|
|
1291
|
-
async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
1292
|
-
return await self._workflow_builder.add_function_group(name, config)
|
|
1293
|
-
|
|
1294
|
-
@override
|
|
1295
|
-
async def get_function(self, name: str) -> Function:
|
|
1296
|
-
# If a function tries to get another function, we assume it uses it
|
|
1297
|
-
fn = await self._workflow_builder.get_function(name)
|
|
1518
|
+
# Check if any shared components have dependencies on per-user components
|
|
1519
|
+
self._validate_dependencies(config)
|
|
1298
1520
|
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
return fn
|
|
1302
|
-
|
|
1303
|
-
@override
|
|
1304
|
-
async def get_function_group(self, name: str) -> FunctionGroup:
|
|
1305
|
-
# If a function tries to get a function group, we assume it uses it
|
|
1306
|
-
function_group = await self._workflow_builder.get_function_group(name)
|
|
1307
|
-
|
|
1308
|
-
self._dependencies.add_function_group(name)
|
|
1309
|
-
|
|
1310
|
-
return function_group
|
|
1311
|
-
|
|
1312
|
-
@override
|
|
1313
|
-
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
|
1314
|
-
return self._workflow_builder.get_function_config(name)
|
|
1315
|
-
|
|
1316
|
-
@override
|
|
1317
|
-
def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
|
|
1318
|
-
return self._workflow_builder.get_function_group_config(name)
|
|
1319
|
-
|
|
1320
|
-
@override
|
|
1321
|
-
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
1322
|
-
return await self._workflow_builder.set_workflow(config)
|
|
1323
|
-
|
|
1324
|
-
@override
|
|
1325
|
-
def get_workflow(self) -> Function:
|
|
1326
|
-
return self._workflow_builder.get_workflow()
|
|
1327
|
-
|
|
1328
|
-
@override
|
|
1329
|
-
def get_workflow_config(self) -> FunctionBaseConfig:
|
|
1330
|
-
return self._workflow_builder.get_workflow_config()
|
|
1331
|
-
|
|
1332
|
-
@override
|
|
1333
|
-
async def get_tools(self,
|
|
1334
|
-
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
1335
|
-
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
1336
|
-
tools = await self._workflow_builder.get_tools(tool_names, wrapper_type)
|
|
1337
|
-
for tool_name in tool_names:
|
|
1338
|
-
if tool_name in self._workflow_builder._function_groups:
|
|
1339
|
-
self._dependencies.add_function_group(tool_name)
|
|
1340
|
-
else:
|
|
1341
|
-
self._dependencies.add_function(tool_name)
|
|
1342
|
-
return tools
|
|
1343
|
-
|
|
1344
|
-
@override
|
|
1345
|
-
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
|
|
1346
|
-
# If a function tries to get another function as a tool, we assume it uses it
|
|
1347
|
-
fn = await self._workflow_builder.get_tool(fn_name, wrapper_type)
|
|
1348
|
-
|
|
1349
|
-
self._dependencies.add_function(fn_name)
|
|
1350
|
-
|
|
1351
|
-
return fn
|
|
1352
|
-
|
|
1353
|
-
@override
|
|
1354
|
-
async def add_llm(self, name: str, config: LLMBaseConfig) -> None:
|
|
1355
|
-
return await self._workflow_builder.add_llm(name, config)
|
|
1356
|
-
|
|
1357
|
-
@experimental(feature_name="Authentication")
|
|
1358
|
-
@override
|
|
1359
|
-
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> AuthProviderBase:
|
|
1360
|
-
return await self._workflow_builder.add_auth_provider(name, config)
|
|
1361
|
-
|
|
1362
|
-
@override
|
|
1363
|
-
async def get_auth_provider(self, auth_provider_name: str):
|
|
1364
|
-
return await self._workflow_builder.get_auth_provider(auth_provider_name)
|
|
1365
|
-
|
|
1366
|
-
@override
|
|
1367
|
-
async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
1368
|
-
llm = await self._workflow_builder.get_llm(llm_name, wrapper_type)
|
|
1369
|
-
|
|
1370
|
-
self._dependencies.add_llm(llm_name)
|
|
1371
|
-
|
|
1372
|
-
return llm
|
|
1373
|
-
|
|
1374
|
-
@override
|
|
1375
|
-
def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
|
|
1376
|
-
return self._workflow_builder.get_llm_config(llm_name)
|
|
1377
|
-
|
|
1378
|
-
@override
|
|
1379
|
-
async def add_embedder(self, name: str, config: EmbedderBaseConfig) -> None:
|
|
1380
|
-
await self._workflow_builder.add_embedder(name, config)
|
|
1381
|
-
|
|
1382
|
-
@override
|
|
1383
|
-
async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
1384
|
-
embedder = await self._workflow_builder.get_embedder(embedder_name, wrapper_type)
|
|
1385
|
-
|
|
1386
|
-
self._dependencies.add_embedder(embedder_name)
|
|
1387
|
-
|
|
1388
|
-
return embedder
|
|
1389
|
-
|
|
1390
|
-
@override
|
|
1391
|
-
def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
|
|
1392
|
-
return self._workflow_builder.get_embedder_config(embedder_name)
|
|
1393
|
-
|
|
1394
|
-
@override
|
|
1395
|
-
async def add_memory_client(self, name: str, config: MemoryBaseConfig) -> MemoryEditor:
|
|
1396
|
-
return await self._workflow_builder.add_memory_client(name, config)
|
|
1397
|
-
|
|
1398
|
-
@override
|
|
1399
|
-
async def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
|
1521
|
+
def _validate_dependencies(self, config: Config):
|
|
1400
1522
|
"""
|
|
1401
|
-
|
|
1402
|
-
"""
|
|
1403
|
-
memory_client = await self._workflow_builder.get_memory_client(memory_name)
|
|
1404
|
-
|
|
1405
|
-
self._dependencies.add_memory_client(memory_name)
|
|
1523
|
+
Validate no shared component has dependencies on any per-user components.
|
|
1406
1524
|
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
@override
|
|
1410
|
-
def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
|
|
1411
|
-
return self._workflow_builder.get_memory_client_config(memory_name=memory_name)
|
|
1412
|
-
|
|
1413
|
-
@override
|
|
1414
|
-
async def add_object_store(self, name: str, config: ObjectStoreBaseConfig):
|
|
1415
|
-
return await self._workflow_builder.add_object_store(name, config)
|
|
1416
|
-
|
|
1417
|
-
@override
|
|
1418
|
-
async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
|
|
1525
|
+
This prevents invalid configurations where shared components try to use per-user functions that do not exist
|
|
1526
|
+
at shared builder initialization time.
|
|
1419
1527
|
"""
|
|
1420
|
-
Return the instantiated object store client for the given name.
|
|
1421
|
-
"""
|
|
1422
|
-
object_store_client = await self._workflow_builder.get_object_store_client(object_store_name)
|
|
1423
|
-
|
|
1424
|
-
self._dependencies.add_object_store(object_store_name)
|
|
1425
|
-
|
|
1426
|
-
return object_store_client
|
|
1427
|
-
|
|
1428
|
-
@override
|
|
1429
|
-
def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
|
|
1430
|
-
return self._workflow_builder.get_object_store_config(object_store_name)
|
|
1431
|
-
|
|
1432
|
-
@override
|
|
1433
|
-
@experimental(feature_name="TTC")
|
|
1434
|
-
async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
|
|
1435
|
-
await self._workflow_builder.add_ttc_strategy(name, config)
|
|
1436
1528
|
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1529
|
+
# Check shared functions do not depend on per-user functions or function_groups
|
|
1530
|
+
for fn_name, fn_deps in self.function_dependencies.items():
|
|
1531
|
+
if fn_name == WORKFLOW_COMPONENT_NAME:
|
|
1532
|
+
continue
|
|
1533
|
+
|
|
1534
|
+
fn_config = self.get_function_config(fn_name)
|
|
1535
|
+
fn_registration = self._registry.get_function(type(fn_config))
|
|
1536
|
+
|
|
1537
|
+
if not fn_registration.is_per_user:
|
|
1538
|
+
for dep_fn_name in fn_deps.functions:
|
|
1539
|
+
dep_config = config.functions.get(dep_fn_name)
|
|
1540
|
+
if dep_config is not None:
|
|
1541
|
+
dep_registration = self._registry.get_function(type(dep_config))
|
|
1542
|
+
if dep_registration.is_per_user:
|
|
1543
|
+
raise ValueError(f"Function `{fn_name}` depends on per-user function `{dep_fn_name}`")
|
|
1544
|
+
|
|
1545
|
+
for dep_fg_name in fn_deps.function_groups:
|
|
1546
|
+
dep_config = config.function_groups.get(dep_fg_name)
|
|
1547
|
+
if dep_config is not None:
|
|
1548
|
+
dep_registration = self._registry.get_function_group(type(dep_config))
|
|
1549
|
+
if dep_registration.is_per_user:
|
|
1550
|
+
raise ValueError(f"Function `{fn_name}` depends on per-user function_group `{dep_fg_name}`")
|
|
1551
|
+
|
|
1552
|
+
# Check shared function_groups do not depend on per-user functions or function_groups
|
|
1553
|
+
for fg_name, fg_deps in self.function_group_dependencies.items():
|
|
1554
|
+
fg_config = self.get_function_group_config(fg_name)
|
|
1555
|
+
fg_registration = self._registry.get_function_group(type(fg_config))
|
|
1556
|
+
|
|
1557
|
+
if not fg_registration.is_per_user:
|
|
1558
|
+
for dep_fn_name in fg_deps.functions:
|
|
1559
|
+
dep_config = config.functions.get(dep_fn_name)
|
|
1560
|
+
if dep_config is not None:
|
|
1561
|
+
dep_registration = self._registry.get_function(type(dep_config))
|
|
1562
|
+
if dep_registration.is_per_user:
|
|
1563
|
+
raise ValueError(f"FunctionGroup `{fg_name}` depends on per-user function `{dep_fn_name}`")
|
|
1564
|
+
|
|
1565
|
+
for dep_fg_name in fg_deps.function_groups:
|
|
1566
|
+
dep_config = config.function_groups.get(dep_fg_name)
|
|
1567
|
+
if dep_config is not None:
|
|
1568
|
+
dep_registration = self._registry.get_function_group(type(dep_config))
|
|
1569
|
+
if dep_registration.is_per_user:
|
|
1570
|
+
raise ValueError(
|
|
1571
|
+
f"FunctionGroup `{fg_name}` depends on per-user function_group `{dep_fg_name}`")
|
|
1454
1572
|
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
@override
|
|
1460
|
-
async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None) -> Retriever:
|
|
1461
|
-
if not wrapper_type:
|
|
1462
|
-
return await self._workflow_builder.get_retriever(retriever_name=retriever_name)
|
|
1463
|
-
return await self._workflow_builder.get_retriever(retriever_name=retriever_name, wrapper_type=wrapper_type)
|
|
1464
|
-
|
|
1465
|
-
@override
|
|
1466
|
-
async def get_retriever_config(self, retriever_name: str) -> RetrieverBaseConfig:
|
|
1467
|
-
return await self._workflow_builder.get_retriever_config(retriever_name=retriever_name)
|
|
1468
|
-
|
|
1469
|
-
@override
|
|
1470
|
-
def get_user_manager(self) -> UserManagerHolder:
|
|
1471
|
-
return self._workflow_builder.get_user_manager()
|
|
1573
|
+
if self._workflow is not None:
|
|
1574
|
+
workflow_config = self.get_workflow_config()
|
|
1575
|
+
workflow_registration = self._registry.get_function(type(workflow_config))
|
|
1472
1576
|
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1577
|
+
# Per-user workflow must be owned by PerUserWorkflowBuilder
|
|
1578
|
+
if workflow_registration.is_per_user:
|
|
1579
|
+
raise ValueError("Workflow is a per-user function, but it is owned by a shared WorkflowBuilder")
|
|
1476
1580
|
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1581
|
+
else:
|
|
1582
|
+
workflow_deps = self.function_dependencies.get(WORKFLOW_COMPONENT_NAME, FunctionDependencies())
|
|
1583
|
+
|
|
1584
|
+
for dep_fn_name in workflow_deps.functions:
|
|
1585
|
+
if dep_fn_name in config.functions:
|
|
1586
|
+
dep_config = config.functions[dep_fn_name]
|
|
1587
|
+
if dep_config is not None:
|
|
1588
|
+
dep_registration = self._registry.get_function(type(dep_config))
|
|
1589
|
+
if dep_registration.is_per_user:
|
|
1590
|
+
raise ValueError(f"Shared Workflow depends on per-user function `{dep_fn_name}`")
|
|
1591
|
+
|
|
1592
|
+
for dep_fg_name in workflow_deps.function_groups:
|
|
1593
|
+
if dep_fg_name in config.function_groups:
|
|
1594
|
+
dep_config = config.function_groups[dep_fg_name]
|
|
1595
|
+
if dep_config is not None:
|
|
1596
|
+
dep_registration = self._registry.get_function_group(type(dep_config))
|
|
1597
|
+
if dep_registration.is_per_user:
|
|
1598
|
+
raise ValueError(f"Shared Workflow depends on per-user function_group `{dep_fg_name}`")
|
|
1485
1599
|
|
|
1486
|
-
@
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
return await self._workflow_builder.get_middleware(middleware_name)
|
|
1600
|
+
@classmethod
|
|
1601
|
+
@asynccontextmanager
|
|
1602
|
+
async def from_config(cls, config: Config):
|
|
1490
1603
|
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
return self._workflow_builder.get_middleware_config(middleware_name)
|
|
1604
|
+
async with cls(general_config=config.general) as builder:
|
|
1605
|
+
await builder.populate_builder(config)
|
|
1606
|
+
yield builder
|