nemo-evaluator-launcher 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nemo-evaluator-launcher might be problematic. Click here for more details.

Files changed (60) hide show
  1. nemo_evaluator_launcher/__init__.py +79 -0
  2. nemo_evaluator_launcher/api/__init__.py +24 -0
  3. nemo_evaluator_launcher/api/functional.py +698 -0
  4. nemo_evaluator_launcher/api/types.py +98 -0
  5. nemo_evaluator_launcher/api/utils.py +19 -0
  6. nemo_evaluator_launcher/cli/__init__.py +15 -0
  7. nemo_evaluator_launcher/cli/export.py +267 -0
  8. nemo_evaluator_launcher/cli/info.py +512 -0
  9. nemo_evaluator_launcher/cli/kill.py +41 -0
  10. nemo_evaluator_launcher/cli/ls_runs.py +134 -0
  11. nemo_evaluator_launcher/cli/ls_tasks.py +136 -0
  12. nemo_evaluator_launcher/cli/main.py +226 -0
  13. nemo_evaluator_launcher/cli/run.py +200 -0
  14. nemo_evaluator_launcher/cli/status.py +164 -0
  15. nemo_evaluator_launcher/cli/version.py +55 -0
  16. nemo_evaluator_launcher/common/__init__.py +16 -0
  17. nemo_evaluator_launcher/common/execdb.py +283 -0
  18. nemo_evaluator_launcher/common/helpers.py +366 -0
  19. nemo_evaluator_launcher/common/logging_utils.py +357 -0
  20. nemo_evaluator_launcher/common/mapping.py +295 -0
  21. nemo_evaluator_launcher/common/printing_utils.py +93 -0
  22. nemo_evaluator_launcher/configs/__init__.py +15 -0
  23. nemo_evaluator_launcher/configs/default.yaml +28 -0
  24. nemo_evaluator_launcher/configs/deployment/generic.yaml +33 -0
  25. nemo_evaluator_launcher/configs/deployment/nim.yaml +32 -0
  26. nemo_evaluator_launcher/configs/deployment/none.yaml +16 -0
  27. nemo_evaluator_launcher/configs/deployment/sglang.yaml +38 -0
  28. nemo_evaluator_launcher/configs/deployment/trtllm.yaml +24 -0
  29. nemo_evaluator_launcher/configs/deployment/vllm.yaml +42 -0
  30. nemo_evaluator_launcher/configs/execution/lepton/default.yaml +92 -0
  31. nemo_evaluator_launcher/configs/execution/local.yaml +19 -0
  32. nemo_evaluator_launcher/configs/execution/slurm/default.yaml +34 -0
  33. nemo_evaluator_launcher/executors/__init__.py +22 -0
  34. nemo_evaluator_launcher/executors/base.py +120 -0
  35. nemo_evaluator_launcher/executors/lepton/__init__.py +16 -0
  36. nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +609 -0
  37. nemo_evaluator_launcher/executors/lepton/executor.py +1004 -0
  38. nemo_evaluator_launcher/executors/lepton/job_helpers.py +398 -0
  39. nemo_evaluator_launcher/executors/local/__init__.py +15 -0
  40. nemo_evaluator_launcher/executors/local/executor.py +605 -0
  41. nemo_evaluator_launcher/executors/local/run.template.sh +103 -0
  42. nemo_evaluator_launcher/executors/registry.py +38 -0
  43. nemo_evaluator_launcher/executors/slurm/__init__.py +15 -0
  44. nemo_evaluator_launcher/executors/slurm/executor.py +1147 -0
  45. nemo_evaluator_launcher/exporters/__init__.py +36 -0
  46. nemo_evaluator_launcher/exporters/base.py +121 -0
  47. nemo_evaluator_launcher/exporters/gsheets.py +409 -0
  48. nemo_evaluator_launcher/exporters/local.py +502 -0
  49. nemo_evaluator_launcher/exporters/mlflow.py +619 -0
  50. nemo_evaluator_launcher/exporters/registry.py +40 -0
  51. nemo_evaluator_launcher/exporters/utils.py +624 -0
  52. nemo_evaluator_launcher/exporters/wandb.py +490 -0
  53. nemo_evaluator_launcher/package_info.py +38 -0
  54. nemo_evaluator_launcher/resources/mapping.toml +380 -0
  55. nemo_evaluator_launcher-0.1.28.dist-info/METADATA +494 -0
  56. nemo_evaluator_launcher-0.1.28.dist-info/RECORD +60 -0
  57. nemo_evaluator_launcher-0.1.28.dist-info/WHEEL +5 -0
  58. nemo_evaluator_launcher-0.1.28.dist-info/entry_points.txt +3 -0
  59. nemo_evaluator_launcher-0.1.28.dist-info/licenses/LICENSE +451 -0
  60. nemo_evaluator_launcher-0.1.28.dist-info/top_level.txt +1 -0
