nvidia-nat 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (257) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +27 -18
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +81 -50
  7. nat/agent/react_agent/register.py +59 -40
  8. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +327 -149
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +64 -46
  13. nat/agent/tool_calling_agent/agent.py +152 -29
  14. nat/agent/tool_calling_agent/register.py +61 -38
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +10 -6
  24. nat/builder/context.py +70 -18
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/intermediate_step_manager.py +6 -2
  32. nat/builder/user_interaction_manager.py +2 -2
  33. nat/builder/workflow.py +13 -1
  34. nat/builder/workflow_builder.py +327 -79
  35. nat/cli/cli_utils/config_override.py +2 -2
  36. nat/cli/commands/evaluate.py +1 -1
  37. nat/cli/commands/info/info.py +16 -6
  38. nat/cli/commands/info/list_channels.py +1 -1
  39. nat/cli/commands/info/list_components.py +7 -8
  40. nat/cli/commands/mcp/__init__.py +14 -0
  41. nat/cli/commands/mcp/mcp.py +986 -0
  42. nat/cli/commands/object_store/__init__.py +14 -0
  43. nat/cli/commands/object_store/object_store.py +227 -0
  44. nat/cli/commands/optimize.py +90 -0
  45. nat/cli/commands/registry/publish.py +2 -2
  46. nat/cli/commands/registry/pull.py +2 -2
  47. nat/cli/commands/registry/remove.py +2 -2
  48. nat/cli/commands/registry/search.py +15 -17
  49. nat/cli/commands/start.py +16 -5
  50. nat/cli/commands/uninstall.py +1 -1
  51. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  52. nat/cli/commands/workflow/templates/pyproject.toml.j2 +5 -2
  53. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  54. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  55. nat/cli/commands/workflow/workflow_commands.py +105 -19
  56. nat/cli/entrypoint.py +17 -11
  57. nat/cli/main.py +3 -0
  58. nat/cli/register_workflow.py +38 -4
  59. nat/cli/type_registry.py +79 -10
  60. nat/control_flow/__init__.py +0 -0
  61. nat/control_flow/register.py +20 -0
  62. nat/control_flow/router_agent/__init__.py +0 -0
  63. nat/control_flow/router_agent/agent.py +329 -0
  64. nat/control_flow/router_agent/prompt.py +48 -0
  65. nat/control_flow/router_agent/register.py +91 -0
  66. nat/control_flow/sequential_executor.py +166 -0
  67. nat/data_models/agent.py +34 -0
  68. nat/data_models/api_server.py +196 -67
  69. nat/data_models/authentication.py +23 -9
  70. nat/data_models/common.py +1 -1
  71. nat/data_models/component.py +2 -0
  72. nat/data_models/component_ref.py +11 -0
  73. nat/data_models/config.py +42 -18
  74. nat/data_models/dataset_handler.py +1 -1
  75. nat/data_models/discovery_metadata.py +4 -4
  76. nat/data_models/evaluate.py +4 -1
  77. nat/data_models/function.py +34 -0
  78. nat/data_models/function_dependencies.py +14 -6
  79. nat/data_models/gated_field_mixin.py +242 -0
  80. nat/data_models/intermediate_step.py +3 -3
  81. nat/data_models/optimizable.py +119 -0
  82. nat/data_models/optimizer.py +149 -0
  83. nat/data_models/span.py +41 -3
  84. nat/data_models/swe_bench_model.py +1 -1
  85. nat/data_models/temperature_mixin.py +44 -0
  86. nat/data_models/thinking_mixin.py +86 -0
  87. nat/data_models/top_p_mixin.py +44 -0
  88. nat/embedder/azure_openai_embedder.py +46 -0
  89. nat/embedder/nim_embedder.py +1 -1
  90. nat/embedder/openai_embedder.py +2 -3
  91. nat/embedder/register.py +1 -1
  92. nat/eval/config.py +3 -1
  93. nat/eval/dataset_handler/dataset_handler.py +71 -7
  94. nat/eval/evaluate.py +86 -31
  95. nat/eval/evaluator/base_evaluator.py +1 -1
  96. nat/eval/evaluator/evaluator_model.py +13 -0
  97. nat/eval/intermediate_step_adapter.py +1 -1
  98. nat/eval/rag_evaluator/evaluate.py +9 -6
  99. nat/eval/rag_evaluator/register.py +3 -3
  100. nat/eval/register.py +4 -1
  101. nat/eval/remote_workflow.py +3 -3
  102. nat/eval/runtime_evaluator/__init__.py +14 -0
  103. nat/eval/runtime_evaluator/evaluate.py +123 -0
  104. nat/eval/runtime_evaluator/register.py +100 -0
  105. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  106. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  107. nat/eval/trajectory_evaluator/register.py +1 -1
  108. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  109. nat/eval/utils/eval_trace_ctx.py +89 -0
  110. nat/eval/utils/weave_eval.py +18 -9
  111. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  112. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  113. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  114. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  115. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  116. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  117. nat/experimental/test_time_compute/register.py +0 -1
  118. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  119. nat/front_ends/console/authentication_flow_handler.py +82 -30
  120. nat/front_ends/console/console_front_end_plugin.py +19 -7
  121. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  122. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  123. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  124. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  125. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  126. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  127. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +455 -282
  128. nat/front_ends/fastapi/job_store.py +518 -99
  129. nat/front_ends/fastapi/main.py +11 -19
  130. nat/front_ends/fastapi/message_handler.py +74 -50
  131. nat/front_ends/fastapi/message_validator.py +20 -21
  132. nat/front_ends/fastapi/response_helpers.py +4 -4
  133. nat/front_ends/fastapi/step_adaptor.py +2 -2
  134. nat/front_ends/fastapi/utils.py +57 -0
  135. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  136. nat/front_ends/mcp/mcp_front_end_config.py +47 -3
  137. nat/front_ends/mcp/mcp_front_end_plugin.py +48 -13
  138. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +120 -8
  139. nat/front_ends/mcp/tool_converter.py +44 -14
  140. nat/front_ends/register.py +0 -1
  141. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  142. nat/llm/aws_bedrock_llm.py +24 -12
  143. nat/llm/azure_openai_llm.py +57 -0
  144. nat/llm/litellm_llm.py +69 -0
  145. nat/llm/nim_llm.py +20 -8
  146. nat/llm/openai_llm.py +14 -6
  147. nat/llm/register.py +5 -1
  148. nat/llm/utils/env_config_value.py +2 -3
  149. nat/llm/utils/thinking.py +215 -0
  150. nat/meta/pypi.md +9 -9
  151. nat/object_store/register.py +0 -1
  152. nat/observability/exporter/base_exporter.py +3 -3
  153. nat/observability/exporter/file_exporter.py +1 -1
  154. nat/observability/exporter/processing_exporter.py +309 -81
  155. nat/observability/exporter/span_exporter.py +35 -15
  156. nat/observability/exporter_manager.py +7 -7
  157. nat/observability/mixin/file_mixin.py +7 -7
  158. nat/observability/mixin/redaction_config_mixin.py +42 -0
  159. nat/observability/mixin/tagging_config_mixin.py +62 -0
  160. nat/observability/mixin/type_introspection_mixin.py +420 -107
  161. nat/observability/processor/batching_processor.py +5 -7
  162. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  163. nat/observability/processor/processor.py +3 -0
  164. nat/observability/processor/processor_factory.py +70 -0
  165. nat/observability/processor/redaction/__init__.py +24 -0
  166. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  167. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  168. nat/observability/processor/redaction/redaction_processor.py +177 -0
  169. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  170. nat/observability/processor/span_tagging_processor.py +68 -0
  171. nat/observability/register.py +22 -4
  172. nat/profiler/calc/calc_runner.py +3 -4
  173. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  174. nat/profiler/callbacks/langchain_callback_handler.py +14 -7
  175. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  176. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  177. nat/profiler/data_frame_row.py +1 -1
  178. nat/profiler/decorators/framework_wrapper.py +62 -13
  179. nat/profiler/decorators/function_tracking.py +160 -3
  180. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  181. nat/profiler/forecasting/models/linear_model.py +1 -1
  182. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  183. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  184. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  185. nat/profiler/inference_optimization/data_models.py +3 -3
  186. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  187. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  188. nat/profiler/parameter_optimization/__init__.py +0 -0
  189. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  190. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  191. nat/profiler/parameter_optimization/parameter_optimizer.py +164 -0
  192. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  193. nat/profiler/parameter_optimization/pareto_visualizer.py +395 -0
  194. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  195. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  196. nat/profiler/profile_runner.py +14 -9
  197. nat/profiler/utils.py +4 -2
  198. nat/registry_handlers/local/local_handler.py +2 -2
  199. nat/registry_handlers/package_utils.py +1 -2
  200. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  201. nat/registry_handlers/register.py +3 -4
  202. nat/registry_handlers/rest/rest_handler.py +12 -13
  203. nat/retriever/milvus/retriever.py +2 -2
  204. nat/retriever/nemo_retriever/retriever.py +1 -1
  205. nat/retriever/register.py +0 -1
  206. nat/runtime/loader.py +2 -2
  207. nat/runtime/runner.py +105 -8
  208. nat/runtime/session.py +69 -8
  209. nat/settings/global_settings.py +16 -5
  210. nat/tool/chat_completion.py +5 -2
  211. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  212. nat/tool/datetime_tools.py +49 -9
  213. nat/tool/document_search.py +2 -2
  214. nat/tool/github_tools.py +450 -0
  215. nat/tool/memory_tools/add_memory_tool.py +3 -3
  216. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  217. nat/tool/memory_tools/get_memory_tool.py +4 -4
  218. nat/tool/nvidia_rag.py +1 -1
  219. nat/tool/register.py +2 -9
  220. nat/tool/retriever.py +3 -2
  221. nat/utils/callable_utils.py +70 -0
  222. nat/utils/data_models/schema_validator.py +3 -3
  223. nat/utils/decorators.py +210 -0
  224. nat/utils/exception_handlers/automatic_retries.py +104 -51
  225. nat/utils/exception_handlers/schemas.py +1 -1
  226. nat/utils/io/yaml_tools.py +2 -2
  227. nat/utils/log_levels.py +25 -0
  228. nat/utils/reactive/base/observable_base.py +2 -2
  229. nat/utils/reactive/base/observer_base.py +1 -1
  230. nat/utils/reactive/observable.py +2 -2
  231. nat/utils/reactive/observer.py +4 -4
  232. nat/utils/reactive/subscription.py +1 -1
  233. nat/utils/settings/global_settings.py +6 -8
  234. nat/utils/type_converter.py +12 -3
  235. nat/utils/type_utils.py +9 -5
  236. nvidia_nat-1.3.0.dist-info/METADATA +195 -0
  237. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/RECORD +244 -200
  238. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/entry_points.txt +1 -0
  239. nat/cli/commands/info/list_mcp.py +0 -304
  240. nat/tool/github_tools/create_github_commit.py +0 -133
  241. nat/tool/github_tools/create_github_issue.py +0 -87
  242. nat/tool/github_tools/create_github_pr.py +0 -106
  243. nat/tool/github_tools/get_github_file.py +0 -106
  244. nat/tool/github_tools/get_github_issue.py +0 -166
  245. nat/tool/github_tools/get_github_pr.py +0 -256
  246. nat/tool/github_tools/update_github_issue.py +0 -100
  247. nat/tool/mcp/exceptions.py +0 -142
  248. nat/tool/mcp/mcp_client.py +0 -255
  249. nat/tool/mcp/mcp_tool.py +0 -96
  250. nat/utils/exception_handlers/mcp.py +0 -211
  251. nvidia_nat-1.2.1.dist-info/METADATA +0 -365
  252. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  253. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  254. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/WHEEL +0 -0
  255. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  256. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE.md +0 -0
  257. {nvidia_nat-1.2.1.dist-info → nvidia_nat-1.3.0.dist-info}/top_level.txt +0 -0
