smftools 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +7 -1
  5. smftools/cli/hmm_adata.py +902 -244
  6. smftools/cli/load_adata.py +318 -198
  7. smftools/cli/preprocess_adata.py +285 -171
  8. smftools/cli/spatial_adata.py +137 -53
  9. smftools/cli_entry.py +94 -178
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +5 -1
  12. smftools/config/deaminase.yaml +1 -1
  13. smftools/config/default.yaml +22 -17
  14. smftools/config/direct.yaml +8 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +505 -276
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2125 -1426
  21. smftools/hmm/__init__.py +2 -3
  22. smftools/hmm/archived/call_hmm_peaks.py +16 -1
  23. smftools/hmm/call_hmm_peaks.py +173 -193
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +379 -156
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +195 -29
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +347 -168
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +145 -85
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +8 -8
  84. smftools/preprocessing/append_base_context.py +105 -79
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  86. smftools/preprocessing/{archives → archived}/calculate_complexity.py +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +127 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +44 -22
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +103 -55
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +70 -37
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +688 -271
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +93 -27
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +264 -109
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/METADATA +15 -43
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.4.dist-info/RECORD +0 -176
  128. /smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +0 -0
  129. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  130. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  131. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  132. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  133. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,36 +1,57 @@
1
- from ..readwrite import make_dirs, time_string
1
+ from __future__ import annotations
2
2
 
3
- import os
4
- import subprocess
3
+ import gzip
4
+ from concurrent.futures import ProcessPoolExecutor
5
5
  from pathlib import Path
6
-
7
- from typing import Union, List, Dict, Tuple
6
+ from typing import Dict, Iterable, Tuple
8
7
 
9
8
  import numpy as np
10
- import gzip
11
-
9
+ import pysam
12
10
  from Bio import SeqIO
13
- from Bio.SeqRecord import SeqRecord
14
11
  from Bio.Seq import Seq
12
+ from Bio.SeqRecord import SeqRecord
15
13
  from pyfaidx import Fasta
16
- import pysam
17
14
 
18
- from concurrent.futures import ProcessPoolExecutor
19
- from itertools import chain
15
+ from smftools.logging_utils import get_logger
16
+
17
+ from ..readwrite import time_string
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ def _convert_FASTA_record(
23
+ record: SeqRecord,
24
+ modification_type: str,
25
+ strand: str,
26
+ unconverted: str,
27
+ ) -> SeqRecord:
28
+ """Convert a FASTA record based on modification type and strand.
29
+
30
+ Args:
31
+ record: Input FASTA record.
32
+ modification_type: Modification type (e.g., ``5mC`` or ``6mA``).
33
+ strand: Strand label (``top`` or ``bottom``).
34
+ unconverted: Label for the unconverted record type.
35
+
36
+ Returns:
37
+ Bio.SeqRecord.SeqRecord: Converted FASTA record.
20
38
 
21
- def _convert_FASTA_record(record, modification_type, strand, unconverted):
22
- """ Converts a FASTA record based on modification type and strand. """
39
+ Raises:
40
+ ValueError: If the modification type/strand combination is invalid.
41
+ """
23
42
  conversion_maps = {
24
- ('5mC', 'top'): ('C', 'T'),
25
- ('5mC', 'bottom'): ('G', 'A'),
26
- ('6mA', 'top'): ('A', 'G'),
27
- ('6mA', 'bottom'): ('T', 'C')
43
+ ("5mC", "top"): ("C", "T"),
44
+ ("5mC", "bottom"): ("G", "A"),
45
+ ("6mA", "top"): ("A", "G"),
46
+ ("6mA", "bottom"): ("T", "C"),
28
47
  }
29
48
 
30
49
  sequence = str(record.seq).upper()
31
50
 
32
51
  if modification_type == unconverted:
33
- return SeqRecord(Seq(sequence), id=f"{record.id}_{modification_type}_top", description=record.description)
52
+ return SeqRecord(
53
+ Seq(sequence), id=f"{record.id}_{modification_type}_top", description=record.description
54
+ )
34
55
 
35
56
  if (modification_type, strand) not in conversion_maps:
