smftools 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. smftools/__init__.py +39 -7
  2. smftools/_settings.py +2 -0
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +34 -6
  7. smftools/cli/hmm_adata.py +239 -33
  8. smftools/cli/latent_adata.py +318 -0
  9. smftools/cli/load_adata.py +167 -131
  10. smftools/cli/preprocess_adata.py +180 -53
  11. smftools/cli/spatial_adata.py +152 -100
  12. smftools/cli_entry.py +38 -1
  13. smftools/config/__init__.py +2 -0
  14. smftools/config/conversion.yaml +11 -1
  15. smftools/config/default.yaml +42 -2
  16. smftools/config/experiment_config.py +59 -1
  17. smftools/constants.py +65 -0
  18. smftools/datasets/__init__.py +2 -0
  19. smftools/hmm/HMM.py +97 -3
  20. smftools/hmm/__init__.py +24 -13
  21. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  22. smftools/hmm/archived/calculate_distances.py +2 -0
  23. smftools/hmm/archived/call_hmm_peaks.py +2 -0
  24. smftools/hmm/archived/train_hmm.py +2 -0
  25. smftools/hmm/call_hmm_peaks.py +5 -2
  26. smftools/hmm/display_hmm.py +4 -1
  27. smftools/hmm/hmm_readwrite.py +7 -2
  28. smftools/hmm/nucleosome_hmm_refinement.py +2 -0
  29. smftools/informatics/__init__.py +59 -34
  30. smftools/informatics/archived/bam_conversion.py +2 -0
  31. smftools/informatics/archived/bam_direct.py +2 -0
  32. smftools/informatics/archived/basecall_pod5s.py +2 -0
  33. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  34. smftools/informatics/archived/conversion_smf.py +2 -0
  35. smftools/informatics/archived/deaminase_smf.py +1 -0
  36. smftools/informatics/archived/direct_smf.py +2 -0
  37. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  38. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  39. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
  40. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  41. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  42. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  43. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  44. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  45. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  46. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  47. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  48. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  49. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  50. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  52. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  53. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  54. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  55. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  56. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  57. smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
  58. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  59. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  60. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  61. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  62. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  63. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  64. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  65. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
  66. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  67. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  68. smftools/informatics/archived/print_bam_query_seq.py +2 -0
  69. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  70. smftools/informatics/archived/subsample_pod5.py +2 -0
  71. smftools/informatics/bam_functions.py +1093 -176
  72. smftools/informatics/basecalling.py +2 -0
  73. smftools/informatics/bed_functions.py +271 -61
  74. smftools/informatics/binarize_converted_base_identities.py +3 -0
  75. smftools/informatics/complement_base_list.py +2 -0
  76. smftools/informatics/converted_BAM_to_adata.py +641 -176
  77. smftools/informatics/fasta_functions.py +94 -10
  78. smftools/informatics/h5ad_functions.py +123 -4
  79. smftools/informatics/modkit_extract_to_adata.py +1019 -431
  80. smftools/informatics/modkit_functions.py +2 -0
  81. smftools/informatics/ohe.py +2 -0
  82. smftools/informatics/pod5_functions.py +3 -2
  83. smftools/informatics/sequence_encoding.py +72 -0
  84. smftools/logging_utils.py +21 -2
  85. smftools/machine_learning/__init__.py +22 -6
  86. smftools/machine_learning/data/__init__.py +2 -0
  87. smftools/machine_learning/data/anndata_data_module.py +18 -4
  88. smftools/machine_learning/data/preprocessing.py +2 -0
  89. smftools/machine_learning/evaluation/__init__.py +2 -0
  90. smftools/machine_learning/evaluation/eval_utils.py +2 -0
  91. smftools/machine_learning/evaluation/evaluators.py +14 -9
  92. smftools/machine_learning/inference/__init__.py +2 -0
  93. smftools/machine_learning/inference/inference_utils.py +2 -0
  94. smftools/machine_learning/inference/lightning_inference.py +6 -1
  95. smftools/machine_learning/inference/sklearn_inference.py +2 -0
  96. smftools/machine_learning/inference/sliding_window_inference.py +2 -0
  97. smftools/machine_learning/models/__init__.py +2 -0
  98. smftools/machine_learning/models/base.py +7 -2
  99. smftools/machine_learning/models/cnn.py +7 -2
  100. smftools/machine_learning/models/lightning_base.py +16 -11
  101. smftools/machine_learning/models/mlp.py +5 -1
  102. smftools/machine_learning/models/positional.py +7 -2
  103. smftools/machine_learning/models/rnn.py +5 -1
  104. smftools/machine_learning/models/sklearn_models.py +14 -9
  105. smftools/machine_learning/models/transformer.py +7 -2
  106. smftools/machine_learning/models/wrappers.py +6 -2
  107. smftools/machine_learning/training/__init__.py +2 -0
  108. smftools/machine_learning/training/train_lightning_model.py +13 -3
  109. smftools/machine_learning/training/train_sklearn_model.py +2 -0
  110. smftools/machine_learning/utils/__init__.py +2 -0
  111. smftools/machine_learning/utils/device.py +5 -1
  112. smftools/machine_learning/utils/grl.py +5 -1
  113. smftools/metadata.py +1 -1
  114. smftools/optional_imports.py +31 -0
  115. smftools/plotting/__init__.py +41 -31
  116. smftools/plotting/autocorrelation_plotting.py +9 -5
  117. smftools/plotting/classifiers.py +16 -4
  118. smftools/plotting/general_plotting.py +2415 -629
  119. smftools/plotting/hmm_plotting.py +97 -9
  120. smftools/plotting/position_stats.py +15 -7
  121. smftools/plotting/qc_plotting.py +6 -1
  122. smftools/preprocessing/__init__.py +36 -37
  123. smftools/preprocessing/append_base_context.py +17 -17
  124. smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
  125. smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
  126. smftools/preprocessing/archived/calculate_complexity.py +2 -0
  127. smftools/preprocessing/archived/mark_duplicates.py +2 -0
  128. smftools/preprocessing/archived/preprocessing.py +2 -0
  129. smftools/preprocessing/archived/remove_duplicates.py +2 -0
  130. smftools/preprocessing/binary_layers_to_ohe.py +2 -1
  131. smftools/preprocessing/calculate_complexity_II.py +4 -1
  132. smftools/preprocessing/calculate_consensus.py +1 -1
  133. smftools/preprocessing/calculate_pairwise_differences.py +2 -0
  134. smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
  135. smftools/preprocessing/calculate_position_Youden.py +9 -2
  136. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  137. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
  138. smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
  139. smftools/preprocessing/flag_duplicate_reads.py +42 -54
  140. smftools/preprocessing/make_dirs.py +2 -1
  141. smftools/preprocessing/min_non_diagonal.py +2 -0
  142. smftools/preprocessing/recipes.py +2 -0
  143. smftools/readwrite.py +53 -17
  144. smftools/schema/anndata_schema_v1.yaml +15 -1
  145. smftools/tools/__init__.py +30 -18
  146. smftools/tools/archived/apply_hmm.py +2 -0
  147. smftools/tools/archived/classifiers.py +2 -0
  148. smftools/tools/archived/classify_methylated_features.py +2 -0
  149. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  150. smftools/tools/archived/subset_adata_v1.py +2 -0
  151. smftools/tools/archived/subset_adata_v2.py +2 -0
  152. smftools/tools/calculate_leiden.py +57 -0
  153. smftools/tools/calculate_nmf.py +119 -0
  154. smftools/tools/calculate_umap.py +93 -8
  155. smftools/tools/cluster_adata_on_methylation.py +7 -1
  156. smftools/tools/position_stats.py +17 -27
  157. smftools/tools/rolling_nn_distance.py +235 -0
  158. smftools/tools/tensor_factorization.py +169 -0
  159. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/METADATA +69 -33
  160. smftools-0.3.1.dist-info/RECORD +189 -0
  161. smftools-0.2.5.dist-info/RECORD +0 -181
  162. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
  163. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
  164. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import subprocess
