nemo-evaluator-launcher 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nemo-evaluator-launcher might be problematic. Click here for more details.

Files changed (60) hide show
  1. nemo_evaluator_launcher/__init__.py +79 -0
  2. nemo_evaluator_launcher/api/__init__.py +24 -0
  3. nemo_evaluator_launcher/api/functional.py +698 -0
  4. nemo_evaluator_launcher/api/types.py +98 -0
  5. nemo_evaluator_launcher/api/utils.py +19 -0
  6. nemo_evaluator_launcher/cli/__init__.py +15 -0
  7. nemo_evaluator_launcher/cli/export.py +267 -0
  8. nemo_evaluator_launcher/cli/info.py +512 -0
  9. nemo_evaluator_launcher/cli/kill.py +41 -0
  10. nemo_evaluator_launcher/cli/ls_runs.py +134 -0
  11. nemo_evaluator_launcher/cli/ls_tasks.py +136 -0
  12. nemo_evaluator_launcher/cli/main.py +226 -0
  13. nemo_evaluator_launcher/cli/run.py +200 -0
  14. nemo_evaluator_launcher/cli/status.py +164 -0
  15. nemo_evaluator_launcher/cli/version.py +55 -0
  16. nemo_evaluator_launcher/common/__init__.py +16 -0
  17. nemo_evaluator_launcher/common/execdb.py +283 -0
  18. nemo_evaluator_launcher/common/helpers.py +366 -0
  19. nemo_evaluator_launcher/common/logging_utils.py +357 -0
  20. nemo_evaluator_launcher/common/mapping.py +295 -0
  21. nemo_evaluator_launcher/common/printing_utils.py +93 -0
  22. nemo_evaluator_launcher/configs/__init__.py +15 -0
  23. nemo_evaluator_launcher/configs/default.yaml +28 -0
  24. nemo_evaluator_launcher/configs/deployment/generic.yaml +33 -0
  25. nemo_evaluator_launcher/configs/deployment/nim.yaml +32 -0
  26. nemo_evaluator_launcher/configs/deployment/none.yaml +16 -0
  27. nemo_evaluator_launcher/configs/deployment/sglang.yaml +38 -0
  28. nemo_evaluator_launcher/configs/deployment/trtllm.yaml +24 -0
  29. nemo_evaluator_launcher/configs/deployment/vllm.yaml +42 -0
  30. nemo_evaluator_launcher/configs/execution/lepton/default.yaml +92 -0
  31. nemo_evaluator_launcher/configs/execution/local.yaml +19 -0
  32. nemo_evaluator_launcher/configs/execution/slurm/default.yaml +34 -0
  33. nemo_evaluator_launcher/executors/__init__.py +22 -0
  34. nemo_evaluator_launcher/executors/base.py +120 -0
  35. nemo_evaluator_launcher/executors/lepton/__init__.py +16 -0
  36. nemo_evaluator_launcher/executors/lepton/deployment_helpers.py +609 -0
  37. nemo_evaluator_launcher/executors/lepton/executor.py +1004 -0
  38. nemo_evaluator_launcher/executors/lepton/job_helpers.py +398 -0
  39. nemo_evaluator_launcher/executors/local/__init__.py +15 -0
  40. nemo_evaluator_launcher/executors/local/executor.py +605 -0
  41. nemo_evaluator_launcher/executors/local/run.template.sh +103 -0
  42. nemo_evaluator_launcher/executors/registry.py +38 -0
  43. nemo_evaluator_launcher/executors/slurm/__init__.py +15 -0
  44. nemo_evaluator_launcher/executors/slurm/executor.py +1147 -0
  45. nemo_evaluator_launcher/exporters/__init__.py +36 -0
  46. nemo_evaluator_launcher/exporters/base.py +121 -0
  47. nemo_evaluator_launcher/exporters/gsheets.py +409 -0
  48. nemo_evaluator_launcher/exporters/local.py +502 -0
  49. nemo_evaluator_launcher/exporters/mlflow.py +619 -0
  50. nemo_evaluator_launcher/exporters/registry.py +40 -0
  51. nemo_evaluator_launcher/exporters/utils.py +624 -0
  52. nemo_evaluator_launcher/exporters/wandb.py +490 -0
  53. nemo_evaluator_launcher/package_info.py +38 -0
  54. nemo_evaluator_launcher/resources/mapping.toml +380 -0
  55. nemo_evaluator_launcher-0.1.28.dist-info/METADATA +494 -0
  56. nemo_evaluator_launcher-0.1.28.dist-info/RECORD +60 -0
  57. nemo_evaluator_launcher-0.1.28.dist-info/WHEEL +5 -0
  58. nemo_evaluator_launcher-0.1.28.dist-info/entry_points.txt +3 -0
  59. nemo_evaluator_launcher-0.1.28.dist-info/licenses/LICENSE +451 -0
  60. nemo_evaluator_launcher-0.1.28.dist-info/top_level.txt +1 -0
