smftools 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +39 -7
- smftools/_settings.py +2 -0
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +34 -6
- smftools/cli/hmm_adata.py +239 -33
- smftools/cli/latent_adata.py +318 -0
- smftools/cli/load_adata.py +167 -131
- smftools/cli/preprocess_adata.py +180 -53
- smftools/cli/spatial_adata.py +152 -100
- smftools/cli_entry.py +38 -1
- smftools/config/__init__.py +2 -0
- smftools/config/conversion.yaml +11 -1
- smftools/config/default.yaml +42 -2
- smftools/config/experiment_config.py +59 -1
- smftools/constants.py +65 -0
- smftools/datasets/__init__.py +2 -0
- smftools/hmm/HMM.py +97 -3
- smftools/hmm/__init__.py +24 -13
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +2 -0
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +5 -2
- smftools/hmm/display_hmm.py +4 -1
- smftools/hmm/hmm_readwrite.py +7 -2
- smftools/hmm/nucleosome_hmm_refinement.py +2 -0
- smftools/informatics/__init__.py +59 -34
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +2 -0
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1093 -176
- smftools/informatics/basecalling.py +2 -0
- smftools/informatics/bed_functions.py +271 -61
- smftools/informatics/binarize_converted_base_identities.py +3 -0
- smftools/informatics/complement_base_list.py +2 -0
- smftools/informatics/converted_BAM_to_adata.py +641 -176
- smftools/informatics/fasta_functions.py +94 -10
- smftools/informatics/h5ad_functions.py +123 -4
- smftools/informatics/modkit_extract_to_adata.py +1019 -431
- smftools/informatics/modkit_functions.py +2 -0
- smftools/informatics/ohe.py +2 -0
- smftools/informatics/pod5_functions.py +3 -2
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/machine_learning/__init__.py +22 -6
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +18 -4
- smftools/machine_learning/data/preprocessing.py +2 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +2 -0
- smftools/machine_learning/evaluation/evaluators.py +14 -9
- smftools/machine_learning/inference/__init__.py +2 -0
- smftools/machine_learning/inference/inference_utils.py +2 -0
- smftools/machine_learning/inference/lightning_inference.py +6 -1
- smftools/machine_learning/inference/sklearn_inference.py +2 -0
- smftools/machine_learning/inference/sliding_window_inference.py +2 -0
- smftools/machine_learning/models/__init__.py +2 -0
- smftools/machine_learning/models/base.py +7 -2
- smftools/machine_learning/models/cnn.py +7 -2
- smftools/machine_learning/models/lightning_base.py +16 -11
- smftools/machine_learning/models/mlp.py +5 -1
- smftools/machine_learning/models/positional.py +7 -2
- smftools/machine_learning/models/rnn.py +5 -1
- smftools/machine_learning/models/sklearn_models.py +14 -9
- smftools/machine_learning/models/transformer.py +7 -2
- smftools/machine_learning/models/wrappers.py +6 -2
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +13 -3
- smftools/machine_learning/training/train_sklearn_model.py +2 -0
- smftools/machine_learning/utils/__init__.py +2 -0
- smftools/machine_learning/utils/device.py +5 -1
- smftools/machine_learning/utils/grl.py +5 -1
- smftools/metadata.py +1 -1
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +41 -31
- smftools/plotting/autocorrelation_plotting.py +9 -5
- smftools/plotting/classifiers.py +16 -4
- smftools/plotting/general_plotting.py +2415 -629
- smftools/plotting/hmm_plotting.py +97 -9
- smftools/plotting/position_stats.py +15 -7
- smftools/plotting/qc_plotting.py +6 -1
- smftools/preprocessing/__init__.py +36 -37
- smftools/preprocessing/append_base_context.py +17 -17
- smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
- smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/archived/calculate_complexity.py +2 -0
- smftools/preprocessing/archived/mark_duplicates.py +2 -0
- smftools/preprocessing/archived/preprocessing.py +2 -0
- smftools/preprocessing/archived/remove_duplicates.py +2 -0
- smftools/preprocessing/binary_layers_to_ohe.py +2 -1
- smftools/preprocessing/calculate_complexity_II.py +4 -1
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_pairwise_differences.py +2 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
- smftools/preprocessing/calculate_position_Youden.py +9 -2
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
- smftools/preprocessing/flag_duplicate_reads.py +42 -54
- smftools/preprocessing/make_dirs.py +2 -1
- smftools/preprocessing/min_non_diagonal.py +2 -0
- smftools/preprocessing/recipes.py +2 -0
- smftools/readwrite.py +53 -17
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +30 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +2 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +2 -0
- smftools/tools/archived/subset_adata_v2.py +2 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +119 -0
- smftools/tools/calculate_umap.py +93 -8
- smftools/tools/cluster_adata_on_methylation.py +7 -1
- smftools/tools/position_stats.py +17 -27
- smftools/tools/rolling_nn_distance.py +235 -0
- smftools/tools/tensor_factorization.py +169 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/METADATA +69 -33
- smftools-0.3.1.dist-info/RECORD +189 -0
- smftools-0.2.5.dist-info/RECORD +0 -181
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
smftools/informatics/ohe.py
CHANGED
|
@@ -5,9 +5,8 @@ import subprocess
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Iterable
|
|
7
7
|
|
|
8
|
-
import pod5 as p5
|
|
9
|
-
|
|
10
8
|
from smftools.logging_utils import get_logger
|
|
9
|
+
from smftools.optional_imports import require
|
|
11
10
|
|
|
12
11
|
from ..config import LoadExperimentConfig
|
|
13
12
|
from ..informatics.basecalling import canoncall, modcall
|
|
@@ -15,6 +14,8 @@ from ..readwrite import make_dirs
|
|
|
15
14
|
|
|
16
15
|
logger = get_logger(__name__)
|
|
17
16
|
|
|
17
|
+
p5 = require("pod5", extra="ont", purpose="POD5 IO")
|
|
18
|
+
|
|
18
19
|
|
|
19
20
|
def basecall_pod5s(config_path: str | Path) -> None:
|
|
20
21
|
"""Basecall POD5 inputs using a configuration file.
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Iterable, Mapping
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from smftools.constants import (
|
|
8
|
+
MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
|
|
9
|
+
MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def encode_sequence_to_int(
|
|
14
|
+
sequence: str | Iterable[str],
|
|
15
|
+
*,
|
|
16
|
+
base_to_int: Mapping[str, int] = MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
|
|
17
|
+
unknown_base: str = "N",
|
|
18
|
+
) -> np.ndarray:
|
|
19
|
+
"""Encode a base sequence into integer values using constant mappings.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
sequence: Sequence string or iterable of base characters.
|
|
23
|
+
base_to_int: Mapping of base characters to integer encodings.
|
|
24
|
+
unknown_base: Base to use when a character is not in the encoding map.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
np.ndarray: Integer-encoded sequence array.
|
|
28
|
+
|
|
29
|
+
Raises:
|
|
30
|
+
ValueError: If an unknown base is encountered and ``unknown_base`` is not mapped.
|
|
31
|
+
"""
|
|
32
|
+
if unknown_base not in base_to_int:
|
|
33
|
+
raise ValueError(f"Unknown base '{unknown_base}' not present in encoding map.")
|
|
34
|
+
|
|
35
|
+
if isinstance(sequence, str):
|
|
36
|
+
sequence_iter = sequence
|
|
37
|
+
else:
|
|
38
|
+
sequence_iter = list(sequence)
|
|
39
|
+
|
|
40
|
+
fallback = base_to_int[unknown_base]
|
|
41
|
+
encoded = np.fromiter(
|
|
42
|
+
(base_to_int.get(base, fallback) for base in sequence_iter),
|
|
43
|
+
dtype=np.int16,
|
|
44
|
+
count=len(sequence_iter),
|
|
45
|
+
)
|
|
46
|
+
return encoded
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def decode_int_sequence(
|
|
50
|
+
encoded_sequence: Iterable[int] | np.ndarray,
|
|
51
|
+
*,
|
|
52
|
+
int_to_base: Mapping[int, str] = MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE,
|
|
53
|
+
unknown_base: str = "N",
|
|
54
|
+
) -> list[str]:
|
|
55
|
+
"""Decode integer-encoded bases into characters using constant mappings.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
encoded_sequence: Iterable of integer-encoded bases.
|
|
59
|
+
int_to_base: Mapping of integer encodings to base characters.
|
|
60
|
+
unknown_base: Base to use when an integer is not in the decoding map.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
list[str]: Decoded base characters.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ValueError: If ``unknown_base`` is not available for fallback.
|
|
67
|
+
"""
|
|
68
|
+
if unknown_base not in int_to_base.values():
|
|
69
|
+
raise ValueError(f"Unknown base '{unknown_base}' not present in decoding map.")
|
|
70
|
+
|
|
71
|
+
fallback = unknown_base
|
|
72
|
+
return [int_to_base.get(int(value), fallback) for value in encoded_sequence]
|
smftools/logging_utils.py
CHANGED
|
@@ -15,18 +15,37 @@ def setup_logging(
|
|
|
15
15
|
fmt: str = DEFAULT_LOG_FORMAT,
|
|
16
16
|
datefmt: str = DEFAULT_DATE_FORMAT,
|
|
17
17
|
log_file: Optional[Union[str, Path]] = None,
|
|
18
|
+
reconfigure: bool = False,
|
|
18
19
|
) -> None:
|
|
19
20
|
"""
|
|
20
21
|
Configure logging for smftools.
|
|
21
22
|
|
|
22
23
|
Should be called once by the CLI entrypoint.
|
|
23
|
-
Safe to call multiple times.
|
|
24
|
+
Safe to call multiple times, with optional reconfiguration.
|
|
24
25
|
"""
|
|
25
26
|
logger = logging.getLogger("smftools")
|
|
26
27
|
|
|
27
|
-
if logger.handlers:
|
|
28
|
+
if logger.handlers and not reconfigure:
|
|
29
|
+
if log_file is not None:
|
|
30
|
+
log_path = Path(log_file)
|
|
31
|
+
has_file_handler = any(
|
|
32
|
+
isinstance(handler, logging.FileHandler)
|
|
33
|
+
and Path(getattr(handler, "baseFilename", "")) == log_path
|
|
34
|
+
for handler in logger.handlers
|
|
35
|
+
)
|
|
36
|
+
if not has_file_handler:
|
|
37
|
+
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
38
|
+
file_handler = logging.FileHandler(log_path)
|
|
39
|
+
file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt=datefmt))
|
|
40
|
+
logger.addHandler(file_handler)
|
|
41
|
+
logger.setLevel(level)
|
|
28
42
|
return
|
|
29
43
|
|
|
44
|
+
if logger.handlers and reconfigure:
|
|
45
|
+
for handler in list(logger.handlers):
|
|
46
|
+
logger.removeHandler(handler)
|
|
47
|
+
handler.close()
|
|
48
|
+
|
|
30
49
|
formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
|
|
31
50
|
|
|
32
51
|
# Console handler (stderr)
|
|
@@ -1,7 +1,23 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
"
|
|
7
|
-
|
|
3
|
+
from importlib import import_module
|
|
4
|
+
|
|
5
|
+
_LAZY_MODULES = {
|
|
6
|
+
"data": "smftools.machine_learning.data",
|
|
7
|
+
"evaluation": "smftools.machine_learning.evaluation",
|
|
8
|
+
"inference": "smftools.machine_learning.inference",
|
|
9
|
+
"models": "smftools.machine_learning.models",
|
|
10
|
+
"training": "smftools.machine_learning.training",
|
|
11
|
+
"utils": "smftools.machine_learning.utils",
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def __getattr__(name: str):
|
|
16
|
+
if name in _LAZY_MODULES:
|
|
17
|
+
module = import_module(_LAZY_MODULES[name])
|
|
18
|
+
globals()[name] = module
|
|
19
|
+
return module
|
|
20
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
__all__ = list(_LAZY_MODULES.keys())
|
|
@@ -1,12 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
import pandas as pd
|
|
3
|
-
|
|
4
|
-
import
|
|
5
|
-
from sklearn.utils.class_weight import compute_class_weight
|
|
6
|
-
from torch.utils.data import DataLoader, Dataset, Subset
|
|
5
|
+
|
|
6
|
+
from smftools.optional_imports import require
|
|
7
7
|
|
|
8
8
|
from .preprocessing import random_fill_nans
|
|
9
9
|
|
|
10
|
+
pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning data modules")
|
|
11
|
+
torch = require("torch", extra="ml-base", purpose="ML data loading")
|
|
12
|
+
sklearn_class_weight = require(
|
|
13
|
+
"sklearn.utils.class_weight",
|
|
14
|
+
extra="ml-base",
|
|
15
|
+
purpose="class weighting",
|
|
16
|
+
)
|
|
17
|
+
torch_utils_data = require("torch.utils.data", extra="ml-base", purpose="ML data loading")
|
|
18
|
+
|
|
19
|
+
compute_class_weight = sklearn_class_weight.compute_class_weight
|
|
20
|
+
DataLoader = torch_utils_data.DataLoader
|
|
21
|
+
Dataset = torch_utils_data.Dataset
|
|
22
|
+
Subset = torch_utils_data.Subset
|
|
23
|
+
|
|
10
24
|
|
|
11
25
|
class AnnDataDataset(Dataset):
|
|
12
26
|
"""
|
|
@@ -1,14 +1,19 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
3
4
|
import pandas as pd
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
5
|
+
|
|
6
|
+
from smftools.optional_imports import require
|
|
7
|
+
|
|
8
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="evaluation plots")
|
|
9
|
+
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
|
|
10
|
+
|
|
11
|
+
auc = sklearn_metrics.auc
|
|
12
|
+
confusion_matrix = sklearn_metrics.confusion_matrix
|
|
13
|
+
f1_score = sklearn_metrics.f1_score
|
|
14
|
+
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
15
|
+
roc_auc_score = sklearn_metrics.roc_auc_score
|
|
16
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
class ModelEvaluator:
|
|
@@ -1,9 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
import pandas as pd
|
|
3
|
-
|
|
5
|
+
|
|
6
|
+
from smftools.optional_imports import require
|
|
4
7
|
|
|
5
8
|
from .inference_utils import annotate_split_column
|
|
6
9
|
|
|
10
|
+
torch = require("torch", extra="ml-base", purpose="Lightning inference")
|
|
11
|
+
|
|
7
12
|
|
|
8
13
|
def run_lightning_inference(adata, model, datamodule, trainer, prefix="model", devices=1):
|
|
9
14
|
"""
|
|
@@ -1,9 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
|
-
|
|
3
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
4
6
|
|
|
5
7
|
from ..utils.device import detect_device
|
|
6
8
|
|
|
9
|
+
torch = require("torch", extra="ml-base", purpose="ML base models")
|
|
10
|
+
nn = torch.nn
|
|
11
|
+
|
|
7
12
|
|
|
8
13
|
class BaseTorchModel(nn.Module):
|
|
9
14
|
"""
|
|
@@ -1,9 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
|
-
|
|
3
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
4
6
|
|
|
5
7
|
from .base import BaseTorchModel
|
|
6
8
|
|
|
9
|
+
torch = require("torch", extra="ml-base", purpose="CNN models")
|
|
10
|
+
nn = torch.nn
|
|
11
|
+
|
|
7
12
|
|
|
8
13
|
class CNNClassifier(BaseTorchModel):
|
|
9
14
|
def __init__(
|
|
@@ -1,15 +1,20 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
3
|
-
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
7
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
|
|
8
|
+
pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning models")
|
|
9
|
+
torch = require("torch", extra="ml-base", purpose="Lightning models")
|
|
10
|
+
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
|
|
11
|
+
|
|
12
|
+
auc = sklearn_metrics.auc
|
|
13
|
+
confusion_matrix = sklearn_metrics.confusion_matrix
|
|
14
|
+
f1_score = sklearn_metrics.f1_score
|
|
15
|
+
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
16
|
+
roc_auc_score = sklearn_metrics.roc_auc_score
|
|
17
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
13
18
|
|
|
14
19
|
|
|
15
20
|
class TorchClassifierWrapper(pl.LightningModule):
|
|
@@ -1,7 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
2
4
|
|
|
3
5
|
from .base import BaseTorchModel
|
|
4
6
|
|
|
7
|
+
nn = require("torch.nn", extra="ml-base", purpose="MLP models")
|
|
8
|
+
|
|
5
9
|
|
|
6
10
|
class MLPClassifier(BaseTorchModel):
|
|
7
11
|
def __init__(
|
|
@@ -1,6 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
|
-
|
|
3
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
7
|
+
torch = require("torch", extra="ml-base", purpose="positional encoding")
|
|
8
|
+
nn = torch.nn
|
|
4
9
|
|
|
5
10
|
|
|
6
11
|
class PositionalEncoding(nn.Module):
|
|
@@ -1,7 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
2
4
|
|
|
3
5
|
from .base import BaseTorchModel
|
|
4
6
|
|
|
7
|
+
nn = require("torch.nn", extra="ml-base", purpose="RNN models")
|
|
8
|
+
|
|
5
9
|
|
|
6
10
|
class RNNClassifier(BaseTorchModel):
|
|
7
11
|
def __init__(self, input_size, hidden_dim, num_classes, **kwargs):
|
|
@@ -1,13 +1,18 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
7
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
|
|
8
|
+
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
|
|
9
|
+
|
|
10
|
+
auc = sklearn_metrics.auc
|
|
11
|
+
confusion_matrix = sklearn_metrics.confusion_matrix
|
|
12
|
+
f1_score = sklearn_metrics.f1_score
|
|
13
|
+
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
14
|
+
roc_auc_score = sklearn_metrics.roc_auc_score
|
|
15
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
11
16
|
|
|
12
17
|
|
|
13
18
|
class SklearnModelWrapper:
|
|
@@ -1,11 +1,16 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
|
-
|
|
3
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
4
6
|
|
|
5
7
|
from ..utils.grl import grad_reverse
|
|
6
8
|
from .base import BaseTorchModel
|
|
7
9
|
from .positional import PositionalEncoding
|
|
8
10
|
|
|
11
|
+
torch = require("torch", extra="ml-base", purpose="Transformer models")
|
|
12
|
+
nn = torch.nn
|
|
13
|
+
|
|
9
14
|
|
|
10
15
|
class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
|
|
11
16
|
def __init__(self, *args, **kwargs):
|
|
@@ -1,10 +1,20 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
4
4
|
|
|
5
5
|
from ..data import AnnDataModule
|
|
6
6
|
from ..models import TorchClassifierWrapper
|
|
7
7
|
|
|
8
|
+
torch = require("torch", extra="ml-base", purpose="Lightning training")
|
|
9
|
+
pytorch_lightning = require("pytorch_lightning", extra="ml-extended", purpose="Lightning training")
|
|
10
|
+
pl_callbacks = require(
|
|
11
|
+
"pytorch_lightning.callbacks", extra="ml-extended", purpose="Lightning training"
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
Trainer = pytorch_lightning.Trainer
|
|
15
|
+
EarlyStopping = pl_callbacks.EarlyStopping
|
|
16
|
+
ModelCheckpoint = pl_callbacks.ModelCheckpoint
|
|
17
|
+
|
|
8
18
|
|
|
9
19
|
def train_lightning_model(
|
|
10
20
|
model,
|
smftools/metadata.py
CHANGED
|
@@ -12,7 +12,7 @@ from typing import Any, Iterable, Optional
|
|
|
12
12
|
from ._version import __version__
|
|
13
13
|
from .schema import SCHEMA_REGISTRY_RESOURCE, SCHEMA_REGISTRY_VERSION
|
|
14
14
|
|
|
15
|
-
_DEPENDENCIES = ("anndata", "numpy", "pandas", "
|
|
15
|
+
_DEPENDENCIES = ("anndata", "numpy", "pandas", "umap-learn", "pynndescent", "torch")
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def _iso_timestamp() -> str:
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Utilities for optional dependency handling."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from importlib import import_module
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def require(package: str, *, extra: str, purpose: str | None = None) -> Any:
|
|
10
|
+
"""Import an optional dependency with a helpful error message.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
package: Importable module name (e.g., "torch", "scanpy").
|
|
14
|
+
extra: Extra name users should install (e.g., "ml", "omics").
|
|
15
|
+
purpose: Optional context describing the feature needing the dependency.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
The imported module.
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
ModuleNotFoundError: If the package is not installed.
|
|
22
|
+
"""
|
|
23
|
+
try:
|
|
24
|
+
return import_module(package)
|
|
25
|
+
except ModuleNotFoundError as exc: # pragma: no cover - depends on env
|
|
26
|
+
reason = f" for {purpose}" if purpose else ""
|
|
27
|
+
message = (
|
|
28
|
+
f"Optional dependency '{package}' is required{reason}. "
|
|
29
|
+
f"Install it with: pip install 'smftools[{extra}]'"
|
|
30
|
+
)
|
|
31
|
+
raise ModuleNotFoundError(message) from exc
|