smftools 0.1.7__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. smftools/__init__.py +9 -4
  2. smftools/_version.py +1 -1
  3. smftools/cli.py +184 -0
  4. smftools/config/__init__.py +1 -0
  5. smftools/config/conversion.yaml +33 -0
  6. smftools/config/deaminase.yaml +56 -0
  7. smftools/config/default.yaml +253 -0
  8. smftools/config/direct.yaml +17 -0
  9. smftools/config/experiment_config.py +1191 -0
  10. smftools/hmm/HMM.py +1576 -0
  11. smftools/hmm/__init__.py +20 -0
  12. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  13. smftools/hmm/call_hmm_peaks.py +106 -0
  14. smftools/{tools → hmm}/display_hmm.py +3 -3
  15. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  16. smftools/{tools → hmm}/train_hmm.py +1 -1
  17. smftools/informatics/__init__.py +0 -2
  18. smftools/informatics/archived/deaminase_smf.py +132 -0
  19. smftools/informatics/fast5_to_pod5.py +4 -1
  20. smftools/informatics/helpers/__init__.py +3 -4
  21. smftools/informatics/helpers/align_and_sort_BAM.py +34 -7
  22. smftools/informatics/helpers/aligned_BAM_to_bed.py +35 -24
  23. smftools/informatics/helpers/binarize_converted_base_identities.py +116 -23
  24. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +365 -42
  25. smftools/informatics/helpers/converted_BAM_to_adata_II.py +165 -29
  26. smftools/informatics/helpers/discover_input_files.py +100 -0
  27. smftools/informatics/helpers/extract_base_identities.py +29 -3
  28. smftools/informatics/helpers/extract_read_features_from_bam.py +4 -2
  29. smftools/informatics/helpers/find_conversion_sites.py +5 -4
  30. smftools/informatics/helpers/modkit_extract_to_adata.py +6 -3
  31. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  32. smftools/informatics/helpers/separate_bam_by_bc.py +2 -2
  33. smftools/informatics/helpers/split_and_index_BAM.py +1 -5
  34. smftools/load_adata.py +1346 -0
  35. smftools/machine_learning/__init__.py +12 -0
  36. smftools/machine_learning/data/__init__.py +2 -0
  37. smftools/machine_learning/data/anndata_data_module.py +234 -0
  38. smftools/machine_learning/evaluation/__init__.py +2 -0
  39. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  40. smftools/machine_learning/evaluation/evaluators.py +223 -0
  41. smftools/machine_learning/inference/__init__.py +3 -0
  42. smftools/machine_learning/inference/inference_utils.py +27 -0
  43. smftools/machine_learning/inference/lightning_inference.py +68 -0
  44. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  45. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  46. smftools/machine_learning/models/base.py +295 -0
  47. smftools/machine_learning/models/cnn.py +138 -0
  48. smftools/machine_learning/models/lightning_base.py +345 -0
  49. smftools/machine_learning/models/mlp.py +26 -0
  50. smftools/{tools → machine_learning}/models/positional.py +3 -2
  51. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  52. smftools/machine_learning/models/sklearn_models.py +273 -0
  53. smftools/machine_learning/models/transformer.py +303 -0
  54. smftools/machine_learning/training/__init__.py +2 -0
  55. smftools/machine_learning/training/train_lightning_model.py +135 -0
  56. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  57. smftools/plotting/__init__.py +4 -1
  58. smftools/plotting/autocorrelation_plotting.py +611 -0
  59. smftools/plotting/general_plotting.py +566 -89
  60. smftools/plotting/hmm_plotting.py +260 -0
  61. smftools/plotting/qc_plotting.py +270 -0
  62. smftools/preprocessing/__init__.py +13 -8
  63. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  64. smftools/preprocessing/append_base_context.py +122 -0
  65. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  66. smftools/preprocessing/calculate_complexity_II.py +248 -0
  67. smftools/preprocessing/calculate_coverage.py +10 -1
  68. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  69. smftools/preprocessing/clean_NaN.py +17 -1
  70. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  71. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  72. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  73. smftools/preprocessing/invert_adata.py +12 -5
  74. smftools/preprocessing/load_sample_sheet.py +19 -4
  75. smftools/readwrite.py +849 -43
  76. smftools/tools/__init__.py +3 -32
  77. smftools/tools/calculate_umap.py +5 -5
  78. smftools/tools/general_tools.py +3 -3
  79. smftools/tools/position_stats.py +468 -106
  80. smftools/tools/read_stats.py +115 -1
  81. smftools/tools/spatial_autocorrelation.py +562 -0
  82. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/METADATA +5 -1
  83. smftools-0.2.1.dist-info/RECORD +161 -0
  84. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  85. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  86. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  87. smftools/informatics/load_adata.py +0 -182
  88. smftools/preprocessing/append_C_context.py +0 -82
  89. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  90. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  91. smftools/preprocessing/filter_reads_on_length.py +0 -51
  92. smftools/tools/call_hmm_peaks.py +0 -105
  93. smftools/tools/data/__init__.py +0 -2
  94. smftools/tools/data/anndata_data_module.py +0 -90
  95. smftools/tools/evaluation/__init__.py +0 -0
  96. smftools/tools/inference/__init__.py +0 -1
  97. smftools/tools/inference/lightning_inference.py +0 -41
  98. smftools/tools/models/base.py +0 -14
  99. smftools/tools/models/cnn.py +0 -34
  100. smftools/tools/models/lightning_base.py +0 -41
  101. smftools/tools/models/mlp.py +0 -17
  102. smftools/tools/models/sklearn_models.py +0 -40
  103. smftools/tools/models/transformer.py +0 -133
  104. smftools/tools/training/__init__.py +0 -1
  105. smftools/tools/training/train_lightning_model.py +0 -47
  106. smftools-0.1.7.dist-info/RECORD +0 -136
  107. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  108. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  109. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  110. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  111. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  112. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  113. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  114. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  115. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  116. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  117. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  118. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  119. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  120. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -11,7 +11,11 @@ import traceback
