smftools 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +39 -7
- smftools/_settings.py +2 -0
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +34 -6
- smftools/cli/hmm_adata.py +239 -33
- smftools/cli/latent_adata.py +318 -0
- smftools/cli/load_adata.py +167 -131
- smftools/cli/preprocess_adata.py +180 -53
- smftools/cli/spatial_adata.py +152 -100
- smftools/cli_entry.py +38 -1
- smftools/config/__init__.py +2 -0
- smftools/config/conversion.yaml +11 -1
- smftools/config/default.yaml +42 -2
- smftools/config/experiment_config.py +59 -1
- smftools/constants.py +65 -0
- smftools/datasets/__init__.py +2 -0
- smftools/hmm/HMM.py +97 -3
- smftools/hmm/__init__.py +24 -13
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +2 -0
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +5 -2
- smftools/hmm/display_hmm.py +4 -1
- smftools/hmm/hmm_readwrite.py +7 -2
- smftools/hmm/nucleosome_hmm_refinement.py +2 -0
- smftools/informatics/__init__.py +59 -34
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +2 -0
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1093 -176
- smftools/informatics/basecalling.py +2 -0
- smftools/informatics/bed_functions.py +271 -61
- smftools/informatics/binarize_converted_base_identities.py +3 -0
- smftools/informatics/complement_base_list.py +2 -0
- smftools/informatics/converted_BAM_to_adata.py +641 -176
- smftools/informatics/fasta_functions.py +94 -10
- smftools/informatics/h5ad_functions.py +123 -4
- smftools/informatics/modkit_extract_to_adata.py +1019 -431
- smftools/informatics/modkit_functions.py +2 -0
- smftools/informatics/ohe.py +2 -0
- smftools/informatics/pod5_functions.py +3 -2
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/machine_learning/__init__.py +22 -6
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +18 -4
- smftools/machine_learning/data/preprocessing.py +2 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +2 -0
- smftools/machine_learning/evaluation/evaluators.py +14 -9
- smftools/machine_learning/inference/__init__.py +2 -0
- smftools/machine_learning/inference/inference_utils.py +2 -0
- smftools/machine_learning/inference/lightning_inference.py +6 -1
- smftools/machine_learning/inference/sklearn_inference.py +2 -0
- smftools/machine_learning/inference/sliding_window_inference.py +2 -0
- smftools/machine_learning/models/__init__.py +2 -0
- smftools/machine_learning/models/base.py +7 -2
- smftools/machine_learning/models/cnn.py +7 -2
- smftools/machine_learning/models/lightning_base.py +16 -11
- smftools/machine_learning/models/mlp.py +5 -1
- smftools/machine_learning/models/positional.py +7 -2
- smftools/machine_learning/models/rnn.py +5 -1
- smftools/machine_learning/models/sklearn_models.py +14 -9
- smftools/machine_learning/models/transformer.py +7 -2
- smftools/machine_learning/models/wrappers.py +6 -2
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +13 -3
- smftools/machine_learning/training/train_sklearn_model.py +2 -0
- smftools/machine_learning/utils/__init__.py +2 -0
- smftools/machine_learning/utils/device.py +5 -1
- smftools/machine_learning/utils/grl.py +5 -1
- smftools/metadata.py +1 -1
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +41 -31
- smftools/plotting/autocorrelation_plotting.py +9 -5
- smftools/plotting/classifiers.py +16 -4
- smftools/plotting/general_plotting.py +2415 -629
- smftools/plotting/hmm_plotting.py +97 -9
- smftools/plotting/position_stats.py +15 -7
- smftools/plotting/qc_plotting.py +6 -1
- smftools/preprocessing/__init__.py +36 -37
- smftools/preprocessing/append_base_context.py +17 -17
- smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
- smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/archived/calculate_complexity.py +2 -0
- smftools/preprocessing/archived/mark_duplicates.py +2 -0
- smftools/preprocessing/archived/preprocessing.py +2 -0
- smftools/preprocessing/archived/remove_duplicates.py +2 -0
- smftools/preprocessing/binary_layers_to_ohe.py +2 -1
- smftools/preprocessing/calculate_complexity_II.py +4 -1
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_pairwise_differences.py +2 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
- smftools/preprocessing/calculate_position_Youden.py +9 -2
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
- smftools/preprocessing/flag_duplicate_reads.py +42 -54
- smftools/preprocessing/make_dirs.py +2 -1
- smftools/preprocessing/min_non_diagonal.py +2 -0
- smftools/preprocessing/recipes.py +2 -0
- smftools/readwrite.py +53 -17
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +30 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +2 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +2 -0
- smftools/tools/archived/subset_adata_v2.py +2 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +119 -0
- smftools/tools/calculate_umap.py +93 -8
- smftools/tools/cluster_adata_on_methylation.py +7 -1
- smftools/tools/position_stats.py +17 -27
- smftools/tools/rolling_nn_distance.py +235 -0
- smftools/tools/tensor_factorization.py +169 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/METADATA +69 -33
- smftools-0.3.1.dist-info/RECORD +189 -0
- smftools-0.2.5.dist-info/RECORD +0 -181
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,14 +1,51 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import concurrent.futures
|
|
2
4
|
import gc
|
|
3
5
|
import re
|
|
4
6
|
import shutil
|
|
7
|
+
from dataclasses import dataclass, field
|
|
5
8
|
from pathlib import Path
|
|
6
|
-
from typing import Iterable, Optional, Union
|
|
9
|
+
from typing import Iterable, Mapping, Optional, Union
|
|
7
10
|
|
|
8
11
|
import numpy as np
|
|
9
12
|
import pandas as pd
|
|
10
13
|
from tqdm import tqdm
|
|
11
14
|
|
|
15
|
+
from smftools.constants import (
|
|
16
|
+
BARCODE,
|
|
17
|
+
BASE_QUALITY_SCORES,
|
|
18
|
+
DATASET,
|
|
19
|
+
DEMUX_TYPE,
|
|
20
|
+
H5_DIR,
|
|
21
|
+
MISMATCH_INTEGER_ENCODING,
|
|
22
|
+
MODKIT_EXTRACT_CALL_CODE_CANONICAL,
|
|
23
|
+
MODKIT_EXTRACT_CALL_CODE_MODIFIED,
|
|
24
|
+
MODKIT_EXTRACT_MODIFIED_BASE_A,
|
|
25
|
+
MODKIT_EXTRACT_MODIFIED_BASE_C,
|
|
26
|
+
MODKIT_EXTRACT_REF_STRAND_MINUS,
|
|
27
|
+
MODKIT_EXTRACT_REF_STRAND_PLUS,
|
|
28
|
+
MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
|
|
29
|
+
MODKIT_EXTRACT_SEQUENCE_BASES,
|
|
30
|
+
MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE,
|
|
31
|
+
MODKIT_EXTRACT_SEQUENCE_PADDING_BASE,
|
|
32
|
+
MODKIT_EXTRACT_TSV_COLUMN_CALL_CODE,
|
|
33
|
+
MODKIT_EXTRACT_TSV_COLUMN_CALL_PROB,
|
|
34
|
+
MODKIT_EXTRACT_TSV_COLUMN_CHROM,
|
|
35
|
+
MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE,
|
|
36
|
+
MODKIT_EXTRACT_TSV_COLUMN_READ_ID,
|
|
37
|
+
MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION,
|
|
38
|
+
MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND,
|
|
39
|
+
READ_MAPPING_DIRECTION,
|
|
40
|
+
READ_SPAN_MASK,
|
|
41
|
+
REFERENCE,
|
|
42
|
+
REFERENCE_DATASET_STRAND,
|
|
43
|
+
REFERENCE_STRAND,
|
|
44
|
+
SAMPLE,
|
|
45
|
+
SEQUENCE_INTEGER_DECODING,
|
|
46
|
+
SEQUENCE_INTEGER_ENCODING,
|
|
47
|
+
STRAND,
|
|
48
|
+
)
|
|
12
49
|
from smftools.logging_utils import get_logger
|
|
13
50
|
|
|
14
51
|
from .bam_functions import count_aligned_reads
|
|
@@ -16,9 +53,81 @@ from .bam_functions import count_aligned_reads
|
|
|
16
53
|
logger = get_logger(__name__)
|
|
17
54
|
|
|
18
55
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
56
|
+
@dataclass
|
|
57
|
+
class ModkitBatchDictionaries:
|
|
58
|
+
"""Container for per-batch modification dictionaries.
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
dict_total: Raw TSV DataFrames keyed by record and sample index.
|
|
62
|
+
dict_a: Adenine modification DataFrames.
|
|
63
|
+
dict_a_bottom: Adenine minus-strand DataFrames.
|
|
64
|
+
dict_a_top: Adenine plus-strand DataFrames.
|
|
65
|
+
dict_c: Cytosine modification DataFrames.
|
|
66
|
+
dict_c_bottom: Cytosine minus-strand DataFrames.
|
|
67
|
+
dict_c_top: Cytosine plus-strand DataFrames.
|
|
68
|
+
dict_combined_bottom: Combined minus-strand methylation arrays.
|
|
69
|
+
dict_combined_top: Combined plus-strand methylation arrays.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
dict_total: dict = field(default_factory=dict)
|
|
73
|
+
dict_a: dict = field(default_factory=dict)
|
|
74
|
+
dict_a_bottom: dict = field(default_factory=dict)
|
|
75
|
+
dict_a_top: dict = field(default_factory=dict)
|
|
76
|
+
dict_c: dict = field(default_factory=dict)
|
|
77
|
+
dict_c_bottom: dict = field(default_factory=dict)
|
|
78
|
+
dict_c_top: dict = field(default_factory=dict)
|
|
79
|
+
dict_combined_bottom: dict = field(default_factory=dict)
|
|
80
|
+
dict_combined_top: dict = field(default_factory=dict)
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def sample_types(self) -> list[str]:
|
|
84
|
+
"""Return ordered labels for the dictionary list."""
|
|
85
|
+
return [
|
|
86
|
+
"total",
|
|
87
|
+
"m6A",
|
|
88
|
+
"m6A_bottom_strand",
|
|
89
|
+
"m6A_top_strand",
|
|
90
|
+
"5mC",
|
|
91
|
+
"5mC_bottom_strand",
|
|
92
|
+
"5mC_top_strand",
|
|
93
|
+
"combined_bottom_strand",
|
|
94
|
+
"combined_top_strand",
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
def as_list(self) -> list[dict]:
|
|
98
|
+
"""Return the dictionaries in the expected list ordering."""
|
|
99
|
+
return [
|
|
100
|
+
self.dict_total,
|
|
101
|
+
self.dict_a,
|
|
102
|
+
self.dict_a_bottom,
|
|
103
|
+
self.dict_a_top,
|
|
104
|
+
self.dict_c,
|
|
105
|
+
self.dict_c_bottom,
|
|
106
|
+
self.dict_c_top,
|
|
107
|
+
self.dict_combined_bottom,
|
|
108
|
+
self.dict_combined_top,
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def filter_bam_records(bam, mapping_threshold, samtools_backend: str | None = "auto"):
|
|
113
|
+
"""Identify reference records that exceed a mapping threshold in one BAM.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
bam (Path | str): BAM file to inspect.
|
|
117
|
+
mapping_threshold (float): Minimum fraction of mapped reads required to keep a record.
|
|
118
|
+
samtools_backend (str | None): Samtools backend selection.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
set[str]: Record names that pass the mapping threshold.
|
|
122
|
+
|
|
123
|
+
Processing Steps:
|
|
124
|
+
1. Count aligned/unaligned reads per record.
|
|
125
|
+
2. Compute percent aligned and per-record mapping percentages.
|
|
126
|
+
3. Return records whose mapping fraction meets the threshold.
|
|
127
|
+
"""
|
|
128
|
+
aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(
|
|
129
|
+
bam, samtools_backend
|
|
130
|
+
)
|
|
22
131
|
|
|
23
132
|
total_reads = aligned_reads_count + unaligned_reads_count
|
|
24
133
|
percent_aligned = (aligned_reads_count * 100 / total_reads) if total_reads > 0 else 0
|
|
@@ -35,13 +144,30 @@ def filter_bam_records(bam, mapping_threshold):
|
|
|
35
144
|
return set(records)
|
|
36
145
|
|
|
37
146
|
|
|
38
|
-
def parallel_filter_bams(bam_path_list, mapping_threshold):
|
|
39
|
-
"""
|
|
147
|
+
def parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend: str | None = "auto"):
|
|
148
|
+
"""Aggregate mapping-threshold records across BAM files in parallel.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
bam_path_list (list[Path | str]): BAM files to scan.
|
|
152
|
+
mapping_threshold (float): Minimum fraction of mapped reads required to keep a record.
|
|
153
|
+
samtools_backend (str | None): Samtools backend selection.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
set[str]: Union of all record names passing the threshold in any BAM.
|
|
157
|
+
|
|
158
|
+
Processing Steps:
|
|
159
|
+
1. Spawn workers to compute passing records per BAM.
|
|
160
|
+
2. Merge all passing records into a single set.
|
|
161
|
+
3. Log the final record set.
|
|
162
|
+
"""
|
|
40
163
|
records_to_analyze = set()
|
|
41
164
|
|
|
42
165
|
with concurrent.futures.ProcessPoolExecutor() as executor:
|
|
43
166
|
results = executor.map(
|
|
44
|
-
filter_bam_records,
|
|
167
|
+
filter_bam_records,
|
|
168
|
+
bam_path_list,
|
|
169
|
+
[mapping_threshold] * len(bam_path_list),
|
|
170
|
+
[samtools_backend] * len(bam_path_list),
|
|
45
171
|
)
|
|
46
172
|
|
|
47
173
|
# Aggregate results
|
|
@@ -53,8 +179,21 @@ def parallel_filter_bams(bam_path_list, mapping_threshold):
|
|
|
53
179
|
|
|
54
180
|
|
|
55
181
|
def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
|
|
56
|
-
"""
|
|
57
|
-
|
|
182
|
+
"""Load and filter a modkit TSV file for relevant records and positions.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
tsv (Path | str): TSV file produced by modkit extract.
|
|
186
|
+
records_to_analyze (Iterable[str]): Record names to keep.
|
|
187
|
+
reference_dict (dict[str, tuple[int, str]]): Mapping of record to (length, sequence).
|
|
188
|
+
sample_index (int): Sample index to attach to the filtered results.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
dict[str, dict[int, pd.DataFrame]]: Filtered data keyed by record and sample index.
|
|
192
|
+
|
|
193
|
+
Processing Steps:
|
|
194
|
+
1. Read the TSV into a DataFrame.
|
|
195
|
+
2. Filter rows for each record to valid reference positions.
|
|
196
|
+
3. Emit per-record DataFrames keyed by the provided sample index.
|
|
58
197
|
"""
|
|
59
198
|
temp_df = pd.read_csv(tsv, sep="\t", header=0)
|
|
60
199
|
filtered_records = {}
|
|
@@ -65,9 +204,9 @@ def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
|
|
|
65
204
|
|
|
66
205
|
ref_length = reference_dict[record][0]
|
|
67
206
|
filtered_df = temp_df[
|
|
68
|
-
(temp_df[
|
|
69
|
-
& (temp_df[
|
|
70
|
-
& (temp_df[
|
|
207
|
+
(temp_df[MODKIT_EXTRACT_TSV_COLUMN_CHROM] == record)
|
|
208
|
+
& (temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION] >= 0)
|
|
209
|
+
& (temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION] < ref_length)
|
|
71
210
|
]
|
|
72
211
|
|
|
73
212
|
if not filtered_df.empty:
|
|
@@ -77,19 +216,23 @@ def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
|
|
|
77
216
|
|
|
78
217
|
|
|
79
218
|
def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size, threads=4):
|
|
80
|
-
"""
|
|
81
|
-
Loads and filters TSV files in parallel.
|
|
219
|
+
"""Load and filter a batch of TSVs in parallel.
|
|
82
220
|
|
|
83
|
-
|
|
84
|
-
tsv_batch (list):
|
|
85
|
-
records_to_analyze (
|
|
86
|
-
reference_dict (dict):
|
|
87
|
-
batch (int):
|
|
88
|
-
batch_size (int):
|
|
89
|
-
threads (int):
|
|
221
|
+
Args:
|
|
222
|
+
tsv_batch (list[Path | str]): TSV file paths for the batch.
|
|
223
|
+
records_to_analyze (Iterable[str]): Record names to keep.
|
|
224
|
+
reference_dict (dict[str, tuple[int, str]]): Mapping of record to (length, sequence).
|
|
225
|
+
batch (int): Batch number for progress logging.
|
|
226
|
+
batch_size (int): Number of TSVs in the batch.
|
|
227
|
+
threads (int): Parallel worker count.
|
|
90
228
|
|
|
91
229
|
Returns:
|
|
92
|
-
dict:
|
|
230
|
+
dict[str, dict[int, pd.DataFrame]]: Per-record DataFrames keyed by sample index.
|
|
231
|
+
|
|
232
|
+
Processing Steps:
|
|
233
|
+
1. Submit each TSV to a worker via `process_tsv`.
|
|
234
|
+
2. Merge per-record outputs into a single dictionary.
|
|
235
|
+
3. Return the aggregated per-record dictionary for the batch.
|
|
93
236
|
"""
|
|
94
237
|
dict_total = {record: {} for record in records_to_analyze}
|
|
95
238
|
|
|
@@ -114,15 +257,19 @@ def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, bat
|
|
|
114
257
|
|
|
115
258
|
|
|
116
259
|
def update_dict_to_skip(dict_to_skip, detected_modifications):
|
|
117
|
-
"""
|
|
118
|
-
Updates the dict_to_skip set based on the detected modifications.
|
|
260
|
+
"""Update dictionary skip indices based on modifications in the batch.
|
|
119
261
|
|
|
120
|
-
|
|
121
|
-
dict_to_skip (set):
|
|
122
|
-
detected_modifications (
|
|
262
|
+
Args:
|
|
263
|
+
dict_to_skip (set[int]): Initial set of dictionary indices to skip.
|
|
264
|
+
detected_modifications (Iterable[str]): Modification labels present (e.g., ["6mA", "5mC"]).
|
|
123
265
|
|
|
124
266
|
Returns:
|
|
125
|
-
set:
|
|
267
|
+
set[int]: Updated skip set after considering present modifications.
|
|
268
|
+
|
|
269
|
+
Processing Steps:
|
|
270
|
+
1. Define indices for A- and C-stranded dictionaries.
|
|
271
|
+
2. Remove indices for modifications that are present.
|
|
272
|
+
3. Return the updated skip set.
|
|
126
273
|
"""
|
|
127
274
|
# Define which indices correspond to modification-specific or strand-specific dictionaries
|
|
128
275
|
A_stranded_dicts = {2, 3} # m6A bottom and top strand dictionaries
|
|
@@ -143,31 +290,49 @@ def update_dict_to_skip(dict_to_skip, detected_modifications):
|
|
|
143
290
|
|
|
144
291
|
|
|
145
292
|
def process_modifications_for_sample(args):
|
|
146
|
-
"""
|
|
147
|
-
Processes a single (record, sample) pair to extract modification-specific data.
|
|
293
|
+
"""Extract modification-specific subsets for one record/sample pair.
|
|
148
294
|
|
|
149
|
-
|
|
150
|
-
args: (record, sample_index, sample_df, mods, max_reference_length)
|
|
295
|
+
Args:
|
|
296
|
+
args (tuple): (record, sample_index, sample_df, mods, max_reference_length).
|
|
151
297
|
|
|
152
298
|
Returns:
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
299
|
+
tuple[str, int, dict[str, pd.DataFrame | list]]:
|
|
300
|
+
Record, sample index, and a dict of modification-specific DataFrames
|
|
301
|
+
(with optional combined placeholders).
|
|
302
|
+
|
|
303
|
+
Processing Steps:
|
|
304
|
+
1. Filter by modified base (A/C) when requested.
|
|
305
|
+
2. Split filtered rows by strand where needed.
|
|
306
|
+
3. Add empty combined placeholders when both modifications are present.
|
|
156
307
|
"""
|
|
157
308
|
record, sample_index, sample_df, mods, max_reference_length = args
|
|
158
309
|
result = {}
|
|
159
310
|
if "6mA" in mods:
|
|
160
|
-
m6a_df = sample_df[
|
|
311
|
+
m6a_df = sample_df[
|
|
312
|
+
sample_df[MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE]
|
|
313
|
+
== MODKIT_EXTRACT_MODIFIED_BASE_A
|
|
314
|
+
]
|
|
161
315
|
result["m6A"] = m6a_df
|
|
162
|
-
result["m6A_minus"] = m6a_df[
|
|
163
|
-
|
|
316
|
+
result["m6A_minus"] = m6a_df[
|
|
317
|
+
m6a_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_MINUS
|
|
318
|
+
]
|
|
319
|
+
result["m6A_plus"] = m6a_df[
|
|
320
|
+
m6a_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_PLUS
|
|
321
|
+
]
|
|
164
322
|
m6a_df = None
|
|
165
323
|
gc.collect()
|
|
166
324
|
if "5mC" in mods:
|
|
167
|
-
m5c_df = sample_df[
|
|
325
|
+
m5c_df = sample_df[
|
|
326
|
+
sample_df[MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE]
|
|
327
|
+
== MODKIT_EXTRACT_MODIFIED_BASE_C
|
|
328
|
+
]
|
|
168
329
|
result["5mC"] = m5c_df
|
|
169
|
-
result["5mC_minus"] = m5c_df[
|
|
170
|
-
|
|
330
|
+
result["5mC_minus"] = m5c_df[
|
|
331
|
+
m5c_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_MINUS
|
|
332
|
+
]
|
|
333
|
+
result["5mC_plus"] = m5c_df[
|
|
334
|
+
m5c_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_PLUS
|
|
335
|
+
]
|
|
171
336
|
m5c_df = None
|
|
172
337
|
gc.collect()
|
|
173
338
|
if "6mA" in mods and "5mC" in mods:
|
|
@@ -177,11 +342,22 @@ def process_modifications_for_sample(args):
|
|
|
177
342
|
|
|
178
343
|
|
|
179
344
|
def parallel_process_modifications(dict_total, mods, max_reference_length, threads=4):
|
|
180
|
-
"""
|
|
181
|
-
|
|
345
|
+
"""Parallelize modification extraction across records and samples.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
dict_total (dict[str, dict[int, pd.DataFrame]]): Raw TSV DataFrames per record/sample.
|
|
349
|
+
mods (list[str]): Modification labels to process.
|
|
350
|
+
max_reference_length (int): Maximum reference length in the dataset.
|
|
351
|
+
threads (int): Parallel worker count.
|
|
182
352
|
|
|
183
353
|
Returns:
|
|
184
|
-
|
|
354
|
+
dict[str, dict[int, dict[str, pd.DataFrame | list]]]: Processed results keyed by
|
|
355
|
+
record and sample index.
|
|
356
|
+
|
|
357
|
+
Processing Steps:
|
|
358
|
+
1. Build a task list of (record, sample) pairs.
|
|
359
|
+
2. Submit tasks to a process pool.
|
|
360
|
+
3. Collect and store results in a nested dictionary.
|
|
185
361
|
"""
|
|
186
362
|
tasks = []
|
|
187
363
|
for record, sample_dict in dict_total.items():
|
|
@@ -201,11 +377,20 @@ def parallel_process_modifications(dict_total, mods, max_reference_length, threa
|
|
|
201
377
|
|
|
202
378
|
|
|
203
379
|
def merge_modification_results(processed_results, mods):
|
|
204
|
-
"""
|
|
205
|
-
|
|
380
|
+
"""Merge per-sample modification outputs into global dictionaries.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
processed_results (dict[str, dict[int, dict]]): Output of parallel modification extraction.
|
|
384
|
+
mods (list[str]): Modification labels to include.
|
|
206
385
|
|
|
207
386
|
Returns:
|
|
208
|
-
|
|
387
|
+
tuple[dict, dict, dict, dict, dict, dict, dict, dict]:
|
|
388
|
+
Global dictionaries for each modification/strand combination.
|
|
389
|
+
|
|
390
|
+
Processing Steps:
|
|
391
|
+
1. Initialize empty output dictionaries per modification category.
|
|
392
|
+
2. Populate each dictionary using the processed sample results.
|
|
393
|
+
3. Return the ordered tuple for downstream processing.
|
|
209
394
|
"""
|
|
210
395
|
m6A_dict = {}
|
|
211
396
|
m6A_minus = {}
|
|
@@ -247,18 +432,18 @@ def merge_modification_results(processed_results, mods):
|
|
|
247
432
|
|
|
248
433
|
|
|
249
434
|
def process_stranded_methylation(args):
|
|
250
|
-
"""
|
|
251
|
-
Processes a single (dict_index, record, sample) task.
|
|
435
|
+
"""Convert modification DataFrames into per-read methylation arrays.
|
|
252
436
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
NumPy methylation array (of float type). Non-numeric values (e.g. '-') are coerced to NaN.
|
|
256
|
-
|
|
257
|
-
Parameters:
|
|
258
|
-
args: (dict_index, record, sample, dict_list, max_reference_length)
|
|
437
|
+
Args:
|
|
438
|
+
args (tuple): (dict_index, record, sample, dict_list, max_reference_length).
|
|
259
439
|
|
|
260
440
|
Returns:
|
|
261
|
-
|
|
441
|
+
tuple[int, str, int, dict[str, np.ndarray]]: Updated dictionary entries for the task.
|
|
442
|
+
|
|
443
|
+
Processing Steps:
|
|
444
|
+
1. For combined dictionaries (indices 7/8), merge A- and C-strand arrays.
|
|
445
|
+
2. For other dictionaries, compute methylation probabilities per read/position.
|
|
446
|
+
3. Return per-read arrays keyed by read name.
|
|
262
447
|
"""
|
|
263
448
|
dict_index, record, sample, dict_list, max_reference_length = args
|
|
264
449
|
processed_data = {}
|
|
@@ -322,13 +507,15 @@ def process_stranded_methylation(args):
|
|
|
322
507
|
temp_df = dict_list[dict_index][record][sample]
|
|
323
508
|
processed_data = {}
|
|
324
509
|
# Extract columns and convert probabilities to float (coercing errors)
|
|
325
|
-
read_ids = temp_df[
|
|
326
|
-
positions = temp_df[
|
|
327
|
-
call_codes = temp_df[
|
|
328
|
-
probabilities = pd.to_numeric(
|
|
510
|
+
read_ids = temp_df[MODKIT_EXTRACT_TSV_COLUMN_READ_ID].values
|
|
511
|
+
positions = temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION].values
|
|
512
|
+
call_codes = temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_CODE].values
|
|
513
|
+
probabilities = pd.to_numeric(
|
|
514
|
+
temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_PROB].values, errors="coerce"
|
|
515
|
+
)
|
|
329
516
|
|
|
330
|
-
modified_codes =
|
|
331
|
-
canonical_codes =
|
|
517
|
+
modified_codes = MODKIT_EXTRACT_CALL_CODE_MODIFIED
|
|
518
|
+
canonical_codes = MODKIT_EXTRACT_CALL_CODE_CANONICAL
|
|
332
519
|
|
|
333
520
|
# Compute methylation probabilities (vectorized)
|
|
334
521
|
methylation_prob = np.full(probabilities.shape, np.nan, dtype=float)
|
|
@@ -356,11 +543,21 @@ def process_stranded_methylation(args):
|
|
|
356
543
|
|
|
357
544
|
|
|
358
545
|
def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=4):
|
|
359
|
-
"""
|
|
360
|
-
|
|
546
|
+
"""Parallelize per-read methylation extraction over all dictionary entries.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
dict_list (list[dict]): List of modification/strand dictionaries.
|
|
550
|
+
dict_to_skip (set[int]): Dictionary indices to exclude from processing.
|
|
551
|
+
max_reference_length (int): Maximum reference length for array sizing.
|
|
552
|
+
threads (int): Parallel worker count.
|
|
361
553
|
|
|
362
554
|
Returns:
|
|
363
|
-
Updated
|
|
555
|
+
list[dict]: Updated dictionary list with per-read methylation arrays.
|
|
556
|
+
|
|
557
|
+
Processing Steps:
|
|
558
|
+
1. Build tasks for every (dict_index, record, sample) to process.
|
|
559
|
+
2. Execute tasks in a process pool.
|
|
560
|
+
3. Replace DataFrames with per-read arrays in-place.
|
|
364
561
|
"""
|
|
365
562
|
tasks = []
|
|
366
563
|
for dict_index, current_dict in enumerate(dict_list):
|
|
@@ -386,21 +583,20 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
386
583
|
dry_run: bool = False,
|
|
387
584
|
verbose: bool = True,
|
|
388
585
|
):
|
|
389
|
-
"""
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
Print progress / warnings.
|
|
586
|
+
"""Delete intermediate .h5ad files and optionally a temporary directory.
|
|
587
|
+
|
|
588
|
+
Args:
|
|
589
|
+
h5_dir (str | Path | Iterable[str] | None): Directory or iterable of h5ad paths.
|
|
590
|
+
tmp_dir (str | Path | None): Temporary directory to remove recursively.
|
|
591
|
+
dry_run (bool): If True, log deletions without performing them.
|
|
592
|
+
verbose (bool): If True, log progress and warnings.
|
|
593
|
+
|
|
594
|
+
Returns:
|
|
595
|
+
None: This function performs deletions in-place.
|
|
596
|
+
|
|
597
|
+
Processing Steps:
|
|
598
|
+
1. Iterate over .h5ad file candidates and delete them (if not dry-run).
|
|
599
|
+
2. Remove the temporary directory tree if requested.
|
|
404
600
|
"""
|
|
405
601
|
|
|
406
602
|
# Helper: remove a single file path (Path-like or string)
|
|
@@ -471,6 +667,455 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
471
667
|
logger.warning(f"[error] failed to remove tmp dir {td}: {e}")
|
|
472
668
|
|
|
473
669
|
|
|
670
|
+
def _collect_input_paths(mod_tsv_dir: Path, bam_dir: Path) -> tuple[list[Path], list[Path]]:
|
|
671
|
+
"""Collect sorted TSV and BAM paths for processing.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
mod_tsv_dir (Path): Directory containing modkit extract TSVs.
|
|
675
|
+
bam_dir (Path): Directory containing aligned BAM files.
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
tuple[list[Path], list[Path]]: Sorted TSV paths and BAM paths.
|
|
679
|
+
|
|
680
|
+
Processing Steps:
|
|
681
|
+
1. Filter TSVs for extract outputs and exclude unclassified entries.
|
|
682
|
+
2. Filter BAMs for aligned files and exclude indexes/unclassified entries.
|
|
683
|
+
3. Sort both lists for deterministic processing.
|
|
684
|
+
"""
|
|
685
|
+
tsvs = sorted(
|
|
686
|
+
p
|
|
687
|
+
for p in mod_tsv_dir.iterdir()
|
|
688
|
+
if p.is_file() and "unclassified" not in p.name and "extract.tsv" in p.name
|
|
689
|
+
)
|
|
690
|
+
bams = sorted(
|
|
691
|
+
p
|
|
692
|
+
for p in bam_dir.iterdir()
|
|
693
|
+
if p.is_file()
|
|
694
|
+
and p.suffix == ".bam"
|
|
695
|
+
and "unclassified" not in p.name
|
|
696
|
+
and ".bai" not in p.name
|
|
697
|
+
)
|
|
698
|
+
return tsvs, bams
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
def _build_sample_maps(bam_path_list: list[Path]) -> tuple[dict[int, str], dict[int, str]]:
|
|
702
|
+
"""Build sample name and barcode maps from BAM filenames.
|
|
703
|
+
|
|
704
|
+
Args:
|
|
705
|
+
bam_path_list (list[Path]): Paths to BAM files in sample order.
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
tuple[dict[int, str], dict[int, str]]: Maps of sample index to sample name and barcode.
|
|
709
|
+
|
|
710
|
+
Processing Steps:
|
|
711
|
+
1. Parse the BAM stem for barcode suffixes.
|
|
712
|
+
2. Build a standardized sample name with barcode suffix.
|
|
713
|
+
3. Store mappings for downstream metadata annotations.
|
|
714
|
+
"""
|
|
715
|
+
sample_name_map: dict[int, str] = {}
|
|
716
|
+
barcode_map: dict[int, str] = {}
|
|
717
|
+
|
|
718
|
+
for idx, bam_path in enumerate(bam_path_list):
|
|
719
|
+
stem = bam_path.stem
|
|
720
|
+
m = re.search(r"^(.*?)[_\-\.]?(barcode[0-9A-Za-z\-]+)$", stem)
|
|
721
|
+
if m:
|
|
722
|
+
sample_name = m.group(1) or stem
|
|
723
|
+
barcode = m.group(2)
|
|
724
|
+
else:
|
|
725
|
+
sample_name = stem
|
|
726
|
+
barcode = stem
|
|
727
|
+
|
|
728
|
+
sample_name = f"{sample_name}_{barcode}"
|
|
729
|
+
barcode_id = int(barcode.split("barcode")[1])
|
|
730
|
+
|
|
731
|
+
sample_name_map[idx] = sample_name
|
|
732
|
+
barcode_map[idx] = str(barcode_id)
|
|
733
|
+
|
|
734
|
+
return sample_name_map, barcode_map
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
def _encode_sequence_array(
|
|
738
|
+
read_sequence: np.ndarray,
|
|
739
|
+
valid_length: int,
|
|
740
|
+
base_to_int: Mapping[str, int],
|
|
741
|
+
padding_value: int,
|
|
742
|
+
) -> np.ndarray:
|
|
743
|
+
"""Convert a base-identity array into integer encoding with padding.
|
|
744
|
+
|
|
745
|
+
Args:
|
|
746
|
+
read_sequence (np.ndarray): Array of base calls (dtype "<U1").
|
|
747
|
+
valid_length (int): Number of valid reference positions for this record.
|
|
748
|
+
base_to_int (Mapping[str, int]): Base-to-integer mapping for A/C/G/T/N/PAD.
|
|
749
|
+
padding_value (int): Integer value to use for padding.
|
|
750
|
+
|
|
751
|
+
Returns:
|
|
752
|
+
np.ndarray: Integer-encoded sequence with padding applied.
|
|
753
|
+
|
|
754
|
+
Processing Steps:
|
|
755
|
+
1. Initialize an integer array filled with the N value.
|
|
756
|
+
2. Overwrite values for known bases (A/C/G/T/N).
|
|
757
|
+
3. Replace positions beyond valid_length with padding.
|
|
758
|
+
"""
|
|
759
|
+
read_sequence = np.asarray(read_sequence, dtype="<U1")
|
|
760
|
+
encoded = np.full(read_sequence.shape, base_to_int["N"], dtype=np.int16)
|
|
761
|
+
for base in MODKIT_EXTRACT_SEQUENCE_BASES:
|
|
762
|
+
encoded[read_sequence == base] = base_to_int[base]
|
|
763
|
+
if valid_length < encoded.size:
|
|
764
|
+
encoded[valid_length:] = padding_value
|
|
765
|
+
return encoded
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def _write_sequence_batches(
|
|
769
|
+
base_identities: Mapping[str, np.ndarray],
|
|
770
|
+
tmp_dir: Path,
|
|
771
|
+
record: str,
|
|
772
|
+
prefix: str,
|
|
773
|
+
base_to_int: Mapping[str, int],
|
|
774
|
+
valid_length: int,
|
|
775
|
+
batch_size: int,
|
|
776
|
+
) -> list[str]:
|
|
777
|
+
"""Encode base identities into integer arrays and write batched H5AD files.
|
|
778
|
+
|
|
779
|
+
Args:
|
|
780
|
+
base_identities (Mapping[str, np.ndarray]): Read name to base identity arrays.
|
|
781
|
+
tmp_dir (Path): Directory for temporary H5AD files.
|
|
782
|
+
record (str): Reference record identifier.
|
|
783
|
+
prefix (str): Prefix used to name batch files.
|
|
784
|
+
base_to_int (Mapping[str, int]): Base-to-integer mapping.
|
|
785
|
+
valid_length (int): Valid reference length for padding determination.
|
|
786
|
+
batch_size (int): Number of reads per H5AD batch file.
|
|
787
|
+
|
|
788
|
+
Returns:
|
|
789
|
+
list[str]: Paths to written H5AD batch files.
|
|
790
|
+
|
|
791
|
+
Processing Steps:
|
|
792
|
+
1. Encode each read sequence to integer values.
|
|
793
|
+
2. Accumulate encoded reads into batches.
|
|
794
|
+
3. Persist each batch as an H5AD with the dictionary stored in `.uns`.
|
|
795
|
+
"""
|
|
796
|
+
import anndata as ad
|
|
797
|
+
|
|
798
|
+
padding_value = base_to_int[MODKIT_EXTRACT_SEQUENCE_PADDING_BASE]
|
|
799
|
+
batch_files: list[str] = []
|
|
800
|
+
batch: dict[str, np.ndarray] = {}
|
|
801
|
+
batch_number = 0
|
|
802
|
+
|
|
803
|
+
for read_name, sequence in base_identities.items():
|
|
804
|
+
if sequence is None:
|
|
805
|
+
continue
|
|
806
|
+
batch[read_name] = _encode_sequence_array(
|
|
807
|
+
sequence, valid_length, base_to_int, padding_value
|
|
808
|
+
)
|
|
809
|
+
if len(batch) >= batch_size:
|
|
810
|
+
save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
|
|
811
|
+
ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
|
|
812
|
+
batch_files.append(str(save_name))
|
|
813
|
+
batch = {}
|
|
814
|
+
batch_number += 1
|
|
815
|
+
|
|
816
|
+
if batch:
|
|
817
|
+
save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
|
|
818
|
+
ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
|
|
819
|
+
batch_files.append(str(save_name))
|
|
820
|
+
|
|
821
|
+
return batch_files
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
def _write_integer_batches(
|
|
825
|
+
sequences: Mapping[str, np.ndarray],
|
|
826
|
+
tmp_dir: Path,
|
|
827
|
+
record: str,
|
|
828
|
+
prefix: str,
|
|
829
|
+
batch_size: int,
|
|
830
|
+
) -> list[str]:
|
|
831
|
+
"""Write integer-encoded sequences into batched H5AD files.
|
|
832
|
+
|
|
833
|
+
Args:
|
|
834
|
+
sequences (Mapping[str, np.ndarray]): Read name to integer arrays.
|
|
835
|
+
tmp_dir (Path): Directory for temporary H5AD files.
|
|
836
|
+
record (str): Reference record identifier.
|
|
837
|
+
prefix (str): Prefix used to name batch files.
|
|
838
|
+
batch_size (int): Number of reads per H5AD batch file.
|
|
839
|
+
|
|
840
|
+
Returns:
|
|
841
|
+
list[str]: Paths to written H5AD batch files.
|
|
842
|
+
|
|
843
|
+
Processing Steps:
|
|
844
|
+
1. Accumulate integer arrays into batches.
|
|
845
|
+
2. Persist each batch as an H5AD with the dictionary stored in `.uns`.
|
|
846
|
+
"""
|
|
847
|
+
import anndata as ad
|
|
848
|
+
|
|
849
|
+
batch_files: list[str] = []
|
|
850
|
+
batch: dict[str, np.ndarray] = {}
|
|
851
|
+
batch_number = 0
|
|
852
|
+
|
|
853
|
+
for read_name, sequence in sequences.items():
|
|
854
|
+
if sequence is None:
|
|
855
|
+
continue
|
|
856
|
+
batch[read_name] = np.asarray(sequence, dtype=np.int16)
|
|
857
|
+
if len(batch) >= batch_size:
|
|
858
|
+
save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
|
|
859
|
+
ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
|
|
860
|
+
batch_files.append(str(save_name))
|
|
861
|
+
batch = {}
|
|
862
|
+
batch_number += 1
|
|
863
|
+
|
|
864
|
+
if batch:
|
|
865
|
+
save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
|
|
866
|
+
ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
|
|
867
|
+
batch_files.append(str(save_name))
|
|
868
|
+
|
|
869
|
+
return batch_files
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
def _load_sequence_batches(
|
|
873
|
+
batch_files: list[Path | str],
|
|
874
|
+
) -> tuple[dict[str, np.ndarray], set[str], set[str]]:
|
|
875
|
+
"""Load integer-encoded sequence batches from H5AD files.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
batch_files (list[Path | str]): H5AD paths containing encoded sequences in `.uns`.
|
|
879
|
+
|
|
880
|
+
Returns:
|
|
881
|
+
tuple[dict[str, np.ndarray], set[str], set[str]]:
|
|
882
|
+
Read-to-sequence mapping and sets of forward/reverse mapped reads.
|
|
883
|
+
|
|
884
|
+
Processing Steps:
|
|
885
|
+
1. Read each H5AD file.
|
|
886
|
+
2. Merge `.uns` dictionaries into a single mapping.
|
|
887
|
+
3. Track forward/reverse read IDs based on the filename marker.
|
|
888
|
+
"""
|
|
889
|
+
import anndata as ad
|
|
890
|
+
|
|
891
|
+
sequences: dict[str, np.ndarray] = {}
|
|
892
|
+
fwd_reads: set[str] = set()
|
|
893
|
+
rev_reads: set[str] = set()
|
|
894
|
+
for batch_file in batch_files:
|
|
895
|
+
batch_path = Path(batch_file)
|
|
896
|
+
batch_sequences = ad.read_h5ad(batch_path).uns
|
|
897
|
+
sequences.update(batch_sequences)
|
|
898
|
+
if "_fwd_" in batch_path.name:
|
|
899
|
+
fwd_reads.update(batch_sequences.keys())
|
|
900
|
+
elif "_rev_" in batch_path.name:
|
|
901
|
+
rev_reads.update(batch_sequences.keys())
|
|
902
|
+
return sequences, fwd_reads, rev_reads
|
|
903
|
+
|
|
904
|
+
|
|
905
|
+
def _load_integer_batches(batch_files: list[Path | str]) -> dict[str, np.ndarray]:
|
|
906
|
+
"""Load integer arrays from batched H5AD files.
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
batch_files (list[Path | str]): H5AD paths containing arrays in `.uns`.
|
|
910
|
+
|
|
911
|
+
Returns:
|
|
912
|
+
dict[str, np.ndarray]: Read-to-array mapping.
|
|
913
|
+
|
|
914
|
+
Processing Steps:
|
|
915
|
+
1. Read each H5AD file.
|
|
916
|
+
2. Merge `.uns` dictionaries into a single mapping.
|
|
917
|
+
"""
|
|
918
|
+
import anndata as ad
|
|
919
|
+
|
|
920
|
+
sequences: dict[str, np.ndarray] = {}
|
|
921
|
+
for batch_file in batch_files:
|
|
922
|
+
batch_path = Path(batch_file)
|
|
923
|
+
sequences.update(ad.read_h5ad(batch_path).uns)
|
|
924
|
+
return sequences
|
|
925
|
+
|
|
926
|
+
|
|
927
|
+
def _normalize_sequence_batch_files(batch_files: object) -> list[Path]:
|
|
928
|
+
"""Normalize cached batch file entries into a list of Paths.
|
|
929
|
+
|
|
930
|
+
Args:
|
|
931
|
+
batch_files (object): Cached batch file entry from AnnData `.uns`.
|
|
932
|
+
|
|
933
|
+
Returns:
|
|
934
|
+
list[Path]: Paths to batch files, filtered to non-empty values.
|
|
935
|
+
|
|
936
|
+
Processing Steps:
|
|
937
|
+
1. Convert numpy arrays and scalars into Python lists.
|
|
938
|
+
2. Filter out empty/placeholder values.
|
|
939
|
+
3. Cast remaining entries to Path objects.
|
|
940
|
+
"""
|
|
941
|
+
if batch_files is None:
|
|
942
|
+
return []
|
|
943
|
+
if isinstance(batch_files, np.ndarray):
|
|
944
|
+
batch_files = batch_files.tolist()
|
|
945
|
+
if isinstance(batch_files, (str, Path)):
|
|
946
|
+
batch_files = [batch_files]
|
|
947
|
+
if not isinstance(batch_files, list):
|
|
948
|
+
batch_files = list(batch_files)
|
|
949
|
+
normalized: list[Path] = []
|
|
950
|
+
for entry in batch_files:
|
|
951
|
+
if entry is None:
|
|
952
|
+
continue
|
|
953
|
+
entry_str = str(entry).strip()
|
|
954
|
+
if not entry_str or entry_str == ".":
|
|
955
|
+
continue
|
|
956
|
+
normalized.append(Path(entry_str))
|
|
957
|
+
return normalized
|
|
958
|
+
|
|
959
|
+
|
|
960
|
+
def _build_modification_dicts(
|
|
961
|
+
dict_total: dict,
|
|
962
|
+
mods: list[str],
|
|
963
|
+
) -> tuple[ModkitBatchDictionaries, set[int]]:
|
|
964
|
+
"""Build modification/strand dictionaries from the raw TSV batch dictionary.
|
|
965
|
+
|
|
966
|
+
Args:
|
|
967
|
+
dict_total (dict): Raw TSV DataFrames keyed by record and sample index.
|
|
968
|
+
mods (list[str]): Modification labels to include (e.g., ["6mA", "5mC"]).
|
|
969
|
+
|
|
970
|
+
Returns:
|
|
971
|
+
tuple[ModkitBatchDictionaries, set[int]]: Batch dictionaries and indices to skip.
|
|
972
|
+
|
|
973
|
+
Processing Steps:
|
|
974
|
+
1. Initialize modification dictionaries and skip-set.
|
|
975
|
+
2. Filter TSV rows per record/sample into modification and strand subsets.
|
|
976
|
+
3. Populate combined dict placeholders when both modifications are present.
|
|
977
|
+
"""
|
|
978
|
+
batch_dicts = ModkitBatchDictionaries(dict_total=dict_total)
|
|
979
|
+
dict_to_skip = {0, 1, 4}
|
|
980
|
+
combined_dicts = {7, 8}
|
|
981
|
+
A_stranded_dicts = {2, 3}
|
|
982
|
+
C_stranded_dicts = {5, 6}
|
|
983
|
+
dict_to_skip.update(combined_dicts | A_stranded_dicts | C_stranded_dicts)
|
|
984
|
+
|
|
985
|
+
for record in dict_total.keys():
|
|
986
|
+
for sample_index in dict_total[record].keys():
|
|
987
|
+
if "6mA" in mods:
|
|
988
|
+
dict_to_skip.difference_update(A_stranded_dicts)
|
|
989
|
+
if (
|
|
990
|
+
record not in batch_dicts.dict_a.keys()
|
|
991
|
+
and record not in batch_dicts.dict_a_bottom.keys()
|
|
992
|
+
and record not in batch_dicts.dict_a_top.keys()
|
|
993
|
+
):
|
|
994
|
+
(
|
|
995
|
+
batch_dicts.dict_a[record],
|
|
996
|
+
batch_dicts.dict_a_bottom[record],
|
|
997
|
+
batch_dicts.dict_a_top[record],
|
|
998
|
+
) = ({}, {}, {})
|
|
999
|
+
|
|
1000
|
+
batch_dicts.dict_a[record][sample_index] = dict_total[record][sample_index][
|
|
1001
|
+
dict_total[record][sample_index][
|
|
1002
|
+
MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE
|
|
1003
|
+
]
|
|
1004
|
+
== MODKIT_EXTRACT_MODIFIED_BASE_A
|
|
1005
|
+
]
|
|
1006
|
+
logger.debug(
|
|
1007
|
+
"Successfully loaded a methyl-adenine dictionary for {}".format(
|
|
1008
|
+
str(sample_index)
|
|
1009
|
+
)
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
batch_dicts.dict_a_bottom[record][sample_index] = batch_dicts.dict_a[record][
|
|
1013
|
+
sample_index
|
|
1014
|
+
][
|
|
1015
|
+
batch_dicts.dict_a[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
|
|
1016
|
+
== MODKIT_EXTRACT_REF_STRAND_MINUS
|
|
1017
|
+
]
|
|
1018
|
+
logger.debug(
|
|
1019
|
+
"Successfully loaded a minus strand methyl-adenine dictionary for {}".format(
|
|
1020
|
+
str(sample_index)
|
|
1021
|
+
)
|
|
1022
|
+
)
|
|
1023
|
+
batch_dicts.dict_a_top[record][sample_index] = batch_dicts.dict_a[record][
|
|
1024
|
+
sample_index
|
|
1025
|
+
][
|
|
1026
|
+
batch_dicts.dict_a[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
|
|
1027
|
+
== MODKIT_EXTRACT_REF_STRAND_PLUS
|
|
1028
|
+
]
|
|
1029
|
+
logger.debug(
|
|
1030
|
+
"Successfully loaded a plus strand methyl-adenine dictionary for ".format(
|
|
1031
|
+
str(sample_index)
|
|
1032
|
+
)
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
batch_dicts.dict_a[record][sample_index] = None
|
|
1036
|
+
gc.collect()
|
|
1037
|
+
|
|
1038
|
+
if "5mC" in mods:
|
|
1039
|
+
dict_to_skip.difference_update(C_stranded_dicts)
|
|
1040
|
+
if (
|
|
1041
|
+
record not in batch_dicts.dict_c.keys()
|
|
1042
|
+
and record not in batch_dicts.dict_c_bottom.keys()
|
|
1043
|
+
and record not in batch_dicts.dict_c_top.keys()
|
|
1044
|
+
):
|
|
1045
|
+
(
|
|
1046
|
+
batch_dicts.dict_c[record],
|
|
1047
|
+
batch_dicts.dict_c_bottom[record],
|
|
1048
|
+
batch_dicts.dict_c_top[record],
|
|
1049
|
+
) = ({}, {}, {})
|
|
1050
|
+
|
|
1051
|
+
batch_dicts.dict_c[record][sample_index] = dict_total[record][sample_index][
|
|
1052
|
+
dict_total[record][sample_index][
|
|
1053
|
+
MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE
|
|
1054
|
+
]
|
|
1055
|
+
== MODKIT_EXTRACT_MODIFIED_BASE_C
|
|
1056
|
+
]
|
|
1057
|
+
logger.debug(
|
|
1058
|
+
"Successfully loaded a methyl-cytosine dictionary for {}".format(
|
|
1059
|
+
str(sample_index)
|
|
1060
|
+
)
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
batch_dicts.dict_c_bottom[record][sample_index] = batch_dicts.dict_c[record][
|
|
1064
|
+
sample_index
|
|
1065
|
+
][
|
|
1066
|
+
batch_dicts.dict_c[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
|
|
1067
|
+
== MODKIT_EXTRACT_REF_STRAND_MINUS
|
|
1068
|
+
]
|
|
1069
|
+
logger.debug(
|
|
1070
|
+
"Successfully loaded a minus strand methyl-cytosine dictionary for {}".format(
|
|
1071
|
+
str(sample_index)
|
|
1072
|
+
)
|
|
1073
|
+
)
|
|
1074
|
+
batch_dicts.dict_c_top[record][sample_index] = batch_dicts.dict_c[record][
|
|
1075
|
+
sample_index
|
|
1076
|
+
][
|
|
1077
|
+
batch_dicts.dict_c[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
|
|
1078
|
+
== MODKIT_EXTRACT_REF_STRAND_PLUS
|
|
1079
|
+
]
|
|
1080
|
+
logger.debug(
|
|
1081
|
+
"Successfully loaded a plus strand methyl-cytosine dictionary for {}".format(
|
|
1082
|
+
str(sample_index)
|
|
1083
|
+
)
|
|
1084
|
+
)
|
|
1085
|
+
|
|
1086
|
+
batch_dicts.dict_c[record][sample_index] = None
|
|
1087
|
+
gc.collect()
|
|
1088
|
+
|
|
1089
|
+
if "6mA" in mods and "5mC" in mods:
|
|
1090
|
+
dict_to_skip.difference_update(combined_dicts)
|
|
1091
|
+
if (
|
|
1092
|
+
record not in batch_dicts.dict_combined_bottom.keys()
|
|
1093
|
+
and record not in batch_dicts.dict_combined_top.keys()
|
|
1094
|
+
):
|
|
1095
|
+
(
|
|
1096
|
+
batch_dicts.dict_combined_bottom[record],
|
|
1097
|
+
batch_dicts.dict_combined_top[record],
|
|
1098
|
+
) = ({}, {})
|
|
1099
|
+
|
|
1100
|
+
logger.debug(
|
|
1101
|
+
"Successfully created a minus strand combined methylation dictionary for {}".format(
|
|
1102
|
+
str(sample_index)
|
|
1103
|
+
)
|
|
1104
|
+
)
|
|
1105
|
+
batch_dicts.dict_combined_bottom[record][sample_index] = []
|
|
1106
|
+
logger.debug(
|
|
1107
|
+
"Successfully created a plus strand combined methylation dictionary for {}".format(
|
|
1108
|
+
str(sample_index)
|
|
1109
|
+
)
|
|
1110
|
+
)
|
|
1111
|
+
batch_dicts.dict_combined_top[record][sample_index] = []
|
|
1112
|
+
|
|
1113
|
+
dict_total[record][sample_index] = None
|
|
1114
|
+
gc.collect()
|
|
1115
|
+
|
|
1116
|
+
return batch_dicts, dict_to_skip
|
|
1117
|
+
|
|
1118
|
+
|
|
474
1119
|
def modkit_extract_to_adata(
|
|
475
1120
|
fasta,
|
|
476
1121
|
bam_dir,
|
|
@@ -484,25 +1129,34 @@ def modkit_extract_to_adata(
|
|
|
484
1129
|
delete_batch_hdfs=False,
|
|
485
1130
|
threads=None,
|
|
486
1131
|
double_barcoded_path=None,
|
|
1132
|
+
samtools_backend: str | None = "auto",
|
|
487
1133
|
):
|
|
488
|
-
"""
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
double_barcoded_path (Path):
|
|
1134
|
+
"""Convert modkit extract TSVs and BAMs into an AnnData object.
|
|
1135
|
+
|
|
1136
|
+
Args:
|
|
1137
|
+
fasta (Path): Reference FASTA path.
|
|
1138
|
+
bam_dir (Path): Directory with aligned BAM files.
|
|
1139
|
+
out_dir (Path): Output directory for intermediate and final H5ADs.
|
|
1140
|
+
input_already_demuxed (bool): Whether reads were already demultiplexed.
|
|
1141
|
+
mapping_threshold (float): Minimum fraction of mapped reads to keep a record.
|
|
1142
|
+
experiment_name (str): Experiment name used in output file naming.
|
|
1143
|
+
mods (list[str]): Modification labels to analyze (e.g., ["6mA", "5mC"]).
|
|
1144
|
+
batch_size (int): Number of TSVs to process per batch.
|
|
1145
|
+
mod_tsv_dir (Path): Directory containing modkit extract TSVs.
|
|
1146
|
+
delete_batch_hdfs (bool): Remove batch H5ADs after concatenation.
|
|
1147
|
+
threads (int | None): Thread count for parallel operations.
|
|
1148
|
+
double_barcoded_path (Path | None): Dorado demux summary directory for double barcodes.
|
|
1149
|
+
samtools_backend (str | None): Samtools backend selection.
|
|
503
1150
|
|
|
504
1151
|
Returns:
|
|
505
|
-
|
|
1152
|
+
tuple[ad.AnnData | None, Path]: The final AnnData (if created) and its H5AD path.
|
|
1153
|
+
|
|
1154
|
+
Processing Steps:
|
|
1155
|
+
1. Discover input TSV/BAM files and derive sample metadata.
|
|
1156
|
+
2. Identify records that pass mapping thresholds and build reference metadata.
|
|
1157
|
+
3. Encode read sequences into integer arrays and cache them.
|
|
1158
|
+
4. Process TSV batches into per-read methylation matrices.
|
|
1159
|
+
5. Concatenate batch H5ADs into a final AnnData with consensus sequences.
|
|
506
1160
|
"""
|
|
507
1161
|
###################################################
|
|
508
1162
|
# Package imports
|
|
@@ -519,12 +1173,11 @@ def modkit_extract_to_adata(
|
|
|
519
1173
|
from ..readwrite import make_dirs
|
|
520
1174
|
from .bam_functions import extract_base_identities
|
|
521
1175
|
from .fasta_functions import get_native_references
|
|
522
|
-
from .ohe import ohe_batching
|
|
523
1176
|
###################################################
|
|
524
1177
|
|
|
525
1178
|
################## Get input tsv and bam file names into a sorted list ################
|
|
526
1179
|
# Make output dirs
|
|
527
|
-
h5_dir = out_dir /
|
|
1180
|
+
h5_dir = out_dir / H5_DIR
|
|
528
1181
|
tmp_dir = out_dir / "tmp"
|
|
529
1182
|
make_dirs([h5_dir, tmp_dir])
|
|
530
1183
|
|
|
@@ -538,60 +1191,19 @@ def modkit_extract_to_adata(
|
|
|
538
1191
|
logger.debug(f"{final_adata_path} already exists. Using existing adata")
|
|
539
1192
|
return final_adata, final_adata_path
|
|
540
1193
|
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
for p in mod_tsv_dir.iterdir()
|
|
545
|
-
if p.is_file() and "unclassified" not in p.name and "extract.tsv" in p.name
|
|
546
|
-
)
|
|
547
|
-
bams = sorted(
|
|
548
|
-
p
|
|
549
|
-
for p in bam_dir.iterdir()
|
|
550
|
-
if p.is_file()
|
|
551
|
-
and p.suffix == ".bam"
|
|
552
|
-
and "unclassified" not in p.name
|
|
553
|
-
and ".bai" not in p.name
|
|
554
|
-
)
|
|
555
|
-
|
|
556
|
-
tsv_path_list = [tsv for tsv in tsvs]
|
|
557
|
-
bam_path_list = [bam for bam in bams]
|
|
1194
|
+
tsvs, bams = _collect_input_paths(mod_tsv_dir, bam_dir)
|
|
1195
|
+
tsv_path_list = list(tsvs)
|
|
1196
|
+
bam_path_list = list(bams)
|
|
558
1197
|
logger.info(f"{len(tsvs)} sample tsv files found: {tsvs}")
|
|
559
1198
|
logger.info(f"{len(bams)} sample bams found: {bams}")
|
|
560
1199
|
|
|
561
1200
|
# Map global sample index (bami / final_sample_index) -> sample name / barcode
|
|
562
|
-
sample_name_map =
|
|
563
|
-
barcode_map = {}
|
|
564
|
-
|
|
565
|
-
for idx, bam_path in enumerate(bam_path_list):
|
|
566
|
-
stem = bam_path.stem
|
|
567
|
-
|
|
568
|
-
# Try to peel off a "barcode..." suffix if present.
|
|
569
|
-
# This handles things like:
|
|
570
|
-
# "mySample_barcode01" -> sample="mySample", barcode="barcode01"
|
|
571
|
-
# "run1-s1_barcode05" -> sample="run1-s1", barcode="barcode05"
|
|
572
|
-
# "barcode01" -> sample="barcode01", barcode="barcode01"
|
|
573
|
-
m = re.search(r"^(.*?)[_\-\.]?(barcode[0-9A-Za-z\-]+)$", stem)
|
|
574
|
-
if m:
|
|
575
|
-
sample_name = m.group(1) or stem
|
|
576
|
-
barcode = m.group(2)
|
|
577
|
-
else:
|
|
578
|
-
# Fallback: treat the whole stem as both sample & barcode
|
|
579
|
-
sample_name = stem
|
|
580
|
-
barcode = stem
|
|
581
|
-
|
|
582
|
-
# make sample name of the format of the bam file stem
|
|
583
|
-
sample_name = sample_name + f"_{barcode}"
|
|
584
|
-
|
|
585
|
-
# Clean the barcode name to be an integer
|
|
586
|
-
barcode = int(barcode.split("barcode")[1])
|
|
587
|
-
|
|
588
|
-
sample_name_map[idx] = sample_name
|
|
589
|
-
barcode_map[idx] = str(barcode)
|
|
1201
|
+
sample_name_map, barcode_map = _build_sample_maps(bam_path_list)
|
|
590
1202
|
##########################################################################################
|
|
591
1203
|
|
|
592
1204
|
######### Get Record names that have over a passed threshold of mapped reads #############
|
|
593
1205
|
# get all records that are above a certain mapping threshold in at least one sample bam
|
|
594
|
-
records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold)
|
|
1206
|
+
records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend)
|
|
595
1207
|
|
|
596
1208
|
##########################################################################################
|
|
597
1209
|
|
|
@@ -611,57 +1223,154 @@ def modkit_extract_to_adata(
|
|
|
611
1223
|
##########################################################################################
|
|
612
1224
|
|
|
613
1225
|
##########################################################################################
|
|
614
|
-
#
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
if
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
1226
|
+
# Encode read sequences into integer arrays and cache in tmp_dir.
|
|
1227
|
+
sequence_batch_files: dict[str, list[str]] = {}
|
|
1228
|
+
mismatch_batch_files: dict[str, list[str]] = {}
|
|
1229
|
+
quality_batch_files: dict[str, list[str]] = {}
|
|
1230
|
+
read_span_batch_files: dict[str, list[str]] = {}
|
|
1231
|
+
sequence_cache_path = tmp_dir / "tmp_sequence_int_file_dict.h5ad"
|
|
1232
|
+
cache_needs_rebuild = True
|
|
1233
|
+
if sequence_cache_path.exists():
|
|
1234
|
+
cached_uns = ad.read_h5ad(sequence_cache_path).uns
|
|
1235
|
+
if "sequence_batch_files" in cached_uns:
|
|
1236
|
+
sequence_batch_files = cached_uns.get("sequence_batch_files", {})
|
|
1237
|
+
mismatch_batch_files = cached_uns.get("mismatch_batch_files", {})
|
|
1238
|
+
quality_batch_files = cached_uns.get("quality_batch_files", {})
|
|
1239
|
+
read_span_batch_files = cached_uns.get("read_span_batch_files", {})
|
|
1240
|
+
cache_needs_rebuild = not (
|
|
1241
|
+
quality_batch_files and read_span_batch_files and sequence_batch_files
|
|
1242
|
+
)
|
|
1243
|
+
else:
|
|
1244
|
+
sequence_batch_files = cached_uns
|
|
1245
|
+
cache_needs_rebuild = True
|
|
1246
|
+
if cache_needs_rebuild:
|
|
1247
|
+
logger.info(
|
|
1248
|
+
"Cached sequence batches missing quality or read-span data; rebuilding cache."
|
|
1249
|
+
)
|
|
1250
|
+
else:
|
|
1251
|
+
logger.debug("Found existing integer-encoded reads, using these")
|
|
1252
|
+
if cache_needs_rebuild:
|
|
626
1253
|
for bami, bam in enumerate(bam_path_list):
|
|
627
|
-
|
|
1254
|
+
logger.info(
|
|
1255
|
+
f"Extracting base level sequences, qualities, reference spans, and mismatches per read for bam {bami}"
|
|
1256
|
+
)
|
|
628
1257
|
for record in records_to_analyze:
|
|
629
1258
|
current_reference_length = reference_dict[record][0]
|
|
630
1259
|
positions = range(current_reference_length)
|
|
631
1260
|
ref_seq = reference_dict[record][1]
|
|
632
|
-
# Extract the base identities of reads aligned to the record
|
|
633
1261
|
(
|
|
634
1262
|
fwd_base_identities,
|
|
635
1263
|
rev_base_identities,
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
1264
|
+
_mismatch_counts_per_read,
|
|
1265
|
+
_mismatch_trend_per_read,
|
|
1266
|
+
mismatch_base_identities,
|
|
1267
|
+
base_quality_scores,
|
|
1268
|
+
read_span_masks,
|
|
1269
|
+
) = extract_base_identities(
|
|
1270
|
+
bam, record, positions, max_reference_length, ref_seq, samtools_backend
|
|
1271
|
+
)
|
|
1272
|
+
mismatch_fwd = {
|
|
1273
|
+
read_name: mismatch_base_identities[read_name]
|
|
1274
|
+
for read_name in fwd_base_identities
|
|
1275
|
+
}
|
|
1276
|
+
mismatch_rev = {
|
|
1277
|
+
read_name: mismatch_base_identities[read_name]
|
|
1278
|
+
for read_name in rev_base_identities
|
|
1279
|
+
}
|
|
1280
|
+
quality_fwd = {
|
|
1281
|
+
read_name: base_quality_scores[read_name] for read_name in fwd_base_identities
|
|
1282
|
+
}
|
|
1283
|
+
quality_rev = {
|
|
1284
|
+
read_name: base_quality_scores[read_name] for read_name in rev_base_identities
|
|
1285
|
+
}
|
|
1286
|
+
read_span_fwd = {
|
|
1287
|
+
read_name: read_span_masks[read_name] for read_name in fwd_base_identities
|
|
1288
|
+
}
|
|
1289
|
+
read_span_rev = {
|
|
1290
|
+
read_name: read_span_masks[read_name] for read_name in rev_base_identities
|
|
1291
|
+
}
|
|
1292
|
+
fwd_sequence_files = _write_sequence_batches(
|
|
644
1293
|
fwd_base_identities,
|
|
645
1294
|
tmp_dir,
|
|
646
1295
|
record,
|
|
647
1296
|
f"{bami}_fwd",
|
|
1297
|
+
MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
|
|
1298
|
+
current_reference_length,
|
|
648
1299
|
batch_size=100000,
|
|
649
|
-
threads=threads,
|
|
650
1300
|
)
|
|
651
|
-
|
|
1301
|
+
rev_sequence_files = _write_sequence_batches(
|
|
652
1302
|
rev_base_identities,
|
|
653
1303
|
tmp_dir,
|
|
654
1304
|
record,
|
|
655
1305
|
f"{bami}_rev",
|
|
1306
|
+
MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
|
|
1307
|
+
current_reference_length,
|
|
656
1308
|
batch_size=100000,
|
|
657
|
-
threads=threads,
|
|
658
1309
|
)
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
1310
|
+
sequence_batch_files[f"{bami}_{record}"] = fwd_sequence_files + rev_sequence_files
|
|
1311
|
+
mismatch_fwd_files = _write_integer_batches(
|
|
1312
|
+
mismatch_fwd,
|
|
1313
|
+
tmp_dir,
|
|
1314
|
+
record,
|
|
1315
|
+
f"{bami}_mismatch_fwd",
|
|
1316
|
+
batch_size=100000,
|
|
1317
|
+
)
|
|
1318
|
+
mismatch_rev_files = _write_integer_batches(
|
|
1319
|
+
mismatch_rev,
|
|
1320
|
+
tmp_dir,
|
|
1321
|
+
record,
|
|
1322
|
+
f"{bami}_mismatch_rev",
|
|
1323
|
+
batch_size=100000,
|
|
1324
|
+
)
|
|
1325
|
+
mismatch_batch_files[f"{bami}_{record}"] = mismatch_fwd_files + mismatch_rev_files
|
|
1326
|
+
quality_fwd_files = _write_integer_batches(
|
|
1327
|
+
quality_fwd,
|
|
1328
|
+
tmp_dir,
|
|
1329
|
+
record,
|
|
1330
|
+
f"{bami}_quality_fwd",
|
|
1331
|
+
batch_size=100000,
|
|
1332
|
+
)
|
|
1333
|
+
quality_rev_files = _write_integer_batches(
|
|
1334
|
+
quality_rev,
|
|
1335
|
+
tmp_dir,
|
|
1336
|
+
record,
|
|
1337
|
+
f"{bami}_quality_rev",
|
|
1338
|
+
batch_size=100000,
|
|
1339
|
+
)
|
|
1340
|
+
quality_batch_files[f"{bami}_{record}"] = quality_fwd_files + quality_rev_files
|
|
1341
|
+
read_span_fwd_files = _write_integer_batches(
|
|
1342
|
+
read_span_fwd,
|
|
1343
|
+
tmp_dir,
|
|
1344
|
+
record,
|
|
1345
|
+
f"{bami}_read_span_fwd",
|
|
1346
|
+
batch_size=100000,
|
|
1347
|
+
)
|
|
1348
|
+
read_span_rev_files = _write_integer_batches(
|
|
1349
|
+
read_span_rev,
|
|
1350
|
+
tmp_dir,
|
|
1351
|
+
record,
|
|
1352
|
+
f"{bami}_read_span_rev",
|
|
1353
|
+
batch_size=100000,
|
|
1354
|
+
)
|
|
1355
|
+
read_span_batch_files[f"{bami}_{record}"] = (
|
|
1356
|
+
read_span_fwd_files + read_span_rev_files
|
|
1357
|
+
)
|
|
1358
|
+
del (
|
|
1359
|
+
fwd_base_identities,
|
|
1360
|
+
rev_base_identities,
|
|
1361
|
+
mismatch_base_identities,
|
|
1362
|
+
base_quality_scores,
|
|
1363
|
+
read_span_masks,
|
|
1364
|
+
)
|
|
1365
|
+
ad.AnnData(
|
|
1366
|
+
X=np.random.rand(1, 1),
|
|
1367
|
+
uns={
|
|
1368
|
+
"sequence_batch_files": sequence_batch_files,
|
|
1369
|
+
"mismatch_batch_files": mismatch_batch_files,
|
|
1370
|
+
"quality_batch_files": quality_batch_files,
|
|
1371
|
+
"read_span_batch_files": read_span_batch_files,
|
|
1372
|
+
},
|
|
1373
|
+
).write_h5ad(sequence_cache_path)
|
|
665
1374
|
##########################################################################################
|
|
666
1375
|
|
|
667
1376
|
##########################################################################################
|
|
@@ -703,47 +1412,9 @@ def modkit_extract_to_adata(
|
|
|
703
1412
|
###################################################
|
|
704
1413
|
### Add the tsvs as dataframes to a dictionary (dict_total) keyed by integer index. Also make modification specific dictionaries and strand specific dictionaries.
|
|
705
1414
|
# # Initialize dictionaries and place them in a list
|
|
706
|
-
(
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
dict_a_bottom,
|
|
710
|
-
dict_a_top,
|
|
711
|
-
dict_c,
|
|
712
|
-
dict_c_bottom,
|
|
713
|
-
dict_c_top,
|
|
714
|
-
dict_combined_bottom,
|
|
715
|
-
dict_combined_top,
|
|
716
|
-
) = {}, {}, {}, {}, {}, {}, {}, {}, {}
|
|
717
|
-
dict_list = [
|
|
718
|
-
dict_total,
|
|
719
|
-
dict_a,
|
|
720
|
-
dict_a_bottom,
|
|
721
|
-
dict_a_top,
|
|
722
|
-
dict_c,
|
|
723
|
-
dict_c_bottom,
|
|
724
|
-
dict_c_top,
|
|
725
|
-
dict_combined_bottom,
|
|
726
|
-
dict_combined_top,
|
|
727
|
-
]
|
|
728
|
-
# Give names to represent each dictionary in the list
|
|
729
|
-
sample_types = [
|
|
730
|
-
"total",
|
|
731
|
-
"m6A",
|
|
732
|
-
"m6A_bottom_strand",
|
|
733
|
-
"m6A_top_strand",
|
|
734
|
-
"5mC",
|
|
735
|
-
"5mC_bottom_strand",
|
|
736
|
-
"5mC_top_strand",
|
|
737
|
-
"combined_bottom_strand",
|
|
738
|
-
"combined_top_strand",
|
|
739
|
-
]
|
|
740
|
-
# Give indices of dictionaries to skip for analysis and final dictionary saving.
|
|
741
|
-
dict_to_skip = [0, 1, 4]
|
|
742
|
-
combined_dicts = [7, 8]
|
|
743
|
-
A_stranded_dicts = [2, 3]
|
|
744
|
-
C_stranded_dicts = [5, 6]
|
|
745
|
-
dict_to_skip = dict_to_skip + combined_dicts + A_stranded_dicts + C_stranded_dicts
|
|
746
|
-
dict_to_skip = set(dict_to_skip)
|
|
1415
|
+
batch_dicts = ModkitBatchDictionaries()
|
|
1416
|
+
dict_list = batch_dicts.as_list()
|
|
1417
|
+
sample_types = batch_dicts.sample_types
|
|
747
1418
|
|
|
748
1419
|
# # Step 1):Load the dict_total dictionary with all of the batch tsv files as dataframes.
|
|
749
1420
|
dict_total = parallel_load_tsvs(
|
|
@@ -755,140 +1426,9 @@ def modkit_extract_to_adata(
|
|
|
755
1426
|
threads=threads,
|
|
756
1427
|
)
|
|
757
1428
|
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
# c5m_dict, c5m_minus_strand, c5m_plus_strand,
|
|
762
|
-
# combined_minus_strand, combined_plus_strand) = merge_modification_results(processed_mod_results, mods)
|
|
763
|
-
|
|
764
|
-
# # Create dict_list with the desired ordering:
|
|
765
|
-
# # 0: dict_total, 1: m6A, 2: m6A_minus, 3: m6A_plus, 4: 5mC, 5: 5mC_minus, 6: 5mC_plus, 7: combined_minus, 8: combined_plus
|
|
766
|
-
# dict_list = [dict_total, m6A_dict, m6A_minus_strand, m6A_plus_strand,
|
|
767
|
-
# c5m_dict, c5m_minus_strand, c5m_plus_strand,
|
|
768
|
-
# combined_minus_strand, combined_plus_strand]
|
|
769
|
-
|
|
770
|
-
# # Initialize dict_to_skip (default skip all mod-specific indices)
|
|
771
|
-
# dict_to_skip = set([0, 1, 4, 7, 8, 2, 3, 5, 6])
|
|
772
|
-
# # Update dict_to_skip based on modifications present in mods
|
|
773
|
-
# dict_to_skip = update_dict_to_skip(dict_to_skip, mods)
|
|
774
|
-
|
|
775
|
-
# # Step 3: Process stranded methylation data in parallel
|
|
776
|
-
# dict_list = parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=threads or 4)
|
|
777
|
-
|
|
778
|
-
# Iterate over dict_total of all the tsv files and extract the modification specific and strand specific dataframes into dictionaries
|
|
779
|
-
for record in dict_total.keys():
|
|
780
|
-
for sample_index in dict_total[record].keys():
|
|
781
|
-
if "6mA" in mods:
|
|
782
|
-
# Remove Adenine stranded dicts from the dicts to skip set
|
|
783
|
-
dict_to_skip.difference_update(set(A_stranded_dicts))
|
|
784
|
-
|
|
785
|
-
if (
|
|
786
|
-
record not in dict_a.keys()
|
|
787
|
-
and record not in dict_a_bottom.keys()
|
|
788
|
-
and record not in dict_a_top.keys()
|
|
789
|
-
):
|
|
790
|
-
dict_a[record], dict_a_bottom[record], dict_a_top[record] = {}, {}, {}
|
|
791
|
-
|
|
792
|
-
# get a dictionary of dataframes that only contain methylated adenine positions
|
|
793
|
-
dict_a[record][sample_index] = dict_total[record][sample_index][
|
|
794
|
-
dict_total[record][sample_index]["modified_primary_base"] == "A"
|
|
795
|
-
]
|
|
796
|
-
logger.debug(
|
|
797
|
-
"Successfully loaded a methyl-adenine dictionary for {}".format(
|
|
798
|
-
str(sample_index)
|
|
799
|
-
)
|
|
800
|
-
)
|
|
801
|
-
|
|
802
|
-
# Stratify the adenine dictionary into two strand specific dictionaries.
|
|
803
|
-
dict_a_bottom[record][sample_index] = dict_a[record][sample_index][
|
|
804
|
-
dict_a[record][sample_index]["ref_strand"] == "-"
|
|
805
|
-
]
|
|
806
|
-
logger.debug(
|
|
807
|
-
"Successfully loaded a minus strand methyl-adenine dictionary for {}".format(
|
|
808
|
-
str(sample_index)
|
|
809
|
-
)
|
|
810
|
-
)
|
|
811
|
-
dict_a_top[record][sample_index] = dict_a[record][sample_index][
|
|
812
|
-
dict_a[record][sample_index]["ref_strand"] == "+"
|
|
813
|
-
]
|
|
814
|
-
logger.debug(
|
|
815
|
-
"Successfully loaded a plus strand methyl-adenine dictionary for ".format(
|
|
816
|
-
str(sample_index)
|
|
817
|
-
)
|
|
818
|
-
)
|
|
819
|
-
|
|
820
|
-
# Reassign pointer for dict_a to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
821
|
-
dict_a[record][sample_index] = None
|
|
822
|
-
gc.collect()
|
|
823
|
-
|
|
824
|
-
if "5mC" in mods:
|
|
825
|
-
# Remove Cytosine stranded dicts from the dicts to skip set
|
|
826
|
-
dict_to_skip.difference_update(set(C_stranded_dicts))
|
|
827
|
-
|
|
828
|
-
if (
|
|
829
|
-
record not in dict_c.keys()
|
|
830
|
-
and record not in dict_c_bottom.keys()
|
|
831
|
-
and record not in dict_c_top.keys()
|
|
832
|
-
):
|
|
833
|
-
dict_c[record], dict_c_bottom[record], dict_c_top[record] = {}, {}, {}
|
|
834
|
-
|
|
835
|
-
# get a dictionary of dataframes that only contain methylated cytosine positions
|
|
836
|
-
dict_c[record][sample_index] = dict_total[record][sample_index][
|
|
837
|
-
dict_total[record][sample_index]["modified_primary_base"] == "C"
|
|
838
|
-
]
|
|
839
|
-
logger.debug(
|
|
840
|
-
"Successfully loaded a methyl-cytosine dictionary for {}".format(
|
|
841
|
-
str(sample_index)
|
|
842
|
-
)
|
|
843
|
-
)
|
|
844
|
-
# Stratify the cytosine dictionary into two strand specific dictionaries.
|
|
845
|
-
dict_c_bottom[record][sample_index] = dict_c[record][sample_index][
|
|
846
|
-
dict_c[record][sample_index]["ref_strand"] == "-"
|
|
847
|
-
]
|
|
848
|
-
logger.debug(
|
|
849
|
-
"Successfully loaded a minus strand methyl-cytosine dictionary for {}".format(
|
|
850
|
-
str(sample_index)
|
|
851
|
-
)
|
|
852
|
-
)
|
|
853
|
-
dict_c_top[record][sample_index] = dict_c[record][sample_index][
|
|
854
|
-
dict_c[record][sample_index]["ref_strand"] == "+"
|
|
855
|
-
]
|
|
856
|
-
logger.debug(
|
|
857
|
-
"Successfully loaded a plus strand methyl-cytosine dictionary for {}".format(
|
|
858
|
-
str(sample_index)
|
|
859
|
-
)
|
|
860
|
-
)
|
|
861
|
-
# Reassign pointer for dict_c to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
862
|
-
dict_c[record][sample_index] = None
|
|
863
|
-
gc.collect()
|
|
864
|
-
|
|
865
|
-
if "6mA" in mods and "5mC" in mods:
|
|
866
|
-
# Remove combined stranded dicts from the dicts to skip set
|
|
867
|
-
dict_to_skip.difference_update(set(combined_dicts))
|
|
868
|
-
# Initialize the sample keys for the combined dictionaries
|
|
869
|
-
|
|
870
|
-
if (
|
|
871
|
-
record not in dict_combined_bottom.keys()
|
|
872
|
-
and record not in dict_combined_top.keys()
|
|
873
|
-
):
|
|
874
|
-
dict_combined_bottom[record], dict_combined_top[record] = {}, {}
|
|
875
|
-
|
|
876
|
-
logger.debug(
|
|
877
|
-
"Successfully created a minus strand combined methylation dictionary for {}".format(
|
|
878
|
-
str(sample_index)
|
|
879
|
-
)
|
|
880
|
-
)
|
|
881
|
-
dict_combined_bottom[record][sample_index] = []
|
|
882
|
-
logger.debug(
|
|
883
|
-
"Successfully created a plus strand combined methylation dictionary for {}".format(
|
|
884
|
-
str(sample_index)
|
|
885
|
-
)
|
|
886
|
-
)
|
|
887
|
-
dict_combined_top[record][sample_index] = []
|
|
888
|
-
|
|
889
|
-
# Reassign pointer for dict_total to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
890
|
-
dict_total[record][sample_index] = None
|
|
891
|
-
gc.collect()
|
|
1429
|
+
batch_dicts, dict_to_skip = _build_modification_dicts(dict_total, mods)
|
|
1430
|
+
dict_list = batch_dicts.as_list()
|
|
1431
|
+
sample_types = batch_dicts.sample_types
|
|
892
1432
|
|
|
893
1433
|
# Iterate over the stranded modification dictionaries and replace the dataframes with a dictionary of read names pointing to a list of values from the dataframe
|
|
894
1434
|
for dict_index, dict_type in enumerate(dict_list):
|
|
@@ -974,14 +1514,14 @@ def modkit_extract_to_adata(
|
|
|
974
1514
|
mod_strand_record_sample_dict[sample] = {}
|
|
975
1515
|
|
|
976
1516
|
# Get relevant columns as NumPy arrays
|
|
977
|
-
read_ids = temp_df[
|
|
978
|
-
positions = temp_df[
|
|
979
|
-
call_codes = temp_df[
|
|
980
|
-
probabilities = temp_df[
|
|
1517
|
+
read_ids = temp_df[MODKIT_EXTRACT_TSV_COLUMN_READ_ID].values
|
|
1518
|
+
positions = temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION].values
|
|
1519
|
+
call_codes = temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_CODE].values
|
|
1520
|
+
probabilities = temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_PROB].values
|
|
981
1521
|
|
|
982
1522
|
# Define valid call code categories
|
|
983
|
-
modified_codes =
|
|
984
|
-
canonical_codes =
|
|
1523
|
+
modified_codes = MODKIT_EXTRACT_CALL_CODE_MODIFIED
|
|
1524
|
+
canonical_codes = MODKIT_EXTRACT_CALL_CODE_CANONICAL
|
|
985
1525
|
|
|
986
1526
|
# Vectorized methylation calculation with NaN for other codes
|
|
987
1527
|
methylation_prob = np.full_like(
|
|
@@ -1077,39 +1617,63 @@ def modkit_extract_to_adata(
|
|
|
1077
1617
|
final_sample_index,
|
|
1078
1618
|
)
|
|
1079
1619
|
)
|
|
1080
|
-
temp_adata.obs[
|
|
1620
|
+
temp_adata.obs[SAMPLE] = [
|
|
1081
1621
|
sample_name_map[final_sample_index]
|
|
1082
1622
|
] * len(temp_adata)
|
|
1083
|
-
temp_adata.obs[
|
|
1623
|
+
temp_adata.obs[BARCODE] = [barcode_map[final_sample_index]] * len(
|
|
1084
1624
|
temp_adata
|
|
1085
1625
|
)
|
|
1086
|
-
temp_adata.obs[
|
|
1087
|
-
temp_adata.obs[
|
|
1088
|
-
temp_adata.obs[
|
|
1089
|
-
temp_adata.obs[
|
|
1626
|
+
temp_adata.obs[REFERENCE] = [f"{record}"] * len(temp_adata)
|
|
1627
|
+
temp_adata.obs[STRAND] = [strand] * len(temp_adata)
|
|
1628
|
+
temp_adata.obs[DATASET] = [dataset] * len(temp_adata)
|
|
1629
|
+
temp_adata.obs[REFERENCE_DATASET_STRAND] = [
|
|
1090
1630
|
f"{record}_{dataset}_{strand}"
|
|
1091
1631
|
] * len(temp_adata)
|
|
1092
|
-
temp_adata.obs[
|
|
1632
|
+
temp_adata.obs[REFERENCE_STRAND] = [f"{record}_{strand}"] * len(
|
|
1093
1633
|
temp_adata
|
|
1094
1634
|
)
|
|
1095
1635
|
|
|
1096
|
-
# Load
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1636
|
+
# Load integer-encoded reads for the current sample/record
|
|
1637
|
+
sequence_files = _normalize_sequence_batch_files(
|
|
1638
|
+
sequence_batch_files.get(f"{final_sample_index}_{record}", [])
|
|
1639
|
+
)
|
|
1640
|
+
mismatch_files = _normalize_sequence_batch_files(
|
|
1641
|
+
mismatch_batch_files.get(f"{final_sample_index}_{record}", [])
|
|
1642
|
+
)
|
|
1643
|
+
quality_files = _normalize_sequence_batch_files(
|
|
1644
|
+
quality_batch_files.get(f"{final_sample_index}_{record}", [])
|
|
1645
|
+
)
|
|
1646
|
+
read_span_files = _normalize_sequence_batch_files(
|
|
1647
|
+
read_span_batch_files.get(f"{final_sample_index}_{record}", [])
|
|
1648
|
+
)
|
|
1649
|
+
if not sequence_files:
|
|
1650
|
+
logger.warning(
|
|
1651
|
+
"No encoded sequence batches found for sample %s record %s",
|
|
1652
|
+
final_sample_index,
|
|
1653
|
+
record,
|
|
1654
|
+
)
|
|
1655
|
+
continue
|
|
1656
|
+
logger.info(f"Loading encoded sequences from {sequence_files}")
|
|
1657
|
+
(
|
|
1658
|
+
encoded_reads,
|
|
1659
|
+
fwd_mapped_reads,
|
|
1660
|
+
rev_mapped_reads,
|
|
1661
|
+
) = _load_sequence_batches(sequence_files)
|
|
1662
|
+
mismatch_reads: dict[str, np.ndarray] = {}
|
|
1663
|
+
if mismatch_files:
|
|
1664
|
+
(
|
|
1665
|
+
mismatch_reads,
|
|
1666
|
+
_mismatch_fwd_reads,
|
|
1667
|
+
_mismatch_rev_reads,
|
|
1668
|
+
) = _load_sequence_batches(mismatch_files)
|
|
1669
|
+
quality_reads: dict[str, np.ndarray] = {}
|
|
1670
|
+
if quality_files:
|
|
1671
|
+
quality_reads = _load_integer_batches(quality_files)
|
|
1672
|
+
read_span_reads: dict[str, np.ndarray] = {}
|
|
1673
|
+
if read_span_files:
|
|
1674
|
+
read_span_reads = _load_integer_batches(read_span_files)
|
|
1675
|
+
|
|
1676
|
+
read_names = list(encoded_reads.keys())
|
|
1113
1677
|
|
|
1114
1678
|
read_mapping_direction = []
|
|
1115
1679
|
for read_id in temp_adata.obs_names:
|
|
@@ -1120,57 +1684,69 @@ def modkit_extract_to_adata(
|
|
|
1120
1684
|
else:
|
|
1121
1685
|
read_mapping_direction.append("unk")
|
|
1122
1686
|
|
|
1123
|
-
temp_adata.obs[
|
|
1687
|
+
temp_adata.obs[READ_MAPPING_DIRECTION] = read_mapping_direction
|
|
1124
1688
|
|
|
1125
1689
|
del temp_df
|
|
1126
1690
|
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1691
|
+
padding_value = MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[
|
|
1692
|
+
MODKIT_EXTRACT_SEQUENCE_PADDING_BASE
|
|
1693
|
+
]
|
|
1694
|
+
sequence_length = encoded_reads[read_names[0]].shape[0]
|
|
1695
|
+
encoded_matrix = np.full(
|
|
1696
|
+
(len(sorted_index), sequence_length),
|
|
1697
|
+
padding_value,
|
|
1698
|
+
dtype=np.int16,
|
|
1130
1699
|
)
|
|
1131
|
-
df_A = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
1132
|
-
df_C = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
1133
|
-
df_G = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
1134
|
-
df_T = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
1135
|
-
df_N = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
1136
|
-
|
|
1137
|
-
# Process one-hot data into dictionaries
|
|
1138
|
-
dict_A, dict_C, dict_G, dict_T, dict_N = {}, {}, {}, {}, {}
|
|
1139
|
-
for read_name, one_hot_array in one_hot_reads.items():
|
|
1140
|
-
one_hot_array = one_hot_array.reshape(n_rows_OHE, -1)
|
|
1141
|
-
dict_A[read_name] = one_hot_array[0, :]
|
|
1142
|
-
dict_C[read_name] = one_hot_array[1, :]
|
|
1143
|
-
dict_G[read_name] = one_hot_array[2, :]
|
|
1144
|
-
dict_T[read_name] = one_hot_array[3, :]
|
|
1145
|
-
dict_N[read_name] = one_hot_array[4, :]
|
|
1146
|
-
|
|
1147
|
-
del one_hot_reads
|
|
1148
|
-
gc.collect()
|
|
1149
1700
|
|
|
1150
|
-
# Fill the arrays
|
|
1151
1701
|
for j, read_name in tqdm(
|
|
1152
1702
|
enumerate(sorted_index),
|
|
1153
|
-
desc="Loading
|
|
1703
|
+
desc="Loading integer-encoded reads",
|
|
1154
1704
|
total=len(sorted_index),
|
|
1155
1705
|
):
|
|
1156
|
-
|
|
1157
|
-
df_C[j, :] = dict_C[read_name]
|
|
1158
|
-
df_G[j, :] = dict_G[read_name]
|
|
1159
|
-
df_T[j, :] = dict_T[read_name]
|
|
1160
|
-
df_N[j, :] = dict_N[read_name]
|
|
1706
|
+
encoded_matrix[j, :] = encoded_reads[read_name]
|
|
1161
1707
|
|
|
1162
|
-
del
|
|
1708
|
+
del encoded_reads
|
|
1163
1709
|
gc.collect()
|
|
1164
1710
|
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1711
|
+
temp_adata.layers[SEQUENCE_INTEGER_ENCODING] = encoded_matrix
|
|
1712
|
+
if mismatch_reads:
|
|
1713
|
+
current_reference_length = reference_dict[record][0]
|
|
1714
|
+
default_mismatch_sequence = np.full(
|
|
1715
|
+
sequence_length,
|
|
1716
|
+
MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT["N"],
|
|
1717
|
+
dtype=np.int16,
|
|
1170
1718
|
)
|
|
1171
|
-
|
|
1172
|
-
|
|
1719
|
+
if current_reference_length < sequence_length:
|
|
1720
|
+
default_mismatch_sequence[current_reference_length:] = (
|
|
1721
|
+
padding_value
|
|
1722
|
+
)
|
|
1723
|
+
mismatch_matrix = np.vstack(
|
|
1724
|
+
[
|
|
1725
|
+
mismatch_reads.get(read_name, default_mismatch_sequence)
|
|
1726
|
+
for read_name in sorted_index
|
|
1727
|
+
]
|
|
1728
|
+
)
|
|
1729
|
+
temp_adata.layers[MISMATCH_INTEGER_ENCODING] = mismatch_matrix
|
|
1730
|
+
if quality_reads:
|
|
1731
|
+
default_quality_sequence = np.full(
|
|
1732
|
+
sequence_length, -1, dtype=np.int16
|
|
1733
|
+
)
|
|
1734
|
+
quality_matrix = np.vstack(
|
|
1735
|
+
[
|
|
1736
|
+
quality_reads.get(read_name, default_quality_sequence)
|
|
1737
|
+
for read_name in sorted_index
|
|
1738
|
+
]
|
|
1173
1739
|
)
|
|
1740
|
+
temp_adata.layers[BASE_QUALITY_SCORES] = quality_matrix
|
|
1741
|
+
if read_span_reads:
|
|
1742
|
+
default_read_span = np.zeros(sequence_length, dtype=np.int16)
|
|
1743
|
+
read_span_matrix = np.vstack(
|
|
1744
|
+
[
|
|
1745
|
+
read_span_reads.get(read_name, default_read_span)
|
|
1746
|
+
for read_name in sorted_index
|
|
1747
|
+
]
|
|
1748
|
+
)
|
|
1749
|
+
temp_adata.layers[READ_SPAN_MASK] = read_span_matrix
|
|
1174
1750
|
|
|
1175
1751
|
# If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
|
|
1176
1752
|
if adata:
|
|
@@ -1263,8 +1839,14 @@ def modkit_extract_to_adata(
|
|
|
1263
1839
|
for col in final_adata.obs.columns:
|
|
1264
1840
|
final_adata.obs[col] = final_adata.obs[col].astype("category")
|
|
1265
1841
|
|
|
1266
|
-
|
|
1267
|
-
|
|
1842
|
+
final_adata.uns[f"{SEQUENCE_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
|
|
1843
|
+
final_adata.uns[f"{MISMATCH_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
|
|
1844
|
+
final_adata.uns[f"{SEQUENCE_INTEGER_DECODING}_map"] = {
|
|
1845
|
+
str(key): value for key, value in MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE.items()
|
|
1846
|
+
}
|
|
1847
|
+
|
|
1848
|
+
consensus_bases = MODKIT_EXTRACT_SEQUENCE_BASES[:4] # ignore N/PAD for consensus
|
|
1849
|
+
consensus_base_ints = [MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[base] for base in consensus_bases]
|
|
1268
1850
|
final_adata.uns["References"] = {}
|
|
1269
1851
|
for record in records_to_analyze:
|
|
1270
1852
|
# Add FASTA sequence to the object
|
|
@@ -1275,27 +1857,33 @@ def modkit_extract_to_adata(
|
|
|
1275
1857
|
final_adata.uns[f"{record}_FASTA_sequence"] = sequence
|
|
1276
1858
|
final_adata.uns["References"][f"{record}_FASTA_sequence"] = sequence
|
|
1277
1859
|
# Add consensus sequence of samples mapped to the record to the object
|
|
1278
|
-
record_subset = final_adata[final_adata.obs[
|
|
1279
|
-
for strand in record_subset.obs[
|
|
1280
|
-
strand_subset = record_subset[record_subset.obs[
|
|
1281
|
-
for mapping_dir in strand_subset.obs[
|
|
1860
|
+
record_subset = final_adata[final_adata.obs[REFERENCE] == record]
|
|
1861
|
+
for strand in record_subset.obs[STRAND].cat.categories:
|
|
1862
|
+
strand_subset = record_subset[record_subset.obs[STRAND] == strand]
|
|
1863
|
+
for mapping_dir in strand_subset.obs[READ_MAPPING_DIRECTION].cat.categories:
|
|
1282
1864
|
mapping_dir_subset = strand_subset[
|
|
1283
|
-
strand_subset.obs[
|
|
1865
|
+
strand_subset.obs[READ_MAPPING_DIRECTION] == mapping_dir
|
|
1866
|
+
]
|
|
1867
|
+
encoded_sequences = mapping_dir_subset.layers[SEQUENCE_INTEGER_ENCODING]
|
|
1868
|
+
layer_counts = [
|
|
1869
|
+
np.sum(encoded_sequences == base_int, axis=0)
|
|
1870
|
+
for base_int in consensus_base_ints
|
|
1284
1871
|
]
|
|
1285
|
-
layer_map, layer_counts = {}, []
|
|
1286
|
-
for i, layer in enumerate(ohe_layers):
|
|
1287
|
-
layer_map[i] = layer.split("_")[0]
|
|
1288
|
-
layer_counts.append(np.sum(mapping_dir_subset.layers[layer], axis=0))
|
|
1289
1872
|
count_array = np.array(layer_counts)
|
|
1290
1873
|
nucleotide_indexes = np.argmax(count_array, axis=0)
|
|
1291
|
-
consensus_sequence_list = [
|
|
1874
|
+
consensus_sequence_list = [consensus_bases[i] for i in nucleotide_indexes]
|
|
1875
|
+
no_calls_mask = np.sum(count_array, axis=0) == 0
|
|
1876
|
+
if np.any(no_calls_mask):
|
|
1877
|
+
consensus_sequence_list = np.array(consensus_sequence_list, dtype=object)
|
|
1878
|
+
consensus_sequence_list[no_calls_mask] = "N"
|
|
1879
|
+
consensus_sequence_list = consensus_sequence_list.tolist()
|
|
1292
1880
|
final_adata.var[
|
|
1293
1881
|
f"{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples"
|
|
1294
1882
|
] = consensus_sequence_list
|
|
1295
1883
|
|
|
1296
1884
|
if input_already_demuxed:
|
|
1297
|
-
final_adata.obs[
|
|
1298
|
-
final_adata.obs[
|
|
1885
|
+
final_adata.obs[DEMUX_TYPE] = ["already"] * final_adata.shape[0]
|
|
1886
|
+
final_adata.obs[DEMUX_TYPE] = final_adata.obs[DEMUX_TYPE].astype("category")
|
|
1299
1887
|
else:
|
|
1300
1888
|
from .h5ad_functions import add_demux_type_annotation
|
|
1301
1889
|
|