36
57
  raise ValueError(f"Invalid combination: {modification_type}, {strand}")
@@ -38,62 +59,80 @@ def _convert_FASTA_record(record, modification_type, strand, unconverted):
38
59
  original_base, converted_base = conversion_maps[(modification_type, strand)]
39
60
  new_seq = sequence.replace(original_base, converted_base)
40
61
 
41
- return SeqRecord(Seq(new_seq), id=f"{record.id}_{modification_type}_{strand}", description=record.description)
62
+ return SeqRecord(
63
+ Seq(new_seq), id=f"{record.id}_{modification_type}_{strand}", description=record.description
64
+ )
65
+
66
+
67
+ def _process_fasta_record(
68
+ args: tuple[SeqRecord, Iterable[str], Iterable[str], str],
69
+ ) -> list[SeqRecord]:
70
+ """Process a single FASTA record for parallel conversion.
42
71
 
43
- def _process_fasta_record(args):
44
- """
45
- Processes a single FASTA record for parallel execution.
46
72
  Args:
47
- args (tuple): (record, modification_types, strands, unconverted)
73
+ args: Tuple containing ``(record, modification_types, strands, unconverted)``.
74
+
48
75
  Returns:
49
- list of modified SeqRecord objects.
76
+ list[Bio.SeqRecord.SeqRecord]: Converted FASTA records.
50
77
  """
51
78
  record, modification_types, strands, unconverted = args
52
79
  modified_records = []
53
-
80
+
54
81
  for modification_type in modification_types:
55
82
  for i, strand in enumerate(strands):
56
83
  if i > 0 and modification_type == unconverted:
57
84
  continue # Ensure unconverted is added only once
58
85
 
59
- modified_records.append(_convert_FASTA_record(record, modification_type, strand, unconverted))
86
+ modified_records.append(
87
+ _convert_FASTA_record(record, modification_type, strand, unconverted)
88
+ )
60
89
 
61
90
  return modified_records
62
91
 
63
- def generate_converted_FASTA(input_fasta, modification_types, strands, output_fasta, num_threads=4, chunk_size=500):
64
- """
65
- Converts an input FASTA file and writes a new converted FASTA file efficiently.
66
92
 
67
- Parameters:
68
- input_fasta (str): Path to the unconverted FASTA file.
69
- modification_types (list): List of modification types ('5mC', '6mA', or unconverted).
70
- strands (list): List of strands ('top', 'bottom').
71
- output_fasta (str): Path to the converted FASTA output file.
72
- num_threads (int): Number of parallel threads to use.
73
- chunk_size (int): Number of records to process per write batch.
93
+ def generate_converted_FASTA(
94
+ input_fasta: str | Path,
95
+ modification_types: list[str],
96
+ strands: list[str],
97
+ output_fasta: str | Path,
98
+ num_threads: int = 4,
99
+ chunk_size: int = 500,
100
+ ) -> None:
101
+ """Convert a FASTA file and write converted records to disk.
74
102
 
75
- Returns:
76
- None (Writes the converted FASTA file).
103
+ Args:
104
+ input_fasta: Path to the unconverted FASTA file.
105
+ modification_types: List of modification types (``5mC``, ``6mA``, or unconverted).
106
+ strands: List of strands (``top``, ``bottom``).
107
+ output_fasta: Path to the converted FASTA output file.
108
+ num_threads: Number of parallel workers to use.
109
+ chunk_size: Number of records to process per write batch.
77
110
  """
78
111
  unconverted = modification_types[0]
79
112
  input_fasta = str(input_fasta)
80
113
  output_fasta = str(output_fasta)
81
114
 
82
115
  # Detect if input is gzipped
83
- open_func = gzip.open if input_fasta.endswith('.gz') else open
84
- file_mode = 'rt' if input_fasta.endswith('.gz') else 'r'
116
+ open_func = gzip.open if input_fasta.endswith(".gz") else open
117
+ file_mode = "rt" if input_fasta.endswith(".gz") else "r"
85
118
 
86
119
  def _fasta_record_generator():
87
- """ Lazily yields FASTA records from file. """
120
+ """Lazily yields FASTA records from file."""
88
121
  with open_func(input_fasta, file_mode) as handle:
