nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250922__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. nat/agent/base.py +9 -4
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +1 -1
  5. nat/agent/react_agent/register.py +17 -14
  6. nat/agent/reasoning_agent/reasoning_agent.py +9 -7
  7. nat/agent/register.py +1 -0
  8. nat/agent/rewoo_agent/agent.py +9 -2
  9. nat/agent/rewoo_agent/register.py +16 -12
  10. nat/agent/tool_calling_agent/agent.py +69 -7
  11. nat/agent/tool_calling_agent/register.py +14 -13
  12. nat/authentication/credential_validator/__init__.py +14 -0
  13. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  14. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  15. nat/builder/builder.py +27 -4
  16. nat/builder/component_utils.py +7 -3
  17. nat/builder/context.py +28 -6
  18. nat/builder/function.py +313 -0
  19. nat/builder/function_info.py +1 -1
  20. nat/builder/workflow.py +5 -0
  21. nat/builder/workflow_builder.py +215 -16
  22. nat/cli/commands/optimize.py +90 -0
  23. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  24. nat/cli/commands/workflow/workflow_commands.py +4 -7
  25. nat/cli/entrypoint.py +4 -9
  26. nat/cli/register_workflow.py +38 -4
  27. nat/cli/type_registry.py +71 -0
  28. nat/control_flow/__init__.py +0 -0
  29. nat/control_flow/register.py +20 -0
  30. nat/control_flow/router_agent/__init__.py +0 -0
  31. nat/control_flow/router_agent/agent.py +329 -0
  32. nat/control_flow/router_agent/prompt.py +48 -0
  33. nat/control_flow/router_agent/register.py +91 -0
  34. nat/control_flow/sequential_executor.py +167 -0
  35. nat/data_models/agent.py +34 -0
  36. nat/data_models/authentication.py +38 -0
  37. nat/data_models/component.py +2 -0
  38. nat/data_models/component_ref.py +11 -0
  39. nat/data_models/config.py +40 -16
  40. nat/data_models/function.py +34 -0
  41. nat/data_models/function_dependencies.py +8 -0
  42. nat/data_models/optimizable.py +119 -0
  43. nat/data_models/optimizer.py +149 -0
  44. nat/data_models/temperature_mixin.py +4 -3
  45. nat/data_models/top_p_mixin.py +4 -3
  46. nat/embedder/nim_embedder.py +1 -1
  47. nat/embedder/openai_embedder.py +1 -1
  48. nat/eval/config.py +1 -1
  49. nat/eval/evaluate.py +5 -1
  50. nat/eval/register.py +4 -0
  51. nat/eval/runtime_evaluator/__init__.py +14 -0
  52. nat/eval/runtime_evaluator/evaluate.py +123 -0
  53. nat/eval/runtime_evaluator/register.py +100 -0
  54. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
  55. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  56. nat/front_ends/fastapi/fastapi_front_end_config.py +18 -3
  57. nat/front_ends/fastapi/fastapi_front_end_plugin.py +134 -3
  58. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
  59. nat/front_ends/fastapi/job_store.py +518 -99
  60. nat/front_ends/fastapi/main.py +11 -19
  61. nat/front_ends/fastapi/utils.py +57 -0
  62. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  63. nat/front_ends/mcp/mcp_front_end_config.py +5 -1
  64. nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
  65. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +111 -3
  66. nat/front_ends/mcp/tool_converter.py +3 -0
  67. nat/llm/aws_bedrock_llm.py +14 -3
  68. nat/llm/nim_llm.py +14 -3
  69. nat/llm/openai_llm.py +8 -1
  70. nat/observability/exporter/processing_exporter.py +29 -55
  71. nat/observability/mixin/redaction_config_mixin.py +5 -4
  72. nat/observability/mixin/tagging_config_mixin.py +26 -14
  73. nat/observability/mixin/type_introspection_mixin.py +420 -107
  74. nat/observability/processor/processor.py +3 -0
  75. nat/observability/processor/redaction/__init__.py +24 -0
  76. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  77. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  78. nat/observability/processor/redaction/redaction_processor.py +177 -0
  79. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  80. nat/observability/processor/span_tagging_processor.py +21 -14
  81. nat/profiler/decorators/framework_wrapper.py +9 -6
  82. nat/profiler/parameter_optimization/__init__.py +0 -0
  83. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  84. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  85. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  86. nat/profiler/parameter_optimization/parameter_selection.py +108 -0
  87. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  88. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  89. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  90. nat/profiler/utils.py +3 -1
  91. nat/tool/chat_completion.py +4 -1
  92. nat/tool/github_tools.py +450 -0
  93. nat/tool/register.py +2 -7
  94. nat/utils/callable_utils.py +70 -0
  95. nat/utils/exception_handlers/automatic_retries.py +103 -48
  96. nat/utils/log_levels.py +25 -0
  97. nat/utils/type_utils.py +4 -0
  98. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/METADATA +10 -1
  99. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/RECORD +105 -76
  100. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/entry_points.txt +1 -0
  101. nat/observability/processor/header_redaction_processor.py +0 -123
  102. nat/observability/processor/redaction_processor.py +0 -77
  103. nat/tool/github_tools/create_github_commit.py +0 -133
  104. nat/tool/github_tools/create_github_issue.py +0 -87
  105. nat/tool/github_tools/create_github_pr.py +0 -106
  106. nat/tool/github_tools/get_github_file.py +0 -106
  107. nat/tool/github_tools/get_github_issue.py +0 -166
  108. nat/tool/github_tools/get_github_pr.py +0 -256
  109. nat/tool/github_tools/update_github_issue.py +0 -100
  110. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  111. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/WHEEL +0 -0
  112. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  113. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE.md +0 -0
  114. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,167 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import typing
