smftools 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/helpers.py +32 -6
  3. smftools/cli/hmm_adata.py +232 -31
  4. smftools/cli/latent_adata.py +318 -0
  5. smftools/cli/load_adata.py +77 -73
  6. smftools/cli/preprocess_adata.py +178 -53
  7. smftools/cli/spatial_adata.py +149 -101
  8. smftools/cli_entry.py +12 -0
  9. smftools/config/conversion.yaml +11 -1
  10. smftools/config/default.yaml +38 -1
  11. smftools/config/experiment_config.py +53 -1
  12. smftools/constants.py +65 -0
  13. smftools/hmm/HMM.py +88 -0
  14. smftools/informatics/__init__.py +6 -0
  15. smftools/informatics/bam_functions.py +358 -8
  16. smftools/informatics/converted_BAM_to_adata.py +584 -163
  17. smftools/informatics/h5ad_functions.py +115 -2
  18. smftools/informatics/modkit_extract_to_adata.py +1003 -425
  19. smftools/informatics/sequence_encoding.py +72 -0
  20. smftools/logging_utils.py +21 -2
  21. smftools/metadata.py +1 -1
  22. smftools/plotting/__init__.py +9 -0
  23. smftools/plotting/general_plotting.py +2411 -628
  24. smftools/plotting/hmm_plotting.py +85 -7
  25. smftools/preprocessing/__init__.py +1 -0
  26. smftools/preprocessing/append_base_context.py +17 -17
  27. smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
  28. smftools/preprocessing/calculate_consensus.py +1 -1
  29. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  30. smftools/readwrite.py +53 -17
  31. smftools/schema/anndata_schema_v1.yaml +15 -1
  32. smftools/tools/__init__.py +4 -0
  33. smftools/tools/calculate_leiden.py +57 -0
  34. smftools/tools/calculate_nmf.py +119 -0
  35. smftools/tools/calculate_umap.py +91 -8
  36. smftools/tools/rolling_nn_distance.py +235 -0
  37. smftools/tools/tensor_factorization.py +169 -0
  38. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/METADATA +8 -6
  39. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/RECORD +42 -35
  40. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
  41. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
  42. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -5,22 +5,45 @@ import logging
5
5
  import shutil
6
6
  import time
7
7
  import traceback
8
+ from dataclasses import dataclass
8
9
  from multiprocessing import Manager, Pool, current_process
9
10
  from pathlib import Path
10
- from typing import TYPE_CHECKING, Iterable, Optional, Union
11
+ from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Union
11
12
 
12
13
  import anndata as ad
13
14
  import numpy as np
14
15
  import pandas as pd
15
16
 
17
+ from smftools.constants import (
18
+ BAM_SUFFIX,
19
+ BARCODE,
20
+ BASE_QUALITY_SCORES,
21
+ DATASET,
22
+ DEMUX_TYPE,
23
+ H5_DIR,
24
+ MISMATCH_INTEGER_ENCODING,
25
+ MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
26
+ MODKIT_EXTRACT_SEQUENCE_BASES,
27
+ MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE,
28
+ MODKIT_EXTRACT_SEQUENCE_PADDING_BASE,
29
+ READ_MAPPING_DIRECTION,
30
+ READ_MISMATCH_TREND,
31
+ READ_SPAN_MASK,
32
+ REFERENCE,
33
+ REFERENCE_DATASET_STRAND,
34
+ REFERENCE_STRAND,
35
+ SAMPLE,
36
+ SEQUENCE_INTEGER_DECODING,
37
+ SEQUENCE_INTEGER_ENCODING,
38
+ STRAND,
39
+ )
16
40
  from smftools.logging_utils import get_logger, setup_logging
17
41
  from smftools.optional_imports import require
18
42
 
19
43
  from ..readwrite import make_dirs
20
44
  from .bam_functions import count_aligned_reads, extract_base_identities
21
45
  from .binarize_converted_base_identities import binarize_converted_base_identities
22
- from .fasta_functions import find_conversion_sites
23
- from .ohe import ohe_batching
46
+ from .fasta_functions import find_conversion_sites, get_native_references
24
47
 
25
48
  logger = get_logger(__name__)
26
49
 
@@ -30,6 +53,67 @@ if TYPE_CHECKING:
30
53
  torch = require("torch", extra="torch", purpose="converted BAM processing")
31
54
 
32
55
 
