smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +49 -7
  4. smftools/cli/hmm_adata.py +250 -32
  5. smftools/cli/latent_adata.py +773 -0
  6. smftools/cli/load_adata.py +78 -74
  7. smftools/cli/preprocess_adata.py +122 -58
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +74 -112
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +52 -4
  12. smftools/config/conversion.yaml +1 -1
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +85 -12
  15. smftools/config/experiment_config.py +146 -1
  16. smftools/constants.py +69 -0
  17. smftools/hmm/HMM.py +88 -0
  18. smftools/hmm/call_hmm_peaks.py +1 -1
  19. smftools/informatics/__init__.py +6 -0
  20. smftools/informatics/bam_functions.py +358 -8
  21. smftools/informatics/binarize_converted_base_identities.py +2 -89
  22. smftools/informatics/converted_BAM_to_adata.py +636 -175
  23. smftools/informatics/h5ad_functions.py +198 -2
  24. smftools/informatics/modkit_extract_to_adata.py +1007 -425
  25. smftools/informatics/sequence_encoding.py +72 -0
  26. smftools/logging_utils.py +21 -2
  27. smftools/metadata.py +1 -1
  28. smftools/plotting/__init__.py +26 -3
  29. smftools/plotting/autocorrelation_plotting.py +22 -4
  30. smftools/plotting/chimeric_plotting.py +1893 -0
  31. smftools/plotting/classifiers.py +28 -14
  32. smftools/plotting/general_plotting.py +62 -1583
  33. smftools/plotting/hmm_plotting.py +1670 -8
  34. smftools/plotting/latent_plotting.py +804 -0
  35. smftools/plotting/plotting_utils.py +243 -0
  36. smftools/plotting/position_stats.py +16 -8
  37. smftools/plotting/preprocess_plotting.py +281 -0
  38. smftools/plotting/qc_plotting.py +8 -3
  39. smftools/plotting/spatial_plotting.py +1134 -0
  40. smftools/plotting/variant_plotting.py +1231 -0
  41. smftools/preprocessing/__init__.py +4 -0
  42. smftools/preprocessing/append_base_context.py +18 -18
  43. smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
  44. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  45. smftools/preprocessing/append_variant_call_layer.py +480 -0
  46. smftools/preprocessing/calculate_consensus.py +1 -1
  47. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  48. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  49. smftools/preprocessing/invert_adata.py +1 -0
  50. smftools/readwrite.py +159 -99
  51. smftools/schema/anndata_schema_v1.yaml +15 -1
  52. smftools/tools/__init__.py +10 -0
  53. smftools/tools/calculate_knn.py +121 -0
  54. smftools/tools/calculate_leiden.py +57 -0
  55. smftools/tools/calculate_nmf.py +130 -0
  56. smftools/tools/calculate_pca.py +180 -0
  57. smftools/tools/calculate_umap.py +79 -80
  58. smftools/tools/position_stats.py +4 -4
  59. smftools/tools/rolling_nn_distance.py +872 -0
  60. smftools/tools/sequence_alignment.py +140 -0
  61. smftools/tools/tensor_factorization.py +217 -0
  62. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
  63. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
  64. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  65. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  66. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -5,22 +5,45 @@ import logging
5
5
  import shutil
6
6
  import time
7
7
  import traceback
8
+ from dataclasses import dataclass
8
9
  from multiprocessing import Manager, Pool, current_process
9
10
  from pathlib import Path
10
- from typing import TYPE_CHECKING, Iterable, Optional, Union
11
+ from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Union
11
12
 
12
13
  import anndata as ad
13
14
  import numpy as np
14
15
  import pandas as pd
15
16
 
17
+ from smftools.constants import (
18
+ BAM_SUFFIX,
19
+ BARCODE,
20
+ BASE_QUALITY_SCORES,
21
+ DATASET,
22
+ DEMUX_TYPE,
23
+ H5_DIR,
24
+ MISMATCH_INTEGER_ENCODING,
25
+ MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
26
+ MODKIT_EXTRACT_SEQUENCE_BASES,
27
+ MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE,
28
+ MODKIT_EXTRACT_SEQUENCE_PADDING_BASE,
29
+ READ_MAPPING_DIRECTION,
30
+ READ_MISMATCH_TREND,
31
+ READ_SPAN_MASK,
32
+ REFERENCE,
33
+ REFERENCE_DATASET_STRAND,
34
+ REFERENCE_STRAND,
35
+ SAMPLE,
36
+ SEQUENCE_INTEGER_DECODING,
37
+ SEQUENCE_INTEGER_ENCODING,
38
+ STRAND,
39
+ )
16
40
  from smftools.logging_utils import get_logger, setup_logging
17
41
  from smftools.optional_imports import require
18
42
 
19
43
  from ..readwrite import make_dirs
20
44
  from .bam_functions import count_aligned_reads, extract_base_identities
21
45
  from .binarize_converted_base_identities import binarize_converted_base_identities
22
- from .fasta_functions import find_conversion_sites
23
- from .ohe import ohe_batching
46
+ from .fasta_functions import find_conversion_sites, get_native_references
24
47
 
25
48
  logger = get_logger(__name__)
26
49
 
@@ -30,6 +53,67 @@ if TYPE_CHECKING:
30
53
  torch = require("torch", extra="torch", purpose="converted BAM processing")
31
54
 
32
55
 
