smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,61 +1,65 @@
1
- import numpy as np
2
- import time
3
- import os
4
1
  import gc
5
- import pandas as pd
6
- import anndata as ad
7
- from tqdm import tqdm
8
2
  import multiprocessing
9
- from multiprocessing import Manager, Lock, current_process, Pool
3
+ import shutil
4
+ import time
10
5
  import traceback
11
- import gzip
6
+ from multiprocessing import Manager, Pool, current_process
7
+ from pathlib import Path
8
+ from typing import Iterable, Optional, Union
9
+
10
+ import anndata as ad
11
+ import numpy as np
12
+ import pandas as pd
12
13
  import torch
13
14
 
14
- import shutil
15
- from pathlib import Path
16
- from typing import Union, Iterable, Optional
15
+ from smftools.logging_utils import get_logger
17
16
 
18
- from ..readwrite import make_dirs, safe_write_h5ad
17
+ from ..readwrite import make_dirs
18
+ from .bam_functions import count_aligned_reads, extract_base_identities
19
19
  from .binarize_converted_base_identities import binarize_converted_base_identities
20
20
  from .fasta_functions import find_conversion_sites
21
- from .bam_functions import count_aligned_reads, extract_base_identities
22
21
  from .ohe import ohe_batching
23
22
 
23
+ logger = get_logger(__name__)
24
+
24
25
  if __name__ == "__main__":
25
26
  multiprocessing.set_start_method("forkserver", force=True)
26
27
 
27
- def converted_BAM_to_adata(converted_FASTA,
28
- split_dir,
29
- output_dir,
30
- input_already_demuxed,
31
- mapping_threshold,
32
- experiment_name,
33
- conversions,
34
- bam_suffix,
35
- device='cpu',
36
- num_threads=8,
37
- deaminase_footprinting=False,
38
- delete_intermediates=True,
39
- double_barcoded_path = None,
40
- ):
41
- """
42
- Converts BAM files into an AnnData object by binarizing modified base identities.
43
28
 
44
- Parameters:
45
- converted_FASTA (Path): Path to the converted FASTA reference.
46
- split_dir (Path): Directory containing converted BAM files.
47
- output_dir (Path): Directory of the output dir
48
- input_already_demuxed (bool): Whether input reads were originally demuxed
49
- mapping_threshold (float): Minimum fraction of aligned reads required for inclusion.
50
- experiment_name (str): Name for the output AnnData object.
51
- conversions (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
52
- bam_suffix (str): File suffix for BAM files.
53
- num_threads (int): Number of parallel processing threads.
54
- deaminase_footprinting (bool): Whether the footprinting was done with a direct deamination chemistry.
55
- double_barcoded_path (Path): Path to dorado demux summary file of double ended barcodes
29
+ def converted_BAM_to_adata(
30
+ converted_FASTA: str | Path,
31
+ split_dir: Path,
32
+ output_dir: Path,
33
+ input_already_demuxed: bool,
34
+ mapping_threshold: float,
35
+ experiment_name: str,
36
+ conversions: list[str],
37
+ bam_suffix: str,
38
+ device: str | torch.device = "cpu",
39
+ num_threads: int = 8,
40
+ deaminase_footprinting: bool = False,
41
+ delete_intermediates: bool = True,
42
+ double_barcoded_path: Path | None = None,
43
+ ) -> tuple[ad.AnnData | None, Path]:
44
+ """Convert BAM files into an AnnData object by binarizing modified base identities.
45
+
46
+ Args:
47
+ converted_FASTA: Path to the converted FASTA reference.
48
+ split_dir: Directory containing converted BAM files.
49
+ output_dir: Output directory for intermediate and final files.
50
+ input_already_demuxed: Whether input reads were originally demultiplexed.
51
+ mapping_threshold: Minimum fraction of aligned reads required for inclusion.
52
+ experiment_name: Name for the output AnnData object.
53
+ conversions: List of modification types (e.g., ``["unconverted", "5mC", "6mA"]``).
54
+ bam_suffix: File suffix for BAM files.
55
+ device: Torch device or device string.
56
+ num_threads: Number of parallel processing threads.
57
+ deaminase_footprinting: Whether the footprinting used direct deamination chemistry.
58
+ delete_intermediates: Whether to remove intermediate files after processing.
59
+ double_barcoded_path: Path to dorado demux summary file of double-ended barcodes.
56
60
 
