smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,8 +1,15 @@
1
- import torch
2
- import torch.nn as nn
3
- from .base import BaseTorchModel
1
+ from __future__ import annotations
2
+
4
3
  import numpy as np
5
4
 
5
+ from smftools.optional_imports import require
6
+
7
+ from .base import BaseTorchModel
8
+
9
+ torch = require("torch", extra="ml-base", purpose="CNN models")
10
+ nn = torch.nn
11
+
12
+
6
13
  class CNNClassifier(BaseTorchModel):
7
14
  def __init__(
8
15
  self,
@@ -15,7 +22,7 @@ class CNNClassifier(BaseTorchModel):
15
22
  use_pooling=False,
16
23
  dropout=0.2,
17
24
  gradcam_layer_idx=-1,
18
- **kwargs
25
+ **kwargs,
19
26
  ):
20
27
  super().__init__(**kwargs)
21
28
  self.name = "CNNClassifier"
@@ -30,7 +37,9 @@ class CNNClassifier(BaseTorchModel):
30
37
 
31
38
  # Build conv layers
32
39
  for out_channels, ksize in zip(conv_channels, kernel_sizes):
33
- layers.append(nn.Conv1d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2))
40
+ layers.append(
41
+ nn.Conv1d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2)
42
+ )
34
43
  if use_batchnorm:
35
44
  layers.append(nn.BatchNorm1d(out_channels))
36
45
  layers.append(nn.ReLU())
@@ -76,7 +85,7 @@ class CNNClassifier(BaseTorchModel):
76
85
  x = self.conv(x)
77
86
  x = x.view(x.size(0), -1)
78
87
  return self.fc(x)
79
-
88
+
80
89
  def _register_gradcam_hooks(self):
81
90
  def forward_hook(module, input, output):
82
91
  self.gradcam_activations = output.detach()
@@ -97,15 +106,15 @@ class CNNClassifier(BaseTorchModel):
97
106
  self.eval() # disable dropout etc.
98
107
 
99
108
  output = self.forward(x) # shape (B, C) or (B, 1)
100
-
109
+
101
110
  if class_idx is None:
102
111
  class_idx = output.argmax(dim=1)
103
-
112
+
104
113
  if output.shape[1] == 1:
105
114
  target = output.view(-1) # shape (B,)
106
115
  else:
107
116
  target = output[torch.arange(output.shape[0]), class_idx]
108
-
117
+
109
118
  target.sum().backward(retain_graph=True)
110
119
 
111
120
  # restore training mode
@@ -114,16 +123,16 @@ class CNNClassifier(BaseTorchModel):
114
123
 
115
124
  # get activations and gradients (set these via forward hook!)
116
125
  activations = self.gradcam_activations # (B, C, L)
117
- gradients = self.gradcam_gradients # (B, C, L)
126
+ gradients = self.gradcam_gradients # (B, C, L)
118
127
 
119
128
  weights = gradients.mean(dim=2, keepdim=True) # (B, C, 1)
120
- cam = (weights * activations).sum(dim=1) # (B, L)
129
+ cam = (weights * activations).sum(dim=1) # (B, L)
121
130
 
122
131
  cam = torch.relu(cam)
123
132
  cam = cam / (cam.max(dim=1, keepdim=True).values + 1e-6)
124
133
 
125
134
  return cam
126
-
135
+
127
136
  def apply_gradcam_to_adata(self, dataloader, adata, obsm_key="gradcam", device="cpu"):
128
137
  self.to(device)
129
138
  self.eval()
@@ -135,4 +144,4 @@ class CNNClassifier(BaseTorchModel):
135
144
  cams.append(cam_batch.cpu().numpy())
136
145
 
137
146
  cams = np.concatenate(cams, axis=0) # shape: [n_obs, input_len]
138
- adata.obsm[obsm_key] = cams
147
+ adata.obsm[obsm_key] = cams
@@ -1,11 +1,22 @@
1
- import torch
2
- import pytorch_lightning as pl
3
- import matplotlib.pyplot as plt
4
- from sklearn.metrics import (
5
- roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
6
- )
1
+ from __future__ import annotations
2
+
7
3
  import numpy as np
