smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. smftools/__init__.py +34 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/cli.py +184 -0
  5. smftools/config/__init__.py +1 -0
  6. smftools/config/conversion.yaml +33 -0
  7. smftools/config/deaminase.yaml +56 -0
  8. smftools/config/default.yaml +253 -0
  9. smftools/config/direct.yaml +17 -0
  10. smftools/config/experiment_config.py +1191 -0
  11. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  12. smftools/datasets/F1_sample_sheet.csv +5 -0
  13. smftools/datasets/__init__.py +9 -0
  14. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  15. smftools/datasets/datasets.py +28 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/hmm/apply_hmm_batched.py +242 -0
  19. smftools/hmm/calculate_distances.py +18 -0
  20. smftools/hmm/call_hmm_peaks.py +106 -0
  21. smftools/hmm/display_hmm.py +18 -0
  22. smftools/hmm/hmm_readwrite.py +16 -0
  23. smftools/hmm/nucleosome_hmm_refinement.py +104 -0
  24. smftools/hmm/train_hmm.py +78 -0
  25. smftools/informatics/__init__.py +14 -0
  26. smftools/informatics/archived/bam_conversion.py +59 -0
  27. smftools/informatics/archived/bam_direct.py +63 -0
  28. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  29. smftools/informatics/archived/conversion_smf.py +132 -0
  30. smftools/informatics/archived/deaminase_smf.py +132 -0
  31. smftools/informatics/archived/direct_smf.py +137 -0
  32. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  33. smftools/informatics/basecall_pod5s.py +80 -0
  34. smftools/informatics/fast5_to_pod5.py +24 -0
  35. smftools/informatics/helpers/__init__.py +73 -0
  36. smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
  37. smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
  38. smftools/informatics/helpers/archived/informatics.py +260 -0
  39. smftools/informatics/helpers/archived/load_adata.py +516 -0
  40. smftools/informatics/helpers/bam_qc.py +66 -0
  41. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  42. smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
  43. smftools/informatics/helpers/canoncall.py +34 -0
  44. smftools/informatics/helpers/complement_base_list.py +21 -0
  45. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
  46. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  47. smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
  48. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  49. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  50. smftools/informatics/helpers/discover_input_files.py +100 -0
  51. smftools/informatics/helpers/extract_base_identities.py +70 -0
  52. smftools/informatics/helpers/extract_mods.py +83 -0
  53. smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
  54. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  55. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  56. smftools/informatics/helpers/find_conversion_sites.py +51 -0
  57. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  58. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  59. smftools/informatics/helpers/get_native_references.py +28 -0
  60. smftools/informatics/helpers/index_fasta.py +12 -0
  61. smftools/informatics/helpers/make_dirs.py +21 -0
  62. smftools/informatics/helpers/make_modbed.py +27 -0
  63. smftools/informatics/helpers/modQC.py +27 -0
  64. smftools/informatics/helpers/modcall.py +36 -0
  65. smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
  66. smftools/informatics/helpers/ohe_batching.py +76 -0
  67. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  68. smftools/informatics/helpers/one_hot_decode.py +27 -0
  69. smftools/informatics/helpers/one_hot_encode.py +57 -0
  70. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  71. smftools/informatics/helpers/run_multiqc.py +28 -0
  72. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  73. smftools/informatics/helpers/split_and_index_BAM.py +32 -0
  74. smftools/informatics/readwrite.py +106 -0
  75. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  76. smftools/informatics/subsample_pod5.py +104 -0
  77. smftools/load_adata.py +1346 -0
  78. smftools/machine_learning/__init__.py +12 -0
  79. smftools/machine_learning/data/__init__.py +2 -0
  80. smftools/machine_learning/data/anndata_data_module.py +234 -0
  81. smftools/machine_learning/data/preprocessing.py +6 -0
  82. smftools/machine_learning/evaluation/__init__.py +2 -0
  83. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  84. smftools/machine_learning/evaluation/evaluators.py +223 -0
  85. smftools/machine_learning/inference/__init__.py +3 -0
  86. smftools/machine_learning/inference/inference_utils.py +27 -0
  87. smftools/machine_learning/inference/lightning_inference.py +68 -0
  88. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  89. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  90. smftools/machine_learning/models/__init__.py +9 -0
  91. smftools/machine_learning/models/base.py +295 -0
  92. smftools/machine_learning/models/cnn.py +138 -0
  93. smftools/machine_learning/models/lightning_base.py +345 -0
  94. smftools/machine_learning/models/mlp.py +26 -0
  95. smftools/machine_learning/models/positional.py +18 -0
  96. smftools/machine_learning/models/rnn.py +17 -0
  97. smftools/machine_learning/models/sklearn_models.py +273 -0
  98. smftools/machine_learning/models/transformer.py +303 -0
  99. smftools/machine_learning/models/wrappers.py +20 -0
  100. smftools/machine_learning/training/__init__.py +2 -0
  101. smftools/machine_learning/training/train_lightning_model.py +135 -0
  102. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  103. smftools/machine_learning/utils/__init__.py +2 -0
  104. smftools/machine_learning/utils/device.py +10 -0
  105. smftools/machine_learning/utils/grl.py +14 -0
  106. smftools/plotting/__init__.py +18 -0
  107. smftools/plotting/autocorrelation_plotting.py +611 -0
  108. smftools/plotting/classifiers.py +355 -0
  109. smftools/plotting/general_plotting.py +682 -0
  110. smftools/plotting/hmm_plotting.py +260 -0
  111. smftools/plotting/position_stats.py +462 -0
  112. smftools/plotting/qc_plotting.py +270 -0
  113. smftools/preprocessing/__init__.py +38 -0
  114. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  115. smftools/preprocessing/append_base_context.py +122 -0
  116. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  117. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  118. smftools/preprocessing/archives/preprocessing.py +614 -0
  119. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  120. smftools/preprocessing/binarize_on_Youden.py +45 -0
  121. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  122. smftools/preprocessing/calculate_complexity.py +72 -0
  123. smftools/preprocessing/calculate_complexity_II.py +248 -0
  124. smftools/preprocessing/calculate_consensus.py +47 -0
  125. smftools/preprocessing/calculate_coverage.py +51 -0
  126. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  127. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  128. smftools/preprocessing/calculate_position_Youden.py +115 -0
  129. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  130. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  131. smftools/preprocessing/clean_NaN.py +62 -0
  132. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  133. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  134. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  135. smftools/preprocessing/flag_duplicate_reads.py +1351 -0
  136. smftools/preprocessing/invert_adata.py +37 -0
  137. smftools/preprocessing/load_sample_sheet.py +53 -0
  138. smftools/preprocessing/make_dirs.py +21 -0
  139. smftools/preprocessing/min_non_diagonal.py +25 -0
  140. smftools/preprocessing/recipes.py +127 -0
  141. smftools/preprocessing/subsample_adata.py +58 -0
  142. smftools/readwrite.py +1004 -0
  143. smftools/tools/__init__.py +20 -0
  144. smftools/tools/archived/apply_hmm.py +202 -0
  145. smftools/tools/archived/classifiers.py +787 -0
  146. smftools/tools/archived/classify_methylated_features.py +66 -0
  147. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  148. smftools/tools/archived/subset_adata_v1.py +32 -0
  149. smftools/tools/archived/subset_adata_v2.py +46 -0
  150. smftools/tools/calculate_umap.py +62 -0
  151. smftools/tools/cluster_adata_on_methylation.py +105 -0
  152. smftools/tools/general_tools.py +69 -0
  153. smftools/tools/position_stats.py +601 -0
  154. smftools/tools/read_stats.py +184 -0
  155. smftools/tools/spatial_autocorrelation.py +562 -0
  156. smftools/tools/subset_adata.py +28 -0
  157. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
  158. smftools-0.2.1.dist-info/RECORD +161 -0
  159. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  160. smftools-0.1.6.dist-info/RECORD +0 -4
  161. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  162. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,505 @@
