zipstrain 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zipstrain/__init__.py +7 -0
- zipstrain/cli.py +808 -0
- zipstrain/compare.py +377 -0
- zipstrain/database.py +871 -0
- zipstrain/profile.py +221 -0
- zipstrain/task_manager.py +1978 -0
- zipstrain/utils.py +451 -0
- zipstrain/visualize.py +586 -0
- zipstrain-0.2.4.dist-info/METADATA +27 -0
- zipstrain-0.2.4.dist-info/RECORD +12 -0
- zipstrain-0.2.4.dist-info/WHEEL +4 -0
- zipstrain-0.2.4.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,1978 @@
|
|
|
1
|
+
"""zipstrain.task_manager
|
|
2
|
+
========================
|
|
3
|
+
Lightweight, asyncio-driven orchestration primitives for building and running
|
|
4
|
+
scientific data-processing pipelines. This module provides a small, composable
|
|
5
|
+
framework for defining Tasks with explicit inputs/outputs, bundling Tasks into
|
|
6
|
+
Batches (local or Slurm), and coordinating their execution with a live terminal
|
|
7
|
+
UI. It is designed to be easy to extend for new Task types and execution
|
|
8
|
+
environments. For most users, this module is not directly used. However,
|
|
9
|
+
it can be used to define new pipelines that chain together multiple steps with clear input/outputs.
|
|
10
|
+
The unit of execution is a batch, which is a collection of tasks to be executed together.
|
|
11
|
+
Each batch can have an optional finalization step that runs after all tasks are complete.
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
Key concepts
|
|
15
|
+
------------
|
|
16
|
+
|
|
17
|
+
- Inputs and Outputs:
|
|
18
|
+
These classes encapsulate task inputs and outputs with validation logics.
|
|
19
|
+
By default, Input and output classes for files, strings, and integers are provided.
|
|
20
|
+
If needed, new types can be defined by subclassing Input or Output.
|
|
21
|
+
|
|
22
|
+
- Engines:
|
|
23
|
+
Any task object can use a container engine (Docker or Apptainer) or run natively (LocalEngine).
|
|
24
|
+
|
|
25
|
+
- Task
|
|
26
|
+
Each task runs a unit of bash script with defined inputs and expected outputs. If an engine is provided,
|
|
27
|
+
the command will be wrapped accordingly to run inside the container.
|
|
28
|
+
|
|
29
|
+
- Batches:
|
|
30
|
+
A batch is a collection of tasks to be executed together. Batches can be run locally or submitted to Slurm.
|
|
31
|
+
Each batch monitors the status of its tasks and updates its own status accordingly. A batch can also have
|
|
32
|
+
expected outputs that are checked after all tasks are complete. Additionally, a batch can have a finalization step that runs after all tasks are complete.
|
|
33
|
+
|
|
34
|
+
- Runner:
|
|
35
|
+
The Runner class orchestrates task generation, batching, and execution. It manages concurrent batch execution,
|
|
36
|
+
monitors progress, and provides a live terminal UI using the rich library.
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
from __future__ import annotations
|
|
42
|
+
from enum import StrEnum
|
|
43
|
+
import re
|
|
44
|
+
import pathlib
|
|
45
|
+
from abc import ABC, abstractmethod
|
|
46
|
+
import asyncio
|
|
47
|
+
import subprocess
|
|
48
|
+
import aiofiles
|
|
49
|
+
from pydantic import BaseModel, Field, field_validator
|
|
50
|
+
from rich.live import Live
|
|
51
|
+
from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn
|
|
52
|
+
from rich.console import Console, Group
|
|
53
|
+
from rich.panel import Panel
|
|
54
|
+
from rich.align import Align
|
|
55
|
+
from zipstrain import database
|
|
56
|
+
from rich.columns import Columns
|
|
57
|
+
import polars as pl
|
|
58
|
+
import psutil
|
|
59
|
+
import shutil
|
|
60
|
+
import signal
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class SlurmConfig(BaseModel):
|
|
64
|
+
"""Configuration model for Slurm batch jobs.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
time (str): Time limit for the job in HH:MM:SS format.
|
|
68
|
+
tasks (int): Number of tasks.
|
|
69
|
+
mem (int): Memory in GB.
|
|
70
|
+
additional_params (dict): Additional SLURM parameters as key-value pairs.
|
|
71
|
+
|
|
72
|
+
NOTE: Additional paramters for slurm should be provided in the additional_params dict in the form
|
|
73
|
+
of {"param-name": "param-value"}, e.g., {"cpus-per-task": "4"} will result in the addition of
|
|
74
|
+
"#SBATCH --cpus-per-task=4" to the sbatch script.
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
time: str = Field(description="Time limit for the job.")
|
|
78
|
+
tasks: int = Field(default=1, description="Number of tasks.")
|
|
79
|
+
mem: int = Field(default=4, description="Memory in GB.")
|
|
80
|
+
additional_params: dict[str, str] = Field(default_factory=dict, description="Additional SLURM parameters as key-value pairs.")
|
|
81
|
+
|
|
82
|
+
@field_validator("time")
|
|
83
|
+
def validate_time(cls, v):
|
|
84
|
+
"""Validate time format HH:MM:SS (H..HHH allowed)."""
|
|
85
|
+
if not re.match(r"^\d{1,3}:\d{2}:\d{2}$", v):
|
|
86
|
+
raise ValueError("Time must be in the format HH:MM:SS (H..HHH allowed)")
|
|
87
|
+
return v
|
|
88
|
+
|
|
89
|
+
def to_slurm_args(self) -> str:
|
|
90
|
+
"""Generates the slurm batch file header form the configuration object"""
|
|
91
|
+
args = [
|
|
92
|
+
f"#SBATCH --time={self.time}",
|
|
93
|
+
f"#SBATCH --ntasks={self.tasks}",
|
|
94
|
+
f"#SBATCH --mem={self.mem}G",
|
|
95
|
+
]
|
|
96
|
+
for key, value in self.additional_params.items():
|
|
97
|
+
args.append(f"#SBATCH --{key}={value}")
|
|
98
|
+
return "\n".join(args)
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def from_json(cls, json_path: str | pathlib.Path) -> SlurmConfig:
|
|
102
|
+
"""Load SlurmConfig from a JSON file."""
|
|
103
|
+
path = pathlib.Path(json_path)
|
|
104
|
+
if not path.exists():
|
|
105
|
+
raise FileNotFoundError(f"Slurm config file {json_path} does not exist.")
|
|
106
|
+
return cls.model_validate_json(path.read_text())
|
|
107
|
+
|
|
108
|
+
async def write_file(path: pathlib.Path, text: str, file_semaphore: asyncio.Semaphore) -> None:
|
|
109
|
+
async with file_semaphore:
|
|
110
|
+
async with aiofiles.open(path, "w") as f:
|
|
111
|
+
await f.write(text)
|
|
112
|
+
|
|
113
|
+
async def read_file(path: pathlib.Path, file_semaphore: asyncio.Semaphore) -> str:
|
|
114
|
+
async with file_semaphore:
|
|
115
|
+
async with aiofiles.open(path, "r") as f:
|
|
116
|
+
content = await f.read()
|
|
117
|
+
return content
|
|
118
|
+
|
|
119
|
+
class Status(StrEnum):
|
|
120
|
+
"""Enumeration of possible task and batch statuses."""
|
|
121
|
+
BATCH_NOT_ASSIGNED = "batch_not_assigned"
|
|
122
|
+
NOT_STARTED = "not_started"
|
|
123
|
+
RUNNING = "running"
|
|
124
|
+
DONE = "done"
|
|
125
|
+
FAILED = "failed"
|
|
126
|
+
SUBMITTED = "submitted"
|
|
127
|
+
SUCCESS = "success"
|
|
128
|
+
PENDING = "pending"
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class Input(ABC):
|
|
132
|
+
"""Abstract base class for task inputs. DO NOT INSTANTIATE DIRECTLY.
|
|
133
|
+
Most commonly used Input types are provided but if you want to define a new one,
|
|
134
|
+
subclass this and implement validate() and get_value().
|
|
135
|
+
"""
|
|
136
|
+
def __init__(self, value: str | int) -> None:
|
|
137
|
+
self.value = value
|
|
138
|
+
self.validate()
|
|
139
|
+
|
|
140
|
+
@abstractmethod
|
|
141
|
+
def validate(self) -> None:
|
|
142
|
+
...
|
|
143
|
+
|
|
144
|
+
@abstractmethod
|
|
145
|
+
def get_value(self) -> str | int:
|
|
146
|
+
...
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class FileInput(Input):
|
|
150
|
+
"""This is used when the input is a file path. By default, the validate method checks for file existence."""
|
|
151
|
+
def validate(self, check_exists: bool = True) -> None:
|
|
152
|
+
if check_exists and not pathlib.Path(self.value).exists():
|
|
153
|
+
raise FileNotFoundError(f"Input file {self.value} does not exist.")
|
|
154
|
+
|
|
155
|
+
def get_value(self) -> str:
|
|
156
|
+
"""Returns the absolute path of the input file as a string."""
|
|
157
|
+
return str(pathlib.Path(self.value).absolute())
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class StringInput(Input):
|
|
161
|
+
"""This is used when the input is a string."""
|
|
162
|
+
|
|
163
|
+
def validate(self) -> None:
|
|
164
|
+
"""Validate that the input value is a string."""
|
|
165
|
+
if not isinstance(self.value, str):
|
|
166
|
+
raise ValueError(f"Input value {self.value!r} is not a string.")
|
|
167
|
+
|
|
168
|
+
def get_value(self) -> str:
|
|
169
|
+
"""Returns the string value."""
|
|
170
|
+
return str(self.value)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class IntInput(Input):
|
|
174
|
+
"""This is used when the input is an integer."""
|
|
175
|
+
def validate(self) -> None:
|
|
176
|
+
"""Validate that the input value is an integer."""
|
|
177
|
+
if not isinstance(self.value, int):
|
|
178
|
+
raise ValueError(f"Input value {self.value!r} is not an integer.")
|
|
179
|
+
|
|
180
|
+
def get_value(self) -> str:
|
|
181
|
+
"""
|
|
182
|
+
Returns the integer value as a string.
|
|
183
|
+
"""
|
|
184
|
+
return str(self.value)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class Output(ABC):
|
|
188
|
+
"""Abstract base class for task outputs. DO NOT INSTANTIATE DIRECTLY.
|
|
189
|
+
Most commonly used Output types are provided but if you want to define a new one,
|
|
190
|
+
subclass this and implement ready().
|
|
191
|
+
This method is used to check if the output is ready/valid after task completion.
|
|
192
|
+
"""
|
|
193
|
+
def __init__(self) -> None:
|
|
194
|
+
self._value = None ## Will be set by the task when it completes
|
|
195
|
+
self.task = None ## Will be set when the output is registered to a task
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def value(self):
|
|
199
|
+
return self._value
|
|
200
|
+
|
|
201
|
+
@abstractmethod
|
|
202
|
+
def ready(self) -> bool:
|
|
203
|
+
...
|
|
204
|
+
|
|
205
|
+
def register_task(self, task: Task) -> None:
|
|
206
|
+
"""Registers the task that produces this output. In most cases, you won't need to override this.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
task (Task): The task that produces this output.
|
|
210
|
+
"""
|
|
211
|
+
self.task = task
|
|
212
|
+
|
|
213
|
+
class FileOutput(Output):
|
|
214
|
+
"""This is used when the output is a file path.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
expected_file (str): The expected output file name relative to the task directory.
|
|
218
|
+
"""
|
|
219
|
+
def __init__(self, expected_file:str) -> None:
|
|
220
|
+
self._expected_file_name = expected_file ### When the task is finished, the expected file should be in task.task_dir / expected_file otherwise ready() will return False
|
|
221
|
+
|
|
222
|
+
def ready(self) -> bool:
|
|
223
|
+
"""Check if the expected output file exists."""
|
|
224
|
+
return True if self.expected_file.absolute().exists() else False
|
|
225
|
+
|
|
226
|
+
def register_task(self, task: Task) -> None:
|
|
227
|
+
"""Registers the task that produces this output and sets the expected file path.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
task (Task): The task that produces this output.
|
|
231
|
+
"""
|
|
232
|
+
super().register_task(task)
|
|
233
|
+
self.expected_file = self.task.task_dir / self._expected_file_name
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class BatchFileOutput(Output):
|
|
237
|
+
"""This is used when the output is a file path relative to the batch directory.
|
|
238
|
+
Also it will be registered to the batch instead of the task.
|
|
239
|
+
"""
|
|
240
|
+
def __init__(self, expected_file:str) -> None:
|
|
241
|
+
self._expected_file_name = expected_file
|
|
242
|
+
|
|
243
|
+
def ready(self) -> bool:
|
|
244
|
+
"""Check if the expected output file exists."""
|
|
245
|
+
return True if self.expected_file.absolute().exists() else False
|
|
246
|
+
|
|
247
|
+
def register_batch(self, batch: Batch) -> None:
|
|
248
|
+
"""Registers the batch that produces this output and sets the expected file path.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
batch (Batch): The batch that produces this output and sets the expected file path.
|
|
252
|
+
"""
|
|
253
|
+
self.expected_file = batch.batch_dir / self._expected_file_name
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class StringOutput(Output):
|
|
257
|
+
"""This is used when the output is a string."""
|
|
258
|
+
def ready(self) -> bool:
|
|
259
|
+
"""Check if the output value is a string."""
|
|
260
|
+
if isinstance(self._value, str):
|
|
261
|
+
return True
|
|
262
|
+
elif self._value is not None:
|
|
263
|
+
raise ValueError(f"Output value for task {self.task.id} is not a string.")
|
|
264
|
+
else:
|
|
265
|
+
return False
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class IntOutput(Output):
|
|
269
|
+
"""This is used when the output is an integer."""
|
|
270
|
+
def ready(self) -> bool:
|
|
271
|
+
"""Check if the output value is an integer."""
|
|
272
|
+
if isinstance(self._value, int):
|
|
273
|
+
return True
|
|
274
|
+
elif self._value is not None:
|
|
275
|
+
raise ValueError(f"Output value for task {self.task.id} is not an integer.")
|
|
276
|
+
else:
|
|
277
|
+
return False
|
|
278
|
+
return False
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class Engine(ABC):
|
|
282
|
+
def __init__(self, address: str) -> None:
|
|
283
|
+
self.address = address
|
|
284
|
+
|
|
285
|
+
@abstractmethod
|
|
286
|
+
def wrap(self, command: str, file_inputs: list[FileInput]) -> str:
|
|
287
|
+
...
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class DockerEngine(Engine):
|
|
291
|
+
def wrap(self, command: str, file_inputs: list[FileInput]) -> str:
|
|
292
|
+
volume_mounts = " ".join(
|
|
293
|
+
[f"-v {file_input.get_value()}:{file_input.get_value()}" for file_input in file_inputs]
|
|
294
|
+
)
|
|
295
|
+
return f"docker run {volume_mounts} {self.address} {command}"
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class ApptainerEngine(Engine):
|
|
299
|
+
def wrap(self, command: str, file_inputs: list[FileInput]) -> str:
|
|
300
|
+
volume_mounts = "--bind " + ",".join(
|
|
301
|
+
[f"{file_input.get_value()}:{file_input.get_value()}" for file_input in file_inputs]
|
|
302
|
+
)
|
|
303
|
+
return f"apptainer run {volume_mounts} {self.address} {command}"
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class LocalEngine(Engine):
|
|
307
|
+
def wrap(self, command: str, file_inputs: list[FileInput]) -> str:
|
|
308
|
+
return command
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class Task(ABC):
|
|
312
|
+
"""Abstract base class for tasks. DO NOT INSTANTIATE DIRECTLY. Any new task type should subclass this
|
|
313
|
+
and implement the TEMPLATE_CMD class attribute. Inputs and expected outputs are specified using <>.
|
|
314
|
+
As an example, if a task has an input file called "input-file" and an expected output file called "output-file",
|
|
315
|
+
the TEMPLATE_CMD could be something like:
|
|
316
|
+
TEMPLATE_CMD = "some_command --input <input-file> --output <output-file>"
|
|
317
|
+
the outputs and inputs will be mapped to the command when map_io() is called later in the runtime.
|
|
318
|
+
"""
|
|
319
|
+
TEMPLATE_CMD = ""
|
|
320
|
+
|
|
321
|
+
def __init__(
|
|
322
|
+
self,
|
|
323
|
+
id: str,
|
|
324
|
+
inputs: dict[str, Input | Output],
|
|
325
|
+
expected_outputs: dict[str, Output] ,
|
|
326
|
+
engine: Engine,
|
|
327
|
+
batch_obj: Batch | None = None,
|
|
328
|
+
file_semaphore: asyncio.Semaphore | None = None
|
|
329
|
+
) -> None:
|
|
330
|
+
self.id = id
|
|
331
|
+
self.inputs = inputs
|
|
332
|
+
self.expected_outputs = expected_outputs
|
|
333
|
+
self._batch_obj = batch_obj
|
|
334
|
+
self.engine = engine
|
|
335
|
+
self._status = self._get_initial_status()
|
|
336
|
+
self.file_semaphore = file_semaphore
|
|
337
|
+
|
|
338
|
+
def map_io(self) -> None:
|
|
339
|
+
"""Maps inputs and expected outputs to the command template. Note that when this method is called,
|
|
340
|
+
all of the inputs and outputs in the TEMPLATE_CMD must be defined in the inputs and expected_outputs dictionaries.
|
|
341
|
+
However, this method is not called by the user directly. It is called by the Batch when the task is added to a batch.
|
|
342
|
+
"""
|
|
343
|
+
cmd = self.TEMPLATE_CMD
|
|
344
|
+
for key, value in self.inputs.items():
|
|
345
|
+
cmd = cmd.replace(f"<{key}>", value.get_value())
|
|
346
|
+
# if any placeholders remain, report them
|
|
347
|
+
|
|
348
|
+
for handle, output in self.expected_outputs.items():
|
|
349
|
+
cmd = cmd.replace(f"<{handle}>", str(output.expected_file.absolute()))
|
|
350
|
+
remaining = re.findall(r"<\w+>", cmd)
|
|
351
|
+
if remaining:
|
|
352
|
+
raise ValueError(f"Not all inputs were mapped in task {self.id}. Remaining placeholders: {remaining}")
|
|
353
|
+
self._command = cmd
|
|
354
|
+
|
|
355
|
+
@property
|
|
356
|
+
def batch_dir(self) -> pathlib.Path:
|
|
357
|
+
"""Returns the batch directory path. Raises an error if the task is not associated with any batch yet."""
|
|
358
|
+
if self._batch_obj is None:
|
|
359
|
+
raise ValueError(f"Task {self.id} is not associated with any batch yet.")
|
|
360
|
+
return self._batch_obj.batch_dir
|
|
361
|
+
|
|
362
|
+
@property
|
|
363
|
+
def task_dir(self) -> pathlib.Path:
|
|
364
|
+
"""Returns the task directory path."""
|
|
365
|
+
return self.batch_dir / self.id
|
|
366
|
+
|
|
367
|
+
@property
|
|
368
|
+
def command(self) -> str:
|
|
369
|
+
"""Returns the command to be executed, wrapped with the engine if applicable."""
|
|
370
|
+
file_inputs = [v for v in self.inputs.values() if isinstance(v, FileInput)]
|
|
371
|
+
return self.engine.wrap(self._command, file_inputs)
|
|
372
|
+
|
|
373
|
+
@property
|
|
374
|
+
def pre_run(self) -> str:
|
|
375
|
+
"""Does the necessary setup before running the task command. This should not be overridden by subclasses unless a task needs special setup like
|
|
376
|
+
batch aggregation."""
|
|
377
|
+
return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self.task_dir.absolute()}"
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
def status(self) -> str:
|
|
381
|
+
"""Returns the current status of the task."""
|
|
382
|
+
return self._status
|
|
383
|
+
|
|
384
|
+
@property
|
|
385
|
+
def post_run(self) -> str:
|
|
386
|
+
"""Does the necessary steps after running the task command. This should not be overridden by subclasses unless a task needs special teardown like
|
|
387
|
+
batch aggregation."""
|
|
388
|
+
return f"cd {self.batch_dir.absolute()} && echo {Status.DONE.value} > {self.task_dir.absolute()}/.status"
|
|
389
|
+
|
|
390
|
+
async def get_status(self) -> str:
|
|
391
|
+
"""Asynchronously reads the task status from the .status file in the task directory."""
|
|
392
|
+
status_path = self.task_dir / ".status"
|
|
393
|
+
# read the status file if it exists
|
|
394
|
+
if status_path.exists():
|
|
395
|
+
raw = await read_file(status_path, self.file_semaphore)
|
|
396
|
+
self._status = raw.strip()
|
|
397
|
+
|
|
398
|
+
# if task reported 'done', check outputs to decide success/failure
|
|
399
|
+
if self._status == Status.DONE.value:
|
|
400
|
+
all_ready = True
|
|
401
|
+
try:
|
|
402
|
+
for output in self.expected_outputs.values():
|
|
403
|
+
if not output.ready():
|
|
404
|
+
all_ready = False
|
|
405
|
+
break
|
|
406
|
+
except Exception:
|
|
407
|
+
all_ready = False
|
|
408
|
+
|
|
409
|
+
if all_ready:
|
|
410
|
+
self._status = Status.SUCCESS.value
|
|
411
|
+
await write_file(status_path, Status.SUCCESS.value, self.file_semaphore)
|
|
412
|
+
else:
|
|
413
|
+
self._status = Status.FAILED.value
|
|
414
|
+
await write_file(status_path, Status.FAILED.value, self.file_semaphore)
|
|
415
|
+
raise ValueError(f"Task {self.id} reported done but outputs are not ready or invalid. {self.expected_outputs['output-file'].expected_file.absolute()}")
|
|
416
|
+
|
|
417
|
+
return self._status
|
|
418
|
+
|
|
419
|
+
def _get_initial_status(self) -> str:
|
|
420
|
+
"""Returns the initial status of the task based on the presence of the batch and task directories."""
|
|
421
|
+
if self._batch_obj is None:
|
|
422
|
+
return Status.BATCH_NOT_ASSIGNED.value
|
|
423
|
+
if not self.task_dir.exists():
|
|
424
|
+
return Status.NOT_STARTED.value
|
|
425
|
+
status_file = self.task_dir / ".status"
|
|
426
|
+
with open(status_file, mode="r") as f:
|
|
427
|
+
status_as_written = f.read().strip()
|
|
428
|
+
if status_as_written in (Status.DONE.value, Status.SUCCESS.value):
|
|
429
|
+
all_ready = True
|
|
430
|
+
try:
|
|
431
|
+
for output in self.expected_outputs.values():
|
|
432
|
+
if not output.ready():
|
|
433
|
+
all_ready = False
|
|
434
|
+
break
|
|
435
|
+
except Exception:
|
|
436
|
+
all_ready = False
|
|
437
|
+
|
|
438
|
+
if all_ready:
|
|
439
|
+
return Status.SUCCESS.value
|
|
440
|
+
else:
|
|
441
|
+
return Status.FAILED.value
|
|
442
|
+
|
|
443
|
+
class TaskGenerator(ABC):
|
|
444
|
+
"""Abstract base class for task generators. DO NOT INSTANTIATE DIRECTLY. A subclass of this class
|
|
445
|
+
should provide an async generator method called generate_tasks() that yields lists of Task objects in an async manner.
|
|
446
|
+
Some important concepts:
|
|
447
|
+
|
|
448
|
+
- generate_tasks() is an async generator that yields lists of Task objects.
|
|
449
|
+
|
|
450
|
+
- yield_size determines how many tasks are generated and yielded at a time.
|
|
451
|
+
|
|
452
|
+
- get_total_tasks() returns the total number of tasks that can be generated.
|
|
453
|
+
|
|
454
|
+
"""
|
|
455
|
+
def __init__(self,
|
|
456
|
+
data,
|
|
457
|
+
yield_size:int,
|
|
458
|
+
|
|
459
|
+
):
|
|
460
|
+
self.data = data
|
|
461
|
+
self.yield_size = yield_size
|
|
462
|
+
self._total_tasks = self.get_total_tasks()
|
|
463
|
+
|
|
464
|
+
@abstractmethod
|
|
465
|
+
async def generate_tasks(self) -> list[Task]:
|
|
466
|
+
pass
|
|
467
|
+
|
|
468
|
+
@abstractmethod
|
|
469
|
+
def get_total_tasks(self) -> int:
|
|
470
|
+
pass
|
|
471
|
+
|
|
472
|
+
class ProfileTaskGenerator(TaskGenerator):
|
|
473
|
+
"""This TaskGenerator generates FastProfileTask objects from a polars DataFrame. Each task profiles a BAM file."""
|
|
474
|
+
def __init__(
|
|
475
|
+
self,
|
|
476
|
+
data: pl.LazyFrame,
|
|
477
|
+
yield_size: int,
|
|
478
|
+
container_engine: Engine,
|
|
479
|
+
stb_file: str,
|
|
480
|
+
profile_bed_file: str,
|
|
481
|
+
gene_range_file: str,
|
|
482
|
+
genome_length_file: str,
|
|
483
|
+
num_procs: int = 4,
|
|
484
|
+
breadth_min_cov: int = 1,
|
|
485
|
+
) -> None:
|
|
486
|
+
super().__init__(data, yield_size)
|
|
487
|
+
self.stb_file = pathlib.Path(stb_file)
|
|
488
|
+
self.profile_bed_file = pathlib.Path(profile_bed_file)
|
|
489
|
+
self.gene_range_file = pathlib.Path(gene_range_file)
|
|
490
|
+
self.genome_length_file = pathlib.Path(genome_length_file)
|
|
491
|
+
self.num_procs = num_procs
|
|
492
|
+
self.breadth_min_cov = breadth_min_cov
|
|
493
|
+
self.engine = container_engine
|
|
494
|
+
if type(self.data) is not pl.LazyFrame:
|
|
495
|
+
raise ValueError("data must be a polars LazyFrame.")
|
|
496
|
+
for path_attr in [
|
|
497
|
+
self.stb_file,
|
|
498
|
+
self.profile_bed_file,
|
|
499
|
+
self.gene_range_file,
|
|
500
|
+
self.genome_length_file,
|
|
501
|
+
]:
|
|
502
|
+
if not path_attr.exists():
|
|
503
|
+
raise FileNotFoundError(f"File {path_attr} does not exist.")
|
|
504
|
+
|
|
505
|
+
def get_total_tasks(self) -> int:
|
|
506
|
+
"""Returns total number of profiles to be generated."""
|
|
507
|
+
return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]
|
|
508
|
+
|
|
509
|
+
async def generate_tasks(self) -> list[Task]:
|
|
510
|
+
"""Yeilds lists of FastProfileTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
|
|
511
|
+
while polars is collecting data to avoid blocking.
|
|
512
|
+
"""
|
|
513
|
+
for offset in range(0, self._total_tasks, self.yield_size):
|
|
514
|
+
batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
|
|
515
|
+
tasks = []
|
|
516
|
+
for row in batch_df.iter_rows(named=True):
|
|
517
|
+
inputs = {
|
|
518
|
+
"bam-file": FileInput(row["bamfile"]),
|
|
519
|
+
"sample-name": StringInput(row["sample_name"]),
|
|
520
|
+
"stb-file": FileInput(self.stb_file),
|
|
521
|
+
"bed-file": FileInput(self.profile_bed_file),
|
|
522
|
+
"gene-range-table": FileInput(self.gene_range_file),
|
|
523
|
+
"genome-length-file": FileInput(self.genome_length_file),
|
|
524
|
+
"num-threads": IntInput(self.num_procs),
|
|
525
|
+
"breadth-min-cov": IntInput(self.breadth_min_cov),
|
|
526
|
+
}
|
|
527
|
+
expected_outputs ={
|
|
528
|
+
"profile": FileOutput(row["sample_name"]+".parquet" ),
|
|
529
|
+
"breadth": FileOutput(row["sample_name"]+"_breadth.parquet" ),
|
|
530
|
+
"scaffold": FileOutput(row["sample_name"]+".parquet.scaffolds" ),
|
|
531
|
+
}
|
|
532
|
+
task = ProfileBamTask(id=row["sample_name"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
|
|
533
|
+
tasks.append(task)
|
|
534
|
+
yield tasks
|
|
535
|
+
|
|
536
|
+
class CompareTaskGenerator(TaskGenerator):
|
|
537
|
+
"""This TaskGenerator generates FastCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genomes functionality in
|
|
538
|
+
zipstrain.compare module.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
data (pl.LazyFrame): Polars LazyFrame containing the data for generating tasks.
|
|
542
|
+
yield_size (int): Number of tasks to yield at a time.
|
|
543
|
+
comp_config (database.GenomeComparisonConfig): Configuration for genome comparison.
|
|
544
|
+
memory_mode (str): Memory mode for the comparison task. Default is "heavy".
|
|
545
|
+
polars_engine (str): Polars engine to use. Default is "streaming".
|
|
546
|
+
chrom_batch_size (int): Chromosome batch size for the comparison task in light memory mode. Default is 10000.
|
|
547
|
+
"""
|
|
548
|
+
def __init__(
|
|
549
|
+
self,
|
|
550
|
+
data: pl.LazyFrame,
|
|
551
|
+
yield_size: int,
|
|
552
|
+
container_engine: Engine,
|
|
553
|
+
comp_config: database.GenomeComparisonConfig,
|
|
554
|
+
memory_mode: str = "heavy",
|
|
555
|
+
polars_engine: str = "streaming",
|
|
556
|
+
chrom_batch_size: int = 10000,
|
|
557
|
+
) -> None:
|
|
558
|
+
super().__init__(data, yield_size)
|
|
559
|
+
self.comp_config = comp_config
|
|
560
|
+
self.engine = container_engine
|
|
561
|
+
self.memory_mode = memory_mode
|
|
562
|
+
self.polars_engine = polars_engine
|
|
563
|
+
self.chrom_batch_size = chrom_batch_size
|
|
564
|
+
if type(self.data) is not pl.LazyFrame:
|
|
565
|
+
raise ValueError("data must be a polars LazyFrame.")
|
|
566
|
+
|
|
567
|
+
def get_total_tasks(self) -> int:
|
|
568
|
+
"""Returns total number of pairwise comparisons to be made."""
|
|
569
|
+
return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]
|
|
570
|
+
|
|
571
|
+
async def generate_tasks(self) -> list[Task]:
|
|
572
|
+
"""Yeilds lists of FastCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
|
|
573
|
+
while polars is collecting data to avoid blocking.
|
|
574
|
+
"""
|
|
575
|
+
for offset in range(0, self._total_tasks, self.yield_size):
|
|
576
|
+
batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
|
|
577
|
+
tasks = []
|
|
578
|
+
for row in batch_df.iter_rows(named=True):
|
|
579
|
+
inputs = {
|
|
580
|
+
"mpile_1_file": FileInput(row["profile_location_1"]),
|
|
581
|
+
"mpile_2_file": FileInput(row["profile_location_2"]),
|
|
582
|
+
"scaffold_1_file": FileInput(row["scaffold_location_1"]),
|
|
583
|
+
"scaffold_2_file": FileInput(row["scaffold_location_2"]),
|
|
584
|
+
"null_model_file": FileInput(self.comp_config.null_model_loc),
|
|
585
|
+
"stb_file": FileInput(self.comp_config.stb_file_loc),
|
|
586
|
+
"min_cov": IntInput(self.comp_config.min_cov),
|
|
587
|
+
"min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
|
|
588
|
+
"memory-mode": StringInput(self.memory_mode),
|
|
589
|
+
"chrom-batch-size": IntInput(self.chrom_batch_size),
|
|
590
|
+
"genome-name": StringInput(self.comp_config.scope),
|
|
591
|
+
"engine": StringInput(self.polars_engine),
|
|
592
|
+
}
|
|
593
|
+
expected_outputs ={
|
|
594
|
+
"output-file": FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_comparison.parquet" ),
|
|
595
|
+
|
|
596
|
+
}
|
|
597
|
+
task = FastCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
|
|
598
|
+
tasks.append(task)
|
|
599
|
+
yield tasks
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
class Batch(ABC):
|
|
603
|
+
"""Batch is a collection of tasks to be executed as a group. This is a base class and should not be instantiated directly.
|
|
604
|
+
A batch is the unit of execution meaning that the enitre batch is either run locally or submitted to a job scheduler like Slurm.
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
tasks (list[Task]): List of Task objects to be included in the batch.
|
|
608
|
+
id (str): Unique identifier for the batch.
|
|
609
|
+
run_dir (pathlib.Path): Directory where the batch will be executed.
|
|
610
|
+
expected_outputs (list[Output]): List of expected outputs for the batch.
|
|
611
|
+
"""
|
|
612
|
+
TEMPLATE_CMD = ""
|
|
613
|
+
|
|
614
|
+
def __init__(self, tasks: list[Task],
|
|
615
|
+
id: str,
|
|
616
|
+
run_dir: pathlib.Path,
|
|
617
|
+
expected_outputs: list[Output],
|
|
618
|
+
file_semaphore: asyncio.Semaphore| None = None
|
|
619
|
+
) -> None:
|
|
620
|
+
self.id = id
|
|
621
|
+
self.tasks = tasks
|
|
622
|
+
self.run_dir = pathlib.Path(run_dir)
|
|
623
|
+
self.batch_dir = self.run_dir / self.id
|
|
624
|
+
self.retry_count = 0
|
|
625
|
+
self.expected_outputs = expected_outputs
|
|
626
|
+
self.file_semaphore = file_semaphore
|
|
627
|
+
for output in self.expected_outputs:
|
|
628
|
+
if isinstance(output, BatchFileOutput):
|
|
629
|
+
output.register_batch(self)
|
|
630
|
+
self._status = self._get_initial_status()
|
|
631
|
+
for task in self.tasks:
|
|
632
|
+
task._batch_obj = self
|
|
633
|
+
task.file_semaphore = self.file_semaphore
|
|
634
|
+
for output in task.expected_outputs.values():
|
|
635
|
+
output.register_task(task)
|
|
636
|
+
task._status= task._get_initial_status()
|
|
637
|
+
task.map_io()
|
|
638
|
+
|
|
639
|
+
self._runner_obj:Runner = None
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
def _get_initial_status(self) -> str:
|
|
644
|
+
"""Returns the initial status of the batch based on the presence of the batch directory."""
|
|
645
|
+
if not self.batch_dir.exists():
|
|
646
|
+
return Status.NOT_STARTED.value
|
|
647
|
+
with open(self.batch_dir / ".status", mode="r") as f:
|
|
648
|
+
status_as_written = f.read().strip()
|
|
649
|
+
if status_as_written in (Status.DONE.value, Status.SUCCESS.value):
|
|
650
|
+
all_ready = True
|
|
651
|
+
try:
|
|
652
|
+
for output in self.expected_outputs:
|
|
653
|
+
if not output.ready():
|
|
654
|
+
all_ready = False
|
|
655
|
+
break
|
|
656
|
+
except Exception:
|
|
657
|
+
all_ready = False
|
|
658
|
+
|
|
659
|
+
if all_ready:
|
|
660
|
+
return Status.SUCCESS.value
|
|
661
|
+
else:
|
|
662
|
+
return Status.FAILED.value
|
|
663
|
+
|
|
664
|
+
def cleanup(self) -> None:
|
|
665
|
+
"""The base class defines if any cleanup is needed after batch success. By default, it does nothing."""
|
|
666
|
+
return None
|
|
667
|
+
|
|
668
|
+
@abstractmethod
|
|
669
|
+
async def cancel(self) -> None:
|
|
670
|
+
"""Cancels the batch. This method should be implemented by subclasses."""
|
|
671
|
+
...
|
|
672
|
+
|
|
673
|
+
def outputs_ready(self) -> bool:
|
|
674
|
+
"""Check if all BATCH-LEVEL expected outputs are ready."""
|
|
675
|
+
try:
|
|
676
|
+
for output in self.expected_outputs:
|
|
677
|
+
if not output.ready():
|
|
678
|
+
return False
|
|
679
|
+
return True
|
|
680
|
+
except Exception:
|
|
681
|
+
return False
|
|
682
|
+
|
|
683
|
+
async def _collect_task_status(self) -> list[str]:
|
|
684
|
+
"""Collects the status of all tasks asynchronously."""
|
|
685
|
+
return await asyncio.gather(*[task.get_status() for task in self.tasks])
|
|
686
|
+
|
|
687
|
+
@abstractmethod
|
|
688
|
+
async def run(self) -> None:
|
|
689
|
+
"""Runs the batch. This method should be implemented by subclasses."""
|
|
690
|
+
...
|
|
691
|
+
|
|
692
|
+
@abstractmethod
|
|
693
|
+
def _parse_job_id(self, sbatch_output: str) -> str:
|
|
694
|
+
"""Parses the job ID from the sbatch output. This method should be implemented by subclasses."""
|
|
695
|
+
...
|
|
696
|
+
|
|
697
|
+
@property
|
|
698
|
+
def status(self) -> str:
|
|
699
|
+
"""Returns the current status of the batch."""
|
|
700
|
+
return self._status
|
|
701
|
+
|
|
702
|
+
@property
|
|
703
|
+
def stats(self) -> dict[str, str]:
|
|
704
|
+
"""Returns a dictionary of task IDs and their statuses."""
|
|
705
|
+
return {task.id: task.status for task in self.tasks}
|
|
706
|
+
|
|
707
|
+
async def update_status(self) -> str:
|
|
708
|
+
"""Updates the status of the batch by collecting the status of all tasks."""
|
|
709
|
+
await self._collect_task_status()
|
|
710
|
+
def _set_file_semaphore(self, file_semaphore: asyncio.Semaphore) -> None:
|
|
711
|
+
self.file_semaphore = file_semaphore
|
|
712
|
+
for task in self.tasks:
|
|
713
|
+
task.file_semaphore = file_semaphore
|
|
714
|
+
|
|
715
|
+
class LocalBatch(Batch):
|
|
716
|
+
"""Batch that runs tasks locally in a single shell script."""
|
|
717
|
+
TEMPLATE_CMD = "#!/bin/bash\n"
|
|
718
|
+
|
|
719
|
+
def __init__(self, tasks, id, run_dir, expected_outputs) -> None:
|
|
720
|
+
super().__init__(tasks, id, run_dir, expected_outputs)
|
|
721
|
+
self._script = self.TEMPLATE_CMD + "\nset -o pipefail\n"
|
|
722
|
+
self._proc: asyncio.subprocess.Process | None = None
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
async def run(self) -> None:
|
|
726
|
+
"""This method runs all tasks in the batch locally by creating a shell script and executing it."""
|
|
727
|
+
if self.status != Status.SUCCESS and self.status != Status.FAILED.value:
|
|
728
|
+
self.batch_dir.mkdir(parents=True, exist_ok=True)
|
|
729
|
+
self._status = Status.RUNNING.value
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
|
|
733
|
+
|
|
734
|
+
for task in self.tasks:
|
|
735
|
+
if task.status == Status.NOT_STARTED.value:
|
|
736
|
+
task.task_dir.mkdir(parents=True, exist_ok=True) # Create task directory
|
|
737
|
+
await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)
|
|
738
|
+
|
|
739
|
+
script_path = self.batch_dir / f"{self.id}.sh" # Path to the shell script for the batch
|
|
740
|
+
|
|
741
|
+
script = self._script
|
|
742
|
+
for task in self.tasks:
|
|
743
|
+
if task.status == Status.NOT_STARTED.value or task.status == Status.FAILED.value:
|
|
744
|
+
script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"
|
|
745
|
+
|
|
746
|
+
await write_file(script_path, script, self.file_semaphore)
|
|
747
|
+
|
|
748
|
+
await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
|
|
749
|
+
|
|
750
|
+
self._proc = await asyncio.create_subprocess_exec(
|
|
751
|
+
"bash", f"{self.id}.sh",
|
|
752
|
+
stdout=asyncio.subprocess.PIPE,
|
|
753
|
+
stderr=asyncio.subprocess.PIPE,
|
|
754
|
+
cwd=self.batch_dir,
|
|
755
|
+
)
|
|
756
|
+
try:
|
|
757
|
+
out_bytes, err_bytes = await self._proc.communicate()
|
|
758
|
+
except asyncio.CancelledError:
|
|
759
|
+
if self._proc and self._proc.returncode is None:
|
|
760
|
+
self._proc.terminate()
|
|
761
|
+
raise
|
|
762
|
+
|
|
763
|
+
await write_file(self.batch_dir / f"{self.id}.out", out_bytes.decode(), self.file_semaphore)
|
|
764
|
+
await write_file(self.batch_dir / f"{self.id}.err", err_bytes.decode(), self.file_semaphore)
|
|
765
|
+
|
|
766
|
+
if self._proc.returncode == 0 and self.outputs_ready():
|
|
767
|
+
self.cleanup()
|
|
768
|
+
self._status = Status.SUCCESS.value
|
|
769
|
+
await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
|
|
770
|
+
else:
|
|
771
|
+
self._status = Status.FAILED.value
|
|
772
|
+
await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
|
|
773
|
+
|
|
774
|
+
elif self.status == Status.SUCCESS.value and self.outputs_ready():
|
|
775
|
+
self._status = Status.SUCCESS.value
|
|
776
|
+
else:
|
|
777
|
+
self._status = Status.FAILED.value
|
|
778
|
+
|
|
779
|
+
def _parse_job_id(self, sbatch_output):
|
|
780
|
+
return super()._parse_job_id(sbatch_output)
|
|
781
|
+
|
|
782
|
+
def cleanup(self) -> None:
|
|
783
|
+
super().cleanup()
|
|
784
|
+
|
|
785
|
+
async def cancel(self) -> None:
|
|
786
|
+
"""Cancels the local batch by terminating the subprocess if it's running."""
|
|
787
|
+
if self._proc and self._proc.returncode is None:
|
|
788
|
+
self._proc.terminate()
|
|
789
|
+
try:
|
|
790
|
+
await asyncio.wait_for(self._proc.wait(), timeout=5.0)
|
|
791
|
+
except asyncio.TimeoutError:
|
|
792
|
+
self._proc.kill()
|
|
793
|
+
await self._proc.wait()
|
|
794
|
+
self._status = Status.FAILED.value
|
|
795
|
+
await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
class SlurmBatch(Batch):
|
|
799
|
+
"""Batch that submits tasks to a Slurm job scheduler.
|
|
800
|
+
|
|
801
|
+
Args:
|
|
802
|
+
tasks (list[Task]): List of Task objects to be included in the batch.
|
|
803
|
+
id (str): Unique identifier for the batch.
|
|
804
|
+
run_dir (pathlib.Path): Directory where the batch will be executed.
|
|
805
|
+
expected_outputs (list[Output]): List of expected outputs for the batch.
|
|
806
|
+
slurm_config (SlurmConfig): Configuration for Slurm job submission. Refer to SlurmConfig class for details."""
|
|
807
|
+
TEMPLATE_CMD = "#!/bin/bash\n"
|
|
808
|
+
|
|
809
|
+
def __init__(self, tasks, id, run_dir, expected_outputs, slurm_config: SlurmConfig) -> None:
|
|
810
|
+
super().__init__(tasks, id, run_dir, expected_outputs)
|
|
811
|
+
self._check_slurm_works()
|
|
812
|
+
self.slurm_config = slurm_config
|
|
813
|
+
self._script = self.TEMPLATE_CMD + self.slurm_config.to_slurm_args() + "\nset -o pipefail\n"
|
|
814
|
+
self._job_id = None
|
|
815
|
+
|
|
816
|
+
def _check_slurm_works(self) -> None:
|
|
817
|
+
"""Checks if Slurm commands are available on the system."""
|
|
818
|
+
try:
|
|
819
|
+
subprocess.run(["sbatch", "--version"], capture_output=True, text=True, check=True)
|
|
820
|
+
subprocess.run(["sacct", "--version"], capture_output=True, text=True, check=True)
|
|
821
|
+
except:
|
|
822
|
+
raise EnvironmentError("Slurm does not seem to be available or configured properly on this system.")
|
|
823
|
+
|
|
824
|
+
async def cancel(self) -> None:
|
|
825
|
+
"""Cancel a running or submitted Slurm job."""
|
|
826
|
+
if self._job_id:
|
|
827
|
+
proc = await asyncio.create_subprocess_exec(
|
|
828
|
+
"scancel", self._job_id,
|
|
829
|
+
stdout=asyncio.subprocess.PIPE,
|
|
830
|
+
stderr=asyncio.subprocess.PIPE,
|
|
831
|
+
)
|
|
832
|
+
await proc.wait()
|
|
833
|
+
|
|
834
|
+
self._status = Status.FAILED.value
|
|
835
|
+
await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
|
|
836
|
+
|
|
837
|
+
async def run(self) -> None:
|
|
838
|
+
"""This method submits the batch to Slurm by creating a batch script and using sbatch command. It also monitors the job status until completion.
|
|
839
|
+
This method is unavoidably different from LocalBatch.run() because of the nature of Slurm job submission.
|
|
840
|
+
"""
|
|
841
|
+
|
|
842
|
+
if self.status != Status.SUCCESS and self.status != Status.FAILED.value:
|
|
843
|
+
self.batch_dir.mkdir(parents=True, exist_ok=True)
|
|
844
|
+
self._status = Status.RUNNING.value
|
|
845
|
+
await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
|
|
846
|
+
# create task directories and initialize .status if needed
|
|
847
|
+
|
|
848
|
+
for task in self.tasks:
|
|
849
|
+
if task.status == Status.NOT_STARTED.value:
|
|
850
|
+
task.task_dir.mkdir(parents=True, exist_ok=True)
|
|
851
|
+
await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)
|
|
852
|
+
# write the batch script (all tasks included)
|
|
853
|
+
|
|
854
|
+
batch_path = self.batch_dir / f"{self.id}.batch"
|
|
855
|
+
script=self._script
|
|
856
|
+
for task in self.tasks:
|
|
857
|
+
if task.status == Status.NOT_STARTED.value:
|
|
858
|
+
script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"
|
|
859
|
+
|
|
860
|
+
await write_file(batch_path, script, self.file_semaphore)
|
|
861
|
+
|
|
862
|
+
proc = await asyncio.create_subprocess_exec(
|
|
863
|
+
"sbatch","--parsable", batch_path.name,
|
|
864
|
+
stdout=asyncio.subprocess.PIPE,
|
|
865
|
+
stderr=asyncio.subprocess.PIPE,
|
|
866
|
+
cwd=str(self.batch_dir),
|
|
867
|
+
)
|
|
868
|
+
out_bytes, out_err = await proc.communicate()
|
|
869
|
+
out = out_bytes.decode().strip() if out_bytes else ""
|
|
870
|
+
if proc.returncode == 0:
|
|
871
|
+
try:
|
|
872
|
+
self._job_id = self._parse_job_id(out)
|
|
873
|
+
self._status = Status.SUBMITTED.value
|
|
874
|
+
await self._wait_to_finish()
|
|
875
|
+
except Exception:
|
|
876
|
+
self._status = Status.FAILED.value
|
|
877
|
+
else:
|
|
878
|
+
self._status = Status.FAILED.value
|
|
879
|
+
|
|
880
|
+
if self._status == Status.SUCCESS.value and self.outputs_ready():
|
|
881
|
+
self.cleanup()
|
|
882
|
+
self._status = Status.SUCCESS.value
|
|
883
|
+
await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
|
|
884
|
+
else:
|
|
885
|
+
self._status = Status.FAILED.value
|
|
886
|
+
await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
|
|
887
|
+
|
|
888
|
+
else:
|
|
889
|
+
if self.status == Status.SUCCESS.value and self.outputs_ready():
|
|
890
|
+
self._status = Status.SUCCESS.value
|
|
891
|
+
else:
|
|
892
|
+
self._status = Status.FAILED.value
|
|
893
|
+
|
|
894
|
+
def _parse_job_id(self, sbatch_output: str) -> str:
|
|
895
|
+
if match := re.search(r"(\d+)", sbatch_output):
|
|
896
|
+
return match.group(1)
|
|
897
|
+
else:
|
|
898
|
+
raise ValueError("Could not parse job ID from sbatch output.")
|
|
899
|
+
|
|
900
|
+
async def _wait_to_finish(self,sleep_duration:float=1.0):
|
|
901
|
+
while self.status not in (Status.SUCCESS.value, Status.FAILED.value):
|
|
902
|
+
await self.update_status()
|
|
903
|
+
await asyncio.sleep(sleep_duration)
|
|
904
|
+
|
|
905
|
+
async def update_status(self):
|
|
906
|
+
if self._job_id is None:
|
|
907
|
+
self._status=Status.NOT_STARTED.value
|
|
908
|
+
else:
|
|
909
|
+
await self._collect_task_status()
|
|
910
|
+
out= await asyncio.create_subprocess_exec(
|
|
911
|
+
"sacct", "-j", self._job_id, "--format=State", "--noheader","--allocations",
|
|
912
|
+
stdout=asyncio.subprocess.PIPE,
|
|
913
|
+
stderr=asyncio.subprocess.PIPE,
|
|
914
|
+
)
|
|
915
|
+
out_bytes, _ = await out.communicate()
|
|
916
|
+
if out_bytes:
|
|
917
|
+
state = out_bytes.decode().strip()
|
|
918
|
+
if state in ["FAILED", "CANCELLED", "TIMEOUT"]:
|
|
919
|
+
self._status = Status.FAILED.value
|
|
920
|
+
elif state=="RUNNING":
|
|
921
|
+
self._status = Status.RUNNING.value
|
|
922
|
+
elif state in ["COMPLETED", "COMPLETING"]:
|
|
923
|
+
self._status = Status.SUCCESS.value
|
|
924
|
+
else:
|
|
925
|
+
self._status = Status.PENDING.value
|
|
926
|
+
await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
|
|
927
|
+
|
|
928
|
+
console = Console()
|
|
929
|
+
|
|
930
|
+
def get_cpu_usage():
|
|
931
|
+
"""Returns the current CPU usage percentage."""
|
|
932
|
+
return psutil.cpu_percent(interval=0.1)
|
|
933
|
+
|
|
934
|
+
def get_memory_usage():
|
|
935
|
+
"""Returns the current memory usage percentage."""
|
|
936
|
+
return psutil.virtual_memory().percent
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
class Runner(ABC):
|
|
940
|
+
"""Base Runner class to manage task generation, batching, and execution.
|
|
941
|
+
|
|
942
|
+
Args:
|
|
943
|
+
run_dir (str | pathlib.Path): Directory where the runner will operate.
|
|
944
|
+
task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
|
|
945
|
+
container_engine (Engine): An instance of Engine to wrap task commands.
|
|
946
|
+
batch_factory (Batch): The class that creates Batch instances. It should be a subclass of Batch or its subclasses.
|
|
947
|
+
final_batch_factory (Batch): A callable that creates the final Batch instance.
|
|
948
|
+
max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
|
|
949
|
+
poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
|
|
950
|
+
tasks_per_batch (int): Number of tasks to include in each batch. Default is
|
|
951
|
+
batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
|
|
952
|
+
slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
|
|
953
|
+
is "slurm". Default is None.
|
|
954
|
+
|
|
955
|
+
"""
|
|
956
|
+
TERMINAL_BATCH_STATES = {Status.SUCCESS.value, Status.FAILED.value}
|
|
957
|
+
def __init__(self,
|
|
958
|
+
run_dir: str | pathlib.Path,
|
|
959
|
+
task_generator: TaskGenerator,
|
|
960
|
+
container_engine: Engine,
|
|
961
|
+
batch_factory: Batch,
|
|
962
|
+
final_batch_factory: Batch,
|
|
963
|
+
max_concurrent_batches: int = 1,
|
|
964
|
+
poll_interval: float = 1.0,
|
|
965
|
+
tasks_per_batch: int = 10,
|
|
966
|
+
batch_type: str = "local",
|
|
967
|
+
slurm_config: SlurmConfig | None = None,
|
|
968
|
+
max_retries: int = 3,
|
|
969
|
+
) -> None:
|
|
970
|
+
self.run_dir = pathlib.Path(run_dir)
|
|
971
|
+
self.run_dir.mkdir(parents=True, exist_ok=True)
|
|
972
|
+
self.task_generator = task_generator
|
|
973
|
+
self.container_engine = container_engine
|
|
974
|
+
self.max_concurrent_batches = max_concurrent_batches
|
|
975
|
+
self.poll_interval = poll_interval
|
|
976
|
+
self.tasks_per_batch = tasks_per_batch
|
|
977
|
+
self.batch_type = batch_type
|
|
978
|
+
self.slurm_config = slurm_config
|
|
979
|
+
self.tasks_queue: asyncio.Queue = asyncio.Queue(maxsize=2 * max_concurrent_batches * tasks_per_batch)
|
|
980
|
+
self.batches_queue: asyncio.Queue = asyncio.Queue(maxsize=2 * max_concurrent_batches)
|
|
981
|
+
self._finished_batches_count = 0
|
|
982
|
+
self._success_batches_count = 0
|
|
983
|
+
self._produced_tasks_count = 0
|
|
984
|
+
self._active_batches: list[Batch] = []
|
|
985
|
+
self._batch_counter = 0
|
|
986
|
+
self._batcher_done = False
|
|
987
|
+
self._final_batch_created = False
|
|
988
|
+
self.batch_factory = batch_factory
|
|
989
|
+
self.final_batch_factory = final_batch_factory
|
|
990
|
+
self._failed_batches_count = 0
|
|
991
|
+
self.max_retries = max_retries
|
|
992
|
+
self.total_expected_tasks = self.task_generator.get_total_tasks()
|
|
993
|
+
self.total_expected_batches = (self.total_expected_tasks + tasks_per_batch - 1) // tasks_per_batch
|
|
994
|
+
self._shutdown_event = asyncio.Event()
|
|
995
|
+
self._shutdown_initiated = False
|
|
996
|
+
|
|
997
|
+
async def _refill_tasks(self):
|
|
998
|
+
"""Repeatedly call task_generator until it returns an empty list. This feeds tasks into the tasks_queue and waits for the queue to have space if it's full in an async manner."""
|
|
999
|
+
async for tasks in self.task_generator.generate_tasks():
|
|
1000
|
+
for task in tasks:
|
|
1001
|
+
await self.tasks_queue.put(task)
|
|
1002
|
+
self._produced_tasks_count += 1
|
|
1003
|
+
await self.tasks_queue.put(None)
|
|
1004
|
+
|
|
1005
|
+
@abstractmethod
|
|
1006
|
+
async def _batcher(self):
|
|
1007
|
+
...
|
|
1008
|
+
|
|
1009
|
+
def _create_final_batch(self) -> Batch|None:
|
|
1010
|
+
"""Creates the final batch using the final_batch_factory callable."""
|
|
1011
|
+
return None
|
|
1012
|
+
|
|
1013
|
+
|
|
1014
|
+
async def _shutdown(self):
|
|
1015
|
+
"""Cancel all active batches and signal shutdown."""
|
|
1016
|
+
if self._shutdown_initiated:
|
|
1017
|
+
return
|
|
1018
|
+
self._shutdown_initiated = True
|
|
1019
|
+
console.print("[yellow]Shutdown requested. Cancelling active jobs...[/]")
|
|
1020
|
+
|
|
1021
|
+
for batch in list(self._active_batches):
|
|
1022
|
+
try:
|
|
1023
|
+
await batch.cancel()
|
|
1024
|
+
except Exception as e:
|
|
1025
|
+
console.print(f"[red]Error cancelling batch {batch.id}: {e}[/]")
|
|
1026
|
+
|
|
1027
|
+
# Signal the main loop to stop
|
|
1028
|
+
self._shutdown_event.set()
|
|
1029
|
+
|
|
1030
|
+
async def run(self):
|
|
1031
|
+
"""
|
|
1032
|
+
Run the producer, batcher and worker coroutines and present a live UI while working.
|
|
1033
|
+
Runs the task generator to produce tasks, batches them using the batcher,
|
|
1034
|
+
and executes batches with up to [max_concurrent_batches] parallel workers.
|
|
1035
|
+
UI: displays an overall panel (produced/finished counts), active batch Progress bars,
|
|
1036
|
+
and system stats (CPU/RAM) using Rich Live to mirror the Runner presentation.
|
|
1037
|
+
|
|
1038
|
+
"""
|
|
1039
|
+
asyncio.create_task(self._batcher())
|
|
1040
|
+
asyncio.create_task(self._refill_tasks())
|
|
1041
|
+
semaphore = asyncio.Semaphore(self.max_concurrent_batches)
|
|
1042
|
+
file_semaphore = asyncio.Semaphore(20)
|
|
1043
|
+
async def run_batch(batch: Batch):
|
|
1044
|
+
async with semaphore:
|
|
1045
|
+
while batch.status != Status.SUCCESS.value and batch.retry_count < self.max_retries:
|
|
1046
|
+
await batch.run()
|
|
1047
|
+
if batch.status == Status.SUCCESS.value:
|
|
1048
|
+
break
|
|
1049
|
+
else:
|
|
1050
|
+
batch.retry_count += 1
|
|
1051
|
+
|
|
1052
|
+
self._finished_batches_count += 1
|
|
1053
|
+
|
|
1054
|
+
if batch.status == Status.SUCCESS.value:
|
|
1055
|
+
self._success_batches_count += 1
|
|
1056
|
+
|
|
1057
|
+
elif batch.status == Status.FAILED.value:
|
|
1058
|
+
self._failed_batches_count += 1
|
|
1059
|
+
|
|
1060
|
+
if batch in self._active_batches:
|
|
1061
|
+
self._active_batches.remove(batch)
|
|
1062
|
+
|
|
1063
|
+
# Rich progress objects
|
|
1064
|
+
|
|
1065
|
+
overall_progress = Progress(
|
|
1066
|
+
TextColumn(f"[bold white]{type(self).__name__}[/]"),
|
|
1067
|
+
BarColumn(),
|
|
1068
|
+
TextColumn("• {task.fields[produced_tasks]}/{task.fields[total_expected_tasks]} tasks produced"),
|
|
1069
|
+
TextColumn("• {task.fields[finished_batches]}/{task.fields[total_expected_batches]} batches finished • {task.fields[failed_batches]} batches failed"),
|
|
1070
|
+
TimeElapsedColumn(),
|
|
1071
|
+
expand=True,
|
|
1072
|
+
)
|
|
1073
|
+
overall_task = overall_progress.add_task("overall", produced_tasks=0, total_expected_tasks=self.total_expected_tasks, finished_batches=0, total_expected_batches=self.total_expected_batches, failed_batches=0)
|
|
1074
|
+
|
|
1075
|
+
batch_progress = Progress(
|
|
1076
|
+
TextColumn("[bold white]{task.fields[batch_id]}[/]"),
|
|
1077
|
+
BarColumn(),
|
|
1078
|
+
TextColumn("{task.completed}/{task.total}"),
|
|
1079
|
+
TextColumn("• {task.fields[status]}"),
|
|
1080
|
+
TimeElapsedColumn(),
|
|
1081
|
+
expand=True,
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
batch_to_progress_id: dict[Batch, int] = {}
|
|
1085
|
+
batch_task_totals: dict[Batch, int] = {}
|
|
1086
|
+
|
|
1087
|
+
body = Panel(Group(
|
|
1088
|
+
Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
|
|
1089
|
+
Panel(overall_progress, title="Overall Progress"),
|
|
1090
|
+
Panel(batch_progress, title="Active Batches", height=10),
|
|
1091
|
+
Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
|
|
1092
|
+
), border_style="magenta")
|
|
1093
|
+
|
|
1094
|
+
loop = asyncio.get_running_loop()
|
|
1095
|
+
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
1096
|
+
loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(self._shutdown()))
|
|
1097
|
+
|
|
1098
|
+
with Live(body, console=console, refresh_per_second=2) as live:
|
|
1099
|
+
while not self._shutdown_event.is_set():
|
|
1100
|
+
await self._update_statuses()
|
|
1101
|
+
if self._batcher_done and self.batches_queue.empty() and len(self._active_batches) == 0:
|
|
1102
|
+
if not self._final_batch_created:
|
|
1103
|
+
final_batch = self._create_final_batch()
|
|
1104
|
+
if final_batch is not None:
|
|
1105
|
+
await self.batches_queue.put(final_batch)
|
|
1106
|
+
self._final_batch_created = True
|
|
1107
|
+
else:
|
|
1108
|
+
self._final_batch_created = True
|
|
1109
|
+
break
|
|
1110
|
+
while len(self._active_batches) < self.max_concurrent_batches and not self.batches_queue.empty():
|
|
1111
|
+
batch = await self.batches_queue.get()
|
|
1112
|
+
if batch is not None:
|
|
1113
|
+
batch._set_file_semaphore(file_semaphore)
|
|
1114
|
+
self._active_batches.append(batch)
|
|
1115
|
+
asyncio.create_task(run_batch(batch))
|
|
1116
|
+
# Update overall progress fields
|
|
1117
|
+
overall_progress.update(overall_task, produced_tasks=self._produced_tasks_count, finished_batches=self._finished_batches_count, failed_batches=self._failed_batches_count)
|
|
1118
|
+
# Add newly queued batches into UI
|
|
1119
|
+
|
|
1120
|
+
|
|
1121
|
+
for batch in list(self._active_batches):
|
|
1122
|
+
if batch not in batch_to_progress_id and batch.status not in self.TERMINAL_BATCH_STATES:
|
|
1123
|
+
total = len(batch.tasks) if batch.tasks else 1
|
|
1124
|
+
task_id = batch_progress.add_task("", total=total, batch_id=f"Batch {batch.id}", status=batch.status)
|
|
1125
|
+
batch_to_progress_id[batch] = task_id
|
|
1126
|
+
batch_task_totals[batch] = total
|
|
1127
|
+
# Remove finished batches from UI
|
|
1128
|
+
for batch, tid in list(batch_to_progress_id.items()):
|
|
1129
|
+
if batch.status in self.TERMINAL_BATCH_STATES:
|
|
1130
|
+
try:
|
|
1131
|
+
batch_progress.remove_task(tid)
|
|
1132
|
+
except Exception:
|
|
1133
|
+
pass
|
|
1134
|
+
del batch_to_progress_id[batch]
|
|
1135
|
+
if batch in batch_task_totals:
|
|
1136
|
+
del batch_task_totals[batch]
|
|
1137
|
+
# Update per-batch progress
|
|
1138
|
+
for batch, tid in batch_to_progress_id.items():
|
|
1139
|
+
completed = sum(1 for t in batch.tasks if t.status in self.TERMINAL_BATCH_STATES)
|
|
1140
|
+
total = batch_task_totals.get(batch, max(1, len(batch.tasks)))
|
|
1141
|
+
batch_progress.update(tid, completed=completed, total=total, status=batch.status)
|
|
1142
|
+
# Update system panel
|
|
1143
|
+
body = Panel(Group(
|
|
1144
|
+
Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
|
|
1145
|
+
Panel(overall_progress, title="Overall Progress"),
|
|
1146
|
+
Panel(batch_progress, title="Active Batches"),
|
|
1147
|
+
Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
|
|
1148
|
+
), border_style="magenta")
|
|
1149
|
+
live.update(body)
|
|
1150
|
+
await asyncio.sleep(self.poll_interval)
|
|
1151
|
+
|
|
1152
|
+
# final UI summary
|
|
1153
|
+
console.clear()
|
|
1154
|
+
total_batches = self._batch_counter + (1 if self._final_batch_created and self.final_batch_factory is not None else 0)
|
|
1155
|
+
summary = Panel(
|
|
1156
|
+
f"[bold green]Run finished![/]\n\n{self._success_batches_count}/{total_batches} batches succeeded.\n\nProduced tasks: {self._produced_tasks_count}\nElapsed: (see time in UI)",
|
|
1157
|
+
expand=True,
|
|
1158
|
+
title="Summary",
|
|
1159
|
+
border_style="green",
|
|
1160
|
+
)
|
|
1161
|
+
console.print(summary)
|
|
1162
|
+
|
|
1163
|
+
async def _update_statuses(self):
|
|
1164
|
+
await asyncio.gather(*[batch.update_status() for batch in self._active_batches if batch.status not in self.TERMINAL_BATCH_STATES])
|
|
1165
|
+
|
|
1166
|
+
def _make_system_stats_panel(self):
|
|
1167
|
+
"""helpers to create a system stats panel for the live UI."""
|
|
1168
|
+
def usage_bar(label: str, percent: float, color: str):
|
|
1169
|
+
p = Progress(
|
|
1170
|
+
TextColumn(f"[bold]{label}[/]"),
|
|
1171
|
+
BarColumn(bar_width=None, complete_style=color),
|
|
1172
|
+
TextColumn(f"{percent:.1f}%"),
|
|
1173
|
+
expand=True,
|
|
1174
|
+
)
|
|
1175
|
+
p.add_task("", total=100, completed=percent)
|
|
1176
|
+
return Panel(p, expand=True, width=30)
|
|
1177
|
+
|
|
1178
|
+
cpu = psutil.cpu_percent(interval=None)
|
|
1179
|
+
ram = psutil.virtual_memory().percent
|
|
1180
|
+
cpu_panel = usage_bar("CPU", cpu, "cyan")
|
|
1181
|
+
ram_panel = usage_bar("RAM", ram, "magenta")
|
|
1182
|
+
return Columns([cpu_panel, ram_panel], expand=True, equal=True, align="center")
|
|
1183
|
+
|
|
1184
|
+
class ProfileRunner(Runner):
|
|
1185
|
+
"""
|
|
1186
|
+
Creates and schedules batches of ProfileBamTask tasks using either local or Slurm batches.
|
|
1187
|
+
|
|
1188
|
+
Args:
|
|
1189
|
+
run_dir (str | pathlib.Path): Directory where the runner will operate.
|
|
1190
|
+
task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
|
|
1191
|
+
container_engine (Engine): An instance of Engine to wrap task commands.
|
|
1192
|
+
max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
|
|
1193
|
+
poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
|
|
1194
|
+
tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
|
|
1195
|
+
batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
|
|
1196
|
+
slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
|
|
1197
|
+
is "slurm". Default is None.
|
|
1198
|
+
"""
|
|
1199
|
+
def __init__(
|
|
1200
|
+
self,
|
|
1201
|
+
run_dir: str | pathlib.Path,
|
|
1202
|
+
task_generator: TaskGenerator,
|
|
1203
|
+
container_engine: Engine,
|
|
1204
|
+
max_concurrent_batches: int = 1,
|
|
1205
|
+
poll_interval: float = 1.0,
|
|
1206
|
+
tasks_per_batch: int = 10,
|
|
1207
|
+
batch_type: str = "local",
|
|
1208
|
+
slurm_config: SlurmConfig | None = None,
|
|
1209
|
+
) -> None:
|
|
1210
|
+
if batch_type == "slurm":
|
|
1211
|
+
if slurm_config is None:
|
|
1212
|
+
raise ValueError("Slurm config must be provided for slurm batch type.")
|
|
1213
|
+
batch_factory = SlurmBatch
|
|
1214
|
+
final_batch_factory = None
|
|
1215
|
+
else:
|
|
1216
|
+
batch_factory = LocalBatch
|
|
1217
|
+
final_batch_factory = None
|
|
1218
|
+
|
|
1219
|
+
super().__init__(
|
|
1220
|
+
run_dir=run_dir,
|
|
1221
|
+
task_generator=task_generator,
|
|
1222
|
+
container_engine=container_engine,
|
|
1223
|
+
batch_factory=batch_factory,
|
|
1224
|
+
final_batch_factory=final_batch_factory,
|
|
1225
|
+
max_concurrent_batches=max_concurrent_batches,
|
|
1226
|
+
poll_interval=poll_interval,
|
|
1227
|
+
tasks_per_batch=tasks_per_batch,
|
|
1228
|
+
batch_type=batch_type,
|
|
1229
|
+
slurm_config=slurm_config,
|
|
1230
|
+
)
|
|
1231
|
+
|
|
1232
|
+
async def _batcher(self):
|
|
1233
|
+
"""
|
|
1234
|
+
Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
|
|
1235
|
+
and puts the batches into the batches_queue.
|
|
1236
|
+
"""
|
|
1237
|
+
buffer: list[Task] = []
|
|
1238
|
+
while True:
|
|
1239
|
+
task = await self.tasks_queue.get()
|
|
1240
|
+
if task is None:
|
|
1241
|
+
if buffer:
|
|
1242
|
+
batch_id = f"batch_{self._batch_counter}"
|
|
1243
|
+
self._batch_counter += 1
|
|
1244
|
+
if self.batch_type == "slurm":
|
|
1245
|
+
batch = self.batch_factory(
|
|
1246
|
+
tasks=buffer,
|
|
1247
|
+
id=batch_id,
|
|
1248
|
+
run_dir=self.run_dir,
|
|
1249
|
+
expected_outputs=[],
|
|
1250
|
+
slurm_config=self.slurm_config,
|
|
1251
|
+
)
|
|
1252
|
+
else:
|
|
1253
|
+
batch = self.batch_factory(
|
|
1254
|
+
tasks=buffer,
|
|
1255
|
+
id=batch_id,
|
|
1256
|
+
run_dir=self.run_dir,
|
|
1257
|
+
expected_outputs=[],
|
|
1258
|
+
)
|
|
1259
|
+
await self.batches_queue.put(batch)
|
|
1260
|
+
await self.batches_queue.put(None)
|
|
1261
|
+
self._batcher_done = True
|
|
1262
|
+
break
|
|
1263
|
+
buffer.append(task)
|
|
1264
|
+
if len(buffer) == self.tasks_per_batch:
|
|
1265
|
+
batch_id = f"batch_{self._batch_counter}"
|
|
1266
|
+
self._batch_counter += 1
|
|
1267
|
+
if self.batch_type == "slurm":
|
|
1268
|
+
batch = self.batch_factory(
|
|
1269
|
+
tasks=buffer,
|
|
1270
|
+
id=batch_id,
|
|
1271
|
+
run_dir=self.run_dir,
|
|
1272
|
+
expected_outputs=[],
|
|
1273
|
+
slurm_config=self.slurm_config,
|
|
1274
|
+
)
|
|
1275
|
+
else:
|
|
1276
|
+
batch = self.batch_factory(
|
|
1277
|
+
tasks=buffer,
|
|
1278
|
+
id=batch_id,
|
|
1279
|
+
run_dir=self.run_dir,
|
|
1280
|
+
expected_outputs=[],
|
|
1281
|
+
)
|
|
1282
|
+
await self.batches_queue.put(batch)
|
|
1283
|
+
buffer = []
|
|
1284
|
+
|
|
1285
|
+
|
|
1286
|
+
class CompareRunner(Runner):
|
|
1287
|
+
"""
|
|
1288
|
+
Creates and schedules batches of FastCompareTask tasks using either local or Slurm batches.
|
|
1289
|
+
|
|
1290
|
+
Args:
|
|
1291
|
+
run_dir (str | pathlib.Path): Directory where the runner will operate.
|
|
1292
|
+
task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
|
|
1293
|
+
container_engine (Engine): An instance of Engine to wrap task commands.
|
|
1294
|
+
max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
|
|
1295
|
+
poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
|
|
1296
|
+
tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
|
|
1297
|
+
batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
|
|
1298
|
+
slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
|
|
1299
|
+
is "slurm". Default is None.
|
|
1300
|
+
"""
|
|
1301
|
+
|
|
1302
|
+
def __init__(
|
|
1303
|
+
self,
|
|
1304
|
+
run_dir: str | pathlib.Path,
|
|
1305
|
+
task_generator: TaskGenerator,
|
|
1306
|
+
container_engine: Engine,
|
|
1307
|
+
max_concurrent_batches: int = 1,
|
|
1308
|
+
poll_interval: float = 1.0,
|
|
1309
|
+
tasks_per_batch: int = 10,
|
|
1310
|
+
batch_type: str = "local",
|
|
1311
|
+
slurm_config: SlurmConfig | None = None,
|
|
1312
|
+
) -> None:
|
|
1313
|
+
if batch_type == "slurm":
|
|
1314
|
+
if slurm_config is None:
|
|
1315
|
+
raise ValueError("Slurm config must be provided for slurm batch type.")
|
|
1316
|
+
batch_factory = FastCompareSlurmBatch
|
|
1317
|
+
final_batch_factory = PrepareCompareGenomeRunOutputsSlurmBatch
|
|
1318
|
+
else:
|
|
1319
|
+
batch_factory = FastCompareLocalBatch
|
|
1320
|
+
final_batch_factory = PrepareCompareGenomeRunOutputsLocalBatch
|
|
1321
|
+
super().__init__(
|
|
1322
|
+
run_dir=run_dir,
|
|
1323
|
+
task_generator=task_generator,
|
|
1324
|
+
container_engine=container_engine,
|
|
1325
|
+
batch_factory=batch_factory,
|
|
1326
|
+
final_batch_factory=final_batch_factory,
|
|
1327
|
+
max_concurrent_batches=max_concurrent_batches,
|
|
1328
|
+
poll_interval=poll_interval,
|
|
1329
|
+
tasks_per_batch=tasks_per_batch,
|
|
1330
|
+
batch_type=batch_type,
|
|
1331
|
+
slurm_config=slurm_config,
|
|
1332
|
+
)
|
|
1333
|
+
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
|
|
1337
|
+
async def _batcher(self):
|
|
1338
|
+
"""
|
|
1339
|
+
Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
|
|
1340
|
+
and puts the batches into the batches_queue. Each batch includes a CollectComps task to merge the outputs of the tasks in the batch.
|
|
1341
|
+
"""
|
|
1342
|
+
buffer: list[Task] = []
|
|
1343
|
+
while True:
|
|
1344
|
+
task = await self.tasks_queue.get()
|
|
1345
|
+
if task is None:
|
|
1346
|
+
if buffer:
|
|
1347
|
+
batch_id = f"batch_{self._batch_counter}"
|
|
1348
|
+
self._batch_counter += 1
|
|
1349
|
+
batch_tasks = buffer + [
|
|
1350
|
+
CollectComps(
|
|
1351
|
+
"concat_parquet",
|
|
1352
|
+
{},
|
|
1353
|
+
{"output-file": FileOutput(f"Merged_batch_{batch_id}.parquet")},
|
|
1354
|
+
engine=self.container_engine,
|
|
1355
|
+
)
|
|
1356
|
+
]
|
|
1357
|
+
expected_outputs = [BatchFileOutput(f"concat_parquet/Merged_batch_{batch_id}.parquet")]
|
|
1358
|
+
if self.batch_type == "slurm":
|
|
1359
|
+
batch = self.batch_factory(
|
|
1360
|
+
tasks=batch_tasks,
|
|
1361
|
+
id=batch_id,
|
|
1362
|
+
run_dir=self.run_dir,
|
|
1363
|
+
expected_outputs=expected_outputs,
|
|
1364
|
+
slurm_config=self.slurm_config,
|
|
1365
|
+
)
|
|
1366
|
+
else:
|
|
1367
|
+
batch = self.batch_factory(
|
|
1368
|
+
tasks=batch_tasks,
|
|
1369
|
+
id=batch_id,
|
|
1370
|
+
run_dir=self.run_dir,
|
|
1371
|
+
expected_outputs=expected_outputs,
|
|
1372
|
+
)
|
|
1373
|
+
await self.batches_queue.put(batch)
|
|
1374
|
+
await self.batches_queue.put(None)
|
|
1375
|
+
self._batcher_done = True
|
|
1376
|
+
break
|
|
1377
|
+
buffer.append(task)
|
|
1378
|
+
if len(buffer) == self.tasks_per_batch:
|
|
1379
|
+
batch_id = f"batch_{self._batch_counter}"
|
|
1380
|
+
self._batch_counter += 1
|
|
1381
|
+
batch_tasks = buffer + [
|
|
1382
|
+
CollectComps(
|
|
1383
|
+
"concat_parquet",
|
|
1384
|
+
{},
|
|
1385
|
+
{"output-file": FileOutput(f"Merged_batch_{batch_id}.parquet")},
|
|
1386
|
+
engine=self.container_engine,
|
|
1387
|
+
)
|
|
1388
|
+
]
|
|
1389
|
+
expected_outputs = [BatchFileOutput(f"concat_parquet/Merged_batch_{batch_id}.parquet")]
|
|
1390
|
+
if self.batch_type == "slurm":
|
|
1391
|
+
batch = self.batch_factory(
|
|
1392
|
+
tasks=batch_tasks,
|
|
1393
|
+
id=batch_id,
|
|
1394
|
+
run_dir=self.run_dir,
|
|
1395
|
+
expected_outputs=expected_outputs,
|
|
1396
|
+
slurm_config=self.slurm_config,
|
|
1397
|
+
)
|
|
1398
|
+
else:
|
|
1399
|
+
batch = self.batch_factory(
|
|
1400
|
+
tasks=batch_tasks,
|
|
1401
|
+
id=batch_id,
|
|
1402
|
+
run_dir=self.run_dir,
|
|
1403
|
+
expected_outputs=expected_outputs,
|
|
1404
|
+
)
|
|
1405
|
+
await self.batches_queue.put(batch)
|
|
1406
|
+
buffer = []
|
|
1407
|
+
|
|
1408
|
+
|
|
1409
|
+
|
|
1410
|
+
def _create_final_batch(self) -> Batch:
|
|
1411
|
+
"""Creates the final batch that prepares the overall outputs after all comparison batches are done."""
|
|
1412
|
+
final_task = PrepareCompareGenomeRunOutputs(
|
|
1413
|
+
id="prepare_outputs",
|
|
1414
|
+
inputs={"output-dir": StringInput("Outputs")},
|
|
1415
|
+
expected_outputs={},
|
|
1416
|
+
engine=self.container_engine,
|
|
1417
|
+
)
|
|
1418
|
+
expected_outputs = [BatchFileOutput("all_comparisons.parquet")]
|
|
1419
|
+
if self.batch_type == "slurm":
|
|
1420
|
+
final_batch=self.final_batch_factory(
|
|
1421
|
+
tasks=[final_task],
|
|
1422
|
+
id="Outputs",
|
|
1423
|
+
run_dir=self.run_dir,
|
|
1424
|
+
expected_outputs=expected_outputs,
|
|
1425
|
+
slurm_config=self.slurm_config,
|
|
1426
|
+
)
|
|
1427
|
+
final_batch._runner_obj = self
|
|
1428
|
+
return final_batch
|
|
1429
|
+
else:
|
|
1430
|
+
final_batch = self.final_batch_factory(
|
|
1431
|
+
tasks=[final_task],
|
|
1432
|
+
id="Outputs",
|
|
1433
|
+
run_dir=self.run_dir,
|
|
1434
|
+
expected_outputs=expected_outputs,
|
|
1435
|
+
)
|
|
1436
|
+
final_batch._runner_obj = self
|
|
1437
|
+
return final_batch
|
|
1438
|
+
|
|
1439
|
+
|
|
1440
|
+
|
|
1441
|
+
class ProfileBamTask(Task):
|
|
1442
|
+
"""A Task that generates a mpileup file and genome breadth file in parquet format for a given BAM file using the fast_profile profile_bam command.
|
|
1443
|
+
The inputs to this task includes:
|
|
1444
|
+
|
|
1445
|
+
- bam-file: The input BAM file to be profiled.
|
|
1446
|
+
|
|
1447
|
+
- bed-file: The BED file specifying the regions to profile.
|
|
1448
|
+
|
|
1449
|
+
- sample-name: The name of the sample being processed.
|
|
1450
|
+
|
|
1451
|
+
- gene-range-table: A BED file specifying the gene ranges for the sample.
|
|
1452
|
+
|
|
1453
|
+
- num-threads: The number of threads to use for processing.
|
|
1454
|
+
|
|
1455
|
+
- genome-length-file: A file containing the lengths of the genomes in the reference fasta.
|
|
1456
|
+
|
|
1457
|
+
- stb-file: The STB file used for profiling.
|
|
1458
|
+
|
|
1459
|
+
Args:
|
|
1460
|
+
id (str): Unique identifier for the task.
|
|
1461
|
+
inputs (dict[str, Input]): Dictionary of input parameters for the task.
|
|
1462
|
+
expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
|
|
1463
|
+
engine (Engine): Container engine to wrap the command."""
|
|
1464
|
+
|
|
1465
|
+
TEMPLATE_CMD="""
|
|
1466
|
+
ln -s <bam-file> input.bam
|
|
1467
|
+
ln -s <bed-file> bed_file.bed
|
|
1468
|
+
ln -s <gene-range-table> gene-range-table.bed
|
|
1469
|
+
samtools index <bam-file>
|
|
1470
|
+
zipstrain profile profile-single --bam-file input.bam \
|
|
1471
|
+
--bed-file bed_file.bed \
|
|
1472
|
+
--gene-range-table gene-range-table.bed \
|
|
1473
|
+
--num-workers <num-workers> \
|
|
1474
|
+
--output-dir .
|
|
1475
|
+
mv input.bam.parquet <sample-name>.parquet
|
|
1476
|
+
samtools idxstats <bam-file> | awk '$3 > 0 {print $1}' > <sample-name>.parquet.scaffolds
|
|
1477
|
+
zipstrain utilities genome_breadth_matrix --profile <sample-name>.parquet \
|
|
1478
|
+
--genome-length <genome-length-file> \
|
|
1479
|
+
--stb <stb-file> \
|
|
1480
|
+
--min-cov <breadth-min-cov> \
|
|
1481
|
+
--output-file <sample-name>_breadth.parquet
|
|
1482
|
+
"""
|
|
1483
|
+
|
|
1484
|
+
class FastCompareTask(Task):
|
|
1485
|
+
"""A Task that performs a fast genome comparison using the fast_profile compare single_compare_genome command.
|
|
1486
|
+
|
|
1487
|
+
Args:
|
|
1488
|
+
id (str): Unique identifier for the task.
|
|
1489
|
+
inputs (dict[str, Input]): Dictionary of input parameters for the task.
|
|
1490
|
+
expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
|
|
1491
|
+
engine (Engine): Container engine to wrap the command.
|
|
1492
|
+
"""
|
|
1493
|
+
TEMPLATE_CMD="""
|
|
1494
|
+
zipstrain compare single_compare_genome --mpileup-contig-1 <mpile_1_file> \
|
|
1495
|
+
--mpileup-contig-2 <mpile_2_file> \
|
|
1496
|
+
--scaffolds-1 <scaffold_1_file> \
|
|
1497
|
+
--scaffolds-2 <scaffold_2_file> \
|
|
1498
|
+
--null-model <null_model_file> \
|
|
1499
|
+
--stb-file <stb_file> \
|
|
1500
|
+
--min-cov <min_cov> \
|
|
1501
|
+
--min-gene-compare-len <min-gene-compare-len> \
|
|
1502
|
+
--memory-mode <memory-mode> \
|
|
1503
|
+
--chrom-batch-size <chrom-batch-size> \
|
|
1504
|
+
--output-file <output-file> \
|
|
1505
|
+
--genome <genome-name> \
|
|
1506
|
+
--engine <engine>
|
|
1507
|
+
"""
|
|
1508
|
+
|
|
1509
|
+
|
|
1510
|
+
class CollectComps(Task):
|
|
1511
|
+
"""A Task that collects and merges comparison parquet files from multiple FastCompareTask tasks into a single parquet file.
|
|
1512
|
+
|
|
1513
|
+
Args:
|
|
1514
|
+
id (str): Unique identifier for the task.
|
|
1515
|
+
inputs (dict[str, Input]): Dictionary of input parameters for the task.
|
|
1516
|
+
expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
|
|
1517
|
+
engine (Engine): Container engine to wrap the command."""
|
|
1518
|
+
TEMPLATE_CMD="""
|
|
1519
|
+
mkdir -p comps
|
|
1520
|
+
cp */*_comparison.parquet comps/
|
|
1521
|
+
zipstrain utilities merge_parquet --input-dir comps --output-file <output-file>
|
|
1522
|
+
rm -rf comps
|
|
1523
|
+
"""
|
|
1524
|
+
|
|
1525
|
+
@property
|
|
1526
|
+
def pre_run(self) -> str:
|
|
1527
|
+
return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status"
|
|
1528
|
+
|
|
1529
|
+
|
|
1530
|
+
|
|
1531
|
+
class PrepareCompareGenomeRunOutputs(Task):
|
|
1532
|
+
"""A Task that prepares the final output by merging all parquet files after all genome comparisons are done."""
|
|
1533
|
+
TEMPLATE_CMD="""
|
|
1534
|
+
mkdir -p <output-dir>/comps
|
|
1535
|
+
find "$(pwd)" -type f -name "Merged_batch_*.parquet" -print0 | xargs -0 -I {} ln -s {} <output-dir>/comps/
|
|
1536
|
+
zipstrain utilities merge_parquet --input-dir <output-dir>/comps --output-file <output-dir>/all_comparisons.parquet
|
|
1537
|
+
rm -rf <output-dir>/comps
|
|
1538
|
+
"""
|
|
1539
|
+
|
|
1540
|
+
@property
|
|
1541
|
+
def pre_run(self) -> str:
|
|
1542
|
+
"""Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs."""
|
|
1543
|
+
return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self._batch_obj._runner_obj.run_dir.absolute()}"
|
|
1544
|
+
|
|
1545
|
+
|
|
1546
|
+
|
|
1547
|
+
|
|
1548
|
+
class FastCompareLocalBatch(LocalBatch):
|
|
1549
|
+
"""A LocalBatch that runs FastCompareTask tasks locally."""
|
|
1550
|
+
def cleanup(self) -> None:
|
|
1551
|
+
tasks_to_remove = [task for task in self.tasks if isinstance(task, FastCompareTask)]
|
|
1552
|
+
for task in tasks_to_remove:
|
|
1553
|
+
self.tasks.remove(task)
|
|
1554
|
+
shutil.rmtree(task.task_dir)
|
|
1555
|
+
|
|
1556
|
+
class FastCompareSlurmBatch(SlurmBatch):
|
|
1557
|
+
"""A SlurmBatch that runs FastCompareTask tasks on a Slurm cluster. Maybe removed in future"""
|
|
1558
|
+
def cleanup(self) -> None:
|
|
1559
|
+
tasks_to_remove = [task for task in self.tasks if isinstance(task, FastCompareTask)]
|
|
1560
|
+
for task in tasks_to_remove:
|
|
1561
|
+
self.tasks.remove(task)
|
|
1562
|
+
shutil.rmtree(task.task_dir)
|
|
1563
|
+
|
|
1564
|
+
class PrepareCompareGenomeRunOutputsLocalBatch(LocalBatch):
|
|
1565
|
+
pass
|
|
1566
|
+
|
|
1567
|
+
|
|
1568
|
+
class PrepareCompareGenomeRunOutputsSlurmBatch(SlurmBatch):
|
|
1569
|
+
pass
|
|
1570
|
+
|
|
1571
|
+
|
|
1572
|
+
def lazy_run_profile(
|
|
1573
|
+
run_dir: str | pathlib.Path,
|
|
1574
|
+
container_engine: Engine,
|
|
1575
|
+
bams_lf:pl.LazyFrame,
|
|
1576
|
+
stb_file:pathlib.Path,
|
|
1577
|
+
gene_range_table:pathlib.Path,
|
|
1578
|
+
bed_file:pathlib.Path,
|
|
1579
|
+
genome_length_file:pathlib.Path,
|
|
1580
|
+
num_procs:int=8,
|
|
1581
|
+
tasks_per_batch: int = 10,
|
|
1582
|
+
max_concurrent_batches: int = 1,
|
|
1583
|
+
poll_interval: float = 5.0,
|
|
1584
|
+
execution_mode: str = "local",
|
|
1585
|
+
slurm_config: SlurmConfig | None = None,
|
|
1586
|
+
)->None:
|
|
1587
|
+
profile_task_generator=ProfileTaskGenerator(
|
|
1588
|
+
data=bams_lf,
|
|
1589
|
+
yield_size=tasks_per_batch,
|
|
1590
|
+
container_engine=container_engine,
|
|
1591
|
+
stb_file=stb_file,
|
|
1592
|
+
profile_bed_file=bed_file,
|
|
1593
|
+
gene_range_file=gene_range_table,
|
|
1594
|
+
genome_length_file=genome_length_file,
|
|
1595
|
+
num_procs=num_procs
|
|
1596
|
+
)
|
|
1597
|
+
if execution_mode=="local":
|
|
1598
|
+
batch_type="local"
|
|
1599
|
+
elif execution_mode=="slurm":
|
|
1600
|
+
batch_type="slurm"
|
|
1601
|
+
else:
|
|
1602
|
+
raise ValueError(f"Unknown execution mode: {execution_mode}")
|
|
1603
|
+
|
|
1604
|
+
runner = ProfileRunner(
|
|
1605
|
+
run_dir=pathlib.Path(run_dir),
|
|
1606
|
+
task_generator=profile_task_generator,
|
|
1607
|
+
container_engine=container_engine,
|
|
1608
|
+
max_concurrent_batches=max_concurrent_batches,
|
|
1609
|
+
poll_interval=poll_interval,
|
|
1610
|
+
tasks_per_batch=tasks_per_batch,
|
|
1611
|
+
batch_type=batch_type,
|
|
1612
|
+
slurm_config=slurm_config,
|
|
1613
|
+
)
|
|
1614
|
+
asyncio.run(runner.run())
|
|
1615
|
+
|
|
1616
|
+
|
|
1617
|
+
def lazy_run_compares(
|
|
1618
|
+
run_dir: str | pathlib.Path,
|
|
1619
|
+
container_engine: Engine,
|
|
1620
|
+
comps_db: database.GenomeComparisonDatabase|None = None,
|
|
1621
|
+
tasks_per_batch: int = 10,
|
|
1622
|
+
max_concurrent_batches: int = 1,
|
|
1623
|
+
poll_interval: float = 5.0,
|
|
1624
|
+
execution_mode: str = "local",
|
|
1625
|
+
slurm_config: SlurmConfig | None = None,
|
|
1626
|
+
memory_mode: str = "heavy",
|
|
1627
|
+
chrom_batch_size: int = 10000,
|
|
1628
|
+
polars_engine: str = "streaming"
|
|
1629
|
+
) -> None:
|
|
1630
|
+
"""A helper function to quickly set up and run a CompareRunner with given parameters.
|
|
1631
|
+
|
|
1632
|
+
Args:
|
|
1633
|
+
run_dir (str | pathlib.Path): Directory where the runner will operate.
|
|
1634
|
+
container_engine (Engine): An instance of Engine to wrap task commands.
|
|
1635
|
+
comps_db (GenomeComparisonDatabase | None): An instance of GenomeComparisonDatabase containing comparison data.
|
|
1636
|
+
tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
|
|
1637
|
+
max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
|
|
1638
|
+
poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 5.0.
|
|
1639
|
+
execution_mode (str): Execution mode, either "local" or "slurm". Default is "local".
|
|
1640
|
+
"""
|
|
1641
|
+
task_generator = CompareTaskGenerator(
|
|
1642
|
+
data=comps_db.to_complete_input_table(),
|
|
1643
|
+
yield_size=tasks_per_batch,
|
|
1644
|
+
container_engine=container_engine,
|
|
1645
|
+
comp_config=comps_db.config,
|
|
1646
|
+
memory_mode=memory_mode,
|
|
1647
|
+
polars_engine=polars_engine,
|
|
1648
|
+
chrom_batch_size=chrom_batch_size,
|
|
1649
|
+
)
|
|
1650
|
+
if execution_mode=="local":
|
|
1651
|
+
batch_type="local"
|
|
1652
|
+
elif execution_mode=="slurm":
|
|
1653
|
+
batch_type="slurm"
|
|
1654
|
+
else:
|
|
1655
|
+
raise ValueError(f"Unknown execution mode: {execution_mode}")
|
|
1656
|
+
runner = CompareRunner(
|
|
1657
|
+
run_dir=pathlib.Path(run_dir),
|
|
1658
|
+
task_generator=task_generator,
|
|
1659
|
+
container_engine=container_engine,
|
|
1660
|
+
max_concurrent_batches=max_concurrent_batches,
|
|
1661
|
+
poll_interval=poll_interval,
|
|
1662
|
+
tasks_per_batch=tasks_per_batch,
|
|
1663
|
+
batch_type=batch_type,
|
|
1664
|
+
slurm_config=slurm_config,
|
|
1665
|
+
)
|
|
1666
|
+
asyncio.run(runner.run())
|
|
1667
|
+
|
|
1668
|
+
|
|
1669
|
+
class FastGeneCompareTask(Task):
|
|
1670
|
+
"""A Task that performs a fast gene comparison using the compare single_compare_gene command.
|
|
1671
|
+
|
|
1672
|
+
Args:
|
|
1673
|
+
id (str): Unique identifier for the task.
|
|
1674
|
+
inputs (dict[str, Input]): Dictionary of input parameters for the task.
|
|
1675
|
+
expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
|
|
1676
|
+
engine (Engine): Container engine to wrap the command.
|
|
1677
|
+
"""
|
|
1678
|
+
TEMPLATE_CMD="""
|
|
1679
|
+
zipstrain compare single_compare_gene --mpileup-contig-1 <mpile_1_file> \
|
|
1680
|
+
--mpileup-contig-2 <mpile_2_file> \
|
|
1681
|
+
--null-model <null_model_file> \
|
|
1682
|
+
--stb-file <stb_file> \
|
|
1683
|
+
--min-cov <min_cov> \
|
|
1684
|
+
--min-gene-compare-len <min-gene-compare-len> \
|
|
1685
|
+
--output-file <output-file> \
|
|
1686
|
+
--engine <engine> \
|
|
1687
|
+
--ani-method <ani-method>
|
|
1688
|
+
"""
|
|
1689
|
+
|
|
1690
|
+
class GeneCompareTaskGenerator(TaskGenerator):
|
|
1691
|
+
"""This TaskGenerator generates FastGeneCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genes functionality in
|
|
1692
|
+
zipstrain.compare module.
|
|
1693
|
+
|
|
1694
|
+
Args:
|
|
1695
|
+
data (pl.LazyFrame): Polars LazyFrame containing the data for generating tasks.
|
|
1696
|
+
yield_size (int): Number of tasks to yield at a time.
|
|
1697
|
+
comp_config (database.GenomeComparisonConfig): Configuration for genome comparison.
|
|
1698
|
+
polars_engine (str): Polars engine to use. Default is "streaming".
|
|
1699
|
+
ani_method (str): ANI calculation method to use. Default is "popani".
|
|
1700
|
+
"""
|
|
1701
|
+
def __init__(
|
|
1702
|
+
self,
|
|
1703
|
+
data: pl.LazyFrame,
|
|
1704
|
+
yield_size: int,
|
|
1705
|
+
container_engine: Engine,
|
|
1706
|
+
comp_config: database.GeneComparisonConfig,
|
|
1707
|
+
polars_engine: str = "streaming",
|
|
1708
|
+
ani_method: str = "popani",
|
|
1709
|
+
) -> None:
|
|
1710
|
+
super().__init__(data, yield_size)
|
|
1711
|
+
self.comp_config = comp_config
|
|
1712
|
+
self.engine = container_engine
|
|
1713
|
+
self.polars_engine = polars_engine
|
|
1714
|
+
self.ani_method = ani_method
|
|
1715
|
+
if type(self.data) is not pl.LazyFrame:
|
|
1716
|
+
raise ValueError("data must be a polars LazyFrame.")
|
|
1717
|
+
|
|
1718
|
+
def get_total_tasks(self) -> int:
|
|
1719
|
+
"""Returns total number of pairwise comparisons to be made."""
|
|
1720
|
+
return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]
|
|
1721
|
+
|
|
1722
|
+
async def generate_tasks(self) -> list[Task]:
|
|
1723
|
+
"""Yields lists of FastGeneCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
|
|
1724
|
+
while polars is collecting data to avoid blocking.
|
|
1725
|
+
"""
|
|
1726
|
+
for offset in range(0, self._total_tasks, self.yield_size):
|
|
1727
|
+
batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
|
|
1728
|
+
tasks = []
|
|
1729
|
+
for row in batch_df.iter_rows(named=True):
|
|
1730
|
+
inputs = {
|
|
1731
|
+
"mpile_1_file": FileInput(row["profile_location_1"]),
|
|
1732
|
+
"mpile_2_file": FileInput(row["profile_location_2"]),
|
|
1733
|
+
"null_model_file": FileInput(self.comp_config.null_model_loc),
|
|
1734
|
+
"stb_file": FileInput(self.comp_config.stb_file_loc),
|
|
1735
|
+
"min_cov": IntInput(self.comp_config.min_cov),
|
|
1736
|
+
"min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
|
|
1737
|
+
"engine": StringInput(self.polars_engine),
|
|
1738
|
+
"ani-method": StringInput(self.ani_method),
|
|
1739
|
+
}
|
|
1740
|
+
expected_outputs ={
|
|
1741
|
+
"output-file": FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_gene_comparison.parquet" ),
|
|
1742
|
+
}
|
|
1743
|
+
task = FastGeneCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
|
|
1744
|
+
tasks.append(task)
|
|
1745
|
+
yield tasks
|
|
1746
|
+
|
|
1747
|
+
class GeneCompareRunner(Runner):
|
|
1748
|
+
"""
|
|
1749
|
+
Creates and schedules batches of FastGeneCompareTask tasks using either local or Slurm batches.
|
|
1750
|
+
|
|
1751
|
+
Args:
|
|
1752
|
+
run_dir (str | pathlib.Path): Directory where the runner will operate.
|
|
1753
|
+
task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
|
|
1754
|
+
container_engine (Engine): An instance of Engine to wrap task commands.
|
|
1755
|
+
max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
|
|
1756
|
+
poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
|
|
1757
|
+
tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
|
|
1758
|
+
batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
|
|
1759
|
+
slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
|
|
1760
|
+
is "slurm". Default is None.
|
|
1761
|
+
"""
|
|
1762
|
+
|
|
1763
|
+
def __init__(
|
|
1764
|
+
self,
|
|
1765
|
+
run_dir: str | pathlib.Path,
|
|
1766
|
+
task_generator: TaskGenerator,
|
|
1767
|
+
container_engine: Engine,
|
|
1768
|
+
max_concurrent_batches: int = 1,
|
|
1769
|
+
poll_interval: float = 1.0,
|
|
1770
|
+
tasks_per_batch: int = 10,
|
|
1771
|
+
batch_type: str = "local",
|
|
1772
|
+
slurm_config: SlurmConfig | None = None,
|
|
1773
|
+
) -> None:
|
|
1774
|
+
if batch_type == "slurm":
|
|
1775
|
+
if slurm_config is None:
|
|
1776
|
+
raise ValueError("Slurm config must be provided for slurm batch type.")
|
|
1777
|
+
batch_factory = FastGeneCompareSlurmBatch
|
|
1778
|
+
final_batch_factory = PrepareGeneCompareRunOutputsSlurmBatch
|
|
1779
|
+
else:
|
|
1780
|
+
batch_factory = FastGeneCompareLocalBatch
|
|
1781
|
+
final_batch_factory = PrepareGeneCompareRunOutputsLocalBatch
|
|
1782
|
+
super().__init__(
|
|
1783
|
+
run_dir=run_dir,
|
|
1784
|
+
task_generator=task_generator,
|
|
1785
|
+
container_engine=container_engine,
|
|
1786
|
+
batch_factory=batch_factory,
|
|
1787
|
+
final_batch_factory=final_batch_factory,
|
|
1788
|
+
max_concurrent_batches=max_concurrent_batches,
|
|
1789
|
+
poll_interval=poll_interval,
|
|
1790
|
+
tasks_per_batch=tasks_per_batch,
|
|
1791
|
+
batch_type=batch_type,
|
|
1792
|
+
slurm_config=slurm_config,
|
|
1793
|
+
)
|
|
1794
|
+
|
|
1795
|
+
async def _batcher(self):
|
|
1796
|
+
"""
|
|
1797
|
+
Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
|
|
1798
|
+
and puts the batches into the batches_queue. Each batch includes a CollectGeneComps task to merge the outputs of the tasks in the batch.
|
|
1799
|
+
"""
|
|
1800
|
+
buffer: list[Task] = []
|
|
1801
|
+
while True:
|
|
1802
|
+
task = await self.tasks_queue.get()
|
|
1803
|
+
if task is None:
|
|
1804
|
+
if buffer:
|
|
1805
|
+
collect_task = CollectGeneComps(
|
|
1806
|
+
id="collect_gene_comps",
|
|
1807
|
+
inputs={},
|
|
1808
|
+
expected_outputs={"output-file": FileOutput(f"Merged_gene_batch_{self._batch_counter}.parquet")},
|
|
1809
|
+
engine=self.container_engine,
|
|
1810
|
+
)
|
|
1811
|
+
buffer.append(collect_task)
|
|
1812
|
+
batch = self.batch_factory(
|
|
1813
|
+
tasks=buffer,
|
|
1814
|
+
id=f"gene_batch_{self._batch_counter}",
|
|
1815
|
+
run_dir=self.run_dir,
|
|
1816
|
+
expected_outputs=[],
|
|
1817
|
+
slurm_config=self.slurm_config if self.batch_type == "slurm" else None,
|
|
1818
|
+
)
|
|
1819
|
+
await self.batches_queue.put(batch)
|
|
1820
|
+
self._batch_counter += 1
|
|
1821
|
+
self._batcher_done = True
|
|
1822
|
+
break
|
|
1823
|
+
|
|
1824
|
+
buffer.append(task)
|
|
1825
|
+
if len(buffer) == self.tasks_per_batch:
|
|
1826
|
+
collect_task = CollectGeneComps(
|
|
1827
|
+
id="collect_gene_comps",
|
|
1828
|
+
inputs={},
|
|
1829
|
+
expected_outputs={"output-file": FileOutput(f"Merged_gene_batch_{self._batch_counter}.parquet")},
|
|
1830
|
+
engine=self.container_engine,
|
|
1831
|
+
)
|
|
1832
|
+
buffer.append(collect_task)
|
|
1833
|
+
batch = self.batch_factory(
|
|
1834
|
+
tasks=buffer,
|
|
1835
|
+
id=f"gene_batch_{self._batch_counter}",
|
|
1836
|
+
run_dir=self.run_dir,
|
|
1837
|
+
expected_outputs=[],
|
|
1838
|
+
slurm_config=self.slurm_config if self.batch_type == "slurm" else None,
|
|
1839
|
+
)
|
|
1840
|
+
await self.batches_queue.put(batch)
|
|
1841
|
+
self._batch_counter += 1
|
|
1842
|
+
buffer = []
|
|
1843
|
+
|
|
1844
|
+
def _create_final_batch(self) -> Batch:
|
|
1845
|
+
"""Creates the final batch that prepares the overall outputs after all gene comparison batches are done."""
|
|
1846
|
+
final_task = PrepareGeneCompareRunOutputs(
|
|
1847
|
+
id="prepare_gene_outputs",
|
|
1848
|
+
inputs={"output-dir": StringInput("Outputs")},
|
|
1849
|
+
expected_outputs={},
|
|
1850
|
+
engine=self.container_engine,
|
|
1851
|
+
)
|
|
1852
|
+
expected_outputs = [BatchFileOutput("all_gene_comparisons.parquet")]
|
|
1853
|
+
if self.batch_type == "slurm":
|
|
1854
|
+
final_batch=self.final_batch_factory(
|
|
1855
|
+
tasks=[final_task],
|
|
1856
|
+
id="Outputs",
|
|
1857
|
+
run_dir=self.run_dir,
|
|
1858
|
+
expected_outputs=expected_outputs,
|
|
1859
|
+
slurm_config=self.slurm_config,
|
|
1860
|
+
)
|
|
1861
|
+
final_batch._runner_obj = self
|
|
1862
|
+
return final_batch
|
|
1863
|
+
else:
|
|
1864
|
+
final_batch = self.final_batch_factory(
|
|
1865
|
+
tasks=[final_task],
|
|
1866
|
+
id="Outputs",
|
|
1867
|
+
run_dir=self.run_dir,
|
|
1868
|
+
expected_outputs=expected_outputs,
|
|
1869
|
+
)
|
|
1870
|
+
final_batch._runner_obj = self
|
|
1871
|
+
return final_batch
|
|
1872
|
+
|
|
1873
|
+
class CollectGeneComps(Task):
|
|
1874
|
+
"""A Task that collects and merges gene comparison parquet files from multiple FastGeneCompareTask tasks into a single parquet file.
|
|
1875
|
+
|
|
1876
|
+
Args:
|
|
1877
|
+
id (str): Unique identifier for the task.
|
|
1878
|
+
inputs (dict[str, Input]): Dictionary of input parameters for the task.
|
|
1879
|
+
expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
|
|
1880
|
+
engine (Engine): Container engine to wrap the command."""
|
|
1881
|
+
TEMPLATE_CMD="""
|
|
1882
|
+
mkdir -p gene_comps
|
|
1883
|
+
cp */*_gene_comparison.parquet gene_comps/
|
|
1884
|
+
zipstrain utilities merge_parquet --input-dir gene_comps --output-file <output-file>
|
|
1885
|
+
rm -rf gene_comps
|
|
1886
|
+
"""
|
|
1887
|
+
|
|
1888
|
+
@property
|
|
1889
|
+
def pre_run(self) -> str:
|
|
1890
|
+
return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status"
|
|
1891
|
+
|
|
1892
|
+
class PrepareGeneCompareRunOutputs(Task):
|
|
1893
|
+
"""A Task that prepares the final output by merging all gene comparison parquet files after all gene comparisons are done."""
|
|
1894
|
+
TEMPLATE_CMD="""
|
|
1895
|
+
mkdir -p <output-dir>/gene_comps
|
|
1896
|
+
find "$(pwd)" -type f -name "Merged_gene_batch_*.parquet" -print0 | xargs -0 -I {} ln -s {} <output-dir>/gene_comps/
|
|
1897
|
+
zipstrain utilities merge_parquet --input-dir <output-dir>/gene_comps --output-file <output-dir>/all_gene_comparisons.parquet
|
|
1898
|
+
rm -rf <output-dir>/gene_comps
|
|
1899
|
+
"""
|
|
1900
|
+
|
|
1901
|
+
@property
|
|
1902
|
+
def pre_run(self) -> str:
|
|
1903
|
+
"""Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs."""
|
|
1904
|
+
return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self._batch_obj._runner_obj.run_dir.absolute()}"
|
|
1905
|
+
|
|
1906
|
+
class FastGeneCompareLocalBatch(LocalBatch):
|
|
1907
|
+
"""A LocalBatch that runs FastGeneCompareTask tasks locally."""
|
|
1908
|
+
def cleanup(self) -> None:
|
|
1909
|
+
tasks_to_remove = [task for task in self.tasks if isinstance(task, FastGeneCompareTask)]
|
|
1910
|
+
for task in tasks_to_remove:
|
|
1911
|
+
self.tasks.remove(task)
|
|
1912
|
+
shutil.rmtree(task.task_dir)
|
|
1913
|
+
|
|
1914
|
+
class FastGeneCompareSlurmBatch(SlurmBatch):
|
|
1915
|
+
"""A SlurmBatch that runs FastGeneCompareTask tasks on a Slurm cluster."""
|
|
1916
|
+
def cleanup(self) -> None:
|
|
1917
|
+
tasks_to_remove = [task for task in self.tasks if isinstance(task, FastGeneCompareTask)]
|
|
1918
|
+
for task in tasks_to_remove:
|
|
1919
|
+
self.tasks.remove(task)
|
|
1920
|
+
shutil.rmtree(task.task_dir)
|
|
1921
|
+
|
|
1922
|
+
class PrepareGeneCompareRunOutputsLocalBatch(LocalBatch):
|
|
1923
|
+
pass
|
|
1924
|
+
|
|
1925
|
+
class PrepareGeneCompareRunOutputsSlurmBatch(SlurmBatch):
|
|
1926
|
+
pass
|
|
1927
|
+
|
|
1928
|
+
def lazy_run_gene_compares(
|
|
1929
|
+
run_dir: str | pathlib.Path,
|
|
1930
|
+
container_engine: Engine,
|
|
1931
|
+
comps_db: database.GeneComparisonDatabase | None = None,
|
|
1932
|
+
tasks_per_batch: int = 10,
|
|
1933
|
+
max_concurrent_batches: int = 1,
|
|
1934
|
+
poll_interval: float = 5.0,
|
|
1935
|
+
execution_mode: str = "local",
|
|
1936
|
+
slurm_config: SlurmConfig | None = None,
|
|
1937
|
+
polars_engine: str = "streaming",
|
|
1938
|
+
ani_method: str = "popani"
|
|
1939
|
+
) -> None:
|
|
1940
|
+
"""A helper function to quickly set up and run a GeneCompareRunner with given parameters.
|
|
1941
|
+
|
|
1942
|
+
Args:
|
|
1943
|
+
run_dir (str | pathlib.Path): Directory where the runner will operate.
|
|
1944
|
+
container_engine (Engine): An instance of Engine to wrap task commands.
|
|
1945
|
+
comps_db (GenomeComparisonDatabase | None): An instance of GenomeComparisonDatabase containing comparison data.
|
|
1946
|
+
tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
|
|
1947
|
+
max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
|
|
1948
|
+
poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 5.0.
|
|
1949
|
+
execution_mode (str): Execution mode, either "local" or "slurm". Default is "local".
|
|
1950
|
+
polars_engine (str): Polars engine to use. Default is "streaming".
|
|
1951
|
+
ani_method (str): ANI calculation method to use. Default is "popani".
|
|
1952
|
+
"""
|
|
1953
|
+
task_generator = GeneCompareTaskGenerator(
|
|
1954
|
+
data=comps_db.to_complete_input_table(),
|
|
1955
|
+
yield_size=tasks_per_batch,
|
|
1956
|
+
container_engine=container_engine,
|
|
1957
|
+
comp_config=comps_db.config,
|
|
1958
|
+
polars_engine=polars_engine,
|
|
1959
|
+
ani_method=ani_method,
|
|
1960
|
+
)
|
|
1961
|
+
if execution_mode=="local":
|
|
1962
|
+
batch_type="local"
|
|
1963
|
+
elif execution_mode=="slurm":
|
|
1964
|
+
batch_type="slurm"
|
|
1965
|
+
else:
|
|
1966
|
+
raise ValueError(f"Unknown execution mode: {execution_mode}")
|
|
1967
|
+
runner = GeneCompareRunner(
|
|
1968
|
+
run_dir=pathlib.Path(run_dir),
|
|
1969
|
+
task_generator=task_generator,
|
|
1970
|
+
container_engine=container_engine,
|
|
1971
|
+
max_concurrent_batches=max_concurrent_batches,
|
|
1972
|
+
poll_interval=poll_interval,
|
|
1973
|
+
tasks_per_batch=tasks_per_batch,
|
|
1974
|
+
batch_type=batch_type,
|
|
1975
|
+
slurm_config=slurm_config,
|
|
1976
|
+
)
|
|
1977
|
+
asyncio.run(runner.run())
|
|
1978
|
+
|