2
4
 
3
5
  from smftools.logging_utils import get_logger
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import concurrent.futures
2
4
  import os
3
5
 
@@ -5,9 +5,8 @@ import subprocess
5
5
  from pathlib import Path
6
6
  from typing import Iterable
7
7
 
8
- import pod5 as p5
9
-
10
8
  from smftools.logging_utils import get_logger
9
+ from smftools.optional_imports import require
11
10
 
12
11
  from ..config import LoadExperimentConfig
13
12
  from ..informatics.basecalling import canoncall, modcall
@@ -15,6 +14,8 @@ from ..readwrite import make_dirs
15
14
 
16
15
  logger = get_logger(__name__)
17
16
 
17
+ p5 = require("pod5", extra="ont", purpose="POD5 IO")
18
+
18
19
 
19
20
  def basecall_pod5s(config_path: str | Path) -> None:
20
21
  """Basecall POD5 inputs using a configuration file.
@@ -0,0 +1,72 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterable, Mapping
4
+
5
+ import numpy as np
6
+
7
+ from smftools.constants import (
8
+ MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
9
+ MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE,
10
+ )
11
+
12
+
13
+ def encode_sequence_to_int(
14
+ sequence: str | Iterable[str],
15
+ *,
16
+ base_to_int: Mapping[str, int] = MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT,
17
+ unknown_base: str = "N",
18
+ ) -> np.ndarray:
19
+ """Encode a base sequence into integer values using constant mappings.
20
+
21
+ Args:
22
+ sequence: Sequence string or iterable of base characters.
23
+ base_to_int: Mapping of base characters to integer encodings.
24
+ unknown_base: Base to use when a character is not in the encoding map.
25
+
26
+ Returns:
27
+ np.ndarray: Integer-encoded sequence array.
28
+
29
+ Raises:
30
+ ValueError: If an unknown base is encountered and ``unknown_base`` is not mapped.
31
+ """
32
+ if unknown_base not in base_to_int:
33
+ raise ValueError(f"Unknown base '{unknown_base}' not present in encoding map.")
34
+
35
+ if isinstance(sequence, str):
36
+ sequence_iter = sequence
37
+ else:
38
+ sequence_iter = list(sequence)
39
+
40
+ fallback = base_to_int[unknown_base]
41
+ encoded = np.fromiter(
42
+ (base_to_int.get(base, fallback) for base in sequence_iter),
43
+ dtype=np.int16,
44
+ count=len(sequence_iter),
45
+ )
46
+ return encoded
47
+
48
+
49
+ def decode_int_sequence(
50
+ encoded_sequence: Iterable[int] | np.ndarray,
51
+ *,
52
+ int_to_base: Mapping[int, str] = MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE,
53
+ unknown_base: str = "N",
54
+ ) -> list[str]:
55
+ """Decode integer-encoded bases into characters using constant mappings.
56
+
57
+ Args:
58
+ encoded_sequence: Iterable of integer-encoded bases.
59
+ int_to_base: Mapping of integer encodings to base characters.
60
+ unknown_base: Base to use when an integer is not in the decoding map.
61
+
62
+ Returns:
63
+ list[str]: Decoded base characters.
64
+
65
+ Raises:
66
+ ValueError: If ``unknown_base`` is not available for fallback.
67
+ """
68
+ if unknown_base not in int_to_base.values():
69
+ raise ValueError(f"Unknown base '{unknown_base}' not present in decoding map.")
70
+
71
+ fallback = unknown_base
72
+ return [int_to_base.get(int(value), fallback) for value in encoded_sequence]
smftools/logging_utils.py CHANGED
@@ -15,18 +15,37 @@ def setup_logging(
15
15
  fmt: str = DEFAULT_LOG_FORMAT,
16
16
  datefmt: str = DEFAULT_DATE_FORMAT,
17
17
  log_file: Optional[Union[str, Path]] = None,
18
+ reconfigure: bool = False,
18
19
  ) -> None:
19
20
  """
