smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +5 -1
- smftools/_version.py +1 -1
- smftools/informatics/__init__.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +63 -10
- smftools/informatics/direct_smf.py +66 -18
- smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
- smftools/informatics/helpers/__init__.py +16 -2
- smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
- smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
- smftools/informatics/helpers/canoncall.py +12 -3
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
- smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +33 -46
- smftools/informatics/helpers/extract_mods.py +55 -23
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/find_conversion_sites.py +33 -44
- smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
- smftools/informatics/helpers/modcall.py +13 -5
- smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
- smftools/informatics/helpers/ohe_batching.py +65 -41
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +45 -9
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/split_and_index_BAM.py +3 -8
- smftools/informatics/load_adata.py +58 -3
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +6 -7
- smftools/preprocessing/append_C_context.py +22 -9
- smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
- smftools/preprocessing/binarize_on_Youden.py +35 -32
- smftools/preprocessing/binary_layers_to_ohe.py +13 -3
- smftools/preprocessing/calculate_complexity.py +3 -2
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
- smftools/preprocessing/calculate_coverage.py +26 -25
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_position_Youden.py +18 -7
- smftools/preprocessing/calculate_read_length_stats.py +39 -46
- smftools/preprocessing/clean_NaN.py +33 -25
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
- smftools/preprocessing/filter_reads_on_length.py +14 -4
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +18 -11
- smftools/preprocessing/load_sample_sheet.py +30 -16
- smftools/preprocessing/recipes.py +22 -20
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +105 -13
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +19 -23
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools/tools/apply_HMM.py +0 -1
- smftools/tools/read_HMM.py +0 -1
- smftools/tools/train_HMM.py +0 -43
- smftools-0.1.3.dist-info/RECORD +0 -84
- /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
- /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import time
|
|
3
|
+
import os
|
|
4
|
+
import gc
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import anndata as ad
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
import multiprocessing
|
|
9
|
+
from multiprocessing import Manager, Lock, current_process, Pool
|
|
10
|
+
import traceback
|
|
11
|
+
import gzip
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from .. import readwrite
|
|
15
|
+
from .binarize_converted_base_identities import binarize_converted_base_identities
|
|
16
|
+
from .find_conversion_sites import find_conversion_sites
|
|
17
|
+
from .count_aligned_reads import count_aligned_reads
|
|
18
|
+
from .extract_base_identities import extract_base_identities
|
|
19
|
+
from .make_dirs import make_dirs
|
|
20
|
+
from .ohe_batching import ohe_batching
|
|
21
|
+
|
|
22
|
+
if __name__ == "__main__":
|
|
23
|
+
multiprocessing.set_start_method("forkserver", force=True)
|
|
24
|
+
|
|
25
|
+
def converted_BAM_to_adata_II(converted_FASTA, split_dir, mapping_threshold, experiment_name, conversion_types, bam_suffix, device='cpu', num_threads=8):
|
|
26
|
+
"""
|
|
27
|
+
Converts BAM files into an AnnData object by binarizing modified base identities.
|
|
28
|
+
|
|
29
|
+
Parameters:
|
|
30
|
+
converted_FASTA (str): Path to the converted FASTA reference.
|
|
31
|
+
split_dir (str): Directory containing converted BAM files.
|
|
32
|
+
mapping_threshold (float): Minimum fraction of aligned reads required for inclusion.
|
|
33
|
+
experiment_name (str): Name for the output AnnData object.
|
|
34
|
+
conversion_types (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
|
|
35
|
+
bam_suffix (str): File suffix for BAM files.
|
|
36
|
+
num_threads (int): Number of parallel processing threads.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
str: Path to the final AnnData object.
|
|
40
|
+
"""
|
|
41
|
+
if torch.cuda.is_available():
|
|
42
|
+
device = torch.device("cuda")
|
|
43
|
+
elif torch.backends.mps.is_available():
|
|
44
|
+
device = torch.device("mps")
|
|
45
|
+
else:
|
|
46
|
+
device = torch.device("cpu")
|
|
47
|
+
|
|
48
|
+
print(f"Using device: {device}")
|
|
49
|
+
|
|
50
|
+
## Set Up Directories and File Paths
|
|
51
|
+
parent_dir = os.path.dirname(split_dir)
|
|
52
|
+
h5_dir = os.path.join(parent_dir, 'h5ads')
|
|
53
|
+
tmp_dir = os.path.join(parent_dir, 'tmp')
|
|
54
|
+
final_adata_path = os.path.join(h5_dir, f'{experiment_name}_{os.path.basename(split_dir)}.h5ad.gz')
|
|
55
|
+
|
|
56
|
+
if os.path.exists(final_adata_path):
|
|
57
|
+
print(f"{final_adata_path} already exists. Using existing AnnData object.")
|
|
58
|
+
return final_adata_path
|
|
59
|
+
|
|
60
|
+
make_dirs([h5_dir, tmp_dir])
|
|
61
|
+
|
|
62
|
+
## Get BAM Files ##
|
|
63
|
+
bam_files = [f for f in os.listdir(split_dir) if f.endswith(bam_suffix) and not f.endswith('.bai') and 'unclassified' not in f]
|
|
64
|
+
bam_files.sort()
|
|
65
|
+
bam_path_list = [os.path.join(split_dir, f) for f in bam_files]
|
|
66
|
+
print(f"Found {len(bam_files)} BAM files: {bam_files}")
|
|
67
|
+
|
|
68
|
+
## Process Conversion Sites
|
|
69
|
+
max_reference_length, record_FASTA_dict, chromosome_FASTA_dict = process_conversion_sites(converted_FASTA, conversion_types)
|
|
70
|
+
|
|
71
|
+
## Filter BAM Files by Mapping Threshold
|
|
72
|
+
records_to_analyze = filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold)
|
|
73
|
+
|
|
74
|
+
## Process BAMs in Parallel
|
|
75
|
+
final_adata = process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device)
|
|
76
|
+
|
|
77
|
+
for chromosome, [seq, comp] in chromosome_FASTA_dict.items():
|
|
78
|
+
final_adata.var[f'{chromosome}_top_strand_FASTA_base'] = list(seq)
|
|
79
|
+
final_adata.var[f'{chromosome}_bottom_strand_FASTA_base'] = list(comp)
|
|
80
|
+
final_adata.uns[f'{chromosome}_FASTA_sequence'] = seq
|
|
81
|
+
|
|
82
|
+
## Save Final AnnData
|
|
83
|
+
# print(f"Saving AnnData to {final_adata_path}")
|
|
84
|
+
# final_adata.write_h5ad(final_adata_path, compression='gzip')
|
|
85
|
+
return final_adata, final_adata_path
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def process_conversion_sites(converted_FASTA, conversion_types):
|
|
89
|
+
"""
|
|
90
|
+
Extracts conversion sites and determines the max reference length.
|
|
91
|
+
|
|
92
|
+
Parameters:
|
|
93
|
+
converted_FASTA (str): Path to the converted reference FASTA.
|
|
94
|
+
conversion_types (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
max_reference_length (int): The length of the longest sequence.
|
|
98
|
+
record_FASTA_dict (dict): Dictionary of sequence information for **both converted & unconverted** records.
|
|
99
|
+
"""
|
|
100
|
+
modification_dict = {}
|
|
101
|
+
record_FASTA_dict = {}
|
|
102
|
+
chromosome_FASTA_dict = {}
|
|
103
|
+
max_reference_length = 0
|
|
104
|
+
unconverted = conversion_types[0]
|
|
105
|
+
conversions = conversion_types[1:]
|
|
106
|
+
|
|
107
|
+
# Process the unconverted sequence once
|
|
108
|
+
modification_dict[unconverted] = find_conversion_sites(converted_FASTA, unconverted, conversion_types)
|
|
109
|
+
# Above points to record_dict[record.id] = [sequence_length, [], [], sequence, complement] with only unconverted record.id keys
|
|
110
|
+
|
|
111
|
+
# Get **max sequence length** from unconverted records
|
|
112
|
+
max_reference_length = max(values[0] for values in modification_dict[unconverted].values())
|
|
113
|
+
|
|
114
|
+
# Add **unconverted records** to `record_FASTA_dict`
|
|
115
|
+
for record, values in modification_dict[unconverted].items():
|
|
116
|
+
sequence_length, top_coords, bottom_coords, sequence, complement = values
|
|
117
|
+
chromosome = record.replace(f"_{unconverted}_top", "")
|
|
118
|
+
|
|
119
|
+
# Store **original sequence**
|
|
120
|
+
record_FASTA_dict[record] = [
|
|
121
|
+
sequence + "N" * (max_reference_length - sequence_length),
|
|
122
|
+
complement + "N" * (max_reference_length - sequence_length),
|
|
123
|
+
chromosome, record, sequence_length, max_reference_length - sequence_length, unconverted, "top"
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
if chromosome not in chromosome_FASTA_dict:
|
|
127
|
+
chromosome_FASTA_dict[chromosome] = [sequence + "N" * (max_reference_length - sequence_length), complement + "N" * (max_reference_length - sequence_length)]
|
|
128
|
+
|
|
129
|
+
# Process converted records
|
|
130
|
+
for conversion in conversions:
|
|
131
|
+
modification_dict[conversion] = find_conversion_sites(converted_FASTA, conversion, conversion_types)
|
|
132
|
+
# Above points to record_dict[record.id] = [sequence_length, top_strand_coordinates, bottom_strand_coordinates, sequence, complement] with only unconverted record.id keys
|
|
133
|
+
|
|
134
|
+
for record, values in modification_dict[conversion].items():
|
|
135
|
+
sequence_length, top_coords, bottom_coords, sequence, complement = values
|
|
136
|
+
chromosome = record.split(f"_{unconverted}_")[0] # Extract chromosome name
|
|
137
|
+
|
|
138
|
+
# Add **both strands** for converted records
|
|
139
|
+
for strand in ["top", "bottom"]:
|
|
140
|
+
converted_name = f"{chromosome}_{conversion}_{strand}"
|
|
141
|
+
unconverted_name = f"{chromosome}_{unconverted}_top"
|
|
142
|
+
|
|
143
|
+
record_FASTA_dict[converted_name] = [
|
|
144
|
+
sequence + "N" * (max_reference_length - sequence_length),
|
|
145
|
+
complement + "N" * (max_reference_length - sequence_length),
|
|
146
|
+
chromosome, unconverted_name, sequence_length,
|
|
147
|
+
max_reference_length - sequence_length, conversion, strand
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
print("Updated record_FASTA_dict Keys:", list(record_FASTA_dict.keys()))
|
|
151
|
+
return max_reference_length, record_FASTA_dict, chromosome_FASTA_dict
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold):
|
|
155
|
+
"""Filters BAM files based on mapping threshold."""
|
|
156
|
+
records_to_analyze = set()
|
|
157
|
+
|
|
158
|
+
for i, bam in enumerate(bam_path_list):
|
|
159
|
+
aligned_reads, unaligned_reads, record_counts = count_aligned_reads(bam)
|
|
160
|
+
aligned_percent = aligned_reads * 100 / (aligned_reads + unaligned_reads)
|
|
161
|
+
print(f"{aligned_percent:.2f}% of reads in {bam_files[i]} aligned successfully.")
|
|
162
|
+
|
|
163
|
+
for record, (count, percent) in record_counts.items():
|
|
164
|
+
if percent >= mapping_threshold:
|
|
165
|
+
records_to_analyze.add(record)
|
|
166
|
+
|
|
167
|
+
print(f"Analyzing the following FASTA records: {records_to_analyze}")
|
|
168
|
+
return records_to_analyze
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, tmp_dir, max_reference_length, device):
|
|
172
|
+
"""Worker function to process a single BAM file (must be at top-level for multiprocessing)."""
|
|
173
|
+
adata_list = []
|
|
174
|
+
|
|
175
|
+
for record in records_to_analyze:
|
|
176
|
+
sample = os.path.basename(bam).split(sep=".bam")[0]
|
|
177
|
+
chromosome = record_FASTA_dict[record][2]
|
|
178
|
+
current_length = record_FASTA_dict[record][4]
|
|
179
|
+
mod_type, strand = record_FASTA_dict[record][6], record_FASTA_dict[record][7]
|
|
180
|
+
|
|
181
|
+
# Extract Base Identities
|
|
182
|
+
fwd_bases, rev_bases = extract_base_identities(bam, record, range(current_length), max_reference_length)
|
|
183
|
+
|
|
184
|
+
# Skip processing if both forward and reverse base identities are empty
|
|
185
|
+
if not fwd_bases and not rev_bases:
|
|
186
|
+
print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid base identities for {record}.")
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
merged_bin = {}
|
|
190
|
+
|
|
191
|
+
# Binarize the Base Identities if they exist
|
|
192
|
+
if fwd_bases:
|
|
193
|
+
fwd_bin = binarize_converted_base_identities(fwd_bases, strand, mod_type, bam, device)
|
|
194
|
+
merged_bin.update(fwd_bin)
|
|
195
|
+
|
|
196
|
+
if rev_bases:
|
|
197
|
+
rev_bin = binarize_converted_base_identities(rev_bases, strand, mod_type, bam, device)
|
|
198
|
+
merged_bin.update(rev_bin)
|
|
199
|
+
|
|
200
|
+
# Skip if merged_bin is empty (no valid binarized data)
|
|
201
|
+
if not merged_bin:
|
|
202
|
+
print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid binarized data for {record}.")
|
|
203
|
+
continue
|
|
204
|
+
|
|
205
|
+
# Convert to DataFrame
|
|
206
|
+
# for key in merged_bin:
|
|
207
|
+
# merged_bin[key] = merged_bin[key].cpu().numpy() # Move to CPU & convert to NumPy
|
|
208
|
+
bin_df = pd.DataFrame.from_dict(merged_bin, orient='index')
|
|
209
|
+
sorted_index = sorted(bin_df.index)
|
|
210
|
+
bin_df = bin_df.reindex(sorted_index)
|
|
211
|
+
|
|
212
|
+
# One-Hot Encode Reads if there is valid data
|
|
213
|
+
one_hot_reads = {}
|
|
214
|
+
|
|
215
|
+
if fwd_bases:
|
|
216
|
+
fwd_ohe_files = ohe_batching(fwd_bases, tmp_dir, record, f"{bam_index}_fwd", batch_size=100000)
|
|
217
|
+
for ohe_file in fwd_ohe_files:
|
|
218
|
+
tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
|
|
219
|
+
one_hot_reads.update(tmp_ohe_dict)
|
|
220
|
+
del tmp_ohe_dict
|
|
221
|
+
|
|
222
|
+
if rev_bases:
|
|
223
|
+
rev_ohe_files = ohe_batching(rev_bases, tmp_dir, record, f"{bam_index}_rev", batch_size=100000)
|
|
224
|
+
for ohe_file in rev_ohe_files:
|
|
225
|
+
tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
|
|
226
|
+
one_hot_reads.update(tmp_ohe_dict)
|
|
227
|
+
del tmp_ohe_dict
|
|
228
|
+
|
|
229
|
+
# Skip if one_hot_reads is empty
|
|
230
|
+
if not one_hot_reads:
|
|
231
|
+
print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid one-hot encoded data for {record}.")
|
|
232
|
+
continue
|
|
233
|
+
|
|
234
|
+
gc.collect()
|
|
235
|
+
|
|
236
|
+
# Convert One-Hot Encodings to Numpy Arrays
|
|
237
|
+
n_rows_OHE = 5
|
|
238
|
+
read_names = list(one_hot_reads.keys())
|
|
239
|
+
|
|
240
|
+
# Skip if no read names exist
|
|
241
|
+
if not read_names:
|
|
242
|
+
print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No reads found in one-hot encoded data for {record}.")
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
|
|
246
|
+
df_A, df_C, df_G, df_T, df_N = [np.zeros((len(sorted_index), sequence_length), dtype=int) for _ in range(5)]
|
|
247
|
+
|
|
248
|
+
# Populate One-Hot Arrays
|
|
249
|
+
for j, read_name in enumerate(sorted_index):
|
|
250
|
+
if read_name in one_hot_reads:
|
|
251
|
+
one_hot_array = one_hot_reads[read_name].reshape(n_rows_OHE, -1)
|
|
252
|
+
df_A[j], df_C[j], df_G[j], df_T[j], df_N[j] = one_hot_array
|
|
253
|
+
|
|
254
|
+
# Convert to AnnData
|
|
255
|
+
X = bin_df.values.astype(np.float32)
|
|
256
|
+
adata = ad.AnnData(X)
|
|
257
|
+
adata.obs_names = bin_df.index.astype(str)
|
|
258
|
+
adata.var_names = bin_df.columns.astype(str)
|
|
259
|
+
adata.obs["Sample"] = [sample] * len(adata)
|
|
260
|
+
adata.obs["Reference"] = [chromosome] * len(adata)
|
|
261
|
+
adata.obs["Strand"] = [strand] * len(adata)
|
|
262
|
+
adata.obs["Dataset"] = [mod_type] * len(adata)
|
|
263
|
+
adata.obs["Reference_dataset_strand"] = [f"{chromosome}_{mod_type}_{strand}"] * len(adata)
|
|
264
|
+
adata.obs["Reference_strand"] = [f"{chromosome}_{strand}"] * len(adata)
|
|
265
|
+
|
|
266
|
+
# Attach One-Hot Encodings to Layers
|
|
267
|
+
adata.layers["A_binary_encoding"] = df_A
|
|
268
|
+
adata.layers["C_binary_encoding"] = df_C
|
|
269
|
+
adata.layers["G_binary_encoding"] = df_G
|
|
270
|
+
adata.layers["T_binary_encoding"] = df_T
|
|
271
|
+
adata.layers["N_binary_encoding"] = df_N
|
|
272
|
+
|
|
273
|
+
adata_list.append(adata)
|
|
274
|
+
|
|
275
|
+
return ad.concat(adata_list, join="outer") if adata_list else None
|
|
276
|
+
|
|
277
|
+
def timestamp():
|
|
278
|
+
"""Returns a formatted timestamp for logging."""
|
|
279
|
+
return time.strftime("[%Y-%m-%d %H:%M:%S]")
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def worker_function(bam_index, bam, records_to_analyze, shared_record_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, progress_queue):
|
|
283
|
+
"""Worker function that processes a single BAM and writes the output to an H5AD file."""
|
|
284
|
+
worker_id = current_process().pid # Get worker process ID
|
|
285
|
+
sample = os.path.basename(bam).split(sep=".bam")[0]
|
|
286
|
+
|
|
287
|
+
try:
|
|
288
|
+
print(f"{timestamp()} [Worker {worker_id}] Processing BAM: {sample}")
|
|
289
|
+
|
|
290
|
+
h5ad_path = os.path.join(h5_dir, f"{sample}.h5ad")
|
|
291
|
+
if os.path.exists(h5ad_path):
|
|
292
|
+
print(f"{timestamp()} [Worker {worker_id}] Skipping {sample}: Already processed.")
|
|
293
|
+
progress_queue.put(sample)
|
|
294
|
+
return
|
|
295
|
+
|
|
296
|
+
# Filter records specific to this BAM
|
|
297
|
+
bam_records_to_analyze = {record for record in records_to_analyze if record in shared_record_FASTA_dict}
|
|
298
|
+
|
|
299
|
+
if not bam_records_to_analyze:
|
|
300
|
+
print(f"{timestamp()} [Worker {worker_id}] No valid records to analyze for {sample}. Skipping.")
|
|
301
|
+
progress_queue.put(sample)
|
|
302
|
+
return
|
|
303
|
+
|
|
304
|
+
# Process BAM
|
|
305
|
+
adata = process_single_bam(bam_index, bam, bam_records_to_analyze, shared_record_FASTA_dict, tmp_dir, max_reference_length, device)
|
|
306
|
+
|
|
307
|
+
if adata is not None:
|
|
308
|
+
adata.write_h5ad(h5ad_path)
|
|
309
|
+
print(f"{timestamp()} [Worker {worker_id}] Completed processing for BAM: {sample}")
|
|
310
|
+
|
|
311
|
+
# Free memory
|
|
312
|
+
del adata
|
|
313
|
+
gc.collect()
|
|
314
|
+
|
|
315
|
+
progress_queue.put(sample)
|
|
316
|
+
|
|
317
|
+
except Exception as e:
|
|
318
|
+
print(f"{timestamp()} [Worker {worker_id}] ERROR while processing {sample}:\n{traceback.format_exc()}")
|
|
319
|
+
progress_queue.put(sample) # Still signal completion to prevent deadlock
|
|
320
|
+
|
|
321
|
+
def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device):
|
|
322
|
+
"""Processes BAM files in parallel, writes each H5AD to disk, and concatenates them at the end."""
|
|
323
|
+
os.makedirs(h5_dir, exist_ok=True) # Ensure h5_dir exists
|
|
324
|
+
|
|
325
|
+
print(f"{timestamp()} Starting parallel BAM processing with {num_threads} threads...")
|
|
326
|
+
|
|
327
|
+
# Ensure macOS uses forkserver to avoid spawning issues
|
|
328
|
+
try:
|
|
329
|
+
import multiprocessing
|
|
330
|
+
multiprocessing.set_start_method("forkserver", force=True)
|
|
331
|
+
except RuntimeError:
|
|
332
|
+
print(f"{timestamp()} [WARNING] Multiprocessing context already set. Skipping set_start_method.")
|
|
333
|
+
|
|
334
|
+
with Manager() as manager:
|
|
335
|
+
progress_queue = manager.Queue()
|
|
336
|
+
shared_record_FASTA_dict = manager.dict(record_FASTA_dict)
|
|
337
|
+
|
|
338
|
+
with Pool(processes=num_threads) as pool:
|
|
339
|
+
results = [
|
|
340
|
+
pool.apply_async(worker_function, (i, bam, records_to_analyze, shared_record_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, progress_queue))
|
|
341
|
+
for i, bam in enumerate(bam_path_list)
|
|
342
|
+
]
|
|
343
|
+
|
|
344
|
+
print(f"{timestamp()} Submitted {len(bam_path_list)} BAMs for processing.")
|
|
345
|
+
|
|
346
|
+
# Track completed BAMs
|
|
347
|
+
completed_bams = set()
|
|
348
|
+
while len(completed_bams) < len(bam_path_list):
|
|
349
|
+
try:
|
|
350
|
+
processed_bam = progress_queue.get(timeout=2400) # Wait for a finished BAM
|
|
351
|
+
completed_bams.add(processed_bam)
|
|
352
|
+
except Exception as e:
|
|
353
|
+
print(f"{timestamp()} [ERROR] Timeout waiting for worker process. Possible crash? {e}")
|
|
354
|
+
|
|
355
|
+
pool.close()
|
|
356
|
+
pool.join() # Ensure all workers finish
|
|
357
|
+
|
|
358
|
+
# Final Concatenation Step
|
|
359
|
+
h5ad_files = [os.path.join(h5_dir, f) for f in os.listdir(h5_dir) if f.endswith(".h5ad")]
|
|
360
|
+
|
|
361
|
+
if not h5ad_files:
|
|
362
|
+
print(f"{timestamp()} No valid H5AD files generated. Exiting.")
|
|
363
|
+
return None
|
|
364
|
+
|
|
365
|
+
print(f"{timestamp()} Concatenating {len(h5ad_files)} H5AD files into final output...")
|
|
366
|
+
final_adata = ad.concat([ad.read_h5ad(f) for f in h5ad_files], join="outer")
|
|
367
|
+
|
|
368
|
+
print(f"{timestamp()} Successfully generated final AnnData object.")
|
|
369
|
+
return final_adata
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
## demux_and_index_BAM
|
|
2
|
+
|
|
3
|
+
def demux_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, barcode_kit, barcode_both_ends, trim, fasta, make_bigwigs, threads):
|
|
4
|
+
"""
|
|
5
|
+
A wrapper function for splitting BAMS and indexing them.
|
|
6
|
+
Parameters:
|
|
7
|
+
aligned_sorted_BAM (str): A string representing the file path of the aligned_sorted BAM file.
|
|
8
|
+
split_dir (str): A string representing the file path to the directory to split the BAMs into.
|
|
9
|
+
bam_suffix (str): A suffix to add to the bam file.
|
|
10
|
+
barcode_kit (str): Name of barcoding kit.
|
|
11
|
+
barcode_both_ends (bool): Whether to require both ends to be barcoded.
|
|
12
|
+
trim (bool): Whether to trim off barcodes after demultiplexing.
|
|
13
|
+
fasta (str): File path to the reference genome to align to.
|
|
14
|
+
make_bigwigs (bool): Whether to make bigwigs
|
|
15
|
+
threads (int): Number of threads to use.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
bam_files (list): List of split BAM file path strings
|
|
19
|
+
Splits an input BAM file on barcode value and makes a BAM index file.
|
|
20
|
+
"""
|
|
21
|
+
from .. import readwrite
|
|
22
|
+
import os
|
|
23
|
+
import subprocess
|
|
24
|
+
import glob
|
|
25
|
+
from .make_dirs import make_dirs
|
|
26
|
+
|
|
27
|
+
input_bam = aligned_sorted_BAM + bam_suffix
|
|
28
|
+
command = ["dorado", "demux", "--kit-name", barcode_kit]
|
|
29
|
+
if barcode_both_ends:
|
|
30
|
+
command.append("--barcode-both-ends")
|
|
31
|
+
if not trim:
|
|
32
|
+
command.append("--no-trim")
|
|
33
|
+
if threads:
|
|
34
|
+
command += ["-t", str(threads)]
|
|
35
|
+
else:
|
|
36
|
+
pass
|
|
37
|
+
command += ["--emit-summary", "--sort-bam", "--output-dir", split_dir]
|
|
38
|
+
command.append(input_bam)
|
|
39
|
+
command_string = ' '.join(command)
|
|
40
|
+
print(f"Running: {command_string}")
|
|
41
|
+
subprocess.run(command)
|
|
42
|
+
|
|
43
|
+
# Make a BAM index file for the BAMs in that directory
|
|
44
|
+
bam_pattern = '*' + bam_suffix
|
|
45
|
+
bam_files = glob.glob(os.path.join(split_dir, bam_pattern))
|
|
46
|
+
bam_files = [bam for bam in bam_files if '.bai' not in bam and 'unclassified' not in bam]
|
|
47
|
+
bam_files.sort()
|
|
48
|
+
|
|
49
|
+
if not bam_files:
|
|
50
|
+
raise FileNotFoundError(f"No BAM files found in {split_dir} with suffix {bam_suffix}")
|
|
51
|
+
|
|
52
|
+
return bam_files
|
|
@@ -1,57 +1,44 @@
|
|
|
1
|
-
## extract_base_identities
|
|
2
|
-
|
|
3
|
-
# General
|
|
4
1
|
def extract_base_identities(bam_file, chromosome, positions, max_reference_length):
|
|
5
2
|
"""
|
|
6
|
-
|
|
3
|
+
Efficiently extracts base identities from mapped reads with reference coordinates.
|
|
7
4
|
|
|
8
5
|
Parameters:
|
|
9
|
-
|
|
10
|
-
chromosome (str):
|
|
11
|
-
positions (list):
|
|
12
|
-
max_reference_length (int):
|
|
6
|
+
bam_file (str): Path to the BAM file.
|
|
7
|
+
chromosome (str): Name of the reference chromosome.
|
|
8
|
+
positions (list): Positions to extract (0-based).
|
|
9
|
+
max_reference_length (int): Maximum reference length for padding.
|
|
13
10
|
|
|
14
11
|
Returns:
|
|
15
|
-
|
|
16
|
-
|
|
12
|
+
dict: Base identities from forward mapped reads.
|
|
13
|
+
dict: Base identities from reverse mapped reads.
|
|
17
14
|
"""
|
|
18
|
-
from .. import readwrite
|
|
19
15
|
import pysam
|
|
20
|
-
|
|
21
|
-
|
|
16
|
+
import numpy as np
|
|
17
|
+
from collections import defaultdict
|
|
18
|
+
import time
|
|
19
|
+
|
|
20
|
+
timestamp = time.strftime("[%Y-%m-%d %H:%M:%S]")
|
|
21
|
+
|
|
22
22
|
positions = set(positions)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
#
|
|
27
|
-
print('{0}: Reading BAM file: {1}'.format(readwrite.time_string(), bam_file))
|
|
23
|
+
fwd_base_identities = defaultdict(lambda: np.full(max_reference_length, 'N', dtype='<U1'))
|
|
24
|
+
rev_base_identities = defaultdict(lambda: np.full(max_reference_length, 'N', dtype='<U1'))
|
|
25
|
+
|
|
26
|
+
#print(f"{timestamp} Reading reads from {chromosome} BAM file: {bam_file}")
|
|
28
27
|
with pysam.AlignmentFile(bam_file, "rb") as bam:
|
|
29
|
-
# Iterate over every read in the bam that comes from the chromosome of interest
|
|
30
|
-
print('{0}: Iterating over reads in bam'.format(readwrite.time_string()))
|
|
31
28
|
total_reads = bam.mapped
|
|
32
|
-
for read in
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
# Initialize the read key in a temp base_identities dictionary by pointing to a N filled list of length reference_length.
|
|
49
|
-
fwd_base_identities[read.query_name] = ['N'] * max_reference_length
|
|
50
|
-
# Iterate over a list of tuples for the given read. The tuples contain the 0-indexed position relative to the read.query_sequence start, as well the 0-based index relative to the reference.
|
|
51
|
-
for read_position, reference_position in read.get_aligned_pairs(matches_only=True):
|
|
52
|
-
# If the aligned read's reference coordinate is in the positions set and if the read position was successfully mapped
|
|
53
|
-
if reference_position in positions and read_position:
|
|
54
|
-
# get the base_identity in the read corresponding to that position
|
|
55
|
-
fwd_base_identities[read.query_name][reference_position] = query_sequence[read_position]
|
|
56
|
-
|
|
57
|
-
return fwd_base_identities, rev_base_identities
|
|
29
|
+
for read in bam.fetch(chromosome):
|
|
30
|
+
if not read.is_mapped:
|
|
31
|
+
continue # Skip unmapped reads
|
|
32
|
+
|
|
33
|
+
read_name = read.query_name
|
|
34
|
+
query_sequence = read.query_sequence
|
|
35
|
+
base_dict = rev_base_identities if read.is_reverse else fwd_base_identities
|
|
36
|
+
|
|
37
|
+
# Use get_aligned_pairs directly with positions filtering
|
|
38
|
+
aligned_pairs = read.get_aligned_pairs(matches_only=True)
|
|
39
|
+
|
|
40
|
+
for read_position, reference_position in aligned_pairs:
|
|
41
|
+
if reference_position in positions:
|
|
42
|
+
base_dict[read_name][reference_position] = query_sequence[read_position]
|
|
43
|
+
|
|
44
|
+
return dict(fwd_base_identities), dict(rev_base_identities)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
## extract_mods
|
|
2
2
|
|
|
3
|
-
def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix):
|
|
3
|
+
def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassified=True, modkit_summary=False, threads=None):
|
|
4
4
|
"""
|
|
5
5
|
Takes all of the aligned, sorted, split modified BAM files and runs Nanopore Modkit Extract to load the modification data into zipped TSV files
|
|
6
6
|
|
|
@@ -9,6 +9,9 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix):
|
|
|
9
9
|
mod_tsv_dir (str): A string representing the file path to the directory to hold the modkit extract outputs.
|
|
10
10
|
split_dit (str): A string representing the file path to the directory containing the converted aligned_sorted_split BAM files.
|
|
11
11
|
bam_suffix (str): The suffix to use for the BAM file.
|
|
12
|
+
skip_unclassified (bool): Whether to skip unclassified bam file for modkit extract command
|
|
13
|
+
modkit_summary (bool): Whether to run and display modkit summary
|
|
14
|
+
threads (int): Number of threads to use
|
|
12
15
|
|
|
13
16
|
Returns:
|
|
14
17
|
None
|
|
@@ -23,29 +26,58 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix):
|
|
|
23
26
|
os.chdir(mod_tsv_dir)
|
|
24
27
|
filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
|
|
25
28
|
bam_files = glob.glob(os.path.join(split_dir, f"*{bam_suffix}"))
|
|
29
|
+
|
|
30
|
+
if threads:
|
|
31
|
+
threads = str(threads)
|
|
32
|
+
else:
|
|
33
|
+
pass
|
|
34
|
+
|
|
26
35
|
for input_file in bam_files:
|
|
27
36
|
print(input_file)
|
|
28
37
|
# Extract the file basename
|
|
29
38
|
file_name = os.path.basename(input_file)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
39
|
+
if skip_unclassified and "unclassified" in file_name:
|
|
40
|
+
print("Skipping modkit extract on unclassified reads")
|
|
41
|
+
else:
|
|
42
|
+
# Construct the output TSV file path
|
|
43
|
+
output_tsv_temp = os.path.join(mod_tsv_dir, file_name)
|
|
44
|
+
output_tsv = output_tsv_temp.replace(bam_suffix, "") + "_extract.tsv"
|
|
45
|
+
if os.path.exists(f"{output_tsv}.gz"):
|
|
46
|
+
print(f"{output_tsv}.gz already exists, skipping modkit extract")
|
|
47
|
+
else:
|
|
48
|
+
print(f"Extracting modification data from {input_file}")
|
|
49
|
+
if modkit_summary:
|
|
50
|
+
# Run modkit summary
|
|
51
|
+
subprocess.run(["modkit", "summary", input_file])
|
|
52
|
+
else:
|
|
53
|
+
pass
|
|
54
|
+
# Run modkit extract
|
|
55
|
+
if threads:
|
|
56
|
+
extract_command = [
|
|
57
|
+
"modkit", "extract",
|
|
58
|
+
"calls", "--mapped-only",
|
|
59
|
+
"--filter-threshold", f'{filter_threshold}',
|
|
60
|
+
"--mod-thresholds", f"m:{m5C_threshold}",
|
|
61
|
+
"--mod-thresholds", f"a:{m6A_threshold}",
|
|
62
|
+
"--mod-thresholds", f"h:{hm5C_threshold}",
|
|
63
|
+
"-t", threads,
|
|
64
|
+
input_file, output_tsv
|
|
65
|
+
]
|
|
66
|
+
else:
|
|
67
|
+
extract_command = [
|
|
68
|
+
"modkit", "extract",
|
|
69
|
+
"calls", "--mapped-only",
|
|
70
|
+
"--filter-threshold", f'{filter_threshold}',
|
|
71
|
+
"--mod-thresholds", f"m:{m5C_threshold}",
|
|
72
|
+
"--mod-thresholds", f"a:{m6A_threshold}",
|
|
73
|
+
"--mod-thresholds", f"h:{hm5C_threshold}",
|
|
74
|
+
input_file, output_tsv
|
|
75
|
+
]
|
|
76
|
+
subprocess.run(extract_command)
|
|
77
|
+
# Zip the output TSV
|
|
78
|
+
print(f'zipping {output_tsv}')
|
|
79
|
+
if threads:
|
|
80
|
+
zip_command = ["pigz", "-f", "-p", threads, output_tsv]
|
|
81
|
+
else:
|
|
82
|
+
zip_command = ["pigz", "-f", output_tsv]
|
|
83
|
+
subprocess.run(zip_command, check=True)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# extract_read_features_from_bam
|
|
2
|
+
|
|
3
|
+
def extract_read_features_from_bam(bam_file_path):
|
|
4
|
+
"""
|
|
5
|
+
Make a dict of reads from a bam that points to a list of read metrics: read length, read median Q-score, reference length.
|
|
6
|
+
Params:
|
|
7
|
+
bam_file_path (str):
|
|
8
|
+
Returns:
|
|
9
|
+
read_metrics (dict)
|
|
10
|
+
"""
|
|
11
|
+
import pysam
|
|
12
|
+
import numpy as np
|
|
13
|
+
# Open the BAM file
|
|
14
|
+
print(f'Extracting read features from BAM: {bam_file_path}')
|
|
15
|
+
with pysam.AlignmentFile(bam_file_path, "rb") as bam_file:
|
|
16
|
+
read_metrics = {}
|
|
17
|
+
reference_lengths = bam_file.lengths # List of lengths for each reference (chromosome)
|
|
18
|
+
for read in bam_file:
|
|
19
|
+
# Skip unmapped reads
|
|
20
|
+
if read.is_unmapped:
|
|
21
|
+
continue
|
|
22
|
+
# Extract the read metrics
|
|
23
|
+
read_quality = read.query_qualities
|
|
24
|
+
median_read_quality = np.median(read_quality)
|
|
25
|
+
# Extract the reference (chromosome) name and its length
|
|
26
|
+
reference_name = read.reference_name
|
|
27
|
+
reference_index = bam_file.references.index(reference_name)
|
|
28
|
+
reference_length = reference_lengths[reference_index]
|
|
29
|
+
read_metrics[read.query_name] = [read.query_length, median_read_quality, reference_length]
|
|
30
|
+
|
|
31
|
+
return read_metrics
|