nemo-evaluator-launcher 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nemo-evaluator-launcher might be problematic. Click here for more details.
- nemo_evaluator_launcher/__init__.py +79 -0
- nemo_evaluator_launcher/api/__init__.py +24 -0
- nemo_evaluator_launcher/api/functional.py +698 -0
- nemo_evaluator_launcher/api/types.py +98 -0
- nemo_evaluator_launcher/api/utils.py +19 -0
- nemo_evaluator_launcher/cli/__init__.py +15 -0
- nemo_evaluator_launcher/cli/export.py +267 -0
- nemo_evaluator_launcher/cli/info.py +512 -0
- nemo_evaluator_launcher/cli/kill.py +41 -0
- nemo_evaluator_launcher/cli/ls_runs.py +134 -0
- nemo_evaluator_launcher/cli/ls_tasks.py +136 -0
- nemo_evaluator_launcher/cli/main.py +226 -0
- nemo_evaluator_launcher/cli/run.py +200 -0
- nemo_evaluator_launcher/cli/status.py +164 -0
- nemo_evaluator_launcher/cli/version.py +55 -0
- nemo_evaluator_launcher/common/__init__.py +16 -0
- nemo_evaluator_launcher/common/execdb.py +283 -0
- nemo_evaluator_launcher/common/helpers.py +366 -0
- nemo_evaluator_launcher/common/logging_utils.py +357 -0
- nemo_evaluator_launcher/common/mapping.py +295 -0
- nemo_evaluator_launcher/common/printing_utils.py +93 -0
- nemo_evaluator_launcher/configs/__init__.py +15 -0
- nemo_evaluator_launcher/configs/default.yaml +28 -0
- nemo_evaluator_launcher/configs/deployment/generic.yaml +33 -0
- nemo_evaluator_launcher/configs/deployment/nim.yaml +32 -0
- nemo_evaluator_launcher/configs/deployment/none.yaml +16 -0
- nemo_evaluator_launcher/configs/deployment/sglang.yaml +38 -0
- nemo_evaluator_launcher/configs/deployment/trtllm.yaml +24 -0
- nemo_evaluator_launcher/configs/deployment/vllm.yaml +42 -0
- nemo_evaluator_launcher/configs/execution/lepton/default.yaml +92 -0
- nemo_evaluator_launcher/configs/execution/local.yaml +19 -0
- nemo_evaluator_launcher/configs/execution/slurm/default.yaml +34 -0
- nemo_evaluator_launcher/executors/__init__.py +22 -0
- nemo_evaluator_launcher/executors/base.py +120 -0
- nemo_evaluator_launcher/executors/lepton/__init__.py +16 -0
- nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +609 -0
- nemo_evaluator_launcher/executors/lepton/executor.py +1004 -0
- nemo_evaluator_launcher/executors/lepton/job_helpers.py +398 -0
- nemo_evaluator_launcher/executors/local/__init__.py +15 -0
- nemo_evaluator_launcher/executors/local/executor.py +605 -0
- nemo_evaluator_launcher/executors/local/run.template.sh +103 -0
- nemo_evaluator_launcher/executors/registry.py +38 -0
- nemo_evaluator_launcher/executors/slurm/__init__.py +15 -0
- nemo_evaluator_launcher/executors/slurm/executor.py +1147 -0
- nemo_evaluator_launcher/exporters/__init__.py +36 -0
- nemo_evaluator_launcher/exporters/base.py +121 -0
- nemo_evaluator_launcher/exporters/gsheets.py +409 -0
- nemo_evaluator_launcher/exporters/local.py +502 -0
- nemo_evaluator_launcher/exporters/mlflow.py +619 -0
- nemo_evaluator_launcher/exporters/registry.py +40 -0
- nemo_evaluator_launcher/exporters/utils.py +624 -0
- nemo_evaluator_launcher/exporters/wandb.py +490 -0
- nemo_evaluator_launcher/package_info.py +38 -0
- nemo_evaluator_launcher/resources/mapping.toml +380 -0
- nemo_evaluator_launcher-0.1.28.dist-info/METADATA +494 -0
- nemo_evaluator_launcher-0.1.28.dist-info/RECORD +60 -0
- nemo_evaluator_launcher-0.1.28.dist-info/WHEEL +5 -0
- nemo_evaluator_launcher-0.1.28.dist-info/entry_points.txt +3 -0
- nemo_evaluator_launcher-0.1.28.dist-info/licenses/LICENSE +451 -0
- nemo_evaluator_launcher-0.1.28.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
#
|
|
16
|
+
"""Logging configuration module for nemo-evaluator-launcher.
|
|
17
|
+
|
|
18
|
+
This module provides a centralized logging configuration using structlog that outputs
|
|
19
|
+
to both stderr and a log file. All modules should import and use the logger from this
|
|
20
|
+
module to ensure consistent logging behavior across the application.
|
|
21
|
+
|
|
22
|
+
LOGGING POLICY:
|
|
23
|
+
==============
|
|
24
|
+
All logging in this project MUST go through this module. This is enforced by a pre-commit
|
|
25
|
+
hook that checks for violations.
|
|
26
|
+
|
|
27
|
+
DO NOT:
|
|
28
|
+
- import structlog directly
|
|
29
|
+
- import logging directly
|
|
30
|
+
- call structlog.get_logger() directly
|
|
31
|
+
- call logging.getLogger() directly
|
|
32
|
+
|
|
33
|
+
DO:
|
|
34
|
+
- from nemo_evaluator_launcher.common.logging_utils import logger
|
|
35
|
+
- from nemo_evaluator_launcher.common.logging_utils import get_logger
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
# Correct
|
|
39
|
+
from nemo_evaluator_launcher.common.logging_utils import logger
|
|
40
|
+
logger.info("User logged in", user_id="12345")
|
|
41
|
+
|
|
42
|
+
# Incorrect
|
|
43
|
+
import structlog
|
|
44
|
+
logger = structlog.get_logger()
|
|
45
|
+
logger.info("User logged in")
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
import json
|
|
49
|
+
import logging
|
|
50
|
+
import logging.config
|
|
51
|
+
import os
|
|
52
|
+
import pathlib
|
|
53
|
+
import sys
|
|
54
|
+
from datetime import datetime
|
|
55
|
+
from pprint import pformat
|
|
56
|
+
from typing import Any, Dict
|
|
57
|
+
|
|
58
|
+
import structlog
|
|
59
|
+
|
|
60
|
+
# If this env var is set, it will override a more standard "LOG_LEVEL". If
|
|
61
|
+
# both are unset, default would be used.
|
|
62
|
+
_LOG_LEVEL_ENV_VAR = "NEMO_EVALUATOR_LOG_LEVEL"
|
|
63
|
+
_DEFAULT_LOG_LEVEL = "WARNING"
|
|
64
|
+
_SENSITIVE_KEY_SUBSTRINGS_NORMALIZED = {
|
|
65
|
+
# Keep minimal, broad substrings
|
|
66
|
+
# NOTE: normalized: lowercased, no spaces/_/-
|
|
67
|
+
"authorization", # covers proxy-authorization, etc.
|
|
68
|
+
"apikey", # covers api_key, api-key, x-api-key, nvidia_api_key, ...
|
|
69
|
+
"accesskey", # covers access_key / access-key
|
|
70
|
+
"privatekey",
|
|
71
|
+
"token", # covers access_token, id_token, refresh_token, *_token
|
|
72
|
+
"secret", # covers openai_client_secret, aws_secret_access_key, *_secret
|
|
73
|
+
"password",
|
|
74
|
+
"pwd", # common shorthand
|
|
75
|
+
"passwd", # common variant
|
|
76
|
+
}
|
|
77
|
+
_ALLOWLISTED_KEYS_SUBSTRINGS = {
|
|
78
|
+
# NOTE: non-normalized (for allowlisting we want more control)
|
|
79
|
+
"_tokens", # This likely would allow us to not redact useful stuff like `limit_tokens`, `max_new_tokens`
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _mask(val: object) -> str:
|
|
84
|
+
s = str(val)
|
|
85
|
+
if len(s) <= 10:
|
|
86
|
+
return "[REDACTED]"
|
|
87
|
+
return f"{s[:2]}…{s[-2:]}"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _normalize(name: object) -> str:
|
|
91
|
+
if not isinstance(name, str):
|
|
92
|
+
return ""
|
|
93
|
+
s = name.lower()
|
|
94
|
+
# drop spaces, hyphens, underscores
|
|
95
|
+
return s.replace(" ", "").replace("-", "").replace("_", "")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _is_sensitive_key(key: object) -> bool:
|
|
99
|
+
k_norm = _normalize(key)
|
|
100
|
+
k_non_norm = str(key)
|
|
101
|
+
return any(
|
|
102
|
+
substr in k_norm for substr in _SENSITIVE_KEY_SUBSTRINGS_NORMALIZED
|
|
103
|
+
) and not any(substr in k_non_norm for substr in _ALLOWLISTED_KEYS_SUBSTRINGS)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _redact_mapping(m: dict) -> dict:
|
|
107
|
+
red = {}
|
|
108
|
+
for k, v in m.items():
|
|
109
|
+
if _is_sensitive_key(k):
|
|
110
|
+
red[k] = _mask(v)
|
|
111
|
+
elif isinstance(v, dict):
|
|
112
|
+
red[k] = _redact_mapping(v)
|
|
113
|
+
else:
|
|
114
|
+
red[k] = v
|
|
115
|
+
return red
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def redact_processor(_: Any, __: str, event_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|
119
|
+
if os.getenv("LOG_DISABLE_REDACTION", "").lower() in {"1", "true", "yes"}:
|
|
120
|
+
return event_dict
|
|
121
|
+
return _redact_mapping(event_dict)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _ensure_log_dir() -> pathlib.Path:
|
|
125
|
+
"""Ensure the log directory exists and return its path."""
|
|
126
|
+
log_dir = pathlib.Path.home() / ".nemo-evaluator" / "logs"
|
|
127
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
128
|
+
return log_dir
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _get_env_log_level() -> str:
|
|
132
|
+
"""Get log level from environment variable, translating single letters to full names.
|
|
133
|
+
|
|
134
|
+
Translates:
|
|
135
|
+
- D -> DEBUG
|
|
136
|
+
- I -> INFO
|
|
137
|
+
- W -> WARNING
|
|
138
|
+
- E -> ERROR
|
|
139
|
+
- F -> CRITICAL
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Uppercase log level string, defaults to WARNING if not set or invalid.
|
|
143
|
+
"""
|
|
144
|
+
env_level = os.getenv(_LOG_LEVEL_ENV_VAR, os.getenv("LOG_LEVEL"))
|
|
145
|
+
# If empty or unset, default
|
|
146
|
+
if not env_level:
|
|
147
|
+
env_level = _DEFAULT_LOG_LEVEL
|
|
148
|
+
env_level = env_level.upper()
|
|
149
|
+
|
|
150
|
+
# Translate single letters to full level names
|
|
151
|
+
level_map = {
|
|
152
|
+
"D": "DEBUG",
|
|
153
|
+
"I": "INFO",
|
|
154
|
+
"W": "WARNING",
|
|
155
|
+
"E": "ERROR",
|
|
156
|
+
"F": "CRITICAL",
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
return level_map.get(env_level, env_level)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def custom_timestamper(_: Any, __: Any, event_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|
163
|
+
"""Add ISO UTC timestamp with milliseconds to event_dict['timestamp']."""
|
|
164
|
+
now = datetime.now()
|
|
165
|
+
event_dict["timestamp"] = now.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
|
166
|
+
return event_dict
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class MainConsoleRenderer:
|
|
170
|
+
"""Custom console renderer for [L TIMESTAMP] message with color by level."""
|
|
171
|
+
|
|
172
|
+
LEVEL_MAP = {
|
|
173
|
+
"debug": ("D", "\033[90m"), # grey
|
|
174
|
+
"info": ("I", "\033[32m"), # green
|
|
175
|
+
"warning": ("W", "\033[33m"), # yellow
|
|
176
|
+
"warn": ("W", "\033[33m"), # yellow
|
|
177
|
+
"error": ("E", "\033[31m"), # red
|
|
178
|
+
"critical": ("F", "\033[41m"), # red background
|
|
179
|
+
"fatal": ("F", "\033[41m"), # alias for critical
|
|
180
|
+
}
|
|
181
|
+
RESET = "\033[0m"
|
|
182
|
+
|
|
183
|
+
def __init__(self, colors: bool = True):
|
|
184
|
+
self.colors = colors
|
|
185
|
+
|
|
186
|
+
def __call__(
|
|
187
|
+
self, logger: Any, method_name: str, event_dict: Dict[str, Any]
|
|
188
|
+
) -> str:
|
|
189
|
+
timestamp = event_dict.get("timestamp", "")
|
|
190
|
+
message = event_dict.get("event", "")
|
|
191
|
+
level = event_dict.get("level", method_name).lower()
|
|
192
|
+
letter, color = self.LEVEL_MAP.get(level, ("?", ""))
|
|
193
|
+
prefix = f"[{letter} {timestamp}]"
|
|
194
|
+
if self.colors and color:
|
|
195
|
+
prefix = f"{color}{prefix}{self.RESET}"
|
|
196
|
+
|
|
197
|
+
# Build the output with message and key-value pairs
|
|
198
|
+
output_parts = [prefix]
|
|
199
|
+
|
|
200
|
+
# Make the main message bold
|
|
201
|
+
if self.colors:
|
|
202
|
+
message = f"\033[1m{message}\033[0m" # bold
|
|
203
|
+
output_parts.append(message)
|
|
204
|
+
|
|
205
|
+
# Add key-value pairs (excluding internal structlog keys)
|
|
206
|
+
kv_pairs = []
|
|
207
|
+
for key, value in event_dict.items():
|
|
208
|
+
if key not in ["timestamp", "event", "level"]:
|
|
209
|
+
# Pretty-format complex values (dict/list) as JSON on new lines
|
|
210
|
+
pretty_value = None
|
|
211
|
+
if isinstance(value, (dict, list)):
|
|
212
|
+
try:
|
|
213
|
+
pretty_value = json.dumps(
|
|
214
|
+
value, ensure_ascii=False, sort_keys=True, indent=2
|
|
215
|
+
)
|
|
216
|
+
except Exception:
|
|
217
|
+
pretty_value = pformat(value, width=100, compact=False)
|
|
218
|
+
elif not isinstance(value, (str, int, float, bool, type(None))):
|
|
219
|
+
# Fall back to reasonably readable representation for other complex types
|
|
220
|
+
pretty_value = pformat(value, width=100, compact=False)
|
|
221
|
+
|
|
222
|
+
rendered_value = (
|
|
223
|
+
pretty_value if pretty_value is not None else str(value)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# If multiline, place value on a new line for readability
|
|
227
|
+
if "\n" in rendered_value:
|
|
228
|
+
if self.colors:
|
|
229
|
+
kv_pairs.append(
|
|
230
|
+
f"\033[35m{key}\033[0m=\n\033[36m{rendered_value}\033[0m"
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
kv_pairs.append(f"{key}=\n{rendered_value}")
|
|
234
|
+
else:
|
|
235
|
+
if self.colors:
|
|
236
|
+
# Format: magenta key + equals + cyan value
|
|
237
|
+
kv_pairs.append(
|
|
238
|
+
f"\033[35m{key}\033[0m=\033[36m{rendered_value}\033[0m"
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
# No colors for plain output
|
|
242
|
+
kv_pairs.append(f"{key}={rendered_value}")
|
|
243
|
+
|
|
244
|
+
if kv_pairs:
|
|
245
|
+
# If any kv is multiline, join with newlines; otherwise keep single line
|
|
246
|
+
if any("\n" in kv for kv in kv_pairs):
|
|
247
|
+
kv_text = "\n".join(kv_pairs)
|
|
248
|
+
else:
|
|
249
|
+
kv_text = " ".join(kv_pairs)
|
|
250
|
+
if self.colors:
|
|
251
|
+
kv_text = f"\033[35m{kv_text}{self.RESET}" # magenta
|
|
252
|
+
output_parts.append(kv_text)
|
|
253
|
+
|
|
254
|
+
return " ".join(output_parts)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _configure_structlog() -> None:
|
|
258
|
+
"""Configure structlog for both console and file output."""
|
|
259
|
+
log_dir = _ensure_log_dir()
|
|
260
|
+
log_file = log_dir / "main.log"
|
|
261
|
+
json_log_file = log_dir / "main.log.json"
|
|
262
|
+
|
|
263
|
+
shared_processors = [
|
|
264
|
+
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
|
265
|
+
redact_processor,
|
|
266
|
+
custom_timestamper,
|
|
267
|
+
structlog.stdlib.add_log_level,
|
|
268
|
+
structlog.processors.StackInfoRenderer(),
|
|
269
|
+
structlog.dev.set_exc_info,
|
|
270
|
+
structlog.processors.format_exc_info,
|
|
271
|
+
structlog.processors.UnicodeDecoder(),
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
logging.config.dictConfig(
|
|
275
|
+
{
|
|
276
|
+
"version": 1,
|
|
277
|
+
"disable_existing_loggers": False,
|
|
278
|
+
"formatters": {
|
|
279
|
+
# Formatter for colored console output
|
|
280
|
+
"colored": {
|
|
281
|
+
"()": "structlog.stdlib.ProcessorFormatter",
|
|
282
|
+
"processors": [
|
|
283
|
+
*shared_processors,
|
|
284
|
+
MainConsoleRenderer(colors=True),
|
|
285
|
+
],
|
|
286
|
+
},
|
|
287
|
+
# Formatter for plain file output
|
|
288
|
+
"plain": {
|
|
289
|
+
"()": "structlog.stdlib.ProcessorFormatter",
|
|
290
|
+
"processors": [
|
|
291
|
+
*shared_processors,
|
|
292
|
+
MainConsoleRenderer(colors=False),
|
|
293
|
+
],
|
|
294
|
+
},
|
|
295
|
+
# Formatter for JSON file output
|
|
296
|
+
"json": {
|
|
297
|
+
"()": "structlog.stdlib.ProcessorFormatter",
|
|
298
|
+
"processors": [
|
|
299
|
+
*shared_processors,
|
|
300
|
+
structlog.processors.JSONRenderer(),
|
|
301
|
+
],
|
|
302
|
+
},
|
|
303
|
+
},
|
|
304
|
+
"handlers": {
|
|
305
|
+
"console": {
|
|
306
|
+
"class": "logging.StreamHandler",
|
|
307
|
+
"level": _get_env_log_level(),
|
|
308
|
+
"formatter": "colored",
|
|
309
|
+
"stream": sys.stderr,
|
|
310
|
+
},
|
|
311
|
+
"file": {
|
|
312
|
+
"class": "logging.handlers.WatchedFileHandler",
|
|
313
|
+
"level": "DEBUG",
|
|
314
|
+
"filename": log_file,
|
|
315
|
+
"formatter": "plain",
|
|
316
|
+
},
|
|
317
|
+
"json_file": {
|
|
318
|
+
"class": "logging.handlers.WatchedFileHandler",
|
|
319
|
+
"level": "DEBUG",
|
|
320
|
+
"filename": json_log_file,
|
|
321
|
+
"formatter": "json",
|
|
322
|
+
},
|
|
323
|
+
},
|
|
324
|
+
"loggers": {
|
|
325
|
+
"": {
|
|
326
|
+
"handlers": ["console", "file", "json_file"],
|
|
327
|
+
"level": "DEBUG",
|
|
328
|
+
"propagate": True,
|
|
329
|
+
},
|
|
330
|
+
},
|
|
331
|
+
}
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
structlog.configure(
|
|
335
|
+
processors=[
|
|
336
|
+
structlog.stdlib.filter_by_level,
|
|
337
|
+
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
|
338
|
+
],
|
|
339
|
+
logger_factory=structlog.stdlib.LoggerFactory(),
|
|
340
|
+
wrapper_class=structlog.stdlib.BoundLogger,
|
|
341
|
+
cache_logger_on_first_use=True,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
structlog.get_logger().debug("Logger configured", config=structlog.get_config())
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
# Configure logging on module import
|
|
348
|
+
_configure_structlog()
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def get_logger(name: str | None = None) -> Any:
|
|
352
|
+
"""Get a configured structlog logger."""
|
|
353
|
+
return structlog.get_logger(name)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
# Export the root logger for convenience
|
|
357
|
+
logger = get_logger()
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
#
|
|
16
|
+
import importlib
|
|
17
|
+
import pathlib
|
|
18
|
+
import sys
|
|
19
|
+
from importlib import resources
|
|
20
|
+
from typing import Any, Optional
|
|
21
|
+
|
|
22
|
+
import requests
|
|
23
|
+
|
|
24
|
+
if sys.version_info >= (3, 11):
|
|
25
|
+
import tomllib
|
|
26
|
+
else:
|
|
27
|
+
import tomli as tomllib
|
|
28
|
+
|
|
29
|
+
from nemo_evaluator_launcher.common.logging_utils import logger
|
|
30
|
+
|
|
31
|
+
# Configuration constants
|
|
32
|
+
# For below, see docs: https://docs.github.com/en/rest/repos/contents
|
|
33
|
+
MAPPING_URL = "https://raw.githubusercontent.com/NVIDIA-NeMo/Eval/main/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/resources/mapping.toml"
|
|
34
|
+
CACHE_DIR = pathlib.Path.home() / ".nemo-evaluator" / "cache"
|
|
35
|
+
CACHE_FILENAME = "mapping.toml"
|
|
36
|
+
INTERNAL_RESOURCES_PKG = "nemo_evaluator_launcher_internal.resources"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _ensure_cache_dir() -> None:
|
|
40
|
+
"""Ensure the cache directory exists."""
|
|
41
|
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _get_cache_file() -> pathlib.Path:
|
|
45
|
+
"""Get the cache file path.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
pathlib.Path: Path to the cache file.
|
|
49
|
+
"""
|
|
50
|
+
return CACHE_DIR / CACHE_FILENAME
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _download_latest_mapping() -> Optional[bytes]:
|
|
54
|
+
"""Download latest mapping from MAPPING_URL and return raw bytes.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Optional[bytes]: Downloaded mapping bytes, or None if download fails.
|
|
58
|
+
"""
|
|
59
|
+
try:
|
|
60
|
+
response = requests.get(MAPPING_URL, timeout=10)
|
|
61
|
+
response.raise_for_status()
|
|
62
|
+
|
|
63
|
+
# For GitHub raw URLs, the response content is the file content directly
|
|
64
|
+
mapping_bytes = response.content
|
|
65
|
+
assert isinstance(mapping_bytes, bytes)
|
|
66
|
+
|
|
67
|
+
logger.debug("Successfully downloaded mapping from remote URL")
|
|
68
|
+
return mapping_bytes
|
|
69
|
+
except (requests.RequestException, OSError) as e:
|
|
70
|
+
logger.warning("Failed to download mapping from remote URL", error=str(e))
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _load_cached_mapping() -> Optional[dict[Any, Any]]:
|
|
75
|
+
"""Load mapping from cache file.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Optional[dict]: Loaded mapping data, or None if loading fails.
|
|
79
|
+
"""
|
|
80
|
+
cache_file = _get_cache_file()
|
|
81
|
+
if not cache_file.exists():
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
with open(cache_file, "rb") as f:
|
|
86
|
+
mapping = tomllib.load(f)
|
|
87
|
+
logger.debug("Loaded mapping from cache")
|
|
88
|
+
return mapping # type: ignore[no-any-return]
|
|
89
|
+
except (OSError, tomllib.TOMLDecodeError) as e:
|
|
90
|
+
logger.warning("Failed to load mapping from cache", error=str(e))
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _save_mapping_to_cache(mapping_bytes: bytes) -> None:
|
|
95
|
+
"""Save mapping to cache file.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
mapping_bytes: Mapping data to save.
|
|
99
|
+
"""
|
|
100
|
+
try:
|
|
101
|
+
_ensure_cache_dir()
|
|
102
|
+
cache_file = _get_cache_file()
|
|
103
|
+
|
|
104
|
+
# Save the mapping data
|
|
105
|
+
with open(cache_file, "wb") as f:
|
|
106
|
+
f.write(mapping_bytes)
|
|
107
|
+
|
|
108
|
+
except OSError as e:
|
|
109
|
+
logger.warning("Failed to save mapping to cache", error=str(e))
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _load_packaged_resource(
|
|
113
|
+
resource_name: str, pkg_name: str = "nemo_evaluator_launcher.resources"
|
|
114
|
+
) -> dict[str, Any]:
|
|
115
|
+
"""Load a resource from the packaged resources.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
resource_name: The name of the resource to load.
|
|
119
|
+
"""
|
|
120
|
+
try:
|
|
121
|
+
resource_toml: dict[str, Any] = {}
|
|
122
|
+
with resources.files(pkg_name).joinpath(resource_name).open("rb") as f:
|
|
123
|
+
resource_toml = tomllib.load(f)
|
|
124
|
+
logger.info(
|
|
125
|
+
"Loaded resource from packaged file", resource=resource_name, pkg=pkg_name
|
|
126
|
+
)
|
|
127
|
+
return resource_toml
|
|
128
|
+
except (OSError, tomllib.TOMLDecodeError) as e:
|
|
129
|
+
logger.error(
|
|
130
|
+
"Failed to load from packaged file",
|
|
131
|
+
resource=resource_name,
|
|
132
|
+
pkg=pkg_name,
|
|
133
|
+
error=str(e),
|
|
134
|
+
)
|
|
135
|
+
raise RuntimeError(f"Failed to load {resource_name} from packaged file") from e
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _process_mapping(mapping_toml: dict) -> dict:
|
|
139
|
+
"""Process the raw mapping TOML into the expected format.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
mapping_toml: Raw mapping TOML data.
|
|
143
|
+
Returns:
|
|
144
|
+
dict: Processed mapping in the expected format.
|
|
145
|
+
"""
|
|
146
|
+
mapping = {}
|
|
147
|
+
for harness_name, harness_data in mapping_toml.items():
|
|
148
|
+
assert isinstance(harness_data["tasks"], dict)
|
|
149
|
+
for endpoint_type, harness_tasks in harness_data["tasks"].items():
|
|
150
|
+
assert isinstance(harness_tasks, dict)
|
|
151
|
+
for task_name, task_data in harness_tasks.items():
|
|
152
|
+
assert isinstance(task_data, dict)
|
|
153
|
+
key = (harness_name, task_name)
|
|
154
|
+
if key in mapping:
|
|
155
|
+
raise KeyError(
|
|
156
|
+
f"(harness,task)-tuple key {repr(key)} already exists in the mapping"
|
|
157
|
+
)
|
|
158
|
+
mapping[key] = {
|
|
159
|
+
"task": task_name,
|
|
160
|
+
"harness": harness_name,
|
|
161
|
+
"container": harness_data["container"],
|
|
162
|
+
"endpoint_type": endpoint_type,
|
|
163
|
+
}
|
|
164
|
+
for task_data_key in task_data.keys():
|
|
165
|
+
if task_data_key in mapping[key]:
|
|
166
|
+
raise KeyError(
|
|
167
|
+
f"{repr(task_data_key)} is not allowed as key under {repr(key)} in the mapping"
|
|
168
|
+
)
|
|
169
|
+
mapping[key].update(task_data)
|
|
170
|
+
return mapping
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def load_tasks_mapping(
|
|
174
|
+
latest: bool = False,
|
|
175
|
+
mapping_toml: pathlib.Path | str | None = None,
|
|
176
|
+
) -> dict[tuple[str, str], dict]:
|
|
177
|
+
"""Load tasks mapping.
|
|
178
|
+
|
|
179
|
+
The function obeys the following priority rules:
|
|
180
|
+
1. (Default) If latest==False and mapping_toml is None -> load packaged mapping.
|
|
181
|
+
2. If latest==True -> fetch MAPPING_URL, save to cache, load it.
|
|
182
|
+
3. If mapping_toml is not None -> load mapping from this path.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
dict: Mapping of (harness_name, task_name) to dict holding their configuration.
|
|
186
|
+
|
|
187
|
+
"""
|
|
188
|
+
local_mapping: dict = {}
|
|
189
|
+
if latest:
|
|
190
|
+
mapping_bytes = _download_latest_mapping()
|
|
191
|
+
if mapping_bytes:
|
|
192
|
+
_save_mapping_to_cache(mapping_bytes)
|
|
193
|
+
local_mapping = _process_mapping(
|
|
194
|
+
tomllib.loads(mapping_bytes.decode("utf-8"))
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
# Fallback to cached mapping; raise only if cache is missing/invalid
|
|
198
|
+
cached = _load_cached_mapping()
|
|
199
|
+
if cached:
|
|
200
|
+
local_mapping = _process_mapping(cached)
|
|
201
|
+
else:
|
|
202
|
+
raise RuntimeError("could not download latest mapping")
|
|
203
|
+
|
|
204
|
+
elif mapping_toml is not None:
|
|
205
|
+
with open(mapping_toml, "rb") as f:
|
|
206
|
+
local_mapping = _process_mapping(tomllib.load(f))
|
|
207
|
+
else:
|
|
208
|
+
local_mapping = _process_mapping(_load_packaged_resource(CACHE_FILENAME))
|
|
209
|
+
|
|
210
|
+
# TODO: make more elegant. We consider it ok to avoid a fully-blown plugin system.
|
|
211
|
+
# Check if nemo_evaluator_launcher_internal is available and load its mapping.toml
|
|
212
|
+
# CAVEAT: lazy-loading here, not somewhere top level, is important, to ensure
|
|
213
|
+
# order of package initialization.
|
|
214
|
+
try:
|
|
215
|
+
importlib.import_module("nemo_evaluator_launcher_internal")
|
|
216
|
+
logger.debug("Internal package available, loading internal mapping")
|
|
217
|
+
internal_mapping = _process_mapping(
|
|
218
|
+
_load_packaged_resource(CACHE_FILENAME, INTERNAL_RESOURCES_PKG)
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Merge internal mapping with local mapping (internal takes precedence)
|
|
222
|
+
local_mapping.update(internal_mapping)
|
|
223
|
+
logger.info(
|
|
224
|
+
"Successfully merged internal mapping", internal_tasks=len(internal_mapping)
|
|
225
|
+
)
|
|
226
|
+
except ImportError:
|
|
227
|
+
logger.debug("Internal package not available, using external mapping only")
|
|
228
|
+
except Exception as e:
|
|
229
|
+
logger.warning("Failed to load internal mapping", error=str(e))
|
|
230
|
+
|
|
231
|
+
return local_mapping
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def get_task_from_mapping(query: str, mapping: dict[Any, Any]) -> dict[Any, Any]:
|
|
235
|
+
"""Unambiguously selects one task from the mapping based on the query.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
query: Either `task_name` or `harness_name.task_name`.
|
|
239
|
+
mapping: The object returned from `load_tasks_mapping` function.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
dict: Task data.
|
|
243
|
+
|
|
244
|
+
"""
|
|
245
|
+
num_dots = query.count(".")
|
|
246
|
+
|
|
247
|
+
# if there are no dots in query, treat it like a task name
|
|
248
|
+
if num_dots == 0:
|
|
249
|
+
matching_keys = [key for key in mapping.keys() if key[1] == query]
|
|
250
|
+
# if exactly one task matching the query has been found:
|
|
251
|
+
if len(matching_keys) == 1:
|
|
252
|
+
key = matching_keys[0]
|
|
253
|
+
return mapping[key] # type: ignore[no-any-return]
|
|
254
|
+
# if more than one task matching the query has been found:
|
|
255
|
+
elif len(matching_keys) > 1:
|
|
256
|
+
matching_queries = [
|
|
257
|
+
f"{harness_name}.{task_name}"
|
|
258
|
+
for harness_name, task_name in matching_keys
|
|
259
|
+
]
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"there are multiple tasks named {repr(query)} in the mapping,"
|
|
262
|
+
f" please select one of {repr(matching_queries)}"
|
|
263
|
+
)
|
|
264
|
+
# no tasks have been found:
|
|
265
|
+
else:
|
|
266
|
+
raise ValueError(f"task {repr(query)} does not exist in the mapping")
|
|
267
|
+
|
|
268
|
+
# if there is one dot in query, treat it like "{harness_name}.{task_name}"
|
|
269
|
+
elif num_dots == 1:
|
|
270
|
+
harness_name, task_name = query.split(".")
|
|
271
|
+
matching_keys = [
|
|
272
|
+
key for key in mapping.keys() if key == (harness_name, task_name)
|
|
273
|
+
]
|
|
274
|
+
# if exactly one task matching the query has been found:
|
|
275
|
+
if len(matching_keys) == 1:
|
|
276
|
+
key = matching_keys[0]
|
|
277
|
+
return mapping[key] # type: ignore[no-any-return]
|
|
278
|
+
# if more than one task matching the query has been found:
|
|
279
|
+
elif len(matching_keys) >= 2:
|
|
280
|
+
raise ValueError(
|
|
281
|
+
f"there are multiple matches for {repr(query)} in the mapping,"
|
|
282
|
+
" which means the mapping is not correct"
|
|
283
|
+
)
|
|
284
|
+
# no tasks have been found:
|
|
285
|
+
else:
|
|
286
|
+
raise ValueError(
|
|
287
|
+
f"harness.task {repr(query)} does not exist in the mapping"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# invalid query
|
|
291
|
+
else:
|
|
292
|
+
raise ValueError(
|
|
293
|
+
f"invalid query={repr(query)} for task mapping,"
|
|
294
|
+
" it must contain exactly zero or one occurrence of '.' character"
|
|
295
|
+
)
|