smftools 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. smftools/__init__.py +39 -7
  2. smftools/_settings.py +2 -0
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +34 -6
  7. smftools/cli/hmm_adata.py +239 -33
  8. smftools/cli/latent_adata.py +318 -0
  9. smftools/cli/load_adata.py +167 -131
  10. smftools/cli/preprocess_adata.py +180 -53
  11. smftools/cli/spatial_adata.py +152 -100
  12. smftools/cli_entry.py +38 -1
  13. smftools/config/__init__.py +2 -0
  14. smftools/config/conversion.yaml +11 -1
  15. smftools/config/default.yaml +42 -2
  16. smftools/config/experiment_config.py +59 -1
  17. smftools/constants.py +65 -0
  18. smftools/datasets/__init__.py +2 -0
  19. smftools/hmm/HMM.py +97 -3
  20. smftools/hmm/__init__.py +24 -13
  21. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  22. smftools/hmm/archived/calculate_distances.py +2 -0
  23. smftools/hmm/archived/call_hmm_peaks.py +2 -0
  24. smftools/hmm/archived/train_hmm.py +2 -0
  25. smftools/hmm/call_hmm_peaks.py +5 -2
  26. smftools/hmm/display_hmm.py +4 -1
  27. smftools/hmm/hmm_readwrite.py +7 -2
  28. smftools/hmm/nucleosome_hmm_refinement.py +2 -0
  29. smftools/informatics/__init__.py +59 -34
  30. smftools/informatics/archived/bam_conversion.py +2 -0
  31. smftools/informatics/archived/bam_direct.py +2 -0
  32. smftools/informatics/archived/basecall_pod5s.py +2 -0
  33. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  34. smftools/informatics/archived/conversion_smf.py +2 -0
  35. smftools/informatics/archived/deaminase_smf.py +1 -0
  36. smftools/informatics/archived/direct_smf.py +2 -0
  37. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  38. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  39. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
  40. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  41. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  42. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  43. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  44. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  45. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  46. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  47. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  48. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  49. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  50. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  52. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  53. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  54. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  55. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  56. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  57. smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
  58. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  59. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  60. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  61. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  62. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  63. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  64. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  65. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
  66. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  67. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  68. smftools/informatics/archived/print_bam_query_seq.py +2 -0
  69. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  70. smftools/informatics/archived/subsample_pod5.py +2 -0
  71. smftools/informatics/bam_functions.py +1093 -176
  72. smftools/informatics/basecalling.py +2 -0
  73. smftools/informatics/bed_functions.py +271 -61
  74. smftools/informatics/binarize_converted_base_identities.py +3 -0
  75. smftools/informatics/complement_base_list.py +2 -0
  76. smftools/informatics/converted_BAM_to_adata.py +641 -176
  77. smftools/informatics/fasta_functions.py +94 -10
  78. smftools/informatics/h5ad_functions.py +123 -4
  79. smftools/informatics/modkit_extract_to_adata.py +1019 -431
  80. smftools/informatics/modkit_functions.py +2 -0
  81. smftools/informatics/ohe.py +2 -0
  82. smftools/informatics/pod5_functions.py +3 -2
  83. smftools/informatics/sequence_encoding.py +72 -0
  84. smftools/logging_utils.py +21 -2
  85. smftools/machine_learning/__init__.py +22 -6
  86. smftools/machine_learning/data/__init__.py +2 -0
  87. smftools/machine_learning/data/anndata_data_module.py +18 -4
  88. smftools/machine_learning/data/preprocessing.py +2 -0
  89. smftools/machine_learning/evaluation/__init__.py +2 -0
  90. smftools/machine_learning/evaluation/eval_utils.py +2 -0
  91. smftools/machine_learning/evaluation/evaluators.py +14 -9
  92. smftools/machine_learning/inference/__init__.py +2 -0
  93. smftools/machine_learning/inference/inference_utils.py +2 -0
  94. smftools/machine_learning/inference/lightning_inference.py +6 -1
  95. smftools/machine_learning/inference/sklearn_inference.py +2 -0
  96. smftools/machine_learning/inference/sliding_window_inference.py +2 -0
  97. smftools/machine_learning/models/__init__.py +2 -0
  98. smftools/machine_learning/models/base.py +7 -2
  99. smftools/machine_learning/models/cnn.py +7 -2
  100. smftools/machine_learning/models/lightning_base.py +16 -11
  101. smftools/machine_learning/models/mlp.py +5 -1
  102. smftools/machine_learning/models/positional.py +7 -2
  103. smftools/machine_learning/models/rnn.py +5 -1
  104. smftools/machine_learning/models/sklearn_models.py +14 -9
  105. smftools/machine_learning/models/transformer.py +7 -2
  106. smftools/machine_learning/models/wrappers.py +6 -2
  107. smftools/machine_learning/training/__init__.py +2 -0
  108. smftools/machine_learning/training/train_lightning_model.py +13 -3
  109. smftools/machine_learning/training/train_sklearn_model.py +2 -0
  110. smftools/machine_learning/utils/__init__.py +2 -0
  111. smftools/machine_learning/utils/device.py +5 -1
  112. smftools/machine_learning/utils/grl.py +5 -1
  113. smftools/metadata.py +1 -1
  114. smftools/optional_imports.py +31 -0
  115. smftools/plotting/__init__.py +41 -31
  116. smftools/plotting/autocorrelation_plotting.py +9 -5
  117. smftools/plotting/classifiers.py +16 -4
  118. smftools/plotting/general_plotting.py +2415 -629
  119. smftools/plotting/hmm_plotting.py +97 -9
  120. smftools/plotting/position_stats.py +15 -7
  121. smftools/plotting/qc_plotting.py +6 -1
  122. smftools/preprocessing/__init__.py +36 -37
  123. smftools/preprocessing/append_base_context.py +17 -17
  124. smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
  125. smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
  126. smftools/preprocessing/archived/calculate_complexity.py +2 -0
  127. smftools/preprocessing/archived/mark_duplicates.py +2 -0
  128. smftools/preprocessing/archived/preprocessing.py +2 -0
  129. smftools/preprocessing/archived/remove_duplicates.py +2 -0
  130. smftools/preprocessing/binary_layers_to_ohe.py +2 -1
  131. smftools/preprocessing/calculate_complexity_II.py +4 -1
  132. smftools/preprocessing/calculate_consensus.py +1 -1
  133. smftools/preprocessing/calculate_pairwise_differences.py +2 -0
  134. smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
  135. smftools/preprocessing/calculate_position_Youden.py +9 -2
  136. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  137. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
  138. smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
  139. smftools/preprocessing/flag_duplicate_reads.py +42 -54
  140. smftools/preprocessing/make_dirs.py +2 -1
  141. smftools/preprocessing/min_non_diagonal.py +2 -0
  142. smftools/preprocessing/recipes.py +2 -0
  143. smftools/readwrite.py +53 -17
  144. smftools/schema/anndata_schema_v1.yaml +15 -1
  145. smftools/tools/__init__.py +30 -18
  146. smftools/tools/archived/apply_hmm.py +2 -0
  147. smftools/tools/archived/classifiers.py +2 -0
  148. smftools/tools/archived/classify_methylated_features.py +2 -0
  149. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  150. smftools/tools/archived/subset_adata_v1.py +2 -0
  151. smftools/tools/archived/subset_adata_v2.py +2 -0
  152. smftools/tools/calculate_leiden.py +57 -0
  153. smftools/tools/calculate_nmf.py +119 -0
  154. smftools/tools/calculate_umap.py +93 -8
  155. smftools/tools/cluster_adata_on_methylation.py +7 -1
  156. smftools/tools/position_stats.py +17 -27
  157. smftools/tools/rolling_nn_distance.py +235 -0
  158. smftools/tools/tensor_factorization.py +169 -0
  159. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/METADATA +69 -33
  160. smftools-0.3.1.dist-info/RECORD +189 -0
  161. smftools-0.2.5.dist-info/RECORD +0 -181
  162. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
  163. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
  164. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,29 +1,117 @@
