nvidia-nat 1.3.0a20250923__py3-none-any.whl → 1.3.0a20250925__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/agent.py +5 -4
- nat/agent/react_agent/register.py +12 -1
- nat/agent/reasoning_agent/reasoning_agent.py +2 -2
- nat/agent/rewoo_agent/register.py +12 -1
- nat/agent/tool_calling_agent/register.py +28 -8
- nat/builder/builder.py +33 -24
- nat/builder/component_utils.py +1 -1
- nat/builder/eval_builder.py +14 -9
- nat/builder/framework_enum.py +1 -0
- nat/builder/function.py +108 -52
- nat/builder/workflow_builder.py +89 -79
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +786 -0
- nat/cli/entrypoint.py +2 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/control_flow/sequential_executor.py +6 -7
- nat/eval/evaluate.py +2 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +26 -5
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +2 -2
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +2 -2
- nat/front_ends/console/console_front_end_plugin.py +4 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +4 -4
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/litellm_llm.py +69 -0
- nat/llm/register.py +4 -0
- nat/profiler/decorators/framework_wrapper.py +52 -3
- nat/profiler/decorators/function_tracking.py +33 -1
- nat/profiler/parameter_optimization/prompt_optimizer.py +2 -2
- nat/runtime/loader.py +1 -1
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +1 -1
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/METADATA +6 -3
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/RECORD +43 -41
- nat/cli/commands/info/list_mcp.py +0 -461
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/top_level.txt +0 -0
nat/agent/react_agent/agent.py
CHANGED
|
@@ -59,6 +59,7 @@ class ReActGraphState(BaseModel):
|
|
|
59
59
|
messages: list[BaseMessage] = Field(default_factory=list) # input and output of the ReAct Agent
|
|
60
60
|
agent_scratchpad: list[AgentAction] = Field(default_factory=list) # agent thoughts / intermediate steps
|
|
61
61
|
tool_responses: list[BaseMessage] = Field(default_factory=list) # the responses from any tool calls
|
|
62
|
+
final_answer: str | None = Field(default=None) # the final answer from the ReAct Agent
|
|
62
63
|
|
|
63
64
|
|
|
64
65
|
class ReActAgentGraph(DualNodeAgent):
|
|
@@ -204,6 +205,7 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
204
205
|
# this is where we handle the final output of the Agent, we can clean-up/format/postprocess here
|
|
205
206
|
# the final answer goes in the "messages" state channel
|
|
206
207
|
state.messages += [AIMessage(content=final_answer)]
|
|
208
|
+
state.final_answer = final_answer
|
|
207
209
|
else:
|
|
208
210
|
# the agent wants to call a tool, ensure the thoughts are preserved for the next agentic cycle
|
|
209
211
|
agent_output.log = output_message.content
|
|
@@ -242,10 +244,9 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
242
244
|
async def conditional_edge(self, state: ReActGraphState):
|
|
243
245
|
try:
|
|
244
246
|
logger.debug("%s Starting the ReAct Conditional Edge", AGENT_LOG_PREFIX)
|
|
245
|
-
if
|
|
246
|
-
# the ReAct Agent has finished executing
|
|
247
|
-
|
|
248
|
-
logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, last_message_content)
|
|
247
|
+
if state.final_answer:
|
|
248
|
+
# the ReAct Agent has finished executing
|
|
249
|
+
logger.debug("%s Final answer:\n%s", AGENT_LOG_PREFIX, state.final_answer)
|
|
249
250
|
return AgentDecision.END
|
|
250
251
|
# else the agent wants to call a tool
|
|
251
252
|
agent_output = state.agent_scratchpad[-1]
|
|
@@ -99,7 +99,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
99
99
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
100
100
|
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
101
101
|
# the sample tool provided can easily be copied or changed
|
|
102
|
-
tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
102
|
+
tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
103
103
|
if not tools:
|
|
104
104
|
raise ValueError(f"No tools specified for ReAct Agent '{config.llm_name}'")
|
|
105
105
|
# configure callbacks, for sending intermediate steps
|
|
@@ -118,6 +118,17 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
118
118
|
normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
|
|
119
119
|
|
|
120
120
|
async def _response_fn(input_message: ChatRequest) -> ChatResponse:
|
|
121
|
+
"""
|
|
122
|
+
Main workflow entry function for the ReAct Agent.
|
|
123
|
+
|
|
124
|
+
This function invokes the ReAct Agent Graph and returns the response.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
input_message (ChatRequest): The input message to process
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
ChatResponse: The response from the agent or error message
|
|
131
|
+
"""
|
|
121
132
|
try:
|
|
122
133
|
# initialize the starting state with the user query
|
|
123
134
|
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
|
|
@@ -99,7 +99,7 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
99
99
|
llm: BaseChatModel = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
100
100
|
|
|
101
101
|
# Get the augmented function's description
|
|
102
|
-
augmented_function = builder.get_function(config.augmented_fn)
|
|
102
|
+
augmented_function = await builder.get_function(config.augmented_fn)
|
|
103
103
|
|
|
104
104
|
# For now, we rely on runtime checking for type conversion
|
|
105
105
|
|
|
@@ -119,7 +119,7 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
|
|
|
119
119
|
tool_names_with_desc: list[tuple[str, str]] = []
|
|
120
120
|
|
|
121
121
|
for tool in function_used_tools:
|
|
122
|
-
tool_impl = builder.get_function(tool)
|
|
122
|
+
tool_impl = await builder.get_function(tool)
|
|
123
123
|
tool_names_with_desc.append((tool, tool_impl.description if hasattr(tool_impl, "description") else ""))
|
|
124
124
|
|
|
125
125
|
# Draft the reasoning prompt for the augmented function
|
|
@@ -108,7 +108,7 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
108
108
|
|
|
109
109
|
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
110
110
|
# the sample tool provided can easily be copied or changed
|
|
111
|
-
tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
111
|
+
tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
112
112
|
if not tools:
|
|
113
113
|
raise ValueError(f"No tools specified for ReWOO Agent '{config.llm_name}'")
|
|
114
114
|
|
|
@@ -125,6 +125,17 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
125
125
|
raise_tool_call_error=config.raise_tool_call_error).build_graph()
|
|
126
126
|
|
|
127
127
|
async def _response_fn(input_message: ChatRequest) -> ChatResponse:
|
|
128
|
+
"""
|
|
129
|
+
Main workflow entry function for the ReWOO Agent.
|
|
130
|
+
|
|
131
|
+
This function invokes the ReWOO Agent Graph and returns the response.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
input_message (ChatRequest): The input message to process
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
ChatResponse: The response from the agent or error message
|
|
138
|
+
"""
|
|
128
139
|
try:
|
|
129
140
|
# initialize the starting state with the user query
|
|
130
141
|
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
|
|
@@ -22,6 +22,7 @@ from nat.builder.framework_enum import LLMFrameworkEnum
|
|
|
22
22
|
from nat.builder.function_info import FunctionInfo
|
|
23
23
|
from nat.cli.register_workflow import register_function
|
|
24
24
|
from nat.data_models.agent import AgentBaseConfig
|
|
25
|
+
from nat.data_models.api_server import ChatRequest
|
|
25
26
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
26
27
|
from nat.data_models.component_ref import FunctionRef
|
|
27
28
|
|
|
@@ -38,6 +39,8 @@ class ToolCallAgentWorkflowConfig(AgentBaseConfig, name="tool_calling_agent"):
|
|
|
38
39
|
default_factory=list, description="The list of tools to provide to the tool calling agent.")
|
|
39
40
|
handle_tool_errors: bool = Field(default=True, description="Specify ability to handle tool calling errors.")
|
|
40
41
|
max_iterations: int = Field(default=15, description="Number of tool calls before stoping the tool calling agent.")
|
|
42
|
+
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
43
|
+
|
|
41
44
|
system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
|
|
42
45
|
additional_instructions: str | None = Field(default=None,
|
|
43
46
|
description="Additional instructions appended to the system prompt.")
|
|
@@ -47,7 +50,8 @@ class ToolCallAgentWorkflowConfig(AgentBaseConfig, name="tool_calling_agent"):
|
|
|
47
50
|
|
|
48
51
|
@register_function(config_type=ToolCallAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
49
52
|
async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, builder: Builder):
|
|
50
|
-
from langchain_core.messages
|
|
53
|
+
from langchain_core.messages import trim_messages
|
|
54
|
+
from langchain_core.messages.base import BaseMessage
|
|
51
55
|
from langgraph.graph.state import CompiledStateGraph
|
|
52
56
|
|
|
53
57
|
from nat.agent.base import AGENT_LOG_PREFIX
|
|
@@ -60,13 +64,13 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
60
64
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
61
65
|
# the agent can run any installed tool, simply install the tool and add it to the config file
|
|
62
66
|
# the sample tools provided can easily be copied or changed
|
|
63
|
-
tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
67
|
+
tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
64
68
|
if not tools:
|
|
65
69
|
raise ValueError(f"No tools specified for Tool Calling Agent '{config.llm_name}'")
|
|
66
70
|
|
|
67
71
|
# convert return_direct FunctionRef objects to BaseTool objects
|
|
68
|
-
return_direct_tools = builder.get_tools(
|
|
69
|
-
|
|
72
|
+
return_direct_tools = await builder.get_tools(
|
|
73
|
+
tool_names=config.return_direct, wrapper_type=LLMFrameworkEnum.LANGCHAIN) if config.return_direct else None
|
|
70
74
|
|
|
71
75
|
# construct the Tool Calling Agent Graph from the configured llm, and tools
|
|
72
76
|
graph: CompiledStateGraph = await ToolCallAgentGraph(llm=llm,
|
|
@@ -77,11 +81,27 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
77
81
|
handle_tool_errors=config.handle_tool_errors,
|
|
78
82
|
return_direct=return_direct_tools).build_graph()
|
|
79
83
|
|
|
80
|
-
async def _response_fn(input_message:
|
|
84
|
+
async def _response_fn(input_message: ChatRequest) -> str:
|
|
85
|
+
"""
|
|
86
|
+
Main workflow entry function for the Tool Calling Agent.
|
|
87
|
+
|
|
88
|
+
This function invokes the Tool Calling Agent Graph and returns the response.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
input_message (ChatRequest): The input message to process
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
str: The response from the agent or error message
|
|
95
|
+
"""
|
|
81
96
|
try:
|
|
82
97
|
# initialize the starting state with the user query
|
|
83
|
-
|
|
84
|
-
|
|
98
|
+
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
|
|
99
|
+
max_tokens=config.max_history,
|
|
100
|
+
strategy="last",
|
|
101
|
+
token_counter=len,
|
|
102
|
+
start_on="human",
|
|
103
|
+
include_system=True)
|
|
104
|
+
state = ToolCallAgentGraphState(messages=messages)
|
|
85
105
|
|
|
86
106
|
# run the Tool Calling Agent Graph
|
|
87
107
|
state = await graph.ainvoke(state, config={'recursion_limit': (config.max_iterations + 1) * 2})
|
|
@@ -92,7 +112,7 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
92
112
|
# get and return the output from the state
|
|
93
113
|
state = ToolCallAgentGraphState(**state)
|
|
94
114
|
output_message = state.messages[-1]
|
|
95
|
-
return output_message.content
|
|
115
|
+
return str(output_message.content)
|
|
96
116
|
except Exception as ex:
|
|
97
117
|
logger.exception("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
|
|
98
118
|
if config.verbose:
|
nat/builder/builder.py
CHANGED
|
@@ -45,12 +45,16 @@ from nat.data_models.memory import MemoryBaseConfig
|
|
|
45
45
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
46
46
|
from nat.data_models.retriever import RetrieverBaseConfig
|
|
47
47
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
48
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
48
49
|
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
49
50
|
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
50
51
|
from nat.memory.interfaces import MemoryEditor
|
|
51
52
|
from nat.object_store.interfaces import ObjectStore
|
|
52
53
|
from nat.retriever.interface import Retriever
|
|
53
54
|
|
|
55
|
+
if typing.TYPE_CHECKING:
|
|
56
|
+
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
57
|
+
|
|
54
58
|
|
|
55
59
|
class UserManagerHolder():
|
|
56
60
|
|
|
@@ -72,19 +76,20 @@ class Builder(ABC):
|
|
|
72
76
|
pass
|
|
73
77
|
|
|
74
78
|
@abstractmethod
|
|
75
|
-
def get_function(self, name: str | FunctionRef) -> Function:
|
|
79
|
+
async def get_function(self, name: str | FunctionRef) -> Function:
|
|
76
80
|
pass
|
|
77
81
|
|
|
78
82
|
@abstractmethod
|
|
79
|
-
def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
83
|
+
async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
80
84
|
pass
|
|
81
85
|
|
|
82
|
-
def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
|
|
83
|
-
|
|
84
|
-
return
|
|
86
|
+
async def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
|
|
87
|
+
tasks = [self.get_function(name) for name in function_names]
|
|
88
|
+
return list(await asyncio.gather(*tasks, return_exceptions=False))
|
|
85
89
|
|
|
86
|
-
def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
|
|
87
|
-
|
|
90
|
+
async def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
|
|
91
|
+
tasks = [self.get_function_group(name) for name in function_group_names]
|
|
92
|
+
return list(await asyncio.gather(*tasks, return_exceptions=False))
|
|
88
93
|
|
|
89
94
|
@abstractmethod
|
|
90
95
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
@@ -107,17 +112,17 @@ class Builder(ABC):
|
|
|
107
112
|
pass
|
|
108
113
|
|
|
109
114
|
@abstractmethod
|
|
110
|
-
def get_tools(self,
|
|
111
|
-
|
|
112
|
-
|
|
115
|
+
async def get_tools(self,
|
|
116
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
117
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
113
118
|
pass
|
|
114
119
|
|
|
115
120
|
@abstractmethod
|
|
116
|
-
def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
121
|
+
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
117
122
|
pass
|
|
118
123
|
|
|
119
124
|
@abstractmethod
|
|
120
|
-
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
|
|
125
|
+
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> typing.Any:
|
|
121
126
|
pass
|
|
122
127
|
|
|
123
128
|
@abstractmethod
|
|
@@ -138,7 +143,9 @@ class Builder(ABC):
|
|
|
138
143
|
pass
|
|
139
144
|
|
|
140
145
|
@abstractmethod
|
|
141
|
-
|
|
146
|
+
@experimental(feature_name="Authentication")
|
|
147
|
+
async def add_auth_provider(self, name: str | AuthenticationRef,
|
|
148
|
+
config: AuthProviderBaseConfig) -> AuthProviderBase:
|
|
142
149
|
pass
|
|
143
150
|
|
|
144
151
|
@abstractmethod
|
|
@@ -154,7 +161,7 @@ class Builder(ABC):
|
|
|
154
161
|
return list(auth_providers)
|
|
155
162
|
|
|
156
163
|
@abstractmethod
|
|
157
|
-
async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig):
|
|
164
|
+
async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig) -> ObjectStore:
|
|
158
165
|
pass
|
|
159
166
|
|
|
160
167
|
async def get_object_store_clients(self, object_store_names: Sequence[str | ObjectStoreRef]) -> list[ObjectStore]:
|
|
@@ -172,7 +179,7 @@ class Builder(ABC):
|
|
|
172
179
|
pass
|
|
173
180
|
|
|
174
181
|
@abstractmethod
|
|
175
|
-
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
|
|
182
|
+
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
|
|
176
183
|
pass
|
|
177
184
|
|
|
178
185
|
async def get_embedders(self, embedder_names: Sequence[str | EmbedderRef],
|
|
@@ -193,17 +200,18 @@ class Builder(ABC):
|
|
|
193
200
|
pass
|
|
194
201
|
|
|
195
202
|
@abstractmethod
|
|
196
|
-
async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig):
|
|
203
|
+
async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig) -> MemoryEditor:
|
|
197
204
|
pass
|
|
198
205
|
|
|
199
|
-
def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
|
|
206
|
+
async def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
|
|
200
207
|
"""
|
|
201
208
|
Return a list of memory clients for the specified names.
|
|
202
209
|
"""
|
|
203
|
-
|
|
210
|
+
tasks = [self.get_memory_client(n) for n in memory_names]
|
|
211
|
+
return list(await asyncio.gather(*tasks, return_exceptions=False))
|
|
204
212
|
|
|
205
213
|
@abstractmethod
|
|
206
|
-
def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
214
|
+
async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
207
215
|
"""
|
|
208
216
|
Return the instantiated memory client for the given name.
|
|
209
217
|
"""
|
|
@@ -214,12 +222,12 @@ class Builder(ABC):
|
|
|
214
222
|
pass
|
|
215
223
|
|
|
216
224
|
@abstractmethod
|
|
217
|
-
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
|
|
225
|
+
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
|
|
218
226
|
pass
|
|
219
227
|
|
|
220
228
|
async def get_retrievers(self,
|
|
221
229
|
retriever_names: Sequence[str | RetrieverRef],
|
|
222
|
-
wrapper_type: LLMFrameworkEnum | str | None = None):
|
|
230
|
+
wrapper_type: LLMFrameworkEnum | str | None = None) -> list[Retriever]:
|
|
223
231
|
|
|
224
232
|
tasks = [self.get_retriever(n, wrapper_type=wrapper_type) for n in retriever_names]
|
|
225
233
|
|
|
@@ -251,14 +259,15 @@ class Builder(ABC):
|
|
|
251
259
|
pass
|
|
252
260
|
|
|
253
261
|
@abstractmethod
|
|
254
|
-
|
|
262
|
+
@experimental(feature_name="TTC")
|
|
263
|
+
async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig):
|
|
255
264
|
pass
|
|
256
265
|
|
|
257
266
|
@abstractmethod
|
|
258
267
|
async def get_ttc_strategy(self,
|
|
259
268
|
strategy_name: str | TTCStrategyRef,
|
|
260
269
|
pipeline_type: PipelineTypeEnum,
|
|
261
|
-
stage_type: StageTypeEnum):
|
|
270
|
+
stage_type: StageTypeEnum) -> "StrategyBase":
|
|
262
271
|
pass
|
|
263
272
|
|
|
264
273
|
@abstractmethod
|
|
@@ -304,5 +313,5 @@ class EvalBuilder(ABC):
|
|
|
304
313
|
pass
|
|
305
314
|
|
|
306
315
|
@abstractmethod
|
|
307
|
-
def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
316
|
+
async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
308
317
|
pass
|
nat/builder/component_utils.py
CHANGED
|
@@ -158,7 +158,7 @@ def recursive_componentref_discovery(cls: TypedBaseModel, value: typing.Any,
|
|
|
158
158
|
yield from recursive_componentref_discovery(cls, field_data, field_info.annotation)
|
|
159
159
|
if (decomposed_type.is_union):
|
|
160
160
|
for arg in decomposed_type.args:
|
|
161
|
-
if arg is typing.Any or
|
|
161
|
+
if arg is typing.Any or DecomposedType(arg).is_instance(value):
|
|
162
162
|
yield from recursive_componentref_discovery(cls, value, arg)
|
|
163
163
|
else:
|
|
164
164
|
for arg in decomposed_type.args:
|
nat/builder/eval_builder.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import dataclasses
|
|
17
18
|
import logging
|
|
18
19
|
from contextlib import asynccontextmanager
|
|
@@ -90,17 +91,20 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
|
|
|
90
91
|
return self.eval_general_config.output_dir
|
|
91
92
|
|
|
92
93
|
@override
|
|
93
|
-
def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
|
|
94
|
-
tools = []
|
|
94
|
+
async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
|
|
95
95
|
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
|
|
97
|
+
async def get_tool(fn_name: str):
|
|
98
|
+
fn = await self.get_function(fn_name)
|
|
98
99
|
try:
|
|
99
|
-
|
|
100
|
+
return tool_wrapper_reg.build_fn(fn_name, fn, self)
|
|
100
101
|
except Exception:
|
|
101
102
|
logger.exception("Error fetching tool `%s`", fn_name)
|
|
103
|
+
return None
|
|
102
104
|
|
|
103
|
-
|
|
105
|
+
tasks = [get_tool(fn_name) for fn_name in self._functions]
|
|
106
|
+
tools = await asyncio.gather(*tasks, return_exceptions=False)
|
|
107
|
+
return [tool for tool in tools if tool is not None]
|
|
104
108
|
|
|
105
109
|
def _log_build_failure_evaluator(self,
|
|
106
110
|
failing_evaluator_name: str,
|
|
@@ -127,11 +131,12 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
|
|
|
127
131
|
remaining_components,
|
|
128
132
|
original_error)
|
|
129
133
|
|
|
130
|
-
|
|
134
|
+
@override
|
|
135
|
+
async def populate_builder(self, config: Config, skip_workflow: bool = False):
|
|
131
136
|
# Skip setting workflow if workflow config is EmptyFunctionConfig
|
|
132
|
-
skip_workflow = isinstance(config.workflow, EmptyFunctionConfig)
|
|
137
|
+
skip_workflow = skip_workflow or isinstance(config.workflow, EmptyFunctionConfig)
|
|
133
138
|
|
|
134
|
-
await super().populate_builder(config, skip_workflow)
|
|
139
|
+
await super().populate_builder(config, skip_workflow=skip_workflow)
|
|
135
140
|
|
|
136
141
|
# Initialize progress tracking for evaluators
|
|
137
142
|
completed_evaluators = []
|
nat/builder/framework_enum.py
CHANGED