nvidia-nat-adk 1.1.0a20251020__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +21 -0
- nat/plugins/adk/__init__.py +14 -0
- nat/plugins/adk/adk_callback_handler.py +308 -0
- nat/plugins/adk/llm.py +108 -0
- nat/plugins/adk/register.py +21 -0
- nat/plugins/adk/tool_wrapper.py +159 -0
- nvidia_nat_adk-1.1.0a20251020.dist-info/METADATA +43 -0
- nvidia_nat_adk-1.1.0a20251020.dist-info/RECORD +13 -0
- nvidia_nat_adk-1.1.0a20251020.dist-info/WHEEL +5 -0
- nvidia_nat_adk-1.1.0a20251020.dist-info/entry_points.txt +2 -0
- nvidia_nat_adk-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_adk-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_adk-1.1.0a20251020.dist-info/top_level.txt +1 -0
nat/meta/pypi.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
<!--
|
|
2
|
+
SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
you may not use this file except in compliance with the License.
|
|
7
|
+
You may obtain a copy of the License at
|
|
8
|
+
|
|
9
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
|
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
See the License for the specific language governing permissions and
|
|
15
|
+
limitations under the License.
|
|
16
|
+
-->
|
|
17
|
+
|
|
18
|
+

|
|
19
|
+
|
|
20
|
+
# NVIDIA NeMo Agent Toolkit — Google ADK Subpackage
|
|
21
|
+
Subpackage providing Google ADK integration for the NVIDIA NeMo Agent Toolkit.
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import copy
|
|
17
|
+
import logging
|
|
18
|
+
import threading
|
|
19
|
+
import time
|
|
20
|
+
from collections.abc import Callable
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
from nat.builder.context import Context
|
|
24
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
25
|
+
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
26
|
+
from nat.data_models.intermediate_step import IntermediateStepType
|
|
27
|
+
from nat.data_models.intermediate_step import StreamEventData
|
|
28
|
+
from nat.data_models.intermediate_step import TraceMetadata
|
|
29
|
+
from nat.data_models.intermediate_step import UsageInfo
|
|
30
|
+
from nat.profiler.callbacks.base_callback_class import BaseProfilerCallback
|
|
31
|
+
from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ADKProfilerHandler(BaseProfilerCallback):
|
|
37
|
+
"""
|
|
38
|
+
A callback manager/handler for Google ADK that intercepts calls to:
|
|
39
|
+
- Tools
|
|
40
|
+
- LLMs
|
|
41
|
+
|
|
42
|
+
to collect usage statistics (tokens, inputs, outputs, time intervals, etc.)
|
|
43
|
+
and store them in NeMo Agent Toolkit's usage_stats queue for subsequent analysis.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self) -> None:
|
|
47
|
+
super().__init__()
|
|
48
|
+
self._lock = threading.Lock()
|
|
49
|
+
self.last_call_ts = time.time()
|
|
50
|
+
self.step_manager = Context.get().intermediate_step_manager
|
|
51
|
+
|
|
52
|
+
# Original references to Google ADK Tool and LLM methods (for uninstrumenting if needed)
|
|
53
|
+
self._original_tool_call = None
|
|
54
|
+
self._original_llm_call = None
|
|
55
|
+
self._instrumented = False
|
|
56
|
+
|
|
57
|
+
def instrument(self) -> None:
|
|
58
|
+
"""
|
|
59
|
+
Monkey-patch the relevant Google ADK methods with usage-stat collection logic.
|
|
60
|
+
Assumes the 'google-adk' library is installed.
|
|
61
|
+
"""
|
|
62
|
+
import litellm
|
|
63
|
+
|
|
64
|
+
if getattr(self, "_instrumented", False):
|
|
65
|
+
logger.debug("ADKProfilerHandler already instrumented; skipping.")
|
|
66
|
+
return
|
|
67
|
+
try:
|
|
68
|
+
from google.adk.tools.function_tool import FunctionTool
|
|
69
|
+
except Exception as _e:
|
|
70
|
+
logger.exception("ADK import failed; skipping instrumentation")
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
# Save the originals
|
|
74
|
+
self._original_tool_call = getattr(FunctionTool, "run_async", None)
|
|
75
|
+
self._original_llm_call = getattr(litellm, "acompletion", None)
|
|
76
|
+
|
|
77
|
+
# Patch if available
|
|
78
|
+
if self._original_tool_call:
|
|
79
|
+
FunctionTool.run_async = self._tool_use_monkey_patch()
|
|
80
|
+
|
|
81
|
+
if self._original_llm_call:
|
|
82
|
+
litellm.acompletion = self._llm_call_monkey_patch()
|
|
83
|
+
|
|
84
|
+
logger.debug("ADKProfilerHandler instrumentation applied successfully.")
|
|
85
|
+
self._instrumented = True
|
|
86
|
+
|
|
87
|
+
def uninstrument(self) -> None:
|
|
88
|
+
""" Restore the original Google ADK methods.
|
|
89
|
+
Add an explicit unpatch to avoid side-effects across tests/process lifetime.
|
|
90
|
+
"""
|
|
91
|
+
try:
|
|
92
|
+
import litellm
|
|
93
|
+
from google.adk.tools.function_tool import FunctionTool
|
|
94
|
+
if self._original_tool_call:
|
|
95
|
+
FunctionTool.run_async = self._original_tool_call
|
|
96
|
+
if self._original_llm_call:
|
|
97
|
+
litellm.acompletion = self._original_llm_call
|
|
98
|
+
logger.debug("ADKProfilerHandler uninstrumented successfully.")
|
|
99
|
+
except Exception as _e:
|
|
100
|
+
logger.exception("Failed to uninstrument ADKProfilerHandler")
|
|
101
|
+
|
|
102
|
+
def _tool_use_monkey_patch(self) -> Callable[..., Any]:
|
|
103
|
+
"""
|
|
104
|
+
Returns a function that wraps calls to BaseTool.run_async with usage-logging.
|
|
105
|
+
"""
|
|
106
|
+
original_func = self._original_tool_call
|
|
107
|
+
|
|
108
|
+
async def wrapped_tool_use(base_tool_instance, *args, **kwargs) -> Any:
|
|
109
|
+
"""
|
|
110
|
+
Replicates _tool_use_wrapper logic without wrapt: collects usage stats,
|
|
111
|
+
calls the original, and captures output stats.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
base_tool_instance (FunctionTool): The instance of the tool being called.
|
|
115
|
+
*args: Positional arguments to the tool.
|
|
116
|
+
**kwargs: Keyword arguments to the tool.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Any: The result of the tool execution.
|
|
120
|
+
"""
|
|
121
|
+
now = time.time()
|
|
122
|
+
tool_name = ""
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
tool_name = base_tool_instance.name
|
|
126
|
+
except Exception as _e:
|
|
127
|
+
logger.exception("Error getting tool name")
|
|
128
|
+
tool_name = ""
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
# Pre-call usage event - safely extract kwargs args if present
|
|
132
|
+
kwargs_args = (kwargs.get("args", {}) if isinstance(kwargs.get("args"), dict) else {})
|
|
133
|
+
stats = IntermediateStepPayload(
|
|
134
|
+
event_type=IntermediateStepType.TOOL_START,
|
|
135
|
+
framework=LLMFrameworkEnum.ADK,
|
|
136
|
+
name=tool_name,
|
|
137
|
+
data=StreamEventData(),
|
|
138
|
+
metadata=TraceMetadata(tool_inputs={
|
|
139
|
+
"args": args, "kwargs": dict(kwargs_args)
|
|
140
|
+
}),
|
|
141
|
+
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Store the UUID to ensure the END event uses the same ID
|
|
145
|
+
step_uuid = stats.UUID
|
|
146
|
+
self.step_manager.push_intermediate_step(stats)
|
|
147
|
+
|
|
148
|
+
with self._lock:
|
|
149
|
+
self.last_call_ts = now
|
|
150
|
+
|
|
151
|
+
# Call the original _use(...)
|
|
152
|
+
if original_func is None:
|
|
153
|
+
raise RuntimeError(
|
|
154
|
+
"Original tool function is None - instrumentation may not have been set up correctly")
|
|
155
|
+
result = await original_func(base_tool_instance, *args, **kwargs)
|
|
156
|
+
now = time.time()
|
|
157
|
+
# Post-call usage stats - safely extract kwargs args if present
|
|
158
|
+
kwargs_args = (kwargs.get("args", {}) if isinstance(kwargs.get("args"), dict) else {})
|
|
159
|
+
usage_stat = IntermediateStepPayload(
|
|
160
|
+
event_type=IntermediateStepType.TOOL_END,
|
|
161
|
+
span_event_timestamp=now,
|
|
162
|
+
framework=LLMFrameworkEnum.ADK,
|
|
163
|
+
name=tool_name,
|
|
164
|
+
data=StreamEventData(
|
|
165
|
+
input={
|
|
166
|
+
"args": args, "kwargs": dict(kwargs_args)
|
|
167
|
+
},
|
|
168
|
+
output=str(result),
|
|
169
|
+
),
|
|
170
|
+
metadata=TraceMetadata(tool_outputs={"result": str(result)}),
|
|
171
|
+
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
|
|
172
|
+
UUID=step_uuid, # Use the same UUID as the START event
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
self.step_manager.push_intermediate_step(usage_stat)
|
|
176
|
+
|
|
177
|
+
return result
|
|
178
|
+
|
|
179
|
+
except Exception as _e:
|
|
180
|
+
logger.exception("BaseTool error occured")
|
|
181
|
+
raise
|
|
182
|
+
|
|
183
|
+
return wrapped_tool_use
|
|
184
|
+
|
|
185
|
+
def _llm_call_monkey_patch(self) -> Callable[..., Any]:
|
|
186
|
+
"""
|
|
187
|
+
Returns a function that wraps calls to litellm.acompletion(...) with usage-logging.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Callable[..., Any]: The wrapped function.
|
|
191
|
+
"""
|
|
192
|
+
original_func = self._original_llm_call
|
|
193
|
+
|
|
194
|
+
async def wrapped_llm_call(*args, **kwargs) -> Any:
|
|
195
|
+
"""
|
|
196
|
+
Replicates _llm_call_wrapper logic without wrapt: collects usage stats,
|
|
197
|
+
calls the original, and captures output stats.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
*args: Positional arguments to the LLM call.
|
|
201
|
+
**kwargs: Keyword arguments to the LLM call.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Any: The result of the LLM call.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
now = time.time()
|
|
208
|
+
with self._lock:
|
|
209
|
+
seconds_between_calls = int(now - self.last_call_ts)
|
|
210
|
+
model_name = kwargs.get("model")
|
|
211
|
+
if not model_name and args:
|
|
212
|
+
first = args[0]
|
|
213
|
+
if isinstance(first, str):
|
|
214
|
+
model_name = first
|
|
215
|
+
model_name = model_name or ""
|
|
216
|
+
|
|
217
|
+
model_input = ""
|
|
218
|
+
try:
|
|
219
|
+
for message in kwargs.get("messages", []):
|
|
220
|
+
content = message.get("content", "")
|
|
221
|
+
if isinstance(content, list):
|
|
222
|
+
for part in content:
|
|
223
|
+
if isinstance(part, dict):
|
|
224
|
+
model_input += str(part.get("text", "")) # text parts
|
|
225
|
+
else:
|
|
226
|
+
model_input += str(part)
|
|
227
|
+
else:
|
|
228
|
+
model_input += content or ""
|
|
229
|
+
except Exception as _e:
|
|
230
|
+
logger.exception("Error getting model input")
|
|
231
|
+
|
|
232
|
+
# Record the start event
|
|
233
|
+
input_stats = IntermediateStepPayload(
|
|
234
|
+
event_type=IntermediateStepType.LLM_START,
|
|
235
|
+
framework=LLMFrameworkEnum.ADK,
|
|
236
|
+
name=model_name,
|
|
237
|
+
data=StreamEventData(input=model_input),
|
|
238
|
+
metadata=TraceMetadata(chat_inputs=copy.deepcopy(kwargs.get("messages", []))),
|
|
239
|
+
usage_info=UsageInfo(
|
|
240
|
+
token_usage=TokenUsageBaseModel(),
|
|
241
|
+
num_llm_calls=1,
|
|
242
|
+
seconds_between_calls=seconds_between_calls,
|
|
243
|
+
),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Store the UUID to ensure the END event uses the same ID
|
|
247
|
+
step_uuid = input_stats.UUID
|
|
248
|
+
self.step_manager.push_intermediate_step(input_stats)
|
|
249
|
+
|
|
250
|
+
# Call the original litellm.acompletion(...)
|
|
251
|
+
if original_func is None:
|
|
252
|
+
raise RuntimeError("Original LLM function is None - instrumentation may not have been set up correctly")
|
|
253
|
+
output = await original_func(*args, **kwargs)
|
|
254
|
+
|
|
255
|
+
model_output = ""
|
|
256
|
+
try:
|
|
257
|
+
for choice in output.choices:
|
|
258
|
+
msg = choice.message
|
|
259
|
+
model_output += msg.content or ""
|
|
260
|
+
except Exception as _e:
|
|
261
|
+
logger.exception("Error getting model output")
|
|
262
|
+
|
|
263
|
+
now = time.time()
|
|
264
|
+
# Record the end event
|
|
265
|
+
# Prepare safe metadata and usage
|
|
266
|
+
chat_resp: dict[str, Any] = {}
|
|
267
|
+
try:
|
|
268
|
+
if getattr(output, "choices", []):
|
|
269
|
+
first_choice = output.choices[0]
|
|
270
|
+
chat_resp = first_choice.model_dump() if hasattr(
|
|
271
|
+
first_choice, "model_dump") else getattr(first_choice, "__dict__", {}) or {}
|
|
272
|
+
except Exception as _e:
|
|
273
|
+
logger.exception("Error preparing chat_responses")
|
|
274
|
+
|
|
275
|
+
usage_payload: dict[str, Any] = {}
|
|
276
|
+
try:
|
|
277
|
+
usage_obj = getattr(output, "usage", None) or (getattr(output, "model_extra", {}) or {}).get("usage")
|
|
278
|
+
if usage_obj:
|
|
279
|
+
if hasattr(usage_obj, "model_dump"):
|
|
280
|
+
usage_payload = usage_obj.model_dump()
|
|
281
|
+
elif isinstance(usage_obj, dict):
|
|
282
|
+
usage_payload = usage_obj
|
|
283
|
+
except Exception as _e:
|
|
284
|
+
logger.exception("Error preparing token usage")
|
|
285
|
+
|
|
286
|
+
output_stats = IntermediateStepPayload(
|
|
287
|
+
event_type=IntermediateStepType.LLM_END,
|
|
288
|
+
span_event_timestamp=now,
|
|
289
|
+
framework=LLMFrameworkEnum.ADK,
|
|
290
|
+
name=model_name,
|
|
291
|
+
data=StreamEventData(input=model_input, output=model_output),
|
|
292
|
+
metadata=TraceMetadata(chat_responses=chat_resp),
|
|
293
|
+
usage_info=UsageInfo(
|
|
294
|
+
token_usage=TokenUsageBaseModel(**usage_payload),
|
|
295
|
+
num_llm_calls=1,
|
|
296
|
+
seconds_between_calls=seconds_between_calls,
|
|
297
|
+
),
|
|
298
|
+
UUID=step_uuid, # Use the same UUID as the START event
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
self.step_manager.push_intermediate_step(output_stats)
|
|
302
|
+
|
|
303
|
+
with self._lock:
|
|
304
|
+
self.last_call_ts = now
|
|
305
|
+
|
|
306
|
+
return output
|
|
307
|
+
|
|
308
|
+
return wrapped_llm_call
|
nat/plugins/adk/llm.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
from nat.builder.builder import Builder
|
|
20
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
21
|
+
from nat.cli.register_workflow import register_llm_client
|
|
22
|
+
from nat.llm.azure_openai_llm import AzureOpenAIModelConfig
|
|
23
|
+
from nat.llm.litellm_llm import LiteLlmModelConfig
|
|
24
|
+
from nat.llm.nim_llm import NIMModelConfig
|
|
25
|
+
from nat.llm.openai_llm import OpenAIModelConfig
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_llm_client(config_type=AzureOpenAIModelConfig, wrapper_type=LLMFrameworkEnum.ADK)
|
|
31
|
+
async def azure_openai_adk(config: AzureOpenAIModelConfig, _builder: Builder):
|
|
32
|
+
"""Create and yield a Google ADK `AzureOpenAI` client from a NAT `AzureOpenAIModelConfig`.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config (AzureOpenAIModelConfig): The configuration for the AzureOpenAI model.
|
|
36
|
+
_builder (Builder): The NAT builder instance.
|
|
37
|
+
"""
|
|
38
|
+
from google.adk.models.lite_llm import LiteLlm
|
|
39
|
+
|
|
40
|
+
config_dict = config.model_dump(
|
|
41
|
+
exclude={"type", "max_retries", "thinking", "azure_endpoint", "azure_deployment", "model_name", "model"},
|
|
42
|
+
by_alias=True,
|
|
43
|
+
exclude_none=True,
|
|
44
|
+
)
|
|
45
|
+
if config.azure_endpoint:
|
|
46
|
+
config_dict["api_base"] = config.azure_endpoint
|
|
47
|
+
|
|
48
|
+
yield LiteLlm(f"azure/{config.azure_deployment}", **config_dict)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@register_llm_client(config_type=LiteLlmModelConfig, wrapper_type=LLMFrameworkEnum.ADK)
|
|
52
|
+
async def litellm_adk(litellm_config: LiteLlmModelConfig, _builder: Builder):
|
|
53
|
+
from google.adk.models.lite_llm import LiteLlm
|
|
54
|
+
yield LiteLlm(**litellm_config.model_dump(
|
|
55
|
+
exclude={"type", "max_retries", "thinking"},
|
|
56
|
+
by_alias=True,
|
|
57
|
+
exclude_none=True,
|
|
58
|
+
))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.ADK)
|
|
62
|
+
async def nim_adk(config: NIMModelConfig, _builder: Builder):
|
|
63
|
+
"""Create and yield a Google ADK `NIM` client from a NAT `NIMModelConfig`.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
config (NIMModelConfig): The configuration for the NIM model.
|
|
67
|
+
_builder (Builder): The NAT builder instance.
|
|
68
|
+
"""
|
|
69
|
+
import litellm
|
|
70
|
+
from google.adk.models.lite_llm import LiteLlm
|
|
71
|
+
|
|
72
|
+
logger.warning("NIMs do not currently support tools with ADK. Tools will be ignored.")
|
|
73
|
+
litellm.add_function_to_prompt = True
|
|
74
|
+
litellm.drop_params = True
|
|
75
|
+
|
|
76
|
+
if (api_key := os.getenv("NVIDIA_API_KEY", None)) is not None:
|
|
77
|
+
os.environ["NVIDIA_NIM_API_KEY"] = api_key
|
|
78
|
+
|
|
79
|
+
config_dict = config.model_dump(
|
|
80
|
+
exclude={"type", "max_retries", "thinking", "model_name", "model", "base_url"},
|
|
81
|
+
by_alias=True,
|
|
82
|
+
exclude_none=True,
|
|
83
|
+
)
|
|
84
|
+
if config.base_url:
|
|
85
|
+
config_dict["api_base"] = config.base_url
|
|
86
|
+
|
|
87
|
+
yield LiteLlm(f"nvidia_nim/{config.model_name}", **config_dict)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@register_llm_client(config_type=OpenAIModelConfig, wrapper_type=LLMFrameworkEnum.ADK)
|
|
91
|
+
async def openai_adk(config: OpenAIModelConfig, _builder: Builder):
|
|
92
|
+
"""Create and yield a Google ADK `OpenAI` client from a NAT `OpenAIModelConfig`.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
config (OpenAIModelConfig): The configuration for the OpenAI model.
|
|
96
|
+
_builder (Builder): The NAT builder instance.
|
|
97
|
+
"""
|
|
98
|
+
from google.adk.models.lite_llm import LiteLlm
|
|
99
|
+
|
|
100
|
+
config_dict = config.model_dump(
|
|
101
|
+
exclude={"type", "max_retries", "thinking", "model_name", "model", "base_url"},
|
|
102
|
+
by_alias=True,
|
|
103
|
+
exclude_none=True,
|
|
104
|
+
)
|
|
105
|
+
if config.base_url:
|
|
106
|
+
config_dict["api_base"] = config.base_url
|
|
107
|
+
|
|
108
|
+
yield LiteLlm(config.model_name, **config_dict)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# pylint: disable=unused-import
|
|
17
|
+
# flake8: noqa
|
|
18
|
+
|
|
19
|
+
from . import adk_callback_handler
|
|
20
|
+
from . import llm
|
|
21
|
+
from . import tool_wrapper
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Tool Wrapper file"""
|
|
16
|
+
import logging
|
|
17
|
+
import types
|
|
18
|
+
from collections.abc import AsyncIterator
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
from dataclasses import is_dataclass
|
|
21
|
+
from typing import Any
|
|
22
|
+
from typing import get_args
|
|
23
|
+
from typing import get_origin
|
|
24
|
+
|
|
25
|
+
from pydantic import BaseModel
|
|
26
|
+
|
|
27
|
+
from nat.builder.builder import Builder
|
|
28
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
29
|
+
from nat.builder.function import Function
|
|
30
|
+
from nat.cli.register_workflow import register_tool_wrapper
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def resolve_type(t: Any) -> Any:
|
|
36
|
+
"""Return the non-None member of a Union/PEP 604 union;
|
|
37
|
+
otherwise return the type unchanged.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
t (Any): The type to resolve.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Any: The resolved type.
|
|
44
|
+
"""
|
|
45
|
+
origin = get_origin(t)
|
|
46
|
+
if origin is types.UnionType:
|
|
47
|
+
for arg in get_args(t):
|
|
48
|
+
if arg is not type(None):
|
|
49
|
+
return arg
|
|
50
|
+
return t
|
|
51
|
+
return t
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@register_tool_wrapper(wrapper_type=LLMFrameworkEnum.ADK)
|
|
55
|
+
def google_adk_tool_wrapper(
|
|
56
|
+
name: str,
|
|
57
|
+
fn: Function,
|
|
58
|
+
_builder: Builder # pylint: disable=W0613
|
|
59
|
+
) -> Any: # Changed from Callable[..., Any] to Any to allow FunctionTool return
|
|
60
|
+
"""Wrap a NAT `Function` as a Google ADK `FunctionTool`.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
name (str): The name of the tool.
|
|
64
|
+
fn (Function): The NAT `Function` to wrap.
|
|
65
|
+
_builder (Builder): The NAT `Builder` (not used).
|
|
66
|
+
Returns:
|
|
67
|
+
A Google ADK `FunctionTool` wrapping the NAT `Function`.
|
|
68
|
+
"""
|
|
69
|
+
import inspect
|
|
70
|
+
|
|
71
|
+
async def callable_ainvoke(*args: Any, **kwargs: Any) -> Any:
|
|
72
|
+
"""Async function to invoke the NAT function.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
*args: Positional arguments to pass to the NAT function.
|
|
76
|
+
**kwargs: Keyword arguments to pass to the NAT function.
|
|
77
|
+
Returns:
|
|
78
|
+
Any: The result of invoking the NAT function.
|
|
79
|
+
"""
|
|
80
|
+
return await fn.acall_invoke(*args, **kwargs)
|
|
81
|
+
|
|
82
|
+
async def callable_astream(*args: Any, **kwargs: Any) -> AsyncIterator[Any]:
|
|
83
|
+
"""Async generator to stream results from the NAT function.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
*args: Positional arguments to pass to the NAT function.
|
|
87
|
+
**kwargs: Keyword arguments to pass to the NAT function.
|
|
88
|
+
Yields:
|
|
89
|
+
Any: Streamed items from the NAT function.
|
|
90
|
+
"""
|
|
91
|
+
async for item in fn.acall_stream(*args, **kwargs):
|
|
92
|
+
yield item
|
|
93
|
+
|
|
94
|
+
def nat_function(
|
|
95
|
+
func: Callable[..., Any] | None = None,
|
|
96
|
+
*,
|
|
97
|
+
name: str = name,
|
|
98
|
+
description: str | None = fn.description,
|
|
99
|
+
input_schema: Any = fn.input_schema,
|
|
100
|
+
) -> Callable[..., Any]:
|
|
101
|
+
"""
|
|
102
|
+
Decorator to wrap a function as a NAT function.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
func (Callable): The function to wrap.
|
|
106
|
+
name (str): The name of the function.
|
|
107
|
+
description (str): The description of the function.
|
|
108
|
+
input_schema (BaseModel): The Pydantic model defining the input schema.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Callable[..., Any]: The wrapped function.
|
|
112
|
+
"""
|
|
113
|
+
if func is None:
|
|
114
|
+
raise ValueError("'func' must be provided.")
|
|
115
|
+
|
|
116
|
+
# If input_schema is a dataclass, convert it to a Pydantic model
|
|
117
|
+
if input_schema is not None and is_dataclass(input_schema):
|
|
118
|
+
input_schema = BaseModel.model_validate(input_schema)
|
|
119
|
+
|
|
120
|
+
def decorator(func_to_wrap: Callable[..., Any]) -> Callable[..., Any]:
|
|
121
|
+
"""
|
|
122
|
+
Decorator to set metadata on the function.
|
|
123
|
+
"""
|
|
124
|
+
# Set the function's metadata
|
|
125
|
+
if name is not None:
|
|
126
|
+
func_to_wrap.__name__ = name
|
|
127
|
+
if description is not None:
|
|
128
|
+
func_to_wrap.__doc__ = description
|
|
129
|
+
|
|
130
|
+
# Set signature only if input_schema is provided
|
|
131
|
+
params: list[inspect.Parameter] = []
|
|
132
|
+
if input_schema is not None:
|
|
133
|
+
annotations = getattr(input_schema, "__annotations__", {}) or {}
|
|
134
|
+
for param_name, param_annotation in annotations.items():
|
|
135
|
+
params.append(
|
|
136
|
+
inspect.Parameter(
|
|
137
|
+
param_name,
|
|
138
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
139
|
+
annotation=resolve_type(param_annotation),
|
|
140
|
+
))
|
|
141
|
+
setattr(func_to_wrap, "__signature__", inspect.Signature(parameters=params))
|
|
142
|
+
|
|
143
|
+
return func_to_wrap
|
|
144
|
+
|
|
145
|
+
# If func is None, return the decorator itself to be applied later
|
|
146
|
+
if func is None:
|
|
147
|
+
return decorator
|
|
148
|
+
# Otherwise, apply the decorator to the provided function
|
|
149
|
+
return decorator(func)
|
|
150
|
+
|
|
151
|
+
from google.adk.tools.function_tool import FunctionTool
|
|
152
|
+
|
|
153
|
+
if fn.has_streaming_output and not fn.has_single_output:
|
|
154
|
+
logger.debug("Creating streaming FunctionTool for: %s", name)
|
|
155
|
+
callable_tool = nat_function(func=callable_astream)
|
|
156
|
+
else:
|
|
157
|
+
logger.debug("Creating non-streaming FunctionTool for: %s", name)
|
|
158
|
+
callable_tool = nat_function(func=callable_ainvoke)
|
|
159
|
+
return FunctionTool(callable_tool)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: nvidia-nat-adk
|
|
3
|
+
Version: 1.1.0a20251020
|
|
4
|
+
Summary: Subpackage for Google ADK integration in NeMo Agent Toolkit
|
|
5
|
+
Author: NVIDIA Corporation
|
|
6
|
+
Maintainer: NVIDIA Corporation
|
|
7
|
+
License: Apache-2.0
|
|
8
|
+
Project-URL: documentation, https://docs.nvidia.com/nemo/agent-toolkit/latest/
|
|
9
|
+
Project-URL: source, https://github.com/NVIDIA/NeMo-Agent-Toolkit
|
|
10
|
+
Keywords: ai,rag,agents
|
|
11
|
+
Classifier: Programming Language :: Python
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
15
|
+
Requires-Python: <3.14,>=3.11
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE-3rd-party.txt
|
|
18
|
+
License-File: LICENSE.md
|
|
19
|
+
Requires-Dist: nvidia-nat[litellm]==v1.1.0a20251020
|
|
20
|
+
Requires-Dist: google-adk~=1.14.1
|
|
21
|
+
Dynamic: license-file
|
|
22
|
+
|
|
23
|
+
<!--
|
|
24
|
+
SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
25
|
+
SPDX-License-Identifier: Apache-2.0
|
|
26
|
+
|
|
27
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
28
|
+
you may not use this file except in compliance with the License.
|
|
29
|
+
You may obtain a copy of the License at
|
|
30
|
+
|
|
31
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
32
|
+
|
|
33
|
+
Unless required by applicable law or agreed to in writing, software
|
|
34
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
35
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
36
|
+
See the License for the specific language governing permissions and
|
|
37
|
+
limitations under the License.
|
|
38
|
+
-->
|
|
39
|
+
|
|
40
|
+

|
|
41
|
+
|
|
42
|
+
# NVIDIA NeMo Agent Toolkit — Google ADK Subpackage
|
|
43
|
+
Subpackage providing Google ADK integration for the NVIDIA NeMo Agent Toolkit.
|