56
+ @dataclass(frozen=True)
57
+ class RecordFastaInfo:
58
+ """Structured FASTA metadata for a single converted record.
59
+
60
+ Attributes:
61
+ sequence: Padded top-strand sequence for the record.
62
+ complement: Padded bottom-strand sequence for the record.
63
+ chromosome: Canonical chromosome name for the record.
64
+ unconverted_name: FASTA record name for the unconverted reference.
65
+ sequence_length: Length of the unpadded reference sequence.
66
+ padding_length: Number of padded bases applied to reach max length.
67
+ conversion: Conversion label (e.g., "unconverted", "5mC").
68
+ strand: Strand label ("top" or "bottom").
69
+ max_reference_length: Maximum reference length across all records.
70
+ """
71
+
72
+ sequence: str
73
+ complement: str
74
+ chromosome: str
75
+ unconverted_name: str
76
+ sequence_length: int
77
+ padding_length: int
78
+ conversion: str
79
+ strand: str
80
+ max_reference_length: int
81
+
82
+
83
+ @dataclass(frozen=True)
84
+ class SequenceEncodingConfig:
85
+ """Configuration for integer sequence encoding.
86
+
87
+ Attributes:
88
+ base_to_int: Mapping of base characters to integer encodings.
89
+ bases: Valid base characters used for encoding.
90
+ padding_base: Base token used for padding.
91
+ batch_size: Number of reads per temporary batch file.
92
+ """
93
+
94
+ base_to_int: Mapping[str, int]
95
+ bases: tuple[str, ...]
96
+ padding_base: str
97
+ batch_size: int = 100000
98
+
99
+ @property
100
+ def padding_value(self) -> int:
101
+ """Return the integer value used for padding positions."""
102
+ return self.base_to_int[self.padding_base]
103
+
104
+ @property
105
+ def unknown_value(self) -> int:
106
+ """Return the integer value used for unknown bases."""
107
+ return self.base_to_int["N"]
108
+
109
+
110
+ SEQUENCE_ENCODING_CONFIG = SequenceEncodingConfig(
111
+ base_to_int=MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
112
+ bases=MODKIT_EXTRACT_SEQUENCE_BASES,
113
+ padding_base=MODKIT_EXTRACT_SEQUENCE_PADDING_BASE,
114
+ )
115
+
116
+
33
117
  def converted_BAM_to_adata(
34
118
  converted_FASTA: str | Path,
35
119
  split_dir: Path,
@@ -46,7 +130,7 @@ def converted_BAM_to_adata(
46
130
  double_barcoded_path: Path | None = None,
47
131
  samtools_backend: str | None = "auto",
48
132
  ) -> tuple[ad.AnnData | None, Path]:
49
- """Convert BAM files into an AnnData object by binarizing modified base identities.
133
+ """Convert converted BAM files into an AnnData object with integer sequence encoding.
50
134
 
51
135
  Args:
52
136
  converted_FASTA: Path to the converted FASTA reference.
@@ -62,9 +146,18 @@ def converted_BAM_to_adata(
62
146
  deaminase_footprinting: Whether the footprinting used direct deamination chemistry.
63
147
  delete_intermediates: Whether to remove intermediate files after processing.
64
148
  double_barcoded_path: Path to dorado demux summary file of double-ended barcodes.
149
+ samtools_backend: Samtools backend choice for alignment parsing.
65
150
 
66
151
  Returns:
67
152
  tuple[anndata.AnnData | None, Path]: The AnnData object (if generated) and its path.
153
+
154
+ Processing Steps:
155
+ 1. Resolve the best available torch device and create output directories.
156
+ 2. Load converted FASTA records and compute conversion sites.
157
+ 3. Filter BAMs based on mapping thresholds.
158
+ 4. Process each BAM in parallel, building per-sample H5AD files.
159
+ 5. Concatenate per-sample AnnData objects and attach reference metadata.
160
+ 6. Add demultiplexing annotations and clean intermediate artifacts.
68
161
  """
69
162
  if torch.cuda.is_available():
70
163
  device = torch.device("cuda")
@@ -76,7 +169,7 @@ def converted_BAM_to_adata(
76
169
  logger.debug(f"Using device: {device}")
77
170
 
78
171
  ## Set Up Directories and File Paths
79
- h5_dir = output_dir / "h5ads"
172
+ h5_dir = output_dir / H5_DIR
80
173
  tmp_dir = output_dir / "tmp"
81
174
  final_adata = None
82
175
  final_adata_path = h5_dir / f"{experiment_name}.h5ad.gz"
@@ -90,7 +183,7 @@ def converted_BAM_to_adata(
90
183
  bam_files = sorted(
91
184
  p
92
185
  for p in split_dir.iterdir()
93
- if p.is_file() and p.suffix == ".bam" and "unclassified" not in p.name
186
+ if p.is_file() and p.suffix == BAM_SUFFIX and "unclassified" not in p.name
94
187
  )
95
188
 
96
189
  bam_path_list = bam_files
@@ -108,6 +201,16 @@ def converted_BAM_to_adata(
108
201
  bam_path_list, bam_files, mapping_threshold, samtools_backend
109
202
  )
110
203
 
204
+ # Get converted record sequences:
205
+ converted_FASTA_record_seq_map = get_native_references(converted_FASTA)
206
+ # Pad the record sequences
207
+ for record, [record_length, seq] in converted_FASTA_record_seq_map.items():
208
+ if max_reference_length > record_length:
209
+ pad_number = max_reference_length - record_length
210
+ record_length += pad_number
211
+ seq += "N" * pad_number
212
+ converted_FASTA_record_seq_map[record] = [record_length, seq]
213
+
111
214
  ## Process BAMs in Parallel
112
215
  final_adata = process_bams_parallel(
113
216
  bam_path_list,
@@ -121,8 +224,15 @@ def converted_BAM_to_adata(
121
224
  device,
122
225
  deaminase_footprinting,
123
226
  samtools_backend,
227
+ converted_FASTA_record_seq_map,
124
228
  )
125
229
 
230
+ final_adata.uns[f"{SEQUENCE_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
231
+ final_adata.uns[f"{MISMATCH_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
232
+ final_adata.uns[f"{SEQUENCE_INTEGER_DECODING}_map"] = {
233
+ str(key): value for key, value in MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE.items()
234
+ }
235
+
126
236
  final_adata.uns["References"] = {}
127
237
  for chromosome, [seq, comp] in chromosome_FASTA_dict.items():
128
238
  final_adata.var[f"{chromosome}_top_strand_FASTA_base"] = list(seq)
@@ -130,6 +240,11 @@ def converted_BAM_to_adata(
130
240
  final_adata.uns[f"{chromosome}_FASTA_sequence"] = seq
131
241
  final_adata.uns["References"][f"{chromosome}_FASTA_sequence"] = seq
132
242
 
243
+ if not deaminase_footprinting:
244
+ for record, [_length, seq] in converted_FASTA_record_seq_map.items():
245
+ if "unconverted" not in record:
246
+ final_adata.var[f"{record}_top_strand_FASTA_base"] = list(seq)
247
+
133
248
  final_adata.obs_names_make_unique()
134
249
  cols = final_adata.obs.columns
135
250
 
@@ -137,9 +252,29 @@ def converted_BAM_to_adata(
137
252
  for col in cols:
138
253
  final_adata.obs[col] = final_adata.obs[col].astype("category")
139
254
 
255
+ consensus_bases = MODKIT_EXTRACT_SEQUENCE_BASES[:4] # ignore N/PAD for consensus
256
+ consensus_base_ints = [MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[base] for base in consensus_bases]
257
+ for ref_group in final_adata.obs[REFERENCE_DATASET_STRAND].cat.categories:
258
+ group_subset = final_adata[final_adata.obs[REFERENCE_DATASET_STRAND] == ref_group]
259
+ encoded_sequences = group_subset.layers[SEQUENCE_INTEGER_ENCODING]
260
+ layer_counts = [
261
+ np.sum(encoded_sequences == base_int, axis=0) for base_int in consensus_base_ints
262
+ ]
263
+ count_array = np.array(layer_counts)
264
+ nucleotide_indexes = np.argmax(count_array, axis=0)
265
+ consensus_sequence_list = [consensus_bases[i] for i in nucleotide_indexes]
266
+ no_calls_mask = np.sum(count_array, axis=0) == 0
267
+ if np.any(no_calls_mask):
268
+ consensus_sequence_list = np.array(consensus_sequence_list, dtype=object)
269
+ consensus_sequence_list[no_calls_mask] = "N"
270
+ consensus_sequence_list = consensus_sequence_list.tolist()
271
+ final_adata.var[f"{ref_group}_consensus_sequence_from_all_samples"] = (
272
+ consensus_sequence_list
273
+ )
274
+
140
275
  if input_already_demuxed:
141
- final_adata.obs["demux_type"] = ["already"] * final_adata.shape[0]
142
- final_adata.obs["demux_type"] = final_adata.obs["demux_type"].astype("category")
276
+ final_adata.obs[DEMUX_TYPE] = ["already"] * final_adata.shape[0]
277
+ final_adata.obs[DEMUX_TYPE] = final_adata.obs[DEMUX_TYPE].astype("category")
143
278
  else:
144
279
  from .h5ad_functions import add_demux_type_annotation
145
280
 
@@ -156,23 +291,31 @@ def converted_BAM_to_adata(
156
291
 
157
292
 
158
293
  def process_conversion_sites(
159
- converted_FASTA, conversions=["unconverted", "5mC"], deaminase_footprinting=False
160
- ):
161
- """
162
- Extracts conversion sites and determines the max reference length.
294
+ converted_FASTA: str | Path,
295
+ conversions: list[str] | None = None,
296
+ deaminase_footprinting: bool = False,
297
+ ) -> tuple[int, dict[str, RecordFastaInfo], dict[str, tuple[str, str]]]:
298
+ """Extract conversion sites and FASTA metadata for converted references.
163
299
 
164
- Parameters:
165
- converted_FASTA (str): Path to the converted reference FASTA.
166
- conversions (list): List of modification types (e.g., ['unconverted', '5mC', '6mA']).
167
- deaminase_footprinting (bool): Whether the footprinting was done with a direct deamination chemistry.
300
+ Args:
301
+ converted_FASTA: Path to the converted reference FASTA.
302
+ conversions: List of modification types (e.g., ["unconverted", "5mC", "6mA"]).
303
+ deaminase_footprinting: Whether the footprinting was done with direct deamination chemistry.
168
304
 
169
305
  Returns:
170
- max_reference_length (int): The length of the longest sequence.
171
- record_FASTA_dict (dict): Dictionary of sequence information for **both converted & unconverted** records.
306
+ tuple[int, dict[str, RecordFastaInfo], dict[str, tuple[str, str]]]:
307
+ Maximum reference length, record metadata, and chromosome sequences.
308
+
309
+ Processing Steps:
310
+ 1. Parse unconverted FASTA records to determine the max reference length.
311
+ 2. Build record metadata for unconverted and converted strands.
312
+ 3. Cache chromosome-level sequences for downstream annotation.
172
313
  """
173
- modification_dict = {}
174
- record_FASTA_dict = {}
175
- chromosome_FASTA_dict = {}
314
+ if conversions is None:
315
+ conversions = ["unconverted", "5mC"]
316
+ modification_dict: dict[str, dict] = {}
317
+ record_FASTA_dict: dict[str, RecordFastaInfo] = {}
318
+ chromosome_FASTA_dict: dict[str, tuple[str, str]] = {}
176
319
  max_reference_length = 0
177
320
  unconverted = conversions[0]
178
321
  conversion_types = conversions[1:]
@@ -196,22 +339,23 @@ def process_conversion_sites(
196
339
  chromosome = record
197
340
 
198
341
  # Store **original sequence**
199
- record_FASTA_dict[record] = [
200
- sequence + "N" * (max_reference_length - sequence_length),
201
- complement + "N" * (max_reference_length - sequence_length),
202
- chromosome,
203
- record,
204
- sequence_length,
205
- max_reference_length - sequence_length,
206
- unconverted,
207
- "top",
208
- ]
342
+ record_FASTA_dict[record] = RecordFastaInfo(
343
+ sequence=sequence + "N" * (max_reference_length - sequence_length),
344
+ complement=complement + "N" * (max_reference_length - sequence_length),
345
+ chromosome=chromosome,
346
+ unconverted_name=record,
347
+ sequence_length=sequence_length,
348
+ padding_length=max_reference_length - sequence_length,
349
+ conversion=unconverted,
350
+ strand="top",
351
+ max_reference_length=max_reference_length,
352
+ )
209
353
 
210
354
  if chromosome not in chromosome_FASTA_dict:
211
- chromosome_FASTA_dict[chromosome] = [
355
+ chromosome_FASTA_dict[chromosome] = (
212
356
  sequence + "N" * (max_reference_length - sequence_length),
213
357
  complement + "N" * (max_reference_length - sequence_length),
214
- ]
358
+ )
215
359
 
216
360
  # Process converted records
217
361
  for conversion in conversion_types:
@@ -233,24 +377,44 @@ def process_conversion_sites(
233
377
  converted_name = f"{chromosome}_{conversion}_{strand}"
234
378
  unconverted_name = f"{chromosome}_{unconverted}_top"
235
379
 
236
- record_FASTA_dict[converted_name] = [
237
- sequence + "N" * (max_reference_length - sequence_length),
238
- complement + "N" * (max_reference_length - sequence_length),
239
- chromosome,
240
- unconverted_name,
241
- sequence_length,
242
- max_reference_length - sequence_length,
243
- conversion,
244
- strand,
245
- ]
246
-
247
- logger.debug("Updated record_FASTA_dict Keys:", list(record_FASTA_dict.keys()))
380
+ record_FASTA_dict[converted_name] = RecordFastaInfo(
381
+ sequence=sequence + "N" * (max_reference_length - sequence_length),
382
+ complement=complement + "N" * (max_reference_length - sequence_length),
383
+ chromosome=chromosome,
384
+ unconverted_name=unconverted_name,
385
+ sequence_length=sequence_length,
386
+ padding_length=max_reference_length - sequence_length,
387
+ conversion=conversion,
388
+ strand=strand,
389
+ max_reference_length=max_reference_length,
390
+ )
391
+
392
+ logger.debug("Updated record_FASTA_dict keys: %s", list(record_FASTA_dict.keys()))
248
393
  return max_reference_length, record_FASTA_dict, chromosome_FASTA_dict
249
394
 
250
395
 
251
- def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold, samtools_backend):
252
- """Filters BAM files based on mapping threshold."""
253
- records_to_analyze = set()
396
+ def filter_bams_by_mapping_threshold(
397
+ bam_path_list: list[Path],
398
+ bam_files: list[Path],
399
+ mapping_threshold: float,
400
+ samtools_backend: str | None,
401
+ ) -> set[str]:
402
+ """Filter FASTA records based on per-BAM mapping thresholds.
403
+
404
+ Args:
405
+ bam_path_list: Ordered list of BAM paths to evaluate.
406
+ bam_files: Matching list of BAM paths used for reporting.
407
+ mapping_threshold: Minimum percentage of aligned reads to include a record.
408
+ samtools_backend: Samtools backend choice for alignment parsing.
409
+
410
+ Returns:
411
+ set[str]: FASTA record IDs that pass the mapping threshold.
412
+
413
+ Processing Steps:
414
+ 1. Count aligned/unaligned reads and per-record percentages.
415
+ 2. Collect record IDs that meet the mapping threshold.
416
+ """
417
+ records_to_analyze: set[str] = set()
254
418
 
255
419
  for i, bam in enumerate(bam_path_list):
256
420
  aligned_reads, unaligned_reads, record_counts = count_aligned_reads(bam, samtools_backend)
@@ -265,33 +429,179 @@ def filter_bams_by_mapping_threshold(bam_path_list, bam_files, mapping_threshold
265
429
  return records_to_analyze
266
430
 
267
431
 
432
+ def _encode_sequence_array(
433
+ read_sequence: np.ndarray,
434
+ valid_length: int,
435
+ config: SequenceEncodingConfig,
436
+ ) -> np.ndarray:
437
+ """Encode a base-identity array into integer values with padding.
438
+
439
+ Args:
440
+ read_sequence: Array of base calls (dtype "<U1").
441
+ valid_length: Number of valid reference positions for this record.
442
+ config: Integer encoding configuration.
443
+
444
+ Returns:
445
+ np.ndarray: Integer-encoded sequence with padding applied.
446
+
447
+ Processing Steps:
448
+ 1. Initialize an array filled with the unknown base encoding.
449
+ 2. Map A/C/G/T/N bases into integer values.
450
+ 3. Mark positions beyond valid_length with the padding value.
451
+ """
452
+ read_sequence = np.asarray(read_sequence, dtype="<U1")
453
+ encoded = np.full(read_sequence.shape, config.unknown_value, dtype=np.int16)
454
+ for base in config.bases:
455
+ encoded[read_sequence == base] = config.base_to_int[base]
456
+ if valid_length < encoded.size:
457
+ encoded[valid_length:] = config.padding_value
458
+ return encoded
459
+
460
+
461
+ def _write_sequence_batches(
462
+ base_identities: Mapping[str, np.ndarray],
463
+ tmp_dir: Path,
464
+ record: str,
465
+ prefix: str,
466
+ valid_length: int,
467
+ config: SequenceEncodingConfig,
468
+ ) -> list[str]:
469
+ """Encode base identities into integer arrays and write batched H5AD files.
470
+
471
+ Args:
472
+ base_identities: Mapping of read name to base identity arrays.
473
+ tmp_dir: Directory for temporary H5AD files.
474
+ record: Reference record identifier.
475
+ prefix: Prefix used to name batch files.
476
+ valid_length: Valid reference length for padding determination.
477
+ config: Integer encoding configuration.
478
+
479
+ Returns:
480
+ list[str]: Paths to written H5AD batch files.
481
+
482
+ Processing Steps:
483
+ 1. Encode each read sequence into integers.
484
+ 2. Accumulate encoded reads into batches.
485
+ 3. Persist each batch to an H5AD file with `.uns` storage.
486
+ """
487
+ batch_files: list[str] = []
488
+ batch: dict[str, np.ndarray] = {}
489
+ batch_number = 0
490
+
491
+ for read_name, sequence in base_identities.items():
492
+ if sequence is None:
493
+ continue
494
+ batch[read_name] = _encode_sequence_array(sequence, valid_length, config)
495
+ if len(batch) >= config.batch_size:
496
+ save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
497
+ ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
498
+ batch_files.append(str(save_name))
499
+ batch = {}
500
+ batch_number += 1
501
+
502
+ if batch:
503
+ save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
504
+ ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
505
+ batch_files.append(str(save_name))
506
+
507
+ return batch_files
508
+
509
+
510
+ def _load_sequence_batches(
511
+ batch_files: list[Path | str],
512
+ ) -> tuple[dict[str, np.ndarray], set[str], set[str]]:
513
+ """Load integer-encoded sequence batches from H5AD files.
514
+
515
+ Args:
516
+ batch_files: H5AD paths containing encoded sequences in `.uns`.
517
+
518
+ Returns:
519
+ tuple[dict[str, np.ndarray], set[str], set[str]]:
520
+ Read-to-sequence mapping and sets of forward/reverse mapped reads.
521
+
522
+ Processing Steps:
523
+ 1. Read each H5AD file.
524
+ 2. Merge `.uns` dictionaries into a single mapping.
525
+ 3. Track forward/reverse read IDs based on filename markers.
526
+ """
527
+ sequences: dict[str, np.ndarray] = {}
528
+ fwd_reads: set[str] = set()
529
+ rev_reads: set[str] = set()
530
+ for batch_file in batch_files:
531
+ batch_path = Path(batch_file)
532
+ batch_sequences = ad.read_h5ad(batch_path).uns
533
+ sequences.update(batch_sequences)
534
+ if "_fwd_" in batch_path.name:
535
+ fwd_reads.update(batch_sequences.keys())
536
+ elif "_rev_" in batch_path.name:
537
+ rev_reads.update(batch_sequences.keys())
538
+ return sequences, fwd_reads, rev_reads
539
+
540
+
268
541
  def process_single_bam(
269
- bam_index,
270
- bam,
271
- records_to_analyze,
272
- record_FASTA_dict,
273
- chromosome_FASTA_dict,
274
- tmp_dir,
275
- max_reference_length,
276
- device,
277
- deaminase_footprinting,
278
- samtools_backend,
279
- ):
280
- """Worker function to process a single BAM file (must be at top-level for multiprocessing)."""
281
- adata_list = []
542
+ bam_index: int,
543
+ bam: Path,
544
+ records_to_analyze: set[str],
545
+ record_FASTA_dict: dict[str, RecordFastaInfo],
546
+ chromosome_FASTA_dict: dict[str, tuple[str, str]],
547
+ tmp_dir: Path,
548
+ max_reference_length: int,
549
+ device: torch.device,
550
+ deaminase_footprinting: bool,
551
+ samtools_backend: str | None,
552
+ converted_FASTA_record_seq_map: dict[str, tuple[int, str]],
553
+ ) -> ad.AnnData | None:
554
+ """Process a single BAM file into per-record AnnData objects.
555
+
556
+ Args:
557
+ bam_index: Index of the BAM within the processing batch.
558
+ bam: Path to the BAM file.
559
+ records_to_analyze: FASTA record IDs that passed the mapping threshold.
560
+ record_FASTA_dict: FASTA metadata keyed by record ID.
561
+ chromosome_FASTA_dict: Chromosome sequences for annotations.
562
+ tmp_dir: Directory for temporary batch files.
563
+ max_reference_length: Maximum reference length for padding.
564
+ device: Torch device used for binarization.
565
+ deaminase_footprinting: Whether direct deamination chemistry was used.
566
+ samtools_backend: Samtools backend choice for alignment parsing.
567
+ converted_FASTA_record_seq_map: record to seq map
568
+
569
+ Returns:
570
+ anndata.AnnData | None: Concatenated AnnData object or None if no data.
571
+
572
+ Processing Steps:
573
+ 1. Extract base identities and mismatch profiles for each record.
574
+ 2. Binarize modified base identities into feature matrices.
575
+ 3. Encode read sequences into integer arrays and cache batches.
576
+ 4. Build AnnData layers/obs metadata for each record and concatenate.
577
+ """
578
+ adata_list: list[ad.AnnData] = []
282
579
 
283
580
  for record in records_to_analyze:
284
581
  sample = bam.stem
285
- chromosome = record_FASTA_dict[record][2]
286
- current_length = record_FASTA_dict[record][4]
287
- mod_type, strand = record_FASTA_dict[record][6], record_FASTA_dict[record][7]
288
- sequence = chromosome_FASTA_dict[chromosome][0]
582
+ record_info = record_FASTA_dict[record]
583
+ chromosome = record_info.chromosome
584
+ current_length = record_info.sequence_length
585
+ mod_type, strand = record_info.conversion, record_info.strand
586
+ non_converted_sequence = chromosome_FASTA_dict[chromosome][0]
587
+ record_sequence = converted_FASTA_record_seq_map[record][1]
289
588
 
290
589
  # Extract Base Identities
291
- fwd_bases, rev_bases, mismatch_counts_per_read, mismatch_trend_per_read = (
292
- extract_base_identities(
293
- bam, record, range(current_length), max_reference_length, sequence, samtools_backend
294
- )
590
+ (
591
+ fwd_bases,
592
+ rev_bases,
593
+ mismatch_counts_per_read,
594
+ mismatch_trend_per_read,
595
+ mismatch_base_identities,
596
+ base_quality_scores,
597
+ read_span_masks,
598
+ ) = extract_base_identities(
599
+ bam,
600
+ record,
601
+ range(current_length),
602
+ max_reference_length,
603
+ record_sequence,
604
+ samtools_backend,
295
605
  )
296
606
  mismatch_trend_series = pd.Series(mismatch_trend_per_read)
297
607
 
@@ -343,83 +653,115 @@ def process_single_bam(
343
653
  sorted_index = sorted(bin_df.index)
344
654
  bin_df = bin_df.reindex(sorted_index)
345
655
 
346
- # One-Hot Encode Reads if there is valid data
347
- one_hot_reads = {}
348
-
656
+ # Integer-encode reads if there is valid data
657
+ batch_files: list[str] = []
349
658
  if fwd_bases:
350
- fwd_ohe_files = ohe_batching(
351
- fwd_bases, tmp_dir, record, f"{bam_index}_fwd", batch_size=100000
659
+ batch_files.extend(
660
+ _write_sequence_batches(
661
+ fwd_bases,
662
+ tmp_dir,
663
+ record,
664
+ f"{bam_index}_fwd",
665
+ current_length,
666
+ SEQUENCE_ENCODING_CONFIG,
667
+ )
352
668
  )
353
- for ohe_file in fwd_ohe_files:
354
- tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
355
- one_hot_reads.update(tmp_ohe_dict)
356
- del tmp_ohe_dict
357
669
 
358
670
  if rev_bases:
359
- rev_ohe_files = ohe_batching(
360
- rev_bases, tmp_dir, record, f"{bam_index}_rev", batch_size=100000
671
+ batch_files.extend(
672
+ _write_sequence_batches(
673
+ rev_bases,
674
+ tmp_dir,
675
+ record,
676
+ f"{bam_index}_rev",
677
+ current_length,
678
+ SEQUENCE_ENCODING_CONFIG,
679
+ )
361
680
  )
362
- for ohe_file in rev_ohe_files:
363
- tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
364
- one_hot_reads.update(tmp_ohe_dict)
365
- del tmp_ohe_dict
366
681
 
367
- # Skip if one_hot_reads is empty
368
- if not one_hot_reads:
682
+ if not batch_files:
369
683
  logger.debug(
370
- f"[Worker {current_process().pid}] Skipping {sample} - No valid one-hot encoded data for {record}."
684
+ f"[Worker {current_process().pid}] Skipping {sample} - No valid encoded data for {record}."
371
685
  )
372
686
  continue
373
687
 
374
688
  gc.collect()
375
689
 
376
- # Convert One-Hot Encodings to Numpy Arrays
377
- n_rows_OHE = 5
378
- read_names = list(one_hot_reads.keys())
379
-
380
- # Skip if no read names exist
381
- if not read_names:
690
+ encoded_reads, fwd_reads, rev_reads = _load_sequence_batches(batch_files)
691
+ if not encoded_reads:
382
692
  logger.debug(
383
- f"[Worker {current_process().pid}] Skipping {sample} - No reads found in one-hot encoded data for {record}."
693
+ f"[Worker {current_process().pid}] Skipping {sample} - No reads found in encoded data for {record}."
384
694
  )
385
695
  continue
386
696
 
387
- sequence_length = one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
388
- df_A, df_C, df_G, df_T, df_N = [
389
- np.zeros((len(sorted_index), sequence_length), dtype=int) for _ in range(5)
390
- ]
697
+ sequence_length = max_reference_length
698
+ default_sequence = np.full(
699
+ sequence_length, SEQUENCE_ENCODING_CONFIG.unknown_value, dtype=np.int16
700
+ )
701
+ if current_length < sequence_length:
702
+ default_sequence[current_length:] = SEQUENCE_ENCODING_CONFIG.padding_value
391
703
 
392
- # Populate One-Hot Arrays
393
- for j, read_name in enumerate(sorted_index):
394
- if read_name in one_hot_reads:
395
- one_hot_array = one_hot_reads[read_name].reshape(n_rows_OHE, -1)
396
- df_A[j], df_C[j], df_G[j], df_T[j], df_N[j] = one_hot_array
704
+ encoded_matrix = np.vstack(
705
+ [encoded_reads.get(read_name, default_sequence) for read_name in sorted_index]
706
+ )
707
+ default_mismatch_sequence = np.full(
708
+ sequence_length, SEQUENCE_ENCODING_CONFIG.unknown_value, dtype=np.int16
709
+ )
710
+ if current_length < sequence_length:
711
+ default_mismatch_sequence[current_length:] = SEQUENCE_ENCODING_CONFIG.padding_value
712
+ mismatch_encoded_matrix = np.vstack(
713
+ [
714
+ mismatch_base_identities.get(read_name, default_mismatch_sequence)
715
+ for read_name in sorted_index
716
+ ]
717
+ )
718
+ default_quality_sequence = np.full(sequence_length, -1, dtype=np.int16)
719
+ quality_matrix = np.vstack(
720
+ [
721
+ base_quality_scores.get(read_name, default_quality_sequence)
722
+ for read_name in sorted_index
723
+ ]
724
+ )
725
+ default_read_span = np.zeros(sequence_length, dtype=np.int16)
726
+ read_span_matrix = np.vstack(
727
+ [read_span_masks.get(read_name, default_read_span) for read_name in sorted_index]
728
+ )
397
729
 
398
730
  # Convert to AnnData
399
731
  X = bin_df.values.astype(np.float32)
400
732
  adata = ad.AnnData(X)
401
733
  adata.obs_names = bin_df.index.astype(str)
402
734
  adata.var_names = bin_df.columns.astype(str)
403
- adata.obs["Sample"] = [sample] * len(adata)
735
+ adata.obs[SAMPLE] = [sample] * len(adata)
404
736
  try:
405
737
  barcode = sample.split("barcode")[1]
406
738
  except Exception:
407
739
  barcode = np.nan
408
- adata.obs["Barcode"] = [int(barcode)] * len(adata)
409
- adata.obs["Barcode"] = adata.obs["Barcode"].astype(str)
410
- adata.obs["Reference"] = [chromosome] * len(adata)
411
- adata.obs["Strand"] = [strand] * len(adata)
412
- adata.obs["Dataset"] = [mod_type] * len(adata)
413
- adata.obs["Reference_dataset_strand"] = [f"{chromosome}_{mod_type}_{strand}"] * len(adata)
414
- adata.obs["Reference_strand"] = [f"{chromosome}_{strand}"] * len(adata)
415
- adata.obs["Read_mismatch_trend"] = adata.obs_names.map(mismatch_trend_series)
416
-
417
- # Attach One-Hot Encodings to Layers
418
- adata.layers["A_binary_sequence_encoding"] = df_A
419
- adata.layers["C_binary_sequence_encoding"] = df_C
420
- adata.layers["G_binary_sequence_encoding"] = df_G
421
- adata.layers["T_binary_sequence_encoding"] = df_T
422
- adata.layers["N_binary_sequence_encoding"] = df_N
740
+ adata.obs[BARCODE] = [int(barcode)] * len(adata)
741
+ adata.obs[BARCODE] = adata.obs[BARCODE].astype(str)
742
+ adata.obs[REFERENCE] = [chromosome] * len(adata)
743
+ adata.obs[STRAND] = [strand] * len(adata)
744
+ adata.obs[DATASET] = [mod_type] * len(adata)
745
+ adata.obs[REFERENCE_DATASET_STRAND] = [f"{chromosome}_{mod_type}_{strand}"] * len(adata)
746
+ adata.obs[REFERENCE_STRAND] = [f"{chromosome}_{strand}"] * len(adata)
747
+ adata.obs[READ_MISMATCH_TREND] = adata.obs_names.map(mismatch_trend_series)
748
+
749
+ read_mapping_direction = []
750
+ for read_id in adata.obs_names:
751
+ if read_id in fwd_reads:
752
+ read_mapping_direction.append("fwd")
753
+ elif read_id in rev_reads:
754
+ read_mapping_direction.append("rev")
755
+ else:
756
+ read_mapping_direction.append("unk")
757
+
758
+ adata.obs[READ_MAPPING_DIRECTION] = read_mapping_direction
759
+
760
+ # Attach integer sequence encoding to layers
761
+ adata.layers[SEQUENCE_INTEGER_ENCODING] = encoded_matrix
762
+ adata.layers[MISMATCH_INTEGER_ENCODING] = mismatch_encoded_matrix
763
+ adata.layers[BASE_QUALITY_SCORES] = quality_matrix
764
+ adata.layers[READ_SPAN_MASK] = read_span_matrix
423
765
 
424
766
  adata_list.append(adata)
425
767
 
@@ -427,27 +769,56 @@ def process_single_bam(
427
769
 
428
770
 
429
771
  def timestamp():
430
- """Returns a formatted timestamp for logging."""
772
+ """Return a formatted timestamp for logging.
773
+
774
+ Returns:
775
+ str: Timestamp string in the format ``[YYYY-MM-DD HH:MM:SS]``.
776
+ """
431
777
  return time.strftime("[%Y-%m-%d %H:%M:%S]")
432
778
 
433
779
 
434
780
  def worker_function(
435
- bam_index,
436
- bam,
437
- records_to_analyze,
438
- shared_record_FASTA_dict,
439
- chromosome_FASTA_dict,
440
- tmp_dir,
441
- h5_dir,
442
- max_reference_length,
443
- device,
444
- deaminase_footprinting,
445
- samtools_backend,
781
+ bam_index: int,
782
+ bam: Path,
783
+ records_to_analyze: set[str],
784
+ shared_record_FASTA_dict: dict[str, RecordFastaInfo],
785
+ chromosome_FASTA_dict: dict[str, tuple[str, str]],
786
+ tmp_dir: Path,
787
+ h5_dir: Path,
788
+ max_reference_length: int,
789
+ device: torch.device,
790
+ deaminase_footprinting: bool,
791
+ samtools_backend: str | None,
792
+ converted_FASTA_record_seq_map: dict[str, tuple[int, str]],
446
793
  progress_queue,
447
- log_level,
448
- log_file,
794
+ log_level: int,
795
+ log_file: Path | None,
449
796
  ):
450
- """Worker function that processes a single BAM and writes the output to an H5AD file."""
797
+ """Process a single BAM and write the output to an H5AD file.
798
+
799
+ Args:
800
+ bam_index: Index of the BAM within the processing batch.
801
+ bam: Path to the BAM file.
802
+ records_to_analyze: FASTA record IDs that passed the mapping threshold.
803
+ shared_record_FASTA_dict: Shared FASTA metadata keyed by record ID.
804
+ chromosome_FASTA_dict: Chromosome sequences for annotations.
805
+ tmp_dir: Directory for temporary batch files.
806
+ h5_dir: Directory for per-BAM H5AD outputs.
807
+ max_reference_length: Maximum reference length for padding.
808
+ device: Torch device used for binarization.
809
+ deaminase_footprinting: Whether direct deamination chemistry was used.
810
+ samtools_backend: Samtools backend choice for alignment parsing.
811
+ converted_FASTA_record_seq_map: record to sequence map
812
+ progress_queue: Queue used to signal completion.
813
+ log_level: Logging level to configure in workers.
814
+ log_file: Optional log file path.
815
+
816
+ Processing Steps:
817
+ 1. Skip processing if an output H5AD already exists.
818
+ 2. Filter records to those present in the FASTA metadata.
819
+ 3. Run per-record processing and write AnnData output.
820
+ 4. Signal completion via the progress queue.
821
+ """
451
822
  _ensure_worker_logging(log_level, log_file)
452
823
  worker_id = current_process().pid # Get worker process ID
453
824
  sample = bam.stem
@@ -485,6 +856,7 @@ def worker_function(
485
856
  device,
486
857
  deaminase_footprinting,
487
858
  samtools_backend,
859
+ converted_FASTA_record_seq_map,
488
860
  )
489
861
 
490
862
  if adata is not None:
@@ -505,19 +877,43 @@ def worker_function(
505
877
 
506
878
 
507
879
  def process_bams_parallel(
508
- bam_path_list,
509
- records_to_analyze,
510
- record_FASTA_dict,
511
- chromosome_FASTA_dict,
512
- tmp_dir,
513
- h5_dir,
514
- num_threads,
515
- max_reference_length,
516
- device,
517
- deaminase_footprinting,
518
- samtools_backend,
519
- ):
520
- """Processes BAM files in parallel, writes each H5AD to disk, and concatenates them at the end."""
880
+ bam_path_list: list[Path],
881
+ records_to_analyze: set[str],
882
+ record_FASTA_dict: dict[str, RecordFastaInfo],
883
+ chromosome_FASTA_dict: dict[str, tuple[str, str]],
884
+ tmp_dir: Path,
885
+ h5_dir: Path,
886
+ num_threads: int,
887
+ max_reference_length: int,
888
+ device: torch.device,
889
+ deaminase_footprinting: bool,
890
+ samtools_backend: str | None,
891
+ converted_FASTA_record_seq_map: dict[str, tuple[int, str]],
892
+ ) -> ad.AnnData | None:
893
+ """Process BAM files in parallel and concatenate the resulting AnnData.
894
+
895
+ Args:
896
+ bam_path_list: List of BAM files to process.
897
+ records_to_analyze: FASTA record IDs that passed the mapping threshold.
898
+ record_FASTA_dict: FASTA metadata keyed by record ID.
899
+ chromosome_FASTA_dict: Chromosome sequences for annotations.
900
+ tmp_dir: Directory for temporary batch files.
901
+ h5_dir: Directory for per-BAM H5AD outputs.
902
+ num_threads: Number of worker processes.
903
+ max_reference_length: Maximum reference length for padding.
904
+ device: Torch device used for binarization.
905
+ deaminase_footprinting: Whether direct deamination chemistry was used.
906
+ samtools_backend: Samtools backend choice for alignment parsing.
907
+ converted_FASTA_record_seq_map: map from converted record name to the converted reference length and sequence.
908
+
909
+ Returns:
910
+ anndata.AnnData | None: Concatenated AnnData or None if no H5ADs produced.
911
+
912
+ Processing Steps:
913
+ 1. Spawn worker processes to handle each BAM.
914
+ 2. Track completion via a multiprocessing queue.
915
+ 3. Concatenate per-BAM H5AD files into a final AnnData.
916
+ """
521
917
  make_dirs(h5_dir) # Ensure h5_dir exists
522
918
 
523
919
  logger.info(f"Starting parallel BAM processing with {num_threads} threads...")
@@ -543,6 +939,7 @@ def process_bams_parallel(
543
939
  device,
544
940
  deaminase_footprinting,
545
941
  samtools_backend,
942
+ converted_FASTA_record_seq_map,
546
943
  progress_queue,
547
944
  log_level,
548
945
  log_file,
@@ -583,7 +980,16 @@ def process_bams_parallel(
583
980
 
584
981
 
585
982
  def _log_async_result_errors(results, bam_path_list):
586
- """Log worker failures captured by multiprocessing AsyncResult objects."""
983
+ """Log worker failures captured by multiprocessing AsyncResult objects.
984
+
985
+ Args:
986
+ results: Iterable of AsyncResult objects from multiprocessing.
987
+ bam_path_list: List of BAM paths matching the async results.
988
+
989
+ Processing Steps:
990
+ 1. Iterate over async results.
991
+ 2. Retrieve results to surface worker exceptions.
992
+ """
587
993
  for bam, result in zip(bam_path_list, results):
588
994
  if not result.ready():
589
995
  continue
@@ -594,6 +1000,15 @@ def _log_async_result_errors(results, bam_path_list):
594
1000
 
595
1001
 
596
1002
  def _get_logger_config() -> tuple[int, Path | None]:
1003
+ """Return the active smftools logger level and optional file path.
1004
+
1005
+ Returns:
1006
+ tuple[int, Path | None]: Log level and log file path (if configured).
1007
+
1008
+ Processing Steps:
1009
+ 1. Inspect the smftools logger for configured handlers.
1010
+ 2. Extract log level and file handler path.
1011
+ """
597
1012
  smftools_logger = logging.getLogger("smftools")
598
1013
  level = smftools_logger.level
599
1014
  if level == logging.NOTSET:
@@ -607,6 +1022,16 @@ def _get_logger_config() -> tuple[int, Path | None]:
607
1022
 
608
1023
 
609
1024
  def _ensure_worker_logging(log_level: int, log_file: Path | None) -> None:
1025
+ """Ensure worker processes have logging configured.
1026
+
1027
+ Args:
1028
+ log_level: Logging level to configure.
1029
+ log_file: Optional log file path.
1030
+
1031
+ Processing Steps:
1032
+ 1. Check if handlers are already configured.
1033
+ 2. Initialize logging with the provided level and file path.
1034
+ """
610
1035
  smftools_logger = logging.getLogger("smftools")
611
1036
  if not smftools_logger.handlers:
612
1037
  setup_logging(level=log_level, log_file=log_file)
@@ -619,21 +1044,17 @@ def delete_intermediate_h5ads_and_tmpdir(
619
1044
  dry_run: bool = False,
620
1045
  verbose: bool = True,
621
1046
  ):
622
- """
623
- Delete intermediate .h5ad files and a temporary directory.
624
-
625
- Parameters
626
- ----------
627
- h5_dir : str | Path | iterable[str] | None
628
- If a directory path is given, all files directly inside it will be considered.
629
- If an iterable of file paths is given, those files will be considered.
630
- Only files ending with '.h5ad' (and not ending with '.gz') are removed.
631
- tmp_dir : str | Path | None
632
- Path to a directory to remove recursively (e.g. a temp dir created earlier).
633
- dry_run : bool
634
- If True, print what *would* be removed but do not actually delete.
635
- verbose : bool
636
- Print progress / warnings.
1047
+ """Delete intermediate .h5ad files and a temporary directory.
1048
+
1049
+ Args:
1050
+ h5_dir: Directory path or iterable of file paths to inspect for `.h5ad` files.
1051
+ tmp_dir: Optional directory to remove recursively.
1052
+ dry_run: If True, log what would be removed without deleting.
1053
+ verbose: If True, log progress and warnings.
1054
+
1055
+ Processing Steps:
1056
+ 1. Remove `.h5ad` files (excluding `.gz`) from the provided directory or list.
1057
+ 2. Optionally remove the temporary directory tree.
637
1058
  """
638
1059
 
639
1060
  # Helper: remove a single file path (Path-like or string)