@@ -0,0 +1,605 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ """Local executor implementation for nemo-evaluator-launcher.
17
+
18
+ Handles running evaluation jobs locally using shell scripts and Docker containers.
19
+ """
20
+
21
+ import copy
22
+ import os
23
+ import pathlib
24
+ import platform
25
+ import shlex
26
+ import shutil
27
+ import subprocess
28
+ import time
29
+ from typing import List, Optional
30
+
31
+ import jinja2
32
+ import yaml
33
+ from omegaconf import DictConfig, OmegaConf
34
+
35
+ from nemo_evaluator_launcher.common.execdb import (
36
+ ExecutionDB,
37
+ JobData,
38
+ generate_invocation_id,
39
+ generate_job_id,
40
+ )
41
+ from nemo_evaluator_launcher.common.helpers import (
42
+ get_eval_factory_command,
43
+ get_eval_factory_dataset_size_from_run_config,
44
+ get_timestamp_string,
45
+ )
46
+ from nemo_evaluator_launcher.common.logging_utils import logger
47
+ from nemo_evaluator_launcher.common.mapping import (
48
+ get_task_from_mapping,
49
+ load_tasks_mapping,
50
+ )
51
+ from nemo_evaluator_launcher.common.printing_utils import bold, cyan, grey, red
52
+ from nemo_evaluator_launcher.executors.base import (
53
+ BaseExecutor,
54
+ ExecutionState,
55
+ ExecutionStatus,
56
+ )
57
+ from nemo_evaluator_launcher.executors.registry import register_executor
58
+
59
+
60
+ @register_executor("local")
61
+ class LocalExecutor(BaseExecutor):
62
+ @classmethod
63
+ def execute_eval(cls, cfg: DictConfig, dry_run: bool = False) -> str:
64
+ """Run evaluation jobs locally using the provided configuration.
65
+
66
+ Args:
67
+ cfg: The configuration object for the evaluation run.
68
+ dry_run: If True, prepare scripts and save them without execution.
69
+
70
+ Returns:
71
+ str: The invocation ID for the evaluation run.
72
+
73
+ Raises:
74
+ NotImplementedError: If deployment is not 'none'.
75
+ RuntimeError: If the run script fails.
76
+ """
77
+ if cfg.deployment.type != "none":
78
+ raise NotImplementedError(
79
+ f"type {cfg.deployment.type} is not implemented -- add deployment support"
80
+ )
81
+
82
+ # Check if docker is available (skip in dry_run mode)
83
+ if not dry_run and shutil.which("docker") is None:
84
+ raise RuntimeError(
85
+ "Docker is not installed or not in PATH. "
86
+ "Please install Docker to run local evaluations."
87
+ )
88
+
89
+ # Generate invocation ID for this evaluation run
90
+ invocation_id = generate_invocation_id()
91
+
92
+ output_dir = pathlib.Path(cfg.execution.output_dir).absolute() / (
93
+ get_timestamp_string(include_microseconds=False) + "-" + invocation_id
94
+ )
95
+ output_dir.mkdir(parents=True, exist_ok=True)
96
+
97
+ tasks_mapping = load_tasks_mapping()
98
+ evaluation_tasks = []
99
+ job_ids = []
100
+
101
+ eval_template = jinja2.Template(
102
+ open(pathlib.Path(__file__).parent / "run.template.sh", "r").read()
103
+ )
104
+
105
+ execution_mode = cfg.execution.get("mode", "parallel")
106
+ if execution_mode == "parallel":
107
+ is_execution_mode_sequential = False
108
+ elif execution_mode == "sequential":
109
+ is_execution_mode_sequential = True
110
+ else:
111
+ raise ValueError(
112
+ "unknown execution mode: {}. Choose one of {}".format(
113
+ repr(execution_mode), ["parallel", "sequential"]
114
+ )
115
+ )
116
+
117
+ # Will accumulate if any task contains unsafe commands.
118
+ is_potentially_unsafe = False
119
+ for idx, task in enumerate(cfg.evaluation.tasks):
120
+ task_definition = get_task_from_mapping(task.name, tasks_mapping)
121
+
122
+ # Create job ID as <invocation_id>.<n>
123
+ job_id = generate_job_id(invocation_id, idx)
124
+ job_ids.append(job_id)
125
+ container_name = f"{task.name}-{get_timestamp_string()}"
126
+
127
+ # collect all env vars
128
+ env_vars = copy.deepcopy(dict(cfg.evaluation.get("env_vars", {})))
129
+ env_vars.update(task.get("env_vars", {}))
130
+ if cfg.target.api_endpoint.api_key_name:
131
+ assert "API_KEY" not in env_vars
132
+ env_vars["API_KEY"] = cfg.target.api_endpoint.api_key_name
133
+
134
+ # check if the environment variables are set
135
+ for env_var in env_vars.values():
136
+ if os.getenv(env_var) is None:
137
+ raise ValueError(
138
+ f"Trying to pass an unset environment variable {env_var}."
139
+ )
140
+
141
+ # check if required env vars are defined:
142
+ for required_env_var in task_definition.get("required_env_vars", []):
143
+ if required_env_var not in env_vars.keys():
144
+ raise ValueError(
145
+ f"{task.name} task requires environment variable {required_env_var}."
146
+ " Specify it in the task subconfig in the 'env_vars' dict as the following"
147
+ f" pair {required_env_var}: YOUR_ENV_VAR_NAME"
148
+ )
149
+
150
+ # format env_vars for a template
151
+ env_vars = [
152
+ f"{env_var_dst}=${env_var_src}"
153
+ for env_var_dst, env_var_src in env_vars.items()
154
+ ]
155
+
156
+ eval_image = task_definition["container"]
157
+ if "container" in task:
158
+ eval_image = task["container"]
159
+
160
+ task_output_dir = output_dir / task.name
161
+ task_output_dir.mkdir(parents=True, exist_ok=True)
162
+ eval_factory_command_struct = get_eval_factory_command(
163
+ cfg, task, task_definition
164
+ )
165
+ eval_factory_command = eval_factory_command_struct.cmd
166
+ # The debug comment for placing into the script and easy debug. Reason
167
+ # (see `CmdAndReadableComment`) is the current way of passing the command
168
+ # is base64-encoded config `echo`-ed into file.
169
+ # TODO(agronskiy): cleaner way is to encode everything with base64, not
170
+ # some parts (like ef_config.yaml) and just output as logs somewhere.
171
+ eval_factory_command_debug_comment = eval_factory_command_struct.debug
172
+ is_potentially_unsafe = (
173
+ is_potentially_unsafe
174
+ or eval_factory_command_struct.is_potentially_unsafe
175
+ )
176
+ evaluation_task = {
177
+ "name": task.name,
178
+ "job_id": job_id,
179
+ "eval_image": eval_image,
180
+ "container_name": container_name,
181
+ "env_vars": env_vars,
182
+ "output_dir": task_output_dir,
183
+ "eval_factory_command": eval_factory_command,
184
+ "eval_factory_command_debug_comment": eval_factory_command_debug_comment,
185
+ }
186
+ evaluation_tasks.append(evaluation_task)
187
+
188
+ # Check if auto-export is enabled by presence of destination(s)
189
+ auto_export_config = cfg.execution.get("auto_export", {})
190
+ auto_export_destinations = auto_export_config.get("destinations", [])
191
+
192
+ extra_docker_args = cfg.execution.get("extra_docker_args", "")
193
+
194
+ run_sh_content = (
195
+ eval_template.render(
196
+ evaluation_tasks=[evaluation_task],
197
+ auto_export_destinations=auto_export_destinations,
198
+ extra_docker_args=extra_docker_args,
199
+ ).rstrip("\n")
200
+ + "\n"
201
+ )
202
+
203
+ (task_output_dir / "run.sh").write_text(run_sh_content)
204
+
205
+ run_all_sequentially_sh_content = (
206
+ eval_template.render(
207
+ evaluation_tasks=evaluation_tasks,
208
+ auto_export_destinations=auto_export_destinations,
209
+ extra_docker_args=extra_docker_args,
210
+ ).rstrip("\n")
211
+ + "\n"
212
+ )
213
+ (output_dir / "run_all.sequential.sh").write_text(
214
+ run_all_sequentially_sh_content
215
+ )
216
+
217
+ if dry_run:
218
+ print(bold("\n\n=============================================\n\n"))
219
+ print(bold(cyan(f"DRY RUN: Scripts prepared and saved to {output_dir}")))
220
+ if is_execution_mode_sequential:
221
+ print(
222
+ cyan(
223
+ "\n\n=========== Main script | run_all.sequential.sh =====================\n\n"
224
+ )
225
+ )
226
+
227
+ with open(output_dir / "run_all.sequential.sh", "r") as f:
228
+ print(grey(f.read()))
229
+ else:
230
+ for idx, task in enumerate(cfg.evaluation.tasks):
231
+ task_output_dir = output_dir / task.name
232
+ print(
233
+ cyan(
234
+ f"\n\n=========== Task script | {task.name}/run.sh =====================\n\n"
235
+ )
236
+ )
237
+ with open(task_output_dir / "run.sh", "r") as f:
238
+ print(grey(f.read()))
239
+ print(bold("\nTo execute, run without --dry-run"))
240
+
241
+ if is_potentially_unsafe:
242
+ print(
243
+ red(
244
+ "\nFound `pre_cmd` which carries security risk. When running without --dry-run "
245
+ "make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1"
246
+ )
247
+ )
248
+ return invocation_id
249
+
250
+ if is_potentially_unsafe:
251
+ if os.environ.get("NEMO_EVALUATOR_TRUST_PRE_CMD", "") == "1":
252
+ logger.warning(
253
+ "Found non-empty task commands (e.g. `pre_cmd`) and NEMO_EVALUATOR_TRUST_PRE_CMD "
254
+ "is set, proceeding with caution."
255
+ )
256
+
257
+ else:
258
+ logger.error(
259
+ "Found non-empty task commands (e.g. `pre_cmd`) and NEMO_EVALUATOR_TRUST_PRE_CMD "
260
+ "is not set. This might carry security risk and unstable environments. "
261
+ "To continue, make sure you trust the command and set NEMO_EVALUATOR_TRUST_PRE_CMD=1.",
262
+ )
263
+ raise AttributeError(
264
+ "Untrusted command found in config, make sure you trust and "
265
+ "set NEMO_EVALUATOR_TRUST_PRE_CMD=1."
266
+ )
267
+
268
+ # Save launched jobs metadata
269
+ db = ExecutionDB()
270
+ for job_id, task, evaluation_task in zip(
271
+ job_ids, cfg.evaluation.tasks, evaluation_tasks
272
+ ):
273
+ db.write_job(
274
+ job=JobData(
275
+ invocation_id=invocation_id,
276
+ job_id=job_id,
277
+ timestamp=time.time(),
278
+ executor="local",
279
+ data={
280
+ "output_dir": str(evaluation_task["output_dir"]),
281
+ "container": evaluation_task["container_name"],
282
+ "eval_image": evaluation_task["eval_image"],
283
+ },
284
+ config=OmegaConf.to_object(cfg),
285
+ )
286
+ )
287
+
288
+ # Launch bash scripts with Popen for non-blocking execution.
289
+ # To ensure subprocess continues after python exits:
290
+ # - on Unix-like systems, to fully detach the subprocess
291
+ # so it does not die when Python exits, pass start_new_session=True;
292
+ # - on Windows use creationflags=subprocess.CREATE_NEW_PROCESS_GROUP flag.
293
+ os_name = platform.system()
294
+ processes = []
295
+
296
+ if is_execution_mode_sequential:
297
+ if os_name == "Windows":
298
+ proc = subprocess.Popen(
299
+ shlex.split("bash run_all.sequential.sh"),
300
+ cwd=output_dir,
301
+ creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
302
+ )
303
+ else:
304
+ proc = subprocess.Popen(
305
+ shlex.split("bash run_all.sequential.sh"),
306
+ cwd=output_dir,
307
+ start_new_session=True,
308
+ )
309
+ processes.append(("run_all.sequential.sh", proc, output_dir))
310
+ else:
311
+ for task in cfg.evaluation.tasks:
312
+ if os_name == "Windows":
313
+ proc = subprocess.Popen(
314
+ shlex.split("bash run.sh"),
315
+ cwd=output_dir / task.name,
316
+ creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
317
+ )
318
+ else:
319
+ proc = subprocess.Popen(
320
+ shlex.split("bash run.sh"),
321
+ cwd=output_dir / task.name,
322
+ start_new_session=True,
323
+ )
324
+ processes.append((task.name, proc, output_dir / task.name))
325
+
326
+ # Wait briefly and check if bash scripts exited immediately (which means error)
327
+ time.sleep(0.3)
328
+
329
+ for name, proc, work_dir in processes:
330
+ exit_code = proc.poll()
331
+ if exit_code is not None and exit_code != 0:
332
+ error_msg = f"Script for {name} exited with code {exit_code}"
333
+ raise RuntimeError(f"Job startup failed | {error_msg}")
334
+
335
+ print(bold(cyan("\nCommands for real-time monitoring:")))
336
+ for job_id, evaluation_task in zip(job_ids, evaluation_tasks):
337
+ log_file = evaluation_task["output_dir"] / "logs" / "stdout.log"
338
+ print(f" tail -f {log_file}")
339
+
340
+ print(bold(cyan("\nFollow all logs for this invocation:")))
341
+ print(f" tail -f {output_dir}/*/logs/stdout.log\n")
342
+
343
+ return invocation_id
344
+
345
+ @staticmethod
346
+ def get_status(id: str) -> List[ExecutionStatus]:
347
+ """Get the status of a specific job or all jobs in an invocation group.
348
+
349
+ Args:
350
+ id: Unique job identifier or invocation identifier.
351
+
352
+ Returns:
353
+ List containing the execution status for the job(s).
354
+ """
355
+ db = ExecutionDB()
356
+
357
+ # If id looks like an invocation_id (no dot), get all jobs for it
358
+ if "." not in id:
359
+ jobs = db.get_jobs(id)
360
+ statuses: List[ExecutionStatus] = []
361
+ for job_id, _ in jobs.items():
362
+ statuses.extend(LocalExecutor.get_status(job_id))
363
+ return statuses
364
+
365
+ # Otherwise, treat as job_id
366
+ job_data = db.get_job(id)
367
+ if job_data is None:
368
+ return []
369
+ if job_data.executor != "local":
370
+ return []
371
+
372
+ output_dir = pathlib.Path(job_data.data.get("output_dir", ""))
373
+ if not output_dir.exists():
374
+ return [ExecutionStatus(id=id, state=ExecutionState.PENDING)]
375
+
376
+ artifacts_dir = output_dir / "artifacts"
377
+ progress = _get_progress(artifacts_dir)
378
+
379
+ logs_dir = output_dir / "logs"
380
+ if not logs_dir.exists():
381
+ return [
382
+ ExecutionStatus(
383
+ id=id,
384
+ state=ExecutionState.PENDING,
385
+ progress=dict(progress=progress),
386
+ )
387
+ ]
388
+
389
+ # Check if job was killed
390
+ if job_data.data.get("killed", False):
391
+ return [
392
+ ExecutionStatus(
393
+ id=id, state=ExecutionState.KILLED, progress=dict(progress=progress)
394
+ )
395
+ ]
396
+
397
+ stage_files = {
398
+ "pre_start": logs_dir / "stage.pre-start",
399
+ "running": logs_dir / "stage.running",
400
+ "exit": logs_dir / "stage.exit",
401
+ }
402
+
403
+ if stage_files["exit"].exists():
404
+ try:
405
+ content = stage_files["exit"].read_text().strip()
406
+ if " " in content:
407
+ timestamp, exit_code_str = content.rsplit(" ", 1)
408
+ exit_code = int(exit_code_str)
409
+ if exit_code == 0:
410
+ return [
411
+ ExecutionStatus(
412
+ id=id,
413
+ state=ExecutionState.SUCCESS,
414
+ progress=dict(progress=progress),
415
+ )
416
+ ]
417
+ else:
418
+ return [
419
+ ExecutionStatus(
420
+ id=id,
421
+ state=ExecutionState.FAILED,
422
+ progress=dict(progress=progress),
423
+ )
424
+ ]
425
+ else:
426
+ return [
427
+ ExecutionStatus(
428
+ id=id,
429
+ state=ExecutionState.FAILED,
430
+ progress=dict(progress=progress),
431
+ )
432
+ ]
433
+ except (ValueError, OSError):
434
+ return [
435
+ ExecutionStatus(
436
+ id=id,
437
+ state=ExecutionState.FAILED,
438
+ progress=dict(progress=progress),
439
+ )
440
+ ]
441
+ elif stage_files["running"].exists():
442
+ return [
443
+ ExecutionStatus(
444
+ id=id,
445
+ state=ExecutionState.RUNNING,
446
+ progress=dict(progress=progress),
447
+ )
448
+ ]
449
+ elif stage_files["pre_start"].exists():
450
+ return [
451
+ ExecutionStatus(
452
+ id=id,
453
+ state=ExecutionState.PENDING,
454
+ progress=dict(progress=progress),
455
+ )
456
+ ]
457
+
458
+ return [
459
+ ExecutionStatus(
460
+ id=id, state=ExecutionState.PENDING, progress=dict(progress=progress)
461
+ )
462
+ ]
463
+
464
+ @staticmethod
465
+ def kill_job(job_id: str) -> None:
466
+ """Kill a local job.
467
+
468
+ Args:
469
+ job_id: The job ID (e.g., abc123.0) to kill.
470
+
471
+ Raises:
472
+ ValueError: If job is not found or invalid.
473
+ RuntimeError: If Docker container cannot be stopped.
474
+ """
475
+ db = ExecutionDB()
476
+ job_data = db.get_job(job_id)
477
+
478
+ if job_data is None:
479
+ raise ValueError(f"Job {job_id} not found")
480
+
481
+ if job_data.executor != "local":
482
+ raise ValueError(
483
+ f"Job {job_id} is not a local job (executor: {job_data.executor})"
484
+ )
485
+
486
+ # Get container name from database
487
+ container_name = job_data.data.get("container")
488
+ if not container_name:
489
+ raise ValueError(f"No container name found for job {job_id}")
490
+
491
+ killed_something = False
492
+
493
+ # First, try to stop the Docker container if it's running
494
+ result = subprocess.run(
495
+ shlex.split(f"docker stop {container_name}"),
496
+ capture_output=True,
497
+ text=True,
498
+ timeout=30,
499
+ )
500
+ if result.returncode == 0:
501
+ killed_something = True
502
+ # Don't raise error if container doesn't exist (might be still pulling)
503
+
504
+ # Find and kill Docker processes for this container
505
+ result = subprocess.run(
506
+ shlex.split(f"pkill -f 'docker run.*{container_name}'"),
507
+ capture_output=True,
508
+ text=True,
509
+ timeout=10,
510
+ )
511
+ if result.returncode == 0:
512
+ killed_something = True
513
+
514
+ # If we successfully killed something, mark as killed
515
+ if killed_something:
516
+ job_data.data["killed"] = True
517
+ db.write_job(job_data)
518
+ LocalExecutor._add_to_killed_jobs(job_data.invocation_id, job_id)
519
+ return
520
+
521
+ # If nothing was killed, check if this is a pending job
522
+ status_list = LocalExecutor.get_status(job_id)
523
+ if status_list and status_list[0].state == ExecutionState.PENDING:
524
+ # For pending jobs, mark as killed even though there's nothing to kill yet
525
+ job_data.data["killed"] = True
526
+ db.write_job(job_data)
527
+ LocalExecutor._add_to_killed_jobs(job_data.invocation_id, job_id)
528
+ return
529
+
530
+ # Use common helper to get informative error message based on job status
531
+ current_status = status_list[0].state if status_list else None
532
+ error_msg = LocalExecutor.get_kill_failure_message(
533
+ job_id, f"container: {container_name}", current_status
534
+ )
535
+ raise RuntimeError(error_msg)
536
+
537
+ @staticmethod
538
+ def _add_to_killed_jobs(invocation_id: str, job_id: str) -> None:
539
+ """Add a job ID to the killed jobs file for this invocation.
540
+
541
+ Args:
542
+ invocation_id: The invocation ID.
543
+ job_id: The job ID to mark as killed.
544
+ """
545
+ db = ExecutionDB()
546
+ jobs = db.get_jobs(invocation_id)
547
+ if not jobs:
548
+ return
549
+
550
+ # Get invocation output directory from any job's output_dir
551
+ first_job_data = next(iter(jobs.values()))
552
+ job_output_dir = pathlib.Path(first_job_data.data.get("output_dir", ""))
553
+ if not job_output_dir.exists():
554
+ return
555
+
556
+ # Invocation dir is parent of job output dir
557
+ invocation_dir = job_output_dir.parent
558
+ killed_jobs_file = invocation_dir / "killed_jobs.txt"
559
+
560
+ # Append job_id to file
561
+ with open(killed_jobs_file, "a") as f:
562
+ f.write(f"{job_id}\n")
563
+
564
+
565
+ def _get_progress(artifacts_dir: pathlib.Path) -> Optional[float]:
566
+ """Get the progress of a local job.
567
+
568
+ Args:
569
+ artifacts_dir: The directory containing the evaluation artifacts.
570
+
571
+ Returns:
572
+ The progress of the job as a float between 0 and 1.
573
+ """
574
+ progress_filepath = artifacts_dir / "progress"
575
+ if not progress_filepath.exists():
576
+ return None
577
+ progress_str = progress_filepath.read_text().strip()
578
+ try:
579
+ processed_samples = int(progress_str)
580
+ except ValueError:
581
+ return None
582
+
583
+ dataset_size = _get_dataset_size(artifacts_dir)
584
+ if dataset_size is not None:
585
+ progress = processed_samples / dataset_size
586
+ else:
587
+ # NOTE(dfridman): if we don't know the dataset size, report the number of processed samples
588
+ progress = processed_samples
589
+ return progress
590
+
591
+
592
+ def _get_dataset_size(artifacts_dir: pathlib.Path) -> Optional[int]:
593
+ """Get the dataset size for a benchmark.
594
+
595
+ Args:
596
+ artifacts_dir: The directory containing the evaluation artifacts.
597
+
598
+ Returns:
599
+ The dataset size for the benchmark.
600
+ """
601
+ run_config = artifacts_dir / "run_config.yml"
602
+ if not run_config.exists():
603
+ return None
604
+ run_config = yaml.safe_load(run_config.read_text())
605
+ return get_eval_factory_dataset_size_from_run_config(run_config)