nvidia-nat 1.4.0a20251112__py3-none-any.whl → 1.4.0a20251120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. nat/builder/builder.py +52 -0
  2. nat/builder/component_utils.py +7 -1
  3. nat/builder/context.py +17 -0
  4. nat/builder/framework_enum.py +1 -0
  5. nat/builder/function.py +74 -3
  6. nat/builder/workflow.py +4 -2
  7. nat/builder/workflow_builder.py +129 -0
  8. nat/cli/register_workflow.py +50 -0
  9. nat/cli/type_registry.py +68 -0
  10. nat/data_models/component.py +2 -0
  11. nat/data_models/component_ref.py +11 -0
  12. nat/data_models/config.py +16 -0
  13. nat/data_models/function.py +14 -1
  14. nat/data_models/middleware.py +35 -0
  15. nat/data_models/runtime_enum.py +26 -0
  16. nat/eval/evaluate.py +10 -2
  17. nat/front_ends/fastapi/fastapi_front_end_config.py +22 -0
  18. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +124 -0
  19. nat/front_ends/mcp/mcp_front_end_plugin.py +4 -0
  20. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +26 -0
  21. nat/middleware/__init__.py +35 -0
  22. nat/middleware/cache_middleware.py +256 -0
  23. nat/middleware/function_middleware.py +186 -0
  24. nat/middleware/middleware.py +184 -0
  25. nat/middleware/register.py +35 -0
  26. nat/profiler/decorators/framework_wrapper.py +16 -0
  27. nat/retriever/milvus/register.py +11 -3
  28. nat/retriever/milvus/retriever.py +102 -40
  29. nat/runtime/runner.py +12 -1
  30. nat/runtime/session.py +10 -3
  31. nat/tool/code_execution/code_sandbox.py +1 -1
  32. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/METADATA +9 -3
  33. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/RECORD +38 -31
  34. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/entry_points.txt +1 -0
  35. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/WHEEL +0 -0
  36. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  37. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE.md +0 -0
  38. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/top_level.txt +0 -0
nat/builder/builder.py CHANGED
@@ -32,6 +32,7 @@ from nat.data_models.component_ref import FunctionGroupRef
32
32
  from nat.data_models.component_ref import FunctionRef
33
33
  from nat.data_models.component_ref import LLMRef
34
34
  from nat.data_models.component_ref import MemoryRef
35
+ from nat.data_models.component_ref import MiddlewareRef
35
36
  from nat.data_models.component_ref import ObjectStoreRef
36
37
  from nat.data_models.component_ref import RetrieverRef
37
38
  from nat.data_models.component_ref import TTCStrategyRef
@@ -42,6 +43,7 @@ from nat.data_models.function import FunctionGroupBaseConfig
42
43
  from nat.data_models.function_dependencies import FunctionDependencies
43
44
  from nat.data_models.llm import LLMBaseConfig
44
45
  from nat.data_models.memory import MemoryBaseConfig
46
+ from nat.data_models.middleware import MiddlewareBaseConfig
45
47
  from nat.data_models.object_store import ObjectStoreBaseConfig
46
48
  from nat.data_models.retriever import RetrieverBaseConfig
47
49
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
@@ -49,6 +51,7 @@ from nat.experimental.decorators.experimental_warning_decorator import experimen
49
51
  from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
50
52
  from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
51
53
  from nat.memory.interfaces import MemoryEditor
54
+ from nat.middleware.middleware import Middleware
52
55
  from nat.object_store.interfaces import ObjectStore
53
56
  from nat.retriever.interface import Retriever
54
57
 
@@ -289,6 +292,55 @@ class Builder(ABC):
289
292
  def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
290
293
  pass
291
294
 