20
21
  Configure logging for smftools.
21
22
 
22
23
  Should be called once by the CLI entrypoint.
23
- Safe to call multiple times.
24
+ Safe to call multiple times, with optional reconfiguration.
24
25
  """
25
26
  logger = logging.getLogger("smftools")
26
27
 
27
- if logger.handlers:
28
+ if logger.handlers and not reconfigure:
29
+ if log_file is not None:
30
+ log_path = Path(log_file)
31
+ has_file_handler = any(
32
+ isinstance(handler, logging.FileHandler)
33
+ and Path(getattr(handler, "baseFilename", "")) == log_path
34
+ for handler in logger.handlers
35
+ )
36
+ if not has_file_handler:
37
+ log_path.parent.mkdir(parents=True, exist_ok=True)
38
+ file_handler = logging.FileHandler(log_path)
39
+ file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt=datefmt))
40
+ logger.addHandler(file_handler)
41
+ logger.setLevel(level)
28
42
  return
29
43
 
44
+ if logger.handlers and reconfigure:
45
+ for handler in list(logger.handlers):
46
+ logger.removeHandler(handler)
47
+ handler.close()
48
+
30
49
  formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
31
50
 
32
51
  # Console handler (stderr)
@@ -1,7 +1,23 @@
1
- from . import data, evaluation, inference, models, training, utils
1
+ from __future__ import annotations
2
2
 
3
- __all__ = [
4
- "calculate_relative_risk_on_activity",
5
- "evaluate_models_by_subgroup",
6
- "prepare_melted_model_data",
7
- ]
3
+ from importlib import import_module
4
+
5
+ _LAZY_MODULES = {
6
+ "data": "smftools.machine_learning.data",
7
+ "evaluation": "smftools.machine_learning.evaluation",
8
+ "inference": "smftools.machine_learning.inference",
9
+ "models": "smftools.machine_learning.models",
10
+ "training": "smftools.machine_learning.training",
11
+ "utils": "smftools.machine_learning.utils",
12
+ }
13
+
14
+
15
+ def __getattr__(name: str):
16
+ if name in _LAZY_MODULES:
17
+ module = import_module(_LAZY_MODULES[name])
18
+ globals()[name] = module
19
+ return module
20
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
21
+
22
+
23
+ __all__ = list(_LAZY_MODULES.keys())
@@ -1,2 +1,4 @@
1
+ from __future__ import annotations
2
+
1
3
  from .anndata_data_module import AnnDataModule, build_anndata_loader
2
4
  from .preprocessing import random_fill_nans
@@ -1,12 +1,26 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
- import pytorch_lightning as pl
4
- import torch
5
- from sklearn.utils.class_weight import compute_class_weight
6
- from torch.utils.data import DataLoader, Dataset, Subset
5
+
6
+ from smftools.optional_imports import require
7
7
 
8
8
  from .preprocessing import random_fill_nans
9
9
 
10
+ pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning data modules")
11
+ torch = require("torch", extra="ml-base", purpose="ML data loading")
12
+ sklearn_class_weight = require(
13
+ "sklearn.utils.class_weight",
14
+ extra="ml-base",
15
+ purpose="class weighting",
16
+ )
17
+ torch_utils_data = require("torch.utils.data", extra="ml-base", purpose="ML data loading")
18
+
19
+ compute_class_weight = sklearn_class_weight.compute_class_weight
20
+ DataLoader = torch_utils_data.DataLoader
21
+ Dataset = torch_utils_data.Dataset
22
+ Subset = torch_utils_data.Subset
23
+
10
24
 
11
25
  class AnnDataDataset(Dataset):
12
26
  """
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
 
