smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +34 -0
- smftools/_settings.py +20 -0
- smftools/_version.py +1 -0
- smftools/cli.py +184 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +33 -0
- smftools/config/deaminase.yaml +56 -0
- smftools/config/default.yaml +253 -0
- smftools/config/direct.yaml +17 -0
- smftools/config/experiment_config.py +1191 -0
- smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
- smftools/datasets/F1_sample_sheet.csv +5 -0
- smftools/datasets/__init__.py +9 -0
- smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
- smftools/datasets/datasets.py +28 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/hmm/apply_hmm_batched.py +242 -0
- smftools/hmm/calculate_distances.py +18 -0
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/hmm/display_hmm.py +18 -0
- smftools/hmm/hmm_readwrite.py +16 -0
- smftools/hmm/nucleosome_hmm_refinement.py +104 -0
- smftools/hmm/train_hmm.py +78 -0
- smftools/informatics/__init__.py +14 -0
- smftools/informatics/archived/bam_conversion.py +59 -0
- smftools/informatics/archived/bam_direct.py +63 -0
- smftools/informatics/archived/basecalls_to_adata.py +71 -0
- smftools/informatics/archived/conversion_smf.py +132 -0
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/direct_smf.py +137 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/fast5_to_pod5.py +24 -0
- smftools/informatics/helpers/__init__.py +73 -0
- smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
- smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
- smftools/informatics/helpers/archived/informatics.py +260 -0
- smftools/informatics/helpers/archived/load_adata.py +516 -0
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/bed_to_bigwig.py +39 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
- smftools/informatics/helpers/canoncall.py +34 -0
- smftools/informatics/helpers/complement_base_list.py +21 -0
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
- smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
- smftools/informatics/helpers/count_aligned_reads.py +43 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/discover_input_files.py +100 -0
- smftools/informatics/helpers/extract_base_identities.py +70 -0
- smftools/informatics/helpers/extract_mods.py +83 -0
- smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
- smftools/informatics/helpers/find_conversion_sites.py +51 -0
- smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
- smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
- smftools/informatics/helpers/get_native_references.py +28 -0
- smftools/informatics/helpers/index_fasta.py +12 -0
- smftools/informatics/helpers/make_dirs.py +21 -0
- smftools/informatics/helpers/make_modbed.py +27 -0
- smftools/informatics/helpers/modQC.py +27 -0
- smftools/informatics/helpers/modcall.py +36 -0
- smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
- smftools/informatics/helpers/ohe_batching.py +76 -0
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +57 -0
- smftools/informatics/helpers/plot_bed_histograms.py +269 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
- smftools/informatics/helpers/split_and_index_BAM.py +32 -0
- smftools/informatics/readwrite.py +106 -0
- smftools/informatics/subsample_fasta_from_bed.py +47 -0
- smftools/informatics/subsample_pod5.py +104 -0
- smftools/load_adata.py +1346 -0
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/data/preprocessing.py +6 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/__init__.py +9 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/machine_learning/models/positional.py +18 -0
- smftools/machine_learning/models/rnn.py +17 -0
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/models/wrappers.py +20 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/machine_learning/utils/__init__.py +2 -0
- smftools/machine_learning/utils/device.py +10 -0
- smftools/machine_learning/utils/grl.py +14 -0
- smftools/plotting/__init__.py +18 -0
- smftools/plotting/autocorrelation_plotting.py +611 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +682 -0
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +38 -0
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/archives/mark_duplicates.py +146 -0
- smftools/preprocessing/archives/preprocessing.py +614 -0
- smftools/preprocessing/archives/remove_duplicates.py +21 -0
- smftools/preprocessing/binarize_on_Youden.py +45 -0
- smftools/preprocessing/binary_layers_to_ohe.py +40 -0
- smftools/preprocessing/calculate_complexity.py +72 -0
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_consensus.py +47 -0
- smftools/preprocessing/calculate_coverage.py +51 -0
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
- smftools/preprocessing/calculate_position_Youden.py +115 -0
- smftools/preprocessing/calculate_read_length_stats.py +79 -0
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +62 -0
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1351 -0
- smftools/preprocessing/invert_adata.py +37 -0
- smftools/preprocessing/load_sample_sheet.py +53 -0
- smftools/preprocessing/make_dirs.py +21 -0
- smftools/preprocessing/min_non_diagonal.py +25 -0
- smftools/preprocessing/recipes.py +127 -0
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +1004 -0
- smftools/tools/__init__.py +20 -0
- smftools/tools/archived/apply_hmm.py +202 -0
- smftools/tools/archived/classifiers.py +787 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/position_stats.py +601 -0
- smftools/tools/read_stats.py +184 -0
- smftools/tools/spatial_autocorrelation.py +562 -0
- smftools/tools/subset_adata.py +28 -0
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
- smftools-0.2.1.dist-info/RECORD +161 -0
- smftools-0.2.1.dist-info/entry_points.txt +2 -0
- smftools-0.1.6.dist-info/RECORD +0 -4
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
- {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,887 @@
|
|
|
1
|
+
## modkit_extract_to_adata
|
|
2
|
+
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
import gc
|
|
5
|
+
from .count_aligned_reads import count_aligned_reads
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
def filter_bam_records(bam, mapping_threshold):
|
|
11
|
+
"""Processes a single BAM file, counts reads, and determines records to analyze."""
|
|
12
|
+
aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(bam)
|
|
13
|
+
|
|
14
|
+
total_reads = aligned_reads_count + unaligned_reads_count
|
|
15
|
+
percent_aligned = (aligned_reads_count * 100 / total_reads) if total_reads > 0 else 0
|
|
16
|
+
print(f'{percent_aligned:.2f}% of reads in {bam} aligned successfully')
|
|
17
|
+
|
|
18
|
+
records = []
|
|
19
|
+
for record, (count, percentage) in record_counts_dict.items():
|
|
20
|
+
print(f'{count} reads mapped to reference {record}. This is {percentage*100:.2f}% of all mapped reads in {bam}')
|
|
21
|
+
if percentage >= mapping_threshold:
|
|
22
|
+
records.append(record)
|
|
23
|
+
|
|
24
|
+
return set(records)
|
|
25
|
+
|
|
26
|
+
def parallel_filter_bams(bam_path_list, mapping_threshold):
|
|
27
|
+
"""Parallel processing for multiple BAM files."""
|
|
28
|
+
records_to_analyze = set()
|
|
29
|
+
|
|
30
|
+
with concurrent.futures.ProcessPoolExecutor() as executor:
|
|
31
|
+
results = executor.map(filter_bam_records, bam_path_list, [mapping_threshold] * len(bam_path_list))
|
|
32
|
+
|
|
33
|
+
# Aggregate results
|
|
34
|
+
for result in results:
|
|
35
|
+
records_to_analyze.update(result)
|
|
36
|
+
|
|
37
|
+
print(f'Records to analyze: {records_to_analyze}')
|
|
38
|
+
return records_to_analyze
|
|
39
|
+
|
|
40
|
+
def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
|
|
41
|
+
"""
|
|
42
|
+
Loads and filters a single TSV file based on chromosome and position criteria.
|
|
43
|
+
"""
|
|
44
|
+
temp_df = pd.read_csv(tsv, sep='\t', header=0)
|
|
45
|
+
filtered_records = {}
|
|
46
|
+
|
|
47
|
+
for record in records_to_analyze:
|
|
48
|
+
if record not in reference_dict:
|
|
49
|
+
continue
|
|
50
|
+
|
|
51
|
+
ref_length = reference_dict[record][0]
|
|
52
|
+
filtered_df = temp_df[(temp_df['chrom'] == record) &
|
|
53
|
+
(temp_df['ref_position'] >= 0) &
|
|
54
|
+
(temp_df['ref_position'] < ref_length)]
|
|
55
|
+
|
|
56
|
+
if not filtered_df.empty:
|
|
57
|
+
filtered_records[record] = {sample_index: filtered_df}
|
|
58
|
+
|
|
59
|
+
return filtered_records
|
|
60
|
+
|
|
61
|
+
def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size, threads=4):
|
|
62
|
+
"""
|
|
63
|
+
Loads and filters TSV files in parallel.
|
|
64
|
+
|
|
65
|
+
Parameters:
|
|
66
|
+
tsv_batch (list): List of TSV file paths.
|
|
67
|
+
records_to_analyze (list): Chromosome records to analyze.
|
|
68
|
+
reference_dict (dict): Dictionary containing reference lengths.
|
|
69
|
+
batch (int): Current batch number.
|
|
70
|
+
batch_size (int): Total files in the batch.
|
|
71
|
+
threads (int): Number of parallel workers.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
dict: Processed `dict_total` dictionary.
|
|
75
|
+
"""
|
|
76
|
+
dict_total = {record: {} for record in records_to_analyze}
|
|
77
|
+
|
|
78
|
+
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
|
|
79
|
+
futures = {
|
|
80
|
+
executor.submit(process_tsv, tsv, records_to_analyze, reference_dict, sample_index): sample_index
|
|
81
|
+
for sample_index, tsv in enumerate(tsv_batch)
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
for future in tqdm(concurrent.futures.as_completed(futures), desc=f'Processing batch {batch}', total=batch_size):
|
|
85
|
+
result = future.result()
|
|
86
|
+
for record, sample_data in result.items():
|
|
87
|
+
dict_total[record].update(sample_data)
|
|
88
|
+
|
|
89
|
+
return dict_total
|
|
90
|
+
|
|
91
|
+
def update_dict_to_skip(dict_to_skip, detected_modifications):
|
|
92
|
+
"""
|
|
93
|
+
Updates the dict_to_skip set based on the detected modifications.
|
|
94
|
+
|
|
95
|
+
Parameters:
|
|
96
|
+
dict_to_skip (set): The initial set of dictionary indices to skip.
|
|
97
|
+
detected_modifications (list or set): The modifications (e.g. ['6mA', '5mC']) present.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
set: The updated dict_to_skip set.
|
|
101
|
+
"""
|
|
102
|
+
# Define which indices correspond to modification-specific or strand-specific dictionaries
|
|
103
|
+
A_stranded_dicts = {2, 3} # m6A bottom and top strand dictionaries
|
|
104
|
+
C_stranded_dicts = {5, 6} # 5mC bottom and top strand dictionaries
|
|
105
|
+
combined_dicts = {7, 8} # Combined strand dictionaries
|
|
106
|
+
|
|
107
|
+
# If '6mA' is present, remove the A_stranded indices from the skip set
|
|
108
|
+
if '6mA' in detected_modifications:
|
|
109
|
+
dict_to_skip -= A_stranded_dicts
|
|
110
|
+
# If '5mC' is present, remove the C_stranded indices from the skip set
|
|
111
|
+
if '5mC' in detected_modifications:
|
|
112
|
+
dict_to_skip -= C_stranded_dicts
|
|
113
|
+
# If both modifications are present, remove the combined indices from the skip set
|
|
114
|
+
if '6mA' in detected_modifications and '5mC' in detected_modifications:
|
|
115
|
+
dict_to_skip -= combined_dicts
|
|
116
|
+
|
|
117
|
+
return dict_to_skip
|
|
118
|
+
|
|
119
|
+
def process_modifications_for_sample(args):
|
|
120
|
+
"""
|
|
121
|
+
Processes a single (record, sample) pair to extract modification-specific data.
|
|
122
|
+
|
|
123
|
+
Parameters:
|
|
124
|
+
args: (record, sample_index, sample_df, mods, max_reference_length)
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
(record, sample_index, result) where result is a dict with keys:
|
|
128
|
+
'm6A', 'm6A_minus', 'm6A_plus', '5mC', '5mC_minus', '5mC_plus', and
|
|
129
|
+
optionally 'combined_minus' and 'combined_plus' (initialized as empty lists).
|
|
130
|
+
"""
|
|
131
|
+
record, sample_index, sample_df, mods, max_reference_length = args
|
|
132
|
+
result = {}
|
|
133
|
+
if '6mA' in mods:
|
|
134
|
+
m6a_df = sample_df[sample_df['modified_primary_base'] == 'A']
|
|
135
|
+
result['m6A'] = m6a_df
|
|
136
|
+
result['m6A_minus'] = m6a_df[m6a_df['ref_strand'] == '-']
|
|
137
|
+
result['m6A_plus'] = m6a_df[m6a_df['ref_strand'] == '+']
|
|
138
|
+
m6a_df = None
|
|
139
|
+
gc.collect()
|
|
140
|
+
if '5mC' in mods:
|
|
141
|
+
m5c_df = sample_df[sample_df['modified_primary_base'] == 'C']
|
|
142
|
+
result['5mC'] = m5c_df
|
|
143
|
+
result['5mC_minus'] = m5c_df[m5c_df['ref_strand'] == '-']
|
|
144
|
+
result['5mC_plus'] = m5c_df[m5c_df['ref_strand'] == '+']
|
|
145
|
+
m5c_df = None
|
|
146
|
+
gc.collect()
|
|
147
|
+
if '6mA' in mods and '5mC' in mods:
|
|
148
|
+
result['combined_minus'] = []
|
|
149
|
+
result['combined_plus'] = []
|
|
150
|
+
return record, sample_index, result
|
|
151
|
+
|
|
152
|
+
def parallel_process_modifications(dict_total, mods, max_reference_length, threads=4):
|
|
153
|
+
"""
|
|
154
|
+
Processes each (record, sample) pair in dict_total in parallel to extract modification-specific data.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
processed_results: Dict keyed by record, with sub-dict keyed by sample index and the processed results.
|
|
158
|
+
"""
|
|
159
|
+
tasks = []
|
|
160
|
+
for record, sample_dict in dict_total.items():
|
|
161
|
+
for sample_index, sample_df in sample_dict.items():
|
|
162
|
+
tasks.append((record, sample_index, sample_df, mods, max_reference_length))
|
|
163
|
+
processed_results = {}
|
|
164
|
+
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
|
|
165
|
+
for record, sample_index, result in tqdm(
|
|
166
|
+
executor.map(process_modifications_for_sample, tasks),
|
|
167
|
+
total=len(tasks),
|
|
168
|
+
desc="Processing modifications"):
|
|
169
|
+
if record not in processed_results:
|
|
170
|
+
processed_results[record] = {}
|
|
171
|
+
processed_results[record][sample_index] = result
|
|
172
|
+
return processed_results
|
|
173
|
+
|
|
174
|
+
def merge_modification_results(processed_results, mods):
|
|
175
|
+
"""
|
|
176
|
+
Merges individual sample results into global dictionaries.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
A tuple: (m6A_dict, m6A_minus, m6A_plus, c5m_dict, c5m_minus, c5m_plus, combined_minus, combined_plus)
|
|
180
|
+
"""
|
|
181
|
+
m6A_dict = {}
|
|
182
|
+
m6A_minus = {}
|
|
183
|
+
m6A_plus = {}
|
|
184
|
+
c5m_dict = {}
|
|
185
|
+
c5m_minus = {}
|
|
186
|
+
c5m_plus = {}
|
|
187
|
+
combined_minus = {}
|
|
188
|
+
combined_plus = {}
|
|
189
|
+
for record, sample_results in processed_results.items():
|
|
190
|
+
for sample_index, res in sample_results.items():
|
|
191
|
+
if '6mA' in mods:
|
|
192
|
+
if record not in m6A_dict:
|
|
193
|
+
m6A_dict[record], m6A_minus[record], m6A_plus[record] = {}, {}, {}
|
|
194
|
+
m6A_dict[record][sample_index] = res.get('m6A', pd.DataFrame())
|
|
195
|
+
m6A_minus[record][sample_index] = res.get('m6A_minus', pd.DataFrame())
|
|
196
|
+
m6A_plus[record][sample_index] = res.get('m6A_plus', pd.DataFrame())
|
|
197
|
+
if '5mC' in mods:
|
|
198
|
+
if record not in c5m_dict:
|
|
199
|
+
c5m_dict[record], c5m_minus[record], c5m_plus[record] = {}, {}, {}
|
|
200
|
+
c5m_dict[record][sample_index] = res.get('5mC', pd.DataFrame())
|
|
201
|
+
c5m_minus[record][sample_index] = res.get('5mC_minus', pd.DataFrame())
|
|
202
|
+
c5m_plus[record][sample_index] = res.get('5mC_plus', pd.DataFrame())
|
|
203
|
+
if '6mA' in mods and '5mC' in mods:
|
|
204
|
+
if record not in combined_minus:
|
|
205
|
+
combined_minus[record], combined_plus[record] = {}, {}
|
|
206
|
+
combined_minus[record][sample_index] = res.get('combined_minus', [])
|
|
207
|
+
combined_plus[record][sample_index] = res.get('combined_plus', [])
|
|
208
|
+
return (m6A_dict, m6A_minus, m6A_plus,
|
|
209
|
+
c5m_dict, c5m_minus, c5m_plus,
|
|
210
|
+
combined_minus, combined_plus)
|
|
211
|
+
|
|
212
|
+
def process_stranded_methylation(args):
|
|
213
|
+
"""
|
|
214
|
+
Processes a single (dict_index, record, sample) task.
|
|
215
|
+
|
|
216
|
+
For combined dictionaries (indices 7 or 8), it merges the corresponding A-stranded and C-stranded data.
|
|
217
|
+
For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
|
|
218
|
+
NumPy methylation array (of float type). Non-numeric values (e.g. '-') are coerced to NaN.
|
|
219
|
+
|
|
220
|
+
Parameters:
|
|
221
|
+
args: (dict_index, record, sample, dict_list, max_reference_length)
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
(dict_index, record, sample, processed_data)
|
|
225
|
+
"""
|
|
226
|
+
dict_index, record, sample, dict_list, max_reference_length = args
|
|
227
|
+
processed_data = {}
|
|
228
|
+
|
|
229
|
+
# For combined bottom strand (index 7)
|
|
230
|
+
if dict_index == 7:
|
|
231
|
+
temp_a = dict_list[2][record].get(sample, {}).copy()
|
|
232
|
+
temp_c = dict_list[5][record].get(sample, {}).copy()
|
|
233
|
+
processed_data = {}
|
|
234
|
+
for read in set(temp_a.keys()) | set(temp_c.keys()):
|
|
235
|
+
if read in temp_a:
|
|
236
|
+
# Convert using pd.to_numeric with errors='coerce'
|
|
237
|
+
value_a = pd.to_numeric(np.array(temp_a[read]), errors='coerce')
|
|
238
|
+
else:
|
|
239
|
+
value_a = None
|
|
240
|
+
if read in temp_c:
|
|
241
|
+
value_c = pd.to_numeric(np.array(temp_c[read]), errors='coerce')
|
|
242
|
+
else:
|
|
243
|
+
value_c = None
|
|
244
|
+
if value_a is not None and value_c is not None:
|
|
245
|
+
processed_data[read] = np.where(
|
|
246
|
+
np.isnan(value_a) & np.isnan(value_c),
|
|
247
|
+
np.nan,
|
|
248
|
+
np.nan_to_num(value_a) + np.nan_to_num(value_c)
|
|
249
|
+
)
|
|
250
|
+
elif value_a is not None:
|
|
251
|
+
processed_data[read] = value_a
|
|
252
|
+
elif value_c is not None:
|
|
253
|
+
processed_data[read] = value_c
|
|
254
|
+
del temp_a, temp_c
|
|
255
|
+
|
|
256
|
+
# For combined top strand (index 8)
|
|
257
|
+
elif dict_index == 8:
|
|
258
|
+
temp_a = dict_list[3][record].get(sample, {}).copy()
|
|
259
|
+
temp_c = dict_list[6][record].get(sample, {}).copy()
|
|
260
|
+
processed_data = {}
|
|
261
|
+
for read in set(temp_a.keys()) | set(temp_c.keys()):
|
|
262
|
+
if read in temp_a:
|
|
263
|
+
value_a = pd.to_numeric(np.array(temp_a[read]), errors='coerce')
|
|
264
|
+
else:
|
|
265
|
+
value_a = None
|
|
266
|
+
if read in temp_c:
|
|
267
|
+
value_c = pd.to_numeric(np.array(temp_c[read]), errors='coerce')
|
|
268
|
+
else:
|
|
269
|
+
value_c = None
|
|
270
|
+
if value_a is not None and value_c is not None:
|
|
271
|
+
processed_data[read] = np.where(
|
|
272
|
+
np.isnan(value_a) & np.isnan(value_c),
|
|
273
|
+
np.nan,
|
|
274
|
+
np.nan_to_num(value_a) + np.nan_to_num(value_c)
|
|
275
|
+
)
|
|
276
|
+
elif value_a is not None:
|
|
277
|
+
processed_data[read] = value_a
|
|
278
|
+
elif value_c is not None:
|
|
279
|
+
processed_data[read] = value_c
|
|
280
|
+
del temp_a, temp_c
|
|
281
|
+
|
|
282
|
+
# For all other dictionaries
|
|
283
|
+
else:
|
|
284
|
+
# current_data is a DataFrame
|
|
285
|
+
temp_df = dict_list[dict_index][record][sample]
|
|
286
|
+
processed_data = {}
|
|
287
|
+
# Extract columns and convert probabilities to float (coercing errors)
|
|
288
|
+
read_ids = temp_df['read_id'].values
|
|
289
|
+
positions = temp_df['ref_position'].values
|
|
290
|
+
call_codes = temp_df['call_code'].values
|
|
291
|
+
probabilities = pd.to_numeric(temp_df['call_prob'].values, errors='coerce')
|
|
292
|
+
|
|
293
|
+
modified_codes = {'a', 'h', 'm'}
|
|
294
|
+
canonical_codes = {'-'}
|
|
295
|
+
|
|
296
|
+
# Compute methylation probabilities (vectorized)
|
|
297
|
+
methylation_prob = np.full(probabilities.shape, np.nan, dtype=float)
|
|
298
|
+
methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[np.isin(call_codes, list(modified_codes))]
|
|
299
|
+
methylation_prob[np.isin(call_codes, list(canonical_codes))] = 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
|
|
300
|
+
|
|
301
|
+
# Preallocate storage for each unique read
|
|
302
|
+
unique_reads = np.unique(read_ids)
|
|
303
|
+
for read in unique_reads:
|
|
304
|
+
processed_data[read] = np.full(max_reference_length, np.nan, dtype=float)
|
|
305
|
+
|
|
306
|
+
# Assign values efficiently
|
|
307
|
+
for i in range(len(read_ids)):
|
|
308
|
+
read = read_ids[i]
|
|
309
|
+
pos = positions[i]
|
|
310
|
+
prob = methylation_prob[i]
|
|
311
|
+
processed_data[read][pos] = prob
|
|
312
|
+
|
|
313
|
+
gc.collect()
|
|
314
|
+
return dict_index, record, sample, processed_data
|
|
315
|
+
|
|
316
|
+
def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=4):
|
|
317
|
+
"""
|
|
318
|
+
Processes all (dict_index, record, sample) tasks in dict_list (excluding indices in dict_to_skip) in parallel.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
Updated dict_list with processed (nested) dictionaries.
|
|
322
|
+
"""
|
|
323
|
+
tasks = []
|
|
324
|
+
for dict_index, current_dict in enumerate(dict_list):
|
|
325
|
+
if dict_index not in dict_to_skip:
|
|
326
|
+
for record in current_dict.keys():
|
|
327
|
+
for sample in current_dict[record].keys():
|
|
328
|
+
tasks.append((dict_index, record, sample, dict_list, max_reference_length))
|
|
329
|
+
|
|
330
|
+
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
|
|
331
|
+
for dict_index, record, sample, processed_data in tqdm(
|
|
332
|
+
executor.map(process_stranded_methylation, tasks),
|
|
333
|
+
total=len(tasks),
|
|
334
|
+
desc="Extracting stranded methylation states"
|
|
335
|
+
):
|
|
336
|
+
dict_list[dict_index][record][sample] = processed_data
|
|
337
|
+
return dict_list
|
|
338
|
+
|
|
339
|
+
def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name, mods, batch_size, mod_tsv_dir, delete_batch_hdfs=False, threads=None):
|
|
340
|
+
"""
|
|
341
|
+
Takes modkit extract outputs and organizes it into an adata object
|
|
342
|
+
|
|
343
|
+
Parameters:
|
|
344
|
+
fasta (str): File path to the reference genome to align to.
|
|
345
|
+
bam_dir (str): File path to the directory containing the aligned_sorted split modified BAM files
|
|
346
|
+
mapping_threshold (float): A value in between 0 and 1 to threshold the minimal fraction of aligned reads which map to the reference region. References with values above the threshold are included in the output adata.
|
|
347
|
+
experiment_name (str): A string to provide an experiment name to the output adata file.
|
|
348
|
+
mods (list): A list of strings of the modification types to use in the analysis.
|
|
349
|
+
batch_size (int): An integer number of TSV files to analyze in memory at once while loading the final adata object.
|
|
350
|
+
mod_tsv_dir (str): String representing the path to the mod TSV directory
|
|
351
|
+
delete_batch_hdfs (bool): Whether to delete the batch hdfs after writing out the final concatenated hdf. Default is False
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
final_adata_path (str): Path to the final adata
|
|
355
|
+
"""
|
|
356
|
+
###################################################
|
|
357
|
+
# Package imports
|
|
358
|
+
from .. import readwrite
|
|
359
|
+
from .get_native_references import get_native_references
|
|
360
|
+
from .extract_base_identities import extract_base_identities
|
|
361
|
+
from .ohe_batching import ohe_batching
|
|
362
|
+
import pandas as pd
|
|
363
|
+
import anndata as ad
|
|
364
|
+
import os
|
|
365
|
+
import gc
|
|
366
|
+
import math
|
|
367
|
+
import numpy as np
|
|
368
|
+
from Bio.Seq import Seq
|
|
369
|
+
from tqdm import tqdm
|
|
370
|
+
import h5py
|
|
371
|
+
from .make_dirs import make_dirs
|
|
372
|
+
###################################################
|
|
373
|
+
|
|
374
|
+
################## Get input tsv and bam file names into a sorted list ################
|
|
375
|
+
# List all files in the directory
|
|
376
|
+
tsv_files = os.listdir(mod_tsv_dir)
|
|
377
|
+
bam_files = os.listdir(bam_dir)
|
|
378
|
+
# get current working directory
|
|
379
|
+
parent_dir = os.path.dirname(mod_tsv_dir)
|
|
380
|
+
|
|
381
|
+
# Make output dirs
|
|
382
|
+
h5_dir = os.path.join(parent_dir, 'h5ads')
|
|
383
|
+
tmp_dir = os.path.join(parent_dir, 'tmp')
|
|
384
|
+
make_dirs([h5_dir, tmp_dir])
|
|
385
|
+
existing_h5s = os.listdir(h5_dir)
|
|
386
|
+
existing_h5s = [h5 for h5 in existing_h5s if '.h5ad.gz' in h5]
|
|
387
|
+
final_hdf = f'{experiment_name}_final_experiment_hdf5.h5ad'
|
|
388
|
+
final_adata_path = os.path.join(h5_dir, final_hdf)
|
|
389
|
+
final_adata = None
|
|
390
|
+
|
|
391
|
+
if os.path.exists(f"{final_adata_path}.gz"):
|
|
392
|
+
print(f'{final_adata_path}.gz already exists. Using existing adata')
|
|
393
|
+
return final_adata, f"{final_adata_path}.gz"
|
|
394
|
+
|
|
395
|
+
elif os.path.exists(f"{final_adata_path}"):
|
|
396
|
+
print(f'{final_adata_path} already exists. Using existing adata')
|
|
397
|
+
return final_adata, final_adata_path
|
|
398
|
+
|
|
399
|
+
# Filter file names that contain the search string in their filename and keep them in a list
|
|
400
|
+
tsvs = [tsv for tsv in tsv_files if 'extract.tsv' in tsv and 'unclassified' not in tsv]
|
|
401
|
+
bams = [bam for bam in bam_files if '.bam' in bam and '.bai' not in bam and 'unclassified' not in bam]
|
|
402
|
+
# Sort file list by names and print the list of file names
|
|
403
|
+
tsvs.sort()
|
|
404
|
+
tsv_path_list = [os.path.join(mod_tsv_dir, tsv) for tsv in tsvs]
|
|
405
|
+
bams.sort()
|
|
406
|
+
bam_path_list = [os.path.join(bam_dir, bam) for bam in bams]
|
|
407
|
+
print(f'{len(tsvs)} sample tsv files found: {tsvs}')
|
|
408
|
+
print(f'{len(bams)} sample bams found: {bams}')
|
|
409
|
+
##########################################################################################
|
|
410
|
+
|
|
411
|
+
######### Get Record names that have over a passed threshold of mapped reads #############
|
|
412
|
+
# get all records that are above a certain mapping threshold in at least one sample bam
|
|
413
|
+
records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold)
|
|
414
|
+
|
|
415
|
+
##########################################################################################
|
|
416
|
+
|
|
417
|
+
########### Determine the maximum record length to analyze in the dataset ################
|
|
418
|
+
# Get all references within the FASTA and indicate the length and identity of the record sequence
|
|
419
|
+
max_reference_length = 0
|
|
420
|
+
reference_dict = get_native_references(fasta) # returns a dict keyed by record name. Points to a tuple of (reference length, reference sequence)
|
|
421
|
+
# Get the max record length in the dataset.
|
|
422
|
+
for record in records_to_analyze:
|
|
423
|
+
if reference_dict[record][0] > max_reference_length:
|
|
424
|
+
max_reference_length = reference_dict[record][0]
|
|
425
|
+
print(f'{readwrite.time_string()}: Max reference length in dataset: {max_reference_length}')
|
|
426
|
+
batches = math.ceil(len(tsvs) / batch_size) # Number of batches to process
|
|
427
|
+
print('{0}: Processing input tsvs in {1} batches of {2} tsvs '.format(readwrite.time_string(), batches, batch_size))
|
|
428
|
+
##########################################################################################
|
|
429
|
+
|
|
430
|
+
##########################################################################################
|
|
431
|
+
# One hot encode read sequences and write them out into the tmp_dir as h5ad files.
|
|
432
|
+
# Save the file paths in the bam_record_ohe_files dict.
|
|
433
|
+
bam_record_ohe_files = {}
|
|
434
|
+
bam_record_save = os.path.join(tmp_dir, 'tmp_file_dict.h5ad')
|
|
435
|
+
fwd_mapped_reads = set()
|
|
436
|
+
rev_mapped_reads = set()
|
|
437
|
+
# If this step has already been performed, read in the tmp_dile_dict
|
|
438
|
+
if os.path.exists(bam_record_save):
|
|
439
|
+
bam_record_ohe_files = ad.read_h5ad(bam_record_save).uns
|
|
440
|
+
print('Found existing OHE reads, using these')
|
|
441
|
+
else:
|
|
442
|
+
# Iterate over split bams
|
|
443
|
+
for bami, bam in enumerate(bam_path_list):
|
|
444
|
+
# Iterate over references to process
|
|
445
|
+
for record in records_to_analyze:
|
|
446
|
+
current_reference_length = reference_dict[record][0]
|
|
447
|
+
positions = range(current_reference_length)
|
|
448
|
+
ref_seq = reference_dict[record][1]
|
|
449
|
+
# Extract the base identities of reads aligned to the record
|
|
450
|
+
fwd_base_identities, rev_base_identities, mismatch_counts_per_read, mismatch_trend_per_read = extract_base_identities(bam, record, positions, max_reference_length, ref_seq)
|
|
451
|
+
# Store read names of fwd and rev mapped reads
|
|
452
|
+
fwd_mapped_reads.update(fwd_base_identities.keys())
|
|
453
|
+
rev_mapped_reads.update(rev_base_identities.keys())
|
|
454
|
+
# One hot encode the sequence string of the reads
|
|
455
|
+
fwd_ohe_files = ohe_batching(fwd_base_identities, tmp_dir, record, f"{bami}_fwd",batch_size=100000, threads=threads)
|
|
456
|
+
rev_ohe_files = ohe_batching(rev_base_identities, tmp_dir, record, f"{bami}_rev",batch_size=100000, threads=threads)
|
|
457
|
+
bam_record_ohe_files[f'{bami}_{record}'] = fwd_ohe_files + rev_ohe_files
|
|
458
|
+
del fwd_base_identities, rev_base_identities
|
|
459
|
+
# Save out the ohe file paths
|
|
460
|
+
X = np.random.rand(1, 1)
|
|
461
|
+
tmp_ad = ad.AnnData(X=X, uns=bam_record_ohe_files)
|
|
462
|
+
tmp_ad.write_h5ad(bam_record_save)
|
|
463
|
+
##########################################################################################
|
|
464
|
+
|
|
465
|
+
##########################################################################################
|
|
466
|
+
# Iterate over records to analyze and return a dictionary keyed by the reference name that points to a tuple containing the top strand sequence and the complement
|
|
467
|
+
record_seq_dict = {}
|
|
468
|
+
for record in records_to_analyze:
|
|
469
|
+
current_reference_length = reference_dict[record][0]
|
|
470
|
+
delta_max_length = max_reference_length - current_reference_length
|
|
471
|
+
sequence = reference_dict[record][1] + 'N'*delta_max_length
|
|
472
|
+
complement = str(Seq(reference_dict[record][1]).complement()).upper() + 'N'*delta_max_length
|
|
473
|
+
record_seq_dict[record] = (sequence, complement)
|
|
474
|
+
##########################################################################################
|
|
475
|
+
|
|
476
|
+
###################################################
|
|
477
|
+
# Begin iterating over batches
|
|
478
|
+
for batch in range(batches):
|
|
479
|
+
print('{0}: Processing tsvs for batch {1} '.format(readwrite.time_string(), batch))
|
|
480
|
+
# For the final batch, just take the remaining tsv and bam files
|
|
481
|
+
if batch == batches - 1:
|
|
482
|
+
tsv_batch = tsv_path_list
|
|
483
|
+
bam_batch = bam_path_list
|
|
484
|
+
# For all other batches, take the next batch of tsvs and bams out of the file queue.
|
|
485
|
+
else:
|
|
486
|
+
tsv_batch = tsv_path_list[:batch_size]
|
|
487
|
+
bam_batch = bam_path_list[:batch_size]
|
|
488
|
+
tsv_path_list = tsv_path_list[batch_size:]
|
|
489
|
+
bam_path_list = bam_path_list[batch_size:]
|
|
490
|
+
print('{0}: tsvs in batch {1} '.format(readwrite.time_string(), tsv_batch))
|
|
491
|
+
|
|
492
|
+
batch_already_processed = sum([1 for h5 in existing_h5s if f'_{batch}_' in h5])
|
|
493
|
+
###################################################
|
|
494
|
+
if batch_already_processed:
|
|
495
|
+
print(f'Batch {batch} has already been processed into h5ads. Skipping batch and using existing files')
|
|
496
|
+
else:
|
|
497
|
+
###################################################
|
|
498
|
+
### Add the tsvs as dataframes to a dictionary (dict_total) keyed by integer index. Also make modification specific dictionaries and strand specific dictionaries.
|
|
499
|
+
# # Initialize dictionaries and place them in a list
|
|
500
|
+
dict_total, dict_a, dict_a_bottom, dict_a_top, dict_c, dict_c_bottom, dict_c_top, dict_combined_bottom, dict_combined_top = {},{},{},{},{},{},{},{},{}
|
|
501
|
+
dict_list = [dict_total, dict_a, dict_a_bottom, dict_a_top, dict_c, dict_c_bottom, dict_c_top, dict_combined_bottom, dict_combined_top]
|
|
502
|
+
# Give names to represent each dictionary in the list
|
|
503
|
+
sample_types = ['total', 'm6A', 'm6A_bottom_strand', 'm6A_top_strand', '5mC', '5mC_bottom_strand', '5mC_top_strand', 'combined_bottom_strand', 'combined_top_strand']
|
|
504
|
+
# Give indices of dictionaries to skip for analysis and final dictionary saving.
|
|
505
|
+
dict_to_skip = [0, 1, 4]
|
|
506
|
+
combined_dicts = [7, 8]
|
|
507
|
+
A_stranded_dicts = [2, 3]
|
|
508
|
+
C_stranded_dicts = [5, 6]
|
|
509
|
+
dict_to_skip = dict_to_skip + combined_dicts + A_stranded_dicts + C_stranded_dicts
|
|
510
|
+
dict_to_skip = set(dict_to_skip)
|
|
511
|
+
|
|
512
|
+
# # Step 1):Load the dict_total dictionary with all of the batch tsv files as dataframes.
|
|
513
|
+
dict_total = parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size=len(tsv_batch), threads=threads)
|
|
514
|
+
|
|
515
|
+
# # Step 2: Extract modification-specific data (per (record,sample)) in parallel
|
|
516
|
+
# processed_mod_results = parallel_process_modifications(dict_total, mods, max_reference_length, threads=threads or 4)
|
|
517
|
+
# (m6A_dict, m6A_minus_strand, m6A_plus_strand,
|
|
518
|
+
# c5m_dict, c5m_minus_strand, c5m_plus_strand,
|
|
519
|
+
# combined_minus_strand, combined_plus_strand) = merge_modification_results(processed_mod_results, mods)
|
|
520
|
+
|
|
521
|
+
# # Create dict_list with the desired ordering:
|
|
522
|
+
# # 0: dict_total, 1: m6A, 2: m6A_minus, 3: m6A_plus, 4: 5mC, 5: 5mC_minus, 6: 5mC_plus, 7: combined_minus, 8: combined_plus
|
|
523
|
+
# dict_list = [dict_total, m6A_dict, m6A_minus_strand, m6A_plus_strand,
|
|
524
|
+
# c5m_dict, c5m_minus_strand, c5m_plus_strand,
|
|
525
|
+
# combined_minus_strand, combined_plus_strand]
|
|
526
|
+
|
|
527
|
+
# # Initialize dict_to_skip (default skip all mod-specific indices)
|
|
528
|
+
# dict_to_skip = set([0, 1, 4, 7, 8, 2, 3, 5, 6])
|
|
529
|
+
# # Update dict_to_skip based on modifications present in mods
|
|
530
|
+
# dict_to_skip = update_dict_to_skip(dict_to_skip, mods)
|
|
531
|
+
|
|
532
|
+
# # Step 3: Process stranded methylation data in parallel
|
|
533
|
+
# dict_list = parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=threads or 4)
|
|
534
|
+
|
|
535
|
+
# Iterate over dict_total of all the tsv files and extract the modification specific and strand specific dataframes into dictionaries
|
|
536
|
+
for record in dict_total.keys():
|
|
537
|
+
for sample_index in dict_total[record].keys():
|
|
538
|
+
if '6mA' in mods:
|
|
539
|
+
# Remove Adenine stranded dicts from the dicts to skip set
|
|
540
|
+
dict_to_skip.difference_update(set(A_stranded_dicts))
|
|
541
|
+
|
|
542
|
+
if record not in dict_a.keys() and record not in dict_a_bottom.keys() and record not in dict_a_top.keys():
|
|
543
|
+
dict_a[record], dict_a_bottom[record], dict_a_top[record] = {}, {}, {}
|
|
544
|
+
|
|
545
|
+
# get a dictionary of dataframes that only contain methylated adenine positions
|
|
546
|
+
dict_a[record][sample_index] = dict_total[record][sample_index][dict_total[record][sample_index]['modified_primary_base'] == 'A']
|
|
547
|
+
print('{}: Successfully loaded a methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
|
|
548
|
+
# Stratify the adenine dictionary into two strand specific dictionaries.
|
|
549
|
+
dict_a_bottom[record][sample_index] = dict_a[record][sample_index][dict_a[record][sample_index]['ref_strand'] == '-']
|
|
550
|
+
print('{}: Successfully loaded a minus strand methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
|
|
551
|
+
dict_a_top[record][sample_index] = dict_a[record][sample_index][dict_a[record][sample_index]['ref_strand'] == '+']
|
|
552
|
+
print('{}: Successfully loaded a plus strand methyl-adenine dictionary for '.format(readwrite.time_string()) + str(sample_index))
|
|
553
|
+
|
|
554
|
+
# Reassign pointer for dict_a to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
555
|
+
dict_a[record][sample_index] = None
|
|
556
|
+
gc.collect()
|
|
557
|
+
|
|
558
|
+
if '5mC' in mods:
|
|
559
|
+
# Remove Cytosine stranded dicts from the dicts to skip set
|
|
560
|
+
dict_to_skip.difference_update(set(C_stranded_dicts))
|
|
561
|
+
|
|
562
|
+
if record not in dict_c.keys() and record not in dict_c_bottom.keys() and record not in dict_c_top.keys():
|
|
563
|
+
dict_c[record], dict_c_bottom[record], dict_c_top[record] = {}, {}, {}
|
|
564
|
+
|
|
565
|
+
# get a dictionary of dataframes that only contain methylated cytosine positions
|
|
566
|
+
dict_c[record][sample_index] = dict_total[record][sample_index][dict_total[record][sample_index]['modified_primary_base'] == 'C']
|
|
567
|
+
print('{}: Successfully loaded a methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
|
|
568
|
+
# Stratify the cytosine dictionary into two strand specific dictionaries.
|
|
569
|
+
dict_c_bottom[record][sample_index] = dict_c[record][sample_index][dict_c[record][sample_index]['ref_strand'] == '-']
|
|
570
|
+
print('{}: Successfully loaded a minus strand methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
|
|
571
|
+
dict_c_top[record][sample_index] = dict_c[record][sample_index][dict_c[record][sample_index]['ref_strand'] == '+']
|
|
572
|
+
print('{}: Successfully loaded a plus strand methyl-cytosine dictionary for '.format(readwrite.time_string()) + str(sample_index))
|
|
573
|
+
# Reassign pointer for dict_c to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
574
|
+
dict_c[record][sample_index] = None
|
|
575
|
+
gc.collect()
|
|
576
|
+
|
|
577
|
+
if '6mA' in mods and '5mC' in mods:
|
|
578
|
+
# Remove combined stranded dicts from the dicts to skip set
|
|
579
|
+
dict_to_skip.difference_update(set(combined_dicts))
|
|
580
|
+
# Initialize the sample keys for the combined dictionaries
|
|
581
|
+
|
|
582
|
+
if record not in dict_combined_bottom.keys() and record not in dict_combined_top.keys():
|
|
583
|
+
dict_combined_bottom[record], dict_combined_top[record]= {}, {}
|
|
584
|
+
|
|
585
|
+
print('{}: Successfully created a minus strand combined methylation dictionary for '.format(readwrite.time_string()) + str(sample_index))
|
|
586
|
+
dict_combined_bottom[record][sample_index] = []
|
|
587
|
+
print('{}: Successfully created a plus strand combined methylation dictionary for '.format(readwrite.time_string()) + str(sample_index))
|
|
588
|
+
dict_combined_top[record][sample_index] = []
|
|
589
|
+
|
|
590
|
+
# Reassign pointer for dict_total to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
591
|
+
dict_total[record][sample_index] = None
|
|
592
|
+
gc.collect()
|
|
593
|
+
|
|
594
|
+
# Iterate over the stranded modification dictionaries and replace the dataframes with a dictionary of read names pointing to a list of values from the dataframe
|
|
595
|
+
for dict_index, dict_type in enumerate(dict_list):
|
|
596
|
+
# Only iterate over stranded dictionaries
|
|
597
|
+
if dict_index not in dict_to_skip:
|
|
598
|
+
print('{0}: Extracting methylation states for {1} dictionary'.format(readwrite.time_string(), sample_types[dict_index]))
|
|
599
|
+
for record in dict_type.keys():
|
|
600
|
+
# Get the dictionary for the modification type of interest from the reference mapping of interest
|
|
601
|
+
mod_strand_record_sample_dict = dict_type[record]
|
|
602
|
+
print('{0}: Extracting methylation states for {1} dictionary'.format(readwrite.time_string(), record))
|
|
603
|
+
# For each sample in a stranded dictionary
|
|
604
|
+
n_samples = len(mod_strand_record_sample_dict.keys())
|
|
605
|
+
for sample in tqdm(mod_strand_record_sample_dict.keys(), desc=f'Extracting {sample_types[dict_index]} dictionary from record {record} for sample', total=n_samples):
|
|
606
|
+
# Load the combined bottom strand dictionary after all the individual dictionaries have been made for the sample
|
|
607
|
+
if dict_index == 7:
|
|
608
|
+
# Load the minus strand dictionaries for each sample into temporary variables
|
|
609
|
+
temp_a_dict = dict_list[2][record][sample].copy()
|
|
610
|
+
temp_c_dict = dict_list[5][record][sample].copy()
|
|
611
|
+
mod_strand_record_sample_dict[sample] = {}
|
|
612
|
+
# Iterate over the reads present in the merge of both dictionaries
|
|
613
|
+
for read in set(temp_a_dict) | set(temp_c_dict):
|
|
614
|
+
# Add the arrays element-wise if the read is present in both dictionaries
|
|
615
|
+
if read in temp_a_dict and read in temp_c_dict:
|
|
616
|
+
mod_strand_record_sample_dict[sample][read] = np.where(np.isnan(temp_a_dict[read]) & np.isnan(temp_c_dict[read]), np.nan, np.nan_to_num(temp_a_dict[read]) + np.nan_to_num(temp_c_dict[read]))
|
|
617
|
+
# If the read is present in only one dictionary, copy its value
|
|
618
|
+
elif read in temp_a_dict:
|
|
619
|
+
mod_strand_record_sample_dict[sample][read] = temp_a_dict[read]
|
|
620
|
+
elif read in temp_c_dict:
|
|
621
|
+
mod_strand_record_sample_dict[sample][read] = temp_c_dict[read]
|
|
622
|
+
del temp_a_dict, temp_c_dict
|
|
623
|
+
# Load the combined top strand dictionary after all the individual dictionaries have been made for the sample
|
|
624
|
+
elif dict_index == 8:
|
|
625
|
+
# Load the plus strand dictionaries for each sample into temporary variables
|
|
626
|
+
temp_a_dict = dict_list[3][record][sample].copy()
|
|
627
|
+
temp_c_dict = dict_list[6][record][sample].copy()
|
|
628
|
+
mod_strand_record_sample_dict[sample] = {}
|
|
629
|
+
# Iterate over the reads present in the merge of both dictionaries
|
|
630
|
+
for read in set(temp_a_dict) | set(temp_c_dict):
|
|
631
|
+
# Add the arrays element-wise if the read is present in both dictionaries
|
|
632
|
+
if read in temp_a_dict and read in temp_c_dict:
|
|
633
|
+
mod_strand_record_sample_dict[sample][read] = np.where(np.isnan(temp_a_dict[read]) & np.isnan(temp_c_dict[read]), np.nan, np.nan_to_num(temp_a_dict[read]) + np.nan_to_num(temp_c_dict[read]))
|
|
634
|
+
# If the read is present in only one dictionary, copy its value
|
|
635
|
+
elif read in temp_a_dict:
|
|
636
|
+
mod_strand_record_sample_dict[sample][read] = temp_a_dict[read]
|
|
637
|
+
elif read in temp_c_dict:
|
|
638
|
+
mod_strand_record_sample_dict[sample][read] = temp_c_dict[read]
|
|
639
|
+
del temp_a_dict, temp_c_dict
|
|
640
|
+
# For all other dictionaries
|
|
641
|
+
else:
|
|
642
|
+
|
|
643
|
+
# use temp_df to point to the dataframe held in mod_strand_record_sample_dict[sample]
|
|
644
|
+
temp_df = mod_strand_record_sample_dict[sample]
|
|
645
|
+
# reassign the dictionary pointer to a nested dictionary.
|
|
646
|
+
mod_strand_record_sample_dict[sample] = {}
|
|
647
|
+
|
|
648
|
+
# Get relevant columns as NumPy arrays
|
|
649
|
+
read_ids = temp_df['read_id'].values
|
|
650
|
+
positions = temp_df['ref_position'].values
|
|
651
|
+
call_codes = temp_df['call_code'].values
|
|
652
|
+
probabilities = temp_df['call_prob'].values
|
|
653
|
+
|
|
654
|
+
# Define valid call code categories
|
|
655
|
+
modified_codes = {'a', 'h', 'm'}
|
|
656
|
+
canonical_codes = {'-'}
|
|
657
|
+
|
|
658
|
+
# Vectorized methylation calculation with NaN for other codes
|
|
659
|
+
methylation_prob = np.full_like(probabilities, np.nan) # Default all to NaN
|
|
660
|
+
methylation_prob[np.isin(call_codes, list(modified_codes))] = probabilities[np.isin(call_codes, list(modified_codes))]
|
|
661
|
+
methylation_prob[np.isin(call_codes, list(canonical_codes))] = 1 - probabilities[np.isin(call_codes, list(canonical_codes))]
|
|
662
|
+
|
|
663
|
+
# Find unique reads
|
|
664
|
+
unique_reads = np.unique(read_ids)
|
|
665
|
+
# Preallocate storage for each read
|
|
666
|
+
for read in unique_reads:
|
|
667
|
+
mod_strand_record_sample_dict[sample][read] = np.full(max_reference_length, np.nan)
|
|
668
|
+
|
|
669
|
+
# Efficient NumPy indexing to assign values
|
|
670
|
+
for i in range(len(read_ids)):
|
|
671
|
+
read = read_ids[i]
|
|
672
|
+
pos = positions[i]
|
|
673
|
+
prob = methylation_prob[i]
|
|
674
|
+
|
|
675
|
+
# Assign methylation probability
|
|
676
|
+
mod_strand_record_sample_dict[sample][read][pos] = prob
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
# Save the sample files in the batch as gzipped hdf5 files
|
|
680
|
+
os.chdir(h5_dir)
|
|
681
|
+
print('{0}: Converting batch {1} dictionaries to anndata objects'.format(readwrite.time_string(), batch))
|
|
682
|
+
for dict_index, dict_type in enumerate(dict_list):
|
|
683
|
+
if dict_index not in dict_to_skip:
|
|
684
|
+
# Initialize an hdf5 file for the current modified strand
|
|
685
|
+
adata = None
|
|
686
|
+
print('{0}: Converting {1} dictionary to an anndata object'.format(readwrite.time_string(), sample_types[dict_index]))
|
|
687
|
+
for record in dict_type.keys():
|
|
688
|
+
# Get the dictionary for the modification type of interest from the reference mapping of interest
|
|
689
|
+
mod_strand_record_sample_dict = dict_type[record]
|
|
690
|
+
for sample in mod_strand_record_sample_dict.keys():
|
|
691
|
+
print('{0}: Converting {1} dictionary for sample {2} to an anndata object'.format(readwrite.time_string(), sample_types[dict_index], sample))
|
|
692
|
+
sample = int(sample)
|
|
693
|
+
final_sample_index = sample + (batch * batch_size)
|
|
694
|
+
print('{0}: Final sample index for sample: {1}'.format(readwrite.time_string(), final_sample_index))
|
|
695
|
+
print('{0}: Converting {1} dictionary for sample {2} to a dataframe'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
|
|
696
|
+
temp_df = pd.DataFrame.from_dict(mod_strand_record_sample_dict[sample], orient='index')
|
|
697
|
+
mod_strand_record_sample_dict[sample] = None # reassign pointer to facilitate memory usage
|
|
698
|
+
sorted_index = sorted(temp_df.index)
|
|
699
|
+
temp_df = temp_df.reindex(sorted_index)
|
|
700
|
+
X = temp_df.values
|
|
701
|
+
dataset, strand = sample_types[dict_index].split('_')[:2]
|
|
702
|
+
|
|
703
|
+
print('{0}: Loading {1} dataframe for sample {2} into a temp anndata object'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
|
|
704
|
+
temp_adata = ad.AnnData(X)
|
|
705
|
+
if temp_adata.shape[0] > 0:
|
|
706
|
+
print('{0}: Adding read names and position ids to {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
|
|
707
|
+
temp_adata.obs_names = temp_df.index
|
|
708
|
+
temp_adata.obs_names = temp_adata.obs_names.astype(str)
|
|
709
|
+
temp_adata.var_names = temp_df.columns
|
|
710
|
+
temp_adata.var_names = temp_adata.var_names.astype(str)
|
|
711
|
+
print('{0}: Adding {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
|
|
712
|
+
temp_adata.obs['Sample'] = [str(final_sample_index)] * len(temp_adata)
|
|
713
|
+
temp_adata.obs['Barcode'] = [str(final_sample_index)] * len(temp_adata)
|
|
714
|
+
temp_adata.obs['Reference'] = [f'{record}'] * len(temp_adata)
|
|
715
|
+
temp_adata.obs['Strand'] = [strand] * len(temp_adata)
|
|
716
|
+
temp_adata.obs['Dataset'] = [dataset] * len(temp_adata)
|
|
717
|
+
temp_adata.obs['Reference_dataset_strand'] = [f'{record}_{dataset}_{strand}'] * len(temp_adata)
|
|
718
|
+
temp_adata.obs['Reference_strand'] = [f'{record}_{strand}'] * len(temp_adata)
|
|
719
|
+
|
|
720
|
+
# Load in the one hot encoded reads from the current sample and record
|
|
721
|
+
one_hot_reads = {}
|
|
722
|
+
n_rows_OHE = 5
|
|
723
|
+
ohe_files = bam_record_ohe_files[f'{final_sample_index}_{record}']
|
|
724
|
+
print(f'Loading OHEs from {ohe_files}')
|
|
725
|
+
fwd_mapped_reads = set()
|
|
726
|
+
rev_mapped_reads = set()
|
|
727
|
+
for ohe_file in ohe_files:
|
|
728
|
+
tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
|
|
729
|
+
one_hot_reads.update(tmp_ohe_dict)
|
|
730
|
+
if '_fwd_' in ohe_file:
|
|
731
|
+
fwd_mapped_reads.update(tmp_ohe_dict.keys())
|
|
732
|
+
elif '_rev_' in ohe_file:
|
|
733
|
+
rev_mapped_reads.update(tmp_ohe_dict.keys())
|
|
734
|
+
del tmp_ohe_dict
|
|
735
|
+
|
|
736
|
+
read_names = list(one_hot_reads.keys())
|
|
737
|
+
|
|
738
|
+
read_mapping_direction = []
|
|
739
|
+
for read_id in temp_adata.obs_names:
|
|
740
|
+
if read_id in fwd_mapped_reads:
|
|
741
|
+
read_mapping_direction.append('fwd')
|
|
742
|
+
elif read_id in rev_mapped_reads:
|
|
743
|
+
read_mapping_direction.append('rev')
|
|
744
|
+
else:
|
|
745
|
+
read_mapping_direction.append('unk')
|
|
746
|
+
|
|
747
|
+
temp_adata.obs['Read_mapping_direction'] = read_mapping_direction
|
|
748
|
+
|
|
749
|
+
del temp_df
|
|
750
|
+
|
|
751
|
+
# Initialize NumPy arrays
|
|
752
|
+
sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
|
|
753
|
+
df_A = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
754
|
+
df_C = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
755
|
+
df_G = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
756
|
+
df_T = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
757
|
+
df_N = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
758
|
+
|
|
759
|
+
# Process one-hot data into dictionaries
|
|
760
|
+
dict_A, dict_C, dict_G, dict_T, dict_N = {}, {}, {}, {}, {}
|
|
761
|
+
for read_name, one_hot_array in one_hot_reads.items():
|
|
762
|
+
one_hot_array = one_hot_array.reshape(n_rows_OHE, -1)
|
|
763
|
+
dict_A[read_name] = one_hot_array[0, :]
|
|
764
|
+
dict_C[read_name] = one_hot_array[1, :]
|
|
765
|
+
dict_G[read_name] = one_hot_array[2, :]
|
|
766
|
+
dict_T[read_name] = one_hot_array[3, :]
|
|
767
|
+
dict_N[read_name] = one_hot_array[4, :]
|
|
768
|
+
|
|
769
|
+
del one_hot_reads
|
|
770
|
+
gc.collect()
|
|
771
|
+
|
|
772
|
+
# Fill the arrays
|
|
773
|
+
for j, read_name in tqdm(enumerate(sorted_index), desc='Loading dataframes of OHE reads', total=len(sorted_index)):
|
|
774
|
+
df_A[j, :] = dict_A[read_name]
|
|
775
|
+
df_C[j, :] = dict_C[read_name]
|
|
776
|
+
df_G[j, :] = dict_G[read_name]
|
|
777
|
+
df_T[j, :] = dict_T[read_name]
|
|
778
|
+
df_N[j, :] = dict_N[read_name]
|
|
779
|
+
|
|
780
|
+
del dict_A, dict_C, dict_G, dict_T, dict_N
|
|
781
|
+
gc.collect()
|
|
782
|
+
|
|
783
|
+
# Store the results in AnnData layers
|
|
784
|
+
ohe_df_map = {0: df_A, 1: df_C, 2: df_G, 3: df_T, 4: df_N}
|
|
785
|
+
for j, base in enumerate(['A', 'C', 'G', 'T', 'N']):
|
|
786
|
+
temp_adata.layers[f'{base}_binary_encoding'] = ohe_df_map[j]
|
|
787
|
+
ohe_df_map[j] = None # Reassign pointer for memory usage purposes
|
|
788
|
+
|
|
789
|
+
# If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
|
|
790
|
+
if adata:
|
|
791
|
+
if temp_adata.shape[0] > 0:
|
|
792
|
+
print('{0}: Concatenating {1} anndata object for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
|
|
793
|
+
adata = ad.concat([adata, temp_adata], join='outer', index_unique=None)
|
|
794
|
+
del temp_adata
|
|
795
|
+
else:
|
|
796
|
+
print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata")
|
|
797
|
+
else:
|
|
798
|
+
if temp_adata.shape[0] > 0:
|
|
799
|
+
print('{0}: Initializing {1} anndata object for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
|
|
800
|
+
adata = temp_adata
|
|
801
|
+
else:
|
|
802
|
+
print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata")
|
|
803
|
+
|
|
804
|
+
gc.collect()
|
|
805
|
+
else:
|
|
806
|
+
print(f"{sample} did not have any mapped reads on {record}_{dataset}_{strand}, omiting from final adata. Skipping sample.")
|
|
807
|
+
|
|
808
|
+
try:
|
|
809
|
+
print('{0}: Writing {1} anndata out as a hdf5 file'.format(readwrite.time_string(), sample_types[dict_index]))
|
|
810
|
+
adata.write_h5ad('{0}_{1}_{2}_SMF_binarized_sample_hdf5.h5ad.gz'.format(readwrite.date_string(), batch, sample_types[dict_index]), compression='gzip')
|
|
811
|
+
except:
|
|
812
|
+
print(f"Skipping writing anndata for sample")
|
|
813
|
+
|
|
814
|
+
# Delete the batch dictionaries from memory
|
|
815
|
+
del dict_list, adata
|
|
816
|
+
gc.collect()
|
|
817
|
+
|
|
818
|
+
# Iterate over all of the batched hdf5 files and concatenate them.
|
|
819
|
+
os.chdir(h5_dir)
|
|
820
|
+
files = os.listdir(h5_dir)
|
|
821
|
+
# Filter file names that contain the search string in their filename and keep them in a list
|
|
822
|
+
hdfs = [hdf for hdf in files if 'hdf5.h5ad' in hdf and hdf != final_hdf]
|
|
823
|
+
combined_hdfs = [hdf for hdf in hdfs if "combined" in hdf]
|
|
824
|
+
if len(combined_hdfs) > 0:
|
|
825
|
+
hdfs = combined_hdfs
|
|
826
|
+
else:
|
|
827
|
+
pass
|
|
828
|
+
# Sort file list by names and print the list of file names
|
|
829
|
+
hdfs.sort()
|
|
830
|
+
print('{0} sample files found: {1}'.format(len(hdfs), hdfs))
|
|
831
|
+
hdf_paths = [os.path.join(h5_dir, hd5) for hd5 in hdfs]
|
|
832
|
+
final_adata = None
|
|
833
|
+
for hdf_index, hdf in enumerate(hdf_paths):
|
|
834
|
+
print('{0}: Reading in {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
|
|
835
|
+
temp_adata = ad.read_h5ad(hdf)
|
|
836
|
+
if final_adata:
|
|
837
|
+
print('{0}: Concatenating final adata object with {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
|
|
838
|
+
final_adata = ad.concat([final_adata, temp_adata], join='outer', index_unique=None)
|
|
839
|
+
else:
|
|
840
|
+
print('{0}: Initializing final adata object with {1} hdf5 file'.format(readwrite.time_string(), hdfs[hdf_index]))
|
|
841
|
+
final_adata = temp_adata
|
|
842
|
+
del temp_adata
|
|
843
|
+
|
|
844
|
+
# Set obs columns to type 'category'
|
|
845
|
+
for col in final_adata.obs.columns:
|
|
846
|
+
final_adata.obs[col] = final_adata.obs[col].astype('category')
|
|
847
|
+
|
|
848
|
+
ohe_bases = ['A', 'C', 'G', 'T'] # ignore N bases for consensus
|
|
849
|
+
ohe_layers = [f"{ohe_base}_binary_encoding" for ohe_base in ohe_bases]
|
|
850
|
+
for record in records_to_analyze:
|
|
851
|
+
# Add FASTA sequence to the object
|
|
852
|
+
sequence = record_seq_dict[record][0]
|
|
853
|
+
complement = record_seq_dict[record][1]
|
|
854
|
+
final_adata.var[f'{record}_top_strand_FASTA_base'] = list(sequence)
|
|
855
|
+
final_adata.var[f'{record}_bottom_strand_FASTA_base'] = list(complement)
|
|
856
|
+
final_adata.uns[f'{record}_FASTA_sequence'] = sequence
|
|
857
|
+
# Add consensus sequence of samples mapped to the record to the object
|
|
858
|
+
record_subset = final_adata[final_adata.obs['Reference'] == record]
|
|
859
|
+
for strand in record_subset.obs['Strand'].cat.categories:
|
|
860
|
+
strand_subset = record_subset[record_subset.obs['Strand'] == strand]
|
|
861
|
+
for mapping_dir in strand_subset.obs['Read_mapping_direction'].cat.categories:
|
|
862
|
+
mapping_dir_subset = strand_subset[strand_subset.obs['Read_mapping_direction'] == mapping_dir]
|
|
863
|
+
layer_map, layer_counts = {}, []
|
|
864
|
+
for i, layer in enumerate(ohe_layers):
|
|
865
|
+
layer_map[i] = layer.split('_')[0]
|
|
866
|
+
layer_counts.append(np.sum(mapping_dir_subset.layers[layer], axis=0))
|
|
867
|
+
count_array = np.array(layer_counts)
|
|
868
|
+
nucleotide_indexes = np.argmax(count_array, axis=0)
|
|
869
|
+
consensus_sequence_list = [layer_map[i] for i in nucleotide_indexes]
|
|
870
|
+
final_adata.var[f'{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples'] = consensus_sequence_list
|
|
871
|
+
|
|
872
|
+
#final_adata.write_h5ad(final_adata_path)
|
|
873
|
+
|
|
874
|
+
# Delete the individual h5ad files and only keep the final concatenated file
|
|
875
|
+
if delete_batch_hdfs:
|
|
876
|
+
files = os.listdir(h5_dir)
|
|
877
|
+
hdfs_to_delete = [hdf for hdf in files if 'hdf5.h5ad' in hdf and hdf != final_hdf]
|
|
878
|
+
hdf_paths_to_delete = [os.path.join(h5_dir, hdf) for hdf in hdfs_to_delete]
|
|
879
|
+
# Iterate over the files and delete them
|
|
880
|
+
for hdf in hdf_paths_to_delete:
|
|
881
|
+
try:
|
|
882
|
+
os.remove(hdf)
|
|
883
|
+
print(f"Deleted file: {hdf}")
|
|
884
|
+
except OSError as e:
|
|
885
|
+
print(f"Error deleting file {hdf}: {e}")
|
|
886
|
+
|
|
887
|
+
return final_adata, final_adata_path
|