8
4
 
5
+ from smftools.optional_imports import require
6
+
7
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
8
+ pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning models")
9
+ torch = require("torch", extra="ml-base", purpose="Lightning models")
10
+ sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
11
+
12
+ auc = sklearn_metrics.auc
13
+ confusion_matrix = sklearn_metrics.confusion_matrix
14
+ f1_score = sklearn_metrics.f1_score
15
+ precision_recall_curve = sklearn_metrics.precision_recall_curve
16
+ roc_auc_score = sklearn_metrics.roc_auc_score
17
+ roc_curve = sklearn_metrics.roc_curve
18
+
19
+
9
20
  class TorchClassifierWrapper(pl.LightningModule):
10
21
  """
11
22
  A Pytorch Lightning wrapper for PyTorch classifiers.
@@ -16,25 +27,26 @@ class TorchClassifierWrapper(pl.LightningModule):
16
27
  - Can pass the index of the class label to use as the focus class when calculating precision/recall.
17
28
  - Contains a prediction step to run inference with.
18
29
  """
30
+
19
31
  def __init__(
20
32
  self,
21
33
  model: torch.nn.Module,
22
34
  label_col: str,
23
35
  num_classes: int,
24
- class_names: list=None,
36
+ class_names: list = None,
25
37
  optimizer_cls=torch.optim.AdamW,
26
38
  optimizer_kwargs=None,
27
39
  criterion_kwargs=None,
28
40
  lr: float = 1e-3,
29
41
  focus_class: int = 1, # used for binary or multiclass precision-recall
30
42
  class_weights=None,
31
- enforce_eval_balance: bool=False,
32
- target_eval_freq: float=0.3,
33
- max_eval_positive: int=None
43
+ enforce_eval_balance: bool = False,
44
+ target_eval_freq: float = 0.3,
45
+ max_eval_positive: int = None,
34
46
  ):
35
47
  super().__init__()
36
48
  self.model = model
37
- self.save_hyperparameters(ignore=['model']) # logs all except actual model instance
49
+ self.save_hyperparameters(ignore=["model"]) # logs all except actual model instance
38
50
  self.optimizer_cls = optimizer_cls
39
51
  self.optimizer_kwargs = optimizer_kwargs or {"weight_decay": 1e-4}
40
52
  self.criterion = None
@@ -57,14 +69,17 @@ class TorchClassifierWrapper(pl.LightningModule):
57
69
  if torch.is_tensor(class_weights[self.focus_class]):
58
70
  self.criterion_kwargs["pos_weight"] = class_weights[self.focus_class]
59
71
  else:
60
- self.criterion_kwargs["pos_weight"] = torch.tensor(class_weights[self.focus_class], dtype=torch.float32, device=self.device)
72
+ self.criterion_kwargs["pos_weight"] = torch.tensor(
73
+ class_weights[self.focus_class], dtype=torch.float32, device=self.device
74
+ )
61
75
  else:
62
76
  # CrossEntropyLoss expects weight tensor of size C
63
77
  if torch.is_tensor(class_weights):
64
78
  self.criterion_kwargs["weight"] = class_weights
65
79
  else:
66
- self.criterion_kwargs["weight"] = torch.tensor(class_weights, dtype=torch.float32)
67
-
80
+ self.criterion_kwargs["weight"] = torch.tensor(
81
+ class_weights, dtype=torch.float32
82
+ )
68
83
 
69
84
  self._val_outputs = []
70
85
  self._test_outputs = []
@@ -78,12 +93,20 @@ class TorchClassifierWrapper(pl.LightningModule):
78
93
 
79
94
  def _init_criterion(self):
80
95
  if self.num_classes == 2:
81
- if "pos_weight" in self.criterion_kwargs and not torch.is_tensor(self.criterion_kwargs["pos_weight"]):
82
- self.criterion_kwargs["pos_weight"] = torch.tensor(self.criterion_kwargs["pos_weight"], dtype=torch.float32, device=self.device)
96
+ if "pos_weight" in self.criterion_kwargs and not torch.is_tensor(
97
+ self.criterion_kwargs["pos_weight"]
98
+ ):
99
+ self.criterion_kwargs["pos_weight"] = torch.tensor(
100
+ self.criterion_kwargs["pos_weight"], dtype=torch.float32, device=self.device
101
+ )
83
102
  self.criterion = torch.nn.BCEWithLogitsLoss(**self.criterion_kwargs)
