nemo-evaluator 0.1.41__py3-none-any.whl → 0.1.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nemo_evaluator/adapters/adapter_config.py +147 -51
- nemo_evaluator/adapters/interceptors/endpoint_interceptor.py +37 -8
- nemo_evaluator/adapters/interceptors/reasoning_interceptor.py +19 -4
- nemo_evaluator/adapters/pipeline.py +310 -0
- nemo_evaluator/adapters/server.py +50 -220
- nemo_evaluator/adapters/types.py +1 -0
- nemo_evaluator/api/api_dataclasses.py +95 -6
- nemo_evaluator/client/__init__.py +28 -0
- nemo_evaluator/client/adapter_transport.py +356 -0
- nemo_evaluator/client/client.py +371 -0
- nemo_evaluator/core/entrypoint.py +9 -5
- nemo_evaluator/core/evaluate.py +57 -1
- nemo_evaluator/core/input.py +90 -6
- nemo_evaluator/core/resources.py +1 -2
- nemo_evaluator/core/utils.py +123 -0
- nemo_evaluator/logging/__init__.py +2 -0
- nemo_evaluator/logging/context.py +15 -0
- nemo_evaluator/package_info.py +1 -1
- nemo_evaluator/sandbox/__init__.py +33 -0
- nemo_evaluator/sandbox/base.py +115 -0
- nemo_evaluator/sandbox/ecs_fargate.py +1332 -0
- {nemo_evaluator-0.1.41.dist-info → nemo_evaluator-0.1.71.dist-info}/METADATA +1 -1
- {nemo_evaluator-0.1.41.dist-info → nemo_evaluator-0.1.71.dist-info}/RECORD +27 -20
- {nemo_evaluator-0.1.41.dist-info → nemo_evaluator-0.1.71.dist-info}/WHEEL +1 -1
- {nemo_evaluator-0.1.41.dist-info → nemo_evaluator-0.1.71.dist-info}/entry_points.txt +0 -0
- {nemo_evaluator-0.1.41.dist-info → nemo_evaluator-0.1.71.dist-info}/licenses/LICENSE +0 -0
- {nemo_evaluator-0.1.41.dist-info → nemo_evaluator-0.1.71.dist-info}/top_level.txt +0 -0
|
@@ -13,9 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
|
|
16
17
|
from typing import Any
|
|
17
18
|
|
|
18
|
-
from pydantic import BaseModel, Field
|
|
19
|
+
from pydantic import BaseModel, Field, ValidationError
|
|
20
|
+
|
|
21
|
+
from nemo_evaluator.logging import get_logger
|
|
19
22
|
|
|
20
23
|
|
|
21
24
|
class DiscoveryConfig(BaseModel):
|
|
@@ -58,9 +61,122 @@ class PostEvalHookConfig(BaseModel):
|
|
|
58
61
|
use_enum_values = True
|
|
59
62
|
|
|
60
63
|
|
|
64
|
+
class LegacyAdapterConfig(BaseModel):
|
|
65
|
+
"""Legacy adapter configuration parameters (pre-interceptor format).
|
|
66
|
+
|
|
67
|
+
This model validates legacy configuration dictionaries to catch typos
|
|
68
|
+
and invalid parameters early, before conversion to the new interceptor format.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
class Config:
|
|
72
|
+
extra = "forbid" # Reject any extra fields not defined here
|
|
73
|
+
|
|
74
|
+
# Boolean flags for optional features
|
|
75
|
+
use_caching: bool = Field(default=True, description="Enable caching interceptor")
|
|
76
|
+
save_responses: bool = Field(default=False, description="Save responses to disk")
|
|
77
|
+
save_requests: bool = Field(default=False, description="Save requests to disk")
|
|
78
|
+
use_system_prompt: bool = Field(
|
|
79
|
+
default=False, description="Enable system prompt modification"
|
|
80
|
+
)
|
|
81
|
+
use_omni_info: bool = Field(
|
|
82
|
+
default=False, description="Enable omni info processing"
|
|
83
|
+
)
|
|
84
|
+
use_request_logging: bool = Field(
|
|
85
|
+
default=False, description="Enable request logging"
|
|
86
|
+
)
|
|
87
|
+
use_nvcf: bool = Field(default=False, description="Enable NVCF integration")
|
|
88
|
+
use_response_logging: bool = Field(
|
|
89
|
+
default=False, description="Enable response logging"
|
|
90
|
+
)
|
|
91
|
+
use_reasoning: bool = Field(
|
|
92
|
+
default=False, description="Enable reasoning token processing"
|
|
93
|
+
)
|
|
94
|
+
process_reasoning_traces: bool = Field(
|
|
95
|
+
default=False, description="Process reasoning traces"
|
|
96
|
+
)
|
|
97
|
+
use_progress_tracking: bool = Field(
|
|
98
|
+
default=False, description="Enable progress tracking"
|
|
99
|
+
)
|
|
100
|
+
use_raise_client_errors: bool = Field(
|
|
101
|
+
default=False, description="Raise client errors"
|
|
102
|
+
)
|
|
103
|
+
include_json: bool = Field(default=True, description="Include JSON in responses")
|
|
104
|
+
|
|
105
|
+
# Model fields that are also part of AdapterConfig
|
|
106
|
+
mode: str = Field(
|
|
107
|
+
default="server", description="Adapter mode: 'server' or 'client'"
|
|
108
|
+
)
|
|
109
|
+
generate_html_report: bool = Field(default=True, description="Generate HTML report")
|
|
110
|
+
html_report_size: int | None = Field(default=5, description="HTML report size")
|
|
111
|
+
tracking_requests_stats: bool = Field(
|
|
112
|
+
default=True, description="Track request statistics"
|
|
113
|
+
)
|
|
114
|
+
log_failed_requests: bool = Field(default=False, description="Log failed requests")
|
|
115
|
+
endpoint_type: str = Field(default="chat", description="Endpoint type")
|
|
116
|
+
caching_dir: str | None = Field(default=None, description="Caching directory")
|
|
117
|
+
|
|
118
|
+
# Optional string/dict configuration parameters
|
|
119
|
+
custom_system_prompt: str | None = Field(
|
|
120
|
+
default=None, description="Custom system prompt"
|
|
121
|
+
)
|
|
122
|
+
output_dir: str | None = Field(default=None, description="Output directory")
|
|
123
|
+
params_to_add: dict[str, Any] | None = Field(
|
|
124
|
+
default=None, description="Parameters to add"
|
|
125
|
+
)
|
|
126
|
+
params_to_remove: list[str] | None = Field(
|
|
127
|
+
default=None, description="Parameters to remove"
|
|
128
|
+
)
|
|
129
|
+
params_to_rename: dict[str, str] | None = Field(
|
|
130
|
+
default=None, description="Parameters to rename"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Optional integer limits
|
|
134
|
+
max_logged_requests: int | None = Field(
|
|
135
|
+
default=None, description="Max logged requests"
|
|
136
|
+
)
|
|
137
|
+
max_logged_responses: int | None = Field(
|
|
138
|
+
default=None, description="Max logged responses"
|
|
139
|
+
)
|
|
140
|
+
max_saved_requests: int | None = Field(
|
|
141
|
+
default=None, description="Max saved requests"
|
|
142
|
+
)
|
|
143
|
+
max_saved_responses: int | None = Field(
|
|
144
|
+
default=None, description="Max saved responses"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Reasoning-specific parameters
|
|
148
|
+
start_reasoning_token: str | None = Field(
|
|
149
|
+
default=None, description="Start reasoning token"
|
|
150
|
+
)
|
|
151
|
+
include_if_reasoning_not_finished: bool | None = Field(
|
|
152
|
+
default=None, description="Include unfinished reasoning"
|
|
153
|
+
)
|
|
154
|
+
track_reasoning: bool | None = Field(default=None, description="Track reasoning")
|
|
155
|
+
end_reasoning_token: str = Field(
|
|
156
|
+
default="</think>", description="End reasoning token"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Progress tracking parameters
|
|
160
|
+
progress_tracking_url: str | None = Field(
|
|
161
|
+
default=None, description="Progress tracking URL"
|
|
162
|
+
)
|
|
163
|
+
progress_tracking_interval: int = Field(
|
|
164
|
+
default=1, description="Progress tracking interval"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Logging parameters
|
|
168
|
+
logging_aggregated_stats_interval: int = Field(
|
|
169
|
+
default=100, description="Logging aggregated stats interval"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
61
173
|
class AdapterConfig(BaseModel):
|
|
62
174
|
"""Adapter configuration with registry-based interceptor support"""
|
|
63
175
|
|
|
176
|
+
mode: str = Field(
|
|
177
|
+
description="Adapter mode: 'server' (default) or 'client'",
|
|
178
|
+
default="server",
|
|
179
|
+
)
|
|
64
180
|
discovery: DiscoveryConfig = Field(
|
|
65
181
|
description="Configuration for discovering 3rd party modules and directories",
|
|
66
182
|
default_factory=DiscoveryConfig,
|
|
@@ -82,48 +198,6 @@ class AdapterConfig(BaseModel):
|
|
|
82
198
|
default=False,
|
|
83
199
|
)
|
|
84
200
|
|
|
85
|
-
@classmethod
|
|
86
|
-
def get_legacy_defaults(cls) -> dict[str, Any]:
|
|
87
|
-
"""Get default values for legacy configuration parameters."""
|
|
88
|
-
return {
|
|
89
|
-
"generate_html_report": True,
|
|
90
|
-
"html_report_size": 5,
|
|
91
|
-
"tracking_requests_stats": True,
|
|
92
|
-
"caching_dir": None,
|
|
93
|
-
"log_failed_requests": cls.model_fields["log_failed_requests"].default,
|
|
94
|
-
"endpoint_type": cls.model_fields["endpoint_type"].default,
|
|
95
|
-
# Boolean defaults for optional features
|
|
96
|
-
"use_caching": True,
|
|
97
|
-
"save_responses": False,
|
|
98
|
-
"save_requests": False,
|
|
99
|
-
"use_system_prompt": False,
|
|
100
|
-
"use_omni_info": False,
|
|
101
|
-
"use_request_logging": False,
|
|
102
|
-
"use_nvcf": False,
|
|
103
|
-
"use_response_logging": False,
|
|
104
|
-
"use_reasoning": False,
|
|
105
|
-
"process_reasoning_traces": False,
|
|
106
|
-
"use_progress_tracking": False,
|
|
107
|
-
"use_raise_client_errors": False,
|
|
108
|
-
"include_json": True,
|
|
109
|
-
"custom_system_prompt": None,
|
|
110
|
-
"output_dir": None,
|
|
111
|
-
"params_to_add": None,
|
|
112
|
-
"params_to_remove": None,
|
|
113
|
-
"params_to_rename": None,
|
|
114
|
-
"max_logged_requests": None,
|
|
115
|
-
"max_logged_responses": None,
|
|
116
|
-
"max_saved_requests": None,
|
|
117
|
-
"max_saved_responses": None,
|
|
118
|
-
"start_reasoning_token": None,
|
|
119
|
-
"include_if_reasoning_not_finished": None,
|
|
120
|
-
"track_reasoning": None,
|
|
121
|
-
"end_reasoning_token": "</think>",
|
|
122
|
-
"progress_tracking_url": None,
|
|
123
|
-
"progress_tracking_interval": 1,
|
|
124
|
-
"logging_aggregated_stats_interval": 100,
|
|
125
|
-
}
|
|
126
|
-
|
|
127
201
|
@classmethod
|
|
128
202
|
def get_validated_config(cls, run_config: dict[str, Any]) -> "AdapterConfig":
|
|
129
203
|
"""Extract and validate adapter configuration from run_config.
|
|
@@ -156,9 +230,9 @@ class AdapterConfig(BaseModel):
|
|
|
156
230
|
)
|
|
157
231
|
|
|
158
232
|
# Validate that legacy parameters are not mixed with interceptors
|
|
159
|
-
|
|
233
|
+
legacy_params = set(LegacyAdapterConfig.model_fields.keys())
|
|
160
234
|
model_fields = set(cls.model_fields.keys())
|
|
161
|
-
legacy_only_params =
|
|
235
|
+
legacy_only_params = legacy_params - model_fields
|
|
162
236
|
|
|
163
237
|
for config_name, config in [
|
|
164
238
|
("global_adapter_config", global_cfg),
|
|
@@ -203,14 +277,19 @@ class AdapterConfig(BaseModel):
|
|
|
203
277
|
{"name": s} if isinstance(s, str) else s
|
|
204
278
|
for s in merged["post_eval_hooks"]
|
|
205
279
|
]
|
|
280
|
+
|
|
206
281
|
try:
|
|
207
282
|
config = cls(**merged)
|
|
208
283
|
|
|
209
284
|
# If no interceptors are configured, try to convert from legacy format
|
|
210
285
|
if not config.interceptors:
|
|
286
|
+
# Pass mode through merged config so it's preserved in legacy conversion
|
|
211
287
|
config = cls.from_legacy_config(merged, run_config)
|
|
212
288
|
|
|
213
289
|
return config
|
|
290
|
+
except ValidationError:
|
|
291
|
+
# Re-raise ValidationError directly for clear error messages
|
|
292
|
+
raise
|
|
214
293
|
except Exception as e:
|
|
215
294
|
raise ValueError(f"Invalid adapter configuration: {e}") from e
|
|
216
295
|
|
|
@@ -274,10 +353,29 @@ class AdapterConfig(BaseModel):
|
|
|
274
353
|
|
|
275
354
|
Returns:
|
|
276
355
|
AdapterConfig instance with interceptors based on legacy config
|
|
356
|
+
|
|
357
|
+
Raises:
|
|
358
|
+
ValidationError: If legacy_config contains typos or invalid field names
|
|
277
359
|
"""
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
360
|
+
logger = get_logger(__name__)
|
|
361
|
+
|
|
362
|
+
# Validate legacy config using Pydantic model (catches typos early)
|
|
363
|
+
# Filter out modern fields (discovery, interceptors, post_eval_hooks) before validation
|
|
364
|
+
modern_fields = {"discovery", "interceptors", "post_eval_hooks"}
|
|
365
|
+
legacy_only = {k: v for k, v in legacy_config.items() if k not in modern_fields}
|
|
366
|
+
|
|
367
|
+
try:
|
|
368
|
+
validated = LegacyAdapterConfig(**legacy_only)
|
|
369
|
+
legacy_config = validated.model_dump()
|
|
370
|
+
except ValidationError:
|
|
371
|
+
# Log helpful message with list of valid fields
|
|
372
|
+
valid_fields = sorted(LegacyAdapterConfig.model_fields.keys())
|
|
373
|
+
logger.error(
|
|
374
|
+
f"Invalid legacy adapter configuration. "
|
|
375
|
+
f"Supported parameters: {', '.join(valid_fields)}"
|
|
376
|
+
)
|
|
377
|
+
# Re-raise the original ValidationError
|
|
378
|
+
raise
|
|
281
379
|
|
|
282
380
|
interceptors = []
|
|
283
381
|
post_eval_hooks = []
|
|
@@ -474,8 +572,6 @@ class AdapterConfig(BaseModel):
|
|
|
474
572
|
)
|
|
475
573
|
|
|
476
574
|
if legacy_config["use_reasoning"]:
|
|
477
|
-
from nemo_evaluator.logging import get_logger
|
|
478
|
-
|
|
479
575
|
logger = get_logger(__name__)
|
|
480
576
|
logger.warning(
|
|
481
577
|
'"use_reasoning" is deprecated as it might suggest it touches on switching on/off reasoning for mode when it does not. Use "process_reasoning_traces" instead.'
|
|
@@ -543,7 +639,6 @@ class AdapterConfig(BaseModel):
|
|
|
543
639
|
from nemo_evaluator.adapters.interceptors.raise_client_error_interceptor import (
|
|
544
640
|
RaiseClientErrorInterceptor,
|
|
545
641
|
)
|
|
546
|
-
from nemo_evaluator.logging import get_logger
|
|
547
642
|
|
|
548
643
|
logger = get_logger(__name__)
|
|
549
644
|
|
|
@@ -583,6 +678,7 @@ class AdapterConfig(BaseModel):
|
|
|
583
678
|
)
|
|
584
679
|
|
|
585
680
|
return cls(
|
|
681
|
+
mode=legacy_config["mode"],
|
|
586
682
|
interceptors=interceptors,
|
|
587
683
|
post_eval_hooks=post_eval_hooks,
|
|
588
684
|
endpoint_type=legacy_config["endpoint_type"],
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
"""Endpoint interceptor that makes actual requests to the upstream API."""
|
|
17
17
|
|
|
18
|
+
import json
|
|
18
19
|
import time
|
|
19
20
|
from typing import final
|
|
20
21
|
|
|
@@ -80,15 +81,43 @@ class EndpointInterceptor(RequestToResponseInterceptor):
|
|
|
80
81
|
start_time = time.time()
|
|
81
82
|
|
|
82
83
|
# This is a final interceptor, we'll need the flask_request and api
|
|
84
|
+
raw_response = requests.request(
|
|
85
|
+
method=ar.r.method,
|
|
86
|
+
url=context.url,
|
|
87
|
+
headers={k: v for k, v in ar.r.headers if k.lower() != "host"},
|
|
88
|
+
json=ar.r.json,
|
|
89
|
+
cookies=ar.r.cookies,
|
|
90
|
+
allow_redirects=False,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# replace choices[xx].message.content=None with empty string
|
|
94
|
+
if raw_response.content is not None:
|
|
95
|
+
try:
|
|
96
|
+
response_json = json.loads(raw_response.content)
|
|
97
|
+
if (
|
|
98
|
+
"choices" in response_json
|
|
99
|
+
and isinstance(response_json["choices"], list)
|
|
100
|
+
and len(response_json["choices"]) > 0
|
|
101
|
+
):
|
|
102
|
+
for i, choice in enumerate(response_json["choices"]):
|
|
103
|
+
if (
|
|
104
|
+
"message" in choice
|
|
105
|
+
and "content" in choice["message"]
|
|
106
|
+
and choice["message"]["content"] is None
|
|
107
|
+
):
|
|
108
|
+
self.logger.warning(
|
|
109
|
+
f"choices[{i}].message.content is None, replacing with empty string"
|
|
110
|
+
)
|
|
111
|
+
choice["message"]["content"] = ""
|
|
112
|
+
raw_response._content = json.dumps(response_json).encode("utf-8")
|
|
113
|
+
except (json.JSONDecodeError, TypeError, KeyError) as e:
|
|
114
|
+
# If JSON parsing fails or unexpected structure, leave response unchanged
|
|
115
|
+
self.logger.debug(
|
|
116
|
+
"Could not parse response as JSON, leaving unchanged", error=str(e)
|
|
117
|
+
)
|
|
118
|
+
|
|
83
119
|
resp = AdapterResponse(
|
|
84
|
-
r=
|
|
85
|
-
method=ar.r.method,
|
|
86
|
-
url=context.url,
|
|
87
|
-
headers={k: v for k, v in ar.r.headers if k.lower() != "host"},
|
|
88
|
-
json=ar.r.json,
|
|
89
|
-
cookies=ar.r.cookies,
|
|
90
|
-
allow_redirects=False,
|
|
91
|
-
),
|
|
120
|
+
r=raw_response,
|
|
92
121
|
rctx=ar.rctx,
|
|
93
122
|
latency_ms=round(
|
|
94
123
|
(time.time() - start_time) * 1000, 2
|
|
@@ -127,6 +127,8 @@ class ResponseReasoningInterceptor(ResponseInterceptor, PostEvalHook):
|
|
|
127
127
|
"responses_with_reasoning": 0,
|
|
128
128
|
"reasoning_finished_count": 0,
|
|
129
129
|
"reasoning_started_count": 0,
|
|
130
|
+
"reasoning_unfinished_count": 0,
|
|
131
|
+
"reasoning_finished_ratio": 0,
|
|
130
132
|
"avg_reasoning_words": None,
|
|
131
133
|
"avg_original_content_words": None,
|
|
132
134
|
"avg_updated_content_words": None,
|
|
@@ -281,12 +283,18 @@ class ResponseReasoningInterceptor(ResponseInterceptor, PostEvalHook):
|
|
|
281
283
|
)
|
|
282
284
|
|
|
283
285
|
# Increment counters
|
|
284
|
-
if
|
|
286
|
+
if (
|
|
287
|
+
reasoning_words == "unknown"
|
|
288
|
+
and reasoning_info.get("reasoning_started") is True
|
|
289
|
+
) or (isinstance(reasoning_words, int) and reasoning_words > 0):
|
|
290
|
+
# if reasoning started but not finished, or finished and we have non-zero reasoning words
|
|
285
291
|
self._reasoning_stats["responses_with_reasoning"] += 1
|
|
286
|
-
if reasoning_info.get("reasoning_started"):
|
|
292
|
+
if reasoning_info.get("reasoning_started") is True:
|
|
287
293
|
self._reasoning_stats["reasoning_started_count"] += 1
|
|
288
|
-
|
|
289
|
-
|
|
294
|
+
if reasoning_info.get("reasoning_finished"):
|
|
295
|
+
self._reasoning_stats["reasoning_finished_count"] += 1
|
|
296
|
+
else:
|
|
297
|
+
self._reasoning_stats["reasoning_unfinished_count"] += 1
|
|
290
298
|
|
|
291
299
|
# Update running averages
|
|
292
300
|
for stat_key, value in [
|
|
@@ -340,6 +348,13 @@ class ResponseReasoningInterceptor(ResponseInterceptor, PostEvalHook):
|
|
|
340
348
|
updated_content_tokens
|
|
341
349
|
)
|
|
342
350
|
|
|
351
|
+
# Update ratio
|
|
352
|
+
if self._reasoning_stats["responses_with_reasoning"]:
|
|
353
|
+
self._reasoning_stats["reasoning_finished_ratio"] = (
|
|
354
|
+
self._reasoning_stats["reasoning_finished_count"]
|
|
355
|
+
/ self._reasoning_stats["responses_with_reasoning"]
|
|
356
|
+
)
|
|
357
|
+
|
|
343
358
|
# Log aggregated stats at specified interval
|
|
344
359
|
if (
|
|
345
360
|
self._reasoning_stats["total_responses"]
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""Shared adapter pipeline logic used by both server and client modes."""
|
|
17
|
+
|
|
18
|
+
from typing import List
|
|
19
|
+
|
|
20
|
+
from nemo_evaluator.adapters.adapter_config import AdapterConfig
|
|
21
|
+
from nemo_evaluator.adapters.registry import InterceptorRegistry
|
|
22
|
+
from nemo_evaluator.adapters.types import (
|
|
23
|
+
AdapterGlobalContext,
|
|
24
|
+
AdapterRequest,
|
|
25
|
+
AdapterResponse,
|
|
26
|
+
FatalErrorException,
|
|
27
|
+
PostEvalHook,
|
|
28
|
+
RequestInterceptor,
|
|
29
|
+
RequestToResponseInterceptor,
|
|
30
|
+
ResponseInterceptor,
|
|
31
|
+
)
|
|
32
|
+
from nemo_evaluator.logging import get_logger
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class AdapterPipeline:
|
|
38
|
+
"""Shared adapter pipeline that processes requests/responses through interceptors.
|
|
39
|
+
|
|
40
|
+
This class encapsulates the core adapter logic that is used by both:
|
|
41
|
+
- Server mode (AdapterServer with Flask)
|
|
42
|
+
- Client mode (AdapterTransport with httpx)
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
adapter_config: AdapterConfig,
|
|
48
|
+
output_dir: str,
|
|
49
|
+
model_name: str | None = None,
|
|
50
|
+
):
|
|
51
|
+
"""Initialize the adapter pipeline.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
adapter_config: Adapter configuration with interceptors and hooks
|
|
55
|
+
output_dir: Directory for output files
|
|
56
|
+
model_name: Optional model name for logging context
|
|
57
|
+
"""
|
|
58
|
+
self.adapter_config = adapter_config
|
|
59
|
+
self.output_dir = output_dir
|
|
60
|
+
self.model_name = model_name
|
|
61
|
+
|
|
62
|
+
# Initialize interceptor chain and hooks
|
|
63
|
+
self.interceptor_chain: List[
|
|
64
|
+
RequestInterceptor | RequestToResponseInterceptor | ResponseInterceptor
|
|
65
|
+
] = []
|
|
66
|
+
self.post_eval_hooks: List[PostEvalHook] = []
|
|
67
|
+
self._post_eval_hooks_executed: bool = False
|
|
68
|
+
|
|
69
|
+
# Initialize registry and discover components
|
|
70
|
+
self.registry = InterceptorRegistry.get_instance()
|
|
71
|
+
self.registry.discover_components(
|
|
72
|
+
modules=adapter_config.discovery.modules,
|
|
73
|
+
dirs=adapter_config.discovery.dirs,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Validate and build chains
|
|
77
|
+
self._validate_and_build_chains()
|
|
78
|
+
|
|
79
|
+
def _validate_and_build_chains(self) -> None:
|
|
80
|
+
"""Validate configuration and build interceptor chains."""
|
|
81
|
+
try:
|
|
82
|
+
# Check if adapter chain is properly defined
|
|
83
|
+
self._validate_adapter_chain_definition()
|
|
84
|
+
|
|
85
|
+
# Validate interceptor order
|
|
86
|
+
self._validate_interceptor_order()
|
|
87
|
+
|
|
88
|
+
# Build the chains
|
|
89
|
+
self._build_interceptor_chain()
|
|
90
|
+
self._build_post_eval_hooks()
|
|
91
|
+
|
|
92
|
+
except Exception as e:
|
|
93
|
+
logger.error(f"Failed to build interceptor chains: {e}")
|
|
94
|
+
raise
|
|
95
|
+
|
|
96
|
+
def _validate_adapter_chain_definition(self) -> None:
|
|
97
|
+
"""Validate that the adapter chain is properly defined with at least one enabled component."""
|
|
98
|
+
enabled_interceptors = [
|
|
99
|
+
ic for ic in self.adapter_config.interceptors if ic.enabled
|
|
100
|
+
]
|
|
101
|
+
enabled_post_eval_hooks = [
|
|
102
|
+
hook for hook in self.adapter_config.post_eval_hooks if hook.enabled
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
if not enabled_interceptors and not enabled_post_eval_hooks:
|
|
106
|
+
warning_msg = (
|
|
107
|
+
"Adapter pipeline cannot start: No enabled interceptors or "
|
|
108
|
+
"post-eval hooks found. The pipeline requires at least one enabled "
|
|
109
|
+
"interceptor or post-eval hook to function properly. "
|
|
110
|
+
f"Configured interceptors: "
|
|
111
|
+
f"{[ic.name for ic in self.adapter_config.interceptors]}, "
|
|
112
|
+
f"Configured post-eval hooks: "
|
|
113
|
+
f"{[hook.name for hook in self.adapter_config.post_eval_hooks]}"
|
|
114
|
+
)
|
|
115
|
+
logger.warning(warning_msg)
|
|
116
|
+
raise RuntimeError(warning_msg)
|
|
117
|
+
|
|
118
|
+
def _validate_interceptor_order(self) -> None:
|
|
119
|
+
"""Validate that the configured interceptor list follows the correct stage order.
|
|
120
|
+
|
|
121
|
+
The order must be: Request -> RequestToResponse -> Response
|
|
122
|
+
"""
|
|
123
|
+
# Define stage hierarchy and allowed transitions
|
|
124
|
+
STAGE_ORDER = ["request", "request_to_response", "response"]
|
|
125
|
+
current_stage_idx = 0
|
|
126
|
+
|
|
127
|
+
for interceptor_config in self.adapter_config.interceptors:
|
|
128
|
+
if not interceptor_config.enabled:
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
metadata = self.registry.get_metadata(interceptor_config.name)
|
|
132
|
+
if metadata is None:
|
|
133
|
+
raise ValueError(f"Unknown interceptor: {interceptor_config.name}")
|
|
134
|
+
|
|
135
|
+
# Determine the stage of this interceptor
|
|
136
|
+
if metadata.supports_request_to_response_interception():
|
|
137
|
+
interceptor_stage = "request_to_response"
|
|
138
|
+
elif metadata.supports_request_interception():
|
|
139
|
+
interceptor_stage = "request"
|
|
140
|
+
elif metadata.supports_response_interception():
|
|
141
|
+
interceptor_stage = "response"
|
|
142
|
+
else:
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"Interceptor {interceptor_config.name} doesn't implement any known interface"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Find the stage index
|
|
148
|
+
try:
|
|
149
|
+
stage_idx = STAGE_ORDER.index(interceptor_stage)
|
|
150
|
+
except ValueError:
|
|
151
|
+
raise ValueError(f"Unknown stage: {interceptor_stage}")
|
|
152
|
+
|
|
153
|
+
# Validate progression: can only move forward or stay at same stage
|
|
154
|
+
if stage_idx < current_stage_idx:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Invalid stage order: interceptor {interceptor_config.name} (stage: {interceptor_stage}) "
|
|
157
|
+
f"appears after {STAGE_ORDER[current_stage_idx]} stage. "
|
|
158
|
+
f"Expected order: Request -> RequestToResponse -> Response"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Update current stage if we've moved forward
|
|
162
|
+
current_stage_idx = max(current_stage_idx, stage_idx)
|
|
163
|
+
|
|
164
|
+
def _build_interceptor_chain(self) -> None:
|
|
165
|
+
"""Build interceptor chain from validated configuration."""
|
|
166
|
+
self.interceptor_chain = []
|
|
167
|
+
for interceptor_config in self.adapter_config.interceptors:
|
|
168
|
+
if interceptor_config.enabled:
|
|
169
|
+
interceptor = self.registry._get_or_create_instance(
|
|
170
|
+
interceptor_config.name,
|
|
171
|
+
interceptor_config.config,
|
|
172
|
+
)
|
|
173
|
+
self.interceptor_chain.append(interceptor)
|
|
174
|
+
|
|
175
|
+
logger.info(
|
|
176
|
+
"Built interceptor chain",
|
|
177
|
+
interceptors=[type(i).__name__ for i in self.interceptor_chain],
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def _build_post_eval_hooks(self) -> None:
|
|
181
|
+
"""Build post-evaluation hooks from validated configuration."""
|
|
182
|
+
self.post_eval_hooks = []
|
|
183
|
+
|
|
184
|
+
# Add configured post-eval hooks
|
|
185
|
+
for hook_config in self.adapter_config.post_eval_hooks:
|
|
186
|
+
if hook_config.enabled:
|
|
187
|
+
hook = self.registry._get_or_create_instance(
|
|
188
|
+
hook_config.name, hook_config.config
|
|
189
|
+
)
|
|
190
|
+
self.post_eval_hooks.append(hook)
|
|
191
|
+
|
|
192
|
+
# Also add interceptors that implement PostEvalHook
|
|
193
|
+
for interceptor in self.interceptor_chain:
|
|
194
|
+
if hasattr(interceptor, "post_eval_hook") and callable(
|
|
195
|
+
getattr(interceptor, "post_eval_hook")
|
|
196
|
+
):
|
|
197
|
+
self.post_eval_hooks.append(interceptor)
|
|
198
|
+
|
|
199
|
+
logger.info(
|
|
200
|
+
"Built post-eval hooks",
|
|
201
|
+
hooks=[type(h).__name__ for h in self.post_eval_hooks],
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def process_request(
|
|
205
|
+
self, adapter_request: AdapterRequest, global_context: AdapterGlobalContext
|
|
206
|
+
) -> tuple[AdapterRequest, AdapterResponse | None]:
|
|
207
|
+
"""Process request through the interceptor chain.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
adapter_request: The request to process
|
|
211
|
+
global_context: Global context for the request
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Tuple of (modified_request, optional_response)
|
|
215
|
+
- If an interceptor returns a response, it's returned as the second element
|
|
216
|
+
- Otherwise, the second element is None and the first is the modified request
|
|
217
|
+
"""
|
|
218
|
+
current_request = adapter_request
|
|
219
|
+
request_logger = get_logger()
|
|
220
|
+
|
|
221
|
+
for interceptor in self.interceptor_chain:
|
|
222
|
+
try:
|
|
223
|
+
if isinstance(
|
|
224
|
+
interceptor, (RequestInterceptor, RequestToResponseInterceptor)
|
|
225
|
+
):
|
|
226
|
+
result = interceptor.intercept_request(
|
|
227
|
+
current_request, global_context
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# If interceptor returns a response, we're done with request processing
|
|
231
|
+
if isinstance(result, AdapterResponse):
|
|
232
|
+
return current_request, result
|
|
233
|
+
else:
|
|
234
|
+
current_request = result
|
|
235
|
+
else:
|
|
236
|
+
# This is a ResponseInterceptor, skip in request phase
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
except FatalErrorException:
|
|
240
|
+
# Re-raise fatal errors
|
|
241
|
+
raise
|
|
242
|
+
except Exception as e:
|
|
243
|
+
request_logger.error(
|
|
244
|
+
f"Request interceptor {type(interceptor).__name__} failed: {e}"
|
|
245
|
+
)
|
|
246
|
+
# Continue with next interceptor
|
|
247
|
+
continue
|
|
248
|
+
|
|
249
|
+
return current_request, None
|
|
250
|
+
|
|
251
|
+
def process_response(
|
|
252
|
+
self, adapter_response: AdapterResponse, global_context: AdapterGlobalContext
|
|
253
|
+
) -> AdapterResponse:
|
|
254
|
+
"""Process response through the interceptor chain (in reverse order).
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
adapter_response: The response to process
|
|
258
|
+
global_context: Global context for the response
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
Modified response after processing through all response interceptors
|
|
262
|
+
"""
|
|
263
|
+
current_response = adapter_response
|
|
264
|
+
request_logger = get_logger()
|
|
265
|
+
|
|
266
|
+
for interceptor in reversed(self.interceptor_chain):
|
|
267
|
+
try:
|
|
268
|
+
if isinstance(interceptor, ResponseInterceptor):
|
|
269
|
+
current_response = interceptor.intercept_response(
|
|
270
|
+
current_response, global_context
|
|
271
|
+
)
|
|
272
|
+
except FatalErrorException:
|
|
273
|
+
# Re-raise fatal errors
|
|
274
|
+
raise
|
|
275
|
+
except Exception as e:
|
|
276
|
+
request_logger.error(
|
|
277
|
+
f"Response interceptor {type(interceptor).__name__} failed: {e}"
|
|
278
|
+
)
|
|
279
|
+
# Continue with next interceptor
|
|
280
|
+
continue
|
|
281
|
+
|
|
282
|
+
return current_response
|
|
283
|
+
|
|
284
|
+
def run_post_eval_hooks(self, url: str = "") -> None:
|
|
285
|
+
"""Run all configured post-evaluation hooks.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
url: Optional URL for global context (not always relevant)
|
|
289
|
+
"""
|
|
290
|
+
if self._post_eval_hooks_executed:
|
|
291
|
+
logger.warning("Post-eval hooks have already been executed, skipping")
|
|
292
|
+
return
|
|
293
|
+
|
|
294
|
+
global_context = AdapterGlobalContext(
|
|
295
|
+
output_dir=self.output_dir,
|
|
296
|
+
url=url,
|
|
297
|
+
model_name=self.model_name,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
for hook in self.post_eval_hooks:
|
|
301
|
+
try:
|
|
302
|
+
hook.post_eval_hook(global_context)
|
|
303
|
+
logger.info(f"Successfully ran post-eval hook: {type(hook).__name__}")
|
|
304
|
+
except Exception as e:
|
|
305
|
+
logger.error(f"Post-eval hook {type(hook).__name__} failed: {e}")
|
|
306
|
+
# Continue with other hooks
|
|
307
|
+
continue
|
|
308
|
+
|
|
309
|
+
self._post_eval_hooks_executed = True
|
|
310
|
+
logger.info("Post-eval hooks execution completed")
|