1
+ import numpy as np
2
+ import time
3
+ import os
4
+ import gc
5
+ import pandas as pd
6
+ import anndata as ad
7
+ from tqdm import tqdm
8
+ import multiprocessing
9
+ from multiprocessing import Manager, Lock, current_process, Pool
10
+ import traceback
11
+ import gzip
12
+ import torch
13
+
14
+ import shutil
15
+ from pathlib import Path
16
+ from typing import Union, Iterable, Optional
17
+
18
+ from ... import readwrite
19
+ from .binarize_converted_base_identities import binarize_converted_base_identities
20
+ from .find_conversion_sites import find_conversion_sites
21
+ from .count_aligned_reads import count_aligned_reads
22
+ from .extract_base_identities import extract_base_identities
23
+ from .make_dirs import make_dirs
24
+ from .ohe_batching import ohe_batching
25
+
26
+ if __name__ == "__main__":
27
+ multiprocessing.set_start_method("forkserver", force=True)
28
+
29
+ def converted_BAM_to_adata_II(converted_FASTA,
30
+ split_dir,
31
+ mapping_threshold,
32
+ experiment_name,
33
+ conversions,
34
+ bam_suffix,
35
+ device='cpu',
36
+ num_threads=8,
37
+ deaminase_footprinting=False,
38
+ delete_intermediates=True
39
+ ):
40
+ """
41
+ Converts BAM files into an AnnData object by binarizing modified base identities.
42
+
43
+ Parameters:
44
+ converted_FASTA (str): Path to the converted FASTA reference.
45
+ split_dir (str): Directory containing converted BAM files.
46
+ mapping_threshold (float): Minimum fraction of aligned reads required for inclusion.
47
+ experiment_name (str): Name for the output AnnData object.
48
+ conversions (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
49
+ bam_suffix (str): File suffix for BAM files.
50
+ num_threads (int): Number of parallel processing threads.
51
+ deaminase_footprinting (bool): Whether the footprinting was done with a direct deamination chemistry.
52
+
53
+ Returns:
54
+ str: Path to the final AnnData object.
55
+ """
56
+ if torch.cuda.is_available():
57
+ device = torch.device("cuda")
58
+ elif torch.backends.mps.is_available():
59
+ device = torch.device("mps")
60
+ else:
61
+ device = torch.device("cpu")
62
+
63
+ print(f"Using device: {device}")
64
+
65
+ ## Set Up Directories and File Paths
66
+ #parent_dir = os.path.dirname(split_dir)
67
+ h5_dir = os.path.join(split_dir, 'h5ads')
68
+ tmp_dir = os.path.join(split_dir, 'tmp')
69
+ final_adata = None
70
+ final_adata_path = os.path.join(h5_dir, f'{experiment_name}_{os.path.basename(split_dir)}.h5ad.gz')
71
+
72
+ if os.path.exists(final_adata_path):
73
+ print(f"{final_adata_path} already exists. Using existing AnnData object.")
74
+ return final_adata, final_adata_path
75
+
76
+ make_dirs([h5_dir, tmp_dir])
77
+
78
+ ## Get BAM Files ##
79
+ bam_files = [f for f in os.listdir(split_dir) if f.endswith(bam_suffix) and not f.endswith('.bai') and 'unclassified' not in f]
80
+ bam_files.sort()
81
+ bam_path_list = [os.path.join(split_dir, f) for f in bam_files]
82
+ print(f"Found {len(bam_files)} BAM files: {bam_files}")
83
+
84
+ ## Process Conversion Sites
85
+ max_reference_length, record_FASTA_dict, chromosome_FASTA_dict = process_conversion_sites(converted_FASTA, conversions, deaminase_footprinting)
86
+
87
+ ## Filter BAM Files by Mapping Threshold
88
+ records_to_analyze = filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold)
89
+
90
+ ## Process BAMs in Parallel
91
+ final_adata = process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device, deaminase_footprinting)
92
+
93
+ for chromosome, [seq, comp] in chromosome_FASTA_dict.items():
94
+ final_adata.var[f'{chromosome}_top_strand_FASTA_base'] = list(seq)
95
+ final_adata.var[f'{chromosome}_bottom_strand_FASTA_base'] = list(comp)
96
+ final_adata.uns[f'{chromosome}_FASTA_sequence'] = seq
97
+
98
+ final_adata.obs_names_make_unique()
99
+ cols = final_adata.obs.columns
100
+
101
+ # Make obs cols categorical
102
+ for col in cols:
103
+ final_adata.obs[col] = final_adata.obs[col].astype('category')
104
+
105
+ ## Save Final AnnData
106
+ print(f"Saving AnnData to {final_adata_path}")
107
+ backup_dir=os.path.join(os.path.dirname(final_adata_path), 'adata_accessory_data')
108
+ readwrite.safe_write_h5ad(final_adata, final_adata_path, compression='gzip', backup=True, backup_dir=backup_dir)
109
+
110
+ ## Delete intermediate h5ad files and temp directories
111
+ if delete_intermediates:
112
+ delete_intermediate_h5ads_and_tmpdir(h5_dir, tmp_dir)
113
+
114
+ return final_adata, final_adata_path
115
+
116
+
117
+ def process_conversion_sites(converted_FASTA, conversions=['unconverted', '5mC'], deaminase_footprinting=False):
118
+ """
119
+ Extracts conversion sites and determines the max reference length.
120
+
121
+ Parameters:
122
+ converted_FASTA (str): Path to the converted reference FASTA.
123
+ conversions (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
124
+ deaminase_footprinting (bool): Whether the footprinting was done with a direct deamination chemistry.
125
+
126
+ Returns:
127
+ max_reference_length (int): The length of the longest sequence.
128
+ record_FASTA_dict (dict): Dictionary of sequence information for **both converted & unconverted** records.
129
+ """
130
+ modification_dict = {}
131
+ record_FASTA_dict = {}
132
+ chromosome_FASTA_dict = {}
133
+ max_reference_length = 0
134
+ unconverted = conversions[0]
135
+ conversion_types = conversions[1:]
136
+
137
+ # Process the unconverted sequence once
138
+ modification_dict[unconverted] = find_conversion_sites(converted_FASTA, unconverted, conversions, deaminase_footprinting)
139
+ # Above points to record_dict[record.id] = [sequence_length, [], [], sequence, complement] with only unconverted record.id keys
140
+
141
+ # Get **max sequence length** from unconverted records
142
+ max_reference_length = max(values[0] for values in modification_dict[unconverted].values())
143
+
144
+ # Add **unconverted records** to `record_FASTA_dict`
145
+ for record, values in modification_dict[unconverted].items():
146
+ sequence_length, top_coords, bottom_coords, sequence, complement = values
147
+
148
+ if not deaminase_footprinting:
149
+ chromosome = record.replace(f"_{unconverted}_top", "")
150
+ else:
151
+ chromosome = record
152
+
153
+ # Store **original sequence**
154
+ record_FASTA_dict[record] = [
155
+ sequence + "N" * (max_reference_length - sequence_length),
156
+ complement + "N" * (max_reference_length - sequence_length),
157
+ chromosome, record, sequence_length, max_reference_length - sequence_length, unconverted, "top"
158
+ ]
159
+
160
+ if chromosome not in chromosome_FASTA_dict:
161
+ chromosome_FASTA_dict[chromosome] = [sequence + "N" * (max_reference_length - sequence_length), complement + "N" * (max_reference_length - sequence_length)]
162
+
163
+ # Process converted records
164
+ for conversion in conversion_types:
165
+ modification_dict[conversion] = find_conversion_sites(converted_FASTA, conversion, conversions, deaminase_footprinting)
166
+ # Above points to record_dict[record.id] = [sequence_length, top_strand_coordinates, bottom_strand_coordinates, sequence, complement] with only unconverted record.id keys
167
+
168
+ for record, values in modification_dict[conversion].items():
169
+ sequence_length, top_coords, bottom_coords, sequence, complement = values
170
+
171
+ if not deaminase_footprinting:
172
+ chromosome = record.split(f"_{unconverted}_")[0] # Extract chromosome name
173
+ else:
174
+ chromosome = record
175
+
176
+ # Add **both strands** for converted records
177
+ for strand in ["top", "bottom"]:
178
+ converted_name = f"{chromosome}_{conversion}_{strand}"
179
+ unconverted_name = f"{chromosome}_{unconverted}_top"
180
+
181
+ record_FASTA_dict[converted_name] = [
182
+ sequence + "N" * (max_reference_length - sequence_length),
183
+ complement + "N" * (max_reference_length - sequence_length),
184
+ chromosome, unconverted_name, sequence_length,
185
+ max_reference_length - sequence_length, conversion, strand
186
+ ]
187
+
188
+ print("Updated record_FASTA_dict Keys:", list(record_FASTA_dict.keys()))
189
+ return max_reference_length, record_FASTA_dict, chromosome_FASTA_dict
190
+
191
+
192
+ def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold):
193
+ """Filters BAM files based on mapping threshold."""
194
+ records_to_analyze = set()
195
+
196
+ for i, bam in enumerate(bam_path_list):
197
+ aligned_reads, unaligned_reads, record_counts = count_aligned_reads(bam)
198
+ aligned_percent = aligned_reads * 100 / (aligned_reads + unaligned_reads)
199
+ print(f"{aligned_percent:.2f}% of reads in {bam_files[i]} aligned successfully.")
200
+
201
+ for record, (count, percent) in record_counts.items():
202
+ if percent >= mapping_threshold:
203
+ records_to_analyze.add(record)
204
+
205
+ print(f"Analyzing the following FASTA records: {records_to_analyze}")
206
+ return records_to_analyze
207
+
208
+
209
+ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, max_reference_length, device, deaminase_footprinting):
210
+ """Worker function to process a single BAM file (must be at top-level for multiprocessing)."""
211
+ adata_list = []
212
+
213
+ for record in records_to_analyze:
214
+ sample = os.path.basename(bam).split(sep=".bam")[0]
215
+ chromosome = record_FASTA_dict[record][2]
216
+ current_length = record_FASTA_dict[record][4]
217
+ mod_type, strand = record_FASTA_dict[record][6], record_FASTA_dict[record][7]
218
+ sequence = chromosome_FASTA_dict[chromosome][0]
219
+
220
+ # Extract Base Identities
221
+ fwd_bases, rev_bases, mismatch_counts_per_read, mismatch_trend_per_read = extract_base_identities(bam, record, range(current_length), max_reference_length, sequence)
222
+ mismatch_trend_series = pd.Series(mismatch_trend_per_read)
223
+
224
+ # Skip processing if both forward and reverse base identities are empty
225
+ if not fwd_bases and not rev_bases:
226
+ print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid base identities for {record}.")
227
+ continue
228
+
229
+ merged_bin = {}
230
+
231
+ # Binarize the Base Identities if they exist
232
+ if fwd_bases:
233
+ fwd_bin = binarize_converted_base_identities(fwd_bases, strand, mod_type, bam, device,deaminase_footprinting, mismatch_trend_per_read)
234
+ merged_bin.update(fwd_bin)
235
+
236
+ if rev_bases:
237
+ rev_bin = binarize_converted_base_identities(rev_bases, strand, mod_type, bam, device, deaminase_footprinting, mismatch_trend_per_read)
238
+ merged_bin.update(rev_bin)
239
+
240
+ # Skip if merged_bin is empty (no valid binarized data)
241
+ if not merged_bin:
242
+ print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid binarized data for {record}.")
243
+ continue
244
+
245
+ # Convert to DataFrame
246
+ # for key in merged_bin:
247
+ # merged_bin[key] = merged_bin[key].cpu().numpy() # Move to CPU & convert to NumPy
248
+ bin_df = pd.DataFrame.from_dict(merged_bin, orient='index')
249
+ sorted_index = sorted(bin_df.index)
250
+ bin_df = bin_df.reindex(sorted_index)
251
+
252
+ # One-Hot Encode Reads if there is valid data
253
+ one_hot_reads = {}
254
+
255
+ if fwd_bases:
256
+ fwd_ohe_files = ohe_batching(fwd_bases, tmp_dir, record, f"{bam_index}_fwd", batch_size=100000)
257
+ for ohe_file in fwd_ohe_files:
258
+ tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
259
+ one_hot_reads.update(tmp_ohe_dict)
260
+ del tmp_ohe_dict
261
+
262
+ if rev_bases:
263
+ rev_ohe_files = ohe_batching(rev_bases, tmp_dir, record, f"{bam_index}_rev", batch_size=100000)
264
+ for ohe_file in rev_ohe_files:
265
+ tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
266
+ one_hot_reads.update(tmp_ohe_dict)
267
+ del tmp_ohe_dict
268
+
269
+ # Skip if one_hot_reads is empty
270
+ if not one_hot_reads:
271
+ print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid one-hot encoded data for {record}.")
272
+ continue
273
+
274
+ gc.collect()
275
+
276
+ # Convert One-Hot Encodings to Numpy Arrays
277
+ n_rows_OHE = 5
278
+ read_names = list(one_hot_reads.keys())
279
+
280
+ # Skip if no read names exist
281
+ if not read_names:
282
+ print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No reads found in one-hot encoded data for {record}.")
283
+ continue
284
+
285
+ sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
286
+ df_A, df_C, df_G, df_T, df_N = [np.zeros((len(sorted_index), sequence_length), dtype=int) for _ in range(5)]
287
+
288
+ # Populate One-Hot Arrays
289
+ for j, read_name in enumerate(sorted_index):
290
+ if read_name in one_hot_reads:
291
+ one_hot_array = one_hot_reads[read_name].reshape(n_rows_OHE, -1)
292
+ df_A[j], df_C[j], df_G[j], df_T[j], df_N[j] = one_hot_array
293
+
294
+ # Convert to AnnData
295
+ X = bin_df.values.astype(np.float32)
296
+ adata = ad.AnnData(X)
297
+ adata.obs_names = bin_df.index.astype(str)
298
+ adata.var_names = bin_df.columns.astype(str)
299
+ adata.obs["Sample"] = [sample] * len(adata)
300
+ try:
301
+ barcode = sample.split('barcode')[1]
302
+ except:
303
+ barcode = np.nan
304
+ adata.obs["Barcode"] = [int(barcode)] * len(adata)
305
+ adata.obs["Barcode"] = adata.obs["Barcode"].astype(str)
306
+ adata.obs["Reference"] = [chromosome] * len(adata)
307
+ adata.obs["Strand"] = [strand] * len(adata)
308
+ adata.obs["Dataset"] = [mod_type] * len(adata)
309
+ adata.obs["Reference_dataset_strand"] = [f"{chromosome}_{mod_type}_{strand}"] * len(adata)
310
+ adata.obs["Reference_strand"] = [f"{chromosome}_{strand}"] * len(adata)
311
+ adata.obs["Read_mismatch_trend"] = adata.obs_names.map(mismatch_trend_series)
312
+
313
+ # Attach One-Hot Encodings to Layers
314
+ adata.layers["A_binary_encoding"] = df_A
315
+ adata.layers["C_binary_encoding"] = df_C
316
+ adata.layers["G_binary_encoding"] = df_G
317
+ adata.layers["T_binary_encoding"] = df_T
318
+ adata.layers["N_binary_encoding"] = df_N
319
+
320
+ adata_list.append(adata)
321
+
322
+ return ad.concat(adata_list, join="outer") if adata_list else None
323
+
324
+ def timestamp():
325
+ """Returns a formatted timestamp for logging."""
326
+ return time.strftime("[%Y-%m-%d %H:%M:%S]")
327
+
328
+
329
+ def worker_function(bam_index, bam, records_to_analyze, shared_record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, deaminase_footprinting, progress_queue):
330
+ """Worker function that processes a single BAM and writes the output to an H5AD file."""
331
+ worker_id = current_process().pid # Get worker process ID
332
+ sample = os.path.basename(bam).split(sep=".bam")[0]
333
+
334
+ try:
335
+ print(f"{timestamp()} [Worker {worker_id}] Processing BAM: {sample}")
336
+
337
+ h5ad_path = os.path.join(h5_dir, f"{sample}.h5ad")
338
+ if os.path.exists(h5ad_path):
339
+ print(f"{timestamp()} [Worker {worker_id}] Skipping {sample}: Already processed.")
340
+ progress_queue.put(sample)
341
+ return
342
+
343
+ # Filter records specific to this BAM
344
+ bam_records_to_analyze = {record for record in records_to_analyze if record in shared_record_FASTA_dict}
345
+
346
+ if not bam_records_to_analyze:
347
+ print(f"{timestamp()} [Worker {worker_id}] No valid records to analyze for {sample}. Skipping.")
348
+ progress_queue.put(sample)
349
+ return
350
+
351
+ # Process BAM
352
+ adata = process_single_bam(bam_index, bam, bam_records_to_analyze, shared_record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, max_reference_length, device, deaminase_footprinting)
353
+
354
+ if adata is not None:
355
+ adata.write_h5ad(h5ad_path)
356
+ print(f"{timestamp()} [Worker {worker_id}] Completed processing for BAM: {sample}")
357
+
358
+ # Free memory
359
+ del adata
360
+ gc.collect()
361
+
362
+ progress_queue.put(sample)
363
+
364
+ except Exception as e:
365
+ print(f"{timestamp()} [Worker {worker_id}] ERROR while processing {sample}:\n{traceback.format_exc()}")
366
+ progress_queue.put(sample) # Still signal completion to prevent deadlock
367
+
368
+ def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device, deaminase_footprinting):
369
+ """Processes BAM files in parallel, writes each H5AD to disk, and concatenates them at the end."""
370
+ os.makedirs(h5_dir, exist_ok=True) # Ensure h5_dir exists
371
+
372
+ print(f"{timestamp()} Starting parallel BAM processing with {num_threads} threads...")
373
+
374
+ # Ensure macOS uses forkserver to avoid spawning issues
375
+ try:
376
+ import multiprocessing
377
+ multiprocessing.set_start_method("forkserver", force=True)
378
+ except RuntimeError:
379
+ print(f"{timestamp()} [WARNING] Multiprocessing context already set. Skipping set_start_method.")
380
+
381
+ with Manager() as manager:
382
+ progress_queue = manager.Queue()
383
+ shared_record_FASTA_dict = manager.dict(record_FASTA_dict)
384
+
385
+ with Pool(processes=num_threads) as pool:
386
+ results = [
387
+ pool.apply_async(worker_function, (i, bam, records_to_analyze, shared_record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, deaminase_footprinting, progress_queue))
388
+ for i, bam in enumerate(bam_path_list)
389
+ ]
390
+
391
+ print(f"{timestamp()} Submitted {len(bam_path_list)} BAMs for processing.")
392
+
393
+ # Track completed BAMs
394
+ completed_bams = set()
395
+ while len(completed_bams) < len(bam_path_list):
396
+ try:
397
+ processed_bam = progress_queue.get(timeout=2400) # Wait for a finished BAM
398
+ completed_bams.add(processed_bam)
399
+ except Exception as e:
400
+ print(f"{timestamp()} [ERROR] Timeout waiting for worker process. Possible crash? {e}")
401
+
402
+ pool.close()
403
+ pool.join() # Ensure all workers finish
404
+
405
+ # Final Concatenation Step
406
+ h5ad_files = [os.path.join(h5_dir, f) for f in os.listdir(h5_dir) if f.endswith(".h5ad")]
407
+
408
+ if not h5ad_files:
409
+ print(f"{timestamp()} No valid H5AD files generated. Exiting.")
410
+ return None
411
+
412
+ print(f"{timestamp()} Concatenating {len(h5ad_files)} H5AD files into final output...")
413
+ final_adata = ad.concat([ad.read_h5ad(f) for f in h5ad_files], join="outer")
414
+
415
+ print(f"{timestamp()} Successfully generated final AnnData object.")
416
+ return final_adata
417
+
418
+ def delete_intermediate_h5ads_and_tmpdir(
419
+ h5_dir: Union[str, Path, Iterable[str], None],
420
+ tmp_dir: Optional[Union[str, Path]] = None,
421
+ *,
422
+ dry_run: bool = False,
423
+ verbose: bool = True,
424
+ ):
425
+ """
426
+ Delete intermediate .h5ad files and a temporary directory.
427
+
428
+ Parameters
429
+ ----------
430
+ h5_dir : str | Path | iterable[str] | None
431
+ If a directory path is given, all files directly inside it will be considered.
432
+ If an iterable of file paths is given, those files will be considered.
433
+ Only files ending with '.h5ad' (and not ending with '.gz') are removed.
434
+ tmp_dir : str | Path | None
435
+ Path to a directory to remove recursively (e.g. a temp dir created earlier).
436
+ dry_run : bool
437
+ If True, print what *would* be removed but do not actually delete.
438
+ verbose : bool
439
+ Print progress / warnings.
440
+ """
441
+ # Helper: remove a single file path (Path-like or string)
442
+ def _maybe_unlink(p: Path):
443
+ if not p.exists():
444
+ if verbose:
445
+ print(f"[skip] not found: {p}")
446
+ return
447
+ if not p.is_file():
448
+ if verbose:
449
+ print(f"[skip] not a file: {p}")
450
+ return
451
+ if dry_run:
452
+ print(f"[dry-run] would remove file: {p}")
453
+ return
454
+ try:
455
+ p.unlink()
456
+ if verbose:
457
+ print(f"Removed file: {p}")
458
+ except Exception as e:
459
+ print(f"[error] failed to remove file {p}: {e}")
460
+
461
+ # Handle h5_dir input (directory OR iterable of file paths)
462
+ if h5_dir is not None:
463
+ # If it's a path to a directory, iterate its children
464
+ if isinstance(h5_dir, (str, Path)) and Path(h5_dir).is_dir():
465
+ dpath = Path(h5_dir)
466
+ for p in dpath.iterdir():
467
+ # only target top-level files (not recursing); require '.h5ad' suffix and exclude gz
468
+ name = p.name.lower()
469
+ if name.endswith(".h5ad") and not name.endswith(".gz"):
470
+ _maybe_unlink(p)
471
+ else:
472
+ if verbose:
473
+ # optional: comment this out if too noisy
474
+ print(f"[skip] not matching pattern: {p.name}")
475
+ else:
476
+ # treat as iterable of file paths
477
+ for f in h5_dir:
478
+ p = Path(f)
479
+ name = p.name.lower()
480
+ if name.endswith(".h5ad") and not name.endswith(".gz"):
481
+ _maybe_unlink(p)
482
+ else:
483
+ if verbose:
484
+ print(f"[skip] not matching pattern or not a file: {p}")
485
+
486
+ # Remove tmp_dir recursively (if provided)
487
+ if tmp_dir is not None:
488
+ td = Path(tmp_dir)
489
+ if not td.exists():
490
+ if verbose:
491
+ print(f"[skip] tmp_dir not found: {td}")
492
+ else:
493
+ if not td.is_dir():
494
+ if verbose:
495
+ print(f"[skip] tmp_dir is not a directory: {td}")
496
+ else:
497
+ if dry_run:
498
+ print(f"[dry-run] would remove directory tree: {td}")
499
+ else:
500
+ try:
501
+ shutil.rmtree(td)
502
+ if verbose:
503
+ print(f"Removed directory tree: {td}")
504
+ except Exception as e:
505
+ print(f"[error] failed to remove tmp dir {td}: {e}")
@@ -0,0 +1,43 @@
1
+ ## count_aligned_reads
2
+
3
+ # General
4
+ def count_aligned_reads(bam_file):
5
+ """
6
+ Counts the number of aligned reads in a bam file that map to each reference record.
7
+
8
+ Parameters:
9
+ bam_file (str): A string representing the path to an aligned BAM file.
10
+
11
+ Returns:
12
+ aligned_reads_count (int): The total number or reads aligned in the BAM.
13
+ unaligned_reads_count (int): The total number of reads not aligned in the BAM.
14
+ record_counts (dict): A dictionary keyed by reference record instance that points toa tuple containing the total reads mapped to the record and the fraction of mapped reads which map to the record.
15
+
16
+ """
17
+ from .. import readwrite
18
+ import pysam
19
+ from tqdm import tqdm
20
+ from collections import defaultdict
21
+
22
+ print('{0}: Counting aligned reads in BAM > {1}'.format(readwrite.time_string(), bam_file))
23
+ aligned_reads_count = 0
24
+ unaligned_reads_count = 0
25
+ # Make a dictionary, keyed by the reference_name of reference chromosome that points to an integer number of read counts mapped to the chromosome, as well as the proportion of mapped reads in that chromosome
26
+ record_counts = defaultdict(int)
27
+
28
+ with pysam.AlignmentFile(bam_file, "rb") as bam:
29
+ total_reads = bam.mapped + bam.unmapped
30
+ # Iterate over reads to get the total mapped read counts and the reads that map to each reference
31
+ for read in tqdm(bam, desc='Counting aligned reads in BAM', total=total_reads):
32
+ if read.is_unmapped:
33
+ unaligned_reads_count += 1
34
+ else:
35
+ aligned_reads_count += 1
36
+ record_counts[read.reference_name] += 1 # Automatically increments if key exists, adds if not
37
+
38
+ # reformat the dictionary to contain read counts mapped to the reference, as well as the proportion of mapped reads in reference
39
+ for reference in record_counts:
40
+ proportion_mapped_reads_in_record = record_counts[reference] / aligned_reads_count
41
+ record_counts[reference] = (record_counts[reference], proportion_mapped_reads_in_record)
42
+
43
+ return aligned_reads_count, unaligned_reads_count, dict(record_counts)
@@ -0,0 +1,52 @@
1
+ ## demux_and_index_BAM
2
+
3
+ def demux_and_index_BAM(aligned_sorted_BAM, split_dir, bam_suffix, barcode_kit, barcode_both_ends, trim, fasta, make_bigwigs, threads):
4
+ """
5
+ A wrapper function for splitting BAMS and indexing them.
6
+ Parameters:
7
+ aligned_sorted_BAM (str): A string representing the file path of the aligned_sorted BAM file.
8
+ split_dir (str): A string representing the file path to the directory to split the BAMs into.
9
+ bam_suffix (str): A suffix to add to the bam file.
10
+ barcode_kit (str): Name of barcoding kit.
11
+ barcode_both_ends (bool): Whether to require both ends to be barcoded.
12
+ trim (bool): Whether to trim off barcodes after demultiplexing.
13
+ fasta (str): File path to the reference genome to align to.
14
+ make_bigwigs (bool): Whether to make bigwigs
15
+ threads (int): Number of threads to use.
16
+
17
+ Returns:
18
+ bam_files (list): List of split BAM file path strings
19
+ Splits an input BAM file on barcode value and makes a BAM index file.
20
+ """
21
+ from .. import readwrite
22
+ import os
23
+ import subprocess
24
+ import glob
25
+ from .make_dirs import make_dirs
26
+
27
+ input_bam = aligned_sorted_BAM + bam_suffix
28
+ command = ["dorado", "demux", "--kit-name", barcode_kit]
29
+ if barcode_both_ends:
30
+ command.append("--barcode-both-ends")
31
+ if not trim:
32
+ command.append("--no-trim")
33
+ if threads:
34
+ command += ["-t", str(threads)]
35
+ else:
36
+ pass
37
+ command += ["--emit-summary", "--sort-bam", "--output-dir", split_dir]
38
+ command.append(input_bam)
39
+ command_string = ' '.join(command)
40
+ print(f"Running: {command_string}")
41
+ subprocess.run(command)
42
+
43
+ # Make a BAM index file for the BAMs in that directory
44
+ bam_pattern = '*' + bam_suffix
45
+ bam_files = glob.glob(os.path.join(split_dir, bam_pattern))
46
+ bam_files = [bam for bam in bam_files if '.bai' not in bam and 'unclassified' not in bam]
47
+ bam_files.sort()
48
+
49
+ if not bam_files:
50
+ raise FileNotFoundError(f"No BAM files found in {split_dir} with suffix {bam_suffix}")
51
+
52
+ return bam_files