84
103
  else:
85
- if "weight" in self.criterion_kwargs and not torch.is_tensor(self.criterion_kwargs["weight"]):
86
- self.criterion_kwargs["weight"] = torch.tensor(self.criterion_kwargs["weight"], dtype=torch.float32, device=self.device)
104
+ if "weight" in self.criterion_kwargs and not torch.is_tensor(
105
+ self.criterion_kwargs["weight"]
106
+ ):
107
+ self.criterion_kwargs["weight"] = torch.tensor(
108
+ self.criterion_kwargs["weight"], dtype=torch.float32, device=self.device
109
+ )
87
110
  self.criterion = torch.nn.CrossEntropyLoss(**self.criterion_kwargs)
88
111
 
89
112
  def _resolve_focus_class(self, focus_class):
@@ -93,11 +116,13 @@ class TorchClassifierWrapper(pl.LightningModule):
93
116
  if self.class_names is None:
94
117
  raise ValueError("class_names must be provided if focus_class is a string.")
95
118
  if focus_class not in self.class_names:
96
- raise ValueError(f"focus_class '{focus_class}' not found in class_names {self.class_names}.")
119
+ raise ValueError(
120
+ f"focus_class '{focus_class}' not found in class_names {self.class_names}."
121
+ )
97
122
  return self.class_names.index(focus_class)
98
123
  else:
99
124
  raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
100
-
125
+
101
126
  def set_training_indices(self, datamodule):
102
127
  """
103
128
  Store obs_names for train/val/test subsets used during training.
@@ -140,7 +165,7 @@ class TorchClassifierWrapper(pl.LightningModule):
140
165
  self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=False)
141
166
  self._val_outputs.append((logits.detach(), y.detach()))
142
167
  return loss
143
-
168
+
144
169
  def test_step(self, batch, batch_idx):
145
170
  """
146
171
  Test step for a batch through the Lightning Trainer.
@@ -189,7 +214,7 @@ class TorchClassifierWrapper(pl.LightningModule):
189
214
  return self.criterion(logits.view(-1, 1), y)
190
215
  else:
191
216
  return self.criterion(logits, y)
192
-
217
+
193
218
  def _get_probs(self, logits):
194
219
  """
195
220
  A helper function for getting class probabilities for binary vs multiclass classifications.
@@ -207,8 +232,10 @@ class TorchClassifierWrapper(pl.LightningModule):
207
232
  return (torch.sigmoid(logits.view(-1)) >= 0.5).long()
208
233
  else:
209
234
  return logits.argmax(dim=1)
210
-
211
- def _subsample_for_fixed_positive_frequency(self, y_true, probs, target_freq=0.3, max_positive=None):
235
+
236
+ def _subsample_for_fixed_positive_frequency(
237
+ self, y_true, probs, target_freq=0.3, max_positive=None
238
+ ):
212
239
  pos_idx = np.where(y_true == self.focus_class)[0]
213
240
  neg_idx = np.where(y_true != self.focus_class)[0]
214
241
 
@@ -216,16 +243,20 @@ class TorchClassifierWrapper(pl.LightningModule):
216
243
  max_positives_possible = len(pos_idx)
217
244
 
218
245
  # maximum achievable positive class frequency
219
- max_possible_freq = max_positives_possible / (max_positives_possible + max_negatives_possible)
246
+ max_possible_freq = max_positives_possible / (
247
+ max_positives_possible + max_negatives_possible
248
+ )
220
249
 
221
250
  if target_freq > max_possible_freq:
222
251
  target_freq = max_possible_freq # clip if you ask for impossible freq
223
252
 
224
253
  # now calculate positive count
225
- num_pos_target = min(int(target_freq * max_negatives_possible / (1 - target_freq)), max_positives_possible)
254
+ num_pos_target = min(
255
+ int(target_freq * max_negatives_possible / (1 - target_freq)), max_positives_possible
256
+ )
226
257
  num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
227
258
  num_neg_target = min(num_neg_target, max_negatives_possible)
228
-
259
+
229
260
  pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
230
261
  neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
231
262
 
@@ -235,7 +266,7 @@ class TorchClassifierWrapper(pl.LightningModule):
235
266
  actual_freq = len(pos_sampled) / len(sampled_idx)
236
267
 
237
268
  return sampled_idx
238
-
269
+
239
270
  def _log_classification_metrics(self, logits, targets, prefix="val"):
240
271
  """
