nvidia-nat 1.4.0a20251112__py3-none-any.whl → 1.4.0a20251120__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. nat/builder/builder.py +52 -0
  2. nat/builder/component_utils.py +7 -1
  3. nat/builder/context.py +17 -0
  4. nat/builder/framework_enum.py +1 -0
  5. nat/builder/function.py +74 -3
  6. nat/builder/workflow.py +4 -2
  7. nat/builder/workflow_builder.py +129 -0
  8. nat/cli/register_workflow.py +50 -0
  9. nat/cli/type_registry.py +68 -0
  10. nat/data_models/component.py +2 -0
  11. nat/data_models/component_ref.py +11 -0
  12. nat/data_models/config.py +16 -0
  13. nat/data_models/function.py +14 -1
  14. nat/data_models/middleware.py +35 -0
  15. nat/data_models/runtime_enum.py +26 -0
  16. nat/eval/evaluate.py +10 -2
  17. nat/front_ends/fastapi/fastapi_front_end_config.py +22 -0
  18. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +124 -0
  19. nat/front_ends/mcp/mcp_front_end_plugin.py +4 -0
  20. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +26 -0
  21. nat/middleware/__init__.py +35 -0
  22. nat/middleware/cache_middleware.py +256 -0
  23. nat/middleware/function_middleware.py +186 -0
  24. nat/middleware/middleware.py +184 -0
  25. nat/middleware/register.py +35 -0
  26. nat/profiler/decorators/framework_wrapper.py +16 -0
  27. nat/retriever/milvus/register.py +11 -3
  28. nat/retriever/milvus/retriever.py +102 -40
  29. nat/runtime/runner.py +12 -1
  30. nat/runtime/session.py +10 -3
  31. nat/tool/code_execution/code_sandbox.py +1 -1
  32. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/METADATA +9 -3
  33. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/RECORD +38 -31
  34. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/entry_points.txt +1 -0
  35. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/WHEEL +0 -0
  36. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  37. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/licenses/LICENSE.md +0 -0
  38. {nvidia_nat-1.4.0a20251112.dist-info → nvidia_nat-1.4.0a20251120.dist-info}/top_level.txt +0 -0
nat/cli/type_registry.py CHANGED
@@ -64,6 +64,8 @@ from nat.data_models.logging import LoggingBaseConfig
64
64
  from nat.data_models.logging import LoggingMethodConfigT
65
65
  from nat.data_models.memory import MemoryBaseConfig
66
66
  from nat.data_models.memory import MemoryBaseConfigT
67
+ from nat.data_models.middleware import MiddlewareBaseConfig
68
+ from nat.data_models.middleware import MiddlewareBaseConfigT
67
69
  from nat.data_models.object_store import ObjectStoreBaseConfig
68
70
  from nat.data_models.object_store import ObjectStoreBaseConfigT
69
71
  from nat.data_models.registry_handler import RegistryHandlerBaseConfig
@@ -76,6 +78,7 @@ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
76
78
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfigT
77
79
  from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
78
80
  from nat.memory.interfaces import MemoryEditor
81
+ from nat.middleware.middleware import Middleware
79
82
  from nat.object_store.interfaces import ObjectStore
80
83
  from nat.observability.exporter.base_exporter import BaseExporter
81
84
  from nat.registry_handlers.registry_handler_base import AbstractRegistryHandler
@@ -89,6 +92,7 @@ EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIte
89
92
  FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
90
93
  FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
91
94
  FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
95
+ MiddlewareBuildCallableT = Callable[[MiddlewareBaseConfigT, Builder], AsyncIterator[Middleware]]
92
96
  TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
93
97
  LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
94
98
  LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
@@ -111,6 +115,7 @@ FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncC
111
115
  FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
112
116
  AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
113
117
  FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
118
+ MiddlewareRegisteredCallableT = Callable[[MiddlewareBaseConfigT, Builder], AbstractAsyncContextManager[Middleware]]
114
119
  TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
115
120
  LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
116
121
  LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
@@ -179,6 +184,8 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
179
184
  and a description.
