smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. smftools/__init__.py +5 -1
  2. smftools/_version.py +1 -1
  3. smftools/informatics/__init__.py +2 -0
  4. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  5. smftools/informatics/basecall_pod5s.py +80 -0
  6. smftools/informatics/conversion_smf.py +63 -10
  7. smftools/informatics/direct_smf.py +66 -18
  8. smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
  9. smftools/informatics/helpers/__init__.py +16 -2
  10. smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
  11. smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
  12. smftools/informatics/helpers/bam_qc.py +66 -0
  13. smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
  14. smftools/informatics/helpers/canoncall.py +12 -3
  15. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
  16. smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
  17. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  18. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  19. smftools/informatics/helpers/extract_base_identities.py +33 -46
  20. smftools/informatics/helpers/extract_mods.py +55 -23
  21. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  22. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  23. smftools/informatics/helpers/find_conversion_sites.py +33 -44
  24. smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
  25. smftools/informatics/helpers/modcall.py +13 -5
  26. smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
  27. smftools/informatics/helpers/ohe_batching.py +65 -41
  28. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  29. smftools/informatics/helpers/one_hot_decode.py +27 -0
  30. smftools/informatics/helpers/one_hot_encode.py +45 -9
  31. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
  32. smftools/informatics/helpers/run_multiqc.py +28 -0
  33. smftools/informatics/helpers/split_and_index_BAM.py +3 -8
  34. smftools/informatics/load_adata.py +58 -3
  35. smftools/plotting/__init__.py +15 -0
  36. smftools/plotting/classifiers.py +355 -0
  37. smftools/plotting/general_plotting.py +205 -0
  38. smftools/plotting/position_stats.py +462 -0
  39. smftools/preprocessing/__init__.py +6 -7
  40. smftools/preprocessing/append_C_context.py +22 -9
  41. smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
  42. smftools/preprocessing/binarize_on_Youden.py +35 -32
  43. smftools/preprocessing/binary_layers_to_ohe.py +13 -3
  44. smftools/preprocessing/calculate_complexity.py +3 -2
  45. smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
  46. smftools/preprocessing/calculate_coverage.py +26 -25
  47. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  48. smftools/preprocessing/calculate_position_Youden.py +18 -7
  49. smftools/preprocessing/calculate_read_length_stats.py +39 -46
  50. smftools/preprocessing/clean_NaN.py +33 -25
  51. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  52. smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
  53. smftools/preprocessing/filter_reads_on_length.py +14 -4
  54. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  55. smftools/preprocessing/invert_adata.py +18 -11
  56. smftools/preprocessing/load_sample_sheet.py +30 -16
  57. smftools/preprocessing/recipes.py +22 -20
  58. smftools/preprocessing/subsample_adata.py +58 -0
  59. smftools/readwrite.py +105 -13
  60. smftools/tools/__init__.py +49 -0
  61. smftools/tools/apply_hmm.py +202 -0
  62. smftools/tools/apply_hmm_batched.py +241 -0
  63. smftools/tools/archived/classify_methylated_features.py +66 -0
  64. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  65. smftools/tools/archived/subset_adata_v1.py +32 -0
  66. smftools/tools/archived/subset_adata_v2.py +46 -0
  67. smftools/tools/calculate_distances.py +18 -0
  68. smftools/tools/calculate_umap.py +62 -0
  69. smftools/tools/call_hmm_peaks.py +105 -0
  70. smftools/tools/classifiers.py +787 -0
  71. smftools/tools/cluster_adata_on_methylation.py +105 -0
  72. smftools/tools/data/__init__.py +2 -0
  73. smftools/tools/data/anndata_data_module.py +90 -0
  74. smftools/tools/data/preprocessing.py +6 -0
  75. smftools/tools/display_hmm.py +18 -0
  76. smftools/tools/general_tools.py +69 -0
  77. smftools/tools/hmm_readwrite.py +16 -0
  78. smftools/tools/inference/__init__.py +1 -0
  79. smftools/tools/inference/lightning_inference.py +41 -0
  80. smftools/tools/models/__init__.py +9 -0
  81. smftools/tools/models/base.py +14 -0
  82. smftools/tools/models/cnn.py +34 -0
  83. smftools/tools/models/lightning_base.py +41 -0
  84. smftools/tools/models/mlp.py +17 -0
  85. smftools/tools/models/positional.py +17 -0
  86. smftools/tools/models/rnn.py +16 -0
  87. smftools/tools/models/sklearn_models.py +40 -0
  88. smftools/tools/models/transformer.py +133 -0
  89. smftools/tools/models/wrappers.py +20 -0
  90. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  91. smftools/tools/position_stats.py +239 -0
  92. smftools/tools/read_stats.py +70 -0
  93. smftools/tools/subset_adata.py +19 -23
  94. smftools/tools/train_hmm.py +78 -0
  95. smftools/tools/training/__init__.py +1 -0
  96. smftools/tools/training/train_lightning_model.py +47 -0
  97. smftools/tools/utils/__init__.py +2 -0
  98. smftools/tools/utils/device.py +10 -0
  99. smftools/tools/utils/grl.py +14 -0
  100. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
  101. smftools-0.1.7.dist-info/RECORD +136 -0
  102. smftools/tools/apply_HMM.py +0 -1
  103. smftools/tools/read_HMM.py +0 -1
  104. smftools/tools/train_HMM.py +0 -43
  105. smftools-0.1.3.dist-info/RECORD +0 -84
  106. /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
  107. /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
  108. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  109. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,369 @@
