py-gbcms 2.0.0__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gbcms/__init__.py +1 -13
- gbcms/cli.py +134 -716
- gbcms/core/kernel.py +126 -0
- gbcms/io/input.py +222 -0
- gbcms/io/output.py +361 -0
- gbcms/models/core.py +133 -0
- gbcms/pipeline.py +212 -0
- gbcms/py.typed +0 -0
- py_gbcms-2.1.1.dist-info/METADATA +216 -0
- py_gbcms-2.1.1.dist-info/RECORD +13 -0
- gbcms/config.py +0 -98
- gbcms/counter.py +0 -1074
- gbcms/models.py +0 -295
- gbcms/numba_counter.py +0 -394
- gbcms/output.py +0 -573
- gbcms/parallel.py +0 -129
- gbcms/processor.py +0 -293
- gbcms/reference.py +0 -86
- gbcms/variant.py +0 -390
- py_gbcms-2.0.0.dist-info/METADATA +0 -506
- py_gbcms-2.0.0.dist-info/RECORD +0 -16
- {py_gbcms-2.0.0.dist-info → py_gbcms-2.1.1.dist-info}/WHEEL +0 -0
- {py_gbcms-2.0.0.dist-info → py_gbcms-2.1.1.dist-info}/entry_points.txt +0 -0
- {py_gbcms-2.0.0.dist-info → py_gbcms-2.1.1.dist-info}/licenses/LICENSE +0 -0
gbcms/models.py
DELETED
|
@@ -1,295 +0,0 @@
|
|
|
1
|
-
"""Pydantic models for type-safe configuration and data structures."""
|
|
2
|
-
|
|
3
|
-
from enum import IntEnum
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class CountType(IntEnum):
|
|
11
|
-
"""Enumeration for different count types."""
|
|
12
|
-
|
|
13
|
-
DP = 0 # Total depth
|
|
14
|
-
RD = 1 # Reference depth
|
|
15
|
-
AD = 2 # Alternate depth
|
|
16
|
-
DPP = 3 # Positive strand depth
|
|
17
|
-
RDP = 4 # Positive strand reference depth
|
|
18
|
-
ADP = 5 # Positive strand alternate depth
|
|
19
|
-
DPF = 6 # Fragment depth
|
|
20
|
-
RDF = 7 # Fragment reference depth
|
|
21
|
-
ADF = 8 # Fragment alternate depth
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class BamFileConfig(BaseModel):
|
|
25
|
-
"""Configuration for a single BAM file."""
|
|
26
|
-
|
|
27
|
-
sample_name: str = Field(..., description="Sample name")
|
|
28
|
-
bam_path: Path = Field(..., description="Path to BAM file")
|
|
29
|
-
bai_path: Path | None = Field(None, description="Path to BAM index")
|
|
30
|
-
|
|
31
|
-
@field_validator("bam_path")
|
|
32
|
-
@classmethod
|
|
33
|
-
def validate_bam_exists(cls, v: Path) -> Path:
|
|
34
|
-
"""Validate BAM file exists."""
|
|
35
|
-
if not v.exists():
|
|
36
|
-
raise ValueError(f"BAM file not found: {v}")
|
|
37
|
-
return v
|
|
38
|
-
|
|
39
|
-
@model_validator(mode="after")
|
|
40
|
-
def validate_bai(self) -> "BamFileConfig":
|
|
41
|
-
"""Validate BAM index exists."""
|
|
42
|
-
if self.bai_path is None:
|
|
43
|
-
# Try to find index
|
|
44
|
-
bai_path1 = Path(str(self.bam_path).replace(".bam", ".bai"))
|
|
45
|
-
bai_path2 = Path(f"{self.bam_path}.bai")
|
|
46
|
-
|
|
47
|
-
if bai_path1.exists():
|
|
48
|
-
self.bai_path = bai_path1
|
|
49
|
-
elif bai_path2.exists():
|
|
50
|
-
self.bai_path = bai_path2
|
|
51
|
-
else:
|
|
52
|
-
raise ValueError(f"BAM index not found for: {self.bam_path}")
|
|
53
|
-
|
|
54
|
-
return self
|
|
55
|
-
|
|
56
|
-
model_config = {"arbitrary_types_allowed": True}
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class VariantFileConfig(BaseModel):
|
|
60
|
-
"""Configuration for variant files."""
|
|
61
|
-
|
|
62
|
-
file_path: Path = Field(..., description="Path to variant file")
|
|
63
|
-
file_format: str = Field(..., description="File format (vcf or maf)")
|
|
64
|
-
|
|
65
|
-
@field_validator("file_path")
|
|
66
|
-
@classmethod
|
|
67
|
-
def validate_file_exists(cls, v: Path) -> Path:
|
|
68
|
-
"""Validate variant file exists."""
|
|
69
|
-
if not v.exists():
|
|
70
|
-
raise ValueError(f"Variant file not found: {v}")
|
|
71
|
-
return v
|
|
72
|
-
|
|
73
|
-
@field_validator("file_format")
|
|
74
|
-
@classmethod
|
|
75
|
-
def validate_format(cls, v: str) -> str:
|
|
76
|
-
"""Validate file format."""
|
|
77
|
-
if v.lower() not in ["vcf", "maf"]:
|
|
78
|
-
raise ValueError(f"Invalid format: {v}. Must be 'vcf' or 'maf'")
|
|
79
|
-
return v.lower()
|
|
80
|
-
|
|
81
|
-
model_config = {"arbitrary_types_allowed": True}
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
class QualityFilters(BaseModel):
|
|
85
|
-
"""Quality filtering parameters."""
|
|
86
|
-
|
|
87
|
-
mapping_quality_threshold: int = Field(20, ge=0, description="Mapping quality threshold")
|
|
88
|
-
base_quality_threshold: int = Field(0, ge=0, description="Base quality threshold")
|
|
89
|
-
filter_duplicate: bool = Field(True, description="Filter duplicate reads")
|
|
90
|
-
filter_improper_pair: bool = Field(False, description="Filter improper pairs")
|
|
91
|
-
filter_qc_failed: bool = Field(False, description="Filter QC failed reads")
|
|
92
|
-
filter_indel: bool = Field(False, description="Filter reads with indels")
|
|
93
|
-
filter_non_primary: bool = Field(False, description="Filter non-primary alignments")
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
class OutputOptions(BaseModel):
|
|
97
|
-
"""Output configuration options."""
|
|
98
|
-
|
|
99
|
-
output_file: Path = Field(..., description="Output file path")
|
|
100
|
-
output_maf: bool = Field(False, description="Output in MAF format")
|
|
101
|
-
output_positive_count: bool = Field(True, description="Output positive strand counts")
|
|
102
|
-
output_negative_count: bool = Field(False, description="Output negative strand counts")
|
|
103
|
-
output_fragment_count: bool = Field(False, description="Output fragment counts")
|
|
104
|
-
fragment_fractional_weight: bool = Field(
|
|
105
|
-
False, description="Use fractional weights for fragments"
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
model_config = {"arbitrary_types_allowed": True}
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
class PerformanceConfig(BaseModel):
|
|
112
|
-
"""Performance and parallelization configuration."""
|
|
113
|
-
|
|
114
|
-
num_threads: int = Field(1, ge=1, description="Number of threads")
|
|
115
|
-
max_block_size: int = Field(10000, ge=1, description="Maximum variants per block")
|
|
116
|
-
max_block_dist: int = Field(100000, ge=1, description="Maximum block distance in bp")
|
|
117
|
-
use_numba: bool = Field(True, description="Use Numba JIT compilation")
|
|
118
|
-
|
|
119
|
-
@field_validator("backend")
|
|
120
|
-
@classmethod
|
|
121
|
-
def validate_backend(cls, v: str) -> str:
|
|
122
|
-
"""Validate backend choice."""
|
|
123
|
-
valid_backends = ["joblib", "loky", "threading", "multiprocessing"]
|
|
124
|
-
if v.lower() not in valid_backends:
|
|
125
|
-
raise ValueError(f"Invalid backend: {v}. Must be one of: {', '.join(valid_backends)}")
|
|
126
|
-
return v.lower()
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
class GetBaseCountsConfig(BaseModel):
|
|
130
|
-
"""Complete configuration for GetBaseCounts with Pydantic validation."""
|
|
131
|
-
|
|
132
|
-
# Input files
|
|
133
|
-
fasta_file: Path = Field(..., description="Reference FASTA file")
|
|
134
|
-
bam_files: list[BamFileConfig] = Field(..., description="BAM files to process")
|
|
135
|
-
variant_files: list[VariantFileConfig] = Field(..., description="Variant files")
|
|
136
|
-
|
|
137
|
-
# Options
|
|
138
|
-
quality_filters: QualityFilters = Field(
|
|
139
|
-
default_factory=QualityFilters, description="Quality filtering options" # type: ignore[arg-type]
|
|
140
|
-
)
|
|
141
|
-
output_options: OutputOptions = Field(..., description="Output options")
|
|
142
|
-
performance: PerformanceConfig = Field(
|
|
143
|
-
default_factory=PerformanceConfig, description="Performance options" # type: ignore[arg-type]
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
# Advanced
|
|
147
|
-
generic_counting: bool = Field(False, description="Use generic counting algorithm")
|
|
148
|
-
max_warning_per_type: int = Field(3, ge=0, description="Maximum warnings per type")
|
|
149
|
-
|
|
150
|
-
@field_validator("fasta_file")
|
|
151
|
-
@classmethod
|
|
152
|
-
def validate_fasta_exists(cls, v: Path) -> Path:
|
|
153
|
-
"""Validate FASTA file exists."""
|
|
154
|
-
if not v.exists():
|
|
155
|
-
raise ValueError(f"FASTA file not found: {v}")
|
|
156
|
-
|
|
157
|
-
fai_file = Path(f"{v}.fai")
|
|
158
|
-
if not fai_file.exists():
|
|
159
|
-
raise ValueError(f"FASTA index not found: {fai_file}")
|
|
160
|
-
|
|
161
|
-
return v
|
|
162
|
-
|
|
163
|
-
@model_validator(mode="after")
|
|
164
|
-
def validate_variant_format_consistency(self) -> "GetBaseCountsConfig":
|
|
165
|
-
"""Validate variant file format consistency."""
|
|
166
|
-
formats = {vf.file_format for vf in self.variant_files}
|
|
167
|
-
if len(formats) > 1:
|
|
168
|
-
raise ValueError("All variant files must be the same format (all VCF or all MAF)")
|
|
169
|
-
|
|
170
|
-
# Check MAF output compatibility
|
|
171
|
-
if self.output_options.output_maf and "maf" not in formats:
|
|
172
|
-
raise ValueError("--omaf can only be used with MAF input")
|
|
173
|
-
|
|
174
|
-
return self
|
|
175
|
-
|
|
176
|
-
def get_sample_names(self) -> list[str]:
|
|
177
|
-
"""Get list of sample names in order."""
|
|
178
|
-
return [bam.sample_name for bam in self.bam_files]
|
|
179
|
-
|
|
180
|
-
def is_maf_input(self) -> bool:
|
|
181
|
-
"""Check if input is MAF format."""
|
|
182
|
-
return self.variant_files[0].file_format == "maf"
|
|
183
|
-
|
|
184
|
-
def is_vcf_input(self) -> bool:
|
|
185
|
-
"""Check if input is VCF format."""
|
|
186
|
-
return self.variant_files[0].file_format == "vcf"
|
|
187
|
-
|
|
188
|
-
model_config = {"arbitrary_types_allowed": True}
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
class VariantCounts(BaseModel):
|
|
192
|
-
"""Type-safe variant counts structure."""
|
|
193
|
-
|
|
194
|
-
sample_name: str
|
|
195
|
-
counts: np.ndarray = Field(..., description="Count array")
|
|
196
|
-
|
|
197
|
-
@field_validator("counts")
|
|
198
|
-
@classmethod
|
|
199
|
-
def validate_counts_shape(cls, v: np.ndarray) -> np.ndarray:
|
|
200
|
-
"""Validate counts array shape."""
|
|
201
|
-
if v.shape != (len(CountType),):
|
|
202
|
-
raise ValueError(f"Counts array must have shape ({len(CountType)},)")
|
|
203
|
-
return v
|
|
204
|
-
|
|
205
|
-
def get_count(self, count_type: CountType) -> float:
|
|
206
|
-
"""Get count for specific type."""
|
|
207
|
-
return float(self.counts[count_type])
|
|
208
|
-
|
|
209
|
-
def set_count(self, count_type: CountType, value: float) -> None:
|
|
210
|
-
"""Set count for specific type."""
|
|
211
|
-
self.counts[count_type] = value
|
|
212
|
-
|
|
213
|
-
model_config = {"arbitrary_types_allowed": True}
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
class VariantModel(BaseModel):
|
|
217
|
-
"""Pydantic model for variant with type safety."""
|
|
218
|
-
|
|
219
|
-
chrom: str = Field(..., description="Chromosome")
|
|
220
|
-
pos: int = Field(..., ge=0, description="Position (0-indexed)")
|
|
221
|
-
end_pos: int = Field(..., ge=0, description="End position")
|
|
222
|
-
ref: str = Field(..., min_length=1, description="Reference allele")
|
|
223
|
-
alt: str = Field(..., min_length=1, description="Alternate allele")
|
|
224
|
-
|
|
225
|
-
# Variant type flags
|
|
226
|
-
snp: bool = Field(False, description="Is SNP")
|
|
227
|
-
dnp: bool = Field(False, description="Is DNP")
|
|
228
|
-
dnp_len: int = Field(0, ge=0, description="DNP length")
|
|
229
|
-
insertion: bool = Field(False, description="Is insertion")
|
|
230
|
-
deletion: bool = Field(False, description="Is deletion")
|
|
231
|
-
|
|
232
|
-
# Sample information
|
|
233
|
-
tumor_sample: str = Field("", description="Tumor sample name")
|
|
234
|
-
normal_sample: str = Field("", description="Normal sample name")
|
|
235
|
-
|
|
236
|
-
# Annotation
|
|
237
|
-
gene: str = Field("", description="Gene name")
|
|
238
|
-
effect: str = Field("", description="Variant effect")
|
|
239
|
-
|
|
240
|
-
# Original MAF coordinates
|
|
241
|
-
maf_pos: int = Field(0, ge=0, description="Original MAF position")
|
|
242
|
-
maf_end_pos: int = Field(0, ge=0, description="Original MAF end position")
|
|
243
|
-
maf_ref: str = Field("", description="Original MAF reference")
|
|
244
|
-
maf_alt: str = Field("", description="Original MAF alternate")
|
|
245
|
-
caller: str = Field("", description="Variant caller")
|
|
246
|
-
|
|
247
|
-
# Counts
|
|
248
|
-
sample_counts: dict[str, VariantCounts] = Field(
|
|
249
|
-
default_factory=dict, description="Counts per sample"
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
@model_validator(mode="after")
|
|
253
|
-
def validate_positions(self) -> "VariantModel":
|
|
254
|
-
"""Validate position consistency."""
|
|
255
|
-
if self.end_pos < self.pos:
|
|
256
|
-
raise ValueError(f"End position {self.end_pos} < start position {self.pos}")
|
|
257
|
-
return self
|
|
258
|
-
|
|
259
|
-
@model_validator(mode="after")
|
|
260
|
-
def validate_variant_type(self) -> "VariantModel":
|
|
261
|
-
"""Validate variant type flags are consistent."""
|
|
262
|
-
type_count = sum([self.snp, self.dnp, self.insertion, self.deletion])
|
|
263
|
-
if type_count == 0:
|
|
264
|
-
# Auto-detect variant type
|
|
265
|
-
if len(self.ref) == len(self.alt) == 1:
|
|
266
|
-
self.snp = True
|
|
267
|
-
elif len(self.ref) == len(self.alt) > 1:
|
|
268
|
-
self.dnp = True
|
|
269
|
-
self.dnp_len = len(self.ref)
|
|
270
|
-
elif len(self.alt) > len(self.ref):
|
|
271
|
-
self.insertion = True
|
|
272
|
-
elif len(self.alt) < len(self.ref):
|
|
273
|
-
self.deletion = True
|
|
274
|
-
|
|
275
|
-
return self
|
|
276
|
-
|
|
277
|
-
def get_variant_key(self) -> tuple[str, int, str, str]:
|
|
278
|
-
"""Get unique variant key."""
|
|
279
|
-
return (self.chrom, self.pos, self.ref, self.alt)
|
|
280
|
-
|
|
281
|
-
def initialize_counts(self, sample_names: list[str]) -> None:
|
|
282
|
-
"""Initialize counts for all samples."""
|
|
283
|
-
for sample in sample_names:
|
|
284
|
-
if sample not in self.sample_counts:
|
|
285
|
-
self.sample_counts[sample] = VariantCounts(
|
|
286
|
-
sample_name=sample, counts=np.zeros(len(CountType), dtype=np.float32)
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
def get_count(self, sample: str, count_type: CountType) -> float:
|
|
290
|
-
"""Get count for specific sample and type."""
|
|
291
|
-
if sample not in self.sample_counts:
|
|
292
|
-
return 0.0
|
|
293
|
-
return self.sample_counts[sample].get_count(count_type)
|
|
294
|
-
|
|
295
|
-
model_config = {"arbitrary_types_allowed": True}
|
gbcms/numba_counter.py
DELETED
|
@@ -1,394 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Numba-optimized counting functions for high performance.
|
|
3
|
-
|
|
4
|
-
This module provides JIT-compiled counting functions that are 50-100x faster
|
|
5
|
-
than the pure Python implementation in `counter.py`. It uses Numba to compile
|
|
6
|
-
Python functions to machine code.
|
|
7
|
-
|
|
8
|
-
**When to use this module:**
|
|
9
|
-
- Large datasets (>10K variants)
|
|
10
|
-
- Production workloads
|
|
11
|
-
- When performance is critical
|
|
12
|
-
- Batch processing
|
|
13
|
-
|
|
14
|
-
**Performance:** 50-100x faster than `counter.py`
|
|
15
|
-
|
|
16
|
-
**Trade-offs:**
|
|
17
|
-
- ✅ Much faster (50-100x)
|
|
18
|
-
- ✅ Parallel processing with prange
|
|
19
|
-
- ✅ Cached compilation
|
|
20
|
-
- ❌ First call is slow (compilation time)
|
|
21
|
-
- ❌ Requires NumPy arrays (not pysam objects)
|
|
22
|
-
- ❌ Harder to debug (compiled code)
|
|
23
|
-
|
|
24
|
-
**Key Functions:**
|
|
25
|
-
- count_snp_base(): Single SNP counting (JIT compiled)
|
|
26
|
-
- count_snp_batch(): Batch SNP counting (parallel)
|
|
27
|
-
- filter_alignments_batch(): Vectorized filtering
|
|
28
|
-
- calculate_fragment_counts(): Fragment-level counting
|
|
29
|
-
|
|
30
|
-
**Usage:**
|
|
31
|
-
from gbcms.numba_counter import count_snp_batch
|
|
32
|
-
import numpy as np
|
|
33
|
-
|
|
34
|
-
# Convert pysam data to NumPy arrays
|
|
35
|
-
bases = np.array([aln.query_sequence for aln in alignments])
|
|
36
|
-
quals = np.array([aln.query_qualities for aln in alignments])
|
|
37
|
-
|
|
38
|
-
# Fast batch counting
|
|
39
|
-
counts = count_snp_batch(bases, quals, positions, ...)
|
|
40
|
-
|
|
41
|
-
**Note:** First call will be slow due to JIT compilation. Subsequent calls
|
|
42
|
-
are very fast. Use `cache=True` to cache compiled functions.
|
|
43
|
-
|
|
44
|
-
**Alternative:** For small datasets or development, see `counter.py` for
|
|
45
|
-
a pure Python implementation that's easier to debug.
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
import numpy as np
|
|
49
|
-
from numba import jit, prange
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@jit(nopython=True, cache=True)
|
|
53
|
-
def count_snp_base(
|
|
54
|
-
query_bases: np.ndarray,
|
|
55
|
-
query_qualities: np.ndarray,
|
|
56
|
-
reference_positions: np.ndarray,
|
|
57
|
-
is_reverse: np.ndarray,
|
|
58
|
-
variant_pos: int,
|
|
59
|
-
ref_base: str,
|
|
60
|
-
alt_base: str,
|
|
61
|
-
base_quality_threshold: int,
|
|
62
|
-
) -> tuple[int, int, int, int, int, int]:
|
|
63
|
-
"""
|
|
64
|
-
Count SNP bases with Numba JIT compilation.
|
|
65
|
-
|
|
66
|
-
Args:
|
|
67
|
-
query_bases: Array of query base characters
|
|
68
|
-
query_qualities: Array of base qualities
|
|
69
|
-
reference_positions: Array of reference positions
|
|
70
|
-
is_reverse: Array of strand orientation flags
|
|
71
|
-
variant_pos: Variant position
|
|
72
|
-
ref_base: Reference base
|
|
73
|
-
alt_base: Alternate base
|
|
74
|
-
base_quality_threshold: Quality threshold
|
|
75
|
-
|
|
76
|
-
Returns:
|
|
77
|
-
Tuple of (DP, RD, AD, DPP, RDP, ADP)
|
|
78
|
-
"""
|
|
79
|
-
dp = 0 # Total depth
|
|
80
|
-
rd = 0 # Reference depth
|
|
81
|
-
ad = 0 # Alternate depth
|
|
82
|
-
dpp = 0 # Positive strand depth
|
|
83
|
-
rdp = 0 # Positive strand reference depth
|
|
84
|
-
adp = 0 # Positive strand alternate depth
|
|
85
|
-
|
|
86
|
-
n_reads = len(query_bases)
|
|
87
|
-
|
|
88
|
-
for i in range(n_reads):
|
|
89
|
-
if reference_positions[i] != variant_pos:
|
|
90
|
-
continue
|
|
91
|
-
|
|
92
|
-
if query_qualities[i] < base_quality_threshold:
|
|
93
|
-
continue
|
|
94
|
-
|
|
95
|
-
base = query_bases[i]
|
|
96
|
-
|
|
97
|
-
# Count total depth
|
|
98
|
-
dp += 1
|
|
99
|
-
if not is_reverse[i]:
|
|
100
|
-
dpp += 1
|
|
101
|
-
|
|
102
|
-
# Count ref/alt
|
|
103
|
-
if base == ref_base:
|
|
104
|
-
rd += 1
|
|
105
|
-
if not is_reverse[i]:
|
|
106
|
-
rdp += 1
|
|
107
|
-
elif base == alt_base:
|
|
108
|
-
ad += 1
|
|
109
|
-
if not is_reverse[i]:
|
|
110
|
-
adp += 1
|
|
111
|
-
|
|
112
|
-
return dp, rd, ad, dpp, rdp, adp
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
@jit(nopython=True, cache=True, parallel=True)
|
|
116
|
-
def count_snp_batch(
|
|
117
|
-
query_bases_list: np.ndarray,
|
|
118
|
-
query_qualities_list: np.ndarray,
|
|
119
|
-
reference_positions_list: np.ndarray,
|
|
120
|
-
is_reverse_list: np.ndarray,
|
|
121
|
-
variant_positions: np.ndarray,
|
|
122
|
-
ref_bases: np.ndarray,
|
|
123
|
-
alt_bases: np.ndarray,
|
|
124
|
-
base_quality_threshold: int,
|
|
125
|
-
) -> np.ndarray:
|
|
126
|
-
"""
|
|
127
|
-
Count multiple SNPs in parallel with Numba.
|
|
128
|
-
|
|
129
|
-
Args:
|
|
130
|
-
query_bases_list: List of query base arrays
|
|
131
|
-
query_qualities_list: List of quality arrays
|
|
132
|
-
reference_positions_list: List of position arrays
|
|
133
|
-
is_reverse_list: List of strand arrays
|
|
134
|
-
variant_positions: Array of variant positions
|
|
135
|
-
ref_bases: Array of reference bases
|
|
136
|
-
alt_bases: Array of alternate bases
|
|
137
|
-
base_quality_threshold: Quality threshold
|
|
138
|
-
|
|
139
|
-
Returns:
|
|
140
|
-
Array of counts (n_variants, 6) with columns (DP, RD, AD, DPP, RDP, ADP)
|
|
141
|
-
"""
|
|
142
|
-
n_variants = len(variant_positions)
|
|
143
|
-
counts = np.zeros((n_variants, 6), dtype=np.int32)
|
|
144
|
-
|
|
145
|
-
for i in prange(n_variants):
|
|
146
|
-
dp, rd, ad, dpp, rdp, adp = count_snp_base(
|
|
147
|
-
query_bases_list[i],
|
|
148
|
-
query_qualities_list[i],
|
|
149
|
-
reference_positions_list[i],
|
|
150
|
-
is_reverse_list[i],
|
|
151
|
-
variant_positions[i],
|
|
152
|
-
ref_bases[i],
|
|
153
|
-
alt_bases[i],
|
|
154
|
-
base_quality_threshold,
|
|
155
|
-
)
|
|
156
|
-
counts[i, 0] = dp
|
|
157
|
-
counts[i, 1] = rd
|
|
158
|
-
counts[i, 2] = ad
|
|
159
|
-
counts[i, 3] = dpp
|
|
160
|
-
counts[i, 4] = rdp
|
|
161
|
-
counts[i, 5] = adp
|
|
162
|
-
|
|
163
|
-
return counts
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
@jit(nopython=True, cache=True)
|
|
167
|
-
def filter_alignment_numba(
|
|
168
|
-
is_duplicate: bool,
|
|
169
|
-
is_proper_pair: bool,
|
|
170
|
-
is_qcfail: bool,
|
|
171
|
-
is_secondary: bool,
|
|
172
|
-
is_supplementary: bool,
|
|
173
|
-
mapping_quality: int,
|
|
174
|
-
has_indel: bool,
|
|
175
|
-
filter_duplicate: bool,
|
|
176
|
-
filter_improper_pair: bool,
|
|
177
|
-
filter_qc_failed: bool,
|
|
178
|
-
filter_non_primary: bool,
|
|
179
|
-
filter_indel: bool,
|
|
180
|
-
mapping_quality_threshold: int,
|
|
181
|
-
) -> bool:
|
|
182
|
-
"""
|
|
183
|
-
Fast alignment filtering with Numba.
|
|
184
|
-
|
|
185
|
-
Returns:
|
|
186
|
-
True if alignment should be filtered (excluded)
|
|
187
|
-
"""
|
|
188
|
-
if filter_duplicate and is_duplicate:
|
|
189
|
-
return True
|
|
190
|
-
if filter_improper_pair and not is_proper_pair:
|
|
191
|
-
return True
|
|
192
|
-
if filter_qc_failed and is_qcfail:
|
|
193
|
-
return True
|
|
194
|
-
if filter_non_primary and (is_secondary or is_supplementary):
|
|
195
|
-
return True
|
|
196
|
-
if mapping_quality < mapping_quality_threshold:
|
|
197
|
-
return True
|
|
198
|
-
if filter_indel and has_indel:
|
|
199
|
-
return True
|
|
200
|
-
return False
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
@jit(nopython=True, cache=True, parallel=True)
|
|
204
|
-
def filter_alignments_batch(
|
|
205
|
-
is_duplicate: np.ndarray,
|
|
206
|
-
is_proper_pair: np.ndarray,
|
|
207
|
-
is_qcfail: np.ndarray,
|
|
208
|
-
is_secondary: np.ndarray,
|
|
209
|
-
is_supplementary: np.ndarray,
|
|
210
|
-
mapping_quality: np.ndarray,
|
|
211
|
-
has_indel: np.ndarray,
|
|
212
|
-
filter_duplicate: bool,
|
|
213
|
-
filter_improper_pair: bool,
|
|
214
|
-
filter_qc_failed: bool,
|
|
215
|
-
filter_non_primary: bool,
|
|
216
|
-
filter_indel: bool,
|
|
217
|
-
mapping_quality_threshold: int,
|
|
218
|
-
) -> np.ndarray:
|
|
219
|
-
"""
|
|
220
|
-
Filter multiple alignments in parallel.
|
|
221
|
-
|
|
222
|
-
Returns:
|
|
223
|
-
Boolean array where True means keep the alignment
|
|
224
|
-
"""
|
|
225
|
-
n = len(is_duplicate)
|
|
226
|
-
keep = np.ones(n, dtype=np.bool_)
|
|
227
|
-
|
|
228
|
-
for i in prange(n):
|
|
229
|
-
keep[i] = not filter_alignment_numba(
|
|
230
|
-
is_duplicate[i],
|
|
231
|
-
is_proper_pair[i],
|
|
232
|
-
is_qcfail[i],
|
|
233
|
-
is_secondary[i],
|
|
234
|
-
is_supplementary[i],
|
|
235
|
-
mapping_quality[i],
|
|
236
|
-
has_indel[i],
|
|
237
|
-
filter_duplicate,
|
|
238
|
-
filter_improper_pair,
|
|
239
|
-
filter_qc_failed,
|
|
240
|
-
filter_non_primary,
|
|
241
|
-
filter_indel,
|
|
242
|
-
mapping_quality_threshold,
|
|
243
|
-
)
|
|
244
|
-
|
|
245
|
-
return keep
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
@jit(nopython=True, cache=True)
|
|
249
|
-
def calculate_fragment_counts(
|
|
250
|
-
fragment_ids: np.ndarray,
|
|
251
|
-
end_numbers: np.ndarray,
|
|
252
|
-
has_ref: np.ndarray,
|
|
253
|
-
has_alt: np.ndarray,
|
|
254
|
-
fractional_weight: float,
|
|
255
|
-
) -> tuple[int, float, float]:
|
|
256
|
-
"""
|
|
257
|
-
Calculate fragment-level counts.
|
|
258
|
-
|
|
259
|
-
Args:
|
|
260
|
-
fragment_ids: Array of fragment identifiers
|
|
261
|
-
end_numbers: Array of read end numbers (1 or 2)
|
|
262
|
-
has_ref: Array indicating if fragment has reference
|
|
263
|
-
has_alt: Array indicating if fragment has alternate
|
|
264
|
-
fractional_weight: Weight for disagreement (0.5 or 0)
|
|
265
|
-
|
|
266
|
-
Returns:
|
|
267
|
-
Tuple of (DPF, RDF, ADF)
|
|
268
|
-
"""
|
|
269
|
-
# Get unique fragments
|
|
270
|
-
unique_fragments = np.unique(fragment_ids)
|
|
271
|
-
dpf = len(unique_fragments)
|
|
272
|
-
|
|
273
|
-
rdf = 0.0
|
|
274
|
-
adf = 0.0
|
|
275
|
-
|
|
276
|
-
for frag_id in unique_fragments:
|
|
277
|
-
# Find all reads for this fragment
|
|
278
|
-
frag_mask = fragment_ids == frag_id
|
|
279
|
-
frag_has_ref = np.any(has_ref[frag_mask])
|
|
280
|
-
frag_has_alt = np.any(has_alt[frag_mask])
|
|
281
|
-
|
|
282
|
-
# Check for overlapping ends
|
|
283
|
-
frag_ends = end_numbers[frag_mask]
|
|
284
|
-
unique_ends, end_counts = np.unique(frag_ends, return_counts=True)
|
|
285
|
-
if np.any(end_counts > 1):
|
|
286
|
-
# Skip fragments with overlapping multimapped reads
|
|
287
|
-
continue
|
|
288
|
-
|
|
289
|
-
# Count based on ref/alt presence
|
|
290
|
-
if frag_has_ref and frag_has_alt:
|
|
291
|
-
rdf += fractional_weight
|
|
292
|
-
adf += fractional_weight
|
|
293
|
-
elif frag_has_ref:
|
|
294
|
-
rdf += 1.0
|
|
295
|
-
elif frag_has_alt:
|
|
296
|
-
adf += 1.0
|
|
297
|
-
|
|
298
|
-
return dpf, rdf, adf
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
@jit(nopython=True, cache=True)
|
|
302
|
-
def find_cigar_position(
|
|
303
|
-
cigar_ops: np.ndarray,
|
|
304
|
-
cigar_lens: np.ndarray,
|
|
305
|
-
alignment_start: int,
|
|
306
|
-
target_pos: int,
|
|
307
|
-
) -> tuple[int, bool]:
|
|
308
|
-
"""
|
|
309
|
-
Find read position corresponding to reference position using CIGAR.
|
|
310
|
-
|
|
311
|
-
Args:
|
|
312
|
-
cigar_ops: Array of CIGAR operations
|
|
313
|
-
cigar_lens: Array of CIGAR lengths
|
|
314
|
-
alignment_start: Alignment start position
|
|
315
|
-
target_pos: Target reference position
|
|
316
|
-
|
|
317
|
-
Returns:
|
|
318
|
-
Tuple of (read_position, is_covered)
|
|
319
|
-
"""
|
|
320
|
-
ref_pos = alignment_start
|
|
321
|
-
read_pos = 0
|
|
322
|
-
|
|
323
|
-
for i in range(len(cigar_ops)):
|
|
324
|
-
op = cigar_ops[i]
|
|
325
|
-
length = cigar_lens[i]
|
|
326
|
-
|
|
327
|
-
if op == 0: # Match/mismatch (M)
|
|
328
|
-
if ref_pos <= target_pos < ref_pos + length:
|
|
329
|
-
return read_pos + (target_pos - ref_pos), True
|
|
330
|
-
ref_pos += length
|
|
331
|
-
read_pos += length
|
|
332
|
-
elif op == 1: # Insertion (I)
|
|
333
|
-
read_pos += length
|
|
334
|
-
elif op == 2: # Deletion (D)
|
|
335
|
-
if ref_pos <= target_pos < ref_pos + length:
|
|
336
|
-
return -1, False # Position is in deletion
|
|
337
|
-
ref_pos += length
|
|
338
|
-
elif op == 3: # Skipped region (N)
|
|
339
|
-
ref_pos += length
|
|
340
|
-
elif op == 4: # Soft clip (S)
|
|
341
|
-
read_pos += length
|
|
342
|
-
# Hard clip (H) and padding (P) don't affect positions
|
|
343
|
-
|
|
344
|
-
return -1, False
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@jit(nopython=True, cache=True)
|
|
348
|
-
def compute_base_quality_stats(
|
|
349
|
-
qualities: np.ndarray,
|
|
350
|
-
min_quality: int,
|
|
351
|
-
) -> tuple[float, float, int]:
|
|
352
|
-
"""
|
|
353
|
-
Compute base quality statistics.
|
|
354
|
-
|
|
355
|
-
Args:
|
|
356
|
-
qualities: Array of base qualities
|
|
357
|
-
min_quality: Minimum quality threshold
|
|
358
|
-
|
|
359
|
-
Returns:
|
|
360
|
-
Tuple of (mean_quality, median_quality, n_passing)
|
|
361
|
-
"""
|
|
362
|
-
n = len(qualities)
|
|
363
|
-
if n == 0:
|
|
364
|
-
return 0.0, 0.0, 0
|
|
365
|
-
|
|
366
|
-
mean_qual = np.mean(qualities)
|
|
367
|
-
median_qual = np.median(qualities)
|
|
368
|
-
n_passing = np.sum(qualities >= min_quality)
|
|
369
|
-
|
|
370
|
-
return float(mean_qual), float(median_qual), int(n_passing)
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
@jit(nopython=True, cache=True, parallel=True)
|
|
374
|
-
def vectorized_quality_filter(
|
|
375
|
-
qualities: np.ndarray,
|
|
376
|
-
threshold: int,
|
|
377
|
-
) -> np.ndarray:
|
|
378
|
-
"""
|
|
379
|
-
Vectorized quality filtering.
|
|
380
|
-
|
|
381
|
-
Args:
|
|
382
|
-
qualities: 2D array of qualities (n_reads, read_length)
|
|
383
|
-
threshold: Quality threshold
|
|
384
|
-
|
|
385
|
-
Returns:
|
|
386
|
-
Boolean array of passing reads
|
|
387
|
-
"""
|
|
388
|
-
n_reads = qualities.shape[0]
|
|
389
|
-
passing = np.zeros(n_reads, dtype=np.bool_)
|
|
390
|
-
|
|
391
|
-
for i in prange(n_reads):
|
|
392
|
-
passing[i] = np.all(qualities[i] >= threshold)
|
|
393
|
-
|
|
394
|
-
return passing
|