1
+ from __future__ import annotations
2
+
1
3
  import gc
2
- import multiprocessing
4
+ import logging
3
5
  import shutil
4
6
  import time
5
7
  import traceback
8
+ from dataclasses import dataclass
6
9
  from multiprocessing import Manager, Pool, current_process
7
10
  from pathlib import Path
8
- from typing import Iterable, Optional, Union
11
+ from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Union
9
12
 
10
13
  import anndata as ad
11
14
  import numpy as np
12
15
  import pandas as pd
13
- import torch
14
16
 
15
- from smftools.logging_utils import get_logger
17
+ from smftools.constants import (
18
+ BAM_SUFFIX,
19
+ BARCODE,
20
+ BASE_QUALITY_SCORES,
21
+ DATASET,
22
+ DEMUX_TYPE,
23
+ H5_DIR,
24
+ MISMATCH_INTEGER_ENCODING,
25
+ MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
26
+ MODKIT_EXTRACT_SEQUENCE_BASES,
27
+ MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE,
28
+ MODKIT_EXTRACT_SEQUENCE_PADDING_BASE,
29
+ READ_MAPPING_DIRECTION,
30
+ READ_MISMATCH_TREND,
31
+ READ_SPAN_MASK,
32
+ REFERENCE,
33
+ REFERENCE_DATASET_STRAND,
34
+ REFERENCE_STRAND,
35
+ SAMPLE,
36
+ SEQUENCE_INTEGER_DECODING,
37
+ SEQUENCE_INTEGER_ENCODING,
38
+ STRAND,
39
+ )
40
+ from smftools.logging_utils import get_logger, setup_logging
41
+ from smftools.optional_imports import require
16
42
 
17
43
  from ..readwrite import make_dirs
18
44
  from .bam_functions import count_aligned_reads, extract_base_identities
19
45
  from .binarize_converted_base_identities import binarize_converted_base_identities
20
- from .fasta_functions import find_conversion_sites
21
- from .ohe import ohe_batching
46
+ from .fasta_functions import find_conversion_sites, get_native_references
22
47
 
23
48
  logger = get_logger(__name__)
24
49
 
25
- if __name__ == "__main__":
26
- multiprocessing.set_start_method("forkserver", force=True)
50
+ if TYPE_CHECKING:
51
+ import torch
52
+
53
+ torch = require("torch", extra="torch", purpose="converted BAM processing")
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class RecordFastaInfo:
58
+ """Structured FASTA metadata for a single converted record.
59
+
60
+ Attributes:
61
+ sequence: Padded top-strand sequence for the record.
62
+ complement: Padded bottom-strand sequence for the record.
63
+ chromosome: Canonical chromosome name for the record.
64
+ unconverted_name: FASTA record name for the unconverted reference.
65
+ sequence_length: Length of the unpadded reference sequence.
66
+ padding_length: Number of padded bases applied to reach max length.
67
+ conversion: Conversion label (e.g., "unconverted", "5mC").
68
+ strand: Strand label ("top" or "bottom").
69
+ max_reference_length: Maximum reference length across all records.
70
+ """
71
+
72
+ sequence: str
73
+ complement: str
74
+ chromosome: str
75
+ unconverted_name: str
76
+ sequence_length: int
77
+ padding_length: int
78
+ conversion: str
79
+ strand: str
80
+ max_reference_length: int
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class SequenceEncodingConfig:
85
+ """Configuration for integer sequence encoding.
86
+
87
+ Attributes:
88
+ base_to_int: Mapping of base characters to integer encodings.
89
+ bases: Valid base characters used for encoding.
90
+ padding_base: Base token used for padding.
91
+ batch_size: Number of reads per temporary batch file.
92
+ """
93
+
94
+ base_to_int: Mapping[str, int]
95
+ bases: tuple[str, ...]
96
+ padding_base: str
97
+ batch_size: int = 100000
98
+
99
+ @property
100
+ def padding_value(self) -> int:
101
+ """Return the integer value used for padding positions."""
102
+ return self.base_to_int[self.padding_base]
103
+
104
+ @property
105
+ def unknown_value(self) -> int:
106
+ """Return the integer value used for unknown bases."""
107
+ return self.base_to_int["N"]
108
+
109
+
110
+ SEQUENCE_ENCODING_CONFIG = SequenceEncodingConfig(
111
+ base_to_int=MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
112
+ bases=MODKIT_EXTRACT_SEQUENCE_BASES,
113
+ padding_base=MODKIT_EXTRACT_SEQUENCE_PADDING_BASE,
114
+ )
27
115
 
28
116
 
