nvidia-nat 1.3.0a20250923__py3-none-any.whl → 1.3.0a20250925__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/agent.py +5 -4
- nat/agent/react_agent/register.py +12 -1
- nat/agent/reasoning_agent/reasoning_agent.py +2 -2
- nat/agent/rewoo_agent/register.py +12 -1
- nat/agent/tool_calling_agent/register.py +28 -8
- nat/builder/builder.py +33 -24
- nat/builder/component_utils.py +1 -1
- nat/builder/eval_builder.py +14 -9
- nat/builder/framework_enum.py +1 -0
- nat/builder/function.py +108 -52
- nat/builder/workflow_builder.py +89 -79
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +786 -0
- nat/cli/entrypoint.py +2 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/control_flow/sequential_executor.py +6 -7
- nat/eval/evaluate.py +2 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +26 -5
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +2 -2
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +2 -2
- nat/front_ends/console/console_front_end_plugin.py +4 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +4 -4
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/litellm_llm.py +69 -0
- nat/llm/register.py +4 -0
- nat/profiler/decorators/framework_wrapper.py +52 -3
- nat/profiler/decorators/function_tracking.py +33 -1
- nat/profiler/parameter_optimization/prompt_optimizer.py +2 -2
- nat/runtime/loader.py +1 -1
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +1 -1
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/METADATA +6 -3
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/RECORD +43 -41
- nat/cli/commands/info/list_mcp.py +0 -461
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250923.dist-info → nvidia_nat-1.3.0a20250925.dist-info}/top_level.txt +0 -0
nat/cli/entrypoint.py
CHANGED
|
@@ -35,6 +35,7 @@ from nat.utils.log_levels import LOG_LEVELS
|
|
|
35
35
|
from .commands.configure.configure import configure_command
|
|
36
36
|
from .commands.evaluate import eval_command
|
|
37
37
|
from .commands.info.info import info_command
|
|
38
|
+
from .commands.mcp.mcp import mcp_command
|
|
38
39
|
from .commands.object_store.object_store import object_store_command
|
|
39
40
|
from .commands.optimize import optimizer_command
|
|
40
41
|
from .commands.registry.registry import registry_command
|
|
@@ -104,11 +105,11 @@ cli.add_command(workflow_command, name="workflow")
|
|
|
104
105
|
cli.add_command(sizing, name="sizing")
|
|
105
106
|
cli.add_command(optimizer_command, name="optimize")
|
|
106
107
|
cli.add_command(object_store_command, name="object-store")
|
|
108
|
+
cli.add_command(mcp_command, name="mcp")
|
|
107
109
|
|
|
108
110
|
# Aliases
|
|
109
111
|
cli.add_command(start_command.get_command(None, "console"), name="run") # type: ignore
|
|
110
112
|
cli.add_command(start_command.get_command(None, "fastapi"), name="serve") # type: ignore
|
|
111
|
-
cli.add_command(start_command.get_command(None, "mcp"), name="mcp") # type: ignore
|
|
112
113
|
|
|
113
114
|
|
|
114
115
|
@cli.result_callback()
|
|
@@ -53,7 +53,7 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
|
|
|
53
53
|
|
|
54
54
|
prompt = create_router_agent_prompt(config)
|
|
55
55
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
56
|
-
branches = builder.get_tools(tool_names=config.branches, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
56
|
+
branches = await builder.get_tools(tool_names=config.branches, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
57
57
|
if not branches:
|
|
58
58
|
raise ValueError(f"No branches specified for Router Agent '{config.llm_name}'")
|
|
59
59
|
|
|
@@ -80,14 +80,12 @@ def _validate_function_type_compatibility(src_fn: Function,
|
|
|
80
80
|
f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
|
|
81
81
|
|
|
82
82
|
|
|
83
|
-
def _validate_tool_list_type_compatibility(sequential_executor_config: SequentialExecutorConfig,
|
|
84
|
-
|
|
83
|
+
async def _validate_tool_list_type_compatibility(sequential_executor_config: SequentialExecutorConfig,
|
|
84
|
+
builder: Builder) -> tuple[type, type]:
|
|
85
85
|
tool_list = sequential_executor_config.tool_list
|
|
86
86
|
tool_execution_config = sequential_executor_config.tool_execution_config
|
|
87
87
|
|
|
88
|
-
function_list
|
|
89
|
-
for function_ref in tool_list:
|
|
90
|
-
function_list.append(builder.get_function(function_ref))
|
|
88
|
+
function_list = await builder.get_functions(tool_list)
|
|
91
89
|
if not function_list:
|
|
92
90
|
raise RuntimeError("The function list is empty")
|
|
93
91
|
input_type = function_list[0].input_type
|
|
@@ -110,11 +108,12 @@ def _validate_tool_list_type_compatibility(sequential_executor_config: Sequentia
|
|
|
110
108
|
async def sequential_execution(config: SequentialExecutorConfig, builder: Builder):
|
|
111
109
|
logger.debug(f"Initializing sequential executor with tool list: {config.tool_list}")
|
|
112
110
|
|
|
113
|
-
tools: list[BaseTool] = builder.get_tools(tool_names=config.tool_list,
|
|
111
|
+
tools: list[BaseTool] = await builder.get_tools(tool_names=config.tool_list,
|
|
112
|
+
wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
114
113
|
tools_dict: dict[str, BaseTool] = {tool.name: tool for tool in tools}
|
|
115
114
|
|
|
116
115
|
try:
|
|
117
|
-
input_type, output_type = _validate_tool_list_type_compatibility(config, builder)
|
|
116
|
+
input_type, output_type = await _validate_tool_list_type_compatibility(config, builder)
|
|
118
117
|
except ValueError as e:
|
|
119
118
|
if config.raise_type_incompatibility:
|
|
120
119
|
logger.error(f"The sequential executor tool list has incompatible types: {e}")
|
nat/eval/evaluate.py
CHANGED
|
@@ -520,7 +520,8 @@ class EvaluationRun:
|
|
|
520
520
|
await self.run_workflow_remote()
|
|
521
521
|
elif not self.config.skip_workflow:
|
|
522
522
|
if session_manager is None:
|
|
523
|
-
|
|
523
|
+
workflow = await eval_workflow.build()
|
|
524
|
+
session_manager = SessionManager(workflow,
|
|
524
525
|
max_concurrency=self.eval_config.general.max_concurrency)
|
|
525
526
|
await self.run_workflow_local(session_manager)
|
|
526
527
|
|
|
@@ -33,7 +33,7 @@ async def register_trajectory_evaluator(config: TrajectoryEvaluatorConfig, build
|
|
|
33
33
|
|
|
34
34
|
from .evaluate import TrajectoryEvaluator
|
|
35
35
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
36
|
-
tools = builder.get_all_tools(wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
36
|
+
tools = await builder.get_all_tools(wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
37
37
|
|
|
38
38
|
_evaluator = TrajectoryEvaluator(llm, tools, builder.get_max_concurrency())
|
|
39
39
|
|
|
@@ -16,7 +16,12 @@
|
|
|
16
16
|
import functools
|
|
17
17
|
import inspect
|
|
18
18
|
import logging
|
|
19
|
+
from collections.abc import AsyncGenerator
|
|
20
|
+
from collections.abc import Callable
|
|
21
|
+
from collections.abc import Generator
|
|
19
22
|
from typing import Any
|
|
23
|
+
from typing import TypeVar
|
|
24
|
+
from typing import overload
|
|
20
25
|
|
|
21
26
|
logger = logging.getLogger(__name__)
|
|
22
27
|
|
|
@@ -25,6 +30,9 @@ BASE_WARNING_MESSAGE = ("is experimental and the API may change in future releas
|
|
|
25
30
|
|
|
26
31
|
_warning_issued = set()
|
|
27
32
|
|
|
33
|
+
# Type variables for overloads
|
|
34
|
+
F = TypeVar('F', bound=Callable[..., Any])
|
|
35
|
+
|
|
28
36
|
|
|
29
37
|
def issue_experimental_warning(function_name: str,
|
|
30
38
|
feature_name: str | None = None,
|
|
@@ -53,7 +61,20 @@ def issue_experimental_warning(function_name: str,
|
|
|
53
61
|
_warning_issued.add(function_name)
|
|
54
62
|
|
|
55
63
|
|
|
56
|
-
|
|
64
|
+
# Overloads for different function types
|
|
65
|
+
@overload
|
|
66
|
+
def experimental(func: F, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
|
|
67
|
+
"""Overload for when a function is passed directly."""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@overload
|
|
72
|
+
def experimental(*, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
|
|
73
|
+
"""Overload for decorator factory usage (when called with parentheses)."""
|
|
74
|
+
...
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def experimental(func: Any = None, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> Any:
|
|
57
78
|
"""
|
|
58
79
|
Decorator that can wrap any type of function (sync, async, generator,
|
|
59
80
|
async generator) and logs a warning that the function is experimental.
|
|
@@ -90,7 +111,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
90
111
|
# ---------------------
|
|
91
112
|
|
|
92
113
|
@functools.wraps(func)
|
|
93
|
-
async def async_gen_wrapper(*args, **kwargs):
|
|
114
|
+
async def async_gen_wrapper(*args, **kwargs) -> AsyncGenerator[Any, Any]:
|
|
94
115
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
95
116
|
async for item in func(*args, **kwargs):
|
|
96
117
|
yield item # yield the original item
|
|
@@ -102,7 +123,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
102
123
|
# ASYNC FUNCTION
|
|
103
124
|
# ---------------------
|
|
104
125
|
@functools.wraps(func)
|
|
105
|
-
async def async_wrapper(*args, **kwargs):
|
|
126
|
+
async def async_wrapper(*args, **kwargs) -> Any:
|
|
106
127
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
107
128
|
result = await func(*args, **kwargs)
|
|
108
129
|
return result
|
|
@@ -114,7 +135,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
114
135
|
# SYNC GENERATOR
|
|
115
136
|
# ---------------------
|
|
116
137
|
@functools.wraps(func)
|
|
117
|
-
def sync_gen_wrapper(*args, **kwargs):
|
|
138
|
+
def sync_gen_wrapper(*args, **kwargs) -> Generator[Any, Any, Any]:
|
|
118
139
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
119
140
|
for item in func(*args, **kwargs):
|
|
120
141
|
yield item # yield the original item
|
|
@@ -122,7 +143,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
122
143
|
return sync_gen_wrapper
|
|
123
144
|
|
|
124
145
|
@functools.wraps(func)
|
|
125
|
-
def sync_wrapper(*args, **kwargs):
|
|
146
|
+
def sync_wrapper(*args, **kwargs) -> Any:
|
|
126
147
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
127
148
|
result = func(*args, **kwargs)
|
|
128
149
|
return result
|
|
@@ -86,7 +86,7 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
|
|
|
86
86
|
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
87
87
|
|
|
88
88
|
# Get the augmented function's description
|
|
89
|
-
augmented_function = builder.get_function(config.augmented_fn)
|
|
89
|
+
augmented_function = await builder.get_function(config.augmented_fn)
|
|
90
90
|
|
|
91
91
|
# For now, we rely on runtime checking for type conversion
|
|
92
92
|
|
|
@@ -105,7 +105,7 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
|
|
|
105
105
|
tool_list = "Tool: Description\n"
|
|
106
106
|
|
|
107
107
|
for tool in function_used_tools:
|
|
108
|
-
tool_impl = builder.get_function(tool)
|
|
108
|
+
tool_impl = await builder.get_function(tool)
|
|
109
109
|
tool_list += f"- {tool}: {tool_impl.description if hasattr(tool_impl, 'description') else ''}\n"
|
|
110
110
|
|
|
111
111
|
# Draft the reasoning prompt for the augmented function
|
|
@@ -82,7 +82,7 @@ async def register_ttc_tool_orchestration_function(
|
|
|
82
82
|
function_map = {}
|
|
83
83
|
for fn_ref in config.augmented_fns:
|
|
84
84
|
# Retrieve the actual function from the builder
|
|
85
|
-
fn_obj = builder.get_function(fn_ref)
|
|
85
|
+
fn_obj = await builder.get_function(fn_ref)
|
|
86
86
|
function_map[fn_ref] = fn_obj
|
|
87
87
|
|
|
88
88
|
# 2) Instantiate search, editing, scoring, selection strategies (if any)
|
|
@@ -80,7 +80,7 @@ async def register_ttc_tool_wrapper_function(
|
|
|
80
80
|
raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
|
|
81
81
|
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
82
82
|
|
|
83
|
-
augmented_function: Function = builder.get_function(config.augmented_fn)
|
|
83
|
+
augmented_function: Function = await builder.get_function(config.augmented_fn)
|
|
84
84
|
input_llm: BaseChatModel = await builder.get_llm(config.input_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
85
85
|
|
|
86
86
|
if not augmented_function.has_single_output:
|
|
@@ -46,11 +46,11 @@ class StrategyBase(ABC):
|
|
|
46
46
|
items: list[TTCItem],
|
|
47
47
|
original_prompt: str | None = None,
|
|
48
48
|
agent_context: str | None = None,
|
|
49
|
-
**kwargs) -> [TTCItem]:
|
|
49
|
+
**kwargs) -> list[TTCItem]:
|
|
50
50
|
pass
|
|
51
51
|
|
|
52
52
|
@abstractmethod
|
|
53
|
-
def supported_pipeline_types(self) -> [PipelineTypeEnum]:
|
|
53
|
+
def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
|
|
54
54
|
"""Return the stage types supported by this selector."""
|
|
55
55
|
pass
|
|
56
56
|
|
|
@@ -55,9 +55,10 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
55
55
|
self.auth_flow_handler = ConsoleAuthenticationFlowHandler()
|
|
56
56
|
|
|
57
57
|
async def pre_run(self):
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
58
|
+
if (self.front_end_config.input_query is not None and self.front_end_config.input_file is not None):
|
|
59
|
+
raise click.UsageError("Must specify either --input or --input_file, not both")
|
|
60
|
+
if (self.front_end_config.input_query is None and self.front_end_config.input_file is None):
|
|
61
|
+
raise click.UsageError("Must specify either --input or --input_file")
|
|
61
62
|
|
|
62
63
|
async def run_workflow(self, session_manager: SessionManager):
|
|
63
64
|
|
|
@@ -237,14 +237,14 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
237
237
|
|
|
238
238
|
async def add_routes(self, app: FastAPI, builder: WorkflowBuilder):
|
|
239
239
|
|
|
240
|
-
await self.add_default_route(app, SessionManager(builder.build()))
|
|
241
|
-
await self.add_evaluate_route(app, SessionManager(builder.build()))
|
|
240
|
+
await self.add_default_route(app, SessionManager(await builder.build()))
|
|
241
|
+
await self.add_evaluate_route(app, SessionManager(await builder.build()))
|
|
242
242
|
await self.add_static_files_route(app, builder)
|
|
243
243
|
await self.add_authorization_route(app)
|
|
244
244
|
|
|
245
245
|
for ep in self.front_end_config.endpoints:
|
|
246
246
|
|
|
247
|
-
entry_workflow = builder.build(entry_function=ep.function_name)
|
|
247
|
+
entry_workflow = await builder.build(entry_function=ep.function_name)
|
|
248
248
|
|
|
249
249
|
await self.add_route(app, endpoint=ep, session_manager=SessionManager(entry_workflow))
|
|
250
250
|
|
|
@@ -86,7 +86,7 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
86
86
|
"""
|
|
87
87
|
pass
|
|
88
88
|
|
|
89
|
-
def _get_all_functions(self, workflow: Workflow) -> dict[str, Function]:
|
|
89
|
+
async def _get_all_functions(self, workflow: Workflow) -> dict[str, Function]:
|
|
90
90
|
"""Get all functions from the workflow.
|
|
91
91
|
|
|
92
92
|
Args:
|
|
@@ -100,7 +100,7 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
100
100
|
# Extract all functions from the workflow
|
|
101
101
|
functions.update(workflow.functions)
|
|
102
102
|
for function_group in workflow.function_groups.values():
|
|
103
|
-
functions.update(function_group.get_accessible_functions())
|
|
103
|
+
functions.update(await function_group.get_accessible_functions())
|
|
104
104
|
|
|
105
105
|
if workflow.config.workflow.workflow_alias:
|
|
106
106
|
functions[workflow.config.workflow.workflow_alias] = workflow
|
|
@@ -223,10 +223,10 @@ class MCPFrontEndPluginWorker(MCPFrontEndPluginWorkerBase):
|
|
|
223
223
|
self._setup_health_endpoint(mcp)
|
|
224
224
|
|
|
225
225
|
# Build the workflow and register all functions with MCP
|
|
226
|
-
workflow = builder.build()
|
|
226
|
+
workflow = await builder.build()
|
|
227
227
|
|
|
228
228
|
# Get all functions from the workflow
|
|
229
|
-
functions = self._get_all_functions(workflow)
|
|
229
|
+
functions = await self._get_all_functions(workflow)
|
|
230
230
|
|
|
231
231
|
# Filter functions based on tool_names if provided
|
|
232
232
|
if self.front_end_config.tool_names:
|
|
@@ -35,6 +35,8 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
|
|
|
35
35
|
|
|
36
36
|
async def run(self):
|
|
37
37
|
|
|
38
|
+
await self.pre_run()
|
|
39
|
+
|
|
38
40
|
# Must yield the workflow function otherwise it cleans up
|
|
39
41
|
async with WorkflowBuilder.from_config(config=self.full_config) as builder:
|
|
40
42
|
|
|
@@ -45,7 +47,7 @@ class SimpleFrontEndPluginBase(FrontEndBase[FrontEndConfigT], ABC):
|
|
|
45
47
|
|
|
46
48
|
click.echo(stream.getvalue())
|
|
47
49
|
|
|
48
|
-
workflow = builder.build()
|
|
50
|
+
workflow = await builder.build()
|
|
49
51
|
session_manager = SessionManager(workflow)
|
|
50
52
|
await self.run_workflow(session_manager)
|
|
51
53
|
|
nat/llm/litellm_llm.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections.abc import AsyncIterator
|
|
17
|
+
|
|
18
|
+
from pydantic import AliasChoices
|
|
19
|
+
from pydantic import ConfigDict
|
|
20
|
+
from pydantic import Field
|
|
21
|
+
|
|
22
|
+
from nat.builder.builder import Builder
|
|
23
|
+
from nat.builder.llm import LLMProviderInfo
|
|
24
|
+
from nat.cli.register_workflow import register_llm_provider
|
|
25
|
+
from nat.data_models.llm import LLMBaseConfig
|
|
26
|
+
from nat.data_models.retry_mixin import RetryMixin
|
|
27
|
+
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
28
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
29
|
+
from nat.data_models.top_p_mixin import TopPMixin
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class LiteLlmModelConfig(
|
|
33
|
+
LLMBaseConfig,
|
|
34
|
+
RetryMixin,
|
|
35
|
+
TemperatureMixin,
|
|
36
|
+
TopPMixin,
|
|
37
|
+
ThinkingMixin,
|
|
38
|
+
name="litellm",
|
|
39
|
+
):
|
|
40
|
+
"""A LiteLlm provider to be used with an LLM client."""
|
|
41
|
+
|
|
42
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
43
|
+
|
|
44
|
+
api_key: str | None = Field(default=None, description="API key to interact with hosted model.")
|
|
45
|
+
base_url: str | None = Field(default=None,
|
|
46
|
+
description="Base url to the hosted model.",
|
|
47
|
+
validation_alias=AliasChoices("base_url", "api_base"),
|
|
48
|
+
serialization_alias="api_base")
|
|
49
|
+
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
50
|
+
serialization_alias="model",
|
|
51
|
+
description="The LiteLlm hosted model name.")
|
|
52
|
+
seed: int | None = Field(default=None, description="Random seed to set for generation.")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@register_llm_provider(config_type=LiteLlmModelConfig)
|
|
56
|
+
async def litellm_model(
|
|
57
|
+
config: LiteLlmModelConfig,
|
|
58
|
+
_builder: Builder,
|
|
59
|
+
) -> AsyncIterator[LLMProviderInfo]:
|
|
60
|
+
"""Litellm model provider.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
config (LiteLlmModelConfig): The LiteLlm model configuration.
|
|
64
|
+
_builder (Builder): The NAT builder instance.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
AsyncIterator[LLMProviderInfo]: An async iterator that yields an LLMProviderInfo object.
|
|
68
|
+
"""
|
|
69
|
+
yield LLMProviderInfo(config=config, description="A LiteLlm model for use with an LLM client.")
|
nat/llm/register.py
CHANGED
|
@@ -15,9 +15,13 @@
|
|
|
15
15
|
|
|
16
16
|
# flake8: noqa
|
|
17
17
|
# isort:skip_file
|
|
18
|
+
"""Register LLM providers via import side effects.
|
|
18
19
|
|
|
20
|
+
This module is imported by the NeMo Agent Toolkit runtime to ensure providers are registered and discoverable.
|
|
21
|
+
"""
|
|
19
22
|
# Import any providers which need to be automatically registered here
|
|
20
23
|
from . import aws_bedrock_llm
|
|
21
24
|
from . import azure_openai_llm
|
|
22
25
|
from . import nim_llm
|
|
23
26
|
from . import openai_llm
|
|
27
|
+
from . import litellm_llm
|
|
@@ -17,6 +17,7 @@ from __future__ import annotations
|
|
|
17
17
|
|
|
18
18
|
import functools
|
|
19
19
|
import logging
|
|
20
|
+
from collections.abc import AsyncIterator
|
|
20
21
|
from collections.abc import Callable
|
|
21
22
|
from contextlib import AbstractAsyncContextManager as AsyncContextManager
|
|
22
23
|
from contextlib import asynccontextmanager
|
|
@@ -32,20 +33,37 @@ _library_instrumented = {
|
|
|
32
33
|
"crewai": False,
|
|
33
34
|
"semantic_kernel": False,
|
|
34
35
|
"agno": False,
|
|
36
|
+
"adk": False,
|
|
35
37
|
}
|
|
36
38
|
|
|
37
39
|
callback_handler_var: ContextVar[Any | None] = ContextVar("callback_handler_var", default=None)
|
|
38
40
|
|
|
39
41
|
|
|
40
42
|
def set_framework_profiler_handler(
|
|
41
|
-
workflow_llms: dict = None,
|
|
42
|
-
frameworks: list[LLMFrameworkEnum] = None,
|
|
43
|
+
workflow_llms: dict | None = None,
|
|
44
|
+
frameworks: list[LLMFrameworkEnum] | None = None,
|
|
43
45
|
) -> Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
|
|
44
46
|
"""
|
|
45
47
|
Decorator that wraps an async context manager function to set up framework-specific profiling.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
workflow_llms (dict | None): A dictionary of workflow LLM configurations.
|
|
51
|
+
frameworks (list[LLMFrameworkEnum] | None): A list of LLM frameworks used in the workflow functions.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
|
|
55
|
+
A decorator that wraps the original function with profiling setup.
|
|
46
56
|
"""
|
|
47
57
|
|
|
48
58
|
def decorator(func: Callable[..., AsyncContextManager[Any]]) -> Callable[..., AsyncContextManager[Any]]:
|
|
59
|
+
"""The actual decorator that wraps the function.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
func (Callable[..., AsyncContextManager[Any]]): The function to wrap.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Callable[..., AsyncContextManager[Any]]: The wrapped function.
|
|
66
|
+
"""
|
|
49
67
|
|
|
50
68
|
@functools.wraps(func)
|
|
51
69
|
@asynccontextmanager
|
|
@@ -99,6 +117,20 @@ def set_framework_profiler_handler(
|
|
|
99
117
|
_library_instrumented["agno"] = True
|
|
100
118
|
logger.info("Agno callback handler registered")
|
|
101
119
|
|
|
120
|
+
if LLMFrameworkEnum.ADK in frameworks and not _library_instrumented["adk"]:
|
|
121
|
+
try:
|
|
122
|
+
from nat.plugins.adk.adk_callback_handler import ADKProfilerHandler
|
|
123
|
+
except ImportError as e:
|
|
124
|
+
logger.warning(
|
|
125
|
+
"ADK profiler not available. " +
|
|
126
|
+
"Install NAT with ADK extras: pip install 'nvidia-nat[adk]'. Error: %s",
|
|
127
|
+
e)
|
|
128
|
+
else:
|
|
129
|
+
handler = ADKProfilerHandler()
|
|
130
|
+
handler.instrument()
|
|
131
|
+
_library_instrumented["adk"] = True
|
|
132
|
+
logger.debug("ADK callback handler registered")
|
|
133
|
+
|
|
102
134
|
# IMPORTANT: actually call the wrapped function as an async context manager
|
|
103
135
|
async with func(workflow_config, builder) as result:
|
|
104
136
|
yield result
|
|
@@ -117,11 +149,28 @@ def chain_wrapped_build_fn(
|
|
|
117
149
|
Convert an original build function into an async context manager that
|
|
118
150
|
wraps it with a single call to set_framework_profiler_handler, passing
|
|
119
151
|
all frameworks at once.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
original_build_fn (Callable[..., AsyncContextManager]): The original build function to wrap.
|
|
155
|
+
workflow_llms (dict): A dictionary of workflow LLM configurations.
|
|
156
|
+
function_frameworks (list[LLMFrameworkEnum]): A list of LLM frameworks used in the workflow functions.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Callable[..., AsyncContextManager]: The wrapped build function.
|
|
120
160
|
"""
|
|
121
161
|
|
|
122
162
|
# Define a base async context manager that simply calls the original build function.
|
|
123
163
|
@asynccontextmanager
|
|
124
|
-
async def base_fn(*args, **kwargs):
|
|
164
|
+
async def base_fn(*args, **kwargs) -> AsyncIterator[Any]:
|
|
165
|
+
"""Base async context manager that calls the original build function.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
*args: Positional arguments to pass to the original build function.
|
|
169
|
+
**kwargs: Keyword arguments to pass to the original build function.
|
|
170
|
+
|
|
171
|
+
Yields:
|
|
172
|
+
The result of the original build function.
|
|
173
|
+
"""
|
|
125
174
|
async with original_build_fn(*args, **kwargs) as w:
|
|
126
175
|
yield w
|
|
127
176
|
|
|
@@ -18,7 +18,9 @@ import inspect
|
|
|
18
18
|
import uuid
|
|
19
19
|
from collections.abc import Callable
|
|
20
20
|
from typing import Any
|
|
21
|
+
from typing import TypeVar
|
|
21
22
|
from typing import cast
|
|
23
|
+
from typing import overload
|
|
22
24
|
|
|
23
25
|
from pydantic import BaseModel
|
|
24
26
|
|
|
@@ -77,7 +79,24 @@ def push_intermediate_step(step_manager: IntermediateStepManager,
|
|
|
77
79
|
step_manager.push_intermediate_step(payload)
|
|
78
80
|
|
|
79
81
|
|
|
80
|
-
|
|
82
|
+
# Type variable for overloads
|
|
83
|
+
F = TypeVar('F', bound=Callable[..., Any])
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# Overloads for different function types
|
|
87
|
+
@overload
|
|
88
|
+
def track_function(func: F, *, metadata: dict[str, Any] | None = None) -> F:
|
|
89
|
+
"""Overload for when a function is passed directly."""
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@overload
|
|
94
|
+
def track_function(*, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
|
|
95
|
+
"""Overload for decorator factory usage (when called with parentheses)."""
|
|
96
|
+
...
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None) -> Any:
|
|
81
100
|
"""
|
|
82
101
|
Decorator that can wrap any type of function (sync, async, generator,
|
|
83
102
|
async generator) and executes "tracking logic" around it.
|
|
@@ -256,6 +275,19 @@ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
|
|
|
256
275
|
return sync_wrapper
|
|
257
276
|
|
|
258
277
|
|
|
278
|
+
# Overloads for track_unregistered_function
|
|
279
|
+
@overload
|
|
280
|
+
def track_unregistered_function(func: F, *, name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
|
|
281
|
+
"""Overload for when a function is passed directly."""
|
|
282
|
+
...
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@overload
|
|
286
|
+
def track_unregistered_function(*, name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
|
|
287
|
+
"""Overload for decorator factory usage (when called with parentheses)."""
|
|
288
|
+
...
|
|
289
|
+
|
|
290
|
+
|
|
259
291
|
def track_unregistered_function(func: Callable[..., Any] | None = None,
|
|
260
292
|
*,
|
|
261
293
|
name: str | None = None,
|
|
@@ -151,11 +151,11 @@ async def optimize_prompts(
|
|
|
151
151
|
if not init_fn_name:
|
|
152
152
|
raise ValueError(
|
|
153
153
|
"No prompt optimization function configured. Set optimizer.prompt_population_init_function")
|
|
154
|
-
init_fn = builder.get_function(init_fn_name)
|
|
154
|
+
init_fn = await builder.get_function(init_fn_name)
|
|
155
155
|
|
|
156
156
|
recombine_fn = None
|
|
157
157
|
if optimizer_config.prompt.prompt_recombination_function:
|
|
158
|
-
recombine_fn = builder.get_function(optimizer_config.prompt.prompt_recombination_function)
|
|
158
|
+
recombine_fn = await builder.get_function(optimizer_config.prompt.prompt_recombination_function)
|
|
159
159
|
|
|
160
160
|
logger.info(
|
|
161
161
|
"GA Prompt optimization ready: init_fn=%s, recombine_fn=%s",
|
nat/runtime/loader.py
CHANGED
|
@@ -114,7 +114,7 @@ async def load_workflow(config_file: StrPath, max_concurrency: int = -1):
|
|
|
114
114
|
# Must yield the workflow function otherwise it cleans up
|
|
115
115
|
async with WorkflowBuilder.from_config(config=config) as workflow:
|
|
116
116
|
|
|
117
|
-
yield SessionManager(workflow.build(), max_concurrency=max_concurrency)
|
|
117
|
+
yield SessionManager(await workflow.build(), max_concurrency=max_concurrency)
|
|
118
118
|
|
|
119
119
|
|
|
120
120
|
@lru_cache
|
nat/utils/type_converter.py
CHANGED
|
@@ -90,7 +90,7 @@ class TypeConverter:
|
|
|
90
90
|
decomposed = DecomposedType(to_type)
|
|
91
91
|
|
|
92
92
|
# 1) If data is already correct type, return it
|
|
93
|
-
if to_type is None or decomposed.is_instance(
|
|
93
|
+
if to_type is None or decomposed.is_instance(data):
|
|
94
94
|
return data
|
|
95
95
|
|
|
96
96
|
root = decomposed.root
|
|
@@ -198,16 +198,17 @@ class TypeConverter:
|
|
|
198
198
|
"""
|
|
199
199
|
visited = set()
|
|
200
200
|
final = self._try_indirect_conversion(data, to_type, visited)
|
|
201
|
+
src_type = type(data)
|
|
201
202
|
if final is not None:
|
|
202
203
|
# Warn once if found a chain
|
|
203
|
-
self._maybe_warn_indirect(
|
|
204
|
+
self._maybe_warn_indirect(src_type, to_type)
|
|
204
205
|
return final
|
|
205
206
|
|
|
206
207
|
# If no success, try parent's indirect
|
|
207
208
|
if self._parent is not None:
|
|
208
209
|
parent_final = self._parent._try_indirect_convert(data, to_type)
|
|
209
210
|
if parent_final is not None:
|
|
210
|
-
self._maybe_warn_indirect(
|
|
211
|
+
self._maybe_warn_indirect(src_type, to_type)
|
|
211
212
|
return parent_final
|
|
212
213
|
|
|
213
214
|
return None
|
nat/utils/type_utils.py
CHANGED
|
@@ -353,7 +353,7 @@ class DecomposedType:
|
|
|
353
353
|
True if the current type is an instance of the specified instance, False otherwise
|
|
354
354
|
"""
|
|
355
355
|
|
|
356
|
-
return isinstance(instance, self.root)
|
|
356
|
+
return isinstance(instance, self.get_base_type().root)
|
|
357
357
|
|
|
358
358
|
def get_pydantic_schema(self,
|
|
359
359
|
converters: list[collections.abc.Callable] | None = None) -> type[BaseModel] | type[None]:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nvidia-nat
|
|
3
|
-
Version: 1.3.
|
|
3
|
+
Version: 1.3.0a20250925
|
|
4
4
|
Summary: NVIDIA NeMo Agent toolkit
|
|
5
5
|
Author: NVIDIA Corporation
|
|
6
6
|
Maintainer: NVIDIA Corporation
|
|
@@ -215,7 +215,7 @@ Description-Content-Type: text/markdown
|
|
|
215
215
|
License-File: LICENSE-3rd-party.txt
|
|
216
216
|
License-File: LICENSE.md
|
|
217
217
|
Requires-Dist: aioboto3>=11.0.0
|
|
218
|
-
Requires-Dist: authlib~=1.
|
|
218
|
+
Requires-Dist: authlib~=1.5
|
|
219
219
|
Requires-Dist: click~=8.1
|
|
220
220
|
Requires-Dist: colorama~=0.4.6
|
|
221
221
|
Requires-Dist: datasets~=4.0
|
|
@@ -241,10 +241,12 @@ Requires-Dist: PyYAML~=6.0
|
|
|
241
241
|
Requires-Dist: ragas~=0.2.14
|
|
242
242
|
Requires-Dist: rich~=13.9
|
|
243
243
|
Requires-Dist: tabulate~=0.9
|
|
244
|
-
Requires-Dist: uvicorn[standard]~=0.
|
|
244
|
+
Requires-Dist: uvicorn[standard]~=0.34
|
|
245
245
|
Requires-Dist: wikipedia~=1.4
|
|
246
246
|
Provides-Extra: all
|
|
247
247
|
Requires-Dist: nvidia-nat-all; extra == "all"
|
|
248
|
+
Provides-Extra: adk
|
|
249
|
+
Requires-Dist: nvidia-nat-adk; extra == "adk"
|
|
248
250
|
Provides-Extra: agno
|
|
249
251
|
Requires-Dist: nvidia-nat-agno; extra == "agno"
|
|
250
252
|
Provides-Extra: crewai
|
|
@@ -287,6 +289,7 @@ Requires-Dist: nvidia-nat-weave; extra == "weave"
|
|
|
287
289
|
Provides-Extra: zep-cloud
|
|
288
290
|
Requires-Dist: nvidia-nat-zep-cloud; extra == "zep-cloud"
|
|
289
291
|
Provides-Extra: examples
|
|
292
|
+
Requires-Dist: nat_adk_demo; extra == "examples"
|
|
290
293
|
Requires-Dist: nat_agno_personal_finance; extra == "examples"
|
|
291
294
|
Requires-Dist: nat_alert_triage_agent; extra == "examples"
|
|
292
295
|
Requires-Dist: nat_automated_description_generation; extra == "examples"
|