@@ -13,13 +13,17 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
16
17
  import dataclasses
17
18
  import inspect
18
19
  import logging
20
+ import typing
19
21
  import warnings
22
+ from collections.abc import Sequence
20
23
  from contextlib import AbstractAsyncContextManager
21
24
  from contextlib import AsyncExitStack
22
25
  from contextlib import asynccontextmanager
26
+ from typing import cast
23
27
 
24
28
  from nat.authentication.interfaces import AuthProviderBase
25
29
  from nat.builder.builder import Builder
@@ -31,6 +35,7 @@ from nat.builder.context import ContextState
31
35
  from nat.builder.embedder import EmbedderProviderInfo
32
36
  from nat.builder.framework_enum import LLMFrameworkEnum
33
37
  from nat.builder.function import Function
38
+ from nat.builder.function import FunctionGroup
34
39
  from nat.builder.function import LambdaFunction
35
40
  from nat.builder.function_info import FunctionInfo
36
41
  from nat.builder.llm import LLMProviderInfo
@@ -42,6 +47,7 @@ from nat.data_models.authentication import AuthProviderBaseConfig
42
47
  from nat.data_models.component import ComponentGroup
43
48
  from nat.data_models.component_ref import AuthenticationRef
44
49
  from nat.data_models.component_ref import EmbedderRef