180
185
  """
181
186
 
187
+ model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
188
+
182
189
  build_fn: FunctionRegisteredCallableT = Field(repr=False)
183
190
  framework_wrappers: list[str] = Field(default_factory=list)
184
191
 
@@ -193,6 +200,15 @@ class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
193
200
  framework_wrappers: list[str] = Field(default_factory=list)
194
201
 
195
202
 
203
+ class RegisteredMiddlewareInfo(RegisteredInfo[MiddlewareBaseConfig]):
204
+ """
205
+ Represents registered middleware. Middleware provides middleware-style wrapping of
206
+ calls with preprocessing and postprocessing logic.
207
+ """
208
+
209
+ build_fn: MiddlewareRegisteredCallableT = Field(repr=False)
210
+
211
+
196
212
  class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
197
213
  """
198
214
  Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
@@ -331,6 +347,9 @@ class TypeRegistry:
331
347
  # Function Groups
332
348
  self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
333
349
 
350
+ # Middleware
351
+ self._registered_middleware: dict[type[MiddlewareBaseConfig], RegisteredMiddlewareInfo] = {}
352
+
334
353
  # LLMs
335
354
  self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
336
355
  self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
@@ -540,6 +559,49 @@ class TypeRegistry:
540
559
  """
541
560
  return list(self._registered_function_groups.values())
542
561
 
562
+ def register_middleware(self, registration: RegisteredMiddlewareInfo):
563
+ """Register middleware with the type registry.
564
+
565
+ Args:
566
+ registration: The middleware registration information
567
+
568
+ Raises:
569
+ ValueError: If middleware with the same config type is already registered
570
+ """
571
+ if (registration.config_type in self._registered_middleware):
572
+ raise ValueError(f"Middleware with the same config type `{registration.config_type}` has already been "
573
+ "registered.")
574
+
575
+ self._registered_middleware[registration.config_type] = registration
576
+
577
+ self._registration_changed()
578
+
579
+ def get_middleware(self, config_type: type[MiddlewareBaseConfig]) -> RegisteredMiddlewareInfo:
580
+ """Get registered middleware by its config type.
581
+
582
+ Args:
583
+ config_type: The middleware configuration type
584
+
585
+ Returns:
586
+ RegisteredMiddlewareInfo: The registered middleware information
587
+
588
+ Raises:
589
+ KeyError: If no middleware is registered for the given config type
590
+ """
591
+ try:
592
+ return self._registered_middleware[config_type]
593
+ except KeyError as err:
594
+ raise KeyError(f"Could not find registered middleware for config `{config_type}`. "
595
+ f"Registered configs: {set(self._registered_middleware.keys())}") from err
596
+
597
+ def get_registered_middleware(self) -> list[RegisteredInfo[MiddlewareBaseConfig]]:
598
+ """Get all registered middleware.
599
+
600
+ Returns:
601
+ list[RegisteredInfo[MiddlewareBaseConfig]]: List of all registered middleware
602
+ """
603
+ return list(self._registered_middleware.values())
604
+
543
605
  def register_llm_provider(self, info: RegisteredLLMProviderInfo):
544
606
 
545
607
  if (info.config_type in self._registered_llm_provider_infos):
@@ -912,6 +974,9 @@ class TypeRegistry:
912
974
  if component_type == ComponentEnum.TTC_STRATEGY:
913
975
  return self._registered_ttc_strategies
914
976
 
977
+ if component_type == ComponentEnum.MIDDLEWARE:
978
+ return self._registered_middleware
979
+
915
980
  raise ValueError(f"Supplied an unsupported component type {component_type}")
916
981
 
917
982
  def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
@@ -1038,6 +1103,9 @@ class TypeRegistry:
1038
1103
  if issubclass(cls, TTCStrategyBaseConfig):
1039
1104
  return self._do_compute_annotation(cls, self.get_registered_ttc_strategies())
1040
1105
 
1106
+ if issubclass(cls, MiddlewareBaseConfig):
1107
+ return self._do_compute_annotation(cls, self.get_registered_middleware())
1108
+
1041
1109
  raise ValueError(f"Supplied an unsupported component type {cls}")
1042
1110
 
1043
1111
 
@@ -28,6 +28,7 @@ class ComponentEnum(StrEnum):
28
28
  FRONT_END = "front_end"
29
29
  FUNCTION = "function"
30
30
  FUNCTION_GROUP = "function_group"
31
+ MIDDLEWARE = "middleware"
31
32
  TTC_STRATEGY = "ttc_strategy"
32
33
  LLM_CLIENT = "llm_client"
33
34
  LLM_PROVIDER = "llm_provider"
@@ -49,6 +50,7 @@ class ComponentGroup(StrEnum):
49
50
  EMBEDDERS = "embedders"
50
51
  FUNCTIONS = "functions"
51
52
  FUNCTION_GROUPS = "function_groups"
53
+ MIDDLEWARE = "middleware"
52
54
  TTC_STRATEGIES = "ttc_strategies"
53
55
  LLMS = "llms"
54
56
  MEMORY = "memory"
@@ -177,3 +177,14 @@ class TTCStrategyRef(ComponentRef):
177
177
  @override
178
178
  def component_group(self):
179
179
  return ComponentGroup.TTC_STRATEGIES
180
+
181
+
182
+ class MiddlewareRef(ComponentRef):
183
+ """
184
+ A reference to middleware in a NAT configuration object.
185
+ """
186
+
187
+ @property
188
+ @override
189
+ def component_group(self):
190
+ return ComponentGroup.MIDDLEWARE
nat/data_models/config.py CHANGED
@@ -43,6 +43,7 @@ from .common import TypedBaseModel
43
43
  from .embedder import EmbedderBaseConfig
44
44
  from .llm import LLMBaseConfig
45
45
  from .memory import MemoryBaseConfig
46
+ from .middleware import FunctionMiddlewareBaseConfig
46
47
  from .object_store import ObjectStoreBaseConfig
47
48
  from .retriever import RetrieverBaseConfig
48
49
 
@@ -86,6 +87,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
86
87
  registered_keys = GlobalTypeRegistry.get().get_registered_front_ends()
87
88
  elif (info.field_name == "ttc_strategies"):
88
89
  registered_keys = GlobalTypeRegistry.get().get_registered_ttc_strategies()
90
+ elif (info.field_name == "middleware"):
91
+ registered_keys = GlobalTypeRegistry.get().get_registered_middleware()
89
92
 
90
93
  else:
91
94
  assert False, f"Unknown field name {info.field_name} in validator"
@@ -253,6 +256,9 @@ class Config(HashableBaseModel):
253
256
  # Function Groups Configuration
254
257
  function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
255
258
 
259
+ # Middleware Configuration
260
+ middleware: dict[str, FunctionMiddlewareBaseConfig] = Field(default_factory=dict)
261
+
256
262
  # LLMs Configuration
257
263
  llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
258
264
 
@@ -303,6 +309,7 @@ class Config(HashableBaseModel):
303
309
 
304
310
  @field_validator("functions",
305
311
  "function_groups",
312
+ "middleware",
306
313
  "llms",
307
314
  "embedders",
308
315
  "memory",
@@ -348,6 +355,10 @@ class Config(HashableBaseModel):
348
355
  typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
349
356
  Discriminator(TypedBaseModel.discriminator)]]
350
357
 
358
+ MiddlewareAnnotation = dict[str,
359
+ typing.Annotated[type_registry.compute_annotation(FunctionMiddlewareBaseConfig),
360
+ Discriminator(TypedBaseModel.discriminator)]]
361
+
351
362
  MemoryAnnotation = dict[str,
352
363
  typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
353
364
  Discriminator(TypedBaseModel.discriminator)]]
@@ -393,6 +404,11 @@ class Config(HashableBaseModel):
393
404
  function_groups_field.annotation = FunctionGroupsAnnotation
394
405
  should_rebuild = True
395
406
 
407
+ middleware_field = cls.model_fields.get("middleware")
408
+ if (middleware_field is not None and middleware_field.annotation != MiddlewareAnnotation):
409
+ middleware_field.annotation = MiddlewareAnnotation
410
+ should_rebuild = True
411
+
396
412
  memory_field = cls.model_fields.get("memory")
397
413
  if memory_field is not None and memory_field.annotation != MemoryAnnotation:
398
414
  memory_field.annotation = MemoryAnnotation
@@ -24,7 +24,16 @@ from .common import TypedBaseModel
24
24
 
25
25
 
26
26
  class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
27
- pass
27
+ """Base configuration for functions.
28
+
29
+ Attributes:
30
+ middleware: List of function middleware names to apply to this function.
31
+ These must match names defined in the `middleware` section of the YAML configuration.
32
+ """
33
+ middleware: list[str] = Field(
34
+ default_factory=list,
35
+ description="List of function middleware names to apply to this function in order",
36
+ )
28
37
 
29
38
 
30
39
  class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
@@ -40,6 +49,10 @@ class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
40
49
  default_factory=list,
41
50
  description="The list of function names which should be excluded from default access to the group",
42
51
  )
52
+ middleware: list[str] = Field(
53
+ default_factory=list,
54
+ description="List of function middleware names to apply to all functions in this group",
55
+ )
43
56
 
44
57
  @field_validator("include", "exclude")
45
58
  @classmethod
@@ -0,0 +1,35 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import typing
17
+
18
+ from .common import BaseModelRegistryTag
19
+ from .common import TypedBaseModel
20
+
21
+
22
+ class MiddlewareBaseConfig(TypedBaseModel, BaseModelRegistryTag):
23
+ """The base level config object for middleware.
24
+
25
+ Middleware provides middleware-style wrapping of calls with
26
+ preprocessing and postprocessing logic.
27
+ """
28
+ pass
29
+
30
+
31
+ MiddlewareBaseConfigT = typing.TypeVar("MiddlewareBaseConfigT", bound=MiddlewareBaseConfig)
32
+
33
+ # Specialized type for function-specific middleware
34
+ FunctionMiddlewareBaseConfig = MiddlewareBaseConfig
35
+ FunctionMiddlewareBaseConfigT = MiddlewareBaseConfigT
@@ -0,0 +1,26 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import enum
17
+
18
+
19
+ class RuntimeTypeEnum(str, enum.Enum):
20
+ """
21
+ Enum representing different runtime types.
22
+ """
23
+
24
+ RUN_OR_SERVE = "run_or_serve"
25
+ EVALUATE = "evaluate"
26
+ OTHER = "other"
nat/eval/evaluate.py CHANGED
@@ -16,6 +16,7 @@
16
16
  import asyncio
17
17
  import logging
18
18
  import shutil
19
+ import warnings
19
20
  from pathlib import Path
20
21
  from typing import Any
21
22
  from uuid import uuid4
@@ -25,6 +26,7 @@ from tqdm import tqdm
25
26
 
26
27
  from nat.data_models.evaluate import EvalConfig
27
28
  from nat.data_models.evaluate import JobEvictionPolicy
29
+ from nat.data_models.runtime_enum import RuntimeTypeEnum
28
30
  from nat.eval.config import EvaluationRunConfig
29
31
  from nat.eval.config import EvaluationRunOutput
30
32
  from nat.eval.dataset_handler.dataset_handler import DatasetHandler
@@ -67,7 +69,13 @@ class EvaluationRun:
67
69
  # Create evaluation trace context
68
70
  try:
69
71
  from nat.eval.utils.eval_trace_ctx import WeaveEvalTraceContext
70
- self.eval_trace_context = WeaveEvalTraceContext()
72
+ with warnings.catch_warnings():
73
+ # Ignore deprecation warnings being triggered by weave. https://github.com/wandb/weave/issues/3666
74
+ warnings.filterwarnings("ignore",
75
+ category=DeprecationWarning,
76
+ message=r"`sentry_sdk\.Hub` is deprecated")
77
+
78
+ self.eval_trace_context = WeaveEvalTraceContext()
71
79
  except Exception:
72
80
  from nat.eval.utils.eval_trace_ctx import EvalTraceContext
73
81
  self.eval_trace_context = EvalTraceContext()
@@ -161,7 +169,7 @@ class EvaluationRun:
161
169
  if stop_event.is_set():
162
170
  return "", []
163
171
 
164
- async with session_manager.run(item.input_obj) as runner:
172
+ async with session_manager.run(item.input_obj, runtime_type=RuntimeTypeEnum.EVALUATE) as runner:
165
173
  if not session_manager.workflow.has_single_output:
166
174
  # raise an error if the workflow has multiple outputs
167
175
  raise NotImplementedError("Multiple outputs are not supported")
@@ -27,6 +27,8 @@ from pydantic import field_validator
27
27
  from nat.data_models.component_ref import ObjectStoreRef
28
28
  from nat.data_models.front_end import FrontEndBaseConfig
29
29
  from nat.data_models.step_adaptor import StepAdaptorConfig
30
+ from nat.eval.evaluator.evaluator_model import EvalInputItem
31
+ from nat.eval.evaluator.evaluator_model import EvalOutputItem
30
32
 
31
33
  logger = logging.getLogger(__name__)
32
34
 
@@ -133,6 +135,19 @@ class AsyncGenerationStatusResponse(BaseAsyncStatusResponse):
133
135
  description="Output of the generate request, this is only available if the job completed successfully.")
134
136
 
135
137
 
138
+ class EvaluateItemRequest(BaseModel):
139
+ """Request model for single-item evaluation endpoint."""
140
+ item: EvalInputItem = Field(description="Single evaluation input item to evaluate")
141
+ evaluator_name: str = Field(description="Name of the evaluator to use (must match config)")
142
+
143
+
144
+ class EvaluateItemResponse(BaseModel):
145
+ """Response model for single-item evaluation endpoint."""
146
+ success: bool = Field(description="Whether the evaluation completed successfully")
147
+ result: EvalOutputItem | None = Field(default=None, description="Evaluation result if successful")
148
+ error: str | None = Field(default=None, description="Error message if evaluation failed")
149
+
150
+
136
151
  class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
137
152
  """