11
11
  import gzip
12
12
  import torch
13
13
 
14
- from .. import readwrite
14
+ import shutil
15
+ from pathlib import Path
16
+ from typing import Union, Iterable, Optional
17
+
18
+ from ... import readwrite
15
19
  from .binarize_converted_base_identities import binarize_converted_base_identities
16
20
  from .find_conversion_sites import find_conversion_sites
17
21
  from .count_aligned_reads import count_aligned_reads
@@ -22,7 +26,17 @@ from .ohe_batching import ohe_batching
22
26
  if __name__ == "__main__":
23
27
  multiprocessing.set_start_method("forkserver", force=True)
24
28
 
25
- def converted_BAM_to_adata_II(converted_FASTA, split_dir, mapping_threshold, experiment_name, conversion_types, bam_suffix, device='cpu', num_threads=8):
29
+ def converted_BAM_to_adata_II(converted_FASTA,
30
+ split_dir,
31
+ mapping_threshold,
32
+ experiment_name,
33
+ conversions,
34
+ bam_suffix,
35
+ device='cpu',
36
+ num_threads=8,
37
+ deaminase_footprinting=False,
38
+ delete_intermediates=True
39
+ ):
26
40
  """
27
41
  Converts BAM files into an AnnData object by binarizing modified base identities.
28
42
 
@@ -31,9 +45,10 @@ def converted_BAM_to_adata_II(converted_FASTA, split_dir, mapping_threshold, exp
31
45
  split_dir (str): Directory containing converted BAM files.
32
46
  mapping_threshold (float): Minimum fraction of aligned reads required for inclusion.
33
47
  experiment_name (str): Name for the output AnnData object.
34
- conversion_types (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
48
+ conversions (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
35
49
  bam_suffix (str): File suffix for BAM files.
36
50
  num_threads (int): Number of parallel processing threads.
51
+ deaminase_footprinting (bool): Whether the footprinting was done with a direct deamination chemistry.
37
52
 
38
53
  Returns:
39
54
  str: Path to the final AnnData object.
@@ -48,14 +63,15 @@ def converted_BAM_to_adata_II(converted_FASTA, split_dir, mapping_threshold, exp
48
63
  print(f"Using device: {device}")
49
64
 
50
65
  ## Set Up Directories and File Paths
51
- parent_dir = os.path.dirname(split_dir)
52
- h5_dir = os.path.join(parent_dir, 'h5ads')
53
- tmp_dir = os.path.join(parent_dir, 'tmp')
66
+ #parent_dir = os.path.dirname(split_dir)
67
+ h5_dir = os.path.join(split_dir, 'h5ads')
68
+ tmp_dir = os.path.join(split_dir, 'tmp')
69
+ final_adata = None
54
70
  final_adata_path = os.path.join(h5_dir, f'{experiment_name}_{os.path.basename(split_dir)}.h5ad.gz')
55
71
 
56
72
  if os.path.exists(final_adata_path):
57
73
  print(f"{final_adata_path} already exists. Using existing AnnData object.")
58
- return final_adata_path
74
+ return final_adata, final_adata_path
59
75
 
60
76
  make_dirs([h5_dir, tmp_dir])
61
77
 
@@ -66,32 +82,46 @@ def converted_BAM_to_adata_II(converted_FASTA, split_dir, mapping_threshold, exp
66
82
  print(f"Found {len(bam_files)} BAM files: {bam_files}")
67
83
 
68
84
  ## Process Conversion Sites
69
- max_reference_length, record_FASTA_dict, chromosome_FASTA_dict = process_conversion_sites(converted_FASTA, conversion_types)
85
+ max_reference_length, record_FASTA_dict, chromosome_FASTA_dict = process_conversion_sites(converted_FASTA, conversions, deaminase_footprinting)
70
86
 
71
87
  ## Filter BAM Files by Mapping Threshold
72
88
  records_to_analyze = filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold)
73
89
 
74
90
  ## Process BAMs in Parallel
75
- final_adata = process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device)
91
+ final_adata = process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device, deaminase_footprinting)
76
92
 
77
93
  for chromosome, [seq, comp] in chromosome_FASTA_dict.items():
78
94
  final_adata.var[f'{chromosome}_top_strand_FASTA_base'] = list(seq)
79
95
  final_adata.var[f'{chromosome}_bottom_strand_FASTA_base'] = list(comp)
80
96
  final_adata.uns[f'{chromosome}_FASTA_sequence'] = seq
81
97
 
98
+ final_adata.obs_names_make_unique()
99
+ cols = final_adata.obs.columns
100
+
101
+ # Make obs cols categorical
102
+ for col in cols:
103
+ final_adata.obs[col] = final_adata.obs[col].astype('category')
104
+
82
105
  ## Save Final AnnData
83
- # print(f"Saving AnnData to {final_adata_path}")
84
- # final_adata.write_h5ad(final_adata_path, compression='gzip')
106
+ print(f"Saving AnnData to {final_adata_path}")
107
+ backup_dir=os.path.join(os.path.dirname(final_adata_path), 'adata_accessory_data')
108
+ readwrite.safe_write_h5ad(final_adata, final_adata_path, compression='gzip', backup=True, backup_dir=backup_dir)
109
+
110
+ ## Delete intermediate h5ad files and temp directories
111
+ if delete_intermediates:
112
+ delete_intermediate_h5ads_and_tmpdir(h5_dir, tmp_dir)
113
+
85
114
  return final_adata, final_adata_path
86
115
 
87
116
 
88
- def process_conversion_sites(converted_FASTA, conversion_types):
117
+ def process_conversion_sites(converted_FASTA, conversions=['unconverted', '5mC'], deaminase_footprinting=False):
89
118
  """
