nvidia-nat 1.3.0a20250922__py3-none-any.whl → 1.3.0a20250924__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nat/agent/react_agent/register.py +12 -1
  2. nat/agent/reasoning_agent/reasoning_agent.py +2 -2
  3. nat/agent/rewoo_agent/register.py +12 -1
  4. nat/agent/tool_calling_agent/register.py +28 -8
  5. nat/builder/builder.py +33 -24
  6. nat/builder/eval_builder.py +14 -9
  7. nat/builder/function.py +108 -52
  8. nat/builder/workflow_builder.py +89 -79
  9. nat/cli/commands/info/info.py +16 -6
  10. nat/cli/commands/mcp/__init__.py +14 -0
  11. nat/cli/commands/mcp/mcp.py +786 -0
  12. nat/cli/entrypoint.py +2 -1
  13. nat/control_flow/router_agent/register.py +1 -1
  14. nat/control_flow/sequential_executor.py +6 -7
  15. nat/eval/evaluate.py +2 -1
  16. nat/eval/trajectory_evaluator/register.py +1 -1
  17. nat/experimental/decorators/experimental_warning_decorator.py +26 -5
  18. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +2 -2
  19. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  20. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  21. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  22. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
  23. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +4 -4
  24. nat/front_ends/simple_base/simple_front_end_plugin_base.py +1 -1
  25. nat/profiler/decorators/function_tracking.py +33 -1
  26. nat/profiler/parameter_optimization/prompt_optimizer.py +2 -2
  27. nat/runtime/loader.py +1 -1
  28. {nvidia_nat-1.3.0a20250922.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/METADATA +1 -1
  29. {nvidia_nat-1.3.0a20250922.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/RECORD +34 -33
  30. nat/cli/commands/info/list_mcp.py +0 -461
  31. {nvidia_nat-1.3.0a20250922.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/WHEEL +0 -0
  32. {nvidia_nat-1.3.0a20250922.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/entry_points.txt +0 -0
  33. {nvidia_nat-1.3.0a20250922.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  34. {nvidia_nat-1.3.0a20250922.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/licenses/LICENSE.md +0 -0
  35. {nvidia_nat-1.3.0a20250922.dist-info → nvidia_nat-1.3.0a20250924.dist-info}/top_level.txt +0 -0
@@ -99,7 +99,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
99
99
  llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
100
100
  # the agent can run any installed tool, simply install the tool and add it to the config file
101
101
  # the sample tool provided can easily be copied or changed
102
- tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
102
+ tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
103
103
  if not tools:
104
104
  raise ValueError(f"No tools specified for ReAct Agent '{config.llm_name}'")
105
105
  # configure callbacks, for sending intermediate steps
@@ -118,6 +118,17 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
118
118
  normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()
119
119
 
120
120
  async def _response_fn(input_message: ChatRequest) -> ChatResponse:
121
+ """
122
+ Main workflow entry function for the ReAct Agent.
123
+
124
+ This function invokes the ReAct Agent Graph and returns the response.
125
+
126
+ Args:
127
+ input_message (ChatRequest): The input message to process
128
+
129
+ Returns:
130
+ ChatResponse: The response from the agent or error message
131
+ """
121
132
  try:
122
133
  # initialize the starting state with the user query
123
134
  messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
@@ -99,7 +99,7 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
99
99
  llm: BaseChatModel = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
100
100
 
101
101
  # Get the augmented function's description
102
- augmented_function = builder.get_function(config.augmented_fn)
102
+ augmented_function = await builder.get_function(config.augmented_fn)
103
103
 
104
104
  # For now, we rely on runtime checking for type conversion
105
105
 
@@ -119,7 +119,7 @@ async def build_reasoning_function(config: ReasoningFunctionConfig, builder: Bui
119
119
  tool_names_with_desc: list[tuple[str, str]] = []
120
120
 
121
121
  for tool in function_used_tools:
122
- tool_impl = builder.get_function(tool)
122
+ tool_impl = await builder.get_function(tool)
123
123
  tool_names_with_desc.append((tool, tool_impl.description if hasattr(tool_impl, "description") else ""))
124
124
 
125
125
  # Draft the reasoning prompt for the augmented function
@@ -108,7 +108,7 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
108
108
 
109
109
  # the agent can run any installed tool, simply install the tool and add it to the config file
110
110
  # the sample tool provided can easily be copied or changed
111
- tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
111
+ tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
112
112
  if not tools:
113
113
  raise ValueError(f"No tools specified for ReWOO Agent '{config.llm_name}'")
114
114
 
@@ -125,6 +125,17 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
125
125
  raise_tool_call_error=config.raise_tool_call_error).build_graph()
126
126
 
127
127
  async def _response_fn(input_message: ChatRequest) -> ChatResponse:
128
+ """
129
+ Main workflow entry function for the ReWOO Agent.
130
+
131
+ This function invokes the ReWOO Agent Graph and returns the response.
132
+
133
+ Args:
134
+ input_message (ChatRequest): The input message to process
135
+
136
+ Returns:
137
+ ChatResponse: The response from the agent or error message
138
+ """
128
139
  try:
129
140
  # initialize the starting state with the user query
130
141
  messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
@@ -22,6 +22,7 @@ from nat.builder.framework_enum import LLMFrameworkEnum
22
22
  from nat.builder.function_info import FunctionInfo
23
23
  from nat.cli.register_workflow import register_function
24
24
  from nat.data_models.agent import AgentBaseConfig
25
+ from nat.data_models.api_server import ChatRequest
25
26
  from nat.data_models.component_ref import FunctionGroupRef
26
27
  from nat.data_models.component_ref import FunctionRef
27
28
 
@@ -38,6 +39,8 @@ class ToolCallAgentWorkflowConfig(AgentBaseConfig, name="tool_calling_agent"):
38
39
  default_factory=list, description="The list of tools to provide to the tool calling agent.")
39
40
  handle_tool_errors: bool = Field(default=True, description="Specify ability to handle tool calling errors.")
40
41
  max_iterations: int = Field(default=15, description="Number of tool calls before stoping the tool calling agent.")
42
+ max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
43
+
41
44
  system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
42
45
  additional_instructions: str | None = Field(default=None,
43
46
  description="Additional instructions appended to the system prompt.")
@@ -47,7 +50,8 @@ class ToolCallAgentWorkflowConfig(AgentBaseConfig, name="tool_calling_agent"):
47
50
 
48
51
  @register_function(config_type=ToolCallAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
49
52
  async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, builder: Builder):
50
- from langchain_core.messages.human import HumanMessage
53
+ from langchain_core.messages import trim_messages
54
+ from langchain_core.messages.base import BaseMessage
51
55
  from langgraph.graph.state import CompiledStateGraph
52
56
 
53
57
  from nat.agent.base import AGENT_LOG_PREFIX
@@ -60,13 +64,13 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
60
64
  llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
61
65
  # the agent can run any installed tool, simply install the tool and add it to the config file
62
66
  # the sample tools provided can easily be copied or changed
63
- tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
67
+ tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
64
68
  if not tools:
65
69
  raise ValueError(f"No tools specified for Tool Calling Agent '{config.llm_name}'")
66
70
 
67
71
  # convert return_direct FunctionRef objects to BaseTool objects
68
- return_direct_tools = builder.get_tools(tool_names=config.return_direct,
69
- wrapper_type=LLMFrameworkEnum.LANGCHAIN) if config.return_direct else None
72
+ return_direct_tools = await builder.get_tools(
73
+ tool_names=config.return_direct, wrapper_type=LLMFrameworkEnum.LANGCHAIN) if config.return_direct else None
70
74
 
71
75
  # construct the Tool Calling Agent Graph from the configured llm, and tools
72
76
  graph: CompiledStateGraph = await ToolCallAgentGraph(llm=llm,
@@ -77,11 +81,27 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
77
81
  handle_tool_errors=config.handle_tool_errors,
78
82
  return_direct=return_direct_tools).build_graph()
79
83
 
80
- async def _response_fn(input_message: str) -> str:
84
+ async def _response_fn(input_message: ChatRequest) -> str:
85
+ """
86
+ Main workflow entry function for the Tool Calling Agent.
87
+
88
+ This function invokes the Tool Calling Agent Graph and returns the response.
89
+
90
+ Args:
91
+ input_message (ChatRequest): The input message to process
92
+
93
+ Returns:
94
+ str: The response from the agent or error message
95
+ """
81
96
  try:
82
97
  # initialize the starting state with the user query
83
- input_message = HumanMessage(content=input_message)
84
- state = ToolCallAgentGraphState(messages=[input_message])
98
+ messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
99
+ max_tokens=config.max_history,
100
+ strategy="last",
101
+ token_counter=len,
102
+ start_on="human",
103
+ include_system=True)
104
+ state = ToolCallAgentGraphState(messages=messages)
85
105
 
86
106
  # run the Tool Calling Agent Graph
87
107
  state = await graph.ainvoke(state, config={'recursion_limit': (config.max_iterations + 1) * 2})
@@ -92,7 +112,7 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
92
112
  # get and return the output from the state
93
113
  state = ToolCallAgentGraphState(**state)
94
114
  output_message = state.messages[-1]
95
- return output_message.content
115
+ return str(output_message.content)
96
116
  except Exception as ex:
97
117
  logger.exception("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
98
118
  if config.verbose:
nat/builder/builder.py CHANGED
@@ -45,12 +45,16 @@ from nat.data_models.memory import MemoryBaseConfig
45
45
  from nat.data_models.object_store import ObjectStoreBaseConfig
46
46
  from nat.data_models.retriever import RetrieverBaseConfig
47
47
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
48
+ from nat.experimental.decorators.experimental_warning_decorator import experimental
48
49
  from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
49
50
  from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
50
51
  from nat.memory.interfaces import MemoryEditor
51
52
  from nat.object_store.interfaces import ObjectStore
52
53
  from nat.retriever.interface import Retriever
53
54
 
55
+ if typing.TYPE_CHECKING:
56
+ from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
57
+
54
58
 
55
59
  class UserManagerHolder():
56
60
 
@@ -72,19 +76,20 @@ class Builder(ABC):
72
76
  pass
73
77
 
74
78
  @abstractmethod
75
- def get_function(self, name: str | FunctionRef) -> Function:
79
+ async def get_function(self, name: str | FunctionRef) -> Function:
76
80
  pass
77
81
 
78
82
  @abstractmethod
79
- def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
83
+ async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
80
84
  pass
81
85
 
82
- def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
83
-
84
- return [self.get_function(name) for name in function_names]
86
+ async def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
87
+ tasks = [self.get_function(name) for name in function_names]
88
+ return list(await asyncio.gather(*tasks, return_exceptions=False))
85
89
 
86
- def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
87
- return [self.get_function_group(name) for name in function_group_names]
90
+ async def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
91
+ tasks = [self.get_function_group(name) for name in function_group_names]
92
+ return list(await asyncio.gather(*tasks, return_exceptions=False))
88
93
 
89
94
  @abstractmethod
90
95
  def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
@@ -107,17 +112,17 @@ class Builder(ABC):
107
112
  pass
108
113
 
109
114
  @abstractmethod
110
- def get_tools(self,
111
- tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
112
- wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
115
+ async def get_tools(self,
116
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
117
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
113
118
  pass
114
119
 
115
120
  @abstractmethod
116
- def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
121
+ async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
117
122
  pass
118
123
 
119
124
  @abstractmethod
120
- async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
125
+ async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> typing.Any:
121
126
  pass
122
127
 
123
128
  @abstractmethod
@@ -138,7 +143,9 @@ class Builder(ABC):
138
143
  pass
139
144
 
140
145
  @abstractmethod
141
- async def add_auth_provider(self, name: str | AuthenticationRef, config: AuthProviderBaseConfig):
146
+ @experimental(feature_name="Authentication")
147
+ async def add_auth_provider(self, name: str | AuthenticationRef,
148
+ config: AuthProviderBaseConfig) -> AuthProviderBase:
142
149
  pass
143
150
 
144
151
  @abstractmethod
@@ -154,7 +161,7 @@ class Builder(ABC):
154
161
  return list(auth_providers)
155
162
 
156
163
  @abstractmethod
157
- async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig):
164
+ async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig) -> ObjectStore:
158
165
  pass
159
166
 
160
167
  async def get_object_store_clients(self, object_store_names: Sequence[str | ObjectStoreRef]) -> list[ObjectStore]:
@@ -172,7 +179,7 @@ class Builder(ABC):
172
179
  pass
173
180
 
174
181
  @abstractmethod
175
- async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
182
+ async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
176
183
  pass
177
184
 
178
185
  async def get_embedders(self, embedder_names: Sequence[str | EmbedderRef],
@@ -193,17 +200,18 @@ class Builder(ABC):
193
200
  pass
194
201
 
195
202
  @abstractmethod
196
- async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig):
203
+ async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig) -> MemoryEditor:
197
204
  pass
198
205
 
199
- def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
206
+ async def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
200
207
  """
201
208
  Return a list of memory clients for the specified names.
202
209
  """
203
- return [self.get_memory_client(n) for n in memory_names]
210
+ tasks = [self.get_memory_client(n) for n in memory_names]
211
+ return list(await asyncio.gather(*tasks, return_exceptions=False))
204
212
 
205
213
  @abstractmethod
206
- def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
214
+ async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
207
215
  """
208
216
  Return the instantiated memory client for the given name.
209
217
  """
@@ -214,12 +222,12 @@ class Builder(ABC):
214
222
  pass
215
223
 
216
224
  @abstractmethod
217
- async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
225
+ async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
218
226
  pass
219
227
 
220
228
  async def get_retrievers(self,
221
229
  retriever_names: Sequence[str | RetrieverRef],
222
- wrapper_type: LLMFrameworkEnum | str | None = None):
230
+ wrapper_type: LLMFrameworkEnum | str | None = None) -> list[Retriever]:
223
231
 
224
232
  tasks = [self.get_retriever(n, wrapper_type=wrapper_type) for n in retriever_names]
225
233
 
@@ -251,14 +259,15 @@ class Builder(ABC):
251
259
  pass
252
260
 
253
261
  @abstractmethod
254
- async def add_ttc_strategy(self, name: str | str, config: TTCStrategyBaseConfig):
262
+ @experimental(feature_name="TTC")
263
+ async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig):
255
264
  pass
256
265
 
257
266
  @abstractmethod
258
267
  async def get_ttc_strategy(self,
259
268
  strategy_name: str | TTCStrategyRef,
260
269
  pipeline_type: PipelineTypeEnum,
261
- stage_type: StageTypeEnum):
270
+ stage_type: StageTypeEnum) -> "StrategyBase":
262
271
  pass
263
272
 
264
273
  @abstractmethod
@@ -304,5 +313,5 @@ class EvalBuilder(ABC):
304
313
  pass
305
314
 
306
315
  @abstractmethod
307
- def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
316
+ async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
308
317
  pass
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
16
17
  import dataclasses
17
18
  import logging
18
19
  from contextlib import asynccontextmanager
@@ -90,17 +91,20 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
90
91
  return self.eval_general_config.output_dir
91
92
 
92
93
  @override
93
- def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
94
- tools = []
94
+ async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
95
95
  tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
96
- for fn_name in self._functions:
97
- fn = self.get_function(fn_name)
96
+
97
+ async def get_tool(fn_name: str):
98
+ fn = await self.get_function(fn_name)
98
99
  try:
99
- tools.append(tool_wrapper_reg.build_fn(fn_name, fn, self))
100
+ return tool_wrapper_reg.build_fn(fn_name, fn, self)
100
101
  except Exception:
101
102
  logger.exception("Error fetching tool `%s`", fn_name)
103
+ return None
102
104
 
103
- return tools
105
+ tasks = [get_tool(fn_name) for fn_name in self._functions]
106
+ tools = await asyncio.gather(*tasks, return_exceptions=False)
107
+ return [tool for tool in tools if tool is not None]
104
108
 
105
109
  def _log_build_failure_evaluator(self,
106
110
  failing_evaluator_name: str,
@@ -127,11 +131,12 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
127
131
  remaining_components,
128
132
  original_error)
129
133
 
130
- async def populate_builder(self, config: Config):
134
+ @override
135
+ async def populate_builder(self, config: Config, skip_workflow: bool = False):
131
136
  # Skip setting workflow if workflow config is EmptyFunctionConfig
132
- skip_workflow = isinstance(config.workflow, EmptyFunctionConfig)
137
+ skip_workflow = skip_workflow or isinstance(config.workflow, EmptyFunctionConfig)
133
138
 
134
- await super().populate_builder(config, skip_workflow)
139
+ await super().populate_builder(config, skip_workflow=skip_workflow)
135
140
 
136
141
  # Initialize progress tracking for evaluators
137
142
  completed_evaluators = []
nat/builder/function.py CHANGED
@@ -357,7 +357,7 @@ class FunctionGroup:
357
357
  *,
358
358
  config: FunctionGroupBaseConfig,
359
359
  instance_name: str | None = None,
360
- filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None):
360
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None):
361
361
  """
362
362
  Creates a new function group.
363
363
 
@@ -367,7 +367,7 @@ class FunctionGroup:
367
367
  The configuration for the function group.
368
368
  instance_name : str | None, optional
369
369
  The name of the function group. If not provided, the type of the function group will be used.
370
- filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
370
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
371
371
  A callback function to additionally filter the functions in the function group dynamically when
372
372
  the functions are accessed via any accessor method.
373
373
  """