1
+ import numpy as np
2
+ import time
3
+ import os
4
+ import gc
5
+ import pandas as pd
6
+ import anndata as ad
7
+ from tqdm import tqdm
8
+ import multiprocessing
9
+ from multiprocessing import Manager, Lock, current_process, Pool
10
+ import traceback
11
+ import gzip
12
+ import torch
13
+
14
+ from .. import readwrite
15
+ from .binarize_converted_base_identities import binarize_converted_base_identities
16
+ from .find_conversion_sites import find_conversion_sites
17
+ from .count_aligned_reads import count_aligned_reads
18
+ from .extract_base_identities import extract_base_identities
19
+ from .make_dirs import make_dirs
20
+ from .ohe_batching import ohe_batching
21
+
22
+ if __name__ == "__main__":
23
+ multiprocessing.set_start_method("forkserver", force=True)
24
+
25
+ def converted_BAM_to_adata_II(converted_FASTA, split_dir, mapping_threshold, experiment_name, conversion_types, bam_suffix, device='cpu', num_threads=8):
26
+ """
27
+ Converts BAM files into an AnnData object by binarizing modified base identities.
28
+
29
+ Parameters:
30
+ converted_FASTA (str): Path to the converted FASTA reference.
31
+ split_dir (str): Directory containing converted BAM files.
32
+ mapping_threshold (float): Minimum fraction of aligned reads required for inclusion.
33
+ experiment_name (str): Name for the output AnnData object.
34
+ conversion_types (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
35
+ bam_suffix (str): File suffix for BAM files.
36
+ num_threads (int): Number of parallel processing threads.
37
+
38
+ Returns:
39
+ str: Path to the final AnnData object.
40
+ """
41
+ if torch.cuda.is_available():
42
+ device = torch.device("cuda")
43
+ elif torch.backends.mps.is_available():
44
+ device = torch.device("mps")
45
+ else:
46
+ device = torch.device("cpu")
47
+
48
+ print(f"Using device: {device}")
49
+
50
+ ## Set Up Directories and File Paths
51
+ parent_dir = os.path.dirname(split_dir)
52
+ h5_dir = os.path.join(parent_dir, 'h5ads')
53
+ tmp_dir = os.path.join(parent_dir, 'tmp')
54
+ final_adata_path = os.path.join(h5_dir, f'{experiment_name}_{os.path.basename(split_dir)}.h5ad.gz')
55
+
56
+ if os.path.exists(final_adata_path):
57
+ print(f"{final_adata_path} already exists. Using existing AnnData object.")
58
+ return final_adata_path
59
+
60
+ make_dirs([h5_dir, tmp_dir])
61
+
62
+ ## Get BAM Files ##
63
+ bam_files = [f for f in os.listdir(split_dir) if f.endswith(bam_suffix) and not f.endswith('.bai') and 'unclassified' not in f]
64
+ bam_files.sort()
65
+ bam_path_list = [os.path.join(split_dir, f) for f in bam_files]
66
+ print(f"Found {len(bam_files)} BAM files: {bam_files}")
67
+
68
+ ## Process Conversion Sites
69
+ max_reference_length, record_FASTA_dict, chromosome_FASTA_dict = process_conversion_sites(converted_FASTA, conversion_types)
70
+
71
+ ## Filter BAM Files by Mapping Threshold
72
+ records_to_analyze = filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold)
73
+
74
+ ## Process BAMs in Parallel
75
+ final_adata = process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device)
76
+
77
+ for chromosome, [seq, comp] in chromosome_FASTA_dict.items():
78
+ final_adata.var[f'{chromosome}_top_strand_FASTA_base'] = list(seq)
79
+ final_adata.var[f'{chromosome}_bottom_strand_FASTA_base'] = list(comp)
80
+ final_adata.uns[f'{chromosome}_FASTA_sequence'] = seq
81
+
82
+ ## Save Final AnnData
83
+ # print(f"Saving AnnData to {final_adata_path}")
84
+ # final_adata.write_h5ad(final_adata_path, compression='gzip')
85
+ return final_adata, final_adata_path
86
+
87
+
88
+ def process_conversion_sites(converted_FASTA, conversion_types):
89
+ """
90
+ Extracts conversion sites and determines the max reference length.
91
+
92
+ Parameters:
93
+ converted_FASTA (str): Path to the converted reference FASTA.
94
+ conversion_types (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
95
+
96
+ Returns:
97
+ max_reference_length (int): The length of the longest sequence.
98
+ record_FASTA_dict (dict): Dictionary of sequence information for **both converted & unconverted** records.
99
+ """
100
+ modification_dict = {}
101
+ record_FASTA_dict = {}
102
+ chromosome_FASTA_dict = {}
103
+ max_reference_length = 0
104
+ unconverted = conversion_types[0]
105
+ conversions = conversion_types[1:]
106
+
107
+ # Process the unconverted sequence once
108
+ modification_dict[unconverted] = find_conversion_sites(converted_FASTA, unconverted, conversion_types)
109
+ # Above points to record_dict[record.id] = [sequence_length, [], [], sequence, complement] with only unconverted record.id keys
110
+
111
+ # Get **max sequence length** from unconverted records
112
+ max_reference_length = max(values[0] for values in modification_dict[unconverted].values())
113
+
114
+ # Add **unconverted records** to `record_FASTA_dict`
115
+ for record, values in modification_dict[unconverted].items():
116
+ sequence_length, top_coords, bottom_coords, sequence, complement = values
117
+ chromosome = record.replace(f"_{unconverted}_top", "")
118
+
119
+ # Store **original sequence**
120
+ record_FASTA_dict[record] = [
121
+ sequence + "N" * (max_reference_length - sequence_length),
122
+ complement + "N" * (max_reference_length - sequence_length),
123
+ chromosome, record, sequence_length, max_reference_length - sequence_length, unconverted, "top"
124
+ ]
125
+
126
+ if chromosome not in chromosome_FASTA_dict:
127
+ chromosome_FASTA_dict[chromosome] = [sequence + "N" * (max_reference_length - sequence_length), complement + "N" * (max_reference_length - sequence_length)]
128
+
129
+ # Process converted records
130
+ for conversion in conversions:
131
+ modification_dict[conversion] = find_conversion_sites(converted_FASTA, conversion, conversion_types)
132
+ # Above points to record_dict[record.id] = [sequence_length, top_strand_coordinates, bottom_strand_coordinates, sequence, complement] with only unconverted record.id keys
133
+
134
+ for record, values in modification_dict[conversion].items():
135
+ sequence_length, top_coords, bottom_coords, sequence, complement = values
136
+ chromosome = record.split(f"_{unconverted}_")[0] # Extract chromosome name
137
+
138
+ # Add **both strands** for converted records
139
+ for strand in ["top", "bottom"]:
140
+ converted_name = f"{chromosome}_{conversion}_{strand}"
141
+ unconverted_name = f"{chromosome}_{unconverted}_top"
142
+
143
+ record_FASTA_dict[converted_name] = [
144
+ sequence + "N" * (max_reference_length - sequence_length),
145
+ complement + "N" * (max_reference_length - sequence_length),
146
+ chromosome, unconverted_name, sequence_length,
147
+ max_reference_length - sequence_length, conversion, strand
148
+ ]
149
+
150
+ print("Updated record_FASTA_dict Keys:", list(record_FASTA_dict.keys()))
151
+ return max_reference_length, record_FASTA_dict, chromosome_FASTA_dict
152
+
153
+
154
+ def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold):
155
+ """Filters BAM files based on mapping threshold."""
156
+ records_to_analyze = set()
157
+
158
+ for i, bam in enumerate(bam_path_list):
159
+ aligned_reads, unaligned_reads, record_counts = count_aligned_reads(bam)
160
+ aligned_percent = aligned_reads * 100 / (aligned_reads + unaligned_reads)
161
+ print(f"{aligned_percent:.2f}% of reads in {bam_files[i]} aligned successfully.")
162
+
163
+ for record, (count, percent) in record_counts.items():
164
+ if percent >= mapping_threshold:
165
+ records_to_analyze.add(record)
166
+
167
+ print(f"Analyzing the following FASTA records: {records_to_analyze}")
168
+ return records_to_analyze
169
+
170
+
171
+ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, tmp_dir, max_reference_length, device):
172
+ """Worker function to process a single BAM file (must be at top-level for multiprocessing)."""
173
+ adata_list = []
174
+
175
+ for record in records_to_analyze:
176
+ sample = os.path.basename(bam).split(sep=".bam")[0]
177
+ chromosome = record_FASTA_dict[record][2]
178
+ current_length = record_FASTA_dict[record][4]
179
+ mod_type, strand = record_FASTA_dict[record][6], record_FASTA_dict[record][7]
180
+
181
+ # Extract Base Identities
182
+ fwd_bases, rev_bases = extract_base_identities(bam, record, range(current_length), max_reference_length)
183
+
184
+ # Skip processing if both forward and reverse base identities are empty
185
+ if not fwd_bases and not rev_bases:
186
+ print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid base identities for {record}.")
187
+ continue
188
+
189
+ merged_bin = {}
190
+
191
+ # Binarize the Base Identities if they exist
192
+ if fwd_bases:
193
+ fwd_bin = binarize_converted_base_identities(fwd_bases, strand, mod_type, bam, device)
194
+ merged_bin.update(fwd_bin)
195
+
196
+ if rev_bases:
197
+ rev_bin = binarize_converted_base_identities(rev_bases, strand, mod_type, bam, device)
198
+ merged_bin.update(rev_bin)
199
+
200
+ # Skip if merged_bin is empty (no valid binarized data)
201
+ if not merged_bin:
202
+ print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid binarized data for {record}.")
203
+ continue
204
+
205
+ # Convert to DataFrame
206
+ # for key in merged_bin:
207
+ # merged_bin[key] = merged_bin[key].cpu().numpy() # Move to CPU & convert to NumPy
208
+ bin_df = pd.DataFrame.from_dict(merged_bin, orient='index')
209
+ sorted_index = sorted(bin_df.index)
210
+ bin_df = bin_df.reindex(sorted_index)
211
+
212
+ # One-Hot Encode Reads if there is valid data
213
+ one_hot_reads = {}
214
+
215
+ if fwd_bases:
216
+ fwd_ohe_files = ohe_batching(fwd_bases, tmp_dir, record, f"{bam_index}_fwd", batch_size=100000)
217
+ for ohe_file in fwd_ohe_files:
218
+ tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
219
+ one_hot_reads.update(tmp_ohe_dict)
220
+ del tmp_ohe_dict
221
+
222
+ if rev_bases:
223
+ rev_ohe_files = ohe_batching(rev_bases, tmp_dir, record, f"{bam_index}_rev", batch_size=100000)
224
+ for ohe_file in rev_ohe_files:
225
+ tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
226
+ one_hot_reads.update(tmp_ohe_dict)
227
+ del tmp_ohe_dict
228
+
229
+ # Skip if one_hot_reads is empty
230
+ if not one_hot_reads:
231
+ print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid one-hot encoded data for {record}.")
232
+ continue
233
+
234
+ gc.collect()
235
+
236
+ # Convert One-Hot Encodings to Numpy Arrays
237
+ n_rows_OHE = 5
238
+ read_names = list(one_hot_reads.keys())
239
+
240
+ # Skip if no read names exist
241
+ if not read_names:
242
+ print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No reads found in one-hot encoded data for {record}.")
243
+ continue
244
+
245
+ sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
246
+ df_A, df_C, df_G, df_T, df_N = [np.zeros((len(sorted_index), sequence_length), dtype=int) for _ in range(5)]
247
+
248
+ # Populate One-Hot Arrays
249
+ for j, read_name in enumerate(sorted_index):
250
+ if read_name in one_hot_reads:
251
+ one_hot_array = one_hot_reads[read_name].reshape(n_rows_OHE, -1)
252
+ df_A[j], df_C[j], df_G[j], df_T[j], df_N[j] = one_hot_array
253
+
254
+ # Convert to AnnData
255
+ X = bin_df.values.astype(np.float32)
256
+ adata = ad.AnnData(X)
257
+ adata.obs_names = bin_df.index.astype(str)
258
+ adata.var_names = bin_df.columns.astype(str)
259
+ adata.obs["Sample"] = [sample] * len(adata)
260
+ adata.obs["Reference"] = [chromosome] * len(adata)
261
+ adata.obs["Strand"] = [strand] * len(adata)
262
+ adata.obs["Dataset"] = [mod_type] * len(adata)
263
+ adata.obs["Reference_dataset_strand"] = [f"{chromosome}_{mod_type}_{strand}"] * len(adata)
264
+ adata.obs["Reference_strand"] = [f"{chromosome}_{strand}"] * len(adata)
265
+
266
+ # Attach One-Hot Encodings to Layers
267
+ adata.layers["A_binary_encoding"] = df_A
268
+ adata.layers["C_binary_encoding"] = df_C
269
+ adata.layers["G_binary_encoding"] = df_G
270
+ adata.layers["T_binary_encoding"] = df_T
271
+ adata.layers["N_binary_encoding"] = df_N
272
+
273
+ adata_list.append(adata)
274
+
275
+ return ad.concat(adata_list, join="outer") if adata_list else None
276
+
277
+ def timestamp():
278
+ """Returns a formatted timestamp for logging."""
279
+ return time.strftime("[%Y-%m-%d %H:%M:%S]")
280
+
281
+
282
+ def worker_function(bam_index, bam, records_to_analyze, shared_record_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, progress_queue):
283
+ """Worker function that processes a single BAM and writes the output to an H5AD file."""
284
+ worker_id = current_process().pid # Get worker process ID
285
+ sample = os.path.basename(bam).split(sep=".bam")[0]
286
+
287
+ try:
288
+ print(f"{timestamp()} [Worker {worker_id}] Processing BAM: {sample}")
289
+
290
+ h5ad_path = os.path.join(h5_dir, f"{sample}.h5ad")
291
+ if os.path.exists(h5ad_path):
292
+ print(f"{timestamp()} [Worker {worker_id}] Skipping {sample}: Already processed.")
293
+ progress_queue.put(sample)
294
+ return
295
+
296
+ # Filter records specific to this BAM
297
+ bam_records_to_analyze = {record for record in records_to_analyze if record in shared_record_FASTA_dict}
298
+
299
+ if not bam_records_to_analyze:
300
+ print(f"{timestamp()} [Worker {worker_id}] No valid records to analyze for {sample}. Skipping.")
301
+ progress_queue.put(sample)
302
+ return
303
+
304
+ # Process BAM
305
+ adata = process_single_bam(bam_index, bam, bam_records_to_analyze, shared_record_FASTA_dict, tmp_dir, max_reference_length, device)
306
+
307
+ if adata is not None:
308
+ adata.write_h5ad(h5ad_path)
309
+ print(f"{timestamp()} [Worker {worker_id}] Completed processing for BAM: {sample}")
310
+
311
+ # Free memory
312
+ del adata
313
+ gc.collect()
314
+
315
+ progress_queue.put(sample)
316
+
317
+ except Exception as e:
318
+ print(f"{timestamp()} [Worker {worker_id}] ERROR while processing {sample}:\n{traceback.format_exc()}")
319
+ progress_queue.put(sample) # Still signal completion to prevent deadlock
320
+
321
+ def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device):
322
+ """Processes BAM files in parallel, writes each H5AD to disk, and concatenates them at the end."""
323
+ os.makedirs(h5_dir, exist_ok=True) # Ensure h5_dir exists
324
+
325
+ print(f"{timestamp()} Starting parallel BAM processing with {num_threads} threads...")
326
+
327
+ # Ensure macOS uses forkserver to avoid spawning issues
328
+ try:
329
+ import multiprocessing
330
+ multiprocessing.set_start_method("forkserver", force=True)
331
+ except RuntimeError:
332
+ print(f"{timestamp()} [WARNING] Multiprocessing context already set. Skipping set_start_method.")
333
+
334
+ with Manager() as manager:
335
+ progress_queue = manager.Queue()
336
+ shared_record_FASTA_dict = manager.dict(record_FASTA_dict)
337
+
338
+ with Pool(processes=num_threads) as pool:
339
+ results = [
340
+ pool.apply_async(worker_function, (i, bam, records_to_analyze, shared_record_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, progress_queue))
341
+ for i, bam in enumerate(bam_path_list)
342
+ ]
343
+
344
+ print(f"{timestamp()} Submitted {len(bam_path_list)} BAMs for processing.")
345
+
346
+ # Track completed BAMs
347
+ completed_bams = set()
348
+ while len(completed_bams) < len(bam_path_list):
349
+ try:
350
+ processed_bam = progress_queue.get(timeout=2400) # Wait for a finished BAM
351
+ completed_bams.add(processed_bam)
352
+ except Exception as e:
353
+ print(f"{timestamp()} [ERROR] Timeout waiting for worker process. Possible crash? {e}")
354
+
355
+ pool.close()
356
+ pool.join() # Ensure all workers finish
357
+
358
+ # Final Concatenation Step
359
+ h5ad_files = [os.path.join(h5_dir, f) for f in os.listdir(h5_dir) if f.endswith(".h5ad")]
360
+
361
+ if not h5ad_files:
362
+ print(f"{timestamp()} No valid H5AD files generated. Exiting.")
363
+ return None
364
+
365
+ print(f"{timestamp()} Concatenating {len(h5ad_files)} H5AD files into final output...")
366
+ final_adata = ad.concat([ad.read_h5ad(f) for f in h5ad_files], join="outer")
367
+
368
+ print(f"{timestamp()} Successfully generated final AnnData object.")
369
+ return final_adata
@@ -0,0 +1,52 @@
1
+ ## demux_and_index_BAM
2
+
3
+ def demux_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, barcode_kit, barcode_both_ends, trim, fasta, make_bigwigs, threads):
4
+ """
5
+ A wrapper function for splitting BAMS and indexing them.
6
+ Parameters:
7
+ aligned_sorted_BAM (str): A string representing the file path of the aligned_sorted BAM file.
8
+ split_dir (str): A string representing the file path to the directory to split the BAMs into.
9
+ bam_suffix (str): A suffix to add to the bam file.
10
+ barcode_kit (str): Name of barcoding kit.
11
+ barcode_both_ends (bool): Whether to require both ends to be barcoded.
12
+ trim (bool): Whether to trim off barcodes after demultiplexing.
13
+ fasta (str): File path to the reference genome to align to.
14
+ make_bigwigs (bool): Whether to make bigwigs
15
+ threads (int): Number of threads to use.
16
+
17
+ Returns:
18
+ bam_files (list): List of split BAM file path strings
19
+ Splits an input BAM file on barcode value and makes a BAM index file.
20
+ """
21
+ from .. import readwrite
22
+ import os
23
+ import subprocess
24
+ import glob
25
+ from .make_dirs import make_dirs
26
+
27
+ input_bam = aligned_sorted_BAM + bam_suffix
28
+ command = ["dorado", "demux", "--kit-name", barcode_kit]
29
+ if barcode_both_ends:
30
+ command.append("--barcode-both-ends")
31
+ if not trim:
32
+ command.append("--no-trim")
33
+ if threads:
34
+ command += ["-t", str(threads)]
35
+ else:
36
+ pass
37
+ command += ["--emit-summary", "--sort-bam", "--output-dir", split_dir]
38
+ command.append(input_bam)
39
+ command_string = ' '.join(command)
40
+ print(f"Running: {command_string}")
41
+ subprocess.run(command)
42
+
43
+ # Make a BAM index file for the BAMs in that directory
44
+ bam_pattern = '*' + bam_suffix
45
+ bam_files = glob.glob(os.path.join(split_dir, bam_pattern))
46
+ bam_files = [bam for bam in bam_files if '.bai' not in bam and 'unclassified' not in bam]
47
+ bam_files.sort()
48
+
49
+ if not bam_files:
50
+ raise FileNotFoundError(f"No BAM files found in {split_dir} with suffix {bam_suffix}")
51
+
52
+ return bam_files
@@ -1,57 +1,44 @@
1
- ## extract_base_identities
2
-
3
- # General
4
1
  def extract_base_identities(bam_file, chromosome, positions, max_reference_length):