3
5
 
@@ -1,2 +1,4 @@
1
+ from __future__ import annotations
2
+
1
3
  from .eval_utils import flatten_sliding_window_results
2
4
  from .evaluators import ModelEvaluator, PostInferenceModelEvaluator
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import pandas as pd
2
4
 
3
5
 
@@ -1,14 +1,19 @@
1
- import matplotlib.pyplot as plt
1
+ from __future__ import annotations
2
+
2
3
  import numpy as np
3
4
  import pandas as pd
4
- from sklearn.metrics import (
5
- auc,
6
- confusion_matrix,
7
- f1_score,
8
- precision_recall_curve,
9
- roc_auc_score,
10
- roc_curve,
11
- )
5
+
6
+ from smftools.optional_imports import require
7
+
8
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="evaluation plots")
9
+ sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
10
+
11
+ auc = sklearn_metrics.auc
12
+ confusion_matrix = sklearn_metrics.confusion_matrix
13
+ f1_score = sklearn_metrics.f1_score
14
+ precision_recall_curve = sklearn_metrics.precision_recall_curve
15
+ roc_auc_score = sklearn_metrics.roc_auc_score
16
+ roc_curve = sklearn_metrics.roc_curve
12
17
 
13
18
 
14
19
  class ModelEvaluator:
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from .lightning_inference import run_lightning_inference
2
4
  from .sklearn_inference import run_sklearn_inference
3
5
  from .sliding_window_inference import sliding_window_inference
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import pandas as pd
2
4
 
3
5
 
@@ -1,9 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
- import torch
5
+
6
+ from smftools.optional_imports import require
4
7
 
5
8
  from .inference_utils import annotate_split_column
6
9
 
10
+ torch = require("torch", extra="ml-base", purpose="Lightning inference")
11
+
7
12
 