18
+
19
+ from langchain_core.tools.base import BaseTool
20
+ from pydantic import BaseModel
21
+ from pydantic import Field
22
+
23
+ from nat.builder.builder import Builder
24
+ from nat.builder.framework_enum import LLMFrameworkEnum
25
+ from nat.builder.function import Function
26
+ from nat.builder.function_info import FunctionInfo
27
+ from nat.cli.register_workflow import register_function
28
+ from nat.data_models.component_ref import FunctionRef
29
+ from nat.data_models.function import FunctionBaseConfig
30
+ from nat.utils.type_utils import DecomposedType
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class ToolExecutionConfig(BaseModel):
36
+ """Configuration for individual tool execution within sequential execution."""
37
+
38
+ use_streaming: bool = Field(default=False, description="Whether to use streaming output for the tool.")
39
+
40
+
41
+ class SequentialExecutorConfig(FunctionBaseConfig, name="sequential_executor"):
42
+ """Configuration for sequential execution of a list of functions."""
43
+
44
+ tool_list: list[FunctionRef] = Field(default_factory=list,
45
+ description="A list of functions to execute sequentially.")
46
+ tool_execution_config: dict[str, ToolExecutionConfig] = Field(default_factory=dict,
47
+ description="Optional configuration for each"
48
+ "tool in the sequential execution tool list."
49
+ "Keys must match the tool names from the"
50
+ "tool_list.")
51
+ raise_type_incompatibility: bool = Field(
52
+ default=False,
53
+ description="Default to False. Check if the adjacent tools are type compatible,"
54
+ "which means the output type of the previous function is compatible with the input type of the next function."
55
+ "If set to True, any incompatibility will raise an exception. If set to false, the incompatibility will only"
56
+ "generate a warning message and the sequential execution will continue.")
57
+
58
+
59
+ def _get_function_output_type(function: Function, tool_execution_config: dict[str, ToolExecutionConfig]) -> type:
60
+ function_config = tool_execution_config.get(function.instance_name, None)
61
+ if function_config:
62
+ return function.streaming_output_type if function_config.use_streaming else function.single_output_type
63
+ else:
64
+ return function.single_output_type
65
+
66
+
67
+ def _validate_function_type_compatibility(src_fn: Function,
68
+ target_fn: Function,
69
+ tool_execution_config: dict[str, ToolExecutionConfig]) -> None:
70
+ src_output_type = _get_function_output_type(src_fn, tool_execution_config)
71
+ target_input_type = target_fn.input_type
72
+ logger.debug(
73
+ f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
74
+ f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
75
+
76
+ is_compatible = DecomposedType.is_type_compatible(src_output_type, target_input_type)
77
+ if not is_compatible:
78
+ raise ValueError(
79
+ f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
80
+ f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
81
+
82
+
83
+ def _validate_tool_list_type_compatibility(sequential_executor_config: SequentialExecutorConfig,
84
+ builder: Builder) -> tuple[type, type]:
85
+ tool_list = sequential_executor_config.tool_list
86
+ tool_execution_config = sequential_executor_config.tool_execution_config
87
+
88
+ function_list: list[Function] = []
89
+ for function_ref in tool_list:
90
+ function_list.append(builder.get_function(function_ref))
91
+ if not function_list:
92
+ raise RuntimeError("The function list is empty")
93
+ input_type = function_list[0].input_type
94
+
95
+ if len(function_list) > 1:
96
+ for src_fn, target_fn in zip(function_list[0:-1], function_list[1:]):
97
+ try:
98
+ _validate_function_type_compatibility(src_fn, target_fn, tool_execution_config)
99
+ except ValueError as e:
100
+ raise ValueError(f"The sequential tool list has incompatible types: {e}")
101
+
102
+ output_type = _get_function_output_type(function_list[-1], tool_execution_config)
103
+ logger.debug(f"The input type of the sequential executor tool list is {str(input_type)},"
104
+ f"the output type is {str(output_type)}")
105
+
106
+ return (input_type, output_type)
107
+
108
+
109
+ @register_function(config_type=SequentialExecutorConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
110
+ async def sequential_execution(config: SequentialExecutorConfig, builder: Builder):
111
+ logger.debug(f"Initializing sequential executor with tool list: {config.tool_list}")
112
+
113
+ tools: list[BaseTool] = builder.get_tools(tool_names=config.tool_list, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
114
+ tools_dict: dict[str, BaseTool] = {tool.name: tool for tool in tools}
115
+
116
+ try:
117
+ input_type, output_type = _validate_tool_list_type_compatibility(config, builder)
118
+ except ValueError as e:
119
+ if config.raise_type_incompatibility:
120
+ logger.error(f"The sequential executor tool list has incompatible types: {e}")
121
+ raise
122
+ else:
123
+ logger.warning(f"The sequential executor tool list has incompatible types: {e}")
124
+ input_type = typing.Any
125
+ output_type = typing.Any
126
+ except Exception as e:
127
+ raise ValueError(f"Error with the sequential executor tool list: {e}")
128
+
129
+ # The type annotation of _sequential_function_execution is dynamically set according to the tool list
130
+ async def _sequential_function_execution(initial_tool_input):
131
+ logger.debug(f"Executing sequential executor with tool list: {config.tool_list}")
132
+
133
+ tool_list: list[FunctionRef] = config.tool_list
134
+ tool_input = initial_tool_input
135
+ tool_response = None
136
+
137
+ for tool_name in tool_list:
138
+ tool = tools_dict[tool_name]
139
+ tool_execution_config = config.tool_execution_config.get(tool_name, None)
140
+ logger.debug(f"Executing tool {tool_name} with input: {tool_input}")
141
+ try:
142
+ if tool_execution_config:
143
+ if tool_execution_config.use_streaming:
144
+ output = ""
145
+ async for chunk in tool.astream(tool_input):
146
+ output += chunk.content
147
+ tool_response = output
148
+ else:
149
+ tool_response = await tool.ainvoke(tool_input)
150
+ else:
151
+ tool_response = await tool.ainvoke(tool_input)
152
+ except Exception as e:
153
+ logger.error(f"Error with tool {tool_name}: {e}")
154
+ raise
155
+
156
+ # The input of the next tool is the response of the previous tool
157
+ tool_input = tool_response
158
+
159
+ return tool_response
160
+
161
+ # Dynamically set the annotations for the function
162
+ _sequential_function_execution.__annotations__ = {"initial_tool_input": input_type, "return": output_type}
163
+ logger.debug(f"Sequential executor function annotations: {_sequential_function_execution.__annotations__}")
164
+
165
+ yield FunctionInfo.from_fn(_sequential_function_execution,
166
+ description="Executes a list of functions sequentially."
167
+ "The input of the next tool is the response of the previous tool.")
@@ -0,0 +1,34 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import Field
17
+ from pydantic import PositiveInt
18
+
19
+ from nat.data_models.component_ref import LLMRef
20
+ from nat.data_models.function import FunctionBaseConfig
21
+
22
+
23
+ class AgentBaseConfig(FunctionBaseConfig):
24
+ """Base configuration class for all NAT agents with common fields."""
25
+
26
+ workflow_alias: str | None = Field(
27
+ default=None,
28
+ description=("The alias of the workflow. Useful when the agent is configured as a workflow "
29
+ "and needs to expose a customized name as a tool."))
30
+ llm_name: LLMRef = Field(description="The LLM model to use with the agent.")
31
+ verbose: bool = Field(default=False, description="Set the verbosity of the agent's logging.")
32
+ description: str = Field(description="The description of this function's use.")
33
+ log_response_max_chars: PositiveInt = Field(
34
+ default=1000, description="Maximum number of characters to display in logs when logging responses.")
@@ -177,6 +177,26 @@ Credential = typing.Annotated[
177
177
  ]
178
178
 
179
179
 
180
+ class TokenValidationResult(BaseModel):
181
+ """
182
+ Standard result for Bearer Token Validation.
183
+ """
184
+ model_config = ConfigDict(extra="forbid")
185
+
186
+ client_id: str | None = Field(description="OAuth2 client identifier")
187
+ scopes: list[str] | None = Field(default=None, description="List of granted scopes (introspection only)")
188
+ expires_at: int | None = Field(default=None, description="Token expiration time (Unix timestamp)")
189
+ audience: list[str] | None = Field(default=None, description="Token audiences (aud claim)")
190
+ subject: str | None = Field(default=None, description="Token subject (sub claim)")
191
+ issuer: str | None = Field(default=None, description="Token issuer (iss claim)")
192
+ token_type: str = Field(description="Token type")
193
+ active: bool | None = Field(default=True, description="Token active status")
194
+ nbf: int | None = Field(default=None, description="Not before time (Unix timestamp)")
195
+ iat: int | None = Field(default=None, description="Issued at time (Unix timestamp)")
196
+ jti: str | None = Field(default=None, description="JWT ID")
197
+ username: str | None = Field(default=None, description="Username (introspection only)")
198
+
199
+
180
200
  class AuthResult(BaseModel):
181
201
  """
182
202
  Represents the result of an authentication process.
@@ -229,3 +249,21 @@ class AuthResult(BaseModel):
229
249
  target_kwargs.setdefault(k, {}).update(v)
230
250
  else:
231
251
  target_kwargs[k] = v
252
+
253
+
254
+ class AuthReason(str, Enum):
255
+ """
256
+ Why the caller is asking for auth now.
257
+ """
258
+ NORMAL = "normal"
259
+ RETRY_AFTER_401 = "retry_after_401"
260
+
261
+
262
+ class AuthRequest(BaseModel):
263
+ """
264
+ Authentication request payload for provider.authenticate(...).
265
+ """
266
+ model_config = ConfigDict(extra="forbid")
267
+
268
+ reason: AuthReason = Field(default=AuthReason.NORMAL, description="Purpose of this auth attempt.")
269
+ www_authenticate: str | None = Field(default=None, description="Raw WWW-Authenticate header from a 401 response.")
@@ -27,6 +27,7 @@ class ComponentEnum(StrEnum):
27
27
  EVALUATOR = "evaluator"
28
28
  FRONT_END = "front_end"
29
29
  FUNCTION = "function"
30
+ FUNCTION_GROUP = "function_group"
30
31
  TTC_STRATEGY = "ttc_strategy"
31
32
  LLM_CLIENT = "llm_client"
32
33
  LLM_PROVIDER = "llm_provider"
@@ -47,6 +48,7 @@ class ComponentGroup(StrEnum):
47
48
  AUTHENTICATION = "authentication"
48
49
  EMBEDDERS = "embedders"
49
50
  FUNCTIONS = "functions"
51
+ FUNCTION_GROUPS = "function_groups"
50
52
  TTC_STRATEGIES = "ttc_strategies"
51
53
  LLMS = "llms"
52
54
  MEMORY = "memory"
@@ -102,6 +102,17 @@ class FunctionRef(ComponentRef):
102
102
  return ComponentGroup.FUNCTIONS
103
103
 
104
104
 
105
+ class FunctionGroupRef(ComponentRef):
106
+ """
107
+ A reference to a function group in a NAT configuration object.
108
+ """
109
+
110
+ @property
111
+ @override
112
+ def component_group(self):
113
+ return ComponentGroup.FUNCTION_GROUPS
114
+
115
+
105
116
  class LLMRef(ComponentRef):
106
117
  """
107
118
  A reference to an LLM in a NAT configuration object.
nat/data_models/config.py CHANGED
@@ -20,6 +20,7 @@ import typing
20
20
  from pydantic import BaseModel
21
21
  from pydantic import ConfigDict
22
22
  from pydantic import Discriminator
23
+ from pydantic import Field
23
24
  from pydantic import ValidationError
24
25
  from pydantic import ValidationInfo
25
26
  from pydantic import ValidatorFunctionWrapHandler
@@ -29,7 +30,9 @@ from nat.data_models.evaluate import EvalConfig
29
30
  from nat.data_models.front_end import FrontEndBaseConfig
30
31
  from nat.data_models.function import EmptyFunctionConfig
31
32
  from nat.data_models.function import FunctionBaseConfig
33
+ from nat.data_models.function import FunctionGroupBaseConfig
32
34
  from nat.data_models.logging import LoggingBaseConfig
35
+ from nat.data_models.optimizer import OptimizerConfig
33
36
  from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
34
37
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
35
38
  from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
@@ -57,9 +60,10 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
57
60
  error_type = e['type']
58
61
  if error_type == 'union_tag_invalid' and "ctx" in e and not logged_once:
59
62
  requested_type = e["ctx"]["tag"]
60
-
61
63
  if (info.field_name in ('workflow', 'functions')):
62
64
  registered_keys = GlobalTypeRegistry.get().get_registered_functions()
65
+ elif (info.field_name == "function_groups"):
66
+ registered_keys = GlobalTypeRegistry.get().get_registered_function_groups()
63
67
  elif (info.field_name == "authentication"):
64
68
  registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
65
69
  elif (info.field_name == "llms"):
@@ -135,8 +139,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
135
139
 
136
140
  class TelemetryConfig(BaseModel):
137
141
 
138
- logging: dict[str, LoggingBaseConfig] = {}
139
- tracing: dict[str, TelemetryExporterBaseConfig] = {}
142
+ logging: dict[str, LoggingBaseConfig] = Field(default_factory=dict)
143
+ tracing: dict[str, TelemetryExporterBaseConfig] = Field(default_factory=dict)
140
144
 
141
145
  @field_validator("logging", "tracing", mode="wrap")
142
146
  @classmethod
@@ -185,10 +189,14 @@ class GeneralConfig(BaseModel):
185
189
 
186
190
  model_config = ConfigDict(protected_namespaces=())
187
191
 
188
- use_uvloop: bool = True
192
+ use_uvloop: bool | None = Field(
193
+ default=None,
194
+ deprecated=
195
+ "`use_uvloop` field is deprecated and will be removed in a future release. The use of `uv_loop` is now" +
196
+ "automatically determined based on platform")
189
197
  """
190
- Whether to use uvloop for the event loop. This can provide a significant speedup in some cases. Disable to provide
191
- better error messages when debugging.
198
+ This field is deprecated and ignored. It previously controlled whether to use uvloop as the event loop. uvloop
199
+ usage is now determined automatically based on the platform.
192
200
  """
193
201
 
194
202
  telemetry: TelemetryConfig = TelemetryConfig()
@@ -240,31 +248,37 @@ class Config(HashableBaseModel):
240
248
  general: GeneralConfig = GeneralConfig()
241
249
 
242
250
  # Functions Configuration
243
- functions: dict[str, FunctionBaseConfig] = {}
251
+ functions: dict[str, FunctionBaseConfig] = Field(default_factory=dict)
252
+
253
+ # Function Groups Configuration
254
+ function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
244
255
 
245
256
  # LLMs Configuration
246
- llms: dict[str, LLMBaseConfig] = {}
257
+ llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
247
258
 
248
259
  # Embedders Configuration
249
- embedders: dict[str, EmbedderBaseConfig] = {}
260
+ embedders: dict[str, EmbedderBaseConfig] = Field(default_factory=dict)
250
261
 
251
262
  # Memory Configuration
252
- memory: dict[str, MemoryBaseConfig] = {}
263
+ memory: dict[str, MemoryBaseConfig] = Field(default_factory=dict)
253
264
 
254
265
  # Object Stores Configuration
255
- object_stores: dict[str, ObjectStoreBaseConfig] = {}
266
+ object_stores: dict[str, ObjectStoreBaseConfig] = Field(default_factory=dict)
267
+
268
+ # Optimizer Configuration
269
+ optimizer: OptimizerConfig = OptimizerConfig()
256
270
 
257
271
  # Retriever Configuration
258
- retrievers: dict[str, RetrieverBaseConfig] = {}
272
+ retrievers: dict[str, RetrieverBaseConfig] = Field(default_factory=dict)
259
273
 
260
274
  # TTC Strategies
261
- ttc_strategies: dict[str, TTCStrategyBaseConfig] = {}
275
+ ttc_strategies: dict[str, TTCStrategyBaseConfig] = Field(default_factory=dict)
262
276
 
263
277
  # Workflow Configuration
264
278
  workflow: FunctionBaseConfig = EmptyFunctionConfig()
265
279
 
266
280
  # Authentication Configuration
267
- authentication: dict[str, AuthProviderBaseConfig] = {}
281
+ authentication: dict[str, AuthProviderBaseConfig] = Field(default_factory=dict)
268
282
 
269
283
  # Evaluation Options
270
284
  eval: EvalConfig = EvalConfig()
@@ -278,6 +292,7 @@ class Config(HashableBaseModel):
278
292
  stream.write(f"Workflow Type: {self.workflow.type}\n")
279
293
 
280
294
  stream.write(f"Number of Functions: {len(self.functions)}\n")
295
+ stream.write(f"Number of Function Groups: {len(self.function_groups)}\n")
281
296
  stream.write(f"Number of LLMs: {len(self.llms)}\n")
282
297
  stream.write(f"Number of Embedders: {len(self.embedders)}\n")
283
298
  stream.write(f"Number of Memory: {len(self.memory)}\n")
@@ -287,6 +302,7 @@ class Config(HashableBaseModel):
287
302
  stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
288
303
 
289
304
  @field_validator("functions",
305
+ "function_groups",
290
306
  "llms",
291
307
  "embedders",
292
308
  "memory",
@@ -328,6 +344,10 @@ class Config(HashableBaseModel):
328
344
  typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
329
345
  Discriminator(TypedBaseModel.discriminator)]]
330
346
 
347
+ FunctionGroupsAnnotation = dict[str,
348
+ typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
349
+ Discriminator(TypedBaseModel.discriminator)]]
350
+
331
351
  MemoryAnnotation = dict[str,
332
352
  typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
333
353
  Discriminator(TypedBaseModel.discriminator)]]
@@ -335,7 +355,6 @@ class Config(HashableBaseModel):
335
355
  ObjectStoreAnnotation = dict[str,
336
356
  typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
337
357
  Discriminator(TypedBaseModel.discriminator)]]
