zipstrain 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1978 @@
1
+ """zipstrain.task_manager
2
+ ========================
3
+ Lightweight, asyncio-driven orchestration primitives for building and running
4
+ scientific data-processing pipelines. This module provides a small, composable
5
+ framework for defining Tasks with explicit inputs/outputs, bundling Tasks into
6
+ Batches (local or Slurm), and coordinating their execution with a live terminal
7
+ UI. It is designed to be easy to extend for new Task types and execution
8
+ environments. For most users, this module is not directly used. However,
9
+ it can be used to define new pipelines that chain together multiple steps with clear input/outputs.
10
+ The unit of execution is a batch, which is a collection of tasks to be executed together.
11
+ Each batch can have an optional finalization step that runs after all tasks are complete.
12
+
13
+
14
+ Key concepts
15
+ ------------
16
+
17
+ - Inputs and Outputs:
18
+ These classes encapsulate task inputs and outputs with validation logics.
19
+ By default, Input and output classes for files, strings, and integers are provided.
20
+ If needed, new types can be defined by subclassing Input or Output.
21
+
22
+ - Engines:
23
+ Any task object can use a container engine (Docker or Apptainer) or run natively (LocalEngine).
24
+
25
+ - Task
26
+ Each task runs a unit of bash script with defined inputs and expected outputs. If an engine is provided,
27
+ the command will be wrapped accordingly to run inside the container.
28
+
29
+ - Batches:
30
+ A batch is a collection of tasks to be executed together. Batches can be run locally or submitted to Slurm.
31
+ Each batch monitors the status of its tasks and updates its own status accordingly. A batch can also have
32
+ expected outputs that are checked after all tasks are complete. Additionally, a batch can have a finalization step that runs after all tasks are complete.
33
+
34
+ - Runner:
35
+ The Runner class orchestrates task generation, batching, and execution. It manages concurrent batch execution,
36
+ monitors progress, and provides a live terminal UI using the rich library.
37
+
38
+ """
39
+
40
+
41
+ from __future__ import annotations
42
+ from enum import StrEnum
43
+ import re
44
+ import pathlib
45
+ from abc import ABC, abstractmethod
46
+ import asyncio
47
+ import subprocess
48
+ import aiofiles
49
+ from pydantic import BaseModel, Field, field_validator
50
+ from rich.live import Live
51
+ from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn
52
+ from rich.console import Console, Group
53
+ from rich.panel import Panel
54
+ from rich.align import Align
55
+ from zipstrain import database
56
+ from rich.columns import Columns
57
+ import polars as pl
58
+ import psutil
59
+ import shutil
60
+ import signal
61
+
62
+
63
+ class SlurmConfig(BaseModel):
64
+ """Configuration model for Slurm batch jobs.
65
+
66
+ Attributes:
67
+ time (str): Time limit for the job in HH:MM:SS format.
68
+ tasks (int): Number of tasks.
69
+ mem (int): Memory in GB.
70
+ additional_params (dict): Additional SLURM parameters as key-value pairs.
71
+
72
+ NOTE: Additional paramters for slurm should be provided in the additional_params dict in the form
73
+ of {"param-name": "param-value"}, e.g., {"cpus-per-task": "4"} will result in the addition of
74
+ "#SBATCH --cpus-per-task=4" to the sbatch script.
75
+
76
+ """
77
+ time: str = Field(description="Time limit for the job.")
78
+ tasks: int = Field(default=1, description="Number of tasks.")
79
+ mem: int = Field(default=4, description="Memory in GB.")
80
+ additional_params: dict[str, str] = Field(default_factory=dict, description="Additional SLURM parameters as key-value pairs.")
81
+
82
+ @field_validator("time")
83
+ def validate_time(cls, v):
84
+ """Validate time format HH:MM:SS (H..HHH allowed)."""
85
+ if not re.match(r"^\d{1,3}:\d{2}:\d{2}$", v):
86
+ raise ValueError("Time must be in the format HH:MM:SS (H..HHH allowed)")
87
+ return v
88
+
89
+ def to_slurm_args(self) -> str:
90
+ """Generates the slurm batch file header form the configuration object"""
91
+ args = [
92
+ f"#SBATCH --time={self.time}",
93
+ f"#SBATCH --ntasks={self.tasks}",
94
+ f"#SBATCH --mem={self.mem}G",
95
+ ]
96
+ for key, value in self.additional_params.items():
97
+ args.append(f"#SBATCH --{key}={value}")
98
+ return "\n".join(args)
99
+
100
+ @classmethod
101
+ def from_json(cls, json_path: str | pathlib.Path) -> SlurmConfig:
102
+ """Load SlurmConfig from a JSON file."""
103
+ path = pathlib.Path(json_path)
104
+ if not path.exists():
105
+ raise FileNotFoundError(f"Slurm config file {json_path} does not exist.")
106
+ return cls.model_validate_json(path.read_text())
107
+
108
+ async def write_file(path: pathlib.Path, text: str, file_semaphore: asyncio.Semaphore) -> None:
109
+ async with file_semaphore:
110
+ async with aiofiles.open(path, "w") as f:
111
+ await f.write(text)
112
+
113
+ async def read_file(path: pathlib.Path, file_semaphore: asyncio.Semaphore) -> str:
114
+ async with file_semaphore:
115
+ async with aiofiles.open(path, "r") as f:
116
+ content = await f.read()
117
+ return content
118
+
119
+ class Status(StrEnum):
120
+ """Enumeration of possible task and batch statuses."""
121
+ BATCH_NOT_ASSIGNED = "batch_not_assigned"
122
+ NOT_STARTED = "not_started"
123
+ RUNNING = "running"
124
+ DONE = "done"
125
+ FAILED = "failed"
126
+ SUBMITTED = "submitted"
127
+ SUCCESS = "success"
128
+ PENDING = "pending"
129
+
130
+
131
+ class Input(ABC):
132
+ """Abstract base class for task inputs. DO NOT INSTANTIATE DIRECTLY.
133
+ Most commonly used Input types are provided but if you want to define a new one,
134
+ subclass this and implement validate() and get_value().
135
+ """
136
+ def __init__(self, value: str | int) -> None:
137
+ self.value = value
138
+ self.validate()
139
+
140
+ @abstractmethod
141
+ def validate(self) -> None:
142
+ ...
143
+
144
+ @abstractmethod
145
+ def get_value(self) -> str | int:
146
+ ...
147
+
148
+
149
+ class FileInput(Input):
150
+ """This is used when the input is a file path. By default, the validate method checks for file existence."""
151
+ def validate(self, check_exists: bool = True) -> None:
152
+ if check_exists and not pathlib.Path(self.value).exists():
153
+ raise FileNotFoundError(f"Input file {self.value} does not exist.")
154
+
155
+ def get_value(self) -> str:
156
+ """Returns the absolute path of the input file as a string."""
157
+ return str(pathlib.Path(self.value).absolute())
158
+
159
+
160
+ class StringInput(Input):
161
+ """This is used when the input is a string."""
162
+
163
+ def validate(self) -> None:
164
+ """Validate that the input value is a string."""
165
+ if not isinstance(self.value, str):
166
+ raise ValueError(f"Input value {self.value!r} is not a string.")
167
+
168
+ def get_value(self) -> str:
169
+ """Returns the string value."""
170
+ return str(self.value)
171
+
172
+
173
+ class IntInput(Input):
174
+ """This is used when the input is an integer."""
175
+ def validate(self) -> None:
176
+ """Validate that the input value is an integer."""
177
+ if not isinstance(self.value, int):
178
+ raise ValueError(f"Input value {self.value!r} is not an integer.")
179
+
180
+ def get_value(self) -> str:
181
+ """
182
+ Returns the integer value as a string.
183
+ """
184
+ return str(self.value)
185
+
186
+
187
+ class Output(ABC):
188
+ """Abstract base class for task outputs. DO NOT INSTANTIATE DIRECTLY.
189
+ Most commonly used Output types are provided but if you want to define a new one,
190
+ subclass this and implement ready().
191
+ This method is used to check if the output is ready/valid after task completion.
192
+ """
193
+ def __init__(self) -> None:
194
+ self._value = None ## Will be set by the task when it completes
195
+ self.task = None ## Will be set when the output is registered to a task
196
+
197
+ @property
198
+ def value(self):
199
+ return self._value
200
+
201
+ @abstractmethod
202
+ def ready(self) -> bool:
203
+ ...
204
+
205
+ def register_task(self, task: Task) -> None:
206
+ """Registers the task that produces this output. In most cases, you won't need to override this.
207
+
208
+ Args:
209
+ task (Task): The task that produces this output.
210
+ """
211
+ self.task = task
212
+
213
+ class FileOutput(Output):
214
+ """This is used when the output is a file path.
215
+
216
+ Args:
217
+ expected_file (str): The expected output file name relative to the task directory.
218
+ """
219
+ def __init__(self, expected_file:str) -> None:
220
+ self._expected_file_name = expected_file ### When the task is finished, the expected file should be in task.task_dir / expected_file otherwise ready() will return False
221
+
222
+ def ready(self) -> bool:
223
+ """Check if the expected output file exists."""
224
+ return True if self.expected_file.absolute().exists() else False
225
+
226
+ def register_task(self, task: Task) -> None:
227
+ """Registers the task that produces this output and sets the expected file path.
228
+
229
+ Args:
230
+ task (Task): The task that produces this output.
231
+ """
232
+ super().register_task(task)
233
+ self.expected_file = self.task.task_dir / self._expected_file_name
234
+
235
+
236
+ class BatchFileOutput(Output):
237
+ """This is used when the output is a file path relative to the batch directory.
238
+ Also it will be registered to the batch instead of the task.
239
+ """
240
+ def __init__(self, expected_file:str) -> None:
241
+ self._expected_file_name = expected_file
242
+
243
+ def ready(self) -> bool:
244
+ """Check if the expected output file exists."""
245
+ return True if self.expected_file.absolute().exists() else False
246
+
247
+ def register_batch(self, batch: Batch) -> None:
248
+ """Registers the batch that produces this output and sets the expected file path.
249
+
250
+ Args:
251
+ batch (Batch): The batch that produces this output and sets the expected file path.
252
+ """
253
+ self.expected_file = batch.batch_dir / self._expected_file_name
254
+
255
+
256
+ class StringOutput(Output):
257
+ """This is used when the output is a string."""
258
+ def ready(self) -> bool:
259
+ """Check if the output value is a string."""
260
+ if isinstance(self._value, str):
261
+ return True
262
+ elif self._value is not None:
263
+ raise ValueError(f"Output value for task {self.task.id} is not a string.")
264
+ else:
265
+ return False
266
+
267
+
268
+ class IntOutput(Output):
269
+ """This is used when the output is an integer."""
270
+ def ready(self) -> bool:
271
+ """Check if the output value is an integer."""
272
+ if isinstance(self._value, int):
273
+ return True
274
+ elif self._value is not None:
275
+ raise ValueError(f"Output value for task {self.task.id} is not an integer.")
276
+ else:
277
+ return False
278
+ return False
279
+
280
+
281
+ class Engine(ABC):
282
+ def __init__(self, address: str) -> None:
283
+ self.address = address
284
+
285
+ @abstractmethod
286
+ def wrap(self, command: str, file_inputs: list[FileInput]) -> str:
287
+ ...
288
+
289
+
290
+ class DockerEngine(Engine):
291
+ def wrap(self, command: str, file_inputs: list[FileInput]) -> str:
292
+ volume_mounts = " ".join(
293
+ [f"-v {file_input.get_value()}:{file_input.get_value()}" for file_input in file_inputs]
294
+ )
295
+ return f"docker run {volume_mounts} {self.address} {command}"
296
+
297
+
298
+ class ApptainerEngine(Engine):
299
+ def wrap(self, command: str, file_inputs: list[FileInput]) -> str:
300
+ volume_mounts = "--bind " + ",".join(
301
+ [f"{file_input.get_value()}:{file_input.get_value()}" for file_input in file_inputs]
302
+ )
303
+ return f"apptainer run {volume_mounts} {self.address} {command}"
304
+
305
+
306
+ class LocalEngine(Engine):
307
+ def wrap(self, command: str, file_inputs: list[FileInput]) -> str:
308
+ return command
309
+
310
+
311
+ class Task(ABC):
312
+ """Abstract base class for tasks. DO NOT INSTANTIATE DIRECTLY. Any new task type should subclass this
313
+ and implement the TEMPLATE_CMD class attribute. Inputs and expected outputs are specified using <>.
314
+ As an example, if a task has an input file called "input-file" and an expected output file called "output-file",
315
+ the TEMPLATE_CMD could be something like:
316
+ TEMPLATE_CMD = "some_command --input <input-file> --output <output-file>"
317
+ the outputs and inputs will be mapped to the command when map_io() is called later in the runtime.
318
+ """
319
+ TEMPLATE_CMD = ""
320
+
321
+ def __init__(
322
+ self,
323
+ id: str,
324
+ inputs: dict[str, Input | Output],
325
+ expected_outputs: dict[str, Output] ,
326
+ engine: Engine,
327
+ batch_obj: Batch | None = None,
328
+ file_semaphore: asyncio.Semaphore | None = None
329
+ ) -> None:
330
+ self.id = id
331
+ self.inputs = inputs
332
+ self.expected_outputs = expected_outputs
333
+ self._batch_obj = batch_obj
334
+ self.engine = engine
335
+ self._status = self._get_initial_status()
336
+ self.file_semaphore = file_semaphore
337
+
338
+ def map_io(self) -> None:
339
+ """Maps inputs and expected outputs to the command template. Note that when this method is called,
340
+ all of the inputs and outputs in the TEMPLATE_CMD must be defined in the inputs and expected_outputs dictionaries.
341
+ However, this method is not called by the user directly. It is called by the Batch when the task is added to a batch.
342
+ """
343
+ cmd = self.TEMPLATE_CMD
344
+ for key, value in self.inputs.items():
345
+ cmd = cmd.replace(f"<{key}>", value.get_value())
346
+ # if any placeholders remain, report them
347
+
348
+ for handle, output in self.expected_outputs.items():
349
+ cmd = cmd.replace(f"<{handle}>", str(output.expected_file.absolute()))
350
+ remaining = re.findall(r"<\w+>", cmd)
351
+ if remaining:
352
+ raise ValueError(f"Not all inputs were mapped in task {self.id}. Remaining placeholders: {remaining}")
353
+ self._command = cmd
354
+
355
+ @property
356
+ def batch_dir(self) -> pathlib.Path:
357
+ """Returns the batch directory path. Raises an error if the task is not associated with any batch yet."""
358
+ if self._batch_obj is None:
359
+ raise ValueError(f"Task {self.id} is not associated with any batch yet.")
360
+ return self._batch_obj.batch_dir
361
+
362
+ @property
363
+ def task_dir(self) -> pathlib.Path:
364
+ """Returns the task directory path."""
365
+ return self.batch_dir / self.id
366
+
367
+ @property
368
+ def command(self) -> str:
369
+ """Returns the command to be executed, wrapped with the engine if applicable."""
370
+ file_inputs = [v for v in self.inputs.values() if isinstance(v, FileInput)]
371
+ return self.engine.wrap(self._command, file_inputs)
372
+
373
+ @property
374
+ def pre_run(self) -> str:
375
+ """Does the necessary setup before running the task command. This should not be overridden by subclasses unless a task needs special setup like
376
+ batch aggregation."""
377
+ return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self.task_dir.absolute()}"
378
+
379
+ @property
380
+ def status(self) -> str:
381
+ """Returns the current status of the task."""
382
+ return self._status
383
+
384
+ @property
385
+ def post_run(self) -> str:
386
+ """Does the necessary steps after running the task command. This should not be overridden by subclasses unless a task needs special teardown like
387
+ batch aggregation."""
388
+ return f"cd {self.batch_dir.absolute()} && echo {Status.DONE.value} > {self.task_dir.absolute()}/.status"
389
+
390
+ async def get_status(self) -> str:
391
+ """Asynchronously reads the task status from the .status file in the task directory."""
392
+ status_path = self.task_dir / ".status"
393
+ # read the status file if it exists
394
+ if status_path.exists():
395
+ raw = await read_file(status_path, self.file_semaphore)
396
+ self._status = raw.strip()
397
+
398
+ # if task reported 'done', check outputs to decide success/failure
399
+ if self._status == Status.DONE.value:
400
+ all_ready = True
401
+ try:
402
+ for output in self.expected_outputs.values():
403
+ if not output.ready():
404
+ all_ready = False
405
+ break
406
+ except Exception:
407
+ all_ready = False
408
+
409
+ if all_ready:
410
+ self._status = Status.SUCCESS.value
411
+ await write_file(status_path, Status.SUCCESS.value, self.file_semaphore)
412
+ else:
413
+ self._status = Status.FAILED.value
414
+ await write_file(status_path, Status.FAILED.value, self.file_semaphore)
415
+ raise ValueError(f"Task {self.id} reported done but outputs are not ready or invalid. {self.expected_outputs['output-file'].expected_file.absolute()}")
416
+
417
+ return self._status
418
+
419
+ def _get_initial_status(self) -> str:
420
+ """Returns the initial status of the task based on the presence of the batch and task directories."""
421
+ if self._batch_obj is None:
422
+ return Status.BATCH_NOT_ASSIGNED.value
423
+ if not self.task_dir.exists():
424
+ return Status.NOT_STARTED.value
425
+ status_file = self.task_dir / ".status"
426
+ with open(status_file, mode="r") as f:
427
+ status_as_written = f.read().strip()
428
+ if status_as_written in (Status.DONE.value, Status.SUCCESS.value):
429
+ all_ready = True
430
+ try:
431
+ for output in self.expected_outputs.values():
432
+ if not output.ready():
433
+ all_ready = False
434
+ break
435
+ except Exception:
436
+ all_ready = False
437
+
438
+ if all_ready:
439
+ return Status.SUCCESS.value
440
+ else:
441
+ return Status.FAILED.value
442
+
443
+ class TaskGenerator(ABC):
444
+ """Abstract base class for task generators. DO NOT INSTANTIATE DIRECTLY. A subclass of this class
445
+ should provide an async generator method called generate_tasks() that yields lists of Task objects in an async manner.
446
+ Some important concepts:
447
+
448
+ - generate_tasks() is an async generator that yields lists of Task objects.
449
+
450
+ - yield_size determines how many tasks are generated and yielded at a time.
451
+
452
+ - get_total_tasks() returns the total number of tasks that can be generated.
453
+
454
+ """
455
+ def __init__(self,
456
+ data,
457
+ yield_size:int,
458
+
459
+ ):
460
+ self.data = data
461
+ self.yield_size = yield_size
462
+ self._total_tasks = self.get_total_tasks()
463
+
464
+ @abstractmethod
465
+ async def generate_tasks(self) -> list[Task]:
466
+ pass
467
+
468
+ @abstractmethod
469
+ def get_total_tasks(self) -> int:
470
+ pass
471
+
472
+ class ProfileTaskGenerator(TaskGenerator):
473
+ """This TaskGenerator generates FastProfileTask objects from a polars DataFrame. Each task profiles a BAM file."""
474
+ def __init__(
475
+ self,
476
+ data: pl.LazyFrame,
477
+ yield_size: int,
478
+ container_engine: Engine,
479
+ stb_file: str,
480
+ profile_bed_file: str,
481
+ gene_range_file: str,
482
+ genome_length_file: str,
483
+ num_procs: int = 4,
484
+ breadth_min_cov: int = 1,
485
+ ) -> None:
486
+ super().__init__(data, yield_size)
487
+ self.stb_file = pathlib.Path(stb_file)
488
+ self.profile_bed_file = pathlib.Path(profile_bed_file)
489
+ self.gene_range_file = pathlib.Path(gene_range_file)
490
+ self.genome_length_file = pathlib.Path(genome_length_file)
491
+ self.num_procs = num_procs
492
+ self.breadth_min_cov = breadth_min_cov
493
+ self.engine = container_engine
494
+ if type(self.data) is not pl.LazyFrame:
495
+ raise ValueError("data must be a polars LazyFrame.")
496
+ for path_attr in [
497
+ self.stb_file,
498
+ self.profile_bed_file,
499
+ self.gene_range_file,
500
+ self.genome_length_file,
501
+ ]:
502
+ if not path_attr.exists():
503
+ raise FileNotFoundError(f"File {path_attr} does not exist.")
504
+
505
+ def get_total_tasks(self) -> int:
506
+ """Returns total number of profiles to be generated."""
507
+ return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]
508
+
509
+ async def generate_tasks(self) -> list[Task]:
510
+ """Yeilds lists of FastProfileTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
511
+ while polars is collecting data to avoid blocking.
512
+ """
513
+ for offset in range(0, self._total_tasks, self.yield_size):
514
+ batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
515
+ tasks = []
516
+ for row in batch_df.iter_rows(named=True):
517
+ inputs = {
518
+ "bam-file": FileInput(row["bamfile"]),
519
+ "sample-name": StringInput(row["sample_name"]),
520
+ "stb-file": FileInput(self.stb_file),
521
+ "bed-file": FileInput(self.profile_bed_file),
522
+ "gene-range-table": FileInput(self.gene_range_file),
523
+ "genome-length-file": FileInput(self.genome_length_file),
524
+ "num-threads": IntInput(self.num_procs),
525
+ "breadth-min-cov": IntInput(self.breadth_min_cov),
526
+ }
527
+ expected_outputs ={
528
+ "profile": FileOutput(row["sample_name"]+".parquet" ),
529
+ "breadth": FileOutput(row["sample_name"]+"_breadth.parquet" ),
530
+ "scaffold": FileOutput(row["sample_name"]+".parquet.scaffolds" ),
531
+ }
532
+ task = ProfileBamTask(id=row["sample_name"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
533
+ tasks.append(task)
534
+ yield tasks
535
+
536
+ class CompareTaskGenerator(TaskGenerator):
537
+ """This TaskGenerator generates FastCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genomes functionality in
538
+ zipstrain.compare module.
539
+
540
+ Args:
541
+ data (pl.LazyFrame): Polars LazyFrame containing the data for generating tasks.
542
+ yield_size (int): Number of tasks to yield at a time.
543
+ comp_config (database.GenomeComparisonConfig): Configuration for genome comparison.
544
+ memory_mode (str): Memory mode for the comparison task. Default is "heavy".
545
+ polars_engine (str): Polars engine to use. Default is "streaming".
546
+ chrom_batch_size (int): Chromosome batch size for the comparison task in light memory mode. Default is 10000.
547
+ """
548
+ def __init__(
549
+ self,
550
+ data: pl.LazyFrame,
551
+ yield_size: int,
552
+ container_engine: Engine,
553
+ comp_config: database.GenomeComparisonConfig,
554
+ memory_mode: str = "heavy",
555
+ polars_engine: str = "streaming",
556
+ chrom_batch_size: int = 10000,
557
+ ) -> None:
558
+ super().__init__(data, yield_size)
559
+ self.comp_config = comp_config
560
+ self.engine = container_engine
561
+ self.memory_mode = memory_mode
562
+ self.polars_engine = polars_engine
563
+ self.chrom_batch_size = chrom_batch_size
564
+ if type(self.data) is not pl.LazyFrame:
565
+ raise ValueError("data must be a polars LazyFrame.")
566
+
567
+ def get_total_tasks(self) -> int:
568
+ """Returns total number of pairwise comparisons to be made."""
569
+ return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]
570
+
571
+ async def generate_tasks(self) -> list[Task]:
572
+ """Yeilds lists of FastCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
573
+ while polars is collecting data to avoid blocking.
574
+ """
575
+ for offset in range(0, self._total_tasks, self.yield_size):
576
+ batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
577
+ tasks = []
578
+ for row in batch_df.iter_rows(named=True):
579
+ inputs = {
580
+ "mpile_1_file": FileInput(row["profile_location_1"]),
581
+ "mpile_2_file": FileInput(row["profile_location_2"]),
582
+ "scaffold_1_file": FileInput(row["scaffold_location_1"]),
583
+ "scaffold_2_file": FileInput(row["scaffold_location_2"]),
584
+ "null_model_file": FileInput(self.comp_config.null_model_loc),
585
+ "stb_file": FileInput(self.comp_config.stb_file_loc),
586
+ "min_cov": IntInput(self.comp_config.min_cov),
587
+ "min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
588
+ "memory-mode": StringInput(self.memory_mode),
589
+ "chrom-batch-size": IntInput(self.chrom_batch_size),
590
+ "genome-name": StringInput(self.comp_config.scope),
591
+ "engine": StringInput(self.polars_engine),
592
+ }
593
+ expected_outputs ={
594
+ "output-file": FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_comparison.parquet" ),
595
+
596
+ }
597
+ task = FastCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
598
+ tasks.append(task)
599
+ yield tasks
600
+
601
+
602
+ class Batch(ABC):
603
+ """Batch is a collection of tasks to be executed as a group. This is a base class and should not be instantiated directly.
604
+ A batch is the unit of execution meaning that the enitre batch is either run locally or submitted to a job scheduler like Slurm.
605
+
606
+ Args:
607
+ tasks (list[Task]): List of Task objects to be included in the batch.
608
+ id (str): Unique identifier for the batch.
609
+ run_dir (pathlib.Path): Directory where the batch will be executed.
610
+ expected_outputs (list[Output]): List of expected outputs for the batch.
611
+ """
612
+ TEMPLATE_CMD = ""
613
+
614
+ def __init__(self, tasks: list[Task],
615
+ id: str,
616
+ run_dir: pathlib.Path,
617
+ expected_outputs: list[Output],
618
+ file_semaphore: asyncio.Semaphore| None = None
619
+ ) -> None:
620
+ self.id = id
621
+ self.tasks = tasks
622
+ self.run_dir = pathlib.Path(run_dir)
623
+ self.batch_dir = self.run_dir / self.id
624
+ self.retry_count = 0
625
+ self.expected_outputs = expected_outputs
626
+ self.file_semaphore = file_semaphore
627
+ for output in self.expected_outputs:
628
+ if isinstance(output, BatchFileOutput):
629
+ output.register_batch(self)
630
+ self._status = self._get_initial_status()
631
+ for task in self.tasks:
632
+ task._batch_obj = self
633
+ task.file_semaphore = self.file_semaphore
634
+ for output in task.expected_outputs.values():
635
+ output.register_task(task)
636
+ task._status= task._get_initial_status()
637
+ task.map_io()
638
+
639
+ self._runner_obj:Runner = None
640
+
641
+
642
+
643
+ def _get_initial_status(self) -> str:
644
+ """Returns the initial status of the batch based on the presence of the batch directory."""
645
+ if not self.batch_dir.exists():
646
+ return Status.NOT_STARTED.value
647
+ with open(self.batch_dir / ".status", mode="r") as f:
648
+ status_as_written = f.read().strip()
649
+ if status_as_written in (Status.DONE.value, Status.SUCCESS.value):
650
+ all_ready = True
651
+ try:
652
+ for output in self.expected_outputs:
653
+ if not output.ready():
654
+ all_ready = False
655
+ break
656
+ except Exception:
657
+ all_ready = False
658
+
659
+ if all_ready:
660
+ return Status.SUCCESS.value
661
+ else:
662
+ return Status.FAILED.value
663
+
664
+ def cleanup(self) -> None:
665
+ """The base class defines if any cleanup is needed after batch success. By default, it does nothing."""
666
+ return None
667
+
668
+ @abstractmethod
669
+ async def cancel(self) -> None:
670
+ """Cancels the batch. This method should be implemented by subclasses."""
671
+ ...
672
+
673
+ def outputs_ready(self) -> bool:
674
+ """Check if all BATCH-LEVEL expected outputs are ready."""
675
+ try:
676
+ for output in self.expected_outputs:
677
+ if not output.ready():
678
+ return False
679
+ return True
680
+ except Exception:
681
+ return False
682
+
683
+ async def _collect_task_status(self) -> list[str]:
684
+ """Collects the status of all tasks asynchronously."""
685
+ return await asyncio.gather(*[task.get_status() for task in self.tasks])
686
+
687
+ @abstractmethod
688
+ async def run(self) -> None:
689
+ """Runs the batch. This method should be implemented by subclasses."""
690
+ ...
691
+
692
+ @abstractmethod
693
+ def _parse_job_id(self, sbatch_output: str) -> str:
694
+ """Parses the job ID from the sbatch output. This method should be implemented by subclasses."""
695
+ ...
696
+
697
+ @property
698
+ def status(self) -> str:
699
+ """Returns the current status of the batch."""
700
+ return self._status
701
+
702
+ @property
703
+ def stats(self) -> dict[str, str]:
704
+ """Returns a dictionary of task IDs and their statuses."""
705
+ return {task.id: task.status for task in self.tasks}
706
+
707
+ async def update_status(self) -> str:
708
+ """Updates the status of the batch by collecting the status of all tasks."""
709
+ await self._collect_task_status()
710
+ def _set_file_semaphore(self, file_semaphore: asyncio.Semaphore) -> None:
711
+ self.file_semaphore = file_semaphore
712
+ for task in self.tasks:
713
+ task.file_semaphore = file_semaphore
714
+
715
+ class LocalBatch(Batch):
716
+ """Batch that runs tasks locally in a single shell script."""
717
+ TEMPLATE_CMD = "#!/bin/bash\n"
718
+
719
+ def __init__(self, tasks, id, run_dir, expected_outputs) -> None:
720
+ super().__init__(tasks, id, run_dir, expected_outputs)
721
+ self._script = self.TEMPLATE_CMD + "\nset -o pipefail\n"
722
+ self._proc: asyncio.subprocess.Process | None = None
723
+
724
+
725
+ async def run(self) -> None:
726
+ """This method runs all tasks in the batch locally by creating a shell script and executing it."""
727
+ if self.status != Status.SUCCESS and self.status != Status.FAILED.value:
728
+ self.batch_dir.mkdir(parents=True, exist_ok=True)
729
+ self._status = Status.RUNNING.value
730
+
731
+
732
+ await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
733
+
734
+ for task in self.tasks:
735
+ if task.status == Status.NOT_STARTED.value:
736
+ task.task_dir.mkdir(parents=True, exist_ok=True) # Create task directory
737
+ await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)
738
+
739
+ script_path = self.batch_dir / f"{self.id}.sh" # Path to the shell script for the batch
740
+
741
+ script = self._script
742
+ for task in self.tasks:
743
+ if task.status == Status.NOT_STARTED.value or task.status == Status.FAILED.value:
744
+ script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"
745
+
746
+ await write_file(script_path, script, self.file_semaphore)
747
+
748
+ await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
749
+
750
+ self._proc = await asyncio.create_subprocess_exec(
751
+ "bash", f"{self.id}.sh",
752
+ stdout=asyncio.subprocess.PIPE,
753
+ stderr=asyncio.subprocess.PIPE,
754
+ cwd=self.batch_dir,
755
+ )
756
+ try:
757
+ out_bytes, err_bytes = await self._proc.communicate()
758
+ except asyncio.CancelledError:
759
+ if self._proc and self._proc.returncode is None:
760
+ self._proc.terminate()
761
+ raise
762
+
763
+ await write_file(self.batch_dir / f"{self.id}.out", out_bytes.decode(), self.file_semaphore)
764
+ await write_file(self.batch_dir / f"{self.id}.err", err_bytes.decode(), self.file_semaphore)
765
+
766
+ if self._proc.returncode == 0 and self.outputs_ready():
767
+ self.cleanup()
768
+ self._status = Status.SUCCESS.value
769
+ await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
770
+ else:
771
+ self._status = Status.FAILED.value
772
+ await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
773
+
774
+ elif self.status == Status.SUCCESS.value and self.outputs_ready():
775
+ self._status = Status.SUCCESS.value
776
+ else:
777
+ self._status = Status.FAILED.value
778
+
779
+ def _parse_job_id(self, sbatch_output):
780
+ return super()._parse_job_id(sbatch_output)
781
+
782
+ def cleanup(self) -> None:
783
+ super().cleanup()
784
+
785
+ async def cancel(self) -> None:
786
+ """Cancels the local batch by terminating the subprocess if it's running."""
787
+ if self._proc and self._proc.returncode is None:
788
+ self._proc.terminate()
789
+ try:
790
+ await asyncio.wait_for(self._proc.wait(), timeout=5.0)
791
+ except asyncio.TimeoutError:
792
+ self._proc.kill()
793
+ await self._proc.wait()
794
+ self._status = Status.FAILED.value
795
+ await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
796
+
797
+
798
+ class SlurmBatch(Batch):
799
+ """Batch that submits tasks to a Slurm job scheduler.
800
+
801
+ Args:
802
+ tasks (list[Task]): List of Task objects to be included in the batch.
803
+ id (str): Unique identifier for the batch.
804
+ run_dir (pathlib.Path): Directory where the batch will be executed.
805
+ expected_outputs (list[Output]): List of expected outputs for the batch.
806
+ slurm_config (SlurmConfig): Configuration for Slurm job submission. Refer to SlurmConfig class for details."""
807
+ TEMPLATE_CMD = "#!/bin/bash\n"
808
+
809
+ def __init__(self, tasks, id, run_dir, expected_outputs, slurm_config: SlurmConfig) -> None:
810
+ super().__init__(tasks, id, run_dir, expected_outputs)
811
+ self._check_slurm_works()
812
+ self.slurm_config = slurm_config
813
+ self._script = self.TEMPLATE_CMD + self.slurm_config.to_slurm_args() + "\nset -o pipefail\n"
814
+ self._job_id = None
815
+
816
+ def _check_slurm_works(self) -> None:
817
+ """Checks if Slurm commands are available on the system."""
818
+ try:
819
+ subprocess.run(["sbatch", "--version"], capture_output=True, text=True, check=True)
820
+ subprocess.run(["sacct", "--version"], capture_output=True, text=True, check=True)
821
+ except:
822
+ raise EnvironmentError("Slurm does not seem to be available or configured properly on this system.")
823
+
824
+ async def cancel(self) -> None:
825
+ """Cancel a running or submitted Slurm job."""
826
+ if self._job_id:
827
+ proc = await asyncio.create_subprocess_exec(
828
+ "scancel", self._job_id,
829
+ stdout=asyncio.subprocess.PIPE,
830
+ stderr=asyncio.subprocess.PIPE,
831
+ )
832
+ await proc.wait()
833
+
834
+ self._status = Status.FAILED.value
835
+ await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
836
+
837
+ async def run(self) -> None:
838
+ """This method submits the batch to Slurm by creating a batch script and using sbatch command. It also monitors the job status until completion.
839
+ This method is unavoidably different from LocalBatch.run() because of the nature of Slurm job submission.
840
+ """
841
+
842
+ if self.status != Status.SUCCESS and self.status != Status.FAILED.value:
843
+ self.batch_dir.mkdir(parents=True, exist_ok=True)
844
+ self._status = Status.RUNNING.value
845
+ await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
846
+ # create task directories and initialize .status if needed
847
+
848
+ for task in self.tasks:
849
+ if task.status == Status.NOT_STARTED.value:
850
+ task.task_dir.mkdir(parents=True, exist_ok=True)
851
+ await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)
852
+ # write the batch script (all tasks included)
853
+
854
+ batch_path = self.batch_dir / f"{self.id}.batch"
855
+ script=self._script
856
+ for task in self.tasks:
857
+ if task.status == Status.NOT_STARTED.value:
858
+ script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"
859
+
860
+ await write_file(batch_path, script, self.file_semaphore)
861
+
862
+ proc = await asyncio.create_subprocess_exec(
863
+ "sbatch","--parsable", batch_path.name,
864
+ stdout=asyncio.subprocess.PIPE,
865
+ stderr=asyncio.subprocess.PIPE,
866
+ cwd=str(self.batch_dir),
867
+ )
868
+ out_bytes, out_err = await proc.communicate()
869
+ out = out_bytes.decode().strip() if out_bytes else ""
870
+ if proc.returncode == 0:
871
+ try:
872
+ self._job_id = self._parse_job_id(out)
873
+ self._status = Status.SUBMITTED.value
874
+ await self._wait_to_finish()
875
+ except Exception:
876
+ self._status = Status.FAILED.value
877
+ else:
878
+ self._status = Status.FAILED.value
879
+
880
+ if self._status == Status.SUCCESS.value and self.outputs_ready():
881
+ self.cleanup()
882
+ self._status = Status.SUCCESS.value
883
+ await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
884
+ else:
885
+ self._status = Status.FAILED.value
886
+ await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
887
+
888
+ else:
889
+ if self.status == Status.SUCCESS.value and self.outputs_ready():
890
+ self._status = Status.SUCCESS.value
891
+ else:
892
+ self._status = Status.FAILED.value
893
+
894
+ def _parse_job_id(self, sbatch_output: str) -> str:
895
+ if match := re.search(r"(\d+)", sbatch_output):
896
+ return match.group(1)
897
+ else:
898
+ raise ValueError("Could not parse job ID from sbatch output.")
899
+
900
+ async def _wait_to_finish(self,sleep_duration:float=1.0):
901
+ while self.status not in (Status.SUCCESS.value, Status.FAILED.value):
902
+ await self.update_status()
903
+ await asyncio.sleep(sleep_duration)
904
+
905
+ async def update_status(self):
906
+ if self._job_id is None:
907
+ self._status=Status.NOT_STARTED.value
908
+ else:
909
+ await self._collect_task_status()
910
+ out= await asyncio.create_subprocess_exec(
911
+ "sacct", "-j", self._job_id, "--format=State", "--noheader","--allocations",
912
+ stdout=asyncio.subprocess.PIPE,
913
+ stderr=asyncio.subprocess.PIPE,
914
+ )
915
+ out_bytes, _ = await out.communicate()
916
+ if out_bytes:
917
+ state = out_bytes.decode().strip()
918
+ if state in ["FAILED", "CANCELLED", "TIMEOUT"]:
919
+ self._status = Status.FAILED.value
920
+ elif state=="RUNNING":
921
+ self._status = Status.RUNNING.value
922
+ elif state in ["COMPLETED", "COMPLETING"]:
923
+ self._status = Status.SUCCESS.value
924
+ else:
925
+ self._status = Status.PENDING.value
926
+ await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
927
+
928
+ console = Console()
929
+
930
+ def get_cpu_usage():
931
+ """Returns the current CPU usage percentage."""
932
+ return psutil.cpu_percent(interval=0.1)
933
+
934
+ def get_memory_usage():
935
+ """Returns the current memory usage percentage."""
936
+ return psutil.virtual_memory().percent
937
+
938
+
939
+ class Runner(ABC):
940
+ """Base Runner class to manage task generation, batching, and execution.
941
+
942
+ Args:
943
+ run_dir (str | pathlib.Path): Directory where the runner will operate.
944
+ task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
945
+ container_engine (Engine): An instance of Engine to wrap task commands.
946
+ batch_factory (Batch): The class that creates Batch instances. It should be a subclass of Batch or its subclasses.
947
+ final_batch_factory (Batch): A callable that creates the final Batch instance.
948
+ max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
949
+ poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
950
+ tasks_per_batch (int): Number of tasks to include in each batch. Default is
951
+ batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
952
+ slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
953
+ is "slurm". Default is None.
954
+
955
+ """
956
+ TERMINAL_BATCH_STATES = {Status.SUCCESS.value, Status.FAILED.value}
957
+ def __init__(self,
958
+ run_dir: str | pathlib.Path,
959
+ task_generator: TaskGenerator,
960
+ container_engine: Engine,
961
+ batch_factory: Batch,
962
+ final_batch_factory: Batch,
963
+ max_concurrent_batches: int = 1,
964
+ poll_interval: float = 1.0,
965
+ tasks_per_batch: int = 10,
966
+ batch_type: str = "local",
967
+ slurm_config: SlurmConfig | None = None,
968
+ max_retries: int = 3,
969
+ ) -> None:
970
+ self.run_dir = pathlib.Path(run_dir)
971
+ self.run_dir.mkdir(parents=True, exist_ok=True)
972
+ self.task_generator = task_generator
973
+ self.container_engine = container_engine
974
+ self.max_concurrent_batches = max_concurrent_batches
975
+ self.poll_interval = poll_interval
976
+ self.tasks_per_batch = tasks_per_batch
977
+ self.batch_type = batch_type
978
+ self.slurm_config = slurm_config
979
+ self.tasks_queue: asyncio.Queue = asyncio.Queue(maxsize=2 * max_concurrent_batches * tasks_per_batch)
980
+ self.batches_queue: asyncio.Queue = asyncio.Queue(maxsize=2 * max_concurrent_batches)
981
+ self._finished_batches_count = 0
982
+ self._success_batches_count = 0
983
+ self._produced_tasks_count = 0
984
+ self._active_batches: list[Batch] = []
985
+ self._batch_counter = 0
986
+ self._batcher_done = False
987
+ self._final_batch_created = False
988
+ self.batch_factory = batch_factory
989
+ self.final_batch_factory = final_batch_factory
990
+ self._failed_batches_count = 0
991
+ self.max_retries = max_retries
992
+ self.total_expected_tasks = self.task_generator.get_total_tasks()
993
+ self.total_expected_batches = (self.total_expected_tasks + tasks_per_batch - 1) // tasks_per_batch
994
+ self._shutdown_event = asyncio.Event()
995
+ self._shutdown_initiated = False
996
+
997
+ async def _refill_tasks(self):
998
+ """Repeatedly call task_generator until it returns an empty list. This feeds tasks into the tasks_queue and waits for the queue to have space if it's full in an async manner."""
999
+ async for tasks in self.task_generator.generate_tasks():
1000
+ for task in tasks:
1001
+ await self.tasks_queue.put(task)
1002
+ self._produced_tasks_count += 1
1003
+ await self.tasks_queue.put(None)
1004
+
1005
+ @abstractmethod
1006
+ async def _batcher(self):
1007
+ ...
1008
+
1009
+ def _create_final_batch(self) -> Batch|None:
1010
+ """Creates the final batch using the final_batch_factory callable."""
1011
+ return None
1012
+
1013
+
1014
+ async def _shutdown(self):
1015
+ """Cancel all active batches and signal shutdown."""
1016
+ if self._shutdown_initiated:
1017
+ return
1018
+ self._shutdown_initiated = True
1019
+ console.print("[yellow]Shutdown requested. Cancelling active jobs...[/]")
1020
+
1021
+ for batch in list(self._active_batches):
1022
+ try:
1023
+ await batch.cancel()
1024
+ except Exception as e:
1025
+ console.print(f"[red]Error cancelling batch {batch.id}: {e}[/]")
1026
+
1027
+ # Signal the main loop to stop
1028
+ self._shutdown_event.set()
1029
+
1030
+ async def run(self):
1031
+ """
1032
+ Run the producer, batcher and worker coroutines and present a live UI while working.
1033
+ Runs the task generator to produce tasks, batches them using the batcher,
1034
+ and executes batches with up to [max_concurrent_batches] parallel workers.
1035
+ UI: displays an overall panel (produced/finished counts), active batch Progress bars,
1036
+ and system stats (CPU/RAM) using Rich Live to mirror the Runner presentation.
1037
+
1038
+ """
1039
+ asyncio.create_task(self._batcher())
1040
+ asyncio.create_task(self._refill_tasks())
1041
+ semaphore = asyncio.Semaphore(self.max_concurrent_batches)
1042
+ file_semaphore = asyncio.Semaphore(20)
1043
+ async def run_batch(batch: Batch):
1044
+ async with semaphore:
1045
+ while batch.status != Status.SUCCESS.value and batch.retry_count < self.max_retries:
1046
+ await batch.run()
1047
+ if batch.status == Status.SUCCESS.value:
1048
+ break
1049
+ else:
1050
+ batch.retry_count += 1
1051
+
1052
+ self._finished_batches_count += 1
1053
+
1054
+ if batch.status == Status.SUCCESS.value:
1055
+ self._success_batches_count += 1
1056
+
1057
+ elif batch.status == Status.FAILED.value:
1058
+ self._failed_batches_count += 1
1059
+
1060
+ if batch in self._active_batches:
1061
+ self._active_batches.remove(batch)
1062
+
1063
+ # Rich progress objects
1064
+
1065
+ overall_progress = Progress(
1066
+ TextColumn(f"[bold white]{type(self).__name__}[/]"),
1067
+ BarColumn(),
1068
+ TextColumn("• {task.fields[produced_tasks]}/{task.fields[total_expected_tasks]} tasks produced"),
1069
+ TextColumn("• {task.fields[finished_batches]}/{task.fields[total_expected_batches]} batches finished • {task.fields[failed_batches]} batches failed"),
1070
+ TimeElapsedColumn(),
1071
+ expand=True,
1072
+ )
1073
+ overall_task = overall_progress.add_task("overall", produced_tasks=0, total_expected_tasks=self.total_expected_tasks, finished_batches=0, total_expected_batches=self.total_expected_batches, failed_batches=0)
1074
+
1075
+ batch_progress = Progress(
1076
+ TextColumn("[bold white]{task.fields[batch_id]}[/]"),
1077
+ BarColumn(),
1078
+ TextColumn("{task.completed}/{task.total}"),
1079
+ TextColumn("• {task.fields[status]}"),
1080
+ TimeElapsedColumn(),
1081
+ expand=True,
1082
+ )
1083
+
1084
+ batch_to_progress_id: dict[Batch, int] = {}
1085
+ batch_task_totals: dict[Batch, int] = {}
1086
+
1087
+ body = Panel(Group(
1088
+ Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
1089
+ Panel(overall_progress, title="Overall Progress"),
1090
+ Panel(batch_progress, title="Active Batches", height=10),
1091
+ Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
1092
+ ), border_style="magenta")
1093
+
1094
+ loop = asyncio.get_running_loop()
1095
+ for sig in (signal.SIGINT, signal.SIGTERM):
1096
+ loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(self._shutdown()))
1097
+
1098
+ with Live(body, console=console, refresh_per_second=2) as live:
1099
+ while not self._shutdown_event.is_set():
1100
+ await self._update_statuses()
1101
+ if self._batcher_done and self.batches_queue.empty() and len(self._active_batches) == 0:
1102
+ if not self._final_batch_created:
1103
+ final_batch = self._create_final_batch()
1104
+ if final_batch is not None:
1105
+ await self.batches_queue.put(final_batch)
1106
+ self._final_batch_created = True
1107
+ else:
1108
+ self._final_batch_created = True
1109
+ break
1110
+ while len(self._active_batches) < self.max_concurrent_batches and not self.batches_queue.empty():
1111
+ batch = await self.batches_queue.get()
1112
+ if batch is not None:
1113
+ batch._set_file_semaphore(file_semaphore)
1114
+ self._active_batches.append(batch)
1115
+ asyncio.create_task(run_batch(batch))
1116
+ # Update overall progress fields
1117
+ overall_progress.update(overall_task, produced_tasks=self._produced_tasks_count, finished_batches=self._finished_batches_count, failed_batches=self._failed_batches_count)
1118
+ # Add newly queued batches into UI
1119
+
1120
+
1121
+ for batch in list(self._active_batches):
1122
+ if batch not in batch_to_progress_id and batch.status not in self.TERMINAL_BATCH_STATES:
1123
+ total = len(batch.tasks) if batch.tasks else 1
1124
+ task_id = batch_progress.add_task("", total=total, batch_id=f"Batch {batch.id}", status=batch.status)
1125
+ batch_to_progress_id[batch] = task_id
1126
+ batch_task_totals[batch] = total
1127
+ # Remove finished batches from UI
1128
+ for batch, tid in list(batch_to_progress_id.items()):
1129
+ if batch.status in self.TERMINAL_BATCH_STATES:
1130
+ try:
1131
+ batch_progress.remove_task(tid)
1132
+ except Exception:
1133
+ pass
1134
+ del batch_to_progress_id[batch]
1135
+ if batch in batch_task_totals:
1136
+ del batch_task_totals[batch]
1137
+ # Update per-batch progress
1138
+ for batch, tid in batch_to_progress_id.items():
1139
+ completed = sum(1 for t in batch.tasks if t.status in self.TERMINAL_BATCH_STATES)
1140
+ total = batch_task_totals.get(batch, max(1, len(batch.tasks)))
1141
+ batch_progress.update(tid, completed=completed, total=total, status=batch.status)
1142
+ # Update system panel
1143
+ body = Panel(Group(
1144
+ Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
1145
+ Panel(overall_progress, title="Overall Progress"),
1146
+ Panel(batch_progress, title="Active Batches"),
1147
+ Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
1148
+ ), border_style="magenta")
1149
+ live.update(body)
1150
+ await asyncio.sleep(self.poll_interval)
1151
+
1152
+ # final UI summary
1153
+ console.clear()
1154
+ total_batches = self._batch_counter + (1 if self._final_batch_created and self.final_batch_factory is not None else 0)
1155
+ summary = Panel(
1156
+ f"[bold green]Run finished![/]\n\n{self._success_batches_count}/{total_batches} batches succeeded.\n\nProduced tasks: {self._produced_tasks_count}\nElapsed: (see time in UI)",
1157
+ expand=True,
1158
+ title="Summary",
1159
+ border_style="green",
1160
+ )
1161
+ console.print(summary)
1162
+
1163
+ async def _update_statuses(self):
1164
+ await asyncio.gather(*[batch.update_status() for batch in self._active_batches if batch.status not in self.TERMINAL_BATCH_STATES])
1165
+
1166
+ def _make_system_stats_panel(self):
1167
+ """helpers to create a system stats panel for the live UI."""
1168
+ def usage_bar(label: str, percent: float, color: str):
1169
+ p = Progress(
1170
+ TextColumn(f"[bold]{label}[/]"),
1171
+ BarColumn(bar_width=None, complete_style=color),
1172
+ TextColumn(f"{percent:.1f}%"),
1173
+ expand=True,
1174
+ )
1175
+ p.add_task("", total=100, completed=percent)
1176
+ return Panel(p, expand=True, width=30)
1177
+
1178
+ cpu = psutil.cpu_percent(interval=None)
1179
+ ram = psutil.virtual_memory().percent
1180
+ cpu_panel = usage_bar("CPU", cpu, "cyan")
1181
+ ram_panel = usage_bar("RAM", ram, "magenta")
1182
+ return Columns([cpu_panel, ram_panel], expand=True, equal=True, align="center")
1183
+
1184
+ class ProfileRunner(Runner):
1185
+ """
1186
+ Creates and schedules batches of ProfileBamTask tasks using either local or Slurm batches.
1187
+
1188
+ Args:
1189
+ run_dir (str | pathlib.Path): Directory where the runner will operate.
1190
+ task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
1191
+ container_engine (Engine): An instance of Engine to wrap task commands.
1192
+ max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
1193
+ poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
1194
+ tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
1195
+ batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
1196
+ slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
1197
+ is "slurm". Default is None.
1198
+ """
1199
+ def __init__(
1200
+ self,
1201
+ run_dir: str | pathlib.Path,
1202
+ task_generator: TaskGenerator,
1203
+ container_engine: Engine,
1204
+ max_concurrent_batches: int = 1,
1205
+ poll_interval: float = 1.0,
1206
+ tasks_per_batch: int = 10,
1207
+ batch_type: str = "local",
1208
+ slurm_config: SlurmConfig | None = None,
1209
+ ) -> None:
1210
+ if batch_type == "slurm":
1211
+ if slurm_config is None:
1212
+ raise ValueError("Slurm config must be provided for slurm batch type.")
1213
+ batch_factory = SlurmBatch
1214
+ final_batch_factory = None
1215
+ else:
1216
+ batch_factory = LocalBatch
1217
+ final_batch_factory = None
1218
+
1219
+ super().__init__(
1220
+ run_dir=run_dir,
1221
+ task_generator=task_generator,
1222
+ container_engine=container_engine,
1223
+ batch_factory=batch_factory,
1224
+ final_batch_factory=final_batch_factory,
1225
+ max_concurrent_batches=max_concurrent_batches,
1226
+ poll_interval=poll_interval,
1227
+ tasks_per_batch=tasks_per_batch,
1228
+ batch_type=batch_type,
1229
+ slurm_config=slurm_config,
1230
+ )
1231
+
1232
+ async def _batcher(self):
1233
+ """
1234
+ Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
1235
+ and puts the batches into the batches_queue.
1236
+ """
1237
+ buffer: list[Task] = []
1238
+ while True:
1239
+ task = await self.tasks_queue.get()
1240
+ if task is None:
1241
+ if buffer:
1242
+ batch_id = f"batch_{self._batch_counter}"
1243
+ self._batch_counter += 1
1244
+ if self.batch_type == "slurm":
1245
+ batch = self.batch_factory(
1246
+ tasks=buffer,
1247
+ id=batch_id,
1248
+ run_dir=self.run_dir,
1249
+ expected_outputs=[],
1250
+ slurm_config=self.slurm_config,
1251
+ )
1252
+ else:
1253
+ batch = self.batch_factory(
1254
+ tasks=buffer,
1255
+ id=batch_id,
1256
+ run_dir=self.run_dir,
1257
+ expected_outputs=[],
1258
+ )
1259
+ await self.batches_queue.put(batch)
1260
+ await self.batches_queue.put(None)
1261
+ self._batcher_done = True
1262
+ break
1263
+ buffer.append(task)
1264
+ if len(buffer) == self.tasks_per_batch:
1265
+ batch_id = f"batch_{self._batch_counter}"
1266
+ self._batch_counter += 1
1267
+ if self.batch_type == "slurm":
1268
+ batch = self.batch_factory(
1269
+ tasks=buffer,
1270
+ id=batch_id,
1271
+ run_dir=self.run_dir,
1272
+ expected_outputs=[],
1273
+ slurm_config=self.slurm_config,
1274
+ )
1275
+ else:
1276
+ batch = self.batch_factory(
1277
+ tasks=buffer,
1278
+ id=batch_id,
1279
+ run_dir=self.run_dir,
1280
+ expected_outputs=[],
1281
+ )
1282
+ await self.batches_queue.put(batch)
1283
+ buffer = []
1284
+
1285
+
1286
+ class CompareRunner(Runner):
1287
+ """
1288
+ Creates and schedules batches of FastCompareTask tasks using either local or Slurm batches.
1289
+
1290
+ Args:
1291
+ run_dir (str | pathlib.Path): Directory where the runner will operate.
1292
+ task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
1293
+ container_engine (Engine): An instance of Engine to wrap task commands.
1294
+ max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
1295
+ poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
1296
+ tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
1297
+ batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
1298
+ slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
1299
+ is "slurm". Default is None.
1300
+ """
1301
+
1302
+ def __init__(
1303
+ self,
1304
+ run_dir: str | pathlib.Path,
1305
+ task_generator: TaskGenerator,
1306
+ container_engine: Engine,
1307
+ max_concurrent_batches: int = 1,
1308
+ poll_interval: float = 1.0,
1309
+ tasks_per_batch: int = 10,
1310
+ batch_type: str = "local",
1311
+ slurm_config: SlurmConfig | None = None,
1312
+ ) -> None:
1313
+ if batch_type == "slurm":
1314
+ if slurm_config is None:
1315
+ raise ValueError("Slurm config must be provided for slurm batch type.")
1316
+ batch_factory = FastCompareSlurmBatch
1317
+ final_batch_factory = PrepareCompareGenomeRunOutputsSlurmBatch
1318
+ else:
1319
+ batch_factory = FastCompareLocalBatch
1320
+ final_batch_factory = PrepareCompareGenomeRunOutputsLocalBatch
1321
+ super().__init__(
1322
+ run_dir=run_dir,
1323
+ task_generator=task_generator,
1324
+ container_engine=container_engine,
1325
+ batch_factory=batch_factory,
1326
+ final_batch_factory=final_batch_factory,
1327
+ max_concurrent_batches=max_concurrent_batches,
1328
+ poll_interval=poll_interval,
1329
+ tasks_per_batch=tasks_per_batch,
1330
+ batch_type=batch_type,
1331
+ slurm_config=slurm_config,
1332
+ )
1333
+
1334
+
1335
+
1336
+
1337
+ async def _batcher(self):
1338
+ """
1339
+ Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
1340
+ and puts the batches into the batches_queue. Each batch includes a CollectComps task to merge the outputs of the tasks in the batch.
1341
+ """
1342
+ buffer: list[Task] = []
1343
+ while True:
1344
+ task = await self.tasks_queue.get()
1345
+ if task is None:
1346
+ if buffer:
1347
+ batch_id = f"batch_{self._batch_counter}"
1348
+ self._batch_counter += 1
1349
+ batch_tasks = buffer + [
1350
+ CollectComps(
1351
+ "concat_parquet",
1352
+ {},
1353
+ {"output-file": FileOutput(f"Merged_batch_{batch_id}.parquet")},
1354
+ engine=self.container_engine,
1355
+ )
1356
+ ]
1357
+ expected_outputs = [BatchFileOutput(f"concat_parquet/Merged_batch_{batch_id}.parquet")]
1358
+ if self.batch_type == "slurm":
1359
+ batch = self.batch_factory(
1360
+ tasks=batch_tasks,
1361
+ id=batch_id,
1362
+ run_dir=self.run_dir,
1363
+ expected_outputs=expected_outputs,
1364
+ slurm_config=self.slurm_config,
1365
+ )
1366
+ else:
1367
+ batch = self.batch_factory(
1368
+ tasks=batch_tasks,
1369
+ id=batch_id,
1370
+ run_dir=self.run_dir,
1371
+ expected_outputs=expected_outputs,
1372
+ )
1373
+ await self.batches_queue.put(batch)
1374
+ await self.batches_queue.put(None)
1375
+ self._batcher_done = True
1376
+ break
1377
+ buffer.append(task)
1378
+ if len(buffer) == self.tasks_per_batch:
1379
+ batch_id = f"batch_{self._batch_counter}"
1380
+ self._batch_counter += 1
1381
+ batch_tasks = buffer + [
1382
+ CollectComps(
1383
+ "concat_parquet",
1384
+ {},
1385
+ {"output-file": FileOutput(f"Merged_batch_{batch_id}.parquet")},
1386
+ engine=self.container_engine,
1387
+ )
1388
+ ]
1389
+ expected_outputs = [BatchFileOutput(f"concat_parquet/Merged_batch_{batch_id}.parquet")]
1390
+ if self.batch_type == "slurm":
1391
+ batch = self.batch_factory(
1392
+ tasks=batch_tasks,
1393
+ id=batch_id,
1394
+ run_dir=self.run_dir,
1395
+ expected_outputs=expected_outputs,
1396
+ slurm_config=self.slurm_config,
1397
+ )
1398
+ else:
1399
+ batch = self.batch_factory(
1400
+ tasks=batch_tasks,
1401
+ id=batch_id,
1402
+ run_dir=self.run_dir,
1403
+ expected_outputs=expected_outputs,
1404
+ )
1405
+ await self.batches_queue.put(batch)
1406
+ buffer = []
1407
+
1408
+
1409
+
1410
+ def _create_final_batch(self) -> Batch:
1411
+ """Creates the final batch that prepares the overall outputs after all comparison batches are done."""
1412
+ final_task = PrepareCompareGenomeRunOutputs(
1413
+ id="prepare_outputs",
1414
+ inputs={"output-dir": StringInput("Outputs")},
1415
+ expected_outputs={},
1416
+ engine=self.container_engine,
1417
+ )
1418
+ expected_outputs = [BatchFileOutput("all_comparisons.parquet")]
1419
+ if self.batch_type == "slurm":
1420
+ final_batch=self.final_batch_factory(
1421
+ tasks=[final_task],
1422
+ id="Outputs",
1423
+ run_dir=self.run_dir,
1424
+ expected_outputs=expected_outputs,
1425
+ slurm_config=self.slurm_config,
1426
+ )
1427
+ final_batch._runner_obj = self
1428
+ return final_batch
1429
+ else:
1430
+ final_batch = self.final_batch_factory(
1431
+ tasks=[final_task],
1432
+ id="Outputs",
1433
+ run_dir=self.run_dir,
1434
+ expected_outputs=expected_outputs,
1435
+ )
1436
+ final_batch._runner_obj = self
1437
+ return final_batch
1438
+
1439
+
1440
+
1441
+ class ProfileBamTask(Task):
1442
+ """A Task that generates a mpileup file and genome breadth file in parquet format for a given BAM file using the fast_profile profile_bam command.
1443
+ The inputs to this task includes:
1444
+
1445
+ - bam-file: The input BAM file to be profiled.
1446
+
1447
+ - bed-file: The BED file specifying the regions to profile.
1448
+
1449
+ - sample-name: The name of the sample being processed.
1450
+
1451
+ - gene-range-table: A BED file specifying the gene ranges for the sample.
1452
+
1453
+ - num-threads: The number of threads to use for processing.
1454
+
1455
+ - genome-length-file: A file containing the lengths of the genomes in the reference fasta.
1456
+
1457
+ - stb-file: The STB file used for profiling.
1458
+
1459
+ Args:
1460
+ id (str): Unique identifier for the task.
1461
+ inputs (dict[str, Input]): Dictionary of input parameters for the task.
1462
+ expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
1463
+ engine (Engine): Container engine to wrap the command."""
1464
+
1465
+ TEMPLATE_CMD="""
1466
+ ln -s <bam-file> input.bam
1467
+ ln -s <bed-file> bed_file.bed
1468
+ ln -s <gene-range-table> gene-range-table.bed
1469
+ samtools index <bam-file>
1470
+ zipstrain profile profile-single --bam-file input.bam \
1471
+ --bed-file bed_file.bed \
1472
+ --gene-range-table gene-range-table.bed \
1473
+ --num-workers <num-workers> \
1474
+ --output-dir .
1475
+ mv input.bam.parquet <sample-name>.parquet
1476
+ samtools idxstats <bam-file> | awk '$3 > 0 {print $1}' > <sample-name>.parquet.scaffolds
1477
+ zipstrain utilities genome_breadth_matrix --profile <sample-name>.parquet \
1478
+ --genome-length <genome-length-file> \
1479
+ --stb <stb-file> \
1480
+ --min-cov <breadth-min-cov> \
1481
+ --output-file <sample-name>_breadth.parquet
1482
+ """
1483
+
1484
+ class FastCompareTask(Task):
1485
+ """A Task that performs a fast genome comparison using the fast_profile compare single_compare_genome command.
1486
+
1487
+ Args:
1488
+ id (str): Unique identifier for the task.
1489
+ inputs (dict[str, Input]): Dictionary of input parameters for the task.
1490
+ expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
1491
+ engine (Engine): Container engine to wrap the command.
1492
+ """
1493
+ TEMPLATE_CMD="""
1494
+ zipstrain compare single_compare_genome --mpileup-contig-1 <mpile_1_file> \
1495
+ --mpileup-contig-2 <mpile_2_file> \
1496
+ --scaffolds-1 <scaffold_1_file> \
1497
+ --scaffolds-2 <scaffold_2_file> \
1498
+ --null-model <null_model_file> \
1499
+ --stb-file <stb_file> \
1500
+ --min-cov <min_cov> \
1501
+ --min-gene-compare-len <min-gene-compare-len> \
1502
+ --memory-mode <memory-mode> \
1503
+ --chrom-batch-size <chrom-batch-size> \
1504
+ --output-file <output-file> \
1505
+ --genome <genome-name> \
1506
+ --engine <engine>
1507
+ """
1508
+
1509
+
1510
+ class CollectComps(Task):
1511
+ """A Task that collects and merges comparison parquet files from multiple FastCompareTask tasks into a single parquet file.
1512
+
1513
+ Args:
1514
+ id (str): Unique identifier for the task.
1515
+ inputs (dict[str, Input]): Dictionary of input parameters for the task.
1516
+ expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
1517
+ engine (Engine): Container engine to wrap the command."""
1518
+ TEMPLATE_CMD="""
1519
+ mkdir -p comps
1520
+ cp */*_comparison.parquet comps/
1521
+ zipstrain utilities merge_parquet --input-dir comps --output-file <output-file>
1522
+ rm -rf comps
1523
+ """
1524
+
1525
+ @property
1526
+ def pre_run(self) -> str:
1527
+ return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status"
1528
+
1529
+
1530
+
1531
+ class PrepareCompareGenomeRunOutputs(Task):
1532
+ """A Task that prepares the final output by merging all parquet files after all genome comparisons are done."""
1533
+ TEMPLATE_CMD="""
1534
+ mkdir -p <output-dir>/comps
1535
+ find "$(pwd)" -type f -name "Merged_batch_*.parquet" -print0 | xargs -0 -I {} ln -s {} <output-dir>/comps/
1536
+ zipstrain utilities merge_parquet --input-dir <output-dir>/comps --output-file <output-dir>/all_comparisons.parquet
1537
+ rm -rf <output-dir>/comps
1538
+ """
1539
+
1540
+ @property
1541
+ def pre_run(self) -> str:
1542
+ """Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs."""
1543
+ return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self._batch_obj._runner_obj.run_dir.absolute()}"
1544
+
1545
+
1546
+
1547
+
1548
+ class FastCompareLocalBatch(LocalBatch):
1549
+ """A LocalBatch that runs FastCompareTask tasks locally."""
1550
+ def cleanup(self) -> None:
1551
+ tasks_to_remove = [task for task in self.tasks if isinstance(task, FastCompareTask)]
1552
+ for task in tasks_to_remove:
1553
+ self.tasks.remove(task)
1554
+ shutil.rmtree(task.task_dir)
1555
+
1556
+ class FastCompareSlurmBatch(SlurmBatch):
1557
+ """A SlurmBatch that runs FastCompareTask tasks on a Slurm cluster. Maybe removed in future"""
1558
+ def cleanup(self) -> None:
1559
+ tasks_to_remove = [task for task in self.tasks if isinstance(task, FastCompareTask)]
1560
+ for task in tasks_to_remove:
1561
+ self.tasks.remove(task)
1562
+ shutil.rmtree(task.task_dir)
1563
+
1564
+ class PrepareCompareGenomeRunOutputsLocalBatch(LocalBatch):
1565
+ pass
1566
+
1567
+
1568
+ class PrepareCompareGenomeRunOutputsSlurmBatch(SlurmBatch):
1569
+ pass
1570
+
1571
+
1572
+ def lazy_run_profile(
1573
+ run_dir: str | pathlib.Path,
1574
+ container_engine: Engine,
1575
+ bams_lf:pl.LazyFrame,
1576
+ stb_file:pathlib.Path,
1577
+ gene_range_table:pathlib.Path,
1578
+ bed_file:pathlib.Path,
1579
+ genome_length_file:pathlib.Path,
1580
+ num_procs:int=8,
1581
+ tasks_per_batch: int = 10,
1582
+ max_concurrent_batches: int = 1,
1583
+ poll_interval: float = 5.0,
1584
+ execution_mode: str = "local",
1585
+ slurm_config: SlurmConfig | None = None,
1586
+ )->None:
1587
+ profile_task_generator=ProfileTaskGenerator(
1588
+ data=bams_lf,
1589
+ yield_size=tasks_per_batch,
1590
+ container_engine=container_engine,
1591
+ stb_file=stb_file,
1592
+ profile_bed_file=bed_file,
1593
+ gene_range_file=gene_range_table,
1594
+ genome_length_file=genome_length_file,
1595
+ num_procs=num_procs
1596
+ )
1597
+ if execution_mode=="local":
1598
+ batch_type="local"
1599
+ elif execution_mode=="slurm":
1600
+ batch_type="slurm"
1601
+ else:
1602
+ raise ValueError(f"Unknown execution mode: {execution_mode}")
1603
+
1604
+ runner = ProfileRunner(
1605
+ run_dir=pathlib.Path(run_dir),
1606
+ task_generator=profile_task_generator,
1607
+ container_engine=container_engine,
1608
+ max_concurrent_batches=max_concurrent_batches,
1609
+ poll_interval=poll_interval,
1610
+ tasks_per_batch=tasks_per_batch,
1611
+ batch_type=batch_type,
1612
+ slurm_config=slurm_config,
1613
+ )
1614
+ asyncio.run(runner.run())
1615
+
1616
+
1617
+ def lazy_run_compares(
1618
+ run_dir: str | pathlib.Path,
1619
+ container_engine: Engine,
1620
+ comps_db: database.GenomeComparisonDatabase|None = None,
1621
+ tasks_per_batch: int = 10,
1622
+ max_concurrent_batches: int = 1,
1623
+ poll_interval: float = 5.0,
1624
+ execution_mode: str = "local",
1625
+ slurm_config: SlurmConfig | None = None,
1626
+ memory_mode: str = "heavy",
1627
+ chrom_batch_size: int = 10000,
1628
+ polars_engine: str = "streaming"
1629
+ ) -> None:
1630
+ """A helper function to quickly set up and run a CompareRunner with given parameters.
1631
+
1632
+ Args:
1633
+ run_dir (str | pathlib.Path): Directory where the runner will operate.
1634
+ container_engine (Engine): An instance of Engine to wrap task commands.
1635
+ comps_db (GenomeComparisonDatabase | None): An instance of GenomeComparisonDatabase containing comparison data.
1636
+ tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
1637
+ max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
1638
+ poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 5.0.
1639
+ execution_mode (str): Execution mode, either "local" or "slurm". Default is "local".
1640
+ """
1641
+ task_generator = CompareTaskGenerator(
1642
+ data=comps_db.to_complete_input_table(),
1643
+ yield_size=tasks_per_batch,
1644
+ container_engine=container_engine,
1645
+ comp_config=comps_db.config,
1646
+ memory_mode=memory_mode,
1647
+ polars_engine=polars_engine,
1648
+ chrom_batch_size=chrom_batch_size,
1649
+ )
1650
+ if execution_mode=="local":
1651
+ batch_type="local"
1652
+ elif execution_mode=="slurm":
1653
+ batch_type="slurm"
1654
+ else:
1655
+ raise ValueError(f"Unknown execution mode: {execution_mode}")
1656
+ runner = CompareRunner(
1657
+ run_dir=pathlib.Path(run_dir),
1658
+ task_generator=task_generator,
1659
+ container_engine=container_engine,
1660
+ max_concurrent_batches=max_concurrent_batches,
1661
+ poll_interval=poll_interval,
1662
+ tasks_per_batch=tasks_per_batch,
1663
+ batch_type=batch_type,
1664
+ slurm_config=slurm_config,
1665
+ )
1666
+ asyncio.run(runner.run())
1667
+
1668
+
1669
+ class FastGeneCompareTask(Task):
1670
+ """A Task that performs a fast gene comparison using the compare single_compare_gene command.
1671
+
1672
+ Args:
1673
+ id (str): Unique identifier for the task.
1674
+ inputs (dict[str, Input]): Dictionary of input parameters for the task.
1675
+ expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
1676
+ engine (Engine): Container engine to wrap the command.
1677
+ """
1678
+ TEMPLATE_CMD="""
1679
+ zipstrain compare single_compare_gene --mpileup-contig-1 <mpile_1_file> \
1680
+ --mpileup-contig-2 <mpile_2_file> \
1681
+ --null-model <null_model_file> \
1682
+ --stb-file <stb_file> \
1683
+ --min-cov <min_cov> \
1684
+ --min-gene-compare-len <min-gene-compare-len> \
1685
+ --output-file <output-file> \
1686
+ --engine <engine> \
1687
+ --ani-method <ani-method>
1688
+ """
1689
+
1690
+ class GeneCompareTaskGenerator(TaskGenerator):
1691
+ """This TaskGenerator generates FastGeneCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genes functionality in
1692
+ zipstrain.compare module.
1693
+
1694
+ Args:
1695
+ data (pl.LazyFrame): Polars LazyFrame containing the data for generating tasks.
1696
+ yield_size (int): Number of tasks to yield at a time.
1697
+ comp_config (database.GenomeComparisonConfig): Configuration for genome comparison.
1698
+ polars_engine (str): Polars engine to use. Default is "streaming".
1699
+ ani_method (str): ANI calculation method to use. Default is "popani".
1700
+ """
1701
+ def __init__(
1702
+ self,
1703
+ data: pl.LazyFrame,
1704
+ yield_size: int,
1705
+ container_engine: Engine,
1706
+ comp_config: database.GeneComparisonConfig,
1707
+ polars_engine: str = "streaming",
1708
+ ani_method: str = "popani",
1709
+ ) -> None:
1710
+ super().__init__(data, yield_size)
1711
+ self.comp_config = comp_config
1712
+ self.engine = container_engine
1713
+ self.polars_engine = polars_engine
1714
+ self.ani_method = ani_method
1715
+ if type(self.data) is not pl.LazyFrame:
1716
+ raise ValueError("data must be a polars LazyFrame.")
1717
+
1718
+ def get_total_tasks(self) -> int:
1719
+ """Returns total number of pairwise comparisons to be made."""
1720
+ return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]
1721
+
1722
+ async def generate_tasks(self) -> list[Task]:
1723
+ """Yields lists of FastGeneCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
1724
+ while polars is collecting data to avoid blocking.
1725
+ """
1726
+ for offset in range(0, self._total_tasks, self.yield_size):
1727
+ batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
1728
+ tasks = []
1729
+ for row in batch_df.iter_rows(named=True):
1730
+ inputs = {
1731
+ "mpile_1_file": FileInput(row["profile_location_1"]),
1732
+ "mpile_2_file": FileInput(row["profile_location_2"]),
1733
+ "null_model_file": FileInput(self.comp_config.null_model_loc),
1734
+ "stb_file": FileInput(self.comp_config.stb_file_loc),
1735
+ "min_cov": IntInput(self.comp_config.min_cov),
1736
+ "min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
1737
+ "engine": StringInput(self.polars_engine),
1738
+ "ani-method": StringInput(self.ani_method),
1739
+ }
1740
+ expected_outputs ={
1741
+ "output-file": FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_gene_comparison.parquet" ),
1742
+ }
1743
+ task = FastGeneCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
1744
+ tasks.append(task)
1745
+ yield tasks
1746
+
1747
+ class GeneCompareRunner(Runner):
1748
+ """
1749
+ Creates and schedules batches of FastGeneCompareTask tasks using either local or Slurm batches.
1750
+
1751
+ Args:
1752
+ run_dir (str | pathlib.Path): Directory where the runner will operate.
1753
+ task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
1754
+ container_engine (Engine): An instance of Engine to wrap task commands.
1755
+ max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
1756
+ poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
1757
+ tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
1758
+ batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
1759
+ slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
1760
+ is "slurm". Default is None.
1761
+ """
1762
+
1763
+ def __init__(
1764
+ self,
1765
+ run_dir: str | pathlib.Path,
1766
+ task_generator: TaskGenerator,
1767
+ container_engine: Engine,
1768
+ max_concurrent_batches: int = 1,
1769
+ poll_interval: float = 1.0,
1770
+ tasks_per_batch: int = 10,
1771
+ batch_type: str = "local",
1772
+ slurm_config: SlurmConfig | None = None,
1773
+ ) -> None:
1774
+ if batch_type == "slurm":
1775
+ if slurm_config is None:
1776
+ raise ValueError("Slurm config must be provided for slurm batch type.")
1777
+ batch_factory = FastGeneCompareSlurmBatch
1778
+ final_batch_factory = PrepareGeneCompareRunOutputsSlurmBatch
1779
+ else:
1780
+ batch_factory = FastGeneCompareLocalBatch
1781
+ final_batch_factory = PrepareGeneCompareRunOutputsLocalBatch
1782
+ super().__init__(
1783
+ run_dir=run_dir,
1784
+ task_generator=task_generator,
1785
+ container_engine=container_engine,
1786
+ batch_factory=batch_factory,
1787
+ final_batch_factory=final_batch_factory,
1788
+ max_concurrent_batches=max_concurrent_batches,
1789
+ poll_interval=poll_interval,
1790
+ tasks_per_batch=tasks_per_batch,
1791
+ batch_type=batch_type,
1792
+ slurm_config=slurm_config,
1793
+ )
1794
+
1795
+ async def _batcher(self):
1796
+ """
1797
+ Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
1798
+ and puts the batches into the batches_queue. Each batch includes a CollectGeneComps task to merge the outputs of the tasks in the batch.
1799
+ """
1800
+ buffer: list[Task] = []
1801
+ while True:
1802
+ task = await self.tasks_queue.get()
1803
+ if task is None:
1804
+ if buffer:
1805
+ collect_task = CollectGeneComps(
1806
+ id="collect_gene_comps",
1807
+ inputs={},
1808
+ expected_outputs={"output-file": FileOutput(f"Merged_gene_batch_{self._batch_counter}.parquet")},
1809
+ engine=self.container_engine,
1810
+ )
1811
+ buffer.append(collect_task)
1812
+ batch = self.batch_factory(
1813
+ tasks=buffer,
1814
+ id=f"gene_batch_{self._batch_counter}",
1815
+ run_dir=self.run_dir,
1816
+ expected_outputs=[],
1817
+ slurm_config=self.slurm_config if self.batch_type == "slurm" else None,
1818
+ )
1819
+ await self.batches_queue.put(batch)
1820
+ self._batch_counter += 1
1821
+ self._batcher_done = True
1822
+ break
1823
+
1824
+ buffer.append(task)
1825
+ if len(buffer) == self.tasks_per_batch:
1826
+ collect_task = CollectGeneComps(
1827
+ id="collect_gene_comps",
1828
+ inputs={},
1829
+ expected_outputs={"output-file": FileOutput(f"Merged_gene_batch_{self._batch_counter}.parquet")},
1830
+ engine=self.container_engine,
1831
+ )
1832
+ buffer.append(collect_task)
1833
+ batch = self.batch_factory(
1834
+ tasks=buffer,
1835
+ id=f"gene_batch_{self._batch_counter}",
1836
+ run_dir=self.run_dir,
1837
+ expected_outputs=[],
1838
+ slurm_config=self.slurm_config if self.batch_type == "slurm" else None,
1839
+ )
1840
+ await self.batches_queue.put(batch)
1841
+ self._batch_counter += 1
1842
+ buffer = []
1843
+
1844
+ def _create_final_batch(self) -> Batch:
1845
+ """Creates the final batch that prepares the overall outputs after all gene comparison batches are done."""
1846
+ final_task = PrepareGeneCompareRunOutputs(
1847
+ id="prepare_gene_outputs",
1848
+ inputs={"output-dir": StringInput("Outputs")},
1849
+ expected_outputs={},
1850
+ engine=self.container_engine,
1851
+ )
1852
+ expected_outputs = [BatchFileOutput("all_gene_comparisons.parquet")]
1853
+ if self.batch_type == "slurm":
1854
+ final_batch=self.final_batch_factory(
1855
+ tasks=[final_task],
1856
+ id="Outputs",
1857
+ run_dir=self.run_dir,
1858
+ expected_outputs=expected_outputs,
1859
+ slurm_config=self.slurm_config,
1860
+ )
1861
+ final_batch._runner_obj = self
1862
+ return final_batch
1863
+ else:
1864
+ final_batch = self.final_batch_factory(
1865
+ tasks=[final_task],
1866
+ id="Outputs",
1867
+ run_dir=self.run_dir,
1868
+ expected_outputs=expected_outputs,
1869
+ )
1870
+ final_batch._runner_obj = self
1871
+ return final_batch
1872
+
1873
+ class CollectGeneComps(Task):
1874
+ """A Task that collects and merges gene comparison parquet files from multiple FastGeneCompareTask tasks into a single parquet file.
1875
+
1876
+ Args:
1877
+ id (str): Unique identifier for the task.
1878
+ inputs (dict[str, Input]): Dictionary of input parameters for the task.
1879
+ expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
1880
+ engine (Engine): Container engine to wrap the command."""
1881
+ TEMPLATE_CMD="""
1882
+ mkdir -p gene_comps
1883
+ cp */*_gene_comparison.parquet gene_comps/
1884
+ zipstrain utilities merge_parquet --input-dir gene_comps --output-file <output-file>
1885
+ rm -rf gene_comps
1886
+ """
1887
+
1888
+ @property
1889
+ def pre_run(self) -> str:
1890
+ return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status"
1891
+
1892
+ class PrepareGeneCompareRunOutputs(Task):
1893
+ """A Task that prepares the final output by merging all gene comparison parquet files after all gene comparisons are done."""
1894
+ TEMPLATE_CMD="""
1895
+ mkdir -p <output-dir>/gene_comps
1896
+ find "$(pwd)" -type f -name "Merged_gene_batch_*.parquet" -print0 | xargs -0 -I {} ln -s {} <output-dir>/gene_comps/
1897
+ zipstrain utilities merge_parquet --input-dir <output-dir>/gene_comps --output-file <output-dir>/all_gene_comparisons.parquet
1898
+ rm -rf <output-dir>/gene_comps
1899
+ """
1900
+
1901
+ @property
1902
+ def pre_run(self) -> str:
1903
+ """Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs."""
1904
+ return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self._batch_obj._runner_obj.run_dir.absolute()}"
1905
+
1906
+ class FastGeneCompareLocalBatch(LocalBatch):
1907
+ """A LocalBatch that runs FastGeneCompareTask tasks locally."""
1908
+ def cleanup(self) -> None:
1909
+ tasks_to_remove = [task for task in self.tasks if isinstance(task, FastGeneCompareTask)]
1910
+ for task in tasks_to_remove:
1911
+ self.tasks.remove(task)
1912
+ shutil.rmtree(task.task_dir)
1913
+
1914
+ class FastGeneCompareSlurmBatch(SlurmBatch):
1915
+ """A SlurmBatch that runs FastGeneCompareTask tasks on a Slurm cluster."""
1916
+ def cleanup(self) -> None:
1917
+ tasks_to_remove = [task for task in self.tasks if isinstance(task, FastGeneCompareTask)]
1918
+ for task in tasks_to_remove:
1919
+ self.tasks.remove(task)
1920
+ shutil.rmtree(task.task_dir)
1921
+
1922
+ class PrepareGeneCompareRunOutputsLocalBatch(LocalBatch):
1923
+ pass
1924
+
1925
+ class PrepareGeneCompareRunOutputsSlurmBatch(SlurmBatch):
1926
+ pass
1927
+
1928
+ def lazy_run_gene_compares(
1929
+ run_dir: str | pathlib.Path,
1930
+ container_engine: Engine,
1931
+ comps_db: database.GeneComparisonDatabase | None = None,
1932
+ tasks_per_batch: int = 10,
1933
+ max_concurrent_batches: int = 1,
1934
+ poll_interval: float = 5.0,
1935
+ execution_mode: str = "local",
1936
+ slurm_config: SlurmConfig | None = None,
1937
+ polars_engine: str = "streaming",
1938
+ ani_method: str = "popani"
1939
+ ) -> None:
1940
+ """A helper function to quickly set up and run a GeneCompareRunner with given parameters.
1941
+
1942
+ Args:
1943
+ run_dir (str | pathlib.Path): Directory where the runner will operate.
1944
+ container_engine (Engine): An instance of Engine to wrap task commands.
1945
+ comps_db (GenomeComparisonDatabase | None): An instance of GenomeComparisonDatabase containing comparison data.
1946
+ tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
1947
+ max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
1948
+ poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 5.0.
1949
+ execution_mode (str): Execution mode, either "local" or "slurm". Default is "local".
1950
+ polars_engine (str): Polars engine to use. Default is "streaming".
1951
+ ani_method (str): ANI calculation method to use. Default is "popani".
1952
+ """
1953
+ task_generator = GeneCompareTaskGenerator(
1954
+ data=comps_db.to_complete_input_table(),
1955
+ yield_size=tasks_per_batch,
1956
+ container_engine=container_engine,
1957
+ comp_config=comps_db.config,
1958
+ polars_engine=polars_engine,
1959
+ ani_method=ani_method,
1960
+ )
1961
+ if execution_mode=="local":
1962
+ batch_type="local"
1963
+ elif execution_mode=="slurm":
1964
+ batch_type="slurm"
1965
+ else:
1966
+ raise ValueError(f"Unknown execution mode: {execution_mode}")
1967
+ runner = GeneCompareRunner(
1968
+ run_dir=pathlib.Path(run_dir),
1969
+ task_generator=task_generator,
1970
+ container_engine=container_engine,
1971
+ max_concurrent_batches=max_concurrent_batches,
1972
+ poll_interval=poll_interval,
1973
+ tasks_per_batch=tasks_per_batch,
1974
+ batch_type=batch_type,
1975
+ slurm_config=slurm_config,
1976
+ )
1977
+ asyncio.run(runner.run())
1978
+