138
153
  A FastAPI based front end that allows a NAT workflow to be served as a microservice.
@@ -239,6 +254,13 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
239
254
  description="Evaluates the performance and accuracy of the workflow on a dataset",
240
255
  )
241
256
 
257
+ evaluate_item: typing.Annotated[EndpointBase,
258
+ Field(description="Endpoint for evaluating a single item.")] = EndpointBase(
259
+ method="POST",
260
+ path="/evaluate/item",
261
+ description="Evaluate a single item with a specified evaluator",
262
+ )
263
+
242
264
  oauth2_callback_path: str | None = Field(
243
265
  default="/auth/redirect",
244
266
  description="OAuth2.0 authentication callback endpoint. If None, no OAuth2 callback endpoint is created.")
@@ -39,6 +39,8 @@ from pydantic import BaseModel
39
39
  from pydantic import Field
40
40
  from starlette.websockets import WebSocket
41
41
 
42
+ from nat.builder.eval_builder import WorkflowEvalBuilder
43
+ from nat.builder.evaluator import EvaluatorInfo
42
44
  from nat.builder.function import Function
43
45
  from nat.builder.workflow_builder import WorkflowBuilder
44
46
  from nat.data_models.api_server import ChatRequest
@@ -51,11 +53,14 @@ from nat.data_models.object_store import NoSuchKeyError
51
53
  from nat.eval.config import EvaluationRunOutput