338
-
339
358
  RetrieverAnnotation = dict[str,
340
359
  typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
341
360
  Discriminator(TypedBaseModel.discriminator)]]
@@ -344,7 +363,7 @@ class Config(HashableBaseModel):
344
363
  typing.Annotated[type_registry.compute_annotation(TTCStrategyBaseConfig),
345
364
  Discriminator(TypedBaseModel.discriminator)]]
346
365
 
347
- WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
366
+ WorkflowAnnotation = typing.Annotated[(type_registry.compute_annotation(FunctionBaseConfig)),
348
367
  Discriminator(TypedBaseModel.discriminator)]
349
368
 
350
369
  should_rebuild = False
@@ -369,6 +388,11 @@ class Config(HashableBaseModel):
369
388
  functions_field.annotation = FunctionsAnnotation
370
389
  should_rebuild = True
371
390
 
391
+ function_groups_field = cls.model_fields.get("function_groups")
392
+ if function_groups_field is not None and function_groups_field.annotation != FunctionGroupsAnnotation:
393
+ function_groups_field.annotation = FunctionGroupsAnnotation
394
+ should_rebuild = True
395
+
372
396
  memory_field = cls.model_fields.get("memory")
373
397
  if memory_field is not None and memory_field.annotation != MemoryAnnotation:
374
398
  memory_field.annotation = MemoryAnnotation
@@ -15,6 +15,10 @@
15
15
 
16
16
  import typing
17
17
 
18
+ from pydantic import Field
19
+ from pydantic import field_validator
20
+ from pydantic import model_validator
21
+
18
22
  from .common import BaseModelRegistryTag
19
23
  from .common import TypedBaseModel
20
24
 
@@ -23,8 +27,38 @@ class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
23
27
  pass
24
28
 
25
29
 
30
+ class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
31
+ """Base configuration for function groups.
32
+
33
+ Function groups enable sharing of configurations and resources across multiple functions.
34
+ """
35
+ include: list[str] = Field(
36
+ default_factory=list,
37
+ description="The list of function names which should be added to the global Function registry",
38
+ )
39
+ exclude: list[str] = Field(
40
+ default_factory=list,
41
+ description="The list of function names which should be excluded from default access to the group",
42
+ )
43
+
44
+ @field_validator("include", "exclude")
45
+ @classmethod
46
+ def _validate_fields_include_exclude(cls, value: list[str]) -> list[str]:
47
+ if len(set(value)) != len(value):
48
+ raise ValueError("Function names must be unique")
49
+ return sorted(value)
50
+
51
+ @model_validator(mode="after")
52
+ def _validate_include_exclude(self):
53
+ if self.include and self.exclude:
54
+ raise ValueError("include and exclude cannot be used together")
55
+ return self
56
+
57
+
26
58
  class EmptyFunctionConfig(FunctionBaseConfig, name="EmptyFunctionConfig"):
