smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,345 @@
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.metrics import (
5
+ roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
6
+ )
7
+ import numpy as np
8
+
9
+ class TorchClassifierWrapper(pl.LightningModule):
10
+ """
11
+ A Pytorch Lightning wrapper for PyTorch classifiers.
12
+ - Takes a PyTorch model as input.
13
+ - Number of classes should be passed.
14
+ - Optimizer is set as default to AdamW without any keyword arguments.
15
+ - Loss criterion is automatically detected based on if it's a binary of multi-class classifier.
16
+ - Can pass the index of the class label to use as the focus class when calculating precision/recall.
17
+ - Contains a prediction step to run inference with.
18
+ """
19
+ def __init__(
20
+ self,
21
+ model: torch.nn.Module,
22
+ label_col: str,
23
+ num_classes: int,
24
+ class_names: list=None,
25
+ optimizer_cls=torch.optim.AdamW,
26
+ optimizer_kwargs=None,
27
+ criterion_kwargs=None,
28
+ lr: float = 1e-3,
29
+ focus_class: int = 1, # used for binary or multiclass precision-recall
30
+ class_weights=None,
31
+ enforce_eval_balance: bool=False,
32
+ target_eval_freq: float=0.3,
33
+ max_eval_positive: int=None
34
+ ):
35
+ super().__init__()
36
+ self.model = model
37
+ self.save_hyperparameters(ignore=['model']) # logs all except actual model instance
38
+ self.optimizer_cls = optimizer_cls
39
+ self.optimizer_kwargs = optimizer_kwargs or {"weight_decay": 1e-4}
40
+ self.criterion = None
41
+ self.lr = lr
42
+ self.label_col = label_col
43
+ self.num_classes = num_classes
44
+ self.class_names = class_names
45
+ self.focus_class = self._resolve_focus_class(focus_class)
46
+ self.focus_class_name = focus_class
47
+ self.enforce_eval_balance = enforce_eval_balance
48
+ self.target_eval_freq = target_eval_freq
49
+ self.max_eval_positive = max_eval_positive
50
+
51
+ # Handle class weights
52
+ self.criterion_kwargs = criterion_kwargs or {}
53
+
54
+ if class_weights is not None:
55
+ if num_classes == 2:
56
+ # BCEWithLogits uses pos_weight, expects a scalar or tensor
57
+ if torch.is_tensor(class_weights[self.focus_class]):
58
+ self.criterion_kwargs["pos_weight"] = class_weights[self.focus_class]
59
+ else:
60
+ self.criterion_kwargs["pos_weight"] = torch.tensor(class_weights[self.focus_class], dtype=torch.float32, device=self.device)
61
+ else:
62
+ # CrossEntropyLoss expects weight tensor of size C
63
+ if torch.is_tensor(class_weights):
64
+ self.criterion_kwargs["weight"] = class_weights
65
+ else:
66
+ self.criterion_kwargs["weight"] = torch.tensor(class_weights, dtype=torch.float32)
67
+
68
+
69
+ self._val_outputs = []
70
+ self._test_outputs = []
71
+
72
+ def setup(self, stage=None):
73
+ """
74
+ Sets the loss criterion.
75
+ """
76
+ if self.criterion is None and self.num_classes is not None:
77
+ self._init_criterion()
78
+
79
+ def _init_criterion(self):
80
+ if self.num_classes == 2:
81
+ if "pos_weight" in self.criterion_kwargs and not torch.is_tensor(self.criterion_kwargs["pos_weight"]):
82
+ self.criterion_kwargs["pos_weight"] = torch.tensor(self.criterion_kwargs["pos_weight"], dtype=torch.float32, device=self.device)
83
+ self.criterion = torch.nn.BCEWithLogitsLoss(**self.criterion_kwargs)
84
+ else:
85
+ if "weight" in self.criterion_kwargs and not torch.is_tensor(self.criterion_kwargs["weight"]):
86
+ self.criterion_kwargs["weight"] = torch.tensor(self.criterion_kwargs["weight"], dtype=torch.float32, device=self.device)
87
+ self.criterion = torch.nn.CrossEntropyLoss(**self.criterion_kwargs)
88
+
89
+ def _resolve_focus_class(self, focus_class):
90
+ if isinstance(focus_class, int):
91
+ return focus_class
92
+ elif isinstance(focus_class, str):
93
+ if self.class_names is None:
94
+ raise ValueError("class_names must be provided if focus_class is a string.")
95
+ if focus_class not in self.class_names:
96
+ raise ValueError(f"focus_class '{focus_class}' not found in class_names {self.class_names}.")
97
+ return self.class_names.index(focus_class)
98
+ else:
99
+ raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
100
+
101
+ def set_training_indices(self, datamodule):
102
+ """
103
+ Store obs_names for train/val/test subsets used during training.
104
+ """
105
+ self.train_obs_names = datamodule.adata.obs_names[datamodule.train_set.indices].tolist()
106
+ self.val_obs_names = datamodule.adata.obs_names[datamodule.val_set.indices].tolist()
107
+ self.test_obs_names = datamodule.adata.obs_names[datamodule.test_set.indices].tolist()
108
+
109
+ def configure_optimizers(self):
110
+ return self.optimizer_cls(self.parameters(), lr=self.lr, **self.optimizer_kwargs)
111
+
112
+ def forward(self, x):
113
+ """
114
+ Forward pass through the model.
115
+ """
116
+ return self.model(x)
117
+
118
+ def training_step(self, batch, batch_idx):
119
+ """
120
+ Training step for a batch through the Lightning Trainer.
121
+ """
122
+ x, y = batch
123
+ if self.num_classes is None:
124
+ self.num_classes = int(torch.max(y).item()) + 1
125
+ self._init_criterion()
126
+ logits = self(x)
127
+ loss = self._compute_loss(logits, y)
128
+ self.log("train_loss", loss, prog_bar=False)
129
+ return loss
130
+
131
+ def validation_step(self, batch, batch_idx):
132
+ """
133
+ Validation step for a batch through the Lightning Trainer.
134
+ """
135
+ x, y = batch
136
+ logits = self(x)
137
+ loss = self._compute_loss(logits, y)
138
+ preds = self._get_preds(logits)
139
+ acc = (preds == y).float().mean()
140
+ self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=False)
141
+ self._val_outputs.append((logits.detach(), y.detach()))
142
+ return loss
143
+
144
+ def test_step(self, batch, batch_idx):
145
+ """
146
+ Test step for a batch through the Lightning Trainer.
147
+ """
148
+ x, y = batch
149
+ logits = self(x)
150
+ self._test_outputs.append((logits.detach(), y.detach()))
151
+
152
+ def predict_step(self, batch, batch_idx):
153
+ """
154
+ Gets predictions and prediction probabilities for the batch using the trained Lightning model.
155
+ """
156
+ x = batch[0]
157
+ logits = self(x)
158
+ probs = self._get_probs(logits)
159
+ preds = self._get_preds(logits)
160
+ return preds, probs
161
+
162
+ def on_validation_epoch_end(self):
163
+ """
164
+ Final logging of all validation steps
165
+ """
166
+ if not self._val_outputs:
167
+ return
168
+ logits, targets = zip(*self._val_outputs)
169
+ self._val_outputs.clear()
170
+ self._log_classification_metrics(logits, targets, prefix="val")
171
+
172
+ def on_test_epoch_end(self):
173
+ """
174
+ Final logging of all testing steps
175
+ """
176
+ if not self._test_outputs:
177
+ return
178
+ logits, targets = zip(*self._test_outputs)
179
+ self._test_outputs.clear()
180
+ self._log_classification_metrics(logits, targets, prefix="test")
181
+ self._plot_roc_pr_curves(logits, targets)
182
+
183
+ def _compute_loss(self, logits, y):
184
+ """
185
+ A helper function for computing loss for binary vs multiclass classifications.
186
+ """
187
+ if self.num_classes == 2:
188
+ y = y.float().view(-1, 1) # shape [B, 1]
189
+ return self.criterion(logits.view(-1, 1), y)
190
+ else:
191
+ return self.criterion(logits, y)
192
+
193
+ def _get_probs(self, logits):
194
+ """
195
+ A helper function for getting class probabilities for binary vs multiclass classifications.
196
+ """
197
+ if self.num_classes == 2:
198
+ return torch.sigmoid(logits.view(-1))
199
+ else:
200
+ return torch.softmax(logits, dim=1)
201
+
202
+ def _get_preds(self, logits):
203
+ """
204
+ A helper function for getting class predictions for binary vs multiclass classifications.
205
+ """
206
+ if self.num_classes == 2:
207
+ return (torch.sigmoid(logits.view(-1)) >= 0.5).long()
208
+ else:
209
+ return logits.argmax(dim=1)
210
+
211
+ def _subsample_for_fixed_positive_frequency(self, y_true, probs, target_freq=0.3, max_positive=None):
212
+ pos_idx = np.where(y_true == self.focus_class)[0]
213
+ neg_idx = np.where(y_true != self.focus_class)[0]
214
+
215
+ max_negatives_possible = len(neg_idx)
216
+ max_positives_possible = len(pos_idx)
217
+
218
+ # maximum achievable positive class frequency
219
+ max_possible_freq = max_positives_possible / (max_positives_possible + max_negatives_possible)
220
+
221
+ if target_freq > max_possible_freq:
222
+ target_freq = max_possible_freq # clip if you ask for impossible freq
223
+
224
+ # now calculate positive count
225
+ num_pos_target = min(int(target_freq * max_negatives_possible / (1 - target_freq)), max_positives_possible)
226
+ num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
227
+ num_neg_target = min(num_neg_target, max_negatives_possible)
228
+
229
+ pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
230
+ neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
231
+
232
+ sampled_idx = np.concatenate([pos_sampled, neg_sampled])
233
+ np.random.shuffle(sampled_idx)
234
+
235
+ actual_freq = len(pos_sampled) / len(sampled_idx)
236
+
237
+ return sampled_idx
238
+
239
+ def _log_classification_metrics(self, logits, targets, prefix="val"):
240
+ """
241
+ A helper function for logging validation and testing split model evaluations.
242
+ """
243
+ logits = torch.cat(logits).cpu()
244
+ y_true = torch.cat(targets).cpu().numpy()
245
+
246
+ probs = self._get_probs(logits).numpy()
247
+ preds = self._get_preds(logits).cpu().numpy()
248
+
249
+ # remap binary focus class correctly:
250
+ binary_focus = (y_true == self.focus_class).astype(int)
251
+
252
+ num_pos = binary_focus.sum()
253
+
254
+ # Subsample if you want to enforce a fixed proportion of the positive class
255
+ if prefix == 'test' and self.enforce_eval_balance:
256
+ sampled_idx = self._subsample_for_fixed_positive_frequency(
257
+ y_true, probs, target_freq=self.target_eval_freq, max_positive=self.max_eval_positive
258
+ )
259
+ y_true = y_true[sampled_idx]
260
+ probs = probs[sampled_idx]
261
+ preds = preds[sampled_idx]
262
+ binary_focus = (y_true == self.focus_class).astype(int)
263
+ num_pos = binary_focus.sum()
264
+
265
+ # Accuracy
266
+ acc = np.mean(preds == y_true)
267
+
268
+ # F1 & ROC-AUC
269
+ if self.num_classes == 2:
270
+ if self.focus_class == 1:
271
+ focus_probs = probs
272
+ else:
273
+ focus_probs = 1 - probs
274
+ f1 = f1_score(y_true, preds)
275
+ fpr, tpr, _ = roc_curve((y_true == self.focus_class).astype(int), focus_probs)
276
+ roc_auc = roc_auc_score((y_true == self.focus_class).astype(int), focus_probs)
277
+ else:
278
+ f1 = f1_score(y_true, preds, average="macro")
279
+ roc_auc = roc_auc_score(y_true, probs, multi_class="ovr", average="macro")
280
+ focus_probs = probs[:, self.focus_class]
281
+ fpr, tpr, _ = roc_curve((y_true == self.focus_class).astype(int), focus_probs)
282
+
283
+ # PR AUC for focus class
284
+ pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
285
+ pr_auc = auc(rc, pr)
286
+ pos_freq = binary_focus.mean()
287
+ pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
288
+
289
+ cm = confusion_matrix(y_true, preds)
290
+
291
+ # Save attributes for later plotting
292
+ if prefix == 'test':
293
+ self.test_roc_curve = (fpr, tpr)
294
+ self.test_pr_curve = (rc, pr)
295
+ self.test_roc_auc = roc_auc
296
+ self.test_pr_auc = pr_auc
297
+ self.test_pos_freq = pos_freq
298
+ self.test_num_pos = num_pos
299
+ self.test_acc = acc
300
+ self.test_f1 = f1
301
+ elif prefix == 'val':
302
+ pass
303
+
304
+ # Logging
305
+ self.log_dict({
306
+ f"{prefix}_acc": acc,
307
+ f"{prefix}_f1": f1,
308
+ f"{prefix}_auc": roc_auc,
309
+ f"{prefix}_pr_auc": pr_auc,
310
+ f"{prefix}_pr_auc_norm": pr_auc_norm,
311
+ f"{prefix}_pos_freq": pos_freq,
312
+ f"{prefix}_num_pos": num_pos
313
+ })
314
+ setattr(self, f"{prefix}_confusion_matrix", cm)
315
+
316
+ def _plot_roc_pr_curves(self, logits, targets):
317
+ plt.figure(figsize=(12, 5))
318
+
319
+ # ROC Curve
320
+ fpr, tpr = self.test_roc_curve
321
+ roc_auc = self.test_roc_auc
322
+ plt.subplot(1, 2, 1)
323
+ plt.plot(fpr, tpr, label=f"ROC AUC={roc_auc:.3f}")
324
+ plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
325
+ plt.xlabel("False Positive Rate")
326
+ plt.ylabel("True Positive Rate")
327
+ plt.ylim(0, 1.05)
328
+ plt.title(f"Test ROC Curve - {self.test_num_pos} positive class instances")
329
+ plt.legend()
330
+
331
+ # PR Curve
332
+ rc, pr = self.test_pr_curve
333
+ pr_auc = self.test_pr_auc
334
+ pos_freq = self.test_pos_freq
335
+ plt.subplot(1, 2, 2)
336
+ plt.plot(rc, pr, label=f"PR AUC={pr_auc:.3f}")
337
+ plt.axhline(pos_freq, linestyle='--', color="gray")
338
+ plt.xlabel("Recall")
339
+ plt.ylabel("Precision")
340
+ plt.ylim(0, 1.05)
341
+ plt.title(f"Test Precision-Recall Curve - {self.test_num_pos} positive class instances")
342
+ plt.legend()
343
+
344
+ plt.tight_layout()
345
+ plt.show()
@@ -0,0 +1,26 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .base import BaseTorchModel
4
+
5
+ class MLPClassifier(BaseTorchModel):
6
+ def __init__(self, input_dim, num_classes=2, hidden_dims=[64, 64], dropout=0.2, use_batchnorm=True, **kwargs):
7
+ super().__init__(**kwargs)
8
+ layers = []
9
+ in_dim = input_dim
10
+
11
+ for h in hidden_dims:
12
+ layers.append(nn.Linear(in_dim, h))
13
+ if use_batchnorm:
14
+ layers.append(nn.BatchNorm1d(h))
15
+ layers.append(nn.ReLU())
16
+ if dropout > 0:
17
+ layers.append(nn.Dropout(dropout))
18
+ in_dim = h
19
+
20
+ output_size = 1 if num_classes == 2 else num_classes
21
+
22
+ layers.append(nn.Linear(in_dim, output_size))
23
+ self.model = nn.Sequential(*layers)
24
+
25
+ def forward(self, x):
26
+ return self.model(x)
@@ -10,8 +10,9 @@ class PositionalEncoding(nn.Module):
10
10
  div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