56
+ @dataclass(frozen=True)
57
+ class RecordFastaInfo:
58
+ """Structured FASTA metadata for a single converted record.
59
+
60
+ Attributes:
61
+ sequence: Padded top-strand sequence for the record.
62
+ complement: Padded bottom-strand sequence for the record.
63
+ chromosome: Canonical chromosome name for the record.
64
+ unconverted_name: FASTA record name for the unconverted reference.
65
+ sequence_length: Length of the unpadded reference sequence.
66
+ padding_length: Number of padded bases applied to reach max length.
67
+ conversion: Conversion label (e.g., "unconverted", "5mC").
68
+ strand: Strand label ("top" or "bottom").
69
+ max_reference_length: Maximum reference length across all records.
70
+ """
71
+
72
+ sequence: str
73
+ complement: str
74
+ chromosome: str
75
+ unconverted_name: str
76
+ sequence_length: int
77
+ padding_length: int
78
+ conversion: str
79
+ strand: str
80
+ max_reference_length: int
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class SequenceEncodingConfig:
85
+ """Configuration for integer sequence encoding.
86
+
87
+ Attributes:
88
+ base_to_int: Mapping of base characters to integer encodings.
89
+ bases: Valid base characters used for encoding.
90
+ padding_base: Base token used for padding.
91
+ batch_size: Number of reads per temporary batch file.
92
+ """
93
+
94
+ base_to_int: Mapping[str, int]
95
+ bases: tuple[str, ...]
96
+ padding_base: str
97
+ batch_size: int = 100000
98
+
99
+ @property
100
+ def padding_value(self) -> int:
101
+ """Return the integer value used for padding positions."""
102
+ return self.base_to_int[self.padding_base]
103
+
104
+ @property
105
+ def unknown_value(self) -> int:
106
+ """Return the integer value used for unknown bases."""
107
+ return self.base_to_int["N"]
108
+
109
+
110
+ SEQUENCE_ENCODING_CONFIG = SequenceEncodingConfig(
111
+ base_to_int=MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
112
+ bases=MODKIT_EXTRACT_SEQUENCE_BASES,
113
+ padding_base=MODKIT_EXTRACT_SEQUENCE_PADDING_BASE,
114
+ )
115
+
116
+
33
117
  def converted_BAM_to_adata(
34
118
  converted_FASTA: str | Path,
35
119
  split_dir: Path,
@@ -46,7 +130,7 @@ def converted_BAM_to_adata(
46
130
  double_barcoded_path: Path | None = None,
47
131
  samtools_backend: str | None = "auto",
48
132
  ) -> tuple[ad.AnnData | None, Path]:
49
- """Convert BAM files into an AnnData object by binarizing modified base identities.
133
+ """Convert converted BAM files into an AnnData object with integer sequence encoding.
50
134
 
51
135
  Args:
52
136
  converted_FASTA: Path to the converted FASTA reference.
@@ -62,9 +146,18 @@ def converted_BAM_to_adata(
62
146
  deaminase_footprinting: Whether the footprinting used direct deamination chemistry.
63
147
  delete_intermediates: Whether to remove intermediate files after processing.
64
148
  double_barcoded_path: Path to dorado demux summary file of double-ended barcodes.
149
+ samtools_backend: Samtools backend choice for alignment parsing.
65
150
 
66
151
  Returns:
67
152
  tuple[anndata.AnnData | None, Path]: The AnnData object (if generated) and its path.
153
+
154
+ Processing Steps:
155
+ 1. Resolve the best available torch device and create output directories.
156
+ 2. Load converted FASTA records and compute conversion sites.
157
+ 3. Filter BAMs based on mapping thresholds.
158
+ 4. Process each BAM in parallel, building per-sample H5AD files.
159
+ 5. Concatenate per-sample AnnData objects and attach reference metadata.
160
+ 6. Add demultiplexing annotations and clean intermediate artifacts.
68
161
  """
69
162
  if torch.cuda.is_available():
70
163
  device = torch.device("cuda")
@@ -76,7 +169,7 @@ def converted_BAM_to_adata(
76
169
  logger.debug(f"Using device: {device}")
77
170
 
78
171
  ## Set Up Directories and File Paths
79
- h5_dir = output_dir / "h5ads"
172
+ h5_dir = output_dir / H5_DIR
80
173
  tmp_dir = output_dir / "tmp"
81
174
  final_adata = None
82
175
  final_adata_path = h5_dir / f"{experiment_name}.h5ad.gz"
@@ -90,7 +183,7 @@ def converted_BAM_to_adata(
90
183
  bam_files = sorted(
91
184
  p
92
185
  for p in split_dir.iterdir()
93
- if p.is_file() and p.suffix == ".bam" and "unclassified" not in p.name
186
+ if p.is_file() and p.suffix == BAM_SUFFIX and "unclassified" not in p.name
94
187
  )
95
188
 
96
189
  bam_path_list = bam_files
@@ -108,6 +201,16 @@ def converted_BAM_to_adata(
108
201
  bam_path_list, bam_files, mapping_threshold, samtools_backend
109
202
  )
110
203
 
204
+ # Get converted record sequences:
205
+ converted_FASTA_record_seq_map = get_native_references(converted_FASTA)
206
+ # Pad the record sequences
207
+ for record, [record_length, seq] in converted_FASTA_record_seq_map.items():
208
+ if max_reference_length > record_length:
209
+ pad_number = max_reference_length - record_length
210
+ record_length += pad_number
211
+ seq += "N" * pad_number
212
+ converted_FASTA_record_seq_map[record] = [record_length, seq]
213
+
111
214
  ## Process BAMs in Parallel
112
215
  final_adata = process_bams_parallel(
113
216
  bam_path_list,
@@ -121,8 +224,15 @@ def converted_BAM_to_adata(
121
224
  device,
122
225
  deaminase_footprinting,
123
226
  samtools_backend,
227
+ converted_FASTA_record_seq_map,
124
228
  )
125
229
 
230
+ final_adata.uns[f"{SEQUENCE_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
231
+ final_adata.uns[f"{MISMATCH_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
232
+ final_adata.uns[f"{SEQUENCE_INTEGER_DECODING}_map"] = {
233
+ str(key): value for key, value in MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE.items()
234
+ }
235
+
126
236
  final_adata.uns["References"] = {}
127
237
  for chromosome, [seq, comp] in chromosome_FASTA_dict.items():
128
238
  final_adata.var[f"{chromosome}_top_strand_FASTA_base"] = list(seq)
@@ -130,6 +240,11 @@ def converted_BAM_to_adata(
130
240
  final_adata.uns[f"{chromosome}_FASTA_sequence"] = seq
131
241
  final_adata.uns["References"][f"{chromosome}_FASTA_sequence"] = seq
132
242
 
243
+ if not deaminase_footprinting:
244
+ for record, [_length, seq] in converted_FASTA_record_seq_map.items():
245
+ if "unconverted" not in record:
246
+ final_adata.var[f"{record}_top_strand_FASTA_base"] = list(seq)
247
+
133
248
  final_adata.obs_names_make_unique()
134
249
  cols = final_adata.obs.columns
135
250
 
@@ -137,9 +252,33 @@ def converted_BAM_to_adata(
137
252
  for col in cols:
138
253
  final_adata.obs[col] = final_adata.obs[col].astype("category")
139
254
 
255
+ consensus_bases = MODKIT_EXTRACT_SEQUENCE_BASES[:4] # ignore N/PAD for consensus
256
+ consensus_base_ints = [MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[base] for base in consensus_bases]
257
+ for ref_group in final_adata.obs[REFERENCE_DATASET_STRAND].cat.categories:
258
+ group_subset = final_adata[final_adata.obs[REFERENCE_DATASET_STRAND] == ref_group]
259
+ encoded_sequences = group_subset.layers[SEQUENCE_INTEGER_ENCODING]
260
+ layer_counts = [
261
+ np.sum(encoded_sequences == base_int, axis=0) for base_int in consensus_base_ints
262
+ ]
263
+ count_array = np.array(layer_counts)
264
+ nucleotide_indexes = np.argmax(count_array, axis=0)
265
+ consensus_sequence_list = [consensus_bases[i] for i in nucleotide_indexes]
266
+ no_calls_mask = np.sum(count_array, axis=0) == 0
267
+ if np.any(no_calls_mask):
268
+ consensus_sequence_list = np.array(consensus_sequence_list, dtype=object)
269
+ consensus_sequence_list[no_calls_mask] = "N"
270
+ consensus_sequence_list = consensus_sequence_list.tolist()
271
+ final_adata.var[f"{ref_group}_consensus_sequence_from_all_samples"] = (
272
+ consensus_sequence_list
273
+ )
274
+
275
+ from .h5ad_functions import append_reference_strand_quality_stats
276
+
277
+ append_reference_strand_quality_stats(final_adata)
278
+
140
279
  if input_already_demuxed:
141
- final_adata.obs["demux_type"] = ["already"] * final_adata.shape[0]
142
- final_adata.obs["demux_type"] = final_adata.obs["demux_type"].astype("category")
280
+ final_adata.obs[DEMUX_TYPE] = ["already"] * final_adata.shape[0]
281
+ final_adata.obs[DEMUX_TYPE] = final_adata.obs[DEMUX_TYPE].astype("category")
143
282
  else:
144
283
  from .h5ad_functions import add_demux_type_annotation
145
284
 
@@ -156,37 +295,47 @@ def converted_BAM_to_adata(
156
295
 
157
296
 
158
297
  def process_conversion_sites(
159
- converted_FASTA, conversions=["unconverted", "5mC"], deaminase_footprinting=False
160
- ):
161
- """
162
- Extracts conversion sites and determines the max reference length.
298
+ converted_FASTA: str | Path,
299
+ conversions: list[str] | None = None,
300
+ deaminase_footprinting: bool = False,
301
+ ) -> tuple[int, dict[str, RecordFastaInfo], dict[str, tuple[str, str]]]:
302
+ """Extract conversion sites and FASTA metadata for converted references.
163
303
 
164
- Parameters:
165
- converted_FASTA (str): Path to the converted reference FASTA.
166
- conversions (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
167
- deaminase_footprinting (bool): Whether the footprinting was done with a direct deamination chemistry.
304
+ Args:
305
+ converted_FASTA: Path to the converted reference FASTA.
306
+ conversions: List of modification types (e.g., ["unconverted", "5mC", "6mA"]).
307
+ deaminase_footprinting: Whether the footprinting was done with direct deamination chemistry.
168
308
 
169
309
  Returns:
170
- max_reference_length (int): The length of the longest sequence.
171
- record_FASTA_dict (dict): Dictionary of sequence information for **both converted & unconverted** records.
310
+ tuple[int, dict[str, RecordFastaInfo], dict[str, tuple[str, str]]]:
311
+ Maximum reference length, record metadata, and chromosome sequences.
312
+
313
+ Processing Steps:
314
+ 1. Parse unconverted FASTA records to determine the max reference length.
315
+ 2. Build record metadata for unconverted and converted strands.
316
+ 3. Cache chromosome-level sequences for downstream annotation.
172
317
  """
173
- modification_dict = {}
174
- record_FASTA_dict = {}
175
- chromosome_FASTA_dict = {}
318
+ if conversions is None:
319
+ conversions = ["unconverted", "5mC"]
320
+ modification_dict: dict[str, dict] = {}
321
+ record_FASTA_dict: dict[str, RecordFastaInfo] = {}
322
+ chromosome_FASTA_dict: dict[str, tuple[str, str]] = {}
176
323
  max_reference_length = 0
177
324
  unconverted = conversions[0]
178
325
  conversion_types = conversions[1:]
179
326
 
180
327
  # Process the unconverted sequence once
328
+ # modification dict is keyed by mod type (ie unconverted, 5mC, 6mA)
329
+ # modification_dict[unconverted] points to a dictionary keyed by unconverted record.id keys.
330
+ # This then maps to [sequence_length, [], [], unconverted sequence, unconverted complement]
181
331
  modification_dict[unconverted] = find_conversion_sites(
182
332
  converted_FASTA, unconverted, conversions, deaminase_footprinting
183
333
  )
184
- # Above points to record_dict[record.id] = [sequence_length, [], [], sequence, complement] with only unconverted record.id keys
185
334
 
186
- # Get **max sequence length** from unconverted records
335
+ # Get max sequence length from unconverted records
187
336
  max_reference_length = max(values[0] for values in modification_dict[unconverted].values())
188
337
 
189
- # Add **unconverted records** to `record_FASTA_dict`
338
+ # Add unconverted records to `record_FASTA_dict`
190
339
  for record, values in modification_dict[unconverted].items():
191
340
  sequence_length, top_coords, bottom_coords, sequence, complement = values
192
341
 
@@ -196,61 +345,91 @@ def process_conversion_sites(
196
345
  chromosome = record
197
346
 
198
347
  # Store **original sequence**
199
- record_FASTA_dict[record] = [
200
- sequence + "N" * (max_reference_length - sequence_length),
201
- complement + "N" * (max_reference_length - sequence_length),
202
- chromosome,
203
- record,
204
- sequence_length,
205
- max_reference_length - sequence_length,
206
- unconverted,
207
- "top",
208
- ]
348
+ record_FASTA_dict[record] = RecordFastaInfo(
349
+ sequence=sequence + "N" * (max_reference_length - sequence_length),
350
+ complement=complement + "N" * (max_reference_length - sequence_length),
351
+ chromosome=chromosome,
352
+ unconverted_name=record,
353
+ sequence_length=sequence_length,
354
+ padding_length=max_reference_length - sequence_length,
355
+ conversion=unconverted,
356
+ strand="top",
357
+ max_reference_length=max_reference_length,
358
+ )
209
359
 
210
360
  if chromosome not in chromosome_FASTA_dict:
211
- chromosome_FASTA_dict[chromosome] = [
361
+ chromosome_FASTA_dict[chromosome] = (
212
362
  sequence + "N" * (max_reference_length - sequence_length),
213
363
  complement + "N" * (max_reference_length - sequence_length),
214
- ]
364
+ )
215
365
 
216
366
  # Process converted records
367
+ # For each conversion type (ie 5mC, 6mA), add the conversion type as a key to modification_dict.
368
+ # This points to a dictionary keyed by the unconverted record id key.
369
+ # This points to [sequence_length, top_strand_coordinates, bottom_strand_coordinates, unconverted sequence, unconverted complement]
217
370
  for conversion in conversion_types:
218
371
  modification_dict[conversion] = find_conversion_sites(
219
372
  converted_FASTA, conversion, conversions, deaminase_footprinting
220
373
  )
221
- # Above points to record_dict[record.id] = [sequence_length, top_strand_coordinates, bottom_strand_coordinates, sequence, complement] with only unconverted record.id keys
222
374
 
375
+ # Iterate over the unconverted record ids in mod_dict, as well as the
376
+ # [sequence_length, top_strand_coordinates, bottom_strand_coordinates, unconverted sequence, unconverted complement] for the conversion type
223
377
  for record, values in modification_dict[conversion].items():
224
378
  sequence_length, top_coords, bottom_coords, sequence, complement = values
225
379
 
226
380
  if not deaminase_footprinting:
227
- chromosome = record.split(f"_{unconverted}_")[0] # Extract chromosome name
381
+ # For conversion smf, make the chromosome name the base record name
382
+ chromosome = record.split(f"_{unconverted}_")[0]
228
383
  else:
384
+ # For deaminase smf, make the chromosome and record name the same
229
385
  chromosome = record
230
386
 
231
- # Add **both strands** for converted records
387
+ # Add both strands for converted records
232
388
  for strand in ["top", "bottom"]:
389
+ # Generate converted/unconverted record names that are found in the converted FASTA
233
390
  converted_name = f"{chromosome}_{conversion}_{strand}"
234
391
  unconverted_name = f"{chromosome}_{unconverted}_top"
235
392
 
236
- record_FASTA_dict[converted_name] = [
237
- sequence + "N" * (max_reference_length - sequence_length),
238
- complement + "N" * (max_reference_length - sequence_length),
239
- chromosome,
240
- unconverted_name,
241
- sequence_length,
242
- max_reference_length - sequence_length,
243
- conversion,
244
- strand,
245
- ]
246
-
247
- logger.debug("Updated record_FASTA_dict Keys:", list(record_FASTA_dict.keys()))
393
+ # Use the converted FASTA record names as keys to a dict that points to RecordFastaInfo objects.
394
+ # These objects will contain the unconverted sequence/complement.
395
+ record_FASTA_dict[converted_name] = RecordFastaInfo(
396
+ sequence=sequence + "N" * (max_reference_length - sequence_length),
397
+ complement=complement + "N" * (max_reference_length - sequence_length),
398
+ chromosome=chromosome,
399
+ unconverted_name=unconverted_name,
400
+ sequence_length=sequence_length,
401
+ padding_length=max_reference_length - sequence_length,
402
+ conversion=conversion,
403
+ strand=strand,
404
+ max_reference_length=max_reference_length,
405
+ )
406
+
407
+ logger.debug("Updated record_FASTA_dict keys: %s", list(record_FASTA_dict.keys()))
248
408
  return max_reference_length, record_FASTA_dict, chromosome_FASTA_dict
249
409
 
250
410
 
251
- def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold, samtools_backend):
252
- """Filters BAM files based on mapping threshold."""
253
- records_to_analyze = set()
411
+ def filter_bams_by_mapping_threshold(
412
+ bam_path_list: list[Path],
413
+ bam_files: list[Path],
414
+ mapping_threshold: float,
415
+ samtools_backend: str | None,
416
+ ) -> set[str]:
417
+ """Filter FASTA records based on per-BAM mapping thresholds.
418
+
419
+ Args:
420
+ bam_path_list: Ordered list of BAM paths to evaluate.
421
+ bam_files: Matching list of BAM paths used for reporting.
422
+ mapping_threshold: Minimum percentage of aligned reads to include a record.
423
+ samtools_backend: Samtools backend choice for alignment parsing.
424
+
425
+ Returns:
426
+ set[str]: FASTA record IDs that pass the mapping threshold.
427
+
428
+ Processing Steps:
429
+ 1. Count aligned/unaligned reads and per-record percentages.
430
+ 2. Collect record IDs that meet the mapping threshold.
431
+ """
432
+ records_to_analyze: set[str] = set()
254
433
 
255
434
  for i, bam in enumerate(bam_path_list):
256
435
  aligned_reads, unaligned_reads, record_counts = count_aligned_reads(bam, samtools_backend)
@@ -265,33 +444,182 @@ def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold
265
444
  return records_to_analyze
266
445
 
267
446
 
447
+ def _encode_sequence_array(
448
+ read_sequence: np.ndarray,
449
+ valid_length: int,
450
+ config: SequenceEncodingConfig,
451
+ ) -> np.ndarray:
452
+ """Encode a base-identity array into integer values with padding.
453
+
454
+ Args:
455
+ read_sequence: Array of base calls (dtype "<U1").
456
+ valid_length: Number of valid reference positions for this record.
457
+ config: Integer encoding configuration.
458
+
459
+ Returns:
460
+ np.ndarray: Integer-encoded sequence with padding applied.
461
+
462
+ Processing Steps:
463
+ 1. Initialize an array filled with the unknown base encoding.
464
+ 2. Map A/C/G/T/N bases into integer values.
465
+ 3. Mark positions beyond valid_length with the padding value.
466
+ """
467
+ read_sequence = np.asarray(read_sequence, dtype="<U1")
468
+ encoded = np.full(read_sequence.shape, config.unknown_value, dtype=np.int16)
469
+ for base in config.bases:
470
+ encoded[read_sequence == base] = config.base_to_int[base]
471
+ if valid_length < encoded.size:
472
+ encoded[valid_length:] = config.padding_value
473
+ return encoded
474
+
475
+
476
+ def _write_sequence_batches(
477
+ base_identities: Mapping[str, np.ndarray],
478
+ tmp_dir: Path,
479
+ record: str,
480
+ prefix: str,
481
+ valid_length: int,
482
+ config: SequenceEncodingConfig,
483
+ ) -> list[str]:
484
+ """Encode base identities into integer arrays and write batched H5AD files.
485
+
486
+ Args:
487
+ base_identities: Mapping of read name to base identity arrays.
488
+ tmp_dir: Directory for temporary H5AD files.
489
+ record: Reference record identifier.
490
+ prefix: Prefix used to name batch files.
491
+ valid_length: Valid reference length for padding determination.
492
+ config: Integer encoding configuration.
493
+
494
+ Returns:
495
+ list[str]: Paths to written H5AD batch files.
496
+
497
+ Processing Steps:
498
+ 1. Encode each read sequence into integers.
499
+ 2. Accumulate encoded reads into batches.
500
+ 3. Persist each batch to an H5AD file with `.uns` storage.
501
+ """
502
+ batch_files: list[str] = []
503
+ batch: dict[str, np.ndarray] = {}
504
+ batch_number = 0
505
+
506
+ for read_name, sequence in base_identities.items():
507
+ if sequence is None:
508
+ continue
509
+ batch[read_name] = _encode_sequence_array(sequence, valid_length, config)
510
+ if len(batch) >= config.batch_size:
511
+ save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
512
+ ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
513
+ batch_files.append(str(save_name))
514
+ batch = {}
515
+ batch_number += 1
516
+
517
+ if batch:
518
+ save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
519
+ ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
520
+ batch_files.append(str(save_name))
521
+
522
+ return batch_files
523
+
524
+
525
+ def _load_sequence_batches(
526
+ batch_files: list[Path | str],
527
+ ) -> tuple[dict[str, np.ndarray], set[str], set[str]]:
528
+ """Load integer-encoded sequence batches from H5AD files.
529
+
530
+ Args:
531
+ batch_files: H5AD paths containing encoded sequences in `.uns`.
532
+
533
+ Returns:
534
+ tuple[dict[str, np.ndarray], set[str], set[str]]:
535
+ Read-to-sequence mapping and sets of forward/reverse mapped reads.
536
+
537
+ Processing Steps:
538
+ 1. Read each H5AD file.
539
+ 2. Merge `.uns` dictionaries into a single mapping.
540
+ 3. Track forward/reverse read IDs based on filename markers.
541
+ """
542
+ sequences: dict[str, np.ndarray] = {}
543
+ fwd_reads: set[str] = set()
544
+ rev_reads: set[str] = set()
545
+ for batch_file in batch_files:
546
+ batch_path = Path(batch_file)
547
+ batch_sequences = ad.read_h5ad(batch_path).uns
548
+ sequences.update(batch_sequences)
549
+ if "_fwd_" in batch_path.name:
550
+ fwd_reads.update(batch_sequences.keys())
551
+ elif "_rev_" in batch_path.name:
552
+ rev_reads.update(batch_sequences.keys())
553
+ return sequences, fwd_reads, rev_reads
554
+
555
+
268
556
  def process_single_bam(
269
- bam_index,
270
- bam,
271
- records_to_analyze,
272
- record_FASTA_dict,
273
- chromosome_FASTA_dict,
274
- tmp_dir,
275
- max_reference_length,
276
- device,
277
- deaminase_footprinting,
278
- samtools_backend,
279
- ):
280
- """Worker function to process a single BAM file (must be at top-level for multiprocessing)."""
281
- adata_list = []
557
+ bam_index: int,
558
+ bam: Path,
559
+ records_to_analyze: set[str],
560
+ record_FASTA_dict: dict[str, RecordFastaInfo],
561
+ chromosome_FASTA_dict: dict[str, tuple[str, str]],
562
+ tmp_dir: Path,
563
+ max_reference_length: int,
564
+ device: torch.device,
565
+ deaminase_footprinting: bool,
566
+ samtools_backend: str | None,
567
+ converted_FASTA_record_seq_map: dict[str, tuple[int, str]],
568
+ ) -> ad.AnnData | None:
569
+ """Process a single BAM file into per-record AnnData objects.
570
+
571
+ Args:
572
+ bam_index: Index of the BAM within the processing batch.
573
+ bam: Path to the BAM file.
574
+ records_to_analyze: FASTA record IDs that passed the mapping threshold.
575
+ record_FASTA_dict: FASTA metadata keyed by record ID.
576
+ chromosome_FASTA_dict: Chromosome sequences for annotations.
577
+ tmp_dir: Directory for temporary batch files.
578
+ max_reference_length: Maximum reference length for padding.
579
+ device: Torch device used for binarization.
580
+ deaminase_footprinting: Whether direct deamination chemistry was used.
581
+ samtools_backend: Samtools backend choice for alignment parsing.
582
+ converted_FASTA_record_seq_map: record to seq map
583
+
584
+ Returns:
585
+ anndata.AnnData | None: Concatenated AnnData object or None if no data.
586
+
587
+ Processing Steps:
588
+ 1. Extract base identities and mismatch profiles for each record.
589
+ 2. Binarize modified base identities into feature matrices.
590
+ 3. Encode read sequences into integer arrays and cache batches.
591
+ 4. Build AnnData layers/obs metadata for each record and concatenate.
592
+ """
593
+ adata_list: list[ad.AnnData] = []
282
594
 
