nvidia-nat 1.4.0a20251102__py3-none-any.whl → 1.4.0a20251120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. nat/builder/builder.py +52 -0
  2. nat/builder/component_utils.py +7 -1
  3. nat/builder/context.py +17 -0
  4. nat/builder/framework_enum.py +1 -0
  5. nat/builder/function.py +74 -3
  6. nat/builder/workflow.py +4 -2
  7. nat/builder/workflow_builder.py +129 -0
  8. nat/cli/commands/workflow/workflow_commands.py +3 -2
  9. nat/cli/register_workflow.py +50 -0
  10. nat/cli/type_registry.py +68 -0
  11. nat/data_models/component.py +2 -0
  12. nat/data_models/component_ref.py +11 -0
  13. nat/data_models/config.py +16 -0
  14. nat/data_models/function.py +14 -1
  15. nat/data_models/middleware.py +35 -0
  16. nat/data_models/runtime_enum.py +26 -0
  17. nat/eval/dataset_handler/dataset_filter.py +34 -2
  18. nat/eval/evaluate.py +11 -3
  19. nat/eval/utils/weave_eval.py +17 -3
  20. nat/front_ends/fastapi/fastapi_front_end_config.py +29 -0
  21. nat/front_ends/fastapi/fastapi_front_end_plugin.py +13 -7
  22. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +144 -14
  23. nat/front_ends/mcp/mcp_front_end_plugin.py +4 -0
  24. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +26 -0
  25. nat/llm/aws_bedrock_llm.py +11 -9
  26. nat/llm/azure_openai_llm.py +12 -4
  27. nat/llm/litellm_llm.py +11 -4
  28. nat/llm/nim_llm.py +11 -9
  29. nat/llm/openai_llm.py +12 -9
  30. nat/middleware/__init__.py +35 -0
  31. nat/middleware/cache_middleware.py +256 -0
  32. nat/middleware/function_middleware.py +186 -0
  33. nat/middleware/middleware.py +184 -0
  34. nat/middleware/register.py +35 -0
  35. nat/profiler/decorators/framework_wrapper.py +16 -0
  36. nat/retriever/milvus/register.py +11 -3
  37. nat/retriever/milvus/retriever.py +102 -40
  38. nat/runtime/runner.py +12 -1
  39. nat/runtime/session.py +10 -3
  40. nat/tool/code_execution/code_sandbox.py +4 -7
  41. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  42. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +5 -0
  43. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  44. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  45. nat/tool/server_tools.py +15 -2
  46. nat/utils/__init__.py +8 -4
  47. nat/utils/io/yaml_tools.py +73 -3
  48. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/METADATA +11 -3
  49. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/RECORD +54 -50
  50. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/entry_points.txt +1 -0
  51. nat/data_models/temperature_mixin.py +0 -44
  52. nat/data_models/top_p_mixin.py +0 -44
  53. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  54. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/WHEEL +0 -0
  55. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  56. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE.md +0 -0
  57. {nvidia_nat-1.4.0a20251102.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/top_level.txt +0 -0
nat/cli/type_registry.py CHANGED
@@ -64,6 +64,8 @@ from nat.data_models.logging import LoggingBaseConfig
64
64
  from nat.data_models.logging import LoggingMethodConfigT
65
65
  from nat.data_models.memory import MemoryBaseConfig
66
66
  from nat.data_models.memory import MemoryBaseConfigT
67
+ from nat.data_models.middleware import MiddlewareBaseConfig
68
+ from nat.data_models.middleware import MiddlewareBaseConfigT
67
69
  from nat.data_models.object_store import ObjectStoreBaseConfig
68
70
  from nat.data_models.object_store import ObjectStoreBaseConfigT
69
71
  from nat.data_models.registry_handler import RegistryHandlerBaseConfig
@@ -76,6 +78,7 @@ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
76
78
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfigT
77
79
  from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
78
80
  from nat.memory.interfaces import MemoryEditor
81
+ from nat.middleware.middleware import Middleware
79
82
  from nat.object_store.interfaces import ObjectStore
80
83
  from nat.observability.exporter.base_exporter import BaseExporter
81
84
  from nat.registry_handlers.registry_handler_base import AbstractRegistryHandler
@@ -89,6 +92,7 @@ EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIte
89
92
  FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
90
93
  FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
91
94
  FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
95
+ MiddlewareBuildCallableT = Callable[[MiddlewareBaseConfigT, Builder], AsyncIterator[Middleware]]
92
96
  TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
93
97
  LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
94
98
  LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
@@ -111,6 +115,7 @@ FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncC
111
115
  FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
112
116
  AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
113
117
  FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
118
+ MiddlewareRegisteredCallableT = Callable[[MiddlewareBaseConfigT, Builder], AbstractAsyncContextManager[Middleware]]
114
119
  TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
115
120
  LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
116
121
  LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
@@ -179,6 +184,8 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
179
184
  and a description.
180
185
  """
181
186
 
187
+ model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
188
+
182
189
  build_fn: FunctionRegisteredCallableT = Field(repr=False)
183
190
  framework_wrappers: list[str] = Field(default_factory=list)
184
191
 
@@ -193,6 +200,15 @@ class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
193
200
  framework_wrappers: list[str] = Field(default_factory=list)
194
201
 
195
202
 
203
+ class RegisteredMiddlewareInfo(RegisteredInfo[MiddlewareBaseConfig]):
204
+ """
205
+ Represents registered middleware. Middleware provides middleware-style wrapping of
206
+ calls with preprocessing and postprocessing logic.
207
+ """
208
+
209
+ build_fn: MiddlewareRegisteredCallableT = Field(repr=False)
210
+
211
+
196
212
  class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
197
213
  """
198
214
  Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
@@ -331,6 +347,9 @@ class TypeRegistry:
331
347
  # Function Groups
332
348
  self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
333
349
 
350
+ # Middleware
351
+ self._registered_middleware: dict[type[MiddlewareBaseConfig], RegisteredMiddlewareInfo] = {}
352
+
334
353
  # LLMs
335
354
  self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
336
355
  self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
@@ -540,6 +559,49 @@ class TypeRegistry:
540
559
  """
541
560
  return list(self._registered_function_groups.values())
542
561
 
562
+ def register_middleware(self, registration: RegisteredMiddlewareInfo):
563
+ """Register middleware with the type registry.
564
+
565
+ Args:
566
+ registration: The middleware registration information
567
+
568
+ Raises:
569
+ ValueError: If middleware with the same config type is already registered
570
+ """
571
+ if (registration.config_type in self._registered_middleware):
572
+ raise ValueError(f"Middleware with the same config type `{registration.config_type}` has already been "
573
+ "registered.")
574
+
575
+ self._registered_middleware[registration.config_type] = registration
576
+
577
+ self._registration_changed()
578
+
579
+ def get_middleware(self, config_type: type[MiddlewareBaseConfig]) -> RegisteredMiddlewareInfo:
580
+ """Get registered middleware by its config type.
581
+
582
+ Args:
583
+ config_type: The middleware configuration type
584
+
585
+ Returns:
586
+ RegisteredMiddlewareInfo: The registered middleware information
587
+
588
+ Raises:
589
+ KeyError: If no middleware is registered for the given config type
590
+ """
591
+ try:
592
+ return self._registered_middleware[config_type]
593
+ except KeyError as err:
594
+ raise KeyError(f"Could not find registered middleware for config `{config_type}`. "
595
+ f"Registered configs: {set(self._registered_middleware.keys())}") from err
596
+
597
+ def get_registered_middleware(self) -> list[RegisteredInfo[MiddlewareBaseConfig]]:
598
+ """Get all registered middleware.
599
+
600
+ Returns:
601
+ list[RegisteredInfo[MiddlewareBaseConfig]]: List of all registered middleware
602
+ """
603
+ return list(self._registered_middleware.values())
604
+
543
605
  def register_llm_provider(self, info: RegisteredLLMProviderInfo):
544
606
 
545
607
  if (info.config_type in self._registered_llm_provider_infos):
@@ -912,6 +974,9 @@ class TypeRegistry:
912
974
  if component_type == ComponentEnum.TTC_STRATEGY:
913
975
  return self._registered_ttc_strategies
914
976
 
977
+ if component_type == ComponentEnum.MIDDLEWARE:
978
+ return self._registered_middleware
979
+
915
980
  raise ValueError(f"Supplied an unsupported component type {component_type}")
916
981
 
917
982
  def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
@@ -1038,6 +1103,9 @@ class TypeRegistry:
1038
1103
  if issubclass(cls, TTCStrategyBaseConfig):
1039
1104
  return self._do_compute_annotation(cls, self.get_registered_ttc_strategies())
1040
1105
 
1106
+ if issubclass(cls, MiddlewareBaseConfig):
1107
+ return self._do_compute_annotation(cls, self.get_registered_middleware())
1108
+
1041
1109
  raise ValueError(f"Supplied an unsupported component type {cls}")
1042
1110
 
1043
1111
 
@@ -28,6 +28,7 @@ class ComponentEnum(StrEnum):
28
28
  FRONT_END = "front_end"
29
29
  FUNCTION = "function"
30
30
  FUNCTION_GROUP = "function_group"
31
+ MIDDLEWARE = "middleware"
31
32
  TTC_STRATEGY = "ttc_strategy"
32
33
  LLM_CLIENT = "llm_client"
33
34
  LLM_PROVIDER = "llm_provider"
@@ -49,6 +50,7 @@ class ComponentGroup(StrEnum):
49
50
  EMBEDDERS = "embedders"
50
51
  FUNCTIONS = "functions"
51
52
  FUNCTION_GROUPS = "function_groups"
53
+ MIDDLEWARE = "middleware"
52
54
  TTC_STRATEGIES = "ttc_strategies"
53
55
  LLMS = "llms"
54
56
  MEMORY = "memory"
@@ -177,3 +177,14 @@ class TTCStrategyRef(ComponentRef):
177
177
  @override
178
178
  def component_group(self):
179
179
  return ComponentGroup.TTC_STRATEGIES
180
+
181
+
182
+ class MiddlewareRef(ComponentRef):
183
+ """
184
+ A reference to middleware in a NAT configuration object.
185
+ """
186
+
187
+ @property
188
+ @override
189
+ def component_group(self):
190
+ return ComponentGroup.MIDDLEWARE
nat/data_models/config.py CHANGED
@@ -43,6 +43,7 @@ from .common import TypedBaseModel
43
43
  from .embedder import EmbedderBaseConfig
44
44
  from .llm import LLMBaseConfig
45
45
  from .memory import MemoryBaseConfig
46
+ from .middleware import FunctionMiddlewareBaseConfig
46
47
  from .object_store import ObjectStoreBaseConfig
47
48
  from .retriever import RetrieverBaseConfig
48
49
 
@@ -86,6 +87,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
86
87
  registered_keys = GlobalTypeRegistry.get().get_registered_front_ends()
87
88
  elif (info.field_name == "ttc_strategies"):
88
89
  registered_keys = GlobalTypeRegistry.get().get_registered_ttc_strategies()
90
+ elif (info.field_name == "middleware"):
91
+ registered_keys = GlobalTypeRegistry.get().get_registered_middleware()
89
92
 
90
93
  else:
91
94
  assert False, f"Unknown field name {info.field_name} in validator"
@@ -253,6 +256,9 @@ class Config(HashableBaseModel):
253
256
  # Function Groups Configuration
254
257
  function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
255
258
 
259
+ # Middleware Configuration
260
+ middleware: dict[str, FunctionMiddlewareBaseConfig] = Field(default_factory=dict)
261
+
256
262
  # LLMs Configuration
257
263
  llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
258
264
 
@@ -303,6 +309,7 @@ class Config(HashableBaseModel):
303
309
 
304
310
  @field_validator("functions",
305
311
  "function_groups",
312
+ "middleware",
306
313
  "llms",
307
314
  "embedders",
308
315
  "memory",
@@ -348,6 +355,10 @@ class Config(HashableBaseModel):
348
355
  typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
349
356
  Discriminator(TypedBaseModel.discriminator)]]
350
357
 
358
+ MiddlewareAnnotation = dict[str,
359
+ typing.Annotated[type_registry.compute_annotation(FunctionMiddlewareBaseConfig),
360
+ Discriminator(TypedBaseModel.discriminator)]]
361
+
351
362
  MemoryAnnotation = dict[str,
352
363
  typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
353
364
  Discriminator(TypedBaseModel.discriminator)]]
@@ -393,6 +404,11 @@ class Config(HashableBaseModel):
393
404
  function_groups_field.annotation = FunctionGroupsAnnotation
394
405
  should_rebuild = True
395
406
 
407
+ middleware_field = cls.model_fields.get("middleware")
408
+ if (middleware_field is not None and middleware_field.annotation != MiddlewareAnnotation):
409
+ middleware_field.annotation = MiddlewareAnnotation
410
+ should_rebuild = True
411
+
396
412
  memory_field = cls.model_fields.get("memory")
397
413
  if memory_field is not None and memory_field.annotation != MemoryAnnotation:
398
414
  memory_field.annotation = MemoryAnnotation
@@ -24,7 +24,16 @@ from .common import TypedBaseModel
24
24
 
25
25
 
26
26
  class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
27
- pass
27
+ """Base configuration for functions.
28
+
29
+ Attributes:
30
+ middleware: List of function middleware names to apply to this function.
31
+ These must match names defined in the `middleware` section of the YAML configuration.
32
+ """
33
+ middleware: list[str] = Field(
34
+ default_factory=list,
35
+ description="List of function middleware names to apply to this function in order",
36
+ )
28
37
 
29
38
 
30
39
  class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
@@ -40,6 +49,10 @@ class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
40
49
  default_factory=list,
41
50
  description="The list of function names which should be excluded from default access to the group",
42
51
  )
52
+ middleware: list[str] = Field(
53
+ default_factory=list,
54
+ description="List of function middleware names to apply to all functions in this group",
55
+ )
43
56
 
44
57
  @field_validator("include", "exclude")
45
58
  @classmethod
@@ -0,0 +1,35 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import typing
17
+
18
+ from .common import BaseModelRegistryTag
19
+ from .common import TypedBaseModel
20
+
21
+
22
+ class MiddlewareBaseConfig(TypedBaseModel, BaseModelRegistryTag):
23
+ """The base level config object for middleware.
24
+
25
+ Middleware provides middleware-style wrapping of calls with
26
+ preprocessing and postprocessing logic.
27
+ """
28
+ pass
29
+
30
+
31
+ MiddlewareBaseConfigT = typing.TypeVar("MiddlewareBaseConfigT", bound=MiddlewareBaseConfig)
32
+
33
+ # Specialized type for function-specific middleware
34
+ FunctionMiddlewareBaseConfig = MiddlewareBaseConfig
35
+ FunctionMiddlewareBaseConfigT = MiddlewareBaseConfigT
@@ -0,0 +1,26 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import enum
17
+
18
+
19
+ class RuntimeTypeEnum(str, enum.Enum):
20
+ """
21
+ Enum representing different runtime types.
22
+ """
23
+
24
+ RUN_OR_SERVE = "run_or_serve"
25
+ EVALUATE = "evaluate"
26
+ OTHER = "other"
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import fnmatch
17
+
16
18
  import pandas as pd
17
19
 
18
20
  from nat.data_models.dataset_handler import EvalFilterConfig
@@ -24,6 +26,7 @@ class DatasetFilter:
24
26
  - If a allowlist is provided, only keep rows matching the filter values.
25
27
  - If a denylist is provided, remove rows matching the filter values.
26
28
  - If the filter column does not exist in the DataFrame, the filtering is skipped for that column.
29
+ - Supports Unix shell-style wildcards (``*``, ``?``, ``[seq]``, ``[!seq]``) for string matching.
27
30
 
28
31
  This is a utility class that is dataset agnostic and can be used to filter any DataFrame based on the provided
29
32
  filter configuration.
@@ -33,6 +36,33 @@ class DatasetFilter:
33
36
 
34
37
  self.filter_config = filter_config
35
38
 
39
+ @staticmethod
40
+ def _match_wildcard_patterns(series: pd.Series, patterns: list[str | int | float]) -> pd.Series:
41
+ """
42
+ Match series values against wildcard patterns and exact values.
43
+
44
+ Args:
45
+ series (pd.Series): pandas Series to match against
46
+ patterns (list[str | int | float]): List of patterns/values
47
+
48
+ Returns:
49
+ pd.Series: Boolean Series indicating matches
50
+ """
51
+ # Convert series to string for pattern matching
52
+ str_series = series.astype(str)
53
+
54
+ # Initialize boolean mask
55
+ matches = pd.Series([False] * len(series), index=series.index)
56
+
57
+ # Check each pattern using fnmatch with list comprehension to avoid lambda capture
58
+ for pattern in patterns:
59
+ pattern_str = str(pattern)
60
+ pattern_matches = pd.Series([fnmatch.fnmatch(val, pattern_str) for val in str_series],
61
+ index=str_series.index)
62
+ matches |= pattern_matches
63
+
64
+ return matches
65
+
36
66
  def apply_filters(self, df) -> pd.DataFrame:
37
67
 
38
68
  filtered_df = df.copy()
@@ -41,12 +71,14 @@ class DatasetFilter:
41
71
  if self.filter_config.allowlist:
42
72
  for column, values in self.filter_config.allowlist.field.items():
43
73
  if column in filtered_df.columns:
44
- filtered_df = filtered_df[filtered_df[column].isin(values)]
74
+ matches = self._match_wildcard_patterns(filtered_df[column], values)
75
+ filtered_df = filtered_df[matches]
45
76
 
46
77
  # Apply denylist (remove specified rows)
47
78
  if self.filter_config.denylist:
48
79
  for column, values in self.filter_config.denylist.field.items():
49
80
  if column in filtered_df.columns:
50
- filtered_df = filtered_df[~filtered_df[column].isin(values)]
81
+ matches = self._match_wildcard_patterns(filtered_df[column], values)
82
+ filtered_df = filtered_df[~matches]
51
83
 
52
84
  return filtered_df
nat/eval/evaluate.py CHANGED
@@ -16,6 +16,7 @@
16
16
  import asyncio
17
17
  import logging
18
18
  import shutil
19
+ import warnings
19
20
  from pathlib import Path
20
21
  from typing import Any
21
22
  from uuid import uuid4
@@ -25,6 +26,7 @@ from tqdm import tqdm
25
26
 
26
27
  from nat.data_models.evaluate import EvalConfig
27
28
  from nat.data_models.evaluate import JobEvictionPolicy
29
+ from nat.data_models.runtime_enum import RuntimeTypeEnum
28
30
  from nat.eval.config import EvaluationRunConfig
29
31
  from nat.eval.config import EvaluationRunOutput
30
32
  from nat.eval.dataset_handler.dataset_handler import DatasetHandler
@@ -67,7 +69,13 @@ class EvaluationRun:
67
69
  # Create evaluation trace context
68
70
  try:
69
71
  from nat.eval.utils.eval_trace_ctx import WeaveEvalTraceContext
70
- self.eval_trace_context = WeaveEvalTraceContext()
72
+ with warnings.catch_warnings():
73
+ # Ignore deprecation warnings being triggered by weave. https://github.com/wandb/weave/issues/3666
74
+ warnings.filterwarnings("ignore",
75
+ category=DeprecationWarning,
76
+ message=r"`sentry_sdk\.Hub` is deprecated")
77
+
78
+ self.eval_trace_context = WeaveEvalTraceContext()
71
79
  except Exception:
72
80
  from nat.eval.utils.eval_trace_ctx import EvalTraceContext
73
81
  self.eval_trace_context = EvalTraceContext()
@@ -161,7 +169,7 @@ class EvaluationRun:
161
169
  if stop_event.is_set():
162
170
  return "", []
163
171
 
164
- async with session_manager.run(item.input_obj) as runner:
172
+ async with session_manager.run(item.input_obj, runtime_type=RuntimeTypeEnum.EVALUATE) as runner:
165
173
  if not session_manager.workflow.has_single_output:
166
174
  # raise an error if the workflow has multiple outputs
167
175
  raise NotImplementedError("Multiple outputs are not supported")
@@ -514,7 +522,7 @@ class EvaluationRun:
514
522
  # Run workflow and evaluate
515
523
  async with WorkflowEvalBuilder.from_config(config=config) as eval_workflow:
516
524
  # Initialize Weave integration
517
- self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config)
525
+ self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config, job_id=job_id)
518
526
 
519
527
  with self.eval_trace_context.evaluation_context():
520
528
  # Run workflow
@@ -82,7 +82,7 @@ class WeaveEvaluationIntegration:
82
82
  """Get the full dataset for Weave."""
83
83
  return [item.full_dataset_entry for item in eval_input.eval_input_items]
84
84
 
85
- def initialize_logger(self, workflow_alias: str, eval_input: EvalInput, config: Any):
85
+ def initialize_logger(self, workflow_alias: str, eval_input: EvalInput, config: Any, job_id: str | None = None):
86
86
  """Initialize the Weave evaluation logger."""
87
87
  if not self.client and not self.initialize_client():
88
88
  # lazy init the client
@@ -92,10 +92,16 @@ class WeaveEvaluationIntegration:
92
92
  weave_dataset = self._get_weave_dataset(eval_input)
93
93
  config_dict = config.model_dump(mode="json")
94
94
  config_dict["name"] = workflow_alias
95
+
96
+ # Include job_id in eval_attributes if provided
97
+ eval_attributes = {}
98
+ if job_id:
99
+ eval_attributes["job_id"] = job_id
100
+
95
101
  self.eval_logger = self.evaluation_logger_cls(model=config_dict,
96
102
  dataset=weave_dataset,
97
103
  name=workflow_alias,
98
- eval_attributes={})
104
+ eval_attributes=eval_attributes)
99
105
  self.pred_loggers = {}
100
106
 
101
107
  # Capture the current evaluation call for context propagation
@@ -136,9 +142,17 @@ class WeaveEvaluationIntegration:
136
142
  coros = []
137
143
  for eval_output_item in eval_output.eval_output_items:
138
144
  if eval_output_item.id in self.pred_loggers:
145
+ # Structure the score as a dict and include reasoning if available
146
+ score_value = {
147
+ "score": eval_output_item.score,
148
+ }
149
+
150
+ if eval_output_item.reasoning is not None:
151
+ score_value["reasoning"] = eval_output_item.reasoning
152
+
139
153
  coros.append(self.pred_loggers[eval_output_item.id].alog_score(
140
154
  scorer=evaluator_name,
141
- score=eval_output_item.score,
155
+ score=score_value,
142
156
  ))
143
157
 
144
158
  # Execute all coroutines concurrently
@@ -27,6 +27,8 @@ from pydantic import field_validator
27
27
  from nat.data_models.component_ref import ObjectStoreRef
28
28
  from nat.data_models.front_end import FrontEndBaseConfig
29
29
  from nat.data_models.step_adaptor import StepAdaptorConfig
30
+ from nat.eval.evaluator.evaluator_model import EvalInputItem
31
+ from nat.eval.evaluator.evaluator_model import EvalOutputItem
30
32
 
31
33
  logger = logging.getLogger(__name__)
32
34
 
@@ -133,6 +135,19 @@ class AsyncGenerationStatusResponse(BaseAsyncStatusResponse):
133
135
  description="Output of the generate request, this is only available if the job completed successfully.")
134
136
 
135
137
 
138
+ class EvaluateItemRequest(BaseModel):
139
+ """Request model for single-item evaluation endpoint."""
140
+ item: EvalInputItem = Field(description="Single evaluation input item to evaluate")
141
+ evaluator_name: str = Field(description="Name of the evaluator to use (must match config)")
142
+
143
+
144
+ class EvaluateItemResponse(BaseModel):
145
+ """Response model for single-item evaluation endpoint."""
146
+ success: bool = Field(description="Whether the evaluation completed successfully")
147
+ result: EvalOutputItem | None = Field(default=None, description="Evaluation result if successful")
148
+ error: str | None = Field(default=None, description="Error message if evaluation failed")
149
+
150
+
136
151
  class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
137
152
  """
138
153
  A FastAPI based front end that allows a NAT workflow to be served as a microservice.
@@ -211,6 +226,13 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
211
226
  "Maximum number of async jobs to run concurrently, this controls the number of dask workers created. "
212
227
  "This parameter is only used when scheduler_address is `None` and a Dask local cluster is created."),