50
+ from nat.data_models.component_ref import FunctionGroupRef
45
51
  from nat.data_models.component_ref import FunctionRef
46
52
  from nat.data_models.component_ref import LLMRef
47
53
  from nat.data_models.component_ref import MemoryRef
@@ -52,6 +58,7 @@ from nat.data_models.config import Config
52
58
  from nat.data_models.config import GeneralConfig
53
59
  from nat.data_models.embedder import EmbedderBaseConfig
54
60
  from nat.data_models.function import FunctionBaseConfig
61
+ from nat.data_models.function import FunctionGroupBaseConfig
55
62
  from nat.data_models.function_dependencies import FunctionDependencies
56
63
  from nat.data_models.llm import LLMBaseConfig
57
64
  from nat.data_models.memory import MemoryBaseConfig
@@ -68,6 +75,7 @@ from nat.object_store.interfaces import ObjectStore
68
75
  from nat.observability.exporter.base_exporter import BaseExporter
69
76
  from nat.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
70
77
  from nat.profiler.utils import detect_llm_frameworks_in_build_fn
78
+ from nat.retriever.interface import Retriever
71
79
  from nat.utils.type_utils import override
72
80
 
73
81
  logger = logging.getLogger(__name__)
@@ -85,6 +93,12 @@ class ConfiguredFunction:
85
93
  instance: Function
86
94
 
87
95
 
96
+ @dataclasses.dataclass
97
+ class ConfiguredFunctionGroup:
98
+ config: FunctionGroupBaseConfig
99
+ instance: FunctionGroup
100
+
101
+
88
102
  @dataclasses.dataclass
89
103
  class ConfiguredLLM:
90
104
  config: LLMBaseConfig
@@ -127,7 +141,6 @@ class ConfiguredTTCStrategy:
127
141
  instance: StrategyBase
128
142
 
129
143
 
130
- # pylint: disable=too-many-public-methods
131
144
  class WorkflowBuilder(Builder, AbstractAsyncContextManager):
132
145
 
133
146
  def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
@@ -143,9 +156,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
143
156
  self._registry = registry
144
157
 
145
158
  self._logging_handlers: dict[str, logging.Handler] = {}
159
+ self._removed_root_handlers: list[tuple[logging.Handler, int]] = []
146
160
  self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
147
161
 
148
162
  self._functions: dict[str, ConfiguredFunction] = {}
163
+ self._function_groups: dict[str, ConfiguredFunctionGroup] = {}
149
164
  self._workflow: ConfiguredFunction | None = None
150
165
 
151
166
  self._llms: dict[str, ConfiguredLLM] = {}
@@ -162,7 +177,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
162
177
 
163
178
  # Create a mapping to track function name -> other function names it depends on
164
179
  self.function_dependencies: dict[str, FunctionDependencies] = {}
180
+ self.function_group_dependencies: dict[str, FunctionDependencies] = {}
165
181
  self.current_function_building: str | None = None
182
+ self.current_function_group_building: str | None = None
166
183
 
167
184
  async def __aenter__(self):
168
185
 
@@ -171,6 +188,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
171
188
  # Get the telemetry info from the config
172
189
  telemetry_config = self.general_config.telemetry
173
190
 
191
+ # If we have logging configuration, we need to manage the root logger properly
192
+ root_logger = logging.getLogger()
193
+
194
+ # Collect configured handler types to determine if we need to adjust existing handlers
195
+ # This is somewhat of a hack by inspecting the class name of the config object
196
+ has_console_handler = any(
197
+ hasattr(config, "__class__") and "console" in config.__class__.__name__.lower()
198
+ for config in telemetry_config.logging.values())
199
+
174
200
  for key, logging_config in telemetry_config.logging.items():
