nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
@@ -13,13 +13,17 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
16
17
  import dataclasses
17
18
  import inspect
18
19
  import logging
20
+ import typing
19
21
  import warnings
22
+ from collections.abc import Sequence
20
23
  from contextlib import AbstractAsyncContextManager
21
24
  from contextlib import AsyncExitStack
22
25
  from contextlib import asynccontextmanager
26
+ from typing import cast
23
27
 
24
28
  from nat.authentication.interfaces import AuthProviderBase
25
29
  from nat.builder.builder import Builder
@@ -31,6 +35,7 @@ from nat.builder.context import ContextState
31
35
  from nat.builder.embedder import EmbedderProviderInfo
32
36
  from nat.builder.framework_enum import LLMFrameworkEnum
33
37
  from nat.builder.function import Function
38
+ from nat.builder.function import FunctionGroup
34
39
  from nat.builder.function import LambdaFunction
35
40
  from nat.builder.function_info import FunctionInfo
36
41
  from nat.builder.llm import LLMProviderInfo
@@ -42,6 +47,7 @@ from nat.data_models.authentication import AuthProviderBaseConfig
42
47
  from nat.data_models.component import ComponentGroup
43
48
  from nat.data_models.component_ref import AuthenticationRef
44
49
  from nat.data_models.component_ref import EmbedderRef
50
+ from nat.data_models.component_ref import FunctionGroupRef
45
51
  from nat.data_models.component_ref import FunctionRef
46
52
  from nat.data_models.component_ref import LLMRef
47
53
  from nat.data_models.component_ref import MemoryRef
@@ -52,6 +58,7 @@ from nat.data_models.config import Config
52
58
  from nat.data_models.config import GeneralConfig
53
59
  from nat.data_models.embedder import EmbedderBaseConfig
54
60
  from nat.data_models.function import FunctionBaseConfig
61
+ from nat.data_models.function import FunctionGroupBaseConfig
55
62
  from nat.data_models.function_dependencies import FunctionDependencies
56
63
  from nat.data_models.llm import LLMBaseConfig
57
64
  from nat.data_models.memory import MemoryBaseConfig
@@ -68,6 +75,7 @@ from nat.object_store.interfaces import ObjectStore
68
75
  from nat.observability.exporter.base_exporter import BaseExporter
69
76
  from nat.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
70
77
  from nat.profiler.utils import detect_llm_frameworks_in_build_fn
78
+ from nat.retriever.interface import Retriever
71
79
  from nat.utils.type_utils import override
72
80
 
73
81
  logger = logging.getLogger(__name__)
@@ -85,6 +93,12 @@ class ConfiguredFunction:
85
93
  instance: Function
86
94
 
87
95
 
96
+ @dataclasses.dataclass
97
+ class ConfiguredFunctionGroup:
98
+ config: FunctionGroupBaseConfig
99
+ instance: FunctionGroup
100
+
101
+
88
102
  @dataclasses.dataclass
89
103
  class ConfiguredLLM:
90
104
  config: LLMBaseConfig
@@ -142,9 +156,11 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
142
156
  self._registry = registry
143
157
 
144
158
  self._logging_handlers: dict[str, logging.Handler] = {}
159
+ self._removed_root_handlers: list[tuple[logging.Handler, int]] = []
145
160
  self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
146
161
 
147
162
  self._functions: dict[str, ConfiguredFunction] = {}
163
+ self._function_groups: dict[str, ConfiguredFunctionGroup] = {}
148
164
  self._workflow: ConfiguredFunction | None = None
149
165
 
150
166
  self._llms: dict[str, ConfiguredLLM] = {}
@@ -161,7 +177,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
161
177
 
162
178
  # Create a mapping to track function name -> other function names it depends on
163
179
  self.function_dependencies: dict[str, FunctionDependencies] = {}
180
+ self.function_group_dependencies: dict[str, FunctionDependencies] = {}
164
181
  self.current_function_building: str | None = None
182
+ self.current_function_group_building: str | None = None
165
183
 
166
184
  async def __aenter__(self):
167
185
 
@@ -170,6 +188,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
170
188
  # Get the telemetry info from the config
171
189
  telemetry_config = self.general_config.telemetry