595
+ # Iterate over BAM records that passed filtering.
283
596
  for record in records_to_analyze:
284
597
  sample = bam.stem
285
- chromosome = record_FASTA_dict[record][2]
286
- current_length = record_FASTA_dict[record][4]
287
- mod_type, strand = record_FASTA_dict[record][6], record_FASTA_dict[record][7]
288
- sequence = chromosome_FASTA_dict[chromosome][0]
289
-
290
- # Extract Base Identities
291
- fwd_bases, rev_bases, mismatch_counts_per_read, mismatch_trend_per_read = (
292
- extract_base_identities(
293
- bam, record, range(current_length), max_reference_length, sequence, samtools_backend
294
- )
598
+ record_info = record_FASTA_dict[record]
599
+ chromosome = record_info.chromosome
600
+ current_length = record_info.sequence_length
601
+ # Note, mod_type and strand are only correctly load for conversion smf and not deaminase
602
+ # However, these variables are only used for conversion smf and not deaminase, so works.
603
+ mod_type, strand = record_info.conversion, record_info.strand
604
+ non_converted_sequence = chromosome_FASTA_dict[chromosome][0]
605
+ record_sequence = converted_FASTA_record_seq_map[record][1]
606
+
607
+ # Extract Base Identities for forward and reverse mapped reads.
608
+ (
609
+ fwd_bases,
610
+ rev_bases,
611
+ mismatch_counts_per_read,
612
+ mismatch_trend_per_read,
613
+ mismatch_base_identities,
614
+ base_quality_scores,
615
+ read_span_masks,
616
+ ) = extract_base_identities(
617
+ bam,
618
+ record,
619
+ range(current_length),
620
+ max_reference_length,
621
+ record_sequence,
622
+ samtools_backend,
295
623
  )
296
624
  mismatch_trend_series = pd.Series(mismatch_trend_per_read)
297
625
 
@@ -305,13 +633,12 @@ def process_single_bam(
305
633
  merged_bin = {}
306
634
 
307
635
  # Binarize the Base Identities if they exist
636
+ # Note, mod_type is always unconverted and strand is always top currently for deaminase smf. this works for now.
308
637
  if fwd_bases:
309
638
  fwd_bin = binarize_converted_base_identities(
310
639
  fwd_bases,
311
640
  strand,
312
641
  mod_type,
313
- bam,
314
- device,
315
642
  deaminase_footprinting,
316
643
  mismatch_trend_per_read,
317
644
  )
@@ -322,8 +649,6 @@ def process_single_bam(
322
649
  rev_bases,
323
650
  strand,
324
651
  mod_type,
325
- bam,
326
- device,
327
652
  deaminase_footprinting,
328
653
  mismatch_trend_per_read,
329
654
  )
@@ -343,83 +668,140 @@ def process_single_bam(
343
668
  sorted_index = sorted(bin_df.index)
344
669
  bin_df = bin_df.reindex(sorted_index)
345
670
 
346
- # One-Hot Encode Reads if there is valid data
347
- one_hot_reads = {}
348
-
671
+ # Integer-encode reads if there is valid data
672
+ batch_files: list[str] = []
349
673
  if fwd_bases:
350
- fwd_ohe_files = ohe_batching(
351
- fwd_bases, tmp_dir, record, f"{bam_index}_fwd", batch_size=100000
674
+ batch_files.extend(
675
+ _write_sequence_batches(
676
+ fwd_bases,
677
+ tmp_dir,
678
+ record,
679
+ f"{bam_index}_fwd",
680
+ current_length,
681
+ SEQUENCE_ENCODING_CONFIG,
682
+ )
352
683
  )
353
- for ohe_file in fwd_ohe_files:
354
- tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
355
- one_hot_reads.update(tmp_ohe_dict)
356
- del tmp_ohe_dict
357
684
 
358
685
  if rev_bases:
359
- rev_ohe_files = ohe_batching(
360
- rev_bases, tmp_dir, record, f"{bam_index}_rev", batch_size=100000
686
+ batch_files.extend(
687
+ _write_sequence_batches(
688
+ rev_bases,
689
+ tmp_dir,
690
+ record,
691
+ f"{bam_index}_rev",
692
+ current_length,
693
+ SEQUENCE_ENCODING_CONFIG,
694
+ )
361
695
  )
362
- for ohe_file in rev_ohe_files:
363
- tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
364
- one_hot_reads.update(tmp_ohe_dict)
365
- del tmp_ohe_dict
366
696
 
367
- # Skip if one_hot_reads is empty
368
- if not one_hot_reads:
697
+ if not batch_files:
369
698
  logger.debug(
370
- f"[Worker {current_process().pid}] Skipping {sample} - No valid one-hot encoded data for {record}."
699
+ f"[Worker {current_process().pid}] Skipping {sample} - No valid encoded data for {record}."
371
700
  )
372
701
  continue
373
702
 
374
703
  gc.collect()
375
704
 
376
- # Convert One-Hot Encodings to Numpy Arrays
377
- n_rows_OHE = 5
378
- read_names = list(one_hot_reads.keys())
379
-
380
- # Skip if no read names exist
381
- if not read_names:
705
+ encoded_reads, fwd_reads, rev_reads = _load_sequence_batches(batch_files)
706
+ if not encoded_reads:
382
707
  logger.debug(
383
- f"[Worker {current_process().pid}] Skipping {sample} - No reads found in one-hot encoded data for {record}."
708
+ f"[Worker {current_process().pid}] Skipping {sample} - No reads found in encoded data for {record}."
384
709
  )
385
710
  continue
386
711
 
387
- sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
388
- df_A, df_C, df_G, df_T, df_N = [
389
- np.zeros((len(sorted_index), sequence_length), dtype=int) for _ in range(5)
390
- ]
712
+ sequence_length = max_reference_length
713
+ default_sequence = np.full(
714
+ sequence_length, SEQUENCE_ENCODING_CONFIG.unknown_value, dtype=np.int16
715
+ )
716
+ if current_length < sequence_length:
717
+ default_sequence[current_length:] = SEQUENCE_ENCODING_CONFIG.padding_value
391
718
 
392
- # Populate One-Hot Arrays
393
- for j, read_name in enumerate(sorted_index):
394
- if read_name in one_hot_reads:
395
- one_hot_array = one_hot_reads[read_name].reshape(n_rows_OHE, -1)
396
- df_A[j], df_C[j], df_G[j], df_T[j], df_N[j] = one_hot_array
719
+ encoded_matrix = np.vstack(
720
+ [encoded_reads.get(read_name, default_sequence) for read_name in sorted_index]
721
+ )
722
+ default_mismatch_sequence = np.full(
723
+ sequence_length, SEQUENCE_ENCODING_CONFIG.unknown_value, dtype=np.int16
724
+ )
725
+ if current_length < sequence_length:
726
+ default_mismatch_sequence[current_length:] = SEQUENCE_ENCODING_CONFIG.padding_value
727
+ mismatch_encoded_matrix = np.vstack(
728
+ [
729
+ mismatch_base_identities.get(read_name, default_mismatch_sequence)
730
+ for read_name in sorted_index
731
+ ]
732
+ )
733
+ default_quality_sequence = np.full(sequence_length, -1, dtype=np.int16)
734
+ quality_matrix = np.vstack(
735
+ [
736
+ base_quality_scores.get(read_name, default_quality_sequence)
737
+ for read_name in sorted_index
738
+ ]
739
+ )
740
+ default_read_span = np.zeros(sequence_length, dtype=np.int16)
741
+ read_span_matrix = np.vstack(
742
+ [read_span_masks.get(read_name, default_read_span) for read_name in sorted_index]
743
+ )
397
744
 
398
745
  # Convert to AnnData
399
746
  X = bin_df.values.astype(np.float32)
400
747
  adata = ad.AnnData(X)
401
748
  adata.obs_names = bin_df.index.astype(str)
402
749
  adata.var_names = bin_df.columns.astype(str)
403
- adata.obs["Sample"] = [sample] * len(adata)
750
+ adata.obs[SAMPLE] = [sample] * len(adata)
404
751
  try:
405
752
  barcode = sample.split("barcode")[1]
406
753
  except Exception:
407
754
  barcode = np.nan
408
- adata.obs["Barcode"] = [int(barcode)] * len(adata)
409
- adata.obs["Barcode"] = adata.obs["Barcode"].astype(str)
410
- adata.obs["Reference"] = [chromosome] * len(adata)
411
- adata.obs["Strand"] = [strand] * len(adata)
412
- adata.obs["Dataset"] = [mod_type] * len(adata)
413
- adata.obs["Reference_dataset_strand"] = [f"{chromosome}_{mod_type}_{strand}"] * len(adata)
414
- adata.obs["Reference_strand"] = [f"{chromosome}_{strand}"] * len(adata)
415
- adata.obs["Read_mismatch_trend"] = adata.obs_names.map(mismatch_trend_series)
416
-
417
- # Attach One-Hot Encodings to Layers
418
- adata.layers["A_binary_sequence_encoding"] = df_A
419
- adata.layers["C_binary_sequence_encoding"] = df_C
420
- adata.layers["G_binary_sequence_encoding"] = df_G
421
- adata.layers["T_binary_sequence_encoding"] = df_T
422
- adata.layers["N_binary_sequence_encoding"] = df_N
755
+ adata.obs[BARCODE] = [int(barcode)] * len(adata)
756
+ adata.obs[BARCODE] = adata.obs[BARCODE].astype(str)
757
+ adata.obs[REFERENCE] = [chromosome] * len(adata)
758
+ adata.obs[STRAND] = [strand] * len(adata)
759
+ adata.obs[DATASET] = [mod_type] * len(adata)
760
+ adata.obs[READ_MISMATCH_TREND] = adata.obs_names.map(mismatch_trend_series)
761
+
762
+ # Currently, deaminase footprinting uses mismatch trend to define the strand.
763
+ if deaminase_footprinting:
764
+ is_ct = adata.obs[READ_MISMATCH_TREND] == "C->T"
765
+ is_ga = adata.obs[READ_MISMATCH_TREND] == "G->A"
766
+
767
+ adata.obs.loc[is_ct, STRAND] = "top"
768
+ adata.obs.loc[is_ga, STRAND] = "bottom"
769
+ # Currently, conversion footprinting uses strand to define the mismatch trend.
770
+ else:
771
+ is_top = adata.obs[STRAND] == "top"
772
+ is_bottom = adata.obs[STRAND] == "bottom"
773
+
774
+ adata.obs.loc[is_top, READ_MISMATCH_TREND] = "C->T"
775
+ adata.obs.loc[is_bottom, READ_MISMATCH_TREND] = "G->A"
776
+
777
+ adata.obs[REFERENCE_DATASET_STRAND] = (
778
+ adata.obs[REFERENCE].astype(str)
779
+ + "_"
780
+ + adata.obs[DATASET].astype(str)
781
+ + "_"
782
+ + adata.obs[STRAND].astype(str)
783
+ )
784
+
785
+ adata.obs[REFERENCE_STRAND] = (
786
+ adata.obs[REFERENCE].astype(str) + "_" + adata.obs[STRAND].astype(str)
787
+ )
788
+
789
+ read_mapping_direction = []
790
+ for read_id in adata.obs_names:
791
+ if read_id in fwd_reads:
792
+ read_mapping_direction.append("fwd")
793
+ elif read_id in rev_reads:
794
+ read_mapping_direction.append("rev")
795
+ else:
796
+ read_mapping_direction.append("unk")
797
+
798
+ adata.obs[READ_MAPPING_DIRECTION] = read_mapping_direction
799
+
800
+ # Attach integer sequence encoding to layers
801
+ adata.layers[SEQUENCE_INTEGER_ENCODING] = encoded_matrix
802
+ adata.layers[MISMATCH_INTEGER_ENCODING] = mismatch_encoded_matrix
803
+ adata.layers[BASE_QUALITY_SCORES] = quality_matrix
804
+ adata.layers[READ_SPAN_MASK] = read_span_matrix
423
805
 
424
806
  adata_list.append(adata)
425
807
 
@@ -427,27 +809,56 @@ def process_single_bam(
427
809
 
428
810
 
429
811
  def timestamp():
430
- """Returns a formatted timestamp for logging."""
812
+ """Return a formatted timestamp for logging.
813
+
814
+ Returns:
815
+ str: Timestamp string in the format ``[YYYY-MM-DD HH:MM:SS]``.
816
+ """
431
817
  return time.strftime("[%Y-%m-%d %H:%M:%S]")
432
818
 
433
819
 
434
820
  def worker_function(
435
- bam_index,
436
- bam,
437
- records_to_analyze,
438
- shared_record_FASTA_dict,
439
- chromosome_FASTA_dict,
440
- tmp_dir,
441
- h5_dir,
442
- max_reference_length,
443
- device,
444
- deaminase_footprinting,
445
- samtools_backend,
821
+ bam_index: int,
822
+ bam: Path,
823
+ records_to_analyze: set[str],
824
+ shared_record_FASTA_dict: dict[str, RecordFastaInfo],
825
+ chromosome_FASTA_dict: dict[str, tuple[str, str]],
826
+ tmp_dir: Path,
827
+ h5_dir: Path,
828
+ max_reference_length: int,
829
+ device: torch.device,
830
+ deaminase_footprinting: bool,
831
+ samtools_backend: str | None,
832
+ converted_FASTA_record_seq_map: dict[str, tuple[int, str]],
446
833
  progress_queue,
447
- log_level,
448
- log_file,
834
+ log_level: int,
835
+ log_file: Path | None,
449
836
  ):
450
- """Worker function that processes a single BAM and writes the output to an H5AD file."""
837
+ """Process a single BAM and write the output to an H5AD file.
838
+
839
+ Args:
840
+ bam_index: Index of the BAM within the processing batch.
841
+ bam: Path to the BAM file.
842
+ records_to_analyze: FASTA record IDs that passed the mapping threshold.
843
+ shared_record_FASTA_dict: Shared FASTA metadata keyed by record ID.
844
+ chromosome_FASTA_dict: Chromosome sequences for annotations.
845
+ tmp_dir: Directory for temporary batch files.
846
+ h5_dir: Directory for per-BAM H5AD outputs.
847
+ max_reference_length: Maximum reference length for padding.
848
+ device: Torch device used for binarization.
849
+ deaminase_footprinting: Whether direct deamination chemistry was used.
850
+ samtools_backend: Samtools backend choice for alignment parsing.
851
+ converted_FASTA_record_seq_map: record to sequence map
852
+ progress_queue: Queue used to signal completion.
853
+ log_level: Logging level to configure in workers.
854
+ log_file: Optional log file path.
855
+
856
+ Processing Steps:
857
+ 1. Skip processing if an output H5AD already exists.
858
+ 2. Filter records to those present in the FASTA metadata.
859
+ 3. Run per-record processing and write AnnData output.
860
+ 4. Signal completion via the progress queue.
861
+ """
451
862
  _ensure_worker_logging(log_level, log_file)
452
863
  worker_id = current_process().pid # Get worker process ID
453
864
  sample = bam.stem
@@ -485,6 +896,7 @@ def worker_function(
485
896
  device,
486
897
  deaminase_footprinting,
487
898
  samtools_backend,
899
+ converted_FASTA_record_seq_map,
488
900
  )
489
901
 
490
902
  if adata is not None:
@@ -505,19 +917,43 @@ def worker_function(
505
917
 
506
918
 
507
919
  def process_bams_parallel(
508
- bam_path_list,
509
- records_to_analyze,
510
- record_FASTA_dict,
511
- chromosome_FASTA_dict,
512
- tmp_dir,
513
- h5_dir,
514
- num_threads,
515
- max_reference_length,
516
- device,
517
- deaminase_footprinting,
518
- samtools_backend,
519
- ):
520
- """Processes BAM files in parallel, writes each H5AD to disk, and concatenates them at the end."""
920
+ bam_path_list: list[Path],
921
+ records_to_analyze: set[str],
922
+ record_FASTA_dict: dict[str, RecordFastaInfo],
923
+ chromosome_FASTA_dict: dict[str, tuple[str, str]],
924
+ tmp_dir: Path,
925
+ h5_dir: Path,
926
+ num_threads: int,
927
+ max_reference_length: int,
928
+ device: torch.device,
929
+ deaminase_footprinting: bool,
930
+ samtools_backend: str | None,
931
+ converted_FASTA_record_seq_map: dict[str, tuple[int, str]],
932
+ ) -> ad.AnnData | None:
933
+ """Process BAM files in parallel and concatenate the resulting AnnData.
934
+
935
+ Args:
936
+ bam_path_list: List of BAM files to process.
937
+ records_to_analyze: FASTA record IDs that passed the mapping threshold.
938
+ record_FASTA_dict: FASTA metadata keyed by record ID.
939
+ chromosome_FASTA_dict: Chromosome sequences for annotations.
940
+ tmp_dir: Directory for temporary batch files.
941
+ h5_dir: Directory for per-BAM H5AD outputs.
942
+ num_threads: Number of worker processes.
943
+ max_reference_length: Maximum reference length for padding.
944
+ device: Torch device used for binarization.
945
+ deaminase_footprinting: Whether direct deamination chemistry was used.
946
+ samtools_backend: Samtools backend choice for alignment parsing.
947
+ converted_FASTA_record_seq_map: map from converted record name to the converted reference length and sequence.
948
+
949
+ Returns:
950
+ anndata.AnnData | None: Concatenated AnnData or None if no H5ADs produced.
951
+
952
+ Processing Steps:
953
+ 1. Spawn worker processes to handle each BAM.
954
+ 2. Track completion via a multiprocessing queue.
955
+ 3. Concatenate per-BAM H5AD files into a final AnnData.
956
+ """
521
957
  make_dirs(h5_dir) # Ensure h5_dir exists
522
958
 
523
959
  logger.info(f"Starting parallel BAM processing with {num_threads} threads...")
@@ -543,6 +979,7 @@ def process_bams_parallel(
543
979
  device,
544
980
  deaminase_footprinting,
545
981
  samtools_backend,
982
+ converted_FASTA_record_seq_map,
546
983
  progress_queue,
547
984
  log_level,
548
985
  log_file,
@@ -583,7 +1020,16 @@ def process_bams_parallel(
583
1020
 
584
1021
 
585
1022
  def _log_async_result_errors(results, bam_path_list):
586
- """Log worker failures captured by multiprocessing AsyncResult objects."""
1023
+ """Log worker failures captured by multiprocessing AsyncResult objects.
1024
+
1025
+ Args:
1026
+ results: Iterable of AsyncResult objects from multiprocessing.
1027
+ bam_path_list: List of BAM paths matching the async results.
1028
+
1029
+ Processing Steps:
1030
+ 1. Iterate over async results.
1031
+ 2. Retrieve results to surface worker exceptions.
1032
+ """
587
1033
  for bam, result in zip(bam_path_list, results):
588
1034
  if not result.ready():
589
1035
  continue
@@ -594,6 +1040,15 @@ def _log_async_result_errors(results, bam_path_list):
594
1040
 
595
1041
 
596
1042
  def _get_logger_config() -> tuple[int, Path | None]:
1043
+ """Return the active smftools logger level and optional file path.
1044
+
1045
+ Returns:
1046
+ tuple[int, Path | None]: Log level and log file path (if configured).
1047
+
1048
+ Processing Steps:
1049
+ 1. Inspect the smftools logger for configured handlers.
1050
+ 2. Extract log level and file handler path.
1051
+ """
597
1052
  smftools_logger = logging.getLogger("smftools")
598
1053
  level = smftools_logger.level
599
1054
  if level == logging.NOTSET:
@@ -607,6 +1062,16 @@ def _get_logger_config() -> tuple[int, Path | None]:
607
1062
 
608
1063
 
609
1064
  def _ensure_worker_logging(log_level: int, log_file: Path | None) -> None:
1065
+ """Ensure worker processes have logging configured.
1066
+
1067
+ Args:
1068
+ log_level: Logging level to configure.
1069
+ log_file: Optional log file path.
1070
+
1071
+ Processing Steps:
1072
+ 1. Check if handlers are already configured.
1073
+ 2. Initialize logging with the provided level and file path.
1074
+ """
610
1075
  smftools_logger = logging.getLogger("smftools")
611
1076
  if not smftools_logger.handlers:
612
1077
  setup_logging(level=log_level, log_file=log_file)
@@ -619,21 +1084,17 @@ def delete_intermediate_h5ads_and_tmpdir(
619
1084
  dry_run: bool = False,
620
1085
  verbose: bool = True,
621
1086
  ):
622
- """
623
- Delete intermediate .h5ad files and a temporary directory.
624
-
625
- Parameters
626
- ----------
627
- h5_dir : str | Path | iterable[str] | None
628
- If a directory path is given, all files directly inside it will be considered.
629
- If an iterable of file paths is given, those files will be considered.
630
- Only files ending with '.h5ad' (and not ending with '.gz') are removed.
631
- tmp_dir : str | Path | None
632
- Path to a directory to remove recursively (e.g. a temp dir created earlier).
633
- dry_run : bool
634
- If True, print what *would* be removed but do not actually delete.
635
- verbose : bool
636
- Print progress / warnings.
1087
+ """Delete intermediate .h5ad files and a temporary directory.
1088
+
1089
+ Args:
1090
+ h5_dir: Directory path or iterable of file paths to inspect for `.h5ad` files.
1091
+ tmp_dir: Optional directory to remove recursively.
1092
+ dry_run: If True, log what would be removed without deleting.
1093
+ verbose: If True, log progress and warnings.
1094
+
1095
+ Processing Steps:
1096
+ 1. Remove `.h5ad` files (excluding `.gz`) from the provided directory or list.
1097
+ 2. Optionally remove the temporary directory tree.
637
1098
  """
638
1099
 
639
1100
  # Helper: remove a single file path (Path-like or string)