89
- for record in SeqIO.parse(handle, 'fasta'):
122
+ for record in SeqIO.parse(handle, "fasta"):
90
123
  yield record
91
124
 
92
- with open(output_fasta, 'w') as output_handle, ProcessPoolExecutor(max_workers=num_threads) as executor:
125
+ with (
126
+ open(output_fasta, "w") as output_handle,
127
+ ProcessPoolExecutor(max_workers=num_threads) as executor,
128
+ ):
93
129
  # Process records in parallel using a named function (avoiding lambda)
94
130
  results = executor.map(
95
131
  _process_fasta_record,
96
- ((record, modification_types, strands, unconverted) for record in _fasta_record_generator())
132
+ (
133
+ (record, modification_types, strands, unconverted)
134
+ for record in _fasta_record_generator()
135
+ ),
97
136
  )
98
137
 
99
138
  buffer = []
@@ -102,14 +141,24 @@ def generate_converted_FASTA(input_fasta, modification_types, strands, output_fa
102
141
 
103
142
  # Write out in chunks to save memory
104
143
  if len(buffer) >= chunk_size:
105
- SeqIO.write(buffer, output_handle, 'fasta')
144
+ SeqIO.write(buffer, output_handle, "fasta")
106
145
  buffer.clear()
107
146
 
108
147
  # Write any remaining records
109
148
  if buffer:
110
- SeqIO.write(buffer, output_handle, 'fasta')
149
+ SeqIO.write(buffer, output_handle, "fasta")
150
+
111
151
 
112
152
  def index_fasta(fasta: str | Path, write_chrom_sizes: bool = True) -> Path:
153
+ """Index a FASTA file and optionally write chromosome sizes.
154
+
155
+ Args:
156
+ fasta: Path to the FASTA file.
157
+ write_chrom_sizes: Whether to write a ``.chrom.sizes`` file.
158
+
159
+ Returns:
160
+ Path: Path to the index file or chromosome sizes file.
161
+ """
113
162
  fasta = Path(fasta)
114
163
  pysam.faidx(str(fasta)) # creates <fasta>.fai
115
164
 
@@ -123,9 +172,15 @@ def index_fasta(fasta: str | Path, write_chrom_sizes: bool = True) -> Path:
123
172
  return chrom_sizes
124
173
  return fai
125
174
 
175
+
126
176
  def get_chromosome_lengths(fasta: str | Path) -> Path:
127
- """
128
- Create (or reuse) <fasta>.chrom.sizes, derived from the FASTA index.
177
+ """Create or reuse ``<fasta>.chrom.sizes`` derived from the FASTA index.
178
+
179
+ Args:
180
+ fasta: Path to the FASTA file.
181
+
182
+ Returns:
183
+ Path: Path to the chromosome sizes file.
129
184
  """
130
185
  fasta = Path(fasta)
131
186
  fai = fasta.with_suffix(fasta.suffix + ".fai")
@@ -133,7 +188,7 @@ def get_chromosome_lengths(fasta: str | Path) -> Path:
133
188
  index_fasta(fasta, write_chrom_sizes=True) # will also create .chrom.sizes
134
189
  chrom_sizes = fasta.with_suffix(".chrom.sizes")
135
190
  if chrom_sizes.exists():
136
- print(f"Using existing chrom length file: {chrom_sizes}")
191
+ logger.debug(f"Using existing chrom length file: {chrom_sizes}")
137
192
  return chrom_sizes
138
193
 
139
194
  # Build chrom.sizes from .fai
@@ -143,10 +198,15 @@ def get_chromosome_lengths(fasta: str | Path) -> Path:
143
198
  out.write(f"{chrom}\t{size}\n")
144
199
  return chrom_sizes
145
200
 
201
+
146
202
  def get_native_references(fasta_file: str | Path) -> Dict[str, Tuple[int, str]]:
147
- """
148
- Return {record_id: (length, sequence)} from a FASTA.
149
- Direct methylation specific
203
+ """Return record lengths and sequences from a FASTA file.
204
+
205
+ Args:
206
+ fasta_file: Path to the FASTA file.
207
+
208
+ Returns:
209
+ dict[str, tuple[int, str]]: Mapping of record ID to ``(length, sequence)``.
150
210
  """
151
211
  fasta_file = Path(fasta_file)
152
212
  print(f"{time_string()}: Opening FASTA file {fasta_file}")
@@ -157,28 +217,35 @@ def get_native_references(fasta_file: str | Path) -> Dict[str, Tuple[int, str]]:
157
217
  record_dict[rec.id] = (len(seq), seq)
158
218
  return record_dict
159
219
 
160
- def find_conversion_sites(fasta_file, modification_type, conversions, deaminase_footprinting=False):
161
- """
162
- Finds genomic coordinates of modified bases (5mC or 6mA) in a reference FASTA file.
163
-
164
- Parameters:
165
- fasta_file (str): Path to the converted reference FASTA.
166
- modification_type (str): Modification type ('5mC' or '6mA') or 'unconverted'.
167
- conversions (list): List of conversion types. The first element is the unconverted record type.
168
- deaminase_footprinting (bool): Whether the footprinting was done with a direct deamination chemistry.
169
-
170
- Returns:
171
- dict: Dictionary where keys are **both unconverted & converted record names**.
172
- Values contain:
173
- [sequence length, top strand coordinates, bottom strand coordinates, sequence, complement sequence].
220
+
221
+ def find_conversion_sites(
222
+ fasta_file: str | Path,
223
+ modification_type: str,
224
+ conversions: list[str],
225
+ deaminase_footprinting: bool = False,
226
+ ) -> dict[str, list]:
227
+ """Find genomic coordinates of modified bases in a reference FASTA.
228
+
229
+ Args:
230
+ fasta_file: Path to the converted reference FASTA.
231
+ modification_type: Modification type (``5mC``, ``6mA``, or ``unconverted``).
232
+ conversions: List of conversion types (first entry is the unconverted record type).
233
+ deaminase_footprinting: Whether the footprinting used direct deamination chemistry.
234
+
235
+ Returns:
236
+ dict[str, list]: Mapping of record name to
237
+ ``[sequence length, top strand coordinates, bottom strand coordinates, sequence, complement]``.
238
+
239
+ Raises:
240
+ ValueError: If the modification type is invalid.
174
241
  """
175
242
  unconverted = conversions[0]
176
243
  record_dict = {}
177
244
 
178
245
  # Define base mapping based on modification type
179
246
  base_mappings = {
180
- '5mC': ('C', 'G'), # Cytosine and Guanine
181
- '6mA': ('A', 'T') # Adenine and Thymine
247
+ "5mC": ("C", "G"), # Cytosine and Guanine
248
+ "6mA": ("A", "T"), # Adenine and Thymine
182
249
  }
183
250
 
184
251
  # Read FASTA file and process records
@@ -200,22 +267,35 @@ def find_conversion_sites(fasta_file, modification_type, conversions, deaminase_
200
267
  top_strand_coordinates = np.where(seq_array == top_base)[0].tolist()
201
268
  bottom_strand_coordinates = np.where(seq_array == bottom_base)[0].tolist()
202
269
 
203
- record_dict[record.id] = [sequence_length, top_strand_coordinates, bottom_strand_coordinates, sequence, complement]
270
+ record_dict[record.id] = [
271
+ sequence_length,
272
+ top_strand_coordinates,
273
+ bottom_strand_coordinates,
274
+ sequence,
275
+ complement,
276
+ ]
204
277
 
205
278
  else:
206
- raise ValueError(f"Invalid modification_type: {modification_type}. Choose '5mC', '6mA', or 'unconverted'.")
279
+ raise ValueError(
280
+ f"Invalid modification_type: {modification_type}. Choose '5mC', '6mA', or 'unconverted'."
281
+ )
207
282
 
208
283
  return record_dict
209
284
 
285
+
210
286
  def subsample_fasta_from_bed(
211
287
  input_FASTA: str | Path,
212
288
  input_bed: str | Path,
213
289
  output_directory: str | Path,
214
- output_FASTA: str | Path
290
+ output_FASTA: str | Path,
215
291
  ) -> None:
216
- """
217
- Take a genome-wide FASTA file and a BED file containing
218
- coordinate windows of interest. Outputs a subsampled FASTA.
292
+ """Subsample a FASTA using BED coordinates.
293
+
294
+ Args:
295
+ input_FASTA: Genome-wide FASTA path.
296
+ input_bed: BED file path containing coordinate windows of interest.
297
+ output_directory: Directory to write the subsampled FASTA.
298
+ output_FASTA: Output FASTA path.
219
299
  """
220
300
 
221
301
  # Normalize everything to Path
@@ -227,22 +307,20 @@ def subsample_fasta_from_bed(
227
307
  # Ensure output directory exists
228
308
  output_directory.mkdir(parents=True, exist_ok=True)
229
309
 
230
- output_FASTA_path = output_directory / output_FASTA
231
-
232
310
  # Load the FASTA file using pyfaidx
233
- fasta = Fasta(str(input_FASTA)) # pyfaidx requires string paths
311
+ fasta = Fasta(str(input_FASTA)) # pyfaidx requires string paths
234
312
 
235
313
  # Open BED + output FASTA
236
- with input_bed.open("r") as bed, output_FASTA_path.open("w") as out_fasta:
314
+ with input_bed.open("r") as bed, output_FASTA.open("w") as out_fasta:
237
315
  for line in bed:
238
316
  fields = line.strip().split()
239
317
  chrom = fields[0]
240
- start = int(fields[1]) # BED is 0-based
241
- end = int(fields[2]) # BED is 0-based and end is exclusive
242
- desc = " ".join(fields[3:]) if len(fields) > 3 else ""
318
+ start = int(fields[1]) # BED is 0-based
319
+ end = int(fields[2]) # BED is 0-based and end is exclusive
320
+ desc = " ".join(fields[3:]) if len(fields) > 3 else ""
243
321
 
244
322
  if chrom not in fasta:
245
- print(f"Warning: {chrom} not found in FASTA")
323
+ logger.warning(f"{chrom} not found in FASTA")
246
324
  continue
247
325
 
248
326
  # pyfaidx is 1-based indexing internally, but [start:end] works with BED coords
@@ -252,4 +330,4 @@ def subsample_fasta_from_bed(
252
330
  if desc:
253
331
  header += f" {desc}"
254
332
 
255
- out_fasta.write(f"{header}\n{sequence}\n")
333
+ out_fasta.write(f"{header}\n{sequence}\n")
@@ -1,8 +1,18 @@
1
+ import glob
2
+ import os
3
+ from concurrent.futures import ProcessPoolExecutor, as_completed
1
4
  from pathlib import Path
2
- import pandas as pd
5
+ from typing import Dict, List, Optional, Union
6
+
3
7
  import numpy as np
8
+ import pandas as pd
4
9
  import scipy.sparse as sp
5
- from typing import Optional, List, Dict, Union
10
+ from pod5 import Reader
11
+
12
+ from smftools.logging_utils import get_logger
13
+
14
+ logger = get_logger(__name__)
15
+
6
16
 
7
17
  def add_demux_type_annotation(
8
18
  adata,
@@ -71,14 +81,15 @@ def add_demux_type_annotation(
71
81
 
72
82
  return adata
73
83
 
84
+
74
85
  def add_read_length_and_mapping_qc(
75
86
  adata,
76
87
  bam_files: Optional[List[str]] = None,
77
88
  read_metrics: Optional[Dict[str, Union[list, tuple]]] = None,
78
89
  uns_flag: str = "add_read_length_and_mapping_qc_performed",
79
- extract_read_features_from_bam_callable = None,
90
+ extract_read_features_from_bam_callable=None,
80
91
  bypass: bool = False,
81
- force_redo: bool = True
92
+ force_redo: bool = True,
82
93
  ):
83
94
  """
84
95
  Populate adata.obs with read/mapping QC columns.
@@ -98,6 +109,7 @@ def add_read_length_and_mapping_qc(
98
109
  Optional callable(bam_path) -> dict mapping read_name -> list/tuple of metrics.
99
110
  If not provided and bam_files is given, function will attempt to call `extract_read_features_from_bam`
100
111
  from the global namespace (your existing helper).
112
+
101
113
  Returns
102
114
  -------
103
115
  None (mutates final_adata in-place)
@@ -113,9 +125,13 @@ def add_read_length_and_mapping_qc(
113
125
  if read_metrics is None:
114
126
  read_metrics = {}
115
127
  if bam_files:
116
- extractor = extract_read_features_from_bam_callable or globals().get("extract_read_features_from_bam")
128
+ extractor = extract_read_features_from_bam_callable or globals().get(
129
+ "extract_read_features_from_bam"
130
+ )
117
131
  if extractor is None:
118
- raise ValueError("No `read_metrics` provided and `extract_read_features_from_bam` not found.")
132
+ raise ValueError(
133
+ "No `read_metrics` provided and `extract_read_features_from_bam` not found."
134
+ )
119
135
  for bam in bam_files:
120
136
  bam_read_metrics = extractor(bam)
121
137
  if not isinstance(bam_read_metrics, dict):
@@ -130,11 +146,11 @@ def add_read_length_and_mapping_qc(
130
146
  if len(read_metrics) == 0:
131
147
  # fill with NaNs
132
148
  n = adata.n_obs
133
- adata.obs['read_length'] = np.full(n, np.nan)
134
- adata.obs['mapped_length'] = np.full(n, np.nan)
135
- adata.obs['reference_length'] = np.full(n, np.nan)
136
- adata.obs['read_quality'] = np.full(n, np.nan)
137
- adata.obs['mapping_quality'] = np.full(n, np.nan)
149
+ adata.obs["read_length"] = np.full(n, np.nan)
150
+ adata.obs["mapped_length"] = np.full(n, np.nan)
151
+ adata.obs["reference_length"] = np.full(n, np.nan)
152
+ adata.obs["read_quality"] = np.full(n, np.nan)
153
+ adata.obs["mapping_quality"] = np.full(n, np.nan)
138
154
  else:
139
155
  # Build DF robustly
140
156
  # Convert values to lists where possible, else to [val, val, val...]
@@ -151,35 +167,45 @@ def add_read_length_and_mapping_qc(
151
167
  vals = vals + [np.nan] * (max_cols - len(vals))
152
168
  rows[k] = vals[:max_cols]
153
169
 
154
- df = pd.DataFrame.from_dict(rows, orient='index', columns=[
155
- 'read_length', 'read_quality', 'reference_length', 'mapped_length', 'mapping_quality'
156
- ])
170
+ df = pd.DataFrame.from_dict(
171
+ rows,
172
+ orient="index",
173
+ columns=[
174
+ "read_length",
175
+ "read_quality",
176
+ "reference_length",
177
+ "mapped_length",
178
+ "mapping_quality",
179
+ ],
180
+ )
157
181
 
158
182
  # Reindex to final_adata.obs_names so order matches adata
159
183
  # If obs_names are not present as keys in df, the results will be NaN
160
184
  df_reindexed = df.reindex(adata.obs_names).astype(float)
161
185
 
162
- adata.obs['read_length'] = df_reindexed['read_length'].values
163
- adata.obs['mapped_length'] = df_reindexed['mapped_length'].values
164
- adata.obs['reference_length'] = df_reindexed['reference_length'].values
165
- adata.obs['read_quality'] = df_reindexed['read_quality'].values
166
- adata.obs['mapping_quality'] = df_reindexed['mapping_quality'].values
186
+ adata.obs["read_length"] = df_reindexed["read_length"].values
187
+ adata.obs["mapped_length"] = df_reindexed["mapped_length"].values
188
+ adata.obs["reference_length"] = df_reindexed["reference_length"].values
189
+ adata.obs["read_quality"] = df_reindexed["read_quality"].values
190
+ adata.obs["mapping_quality"] = df_reindexed["mapping_quality"].values
167
191
 
168
192
  # Compute ratio columns safely (avoid divide-by-zero and preserve NaN)
169
193
  # read_length_to_reference_length_ratio
170
- rl = pd.to_numeric(adata.obs['read_length'], errors='coerce').to_numpy(dtype=float)
171
- ref_len = pd.to_numeric(adata.obs['reference_length'], errors='coerce').to_numpy(dtype=float)
172
- mapped_len = pd.to_numeric(adata.obs['mapped_length'], errors='coerce').to_numpy(dtype=float)
194
+ rl = pd.to_numeric(adata.obs["read_length"], errors="coerce").to_numpy(dtype=float)
195
+ ref_len = pd.to_numeric(adata.obs["reference_length"], errors="coerce").to_numpy(dtype=float)
196
+ mapped_len = pd.to_numeric(adata.obs["mapped_length"], errors="coerce").to_numpy(dtype=float)
173
197
 
174
198
  # safe divisions: use np.where to avoid warnings and replace inf with nan
175
- with np.errstate(divide='ignore', invalid='ignore'):
199
+ with np.errstate(divide="ignore", invalid="ignore"):
176
200
  rl_to_ref = np.where((ref_len != 0) & np.isfinite(ref_len), rl / ref_len, np.nan)
177
- mapped_to_ref = np.where((ref_len != 0) & np.isfinite(ref_len), mapped_len / ref_len, np.nan)
201
+ mapped_to_ref = np.where(
202
+ (ref_len != 0) & np.isfinite(ref_len), mapped_len / ref_len, np.nan
203
+ )
178
204
  mapped_to_read = np.where((rl != 0) & np.isfinite(rl), mapped_len / rl, np.nan)
179
205
 
180
- adata.obs['read_length_to_reference_length_ratio'] = rl_to_ref
181
- adata.obs['mapped_length_to_reference_length_ratio'] = mapped_to_ref
182
- adata.obs['mapped_length_to_read_length_ratio'] = mapped_to_read
206
+ adata.obs["read_length_to_reference_length_ratio"] = rl_to_ref
207
+ adata.obs["mapped_length_to_reference_length_ratio"] = mapped_to_ref
208
+ adata.obs["mapped_length_to_read_length_ratio"] = mapped_to_read
183
209
 
184
210
  # Add read level raw modification signal: sum over X rows
185
211
  X = adata.X
@@ -189,9 +215,149 @@ def add_read_length_and_mapping_qc(
189
215
  else:
190
216
  raw_sig = np.asarray(X.sum(axis=1)).ravel()
191
217
 
192
- adata.obs['Raw_modification_signal'] = raw_sig
218
+ adata.obs["Raw_modification_signal"] = raw_sig
193
219
 
194
220
  # mark as done
195
221
  adata.uns[uns_flag] = True
196
222
 
197
- return None
223
+ return None
224
+
225
+
226
+ def _collect_read_origins_from_pod5(pod5_path: str, target_ids: set[str]) -> dict[str, str]:
227
+ """
228
+ Worker function: scan one POD5 file and return a mapping
229
+ {read_id: pod5_basename} only for read_ids in `target_ids`.
230
+ """
231
+ basename = os.path.basename(pod5_path)
232
+ mapping: dict[str, str] = {}
233
+
234
+ with Reader(pod5_path) as reader:
235
+ for read in reader.reads():
236
+ # Cast read id to string
237
+ rid = str(read.read_id)
238
+ if rid in target_ids:
239
+ mapping[rid] = basename
240
+
241
+ return mapping
242
+
243
+
244
+ def annotate_pod5_origin(
245
+ adata,
246
+ pod5_path_or_dir: str | Path,
247
+ pattern: str = "*.pod5",
248
+ n_jobs: int | None = None,
249
+ fill_value: str | None = "unknown",
250
+ verbose: bool = True,
251
+ csv_path: str | None = None,
252
+ ):
253
+ """
254
+ Add `pod5_origin` column to `adata.obs`, containing the POD5 basename
255
+ each read came from.
256
+
257
+ Parameters
258
+ ----------
259
+ adata
260
+ AnnData with obs_names == read_ids (as strings).
261
+ pod5_path_or_dir
262
+ Directory containing POD5 files or path to a single POD5 file.
263
+ pattern
264
+ Glob pattern for POD5 files inside `pod5_dir`.
265
+ n_jobs
266
+ Number of worker processes. If None or <=1, runs serially.
267
+ fill_value
268
+ Value to use when a read_id is not found in any POD5 file.
269
+ If None, leaves missing as NaN.
270
+ verbose
271
+ Print progress info.
272
+ csv_path
273
+ Path to a csv of the read to pod5 origin mapping
274
+
275
+ Returns
276
+ -------
277
+ None (modifies `adata` in-place).
278
+ """
279
+ pod5_path_or_dir = Path(pod5_path_or_dir)
280
+
281
+ # --- Resolve input into a list of pod5 files ---
282
+ if pod5_path_or_dir.is_dir():
283
+ pod5_files = sorted(str(p) for p in pod5_path_or_dir.glob(pattern))
284
+ if not pod5_files:
285
+ raise FileNotFoundError(
286
+ f"No POD5 files matching {pattern!r} in {str(pod5_path_or_dir)!r}"
287
+ )
288
+ elif pod5_path_or_dir.is_file():
289
+ if pod5_path_or_dir.suffix.lower() != ".pod5":
290
+ raise ValueError(f"Expected a .pod5 file, got: {pod5_path_or_dir}")
291
+ pod5_files = [str(pod5_path_or_dir)]
292
+ else:
293
+ raise FileNotFoundError(f"Path does not exist: {pod5_path_or_dir}")
294
+
295
+ # Make sure obs_names are strings
296
+ obs_names = adata.obs_names.astype(str)
297
+ target_ids = set(obs_names) # only these are interesting
298
+
299
+ if verbose:
300
+ logger.info(f"Found {len(pod5_files)} POD5 files.")
301
+ logger.info(f"Tracking {len(target_ids)} read IDs from AnnData.")
302
+
303
+ # --- Collect mappings (possibly multiprocessed) ---
304
+ global_mapping: dict[str, str] = {}
305
+
306
+ if n_jobs is None or n_jobs <= 1:
307
+ # Serial version (less overhead, useful for debugging)
308
+ if verbose:
309
+ logger.debug("Running in SERIAL mode.")
310
+ for f in pod5_files:
311
+ if verbose:
312
+ logger.debug(f" Scanning {os.path.basename(f)} ...")
313
+ part = _collect_read_origins_from_pod5(f, target_ids)
314
+ global_mapping.update(part)
315
+ else:
316
+ if verbose:
317
+ logger.debug(f"Running in PARALLEL mode with {n_jobs} workers.")
318
+ with ProcessPoolExecutor(max_workers=n_jobs) as ex:
319
+ futures = {
320
+ ex.submit(_collect_read_origins_from_pod5, f, target_ids): f for f in pod5_files
321
+ }
322
+ for fut in as_completed(futures):
323
+ f = futures[fut]
324
+ try:
325
+ part = fut.result()
326
+ except Exception as e:
327
+ logger.warning(f"Error while processing {f}: {e}")
328
+ continue
329
+ global_mapping.update(part)
330
+ if verbose:
331
+ logger.info(f" Finished {os.path.basename(f)} ({len(part)} matching reads)")
332
+
333
+ if verbose:
334
+ logger.info(f"Total reads matched: {len(global_mapping)}")
335
+
336
+ # --- Populate obs['pod5_origin'] in AnnData order, memory-efficiently ---
337
+ origin = np.empty(adata.n_obs, dtype=object)
338
+ default = None if fill_value is None else fill_value
339
+ for i, rid in enumerate(obs_names):
340
+ origin[i] = global_mapping.get(rid, default)
341
+
342
+ adata.obs["pod5_origin"] = origin
343
+ if verbose:
344
+ logger.info("Assigned `pod5_origin` to adata.obs.")
345
+
346
+ # --- Optionally write a CSV ---
347
+ if csv_path is not None:
348
+ if verbose:
349
+ logger.info(f"Writing CSV mapping to: {csv_path}")
350
+
351
+ # Create DataFrame in AnnData order for easier cross-referencing
352
+ df = pd.DataFrame(
353
+ {
354
+ "read_id": obs_names,
355
+ "pod5_origin": origin,
356
+ }
357
+ )
358
+ df.to_csv(csv_path, index=False)
359
+
360
+ if verbose:
361
+ logger.info("CSV saved.")
362
+
363
+ return global_mapping