172
190
 
191
+ # If we have logging configuration, we need to manage the root logger properly
192
+ root_logger = logging.getLogger()
193
+
194
+ # Collect configured handler types to determine if we need to adjust existing handlers
195
+ # This is somewhat of a hack by inspecting the class name of the config object
196
+ has_console_handler = any(
197
+ hasattr(config, "__class__") and "console" in config.__class__.__name__.lower()
198
+ for config in telemetry_config.logging.values())
199
+
173
200
  for key, logging_config in telemetry_config.logging.items():
174
201
  # Use the same pattern as tracing, but for logging
175
202
  logging_info = self._registry.get_logging_method(type(logging_config))
@@ -183,7 +210,31 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
183
210
  self._logging_handlers[key] = handler
184
211
 
185
212
  # Now attach to NAT's root logger
186
- logging.getLogger().addHandler(handler)
213
+ root_logger.addHandler(handler)
214
+
215
+ # If we added logging handlers, manage existing handlers appropriately
216
+ if self._logging_handlers:
217
+ min_handler_level = min((handler.level for handler in root_logger.handlers), default=logging.CRITICAL)
218
+
219
+ # Ensure the root logger level allows messages through
220
+ root_logger.level = max(root_logger.level, min_handler_level)
221
+
222
+ # If a console handler is configured, adjust or remove default CLI handlers
223
+ # to avoid duplicate output while preserving workflow visibility
224
+ if has_console_handler:
225
+ # Remove existing StreamHandlers that are not the newly configured ones
226
+ for handler in root_logger.handlers[:]:
227
+ if type(handler) is logging.StreamHandler and handler not in self._logging_handlers.values():
228
+ self._removed_root_handlers.append((handler, handler.level))
229
+ root_logger.removeHandler(handler)
230
+ else:
231
+ # No console handler configured, but adjust existing handler levels
232
+ # to respect the minimum configured level for file/other handlers
233
+ for handler in root_logger.handlers[:]:
234
+ if type(handler) is logging.StreamHandler:
235
+ old_level = handler.level
236
+ handler.setLevel(min_handler_level)
237
+ self._removed_root_handlers.append((handler, old_level))
187
238
 
188
239
  # Add the telemetry exporters
189
240
  for key, telemetry_exporter_config in telemetry_config.tracing.items():
@@ -195,12 +246,21 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
195
246
 
196
247
  assert self._exit_stack is not None, "Exit stack not initialized"
197
248
 
198
- for _, handler in self._logging_handlers.items():
199
- logging.getLogger().removeHandler(handler)
249
+ root_logger = logging.getLogger()
250
+
251
+ # Remove custom logging handlers
252
+ for handler in self._logging_handlers.values():
253
+ root_logger.removeHandler(handler)
254
+
255
+ # Restore original handlers and their levels
256
+ for handler, old_level in self._removed_root_handlers:
257
+ if handler not in root_logger.handlers:
258
+ root_logger.addHandler(handler)
259
+ handler.setLevel(old_level)
200
260
 
201
261
  await self._exit_stack.__aexit__(*exc_details)
202
262
 