@@ -375,7 +375,7 @@ class FunctionGroup:
375
375
  self._instance_name = instance_name or config.type
376
376
  self._functions: dict[str, Function] = dict()
377
377
  self._filter_fn = filter_fn
378
- self._per_function_filter_fn: dict[str, Callable[[str], bool]] = dict()
378
+ self._per_function_filter_fn: dict[str, Callable[[str], Awaitable[bool]]] = dict()
379
379
 
380
380
  def add_function(self,
381
381
  name: str,
@@ -384,7 +384,7 @@ class FunctionGroup:
384
384
  input_schema: type[BaseModel] | None = None,
385
385
  description: str | None = None,
386
386
  converters: list[Callable] | None = None,
387
- filter_fn: Callable[[str], bool] | None = None):
387
+ filter_fn: Callable[[str], Awaitable[bool]] | None = None):
388
388
  """
389
389
  Adds a function to the function group.
390
390
 
@@ -400,7 +400,7 @@ class FunctionGroup:
400
400
  The description of the function.
401
401
  converters : list[Callable] | None, optional
402
402
  The converters to use for the function.
403
- filter_fn : Callable[[str], bool] | None, optional
403
+ filter_fn : Callable[[str], Awaitable[bool]] | None, optional
404
404
  A callback to determine if the function should be included in the function group. The
405
405
  callback will be called with the function name. The callback is invoked dynamically when
406
406
  the functions are accessed via any accessor method such as `get_accessible_functions`,
@@ -441,12 +441,14 @@ class FunctionGroup:
441
441
  def _get_fn_name(self, name: str) -> str:
442
442
  return f"{self._instance_name}.{name}"
443
443
 
444
- def _fn_should_be_included(self, name: str) -> bool:
445
- return (name not in self._per_function_filter_fn or self._per_function_filter_fn[name](name))
444
+ async def _fn_should_be_included(self, name: str) -> bool:
445
+ if name not in self._per_function_filter_fn:
446
+ return True
447
+ return await self._per_function_filter_fn[name](name)
446
448
 
447
- def _get_all_but_excluded_functions(
449
+ async def _get_all_but_excluded_functions(
448
450
  self,
449
- filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
451
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
450
452
  ) -> dict[str, Function]:
451
453
  """