295
+ @abstractmethod
296
+ async def add_middleware(self, name: str | MiddlewareRef, config: MiddlewareBaseConfig) -> Middleware:
297
+ """Add middleware to the builder.
298
+
299
+ Args:
300
+ name: The name or reference for the middleware
301
+ config: The configuration for the middleware
302
+
303
+ Returns:
304
+ The built middleware instance
305
+ """
306
+ pass
307
+
308
+ @abstractmethod
309
+ async def get_middleware(self, middleware_name: str | MiddlewareRef) -> Middleware:
310
+ """Get built middleware by name.
311
+
312
+ Args:
313
+ middleware_name: The name or reference of the middleware
314
+
315
+ Returns:
316
+ The built middleware instance
317
+ """
318
+ pass
319
+
320
+ @abstractmethod
321
+ def get_middleware_config(self, middleware_name: str | MiddlewareRef) -> MiddlewareBaseConfig:
322
+ """Get the configuration for middleware.
323
+
324
+ Args:
325
+ middleware_name: The name or reference of the middleware
326
+
327
+ Returns:
328
+ The configuration for the middleware
329
+ """
330
+ pass
331
+
332
+ async def get_middleware_list(self, middleware_names: Sequence[str | MiddlewareRef]) -> list[Middleware]:
333
+ """Get multiple middleware by name.
334
+
335
+ Args:
336
+ middleware_names: The names or references of the middleware
337
+
338
+ Returns:
339
+ List of built middleware instances
340
+ """
341
+ tasks = [self.get_middleware(name) for name in middleware_names]
342
+ return list(await asyncio.gather(*tasks, return_exceptions=False))
343
+
292
344
 
293
345
  class EvalBuilder(ABC):
294
346
 
@@ -33,6 +33,7 @@ from nat.data_models.function import FunctionBaseConfig
33
33
  from nat.data_models.function import FunctionGroupBaseConfig
34
34
  from nat.data_models.llm import LLMBaseConfig
35
35
  from nat.data_models.memory import MemoryBaseConfig
36
+ from nat.data_models.middleware import MiddlewareBaseConfig
36
37
  from nat.data_models.object_store import ObjectStoreBaseConfig
37
38
  from nat.data_models.retriever import RetrieverBaseConfig
38
39
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
@@ -41,6 +42,7 @@ from nat.utils.type_utils import DecomposedType
41
42
  logger = logging.getLogger(__name__)
42
43
 
43
44
  # Order in which we want to process the component groups
45
+ # IMPORTANT: MIDDLEWARE must be built before FUNCTIONS
44
46
  _component_group_order = [
45
47
  ComponentGroup.AUTHENTICATION,
46
48
  ComponentGroup.EMBEDDERS,
@@ -49,6 +51,7 @@ _component_group_order = [
49
51
  ComponentGroup.OBJECT_STORES,
50
52
  ComponentGroup.RETRIEVERS,
51
53
  ComponentGroup.TTC_STRATEGIES,
54
+ ComponentGroup.MIDDLEWARE,
52
55
  ComponentGroup.FUNCTION_GROUPS,
53
56
  ComponentGroup.FUNCTIONS,
54
57
  ]
@@ -111,6 +114,8 @@ def group_from_component(component: TypedBaseModel) -> ComponentGroup | None:
111
114
  return ComponentGroup.FUNCTIONS
112
115
  if (isinstance(component, FunctionGroupBaseConfig)):
113
116
  return ComponentGroup.FUNCTION_GROUPS
117
+ if (isinstance(component, MiddlewareBaseConfig)):
118
+ return ComponentGroup.MIDDLEWARE
114
119
  if (isinstance(component, LLMBaseConfig)):
115
120
  return ComponentGroup.LLMS
116
121
  if (isinstance(component, MemoryBaseConfig)):
@@ -260,7 +265,8 @@ def build_dependency_sequence(config: "Config") -> list[ComponentInstanceData]:
260
265
 
261
266
  total_node_count = (len(config.embedders) + len(config.functions) + len(config.function_groups) + len(config.llms) +
262
267
  len(config.memory) + len(config.object_stores) + len(config.retrievers) +
263
- len(config.ttc_strategies) + len(config.authentication) + 1) # +1 for the workflow
268
+ len(config.ttc_strategies) + len(config.authentication) + len(config.middleware) + 1
269
+ ) # +1 for the workflow
264
270
 