29
117
  def converted_BAM_to_adata(
@@ -40,8 +128,9 @@ def converted_BAM_to_adata(
40
128
  deaminase_footprinting: bool = False,
41
129
  delete_intermediates: bool = True,
42
130
  double_barcoded_path: Path | None = None,
131
+ samtools_backend: str | None = "auto",
43
132
  ) -> tuple[ad.AnnData | None, Path]:
44
- """Convert BAM files into an AnnData object by binarizing modified base identities.
133
+ """Convert converted BAM files into an AnnData object with integer sequence encoding.
45
134
 
46
135
  Args:
47
136
  converted_FASTA: Path to the converted FASTA reference.
@@ -57,9 +146,18 @@ def converted_BAM_to_adata(
57
146
  deaminase_footprinting: Whether the footprinting used direct deamination chemistry.
58
147
  delete_intermediates: Whether to remove intermediate files after processing.
59
148
  double_barcoded_path: Path to dorado demux summary file of double-ended barcodes.
149
+ samtools_backend: Samtools backend choice for alignment parsing.
60
150
 
61
151
  Returns:
62
152
  tuple[anndata.AnnData | None, Path]: The AnnData object (if generated) and its path.
153
+
154
+ Processing Steps:
155
+ 1. Resolve the best available torch device and create output directories.
156
+ 2. Load converted FASTA records and compute conversion sites.
157
+ 3. Filter BAMs based on mapping thresholds.
158
+ 4. Process each BAM in parallel, building per-sample H5AD files.
159
+ 5. Concatenate per-sample AnnData objects and attach reference metadata.
160
+ 6. Add demultiplexing annotations and clean intermediate artifacts.
63
161
  """
64
162
  if torch.cuda.is_available():
65
163
  device = torch.device("cuda")
@@ -71,7 +169,7 @@ def converted_BAM_to_adata(
71
169
  logger.debug(f"Using device: {device}")
72
170
 
73
171
  ## Set Up Directories and File Paths
74
- h5_dir = output_dir / "h5ads"
172
+ h5_dir = output_dir / H5_DIR
75
173
  tmp_dir = output_dir / "tmp"
76
174
  final_adata = None
77
175
  final_adata_path = h5_dir / f"{experiment_name}.h5ad.gz"
@@ -85,11 +183,13 @@ def converted_BAM_to_adata(
85
183
  bam_files = sorted(
86
184
  p
87
185
  for p in split_dir.iterdir()
88
- if p.is_file() and p.suffix == ".bam" and "unclassified" not in p.name
186
+ if p.is_file() and p.suffix == BAM_SUFFIX and "unclassified" not in p.name
89
187
  )
90
188
 
91
189
  bam_path_list = bam_files
92
- logger.info(f"Found {len(bam_files)} BAM files: {bam_files}")
190
+
191
+ bam_names = [bam.name for bam in bam_files]
192
+ logger.info(f"Found {len(bam_files)} BAM files within {split_dir}: {bam_names}")
93
193
 
94
194
  ## Process Conversion Sites
95
195
  max_reference_length, record_FASTA_dict, chromosome_FASTA_dict = process_conversion_sites(
@@ -98,9 +198,19 @@ def converted_BAM_to_adata(
98
198
 
99
199
  ## Filter BAM Files by Mapping Threshold
100
200
  records_to_analyze = filter_bams_by_mapping_threshold(
101
- bam_path_list, bam_files, mapping_threshold
201
+ bam_path_list, bam_files, mapping_threshold, samtools_backend
102
202
  )
103
203
 
204
+ # Get converted record sequences:
205
+ converted_FASTA_record_seq_map = get_native_references(converted_FASTA)
206
+ # Pad the record sequences
207
+ for record, [record_length, seq] in converted_FASTA_record_seq_map.items():
208
+ if max_reference_length > record_length:
209
+ pad_number = max_reference_length - record_length
210
+ record_length += pad_number
211
+ seq += "N" * pad_number
212
+ converted_FASTA_record_seq_map[record] = [record_length, seq]
213
+
104
214
  ## Process BAMs in Parallel
105
215
  final_adata = process_bams_parallel(
106
216
  bam_path_list,
@@ -113,8 +223,16 @@ def converted_BAM_to_adata(
113
223
  max_reference_length,
114
224
  device,
115
225
  deaminase_footprinting,
226
+ samtools_backend,
227
+ converted_FASTA_record_seq_map,
116
228
  )
117
229
 
230
+ final_adata.uns[f"{SEQUENCE_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
231
+ final_adata.uns[f"{MISMATCH_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
232
+ final_adata.uns[f"{SEQUENCE_INTEGER_DECODING}_map"] = {
233
+ str(key): value for key, value in MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE.items()
234
+ }
235
+
118
236
  final_adata.uns["References"] = {}
119
237
  for chromosome, [seq, comp] in chromosome_FASTA_dict.items():
120
238
  final_adata.var[f"{chromosome}_top_strand_FASTA_base"] = list(seq)
@@ -122,6 +240,11 @@ def converted_BAM_to_adata(
122
240
  final_adata.uns[f"{chromosome}_FASTA_sequence"] = seq
123
241
  final_adata.uns["References"][f"{chromosome}_FASTA_sequence"] = seq
124
242
 
243
+ if not deaminase_footprinting:
244
+ for record, [_length, seq] in converted_FASTA_record_seq_map.items():
245
+ if "unconverted" not in record:
246
+ final_adata.var[f"{record}_top_strand_FASTA_base"] = list(seq)
247
+
125
248
  final_adata.obs_names_make_unique()
126
249
  cols = final_adata.obs.columns
127
250
 
@@ -129,9 +252,29 @@ def converted_BAM_to_adata(
129
252
  for col in cols:
130
253
  final_adata.obs[col] = final_adata.obs[col].astype("category")
131
254
 
255
+ consensus_bases = MODKIT_EXTRACT_SEQUENCE_BASES[:4] # ignore N/PAD for consensus
256
+ consensus_base_ints = [MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[base] for base in consensus_bases]
257
+ for ref_group in final_adata.obs[REFERENCE_DATASET_STRAND].cat.categories:
258
+ group_subset = final_adata[final_adata.obs[REFERENCE_DATASET_STRAND] == ref_group]
259
+ encoded_sequences = group_subset.layers[SEQUENCE_INTEGER_ENCODING]
260
+ layer_counts = [
261
+ np.sum(encoded_sequences == base_int, axis=0) for base_int in consensus_base_ints
262
+ ]
263
+ count_array = np.array(layer_counts)
264
+ nucleotide_indexes = np.argmax(count_array, axis=0)
265
+ consensus_sequence_list = [consensus_bases[i] for i in nucleotide_indexes]
266
+ no_calls_mask = np.sum(count_array, axis=0) == 0
267
+ if np.any(no_calls_mask):
268
+ consensus_sequence_list = np.array(consensus_sequence_list, dtype=object)
269
+ consensus_sequence_list[no_calls_mask] = "N"
270
+ consensus_sequence_list = consensus_sequence_list.tolist()
271
+ final_adata.var[f"{ref_group}_consensus_sequence_from_all_samples"] = (
272
+ consensus_sequence_list
273
+ )
274
+
132
275
  if input_already_demuxed:
133
- final_adata.obs["demux_type"] = ["already"] * final_adata.shape[0]
134
- final_adata.obs["demux_type"] = final_adata.obs["demux_type"].astype("category")
276
+ final_adata.obs[DEMUX_TYPE] = ["already"] * final_adata.shape[0]
277
+ final_adata.obs[DEMUX_TYPE] = final_adata.obs[DEMUX_TYPE].astype("category")
135
278
  else:
136
279
  from .h5ad_functions import add_demux_type_annotation
137
280
 
@@ -148,23 +291,31 @@ def converted_BAM_to_adata(
148
291
 
149
292
 
150
293
  def process_conversion_sites(
151
- converted_FASTA, conversions=["unconverted", "5mC"], deaminase_footprinting=False
152
- ):
153
- """
154
- Extracts conversion sites and determines the max reference length.
294
+ converted_FASTA: str | Path,
295
+ conversions: list[str] | None = None,
296
+ deaminase_footprinting: bool = False,
297
+ ) -> tuple[int, dict[str, RecordFastaInfo], dict[str, tuple[str, str]]]:
298
+ """Extract conversion sites and FASTA metadata for converted references.
155
299
 
156
- Parameters:
157
- converted_FASTA (str): Path to the converted reference FASTA.
158
- conversions (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
159
- deaminase_footprinting (bool): Whether the footprinting was done with a direct deamination chemistry.
300
+ Args:
301
+ converted_FASTA: Path to the converted reference FASTA.
302
+ conversions: List of modification types (e.g., ["unconverted", "5mC", "6mA"]).
303
+ deaminase_footprinting: Whether the footprinting was done with direct deamination chemistry.
160
304
 
161
305
  Returns:
162
- max_reference_length (int): The length of the longest sequence.
163
- record_FASTA_dict (dict): Dictionary of sequence information for **both converted & unconverted** records.
306
+ tuple[int, dict[str, RecordFastaInfo], dict[str, tuple[str, str]]]:
307
+ Maximum reference length, record metadata, and chromosome sequences.
308
+
309
+ Processing Steps:
310
+ 1. Parse unconverted FASTA records to determine the max reference length.
311
+ 2. Build record metadata for unconverted and converted strands.
312
+ 3. Cache chromosome-level sequences for downstream annotation.
164
313
  """
165
- modification_dict = {}
166
- record_FASTA_dict = {}
167
- chromosome_FASTA_dict = {}
314
+ if conversions is None:
315
+ conversions = ["unconverted", "5mC"]
316
+ modification_dict: dict[str, dict] = {}
317
+ record_FASTA_dict: dict[str, RecordFastaInfo] = {}
318
+ chromosome_FASTA_dict: dict[str, tuple[str, str]] = {}
168
319
  max_reference_length = 0
169
320
  unconverted = conversions[0]
170
321
  conversion_types = conversions[1:]
@@ -188,22 +339,23 @@ def process_conversion_sites(
188
339
  chromosome = record
189
340
 
190
341
  # Store **original sequence**
191
- record_FASTA_dict[record] = [
192
- sequence + "N" * (max_reference_length - sequence_length),
193
- complement + "N" * (max_reference_length - sequence_length),
194
- chromosome,
195
- record,
196
- sequence_length,
197
- max_reference_length - sequence_length,
198
- unconverted,
199
- "top",
200
- ]
342
+ record_FASTA_dict[record] = RecordFastaInfo(
343
+ sequence=sequence + "N" * (max_reference_length - sequence_length),
344
+ complement=complement + "N" * (max_reference_length - sequence_length),
345
+ chromosome=chromosome,
346
+ unconverted_name=record,
347
+ sequence_length=sequence_length,
348
+ padding_length=max_reference_length - sequence_length,
349
+ conversion=unconverted,
350
+ strand="top",
351
+ max_reference_length=max_reference_length,
352
+ )
201
353
 
202
354
  if chromosome not in chromosome_FASTA_dict:
203
- chromosome_FASTA_dict[chromosome] = [
355
+ chromosome_FASTA_dict[chromosome] = (
204
356
  sequence + "N" * (max_reference_length - sequence_length),
205
357
  complement + "N" * (max_reference_length - sequence_length),
206
- ]
358
+ )
207
359
 
208
360
  # Process converted records
209
361
  for conversion in conversion_types:
@@ -225,29 +377,49 @@ def process_conversion_sites(
225
377
  converted_name = f"{chromosome}_{conversion}_{strand}"
226
378
  unconverted_name = f"{chromosome}_{unconverted}_top"
227
379
 
228
- record_FASTA_dict[converted_name] = [
229
- sequence + "N" * (max_reference_length - sequence_length),
230
- complement + "N" * (max_reference_length - sequence_length),
231
- chromosome,
232
- unconverted_name,
233
- sequence_length,
234
- max_reference_length - sequence_length,
235
- conversion,
236
- strand,
237
- ]
238
-
239
- logger.debug("Updated record_FASTA_dict Keys:", list(record_FASTA_dict.keys()))
380
+ record_FASTA_dict[converted_name] = RecordFastaInfo(
381
+ sequence=sequence + "N" * (max_reference_length - sequence_length),
382
+ complement=complement + "N" * (max_reference_length - sequence_length),
383
+ chromosome=chromosome,
384
+ unconverted_name=unconverted_name,
385
+ sequence_length=sequence_length,
386
+ padding_length=max_reference_length - sequence_length,
387
+ conversion=conversion,
388
+ strand=strand,
389
+ max_reference_length=max_reference_length,
390
+ )
391
+
392
+ logger.debug("Updated record_FASTA_dict keys: %s", list(record_FASTA_dict.keys()))
240
393
  return max_reference_length, record_FASTA_dict, chromosome_FASTA_dict
241
394
 
242
395
 
243
- def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold):
244
- """Filters BAM files based on mapping threshold."""
245
- records_to_analyze = set()
396
+ def filter_bams_by_mapping_threshold(
397
+ bam_path_list: list[Path],
398
+ bam_files: list[Path],
399
+ mapping_threshold: float,
400
+ samtools_backend: str | None,
401
+ ) -> set[str]:
402
+ """Filter FASTA records based on per-BAM mapping thresholds.
403
+
404
+ Args:
405
+ bam_path_list: Ordered list of BAM paths to evaluate.
406
+ bam_files: Matching list of BAM paths used for reporting.
407
+ mapping_threshold: Minimum percentage of aligned reads to include a record.
408
+ samtools_backend: Samtools backend choice for alignment parsing.
409
+
410
+ Returns:
411
+ set[str]: FASTA record IDs that pass the mapping threshold.
412
+
413
+ Processing Steps:
414
+ 1. Count aligned/unaligned reads and per-record percentages.
415
+ 2. Collect record IDs that meet the mapping threshold.
416
+ """
417
+ records_to_analyze: set[str] = set()
246
418
 
247
419
  for i, bam in enumerate(bam_path_list):
248
- aligned_reads, unaligned_reads, record_counts = count_aligned_reads(bam)
420
+ aligned_reads, unaligned_reads, record_counts = count_aligned_reads(bam, samtools_backend)
249
421
  aligned_percent = aligned_reads * 100 / (aligned_reads + unaligned_reads)
250
- print(f"{aligned_percent:.2f}% of reads in {bam_files[i]} aligned successfully.")
422
+ logger.info(f"{aligned_percent:.2f}% of reads in {bam_files[i].name} aligned successfully.")
251
423
 
252
424
  for record, (count, percent) in record_counts.items():
253
425
  if percent >= mapping_threshold:
@@ -257,32 +429,179 @@ def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold
257
429
  return records_to_analyze
258
430
 
259
431
 
432
+ def _encode_sequence_array(
433
+ read_sequence: np.ndarray,
434
+ valid_length: int,
435
+ config: SequenceEncodingConfig,
436
+ ) -> np.ndarray:
437
+ """Encode a base-identity array into integer values with padding.
438
+
439
+ Args:
440
+ read_sequence: Array of base calls (dtype "<U1").
441
+ valid_length: Number of valid reference positions for this record.
442
+ config: Integer encoding configuration.
443
+
444
+ Returns:
445
+ np.ndarray: Integer-encoded sequence with padding applied.
446
+
447
+ Processing Steps:
448
+ 1. Initialize an array filled with the unknown base encoding.
449
+ 2. Map A/C/G/T/N bases into integer values.
450
+ 3. Mark positions beyond valid_length with the padding value.
451
+ """
452
+ read_sequence = np.asarray(read_sequence, dtype="<U1")
453
+ encoded = np.full(read_sequence.shape, config.unknown_value, dtype=np.int16)
454
+ for base in config.bases:
455
+ encoded[read_sequence == base] = config.base_to_int[base]
456
+ if valid_length < encoded.size:
457
+ encoded[valid_length:] = config.padding_value
458
+ return encoded
459
+
460
+
461
+ def _write_sequence_batches(
462
+ base_identities: Mapping[str, np.ndarray],
463
+ tmp_dir: Path,
464
+ record: str,
465
+ prefix: str,
466
+ valid_length: int,
467
+ config: SequenceEncodingConfig,
468
+ ) -> list[str]:
469
+ """Encode base identities into integer arrays and write batched H5AD files.
470
+
471
+ Args:
472
+ base_identities: Mapping of read name to base identity arrays.
473
+ tmp_dir: Directory for temporary H5AD files.
474
+ record: Reference record identifier.
475
+ prefix: Prefix used to name batch files.
476
+ valid_length: Valid reference length for padding determination.
477
+ config: Integer encoding configuration.
478
+
479
+ Returns:
480
+ list[str]: Paths to written H5AD batch files.
481
+
482
+ Processing Steps:
483
+ 1. Encode each read sequence into integers.
484
+ 2. Accumulate encoded reads into batches.
485
+ 3. Persist each batch to an H5AD file with `.uns` storage.
486
+ """
487
+ batch_files: list[str] = []
488
+ batch: dict[str, np.ndarray] = {}
489
+ batch_number = 0
490
+
491
+ for read_name, sequence in base_identities.items():
492
+ if sequence is None:
493
+ continue
494
+ batch[read_name] = _encode_sequence_array(sequence, valid_length, config)
495
+ if len(batch) >= config.batch_size:
496
+ save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
497
+ ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
498
+ batch_files.append(str(save_name))
499
+ batch = {}
500
+ batch_number += 1
501
+
502
+ if batch:
503
+ save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
504
+ ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
505
+ batch_files.append(str(save_name))
506
+
507
+ return batch_files
508
+
509
+
510
+ def _load_sequence_batches(
511
+ batch_files: list[Path | str],
512
+ ) -> tuple[dict[str, np.ndarray], set[str], set[str]]:
513
+ """Load integer-encoded sequence batches from H5AD files.
514
+
515
+ Args:
516
+ batch_files: H5AD paths containing encoded sequences in `.uns`.
517
+
518
+ Returns:
519
+ tuple[dict[str, np.ndarray], set[str], set[str]]:
520
+ Read-to-sequence mapping and sets of forward/reverse mapped reads.
521
+
522
+ Processing Steps:
523
+ 1. Read each H5AD file.
524
+ 2. Merge `.uns` dictionaries into a single mapping.
525
+ 3. Track forward/reverse read IDs based on filename markers.
526
+ """
527
+ sequences: dict[str, np.ndarray] = {}
528
+ fwd_reads: set[str] = set()
529
+ rev_reads: set[str] = set()
530
+ for batch_file in batch_files:
531
+ batch_path = Path(batch_file)
532
+ batch_sequences = ad.read_h5ad(batch_path).uns
533
+ sequences.update(batch_sequences)
534
+ if "_fwd_" in batch_path.name:
535
+ fwd_reads.update(batch_sequences.keys())
536
+ elif "_rev_" in batch_path.name:
537
+ rev_reads.update(batch_sequences.keys())
538
+ return sequences, fwd_reads, rev_reads
539
+
540
+
260
541
  def process_single_bam(
261
- bam_index,
262
- bam,
263
- records_to_analyze,
264
- record_FASTA_dict,
265
- chromosome_FASTA_dict,
266
- tmp_dir,
267
- max_reference_length,
268
- device,
269
- deaminase_footprinting,
270
- ):
271
- """Worker function to process a single BAM file (must be at top-level for multiprocessing)."""
272
- adata_list = []
542
+ bam_index: int,
543
+ bam: Path,
544
+ records_to_analyze: set[str],
545
+ record_FASTA_dict: dict[str, RecordFastaInfo],
546
+ chromosome_FASTA_dict: dict[str, tuple[str, str]],
547
+ tmp_dir: Path,
548
+ max_reference_length: int,
549
+ device: torch.device,
550
+ deaminase_footprinting: bool,
551
+ samtools_backend: str | None,
552
+ converted_FASTA_record_seq_map: dict[str, tuple[int, str]],
553
+ ) -> ad.AnnData | None:
554
+ """Process a single BAM file into per-record AnnData objects.
555
+
556
+ Args:
557
+ bam_index: Index of the BAM within the processing batch.
558
+ bam: Path to the BAM file.
559
+ records_to_analyze: FASTA record IDs that passed the mapping threshold.
560
+ record_FASTA_dict: FASTA metadata keyed by record ID.
561
+ chromosome_FASTA_dict: Chromosome sequences for annotations.
562
+ tmp_dir: Directory for temporary batch files.
563
+ max_reference_length: Maximum reference length for padding.
564
+ device: Torch device used for binarization.
565
+ deaminase_footprinting: Whether direct deamination chemistry was used.
566
+ samtools_backend: Samtools backend choice for alignment parsing.
567
+ converted_FASTA_record_seq_map: record to seq map
568
+
569
+ Returns:
570
+ anndata.AnnData | None: Concatenated AnnData object or None if no data.
571
+
572
+ Processing Steps:
573
+ 1. Extract base identities and mismatch profiles for each record.
574
+ 2. Binarize modified base identities into feature matrices.
575
+ 3. Encode read sequences into integer arrays and cache batches.
576
+ 4. Build AnnData layers/obs metadata for each record and concatenate.
577
+ """
578
+ adata_list: list[ad.AnnData] = []
273
579
 
274
580
  for record in records_to_analyze:
275
581
  sample = bam.stem
276
- chromosome = record_FASTA_dict[record][2]
277
- current_length = record_FASTA_dict[record][4]
278
- mod_type, strand = record_FASTA_dict[record][6], record_FASTA_dict[record][7]
279
- sequence = chromosome_FASTA_dict[chromosome][0]
582
+ record_info = record_FASTA_dict[record]
583
+ chromosome = record_info.chromosome
584
+ current_length = record_info.sequence_length
585
+ mod_type, strand = record_info.conversion, record_info.strand
586
+ non_converted_sequence = chromosome_FASTA_dict[chromosome][0]
587
+ record_sequence = converted_FASTA_record_seq_map[record][1]
280
588
 
281
589
  # Extract Base Identities
282
- fwd_bases, rev_bases, mismatch_counts_per_read, mismatch_trend_per_read = (
283
- extract_base_identities(
284
- bam, record, range(current_length), max_reference_length, sequence
285
- )
590
+ (
591
+ fwd_bases,
592
+ rev_bases,
593
+ mismatch_counts_per_read,
594
+ mismatch_trend_per_read,
595
+ mismatch_base_identities,
596
+ base_quality_scores,
597
+ read_span_masks,
598
+ ) = extract_base_identities(
599
+ bam,
600
+ record,
601
+ range(current_length),
602
+ max_reference_length,
603
+ record_sequence,
604
+ samtools_backend,
286
605
  )
287
606
  mismatch_trend_series = pd.Series(mismatch_trend_per_read)
288
607
 
@@ -334,83 +653,115 @@ def process_single_bam(
334
653
  sorted_index = sorted(bin_df.index)
335
654
  bin_df = bin_df.reindex(sorted_index)
336
655
 
337
- # One-Hot Encode Reads if there is valid data
338
- one_hot_reads = {}
339
-
656
+ # Integer-encode reads if there is valid data
657
+ batch_files: list[str] = []
340
658
  if fwd_bases:
341
- fwd_ohe_files = ohe_batching(
342
- fwd_bases, tmp_dir, record, f"{bam_index}_fwd", batch_size=100000
659
+ batch_files.extend(
660
+ _write_sequence_batches(
661
+ fwd_bases,
662
+ tmp_dir,
663
+ record,
664
+ f"{bam_index}_fwd",
665
+ current_length,
666
+ SEQUENCE_ENCODING_CONFIG,
667
+ )
343
668
  )
344
- for ohe_file in fwd_ohe_files:
345
- tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
346
- one_hot_reads.update(tmp_ohe_dict)
347
- del tmp_ohe_dict
348
669
 
349
670
  if rev_bases:
350
- rev_ohe_files = ohe_batching(
351
- rev_bases, tmp_dir, record, f"{bam_index}_rev", batch_size=100000
671
+ batch_files.extend(
672
+ _write_sequence_batches(
673
+ rev_bases,
674
+ tmp_dir,
675
+ record,
676
+ f"{bam_index}_rev",
677
+ current_length,
678
+ SEQUENCE_ENCODING_CONFIG,
679
+ )
352
680
  )
353
- for ohe_file in rev_ohe_files:
354
- tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
355
- one_hot_reads.update(tmp_ohe_dict)
356
- del tmp_ohe_dict
357
681
 
358
- # Skip if one_hot_reads is empty
359
- if not one_hot_reads:
682
+ if not batch_files:
360
683
  logger.debug(
361
- f"[Worker {current_process().pid}] Skipping {sample} - No valid one-hot encoded data for {record}."
684
+ f"[Worker {current_process().pid}] Skipping {sample} - No valid encoded data for {record}."
362
685
  )
363
686
  continue
364
687
 
365
688
  gc.collect()
366
689
 
367
- # Convert One-Hot Encodings to Numpy Arrays
368
- n_rows_OHE = 5
369
- read_names = list(one_hot_reads.keys())
370
-
371
- # Skip if no read names exist
372
- if not read_names:
690
+ encoded_reads, fwd_reads, rev_reads = _load_sequence_batches(batch_files)
691
+ if not encoded_reads:
373
692
  logger.debug(
374
- f"[Worker {current_process().pid}] Skipping {sample} - No reads found in one-hot encoded data for {record}."
693
+ f"[Worker {current_process().pid}] Skipping {sample} - No reads found in encoded data for {record}."
375
694
  )
376
695
  continue
377
696
 
378
- sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
379
- df_A, df_C, df_G, df_T, df_N = [
380
- np.zeros((len(sorted_index), sequence_length), dtype=int) for _ in range(5)
381
- ]
697
+ sequence_length = max_reference_length
698
+ default_sequence = np.full(
699
+ sequence_length, SEQUENCE_ENCODING_CONFIG.unknown_value, dtype=np.int16
700
+ )
701
+ if current_length < sequence_length:
702
+ default_sequence[current_length:] = SEQUENCE_ENCODING_CONFIG.padding_value
382
703
 
383
- # Populate One-Hot Arrays
384
- for j, read_name in enumerate(sorted_index):
385
- if read_name in one_hot_reads:
386
- one_hot_array = one_hot_reads[read_name].reshape(n_rows_OHE, -1)
387
- df_A[j], df_C[j], df_G[j], df_T[j], df_N[j] = one_hot_array
704
+ encoded_matrix = np.vstack(
705
+ [encoded_reads.get(read_name, default_sequence) for read_name in sorted_index]
706
+ )
707
+ default_mismatch_sequence = np.full(
708
+ sequence_length, SEQUENCE_ENCODING_CONFIG.unknown_value, dtype=np.int16
709
+ )
710
+ if current_length < sequence_length:
711
+ default_mismatch_sequence[current_length:] = SEQUENCE_ENCODING_CONFIG.padding_value
712
+ mismatch_encoded_matrix = np.vstack(
713
+ [
714
+ mismatch_base_identities.get(read_name, default_mismatch_sequence)
715
+ for read_name in sorted_index
716
+ ]
717
+ )
718
+ default_quality_sequence = np.full(sequence_length, -1, dtype=np.int16)
719
+ quality_matrix = np.vstack(
720
+ [
721
+ base_quality_scores.get(read_name, default_quality_sequence)
722
+ for read_name in sorted_index
723
+ ]
724
+ )
725
+ default_read_span = np.zeros(sequence_length, dtype=np.int16)
726
+ read_span_matrix = np.vstack(
727
+ [read_span_masks.get(read_name, default_read_span) for read_name in sorted_index]
728
+ )
388
729
 
389
730
  # Convert to AnnData
390
731
  X = bin_df.values.astype(np.float32)
391
732
  adata = ad.AnnData(X)
392
733
  adata.obs_names = bin_df.index.astype(str)
393
734
  adata.var_names = bin_df.columns.astype(str)
394
- adata.obs["Sample"] = [sample] * len(adata)
735
+ adata.obs[SAMPLE] = [sample] * len(adata)
395
736
  try:
396
737
  barcode = sample.split("barcode")[1]
397
738
  except Exception:
398
739
  barcode = np.nan
399
- adata.obs["Barcode"] = [int(barcode)] * len(adata)
400
- adata.obs["Barcode"] = adata.obs["Barcode"].astype(str)
401
- adata.obs["Reference"] = [chromosome] * len(adata)
402
- adata.obs["Strand"] = [strand] * len(adata)
403
- adata.obs["Dataset"] = [mod_type] * len(adata)
404
- adata.obs["Reference_dataset_strand"] = [f"{chromosome}_{mod_type}_{strand}"] * len(adata)
405
- adata.obs["Reference_strand"] = [f"{chromosome}_{strand}"] * len(adata)
406
- adata.obs["Read_mismatch_trend"] = adata.obs_names.map(mismatch_trend_series)
407
-
408
- # Attach One-Hot Encodings to Layers
409
- adata.layers["A_binary_sequence_encoding"] = df_A
410
- adata.layers["C_binary_sequence_encoding"] = df_C
411
- adata.layers["G_binary_sequence_encoding"] = df_G
412
- adata.layers["T_binary_sequence_encoding"] = df_T
413
- adata.layers["N_binary_sequence_encoding"] = df_N
740
+ adata.obs[BARCODE] = [int(barcode)] * len(adata)
741
+ adata.obs[BARCODE] = adata.obs[BARCODE].astype(str)
742
+ adata.obs[REFERENCE] = [chromosome] * len(adata)
743
+ adata.obs[STRAND] = [strand] * len(adata)
744
+ adata.obs[DATASET] = [mod_type] * len(adata)
745
+ adata.obs[REFERENCE_DATASET_STRAND] = [f"{chromosome}_{mod_type}_{strand}"] * len(adata)
746
+ adata.obs[REFERENCE_STRAND] = [f"{chromosome}_{strand}"] * len(adata)
747
+ adata.obs[READ_MISMATCH_TREND] = adata.obs_names.map(mismatch_trend_series)
748
+
749
+ read_mapping_direction = []
750
+ for read_id in adata.obs_names:
751
+ if read_id in fwd_reads:
752
+ read_mapping_direction.append("fwd")
753
+ elif read_id in rev_reads:
754
+ read_mapping_direction.append("rev")
755
+ else:
756
+ read_mapping_direction.append("unk")
757
+
758
+ adata.obs[READ_MAPPING_DIRECTION] = read_mapping_direction
759
+
760
+ # Attach integer sequence encoding to layers
761
+ adata.layers[SEQUENCE_INTEGER_ENCODING] = encoded_matrix
762
+ adata.layers[MISMATCH_INTEGER_ENCODING] = mismatch_encoded_matrix
763
+ adata.layers[BASE_QUALITY_SCORES] = quality_matrix
764
+ adata.layers[READ_SPAN_MASK] = read_span_matrix
414
765
 
415
766
  adata_list.append(adata)
416
767
 
@@ -418,24 +769,57 @@ def process_single_bam(
418
769
 
419
770
 
420
771
  def timestamp():
421
- """Returns a formatted timestamp for logging."""
772
+ """Return a formatted timestamp for logging.
773
+
774
+ Returns:
775
+ str: Timestamp string in the format ``[YYYY-MM-DD HH:MM:SS]``.
776
+ """
422
777
  return time.strftime("[%Y-%m-%d %H:%M:%S]")
423
778
 
424
779
 
425
780
  def worker_function(
426
- bam_index,
427
- bam,
428
- records_to_analyze,
429
- shared_record_FASTA_dict,
430
- chromosome_FASTA_dict,
431
- tmp_dir,
432
- h5_dir,
433
- max_reference_length,
434
- device,
435
- deaminase_footprinting,
781
+ bam_index: int,
782
+ bam: Path,
783
+ records_to_analyze: set[str],
784
+ shared_record_FASTA_dict: dict[str, RecordFastaInfo],
785
+ chromosome_FASTA_dict: dict[str, tuple[str, str]],
786
+ tmp_dir: Path,
787
+ h5_dir: Path,
788
+ max_reference_length: int,
789
+ device: torch.device,
790
+ deaminase_footprinting: bool,
791
+ samtools_backend: str | None,
792
+ converted_FASTA_record_seq_map: dict[str, tuple[int, str]],
436
793
  progress_queue,
794
+ log_level: int,
795
+ log_file: Path | None,
437
796
  ):
438
- """Worker function that processes a single BAM and writes the output to an H5AD file."""
797
+ """Process a single BAM and write the output to an H5AD file.
798
+
799
+ Args:
800
+ bam_index: Index of the BAM within the processing batch.
801
+ bam: Path to the BAM file.
802
+ records_to_analyze: FASTA record IDs that passed the mapping threshold.
803
+ shared_record_FASTA_dict: Shared FASTA metadata keyed by record ID.
804
+ chromosome_FASTA_dict: Chromosome sequences for annotations.
805
+ tmp_dir: Directory for temporary batch files.
806
+ h5_dir: Directory for per-BAM H5AD outputs.
807
+ max_reference_length: Maximum reference length for padding.
808
+ device: Torch device used for binarization.
809
+ deaminase_footprinting: Whether direct deamination chemistry was used.
810
+ samtools_backend: Samtools backend choice for alignment parsing.
811
+ converted_FASTA_record_seq_map: record to sequence map
812
+ progress_queue: Queue used to signal completion.
813
+ log_level: Logging level to configure in workers.
814
+ log_file: Optional log file path.
815
+
816
+ Processing Steps:
817
+ 1. Skip processing if an output H5AD already exists.
818
+ 2. Filter records to those present in the FASTA metadata.
819
+ 3. Run per-record processing and write AnnData output.
820
+ 4. Signal completion via the progress queue.
821
+ """
822
+ _ensure_worker_logging(log_level, log_file)
439
823
  worker_id = current_process().pid # Get worker process ID
440
824
  sample = bam.stem
441
825
 
@@ -471,6 +855,8 @@ def worker_function(
471
855
  max_reference_length,
472
856
  device,
473
857
  deaminase_footprinting,
858
+ samtools_backend,
859
+ converted_FASTA_record_seq_map,
474
860
  )
475
861
 
476
862
  if adata is not None:
@@ -491,29 +877,47 @@ def worker_function(
491
877
 
492
878
 
493
879
  def process_bams_parallel(
494
- bam_path_list,
495
- records_to_analyze,
496
- record_FASTA_dict,
497
- chromosome_FASTA_dict,
498
- tmp_dir,
499
- h5_dir,
500
- num_threads,
501
- max_reference_length,
502
- device,
503
- deaminase_footprinting,
504
- ):
505
- """Processes BAM files in parallel, writes each H5AD to disk, and concatenates them at the end."""
506
- make_dirs(h5_dir) # Ensure h5_dir exists
880
+ bam_path_list: list[Path],
881
+ records_to_analyze: set[str],
882
+ record_FASTA_dict: dict[str, RecordFastaInfo],
883
+ chromosome_FASTA_dict: dict[str, tuple[str, str]],
884
+ tmp_dir: Path,
885
+ h5_dir: Path,
886
+ num_threads: int,
887
+ max_reference_length: int,
888
+ device: torch.device,
889
+ deaminase_footprinting: bool,
890
+ samtools_backend: str | None,
891
+ converted_FASTA_record_seq_map: dict[str, tuple[int, str]],
892
+ ) -> ad.AnnData | None:
893
+ """Process BAM files in parallel and concatenate the resulting AnnData.
507
894
 
508
- logger.info(f"Starting parallel BAM processing with {num_threads} threads...")
895
+ Args:
896
+ bam_path_list: List of BAM files to process.
897
+ records_to_analyze: FASTA record IDs that passed the mapping threshold.
898
+ record_FASTA_dict: FASTA metadata keyed by record ID.
899
+ chromosome_FASTA_dict: Chromosome sequences for annotations.
900
+ tmp_dir: Directory for temporary batch files.
901
+ h5_dir: Directory for per-BAM H5AD outputs.
902
+ num_threads: Number of worker processes.
903
+ max_reference_length: Maximum reference length for padding.
904
+ device: Torch device used for binarization.
905
+ deaminase_footprinting: Whether direct deamination chemistry was used.
906
+ samtools_backend: Samtools backend choice for alignment parsing.
907
+ converted_FASTA_record_seq_map: map from converted record name to the converted reference length and sequence.
509
908
 
510
- # Ensure macOS uses forkserver to avoid spawning issues
511
- try:
512
- import multiprocessing
909
+ Returns:
910
+ anndata.AnnData | None: Concatenated AnnData or None if no H5ADs produced.
513
911
 
514
- multiprocessing.set_start_method("forkserver", force=True)
515
- except RuntimeError:
516
- logger.warning(f"Multiprocessing context already set. Skipping set_start_method.")
912
+ Processing Steps:
913
+ 1. Spawn worker processes to handle each BAM.
914
+ 2. Track completion via a multiprocessing queue.
915
+ 3. Concatenate per-BAM H5AD files into a final AnnData.
916
+ """
917
+ make_dirs(h5_dir) # Ensure h5_dir exists
918
+
919
+ logger.info(f"Starting parallel BAM processing with {num_threads} threads...")
920
+ log_level, log_file = _get_logger_config()
517
921
 
518
922
  with Manager() as manager:
519
923
  progress_queue = manager.Queue()
@@ -534,13 +938,17 @@ def process_bams_parallel(
534
938
  max_reference_length,
535
939
  device,
536
940
  deaminase_footprinting,
941
+ samtools_backend,
942
+ converted_FASTA_record_seq_map,
537
943
  progress_queue,
944
+ log_level,
945
+ log_file,
538
946
  ),
539
947
  )
540
948
  for i, bam in enumerate(bam_path_list)
541
949
  ]
542
950
 
543
- logger.info(f"Submitted {len(bam_path_list)} BAMs for processing.")
951
+ logger.info(f"Submitting {len(results)} BAMs for processing.")
544
952
 
545
953
  # Track completed BAMs
546
954
  completed_bams = set()
@@ -550,15 +958,18 @@ def process_bams_parallel(
550
958
  completed_bams.add(processed_bam)
551
959
  except Exception as e:
552
960
  logger.error(f"Timeout waiting for worker process. Possible crash? {e}")
961
+ _log_async_result_errors(results, bam_path_list)
553
962
 
554
963
  pool.close()
555
964
  pool.join() # Ensure all workers finish
556
965
 
966
+ _log_async_result_errors(results, bam_path_list)
967
+
557
968
  # Final Concatenation Step
558
969
  h5ad_files = [f for f in h5_dir.iterdir() if f.suffix == ".h5ad"]
559
970
 
560
971
  if not h5ad_files:
561
- logger.debug(f"No valid H5AD files generated. Exiting.")
972
+ logger.warning(f"No valid H5AD files generated. Exiting.")
562
973
  return None
563
974
 
564
975
  logger.info(f"Concatenating {len(h5ad_files)} H5AD files into final output...")
@@ -568,6 +979,64 @@ def process_bams_parallel(
568
979
  return final_adata
569
980
 
570
981
 
982
+ def _log_async_result_errors(results, bam_path_list):
983
+ """Log worker failures captured by multiprocessing AsyncResult objects.
984
+
985
+ Args:
986
+ results: Iterable of AsyncResult objects from multiprocessing.
987
+ bam_path_list: List of BAM paths matching the async results.
988
+
989
+ Processing Steps:
990
+ 1. Iterate over async results.
991
+ 2. Retrieve results to surface worker exceptions.
992
+ """
993
+ for bam, result in zip(bam_path_list, results):
994
+ if not result.ready():
995
+ continue
996
+ try:
997
+ result.get()
998
+ except Exception as exc:
999
+ logger.error("Worker process failed for %s: %s", bam, exc)
1000
+
1001
+
1002
+ def _get_logger_config() -> tuple[int, Path | None]:
1003
+ """Return the active smftools logger level and optional file path.
1004
+
1005
+ Returns:
1006
+ tuple[int, Path | None]: Log level and log file path (if configured).
1007
+
1008
+ Processing Steps:
1009
+ 1. Inspect the smftools logger for configured handlers.
1010
+ 2. Extract log level and file handler path.
1011
+ """
1012
+ smftools_logger = logging.getLogger("smftools")
1013
+ level = smftools_logger.level
1014
+ if level == logging.NOTSET:
1015
+ level = logging.INFO
1016
+ log_file: Path | None = None
1017
+ for handler in smftools_logger.handlers:
1018
+ if isinstance(handler, logging.FileHandler):
1019
+ log_file = Path(handler.baseFilename)
1020
+ break
1021
+ return level, log_file
1022
+
1023
+
1024
+ def _ensure_worker_logging(log_level: int, log_file: Path | None) -> None:
1025
+ """Ensure worker processes have logging configured.
1026
+
1027
+ Args:
1028
+ log_level: Logging level to configure.
1029
+ log_file: Optional log file path.
1030
+
1031
+ Processing Steps:
1032
+ 1. Check if handlers are already configured.
1033
+ 2. Initialize logging with the provided level and file path.
1034
+ """
1035
+ smftools_logger = logging.getLogger("smftools")
1036
+ if not smftools_logger.handlers:
1037
+ setup_logging(level=log_level, log_file=log_file)
1038
+
1039
+
571
1040
  def delete_intermediate_h5ads_and_tmpdir(
572
1041
  h5_dir: Union[str, Path, Iterable[str], None],
573
1042
  tmp_dir: Optional[Union[str, Path]] = None,
@@ -575,21 +1044,17 @@ def delete_intermediate_h5ads_and_tmpdir(
575
1044
  dry_run: bool = False,
576
1045
  verbose: bool = True,
577
1046
  ):
578
- """
579
- Delete intermediate .h5ad files and a temporary directory.
580
-
581
- Parameters
582
- ----------
583
- h5_dir : str | Path | iterable[str] | None
584
- If a directory path is given, all files directly inside it will be considered.
585
- If an iterable of file paths is given, those files will be considered.
586
- Only files ending with '.h5ad' (and not ending with '.gz') are removed.
587
- tmp_dir : str | Path | None
588
- Path to a directory to remove recursively (e.g. a temp dir created earlier).
589
- dry_run : bool
590
- If True, print what *would* be removed but do not actually delete.
591
- verbose : bool
592
- Print progress / warnings.
1047
+ """Delete intermediate .h5ad files and a temporary directory.
1048
+
1049
+ Args:
1050
+ h5_dir: Directory path or iterable of file paths to inspect for `.h5ad` files.
1051
+ tmp_dir: Optional directory to remove recursively.
1052
+ dry_run: If True, log what would be removed without deleting.
1053
+ verbose: If True, log progress and warnings.
1054
+
1055
+ Processing Steps:
1056
+ 1. Remove `.h5ad` files (excluding `.gz`) from the provided directory or list.
1057
+ 2. Optionally remove the temporary directory tree.
593
1058
  """
594
1059
 
595
1060
  # Helper: remove a single file path (Path-like or string)