241
272
  A helper function for logging validation and testing split model evaluations.
@@ -252,9 +283,12 @@ class TorchClassifierWrapper(pl.LightningModule):
252
283
  num_pos = binary_focus.sum()
253
284
 
254
285
  # Subsample if you want to enforce a fixed proportion of the positive class
255
- if prefix == 'test' and self.enforce_eval_balance:
286
+ if prefix == "test" and self.enforce_eval_balance:
256
287
  sampled_idx = self._subsample_for_fixed_positive_frequency(
257
- y_true, probs, target_freq=self.target_eval_freq, max_positive=self.max_eval_positive
288
+ y_true,
289
+ probs,
290
+ target_freq=self.target_eval_freq,
291
+ max_positive=self.max_eval_positive,
258
292
  )
259
293
  y_true = y_true[sampled_idx]
260
294
  probs = probs[sampled_idx]
@@ -289,7 +323,7 @@ class TorchClassifierWrapper(pl.LightningModule):
289
323
  cm = confusion_matrix(y_true, preds)
290
324
 
291
325
  # Save attributes for later plotting
292
- if prefix == 'test':
326
+ if prefix == "test":
293
327
  self.test_roc_curve = (fpr, tpr)
294
328
  self.test_pr_curve = (rc, pr)
295
329
  self.test_roc_auc = roc_auc
@@ -298,19 +332,21 @@ class TorchClassifierWrapper(pl.LightningModule):
298
332
  self.test_num_pos = num_pos
299
333
  self.test_acc = acc
300
334
  self.test_f1 = f1
301
- elif prefix == 'val':
335
+ elif prefix == "val":
302
336
  pass
303
337
 
304
338
  # Logging
305
- self.log_dict({
306
- f"{prefix}_acc": acc,
307
- f"{prefix}_f1": f1,
308
- f"{prefix}_auc": roc_auc,
309
- f"{prefix}_pr_auc": pr_auc,
310
- f"{prefix}_pr_auc_norm": pr_auc_norm,
311
- f"{prefix}_pos_freq": pos_freq,
312
- f"{prefix}_num_pos": num_pos
313
- })
339
+ self.log_dict(
340
+ {
341
+ f"{prefix}_acc": acc,
342
+ f"{prefix}_f1": f1,
343
+ f"{prefix}_auc": roc_auc,
344
+ f"{prefix}_pr_auc": pr_auc,
345
+ f"{prefix}_pr_auc_norm": pr_auc_norm,
346
+ f"{prefix}_pos_freq": pos_freq,
347
+ f"{prefix}_num_pos": num_pos,
348
+ }
349
+ )
314
350
  setattr(self, f"{prefix}_confusion_matrix", cm)
315
351
 
316
352
  def _plot_roc_pr_curves(self, logits, targets):
@@ -334,7 +370,7 @@ class TorchClassifierWrapper(pl.LightningModule):
334
370
  pos_freq = self.test_pos_freq
335
371
  plt.subplot(1, 2, 2)
336
372
  plt.plot(rc, pr, label=f"PR AUC={pr_auc:.3f}")
337
- plt.axhline(pos_freq, linestyle='--', color="gray")
373
+ plt.axhline(pos_freq, linestyle="--", color="gray")
338
374
  plt.xlabel("Recall")
339
375
  plt.ylabel("Precision")
340
376
  plt.ylim(0, 1.05)
@@ -1,9 +1,22 @@
1
- import torch
2
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
3
5
  from .base import BaseTorchModel