5
2
  """
6
- Extracts the base identities from every position within the mapped reads that have a reference coordinate
3
+ Efficiently extracts base identities from mapped reads with reference coordinates.
7
4
 
8
5
  Parameters:
9
- bam (str): File path to the BAM file to align (excluding the file suffix).
10
- chromosome (str): A string representing the name of the record within the reference FASTA.
11
- positions (list): A list of position coordinates within the record to extract.
12
- max_reference_length (int): The maximum length of a record in the reference set.
6
+ bam_file (str): Path to the BAM file.
7
+ chromosome (str): Name of the reference chromosome.
8
+ positions (list): Positions to extract (0-based).
9
+ max_reference_length (int): Maximum reference length for padding.
13
10
 
14
11
  Returns:
15
- fwd_base_identities (dict): A dictionary, keyed by read name, that points to a list of base identities from forward mapped reads. If the read does not contain that position, fill the list at that index with a N value.
16
- rev_base_identities (dict): A dictionary, keyed by read name, that points to a list of base identities from reverse mapped reads. If the read does not contain that position, fill the list at that index with a N value.
12
+ dict: Base identities from forward mapped reads.
13
+ dict: Base identities from reverse mapped reads.
17
14
  """
18
- from .. import readwrite
19
15
  import pysam
20
- from tqdm import tqdm
21
-
16
+ import numpy as np
17
+ from collections import defaultdict
18
+ import time
19
+
20
+ timestamp = time.strftime("[%Y-%m-%d %H:%M:%S]")
21
+
22
22
  positions = set(positions)
