smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,64 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import concurrent.futures
|
|
2
4
|
import gc
|
|
3
|
-
|
|
5
|
+
import re
|
|
6
|
+
import shutil
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Iterable, Optional, Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
4
11
|
import pandas as pd
|
|
5
12
|
from tqdm import tqdm
|
|
6
|
-
import numpy as np
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
from typing import Union, Iterable, Optional
|
|
9
|
-
import shutil
|
|
10
13
|
|
|
11
|
-
|
|
14
|
+
from smftools.logging_utils import get_logger
|
|
15
|
+
|
|
16
|
+
from .bam_functions import count_aligned_reads
|
|
17
|
+
|
|
18
|
+
logger = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def filter_bam_records(bam, mapping_threshold, samtools_backend: str | None = "auto"):
|
|
12
22
|
"""Processes a single BAM file, counts reads, and determines records to analyze."""
|
|
13
|
-
aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(
|
|
14
|
-
|
|
23
|
+
aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(
|
|
24
|
+
bam, samtools_backend
|
|
25
|
+
)
|
|
26
|
+
|
|
15
27
|
total_reads = aligned_reads_count + unaligned_reads_count
|
|
16
28
|
percent_aligned = (aligned_reads_count * 100 / total_reads) if total_reads > 0 else 0
|
|
17
|
-
|
|
29
|
+
logger.info(f"{percent_aligned:.2f}% of reads in {bam} aligned successfully")
|
|
18
30
|
|
|
19
31
|
records = []
|
|
20
32
|
for record, (count, percentage) in record_counts_dict.items():
|
|
21
|
-
|
|
33
|
+
logger.info(
|
|
34
|
+
f"{count} reads mapped to reference {record}. This is {percentage * 100:.2f}% of all mapped reads in {bam}"
|
|
35
|
+
)
|
|
22
36
|
if percentage >= mapping_threshold:
|
|
23
37
|
records.append(record)
|
|
24
|
-
|
|
38
|
+
|
|
25
39
|
return set(records)
|
|
26
40
|
|
|
27
|
-
|
|
41
|
+
|
|
42
|
+
def parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend: str | None = "auto"):
|
|
28
43
|
"""Parallel processing for multiple BAM files."""
|
|
29
44
|
records_to_analyze = set()
|
|
30
45
|
|
|
31
46
|
with concurrent.futures.ProcessPoolExecutor() as executor:
|
|
32
|
-
results = executor.map(
|
|
47
|
+
results = executor.map(
|
|
48
|
+
filter_bam_records,
|
|
49
|
+
bam_path_list,
|
|
50
|
+
[mapping_threshold] * len(bam_path_list),
|
|
51
|
+
[samtools_backend] * len(bam_path_list),
|
|
52
|
+
)
|
|
33
53
|
|
|
34
54
|
# Aggregate results
|
|
35
55
|
for result in results:
|
|
36
56
|
records_to_analyze.update(result)
|
|
37
57
|
|
|
38
|
-
|
|
58
|
+
logger.info(f"Records to analyze: {records_to_analyze}")
|
|
39
59
|
return records_to_analyze
|
|
40
60
|
|
|
61
|
+
|
|
41
62
|
def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
|
|
42
63
|
"""
|
|
43
64
|
Loads and filters a single TSV file based on chromosome and position criteria.
|
|
44
65
|
"""
|
|
45
|
-
temp_df = pd.read_csv(tsv, sep=
|
|
66
|
+
temp_df = pd.read_csv(tsv, sep="\t", header=0)
|
|
46
67
|
filtered_records = {}
|
|
47
68
|
|
|
48
69
|
for record in records_to_analyze:
|
|
49
70
|
if record not in reference_dict:
|
|
50
71
|
continue
|
|
51
|
-
|
|
72
|
+
|
|
52
73
|
ref_length = reference_dict[record][0]
|
|
53
|
-
filtered_df = temp_df[
|
|
54
|
-
|
|
55
|
-
|
|
74
|
+
filtered_df = temp_df[
|
|
75
|
+
(temp_df["chrom"] == record)
|
|
76
|
+
& (temp_df["ref_position"] >= 0)
|
|
77
|
+
& (temp_df["ref_position"] < ref_length)
|
|
78
|
+
]
|
|
56
79
|
|
|
57
80
|
if not filtered_df.empty:
|
|
58
81
|
filtered_records[record] = {sample_index: filtered_df}
|
|
59
82
|
|
|
60
83
|
return filtered_records
|
|
61
84
|
|
|
85
|
+
|
|
62
86
|
def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size, threads=4):
|
|
63
87
|
"""
|
|
64
88
|
Loads and filters TSV files in parallel.
|
|
@@ -78,52 +102,60 @@ def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, bat
|
|
|
78
102
|
|
|
79
103
|
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
|
|
80
104
|
futures = {
|
|
81
|
-
executor.submit(
|
|
105
|
+
executor.submit(
|
|
106
|
+
process_tsv, tsv, records_to_analyze, reference_dict, sample_index
|
|
107
|
+
): sample_index
|
|
82
108
|
for sample_index, tsv in enumerate(tsv_batch)
|
|
83
109
|
}
|
|
84
110
|
|
|
85
|
-
for future in tqdm(
|
|
111
|
+
for future in tqdm(
|
|
112
|
+
concurrent.futures.as_completed(futures),
|
|
113
|
+
desc=f"Processing batch {batch}",
|
|
114
|
+
total=batch_size,
|
|
115
|
+
):
|
|
86
116
|
result = future.result()
|
|
87
117
|
for record, sample_data in result.items():
|
|
88
118
|
dict_total[record].update(sample_data)
|
|
89
119
|
|
|
90
120
|
return dict_total
|
|
91
121
|
|
|
122
|
+
|
|
92
123
|
def update_dict_to_skip(dict_to_skip, detected_modifications):
|
|
93
124
|
"""
|
|
94
125
|
Updates the dict_to_skip set based on the detected modifications.
|
|
95
|
-
|
|
126
|
+
|
|
96
127
|
Parameters:
|
|
97
128
|
dict_to_skip (set): The initial set of dictionary indices to skip.
|
|
98
129
|
detected_modifications (list or set): The modifications (e.g. ['6mA', '5mC']) present.
|
|
99
|
-
|
|
130
|
+
|
|
100
131
|
Returns:
|
|
101
132
|
set: The updated dict_to_skip set.
|
|
102
133
|
"""
|
|
103
134
|
# Define which indices correspond to modification-specific or strand-specific dictionaries
|
|
104
|
-
A_stranded_dicts = {2, 3}
|
|
105
|
-
C_stranded_dicts = {5, 6}
|
|
106
|
-
combined_dicts
|
|
135
|
+
A_stranded_dicts = {2, 3} # m6A bottom and top strand dictionaries
|
|
136
|
+
C_stranded_dicts = {5, 6} # 5mC bottom and top strand dictionaries
|
|
137
|
+
combined_dicts = {7, 8} # Combined strand dictionaries
|
|
107
138
|
|
|
108
139
|
# If '6mA' is present, remove the A_stranded indices from the skip set
|
|
109
|
-
if
|
|
140
|
+
if "6mA" in detected_modifications:
|
|
110
141
|
dict_to_skip -= A_stranded_dicts
|
|
111
142
|
# If '5mC' is present, remove the C_stranded indices from the skip set
|
|
112
|
-
if
|
|
143
|
+
if "5mC" in detected_modifications:
|
|
113
144
|
dict_to_skip -= C_stranded_dicts
|
|
114
145
|
# If both modifications are present, remove the combined indices from the skip set
|
|
115
|
-
if
|
|
146
|
+
if "6mA" in detected_modifications and "5mC" in detected_modifications:
|
|
116
147
|
dict_to_skip -= combined_dicts
|
|
117
148
|
|
|
118
149
|
return dict_to_skip
|
|
119
150
|
|
|
151
|
+
|
|
120
152
|
def process_modifications_for_sample(args):
|
|
121
153
|
"""
|
|
122
154
|
Processes a single (record, sample) pair to extract modification-specific data.
|
|
123
|
-
|
|
155
|
+
|
|
124
156
|
Parameters:
|
|
125
157
|
args: (record, sample_index, sample_df, mods, max_reference_length)
|
|
126
|
-
|
|
158
|
+
|
|
127
159
|
Returns:
|
|
128
160
|
(record, sample_index, result) where result is a dict with keys:
|
|
129
161
|
'm6A', 'm6A_minus', 'm6A_plus', '5mC', '5mC_minus', '5mC_plus', and
|
|
@@ -131,29 +163,30 @@ def process_modifications_for_sample(args):
|
|
|
131
163
|
"""
|
|
132
164
|
record, sample_index, sample_df, mods, max_reference_length = args
|
|
133
165
|
result = {}
|
|
134
|
-
if
|
|
135
|
-
m6a_df = sample_df[sample_df[
|
|
136
|
-
result[
|
|
137
|
-
result[
|
|
138
|
-
result[
|
|
166
|
+
if "6mA" in mods:
|
|
167
|
+
m6a_df = sample_df[sample_df["modified_primary_base"] == "A"]
|
|
168
|
+
result["m6A"] = m6a_df
|
|
169
|
+
result["m6A_minus"] = m6a_df[m6a_df["ref_strand"] == "-"]
|
|
170
|
+
result["m6A_plus"] = m6a_df[m6a_df["ref_strand"] == "+"]
|
|
139
171
|
m6a_df = None
|
|
140
172
|
gc.collect()
|
|
141
|
-
if
|
|
142
|
-
m5c_df = sample_df[sample_df[
|
|
143
|
-
result[
|
|
144
|
-
result[
|
|
145
|
-
result[
|
|
173
|
+
if "5mC" in mods:
|
|
174
|
+
m5c_df = sample_df[sample_df["modified_primary_base"] == "C"]
|
|
175
|
+
result["5mC"] = m5c_df
|
|
176
|
+
result["5mC_minus"] = m5c_df[m5c_df["ref_strand"] == "-"]
|
|
177
|
+
result["5mC_plus"] = m5c_df[m5c_df["ref_strand"] == "+"]
|
|
146
178
|
m5c_df = None
|
|
147
179
|
gc.collect()
|
|
148
|
-
if
|
|
149
|
-
result[
|
|
150
|
-
result[
|
|
180
|
+
if "6mA" in mods and "5mC" in mods:
|
|
181
|
+
result["combined_minus"] = []
|
|
182
|
+
result["combined_plus"] = []
|
|
151
183
|
return record, sample_index, result
|
|
152
184
|
|
|
185
|
+
|
|
153
186
|
def parallel_process_modifications(dict_total, mods, max_reference_length, threads=4):
|
|
154
187
|
"""
|
|
155
188
|
Processes each (record, sample) pair in dict_total in parallel to extract modification-specific data.
|
|
156
|
-
|
|
189
|
+
|
|
157
190
|
Returns:
|
|
158
191
|
processed_results: Dict keyed by record, with sub-dict keyed by sample index and the processed results.
|
|
159
192
|
"""
|
|
@@ -164,18 +197,20 @@ def parallel_process_modifications(dict_total, mods, max_reference_length, threa
|
|
|
164
197
|
processed_results = {}
|
|
165
198
|
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
|
|
166
199
|
for record, sample_index, result in tqdm(
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
200
|
+
executor.map(process_modifications_for_sample, tasks),
|
|
201
|
+
total=len(tasks),
|
|
202
|
+
desc="Processing modifications",
|
|
203
|
+
):
|
|
170
204
|
if record not in processed_results:
|
|
171
205
|
processed_results[record] = {}
|
|
172
206
|
processed_results[record][sample_index] = result
|
|
173
207
|
return processed_results
|
|
174
208
|
|
|
209
|
+
|
|
175
210
|
def merge_modification_results(processed_results, mods):
|
|
176
211
|
"""
|
|
177
212
|
Merges individual sample results into global dictionaries.
|
|
178
|
-
|
|
213
|
+
|
|
179
214
|
Returns:
|
|
180
215
|
A tuple: (m6A_dict, m6A_minus, m6A_plus, c5m_dict, c5m_minus, c5m_plus, combined_minus, combined_plus)
|
|
181
216
|
"""
|
|
@@ -189,44 +224,52 @@ def merge_modification_results(processed_results, mods):
|
|
|
189
224
|
combined_plus = {}
|
|
190
225
|
for record, sample_results in processed_results.items():
|
|
191
226
|
for sample_index, res in sample_results.items():
|
|
192
|
-
if
|
|
227
|
+
if "6mA" in mods:
|
|
193
228
|
if record not in m6A_dict:
|
|
194
229
|
m6A_dict[record], m6A_minus[record], m6A_plus[record] = {}, {}, {}
|
|
195
|
-
m6A_dict[record][sample_index] = res.get(
|
|
196
|
-
m6A_minus[record][sample_index] = res.get(
|
|
197
|
-
m6A_plus[record][sample_index] = res.get(
|
|
198
|
-
if
|
|
230
|
+
m6A_dict[record][sample_index] = res.get("m6A", pd.DataFrame())
|
|
231
|
+
m6A_minus[record][sample_index] = res.get("m6A_minus", pd.DataFrame())
|
|
232
|
+
m6A_plus[record][sample_index] = res.get("m6A_plus", pd.DataFrame())
|
|
233
|
+
if "5mC" in mods:
|
|
199
234
|
if record not in c5m_dict:
|
|
200
235
|
c5m_dict[record], c5m_minus[record], c5m_plus[record] = {}, {}, {}
|
|
201
|
-
c5m_dict[record][sample_index] = res.get(
|
|
202
|
-
c5m_minus[record][sample_index] = res.get(
|
|
203
|
-
c5m_plus[record][sample_index] = res.get(
|
|
204
|
-
if
|
|
236
|
+
c5m_dict[record][sample_index] = res.get("5mC", pd.DataFrame())
|
|
237
|
+
c5m_minus[record][sample_index] = res.get("5mC_minus", pd.DataFrame())
|
|
238
|
+
c5m_plus[record][sample_index] = res.get("5mC_plus", pd.DataFrame())
|
|
239
|
+
if "6mA" in mods and "5mC" in mods:
|
|
205
240
|
if record not in combined_minus:
|
|
206
241
|
combined_minus[record], combined_plus[record] = {}, {}
|
|
207
|
-
combined_minus[record][sample_index] = res.get(
|
|
208
|
-
combined_plus[record][sample_index] = res.get(
|
|
209
|
-
return (
|
|
210
|
-
|
|
211
|
-
|
|
242
|
+
combined_minus[record][sample_index] = res.get("combined_minus", [])
|
|
243
|
+
combined_plus[record][sample_index] = res.get("combined_plus", [])
|
|
244
|
+
return (
|
|
245
|
+
m6A_dict,
|
|
246
|
+
m6A_minus,
|
|
247
|
+
m6A_plus,
|
|
248
|
+
c5m_dict,
|
|
249
|
+
c5m_minus,
|
|
250
|
+
c5m_plus,
|
|
251
|
+
combined_minus,
|
|
252
|
+
combined_plus,
|
|
253
|
+
)
|
|
254
|
+
|
|
212
255
|
|
|
213
256
|
def process_stranded_methylation(args):
|
|
214
257
|
"""
|
|
215
258
|
Processes a single (dict_index, record, sample) task.
|
|
216
|
-
|
|
259
|
+
|
|
217
260
|
For combined dictionaries (indices 7 or 8), it merges the corresponding A-stranded and C-stranded data.
|
|
218
|
-
For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
|
|
261
|
+
For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
|
|
219
262
|
NumPy methylation array (of float type). Non-numeric values (e.g. '-') are coerced to NaN.
|
|
220
|
-
|
|
263
|
+
|
|
221
264
|
Parameters:
|
|
222
265
|
args: (dict_index, record, sample, dict_list, max_reference_length)
|
|
223
|
-
|
|
266
|
+
|
|
224
267
|
Returns:
|
|
225
268
|
(dict_index, record, sample, processed_data)
|
|
226
269
|
"""
|
|
227
270
|
dict_index, record, sample, dict_list, max_reference_length = args
|
|
228
271
|
processed_data = {}
|
|
229
|
-
|
|
272
|
+
|
|
230
273
|
# For combined bottom strand (index 7)
|
|
231
274
|
if dict_index == 7:
|
|
232
275
|
temp_a = dict_list[2][record].get(sample, {}).copy()
|
|
@@ -235,18 +278,18 @@ def process_stranded_methylation(args):
|
|
|
235
278
|
for read in set(temp_a.keys()) | set(temp_c.keys()):
|
|
236
279
|
if read in temp_a:
|
|
237
280
|
# Convert using pd.to_numeric with errors='coerce'
|
|
238
|
-
value_a = pd.to_numeric(np.array(temp_a[read]), errors=
|
|
281
|
+
value_a = pd.to_numeric(np.array(temp_a[read]), errors="coerce")
|
|
239
282
|
else:
|
|
240
283
|
value_a = None
|
|
241
284
|
if read in temp_c:
|
|
242
|
-
value_c = pd.to_numeric(np.array(temp_c[read]), errors=
|
|
285
|
+
value_c = pd.to_numeric(np.array(temp_c[read]), errors="coerce")
|
|
243
286
|
else:
|
|
244
287
|
value_c = None
|
|
245
288
|
if value_a is not None and value_c is not None:
|
|
246
289
|
processed_data[read] = np.where(
|
|
247
290
|
np.isnan(value_a) & np.isnan(value_c),
|
|
248
291
|
np.nan,
|
|
249
|
-
np.nan_to_num(value_a) + np.nan_to_num(value_c)
|
|
292
|
+
np.nan_to_num(value_a) + np.nan_to_num(value_c),
|
|
250
293
|
)
|
|
251
294
|
elif value_a is not None:
|
|
252
295
|
processed_data[read] = value_a
|
|
@@ -261,18 +304,18 @@ def process_stranded_methylation(args):
|
|
|
261
304
|
processed_data = {}
|
|
262
305
|
for read in set(temp_a.keys()) | set(temp_c.keys()):
|
|
263
306
|
if read in temp_a:
|
|
264
|
-
value_a = pd.to_numeric(np.array(temp_a[read]), errors=
|
|
307
|
+
value_a = pd.to_numeric(np.array(temp_a[read]), errors="coerce")
|
|
265
308
|
else:
|
|
266
309
|
value_a = None
|
|
267
310
|
if read in temp_c:
|
|
268
|
-
value_c = pd.to_numeric(np.array(temp_c[read]), errors=
|
|
311
|
+
value_c = pd.to_numeric(np.array(temp_c[read]), errors="coerce")
|
|
269
312
|
else:
|
|
270
313
|
value_c = None
|
|
271
314
|
if value_a is not None and value_c is not None:
|
|
272
315
|
processed_data[read] = np.where(
|
|
273
316
|
np.isnan(value_a) & np.isnan(value_c),
|
|
274
317
|
np.nan,
|
|
275
|
-
np.nan_to_num(value_a) + np.nan_to_num(value_c)
|
|
318
|
+
np.nan_to_num(value_a) + np.nan_to_num(value_c),
|
|
276
319
|
)
|
|
277
320
|
elif value_a is not None:
|
|
278
321
|
processed_data[read] = value_a
|
|
@@ -286,24 +329,28 @@ def process_stranded_methylation(args):
|
|
|
286
329
|
temp_df = dict_list[dict_index][record][sample]
|
|
287
330
|
processed_data = {}
|
|
288
331
|
# Extract columns and convert probabilities to float (coercing errors)
|
|
289
|
-
read_ids = temp_df[
|
|
290
|
-
positions = temp_df[
|
|
291
|
-
call_codes = temp_df[
|
|
292
|
-
probabilities = pd.to_numeric(temp_df[
|
|
293
|
-
|
|
294
|
-
modified_codes = {
|
|
295
|
-
canonical_codes = {
|
|
296
|
-
|
|
332
|
+
read_ids = temp_df["read_id"].values
|
|
333
|
+
positions = temp_df["ref_position"].values
|
|
334
|
+
call_codes = temp_df["call_code"].values
|
|
335
|
+
probabilities = pd.to_numeric(temp_df["call_prob"].values, errors="coerce")
|
|
336
|
+
|
|
337
|
+
modified_codes = {"a", "h", "m"}
|
|
338
|
+
canonical_codes = {"-"}
|
|
339
|
+
|
|
297
340
|
# Compute methylation probabilities (vectorized)
|
|
298
341
|
methylation_prob = np.full(probabilities.shape, np.nan, dtype=float)
|
|
299
|
-
methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[
|
|
300
|
-
|
|
301
|
-
|
|
342
|
+
methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[
|
|
343
|
+
np.isin(call_codes, list(modified_codes))
|
|
344
|
+
]
|
|
345
|
+
methylation_prob[np.isin(call_codes, list(canonical_codes))] = (
|
|
346
|
+
1 - probabilities[np.isin(call_codes, list(canonical_codes))]
|
|
347
|
+
)
|
|
348
|
+
|
|
302
349
|
# Preallocate storage for each unique read
|
|
303
350
|
unique_reads = np.unique(read_ids)
|
|
304
351
|
for read in unique_reads:
|
|
305
352
|
processed_data[read] = np.full(max_reference_length, np.nan, dtype=float)
|
|
306
|
-
|
|
353
|
+
|
|
307
354
|
# Assign values efficiently
|
|
308
355
|
for i in range(len(read_ids)):
|
|
309
356
|
read = read_ids[i]
|
|
@@ -314,10 +361,11 @@ def process_stranded_methylation(args):
|
|
|
314
361
|
gc.collect()
|
|
315
362
|
return dict_index, record, sample, processed_data
|
|
316
363
|
|
|
364
|
+
|
|
317
365
|
def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=4):
|
|
318
366
|
"""
|
|
319
367
|
Processes all (dict_index, record, sample) tasks in dict_list (excluding indices in dict_to_skip) in parallel.
|
|
320
|
-
|
|
368
|
+
|
|
321
369
|
Returns:
|
|
322
370
|
Updated dict_list with processed (nested) dictionaries.
|
|
323
371
|
"""
|
|
@@ -327,16 +375,17 @@ def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference
|
|
|
327
375
|
for record in current_dict.keys():
|
|
328
376
|
for sample in current_dict[record].keys():
|
|
329
377
|
tasks.append((dict_index, record, sample, dict_list, max_reference_length))
|
|
330
|
-
|
|
378
|
+
|
|
331
379
|
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
|
|
332
380
|
for dict_index, record, sample, processed_data in tqdm(
|
|
333
381
|
executor.map(process_stranded_methylation, tasks),
|
|
334
382
|
total=len(tasks),
|
|
335
|
-
desc="Extracting stranded methylation states"
|
|
383
|
+
desc="Extracting stranded methylation states",
|
|
336
384
|
):
|
|
337
385
|
dict_list[dict_index][record][sample] = processed_data
|
|
338
386
|
return dict_list
|
|
339
387
|
|
|
388
|
+
|
|
340
389
|
def delete_intermediate_h5ads_and_tmpdir(
|
|
341
390
|
h5_dir: Union[str, Path, Iterable[str], None],
|
|
342
391
|
tmp_dir: Optional[Union[str, Path]] = None,
|
|
@@ -360,25 +409,27 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
360
409
|
verbose : bool
|
|
361
410
|
Print progress / warnings.
|
|
362
411
|
"""
|
|
412
|
+
|
|
363
413
|
# Helper: remove a single file path (Path-like or string)
|
|
364
414
|
def _maybe_unlink(p: Path):
|
|
415
|
+
"""Remove a file path if it exists and is a file."""
|
|
365
416
|
if not p.exists():
|
|
366
417
|
if verbose:
|
|
367
|
-
|
|
418
|
+
logger.debug(f"[skip] not found: {p}")
|
|
368
419
|
return
|
|
369
420
|
if not p.is_file():
|
|
370
421
|
if verbose:
|
|
371
|
-
|
|
422
|
+
logger.debug(f"[skip] not a file: {p}")
|
|
372
423
|
return
|
|
373
424
|
if dry_run:
|
|
374
|
-
|
|
425
|
+
logger.debug(f"[dry-run] would remove file: {p}")
|
|
375
426
|
return
|
|
376
427
|
try:
|
|
377
428
|
p.unlink()
|
|
378
429
|
if verbose:
|
|
379
|
-
|
|
430
|
+
logger.info(f"Removed file: {p}")
|
|
380
431
|
except Exception as e:
|
|
381
|
-
|
|
432
|
+
logger.warning(f"[error] failed to remove file {p}: {e}")
|
|
382
433
|
|
|
383
434
|
# Handle h5_dir input (directory OR iterable of file paths)
|
|
384
435
|
if h5_dir is not None:
|
|
@@ -393,7 +444,7 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
393
444
|
else:
|
|
394
445
|
if verbose:
|
|
395
446
|
# optional: comment this out if too noisy
|
|
396
|
-
|
|
447
|
+
logger.debug(f"[skip] not matching pattern: {p.name}")
|
|
397
448
|
else:
|
|
398
449
|
# treat as iterable of file paths
|
|
399
450
|
for f in h5_dir:
|
|
@@ -403,30 +454,45 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
403
454
|
_maybe_unlink(p)
|
|
404
455
|
else:
|
|
405
456
|
if verbose:
|
|
406
|
-
|
|
457
|
+
logger.debug(f"[skip] not matching pattern or not a file: {p}")
|
|
407
458
|
|
|
408
459
|
# Remove tmp_dir recursively (if provided)
|
|
409
460
|
if tmp_dir is not None:
|
|
410
461
|
td = Path(tmp_dir)
|
|
411
462
|
if not td.exists():
|
|
412
463
|
if verbose:
|
|
413
|
-
|
|
464
|
+
logger.debug(f"[skip] tmp_dir not found: {td}")
|
|
414
465
|
else:
|
|
415
466
|
if not td.is_dir():
|
|
416
467
|
if verbose:
|
|
417
|
-
|
|
468
|
+
logger.debug(f"[skip] tmp_dir is not a directory: {td}")
|
|
418
469
|
else:
|
|
419
470
|
if dry_run:
|
|
420
|
-
|
|
471
|
+
logger.debug(f"[dry-run] would remove directory tree: {td}")
|
|
421
472
|
else:
|
|
422
473
|
try:
|
|
423
474
|
shutil.rmtree(td)
|
|
424
475
|
if verbose:
|
|
425
|
-
|
|
476
|
+
logger.info(f"Removed directory tree: {td}")
|
|
426
477
|
except Exception as e:
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
478
|
+
logger.warning(f"[error] failed to remove tmp dir {td}: {e}")
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def modkit_extract_to_adata(
|
|
482
|
+
fasta,
|
|
483
|
+
bam_dir,
|
|
484
|
+
out_dir,
|
|
485
|
+
input_already_demuxed,
|
|
486
|
+
mapping_threshold,
|
|
487
|
+
experiment_name,
|
|
488
|
+
mods,
|
|
489
|
+
batch_size,
|
|
490
|
+
mod_tsv_dir,
|
|
491
|
+
delete_batch_hdfs=False,
|
|
492
|
+
threads=None,
|
|
493
|
+
double_barcoded_path=None,
|
|
494
|
+
samtools_backend: str | None = "auto",
|
|
495
|
+
):
|
|
430
496
|
"""
|
|
431
497
|
Takes modkit extract outputs and organizes it into an adata object
|
|
432
498
|
|
|
@@ -448,82 +514,121 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
448
514
|
"""
|
|
449
515
|
###################################################
|
|
450
516
|
# Package imports
|
|
451
|
-
from .. import readwrite
|
|
452
|
-
from ..readwrite import safe_write_h5ad, make_dirs
|
|
453
|
-
from .fasta_functions import get_native_references
|
|
454
|
-
from .bam_functions import extract_base_identities
|
|
455
|
-
from .ohe import ohe_batching
|
|
456
|
-
import pandas as pd
|
|
457
|
-
import anndata as ad
|
|
458
|
-
import os
|
|
459
517
|
import gc
|
|
460
518
|
import math
|
|
519
|
+
|
|
520
|
+
import anndata as ad
|
|
461
521
|
import numpy as np
|
|
522
|
+
import pandas as pd
|
|
462
523
|
from Bio.Seq import Seq
|
|
463
524
|
from tqdm import tqdm
|
|
464
|
-
|
|
525
|
+
|
|
526
|
+
from .. import readwrite
|
|
527
|
+
from ..readwrite import make_dirs
|
|
528
|
+
from .bam_functions import extract_base_identities
|
|
529
|
+
from .fasta_functions import get_native_references
|
|
530
|
+
from .ohe import ohe_batching
|
|
465
531
|
###################################################
|
|
466
532
|
|
|
467
533
|
################## Get input tsv and bam file names into a sorted list ################
|
|
468
534
|
# Make output dirs
|
|
469
|
-
h5_dir = out_dir /
|
|
470
|
-
tmp_dir = out_dir /
|
|
535
|
+
h5_dir = out_dir / "h5ads"
|
|
536
|
+
tmp_dir = out_dir / "tmp"
|
|
471
537
|
make_dirs([h5_dir, tmp_dir])
|
|
472
538
|
|
|
473
|
-
existing_h5s =
|
|
474
|
-
existing_h5s = [h5 for h5 in existing_h5s if
|
|
475
|
-
final_hdf = f
|
|
539
|
+
existing_h5s = h5_dir.iterdir()
|
|
540
|
+
existing_h5s = [h5 for h5 in existing_h5s if ".h5ad.gz" in str(h5)]
|
|
541
|
+
final_hdf = f"{experiment_name}.h5ad.gz"
|
|
476
542
|
final_adata_path = h5_dir / final_hdf
|
|
477
543
|
final_adata = None
|
|
478
|
-
|
|
544
|
+
|
|
479
545
|
if final_adata_path.exists():
|
|
480
|
-
|
|
546
|
+
logger.debug(f"{final_adata_path} already exists. Using existing adata")
|
|
481
547
|
return final_adata, final_adata_path
|
|
482
|
-
|
|
548
|
+
|
|
483
549
|
# List all files in the directory
|
|
484
550
|
tsvs = sorted(
|
|
485
|
-
p
|
|
486
|
-
|
|
551
|
+
p
|
|
552
|
+
for p in mod_tsv_dir.iterdir()
|
|
553
|
+
if p.is_file() and "unclassified" not in p.name and "extract.tsv" in p.name
|
|
554
|
+
)
|
|
487
555
|
bams = sorted(
|
|
488
|
-
p
|
|
489
|
-
|
|
556
|
+
p
|
|
557
|
+
for p in bam_dir.iterdir()
|
|
558
|
+
if p.is_file()
|
|
559
|
+
and p.suffix == ".bam"
|
|
560
|
+
and "unclassified" not in p.name
|
|
561
|
+
and ".bai" not in p.name
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
tsv_path_list = [tsv for tsv in tsvs]
|
|
565
|
+
bam_path_list = [bam for bam in bams]
|
|
566
|
+
logger.info(f"{len(tsvs)} sample tsv files found: {tsvs}")
|
|
567
|
+
logger.info(f"{len(bams)} sample bams found: {bams}")
|
|
568
|
+
|
|
569
|
+
# Map global sample index (bami / final_sample_index) -> sample name / barcode
|
|
570
|
+
sample_name_map = {}
|
|
571
|
+
barcode_map = {}
|
|
572
|
+
|
|
573
|
+
for idx, bam_path in enumerate(bam_path_list):
|
|
574
|
+
stem = bam_path.stem
|
|
575
|
+
|
|
576
|
+
# Try to peel off a "barcode..." suffix if present.
|
|
577
|
+
# This handles things like:
|
|
578
|
+
# "mySample_barcode01" -> sample="mySample", barcode="barcode01"
|
|
579
|
+
# "run1-s1_barcode05" -> sample="run1-s1", barcode="barcode05"
|
|
580
|
+
# "barcode01" -> sample="barcode01", barcode="barcode01"
|
|
581
|
+
m = re.search(r"^(.*?)[_\-\.]?(barcode[0-9A-Za-z\-]+)$", stem)
|
|
582
|
+
if m:
|
|
583
|
+
sample_name = m.group(1) or stem
|
|
584
|
+
barcode = m.group(2)
|
|
585
|
+
else:
|
|
586
|
+
# Fallback: treat the whole stem as both sample & barcode
|
|
587
|
+
sample_name = stem
|
|
588
|
+
barcode = stem
|
|
490
589
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
590
|
+
# make sample name of the format of the bam file stem
|
|
591
|
+
sample_name = sample_name + f"_{barcode}"
|
|
592
|
+
|
|
593
|
+
# Clean the barcode name to be an integer
|
|
594
|
+
barcode = int(barcode.split("barcode")[1])
|
|
595
|
+
|
|
596
|
+
sample_name_map[idx] = sample_name
|
|
597
|
+
barcode_map[idx] = str(barcode)
|
|
495
598
|
##########################################################################################
|
|
496
599
|
|
|
497
600
|
######### Get Record names that have over a passed threshold of mapped reads #############
|
|
498
601
|
# get all records that are above a certain mapping threshold in at least one sample bam
|
|
499
|
-
records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold)
|
|
602
|
+
records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend)
|
|
500
603
|
|
|
501
604
|
##########################################################################################
|
|
502
605
|
|
|
503
606
|
########### Determine the maximum record length to analyze in the dataset ################
|
|
504
607
|
# Get all references within the FASTA and indicate the length and identity of the record sequence
|
|
505
608
|
max_reference_length = 0
|
|
506
|
-
reference_dict = get_native_references(
|
|
609
|
+
reference_dict = get_native_references(
|
|
610
|
+
str(fasta)
|
|
611
|
+
) # returns a dict keyed by record name. Points to a tuple of (reference length, reference sequence)
|
|
507
612
|
# Get the max record length in the dataset.
|
|
508
613
|
for record in records_to_analyze:
|
|
509
614
|
if reference_dict[record][0] > max_reference_length:
|
|
510
615
|
max_reference_length = reference_dict[record][0]
|
|
511
|
-
|
|
512
|
-
batches = math.ceil(len(tsvs) / batch_size)
|
|
513
|
-
|
|
616
|
+
logger.info(f"Max reference length in dataset: {max_reference_length}")
|
|
617
|
+
batches = math.ceil(len(tsvs) / batch_size) # Number of batches to process
|
|
618
|
+
logger.info("Processing input tsvs in {0} batches of {1} tsvs ".format(batches, batch_size))
|
|
514
619
|
##########################################################################################
|
|
515
620
|
|
|
516
621
|
##########################################################################################
|
|
517
|
-
# One hot encode read sequences and write them out into the tmp_dir as h5ad files.
|
|
622
|
+
# One hot encode read sequences and write them out into the tmp_dir as h5ad files.
|
|
518
623
|
# Save the file paths in the bam_record_ohe_files dict.
|
|
519
624
|
bam_record_ohe_files = {}
|
|
520
|
-
bam_record_save = tmp_dir /
|
|
625
|
+
bam_record_save = tmp_dir / "tmp_file_dict.h5ad"
|
|
521
626
|
fwd_mapped_reads = set()
|
|
522
627
|
rev_mapped_reads = set()
|
|
523
628
|
# If this step has already been performed, read in the tmp_dile_dict
|
|
524
629
|
if bam_record_save.exists():
|
|
525
630
|
bam_record_ohe_files = ad.read_h5ad(bam_record_save).uns
|
|
526
|
-
|
|
631
|
+
logger.debug("Found existing OHE reads, using these")
|
|
527
632
|
else:
|
|
528
633
|
# Iterate over split bams
|
|
529
634
|
for bami, bam in enumerate(bam_path_list):
|
|
@@ -533,18 +638,39 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
533
638
|
positions = range(current_reference_length)
|
|
534
639
|
ref_seq = reference_dict[record][1]
|
|
535
640
|
# Extract the base identities of reads aligned to the record
|
|
536
|
-
|
|
641
|
+
(
|
|
642
|
+
fwd_base_identities,
|
|
643
|
+
rev_base_identities,
|
|
644
|
+
mismatch_counts_per_read,
|
|
645
|
+
mismatch_trend_per_read,
|
|
646
|
+
) = extract_base_identities(
|
|
647
|
+
bam, record, positions, max_reference_length, ref_seq, samtools_backend
|
|
648
|
+
)
|
|
537
649
|
# Store read names of fwd and rev mapped reads
|
|
538
650
|
fwd_mapped_reads.update(fwd_base_identities.keys())
|
|
539
651
|
rev_mapped_reads.update(rev_base_identities.keys())
|
|
540
652
|
# One hot encode the sequence string of the reads
|
|
541
|
-
fwd_ohe_files = ohe_batching(
|
|
542
|
-
|
|
543
|
-
|
|
653
|
+
fwd_ohe_files = ohe_batching(
|
|
654
|
+
fwd_base_identities,
|
|
655
|
+
tmp_dir,
|
|
656
|
+
record,
|
|
657
|
+
f"{bami}_fwd",
|
|
658
|
+
batch_size=100000,
|
|
659
|
+
threads=threads,
|
|
660
|
+
)
|
|
661
|
+
rev_ohe_files = ohe_batching(
|
|
662
|
+
rev_base_identities,
|
|
663
|
+
tmp_dir,
|
|
664
|
+
record,
|
|
665
|
+
f"{bami}_rev",
|
|
666
|
+
batch_size=100000,
|
|
667
|
+
threads=threads,
|
|
668
|
+
)
|
|
669
|
+
bam_record_ohe_files[f"{bami}_{record}"] = fwd_ohe_files + rev_ohe_files
|
|
544
670
|
del fwd_base_identities, rev_base_identities
|
|
545
671
|
# Save out the ohe file paths
|
|
546
672
|
X = np.random.rand(1, 1)
|
|
547
|
-
tmp_ad = ad.AnnData(X=X, uns=bam_record_ohe_files)
|
|
673
|
+
tmp_ad = ad.AnnData(X=X, uns=bam_record_ohe_files)
|
|
548
674
|
tmp_ad.write_h5ad(bam_record_save)
|
|
549
675
|
##########################################################################################
|
|
550
676
|
|
|
@@ -554,39 +680,73 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
554
680
|
for record in records_to_analyze:
|
|
555
681
|
current_reference_length = reference_dict[record][0]
|
|
556
682
|
delta_max_length = max_reference_length - current_reference_length
|
|
557
|
-
sequence = reference_dict[record][1] +
|
|
558
|
-
complement =
|
|
683
|
+
sequence = reference_dict[record][1] + "N" * delta_max_length
|
|
684
|
+
complement = (
|
|
685
|
+
str(Seq(reference_dict[record][1]).complement()).upper() + "N" * delta_max_length
|
|
686
|
+
)
|
|
559
687
|
record_seq_dict[record] = (sequence, complement)
|
|
560
688
|
##########################################################################################
|
|
561
689
|
|
|
562
690
|
###################################################
|
|
563
691
|
# Begin iterating over batches
|
|
564
692
|
for batch in range(batches):
|
|
565
|
-
|
|
693
|
+
logger.info("Processing tsvs for batch {0} ".format(batch))
|
|
566
694
|
# For the final batch, just take the remaining tsv and bam files
|
|
567
695
|
if batch == batches - 1:
|
|
568
696
|
tsv_batch = tsv_path_list
|
|
569
697
|
bam_batch = bam_path_list
|
|
570
|
-
# For all other batches, take the next batch of tsvs and bams out of the file queue.
|
|
698
|
+
# For all other batches, take the next batch of tsvs and bams out of the file queue.
|
|
571
699
|
else:
|
|
572
700
|
tsv_batch = tsv_path_list[:batch_size]
|
|
573
701
|
bam_batch = bam_path_list[:batch_size]
|
|
574
702
|
tsv_path_list = tsv_path_list[batch_size:]
|
|
575
703
|
bam_path_list = bam_path_list[batch_size:]
|
|
576
|
-
|
|
704
|
+
logger.info("tsvs in batch {0} ".format(tsv_batch))
|
|
577
705
|
|
|
578
|
-
batch_already_processed = sum([1 for h5 in existing_h5s if f
|
|
579
|
-
|
|
706
|
+
batch_already_processed = sum([1 for h5 in existing_h5s if f"_{batch}_" in h5.name])
|
|
707
|
+
###################################################
|
|
580
708
|
if batch_already_processed:
|
|
581
|
-
|
|
709
|
+
logger.debug(
|
|
710
|
+
f"Batch {batch} has already been processed into h5ads. Skipping batch and using existing files"
|
|
711
|
+
)
|
|
582
712
|
else:
|
|
583
713
|
###################################################
|
|
584
714
|
### Add the tsvs as dataframes to a dictionary (dict_total) keyed by integer index. Also make modification specific dictionaries and strand specific dictionaries.
|
|
585
715
|
# # Initialize dictionaries and place them in a list
|
|
586
|
-
|
|
587
|
-
|
|
716
|
+
(
|
|
717
|
+
dict_total,
|
|
718
|
+
dict_a,
|
|
719
|
+
dict_a_bottom,
|
|
720
|
+
dict_a_top,
|
|
721
|
+
dict_c,
|
|
722
|
+
dict_c_bottom,
|
|
723
|
+
dict_c_top,
|
|
724
|
+
dict_combined_bottom,
|
|
725
|
+
dict_combined_top,
|
|
726
|
+
) = {}, {}, {}, {}, {}, {}, {}, {}, {}
|
|
727
|
+
dict_list = [
|
|
728
|
+
dict_total,
|
|
729
|
+
dict_a,
|
|
730
|
+
dict_a_bottom,
|
|
731
|
+
dict_a_top,
|
|
732
|
+
dict_c,
|
|
733
|
+
dict_c_bottom,
|
|
734
|
+
dict_c_top,
|
|
735
|
+
dict_combined_bottom,
|
|
736
|
+
dict_combined_top,
|
|
737
|
+
]
|
|
588
738
|
# Give names to represent each dictionary in the list
|
|
589
|
-
sample_types = [
|
|
739
|
+
sample_types = [
|
|
740
|
+
"total",
|
|
741
|
+
"m6A",
|
|
742
|
+
"m6A_bottom_strand",
|
|
743
|
+
"m6A_top_strand",
|
|
744
|
+
"5mC",
|
|
745
|
+
"5mC_bottom_strand",
|
|
746
|
+
"5mC_top_strand",
|
|
747
|
+
"combined_bottom_strand",
|
|
748
|
+
"combined_top_strand",
|
|
749
|
+
]
|
|
590
750
|
# Give indices of dictionaries to skip for analysis and final dictionary saving.
|
|
591
751
|
dict_to_skip = [0, 1, 4]
|
|
592
752
|
combined_dicts = [7, 8]
|
|
@@ -596,7 +756,14 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
596
756
|
dict_to_skip = set(dict_to_skip)
|
|
597
757
|
|
|
598
758
|
# # Step 1):Load the dict_total dictionary with all of the batch tsv files as dataframes.
|
|
599
|
-
dict_total = parallel_load_tsvs(
|
|
759
|
+
dict_total = parallel_load_tsvs(
|
|
760
|
+
tsv_batch,
|
|
761
|
+
records_to_analyze,
|
|
762
|
+
reference_dict,
|
|
763
|
+
batch,
|
|
764
|
+
batch_size=len(tsv_batch),
|
|
765
|
+
threads=threads,
|
|
766
|
+
)
|
|
600
767
|
|
|
601
768
|
# # Step 2: Extract modification-specific data (per (record,sample)) in parallel
|
|
602
769
|
# processed_mod_results = parallel_process_modifications(dict_total, mods, max_reference_length, threads=threads or 4)
|
|
@@ -621,56 +788,112 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
621
788
|
# Iterate over dict_total of all the tsv files and extract the modification specific and strand specific dataframes into dictionaries
|
|
622
789
|
for record in dict_total.keys():
|
|
623
790
|
for sample_index in dict_total[record].keys():
|
|
624
|
-
if
|
|
791
|
+
if "6mA" in mods:
|
|
625
792
|
# Remove Adenine stranded dicts from the dicts to skip set
|
|
626
793
|
dict_to_skip.difference_update(set(A_stranded_dicts))
|
|
627
794
|
|
|
628
|
-
if
|
|
795
|
+
if (
|
|
796
|
+
record not in dict_a.keys()
|
|
797
|
+
and record not in dict_a_bottom.keys()
|
|
798
|
+
and record not in dict_a_top.keys()
|
|
799
|
+
):
|
|
629
800
|
dict_a[record], dict_a_bottom[record], dict_a_top[record] = {}, {}, {}
|
|
630
801
|
|
|
631
802
|
# get a dictionary of dataframes that only contain methylated adenine positions
|
|
632
|
-
dict_a[record][sample_index] = dict_total[record][sample_index][
|
|
633
|
-
|
|
803
|
+
dict_a[record][sample_index] = dict_total[record][sample_index][
|
|
804
|
+
dict_total[record][sample_index]["modified_primary_base"] == "A"
|
|
805
|
+
]
|
|
806
|
+
logger.debug(
|
|
807
|
+
"Successfully loaded a methyl-adenine dictionary for {}".format(
|
|
808
|
+
str(sample_index)
|
|
809
|
+
)
|
|
810
|
+
)
|
|
811
|
+
|
|
634
812
|
# Stratify the adenine dictionary into two strand specific dictionaries.
|
|
635
|
-
dict_a_bottom[record][sample_index] = dict_a[record][sample_index][
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
813
|
+
dict_a_bottom[record][sample_index] = dict_a[record][sample_index][
|
|
814
|
+
dict_a[record][sample_index]["ref_strand"] == "-"
|
|
815
|
+
]
|
|
816
|
+
logger.debug(
|
|
817
|
+
"Successfully loaded a minus strand methyl-adenine dictionary for {}".format(
|
|
818
|
+
str(sample_index)
|
|
819
|
+
)
|
|
820
|
+
)
|
|
821
|
+
dict_a_top[record][sample_index] = dict_a[record][sample_index][
|
|
822
|
+
dict_a[record][sample_index]["ref_strand"] == "+"
|
|
823
|
+
]
|
|
824
|
+
logger.debug(
|
|
825
|
+
"Successfully loaded a plus strand methyl-adenine dictionary for ".format(
|
|
826
|
+
str(sample_index)
|
|
827
|
+
)
|
|
828
|
+
)
|
|
639
829
|
|
|
640
830
|
# Reassign pointer for dict_a to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
641
831
|
dict_a[record][sample_index] = None
|
|
642
832
|
gc.collect()
|
|
643
833
|
|
|
644
|
-
if
|
|
834
|
+
if "5mC" in mods:
|
|
645
835
|
# Remove Cytosine stranded dicts from the dicts to skip set
|
|
646
836
|
dict_to_skip.difference_update(set(C_stranded_dicts))
|
|
647
837
|
|
|
648
|
-
if
|
|
838
|
+
if (
|
|
839
|
+
record not in dict_c.keys()
|
|
840
|
+
and record not in dict_c_bottom.keys()
|
|
841
|
+
and record not in dict_c_top.keys()
|
|
842
|
+
):
|
|
649
843
|
dict_c[record], dict_c_bottom[record], dict_c_top[record] = {}, {}, {}
|
|
650
844
|
|
|
651
845
|
# get a dictionary of dataframes that only contain methylated cytosine positions
|
|
652
|
-
dict_c[record][sample_index] = dict_total[record][sample_index][
|
|
653
|
-
|
|
846
|
+
dict_c[record][sample_index] = dict_total[record][sample_index][
|
|
847
|
+
dict_total[record][sample_index]["modified_primary_base"] == "C"
|
|
848
|
+
]
|
|
849
|
+
logger.debug(
|
|
850
|
+
"Successfully loaded a methyl-cytosine dictionary for {}".format(
|
|
851
|
+
str(sample_index)
|
|
852
|
+
)
|
|
853
|
+
)
|
|
654
854
|
# Stratify the cytosine dictionary into two strand specific dictionaries.
|
|
655
|
-
dict_c_bottom[record][sample_index] = dict_c[record][sample_index][
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
855
|
+
dict_c_bottom[record][sample_index] = dict_c[record][sample_index][
|
|
856
|
+
dict_c[record][sample_index]["ref_strand"] == "-"
|
|
857
|
+
]
|
|
858
|
+
logger.debug(
|
|
859
|
+
"Successfully loaded a minus strand methyl-cytosine dictionary for {}".format(
|
|
860
|
+
str(sample_index)
|
|
861
|
+
)
|
|
862
|
+
)
|
|
863
|
+
dict_c_top[record][sample_index] = dict_c[record][sample_index][
|
|
864
|
+
dict_c[record][sample_index]["ref_strand"] == "+"
|
|
865
|
+
]
|
|
866
|
+
logger.debug(
|
|
867
|
+
"Successfully loaded a plus strand methyl-cytosine dictionary for {}".format(
|
|
868
|
+
str(sample_index)
|
|
869
|
+
)
|
|
870
|
+
)
|
|
659
871
|
# Reassign pointer for dict_c to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
660
872
|
dict_c[record][sample_index] = None
|
|
661
873
|
gc.collect()
|
|
662
|
-
|
|
663
|
-
if
|
|
874
|
+
|
|
875
|
+
if "6mA" in mods and "5mC" in mods:
|
|
664
876
|
# Remove combined stranded dicts from the dicts to skip set
|
|
665
|
-
dict_to_skip.difference_update(set(combined_dicts))
|
|
877
|
+
dict_to_skip.difference_update(set(combined_dicts))
|
|
666
878
|
# Initialize the sample keys for the combined dictionaries
|
|
667
879
|
|
|
668
|
-
if
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
880
|
+
if (
|
|
881
|
+
record not in dict_combined_bottom.keys()
|
|
882
|
+
and record not in dict_combined_top.keys()
|
|
883
|
+
):
|
|
884
|
+
dict_combined_bottom[record], dict_combined_top[record] = {}, {}
|
|
885
|
+
|
|
886
|
+
logger.debug(
|
|
887
|
+
"Successfully created a minus strand combined methylation dictionary for {}".format(
|
|
888
|
+
str(sample_index)
|
|
889
|
+
)
|
|
890
|
+
)
|
|
672
891
|
dict_combined_bottom[record][sample_index] = []
|
|
673
|
-
|
|
892
|
+
logger.debug(
|
|
893
|
+
"Successfully created a plus strand combined methylation dictionary for {}".format(
|
|
894
|
+
str(sample_index)
|
|
895
|
+
)
|
|
896
|
+
)
|
|
674
897
|
dict_combined_top[record][sample_index] = []
|
|
675
898
|
|
|
676
899
|
# Reassign pointer for dict_total to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
@@ -681,14 +904,24 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
681
904
|
for dict_index, dict_type in enumerate(dict_list):
|
|
682
905
|
# Only iterate over stranded dictionaries
|
|
683
906
|
if dict_index not in dict_to_skip:
|
|
684
|
-
|
|
907
|
+
logger.debug(
|
|
908
|
+
"Extracting methylation states for {} dictionary".format(
|
|
909
|
+
sample_types[dict_index]
|
|
910
|
+
)
|
|
911
|
+
)
|
|
685
912
|
for record in dict_type.keys():
|
|
686
913
|
# Get the dictionary for the modification type of interest from the reference mapping of interest
|
|
687
914
|
mod_strand_record_sample_dict = dict_type[record]
|
|
688
|
-
|
|
915
|
+
logger.debug(
|
|
916
|
+
"Extracting methylation states for {} dictionary".format(record)
|
|
917
|
+
)
|
|
689
918
|
# For each sample in a stranded dictionary
|
|
690
919
|
n_samples = len(mod_strand_record_sample_dict.keys())
|
|
691
|
-
for sample in tqdm(
|
|
920
|
+
for sample in tqdm(
|
|
921
|
+
mod_strand_record_sample_dict.keys(),
|
|
922
|
+
desc=f"Extracting {sample_types[dict_index]} dictionary from record {record} for sample",
|
|
923
|
+
total=n_samples,
|
|
924
|
+
):
|
|
692
925
|
# Load the combined bottom strand dictionary after all the individual dictionaries have been made for the sample
|
|
693
926
|
if dict_index == 7:
|
|
694
927
|
# Load the minus strand dictionaries for each sample into temporary variables
|
|
@@ -699,16 +932,26 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
699
932
|
for read in set(temp_a_dict) | set(temp_c_dict):
|
|
700
933
|
# Add the arrays element-wise if the read is present in both dictionaries
|
|
701
934
|
if read in temp_a_dict and read in temp_c_dict:
|
|
702
|
-
mod_strand_record_sample_dict[sample][read] = np.where(
|
|
935
|
+
mod_strand_record_sample_dict[sample][read] = np.where(
|
|
936
|
+
np.isnan(temp_a_dict[read])
|
|
937
|
+
& np.isnan(temp_c_dict[read]),
|
|
938
|
+
np.nan,
|
|
939
|
+
np.nan_to_num(temp_a_dict[read])
|
|
940
|
+
+ np.nan_to_num(temp_c_dict[read]),
|
|
941
|
+
)
|
|
703
942
|
# If the read is present in only one dictionary, copy its value
|
|
704
943
|
elif read in temp_a_dict:
|
|
705
|
-
mod_strand_record_sample_dict[sample][read] = temp_a_dict[
|
|
944
|
+
mod_strand_record_sample_dict[sample][read] = temp_a_dict[
|
|
945
|
+
read
|
|
946
|
+
]
|
|
706
947
|
elif read in temp_c_dict:
|
|
707
|
-
mod_strand_record_sample_dict[sample][read] = temp_c_dict[
|
|
948
|
+
mod_strand_record_sample_dict[sample][read] = temp_c_dict[
|
|
949
|
+
read
|
|
950
|
+
]
|
|
708
951
|
del temp_a_dict, temp_c_dict
|
|
709
|
-
|
|
952
|
+
# Load the combined top strand dictionary after all the individual dictionaries have been made for the sample
|
|
710
953
|
elif dict_index == 8:
|
|
711
|
-
|
|
954
|
+
# Load the plus strand dictionaries for each sample into temporary variables
|
|
712
955
|
temp_a_dict = dict_list[3][record][sample].copy()
|
|
713
956
|
temp_c_dict = dict_list[6][record][sample].copy()
|
|
714
957
|
mod_strand_record_sample_dict[sample] = {}
|
|
@@ -716,105 +959,163 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
716
959
|
for read in set(temp_a_dict) | set(temp_c_dict):
|
|
717
960
|
# Add the arrays element-wise if the read is present in both dictionaries
|
|
718
961
|
if read in temp_a_dict and read in temp_c_dict:
|
|
719
|
-
mod_strand_record_sample_dict[sample][read] = np.where(
|
|
962
|
+
mod_strand_record_sample_dict[sample][read] = np.where(
|
|
963
|
+
np.isnan(temp_a_dict[read])
|
|
964
|
+
& np.isnan(temp_c_dict[read]),
|
|
965
|
+
np.nan,
|
|
966
|
+
np.nan_to_num(temp_a_dict[read])
|
|
967
|
+
+ np.nan_to_num(temp_c_dict[read]),
|
|
968
|
+
)
|
|
720
969
|
# If the read is present in only one dictionary, copy its value
|
|
721
970
|
elif read in temp_a_dict:
|
|
722
|
-
mod_strand_record_sample_dict[sample][read] = temp_a_dict[
|
|
971
|
+
mod_strand_record_sample_dict[sample][read] = temp_a_dict[
|
|
972
|
+
read
|
|
973
|
+
]
|
|
723
974
|
elif read in temp_c_dict:
|
|
724
|
-
mod_strand_record_sample_dict[sample][read] = temp_c_dict[
|
|
975
|
+
mod_strand_record_sample_dict[sample][read] = temp_c_dict[
|
|
976
|
+
read
|
|
977
|
+
]
|
|
725
978
|
del temp_a_dict, temp_c_dict
|
|
726
979
|
# For all other dictionaries
|
|
727
980
|
else:
|
|
728
|
-
|
|
729
981
|
# use temp_df to point to the dataframe held in mod_strand_record_sample_dict[sample]
|
|
730
982
|
temp_df = mod_strand_record_sample_dict[sample]
|
|
731
983
|
# reassign the dictionary pointer to a nested dictionary.
|
|
732
984
|
mod_strand_record_sample_dict[sample] = {}
|
|
733
985
|
|
|
734
986
|
# Get relevant columns as NumPy arrays
|
|
735
|
-
read_ids = temp_df[
|
|
736
|
-
positions = temp_df[
|
|
737
|
-
call_codes = temp_df[
|
|
738
|
-
probabilities = temp_df[
|
|
987
|
+
read_ids = temp_df["read_id"].values
|
|
988
|
+
positions = temp_df["ref_position"].values
|
|
989
|
+
call_codes = temp_df["call_code"].values
|
|
990
|
+
probabilities = temp_df["call_prob"].values
|
|
739
991
|
|
|
740
992
|
# Define valid call code categories
|
|
741
|
-
modified_codes = {
|
|
742
|
-
canonical_codes = {
|
|
993
|
+
modified_codes = {"a", "h", "m"}
|
|
994
|
+
canonical_codes = {"-"}
|
|
743
995
|
|
|
744
996
|
# Vectorized methylation calculation with NaN for other codes
|
|
745
|
-
methylation_prob = np.full_like(
|
|
746
|
-
|
|
747
|
-
|
|
997
|
+
methylation_prob = np.full_like(
|
|
998
|
+
probabilities, np.nan
|
|
999
|
+
) # Default all to NaN
|
|
1000
|
+
methylation_prob[np.isin(call_codes, list(modified_codes))] = (
|
|
1001
|
+
probabilities[np.isin(call_codes, list(modified_codes))]
|
|
1002
|
+
)
|
|
1003
|
+
methylation_prob[np.isin(call_codes, list(canonical_codes))] = (
|
|
1004
|
+
1 - probabilities[np.isin(call_codes, list(canonical_codes))]
|
|
1005
|
+
)
|
|
748
1006
|
|
|
749
1007
|
# Find unique reads
|
|
750
1008
|
unique_reads = np.unique(read_ids)
|
|
751
1009
|
# Preallocate storage for each read
|
|
752
1010
|
for read in unique_reads:
|
|
753
|
-
mod_strand_record_sample_dict[sample][read] = np.full(
|
|
1011
|
+
mod_strand_record_sample_dict[sample][read] = np.full(
|
|
1012
|
+
max_reference_length, np.nan
|
|
1013
|
+
)
|
|
754
1014
|
|
|
755
1015
|
# Efficient NumPy indexing to assign values
|
|
756
1016
|
for i in range(len(read_ids)):
|
|
757
1017
|
read = read_ids[i]
|
|
758
1018
|
pos = positions[i]
|
|
759
1019
|
prob = methylation_prob[i]
|
|
760
|
-
|
|
1020
|
+
|
|
761
1021
|
# Assign methylation probability
|
|
762
1022
|
mod_strand_record_sample_dict[sample][read][pos] = prob
|
|
763
1023
|
|
|
764
|
-
|
|
765
1024
|
# Save the sample files in the batch as gzipped hdf5 files
|
|
766
|
-
|
|
1025
|
+
logger.info("Converting batch {} dictionaries to anndata objects".format(batch))
|
|
767
1026
|
for dict_index, dict_type in enumerate(dict_list):
|
|
768
1027
|
if dict_index not in dict_to_skip:
|
|
769
1028
|
# Initialize an hdf5 file for the current modified strand
|
|
770
1029
|
adata = None
|
|
771
|
-
|
|
1030
|
+
logger.info(
|
|
1031
|
+
"Converting {} dictionary to an anndata object".format(
|
|
1032
|
+
sample_types[dict_index]
|
|
1033
|
+
)
|
|
1034
|
+
)
|
|
772
1035
|
for record in dict_type.keys():
|
|
773
1036
|
# Get the dictionary for the modification type of interest from the reference mapping of interest
|
|
774
1037
|
mod_strand_record_sample_dict = dict_type[record]
|
|
775
1038
|
for sample in mod_strand_record_sample_dict.keys():
|
|
776
|
-
|
|
1039
|
+
logger.info(
|
|
1040
|
+
"Converting {0} dictionary for sample {1} to an anndata object".format(
|
|
1041
|
+
sample_types[dict_index], sample
|
|
1042
|
+
)
|
|
1043
|
+
)
|
|
777
1044
|
sample = int(sample)
|
|
778
1045
|
final_sample_index = sample + (batch * batch_size)
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
1046
|
+
logger.info(
|
|
1047
|
+
"Final sample index for sample: {}".format(final_sample_index)
|
|
1048
|
+
)
|
|
1049
|
+
logger.debug(
|
|
1050
|
+
"Converting {0} dictionary for sample {1} to a dataframe".format(
|
|
1051
|
+
sample_types[dict_index],
|
|
1052
|
+
final_sample_index,
|
|
1053
|
+
)
|
|
1054
|
+
)
|
|
1055
|
+
temp_df = pd.DataFrame.from_dict(
|
|
1056
|
+
mod_strand_record_sample_dict[sample], orient="index"
|
|
1057
|
+
)
|
|
1058
|
+
mod_strand_record_sample_dict[sample] = (
|
|
1059
|
+
None # reassign pointer to facilitate memory usage
|
|
1060
|
+
)
|
|
783
1061
|
sorted_index = sorted(temp_df.index)
|
|
784
1062
|
temp_df = temp_df.reindex(sorted_index)
|
|
785
1063
|
X = temp_df.values
|
|
786
|
-
dataset, strand = sample_types[dict_index].split(
|
|
787
|
-
|
|
788
|
-
|
|
1064
|
+
dataset, strand = sample_types[dict_index].split("_")[:2]
|
|
1065
|
+
|
|
1066
|
+
logger.info(
|
|
1067
|
+
"Loading {0} dataframe for sample {1} into a temp anndata object".format(
|
|
1068
|
+
sample_types[dict_index],
|
|
1069
|
+
final_sample_index,
|
|
1070
|
+
)
|
|
1071
|
+
)
|
|
789
1072
|
temp_adata = ad.AnnData(X)
|
|
790
1073
|
if temp_adata.shape[0] > 0:
|
|
791
|
-
|
|
1074
|
+
logger.info(
|
|
1075
|
+
"Adding read names and position ids to {0} anndata for sample {1}".format(
|
|
1076
|
+
sample_types[dict_index],
|
|
1077
|
+
final_sample_index,
|
|
1078
|
+
)
|
|
1079
|
+
)
|
|
792
1080
|
temp_adata.obs_names = temp_df.index
|
|
793
1081
|
temp_adata.obs_names = temp_adata.obs_names.astype(str)
|
|
794
1082
|
temp_adata.var_names = temp_df.columns
|
|
795
1083
|
temp_adata.var_names = temp_adata.var_names.astype(str)
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
temp_adata.obs[
|
|
803
|
-
|
|
804
|
-
|
|
1084
|
+
logger.info(
|
|
1085
|
+
"Adding {0} anndata for sample {1}".format(
|
|
1086
|
+
sample_types[dict_index],
|
|
1087
|
+
final_sample_index,
|
|
1088
|
+
)
|
|
1089
|
+
)
|
|
1090
|
+
temp_adata.obs["Sample"] = [
|
|
1091
|
+
sample_name_map[final_sample_index]
|
|
1092
|
+
] * len(temp_adata)
|
|
1093
|
+
temp_adata.obs["Barcode"] = [barcode_map[final_sample_index]] * len(
|
|
1094
|
+
temp_adata
|
|
1095
|
+
)
|
|
1096
|
+
temp_adata.obs["Reference"] = [f"{record}"] * len(temp_adata)
|
|
1097
|
+
temp_adata.obs["Strand"] = [strand] * len(temp_adata)
|
|
1098
|
+
temp_adata.obs["Dataset"] = [dataset] * len(temp_adata)
|
|
1099
|
+
temp_adata.obs["Reference_dataset_strand"] = [
|
|
1100
|
+
f"{record}_{dataset}_{strand}"
|
|
1101
|
+
] * len(temp_adata)
|
|
1102
|
+
temp_adata.obs["Reference_strand"] = [f"{record}_{strand}"] * len(
|
|
1103
|
+
temp_adata
|
|
1104
|
+
)
|
|
1105
|
+
|
|
805
1106
|
# Load in the one hot encoded reads from the current sample and record
|
|
806
1107
|
one_hot_reads = {}
|
|
807
1108
|
n_rows_OHE = 5
|
|
808
|
-
ohe_files = bam_record_ohe_files[f
|
|
809
|
-
|
|
1109
|
+
ohe_files = bam_record_ohe_files[f"{final_sample_index}_{record}"]
|
|
1110
|
+
logger.info(f"Loading OHEs from {ohe_files}")
|
|
810
1111
|
fwd_mapped_reads = set()
|
|
811
1112
|
rev_mapped_reads = set()
|
|
812
1113
|
for ohe_file in ohe_files:
|
|
813
1114
|
tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
|
|
814
1115
|
one_hot_reads.update(tmp_ohe_dict)
|
|
815
|
-
if
|
|
1116
|
+
if "_fwd_" in ohe_file:
|
|
816
1117
|
fwd_mapped_reads.update(tmp_ohe_dict.keys())
|
|
817
|
-
elif
|
|
1118
|
+
elif "_rev_" in ohe_file:
|
|
818
1119
|
rev_mapped_reads.update(tmp_ohe_dict.keys())
|
|
819
1120
|
del tmp_ohe_dict
|
|
820
1121
|
|
|
@@ -823,18 +1124,20 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
823
1124
|
read_mapping_direction = []
|
|
824
1125
|
for read_id in temp_adata.obs_names:
|
|
825
1126
|
if read_id in fwd_mapped_reads:
|
|
826
|
-
read_mapping_direction.append(
|
|
1127
|
+
read_mapping_direction.append("fwd")
|
|
827
1128
|
elif read_id in rev_mapped_reads:
|
|
828
|
-
read_mapping_direction.append(
|
|
1129
|
+
read_mapping_direction.append("rev")
|
|
829
1130
|
else:
|
|
830
|
-
read_mapping_direction.append(
|
|
1131
|
+
read_mapping_direction.append("unk")
|
|
831
1132
|
|
|
832
|
-
temp_adata.obs[
|
|
1133
|
+
temp_adata.obs["Read_mapping_direction"] = read_mapping_direction
|
|
833
1134
|
|
|
834
1135
|
del temp_df
|
|
835
|
-
|
|
1136
|
+
|
|
836
1137
|
# Initialize NumPy arrays
|
|
837
|
-
sequence_length =
|
|
1138
|
+
sequence_length = (
|
|
1139
|
+
one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
|
|
1140
|
+
)
|
|
838
1141
|
df_A = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
839
1142
|
df_C = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
840
1143
|
df_G = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
@@ -855,7 +1158,11 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
855
1158
|
gc.collect()
|
|
856
1159
|
|
|
857
1160
|
# Fill the arrays
|
|
858
|
-
for j, read_name in tqdm(
|
|
1161
|
+
for j, read_name in tqdm(
|
|
1162
|
+
enumerate(sorted_index),
|
|
1163
|
+
desc="Loading dataframes of OHE reads",
|
|
1164
|
+
total=len(sorted_index),
|
|
1165
|
+
):
|
|
859
1166
|
df_A[j, :] = dict_A[read_name]
|
|
860
1167
|
df_C[j, :] = dict_C[read_name]
|
|
861
1168
|
df_G[j, :] = dict_G[read_name]
|
|
@@ -867,43 +1174,78 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
867
1174
|
|
|
868
1175
|
# Store the results in AnnData layers
|
|
869
1176
|
ohe_df_map = {0: df_A, 1: df_C, 2: df_G, 3: df_T, 4: df_N}
|
|
870
|
-
for j, base in enumerate([
|
|
871
|
-
temp_adata.layers[f
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
1177
|
+
for j, base in enumerate(["A", "C", "G", "T", "N"]):
|
|
1178
|
+
temp_adata.layers[f"{base}_binary_sequence_encoding"] = (
|
|
1179
|
+
ohe_df_map[j]
|
|
1180
|
+
)
|
|
1181
|
+
ohe_df_map[j] = (
|
|
1182
|
+
None # Reassign pointer for memory usage purposes
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
# If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
|
|
875
1186
|
if adata:
|
|
876
1187
|
if temp_adata.shape[0] > 0:
|
|
877
|
-
|
|
878
|
-
|
|
1188
|
+
logger.info(
|
|
1189
|
+
"Concatenating {0} anndata object for sample {1}".format(
|
|
1190
|
+
sample_types[dict_index],
|
|
1191
|
+
final_sample_index,
|
|
1192
|
+
)
|
|
1193
|
+
)
|
|
1194
|
+
adata = ad.concat(
|
|
1195
|
+
[adata, temp_adata], join="outer", index_unique=None
|
|
1196
|
+
)
|
|
879
1197
|
del temp_adata
|
|
880
1198
|
else:
|
|
881
|
-
|
|
1199
|
+
logger.warning(
|
|
1200
|
+
f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata"
|
|
1201
|
+
)
|
|
882
1202
|
else:
|
|
883
1203
|
if temp_adata.shape[0] > 0:
|
|
884
|
-
|
|
1204
|
+
logger.info(
|
|
1205
|
+
"Initializing {0} anndata object for sample {1}".format(
|
|
1206
|
+
sample_types[dict_index],
|
|
1207
|
+
final_sample_index,
|
|
1208
|
+
)
|
|
1209
|
+
)
|
|
885
1210
|
adata = temp_adata
|
|
886
1211
|
else:
|
|
887
|
-
|
|
1212
|
+
logger.warning(
|
|
1213
|
+
f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata"
|
|
1214
|
+
)
|
|
888
1215
|
|
|
889
1216
|
gc.collect()
|
|
890
1217
|
else:
|
|
891
|
-
|
|
1218
|
+
logger.warning(
|
|
1219
|
+
f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata. Skipping sample."
|
|
1220
|
+
)
|
|
892
1221
|
|
|
893
1222
|
try:
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
1223
|
+
logger.info(
|
|
1224
|
+
"Writing {0} anndata out as a hdf5 file".format(
|
|
1225
|
+
sample_types[dict_index]
|
|
1226
|
+
)
|
|
1227
|
+
)
|
|
1228
|
+
adata.write_h5ad(
|
|
1229
|
+
h5_dir
|
|
1230
|
+
/ "{0}_{1}_{2}_SMF_binarized_sample_hdf5.h5ad.gz".format(
|
|
1231
|
+
readwrite.date_string(), batch, sample_types[dict_index]
|
|
1232
|
+
),
|
|
1233
|
+
compression="gzip",
|
|
1234
|
+
)
|
|
1235
|
+
except Exception:
|
|
1236
|
+
logger.debug("Skipping writing anndata for sample")
|
|
1237
|
+
|
|
1238
|
+
try:
|
|
1239
|
+
# Delete the batch dictionaries from memory
|
|
1240
|
+
del dict_list, adata
|
|
1241
|
+
except Exception:
|
|
1242
|
+
pass
|
|
901
1243
|
gc.collect()
|
|
902
1244
|
|
|
903
1245
|
# Iterate over all of the batched hdf5 files and concatenate them.
|
|
904
|
-
files = h5_dir.iterdir()
|
|
1246
|
+
files = h5_dir.iterdir()
|
|
905
1247
|
# Filter file names that contain the search string in their filename and keep them in a list
|
|
906
|
-
hdfs = [hdf for hdf in files if
|
|
1248
|
+
hdfs = [hdf for hdf in files if "hdf5.h5ad" in hdf.name and hdf != final_hdf]
|
|
907
1249
|
combined_hdfs = [hdf for hdf in hdfs if "combined" in hdf.name]
|
|
908
1250
|
if len(combined_hdfs) > 0:
|
|
909
1251
|
hdfs = combined_hdfs
|
|
@@ -911,55 +1253,62 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
911
1253
|
pass
|
|
912
1254
|
# Sort file list by names and print the list of file names
|
|
913
1255
|
hdfs.sort()
|
|
914
|
-
|
|
915
|
-
hdf_paths = [
|
|
1256
|
+
logger.info("{0} sample files found: {1}".format(len(hdfs), hdfs))
|
|
1257
|
+
hdf_paths = [hd5 for hd5 in hdfs]
|
|
916
1258
|
final_adata = None
|
|
917
1259
|
for hdf_index, hdf in enumerate(hdf_paths):
|
|
918
|
-
|
|
1260
|
+
logger.info("Reading in {} hdf5 file".format(hdfs[hdf_index]))
|
|
919
1261
|
temp_adata = ad.read_h5ad(hdf)
|
|
920
1262
|
if final_adata:
|
|
921
|
-
|
|
922
|
-
|
|
1263
|
+
logger.info(
|
|
1264
|
+
"Concatenating final adata object with {} hdf5 file".format(hdfs[hdf_index])
|
|
1265
|
+
)
|
|
1266
|
+
final_adata = ad.concat([final_adata, temp_adata], join="outer", index_unique=None)
|
|
923
1267
|
else:
|
|
924
|
-
|
|
1268
|
+
logger.info("Initializing final adata object with {} hdf5 file".format(hdfs[hdf_index]))
|
|
925
1269
|
final_adata = temp_adata
|
|
926
1270
|
del temp_adata
|
|
927
1271
|
|
|
928
1272
|
# Set obs columns to type 'category'
|
|
929
1273
|
for col in final_adata.obs.columns:
|
|
930
|
-
final_adata.obs[col] = final_adata.obs[col].astype(
|
|
1274
|
+
final_adata.obs[col] = final_adata.obs[col].astype("category")
|
|
931
1275
|
|
|
932
|
-
ohe_bases = [
|
|
933
|
-
ohe_layers = [f"{ohe_base}
|
|
934
|
-
final_adata.uns[
|
|
1276
|
+
ohe_bases = ["A", "C", "G", "T"] # ignore N bases for consensus
|
|
1277
|
+
ohe_layers = [f"{ohe_base}_binary_sequence_encoding" for ohe_base in ohe_bases]
|
|
1278
|
+
final_adata.uns["References"] = {}
|
|
935
1279
|
for record in records_to_analyze:
|
|
936
1280
|
# Add FASTA sequence to the object
|
|
937
1281
|
sequence = record_seq_dict[record][0]
|
|
938
1282
|
complement = record_seq_dict[record][1]
|
|
939
|
-
final_adata.var[f
|
|
940
|
-
final_adata.var[f
|
|
941
|
-
final_adata.uns[f
|
|
942
|
-
final_adata.uns[
|
|
1283
|
+
final_adata.var[f"{record}_top_strand_FASTA_base"] = list(sequence)
|
|
1284
|
+
final_adata.var[f"{record}_bottom_strand_FASTA_base"] = list(complement)
|
|
1285
|
+
final_adata.uns[f"{record}_FASTA_sequence"] = sequence
|
|
1286
|
+
final_adata.uns["References"][f"{record}_FASTA_sequence"] = sequence
|
|
943
1287
|
# Add consensus sequence of samples mapped to the record to the object
|
|
944
|
-
record_subset = final_adata[final_adata.obs[
|
|
945
|
-
for strand in record_subset.obs[
|
|
946
|
-
strand_subset = record_subset[record_subset.obs[
|
|
947
|
-
for mapping_dir in strand_subset.obs[
|
|
948
|
-
mapping_dir_subset = strand_subset[
|
|
1288
|
+
record_subset = final_adata[final_adata.obs["Reference"] == record]
|
|
1289
|
+
for strand in record_subset.obs["Strand"].cat.categories:
|
|
1290
|
+
strand_subset = record_subset[record_subset.obs["Strand"] == strand]
|
|
1291
|
+
for mapping_dir in strand_subset.obs["Read_mapping_direction"].cat.categories:
|
|
1292
|
+
mapping_dir_subset = strand_subset[
|
|
1293
|
+
strand_subset.obs["Read_mapping_direction"] == mapping_dir
|
|
1294
|
+
]
|
|
949
1295
|
layer_map, layer_counts = {}, []
|
|
950
1296
|
for i, layer in enumerate(ohe_layers):
|
|
951
|
-
layer_map[i] = layer.split(
|
|
1297
|
+
layer_map[i] = layer.split("_")[0]
|
|
952
1298
|
layer_counts.append(np.sum(mapping_dir_subset.layers[layer], axis=0))
|
|
953
1299
|
count_array = np.array(layer_counts)
|
|
954
1300
|
nucleotide_indexes = np.argmax(count_array, axis=0)
|
|
955
1301
|
consensus_sequence_list = [layer_map[i] for i in nucleotide_indexes]
|
|
956
|
-
final_adata.var[
|
|
1302
|
+
final_adata.var[
|
|
1303
|
+
f"{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples"
|
|
1304
|
+
] = consensus_sequence_list
|
|
957
1305
|
|
|
958
1306
|
if input_already_demuxed:
|
|
959
1307
|
final_adata.obs["demux_type"] = ["already"] * final_adata.shape[0]
|
|
960
1308
|
final_adata.obs["demux_type"] = final_adata.obs["demux_type"].astype("category")
|
|
961
1309
|
else:
|
|
962
1310
|
from .h5ad_functions import add_demux_type_annotation
|
|
1311
|
+
|
|
963
1312
|
double_barcoded_reads = double_barcoded_path / "barcoding_summary.txt"
|
|
964
1313
|
add_demux_type_annotation(final_adata, double_barcoded_reads)
|
|
965
1314
|
|
|
@@ -967,4 +1316,4 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
967
1316
|
if delete_batch_hdfs:
|
|
968
1317
|
delete_intermediate_h5ads_and_tmpdir(h5_dir, tmp_dir)
|
|
969
1318
|
|
|
970
|
-
return final_adata, final_adata_path
|
|
1319
|
+
return final_adata, final_adata_path
|