203
- def build(self, entry_function: str | None = None) -> Workflow:
263
+ async def build(self, entry_function: str | None = None) -> Workflow:
204
264
  """
205
265
  Creates an instance of a workflow object using the added components and the desired entry function.
206
266
 
@@ -224,12 +284,32 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
224
284
  if (self._workflow is None):
225
285
  raise ValueError("Must set a workflow before building")
226
286
 
287
+ # Set of all functions which are "included" by function groups
288
+ included_functions = set()
289
+ # Dictionary of function configs
290
+ function_configs = dict()
291
+ # Dictionary of function group configs
292
+ function_group_configs = dict()
293
+ # Dictionary of function instances
294
+ function_instances = dict()
295
+ # Dictionary of function group instances
296
+ function_group_instances = dict()
297
+
298
+ for k, v in self._function_groups.items():
299
+ included_functions.update((await v.instance.get_included_functions()).keys())
300
+ function_group_configs[k] = v.config
301
+ function_group_instances[k] = v.instance
302
+
303
+ # Function configs need to be restricted to only the functions that are not in a function group
304
+ for k, v in self._functions.items():
305
+ if k not in included_functions:
306
+ function_configs[k] = v.config
307
+ function_instances[k] = v.instance
308
+
227
309
  # Build the config from the added objects
228
310
  config = Config(general=self.general_config,
229
- functions={
230
- k: v.config
231
- for k, v in self._functions.items()
232
- },
311
+ functions=function_configs,
312
+ function_groups=function_group_configs,
233
313
  workflow=self._workflow.config,
234
314
  llms={
235
315
  k: v.config
@@ -259,14 +339,12 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
259
339
  if (entry_function is None):
260
340
  entry_fn_obj = self.get_workflow()
261
341
  else:
262
- entry_fn_obj = self.get_function(entry_function)
342
+ entry_fn_obj = await self.get_function(entry_function)
263
343
 
264
344
  workflow = Workflow.from_entry_fn(config=config,
265
345
  entry_fn=entry_fn_obj,
266
- functions={
267
- k: v.instance
268
- for k, v in self._functions.items()
269
- },
346
+ functions=function_instances,
347
+ function_groups=function_group_instances,
270
348
  llms={
271
349
  k: v.instance
272
350
  for k, v in self._llms.items()
@@ -347,11 +425,53 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
347
425
 
348
426
  return ConfiguredFunction(config=config, instance=build_result)
349
427
 
428
+ async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
429
+ """Build a function group from the provided configuration.
430
+
431
+ Args:
432
+ name: The name of the function group
433
+ config: The function group configuration
434
+
435
+ Returns:
436
+ ConfiguredFunctionGroup: The built function group
437
+
438
+ Raises:
439
+ ValueError: If the function group builder returns invalid results
440
+ """
441
+ registration = self._registry.get_function_group(type(config))
442
+
443
+ inner_builder = ChildBuilder(self)
444
+
445
+ # Build the function group - use the same wrapping pattern as _build_function
446
+ llms = {k: v.instance for k, v in self._llms.items()}
447
+ function_frameworks = detect_llm_frameworks_in_build_fn(registration)
448
+
449
+ build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
450
+
451
+ # Set the currently building function group so the ChildBuilder can track dependencies
452
+ self.current_function_group_building = config.type
453
+ # Empty set of dependencies for the current function group
454
+ self.function_group_dependencies[config.type] = FunctionDependencies()
455
+
456
+ build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
457
+
458
+ self.function_group_dependencies[name] = inner_builder.dependencies
459
+
460
+ if not isinstance(build_result, FunctionGroup):
461
+ raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
462
+ f"Got {type(build_result)}")
463
+
464
+ # set the instance name for the function group based on the workflow-provided name
465
+ build_result.set_instance_name(name)
466
+ return ConfiguredFunctionGroup(config=config, instance=build_result)
467
+
350
468
  @override
351
469
  async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
470
+ if isinstance(name, FunctionRef):
471
+ name = str(name)
352
472
 
353
- if (name in self._functions):
354
- raise ValueError(f"Function `{name}` already exists in the list of functions")
473
+ if (name in self._functions or name in self._function_groups):
474
+ raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
355
475
 
356
476
  build_result = await self._build_function(name=name, config=config)
357
477
 
@@ -360,20 +480,67 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
360
480
  return build_result.instance
361
481
 
362
482
  @override
363
- def get_function(self, name: str | FunctionRef) -> Function:
483
+ async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
484
+ if isinstance(name, FunctionGroupRef):
485
+ name = str(name)
486
+
487
+ if (name in self._function_groups or name in self._functions):
488
+ raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
489
+
490
+ # Build the function group
491
+ build_result = await self._build_function_group(name=name, config=config)
364
492
 
493
+ self._function_groups[name] = build_result
494
+
495
+ # If the function group exposes functions, add them to the global function registry
496
+ # If the function group exposes functions, record and add them to the registry
497
+ included_functions = await build_result.instance.get_included_functions()
498
+ for k in included_functions:
499
+ if k in self._functions:
500
+ raise ValueError(f"Exposed function `{k}` from group `{name}` conflicts with an existing function")
501
+ self._functions.update({
502
+ k: ConfiguredFunction(config=v.config, instance=v)
503
+ for k, v in included_functions.items()
504
+ })
505
+
506
+ return build_result.instance
507
+
508
+ @override
509
+ async def get_function(self, name: str | FunctionRef) -> Function:
510
+ if isinstance(name, FunctionRef):
511
+ name = str(name)
365
512
  if name not in self._functions:
366
513
  raise ValueError(f"Function `{name}` not found")
367
514
 
368
515
  return self._functions[name].instance
369
516
 
517
+ @override
518
+ async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
519
+ if isinstance(name, FunctionGroupRef):
520
+ name = str(name)
521
+ if name not in self._function_groups:
522
+ raise ValueError(f"Function group `{name}` not found")
523
+
524
+ return self._function_groups[name].instance
525
+
370
526
  @override
371
527
  def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
528
+ if isinstance(name, FunctionRef):
529
+ name = str(name)
372
530
  if name not in self._functions:
373
531
  raise ValueError(f"Function `{name}` not found")
374
532
 
375
533
  return self._functions[name].config
376
534
 
535
+ @override
536
+ def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
537
+ if isinstance(name, FunctionGroupRef):
538
+ name = str(name)
539
+ if name not in self._function_groups:
540
+ raise ValueError(f"Function group `{name}` not found")
541
+
542
+ return self._function_groups[name].config
543
+
377
544
  @override
378
545
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
379
546
 
@@ -403,16 +570,57 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
403
570
 
404
571
  @override
405
572
  def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
573
+ if isinstance(fn_name, FunctionRef):
574
+ fn_name = str(fn_name)
406
575
  return self.function_dependencies[fn_name]
407
576
 
408
577
  @override
409
- def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
578
+ def get_function_group_dependencies(self, fn_name: str | FunctionGroupRef) -> FunctionDependencies:
579
+ if isinstance(fn_name, FunctionGroupRef):
580
+ fn_name = str(fn_name)
581
+ return self.function_group_dependencies[fn_name]
582
+
583
+ @override
584
+ async def get_tools(self,
585
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
586
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
587
+
588
+ unique = set(tool_names)
589
+ if len(unique) != len(tool_names):
590
+ raise ValueError("Tool names must be unique")
591
+
592
+ async def _get_tools(n: str | FunctionRef | FunctionGroupRef):
593
+ tools = []
594
+ is_function_group_ref = isinstance(n, FunctionGroupRef)
595
+ if isinstance(n, FunctionRef) or is_function_group_ref:
596
+ n = str(n)
597
+ if n not in self._function_groups:
598
+ # the passed tool name is probably a function, but first check if it's a function group
599
+ if is_function_group_ref:
600
+ raise ValueError(f"Function group `{n}` not found in the list of function groups")
601
+ tools.append(await self.get_tool(n, wrapper_type))
602
+ else:
603
+ tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
604
+ current_function_group = self._function_groups[n]
605
+ for fn_name, fn_instance in (await current_function_group.instance.get_accessible_functions()).items():
606
+ try:
607
+ tools.append(tool_wrapper_reg.build_fn(fn_name, fn_instance, self))
608
+ except Exception:
609
+ logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
610
+ raise
611
+ return tools
612
+
613
+ tool_lists = await asyncio.gather(*[_get_tools(n) for n in tool_names])
614
+ # Flatten the list of lists into a single list
615
+ return [tool for tools in tool_lists for tool in tools]
410
616
 
617
+ @override
618
+ async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
619
+ if isinstance(fn_name, FunctionRef):
620
+ fn_name = str(fn_name)
411
621
  if fn_name not in self._functions:
412
622
  raise ValueError(f"Function `{fn_name}` not found in list of functions")
413
-
414
623
  fn = self._functions[fn_name]
415
-
416
624
  try:
417
625
  # Using the registry, get the tool wrapper for the requested framework
418
626
  tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
@@ -424,7 +632,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
424
632
  raise
425
633
 
426
634
  @override
427
- async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
635
+ async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> None:
428
636
 
429
637
  if (name in self._llms):
430
638
  raise ValueError(f"LLM `{name}` already exists in the list of LLMs")
@@ -440,7 +648,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
440
648
  raise
441
649
 
442
650
  @override
443
- async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str):
651
+ async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
444
652
 
445
653
  if (llm_name not in self._llms):
446
654
  raise ValueError(f"LLM `{llm_name}` not found")
@@ -540,7 +748,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
540
748
  return self._auth_providers[auth_provider_name].instance
541
749
 
542
750
  @override
543
- async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
751
+ async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
544
752
 
545
753
  if (name in self._embedders):
546
754
  raise ValueError(f"Embedder `{name}` already exists in the list of embedders")
@@ -600,7 +808,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
600
808
  return info_obj
601
809
 
602
810
  @override
603
- def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
811
+ async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
604
812
  """
