nvidia-nat 1.3.0a20250923__py3-none-any.whl → 1.3.0a20250925__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. nat/agent/react_agent/agent.py +5 -4
  2. nat/agent/react_agent/register.py +12 -1
  3. nat/agent/reasoning_agent/reasoning_agent.py +2 -2
  4. nat/agent/rewoo_agent/register.py +12 -1
  5. nat/agent/tool_calling_agent/register.py +28 -8
  6. nat/builder/builder.py +33 -24
  7. nat/builder/component_utils.py +1 -1
  8. nat/builder/eval_builder.py +14 -9
  9. nat/builder/framework_enum.py +1 -0
  10. nat/builder/function.py +108 -52
  11. nat/builder/workflow_builder.py +89 -79
  12. nat/cli/commands/info/info.py +16 -6
  13. nat/cli/commands/mcp/__init__.py +14 -0
  14. nat/cli/commands/mcp/mcp.py +786 -0
  15. nat/cli/entrypoint.py +2 -1
  16. nat/control_flow/router_agent/register.py +1 -1
  17. nat/control_flow/sequential_executor.py +6 -7
  18. nat/eval/evaluate.py +2 -1
  19. nat/eval/trajectory_evaluator/register.py +1 -1
  20. nat/experimental/decorators/experimental_warning_decorator.py +26 -5
  21. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +2 -2
  22. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  23. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  24. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  25. nat/front_ends/console/console_front_end_plugin.py +4 -3
  26. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
  27. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +4 -4
  28. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  29. nat/llm/litellm_llm.py +69 -0
  30. nat/llm/register.py +4 -0
  31. nat/profiler/decorators/framework_wrapper.py +52 -3
  32. nat/profiler/decorators/function_tracking.py +33 -1
  33. nat/profiler/parameter_optimization/prompt_optimizer.py +2 -2
  34. nat/runtime/loader.py +1 -1
  35. nat/utils/type_converter.py +4 -3
  36. nat/utils/type_utils.py +1 -1
  37. {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/METADATA +6 -3
  38. {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/RECORD +43 -41
  39. nat/cli/commands/info/list_mcp.py +0 -461
  40. {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/WHEEL +0 -0
  41. {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/entry_points.txt +0 -0
  42. {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  43. {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/licenses/LICENSE.md +0 -0
  44. {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/top_level.txt +0 -0
nat/cli/entrypoint.py CHANGED
@@ -35,6 +35,7 @@ from nat.utils.log_levels import LOG_LEVELS
35
35
  from .commands.configure.configure import configure_command
36
36
  from .commands.evaluate import eval_command
37
37
  from .commands.info.info import info_command
38
+ from .commands.mcp.mcp import mcp_command
38
39
  from .commands.object_store.object_store import object_store_command
39
40
  from .commands.optimize import optimizer_command
40
41
  from .commands.registry.registry import registry_command
@@ -104,11 +105,11 @@ cli.add_command(workflow_command, name="workflow")
104
105
  cli.add_command(sizing, name="sizing")
105
106
  cli.add_command(optimizer_command, name="optimize")
106
107
  cli.add_command(object_store_command, name="object-store")
108
+ cli.add_command(mcp_command, name="mcp")
107
109
 
108
110
  # Aliases
109
111
  cli.add_command(start_command.get_command(None, "console"), name="run") # type: ignore
110
112
  cli.add_command(start_command.get_command(None, "fastapi"), name="serve") # type: ignore
111
- cli.add_command(start_command.get_command(None, "mcp"), name="mcp") # type: ignore
112
113
 
113
114
 
114
115
  @cli.result_callback()
@@ -53,7 +53,7 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
53
53
 
54
54
  prompt = create_router_agent_prompt(config)
55
55
  llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
56
- branches = builder.get_tools(tool_names=config.branches, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
56
+ branches = await builder.get_tools(tool_names=config.branches, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
57
57
  if not branches:
58
58
  raise ValueError(f"No branches specified for Router Agent '{config.llm_name}'")
59
59
 
@@ -80,14 +80,12 @@ def _validate_function_type_compatibility(src_fn: Function,
80
80
  f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
81
81
 
82
82
 
83
- def _validate_tool_list_type_compatibility(sequential_executor_config: SequentialExecutorConfig,
84
- builder: Builder) -> tuple[type, type]:
83
+ async def _validate_tool_list_type_compatibility(sequential_executor_config: SequentialExecutorConfig,
84
+ builder: Builder) -> tuple[type, type]:
85
85
  tool_list = sequential_executor_config.tool_list
86
86
  tool_execution_config = sequential_executor_config.tool_execution_config
87
87
 
88
- function_list: list[Function] = []
89
- for function_ref in tool_list:
90
- function_list.append(builder.get_function(function_ref))
88
+ function_list = await builder.get_functions(tool_list)
91
89
  if not function_list:
92
90
  raise RuntimeError("The function list is empty")
93
91
  input_type = function_list[0].input_type
@@ -110,11 +108,12 @@ def _validate_tool_list_type_compatibility(sequential_executor_config: Sequentia
110
108
  async def sequential_execution(config: SequentialExecutorConfig, builder: Builder):
111
109
  logger.debug(f"Initializing sequential executor with tool list: {config.tool_list}")
112
110
 
113
- tools: list[BaseTool] = builder.get_tools(tool_names=config.tool_list, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
111
+ tools: list[BaseTool] = await builder.get_tools(tool_names=config.tool_list,
112
+ wrapper_type=LLMFrameworkEnum.LANGCHAIN)
114
113
  tools_dict: dict[str, BaseTool] = {tool.name: tool for tool in tools}
115
114
 
116
115
  try:
117
- input_type, output_type = _validate_tool_list_type_compatibility(config, builder)
116
+ input_type, output_type = await _validate_tool_list_type_compatibility(config, builder)
118
117
  except ValueError as e:
119
118
  if config.raise_type_incompatibility:
120
119
  logger.error(f"The sequential executor tool list has incompatible types: {e}")
nat/eval/evaluate.py CHANGED
@@ -520,7 +520,8 @@ class EvaluationRun:
520
520
  await self.run_workflow_remote()
521
521
  elif not self.config.skip_workflow:
522
522
  if session_manager is None:
523
- session_manager = SessionManager(eval_workflow.build(),
523
+ workflow = await eval_workflow.build()
524
+ session_manager = SessionManager(workflow,
524
525
  max_concurrency=self.eval_config.general.max_concurrency)
525
526
  await self.run_workflow_local(session_manager)
526
527
 
@@ -33,7 +33,7 @@ async def register_trajectory_evaluator(config: TrajectoryEvaluatorConfig, build
33
33
 
34
34
  from .evaluate import TrajectoryEvaluator
35
35
  llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
36
- tools = builder.get_all_tools(wrapper_type=LLMFrameworkEnum.LANGCHAIN)
36
+ tools = await builder.get_all_tools(wrapper_type=LLMFrameworkEnum.LANGCHAIN)
37
37
 
38
38
  _evaluator = TrajectoryEvaluator(llm, tools, builder.get_max_concurrency())
39
39
 
@@ -16,7 +16,12 @@
16
16
  import functools
17
17
  import inspect
18
18
  import logging
19
+ from collections.abc import AsyncGenerator
20
+ from collections.abc import Callable
21
+ from collections.abc import Generator
19
22
  from typing import Any
23
+ from typing import TypeVar
24
+ from typing import overload
20
25
 
21
26
  logger = logging.getLogger(__name__)
22
27
 
@@ -25,6 +30,9 @@ BASE_WARNING_MESSAGE = ("is experimental and the API may change in future releas
25
30
 
26
31
  _warning_issued = set()
27
32
 
33
+ # Type variables for overloads
34
+ F = TypeVar('F', bound=Callable[..., Any])
35
+
28
36
 
29
37
  def issue_experimental_warning(function_name: str,
30
38
  feature_name: str | None = None,
@@ -53,7 +61,20 @@ def issue_experimental_warning(function_name: str,
53
61
  _warning_issued.add(function_name)
54
62
 
55
63
 
56
- def experimental(func: Any = None, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None):
64
+ # Overloads for different function types
65
+ @overload
66
+ def experimental(func: F, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
67
+ """Overload for when a function is passed directly."""
68
+ ...
69
+
70
+
71
+ @overload
72
+ def experimental(*, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
73
+ """Overload for decorator factory usage (when called with parentheses)."""
74
+ ...
75
+
76
+
77
+ def experimental(func: Any = None, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> Any:
57
78
  """
58
79
  Decorator that can wrap any type of function (sync, async, generator,
59
80
  async generator) and logs a warning that the function is experimental.
@@ -90,7 +111,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
90
111
  # ---------------------
91
112
 
92
113
  @functools.wraps(func)
93
- async def async_gen_wrapper(*args, **kwargs):
114
+ async def async_gen_wrapper(*args, **kwargs) -> AsyncGenerator[Any, Any]:
94
115
  issue_experimental_warning(function_name, feature_name, metadata)
95
116
  async for item in func(*args, **kwargs):
96
117
  yield item # yield the original item
@@ -102,7 +123,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
102
123
  # ASYNC FUNCTION
103
124
  # ---------------------
104
125
  @functools.wraps(func)
105
- async def async_wrapper(*args, **kwargs):
126
+ async def async_wrapper(*args, **kwargs) -> Any:
106
127
  issue_experimental_warning(function_name, feature_name, metadata)
107
128
  result = await func(*args, **kwargs)
108
129
  return result
@@ -114,7 +135,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
114
135
  # SYNC GENERATOR
115
136
  # ---------------------
116
137
  @functools.wraps(func)
117
- def sync_gen_wrapper(*args, **kwargs):
138
+ def sync_gen_wrapper(*args, **kwargs) -> Generator[Any, Any, Any]:
118
139
  issue_experimental_warning(function_name, feature_name, metadata)
119
140
  for item in func(*args, **kwargs):
120
141
  yield item # yield the original item
@@ -122,7 +143,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
122
143
  return sync_gen_wrapper
123
144
 
124
145
  @functools.wraps(func)
125
- def sync_wrapper(*args, **kwargs):
146
+ def sync_wrapper(*args, **kwargs) -> Any:
126
147
  issue_experimental_warning(function_name, feature_name, metadata)
127
148
  result = func(*args, **kwargs)
128
149
  return result
@@ -86,7 +86,7 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
86
86
  "This error can be resolved by installing nvidia-nat-langchain.")
87
87
 
88
88
  # Get the augmented function's description
89
- augmented_function = builder.get_function(config.augmented_fn)
89
+ augmented_function = await builder.get_function(config.augmented_fn)
90
90
 
91
91
  # For now, we rely on runtime checking for type conversion
92
92
 
@@ -105,7 +105,7 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
105
105
  tool_list = "Tool: Description\n"
106
106
 
107
107
  for tool in function_used_tools:
108
- tool_impl = builder.get_function(tool)
108
+ tool_impl = await builder.get_function(tool)
109
109
  tool_list += f"- {tool}: {tool_impl.description if hasattr(tool_impl, 'description') else ''}\n"
110
110
 
111
111
  # Draft the reasoning prompt for the augmented function
@@ -82,7 +82,7 @@ async def register_ttc_tool_orchestration_function(
82
82
  function_map = {}
83
83
  for fn_ref in config.augmented_fns:
84
84
  # Retrieve the actual function from the builder
85
- fn_obj = builder.get_function(fn_ref)
85
+ fn_obj = await builder.get_function(fn_ref)
86
86
  function_map[fn_ref] = fn_obj
87
87
 
88
88
  # 2) Instantiate search, editing, scoring, selection strategies (if any)
@@ -80,7 +80,7 @@ async def register_ttc_tool_wrapper_function(
80
80
  raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
81
81
  "This error can be resolved by installing nvidia-nat-langchain.")
82
82
 
83
- augmented_function: Function = builder.get_function(config.augmented_fn)
83
+ augmented_function: Function = await builder.get_function(config.augmented_fn)
84
84
  input_llm: BaseChatModel = await builder.get_llm(config.input_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
85
85
 
86
86
  if not augmented_function.has_single_output:
@@ -46,11 +46,11 @@ class StrategyBase(ABC):
46
46
  items: list[TTCItem],
47
47
  original_prompt: str | None = None,
48
48
  agent_context: str | None = None,
49
- **kwargs) -> [TTCItem]:
49
+ **kwargs) -> list[TTCItem]:
50
50
  pass
51
51
 
52
52
  @abstractmethod
53
- def supported_pipeline_types(self) -> [PipelineTypeEnum]:
53
+ def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
54
54
  """Return the stage types supported by this selector."""
55
55
  pass
56
56
 
@@ -55,9 +55,10 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
55
55
  self.auth_flow_handler = ConsoleAuthenticationFlowHandler()
56
56
 
57
57
  async def pre_run(self):
58
-
59
- if (not self.front_end_config.input_query and not self.front_end_config.input_file):
60
- raise click.UsageError("Must specify either --input_query or --input_file")
58
+ if (self.front_end_config.input_query is not None and self.front_end_config.input_file is not None):
59
+ raise click.UsageError("Must specify either --input or --input_file, not both")
60
+ if (self.front_end_config.input_query is None and self.front_end_config.input_file is None):
61
+ raise click.UsageError("Must specify either --input or --input_file")
61
62
 
62
63
  async def run_workflow(self, session_manager: SessionManager):
63
64
 
@@ -237,14 +237,14 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
237
237
 
238
238
  async def add_routes(self, app: FastAPI, builder: WorkflowBuilder):
239
239
 
240
- await self.add_default_route(app, SessionManager(builder.build()))
241
- await self.add_evaluate_route(app, SessionManager(builder.build()))
240
+ await self.add_default_route(app, SessionManager(await builder.build()))
241
+ await self.add_evaluate_route(app, SessionManager(await builder.build()))
242
242
  await self.add_static_files_route(app, builder)
243
243
  await self.add_authorization_route(app)
244
244
 
245
245
  for ep in self.front_end_config.endpoints:
246
246
 
247
- entry_workflow = builder.build(entry_function=ep.function_name)
247
+ entry_workflow = await builder.build(entry_function=ep.function_name)
248
248
 
249
249
  await self.add_route(app, endpoint=ep, session_manager=SessionManager(entry_workflow))
250
250
 
@@ -86,7 +86,7 @@ class MCPFrontEndPluginWorkerBase(ABC):
86
86
  """
87
87
  pass
88
88
 
89
- def _get_all_functions(self, workflow: Workflow) -> dict[str, Function]:
89
+ async def _get_all_functions(self, workflow: Workflow) -> dict[str, Function]:
90
90
  """Get all functions from the workflow.
91
91
 
92
92
  Args:
@@ -100,7 +100,7 @@ class MCPFrontEndPluginWorkerBase(ABC):
100
100
  # Extract all functions from the workflow
101
101
  functions.update(workflow.functions)
102
102
  for function_group in workflow.function_groups.values():
103
- functions.update(function_group.get_accessible_functions())
103
+ functions.update(await function_group.get_accessible_functions())
104
104
 
105
105
  if workflow.config.workflow.workflow_alias:
106
106
  functions[workflow.config.workflow.workflow_alias] = workflow
@@ -223,10 +223,10 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
223
223
  self._setup_health_endpoint(mcp)
224
224
 
225
225
  # Build the workflow and register all functions with MCP
226
- workflow = builder.build()
226
+ workflow = await builder.build()
227
227
 
228
228
  # Get all functions from the workflow
229
- functions = self._get_all_functions(workflow)
229
+ functions = await self._get_all_functions(workflow)
230
230
 
231
231
  # Filter functions based on tool_names if provided
232
232
  if self.front_end_config.tool_names:
@@ -35,6 +35,8 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
35
35
 
36
36
  async def run(self):
37
37
 
38
+ await self.pre_run()
39
+
38
40
  # Must yield the workflow function otherwise it cleans up
39
41
  async with WorkflowBuilder.from_config(config=self.full_config) as builder:
40
42
 
@@ -45,7 +47,7 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
45
47
 
46
48
  click.echo(stream.getvalue())
47
49
 
48
- workflow = builder.build()
50
+ workflow = await builder.build()
49
51
  session_manager = SessionManager(workflow)
50
52
  await self.run_workflow(session_manager)
51
53
 
nat/llm/litellm_llm.py ADDED
@@ -0,0 +1,69 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import AsyncIterator
17
+
18
+ from pydantic import AliasChoices
19
+ from pydantic import ConfigDict
20
+ from pydantic import Field
21
+
22
+ from nat.builder.builder import Builder
23
+ from nat.builder.llm import LLMProviderInfo
24
+ from nat.cli.register_workflow import register_llm_provider
25
+ from nat.data_models.llm import LLMBaseConfig
26
+ from nat.data_models.retry_mixin import RetryMixin
27
+ from nat.data_models.temperature_mixin import TemperatureMixin
28
+ from nat.data_models.thinking_mixin import ThinkingMixin
29
+ from nat.data_models.top_p_mixin import TopPMixin
30
+
31
+
32
+ class LiteLlmModelConfig(
33
+ LLMBaseConfig,
34
+ RetryMixin,
35
+ TemperatureMixin,
36
+ TopPMixin,
37
+ ThinkingMixin,
38
+ name="litellm",
39
+ ):
40
+ """A LiteLlm provider to be used with an LLM client."""
41
+
42
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
43
+
44
+ api_key: str | None = Field(default=None, description="API key to interact with hosted model.")
45
+ base_url: str | None = Field(default=None,
46
+ description="Base url to the hosted model.",
47
+ validation_alias=AliasChoices("base_url", "api_base"),
48
+ serialization_alias="api_base")
49
+ model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
50
+ serialization_alias="model",
51
+ description="The LiteLlm hosted model name.")
52
+ seed: int | None = Field(default=None, description="Random seed to set for generation.")
53
+
54
+
55
+ @register_llm_provider(config_type=LiteLlmModelConfig)
56
+ async def litellm_model(
57
+ config: LiteLlmModelConfig,
58
+ _builder: Builder,
59
+ ) -> AsyncIterator[LLMProviderInfo]:
60
+ """Litellm model provider.
61
+
62
+ Args:
63
+ config (LiteLlmModelConfig): The LiteLlm model configuration.
64
+ _builder (Builder): The NAT builder instance.
65
+
66
+ Returns:
67
+ AsyncIterator[LLMProviderInfo]: An async iterator that yields an LLMProviderInfo object.
68
+ """
69
+ yield LLMProviderInfo(config=config, description="A LiteLlm model for use with an LLM client.")
nat/llm/register.py CHANGED
@@ -15,9 +15,13 @@
15
15
 
16
16
  # flake8: noqa
17
17
  # isort:skip_file
18
+ """Register LLM providers via import side effects.
18
19
 
20
+ This module is imported by the NeMo Agent Toolkit runtime to ensure providers are registered and discoverable.
21
+ """
19
22
  # Import any providers which need to be automatically registered here
20
23
  from . import aws_bedrock_llm
21
24
  from . import azure_openai_llm
22
25
  from . import nim_llm
23
26
  from . import openai_llm
27
+ from . import litellm_llm
@@ -17,6 +17,7 @@ from __future__ import annotations
17
17
 
18
18
  import functools
19
19
  import logging
20
+ from collections.abc import AsyncIterator
20
21
  from collections.abc import Callable
21
22
  from contextlib import AbstractAsyncContextManager as AsyncContextManager
22
23
  from contextlib import asynccontextmanager
@@ -32,20 +33,37 @@ _library_instrumented = {
32
33
  "crewai": False,
33
34
  "semantic_kernel": False,
34
35
  "agno": False,
36
+ "adk": False,
35
37
  }
36
38
 
37
39
  callback_handler_var: ContextVar[Any | None] = ContextVar("callback_handler_var", default=None)
38
40
 
39
41
 
40
42
  def set_framework_profiler_handler(
41
- workflow_llms: dict = None,
42
- frameworks: list[LLMFrameworkEnum] = None,
43
+ workflow_llms: dict | None = None,
44
+ frameworks: list[LLMFrameworkEnum] | None = None,
43
45
  ) -> Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
44
46
  """
45
47
  Decorator that wraps an async context manager function to set up framework-specific profiling.
48
+
49
+ Args:
50
+ workflow_llms (dict | None): A dictionary of workflow LLM configurations.
51
+ frameworks (list[LLMFrameworkEnum] | None): A list of LLM frameworks used in the workflow functions.
52
+
53
+ Returns:
54
+ Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
55
+ A decorator that wraps the original function with profiling setup.
46
56
  """
47
57
 
48
58
  def decorator(func: Callable[..., AsyncContextManager[Any]]) -> Callable[..., AsyncContextManager[Any]]:
59
+ """The actual decorator that wraps the function.
60
+
61
+ Args:
62
+ func (Callable[..., AsyncContextManager[Any]]): The function to wrap.
63
+
64
+ Returns:
65
+ Callable[..., AsyncContextManager[Any]]: The wrapped function.
66
+ """
49
67
 
50
68
  @functools.wraps(func)
51
69
  @asynccontextmanager
@@ -99,6 +117,20 @@ def set_framework_profiler_handler(
99
117
  _library_instrumented["agno"] = True
100
118
  logger.info("Agno callback handler registered")
101
119
 
120
+ if LLMFrameworkEnum.ADK in frameworks and not _library_instrumented["adk"]:
121
+ try:
122
+ from nat.plugins.adk.adk_callback_handler import ADKProfilerHandler
123
+ except ImportError as e:
124
+ logger.warning(
125
+ "ADK profiler not available. " +
126
+ "Install NAT with ADK extras: pip install 'nvidia-nat[adk]'. Error: %s",
127
+ e)
128
+ else:
129
+ handler = ADKProfilerHandler()
130
+ handler.instrument()
131
+ _library_instrumented["adk"] = True
132
+ logger.debug("ADK callback handler registered")
133
+
102
134
  # IMPORTANT: actually call the wrapped function as an async context manager
103
135
  async with func(workflow_config, builder) as result:
104
136
  yield result
@@ -117,11 +149,28 @@ def chain_wrapped_build_fn(
117
149
  Convert an original build function into an async context manager that
118
150
  wraps it with a single call to set_framework_profiler_handler, passing
119
151
  all frameworks at once.
152
+
153
+ Args:
154
+ original_build_fn (Callable[..., AsyncContextManager]): The original build function to wrap.
155
+ workflow_llms (dict): A dictionary of workflow LLM configurations.
156
+ function_frameworks (list[LLMFrameworkEnum]): A list of LLM frameworks used in the workflow functions.
157
+
158
+ Returns:
159
+ Callable[..., AsyncContextManager]: The wrapped build function.
120
160
  """
121
161
 
122
162
  # Define a base async context manager that simply calls the original build function.
123
163
  @asynccontextmanager
124
- async def base_fn(*args, **kwargs):
164
+ async def base_fn(*args, **kwargs) -> AsyncIterator[Any]:
165
+ """Base async context manager that calls the original build function.
166
+
167
+ Args:
168
+ *args: Positional arguments to pass to the original build function.
169
+ **kwargs: Keyword arguments to pass to the original build function.
170
+
171
+ Yields:
172
+ The result of the original build function.
173
+ """
125
174
  async with original_build_fn(*args, **kwargs) as w:
126
175
  yield w
127
176
 
@@ -18,7 +18,9 @@ import inspect
18
18
  import uuid
19
19
  from collections.abc import Callable
20
20
  from typing import Any
21
+ from typing import TypeVar
21
22
  from typing import cast
23
+ from typing import overload
22
24
 
23
25
  from pydantic import BaseModel
24
26
 
@@ -77,7 +79,24 @@ def push_intermediate_step(step_manager: IntermediateStepManager,
77
79
  step_manager.push_intermediate_step(payload)
78
80
 
79
81
 
80
- def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
82
+ # Type variable for overloads
83
+ F = TypeVar('F', bound=Callable[..., Any])
84
+
85
+
86
+ # Overloads for different function types
87
+ @overload
88
+ def track_function(func: F, *, metadata: dict[str, Any] | None = None) -> F:
89
+ """Overload for when a function is passed directly."""
90
+ ...
91
+
92
+
93
+ @overload
94
+ def track_function(*, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
95
+ """Overload for decorator factory usage (when called with parentheses)."""
96
+ ...
97
+
98
+
99
+ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None) -> Any:
81
100
  """
82
101
  Decorator that can wrap any type of function (sync, async, generator,
83
102
  async generator) and executes "tracking logic" around it.
@@ -256,6 +275,19 @@ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
256
275
  return sync_wrapper
257
276
 
258
277
 
278
+ # Overloads for track_unregistered_function
279
+ @overload
280
+ def track_unregistered_function(func: F, *, name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
281
+ """Overload for when a function is passed directly."""
282
+ ...
283
+
284
+
285
+ @overload
286
+ def track_unregistered_function(*, name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
287
+ """Overload for decorator factory usage (when called with parentheses)."""
288
+ ...
289
+
290
+
259
291
  def track_unregistered_function(func: Callable[..., Any] | None = None,
260
292
  *,
261
293
  name: str | None = None,
@@ -151,11 +151,11 @@ async def optimize_prompts(
151
151
  if not init_fn_name:
152
152
  raise ValueError(
153
153
  "No prompt optimization function configured. Set optimizer.prompt_population_init_function")
154
- init_fn = builder.get_function(init_fn_name)
154
+ init_fn = await builder.get_function(init_fn_name)
155
155
 
156
156
  recombine_fn = None
157
157
  if optimizer_config.prompt.prompt_recombination_function:
158
- recombine_fn = builder.get_function(optimizer_config.prompt.prompt_recombination_function)
158
+ recombine_fn = await builder.get_function(optimizer_config.prompt.prompt_recombination_function)
159
159
 
160
160
  logger.info(
161
161
  "GA Prompt optimization ready: init_fn=%s, recombine_fn=%s",
nat/runtime/loader.py CHANGED
@@ -114,7 +114,7 @@ async def load_workflow(config_file: StrPath, max_concurrency: int = -1):
114
114
  # Must yield the workflow function otherwise it cleans up
115
115
  async with WorkflowBuilder.from_config(config=config) as workflow:
116
116
 
117
- yield SessionManager(workflow.build(), max_concurrency=max_concurrency)
117
+ yield SessionManager(await workflow.build(), max_concurrency=max_concurrency)
118
118
 
119
119
 
120
120
  @lru_cache
@@ -90,7 +90,7 @@ class TypeConverter:
90
90
  decomposed = DecomposedType(to_type)
91
91
 
92
92
  # 1) If data is already correct type, return it
93
- if to_type is None or decomposed.is_instance((data, to_type)):
93
+ if to_type is None or decomposed.is_instance(data):
94
94
  return data
95
95
 
96
96
  root = decomposed.root
@@ -198,16 +198,17 @@ class TypeConverter:
198
198
  """
199
199
  visited = set()
200
200
  final = self._try_indirect_conversion(data, to_type, visited)
201
+ src_type = type(data)
201
202
  if final is not None:
202
203
  # Warn once if found a chain
203
- self._maybe_warn_indirect(type(data), to_type)
204
+ self._maybe_warn_indirect(src_type, to_type)
204
205
  return final
205
206
 
206
207
  # If no success, try parent's indirect
207
208
  if self._parent is not None:
208
209
  parent_final = self._parent._try_indirect_convert(data, to_type)
209
210
  if parent_final is not None:
210
- self._maybe_warn_indirect(type(data), to_type)
211
+ self._maybe_warn_indirect(src_type, to_type)
211
212
  return parent_final
212
213
 
213
214
  return None
nat/utils/type_utils.py CHANGED
@@ -353,7 +353,7 @@ class DecomposedType:
353
353
  True if the current type is an instance of the specified instance, False otherwise
354
354
  """
355
355
 
356
- return isinstance(instance, self.root)
356
+ return isinstance(instance, self.get_base_type().root)
357
357
 
358
358
  def get_pydantic_schema(self,
359
359
  converters: list[collections.abc.Callable] | None = None) -> type[BaseModel] | type[None]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat
3
- Version: 1.3.0a20250923
3
+ Version: 1.3.0a20250925
4
4
  Summary: NVIDIA NeMo Agent toolkit
5
5
  Author: NVIDIA Corporation
6
6
  Maintainer: NVIDIA Corporation
@@ -215,7 +215,7 @@ Description-Content-Type: text/markdown
215
215
  License-File: LICENSE-3rd-party.txt
216
216
  License-File: LICENSE.md
217
217
  Requires-Dist: aioboto3>=11.0.0
218
- Requires-Dist: authlib~=1.3.1
218
+ Requires-Dist: authlib~=1.5
219
219
  Requires-Dist: click~=8.1
220
220
  Requires-Dist: colorama~=0.4.6
221
221
  Requires-Dist: datasets~=4.0
@@ -241,10 +241,12 @@ Requires-Dist: PyYAML~=6.0
241
241
  Requires-Dist: ragas~=0.2.14
242
242
  Requires-Dist: rich~=13.9
243
243
  Requires-Dist: tabulate~=0.9
244
- Requires-Dist: uvicorn[standard]~=0.32.0
244
+ Requires-Dist: uvicorn[standard]~=0.34
245
245
  Requires-Dist: wikipedia~=1.4
246
246
  Provides-Extra: all
247
247
  Requires-Dist: nvidia-nat-all; extra == "all"
248
+ Provides-Extra: adk
249
+ Requires-Dist: nvidia-nat-adk; extra == "adk"
248
250
  Provides-Extra: agno
249
251
  Requires-Dist: nvidia-nat-agno; extra == "agno"
250
252
  Provides-Extra: crewai
@@ -287,6 +289,7 @@ Requires-Dist: nvidia-nat-weave; extra == "weave"
287
289
  Provides-Extra: zep-cloud
288
290
  Requires-Dist: nvidia-nat-zep-cloud; extra == "zep-cloud"
289
291
  Provides-Extra: examples
292
+ Requires-Dist: nat_adk_demo; extra == "examples"
290
293
  Requires-Dist: nat_agno_personal_finance; extra == "examples"
291
294
  Requires-Dist: nat_alert_triage_agent; extra == "examples"
292
295
  Requires-Dist: nat_automated_description_generation; extra == "examples"