11
11
  pe[:, 0::2] = torch.sin(position * div_term)
12
12
  pe[:, 1::2] = torch.cos(position * div_term)
13
- self.pe = pe.unsqueeze(0) # (1, max_len, d_model)
13
+ pe = pe.unsqueeze(0) # (1, max_len, d_model)
14
+ self.register_buffer("pe", pe)
14
15
 
15
16
  def forward(self, x):
16
- x = x + self.pe[:, :x.size(1)].to(x.device)
17
+ x = x + self.pe[:, :x.size(1)]
17
18
  return x
@@ -8,7 +8,8 @@ class RNNClassifier(BaseTorchModel):
8
8
  # Define LSTM layer
9
9
  self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
10
10
  # Define fully connected output layer
11
- self.fc = nn.Linear(hidden_dim, num_classes)
11
+ output_size = 1 if num_classes == 2 else num_classes
12
+ self.fc = nn.Linear(hidden_dim, output_size)
12
13
 
13
14
  def forward(self, x):
14
15
  x = x.unsqueeze(1) # [B, 1, L] → for LSTM expecting batch_first
@@ -0,0 +1,273 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from sklearn.metrics import (
4
+ roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
5
+ )
6
+
7
+ class SklearnModelWrapper:
8
+ """
9
+ Unified sklearn wrapper matching TorchClassifierWrapper interface.
10
+ """
11
+ def __init__(
12
+ self,
13
+ model,
14
+ label_col: str,
15
+ num_classes: int,
16
+ class_names=None,
17
+ focus_class: int=1,
18
+ enforce_eval_balance: bool=False,
19
+ target_eval_freq: float=0.3,
20
+ max_eval_positive=None
21
+ ):
22
+ self.model = model
23
+ self.label_col = label_col
24
+ self.num_classes = num_classes
25
+ self.class_names = class_names
26
+ self.focus_class = self._resolve_focus_class(focus_class)
27
+ self.focus_class_name = focus_class
28
+ self.enforce_eval_balance = enforce_eval_balance
29
+ self.target_eval_freq = target_eval_freq
30
+ self.max_eval_positive = max_eval_positive
31
+ self.metrics = {}
32
+
33
+ def _resolve_focus_class(self, focus_class):
34
+ if isinstance(focus_class, int):
35
+ return focus_class
36
+ elif isinstance(focus_class, str):
37
+ if self.class_names is None:
38
+ raise ValueError("class_names must be provided if focus_class is a string.")
39
+ if focus_class not in self.class_names:
40
+ raise ValueError(f"focus_class '{focus_class}' not found in class_names {self.class_names}.")
41
+ return self.class_names.index(focus_class)
42
+ else:
43
+ raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
44
+
45
+ def fit(self, X, y):
46
+ self.model.fit(X, y)
47
+
48
+ def predict(self, X):
49
+ return self.model.predict(X)
50
+
51
+ def predict_proba(self, X):
52
+ return self.model.predict_proba(X)
53
+
54
+ def _subsample_for_fixed_positive_frequency(self, y_true):
55
+ pos_idx = np.where(y_true == self.focus_class)[0]
56
+ neg_idx = np.where(y_true != self.focus_class)[0]
57
+
58
+ max_neg = len(neg_idx)
59
+ max_pos = len(pos_idx)
60
+ max_possible_freq = max_pos / (max_pos + max_neg)
61
+
62
+ target_freq = min(self.target_eval_freq, max_possible_freq)
63
+ num_pos_target = min(int(target_freq * max_neg / (1 - target_freq)), max_pos)
64
+ num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
65
+ num_neg_target = min(num_neg_target, max_neg)
66
+
67
+ if self.max_eval_positive is not None:
68
+ num_pos_target = min(num_pos_target, self.max_eval_positive)
69
+
70
+ pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
71
+ neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
72
+
73
+ sampled_idx = np.concatenate([pos_sampled, neg_sampled])
74
+ np.random.shuffle(sampled_idx)
75
+
76
+ return sampled_idx
77
+
78
+ def evaluate(self, X, y, prefix="test"):
79
+ y_true = y
80
+ y_prob = self.predict_proba(X)
81
+ y_pred = self.predict(X)
82
+
83
+ if self.enforce_eval_balance:
84
+ sampled_idx = self._subsample_for_fixed_positive_frequency(y_true)
85
+ y_true = y_true[sampled_idx]
86
+ y_prob = y_prob[sampled_idx]
87
+ y_pred = y_pred[sampled_idx]
88
+
89
+ binary_focus = (y_true == self.focus_class).astype(int)
90
+ num_pos = binary_focus.sum()
91
+
92
+ is_binary = self.num_classes == 2
93
+
94
+ if is_binary:
95
+ if self.focus_class == 1:
96
+ focus_probs = y_prob[:, 1]
97
+ else:
98
+ focus_probs = y_prob[:, 0]
99
+ preds_focus = (y_pred == self.focus_class).astype(int)
100
+ else:
101
+ focus_probs = y_prob[:, self.focus_class]
102
+ preds_focus = (y_pred == self.focus_class).astype(int)
103
+
104
+ f1 = f1_score(binary_focus, preds_focus)
105
+ roc_auc = roc_auc_score(binary_focus, focus_probs)
106
+ pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
107
+ pr_auc = auc(rc, pr)
108
+ pos_freq = binary_focus.mean()
109
+ pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
110
+ fpr, tpr, _ = roc_curve(binary_focus, focus_probs)
111
+ cm = confusion_matrix(y_true, y_pred)
112
+ acc = np.mean(y_pred == y_true)
113
+
114
+ # store metrics as attributes for plotting later
115
+ setattr(self, f"{prefix}_f1", f1)
116
+ setattr(self, f"{prefix}_roc_curve", (fpr, tpr))
117
+ setattr(self, f"{prefix}_pr_curve", (rc, pr))
118
+ setattr(self, f"{prefix}_roc_auc", roc_auc)
119
+ setattr(self, f"{prefix}_pr_auc", pr_auc)
120
+ setattr(self, f"{prefix}_pos_freq", pos_freq)
121
+ setattr(self, f"{prefix}_num_pos", num_pos)
122
+ setattr(self, f"{prefix}_confusion_matrix", cm)
123
+ setattr(self, f"{prefix}_acc", acc)
124
+
125
+ # also store a metrics dict
126
+ self.metrics = {
127
+ f"{prefix}_acc": acc,
128
+ f"{prefix}_f1": f1,
129
+ f"{prefix}_auc": roc_auc,
130
+ f"{prefix}_pr_auc": pr_auc,
131
+ f"{prefix}_pr_auc_norm": pr_auc_norm,
132
+ f"{prefix}_pos_freq": pos_freq,
133
+ f"{prefix}_num_pos": num_pos
134
+ }
135
+
136
+ return self.metrics
137
+
138
+ def plot_roc_pr_curves(self, prefix="test"):
139
+ plt.figure(figsize=(12, 5))
140
+
141
+ fpr, tpr = getattr(self, f"{prefix}_roc_curve")
142
+ roc_auc = getattr(self, f"{prefix}_roc_auc")
143
+ plt.subplot(1, 2, 1)
144
+ plt.plot(fpr, tpr, label=f"ROC AUC={roc_auc:.3f}")
145
+ plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
146
+ plt.xlabel("False Positive Rate")
147
+ plt.ylabel("True Positive Rate")
148
+ plt.ylim(0, 1.05)
149
+ plt.title(f"ROC Curve - {getattr(self, f'{prefix}_num_pos')} positives")
150
+ plt.legend()
151
+
152
+ rc, pr = getattr(self, f"{prefix}_pr_curve")
153
+ pr_auc = getattr(self, f"{prefix}_pr_auc")
154
+ pos_freq = getattr(self, f"{prefix}_pos_freq")
155
+ plt.subplot(1, 2, 2)
156
+ plt.plot(rc, pr, label=f"PR AUC={pr_auc:.3f}")
157
+ plt.axhline(pos_freq, linestyle="--", color="gray")
158
+ plt.xlabel("Recall")
159
+ plt.ylabel("Precision")
160
+ plt.ylim(0, 1.05)
161
+ plt.title(f"PR Curve - {getattr(self, f'{prefix}_num_pos')} positives")
162
+ plt.legend()
163
+
164
+ plt.tight_layout()
165
+ plt.show()
166
+
167
+ def fit_from_datamodule(self, datamodule):
168
+ datamodule.setup()
169
+ X_tensor, y_tensor = datamodule.train_set.dataset.X_tensor, datamodule.train_set.dataset.y_tensor
170
+ indices = datamodule.train_set.indices
171
+ X_train = X_tensor[indices].numpy()
172
+ y_train = y_tensor[indices].numpy()
173
+ self.fit(X_train, y_train)
174
+ self.train_obs_names = datamodule.adata.obs_names[datamodule.train_set.indices].tolist()
175
+ self.val_obs_names = datamodule.adata.obs_names[datamodule.val_set.indices].tolist()
176
+ self.test_obs_names = datamodule.adata.obs_names[datamodule.test_set.indices].tolist()
177
+
178
+ def evaluate_from_datamodule(self, datamodule, split="test"):
179
+ datamodule.setup()
180
+ if split == "val":
181
+ subset = datamodule.val_set
182
+ elif split == "test":
183
+ subset = datamodule.test_set
184
+ else:
185
+ raise ValueError(f"Invalid split '{split}'")
186
+
187
+ X_tensor, y_tensor = subset.dataset.X_tensor, subset.dataset.y_tensor
188
+ indices = subset.indices
189
+ X_eval = X_tensor[indices].numpy()
190
+ y_eval = y_tensor[indices].numpy()
191
+
192
+ return self.evaluate(X_eval, y_eval, prefix=split)
193
+
194
+ def compute_shap(self, X, background=None, nsamples=100, target_class=None):
195
+ """
196
+ Compute SHAP values on input X, optionally for a specified target class.
197
+
198
+ Parameters
199
+ ----------
200
+ X : array-like
201
+ Input features
202
+ background : array-like
203
+ SHAP background
204
+ nsamples : int
205
+ Number of samples for kernel approximation
206
+ target_class : int, optional
207
+ If None, uses model predicted class
208
+ """
209
+ import shap
210
+
211
+ # choose explainer
212
+ if hasattr(self.model, "tree_") or hasattr(self.model, "estimators_"):
213
+ explainer = shap.TreeExplainer(self.model, data=background)
214
+ else:
215
+ if background is None:
216
+ background = shap.kmeans(X, 10)
217
+ explainer = shap.KernelExplainer(self.model.predict_proba, background)
218
+
219
+ # determine class
220
+ if target_class is None:
221
+ preds = self.model.predict(X)
222
+ target_class = preds
223
+
224
+ if isinstance(explainer, shap.TreeExplainer):
225
+ shap_values = explainer.shap_values(X)
226
+ else:
227
+ shap_values = explainer.shap_values(X, nsamples=nsamples)
228
+
229
+ if isinstance(shap_values, np.ndarray):
230
+ if shap_values.ndim == 3:
231
+ if isinstance(target_class, int):
232
+ return shap_values[:, :, target_class]
233
+ elif isinstance(target_class, np.ndarray):
234
+ # target_class is per-sample
235
+ if np.any(target_class >= shap_values.shape[2]):
236
+ raise ValueError(f"target_class values exceed {shap_values.shape[2]}")
237
+ selected = np.array([
238
+ shap_values[i, :, c]
239
+ for i, c in enumerate(target_class)
240
+ ])
241
+ return selected
242
+ else:
243
+ # fallback to class 0
244
+ return shap_values[:, :, 0]
245
+ else:
246
+ # 2D shape (samples, features), no class dimension
247
+ return shap_values
248
+
249
+ def apply_shap_to_adata(self, dataloader, adata, background=None, adata_key="shap_values", target_class=None, normalize=True):
250
+ """
251
+ Compute SHAP from a DataLoader and store in AnnData if provided.
252
+ """
253
+ X_batches = []
254
+
255
+ for batch in dataloader:
256
+ X = batch[0].detach().cpu().numpy()
257
+ X_batches.append(X)
258
+
259
+ X_full = np.concatenate(X_batches, axis=0)
260
+
261
+ shap_values = self.compute_shap(X_full, background=background, target_class=target_class)
262
+
263
+ if adata is not None:
264
+ adata.obsm[adata_key] = shap_values
265
+
266
+ if normalize:
267
+ arr = shap_values
268
+ # row-wise normalization
269
+ row_max = np.max(np.abs(arr), axis=1, keepdims=True)
270
+ row_max[row_max == 0] = 1 # avoid divide by zero
271
+ normalized = arr / row_max
272
+
273
+ adata.obsm[f"{adata_key}_normalized"] = normalized