57
61
  Returns:
58
- str: Path to the final AnnData object.
62
+ tuple[anndata.AnnData | None, Path]: The AnnData object (if generated) and its path.
59
63
  """
60
64
  if torch.cuda.is_available():
61
65
  device = torch.device("cuda")
@@ -64,69 +68,88 @@ def converted_BAM_to_adata(converted_FASTA,
64
68
  else:
65
69
  device = torch.device("cpu")
66
70
 
67
- print(f"Using device: {device}")
71
+ logger.debug(f"Using device: {device}")
68
72
 
69
73
  ## Set Up Directories and File Paths
70
- h5_dir = output_dir / 'h5ads'
71
- tmp_dir = output_dir / 'tmp'
74
+ h5_dir = output_dir / "h5ads"
75
+ tmp_dir = output_dir / "tmp"
72
76
  final_adata = None
73
- final_adata_path = h5_dir / f'{experiment_name}.h5ad.gz'
77
+ final_adata_path = h5_dir / f"{experiment_name}.h5ad.gz"
74
78
 
75
79
  if final_adata_path.exists():
76
- print(f"{final_adata_path} already exists. Using existing AnnData object.")
80
+ logger.debug(f"{final_adata_path} already exists. Using existing AnnData object.")
77
81
  return final_adata, final_adata_path
78
82
 
79
83
  make_dirs([h5_dir, tmp_dir])
80
84
 
81
85
  bam_files = sorted(
82
- p for p in split_dir.iterdir()
83
- if p.is_file()
84
- and p.suffix == ".bam"
85
- and "unclassified" not in p.name
86
+ p
87
+ for p in split_dir.iterdir()
88
+ if p.is_file() and p.suffix == ".bam" and "unclassified" not in p.name
86
89
  )
87
90
 
88
- bam_path_list = [split_dir / f for f in bam_files]
89
- print(f"Found {len(bam_files)} BAM files: {bam_files}")
91
+ bam_path_list = bam_files
92
+ logger.info(f"Found {len(bam_files)} BAM files: {bam_files}")
90
93
 
91
94
  ## Process Conversion Sites
92
- max_reference_length, record_FASTA_dict, chromosome_FASTA_dict = process_conversion_sites(converted_FASTA, conversions, deaminase_footprinting)
95
+ max_reference_length, record_FASTA_dict, chromosome_FASTA_dict = process_conversion_sites(
96
+ converted_FASTA, conversions, deaminase_footprinting
97
+ )
93
98
 
94
99
  ## Filter BAM Files by Mapping Threshold
95
- records_to_analyze = filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold)
100
+ records_to_analyze = filter_bams_by_mapping_threshold(
101
+ bam_path_list, bam_files, mapping_threshold
102
+ )
96
103
 
97
104
  ## Process BAMs in Parallel
98
- final_adata = process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device, deaminase_footprinting)
105
+ final_adata = process_bams_parallel(
106
+ bam_path_list,
107
+ records_to_analyze,
108
+ record_FASTA_dict,
109
+ chromosome_FASTA_dict,
110
+ tmp_dir,
111
+ h5_dir,
112
+ num_threads,
113
+ max_reference_length,
114
+ device,
115
+ deaminase_footprinting,
116
+ )
99
117
 
100
- final_adata.uns['References'] = {}
118
+ final_adata.uns["References"] = {}
101
119
  for chromosome, [seq, comp] in chromosome_FASTA_dict.items():
102
- final_adata.var[f'{chromosome}_top_strand_FASTA_base'] = list(seq)
103
- final_adata.var[f'{chromosome}_bottom_strand_FASTA_base'] = list(comp)
104
- final_adata.uns[f'{chromosome}_FASTA_sequence'] = seq
105
- final_adata.uns['References'][f'{chromosome}_FASTA_sequence'] = seq
120
+ final_adata.var[f"{chromosome}_top_strand_FASTA_base"] = list(seq)
121
+ final_adata.var[f"{chromosome}_bottom_strand_FASTA_base"] = list(comp)
122
+ final_adata.uns[f"{chromosome}_FASTA_sequence"] = seq
123
+ final_adata.uns["References"][f"{chromosome}_FASTA_sequence"] = seq
106
124
 
107
125
  final_adata.obs_names_make_unique()
108
126
  cols = final_adata.obs.columns
109
127
 
110
128
  # Make obs cols categorical
111
129
  for col in cols:
112
- final_adata.obs[col] = final_adata.obs[col].astype('category')
130
+ final_adata.obs[col] = final_adata.obs[col].astype("category")
113
131
 
114
132
  if input_already_demuxed:
115
133
  final_adata.obs["demux_type"] = ["already"] * final_adata.shape[0]
116
134
  final_adata.obs["demux_type"] = final_adata.obs["demux_type"].astype("category")
117
135
  else:
118
136
  from .h5ad_functions import add_demux_type_annotation
137
+
119
138
  double_barcoded_reads = double_barcoded_path / "barcoding_summary.txt"
139
+ logger.info("Adding demux type to each read")
120
140
  add_demux_type_annotation(final_adata, double_barcoded_reads)
121
141
 
122
142
  ## Delete intermediate h5ad files and temp directories
123
143
  if delete_intermediates:
144
+ logger.info("Deleting intermediate h5ad files")
124
145
  delete_intermediate_h5ads_and_tmpdir(h5_dir, tmp_dir)
125
-
146
+
126
147
  return final_adata, final_adata_path
127
148
 
128
149
 
129
- def process_conversion_sites(converted_FASTA, conversions=['unconverted', '5mC'], deaminase_footprinting=False):
150
+ def process_conversion_sites(
151
+ converted_FASTA, conversions=["unconverted", "5mC"], deaminase_footprinting=False
152
+ ):
130
153
  """
