nvidia-nat-strands 1.4.0a20251209__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +23 -0
- nat/plugins/strands/__init__.py +14 -0
- nat/plugins/strands/llm.py +339 -0
- nat/plugins/strands/register.py +20 -0
- nat/plugins/strands/strands_callback_handler.py +615 -0
- nat/plugins/strands/tool_wrapper.py +127 -0
- nvidia_nat_strands-1.4.0a20251209.dist-info/METADATA +43 -0
- nvidia_nat_strands-1.4.0a20251209.dist-info/RECORD +11 -0
- nvidia_nat_strands-1.4.0a20251209.dist-info/WHEEL +5 -0
- nvidia_nat_strands-1.4.0a20251209.dist-info/entry_points.txt +2 -0
- nvidia_nat_strands-1.4.0a20251209.dist-info/top_level.txt +1 -0
nat/meta/pypi.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
<!--
|
|
2
|
+
SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
you may not use this file except in compliance with the License.
|
|
7
|
+
You may obtain a copy of the License at
|
|
8
|
+
|
|
9
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
|
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
See the License for the specific language governing permissions and
|
|
15
|
+
limitations under the License.
|
|
16
|
+
-->
|
|
17
|
+
|
|
18
|
+

|
|
19
|
+
|
|
20
|
+
# NVIDIA NeMo Agent Toolkit Subpackage
|
|
21
|
+
This is a subpackage for AWS Strands integration in NeMo Agent toolkit.
|
|
22
|
+
|
|
23
|
+
For more information about the NVIDIA NeMo Agent toolkit, please visit the [NeMo Agent toolkit GitHub Repo](https://github.com/NVIDIA/NeMo-Agent-Toolkit).
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""LLM provider wrappers for AWS Strands integration with NVIDIA NeMo Agent toolkit.
|
|
16
|
+
|
|
17
|
+
This module provides Strands-compatible LLM client wrappers for the following providers:
|
|
18
|
+
|
|
19
|
+
Supported Providers
|
|
20
|
+
-------------------
|
|
21
|
+
- **OpenAI**: Direct OpenAI API integration through ``OpenAIModelConfig``
|
|
22
|
+
- **NVIDIA NIM**: OpenAI-compatible endpoints for NVIDIA models through ``NIMModelConfig``
|
|
23
|
+
- **AWS Bedrock**: Amazon Bedrock models (such as Claude) through ``AWSBedrockModelConfig``
|
|
24
|
+
|
|
25
|
+
Each wrapper:
|
|
26
|
+
- Validates that Responses API features are disabled (Strands manages tool execution)
|
|
27
|
+
- Patches clients with NeMo Agent toolkit retry logic from ``RetryMixin``
|
|
28
|
+
- Injects chain-of-thought prompts when ``ThinkingMixin`` is configured
|
|
29
|
+
- Removes NeMo Agent toolkit-specific config keys before instantiating Strands clients
|
|
30
|
+
|
|
31
|
+
Future Provider Support
|
|
32
|
+
-----------------------
|
|
33
|
+
The following providers are not yet supported but could be contributed:
|
|
34
|
+
|
|
35
|
+
- **Azure OpenAI**: Would require a Strands Azure OpenAI client wrapper similar to the
|
|
36
|
+
existing OpenAI integration. Contributors should follow the pattern established in
|
|
37
|
+
``openai_strands`` and ensure Azure-specific authentication (endpoint, API version,
|
|
38
|
+
deployment name) is properly handled.
|
|
39
|
+
|
|
40
|
+
- **LiteLLM**: The wrapper would need to handle LiteLLM's unified interface across
|
|
41
|
+
multiple providers while preserving Strands' tool execution semantics.
|
|
42
|
+
|
|
43
|
+
See the Strands documentation at https://strandsagents.com for model provider details.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
import os
|
|
47
|
+
from collections.abc import AsyncGenerator
|
|
48
|
+
from typing import Any
|
|
49
|
+
from typing import TypeVar
|
|
50
|
+
|
|
51
|
+
from nat.builder.builder import Builder
|
|
52
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
53
|
+
from nat.cli.register_workflow import register_llm_client
|
|
54
|
+
from nat.data_models.common import get_secret_value
|
|
55
|
+
from nat.data_models.llm import LLMBaseConfig
|
|
56
|
+
from nat.data_models.retry_mixin import RetryMixin
|
|
57
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
58
|
+
from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig
|
|
59
|
+
from nat.llm.nim_llm import NIMModelConfig
|
|
60
|
+
from nat.llm.openai_llm import OpenAIModelConfig
|
|
61
|
+
from nat.llm.utils.thinking import BaseThinkingInjector
|
|
62
|
+
from nat.llm.utils.thinking import FunctionArgumentWrapper
|
|
63
|
+
from nat.llm.utils.thinking import patch_with_thinking
|
|
64
|
+
from nat.utils.exception_handlers.automatic_retries import patch_with_retry
|
|
65
|
+
from nat.utils.responses_api import validate_no_responses_api
|
|
66
|
+
from nat.utils.type_utils import override
|
|
67
|
+
|
|
68
|
+
ModelType = TypeVar("ModelType")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType:
|
|
72
|
+
"""Patch a Strands client per NAT config (retries/thinking) and return it.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
client: Concrete Strands model client instance.
|
|
76
|
+
llm_config: NAT LLM config with Retry/Thinking mixins.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
The patched client instance.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
class StrandsThinkingInjector(BaseThinkingInjector):
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
def inject(self, messages, *args, **kwargs) -> FunctionArgumentWrapper:
|
|
86
|
+
thinking_prompt = self.system_prompt
|
|
87
|
+
if not thinking_prompt:
|
|
88
|
+
return FunctionArgumentWrapper(messages, *args, **kwargs)
|
|
89
|
+
|
|
90
|
+
# Strands calls: model.stream(messages, tool_specs, system_prompt)
|
|
91
|
+
# So system_prompt is the 3rd positional argument (index 1 in *args)
|
|
92
|
+
new_args = list(args)
|
|
93
|
+
new_kwargs = dict(kwargs)
|
|
94
|
+
|
|
95
|
+
# Check if system_prompt is passed as positional argument
|
|
96
|
+
if len(new_args) >= 2: # tool_specs, system_prompt
|
|
97
|
+
existing_system_prompt = new_args[1] or "" # system_prompt
|
|
98
|
+
if existing_system_prompt:
|
|
99
|
+
# Prepend thinking prompt to existing system prompt
|
|
100
|
+
combined_prompt = f"{thinking_prompt}\n\n{existing_system_prompt}"
|
|
101
|
+
else:
|
|
102
|
+
combined_prompt = thinking_prompt
|
|
103
|
+
new_args[1] = combined_prompt
|
|
104
|
+
elif "system_prompt" in new_kwargs:
|
|
105
|
+
# system_prompt passed as keyword argument
|
|
106
|
+
existing_system_prompt = new_kwargs["system_prompt"] or ""
|
|
107
|
+
if existing_system_prompt:
|
|
108
|
+
combined_prompt = f"{thinking_prompt}\n\n{existing_system_prompt}"
|
|
109
|
+
else:
|
|
110
|
+
combined_prompt = thinking_prompt
|
|
111
|
+
new_kwargs["system_prompt"] = combined_prompt
|
|
112
|
+
else:
|
|
113
|
+
# No system_prompt provided, add as keyword argument
|
|
114
|
+
new_kwargs["system_prompt"] = thinking_prompt
|
|
115
|
+
|
|
116
|
+
return FunctionArgumentWrapper(messages, *new_args, **new_kwargs)
|
|
117
|
+
|
|
118
|
+
if isinstance(llm_config, RetryMixin):
|
|
119
|
+
client = patch_with_retry(client,
|
|
120
|
+
retries=llm_config.num_retries,
|
|
121
|
+
retry_codes=llm_config.retry_on_status_codes,
|
|
122
|
+
retry_on_messages=llm_config.retry_on_errors)
|
|
123
|
+
|
|
124
|
+
if isinstance(llm_config, ThinkingMixin) and llm_config.thinking_system_prompt is not None:
|
|
125
|
+
client = patch_with_thinking(
|
|
126
|
+
client,
|
|
127
|
+
StrandsThinkingInjector(
|
|
128
|
+
system_prompt=llm_config.thinking_system_prompt,
|
|
129
|
+
function_names=[
|
|
130
|
+
"stream",
|
|
131
|
+
"structured_output",
|
|
132
|
+
],
|
|
133
|
+
))
|
|
134
|
+
|
|
135
|
+
return client
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@register_llm_client(config_type=OpenAIModelConfig, wrapper_type=LLMFrameworkEnum.STRANDS)
|
|
139
|
+
async def openai_strands(llm_config: OpenAIModelConfig, _builder: Builder) -> AsyncGenerator[Any, None]:
|
|
140
|
+
"""Build a Strands OpenAI client from an NVIDIA NeMo Agent toolkit configuration.
|
|
141
|
+
|
|
142
|
+
The wrapper requires the ``nvidia-nat[strands]`` extra and a valid OpenAI-compatible
|
|
143
|
+
API key. When ``llm_config.api_key`` is empty, the integration falls back to the
|
|
144
|
+
``OPENAI_API_KEY`` environment variable. Responses API features are disabled through
|
|
145
|
+
``validate_no_responses_api`` because Strands handles tool execution inside the
|
|
146
|
+
framework runtime. The yielded client is patched with NeMo Agent toolkit retry and
|
|
147
|
+
thinking hooks so that framework-level policies remain consistent.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
llm_config: OpenAI configuration declared in the workflow.
|
|
151
|
+
_builder: Builder instance provided by the workflow factory (unused).
|
|
152
|
+
|
|
153
|
+
Yields:
|
|
154
|
+
Strands ``OpenAIModel`` objects ready to stream responses with NeMo Agent toolkit
|
|
155
|
+
retry/thinking behaviors applied.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
validate_no_responses_api(llm_config, LLMFrameworkEnum.STRANDS)
|
|
159
|
+
|
|
160
|
+
from strands.models.openai import OpenAIModel
|
|
161
|
+
|
|
162
|
+
params = llm_config.model_dump(
|
|
163
|
+
exclude={"type", "api_type", "api_key", "base_url", "model_name", "max_retries", "thinking"},
|
|
164
|
+
by_alias=True,
|
|
165
|
+
exclude_none=True)
|
|
166
|
+
|
|
167
|
+
client = OpenAIModel(
|
|
168
|
+
client_args={
|
|
169
|
+
"api_key": get_secret_value(llm_config.api_key) or os.getenv("OPENAI_API_KEY"),
|
|
170
|
+
"base_url": llm_config.base_url,
|
|
171
|
+
},
|
|
172
|
+
model_id=llm_config.model_name,
|
|
173
|
+
params=params,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
yield _patch_llm_based_on_config(client, llm_config)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.STRANDS)
|
|
180
|
+
async def nim_strands(llm_config: NIMModelConfig, _builder: Builder) -> AsyncGenerator[Any, None]:
|
|
181
|
+
"""Build a Strands OpenAI-compatible client for NVIDIA NIM endpoints.
|
|
182
|
+
|
|
183
|
+
Install the ``nvidia-nat[strands]`` extra and provide a NIM API key either through
|
|
184
|
+
``llm_config.api_key`` or the ``NVIDIA_API_KEY`` environment variable. The wrapper
|
|
185
|
+
uses the OpenAI-compatible Strands client so Strands can route tool calls while the
|
|
186
|
+
NeMo Agent toolkit continues to manage retries, timeouts, and optional thinking
|
|
187
|
+
prompts. Responses API options are blocked to avoid conflicting execution models.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
llm_config: Configuration for calling NVIDIA NIM by way of the OpenAI protocol.
|
|
191
|
+
_builder: Builder instance supplied during workflow construction (unused).
|
|
192
|
+
|
|
193
|
+
Yields:
|
|
194
|
+
Patched Strands clients that stream responses using the NVIDIA NIM endpoint
|
|
195
|
+
configured in ``llm_config``.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
validate_no_responses_api(llm_config, LLMFrameworkEnum.STRANDS)
|
|
199
|
+
|
|
200
|
+
# NIM is OpenAI compatible; use OpenAI model with NIM base_url and api_key
|
|
201
|
+
from strands.models.openai import OpenAIModel
|
|
202
|
+
|
|
203
|
+
# Create a custom OpenAI model that formats text content as strings for NIM compatibility
|
|
204
|
+
class NIMCompatibleOpenAIModel(OpenAIModel):
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def format_request_message_content(cls, content):
|
|
208
|
+
"""Format OpenAI compatible content block with reasoning support.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
content: Message content.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
OpenAI compatible content block.
|
|
215
|
+
|
|
216
|
+
Raises:
|
|
217
|
+
TypeError: If the content block type cannot be converted to
|
|
218
|
+
an OpenAI-compatible format.
|
|
219
|
+
"""
|
|
220
|
+
# Handle reasoning content by extracting the text
|
|
221
|
+
if isinstance(content, dict) and "reasoningContent" in content:
|
|
222
|
+
reasoning_text = content["reasoningContent"].get("reasoningText", {}).get("text", "")
|
|
223
|
+
return {"text": reasoning_text, "type": "text"}
|
|
224
|
+
|
|
225
|
+
# Fall back to parent implementation for other content types
|
|
226
|
+
return super().format_request_message_content(content)
|
|
227
|
+
|
|
228
|
+
@classmethod
|
|
229
|
+
def format_request_messages(cls, messages, system_prompt=None, *, system_prompt_content=None, **kwargs):
|
|
230
|
+
# Get the formatted messages from the parent
|
|
231
|
+
formatted_messages = super().format_request_messages(messages,
|
|
232
|
+
system_prompt,
|
|
233
|
+
system_prompt_content=system_prompt_content,
|
|
234
|
+
**kwargs)
|
|
235
|
+
|
|
236
|
+
# Convert content arrays with only text to strings for NIM
|
|
237
|
+
# compatibility
|
|
238
|
+
for msg in formatted_messages:
|
|
239
|
+
content = msg.get("content")
|
|
240
|
+
if (isinstance(content, list) and len(content) == 1 and isinstance(content[0], str)):
|
|
241
|
+
# If content is a single-item list with a string, flatten it
|
|
242
|
+
msg["content"] = content[0]
|
|
243
|
+
elif (isinstance(content, list)
|
|
244
|
+
and all(isinstance(item, dict) and item.get("type") == "text" for item in content)):
|
|
245
|
+
# If all items are text blocks, join them into a single
|
|
246
|
+
# string
|
|
247
|
+
text_content = "".join(item["text"] for item in content)
|
|
248
|
+
# Ensure we don't send empty strings (NIM rejects them)
|
|
249
|
+
msg["content"] = (text_content if text_content.strip() else " ")
|
|
250
|
+
elif isinstance(content, list) and len(content) == 0:
|
|
251
|
+
# Handle empty content arrays
|
|
252
|
+
msg["content"] = " "
|
|
253
|
+
elif isinstance(content, str) and not content.strip():
|
|
254
|
+
# Handle empty strings
|
|
255
|
+
msg["content"] = " "
|
|
256
|
+
|
|
257
|
+
return formatted_messages
|
|
258
|
+
|
|
259
|
+
params = llm_config.model_dump(
|
|
260
|
+
exclude={"type", "api_type", "api_key", "base_url", "model_name", "max_retries", "thinking"},
|
|
261
|
+
by_alias=True,
|
|
262
|
+
exclude_none=True)
|
|
263
|
+
|
|
264
|
+
# Determine base_url
|
|
265
|
+
base_url = llm_config.base_url or "https://integrate.api.nvidia.com/v1"
|
|
266
|
+
|
|
267
|
+
# Determine api_key; use dummy key for custom NIM endpoints without authentication
|
|
268
|
+
# If base_url is populated (not None) and no API key is available, use a dummy value
|
|
269
|
+
api_key = get_secret_value(llm_config.api_key) or os.getenv("NVIDIA_API_KEY")
|
|
270
|
+
if llm_config.base_url and llm_config.base_url.strip() and api_key is None:
|
|
271
|
+
api_key = "dummy-api-key"
|
|
272
|
+
|
|
273
|
+
client = NIMCompatibleOpenAIModel(
|
|
274
|
+
client_args={
|
|
275
|
+
"api_key": api_key,
|
|
276
|
+
"base_url": base_url,
|
|
277
|
+
},
|
|
278
|
+
model_id=llm_config.model_name,
|
|
279
|
+
params=params,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
yield _patch_llm_based_on_config(client, llm_config)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.STRANDS)
|
|
286
|
+
async def bedrock_strands(llm_config: AWSBedrockModelConfig, _builder: Builder) -> AsyncGenerator[Any, None]:
|
|
287
|
+
"""Build a Strands Bedrock client from an NVIDIA NeMo Agent toolkit configuration.
|
|
288
|
+
|
|
289
|
+
The integration expects the ``nvidia-nat[strands]`` extra plus AWS credentials that
|
|
290
|
+
can be resolved by ``boto3``. Credentials are loaded in the following priority:
|
|
291
|
+
|
|
292
|
+
1. Explicit values embedded in the active AWS profile referenced by
|
|
293
|
+
``llm_config.credentials_profile_name``.
|
|
294
|
+
2. Standard environment variables such as ``AWS_ACCESS_KEY_ID``,
|
|
295
|
+
``AWS_SECRET_ACCESS_KEY``, and ``AWS_SESSION_TOKEN``.
|
|
296
|
+
3. Ambient credentials provided by the compute environment (for example, an IAM role
|
|
297
|
+
attached to the container or instance).
|
|
298
|
+
|
|
299
|
+
When ``llm_config.region_name`` is ``"None"`` or ``None`` Strands uses the regional
|
|
300
|
+
default configured in AWS. Responses API options remain unsupported so that Strands
|
|
301
|
+
can own tool execution. Retry and thinking hooks are added automatically before the
|
|
302
|
+
Bedrock client is yielded.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
llm_config: AWS Bedrock configuration saved in the workflow.
|
|
306
|
+
_builder: Builder reference supplied by the workflow factory (unused).
|
|
307
|
+
|
|
308
|
+
Yields:
|
|
309
|
+
Strands ``BedrockModel`` instances configured for the selected Bedrock
|
|
310
|
+
``model_name`` and patched with NeMo Agent toolkit retry/thinking helpers.
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
validate_no_responses_api(llm_config, LLMFrameworkEnum.STRANDS)
|
|
314
|
+
|
|
315
|
+
from strands.models.bedrock import BedrockModel
|
|
316
|
+
|
|
317
|
+
params = llm_config.model_dump(
|
|
318
|
+
exclude={
|
|
319
|
+
"type",
|
|
320
|
+
"api_type",
|
|
321
|
+
"model_name",
|
|
322
|
+
"region_name",
|
|
323
|
+
"base_url",
|
|
324
|
+
"max_retries",
|
|
325
|
+
"thinking",
|
|
326
|
+
"context_size",
|
|
327
|
+
"credentials_profile_name",
|
|
328
|
+
},
|
|
329
|
+
by_alias=True,
|
|
330
|
+
exclude_none=True,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
region = None if llm_config.region_name in (None, "None") else llm_config.region_name
|
|
334
|
+
client = BedrockModel(model_id=llm_config.model_name,
|
|
335
|
+
region_name=region,
|
|
336
|
+
endpoint_url=llm_config.base_url,
|
|
337
|
+
**params)
|
|
338
|
+
|
|
339
|
+
yield _patch_llm_based_on_config(client, llm_config)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
# isort:skip_file
|
|
18
|
+
|
|
19
|
+
from . import llm
|
|
20
|
+
from . import tool_wrapper
|
|
@@ -0,0 +1,615 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import copy
|
|
18
|
+
import importlib
|
|
19
|
+
import json
|
|
20
|
+
import logging
|
|
21
|
+
import time
|
|
22
|
+
import uuid
|
|
23
|
+
from collections.abc import AsyncGenerator
|
|
24
|
+
from collections.abc import Callable
|
|
25
|
+
from typing import Any
|
|
26
|
+
|
|
27
|
+
from nat.builder.context import Context
|
|
28
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
29
|
+
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
30
|
+
from nat.data_models.intermediate_step import IntermediateStepType
|
|
31
|
+
from nat.data_models.intermediate_step import StreamEventData
|
|
32
|
+
from nat.data_models.intermediate_step import TraceMetadata
|
|
33
|
+
from nat.data_models.intermediate_step import UsageInfo
|
|
34
|
+
from nat.profiler.callbacks.base_callback_class import BaseProfilerCallback
|
|
35
|
+
from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class StrandsToolInstrumentationHook:
|
|
41
|
+
"""Hook callbacks for instrumenting Strands tool invocations.
|
|
42
|
+
|
|
43
|
+
This class provides callbacks for Strands' hooks API to
|
|
44
|
+
capture tool execution events and emit proper TOOL_START/END spans.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, handler: 'StrandsProfilerHandler'):
|
|
48
|
+
"""Initialize the hook with a reference to the profiler handler.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
handler: StrandsProfilerHandler instance that manages this hook
|
|
52
|
+
"""
|
|
53
|
+
self.handler = handler
|
|
54
|
+
self._tool_start_times: dict[str, float] = {}
|
|
55
|
+
self._step_manager = Context.get().intermediate_step_manager
|
|
56
|
+
|
|
57
|
+
def on_before_tool_invocation(self, event: Any) -> None:
|
|
58
|
+
"""Handle tool invocation start.
|
|
59
|
+
|
|
60
|
+
Called by Strands before a tool is executed.
|
|
61
|
+
Emits a TOOL_START span.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
event: BeforeToolInvocationEvent from Strands
|
|
65
|
+
"""
|
|
66
|
+
try:
|
|
67
|
+
tool_use = event.tool_use
|
|
68
|
+
selected_tool = event.selected_tool
|
|
69
|
+
|
|
70
|
+
if not selected_tool:
|
|
71
|
+
logger.debug("Tool hook: no selected_tool, skipping")
|
|
72
|
+
return
|
|
73
|
+
|
|
74
|
+
# Extract tool information
|
|
75
|
+
tool_name, tool_use_id, tool_input = self._extract_tool_info(selected_tool, tool_use)
|
|
76
|
+
|
|
77
|
+
# Store start time for duration calculation
|
|
78
|
+
self._tool_start_times[tool_use_id] = time.time()
|
|
79
|
+
|
|
80
|
+
step_manager = self._step_manager
|
|
81
|
+
|
|
82
|
+
start_payload = IntermediateStepPayload(
|
|
83
|
+
event_type=IntermediateStepType.TOOL_START,
|
|
84
|
+
framework=LLMFrameworkEnum.STRANDS,
|
|
85
|
+
name=tool_name,
|
|
86
|
+
UUID=tool_use_id,
|
|
87
|
+
data=StreamEventData(input=str(tool_input), output=""),
|
|
88
|
+
metadata=TraceMetadata(
|
|
89
|
+
tool_inputs=copy.deepcopy(tool_input),
|
|
90
|
+
tool_info=copy.deepcopy(getattr(selected_tool, 'tool_spec', {})),
|
|
91
|
+
),
|
|
92
|
+
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
step_manager.push_intermediate_step(start_payload)
|
|
96
|
+
|
|
97
|
+
logger.debug("TOOL_START: %s (ID: %s)", tool_name, tool_use_id)
|
|
98
|
+
except Exception: # noqa: BLE001
|
|
99
|
+
logger.error("Error in before_tool_invocation")
|
|
100
|
+
raise
|
|
101
|
+
|
|
102
|
+
def on_after_tool_invocation(self, event: Any) -> None:
|
|
103
|
+
"""Handle tool invocation end.
|
|
104
|
+
|
|
105
|
+
Called by Strands after a tool execution completes.
|
|
106
|
+
Emits a TOOL_END span.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
event: AfterToolInvocationEvent from Strands
|
|
110
|
+
"""
|
|
111
|
+
try:
|
|
112
|
+
tool_use = event.tool_use
|
|
113
|
+
selected_tool = event.selected_tool
|
|
114
|
+
result = event.result
|
|
115
|
+
exception = event.exception
|
|
116
|
+
|
|
117
|
+
if not selected_tool:
|
|
118
|
+
logger.debug("Tool hook: no selected_tool, skipping")
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
# Extract tool information
|
|
122
|
+
tool_name, tool_use_id, tool_input = self._extract_tool_info(selected_tool, tool_use)
|
|
123
|
+
start_time = self._tool_start_times.pop(tool_use_id, time.time())
|
|
124
|
+
|
|
125
|
+
# Extract output from result
|
|
126
|
+
tool_output = ""
|
|
127
|
+
if isinstance(result, dict):
|
|
128
|
+
content = result.get('content', [])
|
|
129
|
+
if isinstance(content, list):
|
|
130
|
+
for item in content:
|
|
131
|
+
if isinstance(item, dict) and 'text' in item:
|
|
132
|
+
tool_output += item['text']
|
|
133
|
+
|
|
134
|
+
# Handle errors
|
|
135
|
+
if exception:
|
|
136
|
+
tool_output = f"Error: {exception}"
|
|
137
|
+
|
|
138
|
+
# Use stored step_manager to avoid context isolation issues
|
|
139
|
+
step_manager = self._step_manager
|
|
140
|
+
|
|
141
|
+
end_payload = IntermediateStepPayload(
|
|
142
|
+
event_type=IntermediateStepType.TOOL_END,
|
|
143
|
+
span_event_timestamp=start_time,
|
|
144
|
+
framework=LLMFrameworkEnum.STRANDS,
|
|
145
|
+
name=tool_name,
|
|
146
|
+
UUID=tool_use_id,
|
|
147
|
+
metadata=TraceMetadata(tool_outputs=tool_output),
|
|
148
|
+
usage_info=UsageInfo(token_usage=TokenUsageBaseModel()),
|
|
149
|
+
data=StreamEventData(input=str(tool_input), output=tool_output),
|
|
150
|
+
)
|
|
151
|
+
step_manager.push_intermediate_step(end_payload)
|
|
152
|
+
|
|
153
|
+
logger.debug("TOOL_END: %s (ID: %s)", tool_name, tool_use_id)
|
|
154
|
+
|
|
155
|
+
except Exception: # noqa: BLE001
|
|
156
|
+
logger.error("Failed to handle after_tool_invocation")
|
|
157
|
+
raise
|
|
158
|
+
|
|
159
|
+
def _extract_tool_info(self, selected_tool: Any, tool_use: dict) -> tuple[str, str, dict]:
|
|
160
|
+
"""Extract tool name, ID, and input from event.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
selected_tool: The tool being invoked
|
|
164
|
+
tool_use: Tool use dictionary from Strands event
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Tuple of (tool_name, tool_use_id, tool_input)
|
|
168
|
+
"""
|
|
169
|
+
tool_name = getattr(selected_tool, 'tool_name', tool_use.get('name', 'unknown_tool'))
|
|
170
|
+
tool_use_id = tool_use.get('toolUseId')
|
|
171
|
+
if tool_use_id is None:
|
|
172
|
+
logger.warning("Missing toolUseId in tool_use event, using 'unknown' fallback")
|
|
173
|
+
tool_use_id = "unknown"
|
|
174
|
+
tool_input = tool_use.get('input', {}) or {}
|
|
175
|
+
return tool_name, tool_use_id, tool_input
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class StrandsProfilerHandler(BaseProfilerCallback):
|
|
179
|
+
|
|
180
|
+
def __init__(self) -> None:
|
|
181
|
+
super().__init__()
|
|
182
|
+
self._patched: bool = False
|
|
183
|
+
self.last_call_ts = time.time()
|
|
184
|
+
|
|
185
|
+
# Note: tool hooks are now created per-agent-instance in wrapped_init
|
|
186
|
+
# to avoid shared state in concurrent execution
|
|
187
|
+
|
|
188
|
+
def instrument(self) -> None:
|
|
189
|
+
"""
|
|
190
|
+
Instrument Strands for telemetry capture.
|
|
191
|
+
|
|
192
|
+
This patches:
|
|
193
|
+
1. Model streaming methods (OpenAI/Bedrock) for LLM spans
|
|
194
|
+
2. Agent.__init__ to auto-register tool hooks on Agent creation
|
|
195
|
+
|
|
196
|
+
Tool instrumentation uses Strands' hooks API,
|
|
197
|
+
which is automatically registered when an Agent is instantiated.
|
|
198
|
+
"""
|
|
199
|
+
if self._patched:
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
# Patch LLM streaming methods
|
|
204
|
+
OpenAIModel = None
|
|
205
|
+
BedrockModel = None
|
|
206
|
+
try:
|
|
207
|
+
openai_mod = importlib.import_module("strands.models.openai")
|
|
208
|
+
OpenAIModel = getattr(openai_mod, "OpenAIModel", None)
|
|
209
|
+
except Exception: # noqa: BLE001
|
|
210
|
+
OpenAIModel = None
|
|
211
|
+
|
|
212
|
+
try:
|
|
213
|
+
bedrock_mod = importlib.import_module("strands.models.bedrock")
|
|
214
|
+
BedrockModel = getattr(bedrock_mod, "BedrockModel", None)
|
|
215
|
+
except Exception: # noqa: BLE001
|
|
216
|
+
BedrockModel = None
|
|
217
|
+
|
|
218
|
+
to_patch: list[tuple[type, str]] = []
|
|
219
|
+
if OpenAIModel is not None:
|
|
220
|
+
for name in ("stream", "structured_output"):
|
|
221
|
+
if hasattr(OpenAIModel, name):
|
|
222
|
+
to_patch.append((OpenAIModel, name))
|
|
223
|
+
if BedrockModel is not None:
|
|
224
|
+
for name in ("stream", "structured_output"):
|
|
225
|
+
if hasattr(BedrockModel, name):
|
|
226
|
+
to_patch.append((BedrockModel, name))
|
|
227
|
+
|
|
228
|
+
for cls, method_name in to_patch:
|
|
229
|
+
original = getattr(cls, method_name)
|
|
230
|
+
wrapped = self._wrap_stream_method(original)
|
|
231
|
+
setattr(cls, method_name, wrapped)
|
|
232
|
+
|
|
233
|
+
debug_targets = [f"{c.__name__}.{m}" for c, m in to_patch]
|
|
234
|
+
logger.info(
|
|
235
|
+
"StrandsProfilerHandler LLM instrumentation: %s",
|
|
236
|
+
debug_targets,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Patch Agent.__init__ to auto-register hooks
|
|
240
|
+
self._instrument_agent_init()
|
|
241
|
+
|
|
242
|
+
self._patched = True
|
|
243
|
+
|
|
244
|
+
except Exception: # noqa: BLE001
|
|
245
|
+
logger.error("Failed to instrument Strands models")
|
|
246
|
+
raise
|
|
247
|
+
|
|
248
|
+
def _instrument_agent_init(self) -> None:
|
|
249
|
+
"""Patch Agent.__init__ to auto-register hooks on instantiation.
|
|
250
|
+
|
|
251
|
+
This ensures that whenever a Strands Agent is created, our tool
|
|
252
|
+
instrumentation hooks are automatically registered without requiring
|
|
253
|
+
any user code changes.
|
|
254
|
+
"""
|
|
255
|
+
try:
|
|
256
|
+
# Import Agent class
|
|
257
|
+
agent_mod = importlib.import_module("strands.agent.agent")
|
|
258
|
+
Agent = getattr(agent_mod, "Agent", None)
|
|
259
|
+
|
|
260
|
+
if Agent is None:
|
|
261
|
+
logger.warning("Agent class not found in strands.agent.agent")
|
|
262
|
+
return
|
|
263
|
+
|
|
264
|
+
# Save reference to handler in closure
|
|
265
|
+
handler = self
|
|
266
|
+
|
|
267
|
+
# Save original __init__
|
|
268
|
+
original_init = Agent.__init__
|
|
269
|
+
|
|
270
|
+
def wrapped_init(agent_self, *args, **kwargs):
|
|
271
|
+
"""Wrapped Agent.__init__ that auto-registers hooks."""
|
|
272
|
+
# Call original init
|
|
273
|
+
original_init(agent_self, *args, **kwargs)
|
|
274
|
+
|
|
275
|
+
# Auto-register tool hooks on this agent instance
|
|
276
|
+
try:
|
|
277
|
+
# Import hook event types
|
|
278
|
+
# pylint: disable=import-outside-toplevel
|
|
279
|
+
from strands.hooks import AfterToolCallEvent
|
|
280
|
+
from strands.hooks import BeforeToolCallEvent
|
|
281
|
+
|
|
282
|
+
# Create a dedicated hook instance for this agent
|
|
283
|
+
agent_tool_hook = StrandsToolInstrumentationHook(handler)
|
|
284
|
+
|
|
285
|
+
# Register tool hooks on this agent instance
|
|
286
|
+
agent_self.hooks.add_callback(BeforeToolCallEvent, agent_tool_hook.on_before_tool_invocation)
|
|
287
|
+
agent_self.hooks.add_callback(AfterToolCallEvent, agent_tool_hook.on_after_tool_invocation)
|
|
288
|
+
|
|
289
|
+
logger.debug("Strands tool hooks registered on Agent instance")
|
|
290
|
+
|
|
291
|
+
except Exception: # noqa: BLE001
|
|
292
|
+
logger.exception("Failed to auto-register hooks")
|
|
293
|
+
|
|
294
|
+
# Replace Agent.__init__ with wrapped version
|
|
295
|
+
Agent.__init__ = wrapped_init
|
|
296
|
+
|
|
297
|
+
logger.info("Strands Agent.__init__ instrumentation applied")
|
|
298
|
+
|
|
299
|
+
except Exception: # noqa: BLE001
|
|
300
|
+
logger.exception("Failed to instrument Agent.__init__")
|
|
301
|
+
|
|
302
|
+
def _extract_model_info(self, model_instance: Any) -> tuple[str, dict[str, Any]]:
|
|
303
|
+
"""Extract model name from Strands model instance."""
|
|
304
|
+
model_name = ""
|
|
305
|
+
|
|
306
|
+
for attr_name in ['config', 'client_args']:
|
|
307
|
+
if hasattr(model_instance, attr_name):
|
|
308
|
+
attr_value = getattr(model_instance, attr_name, None)
|
|
309
|
+
if isinstance(attr_value, dict):
|
|
310
|
+
for key, val in attr_value.items():
|
|
311
|
+
if 'model' in key.lower() and val:
|
|
312
|
+
model_name = str(val)
|
|
313
|
+
break
|
|
314
|
+
if model_name:
|
|
315
|
+
break
|
|
316
|
+
|
|
317
|
+
return str(model_name), {}
|
|
318
|
+
|
|
319
|
+
def _wrap_stream_method(self, original: Callable[..., Any]) -> Callable[..., Any]:
|
|
320
|
+
# Capture handler reference in closure
|
|
321
|
+
handler = self
|
|
322
|
+
|
|
323
|
+
async def wrapped(model_self, *args, **kwargs) -> AsyncGenerator[Any, None]: # type: ignore[override]
|
|
324
|
+
"""
|
|
325
|
+
Wrapper for Strands model streaming that emits paired
|
|
326
|
+
LLM_START/END spans with usage and metrics.
|
|
327
|
+
"""
|
|
328
|
+
context = Context.get()
|
|
329
|
+
step_manager = context.intermediate_step_manager
|
|
330
|
+
|
|
331
|
+
event_uuid = str(uuid.uuid4())
|
|
332
|
+
start_time = time.time()
|
|
333
|
+
|
|
334
|
+
# Extract model info and parameters
|
|
335
|
+
model_name, _ = handler._extract_model_info(model_self)
|
|
336
|
+
|
|
337
|
+
# Extract messages from args (Strands passes as positional args)
|
|
338
|
+
# Signature: stream(self, messages, tool_specs=None,
|
|
339
|
+
# system_prompt=None, **kwargs)
|
|
340
|
+
raw_messages = args[0] if args else []
|
|
341
|
+
tool_specs = args[1] if len(args) > 1 else kwargs.get("tool_specs")
|
|
342
|
+
system_prompt = (args[2] if len(args) > 2 else kwargs.get("system_prompt"))
|
|
343
|
+
|
|
344
|
+
# Build chat_inputs with system prompt and messages
|
|
345
|
+
all_messages = []
|
|
346
|
+
if system_prompt:
|
|
347
|
+
all_messages.append({"text": system_prompt, "role": "system"})
|
|
348
|
+
if isinstance(raw_messages, list):
|
|
349
|
+
all_messages.extend(copy.deepcopy(raw_messages))
|
|
350
|
+
|
|
351
|
+
# Extract tools schema for metadata
|
|
352
|
+
tools_schema = []
|
|
353
|
+
if tool_specs and isinstance(tool_specs, list):
|
|
354
|
+
try:
|
|
355
|
+
tools_schema = [{
|
|
356
|
+
"type": "function",
|
|
357
|
+
"function": {
|
|
358
|
+
"name": tool_spec.get("name", "unknown"),
|
|
359
|
+
"description": tool_spec.get("description", ""),
|
|
360
|
+
"parameters": tool_spec.get("inputSchema", {}).get("json", {})
|
|
361
|
+
}
|
|
362
|
+
} for tool_spec in tool_specs]
|
|
363
|
+
except Exception: # noqa: BLE001
|
|
364
|
+
logger.debug("Failed to extract tools schema", exc_info=True)
|
|
365
|
+
tools_schema = []
|
|
366
|
+
|
|
367
|
+
# Extract string representation of last user message for data.input
|
|
368
|
+
# (full message history is in metadata.chat_inputs)
|
|
369
|
+
llm_input_str = ""
|
|
370
|
+
if all_messages:
|
|
371
|
+
last_msg = all_messages[-1]
|
|
372
|
+
if isinstance(last_msg, dict) and 'text' in last_msg:
|
|
373
|
+
llm_input_str = last_msg['text']
|
|
374
|
+
elif isinstance(last_msg, dict):
|
|
375
|
+
llm_input_str = str(last_msg)
|
|
376
|
+
else:
|
|
377
|
+
llm_input_str = str(last_msg)
|
|
378
|
+
|
|
379
|
+
# Always emit START first (before streaming begins)
|
|
380
|
+
start_payload = IntermediateStepPayload(
|
|
381
|
+
event_type=IntermediateStepType.LLM_START,
|
|
382
|
+
framework=LLMFrameworkEnum.STRANDS,
|
|
383
|
+
name=str(model_name),
|
|
384
|
+
UUID=event_uuid,
|
|
385
|
+
data=StreamEventData(input=llm_input_str, output=""),
|
|
386
|
+
metadata=TraceMetadata(
|
|
387
|
+
chat_inputs=copy.deepcopy(all_messages),
|
|
388
|
+
tools_schema=copy.deepcopy(tools_schema),
|
|
389
|
+
),
|
|
390
|
+
usage_info=UsageInfo(
|
|
391
|
+
token_usage=TokenUsageBaseModel(),
|
|
392
|
+
num_llm_calls=1,
|
|
393
|
+
seconds_between_calls=int(time.time() - self.last_call_ts),
|
|
394
|
+
),
|
|
395
|
+
)
|
|
396
|
+
step_manager.push_intermediate_step(start_payload)
|
|
397
|
+
self.last_call_ts = time.time()
|
|
398
|
+
|
|
399
|
+
# Collect output text, tool calls, and token usage while streaming
|
|
400
|
+
output_text = ""
|
|
401
|
+
tool_calls = [] # List of tool calls made by the LLM
|
|
402
|
+
current_tool_call = None # Currently accumulating tool call
|
|
403
|
+
token_usage = TokenUsageBaseModel()
|
|
404
|
+
ended: bool = False
|
|
405
|
+
|
|
406
|
+
def _push_end_if_needed() -> None:
|
|
407
|
+
nonlocal ended
|
|
408
|
+
if ended:
|
|
409
|
+
return
|
|
410
|
+
|
|
411
|
+
# Determine the output to show in the span
|
|
412
|
+
# If there are tool calls, format them as the output
|
|
413
|
+
# Otherwise, use the text response
|
|
414
|
+
if tool_calls:
|
|
415
|
+
# Format tool calls as readable output
|
|
416
|
+
tool_call_strs = []
|
|
417
|
+
for tc in tool_calls:
|
|
418
|
+
tool_name = tc.get('name', 'unknown')
|
|
419
|
+
tool_input = tc.get('input', {})
|
|
420
|
+
tool_call_strs.append(f"Tool: {tool_name}\nInput: {tool_input}")
|
|
421
|
+
output_content = "\n\n".join(tool_call_strs)
|
|
422
|
+
else:
|
|
423
|
+
output_content = output_text
|
|
424
|
+
|
|
425
|
+
chat_responses_list = []
|
|
426
|
+
if output_content:
|
|
427
|
+
chat_responses_list = [output_content]
|
|
428
|
+
|
|
429
|
+
# Build metadata with standard NAT structure
|
|
430
|
+
metadata = TraceMetadata(
|
|
431
|
+
chat_responses=chat_responses_list,
|
|
432
|
+
chat_inputs=all_messages,
|
|
433
|
+
tools_schema=copy.deepcopy(tools_schema),
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# Push END with input/output and token usage
|
|
437
|
+
end_payload = IntermediateStepPayload(
|
|
438
|
+
event_type=IntermediateStepType.LLM_END,
|
|
439
|
+
span_event_timestamp=start_time,
|
|
440
|
+
framework=LLMFrameworkEnum.STRANDS,
|
|
441
|
+
name=str(model_name),
|
|
442
|
+
UUID=event_uuid,
|
|
443
|
+
data=StreamEventData(input=llm_input_str, output=output_content),
|
|
444
|
+
usage_info=UsageInfo(token_usage=token_usage, num_llm_calls=1),
|
|
445
|
+
metadata=metadata,
|
|
446
|
+
)
|
|
447
|
+
step_manager.push_intermediate_step(end_payload)
|
|
448
|
+
ended = True
|
|
449
|
+
|
|
450
|
+
try:
|
|
451
|
+
agen = original(model_self, *args, **kwargs)
|
|
452
|
+
if hasattr(agen, "__aiter__"):
|
|
453
|
+
async for ev in agen: # type: ignore
|
|
454
|
+
try:
|
|
455
|
+
# Extract text content
|
|
456
|
+
text_content = self._extract_text_from_event(ev)
|
|
457
|
+
if text_content:
|
|
458
|
+
output_text += text_content
|
|
459
|
+
|
|
460
|
+
# Extract tool call information
|
|
461
|
+
tool_call_info = self._extract_tool_call_from_event(ev)
|
|
462
|
+
if tool_call_info:
|
|
463
|
+
if "name" in tool_call_info:
|
|
464
|
+
# New tool call starting
|
|
465
|
+
if current_tool_call:
|
|
466
|
+
# Finalize and save previous tool call
|
|
467
|
+
self._finalize_tool_call(current_tool_call)
|
|
468
|
+
tool_calls.append(current_tool_call)
|
|
469
|
+
current_tool_call = tool_call_info
|
|
470
|
+
elif "input_chunk" in tool_call_info and current_tool_call:
|
|
471
|
+
# Accumulate input JSON string chunks
|
|
472
|
+
current_tool_call["input_str"] += tool_call_info["input_chunk"]
|
|
473
|
+
|
|
474
|
+
# Check for contentBlockStop to finalize current tool call
|
|
475
|
+
if "contentBlockStop" in ev and current_tool_call:
|
|
476
|
+
self._finalize_tool_call(current_tool_call)
|
|
477
|
+
tool_calls.append(current_tool_call)
|
|
478
|
+
current_tool_call = None
|
|
479
|
+
|
|
480
|
+
# Extract usage information (but don't push END yet - wait for all text)
|
|
481
|
+
usage_info = self._extract_usage_from_event(ev)
|
|
482
|
+
if usage_info:
|
|
483
|
+
token_usage = TokenUsageBaseModel(**usage_info)
|
|
484
|
+
|
|
485
|
+
except Exception: # noqa: BLE001
|
|
486
|
+
logger.debug("Failed to extract streaming fields from event", exc_info=True)
|
|
487
|
+
yield ev
|
|
488
|
+
else:
|
|
489
|
+
# Non-async generator fallback
|
|
490
|
+
res = agen
|
|
491
|
+
if asyncio.iscoroutine(res):
|
|
492
|
+
res = await res # type: ignore[func-returns-value]
|
|
493
|
+
yield res
|
|
494
|
+
finally:
|
|
495
|
+
# Ensure END is always pushed
|
|
496
|
+
_push_end_if_needed()
|
|
497
|
+
|
|
498
|
+
return wrapped
|
|
499
|
+
|
|
500
|
+
def _extract_text_from_event(self, ev: dict) -> str:
|
|
501
|
+
"""Extract text content from a Strands event.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
ev: Event dictionary from Strands stream
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
Extracted text content or empty string
|
|
508
|
+
"""
|
|
509
|
+
if not isinstance(ev, dict):
|
|
510
|
+
return ""
|
|
511
|
+
|
|
512
|
+
# Try multiple possible locations for text content
|
|
513
|
+
if "data" in ev:
|
|
514
|
+
return str(ev["data"])
|
|
515
|
+
|
|
516
|
+
# Check for Strands contentBlockDelta structure (for streaming text responses)
|
|
517
|
+
if "contentBlockDelta" in ev and isinstance(ev["contentBlockDelta"], dict):
|
|
518
|
+
delta = ev["contentBlockDelta"].get("delta", {})
|
|
519
|
+
if isinstance(delta, dict) and "text" in delta:
|
|
520
|
+
return str(delta["text"])
|
|
521
|
+
|
|
522
|
+
# Check for other common text fields
|
|
523
|
+
if "content" in ev:
|
|
524
|
+
return str(ev["content"])
|
|
525
|
+
|
|
526
|
+
if "text" in ev:
|
|
527
|
+
return str(ev["text"])
|
|
528
|
+
|
|
529
|
+
# Check for nested content
|
|
530
|
+
if "message" in ev and isinstance(ev["message"], dict):
|
|
531
|
+
if "content" in ev["message"]:
|
|
532
|
+
return str(ev["message"]["content"])
|
|
533
|
+
|
|
534
|
+
return ""
|
|
535
|
+
|
|
536
|
+
def _finalize_tool_call(self, tool_call: dict[str, Any]) -> None:
|
|
537
|
+
"""Parse the accumulated input_str JSON and store in the input field.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
tool_call: Tool call dictionary with input_str to parse
|
|
541
|
+
"""
|
|
542
|
+
input_str = tool_call.get("input_str", "")
|
|
543
|
+
if input_str:
|
|
544
|
+
try:
|
|
545
|
+
tool_call["input"] = json.loads(input_str)
|
|
546
|
+
except (json.JSONDecodeError, ValueError):
|
|
547
|
+
logger.debug("Failed to parse tool input JSON: %s", input_str)
|
|
548
|
+
tool_call["input"] = {"raw": input_str}
|
|
549
|
+
# Remove the temporary input_str field
|
|
550
|
+
tool_call.pop("input_str", None)
|
|
551
|
+
|
|
552
|
+
def _extract_tool_call_from_event(self, ev: dict) -> dict[str, Any] | None:
|
|
553
|
+
"""Extract tool call information from a Strands event.
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
ev: Event dictionary from Strands stream
|
|
557
|
+
|
|
558
|
+
Returns:
|
|
559
|
+
Dictionary with tool call info (name, input_chunk) or None if not a tool call
|
|
560
|
+
"""
|
|
561
|
+
if not isinstance(ev, dict):
|
|
562
|
+
return None
|
|
563
|
+
|
|
564
|
+
# Check for contentBlockStart with toolUse
|
|
565
|
+
if "contentBlockStart" in ev:
|
|
566
|
+
start = ev["contentBlockStart"].get("start", {})
|
|
567
|
+
if isinstance(start, dict) and "toolUse" in start:
|
|
568
|
+
tool_use = start["toolUse"]
|
|
569
|
+
return {
|
|
570
|
+
"name": tool_use.get("name", "unknown"),
|
|
571
|
+
"id": tool_use.get("toolUseId", "unknown"),
|
|
572
|
+
"input_str": "", # Will accumulate JSON string chunks
|
|
573
|
+
"input": {} # Will be parsed at the end
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
# Check for contentBlockDelta with toolUse input (streaming chunks)
|
|
577
|
+
if "contentBlockDelta" in ev:
|
|
578
|
+
delta = ev["contentBlockDelta"].get("delta", {})
|
|
579
|
+
if isinstance(delta, dict) and "toolUse" in delta:
|
|
580
|
+
tool_use_delta = delta["toolUse"]
|
|
581
|
+
input_chunk = tool_use_delta.get("input", "")
|
|
582
|
+
if input_chunk:
|
|
583
|
+
# Return the chunk to be accumulated
|
|
584
|
+
return {"input_chunk": input_chunk}
|
|
585
|
+
|
|
586
|
+
return None
|
|
587
|
+
|
|
588
|
+
def _extract_usage_from_event(self, ev: dict) -> dict[str, int] | None:
|
|
589
|
+
"""Extract usage information from a Strands event.
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
ev: Event dictionary from Strands stream
|
|
593
|
+
|
|
594
|
+
Returns:
|
|
595
|
+
Dictionary with token usage info or None if not found
|
|
596
|
+
"""
|
|
597
|
+
if not isinstance(ev, dict):
|
|
598
|
+
return None
|
|
599
|
+
|
|
600
|
+
md = ev.get("metadata")
|
|
601
|
+
if not isinstance(md, dict):
|
|
602
|
+
return None
|
|
603
|
+
|
|
604
|
+
usage = md.get("usage")
|
|
605
|
+
if not isinstance(usage, dict):
|
|
606
|
+
return None
|
|
607
|
+
|
|
608
|
+
try:
|
|
609
|
+
return {
|
|
610
|
+
"prompt_tokens": int(usage.get("inputTokens") or 0),
|
|
611
|
+
"completion_tokens": int(usage.get("outputTokens") or 0),
|
|
612
|
+
"total_tokens": int(usage.get("totalTokens") or 0),
|
|
613
|
+
}
|
|
614
|
+
except (ValueError, TypeError):
|
|
615
|
+
return None
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import AsyncGenerator
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
23
|
+
from nat.builder.builder import Builder
|
|
24
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
25
|
+
from nat.builder.function import Function
|
|
26
|
+
from nat.cli.register_workflow import register_tool_wrapper
|
|
27
|
+
from strands.types.tools import AgentTool # type: ignore
|
|
28
|
+
from strands.types.tools import ToolSpec # type: ignore
|
|
29
|
+
from strands.types.tools import ToolUse # type: ignore
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _json_schema_from_pydantic(model: type[BaseModel]) -> dict[str, Any]:
|
|
35
|
+
try:
|
|
36
|
+
schema = model.model_json_schema()
|
|
37
|
+
for k in ("title", "additionalProperties"):
|
|
38
|
+
if k in schema:
|
|
39
|
+
del schema[k]
|
|
40
|
+
return {"json": schema}
|
|
41
|
+
except Exception:
|
|
42
|
+
logger.exception("Failed to generate JSON schema")
|
|
43
|
+
return {"json": {}}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _to_tool_result(tool_use_id: str, value: Any) -> dict[str, Any]:
|
|
47
|
+
if isinstance(value, (dict, list, tuple)): # noqa: UP038
|
|
48
|
+
content_item = {"json": value}
|
|
49
|
+
else:
|
|
50
|
+
content_item = {"text": str(value)}
|
|
51
|
+
return {
|
|
52
|
+
"toolUseId": tool_use_id,
|
|
53
|
+
"status": "success",
|
|
54
|
+
"content": [content_item],
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _to_error_result(tool_use_id: str, err: Exception) -> dict[str, Any]:
|
|
59
|
+
return {
|
|
60
|
+
"toolUseId": tool_use_id,
|
|
61
|
+
"status": "error",
|
|
62
|
+
"content": [{
|
|
63
|
+
"text": f"{type(err).__name__}: {err!s}"
|
|
64
|
+
}],
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class NATFunctionAgentTool(AgentTool):
|
|
69
|
+
"""Concrete Strands AgentTool that wraps a NAT Function."""
|
|
70
|
+
|
|
71
|
+
def __init__(self, name: str, description: str | None, input_schema: dict[str, Any], fn: Function) -> None:
|
|
72
|
+
super().__init__()
|
|
73
|
+
|
|
74
|
+
self._tool_name = name
|
|
75
|
+
self._tool_spec: ToolSpec = {
|
|
76
|
+
"name": name,
|
|
77
|
+
"description": description or name,
|
|
78
|
+
"inputSchema": input_schema,
|
|
79
|
+
}
|
|
80
|
+
self._fn = fn
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def tool_name(self) -> str:
|
|
84
|
+
return self._tool_name
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def tool_spec(self) -> ToolSpec:
|
|
88
|
+
return self._tool_spec
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def tool_type(self) -> str:
|
|
92
|
+
return "function"
|
|
93
|
+
|
|
94
|
+
async def stream(self, tool_use: ToolUse, _invocation_state: dict[str, Any],
|
|
95
|
+
**_kwargs: Any) -> AsyncGenerator[Any, None]:
|
|
96
|
+
from strands.types._events import ToolResultEvent # type: ignore
|
|
97
|
+
from strands.types._events import ToolStreamEvent
|
|
98
|
+
|
|
99
|
+
tool_use_id = tool_use.get("toolUseId", "unknown")
|
|
100
|
+
tool_input = tool_use.get("input", {}) or {}
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
if (self._fn.has_streaming_output and not self._fn.has_single_output):
|
|
104
|
+
last_chunk: Any | None = None
|
|
105
|
+
async for chunk in self._fn.acall_stream(**tool_input):
|
|
106
|
+
last_chunk = chunk
|
|
107
|
+
yield ToolStreamEvent(tool_use, chunk)
|
|
108
|
+
final = _to_tool_result(tool_use_id, last_chunk if last_chunk is not None else "")
|
|
109
|
+
yield ToolResultEvent(final)
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
result = await self._fn.acall_invoke(**tool_input)
|
|
113
|
+
yield ToolResultEvent(_to_tool_result(tool_use_id, result))
|
|
114
|
+
except Exception as exc: # noqa: BLE001
|
|
115
|
+
logger.exception("Strands tool '%s' failed", self.tool_name)
|
|
116
|
+
yield ToolResultEvent(_to_error_result(tool_use_id, exc))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@register_tool_wrapper(wrapper_type=LLMFrameworkEnum.STRANDS)
|
|
120
|
+
def strands_tool_wrapper(name: str, fn: Function, _builder: Builder) -> NATFunctionAgentTool:
|
|
121
|
+
"""Create a Strands `AgentTool` wrapper for a NAT `Function`."""
|
|
122
|
+
if fn.input_schema is None:
|
|
123
|
+
raise ValueError(f"Tool '{name}' must define an input schema")
|
|
124
|
+
|
|
125
|
+
input_schema = _json_schema_from_pydantic(fn.input_schema)
|
|
126
|
+
description = fn.description or name
|
|
127
|
+
return NATFunctionAgentTool(name=name, description=description, input_schema=input_schema, fn=fn)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: nvidia-nat-strands
|
|
3
|
+
Version: 1.4.0a20251209
|
|
4
|
+
Summary: Subpackage for AWS Strands integration in NeMo Agent toolkit
|
|
5
|
+
Author: NVIDIA Corporation
|
|
6
|
+
Maintainer: NVIDIA Corporation
|
|
7
|
+
License: Apache-2.0
|
|
8
|
+
Project-URL: documentation, https://docs.nvidia.com/nemo/agent-toolkit/latest/
|
|
9
|
+
Project-URL: source, https://github.com/NVIDIA/NeMo-Agent-Toolkit
|
|
10
|
+
Keywords: ai,rag,agents
|
|
11
|
+
Classifier: Programming Language :: Python
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
15
|
+
Requires-Python: <3.14,>=3.11
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
Requires-Dist: nvidia-nat==v1.4.0a20251209
|
|
18
|
+
Requires-Dist: strands-agents~=1.17
|
|
19
|
+
Requires-Dist: strands-agents-tools~=0.2
|
|
20
|
+
|
|
21
|
+
<!--
|
|
22
|
+
SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
23
|
+
SPDX-License-Identifier: Apache-2.0
|
|
24
|
+
|
|
25
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
26
|
+
you may not use this file except in compliance with the License.
|
|
27
|
+
You may obtain a copy of the License at
|
|
28
|
+
|
|
29
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
30
|
+
|
|
31
|
+
Unless required by applicable law or agreed to in writing, software
|
|
32
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
33
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
34
|
+
See the License for the specific language governing permissions and
|
|
35
|
+
limitations under the License.
|
|
36
|
+
-->
|
|
37
|
+
|
|
38
|
+

|
|
39
|
+
|
|
40
|
+
# NVIDIA NeMo Agent Toolkit Subpackage
|
|
41
|
+
This is a subpackage for AWS Strands integration in NeMo Agent toolkit.
|
|
42
|
+
|
|
43
|
+
For more information about the NVIDIA NeMo Agent toolkit, please visit the [NeMo Agent toolkit GitHub Repo](https://github.com/NVIDIA/NeMo-Agent-Toolkit).
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
nat/meta/pypi.md,sha256=i_a-Zt6wbWAPjlFqa6CsvuaMZTjglgKxF7HzX3OZW5g,1112
|
|
2
|
+
nat/plugins/strands/__init__.py,sha256=GUJrgGtpvyMUCjUBvR3faAdv-tZzbU9W-izgx9aMEQg,680
|
|
3
|
+
nat/plugins/strands/llm.py,sha256=0vGhfS-98E0S2G6pPCRJ-Y_G53vCUofNvPwr-9mmL2U,15057
|
|
4
|
+
nat/plugins/strands/register.py,sha256=oaVjGKwLkqzsSvimHHvgxfOLdpjEsEy8-VCtZPep5Bw,760
|
|
5
|
+
nat/plugins/strands/strands_callback_handler.py,sha256=BeCsR4tTSxk7AVSt7Hq3KqEDyXI3Zwz2j4cwr_rAJv0,24992
|
|
6
|
+
nat/plugins/strands/tool_wrapper.py,sha256=uWWEK4zMe3tL6j-Y_FQOAlD4rgw9O1V3iZxpt03JSk0,4604
|
|
7
|
+
nvidia_nat_strands-1.4.0a20251209.dist-info/METADATA,sha256=OHR8MPzLKKyeG3AIUtBAWKZQyPSzfNByHJVF8Ky-4ZY,1887
|
|
8
|
+
nvidia_nat_strands-1.4.0a20251209.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
+
nvidia_nat_strands-1.4.0a20251209.dist-info/entry_points.txt,sha256=6lerpn7DNmp8gJEPtjQIUA-fF6LlOhCm77lCP6ZOPA4,60
|
|
10
|
+
nvidia_nat_strands-1.4.0a20251209.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
|
11
|
+
nvidia_nat_strands-1.4.0a20251209.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
nat
|