nvidia-nat 1.3.0a20250923__py3-none-any.whl → 1.3.0a20250924__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/register.py +12 -1
- nat/agent/reasoning_agent/reasoning_agent.py +2 -2
- nat/agent/rewoo_agent/register.py +12 -1
- nat/agent/tool_calling_agent/register.py +28 -8
- nat/builder/builder.py +33 -24
- nat/builder/eval_builder.py +14 -9
- nat/builder/function.py +108 -52
- nat/builder/workflow_builder.py +89 -79
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +786 -0
- nat/cli/entrypoint.py +2 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/control_flow/sequential_executor.py +6 -7
- nat/eval/evaluate.py +2 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +26 -5
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +2 -2
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +2 -2
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +4 -4
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +1 -1
- nat/profiler/decorators/function_tracking.py +33 -1
- nat/profiler/parameter_optimization/prompt_optimizer.py +2 -2
- nat/runtime/loader.py +1 -1
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/METADATA +1 -1
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/RECORD +34 -33
- nat/cli/commands/info/list_mcp.py +0 -461
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/top_level.txt +0 -0
|
@@ -99,7 +99,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
99
99
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
100
100
|
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
101
101
|
# the sample tool provided can easily be copied or changed
|
|
102
|
-
tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
102
|
+
tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
103
103
|
if not tools:
|
|
104
104
|
raise ValueError(f"No tools specified for ReAct Agent '{config.llm_name}'")
|
|
105
105
|
# configure callbacks, for sending intermediate steps
|
|
@@ -118,6 +118,17 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
118
118
|
normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
|
|
119
119
|
|
|
120
120
|
async def _response_fn(input_message: ChatRequest) -> ChatResponse:
|
|
121
|
+
"""
|
|
122
|
+
Main workflow entry function for the ReAct Agent.
|
|
123
|
+
|
|
124
|
+
This function invokes the ReAct Agent Graph and returns the response.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
input_message (ChatRequest): The input message to process
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
ChatResponse: The response from the agent or error message
|
|
131
|
+
"""
|
|
121
132
|
try:
|
|
122
133
|
# initialize the starting state with the user query
|
|
123
134
|
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
|
|
@@ -99,7 +99,7 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
99
99
|
llm: BaseChatModel = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
100
100
|
|
|
101
101
|
# Get the augmented function's description
|
|
102
|
-
augmented_function = builder.get_function(config.augmented_fn)
|
|
102
|
+
augmented_function = await builder.get_function(config.augmented_fn)
|
|
103
103
|
|
|
104
104
|
# For now, we rely on runtime checking for type conversion
|
|
105
105
|
|
|
@@ -119,7 +119,7 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
119
119
|
tool_names_with_desc: list[tuple[str, str]] = []
|
|
120
120
|
|
|
121
121
|
for tool in function_used_tools:
|
|
122
|
-
tool_impl = builder.get_function(tool)
|
|
122
|
+
tool_impl = await builder.get_function(tool)
|
|
123
123
|
tool_names_with_desc.append((tool, tool_impl.description if hasattr(tool_impl, "description") else ""))
|
|
124
124
|
|
|
125
125
|
# Draft the reasoning prompt for the augmented function
|
|
@@ -108,7 +108,7 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
108
108
|
|
|
109
109
|
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
110
110
|
# the sample tool provided can easily be copied or changed
|
|
111
|
-
tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
111
|
+
tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
112
112
|
if not tools:
|
|
113
113
|
raise ValueError(f"No tools specified for ReWOO Agent '{config.llm_name}'")
|
|
114
114
|
|
|
@@ -125,6 +125,17 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
125
125
|
raise_tool_call_error=config.raise_tool_call_error).build_graph()
|
|
126
126
|
|
|
127
127
|
async def _response_fn(input_message: ChatRequest) -> ChatResponse:
|
|
128
|
+
"""
|
|
129
|
+
Main workflow entry function for the ReWOO Agent.
|
|
130
|
+
|
|
131
|
+
This function invokes the ReWOO Agent Graph and returns the response.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
input_message (ChatRequest): The input message to process
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
ChatResponse: The response from the agent or error message
|
|
138
|
+
"""
|
|
128
139
|
try:
|
|
129
140
|
# initialize the starting state with the user query
|
|
130
141
|
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
|
|
@@ -22,6 +22,7 @@ from nat.builder.framework_enum import LLMFrameworkEnum
|
|
|
22
22
|
from nat.builder.function_info import FunctionInfo
|
|
23
23
|
from nat.cli.register_workflow import register_function
|
|
24
24
|
from nat.data_models.agent import AgentBaseConfig
|
|
25
|
+
from nat.data_models.api_server import ChatRequest
|
|
25
26
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
26
27
|
from nat.data_models.component_ref import FunctionRef
|
|
27
28
|
|
|
@@ -38,6 +39,8 @@ class ToolCallAgentWorkflowConfig(AgentBaseConfig, name="tool_calling_agent"):
|
|
|
38
39
|
default_factory=list, description="The list of tools to provide to the tool calling agent.")
|
|
39
40
|
handle_tool_errors: bool = Field(default=True, description="Specify ability to handle tool calling errors.")
|
|
40
41
|
max_iterations: int = Field(default=15, description="Number of tool calls before stoping the tool calling agent.")
|
|
42
|
+
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
43
|
+
|
|
41
44
|
system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
|
|
42
45
|
additional_instructions: str | None = Field(default=None,
|
|
43
46
|
description="Additional instructions appended to the system prompt.")
|
|
@@ -47,7 +50,8 @@ class ToolCallAgentWorkflowConfig(AgentBaseConfig, name="tool_calling_agent"):
|
|
|
47
50
|
|
|
48
51
|
@register_function(config_type=ToolCallAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
49
52
|
async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, builder: Builder):
|
|
50
|
-
from langchain_core.messages
|
|
53
|
+
from langchain_core.messages import trim_messages
|
|
54
|
+
from langchain_core.messages.base import BaseMessage
|
|
51
55
|
from langgraph.graph.state import CompiledStateGraph
|
|
52
56
|
|
|
53
57
|
from nat.agent.base import AGENT_LOG_PREFIX
|
|
@@ -60,13 +64,13 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
60
64
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
61
65
|
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
62
66
|
# the sample tools provided can easily be copied or changed
|
|
63
|
-
tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
67
|
+
tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
64
68
|
if not tools:
|
|
65
69
|
raise ValueError(f"No tools specified for Tool Calling Agent '{config.llm_name}'")
|
|
66
70
|
|
|
67
71
|
# convert return_direct FunctionRef objects to BaseTool objects
|
|
68
|
-
return_direct_tools = builder.get_tools(
|
|
69
|
-
|
|
72
|
+
return_direct_tools = await builder.get_tools(
|
|
73
|
+
tool_names=config.return_direct, wrapper_type=LLMFrameworkEnum.LANGCHAIN) if config.return_direct else None
|
|
70
74
|
|
|
71
75
|
# construct the Tool Calling Agent Graph from the configured llm, and tools
|
|
72
76
|
graph: CompiledStateGraph = await ToolCallAgentGraph(llm=llm,
|
|
@@ -77,11 +81,27 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
77
81
|
handle_tool_errors=config.handle_tool_errors,
|
|
78
82
|
return_direct=return_direct_tools).build_graph()
|
|
79
83
|
|
|
80
|
-
async def _response_fn(input_message:
|
|
84
|
+
async def _response_fn(input_message: ChatRequest) -> str:
|
|
85
|
+
"""
|
|
86
|
+
Main workflow entry function for the Tool Calling Agent.
|
|
87
|
+
|
|
88
|
+
This function invokes the Tool Calling Agent Graph and returns the response.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
input_message (ChatRequest): The input message to process
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
str: The response from the agent or error message
|
|
95
|
+
"""
|
|
81
96
|
try:
|
|
82
97
|
# initialize the starting state with the user query
|
|
83
|
-
|
|
84
|
-
|
|
98
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
|
|
99
|
+
max_tokens=config.max_history,
|
|
100
|
+
strategy="last",
|
|
101
|
+
token_counter=len,
|
|
102
|
+
start_on="human",
|
|
103
|
+
include_system=True)
|
|
104
|
+
state = ToolCallAgentGraphState(messages=messages)
|
|
85
105
|
|
|
86
106
|
# run the Tool Calling Agent Graph
|
|
87
107
|
state = await graph.ainvoke(state, config={'recursion_limit': (config.max_iterations + 1) * 2})
|
|
@@ -92,7 +112,7 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
92
112
|
# get and return the output from the state
|
|
93
113
|
state = ToolCallAgentGraphState(**state)
|
|
94
114
|
output_message = state.messages[-1]
|
|
95
|
-
return output_message.content
|
|
115
|
+
return str(output_message.content)
|
|
96
116
|
except Exception as ex:
|
|
97
117
|
logger.exception("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
|
|
98
118
|
if config.verbose:
|
nat/builder/builder.py
CHANGED
|
@@ -45,12 +45,16 @@ from nat.data_models.memory import MemoryBaseConfig
|
|
|
45
45
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
46
46
|
from nat.data_models.retriever import RetrieverBaseConfig
|
|
47
47
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
48
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
48
49
|
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
49
50
|
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
50
51
|
from nat.memory.interfaces import MemoryEditor
|
|
51
52
|
from nat.object_store.interfaces import ObjectStore
|
|
52
53
|
from nat.retriever.interface import Retriever
|
|
53
54
|
|
|
55
|
+
if typing.TYPE_CHECKING:
|
|
56
|
+
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
57
|
+
|
|
54
58
|
|
|
55
59
|
class UserManagerHolder():
|
|
56
60
|
|
|
@@ -72,19 +76,20 @@ class Builder(ABC):
|
|
|
72
76
|
pass
|
|
73
77
|
|
|
74
78
|
@abstractmethod
|
|
75
|
-
def get_function(self, name: str | FunctionRef) -> Function:
|
|
79
|
+
async def get_function(self, name: str | FunctionRef) -> Function:
|
|
76
80
|
pass
|
|
77
81
|
|
|
78
82
|
@abstractmethod
|
|
79
|
-
def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
83
|
+
async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
80
84
|
pass
|
|
81
85
|
|
|
82
|
-
def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
|
|
83
|
-
|
|
84
|
-
return
|
|
86
|
+
async def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
|
|
87
|
+
tasks = [self.get_function(name) for name in function_names]
|
|
88
|
+
return list(await asyncio.gather(*tasks, return_exceptions=False))
|
|
85
89
|
|
|
86
|
-
def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
|
|
87
|
-
|
|
90
|
+
async def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
|
|
91
|
+
tasks = [self.get_function_group(name) for name in function_group_names]
|
|
92
|
+
return list(await asyncio.gather(*tasks, return_exceptions=False))
|
|
88
93
|
|
|
89
94
|
@abstractmethod
|
|
90
95
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
@@ -107,17 +112,17 @@ class Builder(ABC):
|
|
|
107
112
|
pass
|
|
108
113
|
|
|
109
114
|
@abstractmethod
|
|
110
|
-
def get_tools(self,
|
|
111
|
-
|
|
112
|
-
|
|
115
|
+
async def get_tools(self,
|
|
116
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
117
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
113
118
|
pass
|
|
114
119
|
|
|
115
120
|
@abstractmethod
|
|
116
|
-
def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
121
|
+
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
117
122
|
pass
|
|
118
123
|
|
|
119
124
|
@abstractmethod
|
|
120
|
-
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
|
|
125
|
+
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> typing.Any:
|
|
121
126
|
pass
|
|
122
127
|
|
|
123
128
|
@abstractmethod
|
|
@@ -138,7 +143,9 @@ class Builder(ABC):
|
|
|
138
143
|
pass
|
|
139
144
|
|
|
140
145
|
@abstractmethod
|
|
141
|
-
|
|
146
|
+
@experimental(feature_name="Authentication")
|
|
147
|
+
async def add_auth_provider(self, name: str | AuthenticationRef,
|
|
148
|
+
config: AuthProviderBaseConfig) -> AuthProviderBase:
|
|
142
149
|
pass
|
|
143
150
|
|
|
144
151
|
@abstractmethod
|
|
@@ -154,7 +161,7 @@ class Builder(ABC):
|
|
|
154
161
|
return list(auth_providers)
|
|
155
162
|
|
|
156
163
|
@abstractmethod
|
|
157
|
-
async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig):
|
|
164
|
+
async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig) -> ObjectStore:
|
|
158
165
|
pass
|
|
159
166
|
|
|
160
167
|
async def get_object_store_clients(self, object_store_names: Sequence[str | ObjectStoreRef]) -> list[ObjectStore]:
|
|
@@ -172,7 +179,7 @@ class Builder(ABC):
|
|
|
172
179
|
pass
|
|
173
180
|
|
|
174
181
|
@abstractmethod
|
|
175
|
-
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
|
|
182
|
+
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
|
|
176
183
|
pass
|
|
177
184
|
|
|
178
185
|
async def get_embedders(self, embedder_names: Sequence[str | EmbedderRef],
|
|
@@ -193,17 +200,18 @@ class Builder(ABC):
|
|
|
193
200
|
pass
|
|
194
201
|
|
|
195
202
|
@abstractmethod
|
|
196
|
-
async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig):
|
|
203
|
+
async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig) -> MemoryEditor:
|
|
197
204
|
pass
|
|
198
205
|
|
|
199
|
-
def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
|
|
206
|
+
async def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
|
|
200
207
|
"""
|
|
201
208
|
Return a list of memory clients for the specified names.
|
|
202
209
|
"""
|
|
203
|
-
|
|
210
|
+
tasks = [self.get_memory_client(n) for n in memory_names]
|
|
211
|
+
return list(await asyncio.gather(*tasks, return_exceptions=False))
|
|
204
212
|
|
|
205
213
|
@abstractmethod
|
|
206
|
-
def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
214
|
+
async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
207
215
|
"""
|
|
208
216
|
Return the instantiated memory client for the given name.
|
|
209
217
|
"""
|
|
@@ -214,12 +222,12 @@ class Builder(ABC):
|
|
|
214
222
|
pass
|
|
215
223
|
|
|
216
224
|
@abstractmethod
|
|
217
|
-
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
|
|
225
|
+
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
|
|
218
226
|
pass
|
|
219
227
|
|
|
220
228
|
async def get_retrievers(self,
|
|
221
229
|
retriever_names: Sequence[str | RetrieverRef],
|
|
222
|
-
wrapper_type: LLMFrameworkEnum | str | None = None):
|
|
230
|
+
wrapper_type: LLMFrameworkEnum | str | None = None) -> list[Retriever]:
|
|
223
231
|
|
|
224
232
|
tasks = [self.get_retriever(n, wrapper_type=wrapper_type) for n in retriever_names]
|
|
225
233
|
|
|
@@ -251,14 +259,15 @@ class Builder(ABC):
|
|
|
251
259
|
pass
|
|
252
260
|
|
|
253
261
|
@abstractmethod
|
|
254
|
-
|
|
262
|
+
@experimental(feature_name="TTC")
|
|
263
|
+
async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig):
|
|
255
264
|
pass
|
|
256
265
|
|
|
257
266
|
@abstractmethod
|
|
258
267
|
async def get_ttc_strategy(self,
|
|
259
268
|
strategy_name: str | TTCStrategyRef,
|
|
260
269
|
pipeline_type: PipelineTypeEnum,
|
|
261
|
-
stage_type: StageTypeEnum):
|
|
270
|
+
stage_type: StageTypeEnum) -> "StrategyBase":
|
|
262
271
|
pass
|
|
263
272
|
|
|
264
273
|
@abstractmethod
|
|
@@ -304,5 +313,5 @@ class EvalBuilder(ABC):
|
|
|
304
313
|
pass
|
|
305
314
|
|
|
306
315
|
@abstractmethod
|
|
307
|
-
def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
316
|
+
async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
308
317
|
pass
|
nat/builder/eval_builder.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import dataclasses
|
|
17
18
|
import logging
|
|
18
19
|
from contextlib import asynccontextmanager
|
|
@@ -90,17 +91,20 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
|
|
|
90
91
|
return self.eval_general_config.output_dir
|
|
91
92
|
|
|
92
93
|
@override
|
|
93
|
-
def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
|
|
94
|
-
tools = []
|
|
94
|
+
async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
|
|
95
95
|
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
|
|
97
|
+
async def get_tool(fn_name: str):
|
|
98
|
+
fn = await self.get_function(fn_name)
|
|
98
99
|
try:
|
|
99
|
-
|
|
100
|
+
return tool_wrapper_reg.build_fn(fn_name, fn, self)
|
|
100
101
|
except Exception:
|
|
101
102
|
logger.exception("Error fetching tool `%s`", fn_name)
|
|
103
|
+
return None
|
|
102
104
|
|
|
103
|
-
|
|
105
|
+
tasks = [get_tool(fn_name) for fn_name in self._functions]
|
|
106
|
+
tools = await asyncio.gather(*tasks, return_exceptions=False)
|
|
107
|
+
return [tool for tool in tools if tool is not None]
|
|
104
108
|
|
|
105
109
|
def _log_build_failure_evaluator(self,
|
|
106
110
|
failing_evaluator_name: str,
|
|
@@ -127,11 +131,12 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
|
|
|
127
131
|
remaining_components,
|
|
128
132
|
original_error)
|
|
129
133
|
|
|
130
|
-
|
|
134
|
+
@override
|
|
135
|
+
async def populate_builder(self, config: Config, skip_workflow: bool = False):
|
|
131
136
|
# Skip setting workflow if workflow config is EmptyFunctionConfig
|
|
132
|
-
skip_workflow = isinstance(config.workflow, EmptyFunctionConfig)
|
|
137
|
+
skip_workflow = skip_workflow or isinstance(config.workflow, EmptyFunctionConfig)
|
|
133
138
|
|
|
134
|
-
await super().populate_builder(config, skip_workflow)
|
|
139
|
+
await super().populate_builder(config, skip_workflow=skip_workflow)
|
|
135
140
|
|
|
136
141
|
# Initialize progress tracking for evaluators
|
|
137
142
|
completed_evaluators = []
|
nat/builder/function.py
CHANGED
|
@@ -357,7 +357,7 @@ class FunctionGroup:
|
|
|
357
357
|
*,
|
|
358
358
|
config: FunctionGroupBaseConfig,
|
|
359
359
|
instance_name: str | None = None,
|
|
360
|
-
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None):
|
|
360
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None):
|
|
361
361
|
"""
|
|
362
362
|
Creates a new function group.
|
|
363
363
|
|
|
@@ -367,7 +367,7 @@ class FunctionGroup:
|
|
|
367
367
|
The configuration for the function group.
|
|
368
368
|
instance_name : str | None, optional
|
|
369
369
|
The name of the function group. If not provided, the type of the function group will be used.
|
|
370
|
-
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
370
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
|
|
371
371
|
A callback function to additionally filter the functions in the function group dynamically when
|
|
372
372
|
the functions are accessed via any accessor method.
|
|
373
373
|
"""
|
|
@@ -375,7 +375,7 @@ class FunctionGroup:
|
|
|
375
375
|
self._instance_name = instance_name or config.type
|
|
376
376
|
self._functions: dict[str, Function] = dict()
|
|
377
377
|
self._filter_fn = filter_fn
|
|
378
|
-
self._per_function_filter_fn: dict[str, Callable[[str], bool]] = dict()
|
|
378
|
+
self._per_function_filter_fn: dict[str, Callable[[str], Awaitable[bool]]] = dict()
|
|
379
379
|
|
|
380
380
|
def add_function(self,
|
|
381
381
|
name: str,
|
|
@@ -384,7 +384,7 @@ class FunctionGroup:
|
|
|
384
384
|
input_schema: type[BaseModel] | None = None,
|
|
385
385
|
description: str | None = None,
|
|
386
386
|
converters: list[Callable] | None = None,
|
|
387
|
-
filter_fn: Callable[[str], bool] | None = None):
|
|
387
|
+
filter_fn: Callable[[str], Awaitable[bool]] | None = None):
|
|
388
388
|
"""
|
|
389
389
|
Adds a function to the function group.
|
|
390
390
|
|
|
@@ -400,7 +400,7 @@ class FunctionGroup:
|
|
|
400
400
|
The description of the function.
|
|
401
401
|
converters : list[Callable] | None, optional
|
|
402
402
|
The converters to use for the function.
|
|
403
|
-
filter_fn : Callable[[str], bool] | None, optional
|
|
403
|
+
filter_fn : Callable[[str], Awaitable[bool]] | None, optional
|
|
404
404
|
A callback to determine if the function should be included in the function group. The
|
|
405
405
|
callback will be called with the function name. The callback is invoked dynamically when
|
|
406
406
|
the functions are accessed via any accessor method such as `get_accessible_functions`,
|
|
@@ -441,12 +441,14 @@ class FunctionGroup:
|
|
|
441
441
|
def _get_fn_name(self, name: str) -> str:
|
|
442
442
|
return f"{self._instance_name}.{name}"
|
|
443
443
|
|
|
444
|
-
def _fn_should_be_included(self, name: str) -> bool:
|
|
445
|
-
|
|
444
|
+
async def _fn_should_be_included(self, name: str) -> bool:
|
|
445
|
+
if name not in self._per_function_filter_fn:
|
|
446
|
+
return True
|
|
447
|
+
return await self._per_function_filter_fn[name](name)
|
|
446
448
|
|
|
447
|
-
def _get_all_but_excluded_functions(
|
|
449
|
+
async def _get_all_but_excluded_functions(
|
|
448
450
|
self,
|
|
449
|
-
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
451
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
|
|
450
452
|
) -> dict[str, Function]:
|
|
451
453
|
"""
|
|
452
454
|
Returns a dictionary of all functions in the function group except the excluded functions.
|
|
@@ -454,22 +456,35 @@ class FunctionGroup:
|
|
|
454
456
|
missing = set(self._config.exclude) - set(self._functions.keys())
|
|
455
457
|
if missing:
|
|
456
458
|
raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
|
|
457
|
-
|
|
459
|
+
|
|
460
|
+
if filter_fn is None:
|
|
461
|
+
if self._filter_fn is None:
|
|
462
|
+
|
|
463
|
+
async def identity_filter(x: Sequence[str]) -> Sequence[str]:
|
|
464
|
+
return x
|
|
465
|
+
|
|
466
|
+
filter_fn = identity_filter
|
|
467
|
+
else:
|
|
468
|
+
filter_fn = self._filter_fn
|
|
469
|
+
|
|
458
470
|
excluded = set(self._config.exclude)
|
|
459
|
-
included = set(filter_fn(list(self._functions.keys())))
|
|
471
|
+
included = set(await filter_fn(list(self._functions.keys())))
|
|
460
472
|
|
|
461
|
-
|
|
473
|
+
result = {}
|
|
474
|
+
for name in self._functions:
|
|
462
475
|
if name in excluded:
|
|
463
|
-
|
|
464
|
-
if not self._fn_should_be_included(name):
|
|
465
|
-
|
|
466
|
-
|
|
476
|
+
continue
|
|
477
|
+
if not await self._fn_should_be_included(name):
|
|
478
|
+
continue
|
|
479
|
+
if name not in included:
|
|
480
|
+
continue
|
|
481
|
+
result[self._get_fn_name(name)] = self._functions[name]
|
|
467
482
|
|
|
468
|
-
return
|
|
483
|
+
return result
|
|
469
484
|
|
|
470
|
-
def get_accessible_functions(
|
|
485
|
+
async def get_accessible_functions(
|
|
471
486
|
self,
|
|
472
|
-
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
487
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
|
|
473
488
|
) -> dict[str, Function]:
|
|
474
489
|
"""
|
|
475
490
|
Returns a dictionary of all accessible functions in the function group.
|
|
@@ -484,7 +499,7 @@ class FunctionGroup:
|
|
|
484
499
|
|
|
485
500
|
Parameters
|
|
486
501
|
----------
|
|
487
|
-
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
502
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
|
|
488
503
|
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
489
504
|
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
490
505
|
all functions will be returned.
|
|
@@ -500,14 +515,14 @@ class FunctionGroup:
|
|
|
500
515
|
When the function group is configured to include functions that are not found in the group.
|
|
501
516
|
"""
|
|
502
517
|
if self._config.include:
|
|
503
|
-
return self.get_included_functions(filter_fn=filter_fn)
|
|
518
|
+
return await self.get_included_functions(filter_fn=filter_fn)
|
|
504
519
|
if self._config.exclude:
|
|
505
|
-
return self._get_all_but_excluded_functions(filter_fn=filter_fn)
|
|
506
|
-
return self.get_all_functions(filter_fn=filter_fn)
|
|
520
|
+
return await self._get_all_but_excluded_functions(filter_fn=filter_fn)
|
|
521
|
+
return await self.get_all_functions(filter_fn=filter_fn)
|
|
507
522
|
|
|
508
|
-
def get_excluded_functions(
|
|
523
|
+
async def get_excluded_functions(
|
|
509
524
|
self,
|
|
510
|
-
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
525
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
|
|
511
526
|
) -> dict[str, Function]:
|
|
512
527
|
"""
|
|
513
528
|
Returns a dictionary of all functions in the function group which are configured to be excluded or filtered
|
|
@@ -515,7 +530,7 @@ class FunctionGroup:
|
|
|
515
530
|
|
|
516
531
|
Parameters
|
|
517
532
|
----------
|
|
518
|
-
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
533
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
|
|
519
534
|
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
520
535
|
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
521
536
|
then no functions will be added to the returned dictionary.
|
|
@@ -533,22 +548,38 @@ class FunctionGroup:
|
|
|
533
548
|
missing = set(self._config.exclude) - set(self._functions.keys())
|
|
534
549
|
if missing:
|
|
535
550
|
raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
|
|
536
|
-
|
|
551
|
+
|
|
552
|
+
if filter_fn is None:
|
|
553
|
+
if self._filter_fn is None:
|
|
554
|
+
|
|
555
|
+
async def identity_filter(x: Sequence[str]) -> Sequence[str]:
|
|
556
|
+
return x
|
|
557
|
+
|
|
558
|
+
filter_fn = identity_filter
|
|
559
|
+
else:
|
|
560
|
+
filter_fn = self._filter_fn
|
|
561
|
+
|
|
537
562
|
excluded = set(self._config.exclude)
|
|
538
|
-
included = set(filter_fn(list(self._functions.keys())))
|
|
563
|
+
included = set(await filter_fn(list(self._functions.keys())))
|
|
539
564
|
|
|
540
|
-
|
|
565
|
+
result = {}
|
|
566
|
+
for name in self._functions:
|
|
567
|
+
is_excluded = False
|
|
541
568
|
if name in excluded:
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
569
|
+
is_excluded = True
|
|
570
|
+
elif not await self._fn_should_be_included(name):
|
|
571
|
+
is_excluded = True
|
|
572
|
+
elif name not in included:
|
|
573
|
+
is_excluded = True
|
|
546
574
|
|
|
547
|
-
|
|
575
|
+
if is_excluded:
|
|
576
|
+
result[self._get_fn_name(name)] = self._functions[name]
|
|
548
577
|
|
|
549
|
-
|
|
578
|
+
return result
|
|
579
|
+
|
|
580
|
+
async def get_included_functions(
|
|
550
581
|
self,
|
|
551
|
-
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
582
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
|
|
552
583
|
) -> dict[str, Function]:
|
|
553
584
|
"""
|
|
554
585
|
Returns a dictionary of all functions in the function group which are:
|
|
@@ -558,7 +589,7 @@ class FunctionGroup:
|
|
|
558
589
|
|
|
559
590
|
Parameters
|
|
560
591
|
----------
|
|
561
|
-
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
592
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
|
|
562
593
|
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
563
594
|
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
564
595
|
all functions will be returned.
|
|
@@ -576,14 +607,27 @@ class FunctionGroup:
|
|
|
576
607
|
missing = set(self._config.include) - set(self._functions.keys())
|
|
577
608
|
if missing:
|
|
578
609
|
raise ValueError(f"Unknown included functions: {sorted(missing)}")
|
|
579
|
-
filter_fn = filter_fn or self._filter_fn or (lambda x: x)
|
|
580
|
-
included = set(filter_fn(list(self._config.include)))
|
|
581
|
-
included = {name for name in included if self._fn_should_be_included(name)}
|
|
582
|
-
return {self._get_fn_name(name): self._functions[name] for name in included}
|
|
583
610
|
|
|
584
|
-
|
|
611
|
+
if filter_fn is None:
|
|
612
|
+
if self._filter_fn is None:
|
|
613
|
+
|
|
614
|
+
async def identity_filter(x: Sequence[str]) -> Sequence[str]:
|
|
615
|
+
return x
|
|
616
|
+
|
|
617
|
+
filter_fn = identity_filter
|
|
618
|
+
else:
|
|
619
|
+
filter_fn = self._filter_fn
|
|
620
|
+
|
|
621
|
+
included = set(await filter_fn(list(self._config.include)))
|
|
622
|
+
result = {}
|
|
623
|
+
for name in included:
|
|
624
|
+
if await self._fn_should_be_included(name):
|
|
625
|
+
result[self._get_fn_name(name)] = self._functions[name]
|
|
626
|
+
return result
|
|
627
|
+
|
|
628
|
+
async def get_all_functions(
|
|
585
629
|
self,
|
|
586
|
-
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
630
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
|
|
587
631
|
) -> dict[str, Function]:
|
|
588
632
|
"""
|
|
589
633
|
Returns a dictionary of all functions in the function group, regardless if they are included or excluded.
|
|
@@ -592,7 +636,7 @@ class FunctionGroup:
|
|
|
592
636
|
|
|
593
637
|
Parameters
|
|
594
638
|
----------
|
|
595
|
-
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
639
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
|
|
596
640
|
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
597
641
|
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
598
642
|
all functions will be returned.
|
|
@@ -602,23 +646,35 @@ class FunctionGroup:
|
|
|
602
646
|
dict[str, Function]
|
|
603
647
|
A dictionary of all functions in the function group.
|
|
604
648
|
"""
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
649
|
+
if filter_fn is None:
|
|
650
|
+
if self._filter_fn is None:
|
|
651
|
+
|
|
652
|
+
async def identity_filter(x: Sequence[str]) -> Sequence[str]:
|
|
653
|
+
return x
|
|
654
|
+
|
|
655
|
+
filter_fn = identity_filter
|
|
656
|
+
else:
|
|
657
|
+
filter_fn = self._filter_fn
|
|
658
|
+
|
|
659
|
+
included = set(await filter_fn(list(self._functions.keys())))
|
|
660
|
+
result = {}
|
|
661
|
+
for name in included:
|
|
662
|
+
if await self._fn_should_be_included(name):
|
|
663
|
+
result[self._get_fn_name(name)] = self._functions[name]
|
|
664
|
+
return result
|
|
609
665
|
|
|
610
|
-
def set_filter_fn(self, filter_fn: Callable[[Sequence[str]], Sequence[str]]):
|
|
666
|
+
def set_filter_fn(self, filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]]):
|
|
611
667
|
"""
|
|
612
668
|
Sets the filter function for the function group.
|
|
613
669
|
|
|
614
670
|
Parameters
|
|
615
671
|
----------
|
|
616
|
-
filter_fn : Callable[[Sequence[str]], Sequence[str]]
|
|
672
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]]
|
|
617
673
|
The filter function to set for the function group.
|
|
618
674
|
"""
|
|
619
675
|
self._filter_fn = filter_fn
|
|
620
676
|
|
|
621
|
-
def set_per_function_filter_fn(self, name: str, filter_fn: Callable[[str], bool]):
|
|
677
|
+
def set_per_function_filter_fn(self, name: str, filter_fn: Callable[[str], Awaitable[bool]]):
|
|
622
678
|
"""
|
|
623
679
|
Sets the a per-function filter function for the a function within the function group.
|
|
624
680
|
|
|
@@ -626,7 +682,7 @@ class FunctionGroup:
|
|
|
626
682
|
----------
|
|
627
683
|
name : str
|
|
628
684
|
The name of the function.
|
|
629
|
-
filter_fn : Callable[[str], bool]
|
|
685
|
+
filter_fn : Callable[[str], Awaitable[bool]]
|
|
630
686
|
The per-function filter function to set for the function group.
|
|
631
687
|
|
|
632
688
|
Raises
|