nvidia-nat 1.4.0a20251112__py3-none-any.whl → 1.4.0a20251120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/builder/builder.py +52 -0
- nat/builder/component_utils.py +7 -1
- nat/builder/context.py +17 -0
- nat/builder/framework_enum.py +1 -0
- nat/builder/function.py +74 -3
- nat/builder/workflow.py +4 -2
- nat/builder/workflow_builder.py +129 -0
- nat/cli/register_workflow.py +50 -0
- nat/cli/type_registry.py +68 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +16 -0
- nat/data_models/function.py +14 -1
- nat/data_models/middleware.py +35 -0
- nat/data_models/runtime_enum.py +26 -0
- nat/eval/evaluate.py +10 -2
- nat/front_ends/fastapi/fastapi_front_end_config.py +22 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +124 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +4 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +26 -0
- nat/middleware/__init__.py +35 -0
- nat/middleware/cache_middleware.py +256 -0
- nat/middleware/function_middleware.py +186 -0
- nat/middleware/middleware.py +184 -0
- nat/middleware/register.py +35 -0
- nat/profiler/decorators/framework_wrapper.py +16 -0
- nat/retriever/milvus/register.py +11 -3
- nat/retriever/milvus/retriever.py +102 -40
- nat/runtime/runner.py +12 -1
- nat/runtime/session.py +10 -3
- nat/tool/code_execution/code_sandbox.py +1 -1
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/METADATA +9 -3
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/RECORD +38 -31
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/entry_points.txt +1 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/top_level.txt +0 -0
nat/cli/type_registry.py
CHANGED
|
@@ -64,6 +64,8 @@ from nat.data_models.logging import LoggingBaseConfig
|
|
|
64
64
|
from nat.data_models.logging import LoggingMethodConfigT
|
|
65
65
|
from nat.data_models.memory import MemoryBaseConfig
|
|
66
66
|
from nat.data_models.memory import MemoryBaseConfigT
|
|
67
|
+
from nat.data_models.middleware import MiddlewareBaseConfig
|
|
68
|
+
from nat.data_models.middleware import MiddlewareBaseConfigT
|
|
67
69
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
68
70
|
from nat.data_models.object_store import ObjectStoreBaseConfigT
|
|
69
71
|
from nat.data_models.registry_handler import RegistryHandlerBaseConfig
|
|
@@ -76,6 +78,7 @@ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
|
76
78
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfigT
|
|
77
79
|
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
78
80
|
from nat.memory.interfaces import MemoryEditor
|
|
81
|
+
from nat.middleware.middleware import Middleware
|
|
79
82
|
from nat.object_store.interfaces import ObjectStore
|
|
80
83
|
from nat.observability.exporter.base_exporter import BaseExporter
|
|
81
84
|
from nat.registry_handlers.registry_handler_base import AbstractRegistryHandler
|
|
@@ -89,6 +92,7 @@ EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIte
|
|
|
89
92
|
FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
|
|
90
93
|
FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
|
|
91
94
|
FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
|
|
95
|
+
MiddlewareBuildCallableT = Callable[[MiddlewareBaseConfigT, Builder], AsyncIterator[Middleware]]
|
|
92
96
|
TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
|
|
93
97
|
LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
|
|
94
98
|
LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
|
|
@@ -111,6 +115,7 @@ FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncC
|
|
|
111
115
|
FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
|
|
112
116
|
AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
|
|
113
117
|
FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
|
|
118
|
+
MiddlewareRegisteredCallableT = Callable[[MiddlewareBaseConfigT, Builder], AbstractAsyncContextManager[Middleware]]
|
|
114
119
|
TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
|
|
115
120
|
LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
116
121
|
LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
|
|
@@ -179,6 +184,8 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
|
|
|
179
184
|
and a description.
|
|
180
185
|
"""
|
|
181
186
|
|
|
187
|
+
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
|
188
|
+
|
|
182
189
|
build_fn: FunctionRegisteredCallableT = Field(repr=False)
|
|
183
190
|
framework_wrappers: list[str] = Field(default_factory=list)
|
|
184
191
|
|
|
@@ -193,6 +200,15 @@ class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
|
|
|
193
200
|
framework_wrappers: list[str] = Field(default_factory=list)
|
|
194
201
|
|
|
195
202
|
|
|
203
|
+
class RegisteredMiddlewareInfo(RegisteredInfo[MiddlewareBaseConfig]):
|
|
204
|
+
"""
|
|
205
|
+
Represents registered middleware. Middleware provides middleware-style wrapping of
|
|
206
|
+
calls with preprocessing and postprocessing logic.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
build_fn: MiddlewareRegisteredCallableT = Field(repr=False)
|
|
210
|
+
|
|
211
|
+
|
|
196
212
|
class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
|
|
197
213
|
"""
|
|
198
214
|
Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
|
|
@@ -331,6 +347,9 @@ class TypeRegistry:
|
|
|
331
347
|
# Function Groups
|
|
332
348
|
self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
|
|
333
349
|
|
|
350
|
+
# Middleware
|
|
351
|
+
self._registered_middleware: dict[type[MiddlewareBaseConfig], RegisteredMiddlewareInfo] = {}
|
|
352
|
+
|
|
334
353
|
# LLMs
|
|
335
354
|
self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
|
|
336
355
|
self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
|
|
@@ -540,6 +559,49 @@ class TypeRegistry:
|
|
|
540
559
|
"""
|
|
541
560
|
return list(self._registered_function_groups.values())
|
|
542
561
|
|
|
562
|
+
def register_middleware(self, registration: RegisteredMiddlewareInfo):
|
|
563
|
+
"""Register middleware with the type registry.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
registration: The middleware registration information
|
|
567
|
+
|
|
568
|
+
Raises:
|
|
569
|
+
ValueError: If middleware with the same config type is already registered
|
|
570
|
+
"""
|
|
571
|
+
if (registration.config_type in self._registered_middleware):
|
|
572
|
+
raise ValueError(f"Middleware with the same config type `{registration.config_type}` has already been "
|
|
573
|
+
"registered.")
|
|
574
|
+
|
|
575
|
+
self._registered_middleware[registration.config_type] = registration
|
|
576
|
+
|
|
577
|
+
self._registration_changed()
|
|
578
|
+
|
|
579
|
+
def get_middleware(self, config_type: type[MiddlewareBaseConfig]) -> RegisteredMiddlewareInfo:
|
|
580
|
+
"""Get registered middleware by its config type.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
config_type: The middleware configuration type
|
|
584
|
+
|
|
585
|
+
Returns:
|
|
586
|
+
RegisteredMiddlewareInfo: The registered middleware information
|
|
587
|
+
|
|
588
|
+
Raises:
|
|
589
|
+
KeyError: If no middleware is registered for the given config type
|
|
590
|
+
"""
|
|
591
|
+
try:
|
|
592
|
+
return self._registered_middleware[config_type]
|
|
593
|
+
except KeyError as err:
|
|
594
|
+
raise KeyError(f"Could not find registered middleware for config `{config_type}`. "
|
|
595
|
+
f"Registered configs: {set(self._registered_middleware.keys())}") from err
|
|
596
|
+
|
|
597
|
+
def get_registered_middleware(self) -> list[RegisteredInfo[MiddlewareBaseConfig]]:
|
|
598
|
+
"""Get all registered middleware.
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
list[RegisteredInfo[MiddlewareBaseConfig]]: List of all registered middleware
|
|
602
|
+
"""
|
|
603
|
+
return list(self._registered_middleware.values())
|
|
604
|
+
|
|
543
605
|
def register_llm_provider(self, info: RegisteredLLMProviderInfo):
|
|
544
606
|
|
|
545
607
|
if (info.config_type in self._registered_llm_provider_infos):
|
|
@@ -912,6 +974,9 @@ class TypeRegistry:
|
|
|
912
974
|
if component_type == ComponentEnum.TTC_STRATEGY:
|
|
913
975
|
return self._registered_ttc_strategies
|
|
914
976
|
|
|
977
|
+
if component_type == ComponentEnum.MIDDLEWARE:
|
|
978
|
+
return self._registered_middleware
|
|
979
|
+
|
|
915
980
|
raise ValueError(f"Supplied an unsupported component type {component_type}")
|
|
916
981
|
|
|
917
982
|
def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
|
|
@@ -1038,6 +1103,9 @@ class TypeRegistry:
|
|
|
1038
1103
|
if issubclass(cls, TTCStrategyBaseConfig):
|
|
1039
1104
|
return self._do_compute_annotation(cls, self.get_registered_ttc_strategies())
|
|
1040
1105
|
|
|
1106
|
+
if issubclass(cls, MiddlewareBaseConfig):
|
|
1107
|
+
return self._do_compute_annotation(cls, self.get_registered_middleware())
|
|
1108
|
+
|
|
1041
1109
|
raise ValueError(f"Supplied an unsupported component type {cls}")
|
|
1042
1110
|
|
|
1043
1111
|
|
nat/data_models/component.py
CHANGED
|
@@ -28,6 +28,7 @@ class ComponentEnum(StrEnum):
|
|
|
28
28
|
FRONT_END = "front_end"
|
|
29
29
|
FUNCTION = "function"
|
|
30
30
|
FUNCTION_GROUP = "function_group"
|
|
31
|
+
MIDDLEWARE = "middleware"
|
|
31
32
|
TTC_STRATEGY = "ttc_strategy"
|
|
32
33
|
LLM_CLIENT = "llm_client"
|
|
33
34
|
LLM_PROVIDER = "llm_provider"
|
|
@@ -49,6 +50,7 @@ class ComponentGroup(StrEnum):
|
|
|
49
50
|
EMBEDDERS = "embedders"
|
|
50
51
|
FUNCTIONS = "functions"
|
|
51
52
|
FUNCTION_GROUPS = "function_groups"
|
|
53
|
+
MIDDLEWARE = "middleware"
|
|
52
54
|
TTC_STRATEGIES = "ttc_strategies"
|
|
53
55
|
LLMS = "llms"
|
|
54
56
|
MEMORY = "memory"
|
nat/data_models/component_ref.py
CHANGED
|
@@ -177,3 +177,14 @@ class TTCStrategyRef(ComponentRef):
|
|
|
177
177
|
@override
|
|
178
178
|
def component_group(self):
|
|
179
179
|
return ComponentGroup.TTC_STRATEGIES
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class MiddlewareRef(ComponentRef):
|
|
183
|
+
"""
|
|
184
|
+
A reference to middleware in a NAT configuration object.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
@override
|
|
189
|
+
def component_group(self):
|
|
190
|
+
return ComponentGroup.MIDDLEWARE
|
nat/data_models/config.py
CHANGED
|
@@ -43,6 +43,7 @@ from .common import TypedBaseModel
|
|
|
43
43
|
from .embedder import EmbedderBaseConfig
|
|
44
44
|
from .llm import LLMBaseConfig
|
|
45
45
|
from .memory import MemoryBaseConfig
|
|
46
|
+
from .middleware import FunctionMiddlewareBaseConfig
|
|
46
47
|
from .object_store import ObjectStoreBaseConfig
|
|
47
48
|
from .retriever import RetrieverBaseConfig
|
|
48
49
|
|
|
@@ -86,6 +87,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
86
87
|
registered_keys = GlobalTypeRegistry.get().get_registered_front_ends()
|
|
87
88
|
elif (info.field_name == "ttc_strategies"):
|
|
88
89
|
registered_keys = GlobalTypeRegistry.get().get_registered_ttc_strategies()
|
|
90
|
+
elif (info.field_name == "middleware"):
|
|
91
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_middleware()
|
|
89
92
|
|
|
90
93
|
else:
|
|
91
94
|
assert False, f"Unknown field name {info.field_name} in validator"
|
|
@@ -253,6 +256,9 @@ class Config(HashableBaseModel):
|
|
|
253
256
|
# Function Groups Configuration
|
|
254
257
|
function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
|
|
255
258
|
|
|
259
|
+
# Middleware Configuration
|
|
260
|
+
middleware: dict[str, FunctionMiddlewareBaseConfig] = Field(default_factory=dict)
|
|
261
|
+
|
|
256
262
|
# LLMs Configuration
|
|
257
263
|
llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
|
|
258
264
|
|
|
@@ -303,6 +309,7 @@ class Config(HashableBaseModel):
|
|
|
303
309
|
|
|
304
310
|
@field_validator("functions",
|
|
305
311
|
"function_groups",
|
|
312
|
+
"middleware",
|
|
306
313
|
"llms",
|
|
307
314
|
"embedders",
|
|
308
315
|
"memory",
|
|
@@ -348,6 +355,10 @@ class Config(HashableBaseModel):
|
|
|
348
355
|
typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
|
|
349
356
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
350
357
|
|
|
358
|
+
MiddlewareAnnotation = dict[str,
|
|
359
|
+
typing.Annotated[type_registry.compute_annotation(FunctionMiddlewareBaseConfig),
|
|
360
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
361
|
+
|
|
351
362
|
MemoryAnnotation = dict[str,
|
|
352
363
|
typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
|
|
353
364
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -393,6 +404,11 @@ class Config(HashableBaseModel):
|
|
|
393
404
|
function_groups_field.annotation = FunctionGroupsAnnotation
|
|
394
405
|
should_rebuild = True
|
|
395
406
|
|
|
407
|
+
middleware_field = cls.model_fields.get("middleware")
|
|
408
|
+
if (middleware_field is not None and middleware_field.annotation != MiddlewareAnnotation):
|
|
409
|
+
middleware_field.annotation = MiddlewareAnnotation
|
|
410
|
+
should_rebuild = True
|
|
411
|
+
|
|
396
412
|
memory_field = cls.model_fields.get("memory")
|
|
397
413
|
if memory_field is not None and memory_field.annotation != MemoryAnnotation:
|
|
398
414
|
memory_field.annotation = MemoryAnnotation
|
nat/data_models/function.py
CHANGED
|
@@ -24,7 +24,16 @@ from .common import TypedBaseModel
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
27
|
-
|
|
27
|
+
"""Base configuration for functions.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
middleware: List of function middleware names to apply to this function.
|
|
31
|
+
These must match names defined in the `middleware` section of the YAML configuration.
|
|
32
|
+
"""
|
|
33
|
+
middleware: list[str] = Field(
|
|
34
|
+
default_factory=list,
|
|
35
|
+
description="List of function middleware names to apply to this function in order",
|
|
36
|
+
)
|
|
28
37
|
|
|
29
38
|
|
|
30
39
|
class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
@@ -40,6 +49,10 @@ class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
|
40
49
|
default_factory=list,
|
|
41
50
|
description="The list of function names which should be excluded from default access to the group",
|
|
42
51
|
)
|
|
52
|
+
middleware: list[str] = Field(
|
|
53
|
+
default_factory=list,
|
|
54
|
+
description="List of function middleware names to apply to all functions in this group",
|
|
55
|
+
)
|
|
43
56
|
|
|
44
57
|
@field_validator("include", "exclude")
|
|
45
58
|
@classmethod
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
|
|
18
|
+
from .common import BaseModelRegistryTag
|
|
19
|
+
from .common import TypedBaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MiddlewareBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
23
|
+
"""The base level config object for middleware.
|
|
24
|
+
|
|
25
|
+
Middleware provides middleware-style wrapping of calls with
|
|
26
|
+
preprocessing and postprocessing logic.
|
|
27
|
+
"""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
MiddlewareBaseConfigT = typing.TypeVar("MiddlewareBaseConfigT", bound=MiddlewareBaseConfig)
|
|
32
|
+
|
|
33
|
+
# Specialized type for function-specific middleware
|
|
34
|
+
FunctionMiddlewareBaseConfig = MiddlewareBaseConfig
|
|
35
|
+
FunctionMiddlewareBaseConfigT = MiddlewareBaseConfigT
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import enum
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RuntimeTypeEnum(str, enum.Enum):
|
|
20
|
+
"""
|
|
21
|
+
Enum representing different runtime types.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
RUN_OR_SERVE = "run_or_serve"
|
|
25
|
+
EVALUATE = "evaluate"
|
|
26
|
+
OTHER = "other"
|
nat/eval/evaluate.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import logging
|
|
18
18
|
import shutil
|
|
19
|
+
import warnings
|
|
19
20
|
from pathlib import Path
|
|
20
21
|
from typing import Any
|
|
21
22
|
from uuid import uuid4
|
|
@@ -25,6 +26,7 @@ from tqdm import tqdm
|
|
|
25
26
|
|
|
26
27
|
from nat.data_models.evaluate import EvalConfig
|
|
27
28
|
from nat.data_models.evaluate import JobEvictionPolicy
|
|
29
|
+
from nat.data_models.runtime_enum import RuntimeTypeEnum
|
|
28
30
|
from nat.eval.config import EvaluationRunConfig
|
|
29
31
|
from nat.eval.config import EvaluationRunOutput
|
|
30
32
|
from nat.eval.dataset_handler.dataset_handler import DatasetHandler
|
|
@@ -67,7 +69,13 @@ class EvaluationRun:
|
|
|
67
69
|
# Create evaluation trace context
|
|
68
70
|
try:
|
|
69
71
|
from nat.eval.utils.eval_trace_ctx import WeaveEvalTraceContext
|
|
70
|
-
|
|
72
|
+
with warnings.catch_warnings():
|
|
73
|
+
# Ignore deprecation warnings being triggered by weave. https://github.com/wandb/weave/issues/3666
|
|
74
|
+
warnings.filterwarnings("ignore",
|
|
75
|
+
category=DeprecationWarning,
|
|
76
|
+
message=r"`sentry_sdk\.Hub` is deprecated")
|
|
77
|
+
|
|
78
|
+
self.eval_trace_context = WeaveEvalTraceContext()
|
|
71
79
|
except Exception:
|
|
72
80
|
from nat.eval.utils.eval_trace_ctx import EvalTraceContext
|
|
73
81
|
self.eval_trace_context = EvalTraceContext()
|
|
@@ -161,7 +169,7 @@ class EvaluationRun:
|
|
|
161
169
|
if stop_event.is_set():
|
|
162
170
|
return "", []
|
|
163
171
|
|
|
164
|
-
async with session_manager.run(item.input_obj) as runner:
|
|
172
|
+
async with session_manager.run(item.input_obj, runtime_type=RuntimeTypeEnum.EVALUATE) as runner:
|
|
165
173
|
if not session_manager.workflow.has_single_output:
|
|
166
174
|
# raise an error if the workflow has multiple outputs
|
|
167
175
|
raise NotImplementedError("Multiple outputs are not supported")
|
|
@@ -27,6 +27,8 @@ from pydantic import field_validator
|
|
|
27
27
|
from nat.data_models.component_ref import ObjectStoreRef
|
|
28
28
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
29
29
|
from nat.data_models.step_adaptor import StepAdaptorConfig
|
|
30
|
+
from nat.eval.evaluator.evaluator_model import EvalInputItem
|
|
31
|
+
from nat.eval.evaluator.evaluator_model import EvalOutputItem
|
|
30
32
|
|
|
31
33
|
logger = logging.getLogger(__name__)
|
|
32
34
|
|
|
@@ -133,6 +135,19 @@ class AsyncGenerationStatusResponse(BaseAsyncStatusResponse):
|
|
|
133
135
|
description="Output of the generate request, this is only available if the job completed successfully.")
|
|
134
136
|
|
|
135
137
|
|
|
138
|
+
class EvaluateItemRequest(BaseModel):
|
|
139
|
+
"""Request model for single-item evaluation endpoint."""
|
|
140
|
+
item: EvalInputItem = Field(description="Single evaluation input item to evaluate")
|
|
141
|
+
evaluator_name: str = Field(description="Name of the evaluator to use (must match config)")
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class EvaluateItemResponse(BaseModel):
|
|
145
|
+
"""Response model for single-item evaluation endpoint."""
|
|
146
|
+
success: bool = Field(description="Whether the evaluation completed successfully")
|
|
147
|
+
result: EvalOutputItem | None = Field(default=None, description="Evaluation result if successful")
|
|
148
|
+
error: str | None = Field(default=None, description="Error message if evaluation failed")
|
|
149
|
+
|
|
150
|
+
|
|
136
151
|
class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
|
|
137
152
|
"""
|
|
138
153
|
A FastAPI based front end that allows a NAT workflow to be served as a microservice.
|
|
@@ -239,6 +254,13 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
|
|
|
239
254
|
description="Evaluates the performance and accuracy of the workflow on a dataset",
|
|
240
255
|
)
|
|
241
256
|
|
|
257
|
+
evaluate_item: typing.Annotated[EndpointBase,
|
|
258
|
+
Field(description="Endpoint for evaluating a single item.")] = EndpointBase(
|
|
259
|
+
method="POST",
|
|
260
|
+
path="/evaluate/item",
|
|
261
|
+
description="Evaluate a single item with a specified evaluator",
|
|
262
|
+
)
|
|
263
|
+
|
|
242
264
|
oauth2_callback_path: str | None = Field(
|
|
243
265
|
default="/auth/redirect",
|
|
244
266
|
description="OAuth2.0 authentication callback endpoint. If None, no OAuth2 callback endpoint is created.")
|
|
@@ -39,6 +39,8 @@ from pydantic import BaseModel
|
|
|
39
39
|
from pydantic import Field
|
|
40
40
|
from starlette.websockets import WebSocket
|
|
41
41
|
|
|
42
|
+
from nat.builder.eval_builder import WorkflowEvalBuilder
|
|
43
|
+
from nat.builder.evaluator import EvaluatorInfo
|
|
42
44
|
from nat.builder.function import Function
|
|
43
45
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
44
46
|
from nat.data_models.api_server import ChatRequest
|
|
@@ -51,11 +53,14 @@ from nat.data_models.object_store import NoSuchKeyError
|
|
|
51
53
|
from nat.eval.config import EvaluationRunOutput
|
|
52
54
|
from nat.eval.evaluate import EvaluationRun
|
|
53
55
|
from nat.eval.evaluate import EvaluationRunConfig
|
|
56
|
+
from nat.eval.evaluator.evaluator_model import EvalInput
|
|
54
57
|
from nat.front_ends.fastapi.auth_flow_handlers.http_flow_handler import HTTPAuthenticationFlowHandler
|
|
55
58
|
from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import FlowState
|
|
56
59
|
from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import WebSocketAuthenticationFlowHandler
|
|
57
60
|
from nat.front_ends.fastapi.fastapi_front_end_config import AsyncGenerateResponse
|
|
58
61
|
from nat.front_ends.fastapi.fastapi_front_end_config import AsyncGenerationStatusResponse
|
|
62
|
+
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateItemRequest
|
|
63
|
+
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateItemResponse
|
|
59
64
|
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateRequest
|
|
60
65
|
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateResponse
|
|
61
66
|
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateStatusResponse
|
|
@@ -227,6 +232,54 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
227
232
|
self._outstanding_flows: dict[str, FlowState] = {}
|
|
228
233
|
self._outstanding_flows_lock = asyncio.Lock()
|
|
229
234
|
|
|
235
|
+
# Evaluator storage for single-item evaluation
|
|
236
|
+
self._evaluators: dict[str, EvaluatorInfo] = {}
|
|
237
|
+
self._eval_builder: WorkflowEvalBuilder | None = None
|
|
238
|
+
|
|
239
|
+
async def initialize_evaluators(self, config: Config):
|
|
240
|
+
"""Initialize and store evaluators from config for single-item evaluation."""
|
|
241
|
+
if not config.eval or not config.eval.evaluators:
|
|
242
|
+
logger.info("No evaluators configured, skipping evaluator initialization")
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
# Build evaluators using WorkflowEvalBuilder (same pattern as nat eval)
|
|
247
|
+
# Start with registry=None and let populate_builder set everything up
|
|
248
|
+
self._eval_builder = WorkflowEvalBuilder(general_config=config.general,
|
|
249
|
+
eval_general_config=config.eval.general,
|
|
250
|
+
registry=None)
|
|
251
|
+
|
|
252
|
+
# Enter the async context and keep it alive
|
|
253
|
+
await self._eval_builder.__aenter__()
|
|
254
|
+
|
|
255
|
+
# Populate builder with config (this sets up LLMs, functions, etc.)
|
|
256
|
+
# Skip workflow build since we already have it from the main builder
|
|
257
|
+
await self._eval_builder.populate_builder(config, skip_workflow=True)
|
|
258
|
+
|
|
259
|
+
# Now evaluators should be populated by populate_builder
|
|
260
|
+
for name in config.eval.evaluators.keys():
|
|
261
|
+
self._evaluators[name] = self._eval_builder.get_evaluator(name)
|
|
262
|
+
logger.info(f"Initialized evaluator: {name}")
|
|
263
|
+
|
|
264
|
+
logger.info(f"Successfully initialized {len(self._evaluators)} evaluators")
|
|
265
|
+
|
|
266
|
+
except Exception as e:
|
|
267
|
+
logger.error(f"Failed to initialize evaluators: {e}")
|
|
268
|
+
# Don't fail startup, just log the error
|
|
269
|
+
self._evaluators = {}
|
|
270
|
+
|
|
271
|
+
async def cleanup_evaluators(self):
|
|
272
|
+
"""Clean up evaluator resources on shutdown."""
|
|
273
|
+
if self._eval_builder:
|
|
274
|
+
try:
|
|
275
|
+
await self._eval_builder.__aexit__(None, None, None)
|
|
276
|
+
logger.info("Evaluator builder context cleaned up")
|
|
277
|
+
except Exception as e:
|
|
278
|
+
logger.error(f"Error cleaning up evaluator builder: {e}")
|
|
279
|
+
finally:
|
|
280
|
+
self._eval_builder = None
|
|
281
|
+
self._evaluators.clear()
|
|
282
|
+
|
|
230
283
|
def get_step_adaptor(self) -> StepAdaptor:
|
|
231
284
|
|
|
232
285
|
return StepAdaptor(self.front_end_config.step_adaptor)
|
|
@@ -236,12 +289,20 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
236
289
|
# Do things like setting the base URL and global configuration options
|
|
237
290
|
app.root_path = self.front_end_config.root_path
|
|
238
291
|
|
|
292
|
+
# Initialize evaluators for single-item evaluation
|
|
293
|
+
# TODO: we need config control over this as it's not always needed
|
|
294
|
+
await self.initialize_evaluators(self._config)
|
|
295
|
+
|
|
296
|
+
# Ensure evaluator resources are cleaned up when the app shuts down
|
|
297
|
+
app.add_event_handler("shutdown", self.cleanup_evaluators)
|
|
298
|
+
|
|
239
299
|
await self.add_routes(app, builder)
|
|
240
300
|
|
|
241
301
|
async def add_routes(self, app: FastAPI, builder: WorkflowBuilder):
|
|
242
302
|
|
|
243
303
|
await self.add_default_route(app, SessionManager(await builder.build()))
|
|
244
304
|
await self.add_evaluate_route(app, SessionManager(await builder.build()))
|
|
305
|
+
await self.add_evaluate_item_route(app, SessionManager(await builder.build()))
|
|
245
306
|
await self.add_static_files_route(app, builder)
|
|
246
307
|
await self.add_authorization_route(app)
|
|
247
308
|
await self.add_mcp_client_tool_list_route(app, builder)
|
|
@@ -439,6 +500,69 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
439
500
|
else:
|
|
440
501
|
logger.warning("Dask is not available, evaluation endpoints will not be added.")
|
|
441
502
|
|
|
503
|
+
async def add_evaluate_item_route(self, app: FastAPI, session_manager: SessionManager):
|
|
504
|
+
"""Add the single-item evaluation endpoint to the FastAPI app."""
|
|
505
|
+
|
|
506
|
+
async def evaluate_single_item(request: EvaluateItemRequest, http_request: Request) -> EvaluateItemResponse:
|
|
507
|
+
"""Handle single-item evaluation requests."""
|
|
508
|
+
|
|
509
|
+
async with session_manager.session(http_connection=http_request):
|
|
510
|
+
|
|
511
|
+
# Check if evaluator exists
|
|
512
|
+
if request.evaluator_name not in self._evaluators:
|
|
513
|
+
raise HTTPException(status_code=404,
|
|
514
|
+
detail=f"Evaluator '{request.evaluator_name}' not found. "
|
|
515
|
+
f"Available evaluators: {list(self._evaluators.keys())}")
|
|
516
|
+
|
|
517
|
+
try:
|
|
518
|
+
# Get the evaluator
|
|
519
|
+
evaluator = self._evaluators[request.evaluator_name]
|
|
520
|
+
|
|
521
|
+
# Run evaluation on single item
|
|
522
|
+
result = await evaluator.evaluate_fn(EvalInput(eval_input_items=[request.item]))
|
|
523
|
+
|
|
524
|
+
# Extract the single output item
|
|
525
|
+
if result.eval_output_items:
|
|
526
|
+
output_item = result.eval_output_items[0]
|
|
527
|
+
return EvaluateItemResponse(success=True, result=output_item, error=None)
|
|
528
|
+
else:
|
|
529
|
+
return EvaluateItemResponse(success=False, result=None, error="Evaluator returned no results")
|
|
530
|
+
|
|
531
|
+
except Exception as e:
|
|
532
|
+
logger.exception(f"Error evaluating item with {request.evaluator_name}")
|
|
533
|
+
return EvaluateItemResponse(success=False, result=None, error=f"Evaluation failed: {str(e)}")
|
|
534
|
+
|
|
535
|
+
# Register the route
|
|
536
|
+
if self.front_end_config.evaluate_item.path:
|
|
537
|
+
app.add_api_route(path=self.front_end_config.evaluate_item.path,
|
|
538
|
+
endpoint=evaluate_single_item,
|
|
539
|
+
methods=[self.front_end_config.evaluate_item.method],
|
|
540
|
+
response_model=EvaluateItemResponse,
|
|
541
|
+
description=self.front_end_config.evaluate_item.description,
|
|
542
|
+
responses={
|
|
543
|
+
404: {
|
|
544
|
+
"description": "Evaluator not found",
|
|
545
|
+
"content": {
|
|
546
|
+
"application/json": {
|
|
547
|
+
"example": {
|
|
548
|
+
"detail": "Evaluator 'unknown' not found"
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
}
|
|
552
|
+
},
|
|
553
|
+
500: {
|
|
554
|
+
"description": "Internal Server Error",
|
|
555
|
+
"content": {
|
|
556
|
+
"application/json": {
|
|
557
|
+
"example": {
|
|
558
|
+
"detail": "Internal server error occurred"
|
|
559
|
+
}
|
|
560
|
+
}
|
|
561
|
+
}
|
|
562
|
+
}
|
|
563
|
+
})
|
|
564
|
+
logger.info(f"Added evaluate_item route at {self.front_end_config.evaluate_item.path}")
|
|
565
|
+
|
|
442
566
|
async def add_static_files_route(self, app: FastAPI, builder: WorkflowBuilder):
|
|
443
567
|
|
|
444
568
|
if not self.front_end_config.object_store:
|
|
@@ -140,6 +140,10 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
|
|
|
140
140
|
# Mount the MCP server's ASGI app at the configured base_path
|
|
141
141
|
app.mount(self.front_end_config.base_path, mcp.streamable_http_app())
|
|
142
142
|
|
|
143
|
+
# Allow plugins to add routes to the wrapper app (e.g., OAuth discovery endpoints)
|
|
144
|
+
worker = self._get_worker_instance()
|
|
145
|
+
await worker.add_root_level_routes(app, mcp)
|
|
146
|
+
|
|
143
147
|
# Configure and start uvicorn server
|
|
144
148
|
config = uvicorn.Config(
|
|
145
149
|
app,
|
|
@@ -17,12 +17,16 @@ import logging
|
|
|
17
17
|
from abc import ABC
|
|
18
18
|
from abc import abstractmethod
|
|
19
19
|
from collections.abc import Mapping
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
20
21
|
from typing import Any
|
|
21
22
|
|
|
22
23
|
from mcp.server.fastmcp import FastMCP
|
|
23
24
|
from starlette.exceptions import HTTPException
|
|
24
25
|
from starlette.requests import Request
|
|
25
26
|
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from fastapi import FastAPI
|
|
29
|
+
|
|
26
30
|
from nat.builder.function import Function
|
|
27
31
|
from nat.builder.function_base import FunctionBase
|
|
28
32
|
from nat.builder.workflow import Workflow
|
|
@@ -192,6 +196,28 @@ class MCPFrontEndPluginWorkerBase(ABC):
|
|
|
192
196
|
|
|
193
197
|
return functions
|
|
194
198
|
|
|
199
|
+
async def add_root_level_routes(self, wrapper_app: "FastAPI", mcp: FastMCP) -> None:
|
|
200
|
+
"""Add routes to the wrapper FastAPI app (optional extension point).
|
|
201
|
+
|
|
202
|
+
This method is called when base_path is configured and a wrapper
|
|
203
|
+
FastAPI app is created to mount the MCP server. Plugins can override
|
|
204
|
+
this to add routes to the wrapper app at the root level, outside the
|
|
205
|
+
mounted MCP server path.
|
|
206
|
+
|
|
207
|
+
Common use cases:
|
|
208
|
+
- OAuth discovery endpoints (e.g., /.well-known/oauth-protected-resource)
|
|
209
|
+
- Health checks at root level
|
|
210
|
+
- Static file serving
|
|
211
|
+
- Custom authentication/authorization endpoints
|
|
212
|
+
|
|
213
|
+
Default implementation does nothing, making this an optional extension point.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
wrapper_app: The FastAPI wrapper application that mounts the MCP server
|
|
217
|
+
mcp: The FastMCP server instance (already mounted at base_path)
|
|
218
|
+
"""
|
|
219
|
+
pass # Default: no additional root-level routes
|
|
220
|
+
|
|
195
221
|
def _setup_debug_endpoints(self, mcp: FastMCP, functions: Mapping[str, FunctionBase]) -> None:
|
|
196
222
|
"""Set up HTTP debug endpoints for introspecting tools and schemas.
|
|
197
223
|
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Middleware implementations for NeMo Agent Toolkit."""
|
|
16
|
+
|
|
17
|
+
from nat.middleware.cache_middleware import CacheMiddleware
|
|
18
|
+
from nat.middleware.function_middleware import FunctionMiddleware
|
|
19
|
+
from nat.middleware.function_middleware import FunctionMiddlewareChain
|
|
20
|
+
from nat.middleware.function_middleware import validate_middleware
|
|
21
|
+
from nat.middleware.middleware import CallNext
|
|
22
|
+
from nat.middleware.middleware import CallNextStream
|
|
23
|
+
from nat.middleware.middleware import FunctionMiddlewareContext
|
|
24
|
+
from nat.middleware.middleware import Middleware
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"CacheMiddleware",
|
|
28
|
+
"CallNext",
|
|
29
|
+
"CallNextStream",
|
|
30
|
+
"FunctionMiddlewareContext",
|
|
31
|
+
"Middleware",
|
|
32
|
+
"FunctionMiddleware",
|
|
33
|
+
"FunctionMiddlewareChain",
|
|
34
|
+
"validate_middleware",
|
|
35
|
+
]
|