265
271
  dependency_map: dict
266
272
  dependency_graph: nx.DiGraph
nat/builder/context.py CHANGED
@@ -34,6 +34,7 @@ from nat.data_models.intermediate_step import IntermediateStepType
34
34
  from nat.data_models.intermediate_step import StreamEventData
35
35
  from nat.data_models.intermediate_step import TraceMetadata
36
36
  from nat.data_models.invocation_node import InvocationNode
37
+ from nat.data_models.runtime_enum import RuntimeTypeEnum
37
38
  from nat.runtime.user_metadata import RequestAttributes
38
39
  from nat.utils.reactive.subject import Subject
39
40
 
@@ -72,6 +73,8 @@ class ContextState(metaclass=Singleton):
72
73
  self.workflow_trace_id: ContextVar[int | None] = ContextVar("workflow_trace_id", default=None)
73
74
  self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
74
75
  self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
76
+ self.runtime_type: ContextVar[RuntimeTypeEnum] = ContextVar("runtime_type",
77
+ default=RuntimeTypeEnum.RUN_OR_SERVE)
75
78
  self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
76
79
  self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None)
77
80
  self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None)
@@ -302,6 +305,20 @@ class Context:
302
305
  raise RuntimeError("User authentication callback is not set in the context.")
303
306
  return callback
304
307
 
308
+ @property
309
+ def is_evaluating(self) -> bool:
310
+ """
311
+ Indicates whether the current context is in evaluation mode.
312
+
313
+ This property checks the context state to determine if the current
314
+ operation is being performed in evaluation mode. It returns a boolean
315
+ value indicating the evaluation status.
316
+
317
+ Returns:
318
+ bool: True if in evaluation mode, False otherwise.
319
+ """
320
+ return self._context_state.runtime_type.get() == RuntimeTypeEnum.EVALUATE
321
+
305
322
  @staticmethod
306
323
  def get() -> "Context":
307
324
  """
@@ -23,3 +23,4 @@ class LLMFrameworkEnum(str, Enum):
23
23
  SEMANTIC_KERNEL = "semantic_kernel"
24
24
  AGNO = "agno"
25
25
  ADK = "adk"
26
+ STRANDS = "strands"
nat/builder/function.py CHANGED
@@ -34,6 +34,9 @@ from nat.builder.function_info import FunctionInfo
34
34
  from nat.data_models.function import EmptyFunctionConfig
35
35
  from nat.data_models.function import FunctionBaseConfig
36
36
  from nat.data_models.function import FunctionGroupBaseConfig
37
+ from nat.middleware.function_middleware import FunctionMiddlewareChain
38
+ from nat.middleware.middleware import FunctionMiddlewareContext
39
+ from nat.middleware.middleware import Middleware
37
40
 
38
41
  _InvokeFnT = Callable[[InputT], Awaitable[SingleOutputT]]
39
42
  _StreamFnT = Callable[[InputT], AsyncGenerator[StreamingOutputT]]
@@ -64,6 +67,9 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
64
67
  self.description = description
65
68
  self.instance_name = instance_name or config.type
66
69
  self._context = Context.get()
70
+ self._configured_middleware: tuple[Middleware, ...] = tuple()
71
+ self._middlewared_single: _InvokeFnT | None = None
72
+ self._middlewared_stream: _StreamFnT | None = None
67
73
 
68
74
  def convert(self, value: typing.Any, to_type: type[_T]) -> _T:
69
75
  """