452
454
  Returns a dictionary of all functions in the function group except the excluded functions.
@@ -454,22 +456,35 @@ class FunctionGroup:
454
456
  missing = set(self._config.exclude) - set(self._functions.keys())
455
457
  if missing:
456
458
  raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
457
- filter_fn = filter_fn or self._filter_fn or (lambda x: x)
459
+
460
+ if filter_fn is None:
461
+ if self._filter_fn is None:
462
+
463
+ async def identity_filter(x: Sequence[str]) -> Sequence[str]:
464
+ return x
465
+
466
+ filter_fn = identity_filter
467
+ else:
468
+ filter_fn = self._filter_fn
469
+
458
470
  excluded = set(self._config.exclude)
459
- included = set(filter_fn(list(self._functions.keys())))
471
+ included = set(await filter_fn(list(self._functions.keys())))
460
472
 
461
- def predicate(name: str) -> bool:
473
+ result = {}
474
+ for name in self._functions:
462
475
  if name in excluded:
463
- return False
464
- if not self._fn_should_be_included(name):
465
- return False
466
- return name in included
476
+ continue
477
+ if not await self._fn_should_be_included(name):
478
+ continue
479
+ if name not in included:
480
+ continue
481
+ result[self._get_fn_name(name)] = self._functions[name]
467
482
 