213
228
  ge=1)
229
+ dask_workers: typing.Literal["threads", "processes"] = Field(
230
+ default="processes",
231
+ description=(
232
+ "Type of Dask workers to use. Options are 'threads' for Threaded Dask workers or 'processes' for "
233
+ "Process based Dask workers. This parameter is only used when scheduler_address is `None` and a local Dask "
234
+ "cluster is created."),
235
+ )
214
236
  dask_log_level: str = Field(
215
237
  default="WARNING",
216
238
  description="Logging level for Dask.",
@@ -232,6 +254,13 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
232
254
  description="Evaluates the performance and accuracy of the workflow on a dataset",
233
255
  )
234
256
 
257
+ evaluate_item: typing.Annotated[EndpointBase,
258
+ Field(description="Endpoint for evaluating a single item.")] = EndpointBase(
259
+ method="POST",
260
+ path="/evaluate/item",
261
+ description="Evaluate a single item with a specified evaluator",
262
+ )
263
+
235
264
  oauth2_callback_path: str | None = Field(
236
265
  default="/auth/redirect",
237
266
  description="OAuth2.0 authentication callback endpoint. If None, no OAuth2 callback endpoint is created.")
@@ -120,18 +120,24 @@ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]
120
120
 
121
121
  from dask.distributed import LocalCluster
