nvidia-nat 1.4.0a20251102__py3-none-any.whl → 1.4.0a20251120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/builder/builder.py +52 -0
- nat/builder/component_utils.py +7 -1
- nat/builder/context.py +17 -0
- nat/builder/framework_enum.py +1 -0
- nat/builder/function.py +74 -3
- nat/builder/workflow.py +4 -2
- nat/builder/workflow_builder.py +129 -0
- nat/cli/commands/workflow/workflow_commands.py +3 -2
- nat/cli/register_workflow.py +50 -0
- nat/cli/type_registry.py +68 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +16 -0
- nat/data_models/function.py +14 -1
- nat/data_models/middleware.py +35 -0
- nat/data_models/runtime_enum.py +26 -0
- nat/eval/dataset_handler/dataset_filter.py +34 -2
- nat/eval/evaluate.py +11 -3
- nat/eval/utils/weave_eval.py +17 -3
- nat/front_ends/fastapi/fastapi_front_end_config.py +29 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +13 -7
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +144 -14
- nat/front_ends/mcp/mcp_front_end_plugin.py +4 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +26 -0
- nat/llm/aws_bedrock_llm.py +11 -9
- nat/llm/azure_openai_llm.py +12 -4
- nat/llm/litellm_llm.py +11 -4
- nat/llm/nim_llm.py +11 -9
- nat/llm/openai_llm.py +12 -9
- nat/middleware/__init__.py +35 -0
- nat/middleware/cache_middleware.py +256 -0
- nat/middleware/function_middleware.py +186 -0
- nat/middleware/middleware.py +184 -0
- nat/middleware/register.py +35 -0
- nat/profiler/decorators/framework_wrapper.py +16 -0
- nat/retriever/milvus/register.py +11 -3
- nat/retriever/milvus/retriever.py +102 -40
- nat/runtime/runner.py +12 -1
- nat/runtime/session.py +10 -3
- nat/tool/code_execution/code_sandbox.py +4 -7
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +5 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
- nat/tool/server_tools.py +15 -2
- nat/utils/__init__.py +8 -4
- nat/utils/io/yaml_tools.py +73 -3
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/METADATA +11 -3
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/RECORD +54 -50
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/entry_points.txt +1 -0
- nat/data_models/temperature_mixin.py +0 -44
- nat/data_models/top_p_mixin.py +0 -44
- nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/top_level.txt +0 -0
nat/builder/builder.py
CHANGED
|
@@ -32,6 +32,7 @@ from nat.data_models.component_ref import FunctionGroupRef
|
|
|
32
32
|
from nat.data_models.component_ref import FunctionRef
|
|
33
33
|
from nat.data_models.component_ref import LLMRef
|
|
34
34
|
from nat.data_models.component_ref import MemoryRef
|
|
35
|
+
from nat.data_models.component_ref import MiddlewareRef
|
|
35
36
|
from nat.data_models.component_ref import ObjectStoreRef
|
|
36
37
|
from nat.data_models.component_ref import RetrieverRef
|
|
37
38
|
from nat.data_models.component_ref import TTCStrategyRef
|
|
@@ -42,6 +43,7 @@ from nat.data_models.function import FunctionGroupBaseConfig
|
|
|
42
43
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
43
44
|
from nat.data_models.llm import LLMBaseConfig
|
|
44
45
|
from nat.data_models.memory import MemoryBaseConfig
|
|
46
|
+
from nat.data_models.middleware import MiddlewareBaseConfig
|
|
45
47
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
46
48
|
from nat.data_models.retriever import RetrieverBaseConfig
|
|
47
49
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
@@ -49,6 +51,7 @@ from nat.experimental.decorators.experimental_warning_decorator import experimen
|
|
|
49
51
|
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
50
52
|
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
51
53
|
from nat.memory.interfaces import MemoryEditor
|
|
54
|
+
from nat.middleware.middleware import Middleware
|
|
52
55
|
from nat.object_store.interfaces import ObjectStore
|
|
53
56
|
from nat.retriever.interface import Retriever
|
|
54
57
|
|
|
@@ -289,6 +292,55 @@ class Builder(ABC):
|
|
|
289
292
|
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
290
293
|
pass
|
|
291
294
|
|
|
295
|
+
@abstractmethod
|
|
296
|
+
async def add_middleware(self, name: str | MiddlewareRef, config: MiddlewareBaseConfig) -> Middleware:
|
|
297
|
+
"""Add middleware to the builder.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
name: The name or reference for the middleware
|
|
301
|
+
config: The configuration for the middleware
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
The built middleware instance
|
|
305
|
+
"""
|
|
306
|
+
pass
|
|
307
|
+
|
|
308
|
+
@abstractmethod
|
|
309
|
+
async def get_middleware(self, middleware_name: str | MiddlewareRef) -> Middleware:
|
|
310
|
+
"""Get built middleware by name.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
middleware_name: The name or reference of the middleware
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
The built middleware instance
|
|
317
|
+
"""
|
|
318
|
+
pass
|
|
319
|
+
|
|
320
|
+
@abstractmethod
|
|
321
|
+
def get_middleware_config(self, middleware_name: str | MiddlewareRef) -> MiddlewareBaseConfig:
|
|
322
|
+
"""Get the configuration for middleware.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
middleware_name: The name or reference of the middleware
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
The configuration for the middleware
|
|
329
|
+
"""
|
|
330
|
+
pass
|
|
331
|
+
|
|
332
|
+
async def get_middleware_list(self, middleware_names: Sequence[str | MiddlewareRef]) -> list[Middleware]:
|
|
333
|
+
"""Get multiple middleware by name.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
middleware_names: The names or references of the middleware
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
List of built middleware instances
|
|
340
|
+
"""
|
|
341
|
+
tasks = [self.get_middleware(name) for name in middleware_names]
|
|
342
|
+
return list(await asyncio.gather(*tasks, return_exceptions=False))
|
|
343
|
+
|
|
292
344
|
|
|
293
345
|
class EvalBuilder(ABC):
|
|
294
346
|
|
nat/builder/component_utils.py
CHANGED
|
@@ -33,6 +33,7 @@ from nat.data_models.function import FunctionBaseConfig
|
|
|
33
33
|
from nat.data_models.function import FunctionGroupBaseConfig
|
|
34
34
|
from nat.data_models.llm import LLMBaseConfig
|
|
35
35
|
from nat.data_models.memory import MemoryBaseConfig
|
|
36
|
+
from nat.data_models.middleware import MiddlewareBaseConfig
|
|
36
37
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
37
38
|
from nat.data_models.retriever import RetrieverBaseConfig
|
|
38
39
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
@@ -41,6 +42,7 @@ from nat.utils.type_utils import DecomposedType
|
|
|
41
42
|
logger = logging.getLogger(__name__)
|
|
42
43
|
|
|
43
44
|
# Order in which we want to process the component groups
|
|
45
|
+
# IMPORTANT: MIDDLEWARE must be built before FUNCTIONS
|
|
44
46
|
_component_group_order = [
|
|
45
47
|
ComponentGroup.AUTHENTICATION,
|
|
46
48
|
ComponentGroup.EMBEDDERS,
|
|
@@ -49,6 +51,7 @@ _component_group_order = [
|
|
|
49
51
|
ComponentGroup.OBJECT_STORES,
|
|
50
52
|
ComponentGroup.RETRIEVERS,
|
|
51
53
|
ComponentGroup.TTC_STRATEGIES,
|
|
54
|
+
ComponentGroup.MIDDLEWARE,
|
|
52
55
|
ComponentGroup.FUNCTION_GROUPS,
|
|
53
56
|
ComponentGroup.FUNCTIONS,
|
|
54
57
|
]
|
|
@@ -111,6 +114,8 @@ def group_from_component(component: TypedBaseModel) -> ComponentGroup | None:
|
|
|
111
114
|
return ComponentGroup.FUNCTIONS
|
|
112
115
|
if (isinstance(component, FunctionGroupBaseConfig)):
|
|
113
116
|
return ComponentGroup.FUNCTION_GROUPS
|
|
117
|
+
if (isinstance(component, MiddlewareBaseConfig)):
|
|
118
|
+
return ComponentGroup.MIDDLEWARE
|
|
114
119
|
if (isinstance(component, LLMBaseConfig)):
|
|
115
120
|
return ComponentGroup.LLMS
|
|
116
121
|
if (isinstance(component, MemoryBaseConfig)):
|
|
@@ -260,7 +265,8 @@ def build_dependency_sequence(config: "Config") -> list[ComponentInstanceData]:
|
|
|
260
265
|
|
|
261
266
|
total_node_count = (len(config.embedders) + len(config.functions) + len(config.function_groups) + len(config.llms) +
|
|
262
267
|
len(config.memory) + len(config.object_stores) + len(config.retrievers) +
|
|
263
|
-
len(config.ttc_strategies) + len(config.authentication) +
|
|
268
|
+
len(config.ttc_strategies) + len(config.authentication) + len(config.middleware) + 1
|
|
269
|
+
) # +1 for the workflow
|
|
264
270
|
|
|
265
271
|
dependency_map: dict
|
|
266
272
|
dependency_graph: nx.DiGraph
|
nat/builder/context.py
CHANGED
|
@@ -34,6 +34,7 @@ from nat.data_models.intermediate_step import IntermediateStepType
|
|
|
34
34
|
from nat.data_models.intermediate_step import StreamEventData
|
|
35
35
|
from nat.data_models.intermediate_step import TraceMetadata
|
|
36
36
|
from nat.data_models.invocation_node import InvocationNode
|
|
37
|
+
from nat.data_models.runtime_enum import RuntimeTypeEnum
|
|
37
38
|
from nat.runtime.user_metadata import RequestAttributes
|
|
38
39
|
from nat.utils.reactive.subject import Subject
|
|
39
40
|
|
|
@@ -72,6 +73,8 @@ class ContextState(metaclass=Singleton):
|
|
|
72
73
|
self.workflow_trace_id: ContextVar[int | None] = ContextVar("workflow_trace_id", default=None)
|
|
73
74
|
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
|
|
74
75
|
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
|
|
76
|
+
self.runtime_type: ContextVar[RuntimeTypeEnum] = ContextVar("runtime_type",
|
|
77
|
+
default=RuntimeTypeEnum.RUN_OR_SERVE)
|
|
75
78
|
self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
|
|
76
79
|
self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None)
|
|
77
80
|
self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None)
|
|
@@ -302,6 +305,20 @@ class Context:
|
|
|
302
305
|
raise RuntimeError("User authentication callback is not set in the context.")
|
|
303
306
|
return callback
|
|
304
307
|
|
|
308
|
+
@property
|
|
309
|
+
def is_evaluating(self) -> bool:
|
|
310
|
+
"""
|
|
311
|
+
Indicates whether the current context is in evaluation mode.
|
|
312
|
+
|
|
313
|
+
This property checks the context state to determine if the current
|
|
314
|
+
operation is being performed in evaluation mode. It returns a boolean
|
|
315
|
+
value indicating the evaluation status.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
bool: True if in evaluation mode, False otherwise.
|
|
319
|
+
"""
|
|
320
|
+
return self._context_state.runtime_type.get() == RuntimeTypeEnum.EVALUATE
|
|
321
|
+
|
|
305
322
|
@staticmethod
|
|
306
323
|
def get() -> "Context":
|
|
307
324
|
"""
|
nat/builder/framework_enum.py
CHANGED
nat/builder/function.py
CHANGED
|
@@ -34,6 +34,9 @@ from nat.builder.function_info import FunctionInfo
|
|
|
34
34
|
from nat.data_models.function import EmptyFunctionConfig
|
|
35
35
|
from nat.data_models.function import FunctionBaseConfig
|
|
36
36
|
from nat.data_models.function import FunctionGroupBaseConfig
|
|
37
|
+
from nat.middleware.function_middleware import FunctionMiddlewareChain
|
|
38
|
+
from nat.middleware.middleware import FunctionMiddlewareContext
|
|
39
|
+
from nat.middleware.middleware import Middleware
|
|
37
40
|
|
|
38
41
|
_InvokeFnT = Callable[[InputT], Awaitable[SingleOutputT]]
|
|
39
42
|
_StreamFnT = Callable[[InputT], AsyncGenerator[StreamingOutputT]]
|
|
@@ -64,6 +67,9 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
64
67
|
self.description = description
|
|
65
68
|
self.instance_name = instance_name or config.type
|
|
66
69
|
self._context = Context.get()
|
|
70
|
+
self._configured_middleware: tuple[Middleware, ...] = tuple()
|
|
71
|
+
self._middlewared_single: _InvokeFnT | None = None
|
|
72
|
+
self._middlewared_stream: _StreamFnT | None = None
|
|
67
73
|
|
|
68
74
|
def convert(self, value: typing.Any, to_type: type[_T]) -> _T:
|
|
69
75
|
"""
|
|
@@ -108,6 +114,38 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
108
114
|
"""
|
|
109
115
|
return self._converter.try_convert(value, to_type=to_type)
|
|
110
116
|
|
|
117
|
+
@property
|
|
118
|
+
def middleware(self) -> tuple[Middleware, ...]:
|
|
119
|
+
"""Return the currently configured middleware chain."""
|
|
120
|
+
|
|
121
|
+
return self._configured_middleware
|
|
122
|
+
|
|
123
|
+
def configure_middleware(self, middleware: Sequence[Middleware] | None = None) -> None:
|
|
124
|
+
"""Attach an ordered list of middleware to this function instance."""
|
|
125
|
+
|
|
126
|
+
middleware_tuple: tuple[Middleware, ...] = tuple(middleware or ())
|
|
127
|
+
|
|
128
|
+
self._configured_middleware = middleware_tuple
|
|
129
|
+
|
|
130
|
+
if not middleware_tuple:
|
|
131
|
+
self._middlewared_single = None
|
|
132
|
+
self._middlewared_stream = None
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
logger.info(f"Building middleware for function '{self.instance_name}' in order of: {middleware_tuple}")
|
|
136
|
+
|
|
137
|
+
context = FunctionMiddlewareContext(name=self.instance_name,
|
|
138
|
+
config=self.config,
|
|
139
|
+
description=self.description,
|
|
140
|
+
input_schema=self.input_schema,
|
|
141
|
+
single_output_schema=self.single_output_schema,
|
|
142
|
+
stream_output_schema=self.streaming_output_schema)
|
|
143
|
+
|
|
144
|
+
chain = FunctionMiddlewareChain(middleware=middleware_tuple, context=context)
|
|
145
|
+
|
|
146
|
+
self._middlewared_single = chain.build_single(self._ainvoke) if self.has_single_output else None
|
|
147
|
+
self._middlewared_stream = chain.build_stream(self._astream) if self.has_streaming_output else None
|
|
148
|
+
|
|
111
149
|
@abstractmethod
|
|
112
150
|
async def _ainvoke(self, value: InputT) -> SingleOutputT:
|
|
113
151
|
pass
|
|
@@ -150,7 +188,9 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
150
188
|
try:
|
|
151
189
|
converted_input: InputT = self._convert_input(value)
|
|
152
190
|
|
|
153
|
-
|
|
191
|
+
invoke_callable = self._middlewared_single or self._ainvoke
|
|
192
|
+
|
|
193
|
+
result = await invoke_callable(converted_input)
|
|
154
194
|
|
|
155
195
|
if to_type is not None and not isinstance(result, to_type):
|
|
156
196
|
result = self.convert(result, to_type)
|
|
@@ -243,7 +283,9 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
243
283
|
# Collect streaming outputs to capture the final result
|
|
244
284
|
final_output: list[typing.Any] = []
|
|
245
285
|
|
|
246
|
-
|
|
286
|
+
stream_callable = self._middlewared_stream or self._astream
|
|
287
|
+
|
|
288
|
+
async for data in stream_callable(converted_input):
|
|
247
289
|
if to_type is not None and not isinstance(data, to_type):
|
|
248
290
|
converted_data = self.convert(data, to_type=to_type)
|
|
249
291
|
final_output.append(converted_data)
|
|
@@ -357,7 +399,8 @@ class FunctionGroup:
|
|
|
357
399
|
*,
|
|
358
400
|
config: FunctionGroupBaseConfig,
|
|
359
401
|
instance_name: str | None = None,
|
|
360
|
-
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None
|
|
402
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
|
|
403
|
+
middleware: Sequence[Middleware] | None = None):
|
|
361
404
|
"""
|
|
362
405
|
Creates a new function group.
|
|
363
406
|
|
|
@@ -370,12 +413,15 @@ class FunctionGroup:
|
|
|
370
413
|
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
|
|
371
414
|
A callback function to additionally filter the functions in the function group dynamically when
|
|
372
415
|
the functions are accessed via any accessor method.
|
|
416
|
+
middleware : Sequence[Middleware] | None, optional
|
|
417
|
+
The middleware instances to apply to all functions in this group.
|
|
373
418
|
"""
|
|
374
419
|
self._config = config
|
|
375
420
|
self._instance_name = instance_name or config.type
|
|
376
421
|
self._functions: dict[str, Function] = dict()
|
|
377
422
|
self._filter_fn = filter_fn
|
|
378
423
|
self._per_function_filter_fn: dict[str, Callable[[str], Awaitable[bool]]] = dict()
|
|
424
|
+
self._middleware: tuple[Middleware, ...] = tuple(middleware or ())
|
|
379
425
|
|
|
380
426
|
def add_function(self,
|
|
381
427
|
name: str,
|
|
@@ -424,6 +470,9 @@ class FunctionGroup:
|
|
|
424
470
|
info = FunctionInfo.from_fn(fn, input_schema=input_schema, description=description, converters=converters)
|
|
425
471
|
full_name = self._get_fn_name(name)
|
|
426
472
|
lambda_fn = LambdaFunction.from_info(config=EmptyFunctionConfig(), info=info, instance_name=full_name)
|
|
473
|
+
# Configure middleware from the function group if any
|
|
474
|
+
if self._middleware:
|
|
475
|
+
lambda_fn.configure_middleware(self._middleware)
|
|
427
476
|
self._functions[name] = lambda_fn
|
|
428
477
|
if filter_fn:
|
|
429
478
|
self._per_function_filter_fn[name] = filter_fn
|
|
@@ -712,3 +761,25 @@ class FunctionGroup:
|
|
|
712
761
|
Returns the instance name for the function group.
|
|
713
762
|
"""
|
|
714
763
|
return self._instance_name
|
|
764
|
+
|
|
765
|
+
@property
|
|
766
|
+
def middleware(self) -> tuple[Middleware, ...]:
|
|
767
|
+
"""
|
|
768
|
+
Returns the middleware configured for this function group.
|
|
769
|
+
"""
|
|
770
|
+
return self._middleware
|
|
771
|
+
|
|
772
|
+
def configure_middleware(self, middleware: Sequence[Middleware] | None = None) -> None:
|
|
773
|
+
"""
|
|
774
|
+
Configure the middleware for this function group.
|
|
775
|
+
These middleware will be applied to all functions added to the group.
|
|
776
|
+
|
|
777
|
+
Parameters
|
|
778
|
+
----------
|
|
779
|
+
middleware : Sequence[Middleware] | None
|
|
780
|
+
The middleware to configure for the function group.
|
|
781
|
+
"""
|
|
782
|
+
self._middleware = tuple(middleware or ())
|
|
783
|
+
# Update existing functions with the new middleware
|
|
784
|
+
for func in self._functions.values():
|
|
785
|
+
func.configure_middleware(self._middleware)
|
nat/builder/workflow.py
CHANGED
|
@@ -28,6 +28,7 @@ from nat.builder.function_base import StreamingOutputT
|
|
|
28
28
|
from nat.builder.llm import LLMProviderInfo
|
|
29
29
|
from nat.builder.retriever import RetrieverProviderInfo
|
|
30
30
|
from nat.data_models.config import Config
|
|
31
|
+
from nat.data_models.runtime_enum import RuntimeTypeEnum
|
|
31
32
|
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
32
33
|
from nat.memory.interfaces import MemoryEditor
|
|
33
34
|
from nat.object_store.interfaces import ObjectStore
|
|
@@ -94,7 +95,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
94
95
|
return self._exporter_manager.get()
|
|
95
96
|
|
|
96
97
|
@asynccontextmanager
|
|
97
|
-
async def run(self, message: InputT):
|
|
98
|
+
async def run(self, message: InputT, runtime_type: RuntimeTypeEnum = RuntimeTypeEnum.RUN_OR_SERVE):
|
|
98
99
|
"""
|
|
99
100
|
Called each time we start a new workflow run. We'll create
|
|
100
101
|
a new top-level workflow span here.
|
|
@@ -103,7 +104,8 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
103
104
|
async with Runner(input_message=message,
|
|
104
105
|
entry_fn=self._entry_fn,
|
|
105
106
|
context_state=self._context_state,
|
|
106
|
-
exporter_manager=self.exporter_manager
|
|
107
|
+
exporter_manager=self.exporter_manager,
|
|
108
|
+
runtime_type=runtime_type) as runner:
|
|
107
109
|
|
|
108
110
|
# The caller can `yield runner` so they can do `runner.result()` or `runner.result_stream()`
|
|
109
111
|
yield runner
|
nat/builder/workflow_builder.py
CHANGED
|
@@ -51,6 +51,7 @@ from nat.data_models.component_ref import FunctionGroupRef
|
|
|
51
51
|
from nat.data_models.component_ref import FunctionRef
|
|
52
52
|
from nat.data_models.component_ref import LLMRef
|
|
53
53
|
from nat.data_models.component_ref import MemoryRef
|
|
54
|
+
from nat.data_models.component_ref import MiddlewareRef
|
|
54
55
|
from nat.data_models.component_ref import ObjectStoreRef
|
|
55
56
|
from nat.data_models.component_ref import RetrieverRef
|
|
56
57
|
from nat.data_models.component_ref import TTCStrategyRef
|
|
@@ -62,6 +63,7 @@ from nat.data_models.function import FunctionGroupBaseConfig
|
|
|
62
63
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
63
64
|
from nat.data_models.llm import LLMBaseConfig
|
|
64
65
|
from nat.data_models.memory import MemoryBaseConfig
|
|
66
|
+
from nat.data_models.middleware import MiddlewareBaseConfig
|
|
65
67
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
66
68
|
from nat.data_models.retriever import RetrieverBaseConfig
|
|
67
69
|
from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
|
@@ -71,6 +73,8 @@ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEn
|
|
|
71
73
|
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
72
74
|
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
73
75
|
from nat.memory.interfaces import MemoryEditor
|
|
76
|
+
from nat.middleware.function_middleware import FunctionMiddleware
|
|
77
|
+
from nat.middleware.middleware import Middleware
|
|
74
78
|
from nat.object_store.interfaces import ObjectStore
|
|
75
79
|
from nat.observability.exporter.base_exporter import BaseExporter
|
|
76
80
|
from nat.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
|
|
@@ -141,6 +145,12 @@ class ConfiguredTTCStrategy:
|
|
|
141
145
|
instance: StrategyBase
|
|
142
146
|
|
|
143
147
|
|
|
148
|
+
@dataclasses.dataclass
|
|
149
|
+
class ConfiguredMiddleware:
|
|
150
|
+
config: MiddlewareBaseConfig
|
|
151
|
+
instance: Middleware
|
|
152
|
+
|
|
153
|
+
|
|
144
154
|
class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
145
155
|
|
|
146
156
|
def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
|
|
@@ -170,6 +180,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
170
180
|
self._object_stores: dict[str, ConfiguredObjectStore] = {}
|
|
171
181
|
self._retrievers: dict[str, ConfiguredRetriever] = {}
|
|
172
182
|
self._ttc_strategies: dict[str, ConfiguredTTCStrategy] = {}
|
|
183
|
+
self._middleware: dict[str, ConfiguredMiddleware] = {}
|
|
173
184
|
|
|
174
185
|
self._context_state = ContextState.get()
|
|
175
186
|
|
|
@@ -423,6 +434,22 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
423
434
|
raise ValueError("Expected a function, FunctionInfo object, or FunctionBase object to be "
|
|
424
435
|
f"returned from the function builder. Got {type(build_result)}")
|
|
425
436
|
|
|
437
|
+
# Resolve middleware names from config to middleware instances
|
|
438
|
+
# Only FunctionMiddleware types can be used with functions
|
|
439
|
+
middleware_instances = []
|
|
440
|
+
for middleware_name in config.middleware:
|
|
441
|
+
if middleware_name not in self._middleware:
|
|
442
|
+
raise ValueError(f"Middleware `{middleware_name}` not found for function `{name}`. "
|
|
443
|
+
f"It must be configured in the `middleware` section of the YAML configuration.")
|
|
444
|
+
middleware_obj = self._middleware[middleware_name].instance
|
|
445
|
+
if not isinstance(middleware_obj, FunctionMiddleware):
|
|
446
|
+
raise TypeError(
|
|
447
|
+
f"Middleware `{middleware_name}` is not a FunctionMiddleware and cannot be used with functions. "
|
|
448
|
+
f"Only FunctionMiddleware types support function-specific wrapping.")
|
|
449
|
+
middleware_instances.append(middleware_obj)
|
|
450
|
+
|
|
451
|
+
build_result.configure_middleware(middleware_instances)
|
|
452
|
+
|
|
426
453
|
return ConfiguredFunction(config=config, instance=build_result)
|
|
427
454
|
|
|
428
455
|
async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
|
|
@@ -461,6 +488,23 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
461
488
|
raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
|
|
462
489
|
f"Got {type(build_result)}")
|
|
463
490
|
|
|
491
|
+
# Resolve middleware names from config to middleware instances
|
|
492
|
+
# Only FunctionMiddleware types can be used with function groups
|
|
493
|
+
middleware_instances = []
|
|
494
|
+
for middleware_name in config.middleware:
|
|
495
|
+
if middleware_name not in self._middleware:
|
|
496
|
+
raise ValueError(f"Middleware `{middleware_name}` not found for function group `{name}`. "
|
|
497
|
+
f"It must be configured in the `middleware` section of the YAML configuration.")
|
|
498
|
+
middleware_obj = self._middleware[middleware_name].instance
|
|
499
|
+
if not isinstance(middleware_obj, FunctionMiddleware):
|
|
500
|
+
raise TypeError(f"Middleware `{middleware_name}` is not a FunctionMiddleware and "
|
|
501
|
+
f"cannot be used with function groups. "
|
|
502
|
+
f"Only FunctionMiddleware types support function-specific wrapping.")
|
|
503
|
+
middleware_instances.append(middleware_obj)
|
|
504
|
+
|
|
505
|
+
# Configure middleware for the function group
|
|
506
|
+
build_result.configure_middleware(middleware_instances)
|
|
507
|
+
|
|
464
508
|
# set the instance name for the function group based on the workflow-provided name
|
|
465
509
|
build_result.set_instance_name(name)
|
|
466
510
|
return ConfiguredFunctionGroup(config=config, instance=build_result)
|
|
@@ -968,6 +1012,72 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
968
1012
|
|
|
969
1013
|
return config
|
|
970
1014
|
|
|
1015
|
+
@override
|
|
1016
|
+
async def add_middleware(self, name: str | MiddlewareRef, config: MiddlewareBaseConfig) -> Middleware:
|
|
1017
|
+
"""Add middleware to the builder.
|
|
1018
|
+
|
|
1019
|
+
Args:
|
|
1020
|
+
name: The name or reference for the middleware
|
|
1021
|
+
config: The configuration for the middleware
|
|
1022
|
+
|
|
1023
|
+
Returns:
|
|
1024
|
+
The built middleware instance
|
|
1025
|
+
|
|
1026
|
+
Raises:
|
|
1027
|
+
ValueError: If the middleware already exists
|
|
1028
|
+
"""
|
|
1029
|
+
if name in self._middleware:
|
|
1030
|
+
raise ValueError(f"Middleware `{name}` already exists in the list of middleware")
|
|
1031
|
+
|
|
1032
|
+
try:
|
|
1033
|
+
middleware_info = self._registry.get_middleware(type(config))
|
|
1034
|
+
|
|
1035
|
+
middleware_instance = await self._get_exit_stack().enter_async_context(
|
|
1036
|
+
middleware_info.build_fn(config, self))
|
|
1037
|
+
|
|
1038
|
+
self._middleware[name] = ConfiguredMiddleware(config=config, instance=middleware_instance)
|
|
1039
|
+
|
|
1040
|
+
return middleware_instance
|
|
1041
|
+
except Exception as e:
|
|
1042
|
+
logger.error("Error adding function middleware `%s` with config `%s`: %s", name, config, e)
|
|
1043
|
+
raise
|
|
1044
|
+
|
|
1045
|
+
@override
|
|
1046
|
+
async def get_middleware(self, middleware_name: str | MiddlewareRef) -> Middleware:
|
|
1047
|
+
"""Get built middleware by name.
|
|
1048
|
+
|
|
1049
|
+
Args:
|
|
1050
|
+
middleware_name: The name or reference of the middleware
|
|
1051
|
+
|
|
1052
|
+
Returns:
|
|
1053
|
+
The built middleware instance
|
|
1054
|
+
|
|
1055
|
+
Raises:
|
|
1056
|
+
ValueError: If the middleware is not found
|
|
1057
|
+
"""
|
|
1058
|
+
if middleware_name not in self._middleware:
|
|
1059
|
+
raise ValueError(f"Middleware `{middleware_name}` not found")
|
|
1060
|
+
|
|
1061
|
+
return self._middleware[middleware_name].instance
|
|
1062
|
+
|
|
1063
|
+
@override
|
|
1064
|
+
def get_middleware_config(self, middleware_name: str | MiddlewareRef) -> MiddlewareBaseConfig:
|
|
1065
|
+
"""Get the configuration for middleware.
|
|
1066
|
+
|
|
1067
|
+
Args:
|
|
1068
|
+
middleware_name: The name or reference of the middleware
|
|
1069
|
+
|
|
1070
|
+
Returns:
|
|
1071
|
+
The configuration for the middleware
|
|
1072
|
+
|
|
1073
|
+
Raises:
|
|
1074
|
+
ValueError: If the middleware is not found
|
|
1075
|
+
"""
|
|
1076
|
+
if middleware_name not in self._middleware:
|
|
1077
|
+
raise ValueError(f"Middleware `{middleware_name}` not found")
|
|
1078
|
+
|
|
1079
|
+
return self._middleware[middleware_name].config
|
|
1080
|
+
|
|
971
1081
|
@override
|
|
972
1082
|
def get_user_manager(self):
|
|
973
1083
|
return UserManagerHolder(context=Context(self._context_state))
|
|
@@ -1108,6 +1218,10 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
1108
1218
|
elif component_instance.component_group == ComponentGroup.RETRIEVERS:
|
|
1109
1219
|
await self.add_retriever(component_instance.name,
|
|
1110
1220
|
cast(RetrieverBaseConfig, component_instance.config))
|
|
1221
|
+
# Instantiate middleware
|
|
1222
|
+
elif component_instance.component_group == ComponentGroup.MIDDLEWARE:
|
|
1223
|
+
await self.add_middleware(component_instance.name,
|
|
1224
|
+
cast(MiddlewareBaseConfig, component_instance.config))
|
|
1111
1225
|
# Instantiate a function group
|
|
1112
1226
|
elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
|
|
1113
1227
|
await self.add_function_group(component_instance.name,
|
|
@@ -1363,3 +1477,18 @@ class ChildBuilder(Builder):
|
|
|
1363
1477
|
@override
|
|
1364
1478
|
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1365
1479
|
return self._workflow_builder.get_function_group_dependencies(fn_name)
|
|
1480
|
+
|
|
1481
|
+
@override
|
|
1482
|
+
async def add_middleware(self, name: str | MiddlewareRef, config: MiddlewareBaseConfig) -> Middleware:
|
|
1483
|
+
"""Add middleware to the builder."""
|
|
1484
|
+
return await self._workflow_builder.add_middleware(name, config)
|
|
1485
|
+
|
|
1486
|
+
@override
|
|
1487
|
+
async def get_middleware(self, middleware_name: str | MiddlewareRef) -> Middleware:
|
|
1488
|
+
"""Get built middleware by name."""
|
|
1489
|
+
return await self._workflow_builder.get_middleware(middleware_name)
|
|
1490
|
+
|
|
1491
|
+
@override
|
|
1492
|
+
def get_middleware_config(self, middleware_name: str | MiddlewareRef) -> MiddlewareBaseConfig:
|
|
1493
|
+
"""Get the configuration for middleware."""
|
|
1494
|
+
return self._workflow_builder.get_middleware_config(middleware_name)
|
|
@@ -354,7 +354,8 @@ def reinstall_command(workflow_name):
|
|
|
354
354
|
|
|
355
355
|
@click.command()
|
|
356
356
|
@click.argument('workflow_name')
|
|
357
|
-
|
|
357
|
+
@click.option('-y', '--yes', "yes_flag", is_flag=True, default=False, help='Do not prompt for confirmation.')
|
|
358
|
+
def delete_command(workflow_name: str, yes_flag: bool):
|
|
358
359
|
"""
|
|
359
360
|
Delete a NAT workflow and uninstall its package.
|
|
360
361
|
|
|
@@ -362,7 +363,7 @@ def delete_command(workflow_name: str):
|
|
|
362
363
|
workflow_name (str): The name of the workflow to delete.
|
|
363
364
|
"""
|
|
364
365
|
try:
|
|
365
|
-
if not click.confirm(f"Are you sure you want to delete the workflow '{workflow_name}'?"):
|
|
366
|
+
if not yes_flag and not click.confirm(f"Are you sure you want to delete the workflow '{workflow_name}'?"):
|
|
366
367
|
click.echo("Workflow deletion cancelled.")
|
|
367
368
|
return
|
|
368
369
|
editable = get_repo_root() is not None
|
nat/cli/register_workflow.py
CHANGED
|
@@ -38,6 +38,8 @@ from nat.cli.type_registry import LoggingMethodConfigT
|
|
|
38
38
|
from nat.cli.type_registry import LoggingMethodRegisteredCallableT
|
|
39
39
|
from nat.cli.type_registry import MemoryBuildCallableT
|
|
40
40
|
from nat.cli.type_registry import MemoryRegisteredCallableT
|
|
41
|
+
from nat.cli.type_registry import MiddlewareBuildCallableT
|
|
42
|
+
from nat.cli.type_registry import MiddlewareRegisteredCallableT
|
|
41
43
|
from nat.cli.type_registry import ObjectStoreBuildCallableT
|
|
42
44
|
from nat.cli.type_registry import ObjectStoreRegisteredCallableT
|
|
43
45
|
from nat.cli.type_registry import RegisteredLoggingMethod
|
|
@@ -65,6 +67,7 @@ from nat.data_models.function import FunctionConfigT
|
|
|
65
67
|
from nat.data_models.function import FunctionGroupConfigT
|
|
66
68
|
from nat.data_models.llm import LLMBaseConfigT
|
|
67
69
|
from nat.data_models.memory import MemoryBaseConfigT
|
|
70
|
+
from nat.data_models.middleware import MiddlewareBaseConfigT
|
|
68
71
|
from nat.data_models.object_store import ObjectStoreBaseConfigT
|
|
69
72
|
from nat.data_models.registry_handler import RegistryHandlerBaseConfigT
|
|
70
73
|
from nat.data_models.retriever import RetrieverBaseConfigT
|
|
@@ -149,6 +152,10 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
149
152
|
framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
|
|
150
153
|
"""
|
|
151
154
|
Register a workflow with optional framework_wrappers for automatic profiler hooking.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
config_type: The function configuration type
|
|
158
|
+
framework_wrappers: Optional list of framework wrappers for automatic profiler hooking
|
|
152
159
|
"""
|
|
153
160
|
|
|
154
161
|
def register_function_inner(
|
|
@@ -211,6 +218,49 @@ def register_function_group(config_type: type[FunctionGroupConfigT],
|
|
|
211
218
|
return register_function_group_inner
|
|
212
219
|
|
|
213
220
|
|
|
221
|
+
def register_middleware(config_type: type[MiddlewareBaseConfigT]):
|
|
222
|
+
"""
|
|
223
|
+
Register a middleware component.
|
|
224
|
+
|
|
225
|
+
Middleware provides middleware-style wrapping of calls with
|
|
226
|
+
preprocessing and postprocessing logic. They are built as components that can
|
|
227
|
+
be configured in YAML and referenced by name in configurations.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
config_type: The middleware configuration type to register
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
A decorator that wraps the build function as an async context manager
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
def register_middleware_inner(
|
|
237
|
+
fn: MiddlewareBuildCallableT[MiddlewareBaseConfigT]
|
|
238
|
+
) -> MiddlewareRegisteredCallableT[MiddlewareBaseConfigT]:
|
|
239
|
+
from .type_registry import GlobalTypeRegistry
|
|
240
|
+
from .type_registry import RegisteredMiddlewareInfo
|
|
241
|
+
|
|
242
|
+
context_manager_fn = asynccontextmanager(fn)
|
|
243
|
+
|
|
244
|
+
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
245
|
+
component_type=ComponentEnum.MIDDLEWARE)
|
|
246
|
+
|
|
247
|
+
GlobalTypeRegistry.get().register_middleware(
|
|
248
|
+
RegisteredMiddlewareInfo(
|
|
249
|
+
full_type=config_type.full_type,
|
|
250
|
+
config_type=config_type,
|
|
251
|
+
build_fn=context_manager_fn,
|
|
252
|
+
discovery_metadata=discovery_metadata,
|
|
253
|
+
))
|
|
254
|
+
|
|
255
|
+
return context_manager_fn
|
|
256
|
+
|
|
257
|
+
return register_middleware_inner
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
# Compatibility alias for backwards compatibility
|
|
261
|
+
register_function_middleware = register_middleware
|
|
262
|
+
|
|
263
|
+
|
|
214
264
|
def register_llm_provider(config_type: type[LLMBaseConfigT]):
|
|
215
265
|
|
|
216
266
|
def register_llm_provider_inner(
|