smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,10 +1,19 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import subprocess
|
|
3
|
-
import glob
|
|
4
|
-
import zipfile
|
|
5
|
-
from pathlib import Path
|
|
6
2
|
|
|
7
|
-
|
|
3
|
+
from smftools.logging_utils import get_logger
|
|
4
|
+
|
|
5
|
+
logger = get_logger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def extract_mods(
|
|
9
|
+
thresholds,
|
|
10
|
+
mod_tsv_dir,
|
|
11
|
+
split_dir,
|
|
12
|
+
bam_suffix,
|
|
13
|
+
skip_unclassified=True,
|
|
14
|
+
modkit_summary=False,
|
|
15
|
+
threads=None,
|
|
16
|
+
):
|
|
8
17
|
"""
|
|
9
18
|
Takes all of the aligned, sorted, split modified BAM files and runs Nanopore Modkit Extract to load the modification data into zipped TSV files
|
|
10
19
|
|
|
@@ -23,10 +32,12 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
|
|
|
23
32
|
|
|
24
33
|
"""
|
|
25
34
|
filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
|
|
26
|
-
bam_files = sorted(
|
|
35
|
+
bam_files = sorted(
|
|
36
|
+
p for p in split_dir.iterdir() if bam_suffix in p.name and ".bai" not in p.name
|
|
37
|
+
)
|
|
27
38
|
if skip_unclassified:
|
|
28
39
|
bam_files = [p for p in bam_files if "unclassified" not in p.name]
|
|
29
|
-
|
|
40
|
+
logger.info(f"Running modkit extract for the following bam files: {bam_files}")
|
|
30
41
|
|
|
31
42
|
if threads:
|
|
32
43
|
threads = str(threads)
|
|
@@ -34,14 +45,14 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
|
|
|
34
45
|
pass
|
|
35
46
|
|
|
36
47
|
for input_file in bam_files:
|
|
37
|
-
|
|
48
|
+
logger.debug(input_file)
|
|
38
49
|
# Construct the output TSV file path
|
|
39
50
|
output_tsv = mod_tsv_dir / (input_file.stem + "_extract.tsv")
|
|
40
|
-
output_tsv_gz = output_tsv.parent / (output_tsv.name +
|
|
51
|
+
output_tsv_gz = output_tsv.parent / (output_tsv.name + ".gz")
|
|
41
52
|
if output_tsv_gz.exists():
|
|
42
|
-
|
|
53
|
+
logger.debug(f"{output_tsv_gz} already exists, skipping modkit extract")
|
|
43
54
|
else:
|
|
44
|
-
|
|
55
|
+
logger.info(f"Extracting modification data from {input_file}")
|
|
45
56
|
if modkit_summary:
|
|
46
57
|
# Run modkit summary
|
|
47
58
|
subprocess.run(["modkit", "summary", str(input_file)])
|
|
@@ -50,28 +61,43 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
|
|
|
50
61
|
# Run modkit extract
|
|
51
62
|
if threads:
|
|
52
63
|
extract_command = [
|
|
53
|
-
"modkit",
|
|
54
|
-
"
|
|
55
|
-
"
|
|
56
|
-
"--
|
|
57
|
-
"--
|
|
58
|
-
|
|
59
|
-
"-
|
|
60
|
-
|
|
61
|
-
|
|
64
|
+
"modkit",
|
|
65
|
+
"extract",
|
|
66
|
+
"calls",
|
|
67
|
+
"--mapped-only",
|
|
68
|
+
"--filter-threshold",
|
|
69
|
+
f"{filter_threshold}",
|
|
70
|
+
"--mod-thresholds",
|
|
71
|
+
f"m:{m5C_threshold}",
|
|
72
|
+
"--mod-thresholds",
|
|
73
|
+
f"a:{m6A_threshold}",
|
|
74
|
+
"--mod-thresholds",
|
|
75
|
+
f"h:{hm5C_threshold}",
|
|
76
|
+
"-t",
|
|
77
|
+
threads,
|
|
78
|
+
str(input_file),
|
|
79
|
+
str(output_tsv),
|
|
80
|
+
]
|
|
62
81
|
else:
|
|
63
82
|
extract_command = [
|
|
64
|
-
"modkit",
|
|
65
|
-
"
|
|
66
|
-
"
|
|
67
|
-
"--
|
|
68
|
-
"--
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
83
|
+
"modkit",
|
|
84
|
+
"extract",
|
|
85
|
+
"calls",
|
|
86
|
+
"--mapped-only",
|
|
87
|
+
"--filter-threshold",
|
|
88
|
+
f"{filter_threshold}",
|
|
89
|
+
"--mod-thresholds",
|
|
90
|
+
f"m:{m5C_threshold}",
|
|
91
|
+
"--mod-thresholds",
|
|
92
|
+
f"a:{m6A_threshold}",
|
|
93
|
+
"--mod-thresholds",
|
|
94
|
+
f"h:{hm5C_threshold}",
|
|
95
|
+
str(input_file),
|
|
96
|
+
str(output_tsv),
|
|
97
|
+
]
|
|
72
98
|
subprocess.run(extract_command)
|
|
73
99
|
# Zip the output TSV
|
|
74
|
-
|
|
100
|
+
logger.info(f"zipping {output_tsv}")
|
|
75
101
|
if threads:
|
|
76
102
|
zip_command = ["pigz", "-f", "-p", threads, str(output_tsv)]
|
|
77
103
|
else:
|
|
@@ -79,30 +105,39 @@ def extract_mods(thresholds, mod_tsv_dir, split_dir, bam_suffix, skip_unclassifi
|
|
|
79
105
|
subprocess.run(zip_command, check=True)
|
|
80
106
|
return
|
|
81
107
|
|
|
108
|
+
|
|
82
109
|
def make_modbed(aligned_sorted_output, thresholds, mod_bed_dir):
|
|
83
110
|
"""
|
|
84
111
|
Generating position methylation summaries for each barcoded sample starting from the overall BAM file that was direct output of dorado aligner.
|
|
85
112
|
Parameters:
|
|
86
113
|
aligned_sorted_output (str): A string representing the file path to the aligned_sorted non-split BAM file.
|
|
87
|
-
|
|
114
|
+
|
|
88
115
|
Returns:
|
|
89
116
|
None
|
|
90
117
|
"""
|
|
91
|
-
import os
|
|
92
118
|
import subprocess
|
|
93
|
-
|
|
119
|
+
|
|
94
120
|
filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
|
|
95
121
|
command = [
|
|
96
|
-
"modkit",
|
|
97
|
-
"
|
|
122
|
+
"modkit",
|
|
123
|
+
"pileup",
|
|
124
|
+
str(aligned_sorted_output),
|
|
125
|
+
str(mod_bed_dir),
|
|
126
|
+
"--partition-tag",
|
|
127
|
+
"BC",
|
|
98
128
|
"--only-tabs",
|
|
99
|
-
"--filter-threshold",
|
|
100
|
-
|
|
101
|
-
"--mod-thresholds",
|
|
102
|
-
|
|
129
|
+
"--filter-threshold",
|
|
130
|
+
f"{filter_threshold}",
|
|
131
|
+
"--mod-thresholds",
|
|
132
|
+
f"m:{m5C_threshold}",
|
|
133
|
+
"--mod-thresholds",
|
|
134
|
+
f"a:{m6A_threshold}",
|
|
135
|
+
"--mod-thresholds",
|
|
136
|
+
f"h:{hm5C_threshold}",
|
|
103
137
|
]
|
|
104
138
|
subprocess.run(command)
|
|
105
139
|
|
|
140
|
+
|
|
106
141
|
def modQC(aligned_sorted_output, thresholds):
|
|
107
142
|
"""
|
|
108
143
|
Output the percentile of bases falling at a call threshold (threshold is a probability between 0-1) for the overall BAM file.
|
|
@@ -120,10 +155,16 @@ def modQC(aligned_sorted_output, thresholds):
|
|
|
120
155
|
filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold = thresholds
|
|
121
156
|
subprocess.run(["modkit", "sample-probs", str(aligned_sorted_output)])
|
|
122
157
|
command = [
|
|
123
|
-
"modkit",
|
|
124
|
-
"
|
|
125
|
-
|
|
126
|
-
"--
|
|
127
|
-
|
|
158
|
+
"modkit",
|
|
159
|
+
"summary",
|
|
160
|
+
str(aligned_sorted_output),
|
|
161
|
+
"--filter-threshold",
|
|
162
|
+
f"{filter_threshold}",
|
|
163
|
+
"--mod-thresholds",
|
|
164
|
+
f"m:{m5C_threshold}",
|
|
165
|
+
"--mod-thresholds",
|
|
166
|
+
f"a:{m6A_threshold}",
|
|
167
|
+
"--mod-thresholds",
|
|
168
|
+
f"h:{hm5C_threshold}",
|
|
128
169
|
]
|
|
129
|
-
subprocess.run(command)
|
|
170
|
+
subprocess.run(command)
|
smftools/informatics/ohe.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
1
|
-
import
|
|
1
|
+
import concurrent.futures
|
|
2
|
+
import os
|
|
3
|
+
|
|
2
4
|
import anndata as ad
|
|
5
|
+
import numpy as np
|
|
3
6
|
|
|
4
|
-
import
|
|
5
|
-
|
|
7
|
+
from smftools.logging_utils import get_logger
|
|
8
|
+
|
|
9
|
+
logger = get_logger(__name__)
|
|
6
10
|
|
|
7
|
-
|
|
11
|
+
|
|
12
|
+
def one_hot_encode(sequence, device="auto"):
|
|
8
13
|
"""
|
|
9
14
|
One-hot encodes a DNA sequence.
|
|
10
15
|
|
|
@@ -14,7 +19,7 @@ def one_hot_encode(sequence, device='auto'):
|
|
|
14
19
|
Returns:
|
|
15
20
|
ndarray: Flattened one-hot encoded representation of the input sequence.
|
|
16
21
|
"""
|
|
17
|
-
mapping = np.array([
|
|
22
|
+
mapping = np.array(["A", "C", "G", "T", "N"])
|
|
18
23
|
|
|
19
24
|
# Ensure input is a list of characters
|
|
20
25
|
if not isinstance(sequence, list):
|
|
@@ -22,14 +27,14 @@ def one_hot_encode(sequence, device='auto'):
|
|
|
22
27
|
|
|
23
28
|
# Handle empty sequences
|
|
24
29
|
if len(sequence) == 0:
|
|
25
|
-
|
|
30
|
+
logger.warning("Empty sequence encountered in one_hot_encode()")
|
|
26
31
|
return np.zeros(len(mapping)) # Return empty encoding instead of failing
|
|
27
32
|
|
|
28
33
|
# Convert sequence to NumPy array
|
|
29
|
-
seq_array = np.array(sequence, dtype=
|
|
34
|
+
seq_array = np.array(sequence, dtype="<U1")
|
|
30
35
|
|
|
31
36
|
# Replace invalid bases with 'N'
|
|
32
|
-
seq_array = np.where(np.isin(seq_array, mapping), seq_array,
|
|
37
|
+
seq_array = np.where(np.isin(seq_array, mapping), seq_array, "N")
|
|
33
38
|
|
|
34
39
|
# Create one-hot encoding matrix
|
|
35
40
|
one_hot_matrix = (seq_array[:, None] == mapping).astype(int)
|
|
@@ -37,6 +42,7 @@ def one_hot_encode(sequence, device='auto'):
|
|
|
37
42
|
# Flatten and return
|
|
38
43
|
return one_hot_matrix.flatten()
|
|
39
44
|
|
|
45
|
+
|
|
40
46
|
def one_hot_decode(ohe_array):
|
|
41
47
|
"""
|
|
42
48
|
Takes a flattened one hot encoded array and returns the sequence string from that array.
|
|
@@ -47,20 +53,21 @@ def one_hot_decode(ohe_array):
|
|
|
47
53
|
sequence (str): Sequence string of the one hot encoded array
|
|
48
54
|
"""
|
|
49
55
|
# Define the mapping of one-hot encoded indices to DNA bases
|
|
50
|
-
mapping = [
|
|
51
|
-
|
|
56
|
+
mapping = ["A", "C", "G", "T", "N"]
|
|
57
|
+
|
|
52
58
|
# Reshape the flattened array into a 2D matrix with 5 columns (one for each base)
|
|
53
59
|
one_hot_matrix = ohe_array.reshape(-1, 5)
|
|
54
|
-
|
|
60
|
+
|
|
55
61
|
# Get the index of the maximum value (which will be 1) in each row
|
|
56
62
|
decoded_indices = np.argmax(one_hot_matrix, axis=1)
|
|
57
|
-
|
|
63
|
+
|
|
58
64
|
# Map the indices back to the corresponding bases
|
|
59
65
|
sequence_list = [mapping[i] for i in decoded_indices]
|
|
60
|
-
sequence =
|
|
61
|
-
|
|
66
|
+
sequence = "".join(sequence_list)
|
|
67
|
+
|
|
62
68
|
return sequence
|
|
63
69
|
|
|
70
|
+
|
|
64
71
|
def ohe_layers_decode(adata, obs_names):
|
|
65
72
|
"""
|
|
66
73
|
Takes an anndata object and a list of observation names. Returns a list of sequence strings for the reads of interest.
|
|
@@ -72,7 +79,7 @@ def ohe_layers_decode(adata, obs_names):
|
|
|
72
79
|
sequences (list of str): List of strings of the one hot encoded array
|
|
73
80
|
"""
|
|
74
81
|
# Define the mapping of one-hot encoded indices to DNA bases
|
|
75
|
-
mapping = [
|
|
82
|
+
mapping = ["A", "C", "G", "T", "N"]
|
|
76
83
|
|
|
77
84
|
ohe_layers = [f"{base}_binary_encoding" for base in mapping]
|
|
78
85
|
sequences = []
|
|
@@ -85,9 +92,10 @@ def ohe_layers_decode(adata, obs_names):
|
|
|
85
92
|
ohe_array = np.array(ohe_list)
|
|
86
93
|
sequence = one_hot_decode(ohe_array)
|
|
87
94
|
sequences.append(sequence)
|
|
88
|
-
|
|
95
|
+
|
|
89
96
|
return sequences
|
|
90
97
|
|
|
98
|
+
|
|
91
99
|
def _encode_sequence(args):
|
|
92
100
|
"""Parallel helper function for one-hot encoding."""
|
|
93
101
|
read_name, seq, device = args
|
|
@@ -97,18 +105,29 @@ def _encode_sequence(args):
|
|
|
97
105
|
except Exception:
|
|
98
106
|
return None # Skip invalid sequences
|
|
99
107
|
|
|
108
|
+
|
|
100
109
|
def _encode_and_save_batch(batch_data, tmp_dir, prefix, record, batch_number):
|
|
101
110
|
"""Encodes a batch and writes to disk immediately."""
|
|
102
111
|
batch = {read_name: matrix for read_name, matrix in batch_data if matrix is not None}
|
|
103
112
|
|
|
104
113
|
if batch:
|
|
105
|
-
save_name = os.path.join(tmp_dir, f
|
|
114
|
+
save_name = os.path.join(tmp_dir, f"tmp_{prefix}_{record}_{batch_number}.h5ad")
|
|
106
115
|
tmp_ad = ad.AnnData(X=np.zeros((1, 1)), uns=batch) # Placeholder X
|
|
107
116
|
tmp_ad.write_h5ad(save_name)
|
|
108
117
|
return save_name
|
|
109
118
|
return None
|
|
110
119
|
|
|
111
|
-
|
|
120
|
+
|
|
121
|
+
def ohe_batching(
|
|
122
|
+
base_identities,
|
|
123
|
+
tmp_dir,
|
|
124
|
+
record,
|
|
125
|
+
prefix="",
|
|
126
|
+
batch_size=100000,
|
|
127
|
+
progress_bar=None,
|
|
128
|
+
device="auto",
|
|
129
|
+
threads=None,
|
|
130
|
+
):
|
|
112
131
|
"""
|
|
113
132
|
Efficient version of ohe_batching: one-hot encodes sequences in parallel and writes batches immediately.
|
|
114
133
|
|
|
@@ -131,7 +150,9 @@ def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000,
|
|
|
131
150
|
file_names = []
|
|
132
151
|
|
|
133
152
|
# Step 1: Prepare Data for Parallel Encoding
|
|
134
|
-
encoding_args = [
|
|
153
|
+
encoding_args = [
|
|
154
|
+
(read_name, seq, device) for read_name, seq in base_identities.items() if seq is not None
|
|
155
|
+
]
|
|
135
156
|
|
|
136
157
|
# Step 2: Parallel One-Hot Encoding using threads (to avoid nested processes)
|
|
137
158
|
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
|
|
@@ -141,7 +162,9 @@ def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000,
|
|
|
141
162
|
|
|
142
163
|
if len(batch_data) >= batch_size:
|
|
143
164
|
# Step 3: Process and Write Batch Immediately
|
|
144
|
-
file_name = _encode_and_save_batch(
|
|
165
|
+
file_name = _encode_and_save_batch(
|
|
166
|
+
batch_data.copy(), tmp_dir, prefix, record, batch_number
|
|
167
|
+
)
|
|
145
168
|
if file_name:
|
|
146
169
|
file_names.append(file_name)
|
|
147
170
|
|
|
@@ -157,4 +180,4 @@ def ohe_batching(base_identities, tmp_dir, record, prefix='', batch_size=100000,
|
|
|
157
180
|
if file_name:
|
|
158
181
|
file_names.append(file_name)
|
|
159
182
|
|
|
160
|
-
return file_names
|
|
183
|
+
return file_names
|
|
@@ -1,26 +1,29 @@
|
|
|
1
|
-
from
|
|
2
|
-
from ..readwrite import make_dirs
|
|
1
|
+
from __future__ import annotations
|
|
3
2
|
|
|
4
3
|
import os
|
|
5
4
|
import subprocess
|
|
6
5
|
from pathlib import Path
|
|
6
|
+
from typing import Iterable
|
|
7
7
|
|
|
8
8
|
import pod5 as p5
|
|
9
9
|
|
|
10
|
-
from
|
|
10
|
+
from smftools.logging_utils import get_logger
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
12
|
+
from ..config import LoadExperimentConfig
|
|
13
|
+
from ..informatics.basecalling import canoncall, modcall
|
|
14
|
+
from ..readwrite import make_dirs
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
config_path (str): File path to the basecall configuration file
|
|
16
|
+
logger = get_logger(__name__)
|
|
18
17
|
|
|
19
|
-
|
|
20
|
-
|
|
18
|
+
|
|
19
|
+
def basecall_pod5s(config_path: str | Path) -> None:
|
|
20
|
+
"""Basecall POD5 inputs using a configuration file.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
config_path: Path to the basecall configuration file.
|
|
21
24
|
"""
|
|
22
25
|
# Default params
|
|
23
|
-
bam_suffix =
|
|
26
|
+
bam_suffix = ".bam" # If different, change from here.
|
|
24
27
|
|
|
25
28
|
# Load experiment config parameters into global variables
|
|
26
29
|
experiment_config = LoadExperimentConfig(config_path)
|
|
@@ -30,66 +33,89 @@ def basecall_pod5s(config_path):
|
|
|
30
33
|
default_value = None
|
|
31
34
|
|
|
32
35
|
# General config variable init
|
|
33
|
-
input_data_path = Path(
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
36
|
+
input_data_path = Path(
|
|
37
|
+
var_dict.get("input_data_path", default_value)
|
|
38
|
+
) # Path to a directory of POD5s/FAST5s or to a BAM/FASTQ file. Necessary.
|
|
39
|
+
output_directory = Path(
|
|
40
|
+
var_dict.get("output_directory", default_value)
|
|
41
|
+
) # Path to the output directory to make for the analysis. Necessary.
|
|
42
|
+
model = var_dict.get("model", default_value) # needed for dorado basecaller
|
|
43
|
+
model_dir = Path(var_dict.get("model_dir", default_value)) # model directory
|
|
44
|
+
barcode_kit = var_dict.get("barcode_kit", default_value) # needed for dorado basecaller
|
|
45
|
+
barcode_both_ends = var_dict.get("barcode_both_ends", default_value) # dorado demultiplexing
|
|
46
|
+
trim = var_dict.get("trim", default_value) # dorado adapter and barcode removal
|
|
47
|
+
device = var_dict.get("device", "auto")
|
|
41
48
|
|
|
42
49
|
# Modified basecalling specific variable init
|
|
43
|
-
filter_threshold = var_dict.get(
|
|
44
|
-
m6A_threshold = var_dict.get(
|
|
45
|
-
m5C_threshold = var_dict.get(
|
|
46
|
-
hm5C_threshold = var_dict.get(
|
|
50
|
+
filter_threshold = var_dict.get("filter_threshold", default_value)
|
|
51
|
+
m6A_threshold = var_dict.get("m6A_threshold", default_value)
|
|
52
|
+
m5C_threshold = var_dict.get("m5C_threshold", default_value)
|
|
53
|
+
hm5C_threshold = var_dict.get("hm5C_threshold", default_value)
|
|
47
54
|
thresholds = [filter_threshold, m6A_threshold, m5C_threshold, hm5C_threshold]
|
|
48
|
-
mod_list = var_dict.get(
|
|
49
|
-
|
|
55
|
+
mod_list = var_dict.get("mod_list", default_value)
|
|
56
|
+
|
|
50
57
|
# Make initial output directory
|
|
51
58
|
make_dirs([output_directory])
|
|
52
59
|
|
|
53
60
|
# Get the input filetype
|
|
54
61
|
if input_data_path.is_file():
|
|
55
62
|
input_data_filetype = input_data_path.suffixes[0]
|
|
56
|
-
input_is_pod5 = input_data_filetype in [
|
|
57
|
-
input_is_fast5 = input_data_filetype in [
|
|
63
|
+
input_is_pod5 = input_data_filetype in [".pod5", ".p5"]
|
|
64
|
+
input_is_fast5 = input_data_filetype in [".fast5", ".f5"]
|
|
58
65
|
|
|
59
66
|
elif input_data_path.is_dir():
|
|
60
67
|
# Get the file names in the input data dir
|
|
61
68
|
input_files = input_data_path.iterdir()
|
|
62
|
-
input_is_pod5 = sum([True for file in input_files if
|
|
63
|
-
input_is_fast5 = sum([True for file in input_files if
|
|
69
|
+
input_is_pod5 = sum([True for file in input_files if ".pod5" in file or ".p5" in file])
|
|
70
|
+
input_is_fast5 = sum([True for file in input_files if ".fast5" in file or ".f5" in file])
|
|
64
71
|
|
|
65
72
|
# If the input files are not pod5 files, and they are fast5 files, convert the files to a pod5 file before proceeding.
|
|
66
73
|
if input_is_fast5 and not input_is_pod5:
|
|
67
74
|
# take the input directory of fast5 files and write out a single pod5 file into the output directory.
|
|
68
|
-
output_pod5 = output_directory /
|
|
69
|
-
|
|
75
|
+
output_pod5 = output_directory / "FAST5s_to_POD5.pod5"
|
|
76
|
+
logger.info(
|
|
77
|
+
f"Input directory contains fast5 files, converting them and concatenating into a single pod5 file in the {output_pod5}"
|
|
78
|
+
)
|
|
70
79
|
fast5_to_pod5(input_data_path, output_pod5)
|
|
71
80
|
# Reassign the pod5_dir variable to point to the new pod5 file.
|
|
72
81
|
input_data_path = output_pod5
|
|
73
82
|
|
|
74
83
|
model_basename = model.name
|
|
75
|
-
model_basename = model_basename.replace(
|
|
84
|
+
model_basename = model_basename.replace(".", "_")
|
|
76
85
|
|
|
77
86
|
if mod_list:
|
|
78
87
|
mod_string = "_".join(mod_list)
|
|
79
88
|
bam = output_directory / f"{model_basename}_{mod_string}_calls"
|
|
80
|
-
modcall(
|
|
89
|
+
modcall(
|
|
90
|
+
model,
|
|
91
|
+
input_data_path,
|
|
92
|
+
barcode_kit,
|
|
93
|
+
mod_list,
|
|
94
|
+
bam,
|
|
95
|
+
bam_suffix,
|
|
96
|
+
barcode_both_ends,
|
|
97
|
+
trim,
|
|
98
|
+
device,
|
|
99
|
+
)
|
|
81
100
|
else:
|
|
82
101
|
bam = output_directory / f"{model_basename}_canonical_basecalls"
|
|
83
|
-
canoncall(
|
|
102
|
+
canoncall(
|
|
103
|
+
model, input_data_path, barcode_kit, bam, bam_suffix, barcode_both_ends, trim, device
|
|
104
|
+
)
|
|
84
105
|
|
|
85
106
|
|
|
86
107
|
def fast5_to_pod5(
|
|
87
|
-
fast5_dir:
|
|
88
|
-
output_pod5:
|
|
108
|
+
fast5_dir: str | Path | Iterable[str | Path],
|
|
109
|
+
output_pod5: str | Path = "FAST5s_to_POD5.pod5",
|
|
89
110
|
) -> None:
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
|
|
111
|
+
"""Convert FAST5 inputs into a single POD5 file.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
fast5_dir: FAST5 file path, directory, or iterable of file paths to convert.
|
|
115
|
+
output_pod5: Output POD5 file path.
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
FileNotFoundError: If no FAST5 files are found or the input path is invalid.
|
|
93
119
|
"""
|
|
94
120
|
|
|
95
121
|
output_pod5 = str(output_pod5) # ensure string
|
|
@@ -122,45 +148,51 @@ def fast5_to_pod5(
|
|
|
122
148
|
|
|
123
149
|
raise FileNotFoundError(f"Input path invalid: {fast5_dir}")
|
|
124
150
|
|
|
125
|
-
def subsample_pod5(pod5_path, read_name_path, output_directory):
|
|
126
|
-
"""
|
|
127
|
-
Takes a POD5 file and a text file containing read names of interest and writes out a subsampled POD5 for just those reads.
|
|
128
|
-
This is a useful function when you have a list of read names that mapped to a region of interest that you want to reanalyze from the pod5 level.
|
|
129
151
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
152
|
+
def subsample_pod5(
|
|
153
|
+
pod5_path: str | Path,
|
|
154
|
+
read_name_path: str | int,
|
|
155
|
+
output_directory: str | Path,
|
|
156
|
+
) -> None:
|
|
157
|
+
"""Write a subsampled POD5 containing selected reads.
|
|
134
158
|
|
|
135
|
-
|
|
136
|
-
|
|
159
|
+
Args:
|
|
160
|
+
pod5_path: POD5 file path or directory of POD5 files to subsample.
|
|
161
|
+
read_name_path: Path to a text file of read names (one per line) or an integer
|
|
162
|
+
specifying a random subset size.
|
|
163
|
+
output_directory: Directory to write the subsampled POD5 file.
|
|
137
164
|
"""
|
|
138
165
|
|
|
139
166
|
if os.path.isdir(pod5_path):
|
|
140
167
|
pod5_path_is_dir = True
|
|
141
|
-
input_pod5_base =
|
|
168
|
+
input_pod5_base = "input_pod5s.pod5"
|
|
142
169
|
files = os.listdir(pod5_path)
|
|
143
|
-
pod5_files = [os.path.join(pod5_path, file) for file in files if
|
|
170
|
+
pod5_files = [os.path.join(pod5_path, file) for file in files if ".pod5" in file]
|
|
144
171
|
pod5_files.sort()
|
|
145
|
-
|
|
146
|
-
|
|
172
|
+
logger.info(f"Found input pod5s: {pod5_files}")
|
|
173
|
+
|
|
147
174
|
elif os.path.exists(pod5_path):
|
|
148
175
|
pod5_path_is_dir = False
|
|
149
176
|
input_pod5_base = os.path.basename(pod5_path)
|
|
150
177
|
|
|
151
178
|
else:
|
|
152
|
-
|
|
179
|
+
logger.error("pod5_path passed does not exist")
|
|
153
180
|
return None
|
|
154
181
|
|
|
155
|
-
if type(read_name_path)
|
|
182
|
+
if type(read_name_path) is str:
|
|
156
183
|
input_read_name_base = os.path.basename(read_name_path)
|
|
157
|
-
output_base =
|
|
184
|
+
output_base = (
|
|
185
|
+
input_pod5_base.split(".pod5")[0]
|
|
186
|
+
+ "_"
|
|
187
|
+
+ input_read_name_base.split(".txt")[0]
|
|
188
|
+
+ "_subsampled.pod5"
|
|
189
|
+
)
|
|
158
190
|
|
|
159
191
|
# extract read names into a list of strings
|
|
160
|
-
with open(read_name_path,
|
|
192
|
+
with open(read_name_path, "r") as file:
|
|
161
193
|
read_names = [line.strip() for line in file]
|
|
162
194
|
|
|
163
|
-
|
|
195
|
+
logger.info(f"Looking for read_ids: {read_names}")
|
|
164
196
|
read_records = []
|
|
165
197
|
|
|
166
198
|
if pod5_path_is_dir:
|
|
@@ -168,22 +200,25 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
|
|
|
168
200
|
with p5.Reader(input_pod5) as reader:
|
|
169
201
|
try:
|
|
170
202
|
for read_record in reader.reads(selection=read_names, missing_ok=True):
|
|
171
|
-
read_records.append(read_record.to_read())
|
|
172
|
-
|
|
173
|
-
except:
|
|
174
|
-
|
|
175
|
-
else:
|
|
203
|
+
read_records.append(read_record.to_read())
|
|
204
|
+
logger.info(f"Found read in {input_pod5}: {read_record.read_id}")
|
|
205
|
+
except Exception:
|
|
206
|
+
logger.warning("Skipping pod5, could not find reads")
|
|
207
|
+
else:
|
|
176
208
|
with p5.Reader(pod5_path) as reader:
|
|
177
209
|
try:
|
|
178
210
|
for read_record in reader.reads(selection=read_names):
|
|
179
211
|
read_records.append(read_record.to_read())
|
|
180
|
-
|
|
181
|
-
except:
|
|
182
|
-
|
|
212
|
+
logger.info(f"Found read in {input_pod5}: {read_record}")
|
|
213
|
+
except Exception:
|
|
214
|
+
logger.warning("Could not find reads")
|
|
183
215
|
|
|
184
|
-
elif type(read_name_path)
|
|
216
|
+
elif type(read_name_path) is int:
|
|
185
217
|
import random
|
|
186
|
-
|
|
218
|
+
|
|
219
|
+
output_base = (
|
|
220
|
+
input_pod5_base.split(".pod5")[0] + f"_{read_name_path}_randomly_subsampled.pod5"
|
|
221
|
+
)
|
|
187
222
|
all_read_records = []
|
|
188
223
|
|
|
189
224
|
if pod5_path_is_dir:
|
|
@@ -191,7 +226,7 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
|
|
|
191
226
|
random.shuffle(pod5_files)
|
|
192
227
|
for input_pod5 in pod5_files:
|
|
193
228
|
# iterate over the input pod5s
|
|
194
|
-
|
|
229
|
+
logger.info(f"Opening pod5 file {input_pod5}")
|
|
195
230
|
with p5.Reader(pod5_path) as reader:
|
|
196
231
|
for read_record in reader.reads():
|
|
197
232
|
all_read_records.append(read_record.to_read())
|
|
@@ -202,9 +237,11 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
|
|
|
202
237
|
if read_name_path <= len(all_read_records):
|
|
203
238
|
read_records = random.sample(all_read_records, read_name_path)
|
|
204
239
|
else:
|
|
205
|
-
|
|
240
|
+
logger.info(
|
|
241
|
+
"Trying to sample more reads than are contained in the input pod5s, taking all reads"
|
|
242
|
+
)
|
|
206
243
|
read_records = all_read_records
|
|
207
|
-
|
|
244
|
+
|
|
208
245
|
else:
|
|
209
246
|
with p5.Reader(pod5_path) as reader:
|
|
210
247
|
for read_record in reader.reads():
|
|
@@ -214,11 +251,13 @@ def subsample_pod5(pod5_path, read_name_path, output_directory):
|
|
|
214
251
|
# if the subsampling amount is less than the record amount in the file, randomly subsample the reads
|
|
215
252
|
read_records = random.sample(all_read_records, read_name_path)
|
|
216
253
|
else:
|
|
217
|
-
|
|
254
|
+
logger.info(
|
|
255
|
+
"Trying to sample more reads than are contained in the input pod5s, taking all reads"
|
|
256
|
+
)
|
|
218
257
|
read_records = all_read_records
|
|
219
258
|
|
|
220
259
|
output_pod5 = os.path.join(output_directory, output_base)
|
|
221
260
|
|
|
222
261
|
# Write the subsampled POD5
|
|
223
262
|
with p5.Writer(output_pod5) as writer:
|
|
224
|
-
writer.add_reads(read_records)
|
|
263
|
+
writer.add_reads(read_records)
|