nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
nat/builder/workflow_builder.py
CHANGED
|
@@ -13,13 +13,17 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import dataclasses
|
|
17
18
|
import inspect
|
|
18
19
|
import logging
|
|
20
|
+
import typing
|
|
19
21
|
import warnings
|
|
22
|
+
from collections.abc import Sequence
|
|
20
23
|
from contextlib import AbstractAsyncContextManager
|
|
21
24
|
from contextlib import AsyncExitStack
|
|
22
25
|
from contextlib import asynccontextmanager
|
|
26
|
+
from typing import cast
|
|
23
27
|
|
|
24
28
|
from nat.authentication.interfaces import AuthProviderBase
|
|
25
29
|
from nat.builder.builder import Builder
|
|
@@ -31,6 +35,7 @@ from nat.builder.context import ContextState
|
|
|
31
35
|
from nat.builder.embedder import EmbedderProviderInfo
|
|
32
36
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
33
37
|
from nat.builder.function import Function
|
|
38
|
+
from nat.builder.function import FunctionGroup
|
|
34
39
|
from nat.builder.function import LambdaFunction
|
|
35
40
|
from nat.builder.function_info import FunctionInfo
|
|
36
41
|
from nat.builder.llm import LLMProviderInfo
|
|
@@ -42,6 +47,7 @@ from nat.data_models.authentication import AuthProviderBaseConfig
|
|
|
42
47
|
from nat.data_models.component import ComponentGroup
|
|
43
48
|
from nat.data_models.component_ref import AuthenticationRef
|
|
44
49
|
from nat.data_models.component_ref import EmbedderRef
|
|
50
|
+
from nat.data_models.component_ref import FunctionGroupRef
|
|
45
51
|
from nat.data_models.component_ref import FunctionRef
|
|
46
52
|
from nat.data_models.component_ref import LLMRef
|
|
47
53
|
from nat.data_models.component_ref import MemoryRef
|
|
@@ -52,6 +58,7 @@ from nat.data_models.config import Config
|
|
|
52
58
|
from nat.data_models.config import GeneralConfig
|
|
53
59
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
54
60
|
from nat.data_models.function import FunctionBaseConfig
|
|
61
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
55
62
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
56
63
|
from nat.data_models.llm import LLMBaseConfig
|
|
57
64
|
from nat.data_models.memory import MemoryBaseConfig
|
|
@@ -68,6 +75,7 @@ from nat.object_store.interfaces import ObjectStore
|
|
|
68
75
|
from nat.observability.exporter.base_exporter import BaseExporter
|
|
69
76
|
from nat.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
|
|
70
77
|
from nat.profiler.utils import detect_llm_frameworks_in_build_fn
|
|
78
|
+
from nat.retriever.interface import Retriever
|
|
71
79
|
from nat.utils.type_utils import override
|
|
72
80
|
|
|
73
81
|
logger = logging.getLogger(__name__)
|
|
@@ -85,6 +93,12 @@ class ConfiguredFunction:
|
|
|
85
93
|
instance: Function
|
|
86
94
|
|
|
87
95
|
|
|
96
|
+
@dataclasses.dataclass
|
|
97
|
+
class ConfiguredFunctionGroup:
|
|
98
|
+
config: FunctionGroupBaseConfig
|
|
99
|
+
instance: FunctionGroup
|
|
100
|
+
|
|
101
|
+
|
|
88
102
|
@dataclasses.dataclass
|
|
89
103
|
class ConfiguredLLM:
|
|
90
104
|
config: LLMBaseConfig
|
|
@@ -127,7 +141,6 @@ class ConfiguredTTCStrategy:
|
|
|
127
141
|
instance: StrategyBase
|
|
128
142
|
|
|
129
143
|
|
|
130
|
-
# pylint: disable=too-many-public-methods
|
|
131
144
|
class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
132
145
|
|
|
133
146
|
def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
|
|
@@ -146,6 +159,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
146
159
|
self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
|
|
147
160
|
|
|
148
161
|
self._functions: dict[str, ConfiguredFunction] = {}
|
|
162
|
+
self._function_groups: dict[str, ConfiguredFunctionGroup] = {}
|
|
149
163
|
self._workflow: ConfiguredFunction | None = None
|
|
150
164
|
|
|
151
165
|
self._llms: dict[str, ConfiguredLLM] = {}
|
|
@@ -162,7 +176,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
162
176
|
|
|
163
177
|
# Create a mapping to track function name -> other function names it depends on
|
|
164
178
|
self.function_dependencies: dict[str, FunctionDependencies] = {}
|
|
179
|
+
self.function_group_dependencies: dict[str, FunctionDependencies] = {}
|
|
165
180
|
self.current_function_building: str | None = None
|
|
181
|
+
self.current_function_group_building: str | None = None
|
|
166
182
|
|
|
167
183
|
async def __aenter__(self):
|
|
168
184
|
|
|
@@ -201,7 +217,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
201
217
|
|
|
202
218
|
await self._exit_stack.__aexit__(*exc_details)
|
|
203
219
|
|
|
204
|
-
def build(self, entry_function: str | None = None) -> Workflow:
|
|
220
|
+
async def build(self, entry_function: str | None = None) -> Workflow:
|
|
205
221
|
"""
|
|
206
222
|
Creates an instance of a workflow object using the added components and the desired entry function.
|
|
207
223
|
|
|
@@ -225,12 +241,32 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
225
241
|
if (self._workflow is None):
|
|
226
242
|
raise ValueError("Must set a workflow before building")
|
|
227
243
|
|
|
244
|
+
# Set of all functions which are "included" by function groups
|
|
245
|
+
included_functions = set()
|
|
246
|
+
# Dictionary of function configs
|
|
247
|
+
function_configs = dict()
|
|
248
|
+
# Dictionary of function group configs
|
|
249
|
+
function_group_configs = dict()
|
|
250
|
+
# Dictionary of function instances
|
|
251
|
+
function_instances = dict()
|
|
252
|
+
# Dictionary of function group instances
|
|
253
|
+
function_group_instances = dict()
|
|
254
|
+
|
|
255
|
+
for k, v in self._function_groups.items():
|
|
256
|
+
included_functions.update((await v.instance.get_included_functions()).keys())
|
|
257
|
+
function_group_configs[k] = v.config
|
|
258
|
+
function_group_instances[k] = v.instance
|
|
259
|
+
|
|
260
|
+
# Function configs need to be restricted to only the functions that are not in a function group
|
|
261
|
+
for k, v in self._functions.items():
|
|
262
|
+
if k not in included_functions:
|
|
263
|
+
function_configs[k] = v.config
|
|
264
|
+
function_instances[k] = v.instance
|
|
265
|
+
|
|
228
266
|
# Build the config from the added objects
|
|
229
267
|
config = Config(general=self.general_config,
|
|
230
|
-
functions=
|
|
231
|
-
|
|
232
|
-
for k, v in self._functions.items()
|
|
233
|
-
},
|
|
268
|
+
functions=function_configs,
|
|
269
|
+
function_groups=function_group_configs,
|
|
234
270
|
workflow=self._workflow.config,
|
|
235
271
|
llms={
|
|
236
272
|
k: v.config
|
|
@@ -260,14 +296,12 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
260
296
|
if (entry_function is None):
|
|
261
297
|
entry_fn_obj = self.get_workflow()
|
|
262
298
|
else:
|
|
263
|
-
entry_fn_obj = self.get_function(entry_function)
|
|
299
|
+
entry_fn_obj = await self.get_function(entry_function)
|
|
264
300
|
|
|
265
301
|
workflow = Workflow.from_entry_fn(config=config,
|
|
266
302
|
entry_fn=entry_fn_obj,
|
|
267
|
-
functions=
|
|
268
|
-
|
|
269
|
-
for k, v in self._functions.items()
|
|
270
|
-
},
|
|
303
|
+
functions=function_instances,
|
|
304
|
+
function_groups=function_group_instances,
|
|
271
305
|
llms={
|
|
272
306
|
k: v.instance
|
|
273
307
|
for k, v in self._llms.items()
|
|
@@ -348,11 +382,53 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
348
382
|
|
|
349
383
|
return ConfiguredFunction(config=config, instance=build_result)
|
|
350
384
|
|
|
385
|
+
async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
|
|
386
|
+
"""Build a function group from the provided configuration.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
name: The name of the function group
|
|
390
|
+
config: The function group configuration
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
ConfiguredFunctionGroup: The built function group
|
|
394
|
+
|
|
395
|
+
Raises:
|
|
396
|
+
ValueError: If the function group builder returns invalid results
|
|
397
|
+
"""
|
|
398
|
+
registration = self._registry.get_function_group(type(config))
|
|
399
|
+
|
|
400
|
+
inner_builder = ChildBuilder(self)
|
|
401
|
+
|
|
402
|
+
# Build the function group - use the same wrapping pattern as _build_function
|
|
403
|
+
llms = {k: v.instance for k, v in self._llms.items()}
|
|
404
|
+
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
405
|
+
|
|
406
|
+
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
407
|
+
|
|
408
|
+
# Set the currently building function group so the ChildBuilder can track dependencies
|
|
409
|
+
self.current_function_group_building = config.type
|
|
410
|
+
# Empty set of dependencies for the current function group
|
|
411
|
+
self.function_group_dependencies[config.type] = FunctionDependencies()
|
|
412
|
+
|
|
413
|
+
build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
|
|
414
|
+
|
|
415
|
+
self.function_group_dependencies[name] = inner_builder.dependencies
|
|
416
|
+
|
|
417
|
+
if not isinstance(build_result, FunctionGroup):
|
|
418
|
+
raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
|
|
419
|
+
f"Got {type(build_result)}")
|
|
420
|
+
|
|
421
|
+
# set the instance name for the function group based on the workflow-provided name
|
|
422
|
+
build_result.set_instance_name(name)
|
|
423
|
+
return ConfiguredFunctionGroup(config=config, instance=build_result)
|
|
424
|
+
|
|
351
425
|
@override
|
|
352
426
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
427
|
+
if isinstance(name, FunctionRef):
|
|
428
|
+
name = str(name)
|
|
353
429
|
|
|
354
|
-
if (name in self._functions):
|
|
355
|
-
raise ValueError(f"Function `{name}` already exists in the list of functions")
|
|
430
|
+
if (name in self._functions or name in self._function_groups):
|
|
431
|
+
raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
|
|
356
432
|
|
|
357
433
|
build_result = await self._build_function(name=name, config=config)
|
|
358
434
|
|
|
@@ -361,20 +437,67 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
361
437
|
return build_result.instance
|
|
362
438
|
|
|
363
439
|
@override
|
|
364
|
-
def
|
|
440
|
+
async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
441
|
+
if isinstance(name, FunctionGroupRef):
|
|
442
|
+
name = str(name)
|
|
443
|
+
|
|
444
|
+
if (name in self._function_groups or name in self._functions):
|
|
445
|
+
raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
|
|
446
|
+
|
|
447
|
+
# Build the function group
|
|
448
|
+
build_result = await self._build_function_group(name=name, config=config)
|
|
449
|
+
|
|
450
|
+
self._function_groups[name] = build_result
|
|
451
|
+
|
|
452
|
+
# If the function group exposes functions, add them to the global function registry
|
|
453
|
+
# If the function group exposes functions, record and add them to the registry
|
|
454
|
+
included_functions = await build_result.instance.get_included_functions()
|
|
455
|
+
for k in included_functions:
|
|
456
|
+
if k in self._functions:
|
|
457
|
+
raise ValueError(f"Exposed function `{k}` from group `{name}` conflicts with an existing function")
|
|
458
|
+
self._functions.update({
|
|
459
|
+
k: ConfiguredFunction(config=v.config, instance=v)
|
|
460
|
+
for k, v in included_functions.items()
|
|
461
|
+
})
|
|
462
|
+
|
|
463
|
+
return build_result.instance
|
|
365
464
|
|
|
465
|
+
@override
|
|
466
|
+
async def get_function(self, name: str | FunctionRef) -> Function:
|
|
467
|
+
if isinstance(name, FunctionRef):
|
|
468
|
+
name = str(name)
|
|
366
469
|
if name not in self._functions:
|
|
367
470
|
raise ValueError(f"Function `{name}` not found")
|
|
368
471
|
|
|
369
472
|
return self._functions[name].instance
|
|
370
473
|
|
|
474
|
+
@override
|
|
475
|
+
async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
476
|
+
if isinstance(name, FunctionGroupRef):
|
|
477
|
+
name = str(name)
|
|
478
|
+
if name not in self._function_groups:
|
|
479
|
+
raise ValueError(f"Function group `{name}` not found")
|
|
480
|
+
|
|
481
|
+
return self._function_groups[name].instance
|
|
482
|
+
|
|
371
483
|
@override
|
|
372
484
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
485
|
+
if isinstance(name, FunctionRef):
|
|
486
|
+
name = str(name)
|
|
373
487
|
if name not in self._functions:
|
|
374
488
|
raise ValueError(f"Function `{name}` not found")
|
|
375
489
|
|
|
376
490
|
return self._functions[name].config
|
|
377
491
|
|
|
492
|
+
@override
|
|
493
|
+
def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
|
|
494
|
+
if isinstance(name, FunctionGroupRef):
|
|
495
|
+
name = str(name)
|
|
496
|
+
if name not in self._function_groups:
|
|
497
|
+
raise ValueError(f"Function group `{name}` not found")
|
|
498
|
+
|
|
499
|
+
return self._function_groups[name].config
|
|
500
|
+
|
|
378
501
|
@override
|
|
379
502
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
380
503
|
|
|
@@ -404,16 +527,57 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
404
527
|
|
|
405
528
|
@override
|
|
406
529
|
def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
|
|
530
|
+
if isinstance(fn_name, FunctionRef):
|
|
531
|
+
fn_name = str(fn_name)
|
|
407
532
|
return self.function_dependencies[fn_name]
|
|
408
533
|
|
|
409
534
|
@override
|
|
410
|
-
def
|
|
535
|
+
def get_function_group_dependencies(self, fn_name: str | FunctionGroupRef) -> FunctionDependencies:
|
|
536
|
+
if isinstance(fn_name, FunctionGroupRef):
|
|
537
|
+
fn_name = str(fn_name)
|
|
538
|
+
return self.function_group_dependencies[fn_name]
|
|
539
|
+
|
|
540
|
+
@override
|
|
541
|
+
async def get_tools(self,
|
|
542
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
543
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
544
|
+
|
|
545
|
+
unique = set(tool_names)
|
|
546
|
+
if len(unique) != len(tool_names):
|
|
547
|
+
raise ValueError("Tool names must be unique")
|
|
548
|
+
|
|
549
|
+
async def _get_tools(n: str | FunctionRef | FunctionGroupRef):
|
|
550
|
+
tools = []
|
|
551
|
+
is_function_group_ref = isinstance(n, FunctionGroupRef)
|
|
552
|
+
if isinstance(n, FunctionRef) or is_function_group_ref:
|
|
553
|
+
n = str(n)
|
|
554
|
+
if n not in self._function_groups:
|
|
555
|
+
# the passed tool name is probably a function, but first check if it's a function group
|
|
556
|
+
if is_function_group_ref:
|
|
557
|
+
raise ValueError(f"Function group `{n}` not found in the list of function groups")
|
|
558
|
+
tools.append(await self.get_tool(n, wrapper_type))
|
|
559
|
+
else:
|
|
560
|
+
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
561
|
+
current_function_group = self._function_groups[n]
|
|
562
|
+
for fn_name, fn_instance in (await current_function_group.instance.get_accessible_functions()).items():
|
|
563
|
+
try:
|
|
564
|
+
tools.append(tool_wrapper_reg.build_fn(fn_name, fn_instance, self))
|
|
565
|
+
except Exception:
|
|
566
|
+
logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
|
|
567
|
+
raise
|
|
568
|
+
return tools
|
|
569
|
+
|
|
570
|
+
tool_lists = await asyncio.gather(*[_get_tools(n) for n in tool_names])
|
|
571
|
+
# Flatten the list of lists into a single list
|
|
572
|
+
return [tool for tools in tool_lists for tool in tools]
|
|
411
573
|
|
|
574
|
+
@override
|
|
575
|
+
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
576
|
+
if isinstance(fn_name, FunctionRef):
|
|
577
|
+
fn_name = str(fn_name)
|
|
412
578
|
if fn_name not in self._functions:
|
|
413
579
|
raise ValueError(f"Function `{fn_name}` not found in list of functions")
|
|
414
|
-
|
|
415
580
|
fn = self._functions[fn_name]
|
|
416
|
-
|
|
417
581
|
try:
|
|
418
582
|
# Using the registry, get the tool wrapper for the requested framework
|
|
419
583
|
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
@@ -421,11 +585,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
421
585
|
# Wrap in the correct wrapper
|
|
422
586
|
return tool_wrapper_reg.build_fn(fn_name, fn.instance, self)
|
|
423
587
|
except Exception as e:
|
|
424
|
-
logger.error("Error fetching tool `%s
|
|
425
|
-
raise
|
|
588
|
+
logger.error("Error fetching tool `%s`: %s", fn_name, e)
|
|
589
|
+
raise
|
|
426
590
|
|
|
427
591
|
@override
|
|
428
|
-
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
|
|
592
|
+
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> None:
|
|
429
593
|
|
|
430
594
|
if (name in self._llms):
|
|
431
595
|
raise ValueError(f"LLM `{name}` already exists in the list of LLMs")
|
|
@@ -437,11 +601,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
437
601
|
|
|
438
602
|
self._llms[name] = ConfiguredLLM(config=config, instance=info_obj)
|
|
439
603
|
except Exception as e:
|
|
440
|
-
logger.error("Error adding llm `%s` with config `%s
|
|
441
|
-
raise
|
|
604
|
+
logger.error("Error adding llm `%s` with config `%s`: %s", name, config, e)
|
|
605
|
+
raise
|
|
442
606
|
|
|
443
607
|
@override
|
|
444
|
-
async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str):
|
|
608
|
+
async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
445
609
|
|
|
446
610
|
if (llm_name not in self._llms):
|
|
447
611
|
raise ValueError(f"LLM `{llm_name}` not found")
|
|
@@ -458,8 +622,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
458
622
|
# Return a frameworks specific client
|
|
459
623
|
return client
|
|
460
624
|
except Exception as e:
|
|
461
|
-
logger.error("Error getting llm `%s` with wrapper `%s
|
|
462
|
-
raise
|
|
625
|
+
logger.error("Error getting llm `%s` with wrapper `%s`: %s", llm_name, wrapper_type, e)
|
|
626
|
+
raise
|
|
463
627
|
|
|
464
628
|
@override
|
|
465
629
|
def get_llm_config(self, llm_name: str | LLMRef) -> LLMBaseConfig:
|
|
@@ -509,8 +673,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
509
673
|
|
|
510
674
|
return info_obj
|
|
511
675
|
except Exception as e:
|
|
512
|
-
logger.error("Error adding authentication `%s` with config `%s
|
|
513
|
-
raise
|
|
676
|
+
logger.error("Error adding authentication `%s` with config `%s`: %s", name, config, e)
|
|
677
|
+
raise
|
|
514
678
|
|
|
515
679
|
@override
|
|
516
680
|
async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
|
|
@@ -541,7 +705,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
541
705
|
return self._auth_providers[auth_provider_name].instance
|
|
542
706
|
|
|
543
707
|
@override
|
|
544
|
-
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
|
|
708
|
+
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
|
|
545
709
|
|
|
546
710
|
if (name in self._embedders):
|
|
547
711
|
raise ValueError(f"Embedder `{name}` already exists in the list of embedders")
|
|
@@ -553,9 +717,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
553
717
|
|
|
554
718
|
self._embedders[name] = ConfiguredEmbedder(config=config, instance=info_obj)
|
|
555
719
|
except Exception as e:
|
|
556
|
-
logger.error("Error adding embedder `%s` with config `%s
|
|
557
|
-
|
|
558
|
-
raise e
|
|
720
|
+
logger.error("Error adding embedder `%s` with config `%s`: %s", name, config, e)
|
|
721
|
+
raise
|
|
559
722
|
|
|
560
723
|
@override
|
|
561
724
|
async def get_embedder(self, embedder_name: str | EmbedderRef, wrapper_type: LLMFrameworkEnum | str):
|
|
@@ -575,8 +738,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
575
738
|
# Return a frameworks specific client
|
|
576
739
|
return client
|
|
577
740
|
except Exception as e:
|
|
578
|
-
logger.error("Error getting embedder `%s` with wrapper `%s
|
|
579
|
-
raise
|
|
741
|
+
logger.error("Error getting embedder `%s` with wrapper `%s`: %s", embedder_name, wrapper_type, e)
|
|
742
|
+
raise
|
|
580
743
|
|
|
581
744
|
@override
|
|
582
745
|
def get_embedder_config(self, embedder_name: str | EmbedderRef) -> EmbedderBaseConfig:
|
|
@@ -602,7 +765,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
602
765
|
return info_obj
|
|
603
766
|
|
|
604
767
|
@override
|
|
605
|
-
def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
768
|
+
async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
606
769
|
"""
|
|
607
770
|
Return the instantiated memory client for the given name.
|
|
608
771
|
"""
|
|
@@ -648,7 +811,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
648
811
|
return self._object_stores[object_store_name].config
|
|
649
812
|
|
|
650
813
|
@override
|
|
651
|
-
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
|
|
814
|
+
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
|
|
652
815
|
|
|
653
816
|
if (name in self._retrievers):
|
|
654
817
|
raise ValueError(f"Retriever '{name}' already exists in the list of retrievers")
|
|
@@ -661,11 +824,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
661
824
|
self._retrievers[name] = ConfiguredRetriever(config=config, instance=info_obj)
|
|
662
825
|
|
|
663
826
|
except Exception as e:
|
|
664
|
-
logger.error("Error adding retriever `%s` with config `%s
|
|
665
|
-
|
|
666
|
-
raise e
|
|
667
|
-
|
|
668
|
-
# return info_obj
|
|
827
|
+
logger.error("Error adding retriever `%s` with config `%s`: %s", name, config, e)
|
|
828
|
+
raise
|
|
669
829
|
|
|
670
830
|
@override
|
|
671
831
|
async def get_retriever(self,
|
|
@@ -688,8 +848,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
688
848
|
# Return a frameworks specific client
|
|
689
849
|
return client
|
|
690
850
|
except Exception as e:
|
|
691
|
-
logger.error("Error getting retriever `%s` with wrapper `%s
|
|
692
|
-
raise
|
|
851
|
+
logger.error("Error getting retriever `%s` with wrapper `%s`: %s", retriever_name, wrapper_type, e)
|
|
852
|
+
raise
|
|
693
853
|
|
|
694
854
|
@override
|
|
695
855
|
async def get_retriever_config(self, retriever_name: str | RetrieverRef) -> RetrieverBaseConfig:
|
|
@@ -699,9 +859,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
699
859
|
|
|
700
860
|
return self._retrievers[retriever_name].config
|
|
701
861
|
|
|
702
|
-
@experimental(feature_name="TTC")
|
|
703
862
|
@override
|
|
704
|
-
|
|
863
|
+
@experimental(feature_name="TTC")
|
|
864
|
+
async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig) -> None:
|
|
705
865
|
if (name in self._ttc_strategies):
|
|
706
866
|
raise ValueError(f"TTC strategy '{name}' already exists in the list of TTC strategies")
|
|
707
867
|
|
|
@@ -713,9 +873,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
713
873
|
self._ttc_strategies[name] = ConfiguredTTCStrategy(config=config, instance=info_obj)
|
|
714
874
|
|
|
715
875
|
except Exception as e:
|
|
716
|
-
logger.error("Error adding TTC strategy `%s` with config `%s
|
|
717
|
-
|
|
718
|
-
raise e
|
|
876
|
+
logger.error("Error adding TTC strategy `%s` with config `%s`: %s", name, config, e)
|
|
877
|
+
raise
|
|
719
878
|
|
|
720
879
|
@override
|
|
721
880
|
async def get_ttc_strategy(self,
|
|
@@ -743,8 +902,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
743
902
|
|
|
744
903
|
return instance
|
|
745
904
|
except Exception as e:
|
|
746
|
-
logger.error("Error getting TTC strategy `%s
|
|
747
|
-
raise
|
|
905
|
+
logger.error("Error getting TTC strategy `%s`: %s", strategy_name, e)
|
|
906
|
+
raise
|
|
748
907
|
|
|
749
908
|
@override
|
|
750
909
|
async def get_ttc_strategy_config(self,
|
|
@@ -821,7 +980,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
821
980
|
else:
|
|
822
981
|
logger.error("No remaining components to build")
|
|
823
982
|
|
|
824
|
-
logger.error("Original error:", exc_info=
|
|
983
|
+
logger.error("Original error: %s", original_error, exc_info=True)
|
|
825
984
|
|
|
826
985
|
def _log_build_failure_component(self,
|
|
827
986
|
failing_component: ComponentInstanceData,
|
|
@@ -889,29 +1048,40 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
889
1048
|
|
|
890
1049
|
# Instantiate a the llm
|
|
891
1050
|
if component_instance.component_group == ComponentGroup.LLMS:
|
|
892
|
-
await self.add_llm(component_instance.name, component_instance.config)
|
|
1051
|
+
await self.add_llm(component_instance.name, cast(LLMBaseConfig, component_instance.config))
|
|
893
1052
|
# Instantiate a the embedder
|
|
894
1053
|
elif component_instance.component_group == ComponentGroup.EMBEDDERS:
|
|
895
|
-
await self.add_embedder(component_instance.name,
|
|
1054
|
+
await self.add_embedder(component_instance.name,
|
|
1055
|
+
cast(EmbedderBaseConfig, component_instance.config))
|
|
896
1056
|
# Instantiate a memory client
|
|
897
1057
|
elif component_instance.component_group == ComponentGroup.MEMORY:
|
|
898
|
-
await self.add_memory_client(component_instance.name,
|
|
899
|
-
|
|
1058
|
+
await self.add_memory_client(component_instance.name,
|
|
1059
|
+
cast(MemoryBaseConfig, component_instance.config))
|
|
1060
|
+
# Instantiate a object store client
|
|
900
1061
|
elif component_instance.component_group == ComponentGroup.OBJECT_STORES:
|
|
901
|
-
await self.add_object_store(component_instance.name,
|
|
1062
|
+
await self.add_object_store(component_instance.name,
|
|
1063
|
+
cast(ObjectStoreBaseConfig, component_instance.config))
|
|
902
1064
|
# Instantiate a retriever client
|
|
903
1065
|
elif component_instance.component_group == ComponentGroup.RETRIEVERS:
|
|
904
|
-
await self.add_retriever(component_instance.name,
|
|
1066
|
+
await self.add_retriever(component_instance.name,
|
|
1067
|
+
cast(RetrieverBaseConfig, component_instance.config))
|
|
1068
|
+
# Instantiate a function group
|
|
1069
|
+
elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
|
|
1070
|
+
await self.add_function_group(component_instance.name,
|
|
1071
|
+
cast(FunctionGroupBaseConfig, component_instance.config))
|
|
905
1072
|
# Instantiate a function
|
|
906
1073
|
elif component_instance.component_group == ComponentGroup.FUNCTIONS:
|
|
907
1074
|
# If the function is the root, set it as the workflow later
|
|
908
1075
|
if (not component_instance.is_root):
|
|
909
|
-
await self.add_function(component_instance.name,
|
|
1076
|
+
await self.add_function(component_instance.name,
|
|
1077
|
+
cast(FunctionBaseConfig, component_instance.config))
|
|
910
1078
|
elif component_instance.component_group == ComponentGroup.TTC_STRATEGIES:
|
|
911
|
-
await self.add_ttc_strategy(component_instance.name,
|
|
1079
|
+
await self.add_ttc_strategy(component_instance.name,
|
|
1080
|
+
cast(TTCStrategyBaseConfig, component_instance.config))
|
|
912
1081
|
|
|
913
1082
|
elif component_instance.component_group == ComponentGroup.AUTHENTICATION:
|
|
914
|
-
await self.add_auth_provider(component_instance.name,
|
|
1083
|
+
await self.add_auth_provider(component_instance.name,
|
|
1084
|
+
cast(AuthProviderBaseConfig, component_instance.config))
|
|
915
1085
|
else:
|
|
916
1086
|
raise ValueError(f"Unknown component group {component_instance.component_group}")
|
|
917
1087
|
|
|
@@ -961,18 +1131,35 @@ class ChildBuilder(Builder):
|
|
|
961
1131
|
return await self._workflow_builder.add_function(name, config)
|
|
962
1132
|
|
|
963
1133
|
@override
|
|
964
|
-
def
|
|
1134
|
+
async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
1135
|
+
return await self._workflow_builder.add_function_group(name, config)
|
|
1136
|
+
|
|
1137
|
+
@override
|
|
1138
|
+
async def get_function(self, name: str) -> Function:
|
|
965
1139
|
# If a function tries to get another function, we assume it uses it
|
|
966
|
-
fn = self._workflow_builder.get_function(name)
|
|
1140
|
+
fn = await self._workflow_builder.get_function(name)
|
|
967
1141
|
|
|
968
1142
|
self._dependencies.add_function(name)
|
|
969
1143
|
|
|
970
1144
|
return fn
|
|
971
1145
|
|
|
1146
|
+
@override
|
|
1147
|
+
async def get_function_group(self, name: str) -> FunctionGroup:
|
|
1148
|
+
# If a function tries to get a function group, we assume it uses it
|
|
1149
|
+
function_group = await self._workflow_builder.get_function_group(name)
|
|
1150
|
+
|
|
1151
|
+
self._dependencies.add_function_group(name)
|
|
1152
|
+
|
|
1153
|
+
return function_group
|
|
1154
|
+
|
|
972
1155
|
@override
|
|
973
1156
|
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
|
974
1157
|
return self._workflow_builder.get_function_config(name)
|
|
975
1158
|
|
|
1159
|
+
@override
|
|
1160
|
+
def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
|
|
1161
|
+
return self._workflow_builder.get_function_group_config(name)
|
|
1162
|
+
|
|
976
1163
|
@override
|
|
977
1164
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
978
1165
|
return await self._workflow_builder.set_workflow(config)
|
|
@@ -986,20 +1173,33 @@ class ChildBuilder(Builder):
|
|
|
986
1173
|
return self._workflow_builder.get_workflow_config()
|
|
987
1174
|
|
|
988
1175
|
@override
|
|
989
|
-
def
|
|
1176
|
+
async def get_tools(self,
|
|
1177
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
1178
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
1179
|
+
tools = await self._workflow_builder.get_tools(tool_names, wrapper_type)
|
|
1180
|
+
for tool_name in tool_names:
|
|
1181
|
+
if tool_name in self._workflow_builder._function_groups:
|
|
1182
|
+
self._dependencies.add_function_group(tool_name)
|
|
1183
|
+
else:
|
|
1184
|
+
self._dependencies.add_function(tool_name)
|
|
1185
|
+
return tools
|
|
1186
|
+
|
|
1187
|
+
@override
|
|
1188
|
+
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
|
|
990
1189
|
# If a function tries to get another function as a tool, we assume it uses it
|
|
991
|
-
fn = self._workflow_builder.get_tool(fn_name, wrapper_type)
|
|
1190
|
+
fn = await self._workflow_builder.get_tool(fn_name, wrapper_type)
|
|
992
1191
|
|
|
993
1192
|
self._dependencies.add_function(fn_name)
|
|
994
1193
|
|
|
995
1194
|
return fn
|
|
996
1195
|
|
|
997
1196
|
@override
|
|
998
|
-
async def add_llm(self, name: str, config: LLMBaseConfig):
|
|
1197
|
+
async def add_llm(self, name: str, config: LLMBaseConfig) -> None:
|
|
999
1198
|
return await self._workflow_builder.add_llm(name, config)
|
|
1000
1199
|
|
|
1200
|
+
@experimental(feature_name="Authentication")
|
|
1001
1201
|
@override
|
|
1002
|
-
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig):
|
|
1202
|
+
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> AuthProviderBase:
|
|
1003
1203
|
return await self._workflow_builder.add_auth_provider(name, config)
|
|
1004
1204
|
|
|
1005
1205
|
@override
|
|
@@ -1007,7 +1207,7 @@ class ChildBuilder(Builder):
|
|
|
1007
1207
|
return await self._workflow_builder.get_auth_provider(auth_provider_name)
|
|
1008
1208
|
|
|
1009
1209
|
@override
|
|
1010
|
-
async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str):
|
|
1210
|
+
async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
1011
1211
|
llm = await self._workflow_builder.get_llm(llm_name, wrapper_type)
|
|
1012
1212
|
|
|
1013
1213
|
self._dependencies.add_llm(llm_name)
|
|
@@ -1019,11 +1219,11 @@ class ChildBuilder(Builder):
|
|
|
1019
1219
|
return self._workflow_builder.get_llm_config(llm_name)
|
|
1020
1220
|
|
|
1021
1221
|
@override
|
|
1022
|
-
async def add_embedder(self, name: str, config: EmbedderBaseConfig):
|
|
1023
|
-
|
|
1222
|
+
async def add_embedder(self, name: str, config: EmbedderBaseConfig) -> None:
|
|
1223
|
+
await self._workflow_builder.add_embedder(name, config)
|
|
1024
1224
|
|
|
1025
1225
|
@override
|
|
1026
|
-
async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str):
|
|
1226
|
+
async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
1027
1227
|
embedder = await self._workflow_builder.get_embedder(embedder_name, wrapper_type)
|
|
1028
1228
|
|
|
1029
1229
|
self._dependencies.add_embedder(embedder_name)
|
|
@@ -1039,11 +1239,11 @@ class ChildBuilder(Builder):
|
|
|
1039
1239
|
return await self._workflow_builder.add_memory_client(name, config)
|
|
1040
1240
|
|
|
1041
1241
|
@override
|
|
1042
|
-
def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
|
1242
|
+
async def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
|
1043
1243
|
"""
|
|
1044
1244
|
Return the instantiated memory client for the given name.
|
|
1045
1245
|
"""
|
|
1046
|
-
memory_client = self._workflow_builder.get_memory_client(memory_name)
|
|
1246
|
+
memory_client = await self._workflow_builder.get_memory_client(memory_name)
|
|
1047
1247
|
|
|
1048
1248
|
self._dependencies.add_memory_client(memory_name)
|
|
1049
1249
|
|
|
@@ -1073,8 +1273,9 @@ class ChildBuilder(Builder):
|
|
|
1073
1273
|
return self._workflow_builder.get_object_store_config(object_store_name)
|
|
1074
1274
|
|
|
1075
1275
|
@override
|
|
1076
|
-
|
|
1077
|
-
|
|
1276
|
+
@experimental(feature_name="TTC")
|
|
1277
|
+
async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
|
|
1278
|
+
await self._workflow_builder.add_ttc_strategy(name, config)
|
|
1078
1279
|
|
|
1079
1280
|
@override
|
|
1080
1281
|
async def get_ttc_strategy(self,
|
|
@@ -1095,11 +1296,11 @@ class ChildBuilder(Builder):
|
|
|
1095
1296
|
stage_type=stage_type)
|
|
1096
1297
|
|
|
1097
1298
|
@override
|
|
1098
|
-
async def add_retriever(self, name: str, config: RetrieverBaseConfig):
|
|
1099
|
-
|
|
1299
|
+
async def add_retriever(self, name: str, config: RetrieverBaseConfig) -> None:
|
|
1300
|
+
await self._workflow_builder.add_retriever(name, config)
|
|
1100
1301
|
|
|
1101
1302
|
@override
|
|
1102
|
-
async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None):
|
|
1303
|
+
async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None) -> Retriever:
|
|
1103
1304
|
if not wrapper_type:
|
|
1104
1305
|
return await self._workflow_builder.get_retriever(retriever_name=retriever_name)
|
|
1105
1306
|
return await self._workflow_builder.get_retriever(retriever_name=retriever_name, wrapper_type=wrapper_type)
|
|
@@ -1115,3 +1316,7 @@ class ChildBuilder(Builder):
|
|
|
1115
1316
|
@override
|
|
1116
1317
|
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1117
1318
|
return self._workflow_builder.get_function_dependencies(fn_name)
|
|
1319
|
+
|
|
1320
|
+
@override
|
|
1321
|
+
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1322
|
+
return self._workflow_builder.get_function_group_dependencies(fn_name)
|