zipstrain 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zipstrain/utils.py ADDED
@@ -0,0 +1,451 @@
1
+ """
2
+ zipstrain.utils
3
+ ========================
4
+ This module provides utility functions for profiling and compare operations.
5
+ """
6
+ import pathlib
7
+ import polars as pl
8
+ import sys
9
+ import re
10
+ import pyarrow as pa
11
+ import pyarrow.parquet as pq
12
+ from intervaltree import IntervalTree
13
+ from collections import defaultdict,Counter
14
+ from functools import reduce
15
+ from scipy.stats import poisson
16
+ import subprocess
17
+ import pdb
18
+
19
+ def build_null_poisson(error_rate:float=0.001,
20
+ max_total_reads:int=10000,
21
+ p_threshold:float=0.05)->list[float]:
22
+ """
23
+ Build a null model to correct for sequencing errors based on the Poisson distribution.
24
+
25
+ Parameters:
26
+ error_rate (float): Error rate for the sequencing technology.
27
+ max_total_reads (int): Maximum total reads to consider.
28
+ p_threshold (float): Significance threshold for the Poisson distribution.
29
+
30
+ Returns:
31
+ pl.DataFrame: DataFrame containing total reads and maximum error count thresholds.
32
+ """
33
+ records = []
34
+ for n in range(1, max_total_reads + 1):
35
+ lam = n * (error_rate / 3)
36
+ k = 0
37
+ while poisson.sf(k - 1, lam) > p_threshold:
38
+ k += 1
39
+ records.append((n, k - 1))
40
+ return records
41
+
42
+
43
+
44
+ def clean_bases(bases: str, indel_re: re.Pattern) -> str:
45
+ """
46
+ Remove read start/end markers and indels from bases string using regex.
47
+ Returns cleaned uppercase string of bases only.
48
+ Args:
49
+ bases (str): The bases string from mpileup.
50
+ indel_re (re.Pattern): Compiled regex pattern to match indels and markers.
51
+
52
+ """
53
+ cleaned = []
54
+ i = 0
55
+ while i < len(bases):
56
+ m = indel_re.match(bases, i)
57
+ if m:
58
+ if m.group(0).startswith('+') or m.group(0).startswith('-'):
59
+ # indel length
60
+ indel_len = int(m.group(1))
61
+ i = m.end() + indel_len
62
+ else:
63
+ i = m.end()
64
+ else:
65
+ cleaned.append(bases[i].upper())
66
+ i += 1
67
+ return ''.join(cleaned)
68
+
69
+ def count_bases(bases: str):
70
+ """
71
+ Count occurrences of A, C, G, T in the cleaned bases string.
72
+ Args:
73
+ bases (str): Cleaned bases string.
74
+ Returns:
75
+ dict: Dictionary with counts of A, C, G, T.
76
+ """
77
+ counts = Counter(bases)
78
+ return {
79
+ 'A': counts.get('A', 0),
80
+ 'C': counts.get('C', 0),
81
+ 'G': counts.get('G', 0),
82
+ 'T': counts.get('T', 0),
83
+ }
84
+
85
+ def process_mpileup_function(gene_range_table_loc, batch_bed, batch_size, output_file):
86
+ """
87
+ Process mpileup files and save the results in a Parquet file.
88
+
89
+ Parameters:
90
+ gene_range_table_loc (str): Path to the gene range table in TSV format.
91
+ batch_bed (str): Path to the batch BED file.
92
+ batch_size (int): Buffer size for processing stdin from samtools.
93
+ output_file (str): Path to save the output Parquet file.
94
+ """
95
+ indel_re = re.compile(r'\^.|[\$]|[+-](\d+)')
96
+ gene_ranges_pl = pl.scan_csv(gene_range_table_loc,separator='\t', has_header=False).rename({
97
+ "column_1": "scaffold",
98
+ "column_2": "start",
99
+ "column_3": "end",
100
+ "column_4": "gene"
101
+ })
102
+ scaffolds = pl.read_csv(batch_bed, separator='\t', has_header=False)["column_1"].unique().to_list()
103
+ gene_ranges_pl = gene_ranges_pl.filter(pl.col("scaffold").is_in(scaffolds)).collect()
104
+ gene_ranges = defaultdict(IntervalTree)
105
+ for row in gene_ranges_pl.iter_rows(named=True):
106
+ gene_ranges[row["scaffold"]].addi(row["start"], row["end"] + 1, row["gene"])
107
+
108
+ schema = pa.schema([
109
+ ('chrom', pa.string()),
110
+ ('pos', pa.int32()),
111
+ ('gene', pa.string()),
112
+ ('A', pa.uint16()),
113
+ ('C', pa.uint16()),
114
+ ('G', pa.uint16()),
115
+ ('T', pa.uint16()),
116
+ ])
117
+
118
+ chroms = []
119
+ positions = []
120
+ genes = []
121
+ As = []
122
+ Cs = []
123
+ Gs = []
124
+ Ts = []
125
+
126
+ writer = None
127
+ def flush_batch():
128
+ nonlocal writer
129
+ if not chroms:
130
+ return
131
+ batch = pa.RecordBatch.from_arrays([
132
+ pa.array(chroms, type=pa.string()),
133
+ pa.array(positions, type=pa.int32()),
134
+ pa.array(genes, type=pa.string()),
135
+ pa.array(As, type=pa.uint16()),
136
+ pa.array(Cs, type=pa.uint16()),
137
+ pa.array(Gs, type=pa.uint16()),
138
+ pa.array(Ts, type=pa.uint16()),
139
+ ], schema=schema)
140
+
141
+ if writer is None:
142
+ # Open writer for the first time
143
+ writer = pq.ParquetWriter(output_file, schema, compression='snappy')
144
+ writer.write_table(pa.Table.from_batches([batch]))
145
+
146
+ # Clear buffers
147
+ chroms.clear()
148
+ positions.clear()
149
+ genes.clear()
150
+ As.clear()
151
+ Cs.clear()
152
+ Gs.clear()
153
+ Ts.clear()
154
+ for line in sys.stdin:
155
+ if not line.strip():
156
+ continue
157
+ fields = line.strip().split('\t')
158
+ if len(fields) < 5:
159
+ continue
160
+ chrom, pos, _, _, bases = fields[:5]
161
+
162
+ cleaned = clean_bases(bases, indel_re)
163
+ counts = count_bases(cleaned)
164
+
165
+ chroms.append(chrom)
166
+ positions.append(int(pos))
167
+ matches = gene_ranges[chrom][int(pos)]
168
+ genes.append(next(iter(matches)).data if matches else "NA")
169
+ As.append(counts['A'])
170
+ Cs.append(counts['C'])
171
+ Gs.append(counts['G'])
172
+ Ts.append(counts['T'])
173
+
174
+ if len(chroms) >= batch_size:
175
+ flush_batch()
176
+
177
+ # Flush remaining data
178
+ flush_batch()
179
+
180
+ if writer:
181
+ writer.close()
182
+
183
+ def extract_genome_length(stb: pl.LazyFrame, bed_table: pl.LazyFrame) -> pl.LazyFrame:
184
+ """
185
+ Extract the genome length information from the scaffold-to-genome mapping table.
186
+
187
+ Parameters:
188
+ stb (pl.LazyFrame): Scaffold-to-bin mapping table.
189
+ bed_table (pl.LazyFrame): BED table containing genomic regions.
190
+
191
+ Returns:
192
+ pl.LazyFrame: A LazyFrame containing the genome lengths.
193
+ """
194
+ lf= bed_table.select(
195
+ pl.col("scaffold"),
196
+ (pl.col("end") - pl.col("start")).alias("scaffold_length")
197
+ ).group_by("scaffold").agg(
198
+ scaffold_length=pl.sum("scaffold_length")
199
+ ).select(
200
+ pl.col("scaffold").alias("scaffold"),
201
+ pl.col("scaffold_length")
202
+ ).join(
203
+ stb.select(
204
+ pl.col("scaffold").alias("scaffold"),
205
+ pl.col("genome").alias("genome")
206
+ ),
207
+ on="scaffold",
208
+ how="left"
209
+ ).group_by("genome").agg(
210
+ genome_length=pl.sum("scaffold_length")
211
+ ).select(
212
+ pl.col("genome"),
213
+ pl.col("genome_length")
214
+ )
215
+ return lf
216
+
217
+ def make_the_bed(db_fasta_dir: str | pathlib.Path, max_scaffold_length: int = 500_000) -> pl.DataFrame:
218
+ """
219
+ Create a BED file from the database in fasta format.
220
+
221
+ Parameters:
222
+ db_fasta_dir (Union[str, pathlib.Path]): Path to the fasta file.
223
+ max_scaffold_length (int): Splits scaffolds longer than this into multiple entries of length <= max_scaffold_length.
224
+
225
+ Returns:
226
+ pl.LazyFrame: A LazyFrame containing the BED data.
227
+ """
228
+ db_fasta_dir = pathlib.Path(db_fasta_dir)
229
+ if not db_fasta_dir.is_file():
230
+ raise FileNotFoundError(f"{db_fasta_dir} is not a valid fasta file.")
231
+
232
+ records = []
233
+ with db_fasta_dir.open() as f:
234
+ scaffold = None
235
+ seq_chunks = []
236
+
237
+ for line in f:
238
+ line = line.strip()
239
+ if line.startswith(">"):
240
+ # Process the previous scaffold
241
+ if scaffold is not None:
242
+ seq = ''.join(seq_chunks)
243
+ for start in range(0, len(seq), max_scaffold_length):
244
+ end = min(start + max_scaffold_length, len(seq))
245
+ records.append((scaffold, start, end))
246
+ # Start new scaffold
247
+ scaffold = line[1:].split()[0] # ID only (up to first whitespace)
248
+ seq_chunks = []
249
+ else:
250
+ seq_chunks.append(line)
251
+
252
+ # Don't forget the last scaffold
253
+ if scaffold is not None:
254
+ seq = ''.join(seq_chunks)
255
+ for start in range(0, len(seq), max_scaffold_length):
256
+ end = min(start + max_scaffold_length, len(seq))
257
+ records.append((scaffold, start, end))
258
+
259
+ return pl.DataFrame(records, schema=["scaffold", "start", "end"], orient="row")
260
+
261
+
262
+ def get_genome_breadth_matrix(
263
+ profile:pl.LazyFrame,
264
+ name:str,
265
+ genome_length: pl.LazyFrame,
266
+ stb: pl.LazyFrame,
267
+ min_cov: int = 1)-> pl.LazyFrame:
268
+ """
269
+ Get the genome breadth matrix from the provided profiles and scaffold-to-genome mapping.
270
+ Parameters:
271
+ profiles (list): List of tuples containing profile names and their corresponding LazyFrames.
272
+ stb (pl.LazyFrame): Scaffold-to-genome mapping table.
273
+ min_cov (int): Minimum coverage to consider a position.
274
+ Returns:
275
+ pl.LazyFrame: A LazyFrame containing the genome breadth matrix.
276
+ """
277
+ profile = profile.filter((pl.col("A") + pl.col("C") + pl.col("G") + pl.col("T")) >= min_cov)
278
+ profile=profile.group_by("chrom").agg(
279
+ breadth=pl.count()
280
+ ).select(
281
+ pl.col("chrom").alias("scaffold"),
282
+ pl.col("breadth")
283
+ ).join(
284
+ stb,
285
+ on="scaffold",
286
+ how="left"
287
+ )
288
+ profile=profile.join(genome_length, on="genome", how="left")
289
+
290
+ profile=profile.group_by("genome").agg(
291
+ genome_length=pl.first("genome_length"),
292
+ breadth=pl.col("breadth").sum())
293
+ profile = profile.with_columns(
294
+ (pl.col("breadth")/ pl.col("genome_length")).alias("breadth")
295
+ )
296
+ return profile.select(
297
+ pl.col("genome"),
298
+ pl.col("breadth").alias(name)
299
+ )
300
+
301
+ def collect_breadth_tables(
302
+ breadth_tables: list[pl.LazyFrame],
303
+ ) -> pl.LazyFrame:
304
+ """
305
+ Collect multiple genome breadth tables into a single LazyFrame.
306
+
307
+ Parameters:
308
+ breadth_tables (list[pl.LazyFrame]): List of LazyFrames containing genome breadth data.
309
+
310
+ Returns:
311
+ pl.LazyFrame: A LazyFrame containing the combined genome breadth data.
312
+ """
313
+ if not breadth_tables:
314
+ raise ValueError("No breadth tables provided.")
315
+
316
+ return reduce(lambda x, y: x.join(y, on="genome", how="outer", coalesce=True), breadth_tables)
317
+
318
+ def check_samtools():
319
+ try:
320
+ result = subprocess.run(
321
+ ["samtools", "--version"],
322
+ capture_output=True,
323
+ text=True,
324
+ check=True
325
+ )
326
+ return True
327
+ except:
328
+ print("Samtools is not installed or not found in PATH. Please install samtools to use all of the ZipStrain's functionalities.")
329
+ return False
330
+
331
+ def split_lf_to_chunks(lf:pl.LazyFrame,num_chunks:int)->list[pl.LazyFrame]:
332
+ """
333
+ Split a Polars LazyFrame into smaller chunks.
334
+
335
+ Parameters:
336
+ lf (pl.LazyFrame): The input LazyFrame to be split.
337
+ num_chunks (int): The number of chunks to split the LazyFrame into.
338
+
339
+ Returns:
340
+ list[pl.LazyFrame]: A list of smaller LazyFrames.
341
+ """
342
+ total_rows = lf.select(pl.count()).collect().item()
343
+ chunk_size = total_rows // num_chunks
344
+ chunks = []
345
+ for i in range(num_chunks):
346
+ start = i * chunk_size
347
+ end = (i + 1) * chunk_size if i < num_chunks - 1 else total_rows
348
+ chunk = lf.slice(start, end - start)
349
+ chunks.append(chunk)
350
+ return chunks
351
+
352
+
353
+ def estimate_genome_presence(
354
+ profile:pl.LazyFrame,
355
+ bed: pl.LazyFrame,
356
+ stb: pl.LazyFrame,
357
+ ber:float=0.5,
358
+ cv_threshold:float=2.5,
359
+ min_cov_constant_poisson: int = 0.5,
360
+ )->pl.LazyFrame:
361
+ """
362
+ This function estimates the presence of genomes in a sample based on coverage information.
363
+ as long as the coverage is above a certain threshold. BER is used to decide the threshold.
364
+ However, if the coverage is below the threshold, the coefficient of variation (CV) is used instead as
365
+ a more reliable metric for low-coverage scenarios.
366
+
367
+ Args:
368
+ profile (pl.LazyFrame): The profile LazyFrame containing coverage information.
369
+ bed (pl.LazyFrame): The BED table containing genomic regions.
370
+ stb (pl.LazyFrame): The scaffold-to-bin mapping LazyFrame.
371
+ ber (float): Breadth over expected breadth ratio threshold for genome presence.
372
+ cv_threshold (float): Coefficient of variation threshold for genome presence.
373
+ min_cov_constant_poisson (int): Minimum coverage threshold to use BER for presence estimation.
374
+
375
+ Returns:
376
+ pl.LazyFrame: A LazyFrame containing genome presence information.
377
+ """
378
+ profile=profile.with_columns(
379
+ (pl.col("A")+pl.col("T")+pl.col("C")+pl.col("G")).alias("coverage")
380
+ )
381
+ starts_df=bed.select(
382
+ pl.col("scaffold").cast(profile.collect_schema()["chrom"]).alias("chrom"),
383
+ pl.col("start").cast(profile.collect_schema()["pos"]).alias("pos"),
384
+ pl.lit("NA").cast(profile.collect_schema()["gene"]).alias("gene"),
385
+ pl.lit(0).cast(profile.collect_schema()["A"]).alias("A"),
386
+ pl.lit(0).cast(profile.collect_schema()["T"]).alias("T"),
387
+ pl.lit(0).cast(profile.collect_schema()["C"]).alias("C"),
388
+ pl.lit(0).cast(profile.collect_schema()["G"]).alias("G"),
389
+ pl.lit(0).cast(profile.collect_schema()["coverage"]).alias("coverage")
390
+ )
391
+ ends_df=bed.select(
392
+ pl.col("scaffold").cast(profile.collect_schema()["chrom"]).alias("chrom"),
393
+ (pl.col("end")-1).cast(profile.collect_schema()["pos"]).alias("pos"),
394
+ pl.lit("NA").cast(profile.collect_schema()["gene"]).alias("gene"),
395
+ pl.lit(0).cast(profile.collect_schema()["A"]).alias("A"),
396
+ pl.lit(0).cast(profile.collect_schema()["T"]).alias("T"),
397
+ pl.lit(0).cast(profile.collect_schema()["C"]).alias("C"),
398
+ pl.lit(0).cast(profile.collect_schema()["G"]).alias("G"),
399
+ pl.lit(0).cast(profile.collect_schema()["coverage"]).alias("coverage")
400
+ )
401
+
402
+ profile=pl.concat([profile,starts_df,ends_df]).unique(subset=["chrom","pos"],keep="first").sort(["chrom","pos"])
403
+ genome_lengths=bed.join(
404
+ stb,
405
+ on="scaffold",
406
+ how="left"
407
+ ).group_by("genome").agg(
408
+ genome_length=(pl.col("end") - pl.col("start")).sum()
409
+ ).select(
410
+ pl.col("genome"),
411
+ pl.col("genome_length")
412
+ )
413
+ profile=profile.with_columns(
414
+ pl.col("pos").shift(1).fill_null(0).over("chrom").alias("prev_pos"),
415
+ ).with_columns(
416
+ (pl.col("pos") - pl.col("prev_pos")).clip(lower_bound=1).alias("gap_size")
417
+ ).join(
418
+ stb,
419
+ left_on="chrom",
420
+ right_on="scaffold",
421
+ how="left"
422
+ ).group_by("genome").agg(
423
+ cv=pl.col("gap_size").filter(pl.col("gap_size") > 1).std()/pl.col("gap_size").filter(pl.col("gap_size") > 1).mean(),
424
+ total_coverage=pl.col("coverage").sum(),
425
+ covered_positions=(pl.col("coverage")>0).sum()
426
+ ).join(
427
+ genome_lengths,
428
+ on="genome",
429
+ how="left"
430
+ ).with_columns(
431
+ (pl.col("covered_positions")/pl.col("genome_length")).alias("breadth"),
432
+ (pl.col("total_coverage")/pl.col("genome_length")).alias("coverage"),
433
+ ).select(
434
+ pl.col("genome"),
435
+ pl.col("cv"),
436
+ pl.col("breadth"),
437
+ pl.col("coverage"),
438
+ ).with_columns(
439
+ (pl.col("breadth")/(1-(-0.883*pl.col("coverage")).exp())).alias("ber"),
440
+ ).with_columns(
441
+ pl.when(
442
+ pl.col("coverage") >= min_cov_constant_poisson
443
+ ).then(
444
+ pl.col("ber") >= ber
445
+ ).otherwise(
446
+ (pl.col("cv") <= cv_threshold) & (~pl.col("ber").is_nan())
447
+ ).alias("is_present")
448
+ )
449
+
450
+ return profile
451
+