90
119
  Extracts conversion sites and determines the max reference length.
91
120
 
92
121
  Parameters:
93
122
  converted_FASTA (str): Path to the converted reference FASTA.
94
- conversion_types (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
123
+ conversions (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
124
+ deaminase_footprinting (bool): Whether the footprinting was done with a direct deamination chemistry.
95
125
 
96
126
  Returns:
97
127
  max_reference_length (int): The length of the longest sequence.
@@ -101,11 +131,11 @@ def process_conversion_sites(converted_FASTA, conversion_types):
101
131
  record_FASTA_dict = {}
102
132
  chromosome_FASTA_dict = {}
103
133
  max_reference_length = 0
104
- unconverted = conversion_types[0]
105
- conversions = conversion_types[1:]
134
+ unconverted = conversions[0]
135
+ conversion_types = conversions[1:]
106
136
 
107
137
  # Process the unconverted sequence once
108
- modification_dict[unconverted] = find_conversion_sites(converted_FASTA, unconverted, conversion_types)
138
+ modification_dict[unconverted] = find_conversion_sites(converted_FASTA, unconverted, conversions, deaminase_footprinting)
109
139
  # Above points to record_dict[record.id] = [sequence_length, [], [], sequence, complement] with only unconverted record.id keys
110
140
 
111
141
  # Get **max sequence length** from unconverted records
@@ -114,7 +144,11 @@ def process_conversion_sites(converted_FASTA, conversion_types):
114
144
  # Add **unconverted records** to `record_FASTA_dict`
115
145
  for record, values in modification_dict[unconverted].items():
116
146
  sequence_length, top_coords, bottom_coords, sequence, complement = values
117
- chromosome = record.replace(f"_{unconverted}_top", "")
147
+
148
+ if not deaminase_footprinting:
149
+ chromosome = record.replace(f"_{unconverted}_top", "")
150
+ else:
151
+ chromosome = record
118
152
 
119
153
  # Store **original sequence**
120
154
  record_FASTA_dict[record] = [
@@ -127,13 +161,17 @@ def process_conversion_sites(converted_FASTA, conversion_types):
127
161
  chromosome_FASTA_dict[chromosome] = [sequence + "N" * (max_reference_length - sequence_length), complement + "N" * (max_reference_length - sequence_length)]
128
162
 
129
163
  # Process converted records
130
- for conversion in conversions:
131
- modification_dict[conversion] = find_conversion_sites(converted_FASTA, conversion, conversion_types)
164
+ for conversion in conversion_types:
165
+ modification_dict[conversion] = find_conversion_sites(converted_FASTA, conversion, conversions, deaminase_footprinting)
132
166
  # Above points to record_dict[record.id] = [sequence_length, top_strand_coordinates, bottom_strand_coordinates, sequence, complement] with only unconverted record.id keys
133
167
 
134
168
  for record, values in modification_dict[conversion].items():
135
169
  sequence_length, top_coords, bottom_coords, sequence, complement = values
136
- chromosome = record.split(f"_{unconverted}_")[0] # Extract chromosome name
170
+
171
+ if not deaminase_footprinting:
172
+ chromosome = record.split(f"_{unconverted}_")[0] # Extract chromosome name
173
+ else:
174
+ chromosome = record
137
175
 
138
176
  # Add **both strands** for converted records
139
177
  for strand in ["top", "bottom"]:
@@ -168,7 +206,7 @@ def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold
168
206
  return records_to_analyze
169
207
 
170
208
 
171
- def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, tmp_dir, max_reference_length, device):
209
+ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, max_reference_length, device, deaminase_footprinting):
172
210
  """Worker function to process a single BAM file (must be at top-level for multiprocessing)."""
173
211
  adata_list = []
174
212
 
@@ -177,9 +215,11 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, tm
177
215
  chromosome = record_FASTA_dict[record][2]
178
216
  current_length = record_FASTA_dict[record][4]
179
217
  mod_type, strand = record_FASTA_dict[record][6], record_FASTA_dict[record][7]
218
+ sequence = chromosome_FASTA_dict[chromosome][0]
180
219
 
181
220
  # Extract Base Identities
182
- fwd_bases, rev_bases = extract_base_identities(bam, record, range(current_length), max_reference_length)
221
+ fwd_bases, rev_bases, mismatch_counts_per_read, mismatch_trend_per_read = extract_base_identities(bam, record, range(current_length), max_reference_length, sequence)
222
+ mismatch_trend_series = pd.Series(mismatch_trend_per_read)
183
223
 
184
224
  # Skip processing if both forward and reverse base identities are empty
185
225
  if not fwd_bases and not rev_bases:
@@ -190,11 +230,11 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, tm
190
230
 
191
231
  # Binarize the Base Identities if they exist
192
232
  if fwd_bases:
193
- fwd_bin = binarize_converted_base_identities(fwd_bases, strand, mod_type, bam, device)
233
+ fwd_bin = binarize_converted_base_identities(fwd_bases, strand, mod_type, bam, device,deaminase_footprinting, mismatch_trend_per_read)
194
234
  merged_bin.update(fwd_bin)
195
235
 
196
236
  if rev_bases:
197
- rev_bin = binarize_converted_base_identities(rev_bases, strand, mod_type, bam, device)
237
+ rev_bin = binarize_converted_base_identities(rev_bases, strand, mod_type, bam, device, deaminase_footprinting, mismatch_trend_per_read)
198
238
  merged_bin.update(rev_bin)
199
239
 
200
240
  # Skip if merged_bin is empty (no valid binarized data)
@@ -257,11 +297,18 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, tm
257
297
  adata.obs_names = bin_df.index.astype(str)
258
298
  adata.var_names = bin_df.columns.astype(str)
259
299
  adata.obs["Sample"] = [sample] * len(adata)
300
+ try:
301
+ barcode = sample.split('barcode')[1]
302
+ except:
303
+ barcode = np.nan
304
+ adata.obs["Barcode"] = [int(barcode)] * len(adata)
305
+ adata.obs["Barcode"] = adata.obs["Barcode"].astype(str)
260
306
  adata.obs["Reference"] = [chromosome] * len(adata)
261
307
  adata.obs["Strand"] = [strand] * len(adata)
262
308
  adata.obs["Dataset"] = [mod_type] * len(adata)
263
309
  adata.obs["Reference_dataset_strand"] = [f"{chromosome}_{mod_type}_{strand}"] * len(adata)
264
310
  adata.obs["Reference_strand"] = [f"{chromosome}_{strand}"] * len(adata)
311
+ adata.obs["Read_mismatch_trend"] = adata.obs_names.map(mismatch_trend_series)
265
312
 
266
313
  # Attach One-Hot Encodings to Layers
267
314
  adata.layers["A_binary_encoding"] = df_A
@@ -279,7 +326,7 @@ def timestamp():
279
326
  return time.strftime("[%Y-%m-%d %H:%M:%S]")
280
327
 
281
328
 
282
- def worker_function(bam_index, bam, records_to_analyze, shared_record_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, progress_queue):
329
+ def worker_function(bam_index, bam, records_to_analyze, shared_record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, deaminase_footprinting, progress_queue):
283
330
  """Worker function that processes a single BAM and writes the output to an H5AD file."""
284
331
  worker_id = current_process().pid # Get worker process ID
285
332
  sample = os.path.basename(bam).split(sep=".bam")[0]
@@ -302,7 +349,7 @@ def worker_function(bam_index, bam, records_to_analyze, shared_record_FASTA_dict
302
349
  return
303
350
 
304
351
  # Process BAM
305
- adata = process_single_bam(bam_index, bam, bam_records_to_analyze, shared_record_FASTA_dict, tmp_dir, max_reference_length, device)
352
+ adata = process_single_bam(bam_index, bam, bam_records_to_analyze, shared_record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, max_reference_length, device, deaminase_footprinting)
306
353
 
307
354
  if adata is not None:
308
355
  adata.write_h5ad(h5ad_path)
@@ -318,7 +365,7 @@ def worker_function(bam_index, bam, records_to_analyze, shared_record_FASTA_dict
318
365
  print(f"{timestamp()} [Worker {worker_id}] ERROR while processing {sample}:\n{traceback.format_exc()}")
319
366
  progress_queue.put(sample) # Still signal completion to prevent deadlock
320
367
 
321
- def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device):
368
+ def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device, deaminase_footprinting):
322
369
  """Processes BAM files in parallel, writes each H5AD to disk, and concatenates them at the end."""
323
370
  os.makedirs(h5_dir, exist_ok=True) # Ensure h5_dir exists
324
371
 
@@ -337,7 +384,7 @@ def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict,
337
384
 
338
385
  with Pool(processes=num_threads) as pool:
339
386
  results = [
340
- pool.apply_async(worker_function, (i, bam, records_to_analyze, shared_record_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, progress_queue))
387
+ pool.apply_async(worker_function, (i, bam, records_to_analyze, shared_record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, deaminase_footprinting, progress_queue))
341
388
  for i, bam in enumerate(bam_path_list)
342
389
  ]
343
390
 
@@ -366,4 +413,93 @@ def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict,
366
413
  final_adata = ad.concat([ad.read_h5ad(f) for f in h5ad_files], join="outer")
367
414
 
368
415
  print(f"{timestamp()} Successfully generated final AnnData object.")
369
- return final_adata
416
+ return final_adata
417
+
418
+ def delete_intermediate_h5ads_and_tmpdir(
419
+ h5_dir: Union[str, Path, Iterable[str], None],
420
+ tmp_dir: Optional[Union[str, Path]] = None,
421
+ *,
422
+ dry_run: bool = False,
423
+ verbose: bool = True,
424
+ ):
425
+ """
426
+ Delete intermediate .h5ad files and a temporary directory.
427
+
428
+ Parameters
429
+ ----------
430
+ h5_dir : str | Path | iterable[str] | None
431
+ If a directory path is given, all files directly inside it will be considered.
432
+ If an iterable of file paths is given, those files will be considered.
433
+ Only files ending with '.h5ad' (and not ending with '.gz') are removed.
434
+ tmp_dir : str | Path | None
435
+ Path to a directory to remove recursively (e.g. a temp dir created earlier).
436
+ dry_run : bool
437
+ If True, print what *would* be removed but do not actually delete.
438
+ verbose : bool
439
+ Print progress / warnings.
440
+ """
441
+ # Helper: remove a single file path (Path-like or string)
442
+ def _maybe_unlink(p: Path):
443
+ if not p.exists():
444
+ if verbose:
445
+ print(f"[skip] not found: {p}")
446
+ return
447
+ if not p.is_file():
448
+ if verbose:
449
+ print(f"[skip] not a file: {p}")
450
+ return
451
+ if dry_run:
452
+ print(f"[dry-run] would remove file: {p}")
453
+ return
454
+ try:
455
+ p.unlink()
456
+ if verbose:
457
+ print(f"Removed file: {p}")
458
+ except Exception as e:
459
+ print(f"[error] failed to remove file {p}: {e}")
460
+
461
+ # Handle h5_dir input (directory OR iterable of file paths)
462
+ if h5_dir is not None:
463
+ # If it's a path to a directory, iterate its children
464
+ if isinstance(h5_dir, (str, Path)) and Path(h5_dir).is_dir():
465
+ dpath = Path(h5_dir)
466
+ for p in dpath.iterdir():
467
+ # only target top-level files (not recursing); require '.h5ad' suffix and exclude gz
468
+ name = p.name.lower()
469
+ if name.endswith(".h5ad") and not name.endswith(".gz"):
470
+ _maybe_unlink(p)
471
+ else:
472
+ if verbose:
473
+ # optional: comment this out if too noisy
474
+ print(f"[skip] not matching pattern: {p.name}")
475
+ else:
476
+ # treat as iterable of file paths
477
+ for f in h5_dir:
478
+ p = Path(f)
479
+ name = p.name.lower()
480
+ if name.endswith(".h5ad") and not name.endswith(".gz"):
481
+ _maybe_unlink(p)
482
+ else:
483
+ if verbose:
484
+ print(f"[skip] not matching pattern or not a file: {p}")
485
+
486
+ # Remove tmp_dir recursively (if provided)
487
+ if tmp_dir is not None:
488
+ td = Path(tmp_dir)
489
+ if not td.exists():
490
+ if verbose:
491
+ print(f"[skip] tmp_dir not found: {td}")
492
+ else:
493
+ if not td.is_dir():
494
+ if verbose:
495
+ print(f"[skip] tmp_dir is not a directory: {td}")
496
+ else:
497
+ if dry_run:
498
+ print(f"[dry-run] would remove directory tree: {td}")
499
+ else:
500
+ try:
501
+ shutil.rmtree(td)
502
+ if verbose:
503
+ print(f"Removed directory tree: {td}")
504
+ except Exception as e:
505
+ print(f"[error] failed to remove tmp dir {td}: {e}")
@@ -0,0 +1,100 @@
1
+ from pathlib import Path
2
+ from typing import Dict, List, Any, Tuple
3
+
4
+ def discover_input_files(
5
+ input_data_path: str,
6
+ bam_suffix: str = ".bam",
7
+ recursive: bool = False,
8
+ follow_symlinks: bool = False,
9
+ ) -> Dict[str, Any]:
10
+ """
11
+ Discover input files under `input_data_path`.
12
+
13
+ Returns a dict with:
14
+ - pod5_paths, fast5_paths, fastq_paths, bam_paths (lists of str)
15
+ - input_is_pod5, input_is_fast5, input_is_fastq, input_is_bam (bools)
16
+ - all_files_searched (int)
17
+ Behavior:
18
+ - If `input_data_path` is a file, returns that single file categorized.
19
+ - If it is a directory, scans either immediate children (recursive=False)
20
+ or entire tree (recursive=True). Uses Path.suffixes to detect .fastq.gz etc.
21
+ """
22
+ p = Path(input_data_path)
23
+ pod5_exts = {".pod5", ".p5"}
24
+ fast5_exts = {".fast5", ".f5"}
25
+ fastq_exts = {".fastq", ".fq", ".fastq.gz", ".fq.gz", ".fastq.xz", ".fq.xz"}
26
+ # normalize bam suffix with leading dot
27
+ if not bam_suffix.startswith("."):
28
+ bam_suffix = "." + bam_suffix
29
+ bam_suffix = bam_suffix.lower()
30
+
31
+ pod5_paths: List[str] = []
32
+ fast5_paths: List[str] = []
33
+ fastq_paths: List[str] = []
34
+ bam_paths: List[str] = []
35
+ other_paths: List[str] = []
36
+
37
+ def _file_ext_key(pp: Path) -> str:
38
+ # join suffixes to handle .fastq.gz
39
+ return "".join(pp.suffixes).lower() if pp.suffixes else pp.suffix.lower()
40
+
41
+ if p.exists() and p.is_file():
42
+ ext_key = _file_ext_key(p)
43
+ if ext_key in pod5_exts:
44
+ pod5_paths.append(str(p))
45
+ elif ext_key in fast5_exts:
46
+ fast5_paths.append(str(p))
47
+ elif ext_key in fastq_exts:
48
+ fastq_paths.append(str(p))
49
+ elif ext_key == bam_suffix:
50
+ bam_paths.append(str(p))
51
+ else:
52
+ other_paths.append(str(p))
53
+ total_searched = 1
54
+ elif p.exists() and p.is_dir():
55
+ if recursive:
56
+ iterator = p.rglob("*")
57
+ else:
58
+ iterator = p.iterdir()
59
+ total_searched = 0
60
+ for fp in iterator:
61
+ if not fp.is_file():
62
+ continue
63
+ total_searched += 1
64
+ ext_key = _file_ext_key(fp)
65
+ if ext_key in pod5_exts:
66
+ pod5_paths.append(str(fp))
67
+ elif ext_key in fast5_exts:
68
+ fast5_paths.append(str(fp))
69
+ elif ext_key in fastq_exts:
70
+ fastq_paths.append(str(fp))
71
+ elif ext_key == bam_suffix:
72
+ bam_paths.append(str(fp))
73
+ else:
74
+ # additional heuristic: check filename contains extension fragments (.pod5 etc)
75
+ name = fp.name.lower()
76
+ if any(e in name for e in pod5_exts):
77
+ pod5_paths.append(str(fp))
78
+ elif any(e in name for e in fast5_exts):
79
+ fast5_paths.append(str(fp))
80
+ elif any(e in name for e in [".fastq", ".fq"]):
81
+ fastq_paths.append(str(fp))
82
+ elif name.endswith(bam_suffix):
83
+ bam_paths.append(str(fp))
84
+ else:
85
+ other_paths.append(str(fp))
86
+ else:
87
+ raise FileNotFoundError(f"input_data_path does not exist: {input_data_path}")
88
+
89
+ return {
90
+ "pod5_paths": sorted(pod5_paths),
91
+ "fast5_paths": sorted(fast5_paths),
92
+ "fastq_paths": sorted(fastq_paths),
93
+ "bam_paths": sorted(bam_paths),
94
+ "other_paths": sorted(other_paths),
95
+ "input_is_pod5": len(pod5_paths) > 0,
96
+ "input_is_fast5": len(fast5_paths) > 0,
97
+ "input_is_fastq": len(fastq_paths) > 0,
98
+ "input_is_bam": len(bam_paths) > 0,
99
+ "all_files_searched": total_searched,
100
+ }
@@ -1,4 +1,4 @@
1
- def extract_base_identities(bam_file, chromosome, positions, max_reference_length):
1
+ def extract_base_identities(bam_file, chromosome, positions, max_reference_length, sequence):
2
2
  """
3
3
  Efficiently extracts base identities from mapped reads with reference coordinates.
4
4
 
@@ -7,6 +7,7 @@ def extract_base_identities(bam_file, chromosome, positions, max_reference_lengt
7
7
  chromosome (str): Name of the reference chromosome.
8
8
  positions (list): Positions to extract (0-based).
9
9
  max_reference_length (int): Maximum reference length for padding.
10
+ sequence (str): The sequence of the record fasta
10
11
 
11
12
  Returns:
12
13
  dict: Base identities from forward mapped reads.
@@ -16,16 +17,19 @@ def extract_base_identities(bam_file, chromosome, positions, max_reference_lengt
16
17
  import numpy as np
17
18
  from collections import defaultdict
18
19
  import time
20
+ from collections import defaultdict, Counter
19
21
 
20
22
  timestamp = time.strftime("[%Y-%m-%d %H:%M:%S]")
21
23
 
22
24
  positions = set(positions)
23
25
  fwd_base_identities = defaultdict(lambda: np.full(max_reference_length, 'N', dtype='<U1'))
24
26
  rev_base_identities = defaultdict(lambda: np.full(max_reference_length, 'N', dtype='<U1'))
27
+ mismatch_counts_per_read = defaultdict(lambda: defaultdict(Counter))
25
28
 
26
29
  #print(f"{timestamp} Reading reads from {chromosome} BAM file: {bam_file}")
27
30
  with pysam.AlignmentFile(bam_file, "rb") as bam:
28
31
  total_reads = bam.mapped
32
+ ref_seq = sequence.upper()
29
33
  for read in bam.fetch(chromosome):
30
34
  if not read.is_mapped:
31
35
  continue # Skip unmapped reads
@@ -39,6 +43,28 @@ def extract_base_identities(bam_file, chromosome, positions, max_reference_lengt
39
43
 
40
44
  for read_position, reference_position in aligned_pairs:
41
45
  if reference_position in positions:
42
- base_dict[read_name][reference_position] = query_sequence[read_position]
46
+ read_base = query_sequence[read_position]
47
+ ref_base = ref_seq[reference_position]
43
48
 
44
- return dict(fwd_base_identities), dict(rev_base_identities)
49
+ base_dict[read_name][reference_position] = read_base
50
+
51
+ # Track mismatches (excluding Ns)
52
+ if read_base != ref_base and read_base != 'N' and ref_base != 'N':
53
+ mismatch_counts_per_read[read_name][ref_base][read_base] += 1
54
+
55
+ # Determine C→T vs G→A dominance per read
56
+ mismatch_trend_per_read = {}
57
+ for read_name, ref_dict in mismatch_counts_per_read.items():
58
+ c_to_t = ref_dict.get("C", {}).get("T", 0)
59
+ g_to_a = ref_dict.get("G", {}).get("A", 0)
60
+
61
+ if abs(c_to_t - g_to_a) < 0.01 and c_to_t > 0:
62
+ mismatch_trend_per_read[read_name] = "equal"
63
+ elif c_to_t > g_to_a:
64
+ mismatch_trend_per_read[read_name] = "C->T"
65
+ elif g_to_a > c_to_t:
66
+ mismatch_trend_per_read[read_name] = "G->A"
67
+ else:
68
+ mismatch_trend_per_read[read_name] = "none"
69
+
70
+ return dict(fwd_base_identities), dict(rev_base_identities), dict(mismatch_counts_per_read), mismatch_trend_per_read
@@ -2,7 +2,7 @@
2
2
 
3
3
  def extract_read_features_from_bam(bam_file_path):
4
4
  """
5
- Make a dict of reads from a bam that points to a list of read metrics: read length, read median Q-score, reference length.
5
+ Make a dict of reads from a bam that points to a list of read metrics: read length, read median Q-score, reference length, mapped length, mapping quality
6
6
  Params:
7
7
  bam_file_path (str):
8
8
  Returns:
@@ -26,6 +26,8 @@ def extract_read_features_from_bam(bam_file_path):
26
26
  reference_name = read.reference_name
27
27
  reference_index = bam_file.references.index(reference_name)
28
28
  reference_length = reference_lengths[reference_index]
29
- read_metrics[read.query_name] = [read.query_length, median_read_quality, reference_length]
29
+ mapped_length = sum(end - start for start, end in read.get_blocks())
30
+ mapping_quality = read.mapping_quality # Phred-scaled MAPQ
31
+ read_metrics[read.query_name] = [read.query_length, median_read_quality, reference_length, mapped_length, mapping_quality]
30
32
 
31
33
  return read_metrics
@@ -1,11 +1,12 @@
1
- def find_conversion_sites(fasta_file, modification_type, conversion_types):
1
+ def find_conversion_sites(fasta_file, modification_type, conversions, deaminase_footprinting=False):
2
2
  """
3
3
  Finds genomic coordinates of modified bases (5mC or 6mA) in a reference FASTA file.
4
4
 
5
5
  Parameters:
6
6
  fasta_file (str): Path to the converted reference FASTA.
7
7
  modification_type (str): Modification type ('5mC' or '6mA') or 'unconverted'.
8
- conversion_types (list): List of conversion types. The first element is the unconverted record type.
8
+ conversions (list): List of conversion types. The first element is the unconverted record type.
9
+ deaminase_footprinting (bool): Whether the footprinting was done with a direct deamination chemistry.
9
10
 
10
11
  Returns:
11
12
  dict: Dictionary where keys are **both unconverted & converted record names**.
@@ -14,7 +15,7 @@ def find_conversion_sites(fasta_file, modification_type, conversion_types):
14
15
  """
15
16
  import numpy as np
16
17
  from Bio import SeqIO
17
- unconverted = conversion_types[0]
18
+ unconverted = conversions[0]
18
19
  record_dict = {}
19
20
 
20
21
  # Define base mapping based on modification type
@@ -26,7 +27,7 @@ def find_conversion_sites(fasta_file, modification_type, conversion_types):
26
27
  # Read FASTA file and process records
27
28
  with open(fasta_file, "r") as f:
28
29
  for record in SeqIO.parse(f, "fasta"):
29
- if unconverted in record.id:
30
+ if unconverted in record.id or deaminase_footprinting:
30
31
  sequence = str(record.seq).upper()
31
32
  complement = str(record.seq.complement()).upper()
32
33
  sequence_length = len(sequence)
@@ -386,14 +386,15 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
386
386
  existing_h5s = [h5 for h5 in existing_h5s if '.h5ad.gz' in h5]
387
387
  final_hdf = f'{experiment_name}_final_experiment_hdf5.h5ad'
388
388
  final_adata_path = os.path.join(h5_dir, final_hdf)
389
+ final_adata = None
389
390
 
390
391
  if os.path.exists(f"{final_adata_path}.gz"):
391
392
  print(f'{final_adata_path}.gz already exists. Using existing adata')
392
- return f"{final_adata_path}.gz"
393
+ return final_adata, f"{final_adata_path}.gz"
393
394
 
394
395
  elif os.path.exists(f"{final_adata_path}"):
395
396
  print(f'{final_adata_path} already exists. Using existing adata')
396
- return final_adata_path
397
+ return final_adata, final_adata_path
397
398
 
398
399
  # Filter file names that contain the search string in their filename and keep them in a list
399
400
  tsvs = [tsv for tsv in tsv_files if 'extract.tsv' in tsv and 'unclassified' not in tsv]
@@ -444,8 +445,9 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
444
445
  for record in records_to_analyze:
445
446
  current_reference_length = reference_dict[record][0]
446
447
  positions = range(current_reference_length)
448
+ ref_seq = reference_dict[record][1]
447
449
  # Extract the base identities of reads aligned to the record
448
- fwd_base_identities, rev_base_identities = extract_base_identities(bam, record, positions, max_reference_length)
450
+ fwd_base_identities, rev_base_identities, mismatch_counts_per_read, mismatch_trend_per_read = extract_base_identities(bam, record, positions, max_reference_length, ref_seq)
449
451
  # Store read names of fwd and rev mapped reads
450
452
  fwd_mapped_reads.update(fwd_base_identities.keys())
451
453
  rev_mapped_reads.update(rev_base_identities.keys())
@@ -708,6 +710,7 @@ def modkit_extract_to_adata(fasta, bam_dir, mapping_threshold, experiment_name,
708
710
  temp_adata.var_names = temp_adata.var_names.astype(str)
709
711
  print('{0}: Adding {1} anndata for sample {2}'.format(readwrite.time_string(), sample_types[dict_index], final_sample_index))
710
712
  temp_adata.obs['Sample'] = [str(final_sample_index)] * len(temp_adata)
713
+ temp_adata.obs['Barcode'] = [str(final_sample_index)] * len(temp_adata)
711
714
  temp_adata.obs['Reference'] = [f'{record}'] * len(temp_adata)
712
715
  temp_adata.obs['Strand'] = [strand] * len(temp_adata)
713
716
  temp_adata.obs['Dataset'] = [dataset] * len(temp_adata)