131
154
  Extracts conversion sites and determines the max reference length.
132
155
 
@@ -147,7 +170,9 @@ def process_conversion_sites(converted_FASTA, conversions=['unconverted', '5mC']
147
170
  conversion_types = conversions[1:]
148
171
 
149
172
  # Process the unconverted sequence once
150
- modification_dict[unconverted] = find_conversion_sites(converted_FASTA, unconverted, conversions, deaminase_footprinting)
173
+ modification_dict[unconverted] = find_conversion_sites(
174
+ converted_FASTA, unconverted, conversions, deaminase_footprinting
175
+ )
151
176
  # Above points to record_dict[record.id] = [sequence_length, [], [], sequence, complement] with only unconverted record.id keys
152
177
 
153
178
  # Get **max sequence length** from unconverted records
@@ -166,15 +191,25 @@ def process_conversion_sites(converted_FASTA, conversions=['unconverted', '5mC']
166
191
  record_FASTA_dict[record] = [
167
192
  sequence + "N" * (max_reference_length - sequence_length),
168
193
  complement + "N" * (max_reference_length - sequence_length),
169
- chromosome, record, sequence_length, max_reference_length - sequence_length, unconverted, "top"
194
+ chromosome,
195
+ record,
196
+ sequence_length,
197
+ max_reference_length - sequence_length,
198
+ unconverted,
199
+ "top",
170
200
  ]
171
201
 
172
202
  if chromosome not in chromosome_FASTA_dict:
173
- chromosome_FASTA_dict[chromosome] = [sequence + "N" * (max_reference_length - sequence_length), complement + "N" * (max_reference_length - sequence_length)]
203
+ chromosome_FASTA_dict[chromosome] = [
204
+ sequence + "N" * (max_reference_length - sequence_length),
205
+ complement + "N" * (max_reference_length - sequence_length),
206
+ ]
174
207
 
175
208
  # Process converted records
176
209
  for conversion in conversion_types:
177
- modification_dict[conversion] = find_conversion_sites(converted_FASTA, conversion, conversions, deaminase_footprinting)
210
+ modification_dict[conversion] = find_conversion_sites(
211
+ converted_FASTA, conversion, conversions, deaminase_footprinting
212
+ )
178
213
  # Above points to record_dict[record.id] = [sequence_length, top_strand_coordinates, bottom_strand_coordinates, sequence, complement] with only unconverted record.id keys
179
214
 
180
215
  for record, values in modification_dict[conversion].items():
@@ -193,11 +228,15 @@ def process_conversion_sites(converted_FASTA, conversions=['unconverted', '5mC']
193
228
  record_FASTA_dict[converted_name] = [
194
229
  sequence + "N" * (max_reference_length - sequence_length),
195
230
  complement + "N" * (max_reference_length - sequence_length),
196
- chromosome, unconverted_name, sequence_length,
197
- max_reference_length - sequence_length, conversion, strand
231
+ chromosome,
232
+ unconverted_name,
233
+ sequence_length,
234
+ max_reference_length - sequence_length,
235
+ conversion,
236
+ strand,
198
237
  ]
199
238
 
200
- print("Updated record_FASTA_dict Keys:", list(record_FASTA_dict.keys()))
239
+ logger.debug("Updated record_FASTA_dict Keys:", list(record_FASTA_dict.keys()))
201
240
  return max_reference_length, record_FASTA_dict, chromosome_FASTA_dict
202
241
 
203
242
 
@@ -214,11 +253,21 @@ def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold
214
253
  if percent >= mapping_threshold:
215
254
  records_to_analyze.add(record)
216
255
 
217
- print(f"Analyzing the following FASTA records: {records_to_analyze}")
256
+ logger.info(f"Analyzing the following FASTA records: {records_to_analyze}")
218
257
  return records_to_analyze
219
258
 
220
259
 
221
- def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, max_reference_length, device, deaminase_footprinting):
260
+ def process_single_bam(
261
+ bam_index,
262
+ bam,
263
+ records_to_analyze,
264
+ record_FASTA_dict,
265
+ chromosome_FASTA_dict,
266
+ tmp_dir,
267
+ max_reference_length,
268
+ device,
269
+ deaminase_footprinting,
270
+ ):
222
271
  """Worker function to process a single BAM file (must be at top-level for multiprocessing)."""
223
272
  adata_list = []
224
273
 
@@ -230,34 +279,58 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
230
279
  sequence = chromosome_FASTA_dict[chromosome][0]
231
280
 
232
281
  # Extract Base Identities
233
- fwd_bases, rev_bases, mismatch_counts_per_read, mismatch_trend_per_read = extract_base_identities(bam, record, range(current_length), max_reference_length, sequence)
282
+ fwd_bases, rev_bases, mismatch_counts_per_read, mismatch_trend_per_read = (
283
+ extract_base_identities(
284
+ bam, record, range(current_length), max_reference_length, sequence
285
+ )
286
+ )
234
287
  mismatch_trend_series = pd.Series(mismatch_trend_per_read)
235
288
 
236
289
  # Skip processing if both forward and reverse base identities are empty
237
290
  if not fwd_bases and not rev_bases:
238
- print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid base identities for {record}.")
291
+ logger.debug(
292
+ f"[Worker {current_process().pid}] Skipping {sample} - No valid base identities for {record}."
293
+ )
239
294
  continue
240
295
 
241
296
  merged_bin = {}
242
297
 
243
298
  # Binarize the Base Identities if they exist
244
299
  if fwd_bases:
245
- fwd_bin = binarize_converted_base_identities(fwd_bases, strand, mod_type, bam, device,deaminase_footprinting, mismatch_trend_per_read)
300
+ fwd_bin = binarize_converted_base_identities(
301
+ fwd_bases,
302
+ strand,
303
+ mod_type,
304
+ bam,
305
+ device,
306
+ deaminase_footprinting,
307
+ mismatch_trend_per_read,
308
+ )
246
309
  merged_bin.update(fwd_bin)
247
310
 
248
311
  if rev_bases:
249
- rev_bin = binarize_converted_base_identities(rev_bases, strand, mod_type, bam, device, deaminase_footprinting, mismatch_trend_per_read)
312
+ rev_bin = binarize_converted_base_identities(
313
+ rev_bases,
314
+ strand,
315
+ mod_type,
316
+ bam,
317
+ device,
318
+ deaminase_footprinting,
319
+ mismatch_trend_per_read,
320
+ )
250
321
  merged_bin.update(rev_bin)
251
322
 
252
323
  # Skip if merged_bin is empty (no valid binarized data)
253
324
  if not merged_bin:
254
- print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid binarized data for {record}.")
325
+ logger.debug(
326
+ f"[Worker {current_process().pid}] Skipping {sample} - No valid binarized data for {record}."
327
+ )
255
328
  continue
256
329
 
257
330
  # Convert to DataFrame
258
331
  # for key in merged_bin:
259
332
  # merged_bin[key] = merged_bin[key].cpu().numpy() # Move to CPU & convert to NumPy
260
- bin_df = pd.DataFrame.from_dict(merged_bin, orient='index')
333
+ bin_df = pd.DataFrame.from_dict(merged_bin, orient="index")
261
334
  sorted_index = sorted(bin_df.index)
262
335
  bin_df = bin_df.reindex(sorted_index)
263
336
 
@@ -265,14 +338,18 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
265
338
  one_hot_reads = {}
266
339
 
267
340
  if fwd_bases:
268
- fwd_ohe_files = ohe_batching(fwd_bases, tmp_dir, record, f"{bam_index}_fwd", batch_size=100000)
341
+ fwd_ohe_files = ohe_batching(
342
+ fwd_bases, tmp_dir, record, f"{bam_index}_fwd", batch_size=100000
343
+ )
269
344
  for ohe_file in fwd_ohe_files:
270
345
  tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
271
346
  one_hot_reads.update(tmp_ohe_dict)
272
347
  del tmp_ohe_dict
273
348
 
274
349
  if rev_bases:
275
- rev_ohe_files = ohe_batching(rev_bases, tmp_dir, record, f"{bam_index}_rev", batch_size=100000)
350
+ rev_ohe_files = ohe_batching(
351
+ rev_bases, tmp_dir, record, f"{bam_index}_rev", batch_size=100000
352
+ )
276
353
  for ohe_file in rev_ohe_files:
277
354
  tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
278
355
  one_hot_reads.update(tmp_ohe_dict)
@@ -280,7 +357,9 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
280
357
 
281
358
  # Skip if one_hot_reads is empty
282
359
  if not one_hot_reads:
283
- print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No valid one-hot encoded data for {record}.")
360
+ logger.debug(
361
+ f"[Worker {current_process().pid}] Skipping {sample} - No valid one-hot encoded data for {record}."
362
+ )
284
363
  continue
285
364
 
286
365
  gc.collect()
@@ -291,11 +370,15 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
291
370
 
292
371
  # Skip if no read names exist
293
372
  if not read_names:
294
- print(f"{timestamp()} [Worker {current_process().pid}] Skipping {sample} - No reads found in one-hot encoded data for {record}.")
373
+ logger.debug(
374
+ f"[Worker {current_process().pid}] Skipping {sample} - No reads found in one-hot encoded data for {record}."
375
+ )
295
376
  continue
296
377
 
297
378
  sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
298
- df_A, df_C, df_G, df_T, df_N = [np.zeros((len(sorted_index), sequence_length), dtype=int) for _ in range(5)]
379
+ df_A, df_C, df_G, df_T, df_N = [
380
+ np.zeros((len(sorted_index), sequence_length), dtype=int) for _ in range(5)
381
+ ]
299
382
 
300
383
  # Populate One-Hot Arrays
301
384
  for j, read_name in enumerate(sorted_index):
@@ -310,8 +393,8 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
310
393
  adata.var_names = bin_df.columns.astype(str)
311
394
  adata.obs["Sample"] = [sample] * len(adata)
312
395
  try:
313
- barcode = sample.split('barcode')[1]
314
- except:
396
+ barcode = sample.split("barcode")[1]
397
+ except Exception:
315
398
  barcode = np.nan
316
399
  adata.obs["Barcode"] = [int(barcode)] * len(adata)
317
400
  adata.obs["Barcode"] = adata.obs["Barcode"].astype(str)
@@ -323,49 +406,76 @@ def process_single_bam(bam_index, bam, records_to_analyze, record_FASTA_dict, ch
323
406
  adata.obs["Read_mismatch_trend"] = adata.obs_names.map(mismatch_trend_series)
324
407
 
325
408
  # Attach One-Hot Encodings to Layers
326
- adata.layers["A_binary_encoding"] = df_A
327
- adata.layers["C_binary_encoding"] = df_C
328
- adata.layers["G_binary_encoding"] = df_G
329
- adata.layers["T_binary_encoding"] = df_T
330
- adata.layers["N_binary_encoding"] = df_N
409
+ adata.layers["A_binary_sequence_encoding"] = df_A
410
+ adata.layers["C_binary_sequence_encoding"] = df_C
411
+ adata.layers["G_binary_sequence_encoding"] = df_G
412
+ adata.layers["T_binary_sequence_encoding"] = df_T
413
+ adata.layers["N_binary_sequence_encoding"] = df_N
331
414
 
332
415
  adata_list.append(adata)
333
416
 
334
417
  return ad.concat(adata_list, join="outer") if adata_list else None
335
418
 
419
+
336
420
  def timestamp():
337
421
  """Returns a formatted timestamp for logging."""
338
422
  return time.strftime("[%Y-%m-%d %H:%M:%S]")
339
423
 
340
424
 
341
- def worker_function(bam_index, bam, records_to_analyze, shared_record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, deaminase_footprinting, progress_queue):
425
+ def worker_function(
426
+ bam_index,
427
+ bam,
428
+ records_to_analyze,
429
+ shared_record_FASTA_dict,
430
+ chromosome_FASTA_dict,
431
+ tmp_dir,
432
+ h5_dir,
433
+ max_reference_length,
434
+ device,
435
+ deaminase_footprinting,
436
+ progress_queue,
437
+ ):
342
438
  """Worker function that processes a single BAM and writes the output to an H5AD file."""
343
439
  worker_id = current_process().pid # Get worker process ID
344
440
  sample = bam.stem
345
441
 
346
442
  try:
347
- print(f"{timestamp()} [Worker {worker_id}] Processing BAM: {sample}")
443
+ logger.info(f"[Worker {worker_id}] Processing BAM: {sample}")
348
444
 
349
445
  h5ad_path = h5_dir / bam.with_suffix(".h5ad").name
350
446
  if h5ad_path.exists():
351
- print(f"{timestamp()} [Worker {worker_id}] Skipping {sample}: Already processed.")
447
+ logger.debug(f"[Worker {worker_id}] Skipping {sample}: Already processed.")
352
448
  progress_queue.put(sample)
353
449
  return
354
450
 
355
451
  # Filter records specific to this BAM
356
- bam_records_to_analyze = {record for record in records_to_analyze if record in shared_record_FASTA_dict}
452
+ bam_records_to_analyze = {
453
+ record for record in records_to_analyze if record in shared_record_FASTA_dict
454
+ }
357
455
 
358
456
  if not bam_records_to_analyze:
359
- print(f"{timestamp()} [Worker {worker_id}] No valid records to analyze for {sample}. Skipping.")
457
+ logger.debug(
458
+ f"[Worker {worker_id}] No valid records to analyze for {sample}. Skipping."
459
+ )
360
460
  progress_queue.put(sample)
361
461
  return
362
462
 
363
463
  # Process BAM
364
- adata = process_single_bam(bam_index, bam, bam_records_to_analyze, shared_record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, max_reference_length, device, deaminase_footprinting)
464
+ adata = process_single_bam(
465
+ bam_index,
466
+ bam,
467
+ bam_records_to_analyze,
468
+ shared_record_FASTA_dict,
469
+ chromosome_FASTA_dict,
470
+ tmp_dir,
471
+ max_reference_length,
472
+ device,
473
+ deaminase_footprinting,
474
+ )
365
475
 
366
476
  if adata is not None:
367
477
  adata.write_h5ad(str(h5ad_path))
368
- print(f"{timestamp()} [Worker {worker_id}] Completed processing for BAM: {sample}")
478
+ logger.info(f"[Worker {worker_id}] Completed processing for BAM: {sample}")
369
479
 
370
480
  # Free memory
371
481
  del adata
@@ -373,22 +483,37 @@ def worker_function(bam_index, bam, records_to_analyze, shared_record_FASTA_dict
373
483
 
374
484
  progress_queue.put(sample)
375
485
 
376
- except Exception as e:
377
- print(f"{timestamp()} [Worker {worker_id}] ERROR while processing {sample}:\n{traceback.format_exc()}")
486
+ except Exception:
487
+ logger.warning(
488
+ f"[Worker {worker_id}] ERROR while processing {sample}:\n{traceback.format_exc()}"
489
+ )
378
490
  progress_queue.put(sample) # Still signal completion to prevent deadlock
379
491
 
380
- def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, num_threads, max_reference_length, device, deaminase_footprinting):
492
+
493
+ def process_bams_parallel(
494
+ bam_path_list,
495
+ records_to_analyze,
496
+ record_FASTA_dict,
497
+ chromosome_FASTA_dict,
498
+ tmp_dir,
499
+ h5_dir,
500
+ num_threads,
501
+ max_reference_length,
502
+ device,
503
+ deaminase_footprinting,
504
+ ):
381
505
  """Processes BAM files in parallel, writes each H5AD to disk, and concatenates them at the end."""
382
506
  make_dirs(h5_dir) # Ensure h5_dir exists
383
507
 
384
- print(f"{timestamp()} Starting parallel BAM processing with {num_threads} threads...")
508
+ logger.info(f"Starting parallel BAM processing with {num_threads} threads...")
385
509
 
386
510
  # Ensure macOS uses forkserver to avoid spawning issues
387
511
  try:
388
512
  import multiprocessing
513
+
389
514
  multiprocessing.set_start_method("forkserver", force=True)
390
515
  except RuntimeError:
391
- print(f"{timestamp()} [WARNING] Multiprocessing context already set. Skipping set_start_method.")
516
+ logger.warning(f"Multiprocessing context already set. Skipping set_start_method.")
392
517
 
393
518
  with Manager() as manager:
394
519
  progress_queue = manager.Queue()
@@ -396,11 +521,26 @@ def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict,
396
521
 
397
522
  with Pool(processes=num_threads) as pool:
398
523
  results = [
399
- pool.apply_async(worker_function, (i, bam, records_to_analyze, shared_record_FASTA_dict, chromosome_FASTA_dict, tmp_dir, h5_dir, max_reference_length, device, deaminase_footprinting, progress_queue))
524
+ pool.apply_async(
525
+ worker_function,
526
+ (
527
+ i,
528
+ bam,
529
+ records_to_analyze,
530
+ shared_record_FASTA_dict,
531
+ chromosome_FASTA_dict,
532
+ tmp_dir,
533
+ h5_dir,
534
+ max_reference_length,
535
+ device,
536
+ deaminase_footprinting,
537
+ progress_queue,
538
+ ),
539
+ )
400
540
  for i, bam in enumerate(bam_path_list)
401
541
  ]
402
542
 
403
- print(f"{timestamp()} Submitted {len(bam_path_list)} BAMs for processing.")
543
+ logger.info(f"Submitted {len(bam_path_list)} BAMs for processing.")
404
544
 
405
545
  # Track completed BAMs
406
546
  completed_bams = set()
@@ -409,24 +549,25 @@ def process_bams_parallel(bam_path_list, records_to_analyze, record_FASTA_dict,
409
549
  processed_bam = progress_queue.get(timeout=2400) # Wait for a finished BAM
410
550
  completed_bams.add(processed_bam)
411
551
  except Exception as e:
412
- print(f"{timestamp()} [ERROR] Timeout waiting for worker process. Possible crash? {e}")
552
+ logger.error(f"Timeout waiting for worker process. Possible crash? {e}")
413
553
 
414
554
  pool.close()
415
555
  pool.join() # Ensure all workers finish
416
556
 
417
557
  # Final Concatenation Step
418
- h5ad_files = [h5_dir / f for f in h5_dir.iterdir() if f.suffix == ".h5ad"]
558
+ h5ad_files = [f for f in h5_dir.iterdir() if f.suffix == ".h5ad"]
419
559
 
420
560
  if not h5ad_files:
421
- print(f"{timestamp()} No valid H5AD files generated. Exiting.")
561
+ logger.debug(f"No valid H5AD files generated. Exiting.")
422
562
  return None
423
563
 
424
- print(f"{timestamp()} Concatenating {len(h5ad_files)} H5AD files into final output...")
564
+ logger.info(f"Concatenating {len(h5ad_files)} H5AD files into final output...")
425
565
  final_adata = ad.concat([ad.read_h5ad(f) for f in h5ad_files], join="outer")
426
566
 
427
- print(f"{timestamp()} Successfully generated final AnnData object.")
567
+ logger.info(f"Successfully generated final AnnData object.")
428
568
  return final_adata
429
569
 
570
+
430
571
  def delete_intermediate_h5ads_and_tmpdir(
431
572
  h5_dir: Union[str, Path, Iterable[str], None],
432
573
  tmp_dir: Optional[Union[str, Path]] = None,
@@ -450,25 +591,27 @@ def delete_intermediate_h5ads_and_tmpdir(
450
591
  verbose : bool
451
592
  Print progress / warnings.
452
593
  """
594
+
453
595
  # Helper: remove a single file path (Path-like or string)
454
596
  def _maybe_unlink(p: Path):
597
+ """Remove a file path if it exists and is a file."""
455
598
  if not p.exists():
456
599
  if verbose:
457
- print(f"[skip] not found: {p}")
600
+ logger.debug(f"[skip] not found: {p}")
458
601
  return
459
602
  if not p.is_file():
460
603
  if verbose:
461
- print(f"[skip] not a file: {p}")
604
+ logger.debug(f"[skip] not a file: {p}")
462
605
  return
463
606
  if dry_run:
464
- print(f"[dry-run] would remove file: {p}")
607
+ logger.debug(f"[dry-run] would remove file: {p}")
465
608
  return
466
609
  try:
467
610
  p.unlink()
468
611
  if verbose:
469
- print(f"Removed file: {p}")
612
+ logger.info(f"Removed file: {p}")
470
613
  except Exception as e:
471
- print(f"[error] failed to remove file {p}: {e}")
614
+ logger.warning(f"[error] failed to remove file {p}: {e}")
472
615
 
473
616
  # Handle h5_dir input (directory OR iterable of file paths)
474
617
  if h5_dir is not None:
@@ -483,7 +626,7 @@ def delete_intermediate_h5ads_and_tmpdir(
483
626
  else:
484
627
  if verbose:
485
628
  # optional: comment this out if too noisy
486
- print(f"[skip] not matching pattern: {p.name}")
629
+ logger.debug(f"[skip] not matching pattern: {p.name}")
487
630
  else:
488
631
  # treat as iterable of file paths
489
632
  for f in h5_dir:
@@ -493,25 +636,25 @@ def delete_intermediate_h5ads_and_tmpdir(
493
636
  _maybe_unlink(p)
494
637
  else:
495
638
  if verbose:
496
- print(f"[skip] not matching pattern or not a file: {p}")
639
+ logger.debug(f"[skip] not matching pattern or not a file: {p}")
497
640
 
498
641
  # Remove tmp_dir recursively (if provided)
499
642
  if tmp_dir is not None:
500
643
  td = Path(tmp_dir)
501
644
  if not td.exists():
502
645
  if verbose:
503
- print(f"[skip] tmp_dir not found: {td}")
646
+ logger.debug(f"[skip] tmp_dir not found: {td}")
504
647
  else:
505
648
  if not td.is_dir():
506
649
  if verbose:
507
- print(f"[skip] tmp_dir is not a directory: {td}")
650
+ logger.debug(f"[skip] tmp_dir is not a directory: {td}")
508
651
  else:
509
652
  if dry_run:
510
- print(f"[dry-run] would remove directory tree: {td}")
653
+ logger.debug(f"[dry-run] would remove directory tree: {td}")
511
654
  else:
512
655
  try:
513
656
  shutil.rmtree(td)
514
657
  if verbose:
515
- print(f"Removed directory tree: {td}")
658
+ logger.info(f"Removed directory tree: {td}")
516
659
  except Exception as e:
517
- print(f"[error] failed to remove tmp dir {td}: {e}")
660
+ logger.warning(f"[error] failed to remove tmp dir {td}: {e}")