52
54
  from nat.eval.evaluate import EvaluationRun
53
55
  from nat.eval.evaluate import EvaluationRunConfig
56
+ from nat.eval.evaluator.evaluator_model import EvalInput
54
57
  from nat.front_ends.fastapi.auth_flow_handlers.http_flow_handler import HTTPAuthenticationFlowHandler
55
58
  from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import FlowState
56
59
  from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import WebSocketAuthenticationFlowHandler
57
60
  from nat.front_ends.fastapi.fastapi_front_end_config import AsyncGenerateResponse
58
61
  from nat.front_ends.fastapi.fastapi_front_end_config import AsyncGenerationStatusResponse
62
+ from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateItemRequest
63
+ from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateItemResponse
59
64
  from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateRequest
60
65
  from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateResponse
61
66
  from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateStatusResponse
@@ -227,6 +232,54 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
227
232
  self._outstanding_flows: dict[str, FlowState] = {}
228
233
  self._outstanding_flows_lock = asyncio.Lock()
229
234
 
235
+ # Evaluator storage for single-item evaluation
236
+ self._evaluators: dict[str, EvaluatorInfo] = {}
237
+ self._eval_builder: WorkflowEvalBuilder | None = None
238
+
239
+ async def initialize_evaluators(self, config: Config):
240
+ """Initialize and store evaluators from config for single-item evaluation."""
241
+ if not config.eval or not config.eval.evaluators:
242
+ logger.info("No evaluators configured, skipping evaluator initialization")
243
+ return
244
+
245
+ try:
246
+ # Build evaluators using WorkflowEvalBuilder (same pattern as nat eval)
247
+ # Start with registry=None and let populate_builder set everything up
248
+ self._eval_builder = WorkflowEvalBuilder(general_config=config.general,
249
+ eval_general_config=config.eval.general,
250
+ registry=None)
251
+
252
+ # Enter the async context and keep it alive
253
+ await self._eval_builder.__aenter__()
254
+
255
+ # Populate builder with config (this sets up LLMs, functions, etc.)
256
+ # Skip workflow build since we already have it from the main builder
257
+ await self._eval_builder.populate_builder(config, skip_workflow=True)
258
+
259
+ # Now evaluators should be populated by populate_builder
260
+ for name in config.eval.evaluators.keys():
261
+ self._evaluators[name] = self._eval_builder.get_evaluator(name)
262
+ logger.info(f"Initialized evaluator: {name}")
263
+
264
+ logger.info(f"Successfully initialized {len(self._evaluators)} evaluators")
265
+
266
+ except Exception as e:
267
+ logger.error(f"Failed to initialize evaluators: {e}")
268
+ # Don't fail startup, just log the error
269
+ self._evaluators = {}
270
+
271
+ async def cleanup_evaluators(self):
272
+ """Clean up evaluator resources on shutdown."""
273
+ if self._eval_builder:
274
+ try:
275
+ await self._eval_builder.__aexit__(None, None, None)
276
+ logger.info("Evaluator builder context cleaned up")
277
+ except Exception as e:
278
+ logger.error(f"Error cleaning up evaluator builder: {e}")
279
+ finally:
280
+ self._eval_builder = None
281
+ self._evaluators.clear()
282
+
230
283
  def get_step_adaptor(self) -> StepAdaptor:
231
284
 
232
285
  return StepAdaptor(self.front_end_config.step_adaptor)
@@ -236,12 +289,20 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
236
289
  # Do things like setting the base URL and global configuration options
237
290
  app.root_path = self.front_end_config.root_path
238
291
 
292
+ # Initialize evaluators for single-item evaluation
293
+ # TODO: we need config control over this as it's not always needed
294
+ await self.initialize_evaluators(self._config)
295
+
296
+ # Ensure evaluator resources are cleaned up when the app shuts down
297
+ app.add_event_handler("shutdown", self.cleanup_evaluators)
298
+
239
299
  await self.add_routes(app, builder)
240
300
 
241
301
  async def add_routes(self, app: FastAPI, builder: WorkflowBuilder):
242
302
 
243
303
  await self.add_default_route(app, SessionManager(await builder.build()))
244
304
  await self.add_evaluate_route(app, SessionManager(await builder.build()))
305
+ await self.add_evaluate_item_route(app, SessionManager(await builder.build()))
245
306
  await self.add_static_files_route(app, builder)
246
307
  await self.add_authorization_route(app)
247
308
  await self.add_mcp_client_tool_list_route(app, builder)
@@ -439,6 +500,69 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
439
500
  else:
440
501
  logger.warning("Dask is not available, evaluation endpoints will not be added.")
