nvidia-nat 1.3.0a20250827__py3-none-any.whl → 1.3.0a20250828__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +6 -6
- nat/agent/dual_node.py +7 -2
- nat/agent/react_agent/agent.py +6 -1
- nat/agent/react_agent/register.py +4 -0
- nat/agent/rewoo_agent/agent.py +7 -2
- nat/agent/rewoo_agent/register.py +5 -1
- nat/agent/tool_calling_agent/agent.py +6 -1
- nat/agent/tool_calling_agent/register.py +4 -0
- nat/builder/context.py +7 -2
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/entrypoint.py +3 -1
- nat/data_models/gated_field_mixin.py +12 -14
- nat/data_models/temperature_mixin.py +1 -1
- nat/data_models/thinking_mixin.py +68 -0
- nat/data_models/top_p_mixin.py +1 -1
- nat/llm/aws_bedrock_llm.py +10 -9
- nat/llm/azure_openai_llm.py +9 -1
- nat/llm/nim_llm.py +2 -1
- nat/llm/openai_llm.py +2 -1
- nat/llm/utils/thinking.py +215 -0
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/profiler/decorators/function_tracking.py +125 -0
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/METADATA +3 -1
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/RECORD +31 -25
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/top_level.txt +0 -0
nat/llm/aws_bedrock_llm.py
CHANGED
|
@@ -23,9 +23,11 @@ from nat.cli.register_workflow import register_llm_provider
|
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
24
|
from nat.data_models.retry_mixin import RetryMixin
|
|
25
25
|
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
26
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
27
|
+
from nat.data_models.top_p_mixin import TopPMixin
|
|
26
28
|
|
|
27
29
|
|
|
28
|
-
class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, name="aws_bedrock"):
|
|
30
|
+
class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, ThinkingMixin, name="aws_bedrock"):
|
|
29
31
|
"""An AWS Bedrock llm provider to be used with an LLM client."""
|
|
30
32
|
|
|
31
33
|
model_config = ConfigDict(protected_namespaces=())
|
|
@@ -34,14 +36,13 @@ class AWSBedrockModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, name="a
|
|
|
34
36
|
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
35
37
|
serialization_alias="model",
|
|
36
38
|
description="The model name for the hosted AWS Bedrock.")
|
|
37
|
-
max_tokens: int | None = Field(default=1024,
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
"This field is ONLY required when using AWS Bedrock with LlamaIndex.")
|
|
39
|
+
max_tokens: int | None = Field(default=1024, gt=0, description="Maximum number of tokens to generate.")
|
|
40
|
+
context_size: int | None = Field(
|
|
41
|
+
default=1024,
|
|
42
|
+
gt=0,
|
|
43
|
+
description="The maximum number of tokens available for input. This is only required for LlamaIndex. "
|
|
44
|
+
"This field is ignored for Langchain.",
|
|
45
|
+
)
|
|
45
46
|
|
|
46
47
|
# Client parameters
|
|
47
48
|
region_name: str | None = Field(default="None", description="AWS region to use.")
|
nat/llm/azure_openai_llm.py
CHANGED
|
@@ -23,10 +23,18 @@ from nat.cli.register_workflow import register_llm_provider
|
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
24
|
from nat.data_models.retry_mixin import RetryMixin
|
|
25
25
|
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
26
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
26
27
|
from nat.data_models.top_p_mixin import TopPMixin
|
|
27
28
|
|
|
28
29
|
|
|
29
|
-
class AzureOpenAIModelConfig(
|
|
30
|
+
class AzureOpenAIModelConfig(
|
|
31
|
+
LLMBaseConfig,
|
|
32
|
+
RetryMixin,
|
|
33
|
+
TemperatureMixin,
|
|
34
|
+
TopPMixin,
|
|
35
|
+
ThinkingMixin,
|
|
36
|
+
name="azure_openai",
|
|
37
|
+
):
|
|
30
38
|
"""An Azure OpenAI LLM provider to be used with an LLM client."""
|
|
31
39
|
|
|
32
40
|
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
nat/llm/nim_llm.py
CHANGED
|
@@ -24,10 +24,11 @@ from nat.cli.register_workflow import register_llm_provider
|
|
|
24
24
|
from nat.data_models.llm import LLMBaseConfig
|
|
25
25
|
from nat.data_models.retry_mixin import RetryMixin
|
|
26
26
|
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
27
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
27
28
|
from nat.data_models.top_p_mixin import TopPMixin
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
class NIMModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, name="nim"):
|
|
31
|
+
class NIMModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, ThinkingMixin, name="nim"):
|
|
31
32
|
"""An NVIDIA Inference Microservice (NIM) llm provider to be used with an LLM client."""
|
|
32
33
|
|
|
33
34
|
model_config = ConfigDict(protected_namespaces=())
|
nat/llm/openai_llm.py
CHANGED
|
@@ -23,10 +23,11 @@ from nat.cli.register_workflow import register_llm_provider
|
|
|
23
23
|
from nat.data_models.llm import LLMBaseConfig
|
|
24
24
|
from nat.data_models.retry_mixin import RetryMixin
|
|
25
25
|
from nat.data_models.temperature_mixin import TemperatureMixin
|
|
26
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
|
26
27
|
from nat.data_models.top_p_mixin import TopPMixin
|
|
27
28
|
|
|
28
29
|
|
|
29
|
-
class OpenAIModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, name="openai"):
|
|
30
|
+
class OpenAIModelConfig(LLMBaseConfig, RetryMixin, TemperatureMixin, TopPMixin, ThinkingMixin, name="openai"):
|
|
30
31
|
"""An OpenAI LLM provider to be used with an LLM client."""
|
|
31
32
|
|
|
32
33
|
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
import inspect
|
|
18
|
+
import logging
|
|
19
|
+
import types
|
|
20
|
+
from abc import abstractmethod
|
|
21
|
+
from collections.abc import AsyncGenerator
|
|
22
|
+
from collections.abc import Iterable
|
|
23
|
+
from dataclasses import dataclass
|
|
24
|
+
from typing import Any
|
|
25
|
+
from typing import Callable
|
|
26
|
+
from typing import TypeVar
|
|
27
|
+
|
|
28
|
+
ModelType = TypeVar("ModelType")
|
|
29
|
+
MessagesType = TypeVar("MessagesType")
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class FunctionArgumentWrapper:
|
|
35
|
+
"""
|
|
36
|
+
Wrapper for the arguments and keyword arguments of a function.
|
|
37
|
+
|
|
38
|
+
The arguments and keyword arguments are stored in the args and kwargs attributes, respectively.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, *args: Any, **kwargs: Any):
|
|
42
|
+
"""
|
|
43
|
+
Initialize the FunctionArgumentWrapper.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
args: The arguments to the function.
|
|
47
|
+
kwargs: The keyword arguments to the function.
|
|
48
|
+
"""
|
|
49
|
+
self.args = args
|
|
50
|
+
self.kwargs = kwargs
|
|
51
|
+
|
|
52
|
+
def __repr__(self) -> str:
|
|
53
|
+
return f"FunctionArgumentWrapper(args={self.args}, kwargs={self.kwargs})"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class BaseThinkingInjector:
|
|
58
|
+
"""
|
|
59
|
+
Base class for thinking injectors.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
system_prompt: The system prompt to inject.
|
|
63
|
+
function_names: The function names to inject the system prompt into.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
system_prompt: str
|
|
67
|
+
function_names: list[str]
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def inject(self, *args, **kwargs) -> FunctionArgumentWrapper:
|
|
71
|
+
"""
|
|
72
|
+
Inject the system prompt into the arguments.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
args: The arguments to inject the system prompt into.
|
|
76
|
+
kwargs: The keyword arguments to inject the system prompt into.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
FunctionArgumentWrapper: An object that contains the transformed args and kwargs.
|
|
80
|
+
"""
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _make_thinking_decorator(injector: BaseThinkingInjector):
|
|
85
|
+
|
|
86
|
+
def decorate(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
87
|
+
|
|
88
|
+
async def _call_async(obj: object, *call_args, **call_kwargs) -> Any:
|
|
89
|
+
new_args = injector.inject(*call_args, **call_kwargs)
|
|
90
|
+
return await fn(obj, *new_args.args, **new_args.kwargs)
|
|
91
|
+
|
|
92
|
+
async def _agen(obj: object, *call_args, **call_kwargs) -> AsyncGenerator[Any, None]:
|
|
93
|
+
new_args = injector.inject(*call_args, **call_kwargs)
|
|
94
|
+
async for item in fn(obj, *new_args.args, **new_args.kwargs):
|
|
95
|
+
yield item
|
|
96
|
+
|
|
97
|
+
def _gen(obj: object, *call_args, **call_kwargs) -> Iterable[Any]:
|
|
98
|
+
new_args = injector.inject(*call_args, **call_kwargs)
|
|
99
|
+
yield from fn(obj, *new_args.args, **new_args.kwargs)
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
def _sync(obj: object, *call_args, **call_kwargs) -> Any:
|
|
103
|
+
new_args = injector.inject(*call_args, **call_kwargs)
|
|
104
|
+
return fn(obj, *new_args.args, **new_args.kwargs)
|
|
105
|
+
|
|
106
|
+
# Decide which wrapper to return
|
|
107
|
+
if inspect.iscoroutinefunction(fn):
|
|
108
|
+
wrapper = _call_async
|
|
109
|
+
elif inspect.isasyncgenfunction(fn):
|
|
110
|
+
wrapper = _agen
|
|
111
|
+
elif inspect.isgeneratorfunction(fn):
|
|
112
|
+
wrapper = _gen
|
|
113
|
+
else:
|
|
114
|
+
wrapper = _sync
|
|
115
|
+
|
|
116
|
+
return functools.wraps(fn)(wrapper)
|
|
117
|
+
|
|
118
|
+
return decorate
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def patch_with_thinking(obj: ModelType, injector: BaseThinkingInjector) -> ModelType:
|
|
122
|
+
"""
|
|
123
|
+
Patch the given object with a decorator that injects a system prompt into the supplied messages.
|
|
124
|
+
There is an assumption that the first non-object argument is the messages.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
obj: The object to patch.
|
|
128
|
+
injector: The injector to use.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
The patched object.
|
|
132
|
+
|
|
133
|
+
Examples:
|
|
134
|
+
>>> from nat.llm.utils.thinking import BaseThinkingInjector
|
|
135
|
+
>>> from nat.llm.utils.thinking import FunctionArgumentWrapper
|
|
136
|
+
>>> from nat.llm.utils.thinking import patch_with_thinking
|
|
137
|
+
>>>
|
|
138
|
+
>>> class MockClass:
|
|
139
|
+
... def sync_method(self, *args, **kwargs):
|
|
140
|
+
... return (args, kwargs)
|
|
141
|
+
...
|
|
142
|
+
>>> mock_obj_1 = MockClass()
|
|
143
|
+
>>> class AddThinking(BaseThinkingInjector):
|
|
144
|
+
... def inject(self, x: str, *args, **kwargs) -> FunctionArgumentWrapper:
|
|
145
|
+
... return FunctionArgumentWrapper(("thinking " + x), *args, **kwargs)
|
|
146
|
+
>>>
|
|
147
|
+
>>> patched_obj = patch_with_thinking(mock_obj_1, AddThinking(
|
|
148
|
+
... system_prompt="thinking",
|
|
149
|
+
... function_names=["sync_method"],
|
|
150
|
+
... ))
|
|
151
|
+
>>> patched_obj.sync_method("test", 1, 2, 3, foo="bar")
|
|
152
|
+
(('thinking test', 1, 2, 3), {'foo': 'bar'})
|
|
153
|
+
>>>
|
|
154
|
+
>>> mock_obj_2 = MockClass()
|
|
155
|
+
>>> class AddThinkingWithArgs(BaseThinkingInjector):
|
|
156
|
+
... def inject(self, *args, **kwargs) -> FunctionArgumentWrapper:
|
|
157
|
+
... return FunctionArgumentWrapper("thinking", *args, **kwargs)
|
|
158
|
+
>>>
|
|
159
|
+
>>> patched_obj = patch_with_thinking(mock_obj_2, AddThinkingWithArgs(
|
|
160
|
+
... system_prompt="thinking",
|
|
161
|
+
... function_names=["sync_method"],
|
|
162
|
+
... ))
|
|
163
|
+
>>> patched_obj.sync_method("test", 1, 2, 3, foo="bar")
|
|
164
|
+
(('thinking', 'test', 1, 2, 3), {'foo': 'bar'})
|
|
165
|
+
>>>
|
|
166
|
+
>>> mock_obj_3 = MockClass()
|
|
167
|
+
>>> class AddThinkingWithKwargs(BaseThinkingInjector):
|
|
168
|
+
... def inject(self, *args, **kwargs) -> FunctionArgumentWrapper:
|
|
169
|
+
... return FunctionArgumentWrapper(*args, thinking=True, **kwargs)
|
|
170
|
+
>>>
|
|
171
|
+
>>> patched_obj = patch_with_thinking(mock_obj_3, AddThinkingWithKwargs(
|
|
172
|
+
... system_prompt="thinking",
|
|
173
|
+
... function_names=["sync_method"],
|
|
174
|
+
... ))
|
|
175
|
+
>>> patched_obj.sync_method("test", 1, 2, 3, foo="bar")
|
|
176
|
+
(('test', 1, 2, 3), {'thinking': True, 'foo': 'bar'})
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
decorator = _make_thinking_decorator(injector)
|
|
180
|
+
|
|
181
|
+
cls = obj if inspect.isclass(obj) else type(obj)
|
|
182
|
+
cls_name = getattr(cls, "__name__", str(cls))
|
|
183
|
+
|
|
184
|
+
for name, _ in inspect.getmembers(cls, callable):
|
|
185
|
+
if name not in injector.function_names:
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
descriptor = inspect.getattr_static(cls, name)
|
|
189
|
+
original = descriptor.__func__ if isinstance(descriptor, types.MethodType) else descriptor
|
|
190
|
+
wrapped = decorator(original)
|
|
191
|
+
|
|
192
|
+
try: # instance‑level first
|
|
193
|
+
if not inspect.isclass(obj):
|
|
194
|
+
object.__setattr__(obj, name, types.MethodType(wrapped, obj))
|
|
195
|
+
continue
|
|
196
|
+
except Exception as exc:
|
|
197
|
+
logger.info(
|
|
198
|
+
"Instance‑level patch failed for %s.%s (%s); "
|
|
199
|
+
"falling back to class‑level patch.",
|
|
200
|
+
cls_name,
|
|
201
|
+
name,
|
|
202
|
+
exc,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
try: # class‑level fallback
|
|
206
|
+
setattr(cls, name, wrapped)
|
|
207
|
+
except Exception as exc:
|
|
208
|
+
logger.info(
|
|
209
|
+
"Cannot patch method %s.%s with thinking: %s",
|
|
210
|
+
cls_name,
|
|
211
|
+
name,
|
|
212
|
+
exc,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return obj
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from typing import TypeVar
|
|
18
|
+
|
|
19
|
+
from nat.observability.processor.processor import Processor
|
|
20
|
+
from nat.utils.type_utils import override
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
FalsyT = TypeVar("FalsyT")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FalsyBatchFilterProcessor(Processor[list[FalsyT], list[FalsyT]]):
|
|
28
|
+
"""Processor that filters out falsy items from a batch."""
|
|
29
|
+
|
|
30
|
+
@override
|
|
31
|
+
async def process(self, item: list[FalsyT]) -> list[FalsyT]:
|
|
32
|
+
"""Filter out falsy items from a batch.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
item (list[FalsyT]): The batch of items to filter.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
list[FalsyT]: The filtered batch.
|
|
39
|
+
"""
|
|
40
|
+
return [i for i in item if i]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class DictBatchFilterProcessor(FalsyBatchFilterProcessor[dict]):
|
|
44
|
+
"""Processor that filters out empty dict items from a batch."""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ListBatchFilterProcessor(FalsyBatchFilterProcessor[list]):
|
|
49
|
+
"""Processor that filters out empty list items from a batch."""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class SetBatchFilterProcessor(FalsyBatchFilterProcessor[set]):
|
|
54
|
+
"""Processor that filters out empty set items from a batch."""
|
|
55
|
+
pass
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from nat.observability.processor.processor import Processor
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def processor_factory(processor_class: type, from_type: type[Any], to_type: type[Any]) -> type[Processor]:
|
|
22
|
+
"""Create a concrete processor class from a processor class and types.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
processor_class (type): The processor class to create a concrete instance of
|
|
26
|
+
from_type (type[Any]): The type of the input data
|
|
27
|
+
to_type (type[Any]): The type of the output data
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
type[Processor]: The concrete processor class
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
class ConcreteProcessor(processor_class[from_type, to_type]): # type: ignore
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
return ConcreteProcessor
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def processor_factory_from_type(processor_class: type, from_type: type[Any]) -> type[Processor]:
|
|
40
|
+
"""Create a concrete processor class from a processor class and input type.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
processor_class (type): The processor class to create a concrete instance of
|
|
44
|
+
from_type (type[Any]): The type of the input data
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
type[Processor]: The concrete processor class
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
class ConcreteProcessor(processor_class[from_type]): # type: ignore
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
return ConcreteProcessor
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def processor_factory_to_type(processor_class: type, to_type: type[Any]) -> type[Processor]:
|
|
57
|
+
"""Create a concrete processor class from a processor class and output type.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
processor_class (type): The processor class to create a concrete instance of
|
|
61
|
+
to_type (type[Any]): The type of the output data
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
type[Processor]: The concrete processor class
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
class ConcreteProcessor(processor_class[to_type]): # type: ignore
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
return ConcreteProcessor
|
|
@@ -16,7 +16,9 @@
|
|
|
16
16
|
import functools
|
|
17
17
|
import inspect
|
|
18
18
|
import uuid
|
|
19
|
+
from collections.abc import Callable
|
|
19
20
|
from typing import Any
|
|
21
|
+
from typing import cast
|
|
20
22
|
|
|
21
23
|
from pydantic import BaseModel
|
|
22
24
|
|
|
@@ -252,3 +254,126 @@ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
|
|
|
252
254
|
return result
|
|
253
255
|
|
|
254
256
|
return sync_wrapper
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def track_unregistered_function(func: Callable[..., Any] | None = None,
|
|
260
|
+
*,
|
|
261
|
+
name: str | None = None,
|
|
262
|
+
metadata: dict[str, Any] | None = None) -> Callable[..., Any]:
|
|
263
|
+
"""
|
|
264
|
+
Decorator that wraps any function with scope management and automatic tracking.
|
|
265
|
+
|
|
266
|
+
- Sets active function context using the function name
|
|
267
|
+
- Leverages Context.push_active_function for built-in tracking
|
|
268
|
+
- Avoids duplicate tracking entries by relying on the library's built-in systems
|
|
269
|
+
- Supports sync/async functions and generators
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
func: The function to wrap (auto-detected when used without parentheses)
|
|
273
|
+
name: Custom name to use for tracking instead of func.__name__
|
|
274
|
+
metadata: Additional metadata to include in tracking
|
|
275
|
+
"""
|
|
276
|
+
|
|
277
|
+
# If called with parameters: @track_unregistered_function(name="...", metadata={...})
|
|
278
|
+
if func is None:
|
|
279
|
+
|
|
280
|
+
def decorator_wrapper(actual_func: Callable[..., Any]) -> Callable[..., Any]:
|
|
281
|
+
# Cast to ensure type checker understands this returns a callable
|
|
282
|
+
return cast(Callable[..., Any], track_unregistered_function(actual_func, name=name, metadata=metadata))
|
|
283
|
+
|
|
284
|
+
return decorator_wrapper
|
|
285
|
+
|
|
286
|
+
# Direct decoration: @track_unregistered_function or recursive call with actual function
|
|
287
|
+
function_name: str = name if name else func.__name__
|
|
288
|
+
|
|
289
|
+
# --- Validate metadata ---
|
|
290
|
+
if metadata is not None:
|
|
291
|
+
if not isinstance(metadata, dict):
|
|
292
|
+
raise TypeError("metadata must be a dict[str, Any].")
|
|
293
|
+
if any(not isinstance(k, str) for k in metadata.keys()):
|
|
294
|
+
raise TypeError("All metadata keys must be strings.")
|
|
295
|
+
|
|
296
|
+
trace_metadata = TraceMetadata(provided_metadata=metadata)
|
|
297
|
+
|
|
298
|
+
# --- Now detect the function type and wrap accordingly ---
|
|
299
|
+
if inspect.isasyncgenfunction(func):
|
|
300
|
+
# ---------------------
|
|
301
|
+
# ASYNC GENERATOR
|
|
302
|
+
# ---------------------
|
|
303
|
+
|
|
304
|
+
@functools.wraps(func)
|
|
305
|
+
async def async_gen_wrapper(*args, **kwargs):
|
|
306
|
+
context = Context.get()
|
|
307
|
+
input_data = (
|
|
308
|
+
*args,
|
|
309
|
+
kwargs,
|
|
310
|
+
)
|
|
311
|
+
# Only do context management - let push_active_function handle tracking
|
|
312
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
313
|
+
final_outputs = []
|
|
314
|
+
async for item in func(*args, **kwargs):
|
|
315
|
+
final_outputs.append(item)
|
|
316
|
+
yield item
|
|
317
|
+
|
|
318
|
+
manager.set_output(final_outputs)
|
|
319
|
+
|
|
320
|
+
return async_gen_wrapper
|
|
321
|
+
|
|
322
|
+
if inspect.iscoroutinefunction(func):
|
|
323
|
+
# ---------------------
|
|
324
|
+
# ASYNC FUNCTION
|
|
325
|
+
# ---------------------
|
|
326
|
+
@functools.wraps(func)
|
|
327
|
+
async def async_wrapper(*args, **kwargs):
|
|
328
|
+
context = Context.get()
|
|
329
|
+
input_data = (
|
|
330
|
+
*args,
|
|
331
|
+
kwargs,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Only do context management - let push_active_function handle tracking
|
|
335
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
336
|
+
result = await func(*args, **kwargs)
|
|
337
|
+
manager.set_output(result)
|
|
338
|
+
return result
|
|
339
|
+
|
|
340
|
+
return async_wrapper
|
|
341
|
+
|
|
342
|
+
if inspect.isgeneratorfunction(func):
|
|
343
|
+
# ---------------------
|
|
344
|
+
# SYNC GENERATOR
|
|
345
|
+
# ---------------------
|
|
346
|
+
@functools.wraps(func)
|
|
347
|
+
def sync_gen_wrapper(*args, **kwargs):
|
|
348
|
+
context = Context.get()
|
|
349
|
+
input_data = (
|
|
350
|
+
*args,
|
|
351
|
+
kwargs,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Only do context management - let push_active_function handle tracking
|
|
355
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
356
|
+
final_outputs = []
|
|
357
|
+
for item in func(*args, **kwargs):
|
|
358
|
+
final_outputs.append(item)
|
|
359
|
+
yield item
|
|
360
|
+
|
|
361
|
+
manager.set_output(final_outputs)
|
|
362
|
+
|
|
363
|
+
return sync_gen_wrapper
|
|
364
|
+
|
|
365
|
+
@functools.wraps(func)
|
|
366
|
+
def sync_wrapper(*args, **kwargs):
|
|
367
|
+
context = Context.get()
|
|
368
|
+
input_data = (
|
|
369
|
+
*args,
|
|
370
|
+
kwargs,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# Only do context management - let push_active_function handle tracking
|
|
374
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
375
|
+
result = func(*args, **kwargs)
|
|
376
|
+
manager.set_output(result)
|
|
377
|
+
return result
|
|
378
|
+
|
|
379
|
+
return sync_wrapper
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nvidia-nat
|
|
3
|
-
Version: 1.3.
|
|
3
|
+
Version: 1.3.0a20250828
|
|
4
4
|
Summary: NVIDIA NeMo Agent toolkit
|
|
5
5
|
Author: NVIDIA Corporation
|
|
6
6
|
Maintainer: NVIDIA Corporation
|
|
@@ -243,6 +243,8 @@ Provides-Extra: agno
|
|
|
243
243
|
Requires-Dist: nvidia-nat-agno; extra == "agno"
|
|
244
244
|
Provides-Extra: crewai
|
|
245
245
|
Requires-Dist: nvidia-nat-crewai; extra == "crewai"
|
|
246
|
+
Provides-Extra: data-flywheel
|
|
247
|
+
Requires-Dist: nvidia-nat-data-flywheel; extra == "data-flywheel"
|
|
246
248
|
Provides-Extra: ingestion
|
|
247
249
|
Requires-Dist: nvidia-nat-ingestion; extra == "ingestion"
|
|
248
250
|
Provides-Extra: langchain
|