smftools 0.2.5__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +39 -7
- smftools/_settings.py +2 -0
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +2 -0
- smftools/cli/hmm_adata.py +7 -2
- smftools/cli/load_adata.py +130 -98
- smftools/cli/preprocess_adata.py +2 -0
- smftools/cli/spatial_adata.py +5 -1
- smftools/cli_entry.py +26 -1
- smftools/config/__init__.py +2 -0
- smftools/config/default.yaml +4 -1
- smftools/config/experiment_config.py +6 -0
- smftools/datasets/__init__.py +2 -0
- smftools/hmm/HMM.py +9 -3
- smftools/hmm/__init__.py +24 -13
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +2 -0
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +5 -2
- smftools/hmm/display_hmm.py +4 -1
- smftools/hmm/hmm_readwrite.py +7 -2
- smftools/hmm/nucleosome_hmm_refinement.py +2 -0
- smftools/informatics/__init__.py +53 -34
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +2 -0
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +737 -170
- smftools/informatics/basecalling.py +2 -0
- smftools/informatics/bed_functions.py +271 -61
- smftools/informatics/binarize_converted_base_identities.py +3 -0
- smftools/informatics/complement_base_list.py +2 -0
- smftools/informatics/converted_BAM_to_adata.py +66 -22
- smftools/informatics/fasta_functions.py +94 -10
- smftools/informatics/h5ad_functions.py +8 -2
- smftools/informatics/modkit_extract_to_adata.py +16 -6
- smftools/informatics/modkit_functions.py +2 -0
- smftools/informatics/ohe.py +2 -0
- smftools/informatics/pod5_functions.py +3 -2
- smftools/machine_learning/__init__.py +22 -6
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +18 -4
- smftools/machine_learning/data/preprocessing.py +2 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +2 -0
- smftools/machine_learning/evaluation/evaluators.py +14 -9
- smftools/machine_learning/inference/__init__.py +2 -0
- smftools/machine_learning/inference/inference_utils.py +2 -0
- smftools/machine_learning/inference/lightning_inference.py +6 -1
- smftools/machine_learning/inference/sklearn_inference.py +2 -0
- smftools/machine_learning/inference/sliding_window_inference.py +2 -0
- smftools/machine_learning/models/__init__.py +2 -0
- smftools/machine_learning/models/base.py +7 -2
- smftools/machine_learning/models/cnn.py +7 -2
- smftools/machine_learning/models/lightning_base.py +16 -11
- smftools/machine_learning/models/mlp.py +5 -1
- smftools/machine_learning/models/positional.py +7 -2
- smftools/machine_learning/models/rnn.py +5 -1
- smftools/machine_learning/models/sklearn_models.py +14 -9
- smftools/machine_learning/models/transformer.py +7 -2
- smftools/machine_learning/models/wrappers.py +6 -2
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +13 -3
- smftools/machine_learning/training/train_sklearn_model.py +2 -0
- smftools/machine_learning/utils/__init__.py +2 -0
- smftools/machine_learning/utils/device.py +5 -1
- smftools/machine_learning/utils/grl.py +5 -1
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -31
- smftools/plotting/autocorrelation_plotting.py +9 -5
- smftools/plotting/classifiers.py +16 -4
- smftools/plotting/general_plotting.py +6 -3
- smftools/plotting/hmm_plotting.py +12 -2
- smftools/plotting/position_stats.py +15 -7
- smftools/plotting/qc_plotting.py +6 -1
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/archived/calculate_complexity.py +2 -0
- smftools/preprocessing/archived/mark_duplicates.py +2 -0
- smftools/preprocessing/archived/preprocessing.py +2 -0
- smftools/preprocessing/archived/remove_duplicates.py +2 -0
- smftools/preprocessing/binary_layers_to_ohe.py +2 -1
- smftools/preprocessing/calculate_complexity_II.py +4 -1
- smftools/preprocessing/calculate_pairwise_differences.py +2 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
- smftools/preprocessing/calculate_position_Youden.py +9 -2
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
- smftools/preprocessing/flag_duplicate_reads.py +42 -54
- smftools/preprocessing/make_dirs.py +2 -1
- smftools/preprocessing/min_non_diagonal.py +2 -0
- smftools/preprocessing/recipes.py +2 -0
- smftools/tools/__init__.py +26 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +2 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +2 -0
- smftools/tools/archived/subset_adata_v2.py +2 -0
- smftools/tools/calculate_umap.py +3 -1
- smftools/tools/cluster_adata_on_methylation.py +7 -1
- smftools/tools/position_stats.py +17 -27
- {smftools-0.2.5.dist-info → smftools-0.3.0.dist-info}/METADATA +67 -33
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.5.dist-info/RECORD +0 -181
- {smftools-0.2.5.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,23 +1,93 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import gzip
|
|
4
|
+
import shutil
|
|
5
|
+
import subprocess
|
|
4
6
|
from concurrent.futures import ProcessPoolExecutor
|
|
7
|
+
from importlib.util import find_spec
|
|
5
8
|
from pathlib import Path
|
|
6
|
-
from typing import Dict, Iterable, Tuple
|
|
9
|
+
from typing import TYPE_CHECKING, Dict, Iterable, Tuple
|
|
7
10
|
|
|
8
11
|
import numpy as np
|
|
9
|
-
import pysam
|
|
10
12
|
from Bio import SeqIO
|
|
11
13
|
from Bio.Seq import Seq
|
|
12
14
|
from Bio.SeqRecord import SeqRecord
|
|
13
|
-
from pyfaidx import Fasta
|
|
14
15
|
|
|
15
16
|
from smftools.logging_utils import get_logger
|
|
17
|
+
from smftools.optional_imports import require
|
|
16
18
|
|
|
17
19
|
from ..readwrite import time_string
|
|
18
20
|
|
|
19
21
|
logger = get_logger(__name__)
|
|
20
22
|
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
import pysam as pysam_module
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _require_pysam() -> "pysam_module":
|
|
28
|
+
if pysam_types is not None:
|
|
29
|
+
return pysam_types
|
|
30
|
+
return require("pysam", extra="pysam", purpose="FASTA access")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
pysam_types = None
|
|
34
|
+
if find_spec("pysam") is not None:
|
|
35
|
+
pysam_types = require("pysam", extra="pysam", purpose="FASTA access")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _resolve_fasta_backend() -> str:
|
|
39
|
+
"""Resolve the backend to use for FASTA access."""
|
|
40
|
+
if pysam_types is not None:
|
|
41
|
+
return "python"
|
|
42
|
+
if shutil is not None and shutil.which("samtools"):
|
|
43
|
+
return "cli"
|
|
44
|
+
raise RuntimeError("FASTA access requires pysam or samtools in PATH.")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _ensure_fasta_index(fasta: Path) -> None:
|
|
48
|
+
fai = fasta.with_suffix(fasta.suffix + ".fai")
|
|
49
|
+
if fai.exists():
|
|
50
|
+
return
|
|
51
|
+
if subprocess is None or shutil is None or not shutil.which("samtools"):
|
|
52
|
+
pysam_mod = _require_pysam()
|
|
53
|
+
pysam_mod.faidx(str(fasta))
|
|
54
|
+
return
|
|
55
|
+
cp = subprocess.run(
|
|
56
|
+
["samtools", "faidx", str(fasta)],
|
|
57
|
+
stdout=subprocess.DEVNULL,
|
|
58
|
+
stderr=subprocess.PIPE,
|
|
59
|
+
text=True,
|
|
60
|
+
)
|
|
61
|
+
if cp.returncode != 0:
|
|
62
|
+
raise RuntimeError(f"samtools faidx failed (exit {cp.returncode}):\n{cp.stderr}")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _bed_to_faidx_region(chrom: str, start: int, end: int) -> str:
|
|
66
|
+
"""Convert 0-based half-open BED coords to samtools faidx region."""
|
|
67
|
+
start1 = start + 1
|
|
68
|
+
end1 = end
|
|
69
|
+
if start1 > end1:
|
|
70
|
+
start1, end1 = end1, start1
|
|
71
|
+
return f"{chrom}:{start1}-{end1}"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _fetch_sequence_with_samtools(fasta: Path, chrom: str, start: int, end: int) -> str:
|
|
75
|
+
if subprocess is None or shutil is None:
|
|
76
|
+
raise RuntimeError("samtools backend is unavailable.")
|
|
77
|
+
if not shutil.which("samtools"):
|
|
78
|
+
raise RuntimeError("samtools is required but not available in PATH.")
|
|
79
|
+
region = _bed_to_faidx_region(chrom, start, end)
|
|
80
|
+
cp = subprocess.run(
|
|
81
|
+
["samtools", "faidx", str(fasta), region],
|
|
82
|
+
stdout=subprocess.PIPE,
|
|
83
|
+
stderr=subprocess.PIPE,
|
|
84
|
+
text=True,
|
|
85
|
+
)
|
|
86
|
+
if cp.returncode != 0:
|
|
87
|
+
raise RuntimeError(f"samtools faidx failed (exit {cp.returncode}):\n{cp.stderr}")
|
|
88
|
+
lines = [line.strip() for line in cp.stdout.splitlines() if line and not line.startswith(">")]
|
|
89
|
+
return "".join(lines)
|
|
90
|
+
|
|
21
91
|
|
|
22
92
|
def _convert_FASTA_record(
|
|
23
93
|
record: SeqRecord,
|
|
@@ -160,7 +230,7 @@ def index_fasta(fasta: str | Path, write_chrom_sizes: bool = True) -> Path:
|
|
|
160
230
|
Path: Path to the index file or chromosome sizes file.
|
|
161
231
|
"""
|
|
162
232
|
fasta = Path(fasta)
|
|
163
|
-
|
|
233
|
+
_require_pysam().faidx(str(fasta)) # creates <fasta>.fai
|
|
164
234
|
|
|
165
235
|
fai = fasta.with_suffix(fasta.suffix + ".fai")
|
|
166
236
|
if write_chrom_sizes:
|
|
@@ -307,8 +377,13 @@ def subsample_fasta_from_bed(
|
|
|
307
377
|
# Ensure output directory exists
|
|
308
378
|
output_directory.mkdir(parents=True, exist_ok=True)
|
|
309
379
|
|
|
310
|
-
|
|
311
|
-
|
|
380
|
+
backend = _resolve_fasta_backend()
|
|
381
|
+
_ensure_fasta_index(input_FASTA)
|
|
382
|
+
|
|
383
|
+
fasta_handle = None
|
|
384
|
+
if backend == "python":
|
|
385
|
+
pysam_mod = _require_pysam()
|
|
386
|
+
fasta_handle = pysam_mod.FastaFile(str(input_FASTA))
|
|
312
387
|
|
|
313
388
|
# Open BED + output FASTA
|
|
314
389
|
with input_bed.open("r") as bed, output_FASTA.open("w") as out_fasta:
|
|
@@ -319,15 +394,24 @@ def subsample_fasta_from_bed(
|
|
|
319
394
|
end = int(fields[2]) # BED is 0-based and end is exclusive
|
|
320
395
|
desc = " ".join(fields[3:]) if len(fields) > 3 else ""
|
|
321
396
|
|
|
322
|
-
if
|
|
397
|
+
if backend == "python":
|
|
398
|
+
assert fasta_handle is not None
|
|
399
|
+
if chrom not in fasta_handle.references:
|
|
400
|
+
logger.warning(f"{chrom} not found in FASTA")
|
|
401
|
+
continue
|
|
402
|
+
sequence = fasta_handle.fetch(chrom, start, end)
|
|
403
|
+
else:
|
|
404
|
+
sequence = _fetch_sequence_with_samtools(input_FASTA, chrom, start, end)
|
|
405
|
+
|
|
406
|
+
if not sequence:
|
|
323
407
|
logger.warning(f"{chrom} not found in FASTA")
|
|
324
408
|
continue
|
|
325
409
|
|
|
326
|
-
# pyfaidx is 1-based indexing internally, but [start:end] works with BED coords
|
|
327
|
-
sequence = fasta[chrom][start:end].seq
|
|
328
|
-
|
|
329
410
|
header = f">{chrom}:{start}-{end}"
|
|
330
411
|
if desc:
|
|
331
412
|
header += f" {desc}"
|
|
332
413
|
|
|
333
414
|
out_fasta.write(f"{header}\n{sequence}\n")
|
|
415
|
+
|
|
416
|
+
if fasta_handle is not None:
|
|
417
|
+
fasta_handle.close()
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import glob
|
|
2
4
|
import os
|
|
3
5
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
@@ -7,9 +9,9 @@ from typing import Dict, List, Optional, Union
|
|
|
7
9
|
import numpy as np
|
|
8
10
|
import pandas as pd
|
|
9
11
|
import scipy.sparse as sp
|
|
10
|
-
from pod5 import Reader
|
|
11
12
|
|
|
12
13
|
from smftools.logging_utils import get_logger
|
|
14
|
+
from smftools.optional_imports import require
|
|
13
15
|
|
|
14
16
|
logger = get_logger(__name__)
|
|
15
17
|
|
|
@@ -90,6 +92,7 @@ def add_read_length_and_mapping_qc(
|
|
|
90
92
|
extract_read_features_from_bam_callable=None,
|
|
91
93
|
bypass: bool = False,
|
|
92
94
|
force_redo: bool = True,
|
|
95
|
+
samtools_backend: str | None = "auto",
|
|
93
96
|
):
|
|
94
97
|
"""
|
|
95
98
|
Populate adata.obs with read/mapping QC columns.
|
|
@@ -133,7 +136,7 @@ def add_read_length_and_mapping_qc(
|
|
|
133
136
|
"No `read_metrics` provided and `extract_read_features_from_bam` not found."
|
|
134
137
|
)
|
|
135
138
|
for bam in bam_files:
|
|
136
|
-
bam_read_metrics = extractor(bam)
|
|
139
|
+
bam_read_metrics = extractor(bam, samtools_backend)
|
|
137
140
|
if not isinstance(bam_read_metrics, dict):
|
|
138
141
|
raise ValueError(f"extract_read_features_from_bam returned non-dict for {bam}")
|
|
139
142
|
read_metrics.update(bam_read_metrics)
|
|
@@ -228,6 +231,9 @@ def _collect_read_origins_from_pod5(pod5_path: str, target_ids: set[str]) -> dic
|
|
|
228
231
|
Worker function: scan one POD5 file and return a mapping
|
|
229
232
|
{read_id: pod5_basename} only for read_ids in `target_ids`.
|
|
230
233
|
"""
|
|
234
|
+
p5 = require("pod5", extra="ont", purpose="POD5 metadata")
|
|
235
|
+
Reader = p5.Reader
|
|
236
|
+
|
|
231
237
|
basename = os.path.basename(pod5_path)
|
|
232
238
|
mapping: dict[str, str] = {}
|
|
233
239
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import concurrent.futures
|
|
2
4
|
import gc
|
|
3
5
|
import re
|
|
@@ -16,9 +18,11 @@ from .bam_functions import count_aligned_reads
|
|
|
16
18
|
logger = get_logger(__name__)
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
def filter_bam_records(bam, mapping_threshold):
|
|
21
|
+
def filter_bam_records(bam, mapping_threshold, samtools_backend: str | None = "auto"):
|
|
20
22
|
"""Processes a single BAM file, counts reads, and determines records to analyze."""
|
|
21
|
-
aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(
|
|
23
|
+
aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(
|
|
24
|
+
bam, samtools_backend
|
|
25
|
+
)
|
|
22
26
|
|
|
23
27
|
total_reads = aligned_reads_count + unaligned_reads_count
|
|
24
28
|
percent_aligned = (aligned_reads_count * 100 / total_reads) if total_reads > 0 else 0
|
|
@@ -35,13 +39,16 @@ def filter_bam_records(bam, mapping_threshold):
|
|
|
35
39
|
return set(records)
|
|
36
40
|
|
|
37
41
|
|
|
38
|
-
def parallel_filter_bams(bam_path_list, mapping_threshold):
|
|
42
|
+
def parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend: str | None = "auto"):
|
|
39
43
|
"""Parallel processing for multiple BAM files."""
|
|
40
44
|
records_to_analyze = set()
|
|
41
45
|
|
|
42
46
|
with concurrent.futures.ProcessPoolExecutor() as executor:
|
|
43
47
|
results = executor.map(
|
|
44
|
-
filter_bam_records,
|
|
48
|
+
filter_bam_records,
|
|
49
|
+
bam_path_list,
|
|
50
|
+
[mapping_threshold] * len(bam_path_list),
|
|
51
|
+
[samtools_backend] * len(bam_path_list),
|
|
45
52
|
)
|
|
46
53
|
|
|
47
54
|
# Aggregate results
|
|
@@ -484,6 +491,7 @@ def modkit_extract_to_adata(
|
|
|
484
491
|
delete_batch_hdfs=False,
|
|
485
492
|
threads=None,
|
|
486
493
|
double_barcoded_path=None,
|
|
494
|
+
samtools_backend: str | None = "auto",
|
|
487
495
|
):
|
|
488
496
|
"""
|
|
489
497
|
Takes modkit extract outputs and organizes it into an adata object
|
|
@@ -591,7 +599,7 @@ def modkit_extract_to_adata(
|
|
|
591
599
|
|
|
592
600
|
######### Get Record names that have over a passed threshold of mapped reads #############
|
|
593
601
|
# get all records that are above a certain mapping threshold in at least one sample bam
|
|
594
|
-
records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold)
|
|
602
|
+
records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend)
|
|
595
603
|
|
|
596
604
|
##########################################################################################
|
|
597
605
|
|
|
@@ -635,7 +643,9 @@ def modkit_extract_to_adata(
|
|
|
635
643
|
rev_base_identities,
|
|
636
644
|
mismatch_counts_per_read,
|
|
637
645
|
mismatch_trend_per_read,
|
|
638
|
-
) = extract_base_identities(
|
|
646
|
+
) = extract_base_identities(
|
|
647
|
+
bam, record, positions, max_reference_length, ref_seq, samtools_backend
|
|
648
|
+
)
|
|
639
649
|
# Store read names of fwd and rev mapped reads
|
|
640
650
|
fwd_mapped_reads.update(fwd_base_identities.keys())
|
|
641
651
|
rev_mapped_reads.update(rev_base_identities.keys())
|
smftools/informatics/ohe.py
CHANGED
|
@@ -5,9 +5,8 @@ import subprocess
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Iterable
|
|
7
7
|
|
|
8
|
-
import pod5 as p5
|
|
9
|
-
|
|
10
8
|
from smftools.logging_utils import get_logger
|
|
9
|
+
from smftools.optional_imports import require
|
|
11
10
|
|
|
12
11
|
from ..config import LoadExperimentConfig
|
|
13
12
|
from ..informatics.basecalling import canoncall, modcall
|
|
@@ -15,6 +14,8 @@ from ..readwrite import make_dirs
|
|
|
15
14
|
|
|
16
15
|
logger = get_logger(__name__)
|
|
17
16
|
|
|
17
|
+
p5 = require("pod5", extra="ont", purpose="POD5 IO")
|
|
18
|
+
|
|
18
19
|
|
|
19
20
|
def basecall_pod5s(config_path: str | Path) -> None:
|
|
20
21
|
"""Basecall POD5 inputs using a configuration file.
|
|
@@ -1,7 +1,23 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
"
|
|
7
|
-
|
|
3
|
+
from importlib import import_module
|
|
4
|
+
|
|
5
|
+
_LAZY_MODULES = {
|
|
6
|
+
"data": "smftools.machine_learning.data",
|
|
7
|
+
"evaluation": "smftools.machine_learning.evaluation",
|
|
8
|
+
"inference": "smftools.machine_learning.inference",
|
|
9
|
+
"models": "smftools.machine_learning.models",
|
|
10
|
+
"training": "smftools.machine_learning.training",
|
|
11
|
+
"utils": "smftools.machine_learning.utils",
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def __getattr__(name: str):
|
|
16
|
+
if name in _LAZY_MODULES:
|
|
17
|
+
module = import_module(_LAZY_MODULES[name])
|
|
18
|
+
globals()[name] = module
|
|
19
|
+
return module
|
|
20
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
__all__ = list(_LAZY_MODULES.keys())
|
|
@@ -1,12 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
import pandas as pd
|
|
3
|
-
|
|
4
|
-
import
|
|
5
|
-
from sklearn.utils.class_weight import compute_class_weight
|
|
6
|
-
from torch.utils.data import DataLoader, Dataset, Subset
|
|
5
|
+
|
|
6
|
+
from smftools.optional_imports import require
|
|
7
7
|
|
|
8
8
|
from .preprocessing import random_fill_nans
|
|
9
9
|
|
|
10
|
+
pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning data modules")
|
|
11
|
+
torch = require("torch", extra="ml-base", purpose="ML data loading")
|
|
12
|
+
sklearn_class_weight = require(
|
|
13
|
+
"sklearn.utils.class_weight",
|
|
14
|
+
extra="ml-base",
|
|
15
|
+
purpose="class weighting",
|
|
16
|
+
)
|
|
17
|
+
torch_utils_data = require("torch.utils.data", extra="ml-base", purpose="ML data loading")
|
|
18
|
+
|
|
19
|
+
compute_class_weight = sklearn_class_weight.compute_class_weight
|
|
20
|
+
DataLoader = torch_utils_data.DataLoader
|
|
21
|
+
Dataset = torch_utils_data.Dataset
|
|
22
|
+
Subset = torch_utils_data.Subset
|
|
23
|
+
|
|
10
24
|
|
|
11
25
|
class AnnDataDataset(Dataset):
|
|
12
26
|
"""
|
|
@@ -1,14 +1,19 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
3
4
|
import pandas as pd
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
5
|
+
|
|
6
|
+
from smftools.optional_imports import require
|
|
7
|
+
|
|
8
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="evaluation plots")
|
|
9
|
+
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
|
|
10
|
+
|
|
11
|
+
auc = sklearn_metrics.auc
|
|
12
|
+
confusion_matrix = sklearn_metrics.confusion_matrix
|
|
13
|
+
f1_score = sklearn_metrics.f1_score
|
|
14
|
+
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
15
|
+
roc_auc_score = sklearn_metrics.roc_auc_score
|
|
16
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
class ModelEvaluator:
|
|
@@ -1,9 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
import pandas as pd
|
|
3
|
-
|
|
5
|
+
|
|
6
|
+
from smftools.optional_imports import require
|
|
4
7
|
|
|
5
8
|
from .inference_utils import annotate_split_column
|
|
6
9
|
|
|
10
|
+
torch = require("torch", extra="ml-base", purpose="Lightning inference")
|
|
11
|
+
|
|
7
12
|
|
|
8
13
|
def run_lightning_inference(adata, model, datamodule, trainer, prefix="model", devices=1):
|
|
9
14
|
"""
|
|
@@ -1,9 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
|
-
|
|
3
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
4
6
|
|
|
5
7
|
from ..utils.device import detect_device
|
|
6
8
|
|
|
9
|
+
torch = require("torch", extra="ml-base", purpose="ML base models")
|
|
10
|
+
nn = torch.nn
|
|
11
|
+
|
|
7
12
|
|
|
8
13
|
class BaseTorchModel(nn.Module):
|
|
9
14
|
"""
|
|
@@ -1,9 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
|
-
|
|
3
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
4
6
|
|
|
5
7
|
from .base import BaseTorchModel
|
|
6
8
|
|
|
9
|
+
torch = require("torch", extra="ml-base", purpose="CNN models")
|
|
10
|
+
nn = torch.nn
|
|
11
|
+
|
|
7
12
|
|
|
8
13
|
class CNNClassifier(BaseTorchModel):
|
|
9
14
|
def __init__(
|
|
@@ -1,15 +1,20 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
3
|
-
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
7
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
|
|
8
|
+
pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning models")
|
|
9
|
+
torch = require("torch", extra="ml-base", purpose="Lightning models")
|
|
10
|
+
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
|
|
11
|
+
|
|
12
|
+
auc = sklearn_metrics.auc
|
|
13
|
+
confusion_matrix = sklearn_metrics.confusion_matrix
|
|
14
|
+
f1_score = sklearn_metrics.f1_score
|
|
15
|
+
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
16
|
+
roc_auc_score = sklearn_metrics.roc_auc_score
|
|
17
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
13
18
|
|
|
14
19
|
|
|
15
20
|
class TorchClassifierWrapper(pl.LightningModule):
|
|
@@ -1,7 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
2
4
|
|
|
3
5
|
from .base import BaseTorchModel
|
|
4
6
|
|
|
7
|
+
nn = require("torch.nn", extra="ml-base", purpose="MLP models")
|
|
8
|
+
|
|
5
9
|
|
|
6
10
|
class MLPClassifier(BaseTorchModel):
|
|
7
11
|
def __init__(
|
|
@@ -1,6 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
|
-
|
|
3
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
7
|
+
torch = require("torch", extra="ml-base", purpose="positional encoding")
|
|
8
|
+
nn = torch.nn
|
|
4
9
|
|
|
5
10
|
|
|
6
11
|
class PositionalEncoding(nn.Module):
|
|
@@ -1,7 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
2
4
|
|
|
3
5
|
from .base import BaseTorchModel
|
|
4
6
|
|
|
7
|
+
nn = require("torch.nn", extra="ml-base", purpose="RNN models")
|
|
8
|
+
|
|
5
9
|
|
|
6
10
|
class RNNClassifier(BaseTorchModel):
|
|
7
11
|
def __init__(self, input_size, hidden_dim, num_classes, **kwargs):
|
|
@@ -1,13 +1,18 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
7
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
|
|
8
|
+
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
|
|
9
|
+
|
|
10
|
+
auc = sklearn_metrics.auc
|
|
11
|
+
confusion_matrix = sklearn_metrics.confusion_matrix
|
|
12
|
+
f1_score = sklearn_metrics.f1_score
|
|
13
|
+
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
14
|
+
roc_auc_score = sklearn_metrics.roc_auc_score
|
|
15
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
11
16
|
|
|
12
17
|
|
|
13
18
|
class SklearnModelWrapper:
|
|
@@ -1,11 +1,16 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
|
-
|
|
3
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
4
6
|
|
|
5
7
|
from ..utils.grl import grad_reverse
|
|
6
8
|
from .base import BaseTorchModel
|
|
7
9
|
from .positional import PositionalEncoding
|
|
8
10
|
|
|
11
|
+
torch = require("torch", extra="ml-base", purpose="Transformer models")
|
|
12
|
+
nn = torch.nn
|
|
13
|
+
|
|
9
14
|
|
|
10
15
|
class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
|
|
11
16
|
def __init__(self, *args, **kwargs):
|
|
@@ -1,10 +1,20 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
4
4
|
|
|
5
5
|
from ..data import AnnDataModule
|
|
6
6
|
from ..models import TorchClassifierWrapper
|
|
7
7
|
|
|
8
|
+
torch = require("torch", extra="ml-base", purpose="Lightning training")
|
|
9
|
+
pytorch_lightning = require("pytorch_lightning", extra="ml-extended", purpose="Lightning training")
|
|
10
|
+
pl_callbacks = require(
|
|
11
|
+
"pytorch_lightning.callbacks", extra="ml-extended", purpose="Lightning training"
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
Trainer = pytorch_lightning.Trainer
|
|
15
|
+
EarlyStopping = pl_callbacks.EarlyStopping
|
|
16
|
+
ModelCheckpoint = pl_callbacks.ModelCheckpoint
|
|
17
|
+
|
|
8
18
|
|
|
9
19
|
def train_lightning_model(
|
|
10
20
|
model,
|