nvidia-nat 1.4.0a20251102__py3-none-any.whl → 1.4.0a20251120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. nat/builder/builder.py +52 -0
  2. nat/builder/component_utils.py +7 -1
  3. nat/builder/context.py +17 -0
  4. nat/builder/framework_enum.py +1 -0
  5. nat/builder/function.py +74 -3
  6. nat/builder/workflow.py +4 -2
  7. nat/builder/workflow_builder.py +129 -0
  8. nat/cli/commands/workflow/workflow_commands.py +3 -2
  9. nat/cli/register_workflow.py +50 -0
  10. nat/cli/type_registry.py +68 -0
  11. nat/data_models/component.py +2 -0
  12. nat/data_models/component_ref.py +11 -0
  13. nat/data_models/config.py +16 -0
  14. nat/data_models/function.py +14 -1
  15. nat/data_models/middleware.py +35 -0
  16. nat/data_models/runtime_enum.py +26 -0
  17. nat/eval/dataset_handler/dataset_filter.py +34 -2
  18. nat/eval/evaluate.py +11 -3
  19. nat/eval/utils/weave_eval.py +17 -3
  20. nat/front_ends/fastapi/fastapi_front_end_config.py +29 -0
  21. nat/front_ends/fastapi/fastapi_front_end_plugin.py +13 -7
  22. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +144 -14
  23. nat/front_ends/mcp/mcp_front_end_plugin.py +4 -0
  24. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +26 -0
  25. nat/llm/aws_bedrock_llm.py +11 -9
  26. nat/llm/azure_openai_llm.py +12 -4
  27. nat/llm/litellm_llm.py +11 -4
  28. nat/llm/nim_llm.py +11 -9
  29. nat/llm/openai_llm.py +12 -9
  30. nat/middleware/__init__.py +35 -0
  31. nat/middleware/cache_middleware.py +256 -0
  32. nat/middleware/function_middleware.py +186 -0
  33. nat/middleware/middleware.py +184 -0
  34. nat/middleware/register.py +35 -0
  35. nat/profiler/decorators/framework_wrapper.py +16 -0
  36. nat/retriever/milvus/register.py +11 -3
  37. nat/retriever/milvus/retriever.py +102 -40
  38. nat/runtime/runner.py +12 -1
  39. nat/runtime/session.py +10 -3
  40. nat/tool/code_execution/code_sandbox.py +4 -7
  41. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  42. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +5 -0
  43. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  44. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  45. nat/tool/server_tools.py +15 -2
  46. nat/utils/__init__.py +8 -4
  47. nat/utils/io/yaml_tools.py +73 -3
  48. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/METADATA +11 -3
  49. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/RECORD +54 -50
  50. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/entry_points.txt +1 -0
  51. nat/data_models/temperature_mixin.py +0 -44
  52. nat/data_models/top_p_mixin.py +0 -44
  53. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  54. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/WHEEL +0 -0
  55. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  56. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE.md +0 -0
  57. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,256 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Cache middleware for function memoization with similarity matching.
