smftools 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +32 -6
- smftools/cli/hmm_adata.py +232 -31
- smftools/cli/latent_adata.py +318 -0
- smftools/cli/load_adata.py +77 -73
- smftools/cli/preprocess_adata.py +178 -53
- smftools/cli/spatial_adata.py +149 -101
- smftools/cli_entry.py +12 -0
- smftools/config/conversion.yaml +11 -1
- smftools/config/default.yaml +38 -1
- smftools/config/experiment_config.py +53 -1
- smftools/constants.py +65 -0
- smftools/hmm/HMM.py +88 -0
- smftools/informatics/__init__.py +6 -0
- smftools/informatics/bam_functions.py +358 -8
- smftools/informatics/converted_BAM_to_adata.py +584 -163
- smftools/informatics/h5ad_functions.py +115 -2
- smftools/informatics/modkit_extract_to_adata.py +1003 -425
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/metadata.py +1 -1
- smftools/plotting/__init__.py +9 -0
- smftools/plotting/general_plotting.py +2411 -628
- smftools/plotting/hmm_plotting.py +85 -7
- smftools/preprocessing/__init__.py +1 -0
- smftools/preprocessing/append_base_context.py +17 -17
- smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/readwrite.py +53 -17
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +4 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +119 -0
- smftools/tools/calculate_umap.py +91 -8
- smftools/tools/rolling_nn_distance.py +235 -0
- smftools/tools/tensor_factorization.py +169 -0
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/METADATA +8 -6
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/RECORD +42 -35
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,13 +4,48 @@ import concurrent.futures
|
|
|
4
4
|
import gc
|
|
5
5
|
import re
|
|
6
6
|
import shutil
|
|
7
|
+
from dataclasses import dataclass, field
|
|
7
8
|
from pathlib import Path
|
|
8
|
-
from typing import Iterable, Optional, Union
|
|
9
|
+
from typing import Iterable, Mapping, Optional, Union
|
|
9
10
|
|
|
10
11
|
import numpy as np
|
|
11
12
|
import pandas as pd
|
|
12
13
|
from tqdm import tqdm
|
|
13
14
|
|
|
15
|
+
from smftools.constants import (
|
|
16
|
+
BARCODE,
|
|
17
|
+
BASE_QUALITY_SCORES,
|
|
18
|
+
DATASET,
|
|
19
|
+
DEMUX_TYPE,
|
|
20
|
+
H5_DIR,
|
|
21
|
+
MISMATCH_INTEGER_ENCODING,
|
|
22
|
+
MODKIT_EXTRACT_CALL_CODE_CANONICAL,
|
|
23
|
+
MODKIT_EXTRACT_CALL_CODE_MODIFIED,
|
|
24
|
+
MODKIT_EXTRACT_MODIFIED_BASE_A,
|
|
25
|
+
MODKIT_EXTRACT_MODIFIED_BASE_C,
|
|
26
|
+
MODKIT_EXTRACT_REF_STRAND_MINUS,
|
|
27
|
+
MODKIT_EXTRACT_REF_STRAND_PLUS,
|
|
28
|
+
MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
|
|
29
|
+
MODKIT_EXTRACT_SEQUENCE_BASES,
|
|
30
|
+
MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE,
|
|
31
|
+
MODKIT_EXTRACT_SEQUENCE_PADDING_BASE,
|
|
32
|
+
MODKIT_EXTRACT_TSV_COLUMN_CALL_CODE,
|
|
33
|
+
MODKIT_EXTRACT_TSV_COLUMN_CALL_PROB,
|
|
34
|
+
MODKIT_EXTRACT_TSV_COLUMN_CHROM,
|
|
35
|
+
MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE,
|
|
36
|
+
MODKIT_EXTRACT_TSV_COLUMN_READ_ID,
|
|
37
|
+
MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION,
|
|
38
|
+
MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND,
|
|
39
|
+
READ_MAPPING_DIRECTION,
|
|
40
|
+
READ_SPAN_MASK,
|
|
41
|
+
REFERENCE,
|
|
42
|
+
REFERENCE_DATASET_STRAND,
|
|
43
|
+
REFERENCE_STRAND,
|
|
44
|
+
SAMPLE,
|
|
45
|
+
SEQUENCE_INTEGER_DECODING,
|
|
46
|
+
SEQUENCE_INTEGER_ENCODING,
|
|
47
|
+
STRAND,
|
|
48
|
+
)
|
|
14
49
|
from smftools.logging_utils import get_logger
|
|
15
50
|
|
|
16
51
|
from .bam_functions import count_aligned_reads
|
|
@@ -18,8 +53,78 @@ from .bam_functions import count_aligned_reads
|
|
|
18
53
|
logger = get_logger(__name__)
|
|
19
54
|
|
|
20
55
|
|
|
56
|
+
@dataclass
|
|
57
|
+
class ModkitBatchDictionaries:
|
|
58
|
+
"""Container for per-batch modification dictionaries.
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
dict_total: Raw TSV DataFrames keyed by record and sample index.
|
|
62
|
+
dict_a: Adenine modification DataFrames.
|
|
63
|
+
dict_a_bottom: Adenine minus-strand DataFrames.
|
|
64
|
+
dict_a_top: Adenine plus-strand DataFrames.
|
|
65
|
+
dict_c: Cytosine modification DataFrames.
|
|
66
|
+
dict_c_bottom: Cytosine minus-strand DataFrames.
|
|
67
|
+
dict_c_top: Cytosine plus-strand DataFrames.
|
|
68
|
+
dict_combined_bottom: Combined minus-strand methylation arrays.
|
|
69
|
+
dict_combined_top: Combined plus-strand methylation arrays.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
dict_total: dict = field(default_factory=dict)
|
|
73
|
+
dict_a: dict = field(default_factory=dict)
|
|
74
|
+
dict_a_bottom: dict = field(default_factory=dict)
|
|
75
|
+
dict_a_top: dict = field(default_factory=dict)
|
|
76
|
+
dict_c: dict = field(default_factory=dict)
|
|
77
|
+
dict_c_bottom: dict = field(default_factory=dict)
|
|
78
|
+
dict_c_top: dict = field(default_factory=dict)
|
|
79
|
+
dict_combined_bottom: dict = field(default_factory=dict)
|
|
80
|
+
dict_combined_top: dict = field(default_factory=dict)
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def sample_types(self) -> list[str]:
|
|
84
|
+
"""Return ordered labels for the dictionary list."""
|
|
85
|
+
return [
|
|
86
|
+
"total",
|
|
87
|
+
"m6A",
|
|
88
|
+
"m6A_bottom_strand",
|
|
89
|
+
"m6A_top_strand",
|
|
90
|
+
"5mC",
|
|
91
|
+
"5mC_bottom_strand",
|
|
92
|
+
"5mC_top_strand",
|
|
93
|
+
"combined_bottom_strand",
|
|
94
|
+
"combined_top_strand",
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
def as_list(self) -> list[dict]:
|
|
98
|
+
"""Return the dictionaries in the expected list ordering."""
|
|
99
|
+
return [
|
|
100
|
+
self.dict_total,
|
|
101
|
+
self.dict_a,
|
|
102
|
+
self.dict_a_bottom,
|
|
103
|
+
self.dict_a_top,
|
|
104
|
+
self.dict_c,
|
|
105
|
+
self.dict_c_bottom,
|
|
106
|
+
self.dict_c_top,
|
|
107
|
+
self.dict_combined_bottom,
|
|
108
|
+
self.dict_combined_top,
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
|
|
21
112
|
def filter_bam_records(bam, mapping_threshold, samtools_backend: str | None = "auto"):
|
|
22
|
-
"""
|
|
113
|
+
"""Identify reference records that exceed a mapping threshold in one BAM.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
bam (Path | str): BAM file to inspect.
|
|
117
|
+
mapping_threshold (float): Minimum fraction of mapped reads required to keep a record.
|
|
118
|
+
samtools_backend (str | None): Samtools backend selection.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
set[str]: Record names that pass the mapping threshold.
|
|
122
|
+
|
|
123
|
+
Processing Steps:
|
|
124
|
+
1. Count aligned/unaligned reads per record.
|
|
125
|
+
2. Compute percent aligned and per-record mapping percentages.
|
|
126
|
+
3. Return records whose mapping fraction meets the threshold.
|
|
127
|
+
"""
|
|
23
128
|
aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(
|
|
24
129
|
bam, samtools_backend
|
|
25
130
|
)
|
|
@@ -40,7 +145,21 @@ def filter_bam_records(bam, mapping_threshold, samtools_backend: str | None = "a
|
|
|
40
145
|
|
|
41
146
|
|
|
42
147
|
def parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend: str | None = "auto"):
|
|
43
|
-
"""
|
|
148
|
+
"""Aggregate mapping-threshold records across BAM files in parallel.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
bam_path_list (list[Path | str]): BAM files to scan.
|
|
152
|
+
mapping_threshold (float): Minimum fraction of mapped reads required to keep a record.
|
|
153
|
+
samtools_backend (str | None): Samtools backend selection.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
set[str]: Union of all record names passing the threshold in any BAM.
|
|
157
|
+
|
|
158
|
+
Processing Steps:
|
|
159
|
+
1. Spawn workers to compute passing records per BAM.
|
|
160
|
+
2. Merge all passing records into a single set.
|
|
161
|
+
3. Log the final record set.
|
|
162
|
+
"""
|
|
44
163
|
records_to_analyze = set()
|
|
45
164
|
|
|
46
165
|
with concurrent.futures.ProcessPoolExecutor() as executor:
|
|
@@ -60,8 +179,21 @@ def parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend: str
|
|
|
60
179
|
|
|
61
180
|
|
|
62
181
|
def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
|
|
63
|
-
"""
|
|
64
|
-
|
|
182
|
+
"""Load and filter a modkit TSV file for relevant records and positions.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
tsv (Path | str): TSV file produced by modkit extract.
|
|
186
|
+
records_to_analyze (Iterable[str]): Record names to keep.
|
|
187
|
+
reference_dict (dict[str, tuple[int, str]]): Mapping of record to (length, sequence).
|
|
188
|
+
sample_index (int): Sample index to attach to the filtered results.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
dict[str, dict[int, pd.DataFrame]]: Filtered data keyed by record and sample index.
|
|
192
|
+
|
|
193
|
+
Processing Steps:
|
|
194
|
+
1. Read the TSV into a DataFrame.
|
|
195
|
+
2. Filter rows for each record to valid reference positions.
|
|
196
|
+
3. Emit per-record DataFrames keyed by the provided sample index.
|
|
65
197
|
"""
|
|
66
198
|
temp_df = pd.read_csv(tsv, sep="\t", header=0)
|
|
67
199
|
filtered_records = {}
|
|
@@ -72,9 +204,9 @@ def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
|
|
|
72
204
|
|
|
73
205
|
ref_length = reference_dict[record][0]
|
|
74
206
|
filtered_df = temp_df[
|
|
75
|
-
(temp_df[
|
|
76
|
-
& (temp_df[
|
|
77
|
-
& (temp_df[
|
|
207
|
+
(temp_df[MODKIT_EXTRACT_TSV_COLUMN_CHROM] == record)
|
|
208
|
+
& (temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION] >= 0)
|
|
209
|
+
& (temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION] < ref_length)
|
|
78
210
|
]
|
|
79
211
|
|
|
80
212
|
if not filtered_df.empty:
|
|
@@ -84,19 +216,23 @@ def process_tsv(tsv, records_to_analyze, reference_dict, sample_index):
|
|
|
84
216
|
|
|
85
217
|
|
|
86
218
|
def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, batch_size, threads=4):
|
|
87
|
-
"""
|
|
88
|
-
Loads and filters TSV files in parallel.
|
|
219
|
+
"""Load and filter a batch of TSVs in parallel.
|
|
89
220
|
|
|
90
|
-
|
|
91
|
-
tsv_batch (list):
|
|
92
|
-
records_to_analyze (
|
|
93
|
-
reference_dict (dict):
|
|
94
|
-
batch (int):
|
|
95
|
-
batch_size (int):
|
|
96
|
-
threads (int):
|
|
221
|
+
Args:
|
|
222
|
+
tsv_batch (list[Path | str]): TSV file paths for the batch.
|
|
223
|
+
records_to_analyze (Iterable[str]): Record names to keep.
|
|
224
|
+
reference_dict (dict[str, tuple[int, str]]): Mapping of record to (length, sequence).
|
|
225
|
+
batch (int): Batch number for progress logging.
|
|
226
|
+
batch_size (int): Number of TSVs in the batch.
|
|
227
|
+
threads (int): Parallel worker count.
|
|
97
228
|
|
|
98
229
|
Returns:
|
|
99
|
-
dict:
|
|
230
|
+
dict[str, dict[int, pd.DataFrame]]: Per-record DataFrames keyed by sample index.
|
|
231
|
+
|
|
232
|
+
Processing Steps:
|
|
233
|
+
1. Submit each TSV to a worker via `process_tsv`.
|
|
234
|
+
2. Merge per-record outputs into a single dictionary.
|
|
235
|
+
3. Return the aggregated per-record dictionary for the batch.
|
|
100
236
|
"""
|
|
101
237
|
dict_total = {record: {} for record in records_to_analyze}
|
|
102
238
|
|
|
@@ -121,15 +257,19 @@ def parallel_load_tsvs(tsv_batch, records_to_analyze, reference_dict, batch, bat
|
|
|
121
257
|
|
|
122
258
|
|
|
123
259
|
def update_dict_to_skip(dict_to_skip, detected_modifications):
|
|
124
|
-
"""
|
|
125
|
-
Updates the dict_to_skip set based on the detected modifications.
|
|
260
|
+
"""Update dictionary skip indices based on modifications in the batch.
|
|
126
261
|
|
|
127
|
-
|
|
128
|
-
dict_to_skip (set):
|
|
129
|
-
detected_modifications (
|
|
262
|
+
Args:
|
|
263
|
+
dict_to_skip (set[int]): Initial set of dictionary indices to skip.
|
|
264
|
+
detected_modifications (Iterable[str]): Modification labels present (e.g., ["6mA", "5mC"]).
|
|
130
265
|
|
|
131
266
|
Returns:
|
|
132
|
-
set:
|
|
267
|
+
set[int]: Updated skip set after considering present modifications.
|
|
268
|
+
|
|
269
|
+
Processing Steps:
|
|
270
|
+
1. Define indices for A- and C-stranded dictionaries.
|
|
271
|
+
2. Remove indices for modifications that are present.
|
|
272
|
+
3. Return the updated skip set.
|
|
133
273
|
"""
|
|
134
274
|
# Define which indices correspond to modification-specific or strand-specific dictionaries
|
|
135
275
|
A_stranded_dicts = {2, 3} # m6A bottom and top strand dictionaries
|
|
@@ -150,31 +290,49 @@ def update_dict_to_skip(dict_to_skip, detected_modifications):
|
|
|
150
290
|
|
|
151
291
|
|
|
152
292
|
def process_modifications_for_sample(args):
|
|
153
|
-
"""
|
|
154
|
-
Processes a single (record, sample) pair to extract modification-specific data.
|
|
293
|
+
"""Extract modification-specific subsets for one record/sample pair.
|
|
155
294
|
|
|
156
|
-
|
|
157
|
-
args: (record, sample_index, sample_df, mods, max_reference_length)
|
|
295
|
+
Args:
|
|
296
|
+
args (tuple): (record, sample_index, sample_df, mods, max_reference_length).
|
|
158
297
|
|
|
159
298
|
Returns:
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
299
|
+
tuple[str, int, dict[str, pd.DataFrame | list]]:
|
|
300
|
+
Record, sample index, and a dict of modification-specific DataFrames
|
|
301
|
+
(with optional combined placeholders).
|
|
302
|
+
|
|
303
|
+
Processing Steps:
|
|
304
|
+
1. Filter by modified base (A/C) when requested.
|
|
305
|
+
2. Split filtered rows by strand where needed.
|
|
306
|
+
3. Add empty combined placeholders when both modifications are present.
|
|
163
307
|
"""
|
|
164
308
|
record, sample_index, sample_df, mods, max_reference_length = args
|
|
165
309
|
result = {}
|
|
166
310
|
if "6mA" in mods:
|
|
167
|
-
m6a_df = sample_df[
|
|
311
|
+
m6a_df = sample_df[
|
|
312
|
+
sample_df[MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE]
|
|
313
|
+
== MODKIT_EXTRACT_MODIFIED_BASE_A
|
|
314
|
+
]
|
|
168
315
|
result["m6A"] = m6a_df
|
|
169
|
-
result["m6A_minus"] = m6a_df[
|
|
170
|
-
|
|
316
|
+
result["m6A_minus"] = m6a_df[
|
|
317
|
+
m6a_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_MINUS
|
|
318
|
+
]
|
|
319
|
+
result["m6A_plus"] = m6a_df[
|
|
320
|
+
m6a_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_PLUS
|
|
321
|
+
]
|
|
171
322
|
m6a_df = None
|
|
172
323
|
gc.collect()
|
|
173
324
|
if "5mC" in mods:
|
|
174
|
-
m5c_df = sample_df[
|
|
325
|
+
m5c_df = sample_df[
|
|
326
|
+
sample_df[MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE]
|
|
327
|
+
== MODKIT_EXTRACT_MODIFIED_BASE_C
|
|
328
|
+
]
|
|
175
329
|
result["5mC"] = m5c_df
|
|
176
|
-
result["5mC_minus"] = m5c_df[
|
|
177
|
-
|
|
330
|
+
result["5mC_minus"] = m5c_df[
|
|
331
|
+
m5c_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_MINUS
|
|
332
|
+
]
|
|
333
|
+
result["5mC_plus"] = m5c_df[
|
|
334
|
+
m5c_df[MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND] == MODKIT_EXTRACT_REF_STRAND_PLUS
|
|
335
|
+
]
|
|
178
336
|
m5c_df = None
|
|
179
337
|
gc.collect()
|
|
180
338
|
if "6mA" in mods and "5mC" in mods:
|
|
@@ -184,11 +342,22 @@ def process_modifications_for_sample(args):
|
|
|
184
342
|
|
|
185
343
|
|
|
186
344
|
def parallel_process_modifications(dict_total, mods, max_reference_length, threads=4):
|
|
187
|
-
"""
|
|
188
|
-
|
|
345
|
+
"""Parallelize modification extraction across records and samples.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
dict_total (dict[str, dict[int, pd.DataFrame]]): Raw TSV DataFrames per record/sample.
|
|
349
|
+
mods (list[str]): Modification labels to process.
|
|
350
|
+
max_reference_length (int): Maximum reference length in the dataset.
|
|
351
|
+
threads (int): Parallel worker count.
|
|
189
352
|
|
|
190
353
|
Returns:
|
|
191
|
-
|
|
354
|
+
dict[str, dict[int, dict[str, pd.DataFrame | list]]]: Processed results keyed by
|
|
355
|
+
record and sample index.
|
|
356
|
+
|
|
357
|
+
Processing Steps:
|
|
358
|
+
1. Build a task list of (record, sample) pairs.
|
|
359
|
+
2. Submit tasks to a process pool.
|
|
360
|
+
3. Collect and store results in a nested dictionary.
|
|
192
361
|
"""
|
|
193
362
|
tasks = []
|
|
194
363
|
for record, sample_dict in dict_total.items():
|
|
@@ -208,11 +377,20 @@ def parallel_process_modifications(dict_total, mods, max_reference_length, threa
|
|
|
208
377
|
|
|
209
378
|
|
|
210
379
|
def merge_modification_results(processed_results, mods):
|
|
211
|
-
"""
|
|
212
|
-
|
|
380
|
+
"""Merge per-sample modification outputs into global dictionaries.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
processed_results (dict[str, dict[int, dict]]): Output of parallel modification extraction.
|
|
384
|
+
mods (list[str]): Modification labels to include.
|
|
213
385
|
|
|
214
386
|
Returns:
|
|
215
|
-
|
|
387
|
+
tuple[dict, dict, dict, dict, dict, dict, dict, dict]:
|
|
388
|
+
Global dictionaries for each modification/strand combination.
|
|
389
|
+
|
|
390
|
+
Processing Steps:
|
|
391
|
+
1. Initialize empty output dictionaries per modification category.
|
|
392
|
+
2. Populate each dictionary using the processed sample results.
|
|
393
|
+
3. Return the ordered tuple for downstream processing.
|
|
216
394
|
"""
|
|
217
395
|
m6A_dict = {}
|
|
218
396
|
m6A_minus = {}
|
|
@@ -254,18 +432,18 @@ def merge_modification_results(processed_results, mods):
|
|
|
254
432
|
|
|
255
433
|
|
|
256
434
|
def process_stranded_methylation(args):
|
|
257
|
-
"""
|
|
258
|
-
Processes a single (dict_index, record, sample) task.
|
|
259
|
-
|
|
260
|
-
For combined dictionaries (indices 7 or 8), it merges the corresponding A-stranded and C-stranded data.
|
|
261
|
-
For other dictionaries, it converts the DataFrame into a nested dictionary mapping read names to a
|
|
262
|
-
NumPy methylation array (of float type). Non-numeric values (e.g. '-') are coerced to NaN.
|
|
435
|
+
"""Convert modification DataFrames into per-read methylation arrays.
|
|
263
436
|
|
|
264
|
-
|
|
265
|
-
args: (dict_index, record, sample, dict_list, max_reference_length)
|
|
437
|
+
Args:
|
|
438
|
+
args (tuple): (dict_index, record, sample, dict_list, max_reference_length).
|
|
266
439
|
|
|
267
440
|
Returns:
|
|
268
|
-
|
|
441
|
+
tuple[int, str, int, dict[str, np.ndarray]]: Updated dictionary entries for the task.
|
|
442
|
+
|
|
443
|
+
Processing Steps:
|
|
444
|
+
1. For combined dictionaries (indices 7/8), merge A- and C-strand arrays.
|
|
445
|
+
2. For other dictionaries, compute methylation probabilities per read/position.
|
|
446
|
+
3. Return per-read arrays keyed by read name.
|
|
269
447
|
"""
|
|
270
448
|
dict_index, record, sample, dict_list, max_reference_length = args
|
|
271
449
|
processed_data = {}
|
|
@@ -329,13 +507,15 @@ def process_stranded_methylation(args):
|
|
|
329
507
|
temp_df = dict_list[dict_index][record][sample]
|
|
330
508
|
processed_data = {}
|
|
331
509
|
# Extract columns and convert probabilities to float (coercing errors)
|
|
332
|
-
read_ids = temp_df[
|
|
333
|
-
positions = temp_df[
|
|
334
|
-
call_codes = temp_df[
|
|
335
|
-
probabilities = pd.to_numeric(
|
|
510
|
+
read_ids = temp_df[MODKIT_EXTRACT_TSV_COLUMN_READ_ID].values
|
|
511
|
+
positions = temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION].values
|
|
512
|
+
call_codes = temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_CODE].values
|
|
513
|
+
probabilities = pd.to_numeric(
|
|
514
|
+
temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_PROB].values, errors="coerce"
|
|
515
|
+
)
|
|
336
516
|
|
|
337
|
-
modified_codes =
|
|
338
|
-
canonical_codes =
|
|
517
|
+
modified_codes = MODKIT_EXTRACT_CALL_CODE_MODIFIED
|
|
518
|
+
canonical_codes = MODKIT_EXTRACT_CALL_CODE_CANONICAL
|
|
339
519
|
|
|
340
520
|
# Compute methylation probabilities (vectorized)
|
|
341
521
|
methylation_prob = np.full(probabilities.shape, np.nan, dtype=float)
|
|
@@ -363,11 +543,21 @@ def process_stranded_methylation(args):
|
|
|
363
543
|
|
|
364
544
|
|
|
365
545
|
def parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=4):
|
|
366
|
-
"""
|
|
367
|
-
|
|
546
|
+
"""Parallelize per-read methylation extraction over all dictionary entries.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
dict_list (list[dict]): List of modification/strand dictionaries.
|
|
550
|
+
dict_to_skip (set[int]): Dictionary indices to exclude from processing.
|
|
551
|
+
max_reference_length (int): Maximum reference length for array sizing.
|
|
552
|
+
threads (int): Parallel worker count.
|
|
368
553
|
|
|
369
554
|
Returns:
|
|
370
|
-
Updated
|
|
555
|
+
list[dict]: Updated dictionary list with per-read methylation arrays.
|
|
556
|
+
|
|
557
|
+
Processing Steps:
|
|
558
|
+
1. Build tasks for every (dict_index, record, sample) to process.
|
|
559
|
+
2. Execute tasks in a process pool.
|
|
560
|
+
3. Replace DataFrames with per-read arrays in-place.
|
|
371
561
|
"""
|
|
372
562
|
tasks = []
|
|
373
563
|
for dict_index, current_dict in enumerate(dict_list):
|
|
@@ -393,21 +583,20 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
393
583
|
dry_run: bool = False,
|
|
394
584
|
verbose: bool = True,
|
|
395
585
|
):
|
|
396
|
-
"""
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
Print progress / warnings.
|
|
586
|
+
"""Delete intermediate .h5ad files and optionally a temporary directory.
|
|
587
|
+
|
|
588
|
+
Args:
|
|
589
|
+
h5_dir (str | Path | Iterable[str] | None): Directory or iterable of h5ad paths.
|
|
590
|
+
tmp_dir (str | Path | None): Temporary directory to remove recursively.
|
|
591
|
+
dry_run (bool): If True, log deletions without performing them.
|
|
592
|
+
verbose (bool): If True, log progress and warnings.
|
|
593
|
+
|
|
594
|
+
Returns:
|
|
595
|
+
None: This function performs deletions in-place.
|
|
596
|
+
|
|
597
|
+
Processing Steps:
|
|
598
|
+
1. Iterate over .h5ad file candidates and delete them (if not dry-run).
|
|
599
|
+
2. Remove the temporary directory tree if requested.
|
|
411
600
|
"""
|
|
412
601
|
|
|
413
602
|
# Helper: remove a single file path (Path-like or string)
|
|
@@ -478,6 +667,455 @@ def delete_intermediate_h5ads_and_tmpdir(
|
|
|
478
667
|
logger.warning(f"[error] failed to remove tmp dir {td}: {e}")
|
|
479
668
|
|
|
480
669
|
|
|
670
|
+
def _collect_input_paths(mod_tsv_dir: Path, bam_dir: Path) -> tuple[list[Path], list[Path]]:
|
|
671
|
+
"""Collect sorted TSV and BAM paths for processing.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
mod_tsv_dir (Path): Directory containing modkit extract TSVs.
|
|
675
|
+
bam_dir (Path): Directory containing aligned BAM files.
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
tuple[list[Path], list[Path]]: Sorted TSV paths and BAM paths.
|
|
679
|
+
|
|
680
|
+
Processing Steps:
|
|
681
|
+
1. Filter TSVs for extract outputs and exclude unclassified entries.
|
|
682
|
+
2. Filter BAMs for aligned files and exclude indexes/unclassified entries.
|
|
683
|
+
3. Sort both lists for deterministic processing.
|
|
684
|
+
"""
|
|
685
|
+
tsvs = sorted(
|
|
686
|
+
p
|
|
687
|
+
for p in mod_tsv_dir.iterdir()
|
|
688
|
+
if p.is_file() and "unclassified" not in p.name and "extract.tsv" in p.name
|
|
689
|
+
)
|
|
690
|
+
bams = sorted(
|
|
691
|
+
p
|
|
692
|
+
for p in bam_dir.iterdir()
|
|
693
|
+
if p.is_file()
|
|
694
|
+
and p.suffix == ".bam"
|
|
695
|
+
and "unclassified" not in p.name
|
|
696
|
+
and ".bai" not in p.name
|
|
697
|
+
)
|
|
698
|
+
return tsvs, bams
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
def _build_sample_maps(bam_path_list: list[Path]) -> tuple[dict[int, str], dict[int, str]]:
|
|
702
|
+
"""Build sample name and barcode maps from BAM filenames.
|
|
703
|
+
|
|
704
|
+
Args:
|
|
705
|
+
bam_path_list (list[Path]): Paths to BAM files in sample order.
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
tuple[dict[int, str], dict[int, str]]: Maps of sample index to sample name and barcode.
|
|
709
|
+
|
|
710
|
+
Processing Steps:
|
|
711
|
+
1. Parse the BAM stem for barcode suffixes.
|
|
712
|
+
2. Build a standardized sample name with barcode suffix.
|
|
713
|
+
3. Store mappings for downstream metadata annotations.
|
|
714
|
+
"""
|
|
715
|
+
sample_name_map: dict[int, str] = {}
|
|
716
|
+
barcode_map: dict[int, str] = {}
|
|
717
|
+
|
|
718
|
+
for idx, bam_path in enumerate(bam_path_list):
|
|
719
|
+
stem = bam_path.stem
|
|
720
|
+
m = re.search(r"^(.*?)[_\-\.]?(barcode[0-9A-Za-z\-]+)$", stem)
|
|
721
|
+
if m:
|
|
722
|
+
sample_name = m.group(1) or stem
|
|
723
|
+
barcode = m.group(2)
|
|
724
|
+
else:
|
|
725
|
+
sample_name = stem
|
|
726
|
+
barcode = stem
|
|
727
|
+
|
|
728
|
+
sample_name = f"{sample_name}_{barcode}"
|
|
729
|
+
barcode_id = int(barcode.split("barcode")[1])
|
|
730
|
+
|
|
731
|
+
sample_name_map[idx] = sample_name
|
|
732
|
+
barcode_map[idx] = str(barcode_id)
|
|
733
|
+
|
|
734
|
+
return sample_name_map, barcode_map
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
def _encode_sequence_array(
|
|
738
|
+
read_sequence: np.ndarray,
|
|
739
|
+
valid_length: int,
|
|
740
|
+
base_to_int: Mapping[str, int],
|
|
741
|
+
padding_value: int,
|
|
742
|
+
) -> np.ndarray:
|
|
743
|
+
"""Convert a base-identity array into integer encoding with padding.
|
|
744
|
+
|
|
745
|
+
Args:
|
|
746
|
+
read_sequence (np.ndarray): Array of base calls (dtype "<U1").
|
|
747
|
+
valid_length (int): Number of valid reference positions for this record.
|
|
748
|
+
base_to_int (Mapping[str, int]): Base-to-integer mapping for A/C/G/T/N/PAD.
|
|
749
|
+
padding_value (int): Integer value to use for padding.
|
|
750
|
+
|
|
751
|
+
Returns:
|
|
752
|
+
np.ndarray: Integer-encoded sequence with padding applied.
|
|
753
|
+
|
|
754
|
+
Processing Steps:
|
|
755
|
+
1. Initialize an integer array filled with the N value.
|
|
756
|
+
2. Overwrite values for known bases (A/C/G/T/N).
|
|
757
|
+
3. Replace positions beyond valid_length with padding.
|
|
758
|
+
"""
|
|
759
|
+
read_sequence = np.asarray(read_sequence, dtype="<U1")
|
|
760
|
+
encoded = np.full(read_sequence.shape, base_to_int["N"], dtype=np.int16)
|
|
761
|
+
for base in MODKIT_EXTRACT_SEQUENCE_BASES:
|
|
762
|
+
encoded[read_sequence == base] = base_to_int[base]
|
|
763
|
+
if valid_length < encoded.size:
|
|
764
|
+
encoded[valid_length:] = padding_value
|
|
765
|
+
return encoded
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def _write_sequence_batches(
|
|
769
|
+
base_identities: Mapping[str, np.ndarray],
|
|
770
|
+
tmp_dir: Path,
|
|
771
|
+
record: str,
|
|
772
|
+
prefix: str,
|
|
773
|
+
base_to_int: Mapping[str, int],
|
|
774
|
+
valid_length: int,
|
|
775
|
+
batch_size: int,
|
|
776
|
+
) -> list[str]:
|
|
777
|
+
"""Encode base identities into integer arrays and write batched H5AD files.
|
|
778
|
+
|
|
779
|
+
Args:
|
|
780
|
+
base_identities (Mapping[str, np.ndarray]): Read name to base identity arrays.
|
|
781
|
+
tmp_dir (Path): Directory for temporary H5AD files.
|
|
782
|
+
record (str): Reference record identifier.
|
|
783
|
+
prefix (str): Prefix used to name batch files.
|
|
784
|
+
base_to_int (Mapping[str, int]): Base-to-integer mapping.
|
|
785
|
+
valid_length (int): Valid reference length for padding determination.
|
|
786
|
+
batch_size (int): Number of reads per H5AD batch file.
|
|
787
|
+
|
|
788
|
+
Returns:
|
|
789
|
+
list[str]: Paths to written H5AD batch files.
|
|
790
|
+
|
|
791
|
+
Processing Steps:
|
|
792
|
+
1. Encode each read sequence to integer values.
|
|
793
|
+
2. Accumulate encoded reads into batches.
|
|
794
|
+
3. Persist each batch as an H5AD with the dictionary stored in `.uns`.
|
|
795
|
+
"""
|
|
796
|
+
import anndata as ad
|
|
797
|
+
|
|
798
|
+
padding_value = base_to_int[MODKIT_EXTRACT_SEQUENCE_PADDING_BASE]
|
|
799
|
+
batch_files: list[str] = []
|
|
800
|
+
batch: dict[str, np.ndarray] = {}
|
|
801
|
+
batch_number = 0
|
|
802
|
+
|
|
803
|
+
for read_name, sequence in base_identities.items():
|
|
804
|
+
if sequence is None:
|
|
805
|
+
continue
|
|
806
|
+
batch[read_name] = _encode_sequence_array(
|
|
807
|
+
sequence, valid_length, base_to_int, padding_value
|
|
808
|
+
)
|
|
809
|
+
if len(batch) >= batch_size:
|
|
810
|
+
save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
|
|
811
|
+
ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
|
|
812
|
+
batch_files.append(str(save_name))
|
|
813
|
+
batch = {}
|
|
814
|
+
batch_number += 1
|
|
815
|
+
|
|
816
|
+
if batch:
|
|
817
|
+
save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
|
|
818
|
+
ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
|
|
819
|
+
batch_files.append(str(save_name))
|
|
820
|
+
|
|
821
|
+
return batch_files
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
def _write_integer_batches(
|
|
825
|
+
sequences: Mapping[str, np.ndarray],
|
|
826
|
+
tmp_dir: Path,
|
|
827
|
+
record: str,
|
|
828
|
+
prefix: str,
|
|
829
|
+
batch_size: int,
|
|
830
|
+
) -> list[str]:
|
|
831
|
+
"""Write integer-encoded sequences into batched H5AD files.
|
|
832
|
+
|
|
833
|
+
Args:
|
|
834
|
+
sequences (Mapping[str, np.ndarray]): Read name to integer arrays.
|
|
835
|
+
tmp_dir (Path): Directory for temporary H5AD files.
|
|
836
|
+
record (str): Reference record identifier.
|
|
837
|
+
prefix (str): Prefix used to name batch files.
|
|
838
|
+
batch_size (int): Number of reads per H5AD batch file.
|
|
839
|
+
|
|
840
|
+
Returns:
|
|
841
|
+
list[str]: Paths to written H5AD batch files.
|
|
842
|
+
|
|
843
|
+
Processing Steps:
|
|
844
|
+
1. Accumulate integer arrays into batches.
|
|
845
|
+
2. Persist each batch as an H5AD with the dictionary stored in `.uns`.
|
|
846
|
+
"""
|
|
847
|
+
import anndata as ad
|
|
848
|
+
|
|
849
|
+
batch_files: list[str] = []
|
|
850
|
+
batch: dict[str, np.ndarray] = {}
|
|
851
|
+
batch_number = 0
|
|
852
|
+
|
|
853
|
+
for read_name, sequence in sequences.items():
|
|
854
|
+
if sequence is None:
|
|
855
|
+
continue
|
|
856
|
+
batch[read_name] = np.asarray(sequence, dtype=np.int16)
|
|
857
|
+
if len(batch) >= batch_size:
|
|
858
|
+
save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
|
|
859
|
+
ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
|
|
860
|
+
batch_files.append(str(save_name))
|
|
861
|
+
batch = {}
|
|
862
|
+
batch_number += 1
|
|
863
|
+
|
|
864
|
+
if batch:
|
|
865
|
+
save_name = tmp_dir / f"tmp_{prefix}_{record}_{batch_number}.h5ad"
|
|
866
|
+
ad.AnnData(X=np.zeros((1, 1)), uns=batch).write_h5ad(save_name)
|
|
867
|
+
batch_files.append(str(save_name))
|
|
868
|
+
|
|
869
|
+
return batch_files
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
def _load_sequence_batches(
|
|
873
|
+
batch_files: list[Path | str],
|
|
874
|
+
) -> tuple[dict[str, np.ndarray], set[str], set[str]]:
|
|
875
|
+
"""Load integer-encoded sequence batches from H5AD files.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
batch_files (list[Path | str]): H5AD paths containing encoded sequences in `.uns`.
|
|
879
|
+
|
|
880
|
+
Returns:
|
|
881
|
+
tuple[dict[str, np.ndarray], set[str], set[str]]:
|
|
882
|
+
Read-to-sequence mapping and sets of forward/reverse mapped reads.
|
|
883
|
+
|
|
884
|
+
Processing Steps:
|
|
885
|
+
1. Read each H5AD file.
|
|
886
|
+
2. Merge `.uns` dictionaries into a single mapping.
|
|
887
|
+
3. Track forward/reverse read IDs based on the filename marker.
|
|
888
|
+
"""
|
|
889
|
+
import anndata as ad
|
|
890
|
+
|
|
891
|
+
sequences: dict[str, np.ndarray] = {}
|
|
892
|
+
fwd_reads: set[str] = set()
|
|
893
|
+
rev_reads: set[str] = set()
|
|
894
|
+
for batch_file in batch_files:
|
|
895
|
+
batch_path = Path(batch_file)
|
|
896
|
+
batch_sequences = ad.read_h5ad(batch_path).uns
|
|
897
|
+
sequences.update(batch_sequences)
|
|
898
|
+
if "_fwd_" in batch_path.name:
|
|
899
|
+
fwd_reads.update(batch_sequences.keys())
|
|
900
|
+
elif "_rev_" in batch_path.name:
|
|
901
|
+
rev_reads.update(batch_sequences.keys())
|
|
902
|
+
return sequences, fwd_reads, rev_reads
|
|
903
|
+
|
|
904
|
+
|
|
905
|
+
def _load_integer_batches(batch_files: list[Path | str]) -> dict[str, np.ndarray]:
|
|
906
|
+
"""Load integer arrays from batched H5AD files.
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
batch_files (list[Path | str]): H5AD paths containing arrays in `.uns`.
|
|
910
|
+
|
|
911
|
+
Returns:
|
|
912
|
+
dict[str, np.ndarray]: Read-to-array mapping.
|
|
913
|
+
|
|
914
|
+
Processing Steps:
|
|
915
|
+
1. Read each H5AD file.
|
|
916
|
+
2. Merge `.uns` dictionaries into a single mapping.
|
|
917
|
+
"""
|
|
918
|
+
import anndata as ad
|
|
919
|
+
|
|
920
|
+
sequences: dict[str, np.ndarray] = {}
|
|
921
|
+
for batch_file in batch_files:
|
|
922
|
+
batch_path = Path(batch_file)
|
|
923
|
+
sequences.update(ad.read_h5ad(batch_path).uns)
|
|
924
|
+
return sequences
|
|
925
|
+
|
|
926
|
+
|
|
927
|
+
def _normalize_sequence_batch_files(batch_files: object) -> list[Path]:
|
|
928
|
+
"""Normalize cached batch file entries into a list of Paths.
|
|
929
|
+
|
|
930
|
+
Args:
|
|
931
|
+
batch_files (object): Cached batch file entry from AnnData `.uns`.
|
|
932
|
+
|
|
933
|
+
Returns:
|
|
934
|
+
list[Path]: Paths to batch files, filtered to non-empty values.
|
|
935
|
+
|
|
936
|
+
Processing Steps:
|
|
937
|
+
1. Convert numpy arrays and scalars into Python lists.
|
|
938
|
+
2. Filter out empty/placeholder values.
|
|
939
|
+
3. Cast remaining entries to Path objects.
|
|
940
|
+
"""
|
|
941
|
+
if batch_files is None:
|
|
942
|
+
return []
|
|
943
|
+
if isinstance(batch_files, np.ndarray):
|
|
944
|
+
batch_files = batch_files.tolist()
|
|
945
|
+
if isinstance(batch_files, (str, Path)):
|
|
946
|
+
batch_files = [batch_files]
|
|
947
|
+
if not isinstance(batch_files, list):
|
|
948
|
+
batch_files = list(batch_files)
|
|
949
|
+
normalized: list[Path] = []
|
|
950
|
+
for entry in batch_files:
|
|
951
|
+
if entry is None:
|
|
952
|
+
continue
|
|
953
|
+
entry_str = str(entry).strip()
|
|
954
|
+
if not entry_str or entry_str == ".":
|
|
955
|
+
continue
|
|
956
|
+
normalized.append(Path(entry_str))
|
|
957
|
+
return normalized
|
|
958
|
+
|
|
959
|
+
|
|
960
|
+
def _build_modification_dicts(
|
|
961
|
+
dict_total: dict,
|
|
962
|
+
mods: list[str],
|
|
963
|
+
) -> tuple[ModkitBatchDictionaries, set[int]]:
|
|
964
|
+
"""Build modification/strand dictionaries from the raw TSV batch dictionary.
|
|
965
|
+
|
|
966
|
+
Args:
|
|
967
|
+
dict_total (dict): Raw TSV DataFrames keyed by record and sample index.
|
|
968
|
+
mods (list[str]): Modification labels to include (e.g., ["6mA", "5mC"]).
|
|
969
|
+
|
|
970
|
+
Returns:
|
|
971
|
+
tuple[ModkitBatchDictionaries, set[int]]: Batch dictionaries and indices to skip.
|
|
972
|
+
|
|
973
|
+
Processing Steps:
|
|
974
|
+
1. Initialize modification dictionaries and skip-set.
|
|
975
|
+
2. Filter TSV rows per record/sample into modification and strand subsets.
|
|
976
|
+
3. Populate combined dict placeholders when both modifications are present.
|
|
977
|
+
"""
|
|
978
|
+
batch_dicts = ModkitBatchDictionaries(dict_total=dict_total)
|
|
979
|
+
dict_to_skip = {0, 1, 4}
|
|
980
|
+
combined_dicts = {7, 8}
|
|
981
|
+
A_stranded_dicts = {2, 3}
|
|
982
|
+
C_stranded_dicts = {5, 6}
|
|
983
|
+
dict_to_skip.update(combined_dicts | A_stranded_dicts | C_stranded_dicts)
|
|
984
|
+
|
|
985
|
+
for record in dict_total.keys():
|
|
986
|
+
for sample_index in dict_total[record].keys():
|
|
987
|
+
if "6mA" in mods:
|
|
988
|
+
dict_to_skip.difference_update(A_stranded_dicts)
|
|
989
|
+
if (
|
|
990
|
+
record not in batch_dicts.dict_a.keys()
|
|
991
|
+
and record not in batch_dicts.dict_a_bottom.keys()
|
|
992
|
+
and record not in batch_dicts.dict_a_top.keys()
|
|
993
|
+
):
|
|
994
|
+
(
|
|
995
|
+
batch_dicts.dict_a[record],
|
|
996
|
+
batch_dicts.dict_a_bottom[record],
|
|
997
|
+
batch_dicts.dict_a_top[record],
|
|
998
|
+
) = ({}, {}, {})
|
|
999
|
+
|
|
1000
|
+
batch_dicts.dict_a[record][sample_index] = dict_total[record][sample_index][
|
|
1001
|
+
dict_total[record][sample_index][
|
|
1002
|
+
MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE
|
|
1003
|
+
]
|
|
1004
|
+
== MODKIT_EXTRACT_MODIFIED_BASE_A
|
|
1005
|
+
]
|
|
1006
|
+
logger.debug(
|
|
1007
|
+
"Successfully loaded a methyl-adenine dictionary for {}".format(
|
|
1008
|
+
str(sample_index)
|
|
1009
|
+
)
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
batch_dicts.dict_a_bottom[record][sample_index] = batch_dicts.dict_a[record][
|
|
1013
|
+
sample_index
|
|
1014
|
+
][
|
|
1015
|
+
batch_dicts.dict_a[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
|
|
1016
|
+
== MODKIT_EXTRACT_REF_STRAND_MINUS
|
|
1017
|
+
]
|
|
1018
|
+
logger.debug(
|
|
1019
|
+
"Successfully loaded a minus strand methyl-adenine dictionary for {}".format(
|
|
1020
|
+
str(sample_index)
|
|
1021
|
+
)
|
|
1022
|
+
)
|
|
1023
|
+
batch_dicts.dict_a_top[record][sample_index] = batch_dicts.dict_a[record][
|
|
1024
|
+
sample_index
|
|
1025
|
+
][
|
|
1026
|
+
batch_dicts.dict_a[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
|
|
1027
|
+
== MODKIT_EXTRACT_REF_STRAND_PLUS
|
|
1028
|
+
]
|
|
1029
|
+
logger.debug(
|
|
1030
|
+
"Successfully loaded a plus strand methyl-adenine dictionary for ".format(
|
|
1031
|
+
str(sample_index)
|
|
1032
|
+
)
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
batch_dicts.dict_a[record][sample_index] = None
|
|
1036
|
+
gc.collect()
|
|
1037
|
+
|
|
1038
|
+
if "5mC" in mods:
|
|
1039
|
+
dict_to_skip.difference_update(C_stranded_dicts)
|
|
1040
|
+
if (
|
|
1041
|
+
record not in batch_dicts.dict_c.keys()
|
|
1042
|
+
and record not in batch_dicts.dict_c_bottom.keys()
|
|
1043
|
+
and record not in batch_dicts.dict_c_top.keys()
|
|
1044
|
+
):
|
|
1045
|
+
(
|
|
1046
|
+
batch_dicts.dict_c[record],
|
|
1047
|
+
batch_dicts.dict_c_bottom[record],
|
|
1048
|
+
batch_dicts.dict_c_top[record],
|
|
1049
|
+
) = ({}, {}, {})
|
|
1050
|
+
|
|
1051
|
+
batch_dicts.dict_c[record][sample_index] = dict_total[record][sample_index][
|
|
1052
|
+
dict_total[record][sample_index][
|
|
1053
|
+
MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE
|
|
1054
|
+
]
|
|
1055
|
+
== MODKIT_EXTRACT_MODIFIED_BASE_C
|
|
1056
|
+
]
|
|
1057
|
+
logger.debug(
|
|
1058
|
+
"Successfully loaded a methyl-cytosine dictionary for {}".format(
|
|
1059
|
+
str(sample_index)
|
|
1060
|
+
)
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
batch_dicts.dict_c_bottom[record][sample_index] = batch_dicts.dict_c[record][
|
|
1064
|
+
sample_index
|
|
1065
|
+
][
|
|
1066
|
+
batch_dicts.dict_c[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
|
|
1067
|
+
== MODKIT_EXTRACT_REF_STRAND_MINUS
|
|
1068
|
+
]
|
|
1069
|
+
logger.debug(
|
|
1070
|
+
"Successfully loaded a minus strand methyl-cytosine dictionary for {}".format(
|
|
1071
|
+
str(sample_index)
|
|
1072
|
+
)
|
|
1073
|
+
)
|
|
1074
|
+
batch_dicts.dict_c_top[record][sample_index] = batch_dicts.dict_c[record][
|
|
1075
|
+
sample_index
|
|
1076
|
+
][
|
|
1077
|
+
batch_dicts.dict_c[record][sample_index][MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND]
|
|
1078
|
+
== MODKIT_EXTRACT_REF_STRAND_PLUS
|
|
1079
|
+
]
|
|
1080
|
+
logger.debug(
|
|
1081
|
+
"Successfully loaded a plus strand methyl-cytosine dictionary for {}".format(
|
|
1082
|
+
str(sample_index)
|
|
1083
|
+
)
|
|
1084
|
+
)
|
|
1085
|
+
|
|
1086
|
+
batch_dicts.dict_c[record][sample_index] = None
|
|
1087
|
+
gc.collect()
|
|
1088
|
+
|
|
1089
|
+
if "6mA" in mods and "5mC" in mods:
|
|
1090
|
+
dict_to_skip.difference_update(combined_dicts)
|
|
1091
|
+
if (
|
|
1092
|
+
record not in batch_dicts.dict_combined_bottom.keys()
|
|
1093
|
+
and record not in batch_dicts.dict_combined_top.keys()
|
|
1094
|
+
):
|
|
1095
|
+
(
|
|
1096
|
+
batch_dicts.dict_combined_bottom[record],
|
|
1097
|
+
batch_dicts.dict_combined_top[record],
|
|
1098
|
+
) = ({}, {})
|
|
1099
|
+
|
|
1100
|
+
logger.debug(
|
|
1101
|
+
"Successfully created a minus strand combined methylation dictionary for {}".format(
|
|
1102
|
+
str(sample_index)
|
|
1103
|
+
)
|
|
1104
|
+
)
|
|
1105
|
+
batch_dicts.dict_combined_bottom[record][sample_index] = []
|
|
1106
|
+
logger.debug(
|
|
1107
|
+
"Successfully created a plus strand combined methylation dictionary for {}".format(
|
|
1108
|
+
str(sample_index)
|
|
1109
|
+
)
|
|
1110
|
+
)
|
|
1111
|
+
batch_dicts.dict_combined_top[record][sample_index] = []
|
|
1112
|
+
|
|
1113
|
+
dict_total[record][sample_index] = None
|
|
1114
|
+
gc.collect()
|
|
1115
|
+
|
|
1116
|
+
return batch_dicts, dict_to_skip
|
|
1117
|
+
|
|
1118
|
+
|
|
481
1119
|
def modkit_extract_to_adata(
|
|
482
1120
|
fasta,
|
|
483
1121
|
bam_dir,
|
|
@@ -493,24 +1131,32 @@ def modkit_extract_to_adata(
|
|
|
493
1131
|
double_barcoded_path=None,
|
|
494
1132
|
samtools_backend: str | None = "auto",
|
|
495
1133
|
):
|
|
496
|
-
"""
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
double_barcoded_path (Path):
|
|
1134
|
+
"""Convert modkit extract TSVs and BAMs into an AnnData object.
|
|
1135
|
+
|
|
1136
|
+
Args:
|
|
1137
|
+
fasta (Path): Reference FASTA path.
|
|
1138
|
+
bam_dir (Path): Directory with aligned BAM files.
|
|
1139
|
+
out_dir (Path): Output directory for intermediate and final H5ADs.
|
|
1140
|
+
input_already_demuxed (bool): Whether reads were already demultiplexed.
|
|
1141
|
+
mapping_threshold (float): Minimum fraction of mapped reads to keep a record.
|
|
1142
|
+
experiment_name (str): Experiment name used in output file naming.
|
|
1143
|
+
mods (list[str]): Modification labels to analyze (e.g., ["6mA", "5mC"]).
|
|
1144
|
+
batch_size (int): Number of TSVs to process per batch.
|
|
1145
|
+
mod_tsv_dir (Path): Directory containing modkit extract TSVs.
|
|
1146
|
+
delete_batch_hdfs (bool): Remove batch H5ADs after concatenation.
|
|
1147
|
+
threads (int | None): Thread count for parallel operations.
|
|
1148
|
+
double_barcoded_path (Path | None): Dorado demux summary directory for double barcodes.
|
|
1149
|
+
samtools_backend (str | None): Samtools backend selection.
|
|
511
1150
|
|
|
512
1151
|
Returns:
|
|
513
|
-
|
|
1152
|
+
tuple[ad.AnnData | None, Path]: The final AnnData (if created) and its H5AD path.
|
|
1153
|
+
|
|
1154
|
+
Processing Steps:
|
|
1155
|
+
1. Discover input TSV/BAM files and derive sample metadata.
|
|
1156
|
+
2. Identify records that pass mapping thresholds and build reference metadata.
|
|
1157
|
+
3. Encode read sequences into integer arrays and cache them.
|
|
1158
|
+
4. Process TSV batches into per-read methylation matrices.
|
|
1159
|
+
5. Concatenate batch H5ADs into a final AnnData with consensus sequences.
|
|
514
1160
|
"""
|
|
515
1161
|
###################################################
|
|
516
1162
|
# Package imports
|
|
@@ -527,12 +1173,11 @@ def modkit_extract_to_adata(
|
|
|
527
1173
|
from ..readwrite import make_dirs
|
|
528
1174
|
from .bam_functions import extract_base_identities
|
|
529
1175
|
from .fasta_functions import get_native_references
|
|
530
|
-
from .ohe import ohe_batching
|
|
531
1176
|
###################################################
|
|
532
1177
|
|
|
533
1178
|
################## Get input tsv and bam file names into a sorted list ################
|
|
534
1179
|
# Make output dirs
|
|
535
|
-
h5_dir = out_dir /
|
|
1180
|
+
h5_dir = out_dir / H5_DIR
|
|
536
1181
|
tmp_dir = out_dir / "tmp"
|
|
537
1182
|
make_dirs([h5_dir, tmp_dir])
|
|
538
1183
|
|
|
@@ -546,55 +1191,14 @@ def modkit_extract_to_adata(
|
|
|
546
1191
|
logger.debug(f"{final_adata_path} already exists. Using existing adata")
|
|
547
1192
|
return final_adata, final_adata_path
|
|
548
1193
|
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
for p in mod_tsv_dir.iterdir()
|
|
553
|
-
if p.is_file() and "unclassified" not in p.name and "extract.tsv" in p.name
|
|
554
|
-
)
|
|
555
|
-
bams = sorted(
|
|
556
|
-
p
|
|
557
|
-
for p in bam_dir.iterdir()
|
|
558
|
-
if p.is_file()
|
|
559
|
-
and p.suffix == ".bam"
|
|
560
|
-
and "unclassified" not in p.name
|
|
561
|
-
and ".bai" not in p.name
|
|
562
|
-
)
|
|
563
|
-
|
|
564
|
-
tsv_path_list = [tsv for tsv in tsvs]
|
|
565
|
-
bam_path_list = [bam for bam in bams]
|
|
1194
|
+
tsvs, bams = _collect_input_paths(mod_tsv_dir, bam_dir)
|
|
1195
|
+
tsv_path_list = list(tsvs)
|
|
1196
|
+
bam_path_list = list(bams)
|
|
566
1197
|
logger.info(f"{len(tsvs)} sample tsv files found: {tsvs}")
|
|
567
1198
|
logger.info(f"{len(bams)} sample bams found: {bams}")
|
|
568
1199
|
|
|
569
1200
|
# Map global sample index (bami / final_sample_index) -> sample name / barcode
|
|
570
|
-
sample_name_map =
|
|
571
|
-
barcode_map = {}
|
|
572
|
-
|
|
573
|
-
for idx, bam_path in enumerate(bam_path_list):
|
|
574
|
-
stem = bam_path.stem
|
|
575
|
-
|
|
576
|
-
# Try to peel off a "barcode..." suffix if present.
|
|
577
|
-
# This handles things like:
|
|
578
|
-
# "mySample_barcode01" -> sample="mySample", barcode="barcode01"
|
|
579
|
-
# "run1-s1_barcode05" -> sample="run1-s1", barcode="barcode05"
|
|
580
|
-
# "barcode01" -> sample="barcode01", barcode="barcode01"
|
|
581
|
-
m = re.search(r"^(.*?)[_\-\.]?(barcode[0-9A-Za-z\-]+)$", stem)
|
|
582
|
-
if m:
|
|
583
|
-
sample_name = m.group(1) or stem
|
|
584
|
-
barcode = m.group(2)
|
|
585
|
-
else:
|
|
586
|
-
# Fallback: treat the whole stem as both sample & barcode
|
|
587
|
-
sample_name = stem
|
|
588
|
-
barcode = stem
|
|
589
|
-
|
|
590
|
-
# make sample name of the format of the bam file stem
|
|
591
|
-
sample_name = sample_name + f"_{barcode}"
|
|
592
|
-
|
|
593
|
-
# Clean the barcode name to be an integer
|
|
594
|
-
barcode = int(barcode.split("barcode")[1])
|
|
595
|
-
|
|
596
|
-
sample_name_map[idx] = sample_name
|
|
597
|
-
barcode_map[idx] = str(barcode)
|
|
1201
|
+
sample_name_map, barcode_map = _build_sample_maps(bam_path_list)
|
|
598
1202
|
##########################################################################################
|
|
599
1203
|
|
|
600
1204
|
######### Get Record names that have over a passed threshold of mapped reads #############
|
|
@@ -619,59 +1223,154 @@ def modkit_extract_to_adata(
|
|
|
619
1223
|
##########################################################################################
|
|
620
1224
|
|
|
621
1225
|
##########################################################################################
|
|
622
|
-
#
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
if
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
1226
|
+
# Encode read sequences into integer arrays and cache in tmp_dir.
|
|
1227
|
+
sequence_batch_files: dict[str, list[str]] = {}
|
|
1228
|
+
mismatch_batch_files: dict[str, list[str]] = {}
|
|
1229
|
+
quality_batch_files: dict[str, list[str]] = {}
|
|
1230
|
+
read_span_batch_files: dict[str, list[str]] = {}
|
|
1231
|
+
sequence_cache_path = tmp_dir / "tmp_sequence_int_file_dict.h5ad"
|
|
1232
|
+
cache_needs_rebuild = True
|
|
1233
|
+
if sequence_cache_path.exists():
|
|
1234
|
+
cached_uns = ad.read_h5ad(sequence_cache_path).uns
|
|
1235
|
+
if "sequence_batch_files" in cached_uns:
|
|
1236
|
+
sequence_batch_files = cached_uns.get("sequence_batch_files", {})
|
|
1237
|
+
mismatch_batch_files = cached_uns.get("mismatch_batch_files", {})
|
|
1238
|
+
quality_batch_files = cached_uns.get("quality_batch_files", {})
|
|
1239
|
+
read_span_batch_files = cached_uns.get("read_span_batch_files", {})
|
|
1240
|
+
cache_needs_rebuild = not (
|
|
1241
|
+
quality_batch_files and read_span_batch_files and sequence_batch_files
|
|
1242
|
+
)
|
|
1243
|
+
else:
|
|
1244
|
+
sequence_batch_files = cached_uns
|
|
1245
|
+
cache_needs_rebuild = True
|
|
1246
|
+
if cache_needs_rebuild:
|
|
1247
|
+
logger.info(
|
|
1248
|
+
"Cached sequence batches missing quality or read-span data; rebuilding cache."
|
|
1249
|
+
)
|
|
1250
|
+
else:
|
|
1251
|
+
logger.debug("Found existing integer-encoded reads, using these")
|
|
1252
|
+
if cache_needs_rebuild:
|
|
634
1253
|
for bami, bam in enumerate(bam_path_list):
|
|
635
|
-
|
|
1254
|
+
logger.info(
|
|
1255
|
+
f"Extracting base level sequences, qualities, reference spans, and mismatches per read for bam {bami}"
|
|
1256
|
+
)
|
|
636
1257
|
for record in records_to_analyze:
|
|
637
1258
|
current_reference_length = reference_dict[record][0]
|
|
638
1259
|
positions = range(current_reference_length)
|
|
639
1260
|
ref_seq = reference_dict[record][1]
|
|
640
|
-
# Extract the base identities of reads aligned to the record
|
|
641
1261
|
(
|
|
642
1262
|
fwd_base_identities,
|
|
643
1263
|
rev_base_identities,
|
|
644
|
-
|
|
645
|
-
|
|
1264
|
+
_mismatch_counts_per_read,
|
|
1265
|
+
_mismatch_trend_per_read,
|
|
1266
|
+
mismatch_base_identities,
|
|
1267
|
+
base_quality_scores,
|
|
1268
|
+
read_span_masks,
|
|
646
1269
|
) = extract_base_identities(
|
|
647
1270
|
bam, record, positions, max_reference_length, ref_seq, samtools_backend
|
|
648
1271
|
)
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
1272
|
+
mismatch_fwd = {
|
|
1273
|
+
read_name: mismatch_base_identities[read_name]
|
|
1274
|
+
for read_name in fwd_base_identities
|
|
1275
|
+
}
|
|
1276
|
+
mismatch_rev = {
|
|
1277
|
+
read_name: mismatch_base_identities[read_name]
|
|
1278
|
+
for read_name in rev_base_identities
|
|
1279
|
+
}
|
|
1280
|
+
quality_fwd = {
|
|
1281
|
+
read_name: base_quality_scores[read_name] for read_name in fwd_base_identities
|
|
1282
|
+
}
|
|
1283
|
+
quality_rev = {
|
|
1284
|
+
read_name: base_quality_scores[read_name] for read_name in rev_base_identities
|
|
1285
|
+
}
|
|
1286
|
+
read_span_fwd = {
|
|
1287
|
+
read_name: read_span_masks[read_name] for read_name in fwd_base_identities
|
|
1288
|
+
}
|
|
1289
|
+
read_span_rev = {
|
|
1290
|
+
read_name: read_span_masks[read_name] for read_name in rev_base_identities
|
|
1291
|
+
}
|
|
1292
|
+
fwd_sequence_files = _write_sequence_batches(
|
|
654
1293
|
fwd_base_identities,
|
|
655
1294
|
tmp_dir,
|
|
656
1295
|
record,
|
|
657
1296
|
f"{bami}_fwd",
|
|
1297
|
+
MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
|
|
1298
|
+
current_reference_length,
|
|
658
1299
|
batch_size=100000,
|
|
659
|
-
threads=threads,
|
|
660
1300
|
)
|
|
661
|
-
|
|
1301
|
+
rev_sequence_files = _write_sequence_batches(
|
|
662
1302
|
rev_base_identities,
|
|
663
1303
|
tmp_dir,
|
|
664
1304
|
record,
|
|
665
1305
|
f"{bami}_rev",
|
|
1306
|
+
MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
|
|
1307
|
+
current_reference_length,
|
|
1308
|
+
batch_size=100000,
|
|
1309
|
+
)
|
|
1310
|
+
sequence_batch_files[f"{bami}_{record}"] = fwd_sequence_files + rev_sequence_files
|
|
1311
|
+
mismatch_fwd_files = _write_integer_batches(
|
|
1312
|
+
mismatch_fwd,
|
|
1313
|
+
tmp_dir,
|
|
1314
|
+
record,
|
|
1315
|
+
f"{bami}_mismatch_fwd",
|
|
666
1316
|
batch_size=100000,
|
|
667
|
-
threads=threads,
|
|
668
1317
|
)
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
1318
|
+
mismatch_rev_files = _write_integer_batches(
|
|
1319
|
+
mismatch_rev,
|
|
1320
|
+
tmp_dir,
|
|
1321
|
+
record,
|
|
1322
|
+
f"{bami}_mismatch_rev",
|
|
1323
|
+
batch_size=100000,
|
|
1324
|
+
)
|
|
1325
|
+
mismatch_batch_files[f"{bami}_{record}"] = mismatch_fwd_files + mismatch_rev_files
|
|
1326
|
+
quality_fwd_files = _write_integer_batches(
|
|
1327
|
+
quality_fwd,
|
|
1328
|
+
tmp_dir,
|
|
1329
|
+
record,
|
|
1330
|
+
f"{bami}_quality_fwd",
|
|
1331
|
+
batch_size=100000,
|
|
1332
|
+
)
|
|
1333
|
+
quality_rev_files = _write_integer_batches(
|
|
1334
|
+
quality_rev,
|
|
1335
|
+
tmp_dir,
|
|
1336
|
+
record,
|
|
1337
|
+
f"{bami}_quality_rev",
|
|
1338
|
+
batch_size=100000,
|
|
1339
|
+
)
|
|
1340
|
+
quality_batch_files[f"{bami}_{record}"] = quality_fwd_files + quality_rev_files
|
|
1341
|
+
read_span_fwd_files = _write_integer_batches(
|
|
1342
|
+
read_span_fwd,
|
|
1343
|
+
tmp_dir,
|
|
1344
|
+
record,
|
|
1345
|
+
f"{bami}_read_span_fwd",
|
|
1346
|
+
batch_size=100000,
|
|
1347
|
+
)
|
|
1348
|
+
read_span_rev_files = _write_integer_batches(
|
|
1349
|
+
read_span_rev,
|
|
1350
|
+
tmp_dir,
|
|
1351
|
+
record,
|
|
1352
|
+
f"{bami}_read_span_rev",
|
|
1353
|
+
batch_size=100000,
|
|
1354
|
+
)
|
|
1355
|
+
read_span_batch_files[f"{bami}_{record}"] = (
|
|
1356
|
+
read_span_fwd_files + read_span_rev_files
|
|
1357
|
+
)
|
|
1358
|
+
del (
|
|
1359
|
+
fwd_base_identities,
|
|
1360
|
+
rev_base_identities,
|
|
1361
|
+
mismatch_base_identities,
|
|
1362
|
+
base_quality_scores,
|
|
1363
|
+
read_span_masks,
|
|
1364
|
+
)
|
|
1365
|
+
ad.AnnData(
|
|
1366
|
+
X=np.random.rand(1, 1),
|
|
1367
|
+
uns={
|
|
1368
|
+
"sequence_batch_files": sequence_batch_files,
|
|
1369
|
+
"mismatch_batch_files": mismatch_batch_files,
|
|
1370
|
+
"quality_batch_files": quality_batch_files,
|
|
1371
|
+
"read_span_batch_files": read_span_batch_files,
|
|
1372
|
+
},
|
|
1373
|
+
).write_h5ad(sequence_cache_path)
|
|
675
1374
|
##########################################################################################
|
|
676
1375
|
|
|
677
1376
|
##########################################################################################
|
|
@@ -713,47 +1412,9 @@ def modkit_extract_to_adata(
|
|
|
713
1412
|
###################################################
|
|
714
1413
|
### Add the tsvs as dataframes to a dictionary (dict_total) keyed by integer index. Also make modification specific dictionaries and strand specific dictionaries.
|
|
715
1414
|
# # Initialize dictionaries and place them in a list
|
|
716
|
-
(
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
dict_a_bottom,
|
|
720
|
-
dict_a_top,
|
|
721
|
-
dict_c,
|
|
722
|
-
dict_c_bottom,
|
|
723
|
-
dict_c_top,
|
|
724
|
-
dict_combined_bottom,
|
|
725
|
-
dict_combined_top,
|
|
726
|
-
) = {}, {}, {}, {}, {}, {}, {}, {}, {}
|
|
727
|
-
dict_list = [
|
|
728
|
-
dict_total,
|
|
729
|
-
dict_a,
|
|
730
|
-
dict_a_bottom,
|
|
731
|
-
dict_a_top,
|
|
732
|
-
dict_c,
|
|
733
|
-
dict_c_bottom,
|
|
734
|
-
dict_c_top,
|
|
735
|
-
dict_combined_bottom,
|
|
736
|
-
dict_combined_top,
|
|
737
|
-
]
|
|
738
|
-
# Give names to represent each dictionary in the list
|
|
739
|
-
sample_types = [
|
|
740
|
-
"total",
|
|
741
|
-
"m6A",
|
|
742
|
-
"m6A_bottom_strand",
|
|
743
|
-
"m6A_top_strand",
|
|
744
|
-
"5mC",
|
|
745
|
-
"5mC_bottom_strand",
|
|
746
|
-
"5mC_top_strand",
|
|
747
|
-
"combined_bottom_strand",
|
|
748
|
-
"combined_top_strand",
|
|
749
|
-
]
|
|
750
|
-
# Give indices of dictionaries to skip for analysis and final dictionary saving.
|
|
751
|
-
dict_to_skip = [0, 1, 4]
|
|
752
|
-
combined_dicts = [7, 8]
|
|
753
|
-
A_stranded_dicts = [2, 3]
|
|
754
|
-
C_stranded_dicts = [5, 6]
|
|
755
|
-
dict_to_skip = dict_to_skip + combined_dicts + A_stranded_dicts + C_stranded_dicts
|
|
756
|
-
dict_to_skip = set(dict_to_skip)
|
|
1415
|
+
batch_dicts = ModkitBatchDictionaries()
|
|
1416
|
+
dict_list = batch_dicts.as_list()
|
|
1417
|
+
sample_types = batch_dicts.sample_types
|
|
757
1418
|
|
|
758
1419
|
# # Step 1):Load the dict_total dictionary with all of the batch tsv files as dataframes.
|
|
759
1420
|
dict_total = parallel_load_tsvs(
|
|
@@ -765,140 +1426,9 @@ def modkit_extract_to_adata(
|
|
|
765
1426
|
threads=threads,
|
|
766
1427
|
)
|
|
767
1428
|
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
# c5m_dict, c5m_minus_strand, c5m_plus_strand,
|
|
772
|
-
# combined_minus_strand, combined_plus_strand) = merge_modification_results(processed_mod_results, mods)
|
|
773
|
-
|
|
774
|
-
# # Create dict_list with the desired ordering:
|
|
775
|
-
# # 0: dict_total, 1: m6A, 2: m6A_minus, 3: m6A_plus, 4: 5mC, 5: 5mC_minus, 6: 5mC_plus, 7: combined_minus, 8: combined_plus
|
|
776
|
-
# dict_list = [dict_total, m6A_dict, m6A_minus_strand, m6A_plus_strand,
|
|
777
|
-
# c5m_dict, c5m_minus_strand, c5m_plus_strand,
|
|
778
|
-
# combined_minus_strand, combined_plus_strand]
|
|
779
|
-
|
|
780
|
-
# # Initialize dict_to_skip (default skip all mod-specific indices)
|
|
781
|
-
# dict_to_skip = set([0, 1, 4, 7, 8, 2, 3, 5, 6])
|
|
782
|
-
# # Update dict_to_skip based on modifications present in mods
|
|
783
|
-
# dict_to_skip = update_dict_to_skip(dict_to_skip, mods)
|
|
784
|
-
|
|
785
|
-
# # Step 3: Process stranded methylation data in parallel
|
|
786
|
-
# dict_list = parallel_extract_stranded_methylation(dict_list, dict_to_skip, max_reference_length, threads=threads or 4)
|
|
787
|
-
|
|
788
|
-
# Iterate over dict_total of all the tsv files and extract the modification specific and strand specific dataframes into dictionaries
|
|
789
|
-
for record in dict_total.keys():
|
|
790
|
-
for sample_index in dict_total[record].keys():
|
|
791
|
-
if "6mA" in mods:
|
|
792
|
-
# Remove Adenine stranded dicts from the dicts to skip set
|
|
793
|
-
dict_to_skip.difference_update(set(A_stranded_dicts))
|
|
794
|
-
|
|
795
|
-
if (
|
|
796
|
-
record not in dict_a.keys()
|
|
797
|
-
and record not in dict_a_bottom.keys()
|
|
798
|
-
and record not in dict_a_top.keys()
|
|
799
|
-
):
|
|
800
|
-
dict_a[record], dict_a_bottom[record], dict_a_top[record] = {}, {}, {}
|
|
801
|
-
|
|
802
|
-
# get a dictionary of dataframes that only contain methylated adenine positions
|
|
803
|
-
dict_a[record][sample_index] = dict_total[record][sample_index][
|
|
804
|
-
dict_total[record][sample_index]["modified_primary_base"] == "A"
|
|
805
|
-
]
|
|
806
|
-
logger.debug(
|
|
807
|
-
"Successfully loaded a methyl-adenine dictionary for {}".format(
|
|
808
|
-
str(sample_index)
|
|
809
|
-
)
|
|
810
|
-
)
|
|
811
|
-
|
|
812
|
-
# Stratify the adenine dictionary into two strand specific dictionaries.
|
|
813
|
-
dict_a_bottom[record][sample_index] = dict_a[record][sample_index][
|
|
814
|
-
dict_a[record][sample_index]["ref_strand"] == "-"
|
|
815
|
-
]
|
|
816
|
-
logger.debug(
|
|
817
|
-
"Successfully loaded a minus strand methyl-adenine dictionary for {}".format(
|
|
818
|
-
str(sample_index)
|
|
819
|
-
)
|
|
820
|
-
)
|
|
821
|
-
dict_a_top[record][sample_index] = dict_a[record][sample_index][
|
|
822
|
-
dict_a[record][sample_index]["ref_strand"] == "+"
|
|
823
|
-
]
|
|
824
|
-
logger.debug(
|
|
825
|
-
"Successfully loaded a plus strand methyl-adenine dictionary for ".format(
|
|
826
|
-
str(sample_index)
|
|
827
|
-
)
|
|
828
|
-
)
|
|
829
|
-
|
|
830
|
-
# Reassign pointer for dict_a to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
831
|
-
dict_a[record][sample_index] = None
|
|
832
|
-
gc.collect()
|
|
833
|
-
|
|
834
|
-
if "5mC" in mods:
|
|
835
|
-
# Remove Cytosine stranded dicts from the dicts to skip set
|
|
836
|
-
dict_to_skip.difference_update(set(C_stranded_dicts))
|
|
837
|
-
|
|
838
|
-
if (
|
|
839
|
-
record not in dict_c.keys()
|
|
840
|
-
and record not in dict_c_bottom.keys()
|
|
841
|
-
and record not in dict_c_top.keys()
|
|
842
|
-
):
|
|
843
|
-
dict_c[record], dict_c_bottom[record], dict_c_top[record] = {}, {}, {}
|
|
844
|
-
|
|
845
|
-
# get a dictionary of dataframes that only contain methylated cytosine positions
|
|
846
|
-
dict_c[record][sample_index] = dict_total[record][sample_index][
|
|
847
|
-
dict_total[record][sample_index]["modified_primary_base"] == "C"
|
|
848
|
-
]
|
|
849
|
-
logger.debug(
|
|
850
|
-
"Successfully loaded a methyl-cytosine dictionary for {}".format(
|
|
851
|
-
str(sample_index)
|
|
852
|
-
)
|
|
853
|
-
)
|
|
854
|
-
# Stratify the cytosine dictionary into two strand specific dictionaries.
|
|
855
|
-
dict_c_bottom[record][sample_index] = dict_c[record][sample_index][
|
|
856
|
-
dict_c[record][sample_index]["ref_strand"] == "-"
|
|
857
|
-
]
|
|
858
|
-
logger.debug(
|
|
859
|
-
"Successfully loaded a minus strand methyl-cytosine dictionary for {}".format(
|
|
860
|
-
str(sample_index)
|
|
861
|
-
)
|
|
862
|
-
)
|
|
863
|
-
dict_c_top[record][sample_index] = dict_c[record][sample_index][
|
|
864
|
-
dict_c[record][sample_index]["ref_strand"] == "+"
|
|
865
|
-
]
|
|
866
|
-
logger.debug(
|
|
867
|
-
"Successfully loaded a plus strand methyl-cytosine dictionary for {}".format(
|
|
868
|
-
str(sample_index)
|
|
869
|
-
)
|
|
870
|
-
)
|
|
871
|
-
# Reassign pointer for dict_c to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
872
|
-
dict_c[record][sample_index] = None
|
|
873
|
-
gc.collect()
|
|
874
|
-
|
|
875
|
-
if "6mA" in mods and "5mC" in mods:
|
|
876
|
-
# Remove combined stranded dicts from the dicts to skip set
|
|
877
|
-
dict_to_skip.difference_update(set(combined_dicts))
|
|
878
|
-
# Initialize the sample keys for the combined dictionaries
|
|
879
|
-
|
|
880
|
-
if (
|
|
881
|
-
record not in dict_combined_bottom.keys()
|
|
882
|
-
and record not in dict_combined_top.keys()
|
|
883
|
-
):
|
|
884
|
-
dict_combined_bottom[record], dict_combined_top[record] = {}, {}
|
|
885
|
-
|
|
886
|
-
logger.debug(
|
|
887
|
-
"Successfully created a minus strand combined methylation dictionary for {}".format(
|
|
888
|
-
str(sample_index)
|
|
889
|
-
)
|
|
890
|
-
)
|
|
891
|
-
dict_combined_bottom[record][sample_index] = []
|
|
892
|
-
logger.debug(
|
|
893
|
-
"Successfully created a plus strand combined methylation dictionary for {}".format(
|
|
894
|
-
str(sample_index)
|
|
895
|
-
)
|
|
896
|
-
)
|
|
897
|
-
dict_combined_top[record][sample_index] = []
|
|
898
|
-
|
|
899
|
-
# Reassign pointer for dict_total to None and delete the original value that it pointed to in order to decrease memory usage.
|
|
900
|
-
dict_total[record][sample_index] = None
|
|
901
|
-
gc.collect()
|
|
1429
|
+
batch_dicts, dict_to_skip = _build_modification_dicts(dict_total, mods)
|
|
1430
|
+
dict_list = batch_dicts.as_list()
|
|
1431
|
+
sample_types = batch_dicts.sample_types
|
|
902
1432
|
|
|
903
1433
|
# Iterate over the stranded modification dictionaries and replace the dataframes with a dictionary of read names pointing to a list of values from the dataframe
|
|
904
1434
|
for dict_index, dict_type in enumerate(dict_list):
|
|
@@ -984,14 +1514,14 @@ def modkit_extract_to_adata(
|
|
|
984
1514
|
mod_strand_record_sample_dict[sample] = {}
|
|
985
1515
|
|
|
986
1516
|
# Get relevant columns as NumPy arrays
|
|
987
|
-
read_ids = temp_df[
|
|
988
|
-
positions = temp_df[
|
|
989
|
-
call_codes = temp_df[
|
|
990
|
-
probabilities = temp_df[
|
|
1517
|
+
read_ids = temp_df[MODKIT_EXTRACT_TSV_COLUMN_READ_ID].values
|
|
1518
|
+
positions = temp_df[MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION].values
|
|
1519
|
+
call_codes = temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_CODE].values
|
|
1520
|
+
probabilities = temp_df[MODKIT_EXTRACT_TSV_COLUMN_CALL_PROB].values
|
|
991
1521
|
|
|
992
1522
|
# Define valid call code categories
|
|
993
|
-
modified_codes =
|
|
994
|
-
canonical_codes =
|
|
1523
|
+
modified_codes = MODKIT_EXTRACT_CALL_CODE_MODIFIED
|
|
1524
|
+
canonical_codes = MODKIT_EXTRACT_CALL_CODE_CANONICAL
|
|
995
1525
|
|
|
996
1526
|
# Vectorized methylation calculation with NaN for other codes
|
|
997
1527
|
methylation_prob = np.full_like(
|
|
@@ -1087,39 +1617,63 @@ def modkit_extract_to_adata(
|
|
|
1087
1617
|
final_sample_index,
|
|
1088
1618
|
)
|
|
1089
1619
|
)
|
|
1090
|
-
temp_adata.obs[
|
|
1620
|
+
temp_adata.obs[SAMPLE] = [
|
|
1091
1621
|
sample_name_map[final_sample_index]
|
|
1092
1622
|
] * len(temp_adata)
|
|
1093
|
-
temp_adata.obs[
|
|
1623
|
+
temp_adata.obs[BARCODE] = [barcode_map[final_sample_index]] * len(
|
|
1094
1624
|
temp_adata
|
|
1095
1625
|
)
|
|
1096
|
-
temp_adata.obs[
|
|
1097
|
-
temp_adata.obs[
|
|
1098
|
-
temp_adata.obs[
|
|
1099
|
-
temp_adata.obs[
|
|
1626
|
+
temp_adata.obs[REFERENCE] = [f"{record}"] * len(temp_adata)
|
|
1627
|
+
temp_adata.obs[STRAND] = [strand] * len(temp_adata)
|
|
1628
|
+
temp_adata.obs[DATASET] = [dataset] * len(temp_adata)
|
|
1629
|
+
temp_adata.obs[REFERENCE_DATASET_STRAND] = [
|
|
1100
1630
|
f"{record}_{dataset}_{strand}"
|
|
1101
1631
|
] * len(temp_adata)
|
|
1102
|
-
temp_adata.obs[
|
|
1632
|
+
temp_adata.obs[REFERENCE_STRAND] = [f"{record}_{strand}"] * len(
|
|
1103
1633
|
temp_adata
|
|
1104
1634
|
)
|
|
1105
1635
|
|
|
1106
|
-
# Load
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1636
|
+
# Load integer-encoded reads for the current sample/record
|
|
1637
|
+
sequence_files = _normalize_sequence_batch_files(
|
|
1638
|
+
sequence_batch_files.get(f"{final_sample_index}_{record}", [])
|
|
1639
|
+
)
|
|
1640
|
+
mismatch_files = _normalize_sequence_batch_files(
|
|
1641
|
+
mismatch_batch_files.get(f"{final_sample_index}_{record}", [])
|
|
1642
|
+
)
|
|
1643
|
+
quality_files = _normalize_sequence_batch_files(
|
|
1644
|
+
quality_batch_files.get(f"{final_sample_index}_{record}", [])
|
|
1645
|
+
)
|
|
1646
|
+
read_span_files = _normalize_sequence_batch_files(
|
|
1647
|
+
read_span_batch_files.get(f"{final_sample_index}_{record}", [])
|
|
1648
|
+
)
|
|
1649
|
+
if not sequence_files:
|
|
1650
|
+
logger.warning(
|
|
1651
|
+
"No encoded sequence batches found for sample %s record %s",
|
|
1652
|
+
final_sample_index,
|
|
1653
|
+
record,
|
|
1654
|
+
)
|
|
1655
|
+
continue
|
|
1656
|
+
logger.info(f"Loading encoded sequences from {sequence_files}")
|
|
1657
|
+
(
|
|
1658
|
+
encoded_reads,
|
|
1659
|
+
fwd_mapped_reads,
|
|
1660
|
+
rev_mapped_reads,
|
|
1661
|
+
) = _load_sequence_batches(sequence_files)
|
|
1662
|
+
mismatch_reads: dict[str, np.ndarray] = {}
|
|
1663
|
+
if mismatch_files:
|
|
1664
|
+
(
|
|
1665
|
+
mismatch_reads,
|
|
1666
|
+
_mismatch_fwd_reads,
|
|
1667
|
+
_mismatch_rev_reads,
|
|
1668
|
+
) = _load_sequence_batches(mismatch_files)
|
|
1669
|
+
quality_reads: dict[str, np.ndarray] = {}
|
|
1670
|
+
if quality_files:
|
|
1671
|
+
quality_reads = _load_integer_batches(quality_files)
|
|
1672
|
+
read_span_reads: dict[str, np.ndarray] = {}
|
|
1673
|
+
if read_span_files:
|
|
1674
|
+
read_span_reads = _load_integer_batches(read_span_files)
|
|
1675
|
+
|
|
1676
|
+
read_names = list(encoded_reads.keys())
|
|
1123
1677
|
|
|
1124
1678
|
read_mapping_direction = []
|
|
1125
1679
|
for read_id in temp_adata.obs_names:
|
|
@@ -1130,57 +1684,69 @@ def modkit_extract_to_adata(
|
|
|
1130
1684
|
else:
|
|
1131
1685
|
read_mapping_direction.append("unk")
|
|
1132
1686
|
|
|
1133
|
-
temp_adata.obs[
|
|
1687
|
+
temp_adata.obs[READ_MAPPING_DIRECTION] = read_mapping_direction
|
|
1134
1688
|
|
|
1135
1689
|
del temp_df
|
|
1136
1690
|
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1691
|
+
padding_value = MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[
|
|
1692
|
+
MODKIT_EXTRACT_SEQUENCE_PADDING_BASE
|
|
1693
|
+
]
|
|
1694
|
+
sequence_length = encoded_reads[read_names[0]].shape[0]
|
|
1695
|
+
encoded_matrix = np.full(
|
|
1696
|
+
(len(sorted_index), sequence_length),
|
|
1697
|
+
padding_value,
|
|
1698
|
+
dtype=np.int16,
|
|
1140
1699
|
)
|
|
1141
|
-
df_A = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
1142
|
-
df_C = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
1143
|
-
df_G = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
1144
|
-
df_T = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
1145
|
-
df_N = np.zeros((len(sorted_index), sequence_length), dtype=int)
|
|
1146
|
-
|
|
1147
|
-
# Process one-hot data into dictionaries
|
|
1148
|
-
dict_A, dict_C, dict_G, dict_T, dict_N = {}, {}, {}, {}, {}
|
|
1149
|
-
for read_name, one_hot_array in one_hot_reads.items():
|
|
1150
|
-
one_hot_array = one_hot_array.reshape(n_rows_OHE, -1)
|
|
1151
|
-
dict_A[read_name] = one_hot_array[0, :]
|
|
1152
|
-
dict_C[read_name] = one_hot_array[1, :]
|
|
1153
|
-
dict_G[read_name] = one_hot_array[2, :]
|
|
1154
|
-
dict_T[read_name] = one_hot_array[3, :]
|
|
1155
|
-
dict_N[read_name] = one_hot_array[4, :]
|
|
1156
|
-
|
|
1157
|
-
del one_hot_reads
|
|
1158
|
-
gc.collect()
|
|
1159
1700
|
|
|
1160
|
-
# Fill the arrays
|
|
1161
1701
|
for j, read_name in tqdm(
|
|
1162
1702
|
enumerate(sorted_index),
|
|
1163
|
-
desc="Loading
|
|
1703
|
+
desc="Loading integer-encoded reads",
|
|
1164
1704
|
total=len(sorted_index),
|
|
1165
1705
|
):
|
|
1166
|
-
|
|
1167
|
-
df_C[j, :] = dict_C[read_name]
|
|
1168
|
-
df_G[j, :] = dict_G[read_name]
|
|
1169
|
-
df_T[j, :] = dict_T[read_name]
|
|
1170
|
-
df_N[j, :] = dict_N[read_name]
|
|
1706
|
+
encoded_matrix[j, :] = encoded_reads[read_name]
|
|
1171
1707
|
|
|
1172
|
-
del
|
|
1708
|
+
del encoded_reads
|
|
1173
1709
|
gc.collect()
|
|
1174
1710
|
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1711
|
+
temp_adata.layers[SEQUENCE_INTEGER_ENCODING] = encoded_matrix
|
|
1712
|
+
if mismatch_reads:
|
|
1713
|
+
current_reference_length = reference_dict[record][0]
|
|
1714
|
+
default_mismatch_sequence = np.full(
|
|
1715
|
+
sequence_length,
|
|
1716
|
+
MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT["N"],
|
|
1717
|
+
dtype=np.int16,
|
|
1180
1718
|
)
|
|
1181
|
-
|
|
1182
|
-
|
|
1719
|
+
if current_reference_length < sequence_length:
|
|
1720
|
+
default_mismatch_sequence[current_reference_length:] = (
|
|
1721
|
+
padding_value
|
|
1722
|
+
)
|
|
1723
|
+
mismatch_matrix = np.vstack(
|
|
1724
|
+
[
|
|
1725
|
+
mismatch_reads.get(read_name, default_mismatch_sequence)
|
|
1726
|
+
for read_name in sorted_index
|
|
1727
|
+
]
|
|
1728
|
+
)
|
|
1729
|
+
temp_adata.layers[MISMATCH_INTEGER_ENCODING] = mismatch_matrix
|
|
1730
|
+
if quality_reads:
|
|
1731
|
+
default_quality_sequence = np.full(
|
|
1732
|
+
sequence_length, -1, dtype=np.int16
|
|
1733
|
+
)
|
|
1734
|
+
quality_matrix = np.vstack(
|
|
1735
|
+
[
|
|
1736
|
+
quality_reads.get(read_name, default_quality_sequence)
|
|
1737
|
+
for read_name in sorted_index
|
|
1738
|
+
]
|
|
1183
1739
|
)
|
|
1740
|
+
temp_adata.layers[BASE_QUALITY_SCORES] = quality_matrix
|
|
1741
|
+
if read_span_reads:
|
|
1742
|
+
default_read_span = np.zeros(sequence_length, dtype=np.int16)
|
|
1743
|
+
read_span_matrix = np.vstack(
|
|
1744
|
+
[
|
|
1745
|
+
read_span_reads.get(read_name, default_read_span)
|
|
1746
|
+
for read_name in sorted_index
|
|
1747
|
+
]
|
|
1748
|
+
)
|
|
1749
|
+
temp_adata.layers[READ_SPAN_MASK] = read_span_matrix
|
|
1184
1750
|
|
|
1185
1751
|
# If final adata object already has a sample loaded, concatenate the current sample into the existing adata object
|
|
1186
1752
|
if adata:
|
|
@@ -1273,8 +1839,14 @@ def modkit_extract_to_adata(
|
|
|
1273
1839
|
for col in final_adata.obs.columns:
|
|
1274
1840
|
final_adata.obs[col] = final_adata.obs[col].astype("category")
|
|
1275
1841
|
|
|
1276
|
-
|
|
1277
|
-
|
|
1842
|
+
final_adata.uns[f"{SEQUENCE_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
|
|
1843
|
+
final_adata.uns[f"{MISMATCH_INTEGER_ENCODING}_map"] = dict(MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT)
|
|
1844
|
+
final_adata.uns[f"{SEQUENCE_INTEGER_DECODING}_map"] = {
|
|
1845
|
+
str(key): value for key, value in MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE.items()
|
|
1846
|
+
}
|
|
1847
|
+
|
|
1848
|
+
consensus_bases = MODKIT_EXTRACT_SEQUENCE_BASES[:4] # ignore N/PAD for consensus
|
|
1849
|
+
consensus_base_ints = [MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT[base] for base in consensus_bases]
|
|
1278
1850
|
final_adata.uns["References"] = {}
|
|
1279
1851
|
for record in records_to_analyze:
|
|
1280
1852
|
# Add FASTA sequence to the object
|
|
@@ -1285,27 +1857,33 @@ def modkit_extract_to_adata(
|
|
|
1285
1857
|
final_adata.uns[f"{record}_FASTA_sequence"] = sequence
|
|
1286
1858
|
final_adata.uns["References"][f"{record}_FASTA_sequence"] = sequence
|
|
1287
1859
|
# Add consensus sequence of samples mapped to the record to the object
|
|
1288
|
-
record_subset = final_adata[final_adata.obs[
|
|
1289
|
-
for strand in record_subset.obs[
|
|
1290
|
-
strand_subset = record_subset[record_subset.obs[
|
|
1291
|
-
for mapping_dir in strand_subset.obs[
|
|
1860
|
+
record_subset = final_adata[final_adata.obs[REFERENCE] == record]
|
|
1861
|
+
for strand in record_subset.obs[STRAND].cat.categories:
|
|
1862
|
+
strand_subset = record_subset[record_subset.obs[STRAND] == strand]
|
|
1863
|
+
for mapping_dir in strand_subset.obs[READ_MAPPING_DIRECTION].cat.categories:
|
|
1292
1864
|
mapping_dir_subset = strand_subset[
|
|
1293
|
-
strand_subset.obs[
|
|
1865
|
+
strand_subset.obs[READ_MAPPING_DIRECTION] == mapping_dir
|
|
1866
|
+
]
|
|
1867
|
+
encoded_sequences = mapping_dir_subset.layers[SEQUENCE_INTEGER_ENCODING]
|
|
1868
|
+
layer_counts = [
|
|
1869
|
+
np.sum(encoded_sequences == base_int, axis=0)
|
|
1870
|
+
for base_int in consensus_base_ints
|
|
1294
1871
|
]
|
|
1295
|
-
layer_map, layer_counts = {}, []
|
|
1296
|
-
for i, layer in enumerate(ohe_layers):
|
|
1297
|
-
layer_map[i] = layer.split("_")[0]
|
|
1298
|
-
layer_counts.append(np.sum(mapping_dir_subset.layers[layer], axis=0))
|
|
1299
1872
|
count_array = np.array(layer_counts)
|
|
1300
1873
|
nucleotide_indexes = np.argmax(count_array, axis=0)
|
|
1301
|
-
consensus_sequence_list = [
|
|
1874
|
+
consensus_sequence_list = [consensus_bases[i] for i in nucleotide_indexes]
|
|
1875
|
+
no_calls_mask = np.sum(count_array, axis=0) == 0
|
|
1876
|
+
if np.any(no_calls_mask):
|
|
1877
|
+
consensus_sequence_list = np.array(consensus_sequence_list, dtype=object)
|
|
1878
|
+
consensus_sequence_list[no_calls_mask] = "N"
|
|
1879
|
+
consensus_sequence_list = consensus_sequence_list.tolist()
|
|
1302
1880
|
final_adata.var[
|
|
1303
1881
|
f"{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples"
|
|
1304
1882
|
] = consensus_sequence_list
|
|
1305
1883
|
|
|
1306
1884
|
if input_already_demuxed:
|
|
1307
|
-
final_adata.obs[
|
|
1308
|
-
final_adata.obs[
|
|
1885
|
+
final_adata.obs[DEMUX_TYPE] = ["already"] * final_adata.shape[0]
|
|
1886
|
+
final_adata.obs[DEMUX_TYPE] = final_adata.obs[DEMUX_TYPE].astype("category")
|
|
1309
1887
|
else:
|
|
1310
1888
|
from .h5ad_functions import add_demux_type_annotation
|
|
1311
1889
|
|