@@ -108,6 +114,38 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
108
114
  """
109
115
  return self._converter.try_convert(value, to_type=to_type)
110
116
 
117
+ @property
118
+ def middleware(self) -> tuple[Middleware, ...]:
119
+ """Return the currently configured middleware chain."""
120
+
121
+ return self._configured_middleware
122
+
123
+ def configure_middleware(self, middleware: Sequence[Middleware] | None = None) -> None:
124
+ """Attach an ordered list of middleware to this function instance."""
125
+
126
+ middleware_tuple: tuple[Middleware, ...] = tuple(middleware or ())
127
+
128
+ self._configured_middleware = middleware_tuple
129
+
130
+ if not middleware_tuple:
131
+ self._middlewared_single = None
132
+ self._middlewared_stream = None
133
+ return
134
+
135
+ logger.info(f"Building middleware for function '{self.instance_name}' in order of: {middleware_tuple}")
136
+
137
+ context = FunctionMiddlewareContext(name=self.instance_name,
138
+ config=self.config,
139
+ description=self.description,
140
+ input_schema=self.input_schema,
141
+ single_output_schema=self.single_output_schema,
142
+ stream_output_schema=self.streaming_output_schema)
143
+
144
+ chain = FunctionMiddlewareChain(middleware=middleware_tuple, context=context)
145
+
146
+ self._middlewared_single = chain.build_single(self._ainvoke) if self.has_single_output else None
147
+ self._middlewared_stream = chain.build_stream(self._astream) if self.has_streaming_output else None
148
+
111
149
  @abstractmethod
112
150
  async def _ainvoke(self, value: InputT) -> SingleOutputT:
113
151
  pass
@@ -150,7 +188,9 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
150
188
  try:
151
189
  converted_input: InputT = self._convert_input(value)
152
190
 
153
- result = await self._ainvoke(converted_input)
191
+ invoke_callable = self._middlewared_single or self._ainvoke
192
+
193
+ result = await invoke_callable(converted_input)
154
194
 
155
195
  if to_type is not None and not isinstance(result, to_type):
156
196
  result = self.convert(result, to_type)
@@ -243,7 +283,9 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
243
283
  # Collect streaming outputs to capture the final result
244
284
  final_output: list[typing.Any] = []
245
285
 
246
- async for data in self._astream(converted_input):
286
+ stream_callable = self._middlewared_stream or self._astream
287
+
288
+ async for data in stream_callable(converted_input):
247
289
  if to_type is not None and not isinstance(data, to_type):
248
290
  converted_data = self.convert(data, to_type=to_type)
249
291
  final_output.append(converted_data)
@@ -357,7 +399,8 @@ class FunctionGroup:
357
399
  *,
358
400
  config: FunctionGroupBaseConfig,
359
401
  instance_name: str | None = None,
360
- filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None):
402
+ filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
403
+ middleware: Sequence[Middleware] | None = None):
361
404
  """
362
405
  Creates a new function group.
363
406
 
@@ -370,12 +413,15 @@ class FunctionGroup:
370
413
  filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
371
414
  A callback function to additionally filter the functions in the function group dynamically when
372
415
  the functions are accessed via any accessor method.
416
+ middleware : Sequence[Middleware] | None, optional
417
+ The middleware instances to apply to all functions in this group.
373
418
  """
374
419
  self._config = config
375
420
  self._instance_name = instance_name or config.type
376
421
  self._functions: dict[str, Function] = dict()
377
422
  self._filter_fn = filter_fn
378
423
  self._per_function_filter_fn: dict[str, Callable[[str], Awaitable[bool]]] = dict()
424
+ self._middleware: tuple[Middleware, ...] = tuple(middleware or ())
379
425
 
380
426
  def add_function(self,
381
427
  name: str,
@@ -424,6 +470,9 @@ class FunctionGroup:
424
470
  info = FunctionInfo.from_fn(fn, input_schema=input_schema, description=description, converters=converters)
425
471
  full_name = self._get_fn_name(name)
426
472
  lambda_fn = LambdaFunction.from_info(config=EmptyFunctionConfig(), info=info, instance_name=full_name)
473
+ # Configure middleware from the function group if any
474
+ if self._middleware:
475
+ lambda_fn.configure_middleware(self._middleware)
427
476
  self._functions[name] = lambda_fn
428
477
  if filter_fn:
429
478
  self._per_function_filter_fn[name] = filter_fn
@@ -712,3 +761,25 @@ class FunctionGroup:
712
761
  Returns the instance name for the function group.
713
762
  """