16
+
17
+ This module provides a cache middleware that memoizes function calls based on
18
+ input similarity. It demonstrates the middleware pattern by:
19
+
20
+ 1. Preprocessing: Serializing and checking the cache for similar inputs
21
+ 2. Calling next: Delegating to the next middleware/function if no cache hit
22
+ 3. Postprocessing: Caching the result for future use
23
+ 4. Continuing: Returning the result (cached or fresh)
24
+
25
+ The cache supports exact matching for maximum performance and fuzzy matching
26
+ using Python's built-in difflib for similarity computation.
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import json
32
+ import logging
33
+ from collections.abc import AsyncIterator
34
+ from typing import Any
35
+ from typing import Literal
36
+
37
+ from pydantic import Field
38
+
39
+ from nat.builder.context import Context
40
+ from nat.builder.context import ContextState
41
+ from nat.data_models.middleware import FunctionMiddlewareBaseConfig
42
+ from nat.middleware.function_middleware import CallNext
43
+ from nat.middleware.function_middleware import CallNextStream
44
+ from nat.middleware.function_middleware import FunctionMiddleware
45
+ from nat.middleware.function_middleware import FunctionMiddlewareContext
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ class CacheMiddleware(FunctionMiddleware):
51
+ """Cache middleware that memoizes function outputs based on input similarity.
52
+
53
+ This middleware demonstrates the four-phase middleware pattern:
54
+
55
+ 1. **Preprocess**: Serialize input and check cache for similar entries
56
+ 2. **Call Next**: Delegate to next middleware/function if cache miss
57
+ 3. **Postprocess**: Store the result in cache for future use
58
+ 4. **Continue**: Return the result (from cache or fresh)
59
+
60
+ The cache serializes function inputs to strings and performs similarity
61
+ matching against previously seen inputs. If a similar input is found above
62
+ the configured threshold, it returns the cached output without calling the
63
+ next middleware or function.
64
+
65
+ Args:
66
+ enabled_mode: Either "always" to always cache, or "eval" to only
67
+ cache when Context.is_evaluating is True.
68
+ similarity_threshold: Float between 0 and 1. If 1.0, performs
69
+ exact string matching. Otherwise uses difflib for similarity
70
+ computation.
71
+ """
72
+
73
+ def __init__(self, *, enabled_mode: str, similarity_threshold: float) -> None:
74
+ """Initialize the cache middleware.
75
+
76
+ Args:
77
+ enabled_mode: Either "always" or "eval". If "eval", only caches
78
+ when Context.is_evaluating is True.
79
+ similarity_threshold: Similarity threshold between 0 and 1.
80
+ If 1.0, performs exact matching. Otherwise uses fuzzy matching.
81
+ """
82
+ super().__init__(is_final=True)
83
+ self._enabled_mode = enabled_mode
84
+ self._similarity_threshold = similarity_threshold
85
+ self._cache: dict[str, Any] = {}
86
+
87
+ def _should_cache(self) -> bool:
88
+ """Check if caching should be enabled based on the current context."""
89
+ if self._enabled_mode == "always":
90
+ return True
91
+
92
+ # Get the current context and check if we're in evaluation mode
93
+ try:
94
+ context_state = ContextState.get()
95
+ context = Context(context_state)
96
+ return context.is_evaluating
97
+ except Exception:
98
+ logger.warning("Failed to get context for cache decision", exc_info=True)
99
+ return False
100
+
101
+ def _serialize_input(self, value: Any) -> str | None:
102
+ """Serialize the input value to a string for caching.
103
+
104
+ Args:
105
+ value: The input value to serialize.
106
+
107
+ Returns:
108
+ String representation of the input, or None if serialization
109
+ fails.
110
+ """
111
+ try:
112
+ # Try JSON serialization first for best results
113
+ return json.dumps(value, sort_keys=True, default=str)
114
+ except Exception:
115
+ logger.debug("Failed to serialize input for caching", exc_info=True)
116
+ return None
117
+
118
+ def _find_similar_key(self, input_str: str) -> str | None:
119
+ """Find a cached key that is similar to the input string.
120
+
121
+ Args:
122
+ input_str: The serialized input string to match.
123
+
124
+ Returns:
125
+ The most similar cached key if above threshold, None otherwise.
126
+ """
127
+ if self._similarity_threshold == 1.0:
128
+ # Exact matching - fast path
129
+ return input_str if input_str in self._cache else None
130
+
131
+ # Fuzzy matching using difflib
132
+ import difflib
133
+
134
+ best_match = None
135
+ best_ratio = 0.0
136
+
137
+ for cached_key in self._cache:
138
+ # Use SequenceMatcher for similarity computation
139
+ matcher = difflib.SequenceMatcher(None, input_str, cached_key)
140
+ ratio = matcher.ratio()
141
+
142
+ if ratio >= self._similarity_threshold and ratio > best_ratio:
143
+ best_ratio = ratio
144
+ best_match = cached_key
145
+
146
+ return best_match
147
+
148
+ async def function_middleware_invoke(self, value: Any, call_next: CallNext,
149
+ context: FunctionMiddlewareContext) -> Any:
150
+ """Cache middleware for single-output invocations.
151
+
152
+ Implements the four-phase middleware pattern:
153
+
154
+ 1. **Preprocess**: Check if caching is enabled and serialize input
155
+ 2. **Call Next**: Delegate to next middleware/function if cache miss
156
+ 3. **Postprocess**: Store the result in cache
157
+ 4. **Continue**: Return the result (cached or fresh)
158
+
159
+ Args:
160
+ value: The input value to process
161
+ call_next: Callable to invoke the next middleware or function
162
+ context: Metadata about the function being wrapped
163
+
164
+ Returns:
165
+ The cached output if found, otherwise the fresh output
166
+ """
167
+ # Phase 1: Preprocess - check if caching should be enabled
168
+ if not self._should_cache():
169
+ return await call_next(value)
170
+
171
+ # Phase 1: Preprocess - serialize the input
172
+ input_str = self._serialize_input(value)
173
+ if input_str is None:
174
+ # Can't serialize, pass through to next middleware/function
175
+ logger.debug("Could not serialize input for function %s, bypassing cache", context.name)
176
+ return await call_next(value)
177
+
178
+ # Phase 1: Preprocess - look for a similar cached input
179
+ similar_key = self._find_similar_key(input_str)
180
+ if similar_key is not None:
181
+ # Cache hit - short-circuit and return cached output
182
+ logger.debug("Cache hit for function %s with similarity %.2f",
183
+ context.name,
184
+ 1.0 if similar_key == input_str else self._similarity_threshold)
185
+ # Phase 4: Continue - return cached result
186
+ return self._cache[similar_key]
187
+
188
+ # Phase 2: Call next - no cache hit, call next middleware/function
189
+ logger.debug("Cache miss for function %s", context.name)
190
+ result = await call_next(value)
191
+
192
+ # Phase 3: Postprocess - cache the result for future use
193
+ self._cache[input_str] = result
194
+ logger.debug("Cached result for function %s", context.name)
195
+
196
+ # Phase 4: Continue - return the fresh result
197
+ return result
198
+
199
+ async def function_middleware_stream(self,
200
+ value: Any,
201
+ call_next: CallNextStream,
202
+ context: FunctionMiddlewareContext) -> AsyncIterator[Any]:
203
+ """Cache middleware for streaming invocations - bypasses caching.
204
+
205
+ Streaming results are not cached as they would need to be buffered
206
+ entirely in memory, which would defeat the purpose of streaming.
207
+
208
+ This method demonstrates the middleware pattern for streams:
209
+
210
+ 1. **Preprocess**: Log that we're bypassing cache
211
+ 2. **Call Next**: Get stream from next middleware/function
212
+ 3. **Process Chunks**: Yield each chunk as it arrives
213
+ 4. **Continue**: Complete the stream
214
+
215
+ Args:
216
+ value: The input value to process
217
+ call_next: Callable to invoke the next middleware or function stream
218
+ context: Metadata about the function being wrapped
219
+
220
+ Yields:
221
+ Chunks from the stream (unmodified)
222
+ """
223
+ # Phase 1: Preprocess - log that we're bypassing cache for streams
224
+ logger.debug("Streaming call for function %s, bypassing cache", context.name)
225
+
226
+ # Phase 2-3: Call next and process chunks - yield chunks as they arrive
227
+ async for chunk in call_next(value):
228
+ yield chunk
229
+
230
+ # Phase 4: Continue - stream is complete (implicit)
231
+
232
+
233
+ class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"):
234
+ """Configuration for cache middleware.
235
+
236
+ The cache middleware memoizes function outputs based on input similarity,
237
+ with support for both exact and fuzzy matching.
238
+
239
+ Args:
240
+ enabled_mode: Controls when caching is active:
241
+ - "always": Cache is always enabled
242
+ - "eval": Cache only active when Context.is_evaluating is True
243
+ similarity_threshold: Float between 0 and 1 for input matching:
244
+ - 1.0: Exact string matching (fastest)
245
+ - < 1.0: Fuzzy matching using difflib similarity
246
+ """
247
+
248
+ enabled_mode: Literal["always", "eval"] = Field(
249
+ default="eval", description="When caching is enabled: 'always' or 'eval' (only during evaluation)")
250
+ similarity_threshold: float = Field(default=1.0,
251
+ ge=0.0,
252
+ le=1.0,
253
+ description="Similarity threshold between 0 and 1. Use 1.0 for exact matching")
254
+
255
+
256
+ __all__ = ["CacheMiddleware", "CacheMiddlewareConfig"]
@@ -0,0 +1,186 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Function-specific middleware for the NeMo Agent toolkit.
16
+
17
+ This module provides function-specific middleware implementations that extend
18
+ the base Middleware class. FunctionMiddleware is a specialized middleware type
19
+ designed specifically for wrapping function calls with dedicated methods
20
+ for function-specific preprocessing and postprocessing.
21
+
22
+ Middleware is configured at registration time and is bound to instances when they
23
+ are constructed by the workflow builder.
24
+
25
+ Middleware executes in the order provided and can optionally be marked as *final*.
26
+ A final middleware terminates the chain, preventing subsequent middleware or the
27
+ wrapped target from running unless the final middleware explicitly delegates to
28
+ the next callable.
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ from collections.abc import AsyncIterator
34
+ from collections.abc import Sequence
35
+ from typing import Any
36
+
37
+ from nat.middleware.middleware import CallNext
38
+ from nat.middleware.middleware import CallNextStream
39
+ from nat.middleware.middleware import FunctionMiddlewareContext
40
+ from nat.middleware.middleware import Middleware
41
+
42
+
43
+ class FunctionMiddleware(Middleware):
44
+ """Specialized middleware for function-specific wrapping.
45
+
46
+ This class extends the base Middleware class and provides function-specific
47
+ wrapping methods. Functions that use this middleware type will call
48
+ ``function_middleware_invoke`` and ``function_middleware_stream`` instead of
49
+ the base ``middleware_invoke`` and ``middleware_stream`` methods.
50
+ """
51
+
52
+ async def middleware_invoke(self, value: Any, call_next: CallNext, context: FunctionMiddlewareContext) -> Any:
53
+ """Delegate to function_middleware_invoke for function-specific handling."""
54
+ return await self.function_middleware_invoke(value, call_next, context)
55
+
56
+ async def middleware_stream(self, value: Any, call_next: CallNextStream,
57
+ context: FunctionMiddlewareContext) -> AsyncIterator[Any]:
58
+ """Delegate to function_middleware_stream for function-specific handling."""
59
+ async for chunk in self.function_middleware_stream(value, call_next, context):
60
+ yield chunk
61
+
62
+ async def function_middleware_invoke(self, value: Any, call_next: CallNext,
63
+ context: FunctionMiddlewareContext) -> Any:
64
+ """Function-specific middleware for single-output invocations.
65
+
66
+ Args:
67
+ value: The input value to process
68
+ call_next: Callable to invoke the next middleware or function
69
+ context: Metadata about the function being wrapped
70
+
71
+ Returns:
72
+ The (potentially modified) output from the function
73
+
74
+ The default implementation simply delegates to ``call_next``. Override this
75
+ in subclasses to add function-specific preprocessing and postprocessing.
76
+ """
77
+ return await call_next(value)
78
+
79
+ async def function_middleware_stream(self,
80
+ value: Any,
81
+ call_next: CallNextStream,
82
+ context: FunctionMiddlewareContext) -> AsyncIterator[Any]:
83
+ """Function-specific middleware for streaming invocations.
84
+
85
+ Args:
86
+ value: The input value to process
87
+ call_next: Callable to invoke the next middleware or function stream
88
+ context: Metadata about the function being wrapped
89
+
90
+ Yields:
91
+ Chunks from the stream (potentially modified)
92
+
93
+ The default implementation forwards to ``call_next`` untouched. Override this
94
+ in subclasses to add function-specific preprocessing and chunk transformations.
95
+ """
96
+ async for chunk in call_next(value):
97
+ yield chunk
98
+
99
+
100
+ class FunctionMiddlewareChain:
101
+ """Utility that composes middleware-style callables.
102
+
103
+ This class builds a chain of middleware that executes in order,
104
+ with each middleware able to preprocess inputs, call the next middleware,
105
+ and postprocess outputs.
106
+ """
107
+
108
+ def __init__(self, *, middleware: Sequence[Middleware], context: FunctionMiddlewareContext) -> None:
109
+ self._middleware = tuple(middleware)
110
+ self._context = context
111
+
112
+ def build_single(self, final_call: CallNext) -> CallNext:
113
+ """Build the middleware chain for single-output invocations.
114
+
115
+ Args:
116
+ final_call: The final function to call (the actual function implementation)
117
+
118
+ Returns:
119
+ A callable that executes the entire middleware chain
120
+ """
121
+ call = final_call
122
+
123
+ for mw in reversed(self._middleware):
124
+ call_next = call
125
+
126
+ async def wrapped(value: Any, *, _middleware: Middleware = mw, _call_next: CallNext = call_next) -> Any:
127
+ return await _middleware.middleware_invoke(value, _call_next, self._context)
128
+
129
+ call = wrapped
130
+
131
+ return call
132
+
133
+ def build_stream(self, final_call: CallNextStream) -> CallNextStream:
134
+ """Build the middleware chain for streaming invocations.
135
+
136
+ Args:
137
+ final_call: The final function to call (the actual function implementation)
138
+
139
+ Returns:
140
+ A callable that executes the entire middleware chain
141
+ """
142
+ call = final_call
143
+
144
+ for mw in reversed(self._middleware):
145
+ call_next = call
146
+
147
+ async def wrapped(value: Any,
148
+ *,
149
+ _middleware: Middleware = mw,
150
+ _call_next: CallNextStream = call_next) -> AsyncIterator[Any]:
151
+ async for chunk in _middleware.middleware_stream(value, _call_next, self._context):
152
+ yield chunk
153
+
154
+ call = wrapped
155
+
156
+ return call
157
+
158
+
159
+ def validate_middleware(middleware: Sequence[Middleware] | None) -> tuple[Middleware, ...]:
160
+ """Validate a sequence of middleware, enforcing ordering guarantees."""
161
+
162
+ if not middleware:
163
+ return tuple()
164
+
165
+ final_found = False
166
+ for idx, mw in enumerate(middleware):
167
+ if not isinstance(mw, Middleware):
168
+ raise TypeError("All middleware must be instances of Middleware")
169
+
170
+ if mw.is_final:
171
+ if final_found:
172
+ raise ValueError("Only one final Middleware may be specified per function")
173
+
174
+ if idx != len(middleware) - 1:
175
+ raise ValueError("A final Middleware must be the last middleware in the chain")
176
+
177
+ final_found = True
178
+
179
+ return tuple(middleware)
180
+
181
+
182
+ __all__ = [
183
+ "FunctionMiddleware",
184
+ "FunctionMiddlewareChain",
185
+ "validate_middleware",
186
+ ]
@@ -0,0 +1,184 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Base middleware class for the NeMo Agent toolkit.
16
+
17
+ This module provides the base Middleware class that defines the middleware pattern
18
+ for wrapping and modifying function calls. Middleware works like middleware in
19
+ web frameworks - they can modify inputs, call the next middleware in the chain,
20
+ process outputs, and continue.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import dataclasses
26
+ from abc import ABC
27
+ from collections.abc import AsyncIterator
28
+ from collections.abc import Awaitable
29
+ from collections.abc import Callable
30
+ from typing import Any
31
+
32
+ from pydantic import BaseModel
33
+
34
+ #: Type alias for single-output invocation callables.
35
+ CallNext = Callable[[Any], Awaitable[Any]]
36
+
37
+ #: Type alias for streaming invocation callables.
38
+ CallNextStream = Callable[[Any], AsyncIterator[Any]]
39
+
40
+
41
+ @dataclasses.dataclass(frozen=True, kw_only=True)
42
+ class FunctionMiddlewareContext:
43
+ """Context information about the function being wrapped by middleware.
44
+
45
+ Middleware receives this context object which describes the function they
46
+ are wrapping. This allows middleware to make decisions based on the
47
+ function's name, configuration, schema, etc.
48
+ """
49
+
50
+ name: str
51
+ """Name of the function being wrapped."""
52
+
53
+ config: Any
54
+ """Configuration object for the function."""
55
+
56
+ description: str | None
57
+ """Optional description of the function."""
58
+
59
+ input_schema: type[BaseModel] | None
60
+ """Schema describing expected inputs or :class:`NoneType` when absent."""
61
+
62
+ single_output_schema: type[BaseModel] | type[None]
63
+ """Schema describing single outputs or :class:`types.NoneType` when absent."""
64
+
65
+ stream_output_schema: type[BaseModel] | type[None]
66
+ """Schema describing streaming outputs or :class:`types.NoneType` when absent."""
67
+
68
+
69
+ class Middleware(ABC):
70
+ """Base class for middleware-style wrapping.
71
+
72
+ Middleware works like middleware in web frameworks:
73
+
74
+ 1. **Preprocess**: Inspect and optionally modify inputs
75
+ 2. **Call Next**: Delegate to the next middleware or the target itself
76
+ 3. **Postprocess**: Process, transform, or augment the output
77
+ 4. **Continue**: Return or yield the final result
78
+
79
+ Example::
80
+
81
+ class LoggingMiddleware(Middleware):
82
+ async def middleware_invoke(self, value, call_next, context):
83
+ # 1. Preprocess
84
+ print(f"Input: {value}")
85
+
86
+ # 2. Call next middleware/target
87
+ result = await call_next(value)
88
+
89
+ # 3. Postprocess
90
+ print(f"Output: {result}")
91
+
92
+ # 4. Continue
93
+ return result
94
+
95
+ Attributes:
96
+ is_final: If True, this middleware terminates the chain. No subsequent
97
+ middleware or the target will be called unless this middleware
98
+ explicitly delegates to ``call_next``.
99
+ """
100
+
101
+ def __init__(self, *, is_final: bool = False) -> None:
102
+ self._is_final = is_final
103
+
104
+ @property
105
+ def is_final(self) -> bool:
106
+ """Whether this middleware terminates the chain.
107
+
108
+ A final middleware prevents subsequent middleware and the target
109
+ from running unless it explicitly calls ``call_next``.
110
+ """
111
+
112
+ return self._is_final
113
+
114
+ async def middleware_invoke(self, value: Any, call_next: CallNext, context: FunctionMiddlewareContext) -> Any:
115
+ """Middleware for single-output invocations.
116
+
117
+ Args:
118
+ value: The input value to process
119
+ call_next: Callable to invoke the next middleware or target
120
+ context: Metadata about the target being wrapped
121
+
122
+ Returns:
123
+ The (potentially modified) output from the target
124
+
125
+ The default implementation simply delegates to ``call_next``. Override this
126
+ to add preprocessing, postprocessing, or to short-circuit execution::
127
+
128
+ async def middleware_invoke(self, value, call_next, context):
129
+ # Preprocess: modify input
130
+ modified_input = transform(value)
131
+
132
+ # Call next: delegate to next middleware/target
133
+ result = await call_next(modified_input)
134
+
135
+ # Postprocess: modify output
136
+ modified_result = transform_output(result)
137
+
138
+ # Continue: return final result
139
+ return modified_result
140
+ """
141
+
142
+ del context # Unused by the default implementation.
143
+ return await call_next(value)
144
+
145
+ async def middleware_stream(self, value: Any, call_next: CallNextStream,
146
+ context: FunctionMiddlewareContext) -> AsyncIterator[Any]:
147
+ """Middleware for streaming invocations.
148
+
149
+ Args:
150
+ value: The input value to process
151
+ call_next: Callable to invoke the next middleware or target stream
152
+ context: Metadata about the target being wrapped
153
+
154
+ Yields:
155
+ Chunks from the stream (potentially modified)
156
+
157
+ The default implementation forwards to ``call_next`` untouched. Override this
158
+ to add preprocessing, transform chunks, or perform cleanup::
159
+
160
+ async def middleware_stream(self, value, call_next, context):
161
+ # Preprocess: setup or modify input
162
+ modified_input = transform(value)
163
+
164
+ # Call next: get stream from next middleware/target
165
+ async for chunk in call_next(modified_input):
166
+ # Process each chunk
167
+ modified_chunk = transform_chunk(chunk)
168
+ yield modified_chunk
169
+
170
+ # Postprocess: cleanup after stream ends
171
+ await cleanup()
172
+ """
173
+
174
+ del context # Unused by the default implementation.
175
+ async for chunk in call_next(value):
176
+ yield chunk
177
+
178
+
179
+ __all__ = [
180
+ "CallNext",
181
+ "CallNextStream",
182
+ "Middleware",
183
+ "FunctionMiddlewareContext",
184
+ ]
@@ -0,0 +1,35 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Registration module for built-in middleware."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from nat.cli.register_workflow import register_middleware
20
+ from nat.middleware.cache_middleware import CacheMiddleware
21
+ from nat.middleware.cache_middleware import CacheMiddlewareConfig
22
+
23
+
24
+ @register_middleware(config_type=CacheMiddlewareConfig)
25
+ async def cache_middleware(config: CacheMiddlewareConfig, builder):
26
+ """Build a cache middleware from configuration.
27
+
28
+ Args:
29
+ config: The cache middleware configuration
30
+ builder: The workflow builder (unused but required by component pattern)
31
+
32
+ Yields:
33
+ A configured cache middleware instance
34
+ """
35
+ yield CacheMiddleware(enabled_mode=config.enabled_mode, similarity_threshold=config.similarity_threshold)
@@ -34,6 +34,7 @@ _library_instrumented = {
34
34
  "semantic_kernel": False,
35
35
  "agno": False,
36
36
  "adk": False,
37
+ "strands": False,
37
38
  }
38
39
 
39
40
  callback_handler_var: ContextVar[Any | None] = ContextVar("callback_handler_var", default=None)
@@ -131,6 +132,21 @@ def set_framework_profiler_handler(
131
132
  _library_instrumented["adk"] = True
132
133
  logger.debug("ADK callback handler registered")
133
134
 
135
+ if (LLMFrameworkEnum.STRANDS in frameworks and not _library_instrumented["strands"]):
136
+ try:
137
+ from nat.plugins.strands.strands_callback_handler import StrandsProfilerHandler
138
+ except ImportError as e:
139
+ logger.warning(
140
+ "Strands profiler not available. Install NAT with Strands extras: "
141
+ "pip install \"nvidia-nat[strands]\". Error: %s",
142
+ e,
143
+ )
144
+ else:
145
+ handler = StrandsProfilerHandler()
146
+ handler.instrument()
147
+ _library_instrumented["strands"] = True
148
+ logger.debug("Strands callback handler registered")
149
+
134
150
  # IMPORTANT: actually call the wrapped function as an async context manager
135
151
  async with func(workflow_config, builder) as result:
136
152
  yield result