441
502
 
503
+ async def add_evaluate_item_route(self, app: FastAPI, session_manager: SessionManager):
504
+ """Add the single-item evaluation endpoint to the FastAPI app."""
505
+
506
+ async def evaluate_single_item(request: EvaluateItemRequest, http_request: Request) -> EvaluateItemResponse:
507
+ """Handle single-item evaluation requests."""
508
+
509
+ async with session_manager.session(http_connection=http_request):
510
+
511
+ # Check if evaluator exists
512
+ if request.evaluator_name not in self._evaluators:
513
+ raise HTTPException(status_code=404,
514
+ detail=f"Evaluator '{request.evaluator_name}' not found. "
515
+ f"Available evaluators: {list(self._evaluators.keys())}")
516
+
517
+ try:
518
+ # Get the evaluator
519
+ evaluator = self._evaluators[request.evaluator_name]
520
+
521
+ # Run evaluation on single item
522
+ result = await evaluator.evaluate_fn(EvalInput(eval_input_items=[request.item]))
523
+
524
+ # Extract the single output item
525
+ if result.eval_output_items:
526
+ output_item = result.eval_output_items[0]
527
+ return EvaluateItemResponse(success=True, result=output_item, error=None)
528
+ else:
529
+ return EvaluateItemResponse(success=False, result=None, error="Evaluator returned no results")
530
+
531
+ except Exception as e:
532
+ logger.exception(f"Error evaluating item with {request.evaluator_name}")
533
+ return EvaluateItemResponse(success=False, result=None, error=f"Evaluation failed: {str(e)}")
534
+
535
+ # Register the route
536
+ if self.front_end_config.evaluate_item.path:
537
+ app.add_api_route(path=self.front_end_config.evaluate_item.path,
538
+ endpoint=evaluate_single_item,
539
+ methods=[self.front_end_config.evaluate_item.method],
540
+ response_model=EvaluateItemResponse,
541
+ description=self.front_end_config.evaluate_item.description,
542
+ responses={
543
+ 404: {
544
+ "description": "Evaluator not found",
545
+ "content": {
546
+ "application/json": {
547
+ "example": {
548
+ "detail": "Evaluator 'unknown' not found"
549
+ }
550
+ }
551
+ }
552
+ },
553
+ 500: {
554
+ "description": "Internal Server Error",
555
+ "content": {
556
+ "application/json": {
557
+ "example": {
558
+ "detail": "Internal server error occurred"
559
+ }
560
+ }
561
+ }
562
+ }
563
+ })
564
+ logger.info(f"Added evaluate_item route at {self.front_end_config.evaluate_item.path}")
565
+
442
566
  async def add_static_files_route(self, app: FastAPI, builder: WorkflowBuilder):