27
59
  pass
28
60
 
29
61
 
30
62
  FunctionConfigT = typing.TypeVar("FunctionConfigT", bound=FunctionBaseConfig)
63
+
64
+ FunctionGroupConfigT = typing.TypeVar("FunctionGroupConfigT", bound=FunctionGroupBaseConfig)
@@ -23,6 +23,7 @@ class FunctionDependencies(BaseModel):
23
23
  A class to represent the dependencies of a function.
24
24
  """
25
25
  functions: set[str] = Field(default_factory=set)
26
+ function_groups: set[str] = Field(default_factory=set)
26
27
  llms: set[str] = Field(default_factory=set)
27
28
  embedders: set[str] = Field(default_factory=set)
28
29
  memory_clients: set[str] = Field(default_factory=set)
@@ -33,6 +34,10 @@ class FunctionDependencies(BaseModel):
33
34
  def serialize_functions(self, v: set[str]) -> list[str]:
34
35
  return list(v)
35
36
 
37
+ @field_serializer("function_groups", when_used="json")
38
+ def serialize_function_groups(self, v: set[str]) -> list[str]:
39
+ return list(v)
40
+
36
41
  @field_serializer("llms", when_used="json")
37
42
  def serialize_llms(self, v: set[str]) -> list[str]:
38
43
  return list(v)
@@ -56,6 +61,9 @@ class FunctionDependencies(BaseModel):
56
61
  def add_function(self, function: str):
57
62
  self.functions.add(function)
58
63
 
64
+ def add_function_group(self, function_group: str):
65
+ self.function_groups.add(function_group) # pylint: disable=no-member
66
+
59
67
  def add_llm(self, llm: str):
60
68
  self.llms.add(llm)
61
69
 
@@ -0,0 +1,119 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import Sequence
17
+ from typing import Any
18
+ from typing import Generic
19
+ from typing import TypeVar
20
+
21
+ from optuna import Trial
22
+ from pydantic import BaseModel
23
+ from pydantic import ConfigDict
24
+ from pydantic import Field
25
+ from pydantic import model_validator
26
+
27
+ T = TypeVar("T", int, float, bool, str)
28
+
29
+
30
+ # --------------------------------------------------------------------- #
31
+ # 1. Hyper‑parameter metadata container #
32
+ # --------------------------------------------------------------------- #
33
+ class SearchSpace(BaseModel, Generic[T]):
34
+ values: Sequence[T] | None = None
35
+ low: T | None = None
36
+ high: T | None = None
37
+ log: bool = False # log scale
38
+ step: float | None = None
39
+ is_prompt: bool = False
40
+ prompt: str | None = None # prompt to optimize
41
+ prompt_purpose: str | None = None # purpose of the prompt
42
+
43
+ model_config = ConfigDict(protected_namespaces=(), extra="forbid")
44
+
45
+ @model_validator(mode="after")
46
+ def validate_search_space_parameters(self):
47
+ """Validate that either values is provided, or both high and low."""
48
+ if self.values is not None:
49
+ # If values is provided, we don't need high/low
50
+ if self.high is not None or self.low is not None:
51
+ raise ValueError("SearchSpace 'values' is mutually exclusive with 'high' and 'low'")
52
+ return self
53
+
54
+ return self
55
+
56
+ # Helper for Optuna Trials
57
+ def suggest(self, trial: Trial, name: str):
58
+ if self.is_prompt:
59
+ raise ValueError("Prompt optimization not currently supported using Optuna. "
60
+ "Use the genetic algorithm implementation instead.")
61
+ if self.values is not None:
62
+ return trial.suggest_categorical(name, self.values)
63
+ if isinstance(self.low, int):
64
+ return trial.suggest_int(name, self.low, self.high, log=self.log, step=self.step)
65
+ return trial.suggest_float(name, self.low, self.high, log=self.log, step=self.step)
66
+
67
+
68
+ def OptimizableField(
69
+ default: Any,
70
+ *,
71
+ space: SearchSpace | None = None,
72
+ merge_conflict: str = "overwrite",
73
+ **fld_kw,
74
+ ):
75
+ # 1. Pull out any user‑supplied extras (must be a dict)
76
+ user_extra = fld_kw.pop("json_schema_extra", None) or {}
77
+ if not isinstance(user_extra, dict):
78
+ raise TypeError("`json_schema_extra` must be a mapping.")
79
+
80
+ # 2. If the space is a prompt, ensure a concrete base prompt exists
81
+ if space is not None and getattr(space, "is_prompt", False):
82
+ if getattr(space, "prompt", None) is None:
83
+ if default is None:
84
+ raise ValueError("Prompt-optimized fields require a base prompt: provide a "
85
+ "non-None field default or set space.prompt.")
86
+ # Default prompt not provided in space; fall back to the field's default
87
+ space.prompt = default
88
+
89
+ # 3. Prepare our own metadata
90
+ ours = {"optimizable": True}
91
+ if space is not None:
92
+ ours["search_space"] = space
93
+
94
+ # 4. Merge with user extras according to merge_conflict policy
95
+ intersect = ours.keys() & user_extra.keys()
96
+ if intersect:
97
+ if merge_conflict == "error":
98
+ raise ValueError("`json_schema_extra` already contains reserved key(s): "
99
+ f"{', '.join(intersect)}")
100
+ if merge_conflict == "keep":
101
+ # remove the ones the user already set so we don't overwrite them
102
+ ours = {k: v for k, v in ours.items() if k not in intersect}
103
+
104
+ merged_extra = {**user_extra, **ours} # ours wins if 'overwrite'
105
+
106
+ # 5. Return a normal Pydantic Field with merged extras
107
+ return Field(default, json_schema_extra=merged_extra, **fld_kw)
108
+
109
+
110
+ class OptimizableMixin(BaseModel):
111
+ optimizable_params: list[str] = Field(default_factory=list,
112
+ description="List of parameters that can be optimized.",
113
+ exclude=True)
114
+
115
+ search_space: dict[str, SearchSpace] = Field(
116
+ default_factory=dict,
117
+ description="Optional search space overrides for optimizable parameters.",
118
+ exclude=True,
119
+ )