175
201
  # Use the same pattern as tracing, but for logging
176
202
  logging_info = self._registry.get_logging_method(type(logging_config))
@@ -184,7 +210,31 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
184
210
  self._logging_handlers[key] = handler
185
211
 
186
212
  # Now attach to NAT's root logger
187
- logging.getLogger().addHandler(handler)
213
+ root_logger.addHandler(handler)
214
+
215
+ # If we added logging handlers, manage existing handlers appropriately
216
+ if self._logging_handlers:
217
+ min_handler_level = min((handler.level for handler in root_logger.handlers), default=logging.CRITICAL)
218
+
219
+ # Ensure the root logger level allows messages through
220
+ root_logger.level = max(root_logger.level, min_handler_level)
221
+
222
+ # If a console handler is configured, adjust or remove default CLI handlers
223
+ # to avoid duplicate output while preserving workflow visibility
224
+ if has_console_handler:
225
+ # Remove existing StreamHandlers that are not the newly configured ones
226
+ for handler in root_logger.handlers[:]:
227
+ if type(handler) is logging.StreamHandler and handler not in self._logging_handlers.values():
228
+ self._removed_root_handlers.append((handler, handler.level))
229
+ root_logger.removeHandler(handler)
230
+ else:
231
+ # No console handler configured, but adjust existing handler levels
232
+ # to respect the minimum configured level for file/other handlers
233
+ for handler in root_logger.handlers[:]:
234
+ if type(handler) is logging.StreamHandler:
235
+ old_level = handler.level
236
+ handler.setLevel(min_handler_level)
237
+ self._removed_root_handlers.append((handler, old_level))
188
238
 
189
239
  # Add the telemetry exporters
190
240
  for key, telemetry_exporter_config in telemetry_config.tracing.items():
@@ -196,12 +246,21 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
196
246
 
197
247
  assert self._exit_stack is not None, "Exit stack not initialized"
198
248
 
199
- for _, handler in self._logging_handlers.items():
200
- logging.getLogger().removeHandler(handler)
249
+ root_logger = logging.getLogger()
250
+
251
+ # Remove custom logging handlers
252
+ for handler in self._logging_handlers.values():
253
+ root_logger.removeHandler(handler)
254
+
255
+ # Restore original handlers and their levels
256
+ for handler, old_level in self._removed_root_handlers:
257
+ if handler not in root_logger.handlers:
258
+ root_logger.addHandler(handler)
259
+ handler.setLevel(old_level)
201
260
 
202
261
  await self._exit_stack.__aexit__(*exc_details)
203
262
 
