zipstrain 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zipstrain/profile.py ADDED
@@ -0,0 +1,221 @@
1
+ """zipstrain.profile
2
+ ========================
3
+ This module provides functions and utilities to profile a bamfile.
4
+ By profile we mean generating gene, genome, and nucleotide counts at each position on the reference.
5
+ This is a fundamental step for downstream analysis in zipstrain.
6
+ """
7
+ import pathlib
8
+ import polars as pl
9
+ from typing import Generator
10
+ from zipstrain import utils
11
+ import asyncio
12
+ import os
13
+
14
+ def parse_gene_loc_table(fasta_file:pathlib.Path) -> Generator[tuple,None,None]:
15
+ """
16
+ Extract gene locations from a FASTA assuming it is from prodigal yield gene info.
17
+
18
+ Parameters:
19
+ fasta_file (pathlib.Path): Path to the FASTA file.
20
+
21
+ Returns:
22
+ Tuple: A tuple containing:
23
+ - gene_ID
24
+ - scaffold
25
+ - start
26
+ - end
27
+ """
28
+ with open(fasta_file, 'r') as f:
29
+ for line in f:
30
+ if line.startswith('>'):
31
+ parts = line[1:].strip().split()
32
+ gene_id = parts[0]
33
+ scaffold = "_".join(gene_id.split('_')[:-1])
34
+ start = parts[2]
35
+ end=parts[4]
36
+ yield gene_id, scaffold,start,end
37
+
38
+
39
+ def build_gene_loc_table(fasta_file:pathlib.Path,scaffold:set)->pl.DataFrame:
40
+ """
41
+ Build a gene location table from a FASTA file.
42
+
43
+ Parameters:
44
+ fasta_file (pathlib.Path): Path to the FASTA file.
45
+
46
+ Returns:
47
+ pl.DataFrame: A Polars DataFrame containing gene locations.
48
+ """
49
+ scaffolds = []
50
+ gene_ids = []
51
+ pos=[]
52
+ for genes in parse_gene_loc_table(fasta_file):
53
+ if genes[1] in scaffold:
54
+ scaffolds.extend([genes[1]]* (int(genes[3])-int(genes[2])+1))
55
+ gene_ids.extend([genes[0]]* (int(genes[3])-int(genes[2])+1))
56
+ pos.extend(list(range(int(genes[2]), int(genes[3])+1)))
57
+ return pl.DataFrame({
58
+ "scaffold":scaffolds,
59
+ "gene":gene_ids,
60
+ "pos":pos
61
+ })
62
+
63
+ def build_gene_range_table(fasta_file:pathlib.Path)->pl.DataFrame:
64
+ """
65
+ Build a gene location table in the form of <gene scaffold start end> from a FASTA file.
66
+ Parameters:
67
+ fasta_file (pathlib.Path): Path to the FASTA file.
68
+
69
+ Returns:
70
+ pl.DataFrame: A Polars DataFrame containing gene locations.
71
+ """
72
+ out=[]
73
+ for parsed_annot in parse_gene_loc_table(fasta_file):
74
+ out.append(parsed_annot)
75
+ return pl.DataFrame(out, schema=["gene", "scaffold", "start", "end"],orient='row')
76
+
77
+
78
+
79
+ def add_gene_info_to_mpileup(mpileup_df:pl.LazyFrame, gene_range:pl.DataFrame)->pl.DataFrame:
80
+ mpileup_df=mpileup_df.with_columns(pl.col("gene").fill_null("NA"))
81
+ for gene, scaffold, start, end in gene_range.iter_rows():
82
+ mpileup_df=mpileup_df.with_columns(
83
+ pl.when((pl.col("chrom") == scaffold) & (pl.col("pos") >= start) & (pl.col("pos") <= end))
84
+ .then(gene)
85
+ .otherwise(pl.col("gene"))
86
+ .alias("gene")
87
+ )
88
+ return mpileup_df
89
+
90
+
91
+ def get_strain_hetrogeneity(profile:pl.LazyFrame,
92
+ stb:pl.LazyFrame,
93
+ min_cov=5,
94
+ freq_threshold=0.8)->pl.LazyFrame:
95
+ """
96
+ Calculate strain heterogeneity for each genome based on nucleotide frequencies.
97
+ The definition of strain heterogeneity here is the fraction of sites that have enough coverage
98
+ (min_cov) and have a dominant nucleotide with frequency less than freq_threshold.
99
+
100
+ Args:
101
+ profile (pl.LazyFrame): The profile LazyFrame containing nucleotide counts.
102
+ stb (pl.LazyFrame): The scaffold-to-bin mapping LazyFrame. First column is 'scaffold', second column is 'bin'.
103
+ min_cov (int): The minimum coverage threshold.
104
+ freq_threshold (float): The frequency threshold for dominant nucleotides.
105
+
106
+ Returns:
107
+ pl.LazyFrame: A LazyFrame containing strain heterogeneity information grouped by genome.
108
+ """
109
+ # Calculate the total number of sites with sufficient coverage
110
+ profile = profile.with_columns(
111
+ (pl.col("A")+pl.col("T")+pl.col("C")+pl.col("G")).alias("coverage")
112
+ ).filter(pl.col("coverage") >= min_cov)
113
+
114
+ profile = profile.with_columns(
115
+ (pl.max_horizontal(["A", "T", "C", "G"])/pl.col("coverage") < freq_threshold)
116
+ .cast(pl.Int8)
117
+ .alias("heterogeneous_site")
118
+ )
119
+
120
+ profile = profile.join(stb, left_on="chrom", right_on="scaffold", how="left").group_by("genome").agg([
121
+ pl.len().alias(f"total_sites_at_{min_cov}_coverage"),
122
+ pl.sum("heterogeneous_site").alias("heterogeneous_sites")
123
+ ])
124
+
125
+ strain_heterogeneity = profile.with_columns(
126
+ (pl.col("heterogeneous_sites")/pl.col(f"total_sites_at_{min_cov}_coverage")).alias("strain_heterogeneity")
127
+ )
128
+ return strain_heterogeneity
129
+
130
+
131
+
132
+ async def _profile_chunk_task(
133
+ bed_file:pathlib.Path,
134
+ bam_file:pathlib.Path,
135
+ gene_range_table:pathlib.Path,
136
+ output_dir:pathlib.Path,
137
+ chunk_id:int
138
+ )->None:
139
+ cmd=["samtools", "mpileup", "-A", "-l", str(bed_file.absolute()), str(bam_file.absolute())]
140
+ cmd += ["|", "zipstrain", "utilities", "process_mpileup", "--gene-range-table-loc", str(gene_range_table.absolute()), "--batch-bed", str(bed_file.absolute()), "--output-file", f"{bam_file.stem}_{chunk_id}.parquet"]
141
+ proc = await asyncio.create_subprocess_shell(
142
+ " ".join(cmd),
143
+ stdout=asyncio.subprocess.PIPE,
144
+ stderr=asyncio.subprocess.PIPE,
145
+ cwd=output_dir
146
+ )
147
+ stdout, stderr = await proc.communicate()
148
+ if proc.returncode != 0:
149
+ raise Exception(f"Command failed with error: {stderr.decode().strip()}")
150
+
151
+ async def profile_bam_in_chunks(
152
+ bed_file:str,
153
+ bam_file:str,
154
+ gene_range_table:str,
155
+ output_dir:str,
156
+ num_workers:int=4
157
+ )->None:
158
+ """
159
+ Profile a BAM file in chunks using provided BED files.
160
+
161
+ Parameters:
162
+ bed_file (list[pathlib.Path]): A bed file describing all regions to be profiled.
163
+ bam_file (pathlib.Path): Path to the BAM file.
164
+ gene_range_table (pathlib.Path): Path to the gene range table.
165
+ output_dir (pathlib.Path): Directory to save output files.
166
+ num_workers (int): Number of concurrent workers to use.
167
+ """
168
+
169
+ output_dir=pathlib.Path(output_dir)
170
+ bam_file=pathlib.Path(bam_file)
171
+ bed_file=pathlib.Path(bed_file)
172
+ gene_range_table=pathlib.Path(gene_range_table)
173
+
174
+ output_dir.mkdir(parents=True, exist_ok=True)
175
+ (output_dir/"tmp").mkdir(exist_ok=True)
176
+ bed_lf=pl.scan_csv(bed_file,has_header=False,separator="\t")
177
+ bed_chunks=utils.split_lf_to_chunks(bed_lf,num_workers)
178
+ bed_chunk_files=[]
179
+ for chunk_id, bed_file in enumerate(bed_chunks):
180
+ bed_file.sink_csv(output_dir/"tmp"/f"bed_chunk_{chunk_id}.bed",include_header=False,separator="\t")
181
+ bed_chunk_files.append(output_dir/"tmp"/f"bed_chunk_{chunk_id}.bed")
182
+ tasks = []
183
+ for chunk_id, bed_chunk_file in enumerate(bed_chunk_files):
184
+ tasks.append(_profile_chunk_task(
185
+ bed_file=bed_chunk_file,
186
+ bam_file=bam_file,
187
+ gene_range_table=gene_range_table,
188
+ output_dir=output_dir/"tmp",
189
+ chunk_id=chunk_id
190
+ ))
191
+ await asyncio.gather(*tasks)
192
+ pfs=[output_dir/"tmp"/f"{bam_file.stem}_{chunk_id}.parquet" for chunk_id in range(len(bed_chunk_files)) if (output_dir/"tmp"/f"{bam_file.stem}_{chunk_id}.parquet").exists()]
193
+ mpileup_df = pl.concat([pl.scan_parquet(pf) for pf in pfs])
194
+ mpileup_df.sink_parquet(output_dir/f"{bam_file.stem}.parquet", compression='zstd')
195
+ os.system(f"rm -r {output_dir}/tmp")
196
+
197
+ def profile_bam(
198
+ bed_file:str,
199
+ bam_file:str,
200
+ gene_range_table:str,
201
+ output_dir:str,
202
+ num_workers:int=4
203
+ )->None:
204
+ """
205
+ Profile a BAM file in chunks using provided BED files.
206
+
207
+ Parameters:
208
+ bed_file (list[pathlib.Path]): A bed file describing all regions to be profiled.
209
+ bam_file (pathlib.Path): Path to the BAM file.
210
+ gene_range_table (pathlib.Path): Path to the gene range table.
211
+ output_dir (pathlib.Path): Directory to save output files.
212
+ num_workers (int): Number of concurrent workers to use.
213
+ """
214
+ asyncio.run(profile_bam_in_chunks(
215
+ bed_file=bed_file,
216
+ bam_file=bam_file,
217
+ gene_range_table=gene_range_table,
218
+ output_dir=output_dir,
219
+ num_workers=num_workers
220
+ ))
221
+