nvidia-nat-autogen 1.4.0a20260120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,459 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """AutoGen LLM client registrations for NAT.
16
+
17
+ This module provides AutoGen-compatible LLM client wrappers for the following providers:
18
+
19
+ Supported Providers
20
+ -------------------
21
+ - **OpenAI**: Direct OpenAI API integration via ``OpenAIChatCompletionClient``
22
+ - **Azure OpenAI**: Azure-hosted OpenAI models via ``AzureOpenAIChatCompletionClient``
23
+ - **NVIDIA NIM**: OpenAI-compatible endpoints for NVIDIA models
24
+ - **LiteLLM**: Unified interface to multiple LLM providers via OpenAI-compatible client
25
+ - **AWS Bedrock**: Amazon Bedrock models (Claude/Anthropic) via ``AnthropicBedrockChatCompletionClient``
26
+
27
+ Each wrapper:
28
+ - Patches clients with NAT retry logic from ``RetryMixin``
29
+ - Injects chain-of-thought prompts when ``ThinkingMixin`` is configured
30
+ - Removes NAT-specific config keys before instantiating AutoGen clients
31
+ """
32
+
33
+ import logging
34
+ from collections.abc import AsyncGenerator
35
+ from typing import Any
36
+ from typing import TypeVar
37
+
38
+ from nat.builder.builder import Builder
39
+ from nat.builder.framework_enum import LLMFrameworkEnum
40
+ from nat.cli.register_workflow import register_llm_client
41
+ from nat.data_models.common import get_secret_value
42
+ from nat.data_models.llm import LLMBaseConfig
43
+ from nat.data_models.retry_mixin import RetryMixin
44
+ from nat.data_models.thinking_mixin import ThinkingMixin
45
+ from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig
46
+ from nat.llm.azure_openai_llm import AzureOpenAIModelConfig
47
+ from nat.llm.litellm_llm import LiteLlmModelConfig
48
+ from nat.llm.nim_llm import NIMModelConfig
49
+ from nat.llm.openai_llm import OpenAIModelConfig
50
+ from nat.llm.utils.thinking import BaseThinkingInjector
51
+ from nat.llm.utils.thinking import FunctionArgumentWrapper
52
+ from nat.llm.utils.thinking import patch_with_thinking
53
+ from nat.utils.exception_handlers.automatic_retries import patch_with_retry
54
+ from nat.utils.type_utils import override
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+ ModelType = TypeVar("ModelType")
59
+
60
+
61
+ def _patch_autogen_client_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType:
62
+ """Patch AutoGen client with NAT mixins (retry, thinking).
63
+
64
+ Args:
65
+ client (ModelType): The AutoGen LLM client to patch.
66
+ llm_config (LLMBaseConfig): The LLM configuration containing mixin settings.
67
+
68
+ Returns:
69
+ ModelType: The patched AutoGen LLM client.
70
+ """
71
+ from autogen_core.models import SystemMessage
72
+
73
+ class AutoGenThinkingInjector(BaseThinkingInjector):
74
+ """Thinking injector for AutoGen message format.
75
+
76
+ Injects a system message at the start of the message list to enable
77
+ chain-of-thought prompting for supported models (e.g., Nemotron).
78
+ """
79
+
80
+ @override
81
+ def inject(self, messages: list, *args: Any, **kwargs: Any) -> FunctionArgumentWrapper:
82
+ """Inject thinking system prompt into AutoGen messages.
83
+
84
+ Args:
85
+ messages (list): List of AutoGen messages (UserMessage, AssistantMessage, SystemMessage)
86
+ *args (Any): Additional positional arguments
87
+ **kwargs (Any): Additional keyword arguments
88
+
89
+ Returns:
90
+ FunctionArgumentWrapper: Wrapper containing modified args and kwargs
91
+ """
92
+ system_message = SystemMessage(content=self.system_prompt)
93
+ new_messages = [system_message] + messages
94
+ return FunctionArgumentWrapper(new_messages, *args, **kwargs)
95
+
96
+ # Apply retry mixin if configured
97
+ if isinstance(llm_config, RetryMixin):
98
+ client = patch_with_retry(client,
99
+ retries=llm_config.num_retries,
100
+ retry_codes=llm_config.retry_on_status_codes,
101
+ retry_on_messages=llm_config.retry_on_errors)
102
+
103
+ # Apply thinking mixin if configured
104
+ if isinstance(llm_config, ThinkingMixin) and llm_config.thinking_system_prompt is not None:
105
+ client = patch_with_thinking(
106
+ client,
107
+ AutoGenThinkingInjector(system_prompt=llm_config.thinking_system_prompt,
108
+ function_names=[
109
+ "create",
110
+ "create_stream",
111
+ ]))
112
+
113
+ return client
114
+
115
+
116
+ async def _close_autogen_client(client: Any) -> None:
117
+ """Close an AutoGen client if it has a close method.
118
+
119
+ Args:
120
+ client: The AutoGen client to close
121
+ """
122
+ try:
123
+ if hasattr(client, "close"):
124
+ await client.close()
125
+ elif hasattr(client, "_client") and hasattr(client._client, "close"):
126
+ await client._client.close()
127
+ except Exception:
128
+ logger.debug("Error closing AutoGen client", exc_info=True)
129
+
130
+
131
+ @register_llm_client(config_type=OpenAIModelConfig, wrapper_type=LLMFrameworkEnum.AUTOGEN)
132
+ async def openai_autogen(llm_config: OpenAIModelConfig, _builder: Builder) -> AsyncGenerator[ModelType, None]:
133
+ """Create OpenAI client for AutoGen integration.
134
+
135
+ Args:
136
+ llm_config (OpenAIModelConfig): OpenAI model configuration
137
+ _builder (Builder): NAT builder instance
138
+
139
+ Yields:
140
+ AsyncGenerator[ModelType, None]: Configured AutoGen OpenAI client
141
+ """
142
+ from autogen_core.models import ModelFamily
143
+ from autogen_core.models import ModelInfo
144
+ from autogen_ext.models.openai import OpenAIChatCompletionClient
145
+
146
+ # Extract AutoGen-compatible configuration
147
+ config_obj = {
148
+ **llm_config.model_dump(
149
+ exclude={"type", "model_name", "thinking"},
150
+ by_alias=True,
151
+ exclude_none=True,
152
+ ),
153
+ }
154
+
155
+ # Define model info for AutoGen 0.7.4 (replaces model_capabilities)
156
+ model_info = ModelInfo(vision=False,
157
+ function_calling=True,
158
+ json_output=True,
159
+ family=ModelFamily.UNKNOWN,
160
+ structured_output=True,
161
+ multiple_system_messages=True)
162
+
163
+ # Add required AutoGen 0.7.4 parameters
164
+ config_obj.update({"model_info": model_info})
165
+ config_obj.pop("model", None)
166
+
167
+ # Create AutoGen OpenAI client
168
+ client = OpenAIChatCompletionClient(model=llm_config.model_name, **config_obj)
169
+
170
+ try:
171
+ # Apply NAT mixins and yield patched client
172
+ yield _patch_autogen_client_based_on_config(client, llm_config)
173
+ finally:
174
+ await _close_autogen_client(client)
175
+
176
+
177
+ @register_llm_client(config_type=AzureOpenAIModelConfig, wrapper_type=LLMFrameworkEnum.AUTOGEN)
178
+ async def azure_openai_autogen(llm_config: AzureOpenAIModelConfig,
179
+ _builder: Builder) -> AsyncGenerator[ModelType, None]:
180
+ """Create Azure OpenAI client for AutoGen integration.
181
+
182
+ Args:
183
+ llm_config (AzureOpenAIModelConfig): Azure OpenAI model configuration
184
+ _builder (Builder): NAT builder instance
185
+
186
+ Yields:
187
+ AsyncGenerator[ModelType, None]: Configured AutoGen Azure OpenAI client
188
+ """
189
+ from autogen_core.models import ModelFamily
190
+ from autogen_core.models import ModelInfo
191
+ from autogen_ext.models.openai import AzureOpenAIChatCompletionClient
192
+
193
+ config_obj = {
194
+ "api_key":
195
+ llm_config.api_key,
196
+ "base_url":
197
+ f"{llm_config.azure_endpoint}/openai/deployments/{llm_config.azure_deployment}",
198
+ "api_version":
199
+ llm_config.api_version,
200
+ **llm_config.model_dump(
201
+ exclude={"type", "azure_deployment", "thinking", "azure_endpoint", "api_version"},
202
+ by_alias=True,
203
+ exclude_none=True,
204
+ ),
205
+ }
206
+
207
+ model_info = ModelInfo(vision=False,
208
+ function_calling=True,
209
+ json_output=True,
210
+ family=ModelFamily.UNKNOWN,
211
+ structured_output=True,
212
+ multiple_system_messages=True)
213
+
214
+ config_obj.update({"model_info": model_info})
215
+
216
+ client = AzureOpenAIChatCompletionClient(
217
+ model=llm_config.azure_deployment, # Use deployment name for Azure
218
+ **config_obj)
219
+
220
+ try:
221
+ # Apply NAT mixins and yield patched client
222
+ yield _patch_autogen_client_based_on_config(client, llm_config)
223
+ finally:
224
+ await _close_autogen_client(client)
225
+
226
+
227
+ def _strip_strict_from_tools_deep(kwargs: dict[str, Any]) -> dict[str, Any]:
228
+ """Remove 'strict' field from tool definitions in request kwargs for NIM compatibility.
229
+
230
+ NIM's API doesn't support OpenAI's 'strict' parameter in tool/function definitions.
231
+ AutoGen adds this field automatically, so we strip it before sending to NIM.
232
+
233
+ Args:
234
+ kwargs: The request keyword arguments dictionary
235
+
236
+ Returns:
237
+ kwargs with 'strict' field removed from tool function definitions
238
+ """
239
+ tools = kwargs.get("tools")
240
+
241
+ # Handle NotGiven sentinel or None - just return unchanged
242
+ if tools is None or not isinstance(tools, list | tuple):
243
+ return kwargs
244
+
245
+ kwargs = kwargs.copy()
246
+ cleaned_tools = []
247
+ for tool in tools:
248
+ if isinstance(tool, dict):
249
+ tool_copy = tool.copy()
250
+ if "function" in tool_copy and isinstance(tool_copy["function"], dict):
251
+ func_copy = tool_copy["function"].copy()
252
+ func_copy.pop("strict", None)
253
+ tool_copy["function"] = func_copy
254
+ cleaned_tools.append(tool_copy)
255
+ else:
256
+ cleaned_tools.append(tool)
257
+ kwargs["tools"] = cleaned_tools
258
+ return kwargs
259
+
260
+
261
+ def _patch_nim_client_for_tools(client: ModelType) -> ModelType:
262
+ """Patch AutoGen client's underlying OpenAI client to strip 'strict' from tools for NIM.
263
+
264
+ This patches at the lowest level (the actual OpenAI AsyncClient) to ensure
265
+ the 'strict' field is removed after AutoGen's internal processing.
266
+
267
+ Args:
268
+ client: The AutoGen OpenAI client to patch
269
+
270
+ Returns:
271
+ The patched client (unmodified if patching fails)
272
+ """
273
+ try:
274
+ # Access the underlying OpenAI AsyncClient (protected member)
275
+ openai_client = getattr(client, "_client", None)
276
+ if openai_client is None:
277
+ logger.warning("Unable to patch NIM client for tools - _client attribute not found")
278
+ return client
279
+
280
+ # Verify the expected structure exists
281
+ if not hasattr(openai_client, "chat") or not hasattr(openai_client.chat, "completions"):
282
+ logger.warning("Unable to patch NIM client for tools - unexpected client structure")
283
+ return client
284
+
285
+ # Patch the chat.completions.create method
286
+ original_create = openai_client.chat.completions.create
287
+
288
+ async def patched_create(*args: Any, **kwargs: Any) -> Any:
289
+ # Strip 'strict' from tools before sending to NIM
290
+ kwargs = _strip_strict_from_tools_deep(kwargs)
291
+ return await original_create(*args, **kwargs)
292
+
293
+ openai_client.chat.completions.create = patched_create
294
+ return client
295
+
296
+ except AttributeError as e:
297
+ logger.warning("Unable to patch NIM client for tools - AutoGen internal structure changed: %s", e)
298
+ return client
299
+
300
+
301
+ @register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.AUTOGEN)
302
+ async def nim_autogen(llm_config: NIMModelConfig, _builder: Builder) -> AsyncGenerator[ModelType, None]:
303
+ """Create NVIDIA NIM client for AutoGen integration.
304
+
305
+ Args:
306
+ llm_config (NIMModelConfig): NIM model configuration
307
+ _builder (Builder): NAT builder instance
308
+
309
+ Yields:
310
+ Configured AutoGen NIM client (via OpenAI compatibility)
311
+ """
312
+ from autogen_core.models import ModelFamily
313
+ from autogen_core.models import ModelInfo
314
+ from autogen_ext.models.openai import OpenAIChatCompletionClient
315
+
316
+ # Extract NIM configuration for OpenAI-compatible client
317
+ config_obj = {
318
+ **llm_config.model_dump(
319
+ exclude={"type", "model_name", "thinking"},
320
+ by_alias=True,
321
+ exclude_none=True,
322
+ ),
323
+ }
324
+
325
+ # Define model info for AutoGen 0.7.4 (replaces model_capabilities)
326
+ # Note: structured_output=False because NIM doesn't support OpenAI's 'strict' parameter
327
+ model_info = ModelInfo(vision=False,
328
+ function_calling=True,
329
+ json_output=True,
330
+ family=ModelFamily.UNKNOWN,
331
+ structured_output=False,
332
+ multiple_system_messages=True)
333
+
334
+ # Add required AutoGen 0.7.4 parameters
335
+ config_obj.update({"model_info": model_info})
336
+ config_obj.pop("model", None)
337
+
338
+ # NIM uses OpenAI-compatible API
339
+ client = OpenAIChatCompletionClient(model=llm_config.model_name, **config_obj)
340
+
341
+ # Patch to remove 'strict' field from tools (NIM doesn't support it)
342
+ client = _patch_nim_client_for_tools(client)
343
+
344
+ try:
345
+ # Apply NAT mixins and yield patched client
346
+ yield _patch_autogen_client_based_on_config(client, llm_config)
347
+ finally:
348
+ await _close_autogen_client(client)
349
+
350
+
351
+ @register_llm_client(config_type=LiteLlmModelConfig, wrapper_type=LLMFrameworkEnum.AUTOGEN)
352
+ async def litellm_autogen(llm_config: LiteLlmModelConfig, _builder: Builder) -> AsyncGenerator[ModelType, None]:
353
+ """Create LiteLLM client for AutoGen integration.
354
+
355
+ LiteLLM provides a unified interface to multiple LLM providers. This integration
356
+ uses AutoGen's OpenAI-compatible client since LiteLLM exposes an OpenAI-compatible
357
+ API endpoint.
358
+
359
+ Args:
360
+ llm_config (LiteLlmModelConfig): LiteLLM model configuration
361
+ _builder (Builder): NAT builder instance
362
+
363
+ Yields:
364
+ AsyncGenerator[ModelType, None]: Configured AutoGen client via LiteLLM
365
+ """
366
+ from autogen_core.models import ModelFamily
367
+ from autogen_core.models import ModelInfo
368
+ from autogen_ext.models.openai import OpenAIChatCompletionClient
369
+
370
+ # Extract LiteLLM configuration for OpenAI-compatible client
371
+ config_obj = {
372
+ **llm_config.model_dump(
373
+ exclude={"type", "model_name", "thinking"},
374
+ by_alias=True,
375
+ exclude_none=True,
376
+ ),
377
+ }
378
+
379
+ # Resolve API key from secret if provided
380
+ if llm_config.api_key is not None:
381
+ config_obj["api_key"] = get_secret_value(llm_config.api_key)
382
+
383
+ # Define model info for AutoGen
384
+ model_info = ModelInfo(vision=False,
385
+ function_calling=True,
386
+ json_output=True,
387
+ family=ModelFamily.UNKNOWN,
388
+ structured_output=True,
389
+ multiple_system_messages=True)
390
+
391
+ config_obj.update({"model_info": model_info})
392
+ config_obj.pop("model", None)
393
+
394
+ # LiteLLM uses OpenAI-compatible API
395
+ client = OpenAIChatCompletionClient(model=llm_config.model_name, **config_obj)
396
+
397
+ try:
398
+ # Apply NAT mixins and yield patched client
399
+ yield _patch_autogen_client_based_on_config(client, llm_config)
400
+ finally:
401
+ await _close_autogen_client(client)
402
+
403
+
404
+ @register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.AUTOGEN)
405
+ async def bedrock_autogen(llm_config: AWSBedrockModelConfig, _builder: Builder) -> AsyncGenerator[ModelType, None]:
406
+ """Create AWS Bedrock client for AutoGen integration.
407
+
408
+ Uses AutoGen's ``AnthropicBedrockChatCompletionClient`` which supports
409
+ Anthropic Claude models hosted on AWS Bedrock. Credentials are loaded in
410
+ the following priority:
411
+
412
+ 1. Explicit values from ``credentials_profile_name`` in the AWS profile.
413
+ 2. Standard environment variables (``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``,
414
+ ``AWS_SESSION_TOKEN``).
415
+ 3. Ambient credentials provided by the compute environment (IAM role).
416
+
417
+ Args:
418
+ llm_config (AWSBedrockModelConfig): AWS Bedrock model configuration
419
+ _builder (Builder): NAT builder instance
420
+
421
+ Yields:
422
+ AsyncGenerator[ModelType, None]: Configured AutoGen Bedrock client
423
+ """
424
+ from autogen_ext.models.anthropic import AnthropicBedrockChatCompletionClient
425
+
426
+ # Build Bedrock-specific configuration
427
+ bedrock_config: dict[str, Any] = {
428
+ "model": llm_config.model_name,
429
+ }
430
+
431
+ # Handle region - None or "None" string should use AWS default
432
+ if llm_config.region_name not in (None, "None"):
433
+ bedrock_config["aws_region"] = llm_config.region_name
434
+
435
+ # Add optional parameters if provided
436
+ if llm_config.credentials_profile_name is not None:
437
+ bedrock_config["aws_profile"] = llm_config.credentials_profile_name
438
+
439
+ if llm_config.base_url is not None:
440
+ bedrock_config["base_url"] = llm_config.base_url
441
+
442
+ # Add model parameters
443
+ if llm_config.max_tokens is not None:
444
+ bedrock_config["max_tokens"] = llm_config.max_tokens
445
+
446
+ if llm_config.temperature is not None:
447
+ bedrock_config["temperature"] = llm_config.temperature
448
+
449
+ if llm_config.top_p is not None:
450
+ bedrock_config["top_p"] = llm_config.top_p
451
+
452
+ # Create AutoGen Bedrock client
453
+ client = AnthropicBedrockChatCompletionClient(**bedrock_config)
454
+
455
+ try:
456
+ # Apply NAT mixins and yield patched client
457
+ yield _patch_autogen_client_based_on_config(client, llm_config)
458
+ finally:
459
+ await _close_autogen_client(client)
@@ -0,0 +1,22 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # flake8: noqa
17
+ # isort:skip_file
18
+ """AutoGen plugin registration for NAT components."""
19
+
20
+ from . import llm
21
+ from . import tool_wrapper
22
+ from . import callback_handler
@@ -0,0 +1,181 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tool wrapper for AutoGen integration with NAT."""
16
+
17
+ import logging
18
+ from collections.abc import AsyncIterator
19
+ from collections.abc import Callable
20
+ from dataclasses import is_dataclass
21
+
22
+ # PythonType not available in AutoGen 0.7.4, using Any instead
23
+ from typing import Any
24
+
25
+ from autogen_core.tools import FunctionTool
26
+ from pydantic import BaseModel
27
+ from pydantic.dataclasses import dataclass as pydantic_dataclass
28
+
29
+ from nat.builder.builder import Builder
30
+ from nat.builder.framework_enum import LLMFrameworkEnum
31
+ from nat.builder.function import Function
32
+ from nat.cli.register_workflow import register_tool_wrapper
33
+ from nat.utils.type_utils import DecomposedType
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ def resolve_type(t: Any) -> Any:
39
+ """Return the non-None member of a Union/PEP 604 union;
40
+ otherwise return the type unchanged.
41
+
42
+ Args:
43
+ t (Any): The type to resolve.
44
+
45
+ Returns:
46
+ Any: The resolved type.
47
+ """
48
+ resolved = DecomposedType(t)
49
+ if resolved.is_optional:
50
+ return resolved.get_optional_type().type
51
+ return resolved.type
52
+
53
+
54
+ @register_tool_wrapper(wrapper_type=LLMFrameworkEnum.AUTOGEN)
55
+ def autogen_tool_wrapper(
56
+ name: str,
57
+ fn: Function,
58
+ _builder: Builder # pylint: disable=W0613
59
+ ) -> Any: # Changed from Callable[..., Any] to Any to allow FunctionTool return
60
+ """Wrap a NAT `Function` as an AutoGen `FunctionTool`.
61
+
62
+ Args:
63
+ name (str): The name of the tool.
64
+ fn (Function): The NAT function to wrap.
65
+ _builder (Builder): The NAT workflow builder to access registered components.
66
+
67
+ Returns:
68
+ Any: The AutoGen FunctionTool wrapping the NAT function.
69
+ """
70
+
71
+ import inspect
72
+
73
+ async def callable_ainvoke(*args: Any, **kwargs: Any) -> Any:
74
+ """Async function to invoke the NAT function.
75
+
76
+ Args:
77
+ *args: Positional arguments to pass to the NAT function.
78
+ **kwargs: Keyword arguments to pass to the NAT function.
79
+ Returns:
80
+ Any: The result of invoking the NAT function.
81
+ """
82
+ return await fn.acall_invoke(*args, **kwargs)
83
+
84
+ async def callable_astream(*args: Any, **kwargs: Any) -> AsyncIterator[Any]:
85
+ """Async generator to stream results from the NAT function.
86
+
87
+ Args:
88
+ *args (Any): Positional arguments to pass to the NAT function.
89
+ **kwargs (Any): Keyword arguments to pass to the NAT function.
90
+ Yields:
91
+ Any: Streamed items from the NAT function.
92
+ """
93
+ async for item in fn.acall_stream(*args, **kwargs):
94
+ yield item
95
+
96
+ def nat_function(
97
+ func: Callable[..., Any] | None = None,
98
+ *,
99
+ name: str = name,
100
+ description: str | None = fn.description,
101
+ input_schema: Any = fn.input_schema,
102
+ ) -> Callable[..., Any]:
103
+ """
104
+ Decorator to wrap a function as a NAT function.
105
+
106
+ Args:
107
+ func (Callable): The function to wrap.
108
+ name (str): The name of the function.
109
+ description (str): The description of the function.
110
+ input_schema (BaseModel): The Pydantic model defining the input schema.
111
+
112
+ Returns:
113
+ Callable[..., Any]: The wrapped function.
114
+ """
115
+ if func is None:
116
+ raise ValueError("'func' must be provided.")
117
+
118
+ # If input_schema is a dataclass, convert it to a Pydantic model
119
+ if input_schema is not None and is_dataclass(input_schema):
120
+ input_schema = pydantic_dataclass(input_schema)
121
+
122
+ def decorator(func_to_wrap: Callable[..., Any]) -> Callable[..., Any]:
123
+ """
124
+ Decorator to set metadata on the function.
125
+ """
126
+ # Set the function's metadata
127
+ if name is not None:
128
+ func_to_wrap.__name__ = name
129
+ if description is not None:
130
+ func_to_wrap.__doc__ = description
131
+
132
+ # Set signature only if input_schema is provided
133
+ params: list[inspect.Parameter] = []
134
+ annotations: dict[str, Any] = {}
135
+
136
+ if input_schema is not None:
137
+ annotations = {}
138
+ params = []
139
+ model_fields = getattr(input_schema, "model_fields", {})
140
+ for param_name, model_field in model_fields.items():
141
+ resolved_type = resolve_type(model_field.annotation)
142
+
143
+ # Warn about nested Pydantic models or dataclasses that may not serialize properly
144
+ # Note: If autogen is updated to support nested models, this warning can be removed - or
145
+ # if autogen adds a mechanism to remove the tool from the function choices we can add that later.
146
+ if isinstance(resolved_type, type) and (issubclass(resolved_type, BaseModel)
147
+ or is_dataclass(resolved_type)):
148
+ logger.warning(
149
+ "Nested model detected in input schema for parameter '%s' in tool '%s'. "
150
+ "AutoGen may not properly serialize complex nested types for function calling. "
151
+ "Consider flattening the schema or using primitive types.",
152
+ param_name,
153
+ name,
154
+ )
155
+
156
+ default = inspect.Parameter.empty if model_field.is_required() else model_field.default
157
+ params.append(
158
+ inspect.Parameter(param_name,
159
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
160
+ annotation=resolved_type,
161
+ default=default))
162
+ annotations[param_name] = resolved_type
163
+ func_to_wrap.__signature__ = inspect.Signature(parameters=params)
164
+ func_to_wrap.__annotations__ = annotations
165
+
166
+ return func_to_wrap
167
+
168
+ # Apply the decorator to the provided function
169
+ return decorator(func)
170
+
171
+ if fn.has_streaming_output and not fn.has_single_output:
172
+ logger.debug("Creating streaming FunctionTool for: %s", name)
173
+ callable_tool = nat_function(func=callable_astream)
174
+ else:
175
+ logger.debug("Creating non-streaming FunctionTool for: %s", name)
176
+ callable_tool = nat_function(func=callable_ainvoke)
177
+ return FunctionTool(
178
+ func=callable_tool,
179
+ name=name,
180
+ description=fn.description or "No description provided.",
181
+ )