714
763
  return self._instance_name
764
+
765
+ @property
766
+ def middleware(self) -> tuple[Middleware, ...]:
767
+ """
768
+ Returns the middleware configured for this function group.
769
+ """
770
+ return self._middleware
771
+
772
+ def configure_middleware(self, middleware: Sequence[Middleware] | None = None) -> None:
773
+ """
774
+ Configure the middleware for this function group.
775
+ These middleware will be applied to all functions added to the group.
776
+
777
+ Parameters
778
+ ----------
779
+ middleware : Sequence[Middleware] | None
780
+ The middleware to configure for the function group.
781
+ """
782
+ self._middleware = tuple(middleware or ())
783
+ # Update existing functions with the new middleware
784
+ for func in self._functions.values():
785
+ func.configure_middleware(self._middleware)
nat/builder/workflow.py CHANGED
@@ -28,6 +28,7 @@ from nat.builder.function_base import StreamingOutputT
28
28
  from nat.builder.llm import LLMProviderInfo
29
29
  from nat.builder.retriever import RetrieverProviderInfo
30
30
  from nat.data_models.config import Config
31
+ from nat.data_models.runtime_enum import RuntimeTypeEnum
31
32
  from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
32
33
  from nat.memory.interfaces import MemoryEditor
33
34
  from nat.object_store.interfaces import ObjectStore
@@ -94,7 +95,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
94
95
  return self._exporter_manager.get()
95
96
 
96
97
  @asynccontextmanager