443
567
 
444
568
  if not self.front_end_config.object_store:
@@ -140,6 +140,10 @@ class MCPFrontEndPlugin(FrontEndBase[MCPFrontEndConfig]):
140
140
  # Mount the MCP server's ASGI app at the configured base_path
141
141
  app.mount(self.front_end_config.base_path, mcp.streamable_http_app())
142
142
 
143
+ # Allow plugins to add routes to the wrapper app (e.g., OAuth discovery endpoints)
144
+ worker = self._get_worker_instance()
145
+ await worker.add_root_level_routes(app, mcp)
146
+
143
147
  # Configure and start uvicorn server
144
148
  config = uvicorn.Config(
145
149
  app,
@@ -17,12 +17,16 @@ import logging
17
17
  from abc import ABC
18
18
  from abc import abstractmethod
19
19
  from collections.abc import Mapping
20
+ from typing import TYPE_CHECKING
20
21
  from typing import Any
21
22
 
22
23
  from mcp.server.fastmcp import FastMCP
23
24
  from starlette.exceptions import HTTPException
24
25
  from starlette.requests import Request
25
26
 
27
+ if TYPE_CHECKING:
28
+ from fastapi import FastAPI
29
+
26
30
  from nat.builder.function import Function
27
31
  from nat.builder.function_base import FunctionBase
28
32
  from nat.builder.workflow import Workflow
@@ -192,6 +196,28 @@ class MCPFrontEndPluginWorkerBase(ABC):
192
196
 
193
197
  return functions
194
198
 
199
+ async def add_root_level_routes(self, wrapper_app: "FastAPI", mcp: FastMCP) -> None:
200
+ """Add routes to the wrapper FastAPI app (optional extension point).
201
+
202
+ This method is called when base_path is configured and a wrapper
203
+ FastAPI app is created to mount the MCP server. Plugins can override
204
+ this to add routes to the wrapper app at the root level, outside the
205
+ mounted MCP server path.
206
+
207
+ Common use cases:
208
+ - OAuth discovery endpoints (e.g., /.well-known/oauth-protected-resource)
209
+ - Health checks at root level
210
+ - Static file serving
211
+ - Custom authentication/authorization endpoints
212
+
213
+ Default implementation does nothing, making this an optional extension point.
214
+
215
+ Args:
216
+ wrapper_app: The FastAPI wrapper application that mounts the MCP server
217
+ mcp: The FastMCP server instance (already mounted at base_path)
218
+ """
219
+ pass # Default: no additional root-level routes
220
+
195
221
  def _setup_debug_endpoints(self, mcp: FastMCP, functions: Mapping[str, FunctionBase]) -> None:
196
222
  """Set up HTTP debug endpoints for introspecting tools and schemas.
197
223
 
@@ -0,0 +1,35 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Middleware implementations for NeMo Agent Toolkit."""
16
+
17
+ from nat.middleware.cache_middleware import CacheMiddleware
18
+ from nat.middleware.function_middleware import FunctionMiddleware
19
+ from nat.middleware.function_middleware import FunctionMiddlewareChain
20
+ from nat.middleware.function_middleware import validate_middleware
21
+ from nat.middleware.middleware import CallNext
22
+ from nat.middleware.middleware import CallNextStream
23
+ from nat.middleware.middleware import FunctionMiddlewareContext
24
+ from nat.middleware.middleware import Middleware
25
+
26
+ __all__ = [
27
+ "CacheMiddleware",
28
+ "CallNext",
29
+ "CallNextStream",
30
+ "FunctionMiddlewareContext",
31
+ "Middleware",
32
+ "FunctionMiddleware",
33
+ "FunctionMiddlewareChain",
34
+ "validate_middleware",
35
+ ]