204
- def build(self, entry_function: str | None = None) -> Workflow:
263
+ async def build(self, entry_function: str | None = None) -> Workflow:
205
264
  """
206
265
  Creates an instance of a workflow object using the added components and the desired entry function.
207
266
 
@@ -225,12 +284,32 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
225
284
  if (self._workflow is None):
226
285
  raise ValueError("Must set a workflow before building")
227
286
 
287
+ # Set of all functions which are "included" by function groups
288
+ included_functions = set()
289
+ # Dictionary of function configs
290
+ function_configs = dict()
291
+ # Dictionary of function group configs
292
+ function_group_configs = dict()
293
+ # Dictionary of function instances
294
+ function_instances = dict()
295
+ # Dictionary of function group instances
296
+ function_group_instances = dict()
297
+
298
+ for k, v in self._function_groups.items():
299
+ included_functions.update((await v.instance.get_included_functions()).keys())
300
+ function_group_configs[k] = v.config
301
+ function_group_instances[k] = v.instance
302
+
303
+ # Function configs need to be restricted to only the functions that are not in a function group
304
+ for k, v in self._functions.items():
305
+ if k not in included_functions:
306
+ function_configs[k] = v.config
307
+ function_instances[k] = v.instance
308
+
228
309
  # Build the config from the added objects
229
310
  config = Config(general=self.general_config,
230
- functions={
231
- k: v.config
232
- for k, v in self._functions.items()
233
- },
311
+ functions=function_configs,
312
+ function_groups=function_group_configs,
234
313
  workflow=self._workflow.config,
235
314
  llms={
236
315
  k: v.config
@@ -260,14 +339,12 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
260
339
  if (entry_function is None):
261
340
  entry_fn_obj = self.get_workflow()
262
341
  else:
263
- entry_fn_obj = self.get_function(entry_function)
342
+ entry_fn_obj = await self.get_function(entry_function)
264
343
 
265
344
  workflow = Workflow.from_entry_fn(config=config,
266
345
  entry_fn=entry_fn_obj,
267
- functions={
268
- k: v.instance
269
- for k, v in self._functions.items()
270
- },
346
+ functions=function_instances,
347
+ function_groups=function_group_instances,
271
348
  llms={
272
349
  k: v.instance
273
350
  for k, v in self._llms.items()
@@ -348,11 +425,53 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
348
425
 
349
426
  return ConfiguredFunction(config=config, instance=build_result)
350
427
 
428
+ async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
429
+ """Build a function group from the provided configuration.
430
+
431
+ Args:
432
+ name: The name of the function group
433
+ config: The function group configuration
434
+
435
+ Returns:
436
+ ConfiguredFunctionGroup: The built function group
437
+
438
+ Raises:
439
+ ValueError: If the function group builder returns invalid results
440
+ """
441
+ registration = self._registry.get_function_group(type(config))
442
+
443
+ inner_builder = ChildBuilder(self)
444
+
445
+ # Build the function group - use the same wrapping pattern as _build_function
446
+ llms = {k: v.instance for k, v in self._llms.items()}
447
+ function_frameworks = detect_llm_frameworks_in_build_fn(registration)
448
+
449
+ build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
450
+
451
+ # Set the currently building function group so the ChildBuilder can track dependencies
452
+ self.current_function_group_building = config.type
453
+ # Empty set of dependencies for the current function group
454
+ self.function_group_dependencies[config.type] = FunctionDependencies()
455
+
456
+ build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
457
+
458
+ self.function_group_dependencies[name] = inner_builder.dependencies
459
+
460
+ if not isinstance(build_result, FunctionGroup):
461
+ raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
462
+ f"Got {type(build_result)}")
463
+
464
+ # set the instance name for the function group based on the workflow-provided name
465
+ build_result.set_instance_name(name)
466
+ return ConfiguredFunctionGroup(config=config, instance=build_result)
467
+
351
468
  @override
352
469
  async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
470
+ if isinstance(name, FunctionRef):
471
+ name = str(name)
353
472
 
354
- if (name in self._functions):
355
- raise ValueError(f"Function `{name}` already exists in the list of functions")
473
+ if (name in self._functions or name in self._function_groups):
474
+ raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
356
475
 
357
476
  build_result = await self._build_function(name=name, config=config)
358
477
 
@@ -361,20 +480,67 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
361
480
  return build_result.instance
362
481
 
363
482
  @override
364
- def get_function(self, name: str | FunctionRef) -> Function:
483
+ async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
484
+ if isinstance(name, FunctionGroupRef):
485
+ name = str(name)
486
+
487
+ if (name in self._function_groups or name in self._functions):
488
+ raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
489
+
490
+ # Build the function group
491
+ build_result = await self._build_function_group(name=name, config=config)
492
+
493
+ self._function_groups[name] = build_result
494
+
495
+ # If the function group exposes functions, add them to the global function registry
496
+ # If the function group exposes functions, record and add them to the registry
497
+ included_functions = await build_result.instance.get_included_functions()
498
+ for k in included_functions:
499
+ if k in self._functions:
500
+ raise ValueError(f"Exposed function `{k}` from group `{name}` conflicts with an existing function")
501
+ self._functions.update({
502
+ k: ConfiguredFunction(config=v.config, instance=v)
503
+ for k, v in included_functions.items()
504
+ })
365
505
 
506
+ return build_result.instance
507
+
508
+ @override
509
+ async def get_function(self, name: str | FunctionRef) -> Function:
510
+ if isinstance(name, FunctionRef):
511
+ name = str(name)
366
512
  if name not in self._functions:
367
513
  raise ValueError(f"Function `{name}` not found")
368
514
 
369
515
  return self._functions[name].instance
370
516
 
517
+ @override
518
+ async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
519
+ if isinstance(name, FunctionGroupRef):
520
+ name = str(name)
521
+ if name not in self._function_groups:
522
+ raise ValueError(f"Function group `{name}` not found")
523
+
524
+ return self._function_groups[name].instance
525
+
371
526
  @override
372
527
  def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
528
+ if isinstance(name, FunctionRef):
529
+ name = str(name)
373
530
  if name not in self._functions:
374
531
  raise ValueError(f"Function `{name}` not found")
375
532
 
376
533
  return self._functions[name].config
377
534
 
535
+ @override
536
+ def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
537
+ if isinstance(name, FunctionGroupRef):
538
+ name = str(name)
539
+ if name not in self._function_groups:
540
+ raise ValueError(f"Function group `{name}` not found")
541
+
542
+ return self._function_groups[name].config
543
+
378
544
  @override
379
545
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
380
546
 
@@ -404,16 +570,57 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
404
570
 
405
571
  @override
406
572
  def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
573
+ if isinstance(fn_name, FunctionRef):
574
+ fn_name = str(fn_name)
407
575
  return self.function_dependencies[fn_name]
408
576
 
409
577
  @override
410
- def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
578
+ def get_function_group_dependencies(self, fn_name: str | FunctionGroupRef) -> FunctionDependencies:
579
+ if isinstance(fn_name, FunctionGroupRef):
580
+ fn_name = str(fn_name)
581
+ return self.function_group_dependencies[fn_name]
411
582
 
583
+ @override
584
+ async def get_tools(self,
585
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
586
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
587
+
588
+ unique = set(tool_names)
589
+ if len(unique) != len(tool_names):
590
+ raise ValueError("Tool names must be unique")
591
+
592
+ async def _get_tools(n: str | FunctionRef | FunctionGroupRef):
593
+ tools = []
594
+ is_function_group_ref = isinstance(n, FunctionGroupRef)
595
+ if isinstance(n, FunctionRef) or is_function_group_ref:
596
+ n = str(n)
597
+ if n not in self._function_groups:
598
+ # the passed tool name is probably a function, but first check if it's a function group
599
+ if is_function_group_ref:
600
+ raise ValueError(f"Function group `{n}` not found in the list of function groups")
601
+ tools.append(await self.get_tool(n, wrapper_type))
602
+ else:
603
+ tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
604
+ current_function_group = self._function_groups[n]
605
+ for fn_name, fn_instance in (await current_function_group.instance.get_accessible_functions()).items():
606
+ try:
607
+ tools.append(tool_wrapper_reg.build_fn(fn_name, fn_instance, self))
608
+ except Exception:
609
+ logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
610
+ raise
611
+ return tools
612
+
613
+ tool_lists = await asyncio.gather(*[_get_tools(n) for n in tool_names])
614
+ # Flatten the list of lists into a single list
615
+ return [tool for tools in tool_lists for tool in tools]
616
+
617
+ @override
618
+ async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
619
+ if isinstance(fn_name, FunctionRef):
620
+ fn_name = str(fn_name)
412
621
  if fn_name not in self._functions:
413
622
  raise ValueError(f"Function `{fn_name}` not found in list of functions")
414
-
415
623
  fn = self._functions[fn_name]
416
-
417
624
  try:
418
625
  # Using the registry, get the tool wrapper for the requested framework
419
626
  tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
@@ -421,11 +628,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
421
628
  # Wrap in the correct wrapper
422
629
  return tool_wrapper_reg.build_fn(fn_name, fn.instance, self)
423
630
  except Exception as e:
424
- logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
425
- raise e
631
+ logger.error("Error fetching tool `%s`: %s", fn_name, e)
632
+ raise
426
633
 
427
634
  @override
428
- async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
635
+ async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> None:
429
636
 
430
637
  if (name in self._llms):
431
638
  raise ValueError(f"LLM `{name}` already exists in the list of LLMs")
@@ -437,11 +644,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
437
644
 
438
645
  self._llms[name] = ConfiguredLLM(config=config, instance=info_obj)
439
646
  except Exception as e:
440
- logger.error("Error adding llm `%s` with config `%s`", name, config, exc_info=True)
441
- raise e
647
+ logger.error("Error adding llm `%s` with config `%s`: %s", name, config, e)
648
+ raise
442
649
 
443
650
  @override
444
- async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str):
651
+ async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
445
652
 
446
653
  if (llm_name not in self._llms):
447
654
  raise ValueError(f"LLM `{llm_name}` not found")
@@ -458,8 +665,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
458
665
  # Return a frameworks specific client
459
666
  return client
460
667
  except Exception as e:
461
- logger.error("Error getting llm `%s` with wrapper `%s`", llm_name, wrapper_type, exc_info=True)
462
- raise e
668
+ logger.error("Error getting llm `%s` with wrapper `%s`: %s", llm_name, wrapper_type, e)
669
+ raise
463
670
 
464
671
  @override
465
672
  def get_llm_config(self, llm_name: str | LLMRef) -> LLMBaseConfig:
@@ -509,8 +716,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
509
716
 
510
717
  return info_obj
511
718
  except Exception as e:
512
- logger.error("Error adding authentication `%s` with config `%s`", name, config, exc_info=True)
513
- raise e
719
+ logger.error("Error adding authentication `%s` with config `%s`: %s", name, config, e)
720
+ raise
514
721
 
515
722
  @override
516
723
  async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
@@ -541,7 +748,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
541
748
  return self._auth_providers[auth_provider_name].instance
542
749
 
543
750
  @override
544
- async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
751
+ async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
545
752
 
546
753
  if (name in self._embedders):
547
754
  raise ValueError(f"Embedder `{name}` already exists in the list of embedders")
@@ -553,9 +760,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
553
760
 
554
761
  self._embedders[name] = ConfiguredEmbedder(config=config, instance=info_obj)
555
762
  except Exception as e:
556
- logger.error("Error adding embedder `%s` with config `%s`", name, config, exc_info=True)
557
-
558
- raise e
763
+ logger.error("Error adding embedder `%s` with config `%s`: %s", name, config, e)
764
+ raise
559
765
 
560
766
  @override
561
767
  async def get_embedder(self, embedder_name: str | EmbedderRef, wrapper_type: LLMFrameworkEnum | str):
@@ -575,8 +781,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
575
781
  # Return a frameworks specific client
576
782
  return client
577
783
  except Exception as e:
578
- logger.error("Error getting embedder `%s` with wrapper `%s`", embedder_name, wrapper_type, exc_info=True)
579
- raise e
784
+ logger.error("Error getting embedder `%s` with wrapper `%s`: %s", embedder_name, wrapper_type, e)
785
+ raise
580
786
 
581
787
  @override
582
788
  def get_embedder_config(self, embedder_name: str | EmbedderRef) -> EmbedderBaseConfig:
@@ -602,7 +808,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
602
808
  return info_obj
603
809
 
604
810
  @override
605
- def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
811
+ async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
606
812
  """
