nvidia-nat 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +27 -18
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +81 -50
- nat/agent/react_agent/register.py +59 -40
- nat/agent/reasoning_agent/reasoning_agent.py +17 -15
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +327 -149
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +64 -46
- nat/agent/tool_calling_agent/agent.py +152 -29
- nat/agent/tool_calling_agent/register.py +61 -38
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +10 -6
- nat/builder/context.py +70 -18
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/intermediate_step_manager.py +6 -2
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +327 -79
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +5 -2
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +105 -19
- nat/cli/entrypoint.py +17 -11
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +79 -10
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +196 -67
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +42 -18
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +2 -3
- nat/embedder/register.py +1 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +9 -6
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +19 -7
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +455 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +74 -50
- nat/front_ends/fastapi/message_validator.py +20 -21
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +47 -3
- nat/front_ends/mcp/mcp_front_end_plugin.py +48 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +120 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +57 -0
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +5 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +22 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +14 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +164 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +395 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +105 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +12 -3
- nat/utils/type_utils.py +9 -5
- nvidia_nat-1.3.0.dist-info/METADATA +195 -0
- {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/RECORD +244 -200
- {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- nvidia_nat-1.2.1.dist-info/METADATA +0 -365
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/top_level.txt +0 -0
nat/builder/workflow_builder.py
CHANGED
|
@@ -13,13 +13,17 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import dataclasses
|
|
17
18
|
import inspect
|
|
18
19
|
import logging
|
|
20
|
+
import typing
|
|
19
21
|
import warnings
|
|
22
|
+
from collections.abc import Sequence
|
|
20
23
|
from contextlib import AbstractAsyncContextManager
|
|
21
24
|
from contextlib import AsyncExitStack
|
|
22
25
|
from contextlib import asynccontextmanager
|
|
26
|
+
from typing import cast
|
|
23
27
|
|
|
24
28
|
from nat.authentication.interfaces import AuthProviderBase
|
|
25
29
|
from nat.builder.builder import Builder
|
|
@@ -31,6 +35,7 @@ from nat.builder.context import ContextState
|
|
|
31
35
|
from nat.builder.embedder import EmbedderProviderInfo
|
|
32
36
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
33
37
|
from nat.builder.function import Function
|
|
38
|
+
from nat.builder.function import FunctionGroup
|
|
34
39
|
from nat.builder.function import LambdaFunction
|
|
35
40
|
from nat.builder.function_info import FunctionInfo
|
|
36
41
|
from nat.builder.llm import LLMProviderInfo
|
|
@@ -42,6 +47,7 @@ from nat.data_models.authentication import AuthProviderBaseConfig
|
|
|
42
47
|
from nat.data_models.component import ComponentGroup
|
|
43
48
|
from nat.data_models.component_ref import AuthenticationRef
|
|
44
49
|
from nat.data_models.component_ref import EmbedderRef
|
|
50
|
+
from nat.data_models.component_ref import FunctionGroupRef
|
|
45
51
|
from nat.data_models.component_ref import FunctionRef
|
|
46
52
|
from nat.data_models.component_ref import LLMRef
|
|
47
53
|
from nat.data_models.component_ref import MemoryRef
|
|
@@ -52,6 +58,7 @@ from nat.data_models.config import Config
|
|
|
52
58
|
from nat.data_models.config import GeneralConfig
|
|
53
59
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
54
60
|
from nat.data_models.function import FunctionBaseConfig
|
|
61
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
55
62
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
56
63
|
from nat.data_models.llm import LLMBaseConfig
|
|
57
64
|
from nat.data_models.memory import MemoryBaseConfig
|
|
@@ -68,6 +75,7 @@ from nat.object_store.interfaces import ObjectStore
|
|
|
68
75
|
from nat.observability.exporter.base_exporter import BaseExporter
|
|
69
76
|
from nat.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
|
|
70
77
|
from nat.profiler.utils import detect_llm_frameworks_in_build_fn
|
|
78
|
+
from nat.retriever.interface import Retriever
|
|
71
79
|
from nat.utils.type_utils import override
|
|
72
80
|
|
|
73
81
|
logger = logging.getLogger(__name__)
|
|
@@ -85,6 +93,12 @@ class ConfiguredFunction:
|
|
|
85
93
|
instance: Function
|
|
86
94
|
|
|
87
95
|
|
|
96
|
+
@dataclasses.dataclass
|
|
97
|
+
class ConfiguredFunctionGroup:
|
|
98
|
+
config: FunctionGroupBaseConfig
|
|
99
|
+
instance: FunctionGroup
|
|
100
|
+
|
|
101
|
+
|
|
88
102
|
@dataclasses.dataclass
|
|
89
103
|
class ConfiguredLLM:
|
|
90
104
|
config: LLMBaseConfig
|
|
@@ -127,7 +141,6 @@ class ConfiguredTTCStrategy:
|
|
|
127
141
|
instance: StrategyBase
|
|
128
142
|
|
|
129
143
|
|
|
130
|
-
# pylint: disable=too-many-public-methods
|
|
131
144
|
class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
132
145
|
|
|
133
146
|
def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
|
|
@@ -143,9 +156,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
143
156
|
self._registry = registry
|
|
144
157
|
|
|
145
158
|
self._logging_handlers: dict[str, logging.Handler] = {}
|
|
159
|
+
self._removed_root_handlers: list[tuple[logging.Handler, int]] = []
|
|
146
160
|
self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
|
|
147
161
|
|
|
148
162
|
self._functions: dict[str, ConfiguredFunction] = {}
|
|
163
|
+
self._function_groups: dict[str, ConfiguredFunctionGroup] = {}
|
|
149
164
|
self._workflow: ConfiguredFunction | None = None
|
|
150
165
|
|
|
151
166
|
self._llms: dict[str, ConfiguredLLM] = {}
|
|
@@ -162,7 +177,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
162
177
|
|
|
163
178
|
# Create a mapping to track function name -> other function names it depends on
|
|
164
179
|
self.function_dependencies: dict[str, FunctionDependencies] = {}
|
|
180
|
+
self.function_group_dependencies: dict[str, FunctionDependencies] = {}
|
|
165
181
|
self.current_function_building: str | None = None
|
|
182
|
+
self.current_function_group_building: str | None = None
|
|
166
183
|
|
|
167
184
|
async def __aenter__(self):
|
|
168
185
|
|
|
@@ -171,6 +188,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
171
188
|
# Get the telemetry info from the config
|
|
172
189
|
telemetry_config = self.general_config.telemetry
|
|
173
190
|
|
|
191
|
+
# If we have logging configuration, we need to manage the root logger properly
|
|
192
|
+
root_logger = logging.getLogger()
|
|
193
|
+
|
|
194
|
+
# Collect configured handler types to determine if we need to adjust existing handlers
|
|
195
|
+
# This is somewhat of a hack by inspecting the class name of the config object
|
|
196
|
+
has_console_handler = any(
|
|
197
|
+
hasattr(config, "__class__") and "console" in config.__class__.__name__.lower()
|
|
198
|
+
for config in telemetry_config.logging.values())
|
|
199
|
+
|
|
174
200
|
for key, logging_config in telemetry_config.logging.items():
|
|
175
201
|
# Use the same pattern as tracing, but for logging
|
|
176
202
|
logging_info = self._registry.get_logging_method(type(logging_config))
|
|
@@ -184,7 +210,31 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
184
210
|
self._logging_handlers[key] = handler
|
|
185
211
|
|
|
186
212
|
# Now attach to NAT's root logger
|
|
187
|
-
|
|
213
|
+
root_logger.addHandler(handler)
|
|
214
|
+
|
|
215
|
+
# If we added logging handlers, manage existing handlers appropriately
|
|
216
|
+
if self._logging_handlers:
|
|
217
|
+
min_handler_level = min((handler.level for handler in root_logger.handlers), default=logging.CRITICAL)
|
|
218
|
+
|
|
219
|
+
# Ensure the root logger level allows messages through
|
|
220
|
+
root_logger.level = max(root_logger.level, min_handler_level)
|
|
221
|
+
|
|
222
|
+
# If a console handler is configured, adjust or remove default CLI handlers
|
|
223
|
+
# to avoid duplicate output while preserving workflow visibility
|
|
224
|
+
if has_console_handler:
|
|
225
|
+
# Remove existing StreamHandlers that are not the newly configured ones
|
|
226
|
+
for handler in root_logger.handlers[:]:
|
|
227
|
+
if type(handler) is logging.StreamHandler and handler not in self._logging_handlers.values():
|
|
228
|
+
self._removed_root_handlers.append((handler, handler.level))
|
|
229
|
+
root_logger.removeHandler(handler)
|
|
230
|
+
else:
|
|
231
|
+
# No console handler configured, but adjust existing handler levels
|
|
232
|
+
# to respect the minimum configured level for file/other handlers
|
|
233
|
+
for handler in root_logger.handlers[:]:
|
|
234
|
+
if type(handler) is logging.StreamHandler:
|
|
235
|
+
old_level = handler.level
|
|
236
|
+
handler.setLevel(min_handler_level)
|
|
237
|
+
self._removed_root_handlers.append((handler, old_level))
|
|
188
238
|
|
|
189
239
|
# Add the telemetry exporters
|
|
190
240
|
for key, telemetry_exporter_config in telemetry_config.tracing.items():
|
|
@@ -196,12 +246,21 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
196
246
|
|
|
197
247
|
assert self._exit_stack is not None, "Exit stack not initialized"
|
|
198
248
|
|
|
199
|
-
|
|
200
|
-
|
|
249
|
+
root_logger = logging.getLogger()
|
|
250
|
+
|
|
251
|
+
# Remove custom logging handlers
|
|
252
|
+
for handler in self._logging_handlers.values():
|
|
253
|
+
root_logger.removeHandler(handler)
|
|
254
|
+
|
|
255
|
+
# Restore original handlers and their levels
|
|
256
|
+
for handler, old_level in self._removed_root_handlers:
|
|
257
|
+
if handler not in root_logger.handlers:
|
|
258
|
+
root_logger.addHandler(handler)
|
|
259
|
+
handler.setLevel(old_level)
|
|
201
260
|
|
|
202
261
|
await self._exit_stack.__aexit__(*exc_details)
|
|
203
262
|
|
|
204
|
-
def build(self, entry_function: str | None = None) -> Workflow:
|
|
263
|
+
async def build(self, entry_function: str | None = None) -> Workflow:
|
|
205
264
|
"""
|
|
206
265
|
Creates an instance of a workflow object using the added components and the desired entry function.
|
|
207
266
|
|
|
@@ -225,12 +284,32 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
225
284
|
if (self._workflow is None):
|
|
226
285
|
raise ValueError("Must set a workflow before building")
|
|
227
286
|
|
|
287
|
+
# Set of all functions which are "included" by function groups
|
|
288
|
+
included_functions = set()
|
|
289
|
+
# Dictionary of function configs
|
|
290
|
+
function_configs = dict()
|
|
291
|
+
# Dictionary of function group configs
|
|
292
|
+
function_group_configs = dict()
|
|
293
|
+
# Dictionary of function instances
|
|
294
|
+
function_instances = dict()
|
|
295
|
+
# Dictionary of function group instances
|
|
296
|
+
function_group_instances = dict()
|
|
297
|
+
|
|
298
|
+
for k, v in self._function_groups.items():
|
|
299
|
+
included_functions.update((await v.instance.get_included_functions()).keys())
|
|
300
|
+
function_group_configs[k] = v.config
|
|
301
|
+
function_group_instances[k] = v.instance
|
|
302
|
+
|
|
303
|
+
# Function configs need to be restricted to only the functions that are not in a function group
|
|
304
|
+
for k, v in self._functions.items():
|
|
305
|
+
if k not in included_functions:
|
|
306
|
+
function_configs[k] = v.config
|
|
307
|
+
function_instances[k] = v.instance
|
|
308
|
+
|
|
228
309
|
# Build the config from the added objects
|
|
229
310
|
config = Config(general=self.general_config,
|
|
230
|
-
functions=
|
|
231
|
-
|
|
232
|
-
for k, v in self._functions.items()
|
|
233
|
-
},
|
|
311
|
+
functions=function_configs,
|
|
312
|
+
function_groups=function_group_configs,
|
|
234
313
|
workflow=self._workflow.config,
|
|
235
314
|
llms={
|
|
236
315
|
k: v.config
|
|
@@ -260,14 +339,12 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
260
339
|
if (entry_function is None):
|
|
261
340
|
entry_fn_obj = self.get_workflow()
|
|
262
341
|
else:
|
|
263
|
-
entry_fn_obj = self.get_function(entry_function)
|
|
342
|
+
entry_fn_obj = await self.get_function(entry_function)
|
|
264
343
|
|
|
265
344
|
workflow = Workflow.from_entry_fn(config=config,
|
|
266
345
|
entry_fn=entry_fn_obj,
|
|
267
|
-
functions=
|
|
268
|
-
|
|
269
|
-
for k, v in self._functions.items()
|
|
270
|
-
},
|
|
346
|
+
functions=function_instances,
|
|
347
|
+
function_groups=function_group_instances,
|
|
271
348
|
llms={
|
|
272
349
|
k: v.instance
|
|
273
350
|
for k, v in self._llms.items()
|
|
@@ -348,11 +425,53 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
348
425
|
|
|
349
426
|
return ConfiguredFunction(config=config, instance=build_result)
|
|
350
427
|
|
|
428
|
+
async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
|
|
429
|
+
"""Build a function group from the provided configuration.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
name: The name of the function group
|
|
433
|
+
config: The function group configuration
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
ConfiguredFunctionGroup: The built function group
|
|
437
|
+
|
|
438
|
+
Raises:
|
|
439
|
+
ValueError: If the function group builder returns invalid results
|
|
440
|
+
"""
|
|
441
|
+
registration = self._registry.get_function_group(type(config))
|
|
442
|
+
|
|
443
|
+
inner_builder = ChildBuilder(self)
|
|
444
|
+
|
|
445
|
+
# Build the function group - use the same wrapping pattern as _build_function
|
|
446
|
+
llms = {k: v.instance for k, v in self._llms.items()}
|
|
447
|
+
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
448
|
+
|
|
449
|
+
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
450
|
+
|
|
451
|
+
# Set the currently building function group so the ChildBuilder can track dependencies
|
|
452
|
+
self.current_function_group_building = config.type
|
|
453
|
+
# Empty set of dependencies for the current function group
|
|
454
|
+
self.function_group_dependencies[config.type] = FunctionDependencies()
|
|
455
|
+
|
|
456
|
+
build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
|
|
457
|
+
|
|
458
|
+
self.function_group_dependencies[name] = inner_builder.dependencies
|
|
459
|
+
|
|
460
|
+
if not isinstance(build_result, FunctionGroup):
|
|
461
|
+
raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
|
|
462
|
+
f"Got {type(build_result)}")
|
|
463
|
+
|
|
464
|
+
# set the instance name for the function group based on the workflow-provided name
|
|
465
|
+
build_result.set_instance_name(name)
|
|
466
|
+
return ConfiguredFunctionGroup(config=config, instance=build_result)
|
|
467
|
+
|
|
351
468
|
@override
|
|
352
469
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
470
|
+
if isinstance(name, FunctionRef):
|
|
471
|
+
name = str(name)
|
|
353
472
|
|
|
354
|
-
if (name in self._functions):
|
|
355
|
-
raise ValueError(f"Function `{name}` already exists in the list of functions")
|
|
473
|
+
if (name in self._functions or name in self._function_groups):
|
|
474
|
+
raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
|
|
356
475
|
|
|
357
476
|
build_result = await self._build_function(name=name, config=config)
|
|
358
477
|
|
|
@@ -361,20 +480,67 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
361
480
|
return build_result.instance
|
|
362
481
|
|
|
363
482
|
@override
|
|
364
|
-
def
|
|
483
|
+
async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
484
|
+
if isinstance(name, FunctionGroupRef):
|
|
485
|
+
name = str(name)
|
|
486
|
+
|
|
487
|
+
if (name in self._function_groups or name in self._functions):
|
|
488
|
+
raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
|
|
489
|
+
|
|
490
|
+
# Build the function group
|
|
491
|
+
build_result = await self._build_function_group(name=name, config=config)
|
|
492
|
+
|
|
493
|
+
self._function_groups[name] = build_result
|
|
494
|
+
|
|
495
|
+
# If the function group exposes functions, add them to the global function registry
|
|
496
|
+
# If the function group exposes functions, record and add them to the registry
|
|
497
|
+
included_functions = await build_result.instance.get_included_functions()
|
|
498
|
+
for k in included_functions:
|
|
499
|
+
if k in self._functions:
|
|
500
|
+
raise ValueError(f"Exposed function `{k}` from group `{name}` conflicts with an existing function")
|
|
501
|
+
self._functions.update({
|
|
502
|
+
k: ConfiguredFunction(config=v.config, instance=v)
|
|
503
|
+
for k, v in included_functions.items()
|
|
504
|
+
})
|
|
365
505
|
|
|
506
|
+
return build_result.instance
|
|
507
|
+
|
|
508
|
+
@override
|
|
509
|
+
async def get_function(self, name: str | FunctionRef) -> Function:
|
|
510
|
+
if isinstance(name, FunctionRef):
|
|
511
|
+
name = str(name)
|
|
366
512
|
if name not in self._functions:
|
|
367
513
|
raise ValueError(f"Function `{name}` not found")
|
|
368
514
|
|
|
369
515
|
return self._functions[name].instance
|
|
370
516
|
|
|
517
|
+
@override
|
|
518
|
+
async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
519
|
+
if isinstance(name, FunctionGroupRef):
|
|
520
|
+
name = str(name)
|
|
521
|
+
if name not in self._function_groups:
|
|
522
|
+
raise ValueError(f"Function group `{name}` not found")
|
|
523
|
+
|
|
524
|
+
return self._function_groups[name].instance
|
|
525
|
+
|
|
371
526
|
@override
|
|
372
527
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
528
|
+
if isinstance(name, FunctionRef):
|
|
529
|
+
name = str(name)
|
|
373
530
|
if name not in self._functions:
|
|
374
531
|
raise ValueError(f"Function `{name}` not found")
|
|
375
532
|
|
|
376
533
|
return self._functions[name].config
|
|
377
534
|
|
|
535
|
+
@override
|
|
536
|
+
def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
|
|
537
|
+
if isinstance(name, FunctionGroupRef):
|
|
538
|
+
name = str(name)
|
|
539
|
+
if name not in self._function_groups:
|
|
540
|
+
raise ValueError(f"Function group `{name}` not found")
|
|
541
|
+
|
|
542
|
+
return self._function_groups[name].config
|
|
543
|
+
|
|
378
544
|
@override
|
|
379
545
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
380
546
|
|
|
@@ -404,16 +570,57 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
404
570
|
|
|
405
571
|
@override
|
|
406
572
|
def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
|
|
573
|
+
if isinstance(fn_name, FunctionRef):
|
|
574
|
+
fn_name = str(fn_name)
|
|
407
575
|
return self.function_dependencies[fn_name]
|
|
408
576
|
|
|
409
577
|
@override
|
|
410
|
-
def
|
|
578
|
+
def get_function_group_dependencies(self, fn_name: str | FunctionGroupRef) -> FunctionDependencies:
|
|
579
|
+
if isinstance(fn_name, FunctionGroupRef):
|
|
580
|
+
fn_name = str(fn_name)
|
|
581
|
+
return self.function_group_dependencies[fn_name]
|
|
411
582
|
|
|
583
|
+
@override
|
|
584
|
+
async def get_tools(self,
|
|
585
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
586
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
587
|
+
|
|
588
|
+
unique = set(tool_names)
|
|
589
|
+
if len(unique) != len(tool_names):
|
|
590
|
+
raise ValueError("Tool names must be unique")
|
|
591
|
+
|
|
592
|
+
async def _get_tools(n: str | FunctionRef | FunctionGroupRef):
|
|
593
|
+
tools = []
|
|
594
|
+
is_function_group_ref = isinstance(n, FunctionGroupRef)
|
|
595
|
+
if isinstance(n, FunctionRef) or is_function_group_ref:
|
|
596
|
+
n = str(n)
|
|
597
|
+
if n not in self._function_groups:
|
|
598
|
+
# the passed tool name is probably a function, but first check if it's a function group
|
|
599
|
+
if is_function_group_ref:
|
|
600
|
+
raise ValueError(f"Function group `{n}` not found in the list of function groups")
|
|
601
|
+
tools.append(await self.get_tool(n, wrapper_type))
|
|
602
|
+
else:
|
|
603
|
+
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
604
|
+
current_function_group = self._function_groups[n]
|
|
605
|
+
for fn_name, fn_instance in (await current_function_group.instance.get_accessible_functions()).items():
|
|
606
|
+
try:
|
|
607
|
+
tools.append(tool_wrapper_reg.build_fn(fn_name, fn_instance, self))
|
|
608
|
+
except Exception:
|
|
609
|
+
logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
|
|
610
|
+
raise
|
|
611
|
+
return tools
|
|
612
|
+
|
|
613
|
+
tool_lists = await asyncio.gather(*[_get_tools(n) for n in tool_names])
|
|
614
|
+
# Flatten the list of lists into a single list
|
|
615
|
+
return [tool for tools in tool_lists for tool in tools]
|
|
616
|
+
|
|
617
|
+
@override
|
|
618
|
+
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
619
|
+
if isinstance(fn_name, FunctionRef):
|
|
620
|
+
fn_name = str(fn_name)
|
|
412
621
|
if fn_name not in self._functions:
|
|
413
622
|
raise ValueError(f"Function `{fn_name}` not found in list of functions")
|
|
414
|
-
|
|
415
623
|
fn = self._functions[fn_name]
|
|
416
|
-
|
|
417
624
|
try:
|
|
418
625
|
# Using the registry, get the tool wrapper for the requested framework
|
|
419
626
|
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
@@ -421,11 +628,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
421
628
|
# Wrap in the correct wrapper
|
|
422
629
|
return tool_wrapper_reg.build_fn(fn_name, fn.instance, self)
|
|
423
630
|
except Exception as e:
|
|
424
|
-
logger.error("Error fetching tool `%s
|
|
425
|
-
raise
|
|
631
|
+
logger.error("Error fetching tool `%s`: %s", fn_name, e)
|
|
632
|
+
raise
|
|
426
633
|
|
|
427
634
|
@override
|
|
428
|
-
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
|
|
635
|
+
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> None:
|
|
429
636
|
|
|
430
637
|
if (name in self._llms):
|
|
431
638
|
raise ValueError(f"LLM `{name}` already exists in the list of LLMs")
|
|
@@ -437,11 +644,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
437
644
|
|
|
438
645
|
self._llms[name] = ConfiguredLLM(config=config, instance=info_obj)
|
|
439
646
|
except Exception as e:
|
|
440
|
-
logger.error("Error adding llm `%s` with config `%s
|
|
441
|
-
raise
|
|
647
|
+
logger.error("Error adding llm `%s` with config `%s`: %s", name, config, e)
|
|
648
|
+
raise
|
|
442
649
|
|
|
443
650
|
@override
|
|
444
|
-
async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str):
|
|
651
|
+
async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
445
652
|
|
|
446
653
|
if (llm_name not in self._llms):
|
|
447
654
|
raise ValueError(f"LLM `{llm_name}` not found")
|
|
@@ -458,8 +665,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
458
665
|
# Return a frameworks specific client
|
|
459
666
|
return client
|
|
460
667
|
except Exception as e:
|
|
461
|
-
logger.error("Error getting llm `%s` with wrapper `%s
|
|
462
|
-
raise
|
|
668
|
+
logger.error("Error getting llm `%s` with wrapper `%s`: %s", llm_name, wrapper_type, e)
|
|
669
|
+
raise
|
|
463
670
|
|
|
464
671
|
@override
|
|
465
672
|
def get_llm_config(self, llm_name: str | LLMRef) -> LLMBaseConfig:
|
|
@@ -509,8 +716,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
509
716
|
|
|
510
717
|
return info_obj
|
|
511
718
|
except Exception as e:
|
|
512
|
-
logger.error("Error adding authentication `%s` with config `%s
|
|
513
|
-
raise
|
|
719
|
+
logger.error("Error adding authentication `%s` with config `%s`: %s", name, config, e)
|
|
720
|
+
raise
|
|
514
721
|
|
|
515
722
|
@override
|
|
516
723
|
async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
|
|
@@ -541,7 +748,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
541
748
|
return self._auth_providers[auth_provider_name].instance
|
|
542
749
|
|
|
543
750
|
@override
|
|
544
|
-
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
|
|
751
|
+
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
|
|
545
752
|
|
|
546
753
|
if (name in self._embedders):
|
|
547
754
|
raise ValueError(f"Embedder `{name}` already exists in the list of embedders")
|
|
@@ -553,9 +760,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
553
760
|
|
|
554
761
|
self._embedders[name] = ConfiguredEmbedder(config=config, instance=info_obj)
|
|
555
762
|
except Exception as e:
|
|
556
|
-
logger.error("Error adding embedder `%s` with config `%s
|
|
557
|
-
|
|
558
|
-
raise e
|
|
763
|
+
logger.error("Error adding embedder `%s` with config `%s`: %s", name, config, e)
|
|
764
|
+
raise
|
|
559
765
|
|
|
560
766
|
@override
|
|
561
767
|
async def get_embedder(self, embedder_name: str | EmbedderRef, wrapper_type: LLMFrameworkEnum | str):
|
|
@@ -575,8 +781,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
575
781
|
# Return a frameworks specific client
|
|
576
782
|
return client
|
|
577
783
|
except Exception as e:
|
|
578
|
-
logger.error("Error getting embedder `%s` with wrapper `%s
|
|
579
|
-
raise
|
|
784
|
+
logger.error("Error getting embedder `%s` with wrapper `%s`: %s", embedder_name, wrapper_type, e)
|
|
785
|
+
raise
|
|
580
786
|
|
|
581
787
|
@override
|
|
582
788
|
def get_embedder_config(self, embedder_name: str | EmbedderRef) -> EmbedderBaseConfig:
|
|
@@ -602,7 +808,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
602
808
|
return info_obj
|
|
603
809
|
|
|
604
810
|
@override
|
|
605
|
-
def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
811
|
+
async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
606
812
|
"""
|
|
607
813
|
Return the instantiated memory client for the given name.
|
|
608
814
|
"""
|
|
@@ -648,7 +854,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
648
854
|
return self._object_stores[object_store_name].config
|
|
649
855
|
|
|
650
856
|
@override
|
|
651
|
-
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
|
|
857
|
+
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
|
|
652
858
|
|
|
653
859
|
if (name in self._retrievers):
|
|
654
860
|
raise ValueError(f"Retriever '{name}' already exists in the list of retrievers")
|
|
@@ -661,11 +867,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
661
867
|
self._retrievers[name] = ConfiguredRetriever(config=config, instance=info_obj)
|
|
662
868
|
|
|
663
869
|
except Exception as e:
|
|
664
|
-
logger.error("Error adding retriever `%s` with config `%s
|
|
665
|
-
|
|
666
|
-
raise e
|
|
667
|
-
|
|
668
|
-
# return info_obj
|
|
870
|
+
logger.error("Error adding retriever `%s` with config `%s`: %s", name, config, e)
|
|
871
|
+
raise
|
|
669
872
|
|
|
670
873
|
@override
|
|
671
874
|
async def get_retriever(self,
|
|
@@ -688,8 +891,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
688
891
|
# Return a frameworks specific client
|
|
689
892
|
return client
|
|
690
893
|
except Exception as e:
|
|
691
|
-
logger.error("Error getting retriever `%s` with wrapper `%s
|
|
692
|
-
raise
|
|
894
|
+
logger.error("Error getting retriever `%s` with wrapper `%s`: %s", retriever_name, wrapper_type, e)
|
|
895
|
+
raise
|
|
693
896
|
|
|
694
897
|
@override
|
|
695
898
|
async def get_retriever_config(self, retriever_name: str | RetrieverRef) -> RetrieverBaseConfig:
|
|
@@ -699,9 +902,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
699
902
|
|
|
700
903
|
return self._retrievers[retriever_name].config
|
|
701
904
|
|
|
702
|
-
@experimental(feature_name="TTC")
|
|
703
905
|
@override
|
|
704
|
-
|
|
906
|
+
@experimental(feature_name="TTC")
|
|
907
|
+
async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig) -> None:
|
|
705
908
|
if (name in self._ttc_strategies):
|
|
706
909
|
raise ValueError(f"TTC strategy '{name}' already exists in the list of TTC strategies")
|
|
707
910
|
|
|
@@ -713,9 +916,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
713
916
|
self._ttc_strategies[name] = ConfiguredTTCStrategy(config=config, instance=info_obj)
|
|
714
917
|
|
|
715
918
|
except Exception as e:
|
|
716
|
-
logger.error("Error adding TTC strategy `%s` with config `%s
|
|
717
|
-
|
|
718
|
-
raise e
|
|
919
|
+
logger.error("Error adding TTC strategy `%s` with config `%s`: %s", name, config, e)
|
|
920
|
+
raise
|
|
719
921
|
|
|
720
922
|
@override
|
|
721
923
|
async def get_ttc_strategy(self,
|
|
@@ -743,8 +945,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
743
945
|
|
|
744
946
|
return instance
|
|
745
947
|
except Exception as e:
|
|
746
|
-
logger.error("Error getting TTC strategy `%s
|
|
747
|
-
raise
|
|
948
|
+
logger.error("Error getting TTC strategy `%s`: %s", strategy_name, e)
|
|
949
|
+
raise
|
|
748
950
|
|
|
749
951
|
@override
|
|
750
952
|
async def get_ttc_strategy_config(self,
|
|
@@ -821,7 +1023,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
821
1023
|
else:
|
|
822
1024
|
logger.error("No remaining components to build")
|
|
823
1025
|
|
|
824
|
-
logger.error("Original error:", exc_info=
|
|
1026
|
+
logger.error("Original error: %s", original_error, exc_info=True)
|
|
825
1027
|
|
|
826
1028
|
def _log_build_failure_component(self,
|
|
827
1029
|
failing_component: ComponentInstanceData,
|
|
@@ -889,29 +1091,40 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
889
1091
|
|
|
890
1092
|
# Instantiate a the llm
|
|
891
1093
|
if component_instance.component_group == ComponentGroup.LLMS:
|
|
892
|
-
await self.add_llm(component_instance.name, component_instance.config)
|
|
1094
|
+
await self.add_llm(component_instance.name, cast(LLMBaseConfig, component_instance.config))
|
|
893
1095
|
# Instantiate a the embedder
|
|
894
1096
|
elif component_instance.component_group == ComponentGroup.EMBEDDERS:
|
|
895
|
-
await self.add_embedder(component_instance.name,
|
|
1097
|
+
await self.add_embedder(component_instance.name,
|
|
1098
|
+
cast(EmbedderBaseConfig, component_instance.config))
|
|
896
1099
|
# Instantiate a memory client
|
|
897
1100
|
elif component_instance.component_group == ComponentGroup.MEMORY:
|
|
898
|
-
await self.add_memory_client(component_instance.name,
|
|
899
|
-
|
|
1101
|
+
await self.add_memory_client(component_instance.name,
|
|
1102
|
+
cast(MemoryBaseConfig, component_instance.config))
|
|
1103
|
+
# Instantiate a object store client
|
|
900
1104
|
elif component_instance.component_group == ComponentGroup.OBJECT_STORES:
|
|
901
|
-
await self.add_object_store(component_instance.name,
|
|
1105
|
+
await self.add_object_store(component_instance.name,
|
|
1106
|
+
cast(ObjectStoreBaseConfig, component_instance.config))
|
|
902
1107
|
# Instantiate a retriever client
|
|
903
1108
|
elif component_instance.component_group == ComponentGroup.RETRIEVERS:
|
|
904
|
-
await self.add_retriever(component_instance.name,
|
|
1109
|
+
await self.add_retriever(component_instance.name,
|
|
1110
|
+
cast(RetrieverBaseConfig, component_instance.config))
|
|
1111
|
+
# Instantiate a function group
|
|
1112
|
+
elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
|
|
1113
|
+
await self.add_function_group(component_instance.name,
|
|
1114
|
+
cast(FunctionGroupBaseConfig, component_instance.config))
|
|
905
1115
|
# Instantiate a function
|
|
906
1116
|
elif component_instance.component_group == ComponentGroup.FUNCTIONS:
|
|
907
1117
|
# If the function is the root, set it as the workflow later
|
|
908
1118
|
if (not component_instance.is_root):
|
|
909
|
-
await self.add_function(component_instance.name,
|
|
1119
|
+
await self.add_function(component_instance.name,
|
|
1120
|
+
cast(FunctionBaseConfig, component_instance.config))
|
|
910
1121
|
elif component_instance.component_group == ComponentGroup.TTC_STRATEGIES:
|
|
911
|
-
await self.add_ttc_strategy(component_instance.name,
|
|
1122
|
+
await self.add_ttc_strategy(component_instance.name,
|
|
1123
|
+
cast(TTCStrategyBaseConfig, component_instance.config))
|
|
912
1124
|
|
|
913
1125
|
elif component_instance.component_group == ComponentGroup.AUTHENTICATION:
|
|
914
|
-
await self.add_auth_provider(component_instance.name,
|
|
1126
|
+
await self.add_auth_provider(component_instance.name,
|
|
1127
|
+
cast(AuthProviderBaseConfig, component_instance.config))
|
|
915
1128
|
else:
|
|
916
1129
|
raise ValueError(f"Unknown component group {component_instance.component_group}")
|
|
917
1130
|
|
|
@@ -961,18 +1174,35 @@ class ChildBuilder(Builder):
|
|
|
961
1174
|
return await self._workflow_builder.add_function(name, config)
|
|
962
1175
|
|
|
963
1176
|
@override
|
|
964
|
-
def
|
|
1177
|
+
async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
1178
|
+
return await self._workflow_builder.add_function_group(name, config)
|
|
1179
|
+
|
|
1180
|
+
@override
|
|
1181
|
+
async def get_function(self, name: str) -> Function:
|
|
965
1182
|
# If a function tries to get another function, we assume it uses it
|
|
966
|
-
fn = self._workflow_builder.get_function(name)
|
|
1183
|
+
fn = await self._workflow_builder.get_function(name)
|
|
967
1184
|
|
|
968
1185
|
self._dependencies.add_function(name)
|
|
969
1186
|
|
|
970
1187
|
return fn
|
|
971
1188
|
|
|
1189
|
+
@override
|
|
1190
|
+
async def get_function_group(self, name: str) -> FunctionGroup:
|
|
1191
|
+
# If a function tries to get a function group, we assume it uses it
|
|
1192
|
+
function_group = await self._workflow_builder.get_function_group(name)
|
|
1193
|
+
|
|
1194
|
+
self._dependencies.add_function_group(name)
|
|
1195
|
+
|
|
1196
|
+
return function_group
|
|
1197
|
+
|
|
972
1198
|
@override
|
|
973
1199
|
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
|
974
1200
|
return self._workflow_builder.get_function_config(name)
|
|
975
1201
|
|
|
1202
|
+
@override
|
|
1203
|
+
def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
|
|
1204
|
+
return self._workflow_builder.get_function_group_config(name)
|
|
1205
|
+
|
|
976
1206
|
@override
|
|
977
1207
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
978
1208
|
return await self._workflow_builder.set_workflow(config)
|
|
@@ -986,20 +1216,33 @@ class ChildBuilder(Builder):
|
|
|
986
1216
|
return self._workflow_builder.get_workflow_config()
|
|
987
1217
|
|
|
988
1218
|
@override
|
|
989
|
-
def
|
|
1219
|
+
async def get_tools(self,
|
|
1220
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
1221
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
1222
|
+
tools = await self._workflow_builder.get_tools(tool_names, wrapper_type)
|
|
1223
|
+
for tool_name in tool_names:
|
|
1224
|
+
if tool_name in self._workflow_builder._function_groups:
|
|
1225
|
+
self._dependencies.add_function_group(tool_name)
|
|
1226
|
+
else:
|
|
1227
|
+
self._dependencies.add_function(tool_name)
|
|
1228
|
+
return tools
|
|
1229
|
+
|
|
1230
|
+
@override
|
|
1231
|
+
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
|
|
990
1232
|
# If a function tries to get another function as a tool, we assume it uses it
|
|
991
|
-
fn = self._workflow_builder.get_tool(fn_name, wrapper_type)
|
|
1233
|
+
fn = await self._workflow_builder.get_tool(fn_name, wrapper_type)
|
|
992
1234
|
|
|
993
1235
|
self._dependencies.add_function(fn_name)
|
|
994
1236
|
|
|
995
1237
|
return fn
|
|
996
1238
|
|
|
997
1239
|
@override
|
|
998
|
-
async def add_llm(self, name: str, config: LLMBaseConfig):
|
|
1240
|
+
async def add_llm(self, name: str, config: LLMBaseConfig) -> None:
|
|
999
1241
|
return await self._workflow_builder.add_llm(name, config)
|
|
1000
1242
|
|
|
1243
|
+
@experimental(feature_name="Authentication")
|
|
1001
1244
|
@override
|
|
1002
|
-
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig):
|
|
1245
|
+
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> AuthProviderBase:
|
|
1003
1246
|
return await self._workflow_builder.add_auth_provider(name, config)
|
|
1004
1247
|
|
|
1005
1248
|
@override
|
|
@@ -1007,7 +1250,7 @@ class ChildBuilder(Builder):
|
|
|
1007
1250
|
return await self._workflow_builder.get_auth_provider(auth_provider_name)
|
|
1008
1251
|
|
|
1009
1252
|
@override
|
|
1010
|
-
async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str):
|
|
1253
|
+
async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
1011
1254
|
llm = await self._workflow_builder.get_llm(llm_name, wrapper_type)
|
|
1012
1255
|
|
|
1013
1256
|
self._dependencies.add_llm(llm_name)
|
|
@@ -1019,11 +1262,11 @@ class ChildBuilder(Builder):
|
|
|
1019
1262
|
return self._workflow_builder.get_llm_config(llm_name)
|
|
1020
1263
|
|
|
1021
1264
|
@override
|
|
1022
|
-
async def add_embedder(self, name: str, config: EmbedderBaseConfig):
|
|
1023
|
-
|
|
1265
|
+
async def add_embedder(self, name: str, config: EmbedderBaseConfig) -> None:
|
|
1266
|
+
await self._workflow_builder.add_embedder(name, config)
|
|
1024
1267
|
|
|
1025
1268
|
@override
|
|
1026
|
-
async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str):
|
|
1269
|
+
async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
1027
1270
|
embedder = await self._workflow_builder.get_embedder(embedder_name, wrapper_type)
|
|
1028
1271
|
|
|
1029
1272
|
self._dependencies.add_embedder(embedder_name)
|
|
@@ -1039,11 +1282,11 @@ class ChildBuilder(Builder):
|
|
|
1039
1282
|
return await self._workflow_builder.add_memory_client(name, config)
|
|
1040
1283
|
|
|
1041
1284
|
@override
|
|
1042
|
-
def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
|
1285
|
+
async def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
|
1043
1286
|
"""
|
|
1044
1287
|
Return the instantiated memory client for the given name.
|
|
1045
1288
|
"""
|
|
1046
|
-
memory_client = self._workflow_builder.get_memory_client(memory_name)
|
|
1289
|
+
memory_client = await self._workflow_builder.get_memory_client(memory_name)
|
|
1047
1290
|
|
|
1048
1291
|
self._dependencies.add_memory_client(memory_name)
|
|
1049
1292
|
|
|
@@ -1073,8 +1316,9 @@ class ChildBuilder(Builder):
|
|
|
1073
1316
|
return self._workflow_builder.get_object_store_config(object_store_name)
|
|
1074
1317
|
|
|
1075
1318
|
@override
|
|
1076
|
-
|
|
1077
|
-
|
|
1319
|
+
@experimental(feature_name="TTC")
|
|
1320
|
+
async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
|
|
1321
|
+
await self._workflow_builder.add_ttc_strategy(name, config)
|
|
1078
1322
|
|
|
1079
1323
|
@override
|
|
1080
1324
|
async def get_ttc_strategy(self,
|
|
@@ -1095,11 +1339,11 @@ class ChildBuilder(Builder):
|
|
|
1095
1339
|
stage_type=stage_type)
|
|
1096
1340
|
|
|
1097
1341
|
@override
|
|
1098
|
-
async def add_retriever(self, name: str, config: RetrieverBaseConfig):
|
|
1099
|
-
|
|
1342
|
+
async def add_retriever(self, name: str, config: RetrieverBaseConfig) -> None:
|
|
1343
|
+
await self._workflow_builder.add_retriever(name, config)
|
|
1100
1344
|
|
|
1101
1345
|
@override
|
|
1102
|
-
async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None):
|
|
1346
|
+
async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None) -> Retriever:
|
|
1103
1347
|
if not wrapper_type:
|
|
1104
1348
|
return await self._workflow_builder.get_retriever(retriever_name=retriever_name)
|
|
1105
1349
|
return await self._workflow_builder.get_retriever(retriever_name=retriever_name, wrapper_type=wrapper_type)
|
|
@@ -1115,3 +1359,7 @@ class ChildBuilder(Builder):
|
|
|
1115
1359
|
@override
|
|
1116
1360
|
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1117
1361
|
return self._workflow_builder.get_function_dependencies(fn_name)
|
|
1362
|
+
|
|
1363
|
+
@override
|
|
1364
|
+
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1365
|
+
return self._workflow_builder.get_function_group_dependencies(fn_name)
|