4
-
6
+
7
+ nn = require("torch.nn", extra="ml-base", purpose="MLP models")
8
+
9
+
5
10
  class MLPClassifier(BaseTorchModel):
6
- def __init__(self, input_dim, num_classes=2, hidden_dims=[64, 64], dropout=0.2, use_batchnorm=True, **kwargs):
11
+ def __init__(
12
+ self,
13
+ input_dim,
14
+ num_classes=2,
15
+ hidden_dims=[64, 64],
16
+ dropout=0.2,
17
+ use_batchnorm=True,
18
+ **kwargs,
19
+ ):
7
20
  super().__init__(**kwargs)
8
21
  layers = []
9
22
  in_dim = input_dim
@@ -23,4 +36,4 @@ class MLPClassifier(BaseTorchModel):
23
36
  self.model = nn.Sequential(*layers)
24
37
 
25
38
  def forward(self, x):
26
- return self.model(x)
39
+ return self.model(x)
@@ -1,6 +1,12 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
- import torch
3
- import torch.nn as nn
4
+
5
+ from smftools.optional_imports import require
6
+
7
+ torch = require("torch", extra="ml-base", purpose="positional encoding")
8
+ nn = torch.nn
9
+
4
10
 
5
11
  class PositionalEncoding(nn.Module):
6
12
  def __init__(self, d_model, max_len=5000):
@@ -14,5 +20,5 @@ class PositionalEncoding(nn.Module):
14
20
  self.register_buffer("pe", pe)
15
21
 
16
22
  def forward(self, x):
17
- x = x + self.pe[:, :x.size(1)]
18
- return x
23
+ x = x + self.pe[:, : x.size(1)]
24
+ return x
@@ -1,7 +1,12 @@
1
- import torch
2
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
3
5
  from .base import BaseTorchModel
4
6
 
7
+ nn = require("torch.nn", extra="ml-base", purpose="RNN models")
8
+
9
+
5
10
  class RNNClassifier(BaseTorchModel):
6
11
  def __init__(self, input_size, hidden_dim, num_classes, **kwargs):
7
12
  super().__init__(**kwargs)
@@ -14,4 +19,4 @@ class RNNClassifier(BaseTorchModel):
14
19
  def forward(self, x):
15
20
  x = x.unsqueeze(1) # [B, 1, L] → for LSTM expecting batch_first
16
21
  _, (h_n, _) = self.lstm(x) # h_n: [1, B, H]
17
- return self.fc(h_n.squeeze(0)) # [B, H] → [B, num_classes]
22
+ return self.fc(h_n.squeeze(0)) # [B, H] → [B, num_classes]
@@ -1,23 +1,35 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
- import matplotlib.pyplot as plt
3
- from sklearn.metrics import (
4
- roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
5
- )
4
+
5
+ from smftools.optional_imports import require
6
+
7
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
8
+ sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
9
+
10
+ auc = sklearn_metrics.auc
11
+ confusion_matrix = sklearn_metrics.confusion_matrix
12
+ f1_score = sklearn_metrics.f1_score
13
+ precision_recall_curve = sklearn_metrics.precision_recall_curve
14
+ roc_auc_score = sklearn_metrics.roc_auc_score
15
+ roc_curve = sklearn_metrics.roc_curve
16
+
6
17
 
7
18
  class SklearnModelWrapper:
8
19
  """
9
20
  Unified sklearn wrapper matching TorchClassifierWrapper interface.
10
21
  """
22
+
11
23
  def __init__(
12
- self,
13
- model,
24
+ self,
25
+ model,
14
26
  label_col: str,
15
- num_classes: int,
16
- class_names=None,
17
- focus_class: int=1,
18
- enforce_eval_balance: bool=False,
19
- target_eval_freq: float=0.3,
20
- max_eval_positive=None
27
+ num_classes: int,
28
+ class_names=None,
29
+ focus_class: int = 1,
30
+ enforce_eval_balance: bool = False,
31
+ target_eval_freq: float = 0.3,
32
+ max_eval_positive=None,
21
33
  ):
22
34
  self.model = model
23
35
  self.label_col = label_col
