nvidia-nat 1.4.0a20251102__py3-none-any.whl → 1.4.0a20251120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/builder/builder.py +52 -0
- nat/builder/component_utils.py +7 -1
- nat/builder/context.py +17 -0
- nat/builder/framework_enum.py +1 -0
- nat/builder/function.py +74 -3
- nat/builder/workflow.py +4 -2
- nat/builder/workflow_builder.py +129 -0
- nat/cli/commands/workflow/workflow_commands.py +3 -2
- nat/cli/register_workflow.py +50 -0
- nat/cli/type_registry.py +68 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +16 -0
- nat/data_models/function.py +14 -1
- nat/data_models/middleware.py +35 -0
- nat/data_models/runtime_enum.py +26 -0
- nat/eval/dataset_handler/dataset_filter.py +34 -2
- nat/eval/evaluate.py +11 -3
- nat/eval/utils/weave_eval.py +17 -3
- nat/front_ends/fastapi/fastapi_front_end_config.py +29 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +13 -7
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +144 -14
- nat/front_ends/mcp/mcp_front_end_plugin.py +4 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +26 -0
- nat/llm/aws_bedrock_llm.py +11 -9
- nat/llm/azure_openai_llm.py +12 -4
- nat/llm/litellm_llm.py +11 -4
- nat/llm/nim_llm.py +11 -9
- nat/llm/openai_llm.py +12 -9
- nat/middleware/__init__.py +35 -0
- nat/middleware/cache_middleware.py +256 -0
- nat/middleware/function_middleware.py +186 -0
- nat/middleware/middleware.py +184 -0
- nat/middleware/register.py +35 -0
- nat/profiler/decorators/framework_wrapper.py +16 -0
- nat/retriever/milvus/register.py +11 -3
- nat/retriever/milvus/retriever.py +102 -40
- nat/runtime/runner.py +12 -1
- nat/runtime/session.py +10 -3
- nat/tool/code_execution/code_sandbox.py +4 -7
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +5 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
- nat/tool/server_tools.py +15 -2
- nat/utils/__init__.py +8 -4
- nat/utils/io/yaml_tools.py +73 -3
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/METADATA +11 -3
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/RECORD +54 -50
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/entry_points.txt +1 -0
- nat/data_models/temperature_mixin.py +0 -44
- nat/data_models/top_p_mixin.py +0 -44
- nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Cache middleware for function memoization with similarity matching.
|
|
16
|
+
|
|
17
|
+
This module provides a cache middleware that memoizes function calls based on
|
|
18
|
+
input similarity. It demonstrates the middleware pattern by:
|
|
19
|
+
|
|
20
|
+
1. Preprocessing: Serializing and checking the cache for similar inputs
|
|
21
|
+
2. Calling next: Delegating to the next middleware/function if no cache hit
|
|
22
|
+
3. Postprocessing: Caching the result for future use
|
|
23
|
+
4. Continuing: Returning the result (cached or fresh)
|
|
24
|
+
|
|
25
|
+
The cache supports exact matching for maximum performance and fuzzy matching
|
|
26
|
+
using Python's built-in difflib for similarity computation.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import json
|
|
32
|
+
import logging
|
|
33
|
+
from collections.abc import AsyncIterator
|
|
34
|
+
from typing import Any
|
|
35
|
+
from typing import Literal
|
|
36
|
+
|
|
37
|
+
from pydantic import Field
|
|
38
|
+
|
|
39
|
+
from nat.builder.context import Context
|
|
40
|
+
from nat.builder.context import ContextState
|
|
41
|
+
from nat.data_models.middleware import FunctionMiddlewareBaseConfig
|
|
42
|
+
from nat.middleware.function_middleware import CallNext
|
|
43
|
+
from nat.middleware.function_middleware import CallNextStream
|
|
44
|
+
from nat.middleware.function_middleware import FunctionMiddleware
|
|
45
|
+
from nat.middleware.function_middleware import FunctionMiddlewareContext
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class CacheMiddleware(FunctionMiddleware):
|
|
51
|
+
"""Cache middleware that memoizes function outputs based on input similarity.
|
|
52
|
+
|
|
53
|
+
This middleware demonstrates the four-phase middleware pattern:
|
|
54
|
+
|
|
55
|
+
1. **Preprocess**: Serialize input and check cache for similar entries
|
|
56
|
+
2. **Call Next**: Delegate to next middleware/function if cache miss
|
|
57
|
+
3. **Postprocess**: Store the result in cache for future use
|
|
58
|
+
4. **Continue**: Return the result (from cache or fresh)
|
|
59
|
+
|
|
60
|
+
The cache serializes function inputs to strings and performs similarity
|
|
61
|
+
matching against previously seen inputs. If a similar input is found above
|
|
62
|
+
the configured threshold, it returns the cached output without calling the
|
|
63
|
+
next middleware or function.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
enabled_mode: Either "always" to always cache, or "eval" to only
|
|
67
|
+
cache when Context.is_evaluating is True.
|
|
68
|
+
similarity_threshold: Float between 0 and 1. If 1.0, performs
|
|
69
|
+
exact string matching. Otherwise uses difflib for similarity
|
|
70
|
+
computation.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, *, enabled_mode: str, similarity_threshold: float) -> None:
|
|
74
|
+
"""Initialize the cache middleware.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
enabled_mode: Either "always" or "eval". If "eval", only caches
|
|
78
|
+
when Context.is_evaluating is True.
|
|
79
|
+
similarity_threshold: Similarity threshold between 0 and 1.
|
|
80
|
+
If 1.0, performs exact matching. Otherwise uses fuzzy matching.
|
|
81
|
+
"""
|
|
82
|
+
super().__init__(is_final=True)
|
|
83
|
+
self._enabled_mode = enabled_mode
|
|
84
|
+
self._similarity_threshold = similarity_threshold
|
|
85
|
+
self._cache: dict[str, Any] = {}
|
|
86
|
+
|
|
87
|
+
def _should_cache(self) -> bool:
|
|
88
|
+
"""Check if caching should be enabled based on the current context."""
|
|
89
|
+
if self._enabled_mode == "always":
|
|
90
|
+
return True
|
|
91
|
+
|
|
92
|
+
# Get the current context and check if we're in evaluation mode
|
|
93
|
+
try:
|
|
94
|
+
context_state = ContextState.get()
|
|
95
|
+
context = Context(context_state)
|
|
96
|
+
return context.is_evaluating
|
|
97
|
+
except Exception:
|
|
98
|
+
logger.warning("Failed to get context for cache decision", exc_info=True)
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
def _serialize_input(self, value: Any) -> str | None:
|
|
102
|
+
"""Serialize the input value to a string for caching.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
value: The input value to serialize.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
String representation of the input, or None if serialization
|
|
109
|
+
fails.
|
|
110
|
+
"""
|
|
111
|
+
try:
|
|
112
|
+
# Try JSON serialization first for best results
|
|
113
|
+
return json.dumps(value, sort_keys=True, default=str)
|
|
114
|
+
except Exception:
|
|
115
|
+
logger.debug("Failed to serialize input for caching", exc_info=True)
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
def _find_similar_key(self, input_str: str) -> str | None:
|
|
119
|
+
"""Find a cached key that is similar to the input string.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
input_str: The serialized input string to match.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
The most similar cached key if above threshold, None otherwise.
|
|
126
|
+
"""
|
|
127
|
+
if self._similarity_threshold == 1.0:
|
|
128
|
+
# Exact matching - fast path
|
|
129
|
+
return input_str if input_str in self._cache else None
|
|
130
|
+
|
|
131
|
+
# Fuzzy matching using difflib
|
|
132
|
+
import difflib
|
|
133
|
+
|
|
134
|
+
best_match = None
|
|
135
|
+
best_ratio = 0.0
|
|
136
|
+
|
|
137
|
+
for cached_key in self._cache:
|
|
138
|
+
# Use SequenceMatcher for similarity computation
|
|
139
|
+
matcher = difflib.SequenceMatcher(None, input_str, cached_key)
|
|
140
|
+
ratio = matcher.ratio()
|
|
141
|
+
|
|
142
|
+
if ratio >= self._similarity_threshold and ratio > best_ratio:
|
|
143
|
+
best_ratio = ratio
|
|
144
|
+
best_match = cached_key
|
|
145
|
+
|
|
146
|
+
return best_match
|
|
147
|
+
|
|
148
|
+
async def function_middleware_invoke(self, value: Any, call_next: CallNext,
|
|
149
|
+
context: FunctionMiddlewareContext) -> Any:
|
|
150
|
+
"""Cache middleware for single-output invocations.
|
|
151
|
+
|
|
152
|
+
Implements the four-phase middleware pattern:
|
|
153
|
+
|
|
154
|
+
1. **Preprocess**: Check if caching is enabled and serialize input
|
|
155
|
+
2. **Call Next**: Delegate to next middleware/function if cache miss
|
|
156
|
+
3. **Postprocess**: Store the result in cache
|
|
157
|
+
4. **Continue**: Return the result (cached or fresh)
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
value: The input value to process
|
|
161
|
+
call_next: Callable to invoke the next middleware or function
|
|
162
|
+
context: Metadata about the function being wrapped
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
The cached output if found, otherwise the fresh output
|
|
166
|
+
"""
|
|
167
|
+
# Phase 1: Preprocess - check if caching should be enabled
|
|
168
|
+
if not self._should_cache():
|
|
169
|
+
return await call_next(value)
|
|
170
|
+
|
|
171
|
+
# Phase 1: Preprocess - serialize the input
|
|
172
|
+
input_str = self._serialize_input(value)
|
|
173
|
+
if input_str is None:
|
|
174
|
+
# Can't serialize, pass through to next middleware/function
|
|
175
|
+
logger.debug("Could not serialize input for function %s, bypassing cache", context.name)
|
|
176
|
+
return await call_next(value)
|
|
177
|
+
|
|
178
|
+
# Phase 1: Preprocess - look for a similar cached input
|
|
179
|
+
similar_key = self._find_similar_key(input_str)
|
|
180
|
+
if similar_key is not None:
|
|
181
|
+
# Cache hit - short-circuit and return cached output
|
|
182
|
+
logger.debug("Cache hit for function %s with similarity %.2f",
|
|
183
|
+
context.name,
|
|
184
|
+
1.0 if similar_key == input_str else self._similarity_threshold)
|
|
185
|
+
# Phase 4: Continue - return cached result
|
|
186
|
+
return self._cache[similar_key]
|
|
187
|
+
|
|
188
|
+
# Phase 2: Call next - no cache hit, call next middleware/function
|
|
189
|
+
logger.debug("Cache miss for function %s", context.name)
|
|
190
|
+
result = await call_next(value)
|
|
191
|
+
|
|
192
|
+
# Phase 3: Postprocess - cache the result for future use
|
|
193
|
+
self._cache[input_str] = result
|
|
194
|
+
logger.debug("Cached result for function %s", context.name)
|
|
195
|
+
|
|
196
|
+
# Phase 4: Continue - return the fresh result
|
|
197
|
+
return result
|
|
198
|
+
|
|
199
|
+
async def function_middleware_stream(self,
|
|
200
|
+
value: Any,
|
|
201
|
+
call_next: CallNextStream,
|
|
202
|
+
context: FunctionMiddlewareContext) -> AsyncIterator[Any]:
|
|
203
|
+
"""Cache middleware for streaming invocations - bypasses caching.
|
|
204
|
+
|
|
205
|
+
Streaming results are not cached as they would need to be buffered
|
|
206
|
+
entirely in memory, which would defeat the purpose of streaming.
|
|
207
|
+
|
|
208
|
+
This method demonstrates the middleware pattern for streams:
|
|
209
|
+
|
|
210
|
+
1. **Preprocess**: Log that we're bypassing cache
|
|
211
|
+
2. **Call Next**: Get stream from next middleware/function
|
|
212
|
+
3. **Process Chunks**: Yield each chunk as it arrives
|
|
213
|
+
4. **Continue**: Complete the stream
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
value: The input value to process
|
|
217
|
+
call_next: Callable to invoke the next middleware or function stream
|
|
218
|
+
context: Metadata about the function being wrapped
|
|
219
|
+
|
|
220
|
+
Yields:
|
|
221
|
+
Chunks from the stream (unmodified)
|
|
222
|
+
"""
|
|
223
|
+
# Phase 1: Preprocess - log that we're bypassing cache for streams
|
|
224
|
+
logger.debug("Streaming call for function %s, bypassing cache", context.name)
|
|
225
|
+
|
|
226
|
+
# Phase 2-3: Call next and process chunks - yield chunks as they arrive
|
|
227
|
+
async for chunk in call_next(value):
|
|
228
|
+
yield chunk
|
|
229
|
+
|
|
230
|
+
# Phase 4: Continue - stream is complete (implicit)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
|
|
234
|
+
"""Configuration for cache middleware.
|
|
235
|
+
|
|
236
|
+
The cache middleware memoizes function outputs based on input similarity,
|
|
237
|
+
with support for both exact and fuzzy matching.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
enabled_mode: Controls when caching is active:
|
|
241
|
+
- "always": Cache is always enabled
|
|
242
|
+
- "eval": Cache only active when Context.is_evaluating is True
|
|
243
|
+
similarity_threshold: Float between 0 and 1 for input matching:
|
|
244
|
+
- 1.0: Exact string matching (fastest)
|
|
245
|
+
- < 1.0: Fuzzy matching using difflib similarity
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
enabled_mode: Literal["always", "eval"] = Field(
|
|
249
|
+
default="eval", description="When caching is enabled: 'always' or 'eval' (only during evaluation)")
|
|
250
|
+
similarity_threshold: float = Field(default=1.0,
|
|
251
|
+
ge=0.0,
|
|
252
|
+
le=1.0,
|
|
253
|
+
description="Similarity threshold between 0 and 1. Use 1.0 for exact matching")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
__all__ = ["CacheMiddleware", "CacheMiddlewareConfig"]
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Function-specific middleware for the NeMo Agent toolkit.
|
|
16
|
+
|
|
17
|
+
This module provides function-specific middleware implementations that extend
|
|
18
|
+
the base Middleware class. FunctionMiddleware is a specialized middleware type
|
|
19
|
+
designed specifically for wrapping function calls with dedicated methods
|
|
20
|
+
for function-specific preprocessing and postprocessing.
|
|
21
|
+
|
|
22
|
+
Middleware is configured at registration time and is bound to instances when they
|
|
23
|
+
are constructed by the workflow builder.
|
|
24
|
+
|
|
25
|
+
Middleware executes in the order provided and can optionally be marked as *final*.
|
|
26
|
+
A final middleware terminates the chain, preventing subsequent middleware or the
|
|
27
|
+
wrapped target from running unless the final middleware explicitly delegates to
|
|
28
|
+
the next callable.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
from collections.abc import AsyncIterator
|
|
34
|
+
from collections.abc import Sequence
|
|
35
|
+
from typing import Any
|
|
36
|
+
|
|
37
|
+
from nat.middleware.middleware import CallNext
|
|
38
|
+
from nat.middleware.middleware import CallNextStream
|
|
39
|
+
from nat.middleware.middleware import FunctionMiddlewareContext
|
|
40
|
+
from nat.middleware.middleware import Middleware
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class FunctionMiddleware(Middleware):
|
|
44
|
+
"""Specialized middleware for function-specific wrapping.
|
|
45
|
+
|
|
46
|
+
This class extends the base Middleware class and provides function-specific
|
|
47
|
+
wrapping methods. Functions that use this middleware type will call
|
|
48
|
+
``function_middleware_invoke`` and ``function_middleware_stream`` instead of
|
|
49
|
+
the base ``middleware_invoke`` and ``middleware_stream`` methods.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
async def middleware_invoke(self, value: Any, call_next: CallNext, context: FunctionMiddlewareContext) -> Any:
|
|
53
|
+
"""Delegate to function_middleware_invoke for function-specific handling."""
|
|
54
|
+
return await self.function_middleware_invoke(value, call_next, context)
|
|
55
|
+
|
|
56
|
+
async def middleware_stream(self, value: Any, call_next: CallNextStream,
|
|
57
|
+
context: FunctionMiddlewareContext) -> AsyncIterator[Any]:
|
|
58
|
+
"""Delegate to function_middleware_stream for function-specific handling."""
|
|
59
|
+
async for chunk in self.function_middleware_stream(value, call_next, context):
|
|
60
|
+
yield chunk
|
|
61
|
+
|
|
62
|
+
async def function_middleware_invoke(self, value: Any, call_next: CallNext,
|
|
63
|
+
context: FunctionMiddlewareContext) -> Any:
|
|
64
|
+
"""Function-specific middleware for single-output invocations.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
value: The input value to process
|
|
68
|
+
call_next: Callable to invoke the next middleware or function
|
|
69
|
+
context: Metadata about the function being wrapped
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The (potentially modified) output from the function
|
|
73
|
+
|
|
74
|
+
The default implementation simply delegates to ``call_next``. Override this
|
|
75
|
+
in subclasses to add function-specific preprocessing and postprocessing.
|
|
76
|
+
"""
|
|
77
|
+
return await call_next(value)
|
|
78
|
+
|
|
79
|
+
async def function_middleware_stream(self,
|
|
80
|
+
value: Any,
|
|
81
|
+
call_next: CallNextStream,
|
|
82
|
+
context: FunctionMiddlewareContext) -> AsyncIterator[Any]:
|
|
83
|
+
"""Function-specific middleware for streaming invocations.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
value: The input value to process
|
|
87
|
+
call_next: Callable to invoke the next middleware or function stream
|
|
88
|
+
context: Metadata about the function being wrapped
|
|
89
|
+
|
|
90
|
+
Yields:
|
|
91
|
+
Chunks from the stream (potentially modified)
|
|
92
|
+
|
|
93
|
+
The default implementation forwards to ``call_next`` untouched. Override this
|
|
94
|
+
in subclasses to add function-specific preprocessing and chunk transformations.
|
|
95
|
+
"""
|
|
96
|
+
async for chunk in call_next(value):
|
|
97
|
+
yield chunk
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class FunctionMiddlewareChain:
|
|
101
|
+
"""Utility that composes middleware-style callables.
|
|
102
|
+
|
|
103
|
+
This class builds a chain of middleware that executes in order,
|
|
104
|
+
with each middleware able to preprocess inputs, call the next middleware,
|
|
105
|
+
and postprocess outputs.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self, *, middleware: Sequence[Middleware], context: FunctionMiddlewareContext) -> None:
|
|
109
|
+
self._middleware = tuple(middleware)
|
|
110
|
+
self._context = context
|
|
111
|
+
|
|
112
|
+
def build_single(self, final_call: CallNext) -> CallNext:
|
|
113
|
+
"""Build the middleware chain for single-output invocations.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
final_call: The final function to call (the actual function implementation)
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
A callable that executes the entire middleware chain
|
|
120
|
+
"""
|
|
121
|
+
call = final_call
|
|
122
|
+
|
|
123
|
+
for mw in reversed(self._middleware):
|
|
124
|
+
call_next = call
|
|
125
|
+
|
|
126
|
+
async def wrapped(value: Any, *, _middleware: Middleware = mw, _call_next: CallNext = call_next) -> Any:
|
|
127
|
+
return await _middleware.middleware_invoke(value, _call_next, self._context)
|
|
128
|
+
|
|
129
|
+
call = wrapped
|
|
130
|
+
|
|
131
|
+
return call
|
|
132
|
+
|
|
133
|
+
def build_stream(self, final_call: CallNextStream) -> CallNextStream:
|
|
134
|
+
"""Build the middleware chain for streaming invocations.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
final_call: The final function to call (the actual function implementation)
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
A callable that executes the entire middleware chain
|
|
141
|
+
"""
|
|
142
|
+
call = final_call
|
|
143
|
+
|
|
144
|
+
for mw in reversed(self._middleware):
|
|
145
|
+
call_next = call
|
|
146
|
+
|
|
147
|
+
async def wrapped(value: Any,
|
|
148
|
+
*,
|
|
149
|
+
_middleware: Middleware = mw,
|
|
150
|
+
_call_next: CallNextStream = call_next) -> AsyncIterator[Any]:
|
|
151
|
+
async for chunk in _middleware.middleware_stream(value, _call_next, self._context):
|
|
152
|
+
yield chunk
|
|
153
|
+
|
|
154
|
+
call = wrapped
|
|
155
|
+
|
|
156
|
+
return call
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def validate_middleware(middleware: Sequence[Middleware] | None) -> tuple[Middleware, ...]:
|
|
160
|
+
"""Validate a sequence of middleware, enforcing ordering guarantees."""
|
|
161
|
+
|
|
162
|
+
if not middleware:
|
|
163
|
+
return tuple()
|
|
164
|
+
|
|
165
|
+
final_found = False
|
|
166
|
+
for idx, mw in enumerate(middleware):
|
|
167
|
+
if not isinstance(mw, Middleware):
|
|
168
|
+
raise TypeError("All middleware must be instances of Middleware")
|
|
169
|
+
|
|
170
|
+
if mw.is_final:
|
|
171
|
+
if final_found:
|
|
172
|
+
raise ValueError("Only one final Middleware may be specified per function")
|
|
173
|
+
|
|
174
|
+
if idx != len(middleware) - 1:
|
|
175
|
+
raise ValueError("A final Middleware must be the last middleware in the chain")
|
|
176
|
+
|
|
177
|
+
final_found = True
|
|
178
|
+
|
|
179
|
+
return tuple(middleware)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
__all__ = [
|
|
183
|
+
"FunctionMiddleware",
|
|
184
|
+
"FunctionMiddlewareChain",
|
|
185
|
+
"validate_middleware",
|
|
186
|
+
]
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Base middleware class for the NeMo Agent toolkit.
|
|
16
|
+
|
|
17
|
+
This module provides the base Middleware class that defines the middleware pattern
|
|
18
|
+
for wrapping and modifying function calls. Middleware works like middleware in
|
|
19
|
+
web frameworks - they can modify inputs, call the next middleware in the chain,
|
|
20
|
+
process outputs, and continue.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import dataclasses
|
|
26
|
+
from abc import ABC
|
|
27
|
+
from collections.abc import AsyncIterator
|
|
28
|
+
from collections.abc import Awaitable
|
|
29
|
+
from collections.abc import Callable
|
|
30
|
+
from typing import Any
|
|
31
|
+
|
|
32
|
+
from pydantic import BaseModel
|
|
33
|
+
|
|
34
|
+
#: Type alias for single-output invocation callables.
|
|
35
|
+
CallNext = Callable[[Any], Awaitable[Any]]
|
|
36
|
+
|
|
37
|
+
#: Type alias for streaming invocation callables.
|
|
38
|
+
CallNextStream = Callable[[Any], AsyncIterator[Any]]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
42
|
+
class FunctionMiddlewareContext:
|
|
43
|
+
"""Context information about the function being wrapped by middleware.
|
|
44
|
+
|
|
45
|
+
Middleware receives this context object which describes the function they
|
|
46
|
+
are wrapping. This allows middleware to make decisions based on the
|
|
47
|
+
function's name, configuration, schema, etc.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
name: str
|
|
51
|
+
"""Name of the function being wrapped."""
|
|
52
|
+
|
|
53
|
+
config: Any
|
|
54
|
+
"""Configuration object for the function."""
|
|
55
|
+
|
|
56
|
+
description: str | None
|
|
57
|
+
"""Optional description of the function."""
|
|
58
|
+
|
|
59
|
+
input_schema: type[BaseModel] | None
|
|
60
|
+
"""Schema describing expected inputs or :class:`NoneType` when absent."""
|
|
61
|
+
|
|
62
|
+
single_output_schema: type[BaseModel] | type[None]
|
|
63
|
+
"""Schema describing single outputs or :class:`types.NoneType` when absent."""
|
|
64
|
+
|
|
65
|
+
stream_output_schema: type[BaseModel] | type[None]
|
|
66
|
+
"""Schema describing streaming outputs or :class:`types.NoneType` when absent."""
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Middleware(ABC):
|
|
70
|
+
"""Base class for middleware-style wrapping.
|
|
71
|
+
|
|
72
|
+
Middleware works like middleware in web frameworks:
|
|
73
|
+
|
|
74
|
+
1. **Preprocess**: Inspect and optionally modify inputs
|
|
75
|
+
2. **Call Next**: Delegate to the next middleware or the target itself
|
|
76
|
+
3. **Postprocess**: Process, transform, or augment the output
|
|
77
|
+
4. **Continue**: Return or yield the final result
|
|
78
|
+
|
|
79
|
+
Example::
|
|
80
|
+
|
|
81
|
+
class LoggingMiddleware(Middleware):
|
|
82
|
+
async def middleware_invoke(self, value, call_next, context):
|
|
83
|
+
# 1. Preprocess
|
|
84
|
+
print(f"Input: {value}")
|
|
85
|
+
|
|
86
|
+
# 2. Call next middleware/target
|
|
87
|
+
result = await call_next(value)
|
|
88
|
+
|
|
89
|
+
# 3. Postprocess
|
|
90
|
+
print(f"Output: {result}")
|
|
91
|
+
|
|
92
|
+
# 4. Continue
|
|
93
|
+
return result
|
|
94
|
+
|
|
95
|
+
Attributes:
|
|
96
|
+
is_final: If True, this middleware terminates the chain. No subsequent
|
|
97
|
+
middleware or the target will be called unless this middleware
|
|
98
|
+
explicitly delegates to ``call_next``.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, *, is_final: bool = False) -> None:
|
|
102
|
+
self._is_final = is_final
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def is_final(self) -> bool:
|
|
106
|
+
"""Whether this middleware terminates the chain.
|
|
107
|
+
|
|
108
|
+
A final middleware prevents subsequent middleware and the target
|
|
109
|
+
from running unless it explicitly calls ``call_next``.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
return self._is_final
|
|
113
|
+
|
|
114
|
+
async def middleware_invoke(self, value: Any, call_next: CallNext, context: FunctionMiddlewareContext) -> Any:
|
|
115
|
+
"""Middleware for single-output invocations.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
value: The input value to process
|
|
119
|
+
call_next: Callable to invoke the next middleware or target
|
|
120
|
+
context: Metadata about the target being wrapped
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
The (potentially modified) output from the target
|
|
124
|
+
|
|
125
|
+
The default implementation simply delegates to ``call_next``. Override this
|
|
126
|
+
to add preprocessing, postprocessing, or to short-circuit execution::
|
|
127
|
+
|
|
128
|
+
async def middleware_invoke(self, value, call_next, context):
|
|
129
|
+
# Preprocess: modify input
|
|
130
|
+
modified_input = transform(value)
|
|
131
|
+
|
|
132
|
+
# Call next: delegate to next middleware/target
|
|
133
|
+
result = await call_next(modified_input)
|
|
134
|
+
|
|
135
|
+
# Postprocess: modify output
|
|
136
|
+
modified_result = transform_output(result)
|
|
137
|
+
|
|
138
|
+
# Continue: return final result
|
|
139
|
+
return modified_result
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
del context # Unused by the default implementation.
|
|
143
|
+
return await call_next(value)
|
|
144
|
+
|
|
145
|
+
async def middleware_stream(self, value: Any, call_next: CallNextStream,
|
|
146
|
+
context: FunctionMiddlewareContext) -> AsyncIterator[Any]:
|
|
147
|
+
"""Middleware for streaming invocations.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
value: The input value to process
|
|
151
|
+
call_next: Callable to invoke the next middleware or target stream
|
|
152
|
+
context: Metadata about the target being wrapped
|
|
153
|
+
|
|
154
|
+
Yields:
|
|
155
|
+
Chunks from the stream (potentially modified)
|
|
156
|
+
|
|
157
|
+
The default implementation forwards to ``call_next`` untouched. Override this
|
|
158
|
+
to add preprocessing, transform chunks, or perform cleanup::
|
|
159
|
+
|
|
160
|
+
async def middleware_stream(self, value, call_next, context):
|
|
161
|
+
# Preprocess: setup or modify input
|
|
162
|
+
modified_input = transform(value)
|
|
163
|
+
|
|
164
|
+
# Call next: get stream from next middleware/target
|
|
165
|
+
async for chunk in call_next(modified_input):
|
|
166
|
+
# Process each chunk
|
|
167
|
+
modified_chunk = transform_chunk(chunk)
|
|
168
|
+
yield modified_chunk
|
|
169
|
+
|
|
170
|
+
# Postprocess: cleanup after stream ends
|
|
171
|
+
await cleanup()
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
del context # Unused by the default implementation.
|
|
175
|
+
async for chunk in call_next(value):
|
|
176
|
+
yield chunk
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
__all__ = [
|
|
180
|
+
"CallNext",
|
|
181
|
+
"CallNextStream",
|
|
182
|
+
"Middleware",
|
|
183
|
+
"FunctionMiddlewareContext",
|
|
184
|
+
]
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Registration module for built-in middleware."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from nat.cli.register_workflow import register_middleware
|
|
20
|
+
from nat.middleware.cache_middleware import CacheMiddleware
|
|
21
|
+
from nat.middleware.cache_middleware import CacheMiddlewareConfig
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@register_middleware(config_type=CacheMiddlewareConfig)
|
|
25
|
+
async def cache_middleware(config: CacheMiddlewareConfig, builder):
|
|
26
|
+
"""Build a cache middleware from configuration.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
config: The cache middleware configuration
|
|
30
|
+
builder: The workflow builder (unused but required by component pattern)
|
|
31
|
+
|
|
32
|
+
Yields:
|
|
33
|
+
A configured cache middleware instance
|
|
34
|
+
"""
|
|
35
|
+
yield CacheMiddleware(enabled_mode=config.enabled_mode, similarity_threshold=config.similarity_threshold)
|
|
@@ -34,6 +34,7 @@ _library_instrumented = {
|
|
|
34
34
|
"semantic_kernel": False,
|
|
35
35
|
"agno": False,
|
|
36
36
|
"adk": False,
|
|
37
|
+
"strands": False,
|
|
37
38
|
}
|
|
38
39
|
|
|
39
40
|
callback_handler_var: ContextVar[Any | None] = ContextVar("callback_handler_var", default=None)
|
|
@@ -131,6 +132,21 @@ def set_framework_profiler_handler(
|
|
|
131
132
|
_library_instrumented["adk"] = True
|
|
132
133
|
logger.debug("ADK callback handler registered")
|
|
133
134
|
|
|
135
|
+
if (LLMFrameworkEnum.STRANDS in frameworks and not _library_instrumented["strands"]):
|
|
136
|
+
try:
|
|
137
|
+
from nat.plugins.strands.strands_callback_handler import StrandsProfilerHandler
|
|
138
|
+
except ImportError as e:
|
|
139
|
+
logger.warning(
|
|
140
|
+
"Strands profiler not available. Install NAT with Strands extras: "
|
|
141
|
+
"pip install \"nvidia-nat[strands]\". Error: %s",
|
|
142
|
+
e,
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
handler = StrandsProfilerHandler()
|
|
146
|
+
handler.instrument()
|
|
147
|
+
_library_instrumented["strands"] = True
|
|
148
|
+
logger.debug("Strands callback handler registered")
|
|
149
|
+
|
|
134
150
|
# IMPORTANT: actually call the wrapped function as an async context manager
|
|
135
151
|
async with func(workflow_config, builder) as result:
|
|
136
152
|
yield result
|