97
- async def run(self, message: InputT):
98
+ async def run(self, message: InputT, runtime_type: RuntimeTypeEnum = RuntimeTypeEnum.RUN_OR_SERVE):
98
99
  """
99
100
  Called each time we start a new workflow run. We'll create
100
101
  a new top-level workflow span here.
@@ -103,7 +104,8 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
103
104
  async with Runner(input_message=message,
104
105
  entry_fn=self._entry_fn,
105
106
  context_state=self._context_state,
106
- exporter_manager=self.exporter_manager) as runner:
107
+ exporter_manager=self.exporter_manager,
108
+ runtime_type=runtime_type) as runner:
107
109
 
108
110
  # The caller can `yield runner` so they can do `runner.result()` or `runner.result_stream()`
109
111
  yield runner
@@ -51,6 +51,7 @@ from nat.data_models.component_ref import FunctionGroupRef
51
51
  from nat.data_models.component_ref import FunctionRef
52
52
  from nat.data_models.component_ref import LLMRef
53
53
  from nat.data_models.component_ref import MemoryRef
54
+ from nat.data_models.component_ref import MiddlewareRef
54
55
  from nat.data_models.component_ref import ObjectStoreRef
55
56
  from nat.data_models.component_ref import RetrieverRef
56
57
  from nat.data_models.component_ref import TTCStrategyRef
@@ -62,6 +63,7 @@ from nat.data_models.function import FunctionGroupBaseConfig
62
63
  from nat.data_models.function_dependencies import FunctionDependencies
63
64
  from nat.data_models.llm import LLMBaseConfig
64
65
  from nat.data_models.memory import MemoryBaseConfig
66
+ from nat.data_models.middleware import MiddlewareBaseConfig
65
67
  from nat.data_models.object_store import ObjectStoreBaseConfig
66
68
  from nat.data_models.retriever import RetrieverBaseConfig
67
69
  from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
@@ -71,6 +73,8 @@ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEn
71
73
  from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
72
74
  from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
73
75
  from nat.memory.interfaces import MemoryEditor
76
+ from nat.middleware.function_middleware import FunctionMiddleware
77
+ from nat.middleware.middleware import Middleware
74
78
  from nat.object_store.interfaces import ObjectStore
75
79
  from nat.observability.exporter.base_exporter import BaseExporter
76
80
  from nat.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
@@ -141,6 +145,12 @@ class ConfiguredTTCStrategy:
141
145
  instance: StrategyBase
142
146
 
143
147
 
148
+ @dataclasses.dataclass
149
+ class ConfiguredMiddleware:
150
+ config: MiddlewareBaseConfig
151
+ instance: Middleware
152
+
153
+
144
154
  class WorkflowBuilder(Builder, AbstractAsyncContextManager):
145
155
 
146
156
  def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
@@ -170,6 +180,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
170
180
  self._object_stores: dict[str, ConfiguredObjectStore] = {}
171
181
  self._retrievers: dict[str, ConfiguredRetriever] = {}
172
182
  self._ttc_strategies: dict[str, ConfiguredTTCStrategy] = {}
183
+ self._middleware: dict[str, ConfiguredMiddleware] = {}
173
184
 
174
185
  self._context_state = ContextState.get()
175
186
 
@@ -423,6 +434,22 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
423
434
  raise ValueError("Expected a function, FunctionInfo object, or FunctionBase object to be "
424
435
  f"returned from the function builder. Got {type(build_result)}")
425
436
 
437
+ # Resolve middleware names from config to middleware instances
438
+ # Only FunctionMiddleware types can be used with functions
439
+ middleware_instances = []
440
+ for middleware_name in config.middleware:
441
+ if middleware_name not in self._middleware:
442
+ raise ValueError(f"Middleware `{middleware_name}` not found for function `{name}`. "
443
+ f"It must be configured in the `middleware` section of the YAML configuration.")
444
+ middleware_obj = self._middleware[middleware_name].instance
445
+ if not isinstance(middleware_obj, FunctionMiddleware):
446
+ raise TypeError(
447
+ f"Middleware `{middleware_name}` is not a FunctionMiddleware and cannot be used with functions. "
448
+ f"Only FunctionMiddleware types support function-specific wrapping.")
449
+ middleware_instances.append(middleware_obj)
450
+
451
+ build_result.configure_middleware(middleware_instances)
452
+
426
453
  return ConfiguredFunction(config=config, instance=build_result)
427
454
 
428
455
  async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
@@ -461,6 +488,23 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
461
488
  raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
462
489
  f"Got {type(build_result)}")
463
490
 
491
+ # Resolve middleware names from config to middleware instances
492
+ # Only FunctionMiddleware types can be used with function groups
493
+ middleware_instances = []
494
+ for middleware_name in config.middleware:
495
+ if middleware_name not in self._middleware:
496
+ raise ValueError(f"Middleware `{middleware_name}` not found for function group `{name}`. "
497
+ f"It must be configured in the `middleware` section of the YAML configuration.")
498
+ middleware_obj = self._middleware[middleware_name].instance
499
+ if not isinstance(middleware_obj, FunctionMiddleware):
500
+ raise TypeError(f"Middleware `{middleware_name}` is not a FunctionMiddleware and "
501
+ f"cannot be used with function groups. "
502
+ f"Only FunctionMiddleware types support function-specific wrapping.")
503
+ middleware_instances.append(middleware_obj)
504
+
505
+ # Configure middleware for the function group
506
+ build_result.configure_middleware(middleware_instances)
507
+
464
508
  # set the instance name for the function group based on the workflow-provided name
465
509
  build_result.set_instance_name(name)
466
510
  return ConfiguredFunctionGroup(config=config, instance=build_result)
@@ -968,6 +1012,72 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
968
1012
 
969
1013
  return config
970
1014
 
1015
+ @override
1016
+ async def add_middleware(self, name: str | MiddlewareRef, config: MiddlewareBaseConfig) -> Middleware:
1017
+ """Add middleware to the builder.
1018
+
1019
+ Args:
1020
+ name: The name or reference for the middleware
1021
+ config: The configuration for the middleware
1022
+
1023
+ Returns:
1024
+ The built middleware instance
1025
+
1026
+ Raises:
1027
+ ValueError: If the middleware already exists
1028
+ """
1029
+ if name in self._middleware:
1030
+ raise ValueError(f"Middleware `{name}` already exists in the list of middleware")
1031
+
1032
+ try:
1033
+ middleware_info = self._registry.get_middleware(type(config))
1034
+
1035
+ middleware_instance = await self._get_exit_stack().enter_async_context(
1036
+ middleware_info.build_fn(config, self))
1037
+
1038
+ self._middleware[name] = ConfiguredMiddleware(config=config, instance=middleware_instance)
1039
+
1040
+ return middleware_instance
1041
+ except Exception as e:
1042
+ logger.error("Error adding function middleware `%s` with config `%s`: %s", name, config, e)
1043
+ raise
1044
+
1045
+ @override
1046
+ async def get_middleware(self, middleware_name: str | MiddlewareRef) -> Middleware:
1047
+ """Get built middleware by name.
1048
+
1049
+ Args:
1050
+ middleware_name: The name or reference of the middleware
1051
+
1052
+ Returns:
1053
+ The built middleware instance
1054
+
1055
+ Raises:
1056
+ ValueError: If the middleware is not found
1057
+ """
1058
+ if middleware_name not in self._middleware:
1059
+ raise ValueError(f"Middleware `{middleware_name}` not found")
1060
+
1061
+ return self._middleware[middleware_name].instance
1062
+
1063
+ @override
1064
+ def get_middleware_config(self, middleware_name: str | MiddlewareRef) -> MiddlewareBaseConfig:
1065
+ """Get the configuration for middleware.
1066
+
1067
+ Args:
1068
+ middleware_name: The name or reference of the middleware
1069
+
1070
+ Returns:
1071
+ The configuration for the middleware
1072
+
1073
+ Raises:
1074
+ ValueError: If the middleware is not found
1075
+ """
1076
+ if middleware_name not in self._middleware:
1077
+ raise ValueError(f"Middleware `{middleware_name}` not found")
1078
+
1079
+ return self._middleware[middleware_name].config
1080
+
971
1081
  @override
972
1082
  def get_user_manager(self):
973
1083
  return UserManagerHolder(context=Context(self._context_state))
@@ -1108,6 +1218,10 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
1108
1218
  elif component_instance.component_group == ComponentGroup.RETRIEVERS:
1109
1219
  await self.add_retriever(component_instance.name,
1110
1220
  cast(RetrieverBaseConfig, component_instance.config))
1221
+ # Instantiate middleware
1222
+ elif component_instance.component_group == ComponentGroup.MIDDLEWARE:
1223
+ await self.add_middleware(component_instance.name,
1224
+ cast(MiddlewareBaseConfig, component_instance.config))
1111
1225
  # Instantiate a function group
1112
1226
  elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
1113
1227
  await self.add_function_group(component_instance.name,
@@ -1363,3 +1477,18 @@ class ChildBuilder(Builder):
1363
1477
  @override
1364
1478
  def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
1365
1479
  return self._workflow_builder.get_function_group_dependencies(fn_name)
1480
+
1481
+ @override
1482
+ async def add_middleware(self, name: str | MiddlewareRef, config: MiddlewareBaseConfig) -> Middleware:
1483
+ """Add middleware to the builder."""
1484
+ return await self._workflow_builder.add_middleware(name, config)
1485
+
1486
+ @override
1487
+ async def get_middleware(self, middleware_name: str | MiddlewareRef) -> Middleware:
1488
+ """Get built middleware by name."""
1489
+ return await self._workflow_builder.get_middleware(middleware_name)
1490
+
1491
+ @override
1492
+ def get_middleware_config(self, middleware_name: str | MiddlewareRef) -> MiddlewareBaseConfig:
1493
+ """Get the configuration for middleware."""
1494
+ return self._workflow_builder.get_middleware_config(middleware_name)
@@ -38,6 +38,8 @@ from nat.cli.type_registry import LoggingMethodConfigT
38
38
  from nat.cli.type_registry import LoggingMethodRegisteredCallableT
39
39
  from nat.cli.type_registry import MemoryBuildCallableT
40
40
  from nat.cli.type_registry import MemoryRegisteredCallableT
41
+ from nat.cli.type_registry import MiddlewareBuildCallableT
42
+ from nat.cli.type_registry import MiddlewareRegisteredCallableT
41
43
  from nat.cli.type_registry import ObjectStoreBuildCallableT
42
44
  from nat.cli.type_registry import ObjectStoreRegisteredCallableT
43
45
  from nat.cli.type_registry import RegisteredLoggingMethod
@@ -65,6 +67,7 @@ from nat.data_models.function import FunctionConfigT
65
67
  from nat.data_models.function import FunctionGroupConfigT
66
68
  from nat.data_models.llm import LLMBaseConfigT
67
69
  from nat.data_models.memory import MemoryBaseConfigT
70
+ from nat.data_models.middleware import MiddlewareBaseConfigT
68
71
  from nat.data_models.object_store import ObjectStoreBaseConfigT
69
72
  from nat.data_models.registry_handler import RegistryHandlerBaseConfigT
70
73
  from nat.data_models.retriever import RetrieverBaseConfigT
@@ -149,6 +152,10 @@ def register_function(config_type: type[FunctionConfigT],
149
152
  framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
150
153
  """