605
813
  Return the instantiated memory client for the given name.
606
814
  """
@@ -646,7 +854,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
646
854
  return self._object_stores[object_store_name].config
647
855
 
648
856
  @override
649
- async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
857
+ async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
650
858
 
651
859
  if (name in self._retrievers):
652
860
  raise ValueError(f"Retriever '{name}' already exists in the list of retrievers")
@@ -662,8 +870,6 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
662
870
  logger.error("Error adding retriever `%s` with config `%s`: %s", name, config, e)
663
871
  raise
664
872
 
665
- # return info_obj
666
-
667
873
  @override
668
874
  async def get_retriever(self,
669
875
  retriever_name: str | RetrieverRef,
@@ -696,9 +902,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
696
902
 
697
903
  return self._retrievers[retriever_name].config
698
904
 
699
- @experimental(feature_name="TTC")
700
905
  @override
701
- async def add_ttc_strategy(self, name: str | str, config: TTCStrategyBaseConfig):
906
+ @experimental(feature_name="TTC")
907
+ async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig) -> None:
702
908
  if (name in self._ttc_strategies):
703
909
  raise ValueError(f"TTC strategy '{name}' already exists in the list of TTC strategies")
704
910
 
@@ -885,29 +1091,40 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
885
1091
 
886
1092
  # Instantiate a the llm
887
1093
  if component_instance.component_group == ComponentGroup.LLMS:
888
- await self.add_llm(component_instance.name, component_instance.config)
1094
+ await self.add_llm(component_instance.name, cast(LLMBaseConfig, component_instance.config))
889
1095
  # Instantiate a the embedder
890
1096
  elif component_instance.component_group == ComponentGroup.EMBEDDERS:
891
- await self.add_embedder(component_instance.name, component_instance.config)
1097
+ await self.add_embedder(component_instance.name,
1098
+ cast(EmbedderBaseConfig, component_instance.config))
892
1099
  # Instantiate a memory client
893
1100
  elif component_instance.component_group == ComponentGroup.MEMORY:
894
- await self.add_memory_client(component_instance.name, component_instance.config)
895
- # Instantiate a object store client
1101
+ await self.add_memory_client(component_instance.name,
1102
+ cast(MemoryBaseConfig, component_instance.config))
1103
+ # Instantiate a object store client
896
1104
  elif component_instance.component_group == ComponentGroup.OBJECT_STORES:
897
- await self.add_object_store(component_instance.name, component_instance.config)
1105
+ await self.add_object_store(component_instance.name,
1106
+ cast(ObjectStoreBaseConfig, component_instance.config))
898
1107
  # Instantiate a retriever client
899
1108
  elif component_instance.component_group == ComponentGroup.RETRIEVERS:
900
- await self.add_retriever(component_instance.name, component_instance.config)
1109
+ await self.add_retriever(component_instance.name,
1110
+ cast(RetrieverBaseConfig, component_instance.config))
1111
+ # Instantiate a function group
1112
+ elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
1113
+ await self.add_function_group(component_instance.name,
1114
+ cast(FunctionGroupBaseConfig, component_instance.config))
901
1115
  # Instantiate a function
902
1116
  elif component_instance.component_group == ComponentGroup.FUNCTIONS:
903
1117
  # If the function is the root, set it as the workflow later
904
1118
  if (not component_instance.is_root):
905
- await self.add_function(component_instance.name, component_instance.config)
1119
+ await self.add_function(component_instance.name,
1120
+ cast(FunctionBaseConfig, component_instance.config))
906
1121
  elif component_instance.component_group == ComponentGroup.TTC_STRATEGIES:
907
- await self.add_ttc_strategy(component_instance.name, component_instance.config)
1122
+ await self.add_ttc_strategy(component_instance.name,
1123
+ cast(TTCStrategyBaseConfig, component_instance.config))
908
1124
 
909
1125
  elif component_instance.component_group == ComponentGroup.AUTHENTICATION:
910
- await self.add_auth_provider(component_instance.name, component_instance.config)
1126
+ await self.add_auth_provider(component_instance.name,
1127
+ cast(AuthProviderBaseConfig, component_instance.config))
911
1128
  else:
912
1129
  raise ValueError(f"Unknown component group {component_instance.component_group}")
913
1130
 
@@ -957,18 +1174,35 @@ class ChildBuilder(Builder):
957
1174
  return await self._workflow_builder.add_function(name, config)
958
1175
 
959
1176
  @override
960
- def get_function(self, name: str) -> Function:
1177
+ async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
1178
+ return await self._workflow_builder.add_function_group(name, config)
1179
+
1180
+ @override
1181
+ async def get_function(self, name: str) -> Function:
961
1182
  # If a function tries to get another function, we assume it uses it
962
- fn = self._workflow_builder.get_function(name)
1183
+ fn = await self._workflow_builder.get_function(name)
963
1184
 
964
1185
  self._dependencies.add_function(name)
965
1186
 
966
1187
  return fn
967
1188
 
1189
+ @override
1190
+ async def get_function_group(self, name: str) -> FunctionGroup:
1191
+ # If a function tries to get a function group, we assume it uses it
1192
+ function_group = await self._workflow_builder.get_function_group(name)
1193
+
1194
+ self._dependencies.add_function_group(name)
1195
+
1196
+ return function_group
1197
+
968
1198
  @override
969
1199
  def get_function_config(self, name: str) -> FunctionBaseConfig:
970
1200
  return self._workflow_builder.get_function_config(name)
971
1201
 
1202
+ @override
1203
+ def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
1204
+ return self._workflow_builder.get_function_group_config(name)
1205
+
972
1206
  @override
973
1207
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
974
1208
  return await self._workflow_builder.set_workflow(config)
@@ -982,20 +1216,33 @@ class ChildBuilder(Builder):
982
1216
  return self._workflow_builder.get_workflow_config()
983
1217
 
984
1218
  @override
985
- def get_tool(self, fn_name: str, wrapper_type: LLMFrameworkEnum | str):
1219
+ async def get_tools(self,
1220
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
1221
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
1222
+ tools = await self._workflow_builder.get_tools(tool_names, wrapper_type)
1223
+ for tool_name in tool_names:
1224
+ if tool_name in self._workflow_builder._function_groups:
1225
+ self._dependencies.add_function_group(tool_name)
1226
+ else:
1227
+ self._dependencies.add_function(tool_name)
1228
+ return tools
1229
+
1230
+ @override
1231
+ async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
986
1232
  # If a function tries to get another function as a tool, we assume it uses it
987
- fn = self._workflow_builder.get_tool(fn_name, wrapper_type)
1233
+ fn = await self._workflow_builder.get_tool(fn_name, wrapper_type)
988
1234
 
989
1235
  self._dependencies.add_function(fn_name)
990
1236
 
991
1237
  return fn
992
1238
 
993
1239
  @override
994
- async def add_llm(self, name: str, config: LLMBaseConfig):
1240
+ async def add_llm(self, name: str, config: LLMBaseConfig) -> None:
995
1241
  return await self._workflow_builder.add_llm(name, config)
996
1242
 
1243
+ @experimental(feature_name="Authentication")
997
1244
  @override
998
- async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig):
1245
+ async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> AuthProviderBase:
999
1246
  return await self._workflow_builder.add_auth_provider(name, config)
1000
1247
 
1001
1248
  @override
@@ -1003,7 +1250,7 @@ class ChildBuilder(Builder):
1003
1250
  return await self._workflow_builder.get_auth_provider(auth_provider_name)
1004
1251
 
1005
1252
  @override
1006
- async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str):
1253
+ async def get_llm(self, llm_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
1007
1254
  llm = await self._workflow_builder.get_llm(llm_name, wrapper_type)
1008
1255
 
1009
1256
  self._dependencies.add_llm(llm_name)
@@ -1015,11 +1262,11 @@ class ChildBuilder(Builder):
1015
1262
  return self._workflow_builder.get_llm_config(llm_name)
1016
1263
 
1017
1264
  @override
1018
- async def add_embedder(self, name: str, config: EmbedderBaseConfig):
1019
- return await self._workflow_builder.add_embedder(name, config)
1265
+ async def add_embedder(self, name: str, config: EmbedderBaseConfig) -> None:
1266
+ await self._workflow_builder.add_embedder(name, config)
1020
1267
 
1021
1268
  @override
1022
- async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str):
1269
+ async def get_embedder(self, embedder_name: str, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
1023
1270
  embedder = await self._workflow_builder.get_embedder(embedder_name, wrapper_type)
1024
1271
 
1025
1272
  self._dependencies.add_embedder(embedder_name)
@@ -1035,11 +1282,11 @@ class ChildBuilder(Builder):
1035
1282
  return await self._workflow_builder.add_memory_client(name, config)
1036
1283
 
1037
1284
  @override
1038
- def get_memory_client(self, memory_name: str) -> MemoryEditor:
1285
+ async def get_memory_client(self, memory_name: str) -> MemoryEditor:
1039
1286
  """