23
- # Initialize a base identity dictionary that will hold key-value pairs that are: key (read-name) and value (list of base identities at positions of interest)
24
- fwd_base_identities = {}
25
- rev_base_identities = {}
26
- # Open the postion sorted BAM file
27
- print('{0}: Reading BAM file: {1}'.format(readwrite.time_string(), bam_file))
23
+ fwd_base_identities = defaultdict(lambda: np.full(max_reference_length, 'N', dtype='<U1'))
24
+ rev_base_identities = defaultdict(lambda: np.full(max_reference_length, 'N', dtype='<U1'))
25
+
26
+ #print(f"{timestamp} Reading reads from {chromosome} BAM file: {bam_file}")
28
27
  with pysam.AlignmentFile(bam_file, "rb") as bam:
29
- # Iterate over every read in the bam that comes from the chromosome of interest
30
- print('{0}: Iterating over reads in bam'.format(readwrite.time_string()))
31
28
  total_reads = bam.mapped
32
- for read in tqdm(bam.fetch(chromosome), desc='Extracting base identities from reads in BAM', total=total_reads):
33
- # Only iterate over mapped reads
34
- if read.is_mapped:
35
- # Get sequence of read. PySam reports fwd mapped reads as the true read sequence. Pysam reports rev mapped reads as the reverse complement of the read.
36
- query_sequence = read.query_sequence
37
- # If the read aligned as a reverse complement, mark that the read is reversed
38
- if read.is_reverse:
39
- # Initialize the read key in a temp base_identities dictionary by pointing to a N filled list of length reference_length.
40
- rev_base_identities[read.query_name] = ['N'] * max_reference_length
41
- # Iterate over a list of tuples for the given read. The tuples contain the 0-indexed position relative to the read.query_sequence start, as well the 0-based index relative to the reference.
42
- for read_position, reference_position in read.get_aligned_pairs(matches_only=True):
43
- # If the aligned read's reference coordinate is in the positions set and if the read position was successfully mapped
44
- if reference_position in positions and read_position:
45
- # get the base_identity in the read corresponding to that position
46
- rev_base_identities[read.query_name][reference_position] = query_sequence[read_position]
47
- else:
48
- # Initialize the read key in a temp base_identities dictionary by pointing to a N filled list of length reference_length.
49
- fwd_base_identities[read.query_name] = ['N'] * max_reference_length
50
- # Iterate over a list of tuples for the given read. The tuples contain the 0-indexed position relative to the read.query_sequence start, as well the 0-based index relative to the reference.
51
- for read_position, reference_position in read.get_aligned_pairs(matches_only=True):
52
- # If the aligned read's reference coordinate is in the positions set and if the read position was successfully mapped
53
- if reference_position in positions and read_position:
54
- # get the base_identity in the read corresponding to that position
55
- fwd_base_identities[read.query_name][reference_position] = query_sequence[read_position]
56
-
57
- return fwd_base_identities, rev_base_identities
29
+ for read in bam.fetch(chromosome):
30
+ if not read.is_mapped:
31
+ continue # Skip unmapped reads
32
+
33
+ read_name = read.query_name
34
+ query_sequence = read.query_sequence
35
+ base_dict = rev_base_identities if read.is_reverse else fwd_base_identities
36
+
37
+ # Use get_aligned_pairs directly with positions filtering
38
+ aligned_pairs = read.get_aligned_pairs(matches_only=True)
39
+
40
+ for read_position, reference_position in aligned_pairs:
41
+ if reference_position in positions:
42
+ base_dict[read_name][reference_position] = query_sequence[read_position]
43
+
44
+ return dict(fwd_base_identities), dict(rev_base_identities)
@@ -1,6 +1,6 @@
1
1
  ## extract_mods