151
154
  Register a workflow with optional framework_wrappers for automatic profiler hooking.
155
+
156
+ Args:
157
+ config_type: The function configuration type
158
+ framework_wrappers: Optional list of framework wrappers for automatic profiler hooking
152
159
  """
153
160
 
154
161
  def register_function_inner(
@@ -211,6 +218,49 @@ def register_function_group(config_type: type[FunctionGroupConfigT],
211
218
  return register_function_group_inner
212
219
 
213
220
 
221
+ def register_middleware(config_type: type[MiddlewareBaseConfigT]):
222
+ """
223
+ Register a middleware component.
224
+
225
+ Middleware provides middleware-style wrapping of calls with
226
+ preprocessing and postprocessing logic. They are built as components that can
227
+ be configured in YAML and referenced by name in configurations.
228
+
229
+ Args:
230
+ config_type: The middleware configuration type to register
231
+
232
+ Returns:
233
+ A decorator that wraps the build function as an async context manager
234
+ """
235
+
236
+ def register_middleware_inner(
237
+ fn: MiddlewareBuildCallableT[MiddlewareBaseConfigT]
238
+ ) -> MiddlewareRegisteredCallableT[MiddlewareBaseConfigT]:
239
+ from .type_registry import GlobalTypeRegistry
240
+ from .type_registry import RegisteredMiddlewareInfo
241
+
242
+ context_manager_fn = asynccontextmanager(fn)
243
+
244
+ discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
245
+ component_type=ComponentEnum.MIDDLEWARE)
246
+
247
+ GlobalTypeRegistry.get().register_middleware(
248
+ RegisteredMiddlewareInfo(
249
+ full_type=config_type.full_type,
250
+ config_type=config_type,
251
+ build_fn=context_manager_fn,
252
+ discovery_metadata=discovery_metadata,
253
+ ))
254
+
255
+ return context_manager_fn
256
+
257
+ return register_middleware_inner
258
+
259
+
260
+ # Compatibility alias for backwards compatibility
261
+ register_function_middleware = register_middleware
262
+
263
+
214
264
  def register_llm_provider(config_type: type[LLMBaseConfigT]):
215
265
 
216
266
  def register_llm_provider_inner(