nvidia-nat-autogen 1.4.0a20260120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +23 -0
- nat/plugins/autogen/__init__.py +14 -0
- nat/plugins/autogen/callback_handler.py +627 -0
- nat/plugins/autogen/llm.py +459 -0
- nat/plugins/autogen/register.py +22 -0
- nat/plugins/autogen/tool_wrapper.py +181 -0
- nvidia_nat_autogen-1.4.0a20260120.dist-info/METADATA +44 -0
- nvidia_nat_autogen-1.4.0a20260120.dist-info/RECORD +13 -0
- nvidia_nat_autogen-1.4.0a20260120.dist-info/WHEEL +5 -0
- nvidia_nat_autogen-1.4.0a20260120.dist-info/entry_points.txt +2 -0
- nvidia_nat_autogen-1.4.0a20260120.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_autogen-1.4.0a20260120.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_autogen-1.4.0a20260120.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,459 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""AutoGen LLM client registrations for NAT.
|
|
16
|
+
|
|
17
|
+
This module provides AutoGen-compatible LLM client wrappers for the following providers:
|
|
18
|
+
|
|
19
|
+
Supported Providers
|
|
20
|
+
-------------------
|
|
21
|
+
- **OpenAI**: Direct OpenAI API integration via ``OpenAIChatCompletionClient``
|
|
22
|
+
- **Azure OpenAI**: Azure-hosted OpenAI models via ``AzureOpenAIChatCompletionClient``
|
|
23
|
+
- **NVIDIA NIM**: OpenAI-compatible endpoints for NVIDIA models
|
|
24
|
+
- **LiteLLM**: Unified interface to multiple LLM providers via OpenAI-compatible client
|
|
25
|
+
- **AWS Bedrock**: Amazon Bedrock models (Claude/Anthropic) via ``AnthropicBedrockChatCompletionClient``
|
|
26
|
+
|
|
27
|
+
Each wrapper:
|
|
28
|
+
- Patches clients with NAT retry logic from ``RetryMixin``
|
|
29
|
+
- Injects chain-of-thought prompts when ``ThinkingMixin`` is configured
|
|
30
|
+
- Removes NAT-specific config keys before instantiating AutoGen clients
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import logging
|
|
34
|
+
from collections.abc import AsyncGenerator
|
|
35
|
+
from typing import Any
|
|
36
|
+
from typing import TypeVar
|
|
37
|
+
|
|
38
|
+
from nat.builder.builder import Builder
|
|
39
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
40
|
+
from nat.cli.register_workflow import register_llm_client
|
|
41
|
+
from nat.data_models.common import get_secret_value
|
|
42
|
+
from nat.data_models.llm import LLMBaseConfig
|
|
43
|
+
from nat.data_models.retry_mixin import RetryMixin
|
|
44
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
45
|
+
from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig
|
|
46
|
+
from nat.llm.azure_openai_llm import AzureOpenAIModelConfig
|
|
47
|
+
from nat.llm.litellm_llm import LiteLlmModelConfig
|
|
48
|
+
from nat.llm.nim_llm import NIMModelConfig
|
|
49
|
+
from nat.llm.openai_llm import OpenAIModelConfig
|
|
50
|
+
from nat.llm.utils.thinking import BaseThinkingInjector
|
|
51
|
+
from nat.llm.utils.thinking import FunctionArgumentWrapper
|
|
52
|
+
from nat.llm.utils.thinking import patch_with_thinking
|
|
53
|
+
from nat.utils.exception_handlers.automatic_retries import patch_with_retry
|
|
54
|
+
from nat.utils.type_utils import override
|
|
55
|
+
|
|
56
|
+
logger = logging.getLogger(__name__)
|
|
57
|
+
|
|
58
|
+
ModelType = TypeVar("ModelType")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _patch_autogen_client_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType:
|
|
62
|
+
"""Patch AutoGen client with NAT mixins (retry, thinking).
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
client (ModelType): The AutoGen LLM client to patch.
|
|
66
|
+
llm_config (LLMBaseConfig): The LLM configuration containing mixin settings.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
ModelType: The patched AutoGen LLM client.
|
|
70
|
+
"""
|
|
71
|
+
from autogen_core.models import SystemMessage
|
|
72
|
+
|
|
73
|
+
class AutoGenThinkingInjector(BaseThinkingInjector):
|
|
74
|
+
"""Thinking injector for AutoGen message format.
|
|
75
|
+
|
|
76
|
+
Injects a system message at the start of the message list to enable
|
|
77
|
+
chain-of-thought prompting for supported models (e.g., Nemotron).
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
@override
|
|
81
|
+
def inject(self, messages: list, *args: Any, **kwargs: Any) -> FunctionArgumentWrapper:
|
|
82
|
+
"""Inject thinking system prompt into AutoGen messages.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
messages (list): List of AutoGen messages (UserMessage, AssistantMessage, SystemMessage)
|
|
86
|
+
*args (Any): Additional positional arguments
|
|
87
|
+
**kwargs (Any): Additional keyword arguments
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
FunctionArgumentWrapper: Wrapper containing modified args and kwargs
|
|
91
|
+
"""
|
|
92
|
+
system_message = SystemMessage(content=self.system_prompt)
|
|
93
|
+
new_messages = [system_message] + messages
|
|
94
|
+
return FunctionArgumentWrapper(new_messages, *args, **kwargs)
|
|
95
|
+
|
|
96
|
+
# Apply retry mixin if configured
|
|
97
|
+
if isinstance(llm_config, RetryMixin):
|
|
98
|
+
client = patch_with_retry(client,
|
|
99
|
+
retries=llm_config.num_retries,
|
|
100
|
+
retry_codes=llm_config.retry_on_status_codes,
|
|
101
|
+
retry_on_messages=llm_config.retry_on_errors)
|
|
102
|
+
|
|
103
|
+
# Apply thinking mixin if configured
|
|
104
|
+
if isinstance(llm_config, ThinkingMixin) and llm_config.thinking_system_prompt is not None:
|
|
105
|
+
client = patch_with_thinking(
|
|
106
|
+
client,
|
|
107
|
+
AutoGenThinkingInjector(system_prompt=llm_config.thinking_system_prompt,
|
|
108
|
+
function_names=[
|
|
109
|
+
"create",
|
|
110
|
+
"create_stream",
|
|
111
|
+
]))
|
|
112
|
+
|
|
113
|
+
return client
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
async def _close_autogen_client(client: Any) -> None:
|
|
117
|
+
"""Close an AutoGen client if it has a close method.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
client: The AutoGen client to close
|
|
121
|
+
"""
|
|
122
|
+
try:
|
|
123
|
+
if hasattr(client, "close"):
|
|
124
|
+
await client.close()
|
|
125
|
+
elif hasattr(client, "_client") and hasattr(client._client, "close"):
|
|
126
|
+
await client._client.close()
|
|
127
|
+
except Exception:
|
|
128
|
+
logger.debug("Error closing AutoGen client", exc_info=True)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@register_llm_client(config_type=OpenAIModelConfig, wrapper_type=LLMFrameworkEnum.AUTOGEN)
|
|
132
|
+
async def openai_autogen(llm_config: OpenAIModelConfig, _builder: Builder) -> AsyncGenerator[ModelType, None]:
|
|
133
|
+
"""Create OpenAI client for AutoGen integration.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
llm_config (OpenAIModelConfig): OpenAI model configuration
|
|
137
|
+
_builder (Builder): NAT builder instance
|
|
138
|
+
|
|
139
|
+
Yields:
|
|
140
|
+
AsyncGenerator[ModelType, None]: Configured AutoGen OpenAI client
|
|
141
|
+
"""
|
|
142
|
+
from autogen_core.models import ModelFamily
|
|
143
|
+
from autogen_core.models import ModelInfo
|
|
144
|
+
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
|
145
|
+
|
|
146
|
+
# Extract AutoGen-compatible configuration
|
|
147
|
+
config_obj = {
|
|
148
|
+
**llm_config.model_dump(
|
|
149
|
+
exclude={"type", "model_name", "thinking"},
|
|
150
|
+
by_alias=True,
|
|
151
|
+
exclude_none=True,
|
|
152
|
+
),
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
# Define model info for AutoGen 0.7.4 (replaces model_capabilities)
|
|
156
|
+
model_info = ModelInfo(vision=False,
|
|
157
|
+
function_calling=True,
|
|
158
|
+
json_output=True,
|
|
159
|
+
family=ModelFamily.UNKNOWN,
|
|
160
|
+
structured_output=True,
|
|
161
|
+
multiple_system_messages=True)
|
|
162
|
+
|
|
163
|
+
# Add required AutoGen 0.7.4 parameters
|
|
164
|
+
config_obj.update({"model_info": model_info})
|
|
165
|
+
config_obj.pop("model", None)
|
|
166
|
+
|
|
167
|
+
# Create AutoGen OpenAI client
|
|
168
|
+
client = OpenAIChatCompletionClient(model=llm_config.model_name, **config_obj)
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
# Apply NAT mixins and yield patched client
|
|
172
|
+
yield _patch_autogen_client_based_on_config(client, llm_config)
|
|
173
|
+
finally:
|
|
174
|
+
await _close_autogen_client(client)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@register_llm_client(config_type=AzureOpenAIModelConfig, wrapper_type=LLMFrameworkEnum.AUTOGEN)
|
|
178
|
+
async def azure_openai_autogen(llm_config: AzureOpenAIModelConfig,
|
|
179
|
+
_builder: Builder) -> AsyncGenerator[ModelType, None]:
|
|
180
|
+
"""Create Azure OpenAI client for AutoGen integration.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
llm_config (AzureOpenAIModelConfig): Azure OpenAI model configuration
|
|
184
|
+
_builder (Builder): NAT builder instance
|
|
185
|
+
|
|
186
|
+
Yields:
|
|
187
|
+
AsyncGenerator[ModelType, None]: Configured AutoGen Azure OpenAI client
|
|
188
|
+
"""
|
|
189
|
+
from autogen_core.models import ModelFamily
|
|
190
|
+
from autogen_core.models import ModelInfo
|
|
191
|
+
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient
|
|
192
|
+
|
|
193
|
+
config_obj = {
|
|
194
|
+
"api_key":
|
|
195
|
+
llm_config.api_key,
|
|
196
|
+
"base_url":
|
|
197
|
+
f"{llm_config.azure_endpoint}/openai/deployments/{llm_config.azure_deployment}",
|
|
198
|
+
"api_version":
|
|
199
|
+
llm_config.api_version,
|
|
200
|
+
**llm_config.model_dump(
|
|
201
|
+
exclude={"type", "azure_deployment", "thinking", "azure_endpoint", "api_version"},
|
|
202
|
+
by_alias=True,
|
|
203
|
+
exclude_none=True,
|
|
204
|
+
),
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
model_info = ModelInfo(vision=False,
|
|
208
|
+
function_calling=True,
|
|
209
|
+
json_output=True,
|
|
210
|
+
family=ModelFamily.UNKNOWN,
|
|
211
|
+
structured_output=True,
|
|
212
|
+
multiple_system_messages=True)
|
|
213
|
+
|
|
214
|
+
config_obj.update({"model_info": model_info})
|
|
215
|
+
|
|
216
|
+
client = AzureOpenAIChatCompletionClient(
|
|
217
|
+
model=llm_config.azure_deployment, # Use deployment name for Azure
|
|
218
|
+
**config_obj)
|
|
219
|
+
|
|
220
|
+
try:
|
|
221
|
+
# Apply NAT mixins and yield patched client
|
|
222
|
+
yield _patch_autogen_client_based_on_config(client, llm_config)
|
|
223
|
+
finally:
|
|
224
|
+
await _close_autogen_client(client)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _strip_strict_from_tools_deep(kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
228
|
+
"""Remove 'strict' field from tool definitions in request kwargs for NIM compatibility.
|
|
229
|
+
|
|
230
|
+
NIM's API doesn't support OpenAI's 'strict' parameter in tool/function definitions.
|
|
231
|
+
AutoGen adds this field automatically, so we strip it before sending to NIM.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
kwargs: The request keyword arguments dictionary
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
kwargs with 'strict' field removed from tool function definitions
|
|
238
|
+
"""
|
|
239
|
+
tools = kwargs.get("tools")
|
|
240
|
+
|
|
241
|
+
# Handle NotGiven sentinel or None - just return unchanged
|
|
242
|
+
if tools is None or not isinstance(tools, list | tuple):
|
|
243
|
+
return kwargs
|
|
244
|
+
|
|
245
|
+
kwargs = kwargs.copy()
|
|
246
|
+
cleaned_tools = []
|
|
247
|
+
for tool in tools:
|
|
248
|
+
if isinstance(tool, dict):
|
|
249
|
+
tool_copy = tool.copy()
|
|
250
|
+
if "function" in tool_copy and isinstance(tool_copy["function"], dict):
|
|
251
|
+
func_copy = tool_copy["function"].copy()
|
|
252
|
+
func_copy.pop("strict", None)
|
|
253
|
+
tool_copy["function"] = func_copy
|
|
254
|
+
cleaned_tools.append(tool_copy)
|
|
255
|
+
else:
|
|
256
|
+
cleaned_tools.append(tool)
|
|
257
|
+
kwargs["tools"] = cleaned_tools
|
|
258
|
+
return kwargs
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _patch_nim_client_for_tools(client: ModelType) -> ModelType:
|
|
262
|
+
"""Patch AutoGen client's underlying OpenAI client to strip 'strict' from tools for NIM.
|
|
263
|
+
|
|
264
|
+
This patches at the lowest level (the actual OpenAI AsyncClient) to ensure
|
|
265
|
+
the 'strict' field is removed after AutoGen's internal processing.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
client: The AutoGen OpenAI client to patch
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
The patched client (unmodified if patching fails)
|
|
272
|
+
"""
|
|
273
|
+
try:
|
|
274
|
+
# Access the underlying OpenAI AsyncClient (protected member)
|
|
275
|
+
openai_client = getattr(client, "_client", None)
|
|
276
|
+
if openai_client is None:
|
|
277
|
+
logger.warning("Unable to patch NIM client for tools - _client attribute not found")
|
|
278
|
+
return client
|
|
279
|
+
|
|
280
|
+
# Verify the expected structure exists
|
|
281
|
+
if not hasattr(openai_client, "chat") or not hasattr(openai_client.chat, "completions"):
|
|
282
|
+
logger.warning("Unable to patch NIM client for tools - unexpected client structure")
|
|
283
|
+
return client
|
|
284
|
+
|
|
285
|
+
# Patch the chat.completions.create method
|
|
286
|
+
original_create = openai_client.chat.completions.create
|
|
287
|
+
|
|
288
|
+
async def patched_create(*args: Any, **kwargs: Any) -> Any:
|
|
289
|
+
# Strip 'strict' from tools before sending to NIM
|
|
290
|
+
kwargs = _strip_strict_from_tools_deep(kwargs)
|
|
291
|
+
return await original_create(*args, **kwargs)
|
|
292
|
+
|
|
293
|
+
openai_client.chat.completions.create = patched_create
|
|
294
|
+
return client
|
|
295
|
+
|
|
296
|
+
except AttributeError as e:
|
|
297
|
+
logger.warning("Unable to patch NIM client for tools - AutoGen internal structure changed: %s", e)
|
|
298
|
+
return client
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.AUTOGEN)
|
|
302
|
+
async def nim_autogen(llm_config: NIMModelConfig, _builder: Builder) -> AsyncGenerator[ModelType, None]:
|
|
303
|
+
"""Create NVIDIA NIM client for AutoGen integration.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
llm_config (NIMModelConfig): NIM model configuration
|
|
307
|
+
_builder (Builder): NAT builder instance
|
|
308
|
+
|
|
309
|
+
Yields:
|
|
310
|
+
Configured AutoGen NIM client (via OpenAI compatibility)
|
|
311
|
+
"""
|
|
312
|
+
from autogen_core.models import ModelFamily
|
|
313
|
+
from autogen_core.models import ModelInfo
|
|
314
|
+
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
|
315
|
+
|
|
316
|
+
# Extract NIM configuration for OpenAI-compatible client
|
|
317
|
+
config_obj = {
|
|
318
|
+
**llm_config.model_dump(
|
|
319
|
+
exclude={"type", "model_name", "thinking"},
|
|
320
|
+
by_alias=True,
|
|
321
|
+
exclude_none=True,
|
|
322
|
+
),
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
# Define model info for AutoGen 0.7.4 (replaces model_capabilities)
|
|
326
|
+
# Note: structured_output=False because NIM doesn't support OpenAI's 'strict' parameter
|
|
327
|
+
model_info = ModelInfo(vision=False,
|
|
328
|
+
function_calling=True,
|
|
329
|
+
json_output=True,
|
|
330
|
+
family=ModelFamily.UNKNOWN,
|
|
331
|
+
structured_output=False,
|
|
332
|
+
multiple_system_messages=True)
|
|
333
|
+
|
|
334
|
+
# Add required AutoGen 0.7.4 parameters
|
|
335
|
+
config_obj.update({"model_info": model_info})
|
|
336
|
+
config_obj.pop("model", None)
|
|
337
|
+
|
|
338
|
+
# NIM uses OpenAI-compatible API
|
|
339
|
+
client = OpenAIChatCompletionClient(model=llm_config.model_name, **config_obj)
|
|
340
|
+
|
|
341
|
+
# Patch to remove 'strict' field from tools (NIM doesn't support it)
|
|
342
|
+
client = _patch_nim_client_for_tools(client)
|
|
343
|
+
|
|
344
|
+
try:
|
|
345
|
+
# Apply NAT mixins and yield patched client
|
|
346
|
+
yield _patch_autogen_client_based_on_config(client, llm_config)
|
|
347
|
+
finally:
|
|
348
|
+
await _close_autogen_client(client)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@register_llm_client(config_type=LiteLlmModelConfig, wrapper_type=LLMFrameworkEnum.AUTOGEN)
|
|
352
|
+
async def litellm_autogen(llm_config: LiteLlmModelConfig, _builder: Builder) -> AsyncGenerator[ModelType, None]:
|
|
353
|
+
"""Create LiteLLM client for AutoGen integration.
|
|
354
|
+
|
|
355
|
+
LiteLLM provides a unified interface to multiple LLM providers. This integration
|
|
356
|
+
uses AutoGen's OpenAI-compatible client since LiteLLM exposes an OpenAI-compatible
|
|
357
|
+
API endpoint.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
llm_config (LiteLlmModelConfig): LiteLLM model configuration
|
|
361
|
+
_builder (Builder): NAT builder instance
|
|
362
|
+
|
|
363
|
+
Yields:
|
|
364
|
+
AsyncGenerator[ModelType, None]: Configured AutoGen client via LiteLLM
|
|
365
|
+
"""
|
|
366
|
+
from autogen_core.models import ModelFamily
|
|
367
|
+
from autogen_core.models import ModelInfo
|
|
368
|
+
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
|
369
|
+
|
|
370
|
+
# Extract LiteLLM configuration for OpenAI-compatible client
|
|
371
|
+
config_obj = {
|
|
372
|
+
**llm_config.model_dump(
|
|
373
|
+
exclude={"type", "model_name", "thinking"},
|
|
374
|
+
by_alias=True,
|
|
375
|
+
exclude_none=True,
|
|
376
|
+
),
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
# Resolve API key from secret if provided
|
|
380
|
+
if llm_config.api_key is not None:
|
|
381
|
+
config_obj["api_key"] = get_secret_value(llm_config.api_key)
|
|
382
|
+
|
|
383
|
+
# Define model info for AutoGen
|
|
384
|
+
model_info = ModelInfo(vision=False,
|
|
385
|
+
function_calling=True,
|
|
386
|
+
json_output=True,
|
|
387
|
+
family=ModelFamily.UNKNOWN,
|
|
388
|
+
structured_output=True,
|
|
389
|
+
multiple_system_messages=True)
|
|
390
|
+
|
|
391
|
+
config_obj.update({"model_info": model_info})
|
|
392
|
+
config_obj.pop("model", None)
|
|
393
|
+
|
|
394
|
+
# LiteLLM uses OpenAI-compatible API
|
|
395
|
+
client = OpenAIChatCompletionClient(model=llm_config.model_name, **config_obj)
|
|
396
|
+
|
|
397
|
+
try:
|
|
398
|
+
# Apply NAT mixins and yield patched client
|
|
399
|
+
yield _patch_autogen_client_based_on_config(client, llm_config)
|
|
400
|
+
finally:
|
|
401
|
+
await _close_autogen_client(client)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
@register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.AUTOGEN)
|
|
405
|
+
async def bedrock_autogen(llm_config: AWSBedrockModelConfig, _builder: Builder) -> AsyncGenerator[ModelType, None]:
|
|
406
|
+
"""Create AWS Bedrock client for AutoGen integration.
|
|
407
|
+
|
|
408
|
+
Uses AutoGen's ``AnthropicBedrockChatCompletionClient`` which supports
|
|
409
|
+
Anthropic Claude models hosted on AWS Bedrock. Credentials are loaded in
|
|
410
|
+
the following priority:
|
|
411
|
+
|
|
412
|
+
1. Explicit values from ``credentials_profile_name`` in the AWS profile.
|
|
413
|
+
2. Standard environment variables (``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``,
|
|
414
|
+
``AWS_SESSION_TOKEN``).
|
|
415
|
+
3. Ambient credentials provided by the compute environment (IAM role).
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
llm_config (AWSBedrockModelConfig): AWS Bedrock model configuration
|
|
419
|
+
_builder (Builder): NAT builder instance
|
|
420
|
+
|
|
421
|
+
Yields:
|
|
422
|
+
AsyncGenerator[ModelType, None]: Configured AutoGen Bedrock client
|
|
423
|
+
"""
|
|
424
|
+
from autogen_ext.models.anthropic import AnthropicBedrockChatCompletionClient
|
|
425
|
+
|
|
426
|
+
# Build Bedrock-specific configuration
|
|
427
|
+
bedrock_config: dict[str, Any] = {
|
|
428
|
+
"model": llm_config.model_name,
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
# Handle region - None or "None" string should use AWS default
|
|
432
|
+
if llm_config.region_name not in (None, "None"):
|
|
433
|
+
bedrock_config["aws_region"] = llm_config.region_name
|
|
434
|
+
|
|
435
|
+
# Add optional parameters if provided
|
|
436
|
+
if llm_config.credentials_profile_name is not None:
|
|
437
|
+
bedrock_config["aws_profile"] = llm_config.credentials_profile_name
|
|
438
|
+
|
|
439
|
+
if llm_config.base_url is not None:
|
|
440
|
+
bedrock_config["base_url"] = llm_config.base_url
|
|
441
|
+
|
|
442
|
+
# Add model parameters
|
|
443
|
+
if llm_config.max_tokens is not None:
|
|
444
|
+
bedrock_config["max_tokens"] = llm_config.max_tokens
|
|
445
|
+
|
|
446
|
+
if llm_config.temperature is not None:
|
|
447
|
+
bedrock_config["temperature"] = llm_config.temperature
|
|
448
|
+
|
|
449
|
+
if llm_config.top_p is not None:
|
|
450
|
+
bedrock_config["top_p"] = llm_config.top_p
|
|
451
|
+
|
|
452
|
+
# Create AutoGen Bedrock client
|
|
453
|
+
client = AnthropicBedrockChatCompletionClient(**bedrock_config)
|
|
454
|
+
|
|
455
|
+
try:
|
|
456
|
+
# Apply NAT mixins and yield patched client
|
|
457
|
+
yield _patch_autogen_client_based_on_config(client, llm_config)
|
|
458
|
+
finally:
|
|
459
|
+
await _close_autogen_client(client)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
# isort:skip_file
|
|
18
|
+
"""AutoGen plugin registration for NAT components."""
|
|
19
|
+
|
|
20
|
+
from . import llm
|
|
21
|
+
from . import tool_wrapper
|
|
22
|
+
from . import callback_handler
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Tool wrapper for AutoGen integration with NAT."""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import AsyncIterator
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
from dataclasses import is_dataclass
|
|
21
|
+
|
|
22
|
+
# PythonType not available in AutoGen 0.7.4, using Any instead
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
from autogen_core.tools import FunctionTool
|
|
26
|
+
from pydantic import BaseModel
|
|
27
|
+
from pydantic.dataclasses import dataclass as pydantic_dataclass
|
|
28
|
+
|
|
29
|
+
from nat.builder.builder import Builder
|
|
30
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
31
|
+
from nat.builder.function import Function
|
|
32
|
+
from nat.cli.register_workflow import register_tool_wrapper
|
|
33
|
+
from nat.utils.type_utils import DecomposedType
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def resolve_type(t: Any) -> Any:
|
|
39
|
+
"""Return the non-None member of a Union/PEP 604 union;
|
|
40
|
+
otherwise return the type unchanged.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
t (Any): The type to resolve.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Any: The resolved type.
|
|
47
|
+
"""
|
|
48
|
+
resolved = DecomposedType(t)
|
|
49
|
+
if resolved.is_optional:
|
|
50
|
+
return resolved.get_optional_type().type
|
|
51
|
+
return resolved.type
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@register_tool_wrapper(wrapper_type=LLMFrameworkEnum.AUTOGEN)
|
|
55
|
+
def autogen_tool_wrapper(
|
|
56
|
+
name: str,
|
|
57
|
+
fn: Function,
|
|
58
|
+
_builder: Builder # pylint: disable=W0613
|
|
59
|
+
) -> Any: # Changed from Callable[..., Any] to Any to allow FunctionTool return
|
|
60
|
+
"""Wrap a NAT `Function` as an AutoGen `FunctionTool`.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
name (str): The name of the tool.
|
|
64
|
+
fn (Function): The NAT function to wrap.
|
|
65
|
+
_builder (Builder): The NAT workflow builder to access registered components.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Any: The AutoGen FunctionTool wrapping the NAT function.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
import inspect
|
|
72
|
+
|
|
73
|
+
async def callable_ainvoke(*args: Any, **kwargs: Any) -> Any:
|
|
74
|
+
"""Async function to invoke the NAT function.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
*args: Positional arguments to pass to the NAT function.
|
|
78
|
+
**kwargs: Keyword arguments to pass to the NAT function.
|
|
79
|
+
Returns:
|
|
80
|
+
Any: The result of invoking the NAT function.
|
|
81
|
+
"""
|
|
82
|
+
return await fn.acall_invoke(*args, **kwargs)
|
|
83
|
+
|
|
84
|
+
async def callable_astream(*args: Any, **kwargs: Any) -> AsyncIterator[Any]:
|
|
85
|
+
"""Async generator to stream results from the NAT function.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
*args (Any): Positional arguments to pass to the NAT function.
|
|
89
|
+
**kwargs (Any): Keyword arguments to pass to the NAT function.
|
|
90
|
+
Yields:
|
|
91
|
+
Any: Streamed items from the NAT function.
|
|
92
|
+
"""
|
|
93
|
+
async for item in fn.acall_stream(*args, **kwargs):
|
|
94
|
+
yield item
|
|
95
|
+
|
|
96
|
+
def nat_function(
|
|
97
|
+
func: Callable[..., Any] | None = None,
|
|
98
|
+
*,
|
|
99
|
+
name: str = name,
|
|
100
|
+
description: str | None = fn.description,
|
|
101
|
+
input_schema: Any = fn.input_schema,
|
|
102
|
+
) -> Callable[..., Any]:
|
|
103
|
+
"""
|
|
104
|
+
Decorator to wrap a function as a NAT function.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
func (Callable): The function to wrap.
|
|
108
|
+
name (str): The name of the function.
|
|
109
|
+
description (str): The description of the function.
|
|
110
|
+
input_schema (BaseModel): The Pydantic model defining the input schema.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Callable[..., Any]: The wrapped function.
|
|
114
|
+
"""
|
|
115
|
+
if func is None:
|
|
116
|
+
raise ValueError("'func' must be provided.")
|
|
117
|
+
|
|
118
|
+
# If input_schema is a dataclass, convert it to a Pydantic model
|
|
119
|
+
if input_schema is not None and is_dataclass(input_schema):
|
|
120
|
+
input_schema = pydantic_dataclass(input_schema)
|
|
121
|
+
|
|
122
|
+
def decorator(func_to_wrap: Callable[..., Any]) -> Callable[..., Any]:
|
|
123
|
+
"""
|
|
124
|
+
Decorator to set metadata on the function.
|
|
125
|
+
"""
|
|
126
|
+
# Set the function's metadata
|
|
127
|
+
if name is not None:
|
|
128
|
+
func_to_wrap.__name__ = name
|
|
129
|
+
if description is not None:
|
|
130
|
+
func_to_wrap.__doc__ = description
|
|
131
|
+
|
|
132
|
+
# Set signature only if input_schema is provided
|
|
133
|
+
params: list[inspect.Parameter] = []
|
|
134
|
+
annotations: dict[str, Any] = {}
|
|
135
|
+
|
|
136
|
+
if input_schema is not None:
|
|
137
|
+
annotations = {}
|
|
138
|
+
params = []
|
|
139
|
+
model_fields = getattr(input_schema, "model_fields", {})
|
|
140
|
+
for param_name, model_field in model_fields.items():
|
|
141
|
+
resolved_type = resolve_type(model_field.annotation)
|
|
142
|
+
|
|
143
|
+
# Warn about nested Pydantic models or dataclasses that may not serialize properly
|
|
144
|
+
# Note: If autogen is updated to support nested models, this warning can be removed - or
|
|
145
|
+
# if autogen adds a mechanism to remove the tool from the function choices we can add that later.
|
|
146
|
+
if isinstance(resolved_type, type) and (issubclass(resolved_type, BaseModel)
|
|
147
|
+
or is_dataclass(resolved_type)):
|
|
148
|
+
logger.warning(
|
|
149
|
+
"Nested model detected in input schema for parameter '%s' in tool '%s'. "
|
|
150
|
+
"AutoGen may not properly serialize complex nested types for function calling. "
|
|
151
|
+
"Consider flattening the schema or using primitive types.",
|
|
152
|
+
param_name,
|
|
153
|
+
name,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
default = inspect.Parameter.empty if model_field.is_required() else model_field.default
|
|
157
|
+
params.append(
|
|
158
|
+
inspect.Parameter(param_name,
|
|
159
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
160
|
+
annotation=resolved_type,
|
|
161
|
+
default=default))
|
|
162
|
+
annotations[param_name] = resolved_type
|
|
163
|
+
func_to_wrap.__signature__ = inspect.Signature(parameters=params)
|
|
164
|
+
func_to_wrap.__annotations__ = annotations
|
|
165
|
+
|
|
166
|
+
return func_to_wrap
|
|
167
|
+
|
|
168
|
+
# Apply the decorator to the provided function
|
|
169
|
+
return decorator(func)
|
|
170
|
+
|
|
171
|
+
if fn.has_streaming_output and not fn.has_single_output:
|
|
172
|
+
logger.debug("Creating streaming FunctionTool for: %s", name)
|
|
173
|
+
callable_tool = nat_function(func=callable_astream)
|
|
174
|
+
else:
|
|
175
|
+
logger.debug("Creating non-streaming FunctionTool for: %s", name)
|
|
176
|
+
callable_tool = nat_function(func=callable_ainvoke)
|
|
177
|
+
return FunctionTool(
|
|
178
|
+
func=callable_tool,
|
|
179
|
+
name=name,
|
|
180
|
+
description=fn.description or "No description provided.",
|
|
181
|
+
)
|