468
- return {self._get_fn_name(name): self._functions[name] for name in self._functions if predicate(name)}
483
+ return result
469
484
 
470
- def get_accessible_functions(
485
+ async def get_accessible_functions(
471
486
  self,
472
- filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
487
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
473
488
  ) -> dict[str, Function]:
474
489
  """
475
490
  Returns a dictionary of all accessible functions in the function group.
@@ -484,7 +499,7 @@ class FunctionGroup:
484
499
 
485
500
  Parameters
486
501
  ----------
487
- filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
502
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
488
503
  A callback function to additionally filter the functions in the function group dynamically. If not provided
489
504
  then fall back to the function group's filter function. If no filter function is set for the function group
490
505
  all functions will be returned.
@@ -500,14 +515,14 @@ class FunctionGroup:
500
515
  When the function group is configured to include functions that are not found in the group.
501
516
  """
502
517
  if self._config.include:
503
- return self.get_included_functions(filter_fn=filter_fn)
518
+ return await self.get_included_functions(filter_fn=filter_fn)
504
519
  if self._config.exclude:
505
- return self._get_all_but_excluded_functions(filter_fn=filter_fn)
506
- return self.get_all_functions(filter_fn=filter_fn)
520
+ return await self._get_all_but_excluded_functions(filter_fn=filter_fn)
521
+ return await self.get_all_functions(filter_fn=filter_fn)
507
522
 
508
- def get_excluded_functions(
523
+ async def get_excluded_functions(
509
524
  self,
510
- filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
525
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
511
526
  ) -> dict[str, Function]:
512
527
  """
513
528
  Returns a dictionary of all functions in the function group which are configured to be excluded or filtered
@@ -515,7 +530,7 @@ class FunctionGroup:
515
530
 
516
531
  Parameters
517
532
  ----------
518
- filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
533
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
519
534
  A callback function to additionally filter the functions in the function group dynamically. If not provided
520
535
  then fall back to the function group's filter function. If no filter function is set for the function group
521
536
  then no functions will be added to the returned dictionary.
@@ -533,22 +548,38 @@ class FunctionGroup:
533
548
  missing = set(self._config.exclude) - set(self._functions.keys())
534
549
  if missing:
535
550
  raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
536
- filter_fn = filter_fn or self._filter_fn or (lambda x: x)
551
+
552
+ if filter_fn is None:
553
+ if self._filter_fn is None:
554
+
555
+ async def identity_filter(x: Sequence[str]) -> Sequence[str]:
556
+ return x
557
+
558
+ filter_fn = identity_filter
559
+ else:
560
+ filter_fn = self._filter_fn
561
+
537
562
  excluded = set(self._config.exclude)
538
- included = set(filter_fn(list(self._functions.keys())))
563
+ included = set(await filter_fn(list(self._functions.keys())))
539
564
 
540
- def predicate(name: str) -> bool:
565
+ result = {}
566
+ for name in self._functions:
567
+ is_excluded = False
541
568
  if name in excluded:
542
- return True
543
- if not self._fn_should_be_included(name):
544
- return True
545
- return name not in included
569
+ is_excluded = True
570
+ elif not await self._fn_should_be_included(name):
571
+ is_excluded = True
572
+ elif name not in included:
573
+ is_excluded = True
546
574
 
547
- return {self._get_fn_name(name): self._functions[name] for name in self._functions if predicate(name)}
575
+ if is_excluded:
576
+ result[self._get_fn_name(name)] = self._functions[name]
548
577
 
549
- def get_included_functions(
578
+ return result
579
+
580
+ async def get_included_functions(
550
581
  self,
551
- filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
582
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
552
583
  ) -> dict[str, Function]:
553
584
  """
554
585
  Returns a dictionary of all functions in the function group which are:
@@ -558,7 +589,7 @@ class FunctionGroup:
558
589
 
559
590
  Parameters
560
591
  ----------
561
- filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
592
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
562
593
  A callback function to additionally filter the functions in the function group dynamically. If not provided
563
594
  then fall back to the function group's filter function. If no filter function is set for the function group
564
595
  all functions will be returned.
@@ -576,14 +607,27 @@ class FunctionGroup:
576
607
  missing = set(self._config.include) - set(self._functions.keys())
577
608
  if missing:
578
609
  raise ValueError(f"Unknown included functions: {sorted(missing)}")
579
- filter_fn = filter_fn or self._filter_fn or (lambda x: x)
580
- included = set(filter_fn(list(self._config.include)))
581
- included = {name for name in included if self._fn_should_be_included(name)}
582
- return {self._get_fn_name(name): self._functions[name] for name in included}
583
610
 
584
- def get_all_functions(
611
+ if filter_fn is None:
612
+ if self._filter_fn is None:
613
+
614
+ async def identity_filter(x: Sequence[str]) -> Sequence[str]:
615
+ return x
616
+
617
+ filter_fn = identity_filter
618
+ else:
619
+ filter_fn = self._filter_fn
620
+
621
+ included = set(await filter_fn(list(self._config.include)))
622
+ result = {}
623
+ for name in included:
624
+ if await self._fn_should_be_included(name):
625
+ result[self._get_fn_name(name)] = self._functions[name]
626
+ return result
627
+
628
+ async def get_all_functions(
585
629
  self,
586
- filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
630
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
587
631
  ) -> dict[str, Function]:
588
632
  """
589
633
  Returns a dictionary of all functions in the function group, regardless if they are included or excluded.
@@ -592,7 +636,7 @@ class FunctionGroup:
592
636
 
593
637
  Parameters
594
638
  ----------
595
- filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
639
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
596
640
  A callback function to additionally filter the functions in the function group dynamically. If not provided
597
641
  then fall back to the function group's filter function. If no filter function is set for the function group
598
642
  all functions will be returned.
@@ -602,23 +646,35 @@ class FunctionGroup:
602
646
  dict[str, Function]
603
647
  A dictionary of all functions in the function group.
604
648
  """
605
- filter_fn = filter_fn or self._filter_fn or (lambda x: x)
606
- included = set(filter_fn(list(self._functions.keys())))
607
- included = {name for name in included if self._fn_should_be_included(name)}
608
- return {self._get_fn_name(name): self._functions[name] for name in included}
649
+ if filter_fn is None:
650
+ if self._filter_fn is None:
651
+
652
+ async def identity_filter(x: Sequence[str]) -> Sequence[str]:
653
+ return x
654
+
655
+ filter_fn = identity_filter
656
+ else:
657
+ filter_fn = self._filter_fn
658
+
659
+ included = set(await filter_fn(list(self._functions.keys())))
660
+ result = {}
661
+ for name in included:
662
+ if await self._fn_should_be_included(name):
663
+ result[self._get_fn_name(name)] = self._functions[name]
664
+ return result
609
665
 
610
- def set_filter_fn(self, filter_fn: Callable[[Sequence[str]], Sequence[str]]):
666
+ def set_filter_fn(self, filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]]):
611
667
  """
612
668
  Sets the filter function for the function group.
613
669
 
614
670
  Parameters
615
671
  ----------
616
- filter_fn : Callable[[Sequence[str]], Sequence[str]]
672
+ filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]]
617
673
  The filter function to set for the function group.
618
674
  """
619
675
  self._filter_fn = filter_fn
620
676
 
621
- def set_per_function_filter_fn(self, name: str, filter_fn: Callable[[str], bool]):
677
+ def set_per_function_filter_fn(self, name: str, filter_fn: Callable[[str], Awaitable[bool]]):
622
678
  """
623
679
  Sets the a per-function filter function for the a function within the function group.
624
680
 
@@ -626,7 +682,7 @@ class FunctionGroup:
626
682
  ----------
627
683
  name : str
628
684
  The name of the function.
629
- filter_fn : Callable[[str], bool]
685
+ filter_fn : Callable[[str], Awaitable[bool]]
630
686
  The per-function filter function to set for the function group.
631
687
 
632
688
  Raises