nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250922__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +1 -1
- nat/agent/react_agent/register.py +17 -14
- nat/agent/reasoning_agent/reasoning_agent.py +9 -7
- nat/agent/register.py +1 -0
- nat/agent/rewoo_agent/agent.py +9 -2
- nat/agent/rewoo_agent/register.py +16 -12
- nat/agent/tool_calling_agent/agent.py +69 -7
- nat/agent/tool_calling_agent/register.py +14 -13
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/builder/builder.py +27 -4
- nat/builder/component_utils.py +7 -3
- nat/builder/context.py +28 -6
- nat/builder/function.py +313 -0
- nat/builder/function_info.py +1 -1
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +215 -16
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -7
- nat/cli/entrypoint.py +4 -9
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +71 -0
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +167 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/authentication.py +38 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +40 -16
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/temperature_mixin.py +4 -3
- nat/data_models/top_p_mixin.py +4 -3
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/eval/config.py +1 -1
- nat/eval/evaluate.py +5 -1
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +18 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +134 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +5 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +111 -3
- nat/front_ends/mcp/tool_converter.py +3 -0
- nat/llm/aws_bedrock_llm.py +14 -3
- nat/llm/nim_llm.py +14 -3
- nat/llm/openai_llm.py +8 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/profiler/decorators/framework_wrapper.py +9 -6
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +108 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/tool/chat_completion.py +4 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/register.py +2 -7
- nat/utils/callable_utils.py +70 -0
- nat/utils/exception_handlers/automatic_retries.py +103 -48
- nat/utils/log_levels.py +25 -0
- nat/utils/type_utils.py +4 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/METADATA +10 -1
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/RECORD +105 -76
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/entry_points.txt +1 -0
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import typing
|
|
18
|
+
|
|
19
|
+
from langchain_core.tools.base import BaseTool
|
|
20
|
+
from pydantic import BaseModel
|
|
21
|
+
from pydantic import Field
|
|
22
|
+
|
|
23
|
+
from nat.builder.builder import Builder
|
|
24
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
25
|
+
from nat.builder.function import Function
|
|
26
|
+
from nat.builder.function_info import FunctionInfo
|
|
27
|
+
from nat.cli.register_workflow import register_function
|
|
28
|
+
from nat.data_models.component_ref import FunctionRef
|
|
29
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
30
|
+
from nat.utils.type_utils import DecomposedType
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ToolExecutionConfig(BaseModel):
|
|
36
|
+
"""Configuration for individual tool execution within sequential execution."""
|
|
37
|
+
|
|
38
|
+
use_streaming: bool = Field(default=False, description="Whether to use streaming output for the tool.")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SequentialExecutorConfig(FunctionBaseConfig, name="sequential_executor"):
|
|
42
|
+
"""Configuration for sequential execution of a list of functions."""
|
|
43
|
+
|
|
44
|
+
tool_list: list[FunctionRef] = Field(default_factory=list,
|
|
45
|
+
description="A list of functions to execute sequentially.")
|
|
46
|
+
tool_execution_config: dict[str, ToolExecutionConfig] = Field(default_factory=dict,
|
|
47
|
+
description="Optional configuration for each"
|
|
48
|
+
"tool in the sequential execution tool list."
|
|
49
|
+
"Keys must match the tool names from the"
|
|
50
|
+
"tool_list.")
|
|
51
|
+
raise_type_incompatibility: bool = Field(
|
|
52
|
+
default=False,
|
|
53
|
+
description="Default to False. Check if the adjacent tools are type compatible,"
|
|
54
|
+
"which means the output type of the previous function is compatible with the input type of the next function."
|
|
55
|
+
"If set to True, any incompatibility will raise an exception. If set to false, the incompatibility will only"
|
|
56
|
+
"generate a warning message and the sequential execution will continue.")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _get_function_output_type(function: Function, tool_execution_config: dict[str, ToolExecutionConfig]) -> type:
|
|
60
|
+
function_config = tool_execution_config.get(function.instance_name, None)
|
|
61
|
+
if function_config:
|
|
62
|
+
return function.streaming_output_type if function_config.use_streaming else function.single_output_type
|
|
63
|
+
else:
|
|
64
|
+
return function.single_output_type
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _validate_function_type_compatibility(src_fn: Function,
|
|
68
|
+
target_fn: Function,
|
|
69
|
+
tool_execution_config: dict[str, ToolExecutionConfig]) -> None:
|
|
70
|
+
src_output_type = _get_function_output_type(src_fn, tool_execution_config)
|
|
71
|
+
target_input_type = target_fn.input_type
|
|
72
|
+
logger.debug(
|
|
73
|
+
f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
|
|
74
|
+
f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
|
|
75
|
+
|
|
76
|
+
is_compatible = DecomposedType.is_type_compatible(src_output_type, target_input_type)
|
|
77
|
+
if not is_compatible:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
|
|
80
|
+
f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _validate_tool_list_type_compatibility(sequential_executor_config: SequentialExecutorConfig,
|
|
84
|
+
builder: Builder) -> tuple[type, type]:
|
|
85
|
+
tool_list = sequential_executor_config.tool_list
|
|
86
|
+
tool_execution_config = sequential_executor_config.tool_execution_config
|
|
87
|
+
|
|
88
|
+
function_list: list[Function] = []
|
|
89
|
+
for function_ref in tool_list:
|
|
90
|
+
function_list.append(builder.get_function(function_ref))
|
|
91
|
+
if not function_list:
|
|
92
|
+
raise RuntimeError("The function list is empty")
|
|
93
|
+
input_type = function_list[0].input_type
|
|
94
|
+
|
|
95
|
+
if len(function_list) > 1:
|
|
96
|
+
for src_fn, target_fn in zip(function_list[0:-1], function_list[1:]):
|
|
97
|
+
try:
|
|
98
|
+
_validate_function_type_compatibility(src_fn, target_fn, tool_execution_config)
|
|
99
|
+
except ValueError as e:
|
|
100
|
+
raise ValueError(f"The sequential tool list has incompatible types: {e}")
|
|
101
|
+
|
|
102
|
+
output_type = _get_function_output_type(function_list[-1], tool_execution_config)
|
|
103
|
+
logger.debug(f"The input type of the sequential executor tool list is {str(input_type)},"
|
|
104
|
+
f"the output type is {str(output_type)}")
|
|
105
|
+
|
|
106
|
+
return (input_type, output_type)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@register_function(config_type=SequentialExecutorConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
110
|
+
async def sequential_execution(config: SequentialExecutorConfig, builder: Builder):
|
|
111
|
+
logger.debug(f"Initializing sequential executor with tool list: {config.tool_list}")
|
|
112
|
+
|
|
113
|
+
tools: list[BaseTool] = builder.get_tools(tool_names=config.tool_list, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
114
|
+
tools_dict: dict[str, BaseTool] = {tool.name: tool for tool in tools}
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
input_type, output_type = _validate_tool_list_type_compatibility(config, builder)
|
|
118
|
+
except ValueError as e:
|
|
119
|
+
if config.raise_type_incompatibility:
|
|
120
|
+
logger.error(f"The sequential executor tool list has incompatible types: {e}")
|
|
121
|
+
raise
|
|
122
|
+
else:
|
|
123
|
+
logger.warning(f"The sequential executor tool list has incompatible types: {e}")
|
|
124
|
+
input_type = typing.Any
|
|
125
|
+
output_type = typing.Any
|
|
126
|
+
except Exception as e:
|
|
127
|
+
raise ValueError(f"Error with the sequential executor tool list: {e}")
|
|
128
|
+
|
|
129
|
+
# The type annotation of _sequential_function_execution is dynamically set according to the tool list
|
|
130
|
+
async def _sequential_function_execution(initial_tool_input):
|
|
131
|
+
logger.debug(f"Executing sequential executor with tool list: {config.tool_list}")
|
|
132
|
+
|
|
133
|
+
tool_list: list[FunctionRef] = config.tool_list
|
|
134
|
+
tool_input = initial_tool_input
|
|
135
|
+
tool_response = None
|
|
136
|
+
|
|
137
|
+
for tool_name in tool_list:
|
|
138
|
+
tool = tools_dict[tool_name]
|
|
139
|
+
tool_execution_config = config.tool_execution_config.get(tool_name, None)
|
|
140
|
+
logger.debug(f"Executing tool {tool_name} with input: {tool_input}")
|
|
141
|
+
try:
|
|
142
|
+
if tool_execution_config:
|
|
143
|
+
if tool_execution_config.use_streaming:
|
|
144
|
+
output = ""
|
|
145
|
+
async for chunk in tool.astream(tool_input):
|
|
146
|
+
output += chunk.content
|
|
147
|
+
tool_response = output
|
|
148
|
+
else:
|
|
149
|
+
tool_response = await tool.ainvoke(tool_input)
|
|
150
|
+
else:
|
|
151
|
+
tool_response = await tool.ainvoke(tool_input)
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logger.error(f"Error with tool {tool_name}: {e}")
|
|
154
|
+
raise
|
|
155
|
+
|
|
156
|
+
# The input of the next tool is the response of the previous tool
|
|
157
|
+
tool_input = tool_response
|
|
158
|
+
|
|
159
|
+
return tool_response
|
|
160
|
+
|
|
161
|
+
# Dynamically set the annotations for the function
|
|
162
|
+
_sequential_function_execution.__annotations__ = {"initial_tool_input": input_type, "return": output_type}
|
|
163
|
+
logger.debug(f"Sequential executor function annotations: {_sequential_function_execution.__annotations__}")
|
|
164
|
+
|
|
165
|
+
yield FunctionInfo.from_fn(_sequential_function_execution,
|
|
166
|
+
description="Executes a list of functions sequentially."
|
|
167
|
+
"The input of the next tool is the response of the previous tool.")
|
nat/data_models/agent.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
from pydantic import PositiveInt
|
|
18
|
+
|
|
19
|
+
from nat.data_models.component_ref import LLMRef
|
|
20
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AgentBaseConfig(FunctionBaseConfig):
|
|
24
|
+
"""Base configuration class for all NAT agents with common fields."""
|
|
25
|
+
|
|
26
|
+
workflow_alias: str | None = Field(
|
|
27
|
+
default=None,
|
|
28
|
+
description=("The alias of the workflow. Useful when the agent is configured as a workflow "
|
|
29
|
+
"and needs to expose a customized name as a tool."))
|
|
30
|
+
llm_name: LLMRef = Field(description="The LLM model to use with the agent.")
|
|
31
|
+
verbose: bool = Field(default=False, description="Set the verbosity of the agent's logging.")
|
|
32
|
+
description: str = Field(description="The description of this function's use.")
|
|
33
|
+
log_response_max_chars: PositiveInt = Field(
|
|
34
|
+
default=1000, description="Maximum number of characters to display in logs when logging responses.")
|
|
@@ -177,6 +177,26 @@ Credential = typing.Annotated[
|
|
|
177
177
|
]
|
|
178
178
|
|
|
179
179
|
|
|
180
|
+
class TokenValidationResult(BaseModel):
|
|
181
|
+
"""
|
|
182
|
+
Standard result for Bearer Token Validation.
|
|
183
|
+
"""
|
|
184
|
+
model_config = ConfigDict(extra="forbid")
|
|
185
|
+
|
|
186
|
+
client_id: str | None = Field(description="OAuth2 client identifier")
|
|
187
|
+
scopes: list[str] | None = Field(default=None, description="List of granted scopes (introspection only)")
|
|
188
|
+
expires_at: int | None = Field(default=None, description="Token expiration time (Unix timestamp)")
|
|
189
|
+
audience: list[str] | None = Field(default=None, description="Token audiences (aud claim)")
|
|
190
|
+
subject: str | None = Field(default=None, description="Token subject (sub claim)")
|
|
191
|
+
issuer: str | None = Field(default=None, description="Token issuer (iss claim)")
|
|
192
|
+
token_type: str = Field(description="Token type")
|
|
193
|
+
active: bool | None = Field(default=True, description="Token active status")
|
|
194
|
+
nbf: int | None = Field(default=None, description="Not before time (Unix timestamp)")
|
|
195
|
+
iat: int | None = Field(default=None, description="Issued at time (Unix timestamp)")
|
|
196
|
+
jti: str | None = Field(default=None, description="JWT ID")
|
|
197
|
+
username: str | None = Field(default=None, description="Username (introspection only)")
|
|
198
|
+
|
|
199
|
+
|
|
180
200
|
class AuthResult(BaseModel):
|
|
181
201
|
"""
|
|
182
202
|
Represents the result of an authentication process.
|
|
@@ -229,3 +249,21 @@ class AuthResult(BaseModel):
|
|
|
229
249
|
target_kwargs.setdefault(k, {}).update(v)
|
|
230
250
|
else:
|
|
231
251
|
target_kwargs[k] = v
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class AuthReason(str, Enum):
|
|
255
|
+
"""
|
|
256
|
+
Why the caller is asking for auth now.
|
|
257
|
+
"""
|
|
258
|
+
NORMAL = "normal"
|
|
259
|
+
RETRY_AFTER_401 = "retry_after_401"
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class AuthRequest(BaseModel):
|
|
263
|
+
"""
|
|
264
|
+
Authentication request payload for provider.authenticate(...).
|
|
265
|
+
"""
|
|
266
|
+
model_config = ConfigDict(extra="forbid")
|
|
267
|
+
|
|
268
|
+
reason: AuthReason = Field(default=AuthReason.NORMAL, description="Purpose of this auth attempt.")
|
|
269
|
+
www_authenticate: str | None = Field(default=None, description="Raw WWW-Authenticate header from a 401 response.")
|
nat/data_models/component.py
CHANGED
|
@@ -27,6 +27,7 @@ class ComponentEnum(StrEnum):
|
|
|
27
27
|
EVALUATOR = "evaluator"
|
|
28
28
|
FRONT_END = "front_end"
|
|
29
29
|
FUNCTION = "function"
|
|
30
|
+
FUNCTION_GROUP = "function_group"
|
|
30
31
|
TTC_STRATEGY = "ttc_strategy"
|
|
31
32
|
LLM_CLIENT = "llm_client"
|
|
32
33
|
LLM_PROVIDER = "llm_provider"
|
|
@@ -47,6 +48,7 @@ class ComponentGroup(StrEnum):
|
|
|
47
48
|
AUTHENTICATION = "authentication"
|
|
48
49
|
EMBEDDERS = "embedders"
|
|
49
50
|
FUNCTIONS = "functions"
|
|
51
|
+
FUNCTION_GROUPS = "function_groups"
|
|
50
52
|
TTC_STRATEGIES = "ttc_strategies"
|
|
51
53
|
LLMS = "llms"
|
|
52
54
|
MEMORY = "memory"
|
nat/data_models/component_ref.py
CHANGED
|
@@ -102,6 +102,17 @@ class FunctionRef(ComponentRef):
|
|
|
102
102
|
return ComponentGroup.FUNCTIONS
|
|
103
103
|
|
|
104
104
|
|
|
105
|
+
class FunctionGroupRef(ComponentRef):
|
|
106
|
+
"""
|
|
107
|
+
A reference to a function group in a NAT configuration object.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
@override
|
|
112
|
+
def component_group(self):
|
|
113
|
+
return ComponentGroup.FUNCTION_GROUPS
|
|
114
|
+
|
|
115
|
+
|
|
105
116
|
class LLMRef(ComponentRef):
|
|
106
117
|
"""
|
|
107
118
|
A reference to an LLM in a NAT configuration object.
|
nat/data_models/config.py
CHANGED
|
@@ -20,6 +20,7 @@ import typing
|
|
|
20
20
|
from pydantic import BaseModel
|
|
21
21
|
from pydantic import ConfigDict
|
|
22
22
|
from pydantic import Discriminator
|
|
23
|
+
from pydantic import Field
|
|
23
24
|
from pydantic import ValidationError
|
|
24
25
|
from pydantic import ValidationInfo
|
|
25
26
|
from pydantic import ValidatorFunctionWrapHandler
|
|
@@ -29,7 +30,9 @@ from nat.data_models.evaluate import EvalConfig
|
|
|
29
30
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
30
31
|
from nat.data_models.function import EmptyFunctionConfig
|
|
31
32
|
from nat.data_models.function import FunctionBaseConfig
|
|
33
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
32
34
|
from nat.data_models.logging import LoggingBaseConfig
|
|
35
|
+
from nat.data_models.optimizer import OptimizerConfig
|
|
33
36
|
from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
|
34
37
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
35
38
|
from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
@@ -57,9 +60,10 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
57
60
|
error_type = e['type']
|
|
58
61
|
if error_type == 'union_tag_invalid' and "ctx" in e and not logged_once:
|
|
59
62
|
requested_type = e["ctx"]["tag"]
|
|
60
|
-
|
|
61
63
|
if (info.field_name in ('workflow', 'functions')):
|
|
62
64
|
registered_keys = GlobalTypeRegistry.get().get_registered_functions()
|
|
65
|
+
elif (info.field_name == "function_groups"):
|
|
66
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_function_groups()
|
|
63
67
|
elif (info.field_name == "authentication"):
|
|
64
68
|
registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
|
|
65
69
|
elif (info.field_name == "llms"):
|
|
@@ -135,8 +139,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
135
139
|
|
|
136
140
|
class TelemetryConfig(BaseModel):
|
|
137
141
|
|
|
138
|
-
logging: dict[str, LoggingBaseConfig] =
|
|
139
|
-
tracing: dict[str, TelemetryExporterBaseConfig] =
|
|
142
|
+
logging: dict[str, LoggingBaseConfig] = Field(default_factory=dict)
|
|
143
|
+
tracing: dict[str, TelemetryExporterBaseConfig] = Field(default_factory=dict)
|
|
140
144
|
|
|
141
145
|
@field_validator("logging", "tracing", mode="wrap")
|
|
142
146
|
@classmethod
|
|
@@ -185,10 +189,14 @@ class GeneralConfig(BaseModel):
|
|
|
185
189
|
|
|
186
190
|
model_config = ConfigDict(protected_namespaces=())
|
|
187
191
|
|
|
188
|
-
use_uvloop: bool =
|
|
192
|
+
use_uvloop: bool | None = Field(
|
|
193
|
+
default=None,
|
|
194
|
+
deprecated=
|
|
195
|
+
"`use_uvloop` field is deprecated and will be removed in a future release. The use of `uv_loop` is now" +
|
|
196
|
+
"automatically determined based on platform")
|
|
189
197
|
"""
|
|
190
|
-
|
|
191
|
-
|
|
198
|
+
This field is deprecated and ignored. It previously controlled whether to use uvloop as the event loop. uvloop
|
|
199
|
+
usage is now determined automatically based on the platform.
|
|
192
200
|
"""
|
|
193
201
|
|
|
194
202
|
telemetry: TelemetryConfig = TelemetryConfig()
|
|
@@ -240,31 +248,37 @@ class Config(HashableBaseModel):
|
|
|
240
248
|
general: GeneralConfig = GeneralConfig()
|
|
241
249
|
|
|
242
250
|
# Functions Configuration
|
|
243
|
-
functions: dict[str, FunctionBaseConfig] =
|
|
251
|
+
functions: dict[str, FunctionBaseConfig] = Field(default_factory=dict)
|
|
252
|
+
|
|
253
|
+
# Function Groups Configuration
|
|
254
|
+
function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
|
|
244
255
|
|
|
245
256
|
# LLMs Configuration
|
|
246
|
-
llms: dict[str, LLMBaseConfig] =
|
|
257
|
+
llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
|
|
247
258
|
|
|
248
259
|
# Embedders Configuration
|
|
249
|
-
embedders: dict[str, EmbedderBaseConfig] =
|
|
260
|
+
embedders: dict[str, EmbedderBaseConfig] = Field(default_factory=dict)
|
|
250
261
|
|
|
251
262
|
# Memory Configuration
|
|
252
|
-
memory: dict[str, MemoryBaseConfig] =
|
|
263
|
+
memory: dict[str, MemoryBaseConfig] = Field(default_factory=dict)
|
|
253
264
|
|
|
254
265
|
# Object Stores Configuration
|
|
255
|
-
object_stores: dict[str, ObjectStoreBaseConfig] =
|
|
266
|
+
object_stores: dict[str, ObjectStoreBaseConfig] = Field(default_factory=dict)
|
|
267
|
+
|
|
268
|
+
# Optimizer Configuration
|
|
269
|
+
optimizer: OptimizerConfig = OptimizerConfig()
|
|
256
270
|
|
|
257
271
|
# Retriever Configuration
|
|
258
|
-
retrievers: dict[str, RetrieverBaseConfig] =
|
|
272
|
+
retrievers: dict[str, RetrieverBaseConfig] = Field(default_factory=dict)
|
|
259
273
|
|
|
260
274
|
# TTC Strategies
|
|
261
|
-
ttc_strategies: dict[str, TTCStrategyBaseConfig] =
|
|
275
|
+
ttc_strategies: dict[str, TTCStrategyBaseConfig] = Field(default_factory=dict)
|
|
262
276
|
|
|
263
277
|
# Workflow Configuration
|
|
264
278
|
workflow: FunctionBaseConfig = EmptyFunctionConfig()
|
|
265
279
|
|
|
266
280
|
# Authentication Configuration
|
|
267
|
-
authentication: dict[str, AuthProviderBaseConfig] =
|
|
281
|
+
authentication: dict[str, AuthProviderBaseConfig] = Field(default_factory=dict)
|
|
268
282
|
|
|
269
283
|
# Evaluation Options
|
|
270
284
|
eval: EvalConfig = EvalConfig()
|
|
@@ -278,6 +292,7 @@ class Config(HashableBaseModel):
|
|
|
278
292
|
stream.write(f"Workflow Type: {self.workflow.type}\n")
|
|
279
293
|
|
|
280
294
|
stream.write(f"Number of Functions: {len(self.functions)}\n")
|
|
295
|
+
stream.write(f"Number of Function Groups: {len(self.function_groups)}\n")
|
|
281
296
|
stream.write(f"Number of LLMs: {len(self.llms)}\n")
|
|
282
297
|
stream.write(f"Number of Embedders: {len(self.embedders)}\n")
|
|
283
298
|
stream.write(f"Number of Memory: {len(self.memory)}\n")
|
|
@@ -287,6 +302,7 @@ class Config(HashableBaseModel):
|
|
|
287
302
|
stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
|
|
288
303
|
|
|
289
304
|
@field_validator("functions",
|
|
305
|
+
"function_groups",
|
|
290
306
|
"llms",
|
|
291
307
|
"embedders",
|
|
292
308
|
"memory",
|
|
@@ -328,6 +344,10 @@ class Config(HashableBaseModel):
|
|
|
328
344
|
typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
329
345
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
330
346
|
|
|
347
|
+
FunctionGroupsAnnotation = dict[str,
|
|
348
|
+
typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
|
|
349
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
350
|
+
|
|
331
351
|
MemoryAnnotation = dict[str,
|
|
332
352
|
typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
|
|
333
353
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -335,7 +355,6 @@ class Config(HashableBaseModel):
|
|
|
335
355
|
ObjectStoreAnnotation = dict[str,
|
|
336
356
|
typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
|
|
337
357
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
338
|
-
|
|
339
358
|
RetrieverAnnotation = dict[str,
|
|
340
359
|
typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
|
|
341
360
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -344,7 +363,7 @@ class Config(HashableBaseModel):
|
|
|
344
363
|
typing.Annotated[type_registry.compute_annotation(TTCStrategyBaseConfig),
|
|
345
364
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
346
365
|
|
|
347
|
-
WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
366
|
+
WorkflowAnnotation = typing.Annotated[(type_registry.compute_annotation(FunctionBaseConfig)),
|
|
348
367
|
Discriminator(TypedBaseModel.discriminator)]
|
|
349
368
|
|
|
350
369
|
should_rebuild = False
|
|
@@ -369,6 +388,11 @@ class Config(HashableBaseModel):
|
|
|
369
388
|
functions_field.annotation = FunctionsAnnotation
|
|
370
389
|
should_rebuild = True
|
|
371
390
|
|
|
391
|
+
function_groups_field = cls.model_fields.get("function_groups")
|
|
392
|
+
if function_groups_field is not None and function_groups_field.annotation != FunctionGroupsAnnotation:
|
|
393
|
+
function_groups_field.annotation = FunctionGroupsAnnotation
|
|
394
|
+
should_rebuild = True
|
|
395
|
+
|
|
372
396
|
memory_field = cls.model_fields.get("memory")
|
|
373
397
|
if memory_field is not None and memory_field.annotation != MemoryAnnotation:
|
|
374
398
|
memory_field.annotation = MemoryAnnotation
|
nat/data_models/function.py
CHANGED
|
@@ -15,6 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
import typing
|
|
17
17
|
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
from pydantic import field_validator
|
|
20
|
+
from pydantic import model_validator
|
|
21
|
+
|
|
18
22
|
from .common import BaseModelRegistryTag
|
|
19
23
|
from .common import TypedBaseModel
|
|
20
24
|
|
|
@@ -23,8 +27,38 @@ class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
|
23
27
|
pass
|
|
24
28
|
|
|
25
29
|
|
|
30
|
+
class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
31
|
+
"""Base configuration for function groups.
|
|
32
|
+
|
|
33
|
+
Function groups enable sharing of configurations and resources across multiple functions.
|
|
34
|
+
"""
|
|
35
|
+
include: list[str] = Field(
|
|
36
|
+
default_factory=list,
|
|
37
|
+
description="The list of function names which should be added to the global Function registry",
|
|
38
|
+
)
|
|
39
|
+
exclude: list[str] = Field(
|
|
40
|
+
default_factory=list,
|
|
41
|
+
description="The list of function names which should be excluded from default access to the group",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
@field_validator("include", "exclude")
|
|
45
|
+
@classmethod
|
|
46
|
+
def _validate_fields_include_exclude(cls, value: list[str]) -> list[str]:
|
|
47
|
+
if len(set(value)) != len(value):
|
|
48
|
+
raise ValueError("Function names must be unique")
|
|
49
|
+
return sorted(value)
|
|
50
|
+
|
|
51
|
+
@model_validator(mode="after")
|
|
52
|
+
def _validate_include_exclude(self):
|
|
53
|
+
if self.include and self.exclude:
|
|
54
|
+
raise ValueError("include and exclude cannot be used together")
|
|
55
|
+
return self
|
|
56
|
+
|
|
57
|
+
|
|
26
58
|
class EmptyFunctionConfig(FunctionBaseConfig, name="EmptyFunctionConfig"):
|
|
27
59
|
pass
|
|
28
60
|
|
|
29
61
|
|
|
30
62
|
FunctionConfigT = typing.TypeVar("FunctionConfigT", bound=FunctionBaseConfig)
|
|
63
|
+
|
|
64
|
+
FunctionGroupConfigT = typing.TypeVar("FunctionGroupConfigT", bound=FunctionGroupBaseConfig)
|
|
@@ -23,6 +23,7 @@ class FunctionDependencies(BaseModel):
|
|
|
23
23
|
A class to represent the dependencies of a function.
|
|
24
24
|
"""
|
|
25
25
|
functions: set[str] = Field(default_factory=set)
|
|
26
|
+
function_groups: set[str] = Field(default_factory=set)
|
|
26
27
|
llms: set[str] = Field(default_factory=set)
|
|
27
28
|
embedders: set[str] = Field(default_factory=set)
|
|
28
29
|
memory_clients: set[str] = Field(default_factory=set)
|
|
@@ -33,6 +34,10 @@ class FunctionDependencies(BaseModel):
|
|
|
33
34
|
def serialize_functions(self, v: set[str]) -> list[str]:
|
|
34
35
|
return list(v)
|
|
35
36
|
|
|
37
|
+
@field_serializer("function_groups", when_used="json")
|
|
38
|
+
def serialize_function_groups(self, v: set[str]) -> list[str]:
|
|
39
|
+
return list(v)
|
|
40
|
+
|
|
36
41
|
@field_serializer("llms", when_used="json")
|
|
37
42
|
def serialize_llms(self, v: set[str]) -> list[str]:
|
|
38
43
|
return list(v)
|
|
@@ -56,6 +61,9 @@ class FunctionDependencies(BaseModel):
|
|
|
56
61
|
def add_function(self, function: str):
|
|
57
62
|
self.functions.add(function)
|
|
58
63
|
|
|
64
|
+
def add_function_group(self, function_group: str):
|
|
65
|
+
self.function_groups.add(function_group) # pylint: disable=no-member
|
|
66
|
+
|
|
59
67
|
def add_llm(self, llm: str):
|
|
60
68
|
self.llms.add(llm)
|
|
61
69
|
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections.abc import Sequence
|
|
17
|
+
from typing import Any
|
|
18
|
+
from typing import Generic
|
|
19
|
+
from typing import TypeVar
|
|
20
|
+
|
|
21
|
+
from optuna import Trial
|
|
22
|
+
from pydantic import BaseModel
|
|
23
|
+
from pydantic import ConfigDict
|
|
24
|
+
from pydantic import Field
|
|
25
|
+
from pydantic import model_validator
|
|
26
|
+
|
|
27
|
+
T = TypeVar("T", int, float, bool, str)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# --------------------------------------------------------------------- #
|
|
31
|
+
# 1. Hyper‑parameter metadata container #
|
|
32
|
+
# --------------------------------------------------------------------- #
|
|
33
|
+
class SearchSpace(BaseModel, Generic[T]):
|
|
34
|
+
values: Sequence[T] | None = None
|
|
35
|
+
low: T | None = None
|
|
36
|
+
high: T | None = None
|
|
37
|
+
log: bool = False # log scale
|
|
38
|
+
step: float | None = None
|
|
39
|
+
is_prompt: bool = False
|
|
40
|
+
prompt: str | None = None # prompt to optimize
|
|
41
|
+
prompt_purpose: str | None = None # purpose of the prompt
|
|
42
|
+
|
|
43
|
+
model_config = ConfigDict(protected_namespaces=(), extra="forbid")
|
|
44
|
+
|
|
45
|
+
@model_validator(mode="after")
|
|
46
|
+
def validate_search_space_parameters(self):
|
|
47
|
+
"""Validate that either values is provided, or both high and low."""
|
|
48
|
+
if self.values is not None:
|
|
49
|
+
# If values is provided, we don't need high/low
|
|
50
|
+
if self.high is not None or self.low is not None:
|
|
51
|
+
raise ValueError("SearchSpace 'values' is mutually exclusive with 'high' and 'low'")
|
|
52
|
+
return self
|
|
53
|
+
|
|
54
|
+
return self
|
|
55
|
+
|
|
56
|
+
# Helper for Optuna Trials
|
|
57
|
+
def suggest(self, trial: Trial, name: str):
|
|
58
|
+
if self.is_prompt:
|
|
59
|
+
raise ValueError("Prompt optimization not currently supported using Optuna. "
|
|
60
|
+
"Use the genetic algorithm implementation instead.")
|
|
61
|
+
if self.values is not None:
|
|
62
|
+
return trial.suggest_categorical(name, self.values)
|
|
63
|
+
if isinstance(self.low, int):
|
|
64
|
+
return trial.suggest_int(name, self.low, self.high, log=self.log, step=self.step)
|
|
65
|
+
return trial.suggest_float(name, self.low, self.high, log=self.log, step=self.step)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def OptimizableField(
|
|
69
|
+
default: Any,
|
|
70
|
+
*,
|
|
71
|
+
space: SearchSpace | None = None,
|
|
72
|
+
merge_conflict: str = "overwrite",
|
|
73
|
+
**fld_kw,
|
|
74
|
+
):
|
|
75
|
+
# 1. Pull out any user‑supplied extras (must be a dict)
|
|
76
|
+
user_extra = fld_kw.pop("json_schema_extra", None) or {}
|
|
77
|
+
if not isinstance(user_extra, dict):
|
|
78
|
+
raise TypeError("`json_schema_extra` must be a mapping.")
|
|
79
|
+
|
|
80
|
+
# 2. If the space is a prompt, ensure a concrete base prompt exists
|
|
81
|
+
if space is not None and getattr(space, "is_prompt", False):
|
|
82
|
+
if getattr(space, "prompt", None) is None:
|
|
83
|
+
if default is None:
|
|
84
|
+
raise ValueError("Prompt-optimized fields require a base prompt: provide a "
|
|
85
|
+
"non-None field default or set space.prompt.")
|
|
86
|
+
# Default prompt not provided in space; fall back to the field's default
|
|
87
|
+
space.prompt = default
|
|
88
|
+
|
|
89
|
+
# 3. Prepare our own metadata
|
|
90
|
+
ours = {"optimizable": True}
|
|
91
|
+
if space is not None:
|
|
92
|
+
ours["search_space"] = space
|
|
93
|
+
|
|
94
|
+
# 4. Merge with user extras according to merge_conflict policy
|
|
95
|
+
intersect = ours.keys() & user_extra.keys()
|
|
96
|
+
if intersect:
|
|
97
|
+
if merge_conflict == "error":
|
|
98
|
+
raise ValueError("`json_schema_extra` already contains reserved key(s): "
|
|
99
|
+
f"{', '.join(intersect)}")
|
|
100
|
+
if merge_conflict == "keep":
|
|
101
|
+
# remove the ones the user already set so we don't overwrite them
|
|
102
|
+
ours = {k: v for k, v in ours.items() if k not in intersect}
|
|
103
|
+
|
|
104
|
+
merged_extra = {**user_extra, **ours} # ours wins if 'overwrite'
|
|
105
|
+
|
|
106
|
+
# 5. Return a normal Pydantic Field with merged extras
|
|
107
|
+
return Field(default, json_schema_extra=merged_extra, **fld_kw)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class OptimizableMixin(BaseModel):
|
|
111
|
+
optimizable_params: list[str] = Field(default_factory=list,
|
|
112
|
+
description="List of parameters that can be optimized.",
|
|
113
|
+
exclude=True)
|
|
114
|
+
|
|
115
|
+
search_space: dict[str, SearchSpace] = Field(
|
|
116
|
+
default_factory=dict,
|
|
117
|
+
description="Optional search space overrides for optimizable parameters.",
|
|
118
|
+
exclude=True,
|
|
119
|
+
)
|