2
2
 
3
- def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix):
3
+ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassified=True, modkit_summary=False, threads=None):
4
4
  """
5
5
  Takes all of the aligned, sorted, split modified BAM files and runs Nanopore Modkit Extract to load the modification data into zipped TSV files
6
6
 
@@ -9,6 +9,9 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix):
9
9
  mod_tsv_dir (str): A string representing the file path to the directory to hold the modkit extract outputs.
10
10
  split_dit (str): A string representing the file path to the directory containing the converted aligned_sorted_split BAM files.
11
11
  bam_suffix (str): The suffix to use for the BAM file.
12
+ skip_unclassified (bool): Whether to skip unclassified bam file for modkit extract command
13
+ modkit_summary (bool): Whether to run and display modkit summary
14
+ threads (int): Number of threads to use
12
15
 
13
16
  Returns:
14
17
  None
@@ -23,29 +26,58 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix):
23
26
  os.chdir(mod_tsv_dir)
24
27
  filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
25
28
  bam_files = glob.glob(os.path.join(split_dir, f"*{bam_suffix}"))
29
+
30
+ if threads:
31
+ threads = str(threads)
32
+ else:
33
+ pass
34
+
26
35
  for input_file in bam_files:
27
36
  print(input_file)
28
37
  # Extract the file basename
29
38
  file_name = os.path.basename(input_file)
30
- # Construct the output TSV file path
31
- output_tsv_temp = os.path.join(mod_tsv_dir, file_name)
32
- output_tsv = output_tsv_temp.replace(bam_suffix, "") + "_extract.tsv"
33
- # Run modkit summary
34
- subprocess.run(["modkit", "summary", input_file])
35
- # Run modkit extract
36
- subprocess.run([
37
- "modkit", "extract",
38
- "--filter-threshold", f'{filter_threshold}',
39
- "--mod-thresholds", f"m:{m5C_threshold}",
40
- "--mod-thresholds", f"a:{m6A_threshold}",
41
- "--mod-thresholds", f"h:{hm5C_threshold}",
42
- input_file, "null",
43
- "--read-calls", output_tsv
44
- ])
45
- # Zip the output TSV
46
- print(f'zipping {output_tsv}')
47
- with zipfile.ZipFile(f"{output_tsv}.zip", 'w', zipfile.ZIP_DEFLATED) as zipf:
48
- zipf.write(output_tsv, os.path.basename(output_tsv))
49
- # Remove the non-zipped TSV
50
- print(f'removing {output_tsv}')
51
- os.remove(output_tsv)
39
+ if skip_unclassified and "unclassified" in file_name:
40
+ print("Skipping modkit extract on unclassified reads")
41
+ else:
42
+ # Construct the output TSV file path
43
+ output_tsv_temp = os.path.join(mod_tsv_dir, file_name)
44
+ output_tsv = output_tsv_temp.replace(bam_suffix, "") + "_extract.tsv"
45
+ if os.path.exists(f"{output_tsv}.gz"):
46
+ print(f"{output_tsv}.gz already exists, skipping modkit extract")
47
+ else:
48
+ print(f"Extracting modification data from {input_file}")
49
+ if modkit_summary:
50
+ # Run modkit summary
51
+ subprocess.run(["modkit", "summary", input_file])
52
+ else:
53
+ pass
54
+ # Run modkit extract
55
+ if threads:
56
+ extract_command = [
57
+ "modkit", "extract",
58
+ "calls", "--mapped-only",
59
+ "--filter-threshold", f'{filter_threshold}',
60
+ "--mod-thresholds", f"m:{m5C_threshold}",
61
+ "--mod-thresholds", f"a:{m6A_threshold}",
62
+ "--mod-thresholds", f"h:{hm5C_threshold}",
63
+ "-t", threads,
64
+ input_file, output_tsv
65
+ ]
66
+ else:
67
+ extract_command = [
68
+ "modkit", "extract",
69
+ "calls", "--mapped-only",
70
+ "--filter-threshold", f'{filter_threshold}',
71
+ "--mod-thresholds", f"m:{m5C_threshold}",
72
+ "--mod-thresholds", f"a:{m6A_threshold}",
73
+ "--mod-thresholds", f"h:{hm5C_threshold}",
74
+ input_file, output_tsv
75
+ ]
76
+ subprocess.run(extract_command)
77
+ # Zip the output TSV
78
+ print(f'zipping {output_tsv}')
79
+ if threads:
80
+ zip_command = ["pigz", "-f", "-p", threads, output_tsv]
81
+ else:
82
+ zip_command = ["pigz", "-f", output_tsv]
83
+ subprocess.run(zip_command, check=True)
@@ -0,0 +1,31 @@
1
+ # extract_read_features_from_bam
2
+
3
+ def extract_read_features_from_bam(bam_file_path):
4
+ """
5
+ Make a dict of reads from a bam that points to a list of read metrics: read length, read median Q-score, reference length.
6
+ Params:
7
+ bam_file_path (str):
8
+ Returns:
9
+ read_metrics (dict)
10
+ """
11
+ import pysam
12
+ import numpy as np
13
+ # Open the BAM file
14
+ print(f'Extracting read features from BAM: {bam_file_path}')
15
+ with pysam.AlignmentFile(bam_file_path, "rb") as bam_file:
16
+ read_metrics = {}
17
+ reference_lengths = bam_file.lengths # List of lengths for each reference (chromosome)
18
+ for read in bam_file:
19
+ # Skip unmapped reads
20
+ if read.is_unmapped:
21
+ continue
22
+ # Extract the read metrics
23
+ read_quality = read.query_qualities
24
+ median_read_quality = np.median(read_quality)
25
+ # Extract the reference (chromosome) name and its length
26
+ reference_name = read.reference_name
27
+ reference_index = bam_file.references.index(reference_name)
28
+ reference_length = reference_lengths[reference_index]
29
+ read_metrics[read.query_name] = [read.query_length, median_read_quality, reference_length]
30
+
31
+ return read_metrics