@@ -0,0 +1,357 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ """Logging configuration module for nemo-evaluator-launcher.
17
+
18
+ This module provides a centralized logging configuration using structlog that outputs
19
+ to both stderr and a log file. All modules should import and use the logger from this
20
+ module to ensure consistent logging behavior across the application.
21
+
22
+ LOGGING POLICY:
23
+ ==============
24
+ All logging in this project MUST go through this module. This is enforced by a pre-commit
25
+ hook that checks for violations.
26
+
27
+ DO NOT:
28
+ - import structlog directly
29
+ - import logging directly
30
+ - call structlog.get_logger() directly
31
+ - call logging.getLogger() directly
32
+
33
+ DO:
34
+ - from nemo_evaluator_launcher.common.logging_utils import logger
35
+ - from nemo_evaluator_launcher.common.logging_utils import get_logger
36
+
37
+ Examples:
38
+ # Correct
39
+ from nemo_evaluator_launcher.common.logging_utils import logger
40
+ logger.info("User logged in", user_id="12345")
41
+
42
+ # Incorrect
43
+ import structlog
44
+ logger = structlog.get_logger()
45
+ logger.info("User logged in")
46
+ """
47
+
48
+ import json
49
+ import logging
50
+ import logging.config
51
+ import os
52
+ import pathlib
53
+ import sys
54
+ from datetime import datetime
55
+ from pprint import pformat
56
+ from typing import Any, Dict
57
+
58
+ import structlog
59
+
60
+ # If this env var is set, it will override a more standard "LOG_LEVEL". If
61
+ # both are unset, default would be used.
62
+ _LOG_LEVEL_ENV_VAR = "NEMO_EVALUATOR_LOG_LEVEL"
63
+ _DEFAULT_LOG_LEVEL = "WARNING"
64
+ _SENSITIVE_KEY_SUBSTRINGS_NORMALIZED = {
65
+ # Keep minimal, broad substrings
66
+ # NOTE: normalized: lowercased, no spaces/_/-
67
+ "authorization", # covers proxy-authorization, etc.
68
+ "apikey", # covers api_key, api-key, x-api-key, nvidia_api_key, ...
69
+ "accesskey", # covers access_key / access-key
70
+ "privatekey",
71
+ "token", # covers access_token, id_token, refresh_token, *_token
72
+ "secret", # covers openai_client_secret, aws_secret_access_key, *_secret
73
+ "password",
74
+ "pwd", # common shorthand
75
+ "passwd", # common variant
76
+ }
77
+ _ALLOWLISTED_KEYS_SUBSTRINGS = {
78
+ # NOTE: non-normalized (for allowlisting we want more control)
79
+ "_tokens", # This likely would allow us to not redact useful stuff like `limit_tokens`, `max_new_tokens`
80
+ }
81
+
82
+
83
+ def _mask(val: object) -> str:
84
+ s = str(val)
85
+ if len(s) <= 10:
86
+ return "[REDACTED]"
87
+ return f"{s[:2]}…{s[-2:]}"
88
+
89
+
90
+ def _normalize(name: object) -> str:
91
+ if not isinstance(name, str):
92
+ return ""
93
+ s = name.lower()
94
+ # drop spaces, hyphens, underscores
95
+ return s.replace(" ", "").replace("-", "").replace("_", "")
96
+
97
+
98
+ def _is_sensitive_key(key: object) -> bool:
99
+ k_norm = _normalize(key)
100
+ k_non_norm = str(key)
101
+ return any(
102
+ substr in k_norm for substr in _SENSITIVE_KEY_SUBSTRINGS_NORMALIZED
103
+ ) and not any(substr in k_non_norm for substr in _ALLOWLISTED_KEYS_SUBSTRINGS)
104
+
105
+
106
+ def _redact_mapping(m: dict) -> dict:
107
+ red = {}
108
+ for k, v in m.items():
109
+ if _is_sensitive_key(k):
110
+ red[k] = _mask(v)
111
+ elif isinstance(v, dict):
112
+ red[k] = _redact_mapping(v)
113
+ else:
114
+ red[k] = v
115
+ return red
116
+
117
+
118
+ def redact_processor(_: Any, __: str, event_dict: Dict[str, Any]) -> Dict[str, Any]:
119
+ if os.getenv("LOG_DISABLE_REDACTION", "").lower() in {"1", "true", "yes"}:
120
+ return event_dict
121
+ return _redact_mapping(event_dict)
122
+
123
+
124
+ def _ensure_log_dir() -> pathlib.Path:
125
+ """Ensure the log directory exists and return its path."""
126
+ log_dir = pathlib.Path.home() / ".nemo-evaluator" / "logs"
127
+ log_dir.mkdir(parents=True, exist_ok=True)
128
+ return log_dir
129
+
130
+
131
+ def _get_env_log_level() -> str:
132
+ """Get log level from environment variable, translating single letters to full names.
133
+
134
+ Translates:
135
+ - D -> DEBUG
136
+ - I -> INFO
137
+ - W -> WARNING
138
+ - E -> ERROR
139
+ - F -> CRITICAL
140
+
141
+ Returns:
142
+ Uppercase log level string, defaults to WARNING if not set or invalid.
143
+ """
144
+ env_level = os.getenv(_LOG_LEVEL_ENV_VAR, os.getenv("LOG_LEVEL"))
145
+ # If empty or unset, default
146
+ if not env_level:
147
+ env_level = _DEFAULT_LOG_LEVEL
148
+ env_level = env_level.upper()
149
+
150
+ # Translate single letters to full level names
151
+ level_map = {
152
+ "D": "DEBUG",
153
+ "I": "INFO",
154
+ "W": "WARNING",
155
+ "E": "ERROR",
156
+ "F": "CRITICAL",
157
+ }
158
+
159
+ return level_map.get(env_level, env_level)
160
+
161
+
162
+ def custom_timestamper(_: Any, __: Any, event_dict: Dict[str, Any]) -> Dict[str, Any]:
163
+ """Add ISO UTC timestamp with milliseconds to event_dict['timestamp']."""
164
+ now = datetime.now()
165
+ event_dict["timestamp"] = now.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
166
+ return event_dict
167
+
168
+
169
+ class MainConsoleRenderer:
170
+ """Custom console renderer for [L TIMESTAMP] message with color by level."""
171
+
172
+ LEVEL_MAP = {
173
+ "debug": ("D", "\033[90m"), # grey
174
+ "info": ("I", "\033[32m"), # green
175
+ "warning": ("W", "\033[33m"), # yellow
176
+ "warn": ("W", "\033[33m"), # yellow
177
+ "error": ("E", "\033[31m"), # red
178
+ "critical": ("F", "\033[41m"), # red background
179
+ "fatal": ("F", "\033[41m"), # alias for critical
180
+ }
181
+ RESET = "\033[0m"
182
+
183
+ def __init__(self, colors: bool = True):
184
+ self.colors = colors
185
+
186
+ def __call__(
187
+ self, logger: Any, method_name: str, event_dict: Dict[str, Any]
188
+ ) -> str:
189
+ timestamp = event_dict.get("timestamp", "")
190
+ message = event_dict.get("event", "")
191
+ level = event_dict.get("level", method_name).lower()
192
+ letter, color = self.LEVEL_MAP.get(level, ("?", ""))
193
+ prefix = f"[{letter} {timestamp}]"
194
+ if self.colors and color:
195
+ prefix = f"{color}{prefix}{self.RESET}"
196
+
197
+ # Build the output with message and key-value pairs
198
+ output_parts = [prefix]
199
+
200
+ # Make the main message bold
201
+ if self.colors:
202
+ message = f"\033[1m{message}\033[0m" # bold
203
+ output_parts.append(message)
204
+
205
+ # Add key-value pairs (excluding internal structlog keys)
206
+ kv_pairs = []
207
+ for key, value in event_dict.items():
208
+ if key not in ["timestamp", "event", "level"]:
209
+ # Pretty-format complex values (dict/list) as JSON on new lines
210
+ pretty_value = None
211
+ if isinstance(value, (dict, list)):
212
+ try:
213
+ pretty_value = json.dumps(
214
+ value, ensure_ascii=False, sort_keys=True, indent=2
215
+ )
216
+ except Exception:
217
+ pretty_value = pformat(value, width=100, compact=False)
218
+ elif not isinstance(value, (str, int, float, bool, type(None))):
219
+ # Fall back to reasonably readable representation for other complex types
220
+ pretty_value = pformat(value, width=100, compact=False)
221
+
222
+ rendered_value = (
223
+ pretty_value if pretty_value is not None else str(value)
224
+ )
225
+
226
+ # If multiline, place value on a new line for readability
227
+ if "\n" in rendered_value:
228
+ if self.colors:
229
+ kv_pairs.append(
230
+ f"\033[35m{key}\033[0m=\n\033[36m{rendered_value}\033[0m"
231
+ )
232
+ else:
233
+ kv_pairs.append(f"{key}=\n{rendered_value}")
234
+ else:
235
+ if self.colors:
236
+ # Format: magenta key + equals + cyan value
237
+ kv_pairs.append(
238
+ f"\033[35m{key}\033[0m=\033[36m{rendered_value}\033[0m"
239
+ )
240
+ else:
241
+ # No colors for plain output
242
+ kv_pairs.append(f"{key}={rendered_value}")
243
+
244
+ if kv_pairs:
245
+ # If any kv is multiline, join with newlines; otherwise keep single line
246
+ if any("\n" in kv for kv in kv_pairs):
247
+ kv_text = "\n".join(kv_pairs)
248
+ else:
249
+ kv_text = " ".join(kv_pairs)
250
+ if self.colors:
251
+ kv_text = f"\033[35m{kv_text}{self.RESET}" # magenta
252
+ output_parts.append(kv_text)
253
+
254
+ return " ".join(output_parts)
255
+
256
+
257
+ def _configure_structlog() -> None:
258
+ """Configure structlog for both console and file output."""
259
+ log_dir = _ensure_log_dir()
260
+ log_file = log_dir / "main.log"
261
+ json_log_file = log_dir / "main.log.json"
262
+
263
+ shared_processors = [
264
+ structlog.stdlib.ProcessorFormatter.remove_processors_meta,
265
+ redact_processor,
266
+ custom_timestamper,
267
+ structlog.stdlib.add_log_level,
268
+ structlog.processors.StackInfoRenderer(),
269
+ structlog.dev.set_exc_info,
270
+ structlog.processors.format_exc_info,
271
+ structlog.processors.UnicodeDecoder(),
272
+ ]
273
+
274
+ logging.config.dictConfig(
275
+ {
276
+ "version": 1,
277
+ "disable_existing_loggers": False,
278
+ "formatters": {
279
+ # Formatter for colored console output
280
+ "colored": {
281
+ "()": "structlog.stdlib.ProcessorFormatter",
282
+ "processors": [
283
+ *shared_processors,
284
+ MainConsoleRenderer(colors=True),
285
+ ],
286
+ },
287
+ # Formatter for plain file output
288
+ "plain": {
289
+ "()": "structlog.stdlib.ProcessorFormatter",
290
+ "processors": [
291
+ *shared_processors,
292
+ MainConsoleRenderer(colors=False),
293
+ ],
294
+ },
295
+ # Formatter for JSON file output
296
+ "json": {
297
+ "()": "structlog.stdlib.ProcessorFormatter",
298
+ "processors": [
299
+ *shared_processors,
300
+ structlog.processors.JSONRenderer(),
301
+ ],
302
+ },
303
+ },
304
+ "handlers": {
305
+ "console": {
306
+ "class": "logging.StreamHandler",
307
+ "level": _get_env_log_level(),
308
+ "formatter": "colored",
309
+ "stream": sys.stderr,
310
+ },
311
+ "file": {
312
+ "class": "logging.handlers.WatchedFileHandler",
313
+ "level": "DEBUG",
314
+ "filename": log_file,
315
+ "formatter": "plain",
316
+ },
317
+ "json_file": {
318
+ "class": "logging.handlers.WatchedFileHandler",
319
+ "level": "DEBUG",
320
+ "filename": json_log_file,
321
+ "formatter": "json",
322
+ },
323
+ },
324
+ "loggers": {
325
+ "": {
326
+ "handlers": ["console", "file", "json_file"],
327
+ "level": "DEBUG",
328
+ "propagate": True,
329
+ },
330
+ },
331
+ }
332
+ )
333
+
334
+ structlog.configure(
335
+ processors=[
336
+ structlog.stdlib.filter_by_level,
337
+ structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
338
+ ],
339
+ logger_factory=structlog.stdlib.LoggerFactory(),
340
+ wrapper_class=structlog.stdlib.BoundLogger,
341
+ cache_logger_on_first_use=True,
342
+ )
343
+
344
+ structlog.get_logger().debug("Logger configured", config=structlog.get_config())
345
+
346
+
347
+ # Configure logging on module import
348
+ _configure_structlog()
349
+
350
+
351
+ def get_logger(name: str | None = None) -> Any:
352
+ """Get a configured structlog logger."""
353
+ return structlog.get_logger(name)
354
+
355
+
356
+ # Export the root logger for convenience
357
+ logger = get_logger()
@@ -0,0 +1,295 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ import importlib
17
+ import pathlib
18
+ import sys
19
+ from importlib import resources
20
+ from typing import Any, Optional
21
+
22
+ import requests
23
+
24
+ if sys.version_info >= (3, 11):
25
+ import tomllib
26
+ else:
27
+ import tomli as tomllib
28
+
29
+ from nemo_evaluator_launcher.common.logging_utils import logger
30
+
31
+ # Configuration constants
32
+ # For below, see docs: https://docs.github.com/en/rest/repos/contents
33
+ MAPPING_URL = "https://raw.githubusercontent.com/NVIDIA-NeMo/Eval/main/packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/resources/mapping.toml"
34
+ CACHE_DIR = pathlib.Path.home() / ".nemo-evaluator" / "cache"
35
+ CACHE_FILENAME = "mapping.toml"
36
+ INTERNAL_RESOURCES_PKG = "nemo_evaluator_launcher_internal.resources"
37
+
38
+
39
+ def _ensure_cache_dir() -> None:
40
+ """Ensure the cache directory exists."""
41
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
42
+
43
+
44
+ def _get_cache_file() -> pathlib.Path:
45
+ """Get the cache file path.
46
+
47
+ Returns:
48
+ pathlib.Path: Path to the cache file.
49
+ """
50
+ return CACHE_DIR / CACHE_FILENAME
51
+
52
+
53
+ def _download_latest_mapping() -> Optional[bytes]:
54
+ """Download latest mapping from MAPPING_URL and return raw bytes.
55
+
56
+ Returns:
57
+ Optional[bytes]: Downloaded mapping bytes, or None if download fails.
58
+ """
59
+ try:
60
+ response = requests.get(MAPPING_URL, timeout=10)
61
+ response.raise_for_status()
62
+
63
+ # For GitHub raw URLs, the response content is the file content directly
64
+ mapping_bytes = response.content
65
+ assert isinstance(mapping_bytes, bytes)
66
+
67
+ logger.debug("Successfully downloaded mapping from remote URL")
68
+ return mapping_bytes
69
+ except (requests.RequestException, OSError) as e:
70
+ logger.warning("Failed to download mapping from remote URL", error=str(e))
71
+ return None
72
+
73
+
74
+ def _load_cached_mapping() -> Optional[dict[Any, Any]]:
75
+ """Load mapping from cache file.
76
+
77
+ Returns:
78
+ Optional[dict]: Loaded mapping data, or None if loading fails.
79
+ """
80
+ cache_file = _get_cache_file()
81
+ if not cache_file.exists():
82
+ return None
83
+
84
+ try:
85
+ with open(cache_file, "rb") as f:
86
+ mapping = tomllib.load(f)
87
+ logger.debug("Loaded mapping from cache")
88
+ return mapping # type: ignore[no-any-return]
89
+ except (OSError, tomllib.TOMLDecodeError) as e:
90
+ logger.warning("Failed to load mapping from cache", error=str(e))
91
+ return None
92
+
93
+
94
+ def _save_mapping_to_cache(mapping_bytes: bytes) -> None:
95
+ """Save mapping to cache file.
96
+
97
+ Args:
98
+ mapping_bytes: Mapping data to save.
99
+ """
100
+ try:
101
+ _ensure_cache_dir()
102
+ cache_file = _get_cache_file()
103
+
104
+ # Save the mapping data
105
+ with open(cache_file, "wb") as f:
106
+ f.write(mapping_bytes)
107
+
108
+ except OSError as e:
109
+ logger.warning("Failed to save mapping to cache", error=str(e))
110
+
111
+
112
+ def _load_packaged_resource(
113
+ resource_name: str, pkg_name: str = "nemo_evaluator_launcher.resources"
114
+ ) -> dict[str, Any]:
115
+ """Load a resource from the packaged resources.
116
+
117
+ Args:
118
+ resource_name: The name of the resource to load.
119
+ """
120
+ try:
121
+ resource_toml: dict[str, Any] = {}
122
+ with resources.files(pkg_name).joinpath(resource_name).open("rb") as f:
123
+ resource_toml = tomllib.load(f)
124
+ logger.info(
125
+ "Loaded resource from packaged file", resource=resource_name, pkg=pkg_name
126
+ )
127
+ return resource_toml
128
+ except (OSError, tomllib.TOMLDecodeError) as e:
129
+ logger.error(
130
+ "Failed to load from packaged file",
131
+ resource=resource_name,
132
+ pkg=pkg_name,
133
+ error=str(e),
134
+ )
135
+ raise RuntimeError(f"Failed to load {resource_name} from packaged file") from e
136
+
137
+
138
+ def _process_mapping(mapping_toml: dict) -> dict:
139
+ """Process the raw mapping TOML into the expected format.
140
+
141
+ Args:
142
+ mapping_toml: Raw mapping TOML data.
143
+ Returns:
144
+ dict: Processed mapping in the expected format.
145
+ """
146
+ mapping = {}
147
+ for harness_name, harness_data in mapping_toml.items():
148
+ assert isinstance(harness_data["tasks"], dict)
149
+ for endpoint_type, harness_tasks in harness_data["tasks"].items():
150
+ assert isinstance(harness_tasks, dict)
151
+ for task_name, task_data in harness_tasks.items():
152
+ assert isinstance(task_data, dict)
153
+ key = (harness_name, task_name)
154
+ if key in mapping:
155
+ raise KeyError(
156
+ f"(harness,task)-tuple key {repr(key)} already exists in the mapping"
157
+ )
158
+ mapping[key] = {
159
+ "task": task_name,
160
+ "harness": harness_name,
161
+ "container": harness_data["container"],
162
+ "endpoint_type": endpoint_type,
163
+ }
164
+ for task_data_key in task_data.keys():
165
+ if task_data_key in mapping[key]:
166
+ raise KeyError(
167
+ f"{repr(task_data_key)} is not allowed as key under {repr(key)} in the mapping"
168
+ )
169
+ mapping[key].update(task_data)
170
+ return mapping
171
+
172
+
173
+ def load_tasks_mapping(
174
+ latest: bool = False,
175
+ mapping_toml: pathlib.Path | str | None = None,
176
+ ) -> dict[tuple[str, str], dict]:
177
+ """Load tasks mapping.
178
+
179
+ The function obeys the following priority rules:
180
+ 1. (Default) If latest==False and mapping_toml is None -> load packaged mapping.
181
+ 2. If latest==True -> fetch MAPPING_URL, save to cache, load it.
182
+ 3. If mapping_toml is not None -> load mapping from this path.
183
+
184
+ Returns:
185
+ dict: Mapping of (harness_name, task_name) to dict holding their configuration.
186
+
187
+ """
188
+ local_mapping: dict = {}
189
+ if latest:
190
+ mapping_bytes = _download_latest_mapping()
191
+ if mapping_bytes:
192
+ _save_mapping_to_cache(mapping_bytes)
193
+ local_mapping = _process_mapping(
194
+ tomllib.loads(mapping_bytes.decode("utf-8"))
195
+ )
196
+ else:
197
+ # Fallback to cached mapping; raise only if cache is missing/invalid
198
+ cached = _load_cached_mapping()
199
+ if cached:
200
+ local_mapping = _process_mapping(cached)
201
+ else:
202
+ raise RuntimeError("could not download latest mapping")
203
+
204
+ elif mapping_toml is not None:
205
+ with open(mapping_toml, "rb") as f:
206
+ local_mapping = _process_mapping(tomllib.load(f))
207
+ else:
208
+ local_mapping = _process_mapping(_load_packaged_resource(CACHE_FILENAME))
209
+
210
+ # TODO: make more elegant. We consider it ok to avoid a fully-blown plugin system.
211
+ # Check if nemo_evaluator_launcher_internal is available and load its mapping.toml
212
+ # CAVEAT: lazy-loading here, not somewhere top level, is important, to ensure
213
+ # order of package initialization.
214
+ try:
215
+ importlib.import_module("nemo_evaluator_launcher_internal")
216
+ logger.debug("Internal package available, loading internal mapping")
217
+ internal_mapping = _process_mapping(
218
+ _load_packaged_resource(CACHE_FILENAME, INTERNAL_RESOURCES_PKG)
219
+ )
220
+
221
+ # Merge internal mapping with local mapping (internal takes precedence)
222
+ local_mapping.update(internal_mapping)
223
+ logger.info(
224
+ "Successfully merged internal mapping", internal_tasks=len(internal_mapping)
225
+ )
226
+ except ImportError:
227
+ logger.debug("Internal package not available, using external mapping only")
228
+ except Exception as e:
229
+ logger.warning("Failed to load internal mapping", error=str(e))
230
+
231
+ return local_mapping
232
+
233
+
234
+ def get_task_from_mapping(query: str, mapping: dict[Any, Any]) -> dict[Any, Any]:
235
+ """Unambiguously selects one task from the mapping based on the query.
236
+
237
+ Args:
238
+ query: Either `task_name` or `harness_name.task_name`.
239
+ mapping: The object returned from `load_tasks_mapping` function.
240
+
241
+ Returns:
242
+ dict: Task data.
243
+
244
+ """
245
+ num_dots = query.count(".")
246
+
247
+ # if there are no dots in query, treat it like a task name
248
+ if num_dots == 0:
249
+ matching_keys = [key for key in mapping.keys() if key[1] == query]
250
+ # if exactly one task matching the query has been found:
251
+ if len(matching_keys) == 1:
252
+ key = matching_keys[0]
253
+ return mapping[key] # type: ignore[no-any-return]
254
+ # if more than one task matching the query has been found:
255
+ elif len(matching_keys) > 1:
256
+ matching_queries = [
257
+ f"{harness_name}.{task_name}"
258
+ for harness_name, task_name in matching_keys
259
+ ]
260
+ raise ValueError(
261
+ f"there are multiple tasks named {repr(query)} in the mapping,"
262
+ f" please select one of {repr(matching_queries)}"
263
+ )
264
+ # no tasks have been found:
265
+ else:
266
+ raise ValueError(f"task {repr(query)} does not exist in the mapping")
267
+
268
+ # if there is one dot in query, treat it like "{harness_name}.{task_name}"
269
+ elif num_dots == 1:
270
+ harness_name, task_name = query.split(".")
271
+ matching_keys = [
272
+ key for key in mapping.keys() if key == (harness_name, task_name)
273
+ ]
274
+ # if exactly one task matching the query has been found:
275
+ if len(matching_keys) == 1:
276
+ key = matching_keys[0]
277
+ return mapping[key] # type: ignore[no-any-return]
278
+ # if more than one task matching the query has been found:
279
+ elif len(matching_keys) >= 2:
280
+ raise ValueError(
281
+ f"there are multiple matches for {repr(query)} in the mapping,"
282
+ " which means the mapping is not correct"
283
+ )
284
+ # no tasks have been found:
285
+ else:
286
+ raise ValueError(
287
+ f"harness.task {repr(query)} does not exist in the mapping"
288
+ )
289
+
290
+ # invalid query
291
+ else:
292
+ raise ValueError(
293
+ f"invalid query={repr(query)} for task mapping,"
294
+ " it must contain exactly zero or one occurrence of '.' character"
295
+ )