1040
1287
  Return the instantiated memory client for the given name.
1041
1288
  """
1042
- memory_client = self._workflow_builder.get_memory_client(memory_name)
1289
+ memory_client = await self._workflow_builder.get_memory_client(memory_name)
1043
1290
 
1044
1291
  self._dependencies.add_memory_client(memory_name)
1045
1292
 
@@ -1069,8 +1316,9 @@ class ChildBuilder(Builder):
1069
1316
  return self._workflow_builder.get_object_store_config(object_store_name)
1070
1317
 
1071
1318
  @override
1072
- async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig):
1073
- return await self._workflow_builder.add_ttc_strategy(name, config)
1319
+ @experimental(feature_name="TTC")
1320
+ async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
1321
+ await self._workflow_builder.add_ttc_strategy(name, config)
1074
1322
 
1075
1323
  @override
1076
1324
  async def get_ttc_strategy(self,
@@ -1091,11 +1339,11 @@ class ChildBuilder(Builder):
1091
1339
  stage_type=stage_type)
1092
1340
 
1093
1341
  @override
1094
- async def add_retriever(self, name: str, config: RetrieverBaseConfig):
1095
- return await self._workflow_builder.add_retriever(name, config)
1342
+ async def add_retriever(self, name: str, config: RetrieverBaseConfig) -> None:
1343
+ await self._workflow_builder.add_retriever(name, config)
1096
1344
 
1097
1345
  @override
1098
- async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None):
1346
+ async def get_retriever(self, retriever_name: str, wrapper_type: LLMFrameworkEnum | str | None = None) -> Retriever:
1099
1347
  if not wrapper_type:
1100
1348
  return await self._workflow_builder.get_retriever(retriever_name=retriever_name)
1101
1349
  return await self._workflow_builder.get_retriever(retriever_name=retriever_name, wrapper_type=wrapper_type)
@@ -1111,3 +1359,7 @@ class ChildBuilder(Builder):
1111
1359
  @override
1112
1360
  def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
1113
1361
  return self._workflow_builder.get_function_dependencies(fn_name)
1362
+
1363
+ @override
1364
+ def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
1365
+ return self._workflow_builder.get_function_group_dependencies(fn_name)
@@ -84,7 +84,7 @@ class LayeredConfig:
84
84
  if lower_value not in ['true', 'false']:
85
85
  raise ValueError(f"Boolean value must be 'true' or 'false', got '{value}'")
86
86
  value = lower_value == 'true'
87
- elif isinstance(original_value, (int, float)):
87
+ elif isinstance(original_value, int | float):
88
88
  value = type(original_value)(value)
89
89
  elif isinstance(original_value, list):
90
90
  value = [v.strip() for v in value.split(',')]