8
13
  def run_lightning_inference(adata, model, datamodule, trainer, prefix="model", devices=1):
9
14
  """
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from ..data import AnnDataModule
2
4
  from ..evaluation import PostInferenceModelEvaluator
3
5
  from .lightning_inference import run_lightning_inference
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from .base import BaseTorchModel
2
4
  from .cnn import CNNClassifier
3
5
  from .lightning_base import TorchClassifierWrapper
@@ -1,9 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
- import torch
3
- import torch.nn as nn
4
+
5
+ from smftools.optional_imports import require
4
6
 
5
7
  from ..utils.device import detect_device
6
8
 
9
+ torch = require("torch", extra="ml-base", purpose="ML base models")
10
+ nn = torch.nn
11
+
7
12
 
8
13
  class BaseTorchModel(nn.Module):
9
14
  """
@@ -1,9 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
- import torch
3
- import torch.nn as nn
4
+
5
+ from smftools.optional_imports import require
4
6
 
5
7
  from .base import BaseTorchModel
6
8
 
9
+ torch = require("torch", extra="ml-base", purpose="CNN models")
10
+ nn = torch.nn
11
+
7
12
 
8
13
  class CNNClassifier(BaseTorchModel):
9
14
  def __init__(
@@ -1,15 +1,20 @@
1
- import matplotlib.pyplot as plt
1
+ from __future__ import annotations
2
+
2
3
  import numpy as np
3
- import pytorch_lightning as pl
4
- import torch
5
- from sklearn.metrics import (
6
- auc,
7
- confusion_matrix,
8
- f1_score,
9
- precision_recall_curve,
10
- roc_auc_score,
11
- roc_curve,
12
- )
4
+
5
+ from smftools.optional_imports import require
6
+
7
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
8
+ pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning models")
9
+ torch = require("torch", extra="ml-base", purpose="Lightning models")
10
+ sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
11
+
12
+ auc = sklearn_metrics.auc
13
+ confusion_matrix = sklearn_metrics.confusion_matrix
14
+ f1_score = sklearn_metrics.f1_score
15
+ precision_recall_curve = sklearn_metrics.precision_recall_curve
16
+ roc_auc_score = sklearn_metrics.roc_auc_score
17
+ roc_curve = sklearn_metrics.roc_curve
13
18
 
14
19
 
15
20
  class TorchClassifierWrapper(pl.LightningModule):
@@ -1,7 +1,11 @@
1
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
2
4
 
3
5
  from .base import BaseTorchModel
4
6
 
7
+ nn = require("torch.nn", extra="ml-base", purpose="MLP models")
8
+
5
9
 
6
10
  class MLPClassifier(BaseTorchModel):
7
11
  def __init__(
@@ -1,6 +1,11 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
- import torch
3
- import torch.nn as nn
4
+
5
+ from smftools.optional_imports import require
6
+
7
+ torch = require("torch", extra="ml-base", purpose="positional encoding")
8
+ nn = torch.nn
4
9
 
5
10
 
6
11
  class PositionalEncoding(nn.Module):
@@ -1,7 +1,11 @@
1
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
2
4
 
3
5
  from .base import BaseTorchModel
4
6
 
7
+ nn = require("torch.nn", extra="ml-base", purpose="RNN models")
8
+
5
9
 
6
10
  class RNNClassifier(BaseTorchModel):
7
11
  def __init__(self, input_size, hidden_dim, num_classes, **kwargs):
@@ -1,13 +1,18 @@
1
- import matplotlib.pyplot as plt
1
+ from __future__ import annotations
2
+
2
3
  import numpy as np
3
- from sklearn.metrics import (
4
- auc,
5
- confusion_matrix,
6
- f1_score,
7
- precision_recall_curve,
8
- roc_auc_score,
9
- roc_curve,
10
- )
4
+
5
+ from smftools.optional_imports import require
6
+
7
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
8
+ sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
9
+
10
+ auc = sklearn_metrics.auc
11
+ confusion_matrix = sklearn_metrics.confusion_matrix
12
+ f1_score = sklearn_metrics.f1_score
13
+ precision_recall_curve = sklearn_metrics.precision_recall_curve
14
+ roc_auc_score = sklearn_metrics.roc_auc_score
15
+ roc_curve = sklearn_metrics.roc_curve
11
16
 
12
17
 
13
18
  class SklearnModelWrapper:
@@ -1,11 +1,16 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
- import torch
3
- import torch.nn as nn
4
+
5
+ from smftools.optional_imports import require
4
6
 
5
7
  from ..utils.grl import grad_reverse
6
8
  from .base import BaseTorchModel
7
9
  from .positional import PositionalEncoding
8
10
 
11
+ torch = require("torch", extra="ml-base", purpose="Transformer models")
12
+ nn = torch.nn
13
+
9
14
 
10
15
  class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
11
16
  def __init__(self, *args, **kwargs):
@@ -1,5 +1,9 @@
1
- import torch
2
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
5
+ torch = require("torch", extra="ml-base", purpose="model wrappers")
6
+ nn = torch.nn
3
7
 
4
8
 
5
9
  class ScaledModel(nn.Module):
@@ -1,2 +1,4 @@
1
+ from __future__ import annotations
2
+
1
3
  from .train_lightning_model import run_sliding_window_lightning_training, train_lightning_model
2
4
  from .train_sklearn_model import run_sliding_window_sklearn_training, train_sklearn_model
@@ -1,10 +1,20 @@
1
- import torch
2
- from pytorch_lightning import Trainer
3
- from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
4
 
5
5
  from ..data import AnnDataModule
6
6
  from ..models import TorchClassifierWrapper
7
7
 
8
+ torch = require("torch", extra="ml-base", purpose="Lightning training")
9
+ pytorch_lightning = require("pytorch_lightning", extra="ml-extended", purpose="Lightning training")
10
+ pl_callbacks = require(
11
+ "pytorch_lightning.callbacks", extra="ml-extended", purpose="Lightning training"
12
+ )
13
+
14
+ Trainer = pytorch_lightning.Trainer
15
+ EarlyStopping = pl_callbacks.EarlyStopping
16
+ ModelCheckpoint = pl_callbacks.ModelCheckpoint
17
+
8
18
 
9
19
  def train_lightning_model(
10
20
  model,
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from ..data import AnnDataModule
2
4
  from ..models import SklearnModelWrapper
3
5
 
@@ -1,2 +1,4 @@
1
+ from __future__ import annotations
2
+
1
3
  from .device import detect_device
2
4
  from .grl import GradReverse
@@ -1,4 +1,8 @@
1
- import torch
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
5
+ torch = require("torch", extra="ml-base", purpose="device selection")
2
6
 
3
7
 
4
8
  def detect_device():
@@ -1,4 +1,8 @@
1
- import torch
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
5
+ torch = require("torch", extra="ml-base", purpose="gradient reversal layers")
2
6
 
3
7
 
4
8
  class GradReverse(torch.autograd.Function):
smftools/metadata.py CHANGED
@@ -12,7 +12,7 @@ from typing import Any, Iterable, Optional
12
12
  from ._version import __version__
13
13
  from .schema import SCHEMA_REGISTRY_RESOURCE, SCHEMA_REGISTRY_VERSION
14
14
 
15
- _DEPENDENCIES = ("anndata", "numpy", "pandas", "scanpy", "torch")
15
+ _DEPENDENCIES = ("anndata", "numpy", "pandas", "umap-learn", "pynndescent", "torch")
16
16
 
17
17
 
18
18
  def _iso_timestamp() -> str:
@@ -0,0 +1,31 @@
1
+ """Utilities for optional dependency handling."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from importlib import import_module
6
+ from typing import Any
7
+
8
+
9
+ def require(package: str, *, extra: str, purpose: str | None = None) -> Any:
10
+ """Import an optional dependency with a helpful error message.
11
+
12
+ Args:
13
+ package: Importable module name (e.g., "torch", "scanpy").
14
+ extra: Extra name users should install (e.g., "ml", "omics").
15
+ purpose: Optional context describing the feature needing the dependency.
16
+
17
+ Returns:
18
+ The imported module.
19
+
20
+ Raises:
21
+ ModuleNotFoundError: If the package is not installed.
22
+ """
23
+ try:
24
+ return import_module(package)
25
+ except ModuleNotFoundError as exc: # pragma: no cover - depends on env
26
+ reason = f" for {purpose}" if purpose else ""
27
+ message = (
28
+ f"Optional dependency '{package}' is required{reason}. "
29
+ f"Install it with: pip install 'smftools[{extra}]'"
30
+ )
31
+ raise ModuleNotFoundError(message) from exc