607
813
  Return the instantiated memory client for the given name.
608
814
  """
@@ -648,7 +854,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
648
854
  return self._object_stores[object_store_name].config
649
855
 
650
856
  @override
651
- async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
857
+ async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
652
858
 
653
859
  if (name in self._retrievers):
654
860
  raise ValueError(f"Retriever '{name}' already exists in the list of retrievers")
@@ -661,11 +867,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
661
867
  self._retrievers[name] = ConfiguredRetriever(config=config, instance=info_obj)
662
868
 
663
869
  except Exception as e:
664
- logger.error("Error adding retriever `%s` with config `%s`", name, config, exc_info=True)
665
-
666
- raise e
667
-
668
- # return info_obj
870
+ logger.error("Error adding retriever `%s` with config `%s`: %s", name, config, e)
871
+ raise
669
872
 
670
873
  @override
671
874
  async def get_retriever(self,
@@ -688,8 +891,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
688
891
  # Return a frameworks specific client
689
892
  return client
690
893
  except Exception as e:
691
- logger.error("Error getting retriever `%s` with wrapper `%s`", retriever_name, wrapper_type, exc_info=True)
692
- raise e
894
+ logger.error("Error getting retriever `%s` with wrapper `%s`: %s", retriever_name, wrapper_type, e)
895
+ raise
693
896
 
694
897
  @override
695
898
  async def get_retriever_config(self, retriever_name: str | RetrieverRef) -> RetrieverBaseConfig:
@@ -699,9 +902,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
699
902
 
700
903
  return self._retrievers[retriever_name].config
701
904
 
702
- @experimental(feature_name="TTC")
703
905
  @override
704
- async def add_ttc_strategy(self, name: str | str, config: TTCStrategyBaseConfig):
906
+ @experimental(feature_name="TTC")
907
+ async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig) -> None:
705
908
  if (name in self._ttc_strategies):
706
909
  raise ValueError(f"TTC strategy '{name}' already exists in the list of TTC strategies")
707
910
 
@@ -713,9 +916,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
713
916
  self._ttc_strategies[name] = ConfiguredTTCStrategy(config=config, instance=info_obj)
714
917
 
715
918
  except Exception as e:
716
- logger.error("Error adding TTC strategy `%s` with config `%s`", name, config, exc_info=True)
717
-
718
- raise e
919
+ logger.error("Error adding TTC strategy `%s` with config `%s`: %s", name, config, e)
920
+ raise
719
921
 
720
922
  @override
721
923
  async def get_ttc_strategy(self,
@@ -743,8 +945,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
743
945
 
744
946
  return instance
745
947
  except Exception as e:
746
- logger.error("Error getting TTC strategy `%s`", strategy_name, exc_info=True)
747
- raise e
948
+ logger.error("Error getting TTC strategy `%s`: %s", strategy_name, e)
949
+ raise
748
950
 
749
951
  @override
750
952
  async def get_ttc_strategy_config(self,
@@ -821,7 +1023,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
821
1023
  else:
822
1024
  logger.error("No remaining components to build")
823
1025
 
824
- logger.error("Original error:", exc_info=original_error)
1026
+ logger.error("Original error: %s", original_error, exc_info=True)
825
1027
 
826
1028
  def _log_build_failure_component(self,
827
1029
  failing_component: ComponentInstanceData,
@@ -889,29 +1091,40 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
889
1091
 
890
1092
  # Instantiate a the llm
891
1093
  if component_instance.component_group == ComponentGroup.LLMS:
892
- await self.add_llm(component_instance.name, component_instance.config)
1094
+ await self.add_llm(component_instance.name, cast(LLMBaseConfig, component_instance.config))
893
1095
  # Instantiate a the embedder
894
1096
  elif component_instance.component_group == ComponentGroup.EMBEDDERS:
895
- await self.add_embedder(component_instance.name, component_instance.config)
1097
+ await self.add_embedder(component_instance.name,
1098
+ cast(EmbedderBaseConfig, component_instance.config))
896
1099
  # Instantiate a memory client
897
1100
  elif component_instance.component_group == ComponentGroup.MEMORY:
898
- await self.add_memory_client(component_instance.name, component_instance.config)
899
- # Instantiate a object store client
1101
+ await self.add_memory_client(component_instance.name,
1102
+ cast(MemoryBaseConfig, component_instance.config))
1103
+ # Instantiate a object store client
900
1104
  elif component_instance.component_group == ComponentGroup.OBJECT_STORES:
901
- await self.add_object_store(component_instance.name, component_instance.config)
1105
+ await self.add_object_store(component_instance.name,
1106
+ cast(ObjectStoreBaseConfig, component_instance.config))
902
1107
  # Instantiate a retriever client
903
1108
  elif component_instance.component_group == ComponentGroup.RETRIEVERS:
904
- await self.add_retriever(component_instance.name, component_instance.config)
1109
+ await self.add_retriever(component_instance.name,
1110
+ cast(RetrieverBaseConfig, component_instance.config))
1111
+ # Instantiate a function group
1112
+ elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
1113
+ await self.add_function_group(component_instance.name,
1114
+ cast(FunctionGroupBaseConfig, component_instance.config))
905
1115
  # Instantiate a function
906
1116
  elif component_instance.component_group == ComponentGroup.FUNCTIONS:
907
1117
  # If the function is the root, set it as the workflow later
908
1118
  if (not component_instance.is_root):
909
- await self.add_function(component_instance.name, component_instance.config)
1119
+ await self.add_function(component_instance.name,
1120
+ cast(FunctionBaseConfig, component_instance.config))
910
1121
  elif component_instance.component_group == ComponentGroup.TTC_STRATEGIES:
911
- await self.add_ttc_strategy(component_instance.name, component_instance.config)
1122
+ await self.add_ttc_strategy(component_instance.name,
1123
+ cast(TTCStrategyBaseConfig, component_instance.config))
912
1124
 
913
1125
  elif component_instance.component_group == ComponentGroup.AUTHENTICATION:
914
- await self.add_auth_provider(component_instance.name, component_instance.config)
1126
+ await self.add_auth_provider(component_instance.name,
1127
+ cast(AuthProviderBaseConfig, component_instance.config))
915
1128
  else:
916
1129
  raise ValueError(f"Unknown component group {component_instance.component_group}")
917
1130
 
@@ -961,18 +1174,35 @@ class ChildBuilder(Builder):
961
1174
  return await self._workflow_builder.add_function(name, config)
962
1175
 
963
1176
  @override
964
- def get_function(self, name: str) -> Function:
1177
+ async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
1178
+ return await self._workflow_builder.add_function_group(name, config)
1179
+
1180
+ @override
1181
+ async def get_function(self, name: str) -> Function:
965
1182
  # If a function tries to get another function, we assume it uses it
966
- fn = self._workflow_builder.get_function(name)
1183
+ fn = await self._workflow_builder.get_function(name)
967
1184
 
968
1185
  self._dependencies.add_function(name)
969
1186
 
970
1187
  return fn
971
1188
 
1189
+ @override
1190
+ async def get_function_group(self, name: str) -> FunctionGroup:
1191
+ # If a function tries to get a function group, we assume it uses it
1192
+ function_group = await self._workflow_builder.get_function_group(name)
1193
+
1194
+ self._dependencies.add_function_group(name)
1195
+
1196
+ return function_group
1197
+
972
1198
  @override
973
1199
  def get_function_config(self, name: str) -> FunctionBaseConfig:
974
1200
  return self._workflow_builder.get_function_config(name)
975
1201
 
1202
+ @override
1203
+ def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
1204
+ return self._workflow_builder.get_function_group_config(name)
1205
+
976
1206
  @override
977
1207
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
978
1208
  return await self._workflow_builder.set_workflow(config)
@@ -986,20 +1216,33 @@ class ChildBuilder(Builder):
986
1216
  return self._workflow_builder.get_workflow_config()
987
1217
 
988
1218
  @override
989
- def get_tool(self, fn_name: str, wrapper_type: LLMFrameworkEnum | str):
1219
+ async def get_tools(self,
1220
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
1221
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
1222
+ tools = await self._workflow_builder.get_tools(tool_names, wrapper_type)
1223
+ for tool_name in tool_names:
1224
+ if tool_name in self._workflow_builder._function_groups:
1225
+ self._dependencies.add_function_group(tool_name)
1226
+ else:
1227
+ self._dependencies.add_function(tool_name)
1228
+ return tools
1229
+
1230
+ @override
1231
+ async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
990
1232
  # If a function tries to get another function as a tool, we assume it uses it
991
- fn = self._workflow_builder.get_tool(fn_name, wrapper_type)
1233
+ fn = await self._workflow_builder.get_tool(fn_name, wrapper_type)
992
1234
 
993
1235
  self._dependencies.add_function(fn_name)
994
1236
 
995
1237
  return fn
996
1238
 
997
1239
  @override
998
- async def add_llm(self, name: str, config: LLMBaseConfig):
1240
+ async def add_llm(self, name: str, config: LLMBaseConfig) -> None:
999
1241
  return await self._workflow_builder.add_llm(name, config)
1000
1242
 
1243
+ @experimental(feature_name="Authentication")
1001
1244
  @override
1002
- async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig):
1245
+ async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> AuthProviderBase:
1003
1246
  return await self._workflow_builder.add_auth_provider(name, config)
1004
1247
 
1005
1248
  @override
@@ -1007,7 +1250,7 @@ class ChildBuilder(Builder):
1007
1250
  return await self._workflow_builder.get_auth_provider(auth_provider_name)
1008
1251
 
1009
1252
  @override
1010
- async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str):
1253
+ async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
1011
1254
  llm = await self._workflow_builder.get_llm(llm_name, wrapper_type)
1012
1255
 
1013
1256
  self._dependencies.add_llm(llm_name)
@@ -1019,11 +1262,11 @@ class ChildBuilder(Builder):
1019
1262
  return self._workflow_builder.get_llm_config(llm_name)
1020
1263
 
1021
1264
  @override
1022
- async def add_embedder(self, name: str, config: EmbedderBaseConfig):
1023
- return await self._workflow_builder.add_embedder(name, config)
1265
+ async def add_embedder(self, name: str, config: EmbedderBaseConfig) -> None:
1266
+ await self._workflow_builder.add_embedder(name, config)
1024
1267
 
1025
1268
  @override
1026
- async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str):
1269
+ async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
1027
1270
  embedder = await self._workflow_builder.get_embedder(embedder_name, wrapper_type)
1028
1271
 
1029
1272
  self._dependencies.add_embedder(embedder_name)
@@ -1039,11 +1282,11 @@ class ChildBuilder(Builder):
1039
1282
  return await self._workflow_builder.add_memory_client(name, config)
1040
1283
 
1041
1284
  @override
1042
- def get_memory_client(self, memory_name: str) -> MemoryEditor:
1285
+ async def get_memory_client(self, memory_name: str) -> MemoryEditor:
1043
1286
  """
