smftools 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/helpers.py +32 -6
  3. smftools/cli/hmm_adata.py +232 -31
  4. smftools/cli/latent_adata.py +318 -0
  5. smftools/cli/load_adata.py +77 -73
  6. smftools/cli/preprocess_adata.py +178 -53
  7. smftools/cli/spatial_adata.py +149 -101
  8. smftools/cli_entry.py +12 -0
  9. smftools/config/conversion.yaml +11 -1
  10. smftools/config/default.yaml +38 -1
  11. smftools/config/experiment_config.py +53 -1
  12. smftools/constants.py +65 -0
  13. smftools/hmm/HMM.py +88 -0
  14. smftools/informatics/__init__.py +6 -0
  15. smftools/informatics/bam_functions.py +358 -8
  16. smftools/informatics/converted_BAM_to_adata.py +584 -163
  17. smftools/informatics/h5ad_functions.py +115 -2
  18. smftools/informatics/modkit_extract_to_adata.py +1003 -425
  19. smftools/informatics/sequence_encoding.py +72 -0
  20. smftools/logging_utils.py +21 -2
  21. smftools/metadata.py +1 -1
  22. smftools/plotting/__init__.py +9 -0
  23. smftools/plotting/general_plotting.py +2411 -628
  24. smftools/plotting/hmm_plotting.py +85 -7
  25. smftools/preprocessing/__init__.py +1 -0
  26. smftools/preprocessing/append_base_context.py +17 -17
  27. smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
  28. smftools/preprocessing/calculate_consensus.py +1 -1
  29. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  30. smftools/readwrite.py +53 -17
  31. smftools/schema/anndata_schema_v1.yaml +15 -1
  32. smftools/tools/__init__.py +4 -0
  33. smftools/tools/calculate_leiden.py +57 -0
  34. smftools/tools/calculate_nmf.py +119 -0
  35. smftools/tools/calculate_umap.py +91 -8
  36. smftools/tools/rolling_nn_distance.py +235 -0
  37. smftools/tools/tensor_factorization.py +169 -0
  38. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/METADATA +8 -6
  39. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/RECORD +42 -35
  40. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
  41. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
  42. {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -4,13 +4,48 @@ import concurrent.futures
4
4
  import gc
5
5
  import re
6
6
  import shutil
7
+ from dataclasses import dataclass, field
7
8
  from pathlib import Path
8
- from typing import Iterable, Optional, Union
9
+ from typing import Iterable, Mapping, Optional, Union
9
10
 
10
11
  import numpy as np
11
12
  import pandas as pd
12
13
  from tqdm import tqdm
13
14
 
15
+ from smftools.constants import (
16
+ BARCODE,
17
+ BASE_QUALITY_SCORES,
18
+ DATASET,
19
+ DEMUX_TYPE,
20
+ H5_DIR,
21
+ MISMATCH_INTEGER_ENCODING,
22
+ MODKIT_EXTRACT_CALL_CODE_CANONICAL,
23
+ MODKIT_EXTRACT_CALL_CODE_MODIFIED,
24
+ MODKIT_EXTRACT_MODIFIED_BASE_A,
25
+ MODKIT_EXTRACT_MODIFIED_BASE_C,
26
+ MODKIT_EXTRACT_REF_STRAND_MINUS,
27
+ MODKIT_EXTRACT_REF_STRAND_PLUS,
28
+ MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
29
+ MODKIT_EXTRACT_SEQUENCE_BASES,
30
+ MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE,
31
+ MODKIT_EXTRACT_SEQUENCE_PADDING_BASE,
32
+ MODKIT_EXTRACT_TSV_COLUMN_CALL_CODE,
33
+ MODKIT_EXTRACT_TSV_COLUMN_CALL_PROB,
34
+ MODKIT_EXTRACT_TSV_COLUMN_CHROM,
35
+ MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE,
36
+ MODKIT_EXTRACT_TSV_COLUMN_READ_ID,
37
+ MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION,
38
+ MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND,
39
+ READ_MAPPING_DIRECTION,
40
+ READ_SPAN_MASK,
41
+ REFERENCE,
42
+ REFERENCE_DATASET_STRAND,
43
+ REFERENCE_STRAND,
44
+ SAMPLE,
45
+ SEQUENCE_INTEGER_DECODING,
46
+ SEQUENCE_INTEGER_ENCODING,
47
+ STRAND,
48
+ )
14
49
  from smftools.logging_utils import get_logger
15
50
 
16
51
  from .bam_functions import count_aligned_reads
@@ -18,8 +53,78 @@ from .bam_functions import count_aligned_reads
18
53
  logger = get_logger(__name__)
19
54
 
20
55
 
56
+ @dataclass
57
+ class ModkitBatchDictionaries:
58
+ """Container for per-batch modification dictionaries.
59
+
60
+ Attributes:
61
+ dict_total: Raw TSV DataFrames keyed by record and sample index.
62
+ dict_a: Adenine modification DataFrames.
63
+ dict_a_bottom: Adenine minus-strand DataFrames.
64
+ dict_a_top: Adenine plus-strand DataFrames.
65
+ dict_c: Cytosine modification DataFrames.
66
+ dict_c_bottom: Cytosine minus-strand DataFrames.
67
+ dict_c_top: Cytosine plus-strand DataFrames.
68
+ dict_combined_bottom: Combined minus-strand methylation arrays.
69
+ dict_combined_top: Combined plus-strand methylation arrays.
70
+ """
71
+
72
+ dict_total: dict = field(default_factory=dict)
73
+ dict_a: dict = field(default_factory=dict)
74
+ dict_a_bottom: dict = field(default_factory=dict)
75
+ dict_a_top: dict = field(default_factory=dict)
76
+ dict_c: dict = field(default_factory=dict)
77
+ dict_c_bottom: dict = field(default_factory=dict)
78
+ dict_c_top: dict = field(default_factory=dict)
79
+ dict_combined_bottom: dict = field(default_factory=dict)
80
+ dict_combined_top: dict = field(default_factory=dict)
81
+
82
+ @property
83
+ def sample_types(self) -> list[str]:
84
+ """Return ordered labels for the dictionary list."""
85
+ return [
86
+ "total",
87
+ "m6A",
88
+ "m6A_bottom_strand",
89
+ "m6A_top_strand",
90
+ "5mC",
91
+ "5mC_bottom_strand",
92
+ "5mC_top_strand",
93
+ "combined_bottom_strand",
94
+ "combined_top_strand",
95
+ ]
96
+
97
+ def as_list(self) -> list[dict]:
98
+ """Return the dictionaries in the expected list ordering."""
99
+ return [
100
+ self.dict_total,
101
+ self.dict_a,
102
+ self.dict_a_bottom,
103
+ self.dict_a_top,
104
+ self.dict_c,
105
+ self.dict_c_bottom,
106
+ self.dict_c_top,
107
+ self.dict_combined_bottom,
108
+ self.dict_combined_top,
109
+ ]
110
+
111
+
21
112
  def filter_bam_records(bam, mapping_threshold, samtools_backend: str | None = "auto"):
22
- """Processes a single BAM file, counts reads, and determines records to analyze."""
113
+ """Identify reference records that exceed a mapping threshold in one BAM.
114
+
115
+ Args:
116
+ bam (Path | str): BAM file to inspect.
117
+ mapping_threshold (float): Minimum fraction of mapped reads required to keep a record.
118
+ samtools_backend (str | None): Samtools backend selection.
119
+
120
+ Returns:
121
+ set[str]: Record names that pass the mapping threshold.
122
+
123
+ Processing Steps:
124
+ 1. Count aligned/unaligned reads per record.
125
+ 2. Compute percent aligned and per-record mapping percentages.
126
+ 3. Return records whose mapping fraction meets the threshold.
127
+ """
23
128
  aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(
24
129
  bam, samtools_backend
25
130
  )
@@ -40,7 +145,21 @@ def filter_bam_records(bam, mapping_threshold, samtools_backend: str | None = "a
40
145
 
41
146
 
42
147
  def parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend: str | None = "auto"):
43
- """Parallel processing for multiple BAM files."""
148
+ """Aggregate mapping-threshold records across BAM files in parallel.
149
+
150
+ Args:
151
+ bam_path_list (list[Path | str]): BAM files to scan.
152
+ mapping_threshold (float): Minimum fraction of mapped reads required to keep a record.
153
+ samtools_backend (str | None): Samtools backend selection.
154
+
155
+ Returns:
156
+ set[str]: Union of all record names passing the threshold in any BAM.
157
+
158
+ Processing Steps:
159
+ 1. Spawn workers to compute passing records per BAM.
160
+ 2. Merge all passing records into a single set.
161
+ 3. Log the final record set.
162
+ """
44
163
  records_to_analyze = set()
45
164
 
46
165
  with concurrent.futures.ProcessPoolExecutor() as executor:
@@ -60,8 +179,21 @@ def parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend: str
60
179
 
61
180
 
62
181
  def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
63
- """
64
- Loads and filters a single TSV file based on chromosome and position criteria.
182
+ """Load and filter a modkit TSV file for relevant records and positions.
183
+
184
+ Args:
185
+ tsv (Path | str): TSV file produced by modkit extract.
186
+ records_to_analyze (Iterable[str]): Record names to keep.
187
+ reference_dict (dict[str, tuple[int, str]]): Mapping of record to (length, sequence).
188
+ sample_index (int): Sample index to attach to the filtered results.
189
+
190
+ Returns:
191
+ dict[str, dict[int, pd.DataFrame]]: Filtered data keyed by record and sample index.
192
+
193
+ Processing Steps:
194
+ 1. Read the TSV into a DataFrame.
195
+ 2. Filter rows for each record to valid reference positions.
196
+ 3. Emit per-record DataFrames keyed by the provided sample index.
65
197
  """
66
198
  temp_df = pd.read_csv(tsv, sep="\t", header=0)
67
199
  filtered_records = {}
@@ -72,9 +204,9 @@ def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
72
204
 
73
205
  ref_length = reference_dict[record][0]
74
206
  filtered_df = temp_df[
75
- (temp_df["chrom"] == record)
76
- & (temp_df["ref_position"] >= 0)
77
- & (temp_df["ref_position"] < ref_length)
207
+ (temp_df[MODKIT_EXTRACT_TSV_COLUMN_CHROM] == record)
208
+ & (temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION] >= 0)
209
+ & (temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION] < ref_length)
78
210
  ]
79
211
 
80
212
  if not filtered_df.empty:
@@ -84,19 +216,23 @@ def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
84
216
 
85
217
 
86
218
  def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size, threads=4):
87
- """
88
- Loads and filters TSV files in parallel.
219
+ """Load and filter a batch of TSVs in parallel.
89
220
 
90
- Parameters:
91
- tsv_batch (list): List of TSV file paths.
92
- records_to_analyze (list): Chromosome records to analyze.
93
- reference_dict (dict): Dictionary containing reference lengths.
94
- batch (int): Current batch number.
95
- batch_size (int): Total files in the batch.
96
- threads (int): Number of parallel workers.
221
+ Args:
222
+ tsv_batch (list[Path | str]): TSV file paths for the batch.
223
+ records_to_analyze (Iterable[str]): Record names to keep.
224
+ reference_dict (dict[str, tuple[int, str]]): Mapping of record to (length, sequence).
225
+ batch (int): Batch number for progress logging.
226
+ batch_size (int): Number of TSVs in the batch.
227
+ threads (int): Parallel worker count.
97
228
 
98
229
  Returns:
99
- dict: Processed `dict_total` dictionary.
230
+ dict[str, dict[int, pd.DataFrame]]: Per-record DataFrames keyed by sample index.
231
+
232
+ Processing Steps:
233
+ 1. Submit each TSV to a worker via `process_tsv`.
234
+ 2. Merge per-record outputs into a single dictionary.
235
+ 3. Return the aggregated per-record dictionary for the batch.
100
236
  """
101
237
  dict_total = {record: {} for record in records_to_analyze}
102
238
 
@@ -121,15 +257,19 @@ def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, bat
121
257
 
122
258
 
123
259
  def update_dict_to_skip(dict_to_skip, detected_modifications):
124
- """
125
- Updates the dict_to_skip set based on the detected modifications.
260
+ """Update dictionary skip indices based on modifications in the batch.
126
261
 
127
- Parameters:
128
- dict_to_skip (set): The initial set of dictionary indices to skip.
129
- detected_modifications (list or set): The modifications (e.g. ['6mA', '5mC']) present.
262
+ Args:
263
+ dict_to_skip (set[int]): Initial set of dictionary indices to skip.
264
+ detected_modifications (Iterable[str]): Modification labels present (e.g., ["6mA", "5mC"]).
130
265
 
131
266
  Returns:
132
- set: The updated dict_to_skip set.
267
+ set[int]: Updated skip set after considering present modifications.
268
+
269
+ Processing Steps:
270
+ 1. Define indices for A- and C-stranded dictionaries.
271
+ 2. Remove indices for modifications that are present.
272
+ 3. Return the updated skip set.
133
273
  """
134
274
  # Define which indices correspond to modification-specific or strand-specific dictionaries
135
275
  A_stranded_dicts = {2, 3} # m6A bottom and top strand dictionaries
@@ -150,31 +290,49 @@ def update_dict_to_skip(dict_to_skip, detected_modifications):
150
290
 
151
291
 
152
292
  def process_modifications_for_sample(args):
153
- """
154
- Processes a single (record, sample) pair to extract modification-specific data.
293
+ """Extract modification-specific subsets for one record/sample pair.
155
294
 
156
- Parameters:
157
- args: (record, sample_index, sample_df, mods, max_reference_length)
295
+ Args:
296
+ args (tuple): (record, sample_index, sample_df, mods, max_reference_length).
158
297
 
159
298
  Returns:
160
- (record, sample_index, result) where result is a dict with keys:
161
- 'm6A', 'm6A_minus', 'm6A_plus', '5mC', '5mC_minus', '5mC_plus', and
162
- optionally 'combined_minus' and 'combined_plus' (initialized as empty lists).
299
+ tuple[str, int, dict[str, pd.DataFrame | list]]:
300
+ Record, sample index, and a dict of modification-specific DataFrames
301
+ (with optional combined placeholders).
302
+
303
+ Processing Steps:
304
+ 1. Filter by modified base (A/C) when requested.
305
+ 2. Split filtered rows by strand where needed.
306
+ 3. Add empty combined placeholders when both modifications are present.
163
307
  """
164
308
  record, sample_index, sample_df, mods, max_reference_length = args
165
309
  result = {}
166
310
  if "6mA" in mods:
167
- m6a_df = sample_df[sample_df["modified_primary_base"] == "A"]
311
+ m6a_df = sample_df[
312
+ sample_df[MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE]
313
+ == MODKIT_EXTRACT_MODIFIED_BASE_A
314
+ ]
168
315
  result["m6A"] = m6a_df
169
- result["m6A_minus"] = m6a_df[m6a_df["ref_strand"] == "-"]
170
- result["m6A_plus"] = m6a_df[m6a_df["ref_strand"] == "+"]
316
+ result["m6A_minus"] = m6a_df[
317
+ m6a_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_MINUS
318
+ ]
319
+ result["m6A_plus"] = m6a_df[
320
+ m6a_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_PLUS
321
+ ]
171
322
  m6a_df = None
172
323
  gc.collect()
173
324
  if "5mC" in mods:
174
- m5c_df = sample_df[sample_df["modified_primary_base"] == "C"]
325
+ m5c_df = sample_df[
326
+ sample_df[MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE]
327
+ == MODKIT_EXTRACT_MODIFIED_BASE_C
328
+ ]
175
329
  result["5mC"] = m5c_df
176
- result["5mC_minus"] = m5c_df[m5c_df["ref_strand"] == "-"]
177
- result["5mC_plus"] = m5c_df[m5c_df["ref_strand"] == "+"]
330
+ result["5mC_minus"] = m5c_df[
331
+ m5c_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_MINUS
332
+ ]
333
+ result["5mC_plus"] = m5c_df[
334
+ m5c_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_PLUS
335
+ ]
178
336
  m5c_df = None
179
337
  gc.collect()
180
338
  if "6mA" in mods and "5mC" in mods:
@@ -184,11 +342,22 @@ def process_modifications_for_sample(args):
184
342
 
185
343
 
186
344
  def parallel_process_modifications(dict_total, mods, max_reference_length, threads=4):
187
- """
188
- Processes each (record, sample) pair in dict_total in parallel to extract modification-specific data.
345
+ """Parallelize modification extraction across records and samples.
346
+
347
+ Args:
348
+ dict_total (dict[str, dict[int, pd.DataFrame]]): Raw TSV DataFrames per record/sample.
349
+ mods (list[str]): Modification labels to process.
350
+ max_reference_length (int): Maximum reference length in the dataset.
351
+ threads (int): Parallel worker count.
189
352
 
190
353
  Returns:
191
- processed_results: Dict keyed by record, with sub-dict keyed by sample index and the processed results.
354
+ dict[str, dict[int, dict[str, pd.DataFrame | list]]]: Processed results keyed by
355
+ record and sample index.
356
+
357
+ Processing Steps:
358
+ 1. Build a task list of (record, sample) pairs.
359
+ 2. Submit tasks to a process pool.
360
+ 3. Collect and store results in a nested dictionary.
192
361
  """
193
362
  tasks = []
194
363
  for record, sample_dict in dict_total.items():
@@ -208,11 +377,20 @@ def parallel_process_modifications(dict_total, mods, max_reference_length, threa
208
377
 
209
378
 
210
379
  def merge_modification_results(processed_results, mods):
211
- """
212
- Merges individual sample results into global dictionaries.
380
+ """Merge per-sample modification outputs into global dictionaries.
381
+
382
+ Args:
383
+ processed_results (dict[str, dict[int, dict]]): Output of parallel modification extraction.
384
+ mods (list[str]): Modification labels to include.
213
385
 
214
386
  Returns:
215
- A tuple: (m6A_dict, m6A_minus, m6A_plus, c5m_dict, c5m_minus, c5m_plus, combined_minus, combined_plus)
387
+ tuple[dict, dict, dict, dict, dict, dict, dict, dict]:
388
+ Global dictionaries for each modification/strand combination.
389
+
390
+ Processing Steps:
391
+ 1. Initialize empty output dictionaries per modification category.
392
+ 2. Populate each dictionary using the processed sample results.
393
+ 3. Return the ordered tuple for downstream processing.
216
394
  """
217
395
  m6A_dict = {}
218
396
  m6A_minus = {}
@@ -254,18 +432,18 @@ def merge_modification_results(processed_results, mods):
254
432
 
255
433
 
256
434
  def process_stranded_methylation(args):
257
- """
258
- Processes a single (dict_index, record, sample) task.
259
-
260
- For combined dictionaries (indices 7 or 8), it merges the corresponding A-stranded and C-stranded data.
261
- For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
262
- NumPy methylation array (of float type). Non-numeric values (e.g. '-') are coerced to NaN.
435
+ """Convert modification DataFrames into per-read methylation arrays.
263
436
 
264
- Parameters:
265
- args: (dict_index, record, sample, dict_list, max_reference_length)
437
+ Args:
438
+ args (tuple): (dict_index, record, sample, dict_list, max_reference_length).
266
439
 
267
440
  Returns:
268
- (dict_index, record, sample, processed_data)
441
+ tuple[int, str, int, dict[str, np.ndarray]]: Updated dictionary entries for the task.
442
+
443
+ Processing Steps:
444
+ 1. For combined dictionaries (indices 7/8), merge A- and C-strand arrays.
445
+ 2. For other dictionaries, compute methylation probabilities per read/position.
446
+ 3. Return per-read arrays keyed by read name.
269
447
  """
270
448
  dict_index, record, sample, dict_list, max_reference_length = args
271
449
  processed_data = {}
@@ -329,13 +507,15 @@ def process_stranded_methylation(args):
329
507
  temp_df = dict_list[dict_index][record][sample]
330
508
  processed_data = {}
331
509
  # Extract columns and convert probabilities to float (coercing errors)
332
- read_ids = temp_df["read_id"].values
333
- positions = temp_df["ref_position"].values
334
- call_codes = temp_df["call_code"].values
335
- probabilities = pd.to_numeric(temp_df["call_prob"].values, errors="coerce")
510
+ read_ids = temp_df[MODKIT_EXTRACT_TSV_COLUMN_READ_ID].values
511
+ positions = temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION].values
512
+ call_codes = temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_CODE].values
513
+ probabilities = pd.to_numeric(
514
+ temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_PROB].values, errors="coerce"
515
+ )
336
516
 
337
- modified_codes = {"a", "h", "m"}
338
- canonical_codes = {"-"}
517
+ modified_codes = MODKIT_EXTRACT_CALL_CODE_MODIFIED
518
+ canonical_codes = MODKIT_EXTRACT_CALL_CODE_CANONICAL
339
519
 
340
520
  # Compute methylation probabilities (vectorized)
341
521
  methylation_prob = np.full(probabilities.shape, np.nan, dtype=float)
@@ -363,11 +543,21 @@ def process_stranded_methylation(args):
363
543
 
364
544
 
365
545
  def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=4):
366
- """
367
- Processes all (dict_index, record, sample) tasks in dict_list (excluding indices in dict_to_skip) in parallel.
546
+ """Parallelize per-read methylation extraction over all dictionary entries.
547
+
548
+ Args:
549
+ dict_list (list[dict]): List of modification/strand dictionaries.
550
+ dict_to_skip (set[int]): Dictionary indices to exclude from processing.
551
+ max_reference_length (int): Maximum reference length for array sizing.
552
+ threads (int): Parallel worker count.
368
553
 
369
554
  Returns:
370
- Updated dict_list with processed (nested) dictionaries.
555
+ list[dict]: Updated dictionary list with per-read methylation arrays.
556
+
557
+ Processing Steps:
558
+ 1. Build tasks for every (dict_index, record, sample) to process.
559
+ 2. Execute tasks in a process pool.
560
+ 3. Replace DataFrames with per-read arrays in-place.
371
561
  """
372
562
  tasks = []
373
563
  for dict_index, current_dict in enumerate(dict_list):
@@ -393,21 +583,20 @@ def delete_intermediate_h5ads_and_tmpdir(
393
583
  dry_run: bool = False,
394
584
  verbose: bool = True,
395
585
  ):
396
- """
397
- Delete intermediate .h5ad files and a temporary directory.
398
-
399
- Parameters
400
- ----------
401
- h5_dir : str | Path | iterable[str] | None
402
- If a directory path is given, all files directly inside it will be considered.
403
- If an iterable of file paths is given, those files will be considered.
404
- Only files ending with '.h5ad' (and not ending with '.gz') are removed.
405
- tmp_dir : str | Path | None
406
- Path to a directory to remove recursively (e.g. a temp dir created earlier).
407
- dry_run : bool
408
- If True, print what *would* be removed but do not actually delete.
409
- verbose : bool
410
- Print progress / warnings.
586
+ """Delete intermediate .h5ad files and optionally a temporary directory.
587
+
588
+ Args:
589
+ h5_dir (str | Path | Iterable[str] | None): Directory or iterable of h5ad paths.
590
+ tmp_dir (str | Path | None): Temporary directory to remove recursively.
591
+ dry_run (bool): If True, log deletions without performing them.
592
+ verbose (bool): If True, log progress and warnings.
593
+
594
+ Returns:
595
+ None: This function performs deletions in-place.
596
+
597
+ Processing Steps:
598
+ 1. Iterate over .h5ad file candidates and delete them (if not dry-run).
599
+ 2. Remove the temporary directory tree if requested.
411
600
  """
412
601
 
413
602
  # Helper: remove a single file path (Path-like or string)
@@ -478,6 +667,455 @@ def delete_intermediate_h5ads_and_tmpdir(
478
667
  logger.warning(f"[error] failed to remove tmp dir {td}: {e}")
479
668
 
480
669
 
670
+ def _collect_input_paths(mod_tsv_dir: Path, bam_dir: Path) -> tuple[list[Path], list[Path]]:
671
+ """Collect sorted TSV and BAM paths for processing.
672
+
673
+ Args:
674
+ mod_tsv_dir (Path): Directory containing modkit extract TSVs.
675
+ bam_dir (Path): Directory containing aligned BAM files.
676
+
677
+ Returns:
678
+ tuple[list[Path], list[Path]]: Sorted TSV paths and BAM paths.
679
+
680
+ Processing Steps:
681
+ 1. Filter TSVs for extract outputs and exclude unclassified entries.
682
+ 2. Filter BAMs for aligned files and exclude indexes/unclassified entries.
683
+ 3. Sort both lists for deterministic processing.
684
+ """
685
+ tsvs = sorted(
686
+ p
687
+ for p in mod_tsv_dir.iterdir()
688
+ if p.is_file() and "unclassified" not in p.name and "extract.tsv" in p.name
689
+ )
690
+ bams = sorted(
691
+ p
692
+ for p in bam_dir.iterdir()
693
+ if p.is_file()
694
+ and p.suffix == ".bam"
695
+ and "unclassified" not in p.name
696
+ and ".bai" not in p.name
697
+ )
698
+ return tsvs, bams
699
+
700
+
701
+ def _build_sample_maps(bam_path_list: list[Path]) -> tuple[dict[int, str], dict[int, str]]:
702
+ """Build sample name and barcode maps from BAM filenames.
703
+
704
+ Args:
705
+ bam_path_list (list[Path]): Paths to BAM files in sample order.
706
+
707
+ Returns:
708
+ tuple[dict[int, str], dict[int, str]]: Maps of sample index to sample name and barcode.
709
+
710
+ Processing Steps:
711
+ 1. Parse the BAM stem for barcode suffixes.
712
+ 2. Build a standardized sample name with barcode suffix.
713
+ 3. Store mappings for downstream metadata annotations.
714
+ """
715
+ sample_name_map: dict[int, str] = {}
716
+ barcode_map: dict[int, str] = {}
717
+
718
+ for idx, bam_path in enumerate(bam_path_list):
719
+ stem = bam_path.stem
720
+ m = re.search(r"^(.*?)[_\-\.]?(barcode[0-9A-Za-z\-]+)$", stem)
721
+ if m:
722
+ sample_name = m.group(1) or stem
723
+ barcode = m.group(2)
724
+ else:
725
+ sample_name = stem
726
+ barcode = stem
727
+
728
+ sample_name = f"{sample_name}_{barcode}"
729
+ barcode_id = int(barcode.split("barcode")[1])
730
+
731
+ sample_name_map[idx] = sample_name
732
+ barcode_map[idx] = str(barcode_id)
733
+
734
+ return sample_name_map, barcode_map
735
+
736
+
737
+ def _encode_sequence_array(
738
+ read_sequence: np.ndarray,
739
+ valid_length: int,
740
+ base_to_int: Mapping[str, int],
741
+ padding_value: int,
742
+ ) -> np.ndarray:
743
+ """Convert a base-identity array into integer encoding with padding.
744
+
745
+ Args:
746
+ read_sequence (np.ndarray): Array of base calls (dtype "<U1").
747
+ valid_length (int): Number of valid reference positions for this record.
748
+ base_to_int (Mapping[str, int]): Base-to-integer mapping for A/C/G/T/N/PAD.
749
+ padding_value (int): Integer value to use for padding.
750
+
751
+ Returns:
752
+ np.ndarray: Integer-encoded sequence with padding applied.
753
+
754
+ Processing Steps:
755
+ 1. Initialize an integer array filled with the N value.
756
+ 2. Overwrite values for known bases (A/C/G/T/N).
757
+ 3. Replace positions beyond valid_length with padding.
758
+ """
759
+ read_sequence = np.asarray(read_sequence, dtype="<U1")
760
+ encoded = np.full(read_sequence.shape, base_to_int["N"], dtype=np.int16)
761
+ for base in MODKIT_EXTRACT_SEQUENCE_BASES:
762
+ encoded[read_sequence == base] = base_to_int[base]
763
+ if valid_length < encoded.size:
764
+ encoded[valid_length:] = padding_value
765
+ return encoded
766
+
767
+
768
+ def _write_sequence_batches(
769
+ base_identities: Mapping[str, np.ndarray],
770
+ tmp_dir: Path,
771
+ record: str,
772
+ prefix: str,
773
+ base_to_int: Mapping[str, int],
774
+ valid_length: int,
775
+ batch_size: int,
776
+ ) -> list[str]:
777
+ """Encode base identities into integer arrays and write batched H5AD files.
778
+
779
+ Args:
780
+ base_identities (Mapping[str, np.ndarray]): Read name to base identity arrays.
781
+ tmp_dir (Path): Directory for temporary H5AD files.
782
+ record (str): Reference record identifier.
783
+ prefix (str): Prefix used to name batch files.
784
+ base_to_int (Mapping[str, int]): Base-to-integer mapping.
785
+ valid_length (int): Valid reference length for padding determination.
786
+ batch_size (int): Number of reads per H5AD batch file.
787
+
788
+ Returns:
789
+ list[str]: Paths to written H5AD batch files.
790
+
791
+ Processing Steps:
792
+ 1. Encode each read sequence to integer values.
793
+ 2. Accumulate encoded reads into batches.
794
+ 3. Persist each batch as an H5AD with the dictionary stored in `.uns`.
795
+ """
796
+ import anndata as ad
797
+
798
+ padding_value = base_to_int[MODKIT_EXTRACT_SEQUENCE_PADDING_BASE]
799
+ batch_files: list[str] = []
800
+ batch: dict[str, np.ndarray] = {}
801
+ batch_number = 0
802
+
803
+ for read_name, sequence in base_identities.items():
804
+ if sequence is None:
805
+ continue
806
+ batch[read_name] = _encode_sequence_array(
807
+ sequence, valid_length, base_to_int, padding_value
808
+ )
809
+ if len(batch) >= batch_size:
810
+ save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
811
+ ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
812
+ batch_files.append(str(save_name))
813
+ batch = {}
814
+ batch_number += 1
815
+
816
+ if batch:
817
+ save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
818
+ ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
819
+ batch_files.append(str(save_name))
820
+
821
+ return batch_files
822
+
823
+
824
+ def _write_integer_batches(
825
+ sequences: Mapping[str, np.ndarray],
826
+ tmp_dir: Path,
827
+ record: str,
828
+ prefix: str,
829
+ batch_size: int,
830
+ ) -> list[str]:
831
+ """Write integer-encoded sequences into batched H5AD files.
832
+
833
+ Args:
834
+ sequences (Mapping[str, np.ndarray]): Read name to integer arrays.
835
+ tmp_dir (Path): Directory for temporary H5AD files.
836
+ record (str): Reference record identifier.
837
+ prefix (str): Prefix used to name batch files.
838
+ batch_size (int): Number of reads per H5AD batch file.
839
+
840
+ Returns:
841
+ list[str]: Paths to written H5AD batch files.
842
+
843
+ Processing Steps:
844
+ 1. Accumulate integer arrays into batches.
845
+ 2. Persist each batch as an H5AD with the dictionary stored in `.uns`.
846
+ """
847
+ import anndata as ad
848
+
849
+ batch_files: list[str] = []
850
+ batch: dict[str, np.ndarray] = {}
851
+ batch_number = 0
852
+
853
+ for read_name, sequence in sequences.items():
854
+ if sequence is None:
855
+ continue
856
+ batch[read_name] = np.asarray(sequence, dtype=np.int16)
857
+ if len(batch) >= batch_size:
858
+ save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
859
+ ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
860
+ batch_files.append(str(save_name))
861
+ batch = {}
862
+ batch_number += 1
863
+
864
+ if batch:
865
+ save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
866
+ ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
867
+ batch_files.append(str(save_name))
868
+
869
+ return batch_files
870
+
871
+
872
+ def _load_sequence_batches(
873
+ batch_files: list[Path | str],
874
+ ) -> tuple[dict[str, np.ndarray], set[str], set[str]]:
875
+ """Load integer-encoded sequence batches from H5AD files.
876
+
877
+ Args:
878
+ batch_files (list[Path | str]): H5AD paths containing encoded sequences in `.uns`.
879
+
880
+ Returns:
881
+ tuple[dict[str, np.ndarray], set[str], set[str]]:
882
+ Read-to-sequence mapping and sets of forward/reverse mapped reads.
883
+
884
+ Processing Steps:
885
+ 1. Read each H5AD file.
886
+ 2. Merge `.uns` dictionaries into a single mapping.
887
+ 3. Track forward/reverse read IDs based on the filename marker.
888
+ """
889
+ import anndata as ad
890
+
891
+ sequences: dict[str, np.ndarray] = {}
892
+ fwd_reads: set[str] = set()
893
+ rev_reads: set[str] = set()
894
+ for batch_file in batch_files:
895
+ batch_path = Path(batch_file)
896
+ batch_sequences = ad.read_h5ad(batch_path).uns
897
+ sequences.update(batch_sequences)
898
+ if "_fwd_" in batch_path.name:
899
+ fwd_reads.update(batch_sequences.keys())
900
+ elif "_rev_" in batch_path.name:
901
+ rev_reads.update(batch_sequences.keys())
902
+ return sequences, fwd_reads, rev_reads
903
+
904
+
905
+ def _load_integer_batches(batch_files: list[Path | str]) -> dict[str, np.ndarray]:
906
+ """Load integer arrays from batched H5AD files.
907
+
908
+ Args:
909
+ batch_files (list[Path | str]): H5AD paths containing arrays in `.uns`.
910
+
911
+ Returns:
912
+ dict[str, np.ndarray]: Read-to-array mapping.
913
+
914
+ Processing Steps:
915
+ 1. Read each H5AD file.
916
+ 2. Merge `.uns` dictionaries into a single mapping.
917
+ """
918
+ import anndata as ad
919
+
920
+ sequences: dict[str, np.ndarray] = {}
921
+ for batch_file in batch_files:
922
+ batch_path = Path(batch_file)
923
+ sequences.update(ad.read_h5ad(batch_path).uns)
924
+ return sequences
925
+
926
+
927
+ def _normalize_sequence_batch_files(batch_files: object) -> list[Path]:
928
+ """Normalize cached batch file entries into a list of Paths.
929
+
930
+ Args:
931
+ batch_files (object): Cached batch file entry from AnnData `.uns`.
932
+
933
+ Returns:
934
+ list[Path]: Paths to batch files, filtered to non-empty values.
935
+
936
+ Processing Steps:
937
+ 1. Convert numpy arrays and scalars into Python lists.
938
+ 2. Filter out empty/placeholder values.
939
+ 3. Cast remaining entries to Path objects.
940
+ """
941
+ if batch_files is None:
942
+ return []
943
+ if isinstance(batch_files, np.ndarray):
944
+ batch_files = batch_files.tolist()
945
+ if isinstance(batch_files, (str, Path)):
946
+ batch_files = [batch_files]
947
+ if not isinstance(batch_files, list):
948
+ batch_files = list(batch_files)
949
+ normalized: list[Path] = []
950
+ for entry in batch_files:
951
+ if entry is None:
952
+ continue
953
+ entry_str = str(entry).strip()
954
+ if not entry_str or entry_str == ".":
955
+ continue
956
+ normalized.append(Path(entry_str))
957
+ return normalized
958
+
959
+
960
+ def _build_modification_dicts(
961
+ dict_total: dict,
962
+ mods: list[str],
963
+ ) -> tuple[ModkitBatchDictionaries, set[int]]:
964
+ """Build modification/strand dictionaries from the raw TSV batch dictionary.
965
+
966
+ Args:
967
+ dict_total (dict): Raw TSV DataFrames keyed by record and sample index.
968
+ mods (list[str]): Modification labels to include (e.g., ["6mA", "5mC"]).
969
+
970
+ Returns:
971
+ tuple[ModkitBatchDictionaries, set[int]]: Batch dictionaries and indices to skip.
972
+
973
+ Processing Steps:
974
+ 1. Initialize modification dictionaries and skip-set.
975
+ 2. Filter TSV rows per record/sample into modification and strand subsets.
976
+ 3. Populate combined dict placeholders when both modifications are present.
977
+ """
978
+ batch_dicts = ModkitBatchDictionaries(dict_total=dict_total)
979
+ dict_to_skip = {0, 1, 4}
980
+ combined_dicts = {7, 8}
981
+ A_stranded_dicts = {2, 3}
982
+ C_stranded_dicts = {5, 6}
983
+ dict_to_skip.update(combined_dicts | A_stranded_dicts | C_stranded_dicts)
984
+
985
+ for record in dict_total.keys():
986
+ for sample_index in dict_total[record].keys():
987
+ if "6mA" in mods:
988
+ dict_to_skip.difference_update(A_stranded_dicts)
989
+ if (
990
+ record not in batch_dicts.dict_a.keys()
991
+ and record not in batch_dicts.dict_a_bottom.keys()
992
+ and record not in batch_dicts.dict_a_top.keys()
993
+ ):
994
+ (
995
+ batch_dicts.dict_a[record],
996
+ batch_dicts.dict_a_bottom[record],
997
+ batch_dicts.dict_a_top[record],
998
+ ) = ({}, {}, {})
999
+
1000
+ batch_dicts.dict_a[record][sample_index] = dict_total[record][sample_index][
1001
+ dict_total[record][sample_index][
1002
+ MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE
1003
+ ]
1004
+ == MODKIT_EXTRACT_MODIFIED_BASE_A
1005
+ ]
1006
+ logger.debug(
1007
+ "Successfully loaded a methyl-adenine dictionary for {}".format(
1008
+ str(sample_index)
1009
+ )
1010
+ )
1011
+
1012
+ batch_dicts.dict_a_bottom[record][sample_index] = batch_dicts.dict_a[record][
1013
+ sample_index
1014
+ ][
1015
+ batch_dicts.dict_a[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
1016
+ == MODKIT_EXTRACT_REF_STRAND_MINUS
1017
+ ]
1018
+ logger.debug(
1019
+ "Successfully loaded a minus strand methyl-adenine dictionary for {}".format(
1020
+ str(sample_index)
1021
+ )
1022
+ )
1023
+ batch_dicts.dict_a_top[record][sample_index] = batch_dicts.dict_a[record][
1024
+ sample_index
1025
+ ][
1026
+ batch_dicts.dict_a[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
1027
+ == MODKIT_EXTRACT_REF_STRAND_PLUS
1028
+ ]
1029
+ logger.debug(
1030
+ "Successfully loaded a plus strand methyl-adenine dictionary for ".format(
1031
+ str(sample_index)
1032
+ )
1033
+ )
1034
+
1035
+ batch_dicts.dict_a[record][sample_index] = None
1036
+ gc.collect()
1037
+
1038
+ if "5mC" in mods:
1039
+ dict_to_skip.difference_update(C_stranded_dicts)
1040
+ if (
1041
+ record not in batch_dicts.dict_c.keys()
1042
+ and record not in batch_dicts.dict_c_bottom.keys()
1043
+ and record not in batch_dicts.dict_c_top.keys()
1044
+ ):
1045
+ (
1046
+ batch_dicts.dict_c[record],
1047
+ batch_dicts.dict_c_bottom[record],
1048
+ batch_dicts.dict_c_top[record],
1049
+ ) = ({}, {}, {})
1050
+
1051
+ batch_dicts.dict_c[record][sample_index] = dict_total[record][sample_index][
1052
+ dict_total[record][sample_index][
1053
+ MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE
1054
+ ]
1055
+ == MODKIT_EXTRACT_MODIFIED_BASE_C
1056
+ ]
1057
+ logger.debug(
1058
+ "Successfully loaded a methyl-cytosine dictionary for {}".format(
1059
+ str(sample_index)
1060
+ )
1061
+ )
1062
+
1063
+ batch_dicts.dict_c_bottom[record][sample_index] = batch_dicts.dict_c[record][
1064
+ sample_index
1065
+ ][
1066
+ batch_dicts.dict_c[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
1067
+ == MODKIT_EXTRACT_REF_STRAND_MINUS
1068
+ ]
1069
+ logger.debug(
1070
+ "Successfully loaded a minus strand methyl-cytosine dictionary for {}".format(
1071
+ str(sample_index)
1072
+ )
1073
+ )
1074
+ batch_dicts.dict_c_top[record][sample_index] = batch_dicts.dict_c[record][
1075
+ sample_index
1076
+ ][
1077
+ batch_dicts.dict_c[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
1078
+ == MODKIT_EXTRACT_REF_STRAND_PLUS
1079
+ ]
1080
+ logger.debug(
1081
+ "Successfully loaded a plus strand methyl-cytosine dictionary for {}".format(
1082
+ str(sample_index)
1083
+ )
1084
+ )
1085
+
1086
+ batch_dicts.dict_c[record][sample_index] = None
1087
+ gc.collect()
1088
+
1089
+ if "6mA" in mods and "5mC" in mods:
1090
+ dict_to_skip.difference_update(combined_dicts)
1091
+ if (
1092
+ record not in batch_dicts.dict_combined_bottom.keys()
1093
+ and record not in batch_dicts.dict_combined_top.keys()
1094
+ ):
1095
+ (
1096
+ batch_dicts.dict_combined_bottom[record],
1097
+ batch_dicts.dict_combined_top[record],
1098
+ ) = ({}, {})
1099
+
1100
+ logger.debug(
1101
+ "Successfully created a minus strand combined methylation dictionary for {}".format(
1102
+ str(sample_index)
1103
+ )
1104
+ )
1105
+ batch_dicts.dict_combined_bottom[record][sample_index] = []
1106
+ logger.debug(
1107
+ "Successfully created a plus strand combined methylation dictionary for {}".format(
1108
+ str(sample_index)
1109
+ )
1110
+ )
1111
+ batch_dicts.dict_combined_top[record][sample_index] = []
1112
+
1113
+ dict_total[record][sample_index] = None
1114
+ gc.collect()
1115
+
1116
+ return batch_dicts, dict_to_skip
1117
+
1118
+
481
1119
  def modkit_extract_to_adata(
482
1120
  fasta,
483
1121
  bam_dir,
@@ -493,24 +1131,32 @@ def modkit_extract_to_adata(
493
1131
  double_barcoded_path=None,
494
1132
  samtools_backend: str | None = "auto",
495
1133
  ):
496
- """
497
- Takes modkit extract outputs and organizes it into an adata object
498
-
499
- Parameters:
500
- fasta (Path): File path to the reference genome to align to.
501
- bam_dir (Path): File path to the directory containing the aligned_sorted split modified BAM files
502
- out_dir (Path): File path to output directory
503
- input_already_demuxed (bool): Whether input reads were originally demuxed
504
- mapping_threshold (float): A value in between 0 and 1 to threshold the minimal fraction of aligned reads which map to the reference region. References with values above the threshold are included in the output adata.
505
- experiment_name (str): A string to provide an experiment name to the output adata file.
506
- mods (list): A list of strings of the modification types to use in the analysis.
507
- batch_size (int): An integer number of TSV files to analyze in memory at once while loading the final adata object.
508
- mod_tsv_dir (Path): path to the mod TSV directory
509
- delete_batch_hdfs (bool): Whether to delete the batch hdfs after writing out the final concatenated hdf. Default is False
510
- double_barcoded_path (Path): Path to dorado demux summary file of double ended barcodes
1134
+ """Convert modkit extract TSVs and BAMs into an AnnData object.
1135
+
1136
+ Args:
1137
+ fasta (Path): Reference FASTA path.
1138
+ bam_dir (Path): Directory with aligned BAM files.
1139
+ out_dir (Path): Output directory for intermediate and final H5ADs.
1140
+ input_already_demuxed (bool): Whether reads were already demultiplexed.
1141
+ mapping_threshold (float): Minimum fraction of mapped reads to keep a record.
1142
+ experiment_name (str): Experiment name used in output file naming.
1143
+ mods (list[str]): Modification labels to analyze (e.g., ["6mA", "5mC"]).
1144
+ batch_size (int): Number of TSVs to process per batch.
1145
+ mod_tsv_dir (Path): Directory containing modkit extract TSVs.
1146
+ delete_batch_hdfs (bool): Remove batch H5ADs after concatenation.
1147
+ threads (int | None): Thread count for parallel operations.
1148
+ double_barcoded_path (Path | None): Dorado demux summary directory for double barcodes.
1149
+ samtools_backend (str | None): Samtools backend selection.
511
1150
 
512
1151
  Returns:
513
- final_adata_path (Path): Path to the final adata
1152
+ tuple[ad.AnnData | None, Path]: The final AnnData (if created) and its H5AD path.
1153
+
1154
+ Processing Steps:
1155
+ 1. Discover input TSV/BAM files and derive sample metadata.
1156
+ 2. Identify records that pass mapping thresholds and build reference metadata.
1157
+ 3. Encode read sequences into integer arrays and cache them.
1158
+ 4. Process TSV batches into per-read methylation matrices.
1159
+ 5. Concatenate batch H5ADs into a final AnnData with consensus sequences.
514
1160
  """
515
1161
  ###################################################
516
1162
  # Package imports
@@ -527,12 +1173,11 @@ def modkit_extract_to_adata(
527
1173
  from ..readwrite import make_dirs
528
1174
  from .bam_functions import extract_base_identities
529
1175
  from .fasta_functions import get_native_references
530
- from .ohe import ohe_batching
531
1176
  ###################################################
532
1177
 
533
1178
  ################## Get input tsv and bam file names into a sorted list ################
534
1179
  # Make output dirs
535
- h5_dir = out_dir / "h5ads"
1180
+ h5_dir = out_dir / H5_DIR
536
1181
  tmp_dir = out_dir / "tmp"
537
1182
  make_dirs([h5_dir, tmp_dir])
538
1183
 
@@ -546,55 +1191,14 @@ def modkit_extract_to_adata(
546
1191
  logger.debug(f"{final_adata_path} already exists. Using existing adata")
547
1192
  return final_adata, final_adata_path
548
1193
 
549
- # List all files in the directory
550
- tsvs = sorted(
551
- p
552
- for p in mod_tsv_dir.iterdir()
553
- if p.is_file() and "unclassified" not in p.name and "extract.tsv" in p.name
554
- )
555
- bams = sorted(
556
- p
557
- for p in bam_dir.iterdir()
558
- if p.is_file()
559
- and p.suffix == ".bam"
560
- and "unclassified" not in p.name
561
- and ".bai" not in p.name
562
- )
563
-
564
- tsv_path_list = [tsv for tsv in tsvs]
565
- bam_path_list = [bam for bam in bams]
1194
+ tsvs, bams = _collect_input_paths(mod_tsv_dir, bam_dir)
1195
+ tsv_path_list = list(tsvs)
1196
+ bam_path_list = list(bams)
566
1197
  logger.info(f"{len(tsvs)} sample tsv files found: {tsvs}")
567
1198
  logger.info(f"{len(bams)} sample bams found: {bams}")
568
1199
 
569
1200
  # Map global sample index (bami / final_sample_index) -> sample name / barcode
570
- sample_name_map = {}
571
- barcode_map = {}
572
-
573
- for idx, bam_path in enumerate(bam_path_list):
574
- stem = bam_path.stem
575
-
576
- # Try to peel off a "barcode..." suffix if present.
577
- # This handles things like:
578
- # "mySample_barcode01" -> sample="mySample", barcode="barcode01"
579
- # "run1-s1_barcode05" -> sample="run1-s1", barcode="barcode05"
580
- # "barcode01" -> sample="barcode01", barcode="barcode01"
581
- m = re.search(r"^(.*?)[_\-\.]?(barcode[0-9A-Za-z\-]+)$", stem)
582
- if m:
583
- sample_name = m.group(1) or stem
584
- barcode = m.group(2)
585
- else:
586
- # Fallback: treat the whole stem as both sample & barcode
587
- sample_name = stem
588
- barcode = stem
589
-
590
- # make sample name of the format of the bam file stem
591
- sample_name = sample_name + f"_{barcode}"
592
-
593
- # Clean the barcode name to be an integer
594
- barcode = int(barcode.split("barcode")[1])
595
-
596
- sample_name_map[idx] = sample_name
597
- barcode_map[idx] = str(barcode)
1201
+ sample_name_map, barcode_map = _build_sample_maps(bam_path_list)
598
1202
  ##########################################################################################
599
1203
 
600
1204
  ######### Get Record names that have over a passed threshold of mapped reads #############
@@ -619,59 +1223,154 @@ def modkit_extract_to_adata(
619
1223
  ##########################################################################################
620
1224
 
621
1225
  ##########################################################################################
622
- # One hot encode read sequences and write them out into the tmp_dir as h5ad files.
623
- # Save the file paths in the bam_record_ohe_files dict.
624
- bam_record_ohe_files = {}
625
- bam_record_save = tmp_dir / "tmp_file_dict.h5ad"
626
- fwd_mapped_reads = set()
627
- rev_mapped_reads = set()
628
- # If this step has already been performed, read in the tmp_dile_dict
629
- if bam_record_save.exists():
630
- bam_record_ohe_files = ad.read_h5ad(bam_record_save).uns
631
- logger.debug("Found existing OHE reads, using these")
632
- else:
633
- # Iterate over split bams
1226
+ # Encode read sequences into integer arrays and cache in tmp_dir.
1227
+ sequence_batch_files: dict[str, list[str]] = {}
1228
+ mismatch_batch_files: dict[str, list[str]] = {}
1229
+ quality_batch_files: dict[str, list[str]] = {}
1230
+ read_span_batch_files: dict[str, list[str]] = {}
1231
+ sequence_cache_path = tmp_dir / "tmp_sequence_int_file_dict.h5ad"
1232
+ cache_needs_rebuild = True
1233
+ if sequence_cache_path.exists():
1234
+ cached_uns = ad.read_h5ad(sequence_cache_path).uns
1235
+ if "sequence_batch_files" in cached_uns:
1236
+ sequence_batch_files = cached_uns.get("sequence_batch_files", {})
1237
+ mismatch_batch_files = cached_uns.get("mismatch_batch_files", {})
1238
+ quality_batch_files = cached_uns.get("quality_batch_files", {})
1239
+ read_span_batch_files = cached_uns.get("read_span_batch_files", {})
1240
+ cache_needs_rebuild = not (
1241
+ quality_batch_files and read_span_batch_files and sequence_batch_files
1242
+ )
1243
+ else:
1244
+ sequence_batch_files = cached_uns
1245
+ cache_needs_rebuild = True
1246
+ if cache_needs_rebuild:
1247
+ logger.info(
1248
+ "Cached sequence batches missing quality or read-span data; rebuilding cache."
1249
+ )
1250
+ else:
1251
+ logger.debug("Found existing integer-encoded reads, using these")
1252
+ if cache_needs_rebuild:
634
1253
  for bami, bam in enumerate(bam_path_list):
635
- # Iterate over references to process
1254
+ logger.info(
1255
+ f"Extracting base level sequences, qualities, reference spans, and mismatches per read for bam {bami}"
1256
+ )
636
1257
  for record in records_to_analyze:
637
1258
  current_reference_length = reference_dict[record][0]
638
1259
  positions = range(current_reference_length)
639
1260
  ref_seq = reference_dict[record][1]
640
- # Extract the base identities of reads aligned to the record
641
1261
  (
642
1262
  fwd_base_identities,
643
1263
  rev_base_identities,
644
- mismatch_counts_per_read,
645
- mismatch_trend_per_read,
1264
+ _mismatch_counts_per_read,
1265
+ _mismatch_trend_per_read,
1266
+ mismatch_base_identities,
1267
+ base_quality_scores,
1268
+ read_span_masks,
646
1269
  ) = extract_base_identities(
647
1270
  bam, record, positions, max_reference_length, ref_seq, samtools_backend
648
1271
  )
649
- # Store read names of fwd and rev mapped reads
650
- fwd_mapped_reads.update(fwd_base_identities.keys())
651
- rev_mapped_reads.update(rev_base_identities.keys())
652
- # One hot encode the sequence string of the reads
653
- fwd_ohe_files = ohe_batching(
1272
+ mismatch_fwd = {
1273
+ read_name: mismatch_base_identities[read_name]
1274
+ for read_name in fwd_base_identities
1275
+ }
1276
+ mismatch_rev = {
1277
+ read_name: mismatch_base_identities[read_name]
1278
+ for read_name in rev_base_identities
1279
+ }
1280
+ quality_fwd = {
1281
+ read_name: base_quality_scores[read_name] for read_name in fwd_base_identities
1282
+ }
1283
+ quality_rev = {
1284
+ read_name: base_quality_scores[read_name] for read_name in rev_base_identities
1285
+ }
1286
+ read_span_fwd = {
1287
+ read_name: read_span_masks[read_name] for read_name in fwd_base_identities
1288
+ }
1289
+ read_span_rev = {
1290
+ read_name: read_span_masks[read_name] for read_name in rev_base_identities
1291
+ }
1292
+ fwd_sequence_files = _write_sequence_batches(
654
1293
  fwd_base_identities,
655
1294
  tmp_dir,
656
1295
  record,
657
1296
  f"{bami}_fwd",
1297
+ MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
1298
+ current_reference_length,
658
1299
  batch_size=100000,
659
- threads=threads,
660
1300
  )
661
- rev_ohe_files = ohe_batching(
1301
+ rev_sequence_files = _write_sequence_batches(
662
1302
  rev_base_identities,
663
1303
  tmp_dir,
664
1304
  record,
665
1305
  f"{bami}_rev",
1306
+ MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
1307
+ current_reference_length,
1308
+ batch_size=100000,
1309
+ )
1310
+ sequence_batch_files[f"{bami}_{record}"] = fwd_sequence_files + rev_sequence_files
1311
+ mismatch_fwd_files = _write_integer_batches(
1312
+ mismatch_fwd,
1313
+ tmp_dir,
1314
+ record,
1315
+ f"{bami}_mismatch_fwd",
666
1316
  batch_size=100000,
667
- threads=threads,
668
1317
  )
669
- bam_record_ohe_files[f"{bami}_{record}"] = fwd_ohe_files + rev_ohe_files
670
- del fwd_base_identities, rev_base_identities
671
- # Save out the ohe file paths
672
- X = np.random.rand(1, 1)
673
- tmp_ad = ad.AnnData(X=X, uns=bam_record_ohe_files)
674
- tmp_ad.write_h5ad(bam_record_save)
1318
+ mismatch_rev_files = _write_integer_batches(
1319
+ mismatch_rev,
1320
+ tmp_dir,
1321
+ record,
1322
+ f"{bami}_mismatch_rev",
1323
+ batch_size=100000,
1324
+ )
1325
+ mismatch_batch_files[f"{bami}_{record}"] = mismatch_fwd_files + mismatch_rev_files
1326
+ quality_fwd_files = _write_integer_batches(
1327
+ quality_fwd,
1328
+ tmp_dir,
1329
+ record,
1330
+ f"{bami}_quality_fwd",
1331
+ batch_size=100000,
1332
+ )
1333
+ quality_rev_files = _write_integer_batches(
1334
+ quality_rev,
1335
+ tmp_dir,
1336
+ record,
1337
+ f"{bami}_quality_rev",
1338
+ batch_size=100000,
1339
+ )
1340
+ quality_batch_files[f"{bami}_{record}"] = quality_fwd_files + quality_rev_files
1341
+ read_span_fwd_files = _write_integer_batches(
1342
+ read_span_fwd,
1343
+ tmp_dir,
1344
+ record,
1345
+ f"{bami}_read_span_fwd",
1346
+ batch_size=100000,
1347
+ )
1348
+ read_span_rev_files = _write_integer_batches(
1349
+ read_span_rev,
1350
+ tmp_dir,
1351
+ record,
1352
+ f"{bami}_read_span_rev",
1353
+ batch_size=100000,
1354
+ )
1355
+ read_span_batch_files[f"{bami}_{record}"] = (
1356
+ read_span_fwd_files + read_span_rev_files
1357
+ )
1358
+ del (
1359
+ fwd_base_identities,
1360
+ rev_base_identities,
1361
+ mismatch_base_identities,
1362
+ base_quality_scores,
1363
+ read_span_masks,
1364
+ )
1365
+ ad.AnnData(
1366
+ X=np.random.rand(1, 1),
1367
+ uns={
1368
+ "sequence_batch_files": sequence_batch_files,
1369
+ "mismatch_batch_files": mismatch_batch_files,
1370
+ "quality_batch_files": quality_batch_files,
1371
+ "read_span_batch_files": read_span_batch_files,
1372
+ },
1373
+ ).write_h5ad(sequence_cache_path)
675
1374
  ##########################################################################################
676
1375
 
677
1376
  ##########################################################################################
@@ -713,47 +1412,9 @@ def modkit_extract_to_adata(
713
1412
  ###################################################
714
1413
  ### Add the tsvs as dataframes to a dictionary (dict_total) keyed by integer index. Also make modification specific dictionaries and strand specific dictionaries.
715
1414
  # # Initialize dictionaries and place them in a list
716
- (
717
- dict_total,
718
- dict_a,
719
- dict_a_bottom,
720
- dict_a_top,
721
- dict_c,
722
- dict_c_bottom,
723
- dict_c_top,
724
- dict_combined_bottom,
725
- dict_combined_top,
726
- ) = {}, {}, {}, {}, {}, {}, {}, {}, {}
727
- dict_list = [
728
- dict_total,
729
- dict_a,
730
- dict_a_bottom,
731
- dict_a_top,
732
- dict_c,
733
- dict_c_bottom,
734
- dict_c_top,
735
- dict_combined_bottom,
736
- dict_combined_top,
737
- ]
738
- # Give names to represent each dictionary in the list
739
- sample_types = [
740
- "total",
741
- "m6A",
742
- "m6A_bottom_strand",
743
- "m6A_top_strand",
744
- "5mC",
745
- "5mC_bottom_strand",
746
- "5mC_top_strand",
747
- "combined_bottom_strand",
748
- "combined_top_strand",
749
- ]
750
- # Give indices of dictionaries to skip for analysis and final dictionary saving.
751
- dict_to_skip = [0, 1, 4]
752
- combined_dicts = [7, 8]
753
- A_stranded_dicts = [2, 3]
754
- C_stranded_dicts = [5, 6]
755
- dict_to_skip = dict_to_skip + combined_dicts + A_stranded_dicts + C_stranded_dicts
756
- dict_to_skip = set(dict_to_skip)
1415
+ batch_dicts = ModkitBatchDictionaries()
1416
+ dict_list = batch_dicts.as_list()
1417
+ sample_types = batch_dicts.sample_types
757
1418
 
758
1419
  # # Step 1):Load the dict_total dictionary with all of the batch tsv files as dataframes.
759
1420
  dict_total = parallel_load_tsvs(
@@ -765,140 +1426,9 @@ def modkit_extract_to_adata(
765
1426
  threads=threads,
766
1427
  )
767
1428
 
768
- # # Step 2: Extract modification-specific data (per (record,sample)) in parallel
769
- # processed_mod_results = parallel_process_modifications(dict_total, mods, max_reference_length, threads=threads or 4)
770
- # (m6A_dict, m6A_minus_strand, m6A_plus_strand,
771
- # c5m_dict, c5m_minus_strand, c5m_plus_strand,
772
- # combined_minus_strand, combined_plus_strand) = merge_modification_results(processed_mod_results, mods)
773
-
774
- # # Create dict_list with the desired ordering:
775
- # # 0: dict_total, 1: m6A, 2: m6A_minus, 3: m6A_plus, 4: 5mC, 5: 5mC_minus, 6: 5mC_plus, 7: combined_minus, 8: combined_plus
776
- # dict_list = [dict_total, m6A_dict, m6A_minus_strand, m6A_plus_strand,
777
- # c5m_dict, c5m_minus_strand, c5m_plus_strand,
778
- # combined_minus_strand, combined_plus_strand]
779
-
780
- # # Initialize dict_to_skip (default skip all mod-specific indices)
781
- # dict_to_skip = set([0, 1, 4, 7, 8, 2, 3, 5, 6])
782
- # # Update dict_to_skip based on modifications present in mods
783
- # dict_to_skip = update_dict_to_skip(dict_to_skip, mods)
784
-
785
- # # Step 3: Process stranded methylation data in parallel
786
- # dict_list = parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=threads or 4)
787
-
788
- # Iterate over dict_total of all the tsv files and extract the modification specific and strand specific dataframes into dictionaries
789
- for record in dict_total.keys():
790
- for sample_index in dict_total[record].keys():
791
- if "6mA" in mods:
792
- # Remove Adenine stranded dicts from the dicts to skip set
793
- dict_to_skip.difference_update(set(A_stranded_dicts))
794
-
795
- if (
796
- record not in dict_a.keys()
797
- and record not in dict_a_bottom.keys()
798
- and record not in dict_a_top.keys()
799
- ):
800
- dict_a[record], dict_a_bottom[record], dict_a_top[record] = {}, {}, {}
801
-
802
- # get a dictionary of dataframes that only contain methylated adenine positions
803
- dict_a[record][sample_index] = dict_total[record][sample_index][
804
- dict_total[record][sample_index]["modified_primary_base"] == "A"
805
- ]
806
- logger.debug(
807
- "Successfully loaded a methyl-adenine dictionary for {}".format(
808
- str(sample_index)
809
- )
810
- )
811
-
812
- # Stratify the adenine dictionary into two strand specific dictionaries.
813
- dict_a_bottom[record][sample_index] = dict_a[record][sample_index][
814
- dict_a[record][sample_index]["ref_strand"] == "-"
815
- ]
816
- logger.debug(
817
- "Successfully loaded a minus strand methyl-adenine dictionary for {}".format(
818
- str(sample_index)
819
- )
820
- )
821
- dict_a_top[record][sample_index] = dict_a[record][sample_index][
822
- dict_a[record][sample_index]["ref_strand"] == "+"
823
- ]
824
- logger.debug(
825
- "Successfully loaded a plus strand methyl-adenine dictionary for ".format(
826
- str(sample_index)
827
- )
828
- )
829
-
830
- # Reassign pointer for dict_a to None and delete the original value that it pointed to in order to decrease memory usage.
831
- dict_a[record][sample_index] = None
832
- gc.collect()
833
-
834
- if "5mC" in mods:
835
- # Remove Cytosine stranded dicts from the dicts to skip set
836
- dict_to_skip.difference_update(set(C_stranded_dicts))
837
-
838
- if (
839
- record not in dict_c.keys()
840
- and record not in dict_c_bottom.keys()
841
- and record not in dict_c_top.keys()
842
- ):
843
- dict_c[record], dict_c_bottom[record], dict_c_top[record] = {}, {}, {}
844
-
845
- # get a dictionary of dataframes that only contain methylated cytosine positions
846
- dict_c[record][sample_index] = dict_total[record][sample_index][
847
- dict_total[record][sample_index]["modified_primary_base"] == "C"
848
- ]
849
- logger.debug(
850
- "Successfully loaded a methyl-cytosine dictionary for {}".format(
851
- str(sample_index)
852
- )
853
- )
854
- # Stratify the cytosine dictionary into two strand specific dictionaries.
855
- dict_c_bottom[record][sample_index] = dict_c[record][sample_index][
856
- dict_c[record][sample_index]["ref_strand"] == "-"
857
- ]
858
- logger.debug(
859
- "Successfully loaded a minus strand methyl-cytosine dictionary for {}".format(
860
- str(sample_index)
861
- )
862
- )
863
- dict_c_top[record][sample_index] = dict_c[record][sample_index][
864
- dict_c[record][sample_index]["ref_strand"] == "+"
865
- ]
866
- logger.debug(
867
- "Successfully loaded a plus strand methyl-cytosine dictionary for {}".format(
868
- str(sample_index)
869
- )
870
- )
871
- # Reassign pointer for dict_c to None and delete the original value that it pointed to in order to decrease memory usage.
872
- dict_c[record][sample_index] = None
873
- gc.collect()
874
-
875
- if "6mA" in mods and "5mC" in mods:
876
- # Remove combined stranded dicts from the dicts to skip set
877
- dict_to_skip.difference_update(set(combined_dicts))
878
- # Initialize the sample keys for the combined dictionaries
879
-
880
- if (
881
- record not in dict_combined_bottom.keys()
882
- and record not in dict_combined_top.keys()
883
- ):
884
- dict_combined_bottom[record], dict_combined_top[record] = {}, {}
885
-
886
- logger.debug(
887
- "Successfully created a minus strand combined methylation dictionary for {}".format(
888
- str(sample_index)
889
- )
890
- )
891
- dict_combined_bottom[record][sample_index] = []
892
- logger.debug(
893
- "Successfully created a plus strand combined methylation dictionary for {}".format(
894
- str(sample_index)
895
- )
896
- )
897
- dict_combined_top[record][sample_index] = []
898
-
899
- # Reassign pointer for dict_total to None and delete the original value that it pointed to in order to decrease memory usage.
900
- dict_total[record][sample_index] = None
901
- gc.collect()
1429
+ batch_dicts, dict_to_skip = _build_modification_dicts(dict_total, mods)
1430
+ dict_list = batch_dicts.as_list()
1431
+ sample_types = batch_dicts.sample_types
902
1432
 
903
1433
  # Iterate over the stranded modification dictionaries and replace the dataframes with a dictionary of read names pointing to a list of values from the dataframe
904
1434
  for dict_index, dict_type in enumerate(dict_list):
@@ -984,14 +1514,14 @@ def modkit_extract_to_adata(
984
1514
  mod_strand_record_sample_dict[sample] = {}
985
1515
 
986
1516
  # Get relevant columns as NumPy arrays
987
- read_ids = temp_df["read_id"].values
988
- positions = temp_df["ref_position"].values
989
- call_codes = temp_df["call_code"].values
990
- probabilities = temp_df["call_prob"].values
1517
+ read_ids = temp_df[MODKIT_EXTRACT_TSV_COLUMN_READ_ID].values
1518
+ positions = temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION].values
1519
+ call_codes = temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_CODE].values
1520
+ probabilities = temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_PROB].values
991
1521
 
992
1522
  # Define valid call code categories
993
- modified_codes = {"a", "h", "m"}
994
- canonical_codes = {"-"}
1523
+ modified_codes = MODKIT_EXTRACT_CALL_CODE_MODIFIED
1524
+ canonical_codes = MODKIT_EXTRACT_CALL_CODE_CANONICAL
995
1525
 
996
1526
  # Vectorized methylation calculation with NaN for other codes
997
1527
  methylation_prob = np.full_like(
@@ -1087,39 +1617,63 @@ def modkit_extract_to_adata(
1087
1617
  final_sample_index,
1088
1618
  )
1089
1619
  )
1090
- temp_adata.obs["Sample"] = [
1620
+ temp_adata.obs[SAMPLE] = [
1091
1621
  sample_name_map[final_sample_index]
1092
1622
  ] * len(temp_adata)
1093
- temp_adata.obs["Barcode"] = [barcode_map[final_sample_index]] * len(
1623
+ temp_adata.obs[BARCODE] = [barcode_map[final_sample_index]] * len(
1094
1624
  temp_adata
1095
1625
  )
1096
- temp_adata.obs["Reference"] = [f"{record}"] * len(temp_adata)
1097
- temp_adata.obs["Strand"] = [strand] * len(temp_adata)
1098
- temp_adata.obs["Dataset"] = [dataset] * len(temp_adata)
1099
- temp_adata.obs["Reference_dataset_strand"] = [
1626
+ temp_adata.obs[REFERENCE] = [f"{record}"] * len(temp_adata)
1627
+ temp_adata.obs[STRAND] = [strand] * len(temp_adata)
1628
+ temp_adata.obs[DATASET] = [dataset] * len(temp_adata)
1629
+ temp_adata.obs[REFERENCE_DATASET_STRAND] = [
1100
1630
  f"{record}_{dataset}_{strand}"
1101
1631
  ] * len(temp_adata)
1102
- temp_adata.obs["Reference_strand"] = [f"{record}_{strand}"] * len(
1632
+ temp_adata.obs[REFERENCE_STRAND] = [f"{record}_{strand}"] * len(
1103
1633
  temp_adata
1104
1634
  )
1105
1635
 
1106
- # Load in the one hot encoded reads from the current sample and record
1107
- one_hot_reads = {}
1108
- n_rows_OHE = 5
1109
- ohe_files = bam_record_ohe_files[f"{final_sample_index}_{record}"]
1110
- logger.info(f"Loading OHEs from {ohe_files}")
1111
- fwd_mapped_reads = set()
1112
- rev_mapped_reads = set()
1113
- for ohe_file in ohe_files:
1114
- tmp_ohe_dict = ad.read_h5ad(ohe_file).uns
1115
- one_hot_reads.update(tmp_ohe_dict)
1116
- if "_fwd_" in ohe_file:
1117
- fwd_mapped_reads.update(tmp_ohe_dict.keys())
1118
- elif "_rev_" in ohe_file:
1119
- rev_mapped_reads.update(tmp_ohe_dict.keys())
1120
- del tmp_ohe_dict
1121
-
1122
- read_names = list(one_hot_reads.keys())
1636
+ # Load integer-encoded reads for the current sample/record
1637
+ sequence_files = _normalize_sequence_batch_files(
1638
+ sequence_batch_files.get(f"{final_sample_index}_{record}", [])
1639
+ )
1640
+ mismatch_files = _normalize_sequence_batch_files(
1641
+ mismatch_batch_files.get(f"{final_sample_index}_{record}", [])
1642
+ )
1643
+ quality_files = _normalize_sequence_batch_files(
1644
+ quality_batch_files.get(f"{final_sample_index}_{record}", [])
1645
+ )
1646
+ read_span_files = _normalize_sequence_batch_files(
1647
+ read_span_batch_files.get(f"{final_sample_index}_{record}", [])
1648
+ )
1649
+ if not sequence_files:
1650
+ logger.warning(
1651
+ "No encoded sequence batches found for sample %s record %s",
1652
+ final_sample_index,
1653
+ record,
1654
+ )
1655
+ continue
1656
+ logger.info(f"Loading encoded sequences from {sequence_files}")
1657
+ (
1658
+ encoded_reads,
1659
+ fwd_mapped_reads,
1660
+ rev_mapped_reads,
1661
+ ) = _load_sequence_batches(sequence_files)
1662
+ mismatch_reads: dict[str, np.ndarray] = {}
1663
+ if mismatch_files:
1664
+ (
1665
+ mismatch_reads,
1666
+ _mismatch_fwd_reads,
1667
+ _mismatch_rev_reads,
1668
+ ) = _load_sequence_batches(mismatch_files)
1669
+ quality_reads: dict[str, np.ndarray] = {}
1670
+ if quality_files:
1671
+ quality_reads = _load_integer_batches(quality_files)
1672
+ read_span_reads: dict[str, np.ndarray] = {}
1673
+ if read_span_files:
1674
+ read_span_reads = _load_integer_batches(read_span_files)
1675
+
1676
+ read_names = list(encoded_reads.keys())
1123
1677
 
1124
1678
  read_mapping_direction = []
1125
1679
  for read_id in temp_adata.obs_names:
@@ -1130,57 +1684,69 @@ def modkit_extract_to_adata(
1130
1684
  else:
1131
1685
  read_mapping_direction.append("unk")
1132
1686
 
1133
- temp_adata.obs["Read_mapping_direction"] = read_mapping_direction
1687
+ temp_adata.obs[READ_MAPPING_DIRECTION] = read_mapping_direction
1134
1688
 
1135
1689
  del temp_df
1136
1690
 
1137
- # Initialize NumPy arrays
1138
- sequence_length = (
1139
- one_hot_reads[read_names[0]].reshape(n_rows_OHE, -1).shape[1]
1691
+ padding_value = MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[
1692
+ MODKIT_EXTRACT_SEQUENCE_PADDING_BASE
1693
+ ]
1694
+ sequence_length = encoded_reads[read_names[0]].shape[0]
1695
+ encoded_matrix = np.full(
1696
+ (len(sorted_index), sequence_length),
1697
+ padding_value,
1698
+ dtype=np.int16,
1140
1699
  )
1141
- df_A = np.zeros((len(sorted_index), sequence_length), dtype=int)
1142
- df_C = np.zeros((len(sorted_index), sequence_length), dtype=int)
1143
- df_G = np.zeros((len(sorted_index), sequence_length), dtype=int)
1144
- df_T = np.zeros((len(sorted_index), sequence_length), dtype=int)
1145
- df_N = np.zeros((len(sorted_index), sequence_length), dtype=int)
1146
-
1147
- # Process one-hot data into dictionaries
1148
- dict_A, dict_C, dict_G, dict_T, dict_N = {}, {}, {}, {}, {}
1149
- for read_name, one_hot_array in one_hot_reads.items():
1150
- one_hot_array = one_hot_array.reshape(n_rows_OHE, -1)
1151
- dict_A[read_name] = one_hot_array[0, :]
1152
- dict_C[read_name] = one_hot_array[1, :]
1153
- dict_G[read_name] = one_hot_array[2, :]
1154
- dict_T[read_name] = one_hot_array[3, :]
1155
- dict_N[read_name] = one_hot_array[4, :]
1156
-
1157
- del one_hot_reads
1158
- gc.collect()
1159
1700
 
1160
- # Fill the arrays
1161
1701
  for j, read_name in tqdm(
1162
1702
  enumerate(sorted_index),
1163
- desc="Loading dataframes of OHE reads",
1703
+ desc="Loading integer-encoded reads",
1164
1704
  total=len(sorted_index),
1165
1705
  ):
1166
- df_A[j, :] = dict_A[read_name]
1167
- df_C[j, :] = dict_C[read_name]
1168
- df_G[j, :] = dict_G[read_name]
1169
- df_T[j, :] = dict_T[read_name]
1170
- df_N[j, :] = dict_N[read_name]
1706
+ encoded_matrix[j, :] = encoded_reads[read_name]
1171
1707
 
1172
- del dict_A, dict_C, dict_G, dict_T, dict_N
1708
+ del encoded_reads
1173
1709
  gc.collect()
1174
1710
 
1175
- # Store the results in AnnData layers
1176
- ohe_df_map = {0: df_A, 1: df_C, 2: df_G, 3: df_T, 4: df_N}
1177
- for j, base in enumerate(["A", "C", "G", "T", "N"]):
1178
- temp_adata.layers[f"{base}_binary_sequence_encoding"] = (
1179
- ohe_df_map[j]
1711
+ temp_adata.layers[SEQUENCE_INTEGER_ENCODING] = encoded_matrix
1712
+ if mismatch_reads:
1713
+ current_reference_length = reference_dict[record][0]
1714
+ default_mismatch_sequence = np.full(
1715
+ sequence_length,
1716
+ MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT["N"],
1717
+ dtype=np.int16,
1180
1718
  )
1181
- ohe_df_map[j] = (
1182
- None # Reassign pointer for memory usage purposes
1719
+ if current_reference_length < sequence_length:
1720
+ default_mismatch_sequence[current_reference_length:] = (
1721
+ padding_value
1722
+ )
1723
+ mismatch_matrix = np.vstack(
1724
+ [
1725
+ mismatch_reads.get(read_name, default_mismatch_sequence)
1726
+ for read_name in sorted_index
1727
+ ]
1728
+ )
1729
+ temp_adata.layers[MISMATCH_INTEGER_ENCODING] = mismatch_matrix
1730
+ if quality_reads:
1731
+ default_quality_sequence = np.full(
1732
+ sequence_length, -1, dtype=np.int16
1733
+ )
1734
+ quality_matrix = np.vstack(
1735
+ [
1736
+ quality_reads.get(read_name, default_quality_sequence)
1737
+ for read_name in sorted_index
1738
+ ]
1183
1739
  )
1740
+ temp_adata.layers[BASE_QUALITY_SCORES] = quality_matrix
1741
+ if read_span_reads:
1742
+ default_read_span = np.zeros(sequence_length, dtype=np.int16)
1743
+ read_span_matrix = np.vstack(
1744
+ [
1745
+ read_span_reads.get(read_name, default_read_span)
1746
+ for read_name in sorted_index
1747
+ ]
1748
+ )
1749
+ temp_adata.layers[READ_SPAN_MASK] = read_span_matrix
1184
1750
 
1185
1751
  # If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
1186
1752
  if adata:
@@ -1273,8 +1839,14 @@ def modkit_extract_to_adata(
1273
1839
  for col in final_adata.obs.columns:
1274
1840
  final_adata.obs[col] = final_adata.obs[col].astype("category")
1275
1841
 
1276
- ohe_bases = ["A", "C", "G", "T"] # ignore N bases for consensus
1277
- ohe_layers = [f"{ohe_base}_binary_sequence_encoding" for ohe_base in ohe_bases]
1842
+ final_adata.uns[f"{SEQUENCE_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
1843
+ final_adata.uns[f"{MISMATCH_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
1844
+ final_adata.uns[f"{SEQUENCE_INTEGER_DECODING}_map"] = {
1845
+ str(key): value for key, value in MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE.items()
1846
+ }
1847
+
1848
+ consensus_bases = MODKIT_EXTRACT_SEQUENCE_BASES[:4] # ignore N/PAD for consensus
1849
+ consensus_base_ints = [MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[base] for base in consensus_bases]
1278
1850
  final_adata.uns["References"] = {}
1279
1851
  for record in records_to_analyze:
1280
1852
  # Add FASTA sequence to the object
@@ -1285,27 +1857,33 @@ def modkit_extract_to_adata(
1285
1857
  final_adata.uns[f"{record}_FASTA_sequence"] = sequence
1286
1858
  final_adata.uns["References"][f"{record}_FASTA_sequence"] = sequence
1287
1859
  # Add consensus sequence of samples mapped to the record to the object
1288
- record_subset = final_adata[final_adata.obs["Reference"] == record]
1289
- for strand in record_subset.obs["Strand"].cat.categories:
1290
- strand_subset = record_subset[record_subset.obs["Strand"] == strand]
1291
- for mapping_dir in strand_subset.obs["Read_mapping_direction"].cat.categories:
1860
+ record_subset = final_adata[final_adata.obs[REFERENCE] == record]
1861
+ for strand in record_subset.obs[STRAND].cat.categories:
1862
+ strand_subset = record_subset[record_subset.obs[STRAND] == strand]
1863
+ for mapping_dir in strand_subset.obs[READ_MAPPING_DIRECTION].cat.categories:
1292
1864
  mapping_dir_subset = strand_subset[
1293
- strand_subset.obs["Read_mapping_direction"] == mapping_dir
1865
+ strand_subset.obs[READ_MAPPING_DIRECTION] == mapping_dir
1866
+ ]
1867
+ encoded_sequences = mapping_dir_subset.layers[SEQUENCE_INTEGER_ENCODING]
1868
+ layer_counts = [
1869
+ np.sum(encoded_sequences == base_int, axis=0)
1870
+ for base_int in consensus_base_ints
1294
1871
  ]
1295
- layer_map, layer_counts = {}, []
1296
- for i, layer in enumerate(ohe_layers):
1297
- layer_map[i] = layer.split("_")[0]
1298
- layer_counts.append(np.sum(mapping_dir_subset.layers[layer], axis=0))
1299
1872
  count_array = np.array(layer_counts)
1300
1873
  nucleotide_indexes = np.argmax(count_array, axis=0)
1301
- consensus_sequence_list = [layer_map[i] for i in nucleotide_indexes]
1874
+ consensus_sequence_list = [consensus_bases[i] for i in nucleotide_indexes]
1875
+ no_calls_mask = np.sum(count_array, axis=0) == 0
1876
+ if np.any(no_calls_mask):
1877
+ consensus_sequence_list = np.array(consensus_sequence_list, dtype=object)
1878
+ consensus_sequence_list[no_calls_mask] = "N"
1879
+ consensus_sequence_list = consensus_sequence_list.tolist()
1302
1880
  final_adata.var[
1303
1881
  f"{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples"
1304
1882
  ] = consensus_sequence_list
1305
1883
 
1306
1884
  if input_already_demuxed:
1307
- final_adata.obs["demux_type"] = ["already"] * final_adata.shape[0]
1308
- final_adata.obs["demux_type"] = final_adata.obs["demux_type"].astype("category")
1885
+ final_adata.obs[DEMUX_TYPE] = ["already"] * final_adata.shape[0]
1886
+ final_adata.obs[DEMUX_TYPE] = final_adata.obs[DEMUX_TYPE].astype("category")
1309
1887
  else:
1310
1888
  from .h5ad_functions import add_demux_type_annotation
1311
1889