nvidia-nat 1.4.0a20251102__py3-none-any.whl → 1.4.0a20251120__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/builder/builder.py +52 -0
- nat/builder/component_utils.py +7 -1
- nat/builder/context.py +17 -0
- nat/builder/framework_enum.py +1 -0
- nat/builder/function.py +74 -3
- nat/builder/workflow.py +4 -2
- nat/builder/workflow_builder.py +129 -0
- nat/cli/commands/workflow/workflow_commands.py +3 -2
- nat/cli/register_workflow.py +50 -0
- nat/cli/type_registry.py +68 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +16 -0
- nat/data_models/function.py +14 -1
- nat/data_models/middleware.py +35 -0
- nat/data_models/runtime_enum.py +26 -0
- nat/eval/dataset_handler/dataset_filter.py +34 -2
- nat/eval/evaluate.py +11 -3
- nat/eval/utils/weave_eval.py +17 -3
- nat/front_ends/fastapi/fastapi_front_end_config.py +29 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +13 -7
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +144 -14
- nat/front_ends/mcp/mcp_front_end_plugin.py +4 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +26 -0
- nat/llm/aws_bedrock_llm.py +11 -9
- nat/llm/azure_openai_llm.py +12 -4
- nat/llm/litellm_llm.py +11 -4
- nat/llm/nim_llm.py +11 -9
- nat/llm/openai_llm.py +12 -9
- nat/middleware/__init__.py +35 -0
- nat/middleware/cache_middleware.py +256 -0
- nat/middleware/function_middleware.py +186 -0
- nat/middleware/middleware.py +184 -0
- nat/middleware/register.py +35 -0
- nat/profiler/decorators/framework_wrapper.py +16 -0
- nat/retriever/milvus/register.py +11 -3
- nat/retriever/milvus/retriever.py +102 -40
- nat/runtime/runner.py +12 -1
- nat/runtime/session.py +10 -3
- nat/tool/code_execution/code_sandbox.py +4 -7
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +5 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
- nat/tool/server_tools.py +15 -2
- nat/utils/__init__.py +8 -4
- nat/utils/io/yaml_tools.py +73 -3
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/METADATA +11 -3
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/RECORD +54 -50
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/entry_points.txt +1 -0
- nat/data_models/temperature_mixin.py +0 -44
- nat/data_models/top_p_mixin.py +0 -44
- nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/top_level.txt +0 -0
nat/cli/type_registry.py
CHANGED
|
@@ -64,6 +64,8 @@ from nat.data_models.logging import LoggingBaseConfig
|
|
|
64
64
|
from nat.data_models.logging import LoggingMethodConfigT
|
|
65
65
|
from nat.data_models.memory import MemoryBaseConfig
|
|
66
66
|
from nat.data_models.memory import MemoryBaseConfigT
|
|
67
|
+
from nat.data_models.middleware import MiddlewareBaseConfig
|
|
68
|
+
from nat.data_models.middleware import MiddlewareBaseConfigT
|
|
67
69
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
68
70
|
from nat.data_models.object_store import ObjectStoreBaseConfigT
|
|
69
71
|
from nat.data_models.registry_handler import RegistryHandlerBaseConfig
|
|
@@ -76,6 +78,7 @@ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
|
76
78
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfigT
|
|
77
79
|
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
78
80
|
from nat.memory.interfaces import MemoryEditor
|
|
81
|
+
from nat.middleware.middleware import Middleware
|
|
79
82
|
from nat.object_store.interfaces import ObjectStore
|
|
80
83
|
from nat.observability.exporter.base_exporter import BaseExporter
|
|
81
84
|
from nat.registry_handlers.registry_handler_base import AbstractRegistryHandler
|
|
@@ -89,6 +92,7 @@ EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIte
|
|
|
89
92
|
FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
|
|
90
93
|
FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
|
|
91
94
|
FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
|
|
95
|
+
MiddlewareBuildCallableT = Callable[[MiddlewareBaseConfigT, Builder], AsyncIterator[Middleware]]
|
|
92
96
|
TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
|
|
93
97
|
LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
|
|
94
98
|
LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
|
|
@@ -111,6 +115,7 @@ FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncC
|
|
|
111
115
|
FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
|
|
112
116
|
AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
|
|
113
117
|
FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
|
|
118
|
+
MiddlewareRegisteredCallableT = Callable[[MiddlewareBaseConfigT, Builder], AbstractAsyncContextManager[Middleware]]
|
|
114
119
|
TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
|
|
115
120
|
LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
116
121
|
LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
|
|
@@ -179,6 +184,8 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
|
|
|
179
184
|
and a description.
|
|
180
185
|
"""
|
|
181
186
|
|
|
187
|
+
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
|
188
|
+
|
|
182
189
|
build_fn: FunctionRegisteredCallableT = Field(repr=False)
|
|
183
190
|
framework_wrappers: list[str] = Field(default_factory=list)
|
|
184
191
|
|
|
@@ -193,6 +200,15 @@ class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
|
|
|
193
200
|
framework_wrappers: list[str] = Field(default_factory=list)
|
|
194
201
|
|
|
195
202
|
|
|
203
|
+
class RegisteredMiddlewareInfo(RegisteredInfo[MiddlewareBaseConfig]):
|
|
204
|
+
"""
|
|
205
|
+
Represents registered middleware. Middleware provides middleware-style wrapping of
|
|
206
|
+
calls with preprocessing and postprocessing logic.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
build_fn: MiddlewareRegisteredCallableT = Field(repr=False)
|
|
210
|
+
|
|
211
|
+
|
|
196
212
|
class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
|
|
197
213
|
"""
|
|
198
214
|
Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
|
|
@@ -331,6 +347,9 @@ class TypeRegistry:
|
|
|
331
347
|
# Function Groups
|
|
332
348
|
self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
|
|
333
349
|
|
|
350
|
+
# Middleware
|
|
351
|
+
self._registered_middleware: dict[type[MiddlewareBaseConfig], RegisteredMiddlewareInfo] = {}
|
|
352
|
+
|
|
334
353
|
# LLMs
|
|
335
354
|
self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
|
|
336
355
|
self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
|
|
@@ -540,6 +559,49 @@ class TypeRegistry:
|
|
|
540
559
|
"""
|
|
541
560
|
return list(self._registered_function_groups.values())
|
|
542
561
|
|
|
562
|
+
def register_middleware(self, registration: RegisteredMiddlewareInfo):
|
|
563
|
+
"""Register middleware with the type registry.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
registration: The middleware registration information
|
|
567
|
+
|
|
568
|
+
Raises:
|
|
569
|
+
ValueError: If middleware with the same config type is already registered
|
|
570
|
+
"""
|
|
571
|
+
if (registration.config_type in self._registered_middleware):
|
|
572
|
+
raise ValueError(f"Middleware with the same config type `{registration.config_type}` has already been "
|
|
573
|
+
"registered.")
|
|
574
|
+
|
|
575
|
+
self._registered_middleware[registration.config_type] = registration
|
|
576
|
+
|
|
577
|
+
self._registration_changed()
|
|
578
|
+
|
|
579
|
+
def get_middleware(self, config_type: type[MiddlewareBaseConfig]) -> RegisteredMiddlewareInfo:
|
|
580
|
+
"""Get registered middleware by its config type.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
config_type: The middleware configuration type
|
|
584
|
+
|
|
585
|
+
Returns:
|
|
586
|
+
RegisteredMiddlewareInfo: The registered middleware information
|
|
587
|
+
|
|
588
|
+
Raises:
|
|
589
|
+
KeyError: If no middleware is registered for the given config type
|
|
590
|
+
"""
|
|
591
|
+
try:
|
|
592
|
+
return self._registered_middleware[config_type]
|
|
593
|
+
except KeyError as err:
|
|
594
|
+
raise KeyError(f"Could not find registered middleware for config `{config_type}`. "
|
|
595
|
+
f"Registered configs: {set(self._registered_middleware.keys())}") from err
|
|
596
|
+
|
|
597
|
+
def get_registered_middleware(self) -> list[RegisteredInfo[MiddlewareBaseConfig]]:
|
|
598
|
+
"""Get all registered middleware.
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
list[RegisteredInfo[MiddlewareBaseConfig]]: List of all registered middleware
|
|
602
|
+
"""
|
|
603
|
+
return list(self._registered_middleware.values())
|
|
604
|
+
|
|
543
605
|
def register_llm_provider(self, info: RegisteredLLMProviderInfo):
|
|
544
606
|
|
|
545
607
|
if (info.config_type in self._registered_llm_provider_infos):
|
|
@@ -912,6 +974,9 @@ class TypeRegistry:
|
|
|
912
974
|
if component_type == ComponentEnum.TTC_STRATEGY:
|
|
913
975
|
return self._registered_ttc_strategies
|
|
914
976
|
|
|
977
|
+
if component_type == ComponentEnum.MIDDLEWARE:
|
|
978
|
+
return self._registered_middleware
|
|
979
|
+
|
|
915
980
|
raise ValueError(f"Supplied an unsupported component type {component_type}")
|
|
916
981
|
|
|
917
982
|
def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
|
|
@@ -1038,6 +1103,9 @@ class TypeRegistry:
|
|
|
1038
1103
|
if issubclass(cls, TTCStrategyBaseConfig):
|
|
1039
1104
|
return self._do_compute_annotation(cls, self.get_registered_ttc_strategies())
|
|
1040
1105
|
|
|
1106
|
+
if issubclass(cls, MiddlewareBaseConfig):
|
|
1107
|
+
return self._do_compute_annotation(cls, self.get_registered_middleware())
|
|
1108
|
+
|
|
1041
1109
|
raise ValueError(f"Supplied an unsupported component type {cls}")
|
|
1042
1110
|
|
|
1043
1111
|
|
nat/data_models/component.py
CHANGED
|
@@ -28,6 +28,7 @@ class ComponentEnum(StrEnum):
|
|
|
28
28
|
FRONT_END = "front_end"
|
|
29
29
|
FUNCTION = "function"
|
|
30
30
|
FUNCTION_GROUP = "function_group"
|
|
31
|
+
MIDDLEWARE = "middleware"
|
|
31
32
|
TTC_STRATEGY = "ttc_strategy"
|
|
32
33
|
LLM_CLIENT = "llm_client"
|
|
33
34
|
LLM_PROVIDER = "llm_provider"
|
|
@@ -49,6 +50,7 @@ class ComponentGroup(StrEnum):
|
|
|
49
50
|
EMBEDDERS = "embedders"
|
|
50
51
|
FUNCTIONS = "functions"
|
|
51
52
|
FUNCTION_GROUPS = "function_groups"
|
|
53
|
+
MIDDLEWARE = "middleware"
|
|
52
54
|
TTC_STRATEGIES = "ttc_strategies"
|
|
53
55
|
LLMS = "llms"
|
|
54
56
|
MEMORY = "memory"
|
nat/data_models/component_ref.py
CHANGED
|
@@ -177,3 +177,14 @@ class TTCStrategyRef(ComponentRef):
|
|
|
177
177
|
@override
|
|
178
178
|
def component_group(self):
|
|
179
179
|
return ComponentGroup.TTC_STRATEGIES
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class MiddlewareRef(ComponentRef):
|
|
183
|
+
"""
|
|
184
|
+
A reference to middleware in a NAT configuration object.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
@override
|
|
189
|
+
def component_group(self):
|
|
190
|
+
return ComponentGroup.MIDDLEWARE
|
nat/data_models/config.py
CHANGED
|
@@ -43,6 +43,7 @@ from .common import TypedBaseModel
|
|
|
43
43
|
from .embedder import EmbedderBaseConfig
|
|
44
44
|
from .llm import LLMBaseConfig
|
|
45
45
|
from .memory import MemoryBaseConfig
|
|
46
|
+
from .middleware import FunctionMiddlewareBaseConfig
|
|
46
47
|
from .object_store import ObjectStoreBaseConfig
|
|
47
48
|
from .retriever import RetrieverBaseConfig
|
|
48
49
|
|
|
@@ -86,6 +87,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
86
87
|
registered_keys = GlobalTypeRegistry.get().get_registered_front_ends()
|
|
87
88
|
elif (info.field_name == "ttc_strategies"):
|
|
88
89
|
registered_keys = GlobalTypeRegistry.get().get_registered_ttc_strategies()
|
|
90
|
+
elif (info.field_name == "middleware"):
|
|
91
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_middleware()
|
|
89
92
|
|
|
90
93
|
else:
|
|
91
94
|
assert False, f"Unknown field name {info.field_name} in validator"
|
|
@@ -253,6 +256,9 @@ class Config(HashableBaseModel):
|
|
|
253
256
|
# Function Groups Configuration
|
|
254
257
|
function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
|
|
255
258
|
|
|
259
|
+
# Middleware Configuration
|
|
260
|
+
middleware: dict[str, FunctionMiddlewareBaseConfig] = Field(default_factory=dict)
|
|
261
|
+
|
|
256
262
|
# LLMs Configuration
|
|
257
263
|
llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
|
|
258
264
|
|
|
@@ -303,6 +309,7 @@ class Config(HashableBaseModel):
|
|
|
303
309
|
|
|
304
310
|
@field_validator("functions",
|
|
305
311
|
"function_groups",
|
|
312
|
+
"middleware",
|
|
306
313
|
"llms",
|
|
307
314
|
"embedders",
|
|
308
315
|
"memory",
|
|
@@ -348,6 +355,10 @@ class Config(HashableBaseModel):
|
|
|
348
355
|
typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
|
|
349
356
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
350
357
|
|
|
358
|
+
MiddlewareAnnotation = dict[str,
|
|
359
|
+
typing.Annotated[type_registry.compute_annotation(FunctionMiddlewareBaseConfig),
|
|
360
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
361
|
+
|
|
351
362
|
MemoryAnnotation = dict[str,
|
|
352
363
|
typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
|
|
353
364
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -393,6 +404,11 @@ class Config(HashableBaseModel):
|
|
|
393
404
|
function_groups_field.annotation = FunctionGroupsAnnotation
|
|
394
405
|
should_rebuild = True
|
|
395
406
|
|
|
407
|
+
middleware_field = cls.model_fields.get("middleware")
|
|
408
|
+
if (middleware_field is not None and middleware_field.annotation != MiddlewareAnnotation):
|
|
409
|
+
middleware_field.annotation = MiddlewareAnnotation
|
|
410
|
+
should_rebuild = True
|
|
411
|
+
|
|
396
412
|
memory_field = cls.model_fields.get("memory")
|
|
397
413
|
if memory_field is not None and memory_field.annotation != MemoryAnnotation:
|
|
398
414
|
memory_field.annotation = MemoryAnnotation
|
nat/data_models/function.py
CHANGED
|
@@ -24,7 +24,16 @@ from .common import TypedBaseModel
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
27
|
-
|
|
27
|
+
"""Base configuration for functions.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
middleware: List of function middleware names to apply to this function.
|
|
31
|
+
These must match names defined in the `middleware` section of the YAML configuration.
|
|
32
|
+
"""
|
|
33
|
+
middleware: list[str] = Field(
|
|
34
|
+
default_factory=list,
|
|
35
|
+
description="List of function middleware names to apply to this function in order",
|
|
36
|
+
)
|
|
28
37
|
|
|
29
38
|
|
|
30
39
|
class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
@@ -40,6 +49,10 @@ class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
|
40
49
|
default_factory=list,
|
|
41
50
|
description="The list of function names which should be excluded from default access to the group",
|
|
42
51
|
)
|
|
52
|
+
middleware: list[str] = Field(
|
|
53
|
+
default_factory=list,
|
|
54
|
+
description="List of function middleware names to apply to all functions in this group",
|
|
55
|
+
)
|
|
43
56
|
|
|
44
57
|
@field_validator("include", "exclude")
|
|
45
58
|
@classmethod
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
|
|
18
|
+
from .common import BaseModelRegistryTag
|
|
19
|
+
from .common import TypedBaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MiddlewareBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
23
|
+
"""The base level config object for middleware.
|
|
24
|
+
|
|
25
|
+
Middleware provides middleware-style wrapping of calls with
|
|
26
|
+
preprocessing and postprocessing logic.
|
|
27
|
+
"""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
MiddlewareBaseConfigT = typing.TypeVar("MiddlewareBaseConfigT", bound=MiddlewareBaseConfig)
|
|
32
|
+
|
|
33
|
+
# Specialized type for function-specific middleware
|
|
34
|
+
FunctionMiddlewareBaseConfig = MiddlewareBaseConfig
|
|
35
|
+
FunctionMiddlewareBaseConfigT = MiddlewareBaseConfigT
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import enum
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RuntimeTypeEnum(str, enum.Enum):
|
|
20
|
+
"""
|
|
21
|
+
Enum representing different runtime types.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
RUN_OR_SERVE = "run_or_serve"
|
|
25
|
+
EVALUATE = "evaluate"
|
|
26
|
+
OTHER = "other"
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import fnmatch
|
|
17
|
+
|
|
16
18
|
import pandas as pd
|
|
17
19
|
|
|
18
20
|
from nat.data_models.dataset_handler import EvalFilterConfig
|
|
@@ -24,6 +26,7 @@ class DatasetFilter:
|
|
|
24
26
|
- If a allowlist is provided, only keep rows matching the filter values.
|
|
25
27
|
- If a denylist is provided, remove rows matching the filter values.
|
|
26
28
|
- If the filter column does not exist in the DataFrame, the filtering is skipped for that column.
|
|
29
|
+
- Supports Unix shell-style wildcards (``*``, ``?``, ``[seq]``, ``[!seq]``) for string matching.
|
|
27
30
|
|
|
28
31
|
This is a utility class that is dataset agnostic and can be used to filter any DataFrame based on the provided
|
|
29
32
|
filter configuration.
|
|
@@ -33,6 +36,33 @@ class DatasetFilter:
|
|
|
33
36
|
|
|
34
37
|
self.filter_config = filter_config
|
|
35
38
|
|
|
39
|
+
@staticmethod
|
|
40
|
+
def _match_wildcard_patterns(series: pd.Series, patterns: list[str | int | float]) -> pd.Series:
|
|
41
|
+
"""
|
|
42
|
+
Match series values against wildcard patterns and exact values.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
series (pd.Series): pandas Series to match against
|
|
46
|
+
patterns (list[str | int | float]): List of patterns/values
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
pd.Series: Boolean Series indicating matches
|
|
50
|
+
"""
|
|
51
|
+
# Convert series to string for pattern matching
|
|
52
|
+
str_series = series.astype(str)
|
|
53
|
+
|
|
54
|
+
# Initialize boolean mask
|
|
55
|
+
matches = pd.Series([False] * len(series), index=series.index)
|
|
56
|
+
|
|
57
|
+
# Check each pattern using fnmatch with list comprehension to avoid lambda capture
|
|
58
|
+
for pattern in patterns:
|
|
59
|
+
pattern_str = str(pattern)
|
|
60
|
+
pattern_matches = pd.Series([fnmatch.fnmatch(val, pattern_str) for val in str_series],
|
|
61
|
+
index=str_series.index)
|
|
62
|
+
matches |= pattern_matches
|
|
63
|
+
|
|
64
|
+
return matches
|
|
65
|
+
|
|
36
66
|
def apply_filters(self, df) -> pd.DataFrame:
|
|
37
67
|
|
|
38
68
|
filtered_df = df.copy()
|
|
@@ -41,12 +71,14 @@ class DatasetFilter:
|
|
|
41
71
|
if self.filter_config.allowlist:
|
|
42
72
|
for column, values in self.filter_config.allowlist.field.items():
|
|
43
73
|
if column in filtered_df.columns:
|
|
44
|
-
|
|
74
|
+
matches = self._match_wildcard_patterns(filtered_df[column], values)
|
|
75
|
+
filtered_df = filtered_df[matches]
|
|
45
76
|
|
|
46
77
|
# Apply denylist (remove specified rows)
|
|
47
78
|
if self.filter_config.denylist:
|
|
48
79
|
for column, values in self.filter_config.denylist.field.items():
|
|
49
80
|
if column in filtered_df.columns:
|
|
50
|
-
|
|
81
|
+
matches = self._match_wildcard_patterns(filtered_df[column], values)
|
|
82
|
+
filtered_df = filtered_df[~matches]
|
|
51
83
|
|
|
52
84
|
return filtered_df
|
nat/eval/evaluate.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import logging
|
|
18
18
|
import shutil
|
|
19
|
+
import warnings
|
|
19
20
|
from pathlib import Path
|
|
20
21
|
from typing import Any
|
|
21
22
|
from uuid import uuid4
|
|
@@ -25,6 +26,7 @@ from tqdm import tqdm
|
|
|
25
26
|
|
|
26
27
|
from nat.data_models.evaluate import EvalConfig
|
|
27
28
|
from nat.data_models.evaluate import JobEvictionPolicy
|
|
29
|
+
from nat.data_models.runtime_enum import RuntimeTypeEnum
|
|
28
30
|
from nat.eval.config import EvaluationRunConfig
|
|
29
31
|
from nat.eval.config import EvaluationRunOutput
|
|
30
32
|
from nat.eval.dataset_handler.dataset_handler import DatasetHandler
|
|
@@ -67,7 +69,13 @@ class EvaluationRun:
|
|
|
67
69
|
# Create evaluation trace context
|
|
68
70
|
try:
|
|
69
71
|
from nat.eval.utils.eval_trace_ctx import WeaveEvalTraceContext
|
|
70
|
-
|
|
72
|
+
with warnings.catch_warnings():
|
|
73
|
+
# Ignore deprecation warnings being triggered by weave. https://github.com/wandb/weave/issues/3666
|
|
74
|
+
warnings.filterwarnings("ignore",
|
|
75
|
+
category=DeprecationWarning,
|
|
76
|
+
message=r"`sentry_sdk\.Hub` is deprecated")
|
|
77
|
+
|
|
78
|
+
self.eval_trace_context = WeaveEvalTraceContext()
|
|
71
79
|
except Exception:
|
|
72
80
|
from nat.eval.utils.eval_trace_ctx import EvalTraceContext
|
|
73
81
|
self.eval_trace_context = EvalTraceContext()
|
|
@@ -161,7 +169,7 @@ class EvaluationRun:
|
|
|
161
169
|
if stop_event.is_set():
|
|
162
170
|
return "", []
|
|
163
171
|
|
|
164
|
-
async with session_manager.run(item.input_obj) as runner:
|
|
172
|
+
async with session_manager.run(item.input_obj, runtime_type=RuntimeTypeEnum.EVALUATE) as runner:
|
|
165
173
|
if not session_manager.workflow.has_single_output:
|
|
166
174
|
# raise an error if the workflow has multiple outputs
|
|
167
175
|
raise NotImplementedError("Multiple outputs are not supported")
|
|
@@ -514,7 +522,7 @@ class EvaluationRun:
|
|
|
514
522
|
# Run workflow and evaluate
|
|
515
523
|
async with WorkflowEvalBuilder.from_config(config=config) as eval_workflow:
|
|
516
524
|
# Initialize Weave integration
|
|
517
|
-
self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config)
|
|
525
|
+
self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config, job_id=job_id)
|
|
518
526
|
|
|
519
527
|
with self.eval_trace_context.evaluation_context():
|
|
520
528
|
# Run workflow
|
nat/eval/utils/weave_eval.py
CHANGED
|
@@ -82,7 +82,7 @@ class WeaveEvaluationIntegration:
|
|
|
82
82
|
"""Get the full dataset for Weave."""
|
|
83
83
|
return [item.full_dataset_entry for item in eval_input.eval_input_items]
|
|
84
84
|
|
|
85
|
-
def initialize_logger(self, workflow_alias: str, eval_input: EvalInput, config: Any):
|
|
85
|
+
def initialize_logger(self, workflow_alias: str, eval_input: EvalInput, config: Any, job_id: str | None = None):
|
|
86
86
|
"""Initialize the Weave evaluation logger."""
|
|
87
87
|
if not self.client and not self.initialize_client():
|
|
88
88
|
# lazy init the client
|
|
@@ -92,10 +92,16 @@ class WeaveEvaluationIntegration:
|
|
|
92
92
|
weave_dataset = self._get_weave_dataset(eval_input)
|
|
93
93
|
config_dict = config.model_dump(mode="json")
|
|
94
94
|
config_dict["name"] = workflow_alias
|
|
95
|
+
|
|
96
|
+
# Include job_id in eval_attributes if provided
|
|
97
|
+
eval_attributes = {}
|
|
98
|
+
if job_id:
|
|
99
|
+
eval_attributes["job_id"] = job_id
|
|
100
|
+
|
|
95
101
|
self.eval_logger = self.evaluation_logger_cls(model=config_dict,
|
|
96
102
|
dataset=weave_dataset,
|
|
97
103
|
name=workflow_alias,
|
|
98
|
-
eval_attributes=
|
|
104
|
+
eval_attributes=eval_attributes)
|
|
99
105
|
self.pred_loggers = {}
|
|
100
106
|
|
|
101
107
|
# Capture the current evaluation call for context propagation
|
|
@@ -136,9 +142,17 @@ class WeaveEvaluationIntegration:
|
|
|
136
142
|
coros = []
|
|
137
143
|
for eval_output_item in eval_output.eval_output_items:
|
|
138
144
|
if eval_output_item.id in self.pred_loggers:
|
|
145
|
+
# Structure the score as a dict and include reasoning if available
|
|
146
|
+
score_value = {
|
|
147
|
+
"score": eval_output_item.score,
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
if eval_output_item.reasoning is not None:
|
|
151
|
+
score_value["reasoning"] = eval_output_item.reasoning
|
|
152
|
+
|
|
139
153
|
coros.append(self.pred_loggers[eval_output_item.id].alog_score(
|
|
140
154
|
scorer=evaluator_name,
|
|
141
|
-
score=
|
|
155
|
+
score=score_value,
|
|
142
156
|
))
|
|
143
157
|
|
|
144
158
|
# Execute all coroutines concurrently
|
|
@@ -27,6 +27,8 @@ from pydantic import field_validator
|
|
|
27
27
|
from nat.data_models.component_ref import ObjectStoreRef
|
|
28
28
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
29
29
|
from nat.data_models.step_adaptor import StepAdaptorConfig
|
|
30
|
+
from nat.eval.evaluator.evaluator_model import EvalInputItem
|
|
31
|
+
from nat.eval.evaluator.evaluator_model import EvalOutputItem
|
|
30
32
|
|
|
31
33
|
logger = logging.getLogger(__name__)
|
|
32
34
|
|
|
@@ -133,6 +135,19 @@ class AsyncGenerationStatusResponse(BaseAsyncStatusResponse):
|
|
|
133
135
|
description="Output of the generate request, this is only available if the job completed successfully.")
|
|
134
136
|
|
|
135
137
|
|
|
138
|
+
class EvaluateItemRequest(BaseModel):
|
|
139
|
+
"""Request model for single-item evaluation endpoint."""
|
|
140
|
+
item: EvalInputItem = Field(description="Single evaluation input item to evaluate")
|
|
141
|
+
evaluator_name: str = Field(description="Name of the evaluator to use (must match config)")
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class EvaluateItemResponse(BaseModel):
|
|
145
|
+
"""Response model for single-item evaluation endpoint."""
|
|
146
|
+
success: bool = Field(description="Whether the evaluation completed successfully")
|
|
147
|
+
result: EvalOutputItem | None = Field(default=None, description="Evaluation result if successful")
|
|
148
|
+
error: str | None = Field(default=None, description="Error message if evaluation failed")
|
|
149
|
+
|
|
150
|
+
|
|
136
151
|
class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
|
|
137
152
|
"""
|
|
138
153
|
A FastAPI based front end that allows a NAT workflow to be served as a microservice.
|
|
@@ -211,6 +226,13 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
|
|
|
211
226
|
"Maximum number of async jobs to run concurrently, this controls the number of dask workers created. "
|
|
212
227
|
"This parameter is only used when scheduler_address is `None` and a Dask local cluster is created."),
|
|
213
228
|
ge=1)
|
|
229
|
+
dask_workers: typing.Literal["threads", "processes"] = Field(
|
|
230
|
+
default="processes",
|
|
231
|
+
description=(
|
|
232
|
+
"Type of Dask workers to use. Options are 'threads' for Threaded Dask workers or 'processes' for "
|
|
233
|
+
"Process based Dask workers. This parameter is only used when scheduler_address is `None` and a local Dask "
|
|
234
|
+
"cluster is created."),
|
|
235
|
+
)
|
|
214
236
|
dask_log_level: str = Field(
|
|
215
237
|
default="WARNING",
|
|
216
238
|
description="Logging level for Dask.",
|
|
@@ -232,6 +254,13 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
|
|
|
232
254
|
description="Evaluates the performance and accuracy of the workflow on a dataset",
|
|
233
255
|
)
|
|
234
256
|
|
|
257
|
+
evaluate_item: typing.Annotated[EndpointBase,
|
|
258
|
+
Field(description="Endpoint for evaluating a single item.")] = EndpointBase(
|
|
259
|
+
method="POST",
|
|
260
|
+
path="/evaluate/item",
|
|
261
|
+
description="Evaluate a single item with a specified evaluator",
|
|
262
|
+
)
|
|
263
|
+
|
|
235
264
|
oauth2_callback_path: str | None = Field(
|
|
236
265
|
default="/auth/redirect",
|
|
237
266
|
description="OAuth2.0 authentication callback endpoint. If None, no OAuth2 callback endpoint is created.")
|
|
@@ -120,18 +120,24 @@ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]
|
|
|
120
120
|
|
|
121
121
|
from dask.distributed import LocalCluster
|
|
122
122
|
|
|
123
|
-
self.
|
|
123
|
+
use_threads = self.front_end_config.dask_workers == 'threads'
|
|
124
|
+
|
|
125
|
+
# set n_workers to max_running_async_jobs + 1 to allow for one worker to handle the cleanup task
|
|
126
|
+
self._cluster = LocalCluster(processes=not use_threads,
|
|
124
127
|
silence_logs=dask_log_level,
|
|
125
|
-
|
|
126
|
-
|
|
128
|
+
protocol="tcp",
|
|
129
|
+
n_workers=self.front_end_config.max_running_async_jobs + 1)
|
|
127
130
|
|
|
128
131
|
self._scheduler_address = self._cluster.scheduler.address
|
|
129
132
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
+
if not use_threads and sys.platform != "win32":
|
|
134
|
+
with self.blocking_client(self._scheduler_address) as client:
|
|
135
|
+
# Client.run submits a function to be run on each worker
|
|
136
|
+
client.run(self._setup_worker)
|
|
133
137
|
|
|
134
|
-
logger.info("Created local Dask cluster with scheduler at %s",
|
|
138
|
+
logger.info("Created local Dask cluster with scheduler at %s using %s workers",
|
|
139
|
+
self._scheduler_address,
|
|
140
|
+
self.front_end_config.dask_workers)
|
|
135
141
|
|
|
136
142
|
except ImportError:
|
|
137
143
|
logger.warning("Dask is not installed, async execution and evaluation will not be available.")
|