122
122
 
123
- self._cluster = LocalCluster(processes=True,
123
+ use_threads = self.front_end_config.dask_workers == 'threads'
124
+
125
+ # set n_workers to max_running_async_jobs + 1 to allow for one worker to handle the cleanup task
126
+ self._cluster = LocalCluster(processes=not use_threads,
124
127
  silence_logs=dask_log_level,
125
- n_workers=self.front_end_config.max_running_async_jobs,
126
- threads_per_worker=1)
128
+ protocol="tcp",
129
+ n_workers=self.front_end_config.max_running_async_jobs + 1)
127
130
 
128
131
  self._scheduler_address = self._cluster.scheduler.address
129
132
 
130
- with self.blocking_client(self._scheduler_address) as client:
131
- # Client.run submits a function to be run on each worker
132
- client.run(self._setup_worker)
133
+ if not use_threads and sys.platform != "win32":
134
+ with self.blocking_client(self._scheduler_address) as client:
135
+ # Client.run submits a function to be run on each worker
136
+ client.run(self._setup_worker)
133
137
 
134
- logger.info("Created local Dask cluster with scheduler at %s", self._scheduler_address)
138
+ logger.info("Created local Dask cluster with scheduler at %s using %s workers",
139
+ self._scheduler_address,
140
+ self.front_end_config.dask_workers)
135
141
 
136
142
  except ImportError:
137
143
  logger.warning("Dask is not installed, async execution and evaluation will not be available.")