nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +13 -8
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +6 -5
- nat/agent/react_agent/register.py +49 -39
- nat/agent/reasoning_agent/reasoning_agent.py +17 -15
- nat/agent/register.py +2 -0
- nat/agent/responses_api_agent/__init__.py +14 -0
- nat/agent/responses_api_agent/register.py +126 -0
- nat/agent/rewoo_agent/agent.py +304 -117
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +51 -38
- nat/agent/tool_calling_agent/agent.py +75 -17
- nat/agent/tool_calling_agent/register.py +46 -23
- nat/authentication/api_key/api_key_auth_provider.py +6 -11
- nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
- nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
- nat/builder/builder.py +55 -23
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +54 -15
- nat/builder/eval_builder.py +14 -9
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +370 -0
- nat/builder/function_info.py +1 -1
- nat/builder/intermediate_step_manager.py +38 -2
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +306 -54
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/start.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/register.py.j2 +2 -2
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +60 -18
- nat/cli/entrypoint.py +15 -11
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +72 -1
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +199 -69
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +47 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +4 -3
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/intermediate_step.py +9 -1
- nat/data_models/llm.py +15 -1
- nat/data_models/openai_mcp.py +46 -0
- nat/data_models/optimizable.py +208 -0
- nat/data_models/optimizer.py +161 -0
- nat/data_models/span.py +41 -3
- nat/data_models/thinking_mixin.py +2 -2
- nat/embedder/azure_openai_embedder.py +2 -1
- nat/embedder/nim_embedder.py +3 -2
- nat/embedder/openai_embedder.py +3 -2
- nat/eval/config.py +1 -1
- nat/eval/dataset_handler/dataset_downloader.py +3 -2
- nat/eval/dataset_handler/dataset_filter.py +34 -2
- nat/eval/evaluate.py +10 -3
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +7 -4
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
- nat/eval/usage_stats.py +2 -0
- nat/eval/utils/output_uploader.py +3 -2
- nat/eval/utils/weave_eval.py +17 -3
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +2 -2
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +19 -7
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +69 -44
- nat/front_ends/fastapi/message_validator.py +8 -7
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +71 -3
- nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/tool_converter.py +78 -25
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +21 -8
- nat/llm/azure_openai_llm.py +14 -5
- nat/llm/litellm_llm.py +80 -0
- nat/llm/nim_llm.py +23 -9
- nat/llm/openai_llm.py +19 -7
- nat/llm/register.py +4 -0
- nat/llm/utils/thinking.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/exporter/span_exporter.py +43 -15
- nat/observability/exporter_manager.py +2 -2
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +1 -1
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/observability/register.py +16 -0
- nat/profiler/callbacks/langchain_callback_handler.py +32 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
- nat/profiler/callbacks/token_usage_base_model.py +2 -0
- nat/profiler/decorators/framework_wrapper.py +61 -9
- nat/profiler/decorators/function_tracking.py +35 -3
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/registry_handlers/pypi/register_pypi.py +5 -3
- nat/registry_handlers/rest/register_rest.py +5 -3
- nat/retriever/milvus/retriever.py +1 -1
- nat/retriever/nemo_retriever/register.py +2 -1
- nat/runtime/loader.py +1 -1
- nat/runtime/runner.py +111 -6
- nat/runtime/session.py +49 -3
- nat/settings/global_settings.py +2 -2
- nat/tool/chat_completion.py +4 -1
- nat/tool/code_execution/code_sandbox.py +3 -6
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
- nat/tool/datetime_tools.py +1 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/tool/register.py +2 -7
- nat/tool/server_tools.py +15 -2
- nat/utils/__init__.py +76 -0
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +278 -72
- nat/utils/io/yaml_tools.py +73 -3
- nat/utils/log_levels.py +25 -0
- nat/utils/responses_api.py +26 -0
- nat/utils/string_utils.py +16 -0
- nat/utils/type_converter.py +12 -3
- nat/utils/type_utils.py +6 -2
- nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -461
- nat/data_models/temperature_mixin.py +0 -43
- nat/data_models/top_p_mixin.py +0 -43
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
nat/builder/workflow_builder.py
CHANGED
|
@@ -13,13 +13,17 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import dataclasses
|
|
17
18
|
import inspect
|
|
18
19
|
import logging
|
|
20
|
+
import typing
|
|
19
21
|
import warnings
|
|
22
|
+
from collections.abc import Sequence
|
|
20
23
|
from contextlib import AbstractAsyncContextManager
|
|
21
24
|
from contextlib import AsyncExitStack
|
|
22
25
|
from contextlib import asynccontextmanager
|
|
26
|
+
from typing import cast
|
|
23
27
|
|
|
24
28
|
from nat.authentication.interfaces import AuthProviderBase
|
|
25
29
|
from nat.builder.builder import Builder
|
|
@@ -31,6 +35,7 @@ from nat.builder.context import ContextState
|
|
|
31
35
|
from nat.builder.embedder import EmbedderProviderInfo
|
|
32
36
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
33
37
|
from nat.builder.function import Function
|
|
38
|
+
from nat.builder.function import FunctionGroup
|
|
34
39
|
from nat.builder.function import LambdaFunction
|
|
35
40
|
from nat.builder.function_info import FunctionInfo
|
|
36
41
|
from nat.builder.llm import LLMProviderInfo
|
|
@@ -42,6 +47,7 @@ from nat.data_models.authentication import AuthProviderBaseConfig
|
|
|
42
47
|
from nat.data_models.component import ComponentGroup
|
|
43
48
|
from nat.data_models.component_ref import AuthenticationRef
|
|
44
49
|
from nat.data_models.component_ref import EmbedderRef
|
|
50
|
+
from nat.data_models.component_ref import FunctionGroupRef
|
|
45
51
|
from nat.data_models.component_ref import FunctionRef
|
|
46
52
|
from nat.data_models.component_ref import LLMRef
|
|
47
53
|
from nat.data_models.component_ref import MemoryRef
|
|
@@ -52,6 +58,7 @@ from nat.data_models.config import Config
|
|
|
52
58
|
from nat.data_models.config import GeneralConfig
|
|
53
59
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
54
60
|
from nat.data_models.function import FunctionBaseConfig
|
|
61
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
55
62
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
56
63
|
from nat.data_models.llm import LLMBaseConfig
|
|
57
64
|
from nat.data_models.memory import MemoryBaseConfig
|
|
@@ -68,6 +75,7 @@ from nat.object_store.interfaces import ObjectStore
|
|
|
68
75
|
from nat.observability.exporter.base_exporter import BaseExporter
|
|
69
76
|
from nat.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
|
|
70
77
|
from nat.profiler.utils import detect_llm_frameworks_in_build_fn
|
|
78
|
+
from nat.retriever.interface import Retriever
|
|
71
79
|
from nat.utils.type_utils import override
|
|
72
80
|
|
|
73
81
|
logger = logging.getLogger(__name__)
|
|
@@ -85,6 +93,12 @@ class ConfiguredFunction:
|
|
|
85
93
|
instance: Function
|
|
86
94
|
|
|
87
95
|
|
|
96
|
+
@dataclasses.dataclass
|
|
97
|
+
class ConfiguredFunctionGroup:
|
|
98
|
+
config: FunctionGroupBaseConfig
|
|
99
|
+
instance: FunctionGroup
|
|
100
|
+
|
|
101
|
+
|
|
88
102
|
@dataclasses.dataclass
|
|
89
103
|
class ConfiguredLLM:
|
|
90
104
|
config: LLMBaseConfig
|
|
@@ -142,9 +156,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
142
156
|
self._registry = registry
|
|
143
157
|
|
|
144
158
|
self._logging_handlers: dict[str, logging.Handler] = {}
|
|
159
|
+
self._removed_root_handlers: list[tuple[logging.Handler, int]] = []
|
|
145
160
|
self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
|
|
146
161
|
|
|
147
162
|
self._functions: dict[str, ConfiguredFunction] = {}
|
|
163
|
+
self._function_groups: dict[str, ConfiguredFunctionGroup] = {}
|
|
148
164
|
self._workflow: ConfiguredFunction | None = None
|
|
149
165
|
|
|
150
166
|
self._llms: dict[str, ConfiguredLLM] = {}
|
|
@@ -161,7 +177,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
161
177
|
|
|
162
178
|
# Create a mapping to track function name -> other function names it depends on
|
|
163
179
|
self.function_dependencies: dict[str, FunctionDependencies] = {}
|
|
180
|
+
self.function_group_dependencies: dict[str, FunctionDependencies] = {}
|
|
164
181
|
self.current_function_building: str | None = None
|
|
182
|
+
self.current_function_group_building: str | None = None
|
|
165
183
|
|
|
166
184
|
async def __aenter__(self):
|
|
167
185
|
|
|
@@ -170,6 +188,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
170
188
|
# Get the telemetry info from the config
|
|
171
189
|
telemetry_config = self.general_config.telemetry
|
|
172
190
|
|
|
191
|
+
# If we have logging configuration, we need to manage the root logger properly
|
|
192
|
+
root_logger = logging.getLogger()
|
|
193
|
+
|
|
194
|
+
# Collect configured handler types to determine if we need to adjust existing handlers
|
|
195
|
+
# This is somewhat of a hack by inspecting the class name of the config object
|
|
196
|
+
has_console_handler = any(
|
|
197
|
+
hasattr(config, "__class__") and "console" in config.__class__.__name__.lower()
|
|
198
|
+
for config in telemetry_config.logging.values())
|
|
199
|
+
|
|
173
200
|
for key, logging_config in telemetry_config.logging.items():
|
|
174
201
|
# Use the same pattern as tracing, but for logging
|
|
175
202
|
logging_info = self._registry.get_logging_method(type(logging_config))
|
|
@@ -183,7 +210,31 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
183
210
|
self._logging_handlers[key] = handler
|
|
184
211
|
|
|
185
212
|
# Now attach to NAT's root logger
|
|
186
|
-
|
|
213
|
+
root_logger.addHandler(handler)
|
|
214
|
+
|
|
215
|
+
# If we added logging handlers, manage existing handlers appropriately
|
|
216
|
+
if self._logging_handlers:
|
|
217
|
+
min_handler_level = min((handler.level for handler in root_logger.handlers), default=logging.CRITICAL)
|
|
218
|
+
|
|
219
|
+
# Ensure the root logger level allows messages through
|
|
220
|
+
root_logger.level = max(root_logger.level, min_handler_level)
|
|
221
|
+
|
|
222
|
+
# If a console handler is configured, adjust or remove default CLI handlers
|
|
223
|
+
# to avoid duplicate output while preserving workflow visibility
|
|
224
|
+
if has_console_handler:
|
|
225
|
+
# Remove existing StreamHandlers that are not the newly configured ones
|
|
226
|
+
for handler in root_logger.handlers[:]:
|
|
227
|
+
if type(handler) is logging.StreamHandler and handler not in self._logging_handlers.values():
|
|
228
|
+
self._removed_root_handlers.append((handler, handler.level))
|
|
229
|
+
root_logger.removeHandler(handler)
|
|
230
|
+
else:
|
|
231
|
+
# No console handler configured, but adjust existing handler levels
|
|
232
|
+
# to respect the minimum configured level for file/other handlers
|
|
233
|
+
for handler in root_logger.handlers[:]:
|
|
234
|
+
if type(handler) is logging.StreamHandler:
|
|
235
|
+
old_level = handler.level
|
|
236
|
+
handler.setLevel(min_handler_level)
|
|
237
|
+
self._removed_root_handlers.append((handler, old_level))
|
|
187
238
|
|
|
188
239
|
# Add the telemetry exporters
|
|
189
240
|
for key, telemetry_exporter_config in telemetry_config.tracing.items():
|
|
@@ -195,12 +246,21 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
195
246
|
|
|
196
247
|
assert self._exit_stack is not None, "Exit stack not initialized"
|
|
197
248
|
|
|
198
|
-
|
|
199
|
-
|
|
249
|
+
root_logger = logging.getLogger()
|
|
250
|
+
|
|
251
|
+
# Remove custom logging handlers
|
|
252
|
+
for handler in self._logging_handlers.values():
|
|
253
|
+
root_logger.removeHandler(handler)
|
|
254
|
+
|
|
255
|
+
# Restore original handlers and their levels
|
|
256
|
+
for handler, old_level in self._removed_root_handlers:
|
|
257
|
+
if handler not in root_logger.handlers:
|
|
258
|
+
root_logger.addHandler(handler)
|
|
259
|
+
handler.setLevel(old_level)
|
|
200
260
|
|
|
201
261
|
await self._exit_stack.__aexit__(*exc_details)
|
|
202
262
|
|
|
203
|
-
def build(self, entry_function: str | None = None) -> Workflow:
|
|
263
|
+
async def build(self, entry_function: str | None = None) -> Workflow:
|
|
204
264
|
"""
|
|
205
265
|
Creates an instance of a workflow object using the added components and the desired entry function.
|
|
206
266
|
|
|
@@ -224,12 +284,32 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
224
284
|
if (self._workflow is None):
|
|
225
285
|
raise ValueError("Must set a workflow before building")
|
|
226
286
|
|
|
287
|
+
# Set of all functions which are "included" by function groups
|
|
288
|
+
included_functions = set()
|
|
289
|
+
# Dictionary of function configs
|
|
290
|
+
function_configs = dict()
|
|
291
|
+
# Dictionary of function group configs
|
|
292
|
+
function_group_configs = dict()
|
|
293
|
+
# Dictionary of function instances
|
|
294
|
+
function_instances = dict()
|
|
295
|
+
# Dictionary of function group instances
|
|
296
|
+
function_group_instances = dict()
|
|
297
|
+
|
|
298
|
+
for k, v in self._function_groups.items():
|
|
299
|
+
included_functions.update((await v.instance.get_included_functions()).keys())
|
|
300
|
+
function_group_configs[k] = v.config
|
|
301
|
+
function_group_instances[k] = v.instance
|
|
302
|
+
|
|
303
|
+
# Function configs need to be restricted to only the functions that are not in a function group
|
|
304
|
+
for k, v in self._functions.items():
|
|
305
|
+
if k not in included_functions:
|
|
306
|
+
function_configs[k] = v.config
|
|
307
|
+
function_instances[k] = v.instance
|
|
308
|
+
|
|
227
309
|
# Build the config from the added objects
|
|
228
310
|
config = Config(general=self.general_config,
|
|
229
|
-
functions=
|
|
230
|
-
|
|
231
|
-
for k, v in self._functions.items()
|
|
232
|
-
},
|
|
311
|
+
functions=function_configs,
|
|
312
|
+
function_groups=function_group_configs,
|
|
233
313
|
workflow=self._workflow.config,
|
|
234
314
|
llms={
|
|
235
315
|
k: v.config
|
|
@@ -259,14 +339,12 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
259
339
|
if (entry_function is None):
|
|
260
340
|
entry_fn_obj = self.get_workflow()
|
|
261
341
|
else:
|
|
262
|
-
entry_fn_obj = self.get_function(entry_function)
|
|
342
|
+
entry_fn_obj = await self.get_function(entry_function)
|
|
263
343
|
|
|
264
344
|
workflow = Workflow.from_entry_fn(config=config,
|
|
265
345
|
entry_fn=entry_fn_obj,
|
|
266
|
-
functions=
|
|
267
|
-
|
|
268
|
-
for k, v in self._functions.items()
|
|
269
|
-
},
|
|
346
|
+
functions=function_instances,
|
|
347
|
+
function_groups=function_group_instances,
|
|
270
348
|
llms={
|
|
271
349
|
k: v.instance
|
|
272
350
|
for k, v in self._llms.items()
|
|
@@ -347,11 +425,53 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
347
425
|
|
|
348
426
|
return ConfiguredFunction(config=config, instance=build_result)
|
|
349
427
|
|
|
428
|
+
async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
|
|
429
|
+
"""Build a function group from the provided configuration.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
name: The name of the function group
|
|
433
|
+
config: The function group configuration
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
ConfiguredFunctionGroup: The built function group
|
|
437
|
+
|
|
438
|
+
Raises:
|
|
439
|
+
ValueError: If the function group builder returns invalid results
|
|
440
|
+
"""
|
|
441
|
+
registration = self._registry.get_function_group(type(config))
|
|
442
|
+
|
|
443
|
+
inner_builder = ChildBuilder(self)
|
|
444
|
+
|
|
445
|
+
# Build the function group - use the same wrapping pattern as _build_function
|
|
446
|
+
llms = {k: v.instance for k, v in self._llms.items()}
|
|
447
|
+
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
448
|
+
|
|
449
|
+
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
450
|
+
|
|
451
|
+
# Set the currently building function group so the ChildBuilder can track dependencies
|
|
452
|
+
self.current_function_group_building = config.type
|
|
453
|
+
# Empty set of dependencies for the current function group
|
|
454
|
+
self.function_group_dependencies[config.type] = FunctionDependencies()
|
|
455
|
+
|
|
456
|
+
build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
|
|
457
|
+
|
|
458
|
+
self.function_group_dependencies[name] = inner_builder.dependencies
|
|
459
|
+
|
|
460
|
+
if not isinstance(build_result, FunctionGroup):
|
|
461
|
+
raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
|
|
462
|
+
f"Got {type(build_result)}")
|
|
463
|
+
|
|
464
|
+
# set the instance name for the function group based on the workflow-provided name
|
|
465
|
+
build_result.set_instance_name(name)
|
|
466
|
+
return ConfiguredFunctionGroup(config=config, instance=build_result)
|
|
467
|
+
|
|
350
468
|
@override
|
|
351
469
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
470
|
+
if isinstance(name, FunctionRef):
|
|
471
|
+
name = str(name)
|
|
352
472
|
|
|
353
|
-
if (name in self._functions):
|
|
354
|
-
raise ValueError(f"Function `{name}` already exists in the list of functions")
|
|
473
|
+
if (name in self._functions or name in self._function_groups):
|
|
474
|
+
raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
|
|
355
475
|
|
|
356
476
|
build_result = await self._build_function(name=name, config=config)
|
|
357
477
|
|
|
@@ -360,20 +480,67 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
360
480
|
return build_result.instance
|
|
361
481
|
|
|
362
482
|
@override
|
|
363
|
-
def
|
|
483
|
+
async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
484
|
+
if isinstance(name, FunctionGroupRef):
|
|
485
|
+
name = str(name)
|
|
486
|
+
|
|
487
|
+
if (name in self._function_groups or name in self._functions):
|
|
488
|
+
raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
|
|
489
|
+
|
|
490
|
+
# Build the function group
|
|
491
|
+
build_result = await self._build_function_group(name=name, config=config)
|
|
364
492
|
|
|
493
|
+
self._function_groups[name] = build_result
|
|
494
|
+
|
|
495
|
+
# If the function group exposes functions, add them to the global function registry
|
|
496
|
+
# If the function group exposes functions, record and add them to the registry
|
|
497
|
+
included_functions = await build_result.instance.get_included_functions()
|
|
498
|
+
for k in included_functions:
|
|
499
|
+
if k in self._functions:
|
|
500
|
+
raise ValueError(f"Exposed function `{k}` from group `{name}` conflicts with an existing function")
|
|
501
|
+
self._functions.update({
|
|
502
|
+
k: ConfiguredFunction(config=v.config, instance=v)
|
|
503
|
+
for k, v in included_functions.items()
|
|
504
|
+
})
|
|
505
|
+
|
|
506
|
+
return build_result.instance
|
|
507
|
+
|
|
508
|
+
@override
|
|
509
|
+
async def get_function(self, name: str | FunctionRef) -> Function:
|
|
510
|
+
if isinstance(name, FunctionRef):
|
|
511
|
+
name = str(name)
|
|
365
512
|
if name not in self._functions:
|
|
366
513
|
raise ValueError(f"Function `{name}` not found")
|
|
367
514
|
|
|
368
515
|
return self._functions[name].instance
|
|
369
516
|
|
|
517
|
+
@override
|
|
518
|
+
async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
519
|
+
if isinstance(name, FunctionGroupRef):
|
|
520
|
+
name = str(name)
|
|
521
|
+
if name not in self._function_groups:
|
|
522
|
+
raise ValueError(f"Function group `{name}` not found")
|
|
523
|
+
|
|
524
|
+
return self._function_groups[name].instance
|
|
525
|
+
|
|
370
526
|
@override
|
|
371
527
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
528
|
+
if isinstance(name, FunctionRef):
|
|
529
|
+
name = str(name)
|
|
372
530
|
if name not in self._functions:
|
|
373
531
|
raise ValueError(f"Function `{name}` not found")
|
|
374
532
|
|
|
375
533
|
return self._functions[name].config
|
|
376
534
|
|
|
535
|
+
@override
|
|
536
|
+
def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
|
|
537
|
+
if isinstance(name, FunctionGroupRef):
|
|
538
|
+
name = str(name)
|
|
539
|
+
if name not in self._function_groups:
|
|
540
|
+
raise ValueError(f"Function group `{name}` not found")
|
|
541
|
+
|
|
542
|
+
return self._function_groups[name].config
|
|
543
|
+
|
|
377
544
|
@override
|
|
378
545
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
379
546
|
|
|
@@ -403,16 +570,57 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
403
570
|
|
|
404
571
|
@override
|
|
405
572
|
def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
|
|
573
|
+
if isinstance(fn_name, FunctionRef):
|
|
574
|
+
fn_name = str(fn_name)
|
|
406
575
|
return self.function_dependencies[fn_name]
|
|
407
576
|
|
|
408
577
|
@override
|
|
409
|
-
def
|
|
578
|
+
def get_function_group_dependencies(self, fn_name: str | FunctionGroupRef) -> FunctionDependencies:
|
|
579
|
+
if isinstance(fn_name, FunctionGroupRef):
|
|
580
|
+
fn_name = str(fn_name)
|
|
581
|
+
return self.function_group_dependencies[fn_name]
|
|
582
|
+
|
|
583
|
+
@override
|
|
584
|
+
async def get_tools(self,
|
|
585
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
586
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
587
|
+
|
|
588
|
+
unique = set(tool_names)
|
|
589
|
+
if len(unique) != len(tool_names):
|
|
590
|
+
raise ValueError("Tool names must be unique")
|
|
591
|
+
|
|
592
|
+
async def _get_tools(n: str | FunctionRef | FunctionGroupRef):
|
|
593
|
+
tools = []
|
|
594
|
+
is_function_group_ref = isinstance(n, FunctionGroupRef)
|
|
595
|
+
if isinstance(n, FunctionRef) or is_function_group_ref:
|
|
596
|
+
n = str(n)
|
|
597
|
+
if n not in self._function_groups:
|
|
598
|
+
# the passed tool name is probably a function, but first check if it's a function group
|
|
599
|
+
if is_function_group_ref:
|
|
600
|
+
raise ValueError(f"Function group `{n}` not found in the list of function groups")
|
|
601
|
+
tools.append(await self.get_tool(n, wrapper_type))
|
|
602
|
+
else:
|
|
603
|
+
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
604
|
+
current_function_group = self._function_groups[n]
|
|
605
|
+
for fn_name, fn_instance in (await current_function_group.instance.get_accessible_functions()).items():
|
|
606
|
+
try:
|
|
607
|
+
tools.append(tool_wrapper_reg.build_fn(fn_name, fn_instance, self))
|
|
608
|
+
except Exception:
|
|
609
|
+
logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
|
|
610
|
+
raise
|
|
611
|
+
return tools
|
|
612
|
+
|
|
613
|
+
tool_lists = await asyncio.gather(*[_get_tools(n) for n in tool_names])
|
|
614
|
+
# Flatten the list of lists into a single list
|
|
615
|
+
return [tool for tools in tool_lists for tool in tools]
|
|
410
616
|
|
|
617
|
+
@override
|
|
618
|
+
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
619
|
+
if isinstance(fn_name, FunctionRef):
|
|
620
|
+
fn_name = str(fn_name)
|
|
411
621
|
if fn_name not in self._functions:
|
|
412
622
|
raise ValueError(f"Function `{fn_name}` not found in list of functions")
|
|
413
|
-
|
|
414
623
|
fn = self._functions[fn_name]
|
|
415
|
-
|
|
416
624
|
try:
|
|
417
625
|
# Using the registry, get the tool wrapper for the requested framework
|
|
418
626
|
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
@@ -424,7 +632,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
424
632
|
raise
|
|
425
633
|
|
|
426
634
|
@override
|
|
427
|
-
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
|
|
635
|
+
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> None:
|
|
428
636
|
|
|
429
637
|
if (name in self._llms):
|
|
430
638
|
raise ValueError(f"LLM `{name}` already exists in the list of LLMs")
|
|
@@ -440,7 +648,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
440
648
|
raise
|
|
441
649
|
|
|
442
650
|
@override
|
|
443
|
-
async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str):
|
|
651
|
+
async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
444
652
|
|
|
445
653
|
if (llm_name not in self._llms):
|
|
446
654
|
raise ValueError(f"LLM `{llm_name}` not found")
|
|
@@ -540,7 +748,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
540
748
|
return self._auth_providers[auth_provider_name].instance
|
|
541
749
|
|
|
542
750
|
@override
|
|
543
|
-
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
|
|
751
|
+
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
|
|
544
752
|
|
|
545
753
|
if (name in self._embedders):
|
|
546
754
|
raise ValueError(f"Embedder `{name}` already exists in the list of embedders")
|
|
@@ -600,7 +808,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
600
808
|
return info_obj
|
|
601
809
|
|
|
602
810
|
@override
|
|
603
|
-
def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
811
|
+
async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
604
812
|
"""
|
|
605
813
|
Return the instantiated memory client for the given name.
|
|
606
814
|
"""
|
|
@@ -646,7 +854,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
646
854
|
return self._object_stores[object_store_name].config
|
|
647
855
|
|
|
648
856
|
@override
|
|
649
|
-
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
|
|
857
|
+
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
|
|
650
858
|
|
|
651
859
|
if (name in self._retrievers):
|
|
652
860
|
raise ValueError(f"Retriever '{name}' already exists in the list of retrievers")
|
|
@@ -662,8 +870,6 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
662
870
|
logger.error("Error adding retriever `%s` with config `%s`: %s", name, config, e)
|
|
663
871
|
raise
|
|
664
872
|
|
|
665
|
-
# return info_obj
|
|
666
|
-
|
|
667
873
|
@override
|
|
668
874
|
async def get_retriever(self,
|
|
669
875
|
retriever_name: str | RetrieverRef,
|
|
@@ -696,9 +902,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
696
902
|
|
|
697
903
|
return self._retrievers[retriever_name].config
|
|
698
904
|
|
|
699
|
-
@experimental(feature_name="TTC")
|
|
700
905
|
@override
|
|
701
|
-
|
|
906
|
+
@experimental(feature_name="TTC")
|
|
907
|
+
async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig) -> None:
|
|
702
908
|
if (name in self._ttc_strategies):
|
|
703
909
|
raise ValueError(f"TTC strategy '{name}' already exists in the list of TTC strategies")
|
|
704
910
|
|
|
@@ -885,29 +1091,40 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
885
1091
|
|
|
886
1092
|
# Instantiate a the llm
|
|
887
1093
|
if component_instance.component_group == ComponentGroup.LLMS:
|
|
888
|
-
await self.add_llm(component_instance.name, component_instance.config)
|
|
1094
|
+
await self.add_llm(component_instance.name, cast(LLMBaseConfig, component_instance.config))
|
|
889
1095
|
# Instantiate a the embedder
|
|
890
1096
|
elif component_instance.component_group == ComponentGroup.EMBEDDERS:
|
|
891
|
-
await self.add_embedder(component_instance.name,
|
|
1097
|
+
await self.add_embedder(component_instance.name,
|
|
1098
|
+
cast(EmbedderBaseConfig, component_instance.config))
|
|
892
1099
|
# Instantiate a memory client
|
|
893
1100
|
elif component_instance.component_group == ComponentGroup.MEMORY:
|
|
894
|
-
await self.add_memory_client(component_instance.name,
|
|
895
|
-
|
|
1101
|
+
await self.add_memory_client(component_instance.name,
|
|
1102
|
+
cast(MemoryBaseConfig, component_instance.config))
|
|
1103
|
+
# Instantiate a object store client
|
|
896
1104
|
elif component_instance.component_group == ComponentGroup.OBJECT_STORES:
|
|
897
|
-
await self.add_object_store(component_instance.name,
|
|
1105
|
+
await self.add_object_store(component_instance.name,
|
|
1106
|
+
cast(ObjectStoreBaseConfig, component_instance.config))
|
|
898
1107
|
# Instantiate a retriever client
|
|
899
1108
|
elif component_instance.component_group == ComponentGroup.RETRIEVERS:
|
|
900
|
-
await self.add_retriever(component_instance.name,
|
|
1109
|
+
await self.add_retriever(component_instance.name,
|
|
1110
|
+
cast(RetrieverBaseConfig, component_instance.config))
|
|
1111
|
+
# Instantiate a function group
|
|
1112
|
+
elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
|
|
1113
|
+
await self.add_function_group(component_instance.name,
|
|
1114
|
+
cast(FunctionGroupBaseConfig, component_instance.config))
|
|
901
1115
|
# Instantiate a function
|
|
902
1116
|
elif component_instance.component_group == ComponentGroup.FUNCTIONS:
|
|
903
1117
|
# If the function is the root, set it as the workflow later
|
|
904
1118
|
if (not component_instance.is_root):
|
|
905
|
-
await self.add_function(component_instance.name,
|
|
1119
|
+
await self.add_function(component_instance.name,
|
|
1120
|
+
cast(FunctionBaseConfig, component_instance.config))
|
|
906
1121
|
elif component_instance.component_group == ComponentGroup.TTC_STRATEGIES:
|
|
907
|
-
await self.add_ttc_strategy(component_instance.name,
|
|
1122
|
+
await self.add_ttc_strategy(component_instance.name,
|
|
1123
|
+
cast(TTCStrategyBaseConfig, component_instance.config))
|
|
908
1124
|
|
|
909
1125
|
elif component_instance.component_group == ComponentGroup.AUTHENTICATION:
|
|
910
|
-
await self.add_auth_provider(component_instance.name,
|
|
1126
|
+
await self.add_auth_provider(component_instance.name,
|
|
1127
|
+
cast(AuthProviderBaseConfig, component_instance.config))
|
|
911
1128
|
else:
|
|
912
1129
|
raise ValueError(f"Unknown component group {component_instance.component_group}")
|
|
913
1130
|
|
|
@@ -957,18 +1174,35 @@ class ChildBuilder(Builder):
|
|
|
957
1174
|
return await self._workflow_builder.add_function(name, config)
|
|
958
1175
|
|
|
959
1176
|
@override
|
|
960
|
-
def
|
|
1177
|
+
async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
1178
|
+
return await self._workflow_builder.add_function_group(name, config)
|
|
1179
|
+
|
|
1180
|
+
@override
|
|
1181
|
+
async def get_function(self, name: str) -> Function:
|
|
961
1182
|
# If a function tries to get another function, we assume it uses it
|
|
962
|
-
fn = self._workflow_builder.get_function(name)
|
|
1183
|
+
fn = await self._workflow_builder.get_function(name)
|
|
963
1184
|
|
|
964
1185
|
self._dependencies.add_function(name)
|
|
965
1186
|
|
|
966
1187
|
return fn
|
|
967
1188
|
|
|
1189
|
+
@override
|
|
1190
|
+
async def get_function_group(self, name: str) -> FunctionGroup:
|
|
1191
|
+
# If a function tries to get a function group, we assume it uses it
|
|
1192
|
+
function_group = await self._workflow_builder.get_function_group(name)
|
|
1193
|
+
|
|
1194
|
+
self._dependencies.add_function_group(name)
|
|
1195
|
+
|
|
1196
|
+
return function_group
|
|
1197
|
+
|
|
968
1198
|
@override
|
|
969
1199
|
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
|
970
1200
|
return self._workflow_builder.get_function_config(name)
|
|
971
1201
|
|
|
1202
|
+
@override
|
|
1203
|
+
def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
|
|
1204
|
+
return self._workflow_builder.get_function_group_config(name)
|
|
1205
|
+
|
|
972
1206
|
@override
|
|
973
1207
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
974
1208
|
return await self._workflow_builder.set_workflow(config)
|
|
@@ -982,20 +1216,33 @@ class ChildBuilder(Builder):
|
|
|
982
1216
|
return self._workflow_builder.get_workflow_config()
|
|
983
1217
|
|
|
984
1218
|
@override
|
|
985
|
-
def
|
|
1219
|
+
async def get_tools(self,
|
|
1220
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
1221
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
1222
|
+
tools = await self._workflow_builder.get_tools(tool_names, wrapper_type)
|
|
1223
|
+
for tool_name in tool_names:
|
|
1224
|
+
if tool_name in self._workflow_builder._function_groups:
|
|
1225
|
+
self._dependencies.add_function_group(tool_name)
|
|
1226
|
+
else:
|
|
1227
|
+
self._dependencies.add_function(tool_name)
|
|
1228
|
+
return tools
|
|
1229
|
+
|
|
1230
|
+
@override
|
|
1231
|
+
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
|
|
986
1232
|
# If a function tries to get another function as a tool, we assume it uses it
|
|
987
|
-
fn = self._workflow_builder.get_tool(fn_name, wrapper_type)
|
|
1233
|
+
fn = await self._workflow_builder.get_tool(fn_name, wrapper_type)
|
|
988
1234
|
|
|
989
1235
|
self._dependencies.add_function(fn_name)
|
|
990
1236
|
|
|
991
1237
|
return fn
|
|
992
1238
|
|
|
993
1239
|
@override
|
|
994
|
-
async def add_llm(self, name: str, config: LLMBaseConfig):
|
|
1240
|
+
async def add_llm(self, name: str, config: LLMBaseConfig) -> None:
|
|
995
1241
|
return await self._workflow_builder.add_llm(name, config)
|
|
996
1242
|
|
|
1243
|
+
@experimental(feature_name="Authentication")
|
|
997
1244
|
@override
|
|
998
|
-
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig):
|
|
1245
|
+
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> AuthProviderBase:
|
|
999
1246
|
return await self._workflow_builder.add_auth_provider(name, config)
|
|
1000
1247
|
|
|
1001
1248
|
@override
|
|
@@ -1003,7 +1250,7 @@ class ChildBuilder(Builder):
|
|
|
1003
1250
|
return await self._workflow_builder.get_auth_provider(auth_provider_name)
|
|
1004
1251
|
|
|
1005
1252
|
@override
|
|
1006
|
-
async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str):
|
|
1253
|
+
async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
1007
1254
|
llm = await self._workflow_builder.get_llm(llm_name, wrapper_type)
|
|
1008
1255
|
|
|
1009
1256
|
self._dependencies.add_llm(llm_name)
|
|
@@ -1015,11 +1262,11 @@ class ChildBuilder(Builder):
|
|
|
1015
1262
|
return self._workflow_builder.get_llm_config(llm_name)
|
|
1016
1263
|
|
|
1017
1264
|
@override
|
|
1018
|
-
async def add_embedder(self, name: str, config: EmbedderBaseConfig):
|
|
1019
|
-
|
|
1265
|
+
async def add_embedder(self, name: str, config: EmbedderBaseConfig) -> None:
|
|
1266
|
+
await self._workflow_builder.add_embedder(name, config)
|
|
1020
1267
|
|
|
1021
1268
|
@override
|
|
1022
|
-
async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str):
|
|
1269
|
+
async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
1023
1270
|
embedder = await self._workflow_builder.get_embedder(embedder_name, wrapper_type)
|
|
1024
1271
|
|
|
1025
1272
|
self._dependencies.add_embedder(embedder_name)
|
|
@@ -1035,11 +1282,11 @@ class ChildBuilder(Builder):
|
|
|
1035
1282
|
return await self._workflow_builder.add_memory_client(name, config)
|
|
1036
1283
|
|
|
1037
1284
|
@override
|
|
1038
|
-
def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
|
1285
|
+
async def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
|
1039
1286
|
"""
|
|
1040
1287
|
Return the instantiated memory client for the given name.
|
|
1041
1288
|
"""
|
|
1042
|
-
memory_client = self._workflow_builder.get_memory_client(memory_name)
|
|
1289
|
+
memory_client = await self._workflow_builder.get_memory_client(memory_name)
|
|
1043
1290
|
|
|
1044
1291
|
self._dependencies.add_memory_client(memory_name)
|
|
1045
1292
|
|
|
@@ -1069,8 +1316,9 @@ class ChildBuilder(Builder):
|
|
|
1069
1316
|
return self._workflow_builder.get_object_store_config(object_store_name)
|
|
1070
1317
|
|
|
1071
1318
|
@override
|
|
1072
|
-
|
|
1073
|
-
|
|
1319
|
+
@experimental(feature_name="TTC")
|
|
1320
|
+
async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
|
|
1321
|
+
await self._workflow_builder.add_ttc_strategy(name, config)
|
|
1074
1322
|
|
|
1075
1323
|
@override
|
|
1076
1324
|
async def get_ttc_strategy(self,
|
|
@@ -1091,11 +1339,11 @@ class ChildBuilder(Builder):
|
|
|
1091
1339
|
stage_type=stage_type)
|
|
1092
1340
|
|
|
1093
1341
|
@override
|
|
1094
|
-
async def add_retriever(self, name: str, config: RetrieverBaseConfig):
|
|
1095
|
-
|
|
1342
|
+
async def add_retriever(self, name: str, config: RetrieverBaseConfig) -> None:
|
|
1343
|
+
await self._workflow_builder.add_retriever(name, config)
|
|
1096
1344
|
|
|
1097
1345
|
@override
|
|
1098
|
-
async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None):
|
|
1346
|
+
async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None) -> Retriever:
|
|
1099
1347
|
if not wrapper_type:
|
|
1100
1348
|
return await self._workflow_builder.get_retriever(retriever_name=retriever_name)
|
|
1101
1349
|
return await self._workflow_builder.get_retriever(retriever_name=retriever_name, wrapper_type=wrapper_type)
|
|
@@ -1111,3 +1359,7 @@ class ChildBuilder(Builder):
|
|
|
1111
1359
|
@override
|
|
1112
1360
|
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1113
1361
|
return self._workflow_builder.get_function_dependencies(fn_name)
|
|
1362
|
+
|
|
1363
|
+
@override
|
|
1364
|
+
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1365
|
+
return self._workflow_builder.get_function_group_dependencies(fn_name)
|
|
@@ -84,7 +84,7 @@ class LayeredConfig:
|
|
|
84
84
|
if lower_value not in ['true', 'false']:
|
|
85
85
|
raise ValueError(f"Boolean value must be 'true' or 'false', got '{value}'")
|
|
86
86
|
value = lower_value == 'true'
|
|
87
|
-
elif isinstance(original_value,
|
|
87
|
+
elif isinstance(original_value, int | float):
|
|
88
88
|
value = type(original_value)(value)
|
|
89
89
|
elif isinstance(original_value, list):
|
|
90
90
|
value = [v.strip() for v in value.split(',')]
|