smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,64 +1,81 @@
|
|
|
1
1
|
import concurrent.futures
|
|
2
2
|
import gc
|
|
3
|
-
|
|
3
|
+
import re
|
|
4
|
+
import shutil
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Iterable, Optional, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
4
9
|
import pandas as pd
|
|
5
10
|
from tqdm import tqdm
|
|
6
|
-
|
|
7
|
-
from
|
|
8
|
-
|
|
9
|
-
import
|
|
11
|
+
|
|
12
|
+
from smftools.logging_utils import get_logger
|
|
13
|
+
|
|
14
|
+
from .bam_functions import count_aligned_reads
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
10
18
|
|
|
11
19
|
def filter_bam_records(bam, mapping_threshold):
|
|
12
20
|
"""Processes a single BAM file, counts reads, and determines records to analyze."""
|
|
13
21
|
aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(bam)
|
|
14
|
-
|
|
22
|
+
|
|
15
23
|
total_reads = aligned_reads_count + unaligned_reads_count
|
|
16
24
|
percent_aligned = (aligned_reads_count * 100 / total_reads) if total_reads > 0 else 0
|
|
17
|
-
|
|
25
|
+
logger.info(f"{percent_aligned:.2f}% of reads in {bam} aligned successfully")
|
|
18
26
|
|
|
19
27
|
records = []
|
|
20
28
|
for record, (count, percentage) in record_counts_dict.items():
|
|
21
|
-
|
|
29
|
+
logger.info(
|
|
30
|
+
f"{count} reads mapped to reference {record}. This is {percentage * 100:.2f}% of all mapped reads in {bam}"
|
|
31
|
+
)
|
|
22
32
|
if percentage >= mapping_threshold:
|
|
23
33
|
records.append(record)
|
|
24
|
-
|
|
34
|
+
|
|
25
35
|
return set(records)
|
|
26
36
|
|
|
37
|
+
|
|
27
38
|
def parallel_filter_bams(bam_path_list, mapping_threshold):
|
|
28
39
|
"""Parallel processing for multiple BAM files."""
|
|
29
40
|
records_to_analyze = set()
|
|
30
41
|
|
|
31
42
|
with concurrent.futures.ProcessPoolExecutor() as executor:
|
|
32
|
-
results = executor.map(
|
|
43
|
+
results = executor.map(
|
|
44
|
+
filter_bam_records, bam_path_list, [mapping_threshold] * len(bam_path_list)
|
|
45
|
+
)
|
|
33
46
|
|
|
34
47
|
# Aggregate results
|
|
35
48
|
for result in results:
|
|
36
49
|
records_to_analyze.update(result)
|
|
37
50
|
|
|
38
|
-
|
|
51
|
+
logger.info(f"Records to analyze: {records_to_analyze}")
|
|
39
52
|
return records_to_analyze
|
|
40
53
|
|
|
54
|
+
|
|
41
55
|
def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
|
|
42
56
|
"""
|
|
43
57
|
Loads and filters a single TSV file based on chromosome and position criteria.
|
|
44
58
|
"""
|
|
45
|
-
temp_df = pd.read_csv(tsv, sep=
|
|
59
|
+
temp_df = pd.read_csv(tsv, sep="\t", header=0)
|
|
46
60
|
filtered_records = {}
|
|
47
61
|
|
|
48
62
|
for record in records_to_analyze:
|
|
49
63
|
if record not in reference_dict:
|
|
50
64
|
continue
|
|
51
|
-
|
|
65
|
+
|
|
52
66
|
ref_length = reference_dict[record][0]
|
|
53
|
-
filtered_df = temp_df[
|
|
54
|
-
|
|
55
|
-
|
|
67
|
+
filtered_df = temp_df[
|
|
68
|
+
(temp_df["chrom"] == record)
|
|
69
|
+
& (temp_df["ref_position"] >= 0)
|
|
70
|
+
& (temp_df["ref_position"] < ref_length)
|
|
71
|
+
]
|
|
56
72
|
|
|
57
73
|
if not filtered_df.empty:
|
|
58
74
|
filtered_records[record] = {sample_index: filtered_df}
|
|
59
75
|
|
|
60
76
|
return filtered_records
|
|
61
77
|
|
|
78
|
+
|
|
62
79
|
def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size, threads=4):
|
|
63
80
|
"""
|
|
64
81
|
Loads and filters TSV files in parallel.
|
|
@@ -78,52 +95,60 @@ def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, bat
|
|
|
78
95
|
|
|
79
96
|
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
|
|
80
97
|
futures = {
|
|
81
|
-
executor.submit(
|
|
98
|
+
executor.submit(
|
|
99
|
+
process_tsv, tsv, records_to_analyze, reference_dict, sample_index
|
|
100
|
+
): sample_index
|
|
82
101
|
for sample_index, tsv in enumerate(tsv_batch)
|
|
83
102
|
}
|
|
84
103
|
|
|
85
|
-
for future in tqdm(
|
|
104
|
+
for future in tqdm(
|
|
105
|
+
concurrent.futures.as_completed(futures),
|
|
106
|
+
desc=f"Processing batch {batch}",
|
|
107
|
+
total=batch_size,
|
|
108
|
+
):
|
|
86
109
|
result = future.result()
|
|
87
110
|
for record, sample_data in result.items():
|
|
88
111
|
dict_total[record].update(sample_data)
|
|
89
112
|
|
|
90
113
|
return dict_total
|
|
91
114
|
|
|
115
|
+
|
|
92
116
|
def update_dict_to_skip(dict_to_skip, detected_modifications):
|
|
93
117
|
"""
|
|
94
118
|
Updates the dict_to_skip set based on the detected modifications.
|
|
95
|
-
|
|
119
|
+
|
|
96
120
|
Parameters:
|
|
97
121
|
dict_to_skip (set): The initial set of dictionary indices to skip.
|
|
98
122
|
detected_modifications (list or set): The modifications (e.g. ['6mA', '5mC']) present.
|
|
99
|
-
|
|
123
|
+
|
|
100
124
|
Returns:
|
|
101
125
|
set: The updated dict_to_skip set.
|
|
102
126
|
"""
|
|
103
127
|
# Define which indices correspond to modification-specific or strand-specific dictionaries
|
|
104
|
-
A_stranded_dicts = {2, 3}
|
|
105
|
-
C_stranded_dicts = {5, 6}
|
|
106
|
-
combined_dicts
|
|
128
|
+
A_stranded_dicts = {2, 3} # m6A bottom and top strand dictionaries
|
|
129
|
+
C_stranded_dicts = {5, 6} # 5mC bottom and top strand dictionaries
|
|
130
|
+
combined_dicts = {7, 8} # Combined strand dictionaries
|
|
107
131
|
|
|
108
132
|
# If '6mA' is present, remove the A_stranded indices from the skip set
|
|
109
|
-
if
|
|
133
|
+
if "6mA" in detected_modifications:
|
|
110
134
|
dict_to_skip -= A_stranded_dicts
|
|
111
135
|
# If '5mC' is present, remove the C_stranded indices from the skip set
|
|
112
|
-
if
|
|
136
|
+
if "5mC" in detected_modifications:
|
|
113
137
|
dict_to_skip -= C_stranded_dicts
|
|
114
138
|
# If both modifications are present, remove the combined indices from the skip set
|
|
115
|
-
if
|
|
139
|
+
if "6mA" in detected_modifications and "5mC" in detected_modifications:
|
|
116
140
|
dict_to_skip -= combined_dicts
|
|
117
141
|
|
|
118
142
|
return dict_to_skip
|
|
119
143
|
|
|
144
|
+
|
|
120
145
|
def process_modifications_for_sample(args):
|
|
121
146
|
"""
|
|
122
147
|
Processes a single (record, sample) pair to extract modification-specific data.
|
|
123
|
-
|
|
148
|
+
|
|
124
149
|
Parameters:
|
|
125
150
|
args: (record, sample_index, sample_df, mods, max_reference_length)
|
|
126
|
-
|
|
151
|
+
|
|
127
152
|
Returns:
|
|
128
153
|
(record, sample_index, result) where result is a dict with keys:
|
|
129
154
|
'm6A', 'm6A_minus', 'm6A_plus', '5mC', '5mC_minus', '5mC_plus', and
|
|
@@ -131,29 +156,30 @@ def process_modifications_for_sample(args):
|
|
|
131
156
|
"""
|
|
132
157
|
record, sample_index, sample_df, mods, max_reference_length = args
|
|
133
158
|
result = {}
|
|
134
|
-
if
|
|
135
|
-
m6a_df = sample_df[sample_df[
|
|
136
|
-
result[
|
|
137
|
-
result[
|
|
138
|
-
result[
|
|
159
|
+
if "6mA" in mods:
|
|
160
|
+
m6a_df = sample_df[sample_df["modified_primary_base"] == "A"]
|
|
161
|
+
result["m6A"] = m6a_df
|
|
162
|
+
result["m6A_minus"] = m6a_df[m6a_df["ref_strand"] == "-"]
|
|
163
|
+
result["m6A_plus"] = m6a_df[m6a_df["ref_strand"] == "+"]
|
|
139
164
|
m6a_df = None
|
|
140
165
|
gc.collect()
|
|
141
|
-
if
|
|
142
|
-
m5c_df = sample_df[sample_df[
|
|
143
|
-
result[
|
|
144
|
-
result[
|
|
145
|
-
result[
|
|
166
|
+
if "5mC" in mods:
|
|
167
|
+
m5c_df = sample_df[sample_df["modified_primary_base"] == "C"]
|
|
168
|
+
result["5mC"] = m5c_df
|
|
169
|
+
result["5mC_minus"] = m5c_df[m5c_df["ref_strand"] == "-"]
|
|
170
|
+
result["5mC_plus"] = m5c_df[m5c_df["ref_strand"] == "+"]
|
|
146
171
|
m5c_df = None
|
|
147
172
|
gc.collect()
|
|
148
|
-
if
|
|
149
|
-
result[
|
|
150
|
-
result[
|
|
173
|
+
if "6mA" in mods and "5mC" in mods:
|
|
174
|
+
result["combined_minus"] = []
|
|
175
|
+
result["combined_plus"] = []
|
|
151
176
|
return record, sample_index, result
|
|
152
177
|
|
|
178
|
+
|
|
153
179
|
def parallel_process_modifications(dict_total, mods, max_reference_length, threads=4):
|
|
154
180
|
"""
|
|
155
181
|
Processes each (record, sample) pair in dict_total in parallel to extract modification-specific data.
|
|
156
|
-
|
|
182
|
+
|
|
157
183
|
Returns:
|
|
158
184
|
processed_results: Dict keyed by record, with sub-dict keyed by sample index and the processed results.
|
|
159
185
|
"""
|
|
@@ -164,18 +190,20 @@ def parallel_process_modifications(dict_total, mods, max_reference_length, threa
|
|
|
164
190
|
processed_results = {}
|
|
165
191
|
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
|
|
166
192
|
for record, sample_index, result in tqdm(
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
193
|
+
executor.map(process_modifications_for_sample, tasks),
|
|
194
|
+
total=len(tasks),
|
|
195
|
+
desc="Processing modifications",
|
|
196
|
+
):
|
|
170
197
|
if record not in processed_results:
|
|
171
198
|
processed_results[record] = {}
|
|
172
199
|
processed_results[record][sample_index] = result
|
|
173
200
|
return processed_results
|
|
174
201
|
|
|
202
|
+
|
|
175
203
|
def merge_modification_results(processed_results, mods):
|
|
176
204
|
"""
|
|
177
205
|
Merges individual sample results into global dictionaries.
|
|
178
|
-
|
|
206
|
+
|
|
179
207
|
Returns:
|
|
180
208
|
A tuple: (m6A_dict, m6A_minus, m6A_plus, c5m_dict, c5m_minus, c5m_plus, combined_minus, combined_plus)
|
|
181
209
|
"""
|
|
@@ -189,44 +217,52 @@ def merge_modification_results(processed_results, mods):
|
|
|
189
217
|
combined_plus = {}
|
|
190
218
|
for record, sample_results in processed_results.items():
|
|
191
219
|
for sample_index, res in sample_results.items():
|
|
192
|
-
if
|
|
220
|
+
if "6mA" in mods:
|
|
193
221
|
if record not in m6A_dict:
|
|
194
222
|
m6A_dict[record], m6A_minus[record], m6A_plus[record] = {}, {}, {}
|
|
195
|
-
m6A_dict[record][sample_index] = res.get(
|
|
196
|
-
m6A_minus[record][sample_index] = res.get(
|
|
197
|
-
m6A_plus[record][sample_index] = res.get(
|
|
198
|
-
if
|
|
223
|
+
m6A_dict[record][sample_index] = res.get("m6A", pd.DataFrame())
|
|
224
|
+
m6A_minus[record][sample_index] = res.get("m6A_minus", pd.DataFrame())
|
|
225
|
+
m6A_plus[record][sample_index] = res.get("m6A_plus", pd.DataFrame())
|
|
226
|
+
if "5mC" in mods:
|
|
199
227
|
if record not in c5m_dict:
|
|
200
228
|
c5m_dict[record], c5m_minus[record], c5m_plus[record] = {}, {}, {}
|
|
201
|
-
c5m_dict[record][sample_index] = res.get(
|
|
202
|
-
c5m_minus[record][sample_index] = res.get(
|
|
203
|
-
c5m_plus[record][sample_index] = res.get(
|
|
204
|
-
if
|
|
229
|
+
c5m_dict[record][sample_index] = res.get("5mC", pd.DataFrame())
|
|
230
|
+
c5m_minus[record][sample_index] = res.get("5mC_minus", pd.DataFrame())
|
|
231
|
+
c5m_plus[record][sample_index] = res.get("5mC_plus", pd.DataFrame())
|
|
232
|
+
if "6mA" in mods and "5mC" in mods:
|
|
205
233
|
if record not in combined_minus:
|
|
206
234
|
combined_minus[record], combined_plus[record] = {}, {}
|
|
207
|
-
combined_minus[record][sample_index] = res.get(
|
|
208
|
-
combined_plus[record][sample_index] = res.get(
|
|
209
|
-
return (
|
|
210
|
-
|
|
211
|
-
|
|
235
|
+
combined_minus[record][sample_index] = res.get("combined_minus", [])
|
|
236
|
+
combined_plus[record][sample_index] = res.get("combined_plus", [])
|
|
237
|
+
return (
|
|
238
|
+
m6A_dict,
|
|
239
|
+
m6A_minus,
|
|
240
|
+
m6A_plus,
|
|
241
|
+
c5m_dict,
|
|
242
|
+
c5m_minus,
|
|
243
|
+
c5m_plus,
|
|
244
|
+
combined_minus,
|
|
245
|
+
combined_plus,
|
|
246
|
+
)
|
|
247
|
+
|
|
212
248
|
|
|
213
249
|
def process_stranded_methylation(args):
|
|
214
250
|
"""
|
|
215
251
|
Processes a single (dict_index, record, sample) task.
|
|
216
|
-
|
|
252
|
+
|
|
217
253
|
For combined dictionaries (indices 7 or 8), it merges the corresponding A-stranded and C-stranded data.
|
|
218
|
-
For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
|
|
254
|
+
For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
|
|
219
255
|
NumPy methylation array (of float type). Non-numeric values (e.g. '-') are coerced to NaN.
|
|
220
|
-
|
|
256
|
+
|
|
221
257
|
Parameters:
|
|
222
258
|
args: (dict_index, record, sample, dict_list, max_reference_length)
|
|
223
|
-
|
|
259
|
+
|
|
224
260
|
Returns:
|
|
225
261
|
(dict_index, record, sample, processed_data)
|
|
226
262
|
"""
|
|
227
263
|
dict_index, record, sample, dict_list, max_reference_length = args
|
|
228
264
|
processed_data = {}
|
|
229
|
-
|
|
265
|
+
|
|
230
266
|
# For combined bottom strand (index 7)
|
|
231
267
|
if dict_index == 7:
|
|
232
268
|
temp_a = dict_list[2][record].get(sample, {}).copy()
|
|
@@ -235,18 +271,18 @@ def process_stranded_methylation(args):
|
|
|
235
271
|
for read in set(temp_a.keys()) | set(temp_c.keys()):
|
|
236
272
|
if read in temp_a:
|
|
237
273
|
# Convert using pd.to_numeric with errors='coerce'
|
|
238
|
-
value_a = pd.to_numeric(np.array(temp_a[read]), errors=
|
|
274
|
+
value_a = pd.to_numeric(np.array(temp_a[read]), errors="coerce")
|
|
239
275
|
else:
|
|
240
276
|
value_a = None
|
|
241
277
|
if read in temp_c:
|
|
242
|
-
value_c = pd.to_numeric(np.array(temp_c[read]), errors=
|
|
278
|
+
value_c = pd.to_numeric(np.array(temp_c[read]), errors="coerce")
|
|
243
279
|
else:
|
|
244
280
|
value_c = None
|
|
245
281
|
if value_a is not None and value_c is not None:
|
|
246
282
|
processed_data[read] = np.where(
|
|
247
283
|
np.isnan(value_a) & np.isnan(value_c),
|
|
248
284
|
np.nan,
|
|
249
|
-
np.nan_to_num(value_a) + np.nan_to_num(value_c)
|
|
285
|
+
np.nan_to_num(value_a) + np.nan_to_num(value_c),
|
|
250
286
|
)
|
|
251
287
|
elif value_a is not None:
|
|
252
288
|
processed_data[read] = value_a
|
|
@@ -261,18 +297,18 @@ def process_stranded_methylation(args):
|
|
|
261
297
|
processed_data = {}
|
|
262
298
|
for read in set(temp_a.keys()) | set(temp_c.keys()):
|
|
263
299
|
if read in temp_a:
|
|
264
|
-
value_a = pd.to_numeric(np.array(temp_a[read]), errors=
|
|
300
|
+
value_a = pd.to_numeric(np.array(temp_a[read]), errors="coerce")
|
|
265
301
|
else:
|
|
266
302
|
value_a = None
|
|
267
303
|
if read in temp_c:
|
|
268
|
-
value_c = pd.to_numeric(np.array(temp_c[read]), errors=
|
|
304
|
+
value_c = pd.to_numeric(np.array(temp_c[read]), errors="coerce")
|
|
269
305
|
else:
|
|
270
306
|
value_c = None
|
|
271
307
|
if value_a is not None and value_c is not None:
|
|
272
308
|
processed_data[read] = np.where(
|
|
273
309
|
np.isnan(value_a) & np.isnan(value_c),
|
|
274
310
|
np.nan,
|
|
275
|
-
np.nan_to_num(value_a) + np.nan_to_num(value_c)
|
|
311
|
+
np.nan_to_num(value_a) + np.nan_to_num(value_c),
|
|
276
312
|
)
|
|
277
313
|
elif value_a is not None:
|
|
278
314
|
processed_data[read] = value_a
|
|
@@ -286,24 +322,28 @@ def process_stranded_methylation(args):
|
|
|
286
322
|
temp_df = dict_list[dict_index][record][sample]
|
|
287
323
|
processed_data = {}
|
|
288
324
|
# Extract columns and convert probabilities to float (coercing errors)
|
|
289
|
-
read_ids = temp_df[
|
|
290
|
-
positions = temp_df[
|
|
291
|
-
call_codes = temp_df[
|
|
292
|
-
probabilities = pd.to_numeric(temp_df[
|
|
293
|
-
|
|
294
|
-
modified_codes = {
|
|
295
|
-
canonical_codes = {
|
|
296
|
-
|
|
325
|
+
read_ids = temp_df["read_id"].values
|
|
326
|
+
positions = temp_df["ref_position"].values
|
|
327
|
+
call_codes = temp_df["call_code"].values
|
|
328
|
+
probabilities = pd.to_numeric(temp_df["call_prob"].values, errors="coerce")
|
|
329
|
+
|
|
330
|
+
modified_codes = {"a", "h", "m"}
|
|
331
|
+
canonical_codes = {"-"}
|
|
332
|
+
|
|
297
333
|
# Compute methylation probabilities (vectorized)
|
|
298
334
|
methylation_prob = np.full(probabilities.shape, np.nan, dtype=float)
|
|
299
|
-
methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[
|
|
300
|
-
|
|
301
|
-
|
|
335
|
+
methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[
|
|
336
|
+
np.isin(call_codes, list(modified_codes))
|
|
337
|
+
]
|
|
338
|
+
methylation_prob[np.isin(call_codes, list(canonical_codes))] = (
|
|
339
|
+
1 - probabilities[np.isin(call_codes, list(canonical_codes))]
|
|
340
|
+
)
|
|
341
|
+
|
|
302
342
|
# Preallocate storage for each unique read
|
|
303
343
|
unique_reads = np.unique(read_ids)
|
|
304
344
|
for read in unique_reads:
|
|
305
345
|
processed_data[read] = np.full(max_reference_length, np.nan, dtype=float)
|
|
306
|
-
|
|
346
|
+
|
|
307
347
|
# Assign values efficiently
|
|
308
348
|
for i in range(len(read_ids)):
|
|
309
349
|
read = read_ids[i]
|
|
@@ -314,10 +354,11 @@ def process_stranded_methylation(args):
|
|
|
314
354
|
gc.collect()
|
|
315
355
|
return dict_index, record, sample, processed_data
|
|
316
356
|
|
|
357
|
+
|
|
317
358
|
def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=4):
|
|
318
359
|
"""
|
|
319
360
|
Processes all (dict_index, record, sample) tasks in dict_list (excluding indices in dict_to_skip) in parallel.
|
|
320
|
-
|
|
361
|
+
|
|
321
362
|
Returns:
|
|
322
363
|
Updated dict_list with processed (nested) dictionaries.
|
|
323
364
|
"""
|
|
@@ -327,16 +368,17 @@ def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference
|
|
|
327
368
|
for record in current_dict.keys():
|
|
328
369
|
for sample in current_dict[record].keys():
|
|
329
370
|
tasks.append((dict_index, record, sample, dict_list, max_reference_length))
|
|
330
|
-
|
|
371
|
+
|
|
331
372
|
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
|
|
332
373
|
for dict_index, record, sample, processed_data in tqdm(
|
|
333
374
|
executor.map(process_stranded_methylation, tasks),
|
|
334
375
|
total=len(tasks),
|
|
335
|
-
desc="Extracting stranded methylation states"
|
|
376
|
+
desc="Extracting stranded methylation states",
|
|
336
377
|
):
|
|
337
378
|
dict_list[dict_index][record][sample] = processed_data
|
|
338
379
|
return dict_list
|
|
339
380
|
|
|
381
|
+
|
|
340
382
|
def delete_intermediate_h5ads_and_tmpdir(
|
|
341
383
|
h5_dir: Union[str, Path, Iterable[str], None],
|
|
342
384
|
tmp_dir: Optional[Union[str, Path]] = None,
|
|
@@ -360,25 +402,27 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
360
402
|
verbose : bool
|
|
361
403
|
Print progress / warnings.
|
|
362
404
|
"""
|
|
405
|
+
|
|
363
406
|
# Helper: remove a single file path (Path-like or string)
|
|
364
407
|
def _maybe_unlink(p: Path):
|
|
408
|
+
"""Remove a file path if it exists and is a file."""
|
|
365
409
|
if not p.exists():
|
|
366
410
|
if verbose:
|
|
367
|
-
|
|
411
|
+
logger.debug(f"[skip] not found: {p}")
|
|
368
412
|
return
|
|
369
413
|
if not p.is_file():
|
|
370
414
|
if verbose:
|
|
371
|
-
|
|
415
|
+
logger.debug(f"[skip] not a file: {p}")
|
|
372
416
|
return
|
|
373
417
|
if dry_run:
|
|
374
|
-
|
|
418
|
+
logger.debug(f"[dry-run] would remove file: {p}")
|
|
375
419
|
return
|
|
376
420
|
try:
|
|
377
421
|
p.unlink()
|
|
378
422
|
if verbose:
|
|
379
|
-
|
|
423
|
+
logger.info(f"Removed file: {p}")
|
|
380
424
|
except Exception as e:
|
|
381
|
-
|
|
425
|
+
logger.warning(f"[error] failed to remove file {p}: {e}")
|
|
382
426
|
|
|
383
427
|
# Handle h5_dir input (directory OR iterable of file paths)
|
|
384
428
|
if h5_dir is not None:
|
|
@@ -393,7 +437,7 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
393
437
|
else:
|
|
394
438
|
if verbose:
|
|
395
439
|
# optional: comment this out if too noisy
|
|
396
|
-
|
|
440
|
+
logger.debug(f"[skip] not matching pattern: {p.name}")
|
|
397
441
|
else:
|
|
398
442
|
# treat as iterable of file paths
|
|
399
443
|
for f in h5_dir:
|
|
@@ -403,30 +447,44 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
403
447
|
_maybe_unlink(p)
|
|
404
448
|
else:
|
|
405
449
|
if verbose:
|
|
406
|
-
|
|
450
|
+
logger.debug(f"[skip] not matching pattern or not a file: {p}")
|
|
407
451
|
|
|
408
452
|
# Remove tmp_dir recursively (if provided)
|
|
409
453
|
if tmp_dir is not None:
|
|
410
454
|
td = Path(tmp_dir)
|
|
411
455
|
if not td.exists():
|
|
412
456
|
if verbose:
|
|
413
|
-
|
|
457
|
+
logger.debug(f"[skip] tmp_dir not found: {td}")
|
|
414
458
|
else:
|
|
415
459
|
if not td.is_dir():
|
|
416
460
|
if verbose:
|
|
417
|
-
|
|
461
|
+
logger.debug(f"[skip] tmp_dir is not a directory: {td}")
|
|
418
462
|
else:
|
|
419
463
|
if dry_run:
|
|
420
|
-
|
|
464
|
+
logger.debug(f"[dry-run] would remove directory tree: {td}")
|
|
421
465
|
else:
|
|
422
466
|
try:
|
|
423
467
|
shutil.rmtree(td)
|
|
424
468
|
if verbose:
|
|
425
|
-
|
|
469
|
+
logger.info(f"Removed directory tree: {td}")
|
|
426
470
|
except Exception as e:
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
471
|
+
logger.warning(f"[error] failed to remove tmp dir {td}: {e}")
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def modkit_extract_to_adata(
|
|
475
|
+
fasta,
|
|
476
|
+
bam_dir,
|
|
477
|
+
out_dir,
|
|
478
|
+
input_already_demuxed,
|
|
479
|
+
mapping_threshold,
|
|
480
|
+
experiment_name,
|
|
481
|
+
mods,
|
|
482
|
+
batch_size,
|
|
483
|
+
mod_tsv_dir,
|
|
484
|
+
delete_batch_hdfs=False,
|
|
485
|
+
threads=None,
|
|
486
|
+
double_barcoded_path=None,
|
|
487
|
+
):
|
|
430
488
|
"""
|
|
431
489
|
Takes modkit extract outputs and organizes it into an adata object
|
|
432
490
|
|
|
@@ -448,50 +506,87 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
448
506
|
"""
|
|
449
507
|
###################################################
|
|
450
508
|
# Package imports
|
|
451
|
-
from .. import readwrite
|
|
452
|
-
from ..readwrite import safe_write_h5ad, make_dirs
|
|
453
|
-
from .fasta_functions import get_native_references
|
|
454
|
-
from .bam_functions import extract_base_identities
|
|
455
|
-
from .ohe import ohe_batching
|
|
456
|
-
import pandas as pd
|
|
457
|
-
import anndata as ad
|
|
458
|
-
import os
|
|
459
509
|
import gc
|
|
460
510
|
import math
|
|
511
|
+
|
|
512
|
+
import anndata as ad
|
|
461
513
|
import numpy as np
|
|
514
|
+
import pandas as pd
|
|
462
515
|
from Bio.Seq import Seq
|
|
463
516
|
from tqdm import tqdm
|
|
464
|
-
|
|
517
|
+
|
|
518
|
+
from .. import readwrite
|
|
519
|
+
from ..readwrite import make_dirs
|
|
520
|
+
from .bam_functions import extract_base_identities
|
|
521
|
+
from .fasta_functions import get_native_references
|
|
522
|
+
from .ohe import ohe_batching
|
|
465
523
|
###################################################
|
|
466
524
|
|
|
467
525
|
################## Get input tsv and bam file names into a sorted list ################
|
|
468
526
|
# Make output dirs
|
|
469
|
-
h5_dir = out_dir /
|
|
470
|
-
tmp_dir = out_dir /
|
|
527
|
+
h5_dir = out_dir / "h5ads"
|
|
528
|
+
tmp_dir = out_dir / "tmp"
|
|
471
529
|
make_dirs([h5_dir, tmp_dir])
|
|
472
530
|
|
|
473
|
-
existing_h5s =
|
|
474
|
-
existing_h5s = [h5 for h5 in existing_h5s if
|
|
475
|
-
final_hdf = f
|
|
531
|
+
existing_h5s = h5_dir.iterdir()
|
|
532
|
+
existing_h5s = [h5 for h5 in existing_h5s if ".h5ad.gz" in str(h5)]
|
|
533
|
+
final_hdf = f"{experiment_name}.h5ad.gz"
|
|
476
534
|
final_adata_path = h5_dir / final_hdf
|
|
477
535
|
final_adata = None
|
|
478
|
-
|
|
536
|
+
|
|
479
537
|
if final_adata_path.exists():
|
|
480
|
-
|
|
538
|
+
logger.debug(f"{final_adata_path} already exists. Using existing adata")
|
|
481
539
|
return final_adata, final_adata_path
|
|
482
|
-
|
|
540
|
+
|
|
483
541
|
# List all files in the directory
|
|
484
542
|
tsvs = sorted(
|
|
485
|
-
p
|
|
486
|
-
|
|
543
|
+
p
|
|
544
|
+
for p in mod_tsv_dir.iterdir()
|
|
545
|
+
if p.is_file() and "unclassified" not in p.name and "extract.tsv" in p.name
|
|
546
|
+
)
|
|
487
547
|
bams = sorted(
|
|
488
|
-
p
|
|
489
|
-
|
|
548
|
+
p
|
|
549
|
+
for p in bam_dir.iterdir()
|
|
550
|
+
if p.is_file()
|
|
551
|
+
and p.suffix == ".bam"
|
|
552
|
+
and "unclassified" not in p.name
|
|
553
|
+
and ".bai" not in p.name
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
tsv_path_list = [tsv for tsv in tsvs]
|
|
557
|
+
bam_path_list = [bam for bam in bams]
|
|
558
|
+
logger.info(f"{len(tsvs)} sample tsv files found: {tsvs}")
|
|
559
|
+
logger.info(f"{len(bams)} sample bams found: {bams}")
|
|
560
|
+
|
|
561
|
+
# Map global sample index (bami / final_sample_index) -> sample name / barcode
|
|
562
|
+
sample_name_map = {}
|
|
563
|
+
barcode_map = {}
|
|
564
|
+
|
|
565
|
+
for idx, bam_path in enumerate(bam_path_list):
|
|
566
|
+
stem = bam_path.stem
|
|
567
|
+
|
|
568
|
+
# Try to peel off a "barcode..." suffix if present.
|
|
569
|
+
# This handles things like:
|
|
570
|
+
# "mySample_barcode01" -> sample="mySample", barcode="barcode01"
|
|
571
|
+
# "run1-s1_barcode05" -> sample="run1-s1", barcode="barcode05"
|
|
572
|
+
# "barcode01" -> sample="barcode01", barcode="barcode01"
|
|
573
|
+
m = re.search(r"^(.*?)[_\-\.]?(barcode[0-9A-Za-z\-]+)$", stem)
|
|
574
|
+
if m:
|
|
575
|
+
sample_name = m.group(1) or stem
|
|
576
|
+
barcode = m.group(2)
|
|
577
|
+
else:
|
|
578
|
+
# Fallback: treat the whole stem as both sample & barcode
|
|
579
|
+
sample_name = stem
|
|
580
|
+
barcode = stem
|
|
581
|
+
|
|
582
|
+
# make sample name of the format of the bam file stem
|
|
583
|
+
sample_name = sample_name + f"_{barcode}"
|
|
584
|
+
|
|
585
|
+
# Clean the barcode name to be an integer
|
|
586
|
+
barcode = int(barcode.split("barcode")[1])
|
|
490
587
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
print(f'{len(tsvs)} sample tsv files found: {tsvs}')
|
|
494
|
-
print(f'{len(bams)} sample bams found: {bams}')
|
|
588
|
+
sample_name_map[idx] = sample_name
|
|
589
|
+
barcode_map[idx] = str(barcode)
|
|
495
590
|
##########################################################################################
|
|
496
591
|
|
|
497
592
|
######### Get Record names that have over a passed threshold of mapped reads #############
|
|
@@ -503,27 +598,29 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
503
598
|
########### Determine the maximum record length to analyze in the dataset ################
|
|
504
599
|
# Get all references within the FASTA and indicate the length and identity of the record sequence
|
|
505
600
|
max_reference_length = 0
|
|
506
|
-
reference_dict = get_native_references(
|
|
601
|
+
reference_dict = get_native_references(
|
|
602
|
+
str(fasta)
|
|
603
|
+
) # returns a dict keyed by record name. Points to a tuple of (reference length, reference sequence)
|
|
507
604
|
# Get the max record length in the dataset.
|
|
508
605
|
for record in records_to_analyze:
|
|
509
606
|
if reference_dict[record][0] > max_reference_length:
|
|
510
607
|
max_reference_length = reference_dict[record][0]
|
|
511
|
-
|
|
512
|
-
batches = math.ceil(len(tsvs) / batch_size)
|
|
513
|
-
|
|
608
|
+
logger.info(f"Max reference length in dataset: {max_reference_length}")
|
|
609
|
+
batches = math.ceil(len(tsvs) / batch_size) # Number of batches to process
|
|
610
|
+
logger.info("Processing input tsvs in {0} batches of {1} tsvs ".format(batches, batch_size))
|
|
514
611
|
##########################################################################################
|
|
515
612
|
|
|
516
613
|
##########################################################################################
|
|
517
|
-
# One hot encode read sequences and write them out into the tmp_dir as h5ad files.
|
|
614
|
+
# One hot encode read sequences and write them out into the tmp_dir as h5ad files.
|
|
518
615
|
# Save the file paths in the bam_record_ohe_files dict.
|
|
519
616
|
bam_record_ohe_files = {}
|
|
520
|
-
bam_record_save = tmp_dir /
|
|
617
|
+
bam_record_save = tmp_dir / "tmp_file_dict.h5ad"
|
|
521
618
|
fwd_mapped_reads = set()
|
|
522
619
|
rev_mapped_reads = set()
|
|
523
620
|
# If this step has already been performed, read in the tmp_dile_dict
|
|
524
621
|
if bam_record_save.exists():
|
|
525
622
|
bam_record_ohe_files = ad.read_h5ad(bam_record_save).uns
|
|
526
|
-
|
|
623
|
+
logger.debug("Found existing OHE reads, using these")
|
|
527
624
|
else:
|
|
528
625
|
# Iterate over split bams
|
|
529
626
|
for bami, bam in enumerate(bam_path_list):
|
|
@@ -533,18 +630,37 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
533
630
|
positions = range(current_reference_length)
|
|
534
631
|
ref_seq = reference_dict[record][1]
|
|
535
632
|
# Extract the base identities of reads aligned to the record
|
|
536
|
-
|
|
633
|
+
(
|
|
634
|
+
fwd_base_identities,
|
|
635
|
+
rev_base_identities,
|
|
636
|
+
mismatch_counts_per_read,
|
|
637
|
+
mismatch_trend_per_read,
|
|
638
|
+
) = extract_base_identities(bam, record, positions, max_reference_length, ref_seq)
|
|
537
639
|
# Store read names of fwd and rev mapped reads
|
|
538
640
|
fwd_mapped_reads.update(fwd_base_identities.keys())
|
|
539
641
|
rev_mapped_reads.update(rev_base_identities.keys())
|
|
540
642
|
# One hot encode the sequence string of the reads
|
|
541
|
-
fwd_ohe_files = ohe_batching(
|
|
542
|
-
|
|
543
|
-
|
|
643
|
+
fwd_ohe_files = ohe_batching(
|
|
644
|
+
fwd_base_identities,
|
|
645
|
+
tmp_dir,
|
|
646
|
+
record,
|
|
647
|
+
f"{bami}_fwd",
|
|
648
|
+
batch_size=100000,
|
|
649
|
+
threads=threads,
|
|
650
|
+
)
|
|
651
|
+
rev_ohe_files = ohe_batching(
|
|
652
|
+
rev_base_identities,
|
|
653
|
+
tmp_dir,
|
|
654
|
+
record,
|
|
655
|
+
f"{bami}_rev",
|
|
656
|
+
batch_size=100000,
|
|
657
|
+
threads=threads,
|
|
658
|
+
)
|
|
659
|
+
bam_record_ohe_files[f"{bami}_{record}"] = fwd_ohe_files + rev_ohe_files
|
|
544
660
|
del fwd_base_identities, rev_base_identities
|
|
545
661
|
# Save out the ohe file paths
|
|
546
662
|
X = np.random.rand(1, 1)
|
|
547
|
-
tmp_ad = ad.AnnData(X=X, uns=bam_record_ohe_files)
|
|
663
|
+
tmp_ad = ad.AnnData(X=X, uns=bam_record_ohe_files)
|
|
548
664
|
tmp_ad.write_h5ad(bam_record_save)
|
|
549
665
|
##########################################################################################
|
|
550
666
|
|
|
@@ -554,39 +670,73 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
554
670
|
for record in records_to_analyze:
|
|
555
671
|
current_reference_length = reference_dict[record][0]
|
|
556
672
|
delta_max_length = max_reference_length - current_reference_length
|
|
557
|
-
sequence = reference_dict[record][1] +
|
|
558
|
-
complement =
|
|
673
|
+
sequence = reference_dict[record][1] + "N" * delta_max_length
|
|
674
|
+
complement = (
|
|
675
|
+
str(Seq(reference_dict[record][1]).complement()).upper() + "N" * delta_max_length
|
|
676
|
+
)
|
|
559
677
|
record_seq_dict[record] = (sequence, complement)
|
|
560
678
|
##########################################################################################
|
|
561
679
|
|
|
562
680
|
###################################################
|
|
563
681
|
# Begin iterating over batches
|
|
564
682
|
for batch in range(batches):
|
|
565
|
-
|
|
683
|
+
logger.info("Processing tsvs for batch {0} ".format(batch))
|
|
566
684
|
# For the final batch, just take the remaining tsv and bam files
|
|
567
685
|
if batch == batches - 1:
|
|
568
686
|
tsv_batch = tsv_path_list
|
|
569
687
|
bam_batch = bam_path_list
|
|
570
|
-
# For all other batches, take the next batch of tsvs and bams out of the file queue.
|
|
688
|
+
# For all other batches, take the next batch of tsvs and bams out of the file queue.
|
|
571
689
|
else:
|
|
572
690
|
tsv_batch = tsv_path_list[:batch_size]
|
|
573
691
|
bam_batch = bam_path_list[:batch_size]
|
|
574
692
|
tsv_path_list = tsv_path_list[batch_size:]
|
|
575
693
|
bam_path_list = bam_path_list[batch_size:]
|
|
576
|
-
|
|
694
|
+
logger.info("tsvs in batch {0} ".format(tsv_batch))
|
|
577
695
|
|
|
578
|
-
batch_already_processed = sum([1 for h5 in existing_h5s if f
|
|
579
|
-
|
|
696
|
+
batch_already_processed = sum([1 for h5 in existing_h5s if f"_{batch}_" in h5.name])
|
|
697
|
+
###################################################
|
|
580
698
|
if batch_already_processed:
|
|
581
|
-
|
|
699
|
+
logger.debug(
|
|
700
|
+
f"Batch {batch} has already been processed into h5ads. Skipping batch and using existing files"
|
|
701
|
+
)
|
|
582
702
|
else:
|
|
583
703
|
###################################################
|
|
584
704
|
### Add the tsvs as dataframes to a dictionary (dict_total) keyed by integer index. Also make modification specific dictionaries and strand specific dictionaries.
|
|
585
705
|
# # Initialize dictionaries and place them in a list
|
|
586
|
-
|
|
587
|
-
|
|
706
|
+
(
|
|
707
|
+
dict_total,
|
|
708
|
+
dict_a,
|
|
709
|
+
dict_a_bottom,
|
|
710
|
+
dict_a_top,
|
|
711
|
+
dict_c,
|
|
712
|
+
dict_c_bottom,
|
|
713
|
+
dict_c_top,
|
|
714
|
+
dict_combined_bottom,
|
|
715
|
+
dict_combined_top,
|
|
716
|
+
) = {}, {}, {}, {}, {}, {}, {}, {}, {}
|
|
717
|
+
dict_list = [
|
|
718
|
+
dict_total,
|
|
719
|
+
dict_a,
|
|
720
|
+
dict_a_bottom,
|
|
721
|
+
dict_a_top,
|
|
722
|
+
dict_c,
|
|
723
|
+
dict_c_bottom,
|
|
724
|
+
dict_c_top,
|
|
725
|
+
dict_combined_bottom,
|
|
726
|
+
dict_combined_top,
|
|
727
|
+
]
|
|
588
728
|
# Give names to represent each dictionary in the list
|
|
589
|
-
sample_types = [
|
|
729
|
+
sample_types = [
|
|
730
|
+
"total",
|
|
731
|
+
"m6A",
|
|
732
|
+
"m6A_bottom_strand",
|
|
733
|
+
"m6A_top_strand",
|
|
734
|
+
"5mC",
|
|
735
|
+
"5mC_bottom_strand",
|
|
736
|
+
"5mC_top_strand",
|
|
737
|
+
"combined_bottom_strand",
|
|
738
|
+
"combined_top_strand",
|
|
739
|
+
]
|
|
590
740
|
# Give indices of dictionaries to skip for analysis and final dictionary saving.
|
|
591
741
|
dict_to_skip = [0, 1, 4]
|
|
592
742
|
combined_dicts = [7, 8]
|
|
@@ -596,7 +746,14 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
596
746
|
dict_to_skip = set(dict_to_skip)
|
|
597
747
|
|
|
598
748
|
# # Step 1):Load the dict_total dictionary with all of the batch tsv files as dataframes.
|
|
599
|
-
dict_total = parallel_load_tsvs(
|
|
749
|
+
dict_total = parallel_load_tsvs(
|
|
750
|
+
tsv_batch,
|
|
751
|
+
records_to_analyze,
|
|
752
|
+
reference_dict,
|
|
753
|
+
batch,
|
|
754
|
+
batch_size=len(tsv_batch),
|
|
755
|
+
threads=threads,
|
|
756
|
+
)
|
|
600
757
|
|
|
601
758
|
# # Step 2: Extract modification-specific data (per (record,sample)) in parallel
|
|
602
759
|
# processed_mod_results = parallel_process_modifications(dict_total, mods, max_reference_length, threads=threads or 4)
|
|
@@ -621,56 +778,112 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
621
778
|
# Iterate over dict_total of all the tsv files and extract the modification specific and strand specific dataframes into dictionaries
|
|
622
779
|
for record in dict_total.keys():
|
|
623
780
|
for sample_index in dict_total[record].keys():
|
|
624
|
-
if
|
|
781
|
+
if "6mA" in mods:
|
|
625
782
|
# Remove Adenine stranded dicts from the dicts to skip set
|
|
626
783
|
dict_to_skip.difference_update(set(A_stranded_dicts))
|
|
627
784
|
|
|
628
|
-
if
|
|
785
|
+
if (
|
|
786
|
+
record not in dict_a.keys()
|
|
787
|
+
and record not in dict_a_bottom.keys()
|
|
788
|
+
and record not in dict_a_top.keys()
|
|
789
|
+
):
|
|
629
790
|
dict_a[record], dict_a_bottom[record], dict_a_top[record] = {}, {}, {}
|
|
630
791
|
|
|
631
792
|
# get a dictionary of dataframes that only contain methylated adenine positions
|
|
632
|
-
dict_a[record][sample_index] = dict_total[record][sample_index][
|
|
633
|
-
|
|
793
|
+
dict_a[record][sample_index] = dict_total[record][sample_index][
|
|
794
|
+
dict_total[record][sample_index]["modified_primary_base"] == "A"
|
|
795
|
+
]
|
|
796
|
+
logger.debug(
|
|
797
|
+
"Successfully loaded a methyl-adenine dictionary for {}".format(
|
|
798
|
+
str(sample_index)
|
|
799
|
+
)
|
|
800
|
+
)
|
|
801
|
+
|
|
634
802
|
# Stratify the adenine dictionary into two strand specific dictionaries.
|
|
635
|
-
dict_a_bottom[record][sample_index] = dict_a[record][sample_index][
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
803
|
+
dict_a_bottom[record][sample_index] = dict_a[record][sample_index][
|
|
804
|
+
dict_a[record][sample_index]["ref_strand"] == "-"
|
|
805
|
+
]
|
|
806
|
+
logger.debug(
|
|
807
|
+
"Successfully loaded a minus strand methyl-adenine dictionary for {}".format(
|
|
808
|
+
str(sample_index)
|
|
809
|
+
)
|
|
810
|
+
)
|
|
811
|
+
dict_a_top[record][sample_index] = dict_a[record][sample_index][
|
|
812
|
+
dict_a[record][sample_index]["ref_strand"] == "+"
|
|
813
|
+
]
|
|
814
|
+
logger.debug(
|
|
815
|
+
"Successfully loaded a plus strand methyl-adenine dictionary for ".format(
|
|
816
|
+
str(sample_index)
|
|
817
|
+
)
|
|
818
|
+
)
|
|
639
819
|
|
|
640
820
|
# Reassign pointer for dict_a to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
641
821
|
dict_a[record][sample_index] = None
|
|
642
822
|
gc.collect()
|
|
643
823
|
|
|
644
|
-
if
|
|
824
|
+
if "5mC" in mods:
|
|
645
825
|
# Remove Cytosine stranded dicts from the dicts to skip set
|
|
646
826
|
dict_to_skip.difference_update(set(C_stranded_dicts))
|
|
647
827
|
|
|
648
|
-
if
|
|
828
|
+
if (
|
|
829
|
+
record not in dict_c.keys()
|
|
830
|
+
and record not in dict_c_bottom.keys()
|
|
831
|
+
and record not in dict_c_top.keys()
|
|
832
|
+
):
|
|
649
833
|
dict_c[record], dict_c_bottom[record], dict_c_top[record] = {}, {}, {}
|
|
650
834
|
|
|
651
835
|
# get a dictionary of dataframes that only contain methylated cytosine positions
|
|
652
|
-
dict_c[record][sample_index] = dict_total[record][sample_index][
|
|
653
|
-
|
|
836
|
+
dict_c[record][sample_index] = dict_total[record][sample_index][
|
|
837
|
+
dict_total[record][sample_index]["modified_primary_base"] == "C"
|
|
838
|
+
]
|
|
839
|
+
logger.debug(
|
|
840
|
+
"Successfully loaded a methyl-cytosine dictionary for {}".format(
|
|
841
|
+
str(sample_index)
|
|
842
|
+
)
|
|
843
|
+
)
|
|
654
844
|
# Stratify the cytosine dictionary into two strand specific dictionaries.
|
|
655
|
-
dict_c_bottom[record][sample_index] = dict_c[record][sample_index][
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
845
|
+
dict_c_bottom[record][sample_index] = dict_c[record][sample_index][
|
|
846
|
+
dict_c[record][sample_index]["ref_strand"] == "-"
|
|
847
|
+
]
|
|
848
|
+
logger.debug(
|
|
849
|
+
"Successfully loaded a minus strand methyl-cytosine dictionary for {}".format(
|
|
850
|
+
str(sample_index)
|
|
851
|
+
)
|
|
852
|
+
)
|
|
853
|
+
dict_c_top[record][sample_index] = dict_c[record][sample_index][
|
|
854
|
+
dict_c[record][sample_index]["ref_strand"] == "+"
|
|
855
|
+
]
|
|
856
|
+
logger.debug(
|
|
857
|
+
"Successfully loaded a plus strand methyl-cytosine dictionary for {}".format(
|
|
858
|
+
str(sample_index)
|
|
859
|
+
)
|
|
860
|
+
)
|
|
659
861
|
# Reassign pointer for dict_c to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
660
862
|
dict_c[record][sample_index] = None
|
|
661
863
|
gc.collect()
|
|
662
|
-
|
|
663
|
-
if
|
|
864
|
+
|
|
865
|
+
if "6mA" in mods and "5mC" in mods:
|
|
664
866
|
# Remove combined stranded dicts from the dicts to skip set
|
|
665
|
-
dict_to_skip.difference_update(set(combined_dicts))
|
|
867
|
+
dict_to_skip.difference_update(set(combined_dicts))
|
|
666
868
|
# Initialize the sample keys for the combined dictionaries
|
|
667
869
|
|
|
668
|
-
if
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
870
|
+
if (
|
|
871
|
+
record not in dict_combined_bottom.keys()
|
|
872
|
+
and record not in dict_combined_top.keys()
|
|
873
|
+
):
|
|
874
|
+
dict_combined_bottom[record], dict_combined_top[record] = {}, {}
|
|
875
|
+
|
|
876
|
+
logger.debug(
|
|
877
|
+
"Successfully created a minus strand combined methylation dictionary for {}".format(
|
|
878
|
+
str(sample_index)
|
|
879
|
+
)
|
|
880
|
+
)
|
|
672
881
|
dict_combined_bottom[record][sample_index] = []
|
|
673
|
-
|
|
882
|
+
logger.debug(
|
|
883
|
+
"Successfully created a plus strand combined methylation dictionary for {}".format(
|
|
884
|
+
str(sample_index)
|
|
885
|
+
)
|
|
886
|
+
)
|
|
674
887
|
dict_combined_top[record][sample_index] = []
|
|
675
888
|
|
|
676
889
|
# Reassign pointer for dict_total to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
@@ -681,14 +894,24 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
681
894
|
for dict_index, dict_type in enumerate(dict_list):
|
|
682
895
|
# Only iterate over stranded dictionaries
|
|
683
896
|
if dict_index not in dict_to_skip:
|
|
684
|
-
|
|
897
|
+
logger.debug(
|
|
898
|
+
"Extracting methylation states for {} dictionary".format(
|
|
899
|
+
sample_types[dict_index]
|
|
900
|
+
)
|
|
901
|
+
)
|
|
685
902
|
for record in dict_type.keys():
|
|
686
903
|
# Get the dictionary for the modification type of interest from the reference mapping of interest
|
|
687
904
|
mod_strand_record_sample_dict = dict_type[record]
|
|
688
|
-
|
|
905
|
+
logger.debug(
|
|
906
|
+
"Extracting methylation states for {} dictionary".format(record)
|
|
907
|
+
)
|
|
689
908
|
# For each sample in a stranded dictionary
|
|
690
909
|
n_samples = len(mod_strand_record_sample_dict.keys())
|
|
691
|
-
for sample in tqdm(
|
|
910
|
+
for sample in tqdm(
|
|
911
|
+
mod_strand_record_sample_dict.keys(),
|
|
912
|
+
desc=f"Extracting {sample_types[dict_index]} dictionary from record {record} for sample",
|
|
913
|
+
total=n_samples,
|
|
914
|
+
):
|
|
692
915
|
# Load the combined bottom strand dictionary after all the individual dictionaries have been made for the sample
|
|
693
916
|
if dict_index == 7:
|
|
694
917
|
# Load the minus strand dictionaries for each sample into temporary variables
|
|
@@ -699,16 +922,26 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
699
922
|
for read in set(temp_a_dict) | set(temp_c_dict):
|
|
700
923
|
# Add the arrays element-wise if the read is present in both dictionaries
|
|
701
924
|
if read in temp_a_dict and read in temp_c_dict:
|
|
702
|
-
mod_strand_record_sample_dict[sample][read] = np.where(
|
|
925
|
+
mod_strand_record_sample_dict[sample][read] = np.where(
|
|
926
|
+
np.isnan(temp_a_dict[read])
|
|
927
|
+
& np.isnan(temp_c_dict[read]),
|
|
928
|
+
np.nan,
|
|
929
|
+
np.nan_to_num(temp_a_dict[read])
|
|
930
|
+
+ np.nan_to_num(temp_c_dict[read]),
|
|
931
|
+
)
|
|
703
932
|
# If the read is present in only one dictionary, copy its value
|
|
704
933
|
elif read in temp_a_dict:
|
|
705
|
-
mod_strand_record_sample_dict[sample][read] = temp_a_dict[
|
|
934
|
+
mod_strand_record_sample_dict[sample][read] = temp_a_dict[
|
|
935
|
+
read
|
|
936
|
+
]
|
|
706
937
|
elif read in temp_c_dict:
|
|
707
|
-
mod_strand_record_sample_dict[sample][read] = temp_c_dict[
|
|
938
|
+
mod_strand_record_sample_dict[sample][read] = temp_c_dict[
|
|
939
|
+
read
|
|
940
|
+
]
|
|
708
941
|
del temp_a_dict, temp_c_dict
|
|
709
|
-
|
|
942
|
+
# Load the combined top strand dictionary after all the individual dictionaries have been made for the sample
|
|
710
943
|
elif dict_index == 8:
|
|
711
|
-
|
|
944
|
+
# Load the plus strand dictionaries for each sample into temporary variables
|
|
712
945
|
temp_a_dict = dict_list[3][record][sample].copy()
|
|
713
946
|
temp_c_dict = dict_list[6][record][sample].copy()
|
|
714
947
|
mod_strand_record_sample_dict[sample] = {}
|
|
@@ -716,105 +949,163 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
716
949
|
for read in set(temp_a_dict) | set(temp_c_dict):
|
|
717
950
|
# Add the arrays element-wise if the read is present in both dictionaries
|
|
718
951
|
if read in temp_a_dict and read in temp_c_dict:
|
|
719
|
-
mod_strand_record_sample_dict[sample][read] = np.where(
|
|
952
|
+
mod_strand_record_sample_dict[sample][read] = np.where(
|
|
953
|
+
np.isnan(temp_a_dict[read])
|
|
954
|
+
& np.isnan(temp_c_dict[read]),
|
|
955
|
+
np.nan,
|
|
956
|
+
np.nan_to_num(temp_a_dict[read])
|
|
957
|
+
+ np.nan_to_num(temp_c_dict[read]),
|
|
958
|
+
)
|
|
720
959
|
# If the read is present in only one dictionary, copy its value
|
|
721
960
|
elif read in temp_a_dict:
|
|
722
|
-
mod_strand_record_sample_dict[sample][read] = temp_a_dict[
|
|
961
|
+
mod_strand_record_sample_dict[sample][read] = temp_a_dict[
|
|
962
|
+
read
|
|
963
|
+
]
|
|
723
964
|
elif read in temp_c_dict:
|
|
724
|
-
mod_strand_record_sample_dict[sample][read] = temp_c_dict[
|
|
965
|
+
mod_strand_record_sample_dict[sample][read] = temp_c_dict[
|
|
966
|
+
read
|
|
967
|
+
]
|
|
725
968
|
del temp_a_dict, temp_c_dict
|
|
726
969
|
# For all other dictionaries
|
|
727
970
|
else:
|
|
728
|
-
|
|
729
971
|
# use temp_df to point to the dataframe held in mod_strand_record_sample_dict[sample]
|
|
730
972
|
temp_df = mod_strand_record_sample_dict[sample]
|
|
731
973
|
# reassign the dictionary pointer to a nested dictionary.
|
|
732
974
|
mod_strand_record_sample_dict[sample] = {}
|
|
733
975
|
|
|
734
976
|
# Get relevant columns as NumPy arrays
|
|
735
|
-
read_ids = temp_df[
|
|
736
|
-
positions = temp_df[
|
|
737
|
-
call_codes = temp_df[
|
|
738
|
-
probabilities = temp_df[
|
|
977
|
+
read_ids = temp_df["read_id"].values
|
|
978
|
+
positions = temp_df["ref_position"].values
|
|
979
|
+
call_codes = temp_df["call_code"].values
|
|
980
|
+
probabilities = temp_df["call_prob"].values
|
|
739
981
|
|
|
740
982
|
# Define valid call code categories
|
|
741
|
-
modified_codes = {
|
|
742
|
-
canonical_codes = {
|
|
983
|
+
modified_codes = {"a", "h", "m"}
|
|
984
|
+
canonical_codes = {"-"}
|
|
743
985
|
|
|
744
986
|
# Vectorized methylation calculation with NaN for other codes
|
|
745
|
-
methylation_prob = np.full_like(
|
|
746
|
-
|
|
747
|
-
|
|
987
|
+
methylation_prob = np.full_like(
|
|
988
|
+
probabilities, np.nan
|
|
989
|
+
) # Default all to NaN
|
|
990
|
+
methylation_prob[np.isin(call_codes, list(modified_codes))] = (
|
|
991
|
+
probabilities[np.isin(call_codes, list(modified_codes))]
|
|
992
|
+
)
|
|
993
|
+
methylation_prob[np.isin(call_codes, list(canonical_codes))] = (
|
|
994
|
+
1 - probabilities[np.isin(call_codes, list(canonical_codes))]
|
|
995
|
+
)
|
|
748
996
|
|
|
749
997
|
# Find unique reads
|
|
750
998
|
unique_reads = np.unique(read_ids)
|
|
751
999
|
# Preallocate storage for each read
|
|
752
1000
|
for read in unique_reads:
|
|
753
|
-
mod_strand_record_sample_dict[sample][read] = np.full(
|
|
1001
|
+
mod_strand_record_sample_dict[sample][read] = np.full(
|
|
1002
|
+
max_reference_length, np.nan
|
|
1003
|
+
)
|
|
754
1004
|
|
|
755
1005
|
# Efficient NumPy indexing to assign values
|
|
756
1006
|
for i in range(len(read_ids)):
|
|
757
1007
|
read = read_ids[i]
|
|
758
1008
|
pos = positions[i]
|
|
759
1009
|
prob = methylation_prob[i]
|
|
760
|
-
|
|
1010
|
+
|
|
761
1011
|
# Assign methylation probability
|
|
762
1012
|
mod_strand_record_sample_dict[sample][read][pos] = prob
|
|
763
1013
|
|
|
764
|
-
|
|
765
1014
|
# Save the sample files in the batch as gzipped hdf5 files
|
|
766
|
-
|
|
1015
|
+
logger.info("Converting batch {} dictionaries to anndata objects".format(batch))
|
|
767
1016
|
for dict_index, dict_type in enumerate(dict_list):
|
|
768
1017
|
if dict_index not in dict_to_skip:
|
|
769
1018
|
# Initialize an hdf5 file for the current modified strand
|
|
770
1019
|
adata = None
|
|
771
|
-
|
|
1020
|
+
logger.info(
|
|
1021
|
+
"Converting {} dictionary to an anndata object".format(
|
|
1022
|
+
sample_types[dict_index]
|
|
1023
|
+
)
|
|
1024
|
+
)
|
|
772
1025
|
for record in dict_type.keys():
|
|
773
1026
|
# Get the dictionary for the modification type of interest from the reference mapping of interest
|
|
774
1027
|
mod_strand_record_sample_dict = dict_type[record]
|
|
775
1028
|
for sample in mod_strand_record_sample_dict.keys():
|
|
776
|
-
|
|
1029
|
+
logger.info(
|
|
1030
|
+
"Converting {0} dictionary for sample {1} to an anndata object".format(
|
|
1031
|
+
sample_types[dict_index], sample
|
|
1032
|
+
)
|
|
1033
|
+
)
|
|
777
1034
|
sample = int(sample)
|
|
778
1035
|
final_sample_index = sample + (batch * batch_size)
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
1036
|
+
logger.info(
|
|
1037
|
+
"Final sample index for sample: {}".format(final_sample_index)
|
|
1038
|
+
)
|
|
1039
|
+
logger.debug(
|
|
1040
|
+
"Converting {0} dictionary for sample {1} to a dataframe".format(
|
|
1041
|
+
sample_types[dict_index],
|
|
1042
|
+
final_sample_index,
|
|
1043
|
+
)
|
|
1044
|
+
)
|
|
1045
|
+
temp_df = pd.DataFrame.from_dict(
|
|
1046
|
+
mod_strand_record_sample_dict[sample], orient="index"
|
|
1047
|
+
)
|
|
1048
|
+
mod_strand_record_sample_dict[sample] = (
|
|
1049
|
+
None # reassign pointer to facilitate memory usage
|
|
1050
|
+
)
|
|
783
1051
|
sorted_index = sorted(temp_df.index)
|
|
784
1052
|
temp_df = temp_df.reindex(sorted_index)
|
|
785
1053
|
X = temp_df.values
|
|
786
|
-
dataset, strand = sample_types[dict_index].split(
|
|
787
|
-
|
|
788
|
-
|
|
1054
|
+
dataset, strand = sample_types[dict_index].split("_")[:2]
|
|
1055
|
+
|
|
1056
|
+
logger.info(
|
|
1057
|
+
"Loading {0} dataframe for sample {1} into a temp anndata object".format(
|
|
1058
|
+
sample_types[dict_index],
|
|
1059
|
+
final_sample_index,
|
|
1060
|
+
)
|
|
1061
|
+
)
|
|
789
1062
|
temp_adata = ad.AnnData(X)
|
|
790
1063
|
if temp_adata.shape[0] > 0:
|
|
791
|
-
|
|
1064
|
+
logger.info(
|
|
1065
|
+
"Adding read names and position ids to {0} anndata for sample {1}".format(
|
|
1066
|
+
sample_types[dict_index],
|
|
1067
|
+
final_sample_index,
|
|
1068
|
+
)
|
|
1069
|
+
)
|
|
792
1070
|
temp_adata.obs_names = temp_df.index
|
|
793
1071
|
temp_adata.obs_names = temp_adata.obs_names.astype(str)
|
|
794
1072
|
temp_adata.var_names = temp_df.columns
|
|
795
1073
|
temp_adata.var_names = temp_adata.var_names.astype(str)
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
temp_adata.obs[
|
|
803
|
-
|
|
804
|
-
|
|
1074
|
+
logger.info(
|
|
1075
|
+
"Adding {0} anndata for sample {1}".format(
|
|
1076
|
+
sample_types[dict_index],
|
|
1077
|
+
final_sample_index,
|
|
1078
|
+
)
|
|
1079
|
+
)
|
|
1080
|
+
temp_adata.obs["Sample"] = [
|
|
1081
|
+
sample_name_map[final_sample_index]
|
|
1082
|
+
] * len(temp_adata)
|
|
1083
|
+
temp_adata.obs["Barcode"] = [barcode_map[final_sample_index]] * len(
|
|
1084
|
+
temp_adata
|
|
1085
|
+
)
|
|
1086
|
+
temp_adata.obs["Reference"] = [f"{record}"] * len(temp_adata)
|
|
1087
|
+
temp_adata.obs["Strand"] = [strand] * len(temp_adata)
|
|
1088
|
+
temp_adata.obs["Dataset"] = [dataset] * len(temp_adata)
|
|
1089
|
+
temp_adata.obs["Reference_dataset_strand"] = [
|
|
1090
|
+
f"{record}_{dataset}_{strand}"
|
|
1091
|
+
] * len(temp_adata)
|
|
1092
|
+
temp_adata.obs["Reference_strand"] = [f"{record}_{strand}"] * len(
|
|
1093
|
+
temp_adata
|
|
1094
|
+
)
|
|
1095
|
+
|
|
805
1096
|
# Load in the one hot encoded reads from the current sample and record
|
|
806
1097
|
one_hot_reads = {}
|
|
807
1098
|
n_rows_OHE = 5
|
|
808
|
-
ohe_files = bam_record_ohe_files[f
|
|
809
|
-
|
|
1099
|
+
ohe_files = bam_record_ohe_files[f"{final_sample_index}_{record}"]
|
|
1100
|
+
logger.info(f"Loading OHEs from {ohe_files}")
|
|
810
1101
|
fwd_mapped_reads = set()
|
|
811
1102
|
rev_mapped_reads = set()
|
|
812
1103
|
for ohe_file in ohe_files:
|
|
813
1104
|
tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
|
|
814
1105
|
one_hot_reads.update(tmp_ohe_dict)
|
|
815
|
-
if
|
|
1106
|
+
if "_fwd_" in ohe_file:
|
|
816
1107
|
fwd_mapped_reads.update(tmp_ohe_dict.keys())
|
|
817
|
-
elif
|
|
1108
|
+
elif "_rev_" in ohe_file:
|
|
818
1109
|
rev_mapped_reads.update(tmp_ohe_dict.keys())
|
|
819
1110
|
del tmp_ohe_dict
|
|
820
1111
|
|
|
@@ -823,18 +1114,20 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
823
1114
|
read_mapping_direction = []
|
|
824
1115
|
for read_id in temp_adata.obs_names:
|
|
825
1116
|
if read_id in fwd_mapped_reads:
|
|
826
|
-
read_mapping_direction.append(
|
|
1117
|
+
read_mapping_direction.append("fwd")
|
|
827
1118
|
elif read_id in rev_mapped_reads:
|
|
828
|
-
read_mapping_direction.append(
|
|
1119
|
+
read_mapping_direction.append("rev")
|
|
829
1120
|
else:
|
|
830
|
-
read_mapping_direction.append(
|
|
1121
|
+
read_mapping_direction.append("unk")
|
|
831
1122
|
|
|
832
|
-
temp_adata.obs[
|
|
1123
|
+
temp_adata.obs["Read_mapping_direction"] = read_mapping_direction
|
|
833
1124
|
|
|
834
1125
|
del temp_df
|
|
835
|
-
|
|
1126
|
+
|
|
836
1127
|
# Initialize NumPy arrays
|
|
837
|
-
sequence_length =
|
|
1128
|
+
sequence_length = (
|
|
1129
|
+
one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
|
|
1130
|
+
)
|
|
838
1131
|
df_A = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
839
1132
|
df_C = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
840
1133
|
df_G = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
@@ -855,7 +1148,11 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
855
1148
|
gc.collect()
|
|
856
1149
|
|
|
857
1150
|
# Fill the arrays
|
|
858
|
-
for j, read_name in tqdm(
|
|
1151
|
+
for j, read_name in tqdm(
|
|
1152
|
+
enumerate(sorted_index),
|
|
1153
|
+
desc="Loading dataframes of OHE reads",
|
|
1154
|
+
total=len(sorted_index),
|
|
1155
|
+
):
|
|
859
1156
|
df_A[j, :] = dict_A[read_name]
|
|
860
1157
|
df_C[j, :] = dict_C[read_name]
|
|
861
1158
|
df_G[j, :] = dict_G[read_name]
|
|
@@ -867,43 +1164,78 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
867
1164
|
|
|
868
1165
|
# Store the results in AnnData layers
|
|
869
1166
|
ohe_df_map = {0: df_A, 1: df_C, 2: df_G, 3: df_T, 4: df_N}
|
|
870
|
-
for j, base in enumerate([
|
|
871
|
-
temp_adata.layers[f
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
1167
|
+
for j, base in enumerate(["A", "C", "G", "T", "N"]):
|
|
1168
|
+
temp_adata.layers[f"{base}_binary_sequence_encoding"] = (
|
|
1169
|
+
ohe_df_map[j]
|
|
1170
|
+
)
|
|
1171
|
+
ohe_df_map[j] = (
|
|
1172
|
+
None # Reassign pointer for memory usage purposes
|
|
1173
|
+
)
|
|
1174
|
+
|
|
1175
|
+
# If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
|
|
875
1176
|
if adata:
|
|
876
1177
|
if temp_adata.shape[0] > 0:
|
|
877
|
-
|
|
878
|
-
|
|
1178
|
+
logger.info(
|
|
1179
|
+
"Concatenating {0} anndata object for sample {1}".format(
|
|
1180
|
+
sample_types[dict_index],
|
|
1181
|
+
final_sample_index,
|
|
1182
|
+
)
|
|
1183
|
+
)
|
|
1184
|
+
adata = ad.concat(
|
|
1185
|
+
[adata, temp_adata], join="outer", index_unique=None
|
|
1186
|
+
)
|
|
879
1187
|
del temp_adata
|
|
880
1188
|
else:
|
|
881
|
-
|
|
1189
|
+
logger.warning(
|
|
1190
|
+
f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata"
|
|
1191
|
+
)
|
|
882
1192
|
else:
|
|
883
1193
|
if temp_adata.shape[0] > 0:
|
|
884
|
-
|
|
1194
|
+
logger.info(
|
|
1195
|
+
"Initializing {0} anndata object for sample {1}".format(
|
|
1196
|
+
sample_types[dict_index],
|
|
1197
|
+
final_sample_index,
|
|
1198
|
+
)
|
|
1199
|
+
)
|
|
885
1200
|
adata = temp_adata
|
|
886
1201
|
else:
|
|
887
|
-
|
|
1202
|
+
logger.warning(
|
|
1203
|
+
f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata"
|
|
1204
|
+
)
|
|
888
1205
|
|
|
889
1206
|
gc.collect()
|
|
890
1207
|
else:
|
|
891
|
-
|
|
1208
|
+
logger.warning(
|
|
1209
|
+
f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata. Skipping sample."
|
|
1210
|
+
)
|
|
892
1211
|
|
|
893
1212
|
try:
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
1213
|
+
logger.info(
|
|
1214
|
+
"Writing {0} anndata out as a hdf5 file".format(
|
|
1215
|
+
sample_types[dict_index]
|
|
1216
|
+
)
|
|
1217
|
+
)
|
|
1218
|
+
adata.write_h5ad(
|
|
1219
|
+
h5_dir
|
|
1220
|
+
/ "{0}_{1}_{2}_SMF_binarized_sample_hdf5.h5ad.gz".format(
|
|
1221
|
+
readwrite.date_string(), batch, sample_types[dict_index]
|
|
1222
|
+
),
|
|
1223
|
+
compression="gzip",
|
|
1224
|
+
)
|
|
1225
|
+
except Exception:
|
|
1226
|
+
logger.debug("Skipping writing anndata for sample")
|
|
1227
|
+
|
|
1228
|
+
try:
|
|
1229
|
+
# Delete the batch dictionaries from memory
|
|
1230
|
+
del dict_list, adata
|
|
1231
|
+
except Exception:
|
|
1232
|
+
pass
|
|
901
1233
|
gc.collect()
|
|
902
1234
|
|
|
903
1235
|
# Iterate over all of the batched hdf5 files and concatenate them.
|
|
904
|
-
files = h5_dir.iterdir()
|
|
1236
|
+
files = h5_dir.iterdir()
|
|
905
1237
|
# Filter file names that contain the search string in their filename and keep them in a list
|
|
906
|
-
hdfs = [hdf for hdf in files if
|
|
1238
|
+
hdfs = [hdf for hdf in files if "hdf5.h5ad" in hdf.name and hdf != final_hdf]
|
|
907
1239
|
combined_hdfs = [hdf for hdf in hdfs if "combined" in hdf.name]
|
|
908
1240
|
if len(combined_hdfs) > 0:
|
|
909
1241
|
hdfs = combined_hdfs
|
|
@@ -911,55 +1243,62 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
911
1243
|
pass
|
|
912
1244
|
# Sort file list by names and print the list of file names
|
|
913
1245
|
hdfs.sort()
|
|
914
|
-
|
|
915
|
-
hdf_paths = [
|
|
1246
|
+
logger.info("{0} sample files found: {1}".format(len(hdfs), hdfs))
|
|
1247
|
+
hdf_paths = [hd5 for hd5 in hdfs]
|
|
916
1248
|
final_adata = None
|
|
917
1249
|
for hdf_index, hdf in enumerate(hdf_paths):
|
|
918
|
-
|
|
1250
|
+
logger.info("Reading in {} hdf5 file".format(hdfs[hdf_index]))
|
|
919
1251
|
temp_adata = ad.read_h5ad(hdf)
|
|
920
1252
|
if final_adata:
|
|
921
|
-
|
|
922
|
-
|
|
1253
|
+
logger.info(
|
|
1254
|
+
"Concatenating final adata object with {} hdf5 file".format(hdfs[hdf_index])
|
|
1255
|
+
)
|
|
1256
|
+
final_adata = ad.concat([final_adata, temp_adata], join="outer", index_unique=None)
|
|
923
1257
|
else:
|
|
924
|
-
|
|
1258
|
+
logger.info("Initializing final adata object with {} hdf5 file".format(hdfs[hdf_index]))
|
|
925
1259
|
final_adata = temp_adata
|
|
926
1260
|
del temp_adata
|
|
927
1261
|
|
|
928
1262
|
# Set obs columns to type 'category'
|
|
929
1263
|
for col in final_adata.obs.columns:
|
|
930
|
-
final_adata.obs[col] = final_adata.obs[col].astype(
|
|
1264
|
+
final_adata.obs[col] = final_adata.obs[col].astype("category")
|
|
931
1265
|
|
|
932
|
-
ohe_bases = [
|
|
933
|
-
ohe_layers = [f"{ohe_base}
|
|
934
|
-
final_adata.uns[
|
|
1266
|
+
ohe_bases = ["A", "C", "G", "T"] # ignore N bases for consensus
|
|
1267
|
+
ohe_layers = [f"{ohe_base}_binary_sequence_encoding" for ohe_base in ohe_bases]
|
|
1268
|
+
final_adata.uns["References"] = {}
|
|
935
1269
|
for record in records_to_analyze:
|
|
936
1270
|
# Add FASTA sequence to the object
|
|
937
1271
|
sequence = record_seq_dict[record][0]
|
|
938
1272
|
complement = record_seq_dict[record][1]
|
|
939
|
-
final_adata.var[f
|
|
940
|
-
final_adata.var[f
|
|
941
|
-
final_adata.uns[f
|
|
942
|
-
final_adata.uns[
|
|
1273
|
+
final_adata.var[f"{record}_top_strand_FASTA_base"] = list(sequence)
|
|
1274
|
+
final_adata.var[f"{record}_bottom_strand_FASTA_base"] = list(complement)
|
|
1275
|
+
final_adata.uns[f"{record}_FASTA_sequence"] = sequence
|
|
1276
|
+
final_adata.uns["References"][f"{record}_FASTA_sequence"] = sequence
|
|
943
1277
|
# Add consensus sequence of samples mapped to the record to the object
|
|
944
|
-
record_subset = final_adata[final_adata.obs[
|
|
945
|
-
for strand in record_subset.obs[
|
|
946
|
-
strand_subset = record_subset[record_subset.obs[
|
|
947
|
-
for mapping_dir in strand_subset.obs[
|
|
948
|
-
mapping_dir_subset = strand_subset[
|
|
1278
|
+
record_subset = final_adata[final_adata.obs["Reference"] == record]
|
|
1279
|
+
for strand in record_subset.obs["Strand"].cat.categories:
|
|
1280
|
+
strand_subset = record_subset[record_subset.obs["Strand"] == strand]
|
|
1281
|
+
for mapping_dir in strand_subset.obs["Read_mapping_direction"].cat.categories:
|
|
1282
|
+
mapping_dir_subset = strand_subset[
|
|
1283
|
+
strand_subset.obs["Read_mapping_direction"] == mapping_dir
|
|
1284
|
+
]
|
|
949
1285
|
layer_map, layer_counts = {}, []
|
|
950
1286
|
for i, layer in enumerate(ohe_layers):
|
|
951
|
-
layer_map[i] = layer.split(
|
|
1287
|
+
layer_map[i] = layer.split("_")[0]
|
|
952
1288
|
layer_counts.append(np.sum(mapping_dir_subset.layers[layer], axis=0))
|
|
953
1289
|
count_array = np.array(layer_counts)
|
|
954
1290
|
nucleotide_indexes = np.argmax(count_array, axis=0)
|
|
955
1291
|
consensus_sequence_list = [layer_map[i] for i in nucleotide_indexes]
|
|
956
|
-
final_adata.var[
|
|
1292
|
+
final_adata.var[
|
|
1293
|
+
f"{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples"
|
|
1294
|
+
] = consensus_sequence_list
|
|
957
1295
|
|
|
958
1296
|
if input_already_demuxed:
|
|
959
1297
|
final_adata.obs["demux_type"] = ["already"] * final_adata.shape[0]
|
|
960
1298
|
final_adata.obs["demux_type"] = final_adata.obs["demux_type"].astype("category")
|
|
961
1299
|
else:
|
|
962
1300
|
from .h5ad_functions import add_demux_type_annotation
|
|
1301
|
+
|
|
963
1302
|
double_barcoded_reads = double_barcoded_path / "barcoding_summary.txt"
|
|
964
1303
|
add_demux_type_annotation(final_adata, double_barcoded_reads)
|
|
965
1304
|
|
|
@@ -967,4 +1306,4 @@ def modkit_extract_to_adata(fasta, bam_dir, out_dir, input_already_demuxed, mapp
|
|
|
967
1306
|
if delete_batch_hdfs:
|
|
968
1307
|
delete_intermediate_h5ads_and_tmpdir(h5_dir, tmp_dir)
|
|
969
1308
|
|
|
970
|
-
return final_adata, final_adata_path
|
|
1309
|
+
return final_adata, final_adata_path
|