1044
1287
  Return the instantiated memory client for the given name.
1045
1288
  """
1046
- memory_client = self._workflow_builder.get_memory_client(memory_name)
1289
+ memory_client = await self._workflow_builder.get_memory_client(memory_name)
1047
1290
 
1048
1291
  self._dependencies.add_memory_client(memory_name)
1049
1292
 
@@ -1073,8 +1316,9 @@ class ChildBuilder(Builder):
1073
1316
  return self._workflow_builder.get_object_store_config(object_store_name)
1074
1317
 
1075
1318
  @override
1076
- async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig):
1077
- return await self._workflow_builder.add_ttc_strategy(name, config)
1319
+ @experimental(feature_name="TTC")
1320
+ async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
1321
+ await self._workflow_builder.add_ttc_strategy(name, config)
1078
1322
 
1079
1323
  @override
1080
1324
  async def get_ttc_strategy(self,
@@ -1095,11 +1339,11 @@ class ChildBuilder(Builder):
1095
1339
  stage_type=stage_type)
1096
1340
 
1097
1341
  @override
1098
- async def add_retriever(self, name: str, config: RetrieverBaseConfig):
1099
- return await self._workflow_builder.add_retriever(name, config)
1342
+ async def add_retriever(self, name: str, config: RetrieverBaseConfig) -> None:
1343
+ await self._workflow_builder.add_retriever(name, config)
1100
1344
 
1101
1345
  @override
1102
- async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None):
1346
+ async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None) -> Retriever:
1103
1347
  if not wrapper_type:
1104
1348
  return await self._workflow_builder.get_retriever(retriever_name=retriever_name)
1105
1349
  return await self._workflow_builder.get_retriever(retriever_name=retriever_name, wrapper_type=wrapper_type)
@@ -1115,3 +1359,7 @@ class ChildBuilder(Builder):
1115
1359
  @override
1116
1360
  def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
1117
1361
  return self._workflow_builder.get_function_dependencies(fn_name)
1362
+
1363
+ @override
1364
+ def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
1365
+ return self._workflow_builder.get_function_group_dependencies(fn_name)