snakemake-executor-plugin-vastai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,846 @@
1
+ __author__ = "Michał Pogoda"
2
+ __copyright__ = "Copyright 2026, bards.ai"
3
+ __email__ = "michal.pogoda@bards.ai"
4
+ __license__ = "MIT"
5
+
6
+ import base64
7
+ import importlib.metadata
8
+ import os
9
+ import re
10
+ import shutil
11
+ import subprocess
12
+ import time
13
+ import uuid
14
+ from dataclasses import dataclass, field, replace
15
+ from pathlib import Path
16
+ from typing import AsyncGenerator, List, Mapping, Optional
17
+
18
+ from snakemake_interface_common.exceptions import WorkflowError
19
+ from snakemake_interface_executor_plugins.executors.base import SubmittedJobInfo
20
+ from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor
21
+ from snakemake_interface_executor_plugins.jobs import JobExecutorInterface
22
+ from snakemake_interface_executor_plugins.settings import (
23
+ CommonSettings,
24
+ ExecutorSettingsBase,
25
+ )
26
+
27
+ from snakemake_executor_plugin_vastai._common import (
28
+ PYTHON_PATH_SETUP,
29
+ fatal_boot_error,
30
+ )
31
+ from snakemake_executor_plugin_vastai.sshtransfer import SSHJobRunner
32
+
33
+ # Printed by the job wrapper script as the very last line of container output.
34
+ # check_active_jobs() parses it from the instance logs to obtain the exit code.
35
+ EXIT_CODE_PATTERN = re.compile(r"snakemake_vastai_exit_code=(\d+)")
36
+
37
+ # Number of cheapest matching offers to attempt when creating an instance.
38
+ # Offers are single-use and can be taken by other users between search and
39
+ # create, so we fall through to the next candidate on failure.
40
+ OFFER_ATTEMPTS = 5
41
+
42
+ # Instances reporting stopped/offline/unknown may recover (host heartbeat
43
+ # loss is often transient); only give up after this many seconds.
44
+ UNREACHABLE_GRACE_SECONDS = 300
45
+
46
+ # Tolerated consecutive failures of the status API before declaring the
47
+ # instance lost.
48
+ MAX_STATUS_API_MISSES = 5
49
+
50
+
51
+ # Local environment variables that are forwarded to job containers when
52
+ # forward_credentials is enabled, so that the default storage provider (e.g.
53
+ # S3, GCS, Azure, or any S3-compatible service) works remotely without any
54
+ # --envvars ceremony.
55
+ CREDENTIAL_ENVVARS = (
56
+ "AWS_ACCESS_KEY_ID",
57
+ "AWS_SECRET_ACCESS_KEY",
58
+ "AWS_SESSION_TOKEN",
59
+ "AWS_DEFAULT_REGION",
60
+ "AWS_REGION",
61
+ "AWS_ENDPOINT_URL",
62
+ "GOOGLE_CLOUD_PROJECT",
63
+ "AZURE_STORAGE_CONNECTION_STRING",
64
+ "AZURE_STORAGE_ACCOUNT",
65
+ "AZURE_STORAGE_KEY",
66
+ "AZURE_STORAGE_SAS_TOKEN",
67
+ )
68
+
69
+ # Google credentials are a *file* (GOOGLE_APPLICATION_CREDENTIALS points to
70
+ # it), so the content is shipped base64-encoded in this variable and
71
+ # materialized inside the container by the credential setup script.
72
+ GCP_CREDENTIALS_CONTENT_VAR = "SNAKEMAKE_VASTAI_GCP_CREDENTIALS_B64"
73
+
74
+ GCP_ADC_DEFAULT_PATH = "~/.config/gcloud/application_default_credentials.json"
75
+
76
+ # Region shortcuts for the geolocation setting, matching the vastai CLI's
77
+ # georegion expansion.
78
+ GEO_REGIONS = {
79
+ "EU": (
80
+ "AL,AD,AT,BY,BE,BA,BG,HR,CY,CZ,DK,EE,FI,FR,GE,DE,GR,HU,IS,IT,KZ,LV,"
81
+ "LI,LT,LU,MT,MD,MC,ME,NL,NO,PL,PT,RO,RU,RS,SK,SI,ES,SE,CH,UA,GB,VA,MK"
82
+ ),
83
+ "NA": "CA,US",
84
+ "AS": (
85
+ "AE,AM,AR,AU,AZ,BD,BH,BN,BT,MM,KH,KP,IN,ID,IR,IQ,IL,JP,JO,KZ,LV,LI,"
86
+ "MY,MV,MN,NP,KR,PK,PH,QA,SA,SG,LK,SY,TW,TJ,TH,TR,TM,VN,YE,HK,CN,OM"
87
+ ),
88
+ "AF": (
89
+ "DZ,AO,BJ,BW,BF,BI,CM,CV,CF,TD,KM,CG,CD,DJ,EG,GQ,ER,ET,GA,GM,GH,GN,"
90
+ "GW,KE,LS,LR,LY,MW,MA,ML,MR,MU,MZ,NA,NE,NG,RW,SH,ST,SN,SC,SL,SO,ZA,"
91
+ "SS,SD,SZ,TZ,TG,TN,UG,YE,ZM,ZW"
92
+ ),
93
+ "LC": (
94
+ "AG,AR,BS,BB,BZ,BO,BR,CL,CO,CR,CU,DO,EC,SV,GY,HT,HN,JM,MX,NI,PA,PY,"
95
+ "PE,PR,RD,SUR,TT,UR,VZ"
96
+ ),
97
+ "OC": "AU,FJ,GU,KI,MH,FM,NR,NZ,PG,PW,SL,TO,TV,VU",
98
+ }
99
+
100
+ # Vast.ai's curated PyTorch image: CUDA matched to the machine's driver via
101
+ # the server-side @vastai-automatic-tag, OpenSSH preinstalled (required for
102
+ # SSH transfer mode), python on PATH. Snakemake itself is pip-installed by
103
+ # the job bootstrap below.
104
+ DEFAULT_IMAGE = "vastai/pytorch:@vastai-automatic-tag"
105
+
106
+ CREDENTIAL_SETUP_SCRIPT = (
107
+ f'if [ -n "${{{GCP_CREDENTIALS_CONTENT_VAR}:-}}" ]; then\n'
108
+ f' echo "${{{GCP_CREDENTIALS_CONTENT_VAR}}}" | base64 -d '
109
+ "> /tmp/gcp-credentials.json\n"
110
+ " export GOOGLE_APPLICATION_CREDENTIALS=/tmp/gcp-credentials.json\n"
111
+ f" unset {GCP_CREDENTIALS_CONTENT_VAR}\n"
112
+ "fi\n"
113
+ )
114
+
115
+
116
+ @dataclass
117
+ class ExecutorSettings(ExecutorSettingsBase):
118
+ api_key: Optional[str] = field(
119
+ default=None,
120
+ metadata={
121
+ "help": "Vast.ai API key. If not set, it is resolved like the vastai "
122
+ "CLI does: from the VAST_API_KEY environment variable or from "
123
+ "~/.config/vastai/vast_api_key.",
124
+ "env_var": True,
125
+ "required": False,
126
+ },
127
+ )
128
+ gpu_name: Optional[str] = field(
129
+ default=None,
130
+ metadata={
131
+ "help": "Default GPU model to request, in Vast.ai naming with "
132
+ "underscores for spaces (e.g. RTX_4090, H100_SXM, A100_SXM4). "
133
+ "Can be overridden per job with the gpu_model resource. If unset, "
134
+ "any GPU matching the other constraints is used.",
135
+ },
136
+ )
137
+ image: Optional[str] = field(
138
+ default=None,
139
+ metadata={
140
+ "help": "Docker image to run jobs in. Defaults to "
141
+ f"{DEFAULT_IMAGE} (CUDA+PyTorch matched to the machine, OpenSSH "
142
+ "included); an explicitly set --container-image takes "
143
+ "precedence over that default. Snakemake is pip-installed into "
144
+ "the container automatically if the image does not provide it; "
145
+ "bake `pip install snakemake` into your own image to skip that "
146
+ "startup cost.",
147
+ },
148
+ )
149
+ disk: float = field(
150
+ default=40.0,
151
+ metadata={
152
+ "help": "Disk space to allocate per instance in GB. Can be raised "
153
+ "per job with the disk_mb resource.",
154
+ },
155
+ )
156
+ max_price: Optional[float] = field(
157
+ default=None,
158
+ metadata={
159
+ "help": "Maximum on-demand price per instance in $/h (dph_total, "
160
+ "including storage). Offers above this price are never rented.",
161
+ },
162
+ )
163
+ reliability: float = field(
164
+ default=0.98,
165
+ metadata={
166
+ "help": "Minimum host reliability score (0-1) required for offers. "
167
+ "Set to 0 to disable the filter.",
168
+ },
169
+ )
170
+ order: str = field(
171
+ default="dph_total",
172
+ metadata={
173
+ "help": "Sort order for choosing among matching offers (vastai "
174
+ "search order syntax; append '-' for descending). The default "
175
+ "rents the cheapest matching offer. Use e.g. 'dlperf_usd-' for "
176
+ "best DL performance per dollar.",
177
+ },
178
+ )
179
+ geolocation: Optional[str] = field(
180
+ default=None,
181
+ metadata={
182
+ "help": "Restrict offers to a geographic area: a region "
183
+ "shortcut (EU, NA, AS, AF, LC, OC) or comma-separated ISO "
184
+ "country codes (e.g. 'PL,DE,CZ').",
185
+ },
186
+ )
187
+ no_datacenter: bool = field(
188
+ default=False,
189
+ metadata={
190
+ "help": "Also rent from non-datacenter (hobbyist) hosts. By "
191
+ "default only verified datacenter hosts are used — they are "
192
+ "slightly pricier but avoid the most common marketplace "
193
+ "flakiness (Docker Hub pull rate limits, slow residential "
194
+ "uplinks). Disabling the filter gives access to the cheapest "
195
+ "offers.",
196
+ },
197
+ )
198
+ search_query: Optional[str] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "Additional filters appended to every offer search, in "
202
+ "vastai query syntax (e.g. 'cuda_vers>=12.4 geolocation=EU "
203
+ "inet_down>=500'). Can be extended per job with the vastai_query "
204
+ "resource.",
205
+ },
206
+ )
207
+ boot_timeout: int = field(
208
+ default=1800,
209
+ metadata={
210
+ "help": "Maximum seconds to wait for an instance to reach the "
211
+ "running state (includes docker image pull) before the job is "
212
+ "failed and the instance destroyed.",
213
+ },
214
+ )
215
+ no_forward_credentials: bool = field(
216
+ default=False,
217
+ metadata={
218
+ "help": "Do not forward cloud storage credentials (AWS_*, "
219
+ "AZURE_STORAGE_*, GOOGLE_* including the application "
220
+ "credentials file content) from the local environment into the "
221
+ "job containers. By default they are forwarded so that the "
222
+ "default storage provider works remotely without declaring "
223
+ "--envvars. Credentials configured via the storage plugin's own "
224
+ "settings (e.g. --storage-s3-access-key or "
225
+ "SNAKEMAKE_STORAGE_S3_ACCESS_KEY) are always forwarded by "
226
+ "Snakemake itself, independent of this option.",
227
+ },
228
+ )
229
+ keep_instances: bool = field(
230
+ default=False,
231
+ metadata={
232
+ "help": "Do not destroy instances after job completion or failure. "
233
+ "Useful for debugging, but instances keep accruing charges until "
234
+ "destroyed manually!",
235
+ },
236
+ )
237
+
238
+
239
+ common_settings = CommonSettings(
240
+ non_local_exec=True,
241
+ implies_no_shared_fs=True,
242
+ job_deploy_sources=True,
243
+ pass_default_storage_provider_args=True,
244
+ pass_default_resources_args=True,
245
+ # Envvars are injected through the container environment instead of being
246
+ # inlined into the command (see run_job).
247
+ pass_envvar_declarations_to_cmd=False,
248
+ # Lifts Snakemake's requirement for a default storage provider: with a
249
+ # storage provider the executor restores the storage-based behavior
250
+ # itself (see Executor.common_settings), without one it transfers files
251
+ # over SSH (see sshtransfer.py).
252
+ can_transfer_local_files=True,
253
+ auto_deploy_default_storage_provider=False,
254
+ # Instances need to be scheduled and pull the image first; no point in
255
+ # checking immediately.
256
+ init_seconds_before_status_checks=30,
257
+ )
258
+
259
+
260
+ def build_offer_query(
261
+ settings: ExecutorSettings,
262
+ resources: Mapping,
263
+ threads: int,
264
+ ) -> str:
265
+ """Build a vastai offer search query for a job.
266
+
267
+ Resources understood: gpu / nvidia_gpu (count), gpu_model, mem_mb,
268
+ disk_mb, vastai_query (extra filters in vastai query syntax).
269
+ """
270
+ parts = []
271
+
272
+ num_gpus = resources.get("gpu", resources.get("nvidia_gpu", 1))
273
+ try:
274
+ num_gpus = max(1, int(num_gpus))
275
+ except (TypeError, ValueError):
276
+ raise WorkflowError(
277
+ f"Resource gpu/nvidia_gpu must be an integer, got {num_gpus!r}."
278
+ )
279
+ parts.append(f"num_gpus={num_gpus}")
280
+
281
+ gpu_model = resources.get("gpu_model") or settings.gpu_name
282
+ if gpu_model:
283
+ parts.append(f"gpu_name={str(gpu_model).replace(' ', '_')}")
284
+
285
+ if threads:
286
+ parts.append(f"cpu_cores_effective>={int(threads)}")
287
+
288
+ mem_mb = resources.get("mem_mb")
289
+ if mem_mb:
290
+ # cpu_ram is specified in GB in vastai query syntax.
291
+ parts.append(f"cpu_ram>={float(mem_mb) / 1000:g}")
292
+
293
+ parts.append(f"disk_space>={required_disk_gb(settings, resources):g}")
294
+
295
+ if settings.reliability > 0:
296
+ parts.append(f"reliability>{settings.reliability:g}")
297
+
298
+ if settings.geolocation:
299
+ geo = settings.geolocation.strip()
300
+ codes = GEO_REGIONS.get(geo.upper(), geo)
301
+ parts.append(f"geolocation in [{codes}]")
302
+
303
+ if not settings.no_datacenter:
304
+ parts.append("datacenter=true")
305
+
306
+ if settings.max_price is not None:
307
+ parts.append(f"dph_total<={settings.max_price:g}")
308
+
309
+ if settings.search_query:
310
+ parts.append(settings.search_query)
311
+
312
+ job_query = resources.get("vastai_query")
313
+ if job_query:
314
+ parts.append(str(job_query))
315
+
316
+ return " ".join(parts)
317
+
318
+
319
+ def required_disk_gb(settings: ExecutorSettings, resources: Mapping) -> float:
320
+ disk_mb = resources.get("disk_mb")
321
+ disk_gb = float(disk_mb) / 1000 if disk_mb else 0.0
322
+ return max(float(settings.disk), disk_gb)
323
+
324
+
325
+ def resolve_container_image(
326
+ settings_image: Optional[str], snakemake_image: Optional[str]
327
+ ) -> str:
328
+ """Pick the job image: --vastai-image > explicit --container-image >
329
+ the Vast.ai PyTorch default.
330
+
331
+ Snakemake's own default for --container-image is snakemake/snakemake,
332
+ which lacks both CUDA and OpenSSH, so it is replaced unless the user
333
+ explicitly asked for something.
334
+ """
335
+ if settings_image:
336
+ return settings_image
337
+ if snakemake_image and not snakemake_image.startswith("snakemake/snakemake"):
338
+ return snakemake_image
339
+ return DEFAULT_IMAGE
340
+
341
+
342
+ def _pip_pin(package: str) -> str:
343
+ try:
344
+ return f"{package}=={importlib.metadata.version(package)}"
345
+ except importlib.metadata.PackageNotFoundError:
346
+ return package
347
+
348
+
349
+ def snakemake_bootstrap_script(with_storage_plugins: bool = False) -> str:
350
+ """Shell snippet installing snakemake into the container, pinned to the
351
+ local version so spawned-job CLI args stay compatible.
352
+
353
+ With with_storage_plugins (SSH mode), the locally installed storage
354
+ plugins are installed too: their settings are serialized into every job
355
+ command (e.g. --storage-s3-retries), which the remote snakemake can only
356
+ parse with the plugins present. Storage mode covers this via the
357
+ auto-deploy precommand instead.
358
+ """
359
+ pins = [_pip_pin("snakemake")]
360
+ if with_storage_plugins:
361
+ from snakemake_interface_storage_plugins.registry import (
362
+ StoragePluginRegistry,
363
+ )
364
+
365
+ registry = StoragePluginRegistry()
366
+ pins.extend(
367
+ sorted(
368
+ _pip_pin(registry.get_plugin_package_name(name))
369
+ for name in registry.get_registered_plugins()
370
+ )
371
+ )
372
+ quoted = " ".join(f'"{p}"' for p in pins)
373
+ return (
374
+ f"echo '[snakemake-vastai] ensuring job dependencies: {len(pins)} "
375
+ "package(s)'\n"
376
+ f"python -m pip install --quiet {quoted}\n"
377
+ )
378
+
379
+
380
+ def credential_envvars(environ: Mapping[str, str] = os.environ) -> dict:
381
+ env = {var: environ[var] for var in CREDENTIAL_ENVVARS if environ.get(var)}
382
+ gcp_path = environ.get("GOOGLE_APPLICATION_CREDENTIALS")
383
+ if gcp_path is None and environ is os.environ:
384
+ # Fall back to gcloud's application default credentials.
385
+ gcp_path = os.path.expanduser(GCP_ADC_DEFAULT_PATH)
386
+ if gcp_path:
387
+ try:
388
+ with open(gcp_path, "rb") as f:
389
+ env[GCP_CREDENTIALS_CONTENT_VAR] = base64.b64encode(
390
+ f.read()
391
+ ).decode()
392
+ except OSError:
393
+ pass
394
+ return env
395
+
396
+
397
+ class Executor(RemoteExecutor):
398
+ def __post_init__(self):
399
+ # Imported lazily so that merely having the plugin installed does not
400
+ # require the vastai package at snakemake startup of other executors.
401
+ from vastai import VastAI
402
+
403
+ self.settings: ExecutorSettings = self.workflow.executor_settings
404
+ self.vast = VastAI(api_key=self.settings.api_key, raw=True, quiet=True)
405
+ self.container_image = resolve_container_image(
406
+ self.settings.image,
407
+ self.workflow.remote_execution_settings.container_image,
408
+ )
409
+ self.run_id = uuid.uuid4().hex[:8]
410
+ self.log_dir = Path(self.workflow.persistence.aux_path) / "vastai-logs"
411
+ self.log_dir.mkdir(exist_ok=True, parents=True)
412
+ self.logger.info(
413
+ f"Using container image {self.container_image} for Vast.ai jobs."
414
+ )
415
+ if not self.settings.no_forward_credentials:
416
+ forwarded = sorted(credential_envvars())
417
+ if forwarded:
418
+ self.logger.info(
419
+ "Forwarding credentials to job containers: "
420
+ f"{', '.join(forwarded)} (disable with "
421
+ "--vastai-no-forward-credentials)."
422
+ )
423
+
424
+ self._destroyed_instances = set()
425
+ self.storage_mode = (
426
+ self.workflow.storage_registry.default_storage_provider is not None
427
+ )
428
+ if self.storage_mode:
429
+ # Snakemake core skips the source upload when
430
+ # can_transfer_local_files is set, so it is done here instead.
431
+ self.workflow.upload_sources()
432
+ else:
433
+ self.logger.info(
434
+ "No default storage provider configured; transferring files "
435
+ "between this machine and the instances over SSH. For large "
436
+ "data or many jobs, an S3-compatible bucket "
437
+ "(--default-storage-provider s3) is faster and more robust."
438
+ )
439
+ self._init_ssh_transfer()
440
+
441
+ def _init_ssh_transfer(self):
442
+ for tool in ("ssh", "scp", "ssh-keygen"):
443
+ if shutil.which(tool) is None:
444
+ raise WorkflowError(
445
+ f"SSH transfer mode requires '{tool}' on PATH. Install "
446
+ "OpenSSH, or configure a default storage provider "
447
+ "instead (--default-storage-provider)."
448
+ )
449
+ self.ssh_keyfile = Path(self.tmpdir) / "vastai_ssh_key"
450
+ subprocess.run(
451
+ [
452
+ "ssh-keygen",
453
+ "-t",
454
+ "ed25519",
455
+ "-N",
456
+ "",
457
+ "-q",
458
+ "-C",
459
+ f"snakemake-vastai-{self.run_id}",
460
+ "-f",
461
+ str(self.ssh_keyfile),
462
+ ],
463
+ check=True,
464
+ )
465
+ self.ssh_pubkey = (
466
+ self.ssh_keyfile.with_suffix(".pub").read_text().strip()
467
+ )
468
+ self.source_archive_path = Path(self.tmpdir) / "sources.tar.xz"
469
+ self.workflow.write_source_archive(self.source_archive_path)
470
+
471
+ @property
472
+ def common_settings(self):
473
+ # The module-level can_transfer_local_files=True makes Snakemake core
474
+ # skip both the storage requirement and the storage-based source
475
+ # deployment. In storage mode the latter is wanted after all, so the
476
+ # spawned-job machinery is presented with settings that re-enable it
477
+ # (pip-install of storage plugins + `snakemake --deploy-sources` in
478
+ # the job's precommand).
479
+ if getattr(self, "storage_mode", False):
480
+ return replace(
481
+ common_settings,
482
+ can_transfer_local_files=False,
483
+ auto_deploy_default_storage_provider=True,
484
+ )
485
+ return common_settings
486
+
487
+ def job_setup_script(self) -> str:
488
+ """Shell lines run inside the container before the job command: put
489
+ python on PATH, materialize shipped credentials, install snakemake
490
+ if absent."""
491
+ return (
492
+ PYTHON_PATH_SETUP
493
+ + CREDENTIAL_SETUP_SCRIPT
494
+ + snakemake_bootstrap_script(
495
+ with_storage_plugins=not self.storage_mode
496
+ )
497
+ )
498
+
499
+ def _ssh_onstart_script(self) -> str:
500
+ """Instance startup script for SSH transfer mode.
501
+
502
+ Provisions the per-run SSH key by writing authorized_keys directly —
503
+ the attach-key API on a live instance is known to be racy, while
504
+ onstart runs as the container user before sshd accepts connections.
505
+ Also best-effort installs OpenSSH for images that lack it (Vast's
506
+ launch script keeps retrying the proxy tunnel until `ssh` appears).
507
+ """
508
+ log = "/tmp/snakemake-vastai-openssh.log"
509
+ return (
510
+ "#!/bin/sh\n"
511
+ "if ! command -v ssh >/dev/null 2>&1 "
512
+ "|| ! command -v sshd >/dev/null 2>&1; then\n"
513
+ " echo '[snakemake-vastai] image lacks OpenSSH, installing it'\n"
514
+ " { apt-get update && apt-get install -y openssh-server "
515
+ f"openssh-client; }} >{log} 2>&1 \\\n"
516
+ " || echo '[snakemake-vastai] could not install OpenSSH; "
517
+ "SSH transfer will fail (use an image with OpenSSH, e.g. the "
518
+ "default vastai/pytorch)'\n"
519
+ "fi\n"
520
+ "for d in \"$HOME\" /root; do\n"
521
+ " mkdir -p \"$d/.ssh\" 2>/dev/null || continue\n"
522
+ f" echo '{self.ssh_pubkey}' >> \"$d/.ssh/authorized_keys\"\n"
523
+ " chmod 700 \"$d/.ssh\"\n"
524
+ " chmod 600 \"$d/.ssh/authorized_keys\"\n"
525
+ "done\n"
526
+ )
527
+
528
+ def run_job(self, job: JobExecutorInterface):
529
+ exec_job = self.format_job_exec(job)
530
+ query = build_offer_query(self.settings, dict(job.resources), job.threads)
531
+ disk_gb = required_disk_gb(self.settings, job.resources)
532
+ self.logger.debug(f"Searching Vast.ai offers with query: {query}")
533
+
534
+ try:
535
+ offers = self.vast.search_offers(
536
+ query=query,
537
+ type="on-demand",
538
+ order=self.settings.order,
539
+ limit=OFFER_ATTEMPTS,
540
+ storage=disk_gb,
541
+ )
542
+ except Exception as e:
543
+ raise WorkflowError(f"Failed to search Vast.ai offers: {e}")
544
+ if not offers:
545
+ raise WorkflowError(
546
+ f"No Vast.ai offers match the requirements of job {job.name} "
547
+ f"(query: {query}). Relax the filters (e.g. --vastai-max-price, "
548
+ "--vastai-gpu-name, --vastai-search-query) or free up budget."
549
+ )
550
+
551
+ label = f"snakemake-{self.run_id}-{job.jobid}-{job.attempt}"
552
+ if self.storage_mode:
553
+ # The wrapper runs the job command and emits the exit code as the
554
+ # last log line, where check_active_jobs() picks it up. The
555
+ # container exits afterwards (entrypoint mode), flipping the
556
+ # instance to 'exited'.
557
+ script = (
558
+ f"echo '[snakemake-vastai] starting job {job.name} "
559
+ f"(jobid={job.jobid}, attempt={job.attempt})'\n"
560
+ "mkdir -p /snakemake-workdir && cd /snakemake-workdir\n"
561
+ f"{self.job_setup_script()}"
562
+ f"{exec_job}\n"
563
+ "ec=$?\n"
564
+ 'echo "snakemake_vastai_exit_code=${ec}"\n'
565
+ "exit ${ec}\n"
566
+ )
567
+ create_kwargs = dict(
568
+ runtype="args",
569
+ onstart_cmd="/bin/sh",
570
+ args=["-c", script],
571
+ )
572
+ else:
573
+ # SSH transfer mode: the container only has to be alive and
574
+ # reachable; an SSHJobRunner thread drives the job. Note that
575
+ # only "ssh_proxy" gets the proxy tunnel wired up (plain "ssh"
576
+ # leaves the ssh_host:ssh_port unreachable).
577
+ create_kwargs = dict(
578
+ runtype="ssh_proxy",
579
+ onstart_cmd=self._ssh_onstart_script(),
580
+ )
581
+
582
+ instance_id = None
583
+ last_error = None
584
+ for offer in offers:
585
+ try:
586
+ response = self.vast.create_instance(
587
+ id=offer["id"],
588
+ image=self.container_image,
589
+ disk=disk_gb,
590
+ label=label,
591
+ env=self._container_env(),
592
+ cancel_unavail=True,
593
+ **create_kwargs,
594
+ )
595
+ except Exception as e:
596
+ # Offers are single-use; this one was likely taken in the
597
+ # meantime. Try the next candidate.
598
+ last_error = e
599
+ continue
600
+ if response.get("success"):
601
+ instance_id = response["new_contract"]
602
+ self.logger.info(
603
+ f"Job {job.jobid} ({job.name}): rented Vast.ai instance "
604
+ f"{instance_id} ({offer.get('num_gpus')}x "
605
+ f"{offer.get('gpu_name')} @ ${offer.get('dph_total', 0):.3f}/h, "
606
+ f"machine {offer.get('machine_id')})"
607
+ )
608
+ break
609
+ last_error = WorkflowError(f"Unexpected response: {response}")
610
+
611
+ if instance_id is None:
612
+ raise WorkflowError(
613
+ f"Could not rent any of the {len(offers)} matching Vast.ai "
614
+ f"offers for job {job.name}. Last error: {last_error}"
615
+ )
616
+
617
+ job_info = SubmittedJobInfo(
618
+ job=job,
619
+ external_jobid=str(instance_id),
620
+ aux={
621
+ "submitted": time.time(),
622
+ "label": label,
623
+ "seen_running": False,
624
+ "unreachable_since": None,
625
+ "api_misses": 0,
626
+ },
627
+ )
628
+
629
+ if not self.storage_mode:
630
+ try:
631
+ self.vast.attach_ssh(instance_id, self.ssh_pubkey)
632
+ except Exception as e:
633
+ self._destroy_instance(instance_id)
634
+ raise WorkflowError(
635
+ f"Failed to attach SSH key to Vast.ai instance "
636
+ f"{instance_id}: {e}"
637
+ )
638
+ runner = SSHJobRunner(self, job_info, exec_job)
639
+ job_info.aux["runner"] = runner
640
+ runner.start()
641
+
642
+ self.report_job_submission(job_info)
643
+
644
+ def _container_env(self) -> dict:
645
+ env = {}
646
+ if not self.settings.no_forward_credentials:
647
+ env.update(credential_envvars())
648
+ # Declared envvars and storage plugin secrets (forwarded by Snakemake
649
+ # itself) take precedence over ambient credentials.
650
+ env.update(self.envvars())
651
+ return env
652
+
653
+ async def check_active_jobs(
654
+ self, active_jobs: List[SubmittedJobInfo]
655
+ ) -> AsyncGenerator[SubmittedJobInfo, None]:
656
+ for active_job in active_jobs:
657
+ async with self.status_rate_limiter:
658
+ still_active = self._check_job(active_job)
659
+ if still_active:
660
+ yield active_job
661
+
662
+ def _check_job(self, job_info: SubmittedJobInfo) -> bool:
663
+ """Check one job; returns True if it is still running."""
664
+ if not self.storage_mode:
665
+ return self._check_ssh_job(job_info)
666
+
667
+ aux = job_info.aux
668
+ instance_id = int(job_info.external_jobid)
669
+
670
+ try:
671
+ instance = self.vast.show_instance(instance_id)
672
+ if isinstance(instance, list):
673
+ instance = instance[0] if instance else {}
674
+ except Exception as e:
675
+ aux["api_misses"] += 1
676
+ if aux["api_misses"] > MAX_STATUS_API_MISSES:
677
+ self._finalize_job(job_info, instance={})
678
+ return False
679
+ self.logger.warning(
680
+ f"Failed to query status of Vast.ai instance {instance_id} "
681
+ f"({aux['api_misses']}/{MAX_STATUS_API_MISSES}): {e}"
682
+ )
683
+ return True
684
+ aux["api_misses"] = 0
685
+
686
+ status = instance.get("actual_status") if instance else None
687
+
688
+ if not instance:
689
+ # Instance vanished (e.g. destroyed externally).
690
+ self._finalize_job(job_info, instance={})
691
+ return False
692
+
693
+ if status == "running":
694
+ aux["seen_running"] = True
695
+ aux["unreachable_since"] = None
696
+ return True
697
+
698
+ if status == "exited":
699
+ # Container entrypoint finished; outcome is in the logs.
700
+ self._finalize_job(job_info, instance)
701
+ return False
702
+
703
+ if status in (None, "", "created", "loading", "rebooting"):
704
+ if not aux["seen_running"] and fatal_boot_error(
705
+ instance.get("status_msg")
706
+ ):
707
+ self._fail_job(
708
+ job_info,
709
+ f"Instance failed to start "
710
+ f"(status_msg: {instance.get('status_msg')!r}). This "
711
+ "host will not recover; rerun with --retries to "
712
+ "resubmit on a different machine.",
713
+ )
714
+ return False
715
+ elapsed = time.time() - aux["submitted"]
716
+ if not aux["seen_running"] and elapsed > self.settings.boot_timeout:
717
+ self._fail_job(
718
+ job_info,
719
+ f"Instance did not reach running state within "
720
+ f"{self.settings.boot_timeout}s (status: {status!r}, "
721
+ f"status_msg: {instance.get('status_msg')!r}).",
722
+ )
723
+ return False
724
+ return True
725
+
726
+ # stopped / offline / unknown / anything unexpected: possibly a
727
+ # transient host problem, give it a grace period.
728
+ now = time.time()
729
+ if aux["unreachable_since"] is None:
730
+ aux["unreachable_since"] = now
731
+ self.logger.warning(
732
+ f"Vast.ai instance {instance_id} reports status {status!r} "
733
+ f"(status_msg: {instance.get('status_msg')!r}); waiting up to "
734
+ f"{UNREACHABLE_GRACE_SECONDS}s for it to recover."
735
+ )
736
+ if now - aux["unreachable_since"] > UNREACHABLE_GRACE_SECONDS:
737
+ self._finalize_job(job_info, instance)
738
+ return False
739
+ return True
740
+
741
+ def _check_ssh_job(self, job_info: SubmittedJobInfo) -> bool:
742
+ """Inspect the outcome recorded by the job's SSHJobRunner thread."""
743
+ runner = job_info.aux["runner"]
744
+ outcome = job_info.aux.get("outcome")
745
+ if outcome is None:
746
+ if not runner.thread.is_alive():
747
+ self._destroy_instance(int(job_info.external_jobid))
748
+ self.report_job_error(
749
+ job_info,
750
+ msg="The transfer worker for this job died unexpectedly. ",
751
+ )
752
+ return False
753
+ return True
754
+ kind, msg = outcome
755
+ log_file = job_info.aux.get("log_file")
756
+ if kind == "success":
757
+ self.report_job_success(job_info)
758
+ else:
759
+ self.report_job_error(
760
+ job_info,
761
+ msg=f"Vast.ai instance {job_info.external_jobid}: {msg} ",
762
+ aux_logs=[log_file] if log_file else None,
763
+ )
764
+ return False
765
+
766
+ def _finalize_job(self, job_info: SubmittedJobInfo, instance: dict):
767
+ """Fetch logs, parse the exit code, destroy the instance and report."""
768
+ instance_id = int(job_info.external_jobid)
769
+
770
+ log_text = ""
771
+ try:
772
+ log_text = self.vast.logs(instance_id, tail="10000") or ""
773
+ except Exception as e:
774
+ self.logger.warning(
775
+ f"Could not retrieve logs of Vast.ai instance {instance_id}: {e}"
776
+ )
777
+
778
+ log_file = None
779
+ if log_text:
780
+ log_file = self.log_dir / f"{job_info.aux['label']}.log"
781
+ log_file.write_text(log_text)
782
+
783
+ self._destroy_instance(instance_id)
784
+
785
+ matches = EXIT_CODE_PATTERN.findall(log_text)
786
+ exit_code = int(matches[-1]) if matches else None
787
+
788
+ if exit_code == 0:
789
+ self.report_job_success(job_info)
790
+ elif exit_code is not None:
791
+ self.report_job_error(
792
+ job_info,
793
+ msg=f"Job finished with exit code {exit_code} on Vast.ai "
794
+ f"instance {instance_id}. ",
795
+ aux_logs=[str(log_file)] if log_file else None,
796
+ )
797
+ else:
798
+ status = instance.get("actual_status") if instance else "gone"
799
+ status_msg = instance.get("status_msg") if instance else None
800
+ self.report_job_error(
801
+ job_info,
802
+ msg=f"Vast.ai instance {instance_id} terminated without "
803
+ f"reporting a job exit code (instance status: {status!r}, "
804
+ f"status_msg: {status_msg!r}). The host may have failed or "
805
+ "the container crashed. ",
806
+ aux_logs=[str(log_file)] if log_file else None,
807
+ )
808
+
809
+ def _fail_job(self, job_info: SubmittedJobInfo, msg: str):
810
+ instance_id = int(job_info.external_jobid)
811
+ self._destroy_instance(instance_id)
812
+ self.report_job_error(
813
+ job_info, msg=f"Vast.ai instance {instance_id}: {msg} "
814
+ )
815
+
816
+ def _destroy_instance(self, instance_id: int):
817
+ if instance_id in self._destroyed_instances:
818
+ return
819
+ if self.settings.keep_instances:
820
+ self.logger.info(
821
+ f"Keeping Vast.ai instance {instance_id} as requested "
822
+ "(--vastai-keep-instances). Remember to destroy it manually, "
823
+ "it accrues charges until then!"
824
+ )
825
+ return
826
+ try:
827
+ self.vast.destroy_instance(instance_id)
828
+ self._destroyed_instances.add(instance_id)
829
+ self.logger.debug(f"Destroyed Vast.ai instance {instance_id}.")
830
+ except Exception as e:
831
+ self.logger.error(
832
+ f"Failed to destroy Vast.ai instance {instance_id}: {e}. "
833
+ "Please destroy it manually at https://console.vast.ai/instances/ "
834
+ "to stop billing."
835
+ )
836
+
837
+ def cancel_jobs(self, active_jobs: List[SubmittedJobInfo]):
838
+ for job_info in active_jobs:
839
+ self.logger.info(
840
+ f"Cancelling job {job_info.job.jobid} on Vast.ai instance "
841
+ f"{job_info.external_jobid}."
842
+ )
843
+ cancel_event = job_info.aux.get("cancel_event")
844
+ if cancel_event is not None:
845
+ cancel_event.set()
846
+ self._destroy_instance(int(job_info.external_jobid))