@@ -37,7 +49,9 @@ class SklearnModelWrapper:
37
49
  if self.class_names is None:
38
50
  raise ValueError("class_names must be provided if focus_class is a string.")
39
51
  if focus_class not in self.class_names:
40
- raise ValueError(f"focus_class '{focus_class}' not found in class_names {self.class_names}.")
52
+ raise ValueError(
53
+ f"focus_class '{focus_class}' not found in class_names {self.class_names}."
54
+ )
41
55
  return self.class_names.index(focus_class)
42
56
  else:
43
57
  raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
@@ -130,7 +144,7 @@ class SklearnModelWrapper:
130
144
  f"{prefix}_pr_auc": pr_auc,
131
145
  f"{prefix}_pr_auc_norm": pr_auc_norm,
132
146
  f"{prefix}_pos_freq": pos_freq,
133
- f"{prefix}_num_pos": num_pos
147
+ f"{prefix}_num_pos": num_pos,
134
148
  }
135
149
 
136
150
  return self.metrics
@@ -166,7 +180,10 @@ class SklearnModelWrapper:
166
180
 
167
181
  def fit_from_datamodule(self, datamodule):
168
182
  datamodule.setup()
169
- X_tensor, y_tensor = datamodule.train_set.dataset.X_tensor, datamodule.train_set.dataset.y_tensor
183
+ X_tensor, y_tensor = (
184
+ datamodule.train_set.dataset.X_tensor,
185
+ datamodule.train_set.dataset.y_tensor,
186
+ )
170
187
  indices = datamodule.train_set.indices
171
188
  X_train = X_tensor[indices].numpy()
172
189
  y_train = y_tensor[indices].numpy()
@@ -190,11 +207,11 @@ class SklearnModelWrapper:
190
207
  y_eval = y_tensor[indices].numpy()
191
208
 
192
209
  return self.evaluate(X_eval, y_eval, prefix=split)
193
-
210
+
194
211
  def compute_shap(self, X, background=None, nsamples=100, target_class=None):
195
212
  """
196
213
  Compute SHAP values on input X, optionally for a specified target class.
197
-
214
+
198
215
  Parameters
199
216
  ----------
200
217
  X : array-like
@@ -225,7 +242,7 @@ class SklearnModelWrapper:
225
242
  shap_values = explainer.shap_values(X)
226
243
  else:
227
244
  shap_values = explainer.shap_values(X, nsamples=nsamples)
228
-
245
+
229
246
  if isinstance(shap_values, np.ndarray):
230
247
  if shap_values.ndim == 3:
231
248
  if isinstance(target_class, int):
@@ -234,10 +251,7 @@ class SklearnModelWrapper:
234
251
  # target_class is per-sample
235
252
  if np.any(target_class >= shap_values.shape[2]):
236
253
  raise ValueError(f"target_class values exceed {shap_values.shape[2]}")
237
- selected = np.array([
238
- shap_values[i, :, c]
239
- for i, c in enumerate(target_class)
240
- ])
254
+ selected = np.array([shap_values[i, :, c] for i, c in enumerate(target_class)])
241
255
  return selected
242
256
  else:
243
257
  # fallback to class 0
@@ -246,7 +260,15 @@ class SklearnModelWrapper:
246
260
  # 2D shape (samples, features), no class dimension
247
261
  return shap_values
248
262
 
249
- def apply_shap_to_adata(self, dataloader, adata, background=None, adata_key="shap_values", target_class=None, normalize=True):
263
+ def apply_shap_to_adata(
264
+ self,
265
+ dataloader,
266
+ adata,
267
+ background=None,
268
+ adata_key="shap_values",
269
+ target_class=None,
270
+ normalize=True,
271
+ ):
250
272
  """
251
273
  Compute SHAP from a DataLoader and store in AnnData if provided.
252
274
  """
@@ -270,4 +292,4 @@ class SklearnModelWrapper:
270
292
  row_max[row_max == 0] = 1 # avoid divide by zero
271
293
  normalized = arr / row_max
272
294
 
273
- adata.obsm[f"{adata_key}_normalized"] = normalized
295
+ adata.obsm[f"{adata_key}_normalized"] = normalized