smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +7 -6
- smftools/_version.py +1 -1
- smftools/cli/cli_flows.py +94 -0
- smftools/cli/hmm_adata.py +338 -0
- smftools/cli/load_adata.py +577 -0
- smftools/cli/preprocess_adata.py +363 -0
- smftools/cli/spatial_adata.py +564 -0
- smftools/cli_entry.py +435 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +38 -0
- smftools/config/deaminase.yaml +61 -0
- smftools/config/default.yaml +264 -0
- smftools/config/direct.yaml +41 -0
- smftools/config/discover_input_files.py +115 -0
- smftools/config/experiment_config.py +1288 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +13 -9
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/fast5_to_pod5.py +43 -0
- smftools/informatics/archived/helpers/archived/__init__.py +71 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
- smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
- smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
- smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
- smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
- smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
- smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
- smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
- smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
- smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
- smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
- smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
- smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
- smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
- smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
- smftools/informatics/bam_functions.py +812 -0
- smftools/informatics/basecalling.py +67 -0
- smftools/informatics/bed_functions.py +366 -0
- smftools/informatics/binarize_converted_base_identities.py +172 -0
- smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
- smftools/informatics/fasta_functions.py +255 -0
- smftools/informatics/h5ad_functions.py +197 -0
- smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
- smftools/informatics/modkit_functions.py +129 -0
- smftools/informatics/ohe.py +160 -0
- smftools/informatics/pod5_functions.py +224 -0
- smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +609 -0
- smftools/plotting/general_plotting.py +1292 -140
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +15 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/binarize.py +17 -0
- smftools/preprocessing/binarize_on_Youden.py +2 -2
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_position_Youden.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +1021 -89
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
- smftools-0.2.3.dist-info/RECORD +173 -0
- smftools-0.2.3.dist-info/entry_points.txt +2 -0
- smftools/informatics/fast5_to_pod5.py +0 -21
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/__init__.py +0 -74
- smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
- smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
- smftools/informatics/helpers/bam_qc.py +0 -66
- smftools/informatics/helpers/bed_to_bigwig.py +0 -39
- smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
- smftools/informatics/helpers/index_fasta.py +0 -12
- smftools/informatics/helpers/make_dirs.py +0 -21
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/informatics/readwrite.py +0 -106
- smftools/informatics/subsample_fasta_from_bed.py +0 -47
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools/evaluation → cli}/__init__.py +0 -0
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
- /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
- /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pytorch_lightning as pl
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from sklearn.metrics import (
|
|
5
|
+
roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
|
|
6
|
+
)
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
class TorchClassifierWrapper(pl.LightningModule):
|
|
10
|
+
"""
|
|
11
|
+
A Pytorch Lightning wrapper for PyTorch classifiers.
|
|
12
|
+
- Takes a PyTorch model as input.
|
|
13
|
+
- Number of classes should be passed.
|
|
14
|
+
- Optimizer is set as default to AdamW without any keyword arguments.
|
|
15
|
+
- Loss criterion is automatically detected based on if it's a binary of multi-class classifier.
|
|
16
|
+
- Can pass the index of the class label to use as the focus class when calculating precision/recall.
|
|
17
|
+
- Contains a prediction step to run inference with.
|
|
18
|
+
"""
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model: torch.nn.Module,
|
|
22
|
+
label_col: str,
|
|
23
|
+
num_classes: int,
|
|
24
|
+
class_names: list=None,
|
|
25
|
+
optimizer_cls=torch.optim.AdamW,
|
|
26
|
+
optimizer_kwargs=None,
|
|
27
|
+
criterion_kwargs=None,
|
|
28
|
+
lr: float = 1e-3,
|
|
29
|
+
focus_class: int = 1, # used for binary or multiclass precision-recall
|
|
30
|
+
class_weights=None,
|
|
31
|
+
enforce_eval_balance: bool=False,
|
|
32
|
+
target_eval_freq: float=0.3,
|
|
33
|
+
max_eval_positive: int=None
|
|
34
|
+
):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.model = model
|
|
37
|
+
self.save_hyperparameters(ignore=['model']) # logs all except actual model instance
|
|
38
|
+
self.optimizer_cls = optimizer_cls
|
|
39
|
+
self.optimizer_kwargs = optimizer_kwargs or {"weight_decay": 1e-4}
|
|
40
|
+
self.criterion = None
|
|
41
|
+
self.lr = lr
|
|
42
|
+
self.label_col = label_col
|
|
43
|
+
self.num_classes = num_classes
|
|
44
|
+
self.class_names = class_names
|
|
45
|
+
self.focus_class = self._resolve_focus_class(focus_class)
|
|
46
|
+
self.focus_class_name = focus_class
|
|
47
|
+
self.enforce_eval_balance = enforce_eval_balance
|
|
48
|
+
self.target_eval_freq = target_eval_freq
|
|
49
|
+
self.max_eval_positive = max_eval_positive
|
|
50
|
+
|
|
51
|
+
# Handle class weights
|
|
52
|
+
self.criterion_kwargs = criterion_kwargs or {}
|
|
53
|
+
|
|
54
|
+
if class_weights is not None:
|
|
55
|
+
if num_classes == 2:
|
|
56
|
+
# BCEWithLogits uses pos_weight, expects a scalar or tensor
|
|
57
|
+
if torch.is_tensor(class_weights[self.focus_class]):
|
|
58
|
+
self.criterion_kwargs["pos_weight"] = class_weights[self.focus_class]
|
|
59
|
+
else:
|
|
60
|
+
self.criterion_kwargs["pos_weight"] = torch.tensor(class_weights[self.focus_class], dtype=torch.float32, device=self.device)
|
|
61
|
+
else:
|
|
62
|
+
# CrossEntropyLoss expects weight tensor of size C
|
|
63
|
+
if torch.is_tensor(class_weights):
|
|
64
|
+
self.criterion_kwargs["weight"] = class_weights
|
|
65
|
+
else:
|
|
66
|
+
self.criterion_kwargs["weight"] = torch.tensor(class_weights, dtype=torch.float32)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
self._val_outputs = []
|
|
70
|
+
self._test_outputs = []
|
|
71
|
+
|
|
72
|
+
def setup(self, stage=None):
|
|
73
|
+
"""
|
|
74
|
+
Sets the loss criterion.
|
|
75
|
+
"""
|
|
76
|
+
if self.criterion is None and self.num_classes is not None:
|
|
77
|
+
self._init_criterion()
|
|
78
|
+
|
|
79
|
+
def _init_criterion(self):
|
|
80
|
+
if self.num_classes == 2:
|
|
81
|
+
if "pos_weight" in self.criterion_kwargs and not torch.is_tensor(self.criterion_kwargs["pos_weight"]):
|
|
82
|
+
self.criterion_kwargs["pos_weight"] = torch.tensor(self.criterion_kwargs["pos_weight"], dtype=torch.float32, device=self.device)
|
|
83
|
+
self.criterion = torch.nn.BCEWithLogitsLoss(**self.criterion_kwargs)
|
|
84
|
+
else:
|
|
85
|
+
if "weight" in self.criterion_kwargs and not torch.is_tensor(self.criterion_kwargs["weight"]):
|
|
86
|
+
self.criterion_kwargs["weight"] = torch.tensor(self.criterion_kwargs["weight"], dtype=torch.float32, device=self.device)
|
|
87
|
+
self.criterion = torch.nn.CrossEntropyLoss(**self.criterion_kwargs)
|
|
88
|
+
|
|
89
|
+
def _resolve_focus_class(self, focus_class):
|
|
90
|
+
if isinstance(focus_class, int):
|
|
91
|
+
return focus_class
|
|
92
|
+
elif isinstance(focus_class, str):
|
|
93
|
+
if self.class_names is None:
|
|
94
|
+
raise ValueError("class_names must be provided if focus_class is a string.")
|
|
95
|
+
if focus_class not in self.class_names:
|
|
96
|
+
raise ValueError(f"focus_class '{focus_class}' not found in class_names {self.class_names}.")
|
|
97
|
+
return self.class_names.index(focus_class)
|
|
98
|
+
else:
|
|
99
|
+
raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
|
|
100
|
+
|
|
101
|
+
def set_training_indices(self, datamodule):
|
|
102
|
+
"""
|
|
103
|
+
Store obs_names for train/val/test subsets used during training.
|
|
104
|
+
"""
|
|
105
|
+
self.train_obs_names = datamodule.adata.obs_names[datamodule.train_set.indices].tolist()
|
|
106
|
+
self.val_obs_names = datamodule.adata.obs_names[datamodule.val_set.indices].tolist()
|
|
107
|
+
self.test_obs_names = datamodule.adata.obs_names[datamodule.test_set.indices].tolist()
|
|
108
|
+
|
|
109
|
+
def configure_optimizers(self):
|
|
110
|
+
return self.optimizer_cls(self.parameters(), lr=self.lr, **self.optimizer_kwargs)
|
|
111
|
+
|
|
112
|
+
def forward(self, x):
|
|
113
|
+
"""
|
|
114
|
+
Forward pass through the model.
|
|
115
|
+
"""
|
|
116
|
+
return self.model(x)
|
|
117
|
+
|
|
118
|
+
def training_step(self, batch, batch_idx):
|
|
119
|
+
"""
|
|
120
|
+
Training step for a batch through the Lightning Trainer.
|
|
121
|
+
"""
|
|
122
|
+
x, y = batch
|
|
123
|
+
if self.num_classes is None:
|
|
124
|
+
self.num_classes = int(torch.max(y).item()) + 1
|
|
125
|
+
self._init_criterion()
|
|
126
|
+
logits = self(x)
|
|
127
|
+
loss = self._compute_loss(logits, y)
|
|
128
|
+
self.log("train_loss", loss, prog_bar=False)
|
|
129
|
+
return loss
|
|
130
|
+
|
|
131
|
+
def validation_step(self, batch, batch_idx):
|
|
132
|
+
"""
|
|
133
|
+
Validation step for a batch through the Lightning Trainer.
|
|
134
|
+
"""
|
|
135
|
+
x, y = batch
|
|
136
|
+
logits = self(x)
|
|
137
|
+
loss = self._compute_loss(logits, y)
|
|
138
|
+
preds = self._get_preds(logits)
|
|
139
|
+
acc = (preds == y).float().mean()
|
|
140
|
+
self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=False)
|
|
141
|
+
self._val_outputs.append((logits.detach(), y.detach()))
|
|
142
|
+
return loss
|
|
143
|
+
|
|
144
|
+
def test_step(self, batch, batch_idx):
|
|
145
|
+
"""
|
|
146
|
+
Test step for a batch through the Lightning Trainer.
|
|
147
|
+
"""
|
|
148
|
+
x, y = batch
|
|
149
|
+
logits = self(x)
|
|
150
|
+
self._test_outputs.append((logits.detach(), y.detach()))
|
|
151
|
+
|
|
152
|
+
def predict_step(self, batch, batch_idx):
|
|
153
|
+
"""
|
|
154
|
+
Gets predictions and prediction probabilities for the batch using the trained Lightning model.
|
|
155
|
+
"""
|
|
156
|
+
x = batch[0]
|
|
157
|
+
logits = self(x)
|
|
158
|
+
probs = self._get_probs(logits)
|
|
159
|
+
preds = self._get_preds(logits)
|
|
160
|
+
return preds, probs
|
|
161
|
+
|
|
162
|
+
def on_validation_epoch_end(self):
|
|
163
|
+
"""
|
|
164
|
+
Final logging of all validation steps
|
|
165
|
+
"""
|
|
166
|
+
if not self._val_outputs:
|
|
167
|
+
return
|
|
168
|
+
logits, targets = zip(*self._val_outputs)
|
|
169
|
+
self._val_outputs.clear()
|
|
170
|
+
self._log_classification_metrics(logits, targets, prefix="val")
|
|
171
|
+
|
|
172
|
+
def on_test_epoch_end(self):
|
|
173
|
+
"""
|
|
174
|
+
Final logging of all testing steps
|
|
175
|
+
"""
|
|
176
|
+
if not self._test_outputs:
|
|
177
|
+
return
|
|
178
|
+
logits, targets = zip(*self._test_outputs)
|
|
179
|
+
self._test_outputs.clear()
|
|
180
|
+
self._log_classification_metrics(logits, targets, prefix="test")
|
|
181
|
+
self._plot_roc_pr_curves(logits, targets)
|
|
182
|
+
|
|
183
|
+
def _compute_loss(self, logits, y):
|
|
184
|
+
"""
|
|
185
|
+
A helper function for computing loss for binary vs multiclass classifications.
|
|
186
|
+
"""
|
|
187
|
+
if self.num_classes == 2:
|
|
188
|
+
y = y.float().view(-1, 1) # shape [B, 1]
|
|
189
|
+
return self.criterion(logits.view(-1, 1), y)
|
|
190
|
+
else:
|
|
191
|
+
return self.criterion(logits, y)
|
|
192
|
+
|
|
193
|
+
def _get_probs(self, logits):
|
|
194
|
+
"""
|
|
195
|
+
A helper function for getting class probabilities for binary vs multiclass classifications.
|
|
196
|
+
"""
|
|
197
|
+
if self.num_classes == 2:
|
|
198
|
+
return torch.sigmoid(logits.view(-1))
|
|
199
|
+
else:
|
|
200
|
+
return torch.softmax(logits, dim=1)
|
|
201
|
+
|
|
202
|
+
def _get_preds(self, logits):
|
|
203
|
+
"""
|
|
204
|
+
A helper function for getting class predictions for binary vs multiclass classifications.
|
|
205
|
+
"""
|
|
206
|
+
if self.num_classes == 2:
|
|
207
|
+
return (torch.sigmoid(logits.view(-1)) >= 0.5).long()
|
|
208
|
+
else:
|
|
209
|
+
return logits.argmax(dim=1)
|
|
210
|
+
|
|
211
|
+
def _subsample_for_fixed_positive_frequency(self, y_true, probs, target_freq=0.3, max_positive=None):
|
|
212
|
+
pos_idx = np.where(y_true == self.focus_class)[0]
|
|
213
|
+
neg_idx = np.where(y_true != self.focus_class)[0]
|
|
214
|
+
|
|
215
|
+
max_negatives_possible = len(neg_idx)
|
|
216
|
+
max_positives_possible = len(pos_idx)
|
|
217
|
+
|
|
218
|
+
# maximum achievable positive class frequency
|
|
219
|
+
max_possible_freq = max_positives_possible / (max_positives_possible + max_negatives_possible)
|
|
220
|
+
|
|
221
|
+
if target_freq > max_possible_freq:
|
|
222
|
+
target_freq = max_possible_freq # clip if you ask for impossible freq
|
|
223
|
+
|
|
224
|
+
# now calculate positive count
|
|
225
|
+
num_pos_target = min(int(target_freq * max_negatives_possible / (1 - target_freq)), max_positives_possible)
|
|
226
|
+
num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
|
|
227
|
+
num_neg_target = min(num_neg_target, max_negatives_possible)
|
|
228
|
+
|
|
229
|
+
pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
|
|
230
|
+
neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
|
|
231
|
+
|
|
232
|
+
sampled_idx = np.concatenate([pos_sampled, neg_sampled])
|
|
233
|
+
np.random.shuffle(sampled_idx)
|
|
234
|
+
|
|
235
|
+
actual_freq = len(pos_sampled) / len(sampled_idx)
|
|
236
|
+
|
|
237
|
+
return sampled_idx
|
|
238
|
+
|
|
239
|
+
def _log_classification_metrics(self, logits, targets, prefix="val"):
|
|
240
|
+
"""
|
|
241
|
+
A helper function for logging validation and testing split model evaluations.
|
|
242
|
+
"""
|
|
243
|
+
logits = torch.cat(logits).cpu()
|
|
244
|
+
y_true = torch.cat(targets).cpu().numpy()
|
|
245
|
+
|
|
246
|
+
probs = self._get_probs(logits).numpy()
|
|
247
|
+
preds = self._get_preds(logits).cpu().numpy()
|
|
248
|
+
|
|
249
|
+
# remap binary focus class correctly:
|
|
250
|
+
binary_focus = (y_true == self.focus_class).astype(int)
|
|
251
|
+
|
|
252
|
+
num_pos = binary_focus.sum()
|
|
253
|
+
|
|
254
|
+
# Subsample if you want to enforce a fixed proportion of the positive class
|
|
255
|
+
if prefix == 'test' and self.enforce_eval_balance:
|
|
256
|
+
sampled_idx = self._subsample_for_fixed_positive_frequency(
|
|
257
|
+
y_true, probs, target_freq=self.target_eval_freq, max_positive=self.max_eval_positive
|
|
258
|
+
)
|
|
259
|
+
y_true = y_true[sampled_idx]
|
|
260
|
+
probs = probs[sampled_idx]
|
|
261
|
+
preds = preds[sampled_idx]
|
|
262
|
+
binary_focus = (y_true == self.focus_class).astype(int)
|
|
263
|
+
num_pos = binary_focus.sum()
|
|
264
|
+
|
|
265
|
+
# Accuracy
|
|
266
|
+
acc = np.mean(preds == y_true)
|
|
267
|
+
|
|
268
|
+
# F1 & ROC-AUC
|
|
269
|
+
if self.num_classes == 2:
|
|
270
|
+
if self.focus_class == 1:
|
|
271
|
+
focus_probs = probs
|
|
272
|
+
else:
|
|
273
|
+
focus_probs = 1 - probs
|
|
274
|
+
f1 = f1_score(y_true, preds)
|
|
275
|
+
fpr, tpr, _ = roc_curve((y_true == self.focus_class).astype(int), focus_probs)
|
|
276
|
+
roc_auc = roc_auc_score((y_true == self.focus_class).astype(int), focus_probs)
|
|
277
|
+
else:
|
|
278
|
+
f1 = f1_score(y_true, preds, average="macro")
|
|
279
|
+
roc_auc = roc_auc_score(y_true, probs, multi_class="ovr", average="macro")
|
|
280
|
+
focus_probs = probs[:, self.focus_class]
|
|
281
|
+
fpr, tpr, _ = roc_curve((y_true == self.focus_class).astype(int), focus_probs)
|
|
282
|
+
|
|
283
|
+
# PR AUC for focus class
|
|
284
|
+
pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
|
|
285
|
+
pr_auc = auc(rc, pr)
|
|
286
|
+
pos_freq = binary_focus.mean()
|
|
287
|
+
pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
|
|
288
|
+
|
|
289
|
+
cm = confusion_matrix(y_true, preds)
|
|
290
|
+
|
|
291
|
+
# Save attributes for later plotting
|
|
292
|
+
if prefix == 'test':
|
|
293
|
+
self.test_roc_curve = (fpr, tpr)
|
|
294
|
+
self.test_pr_curve = (rc, pr)
|
|
295
|
+
self.test_roc_auc = roc_auc
|
|
296
|
+
self.test_pr_auc = pr_auc
|
|
297
|
+
self.test_pos_freq = pos_freq
|
|
298
|
+
self.test_num_pos = num_pos
|
|
299
|
+
self.test_acc = acc
|
|
300
|
+
self.test_f1 = f1
|
|
301
|
+
elif prefix == 'val':
|
|
302
|
+
pass
|
|
303
|
+
|
|
304
|
+
# Logging
|
|
305
|
+
self.log_dict({
|
|
306
|
+
f"{prefix}_acc": acc,
|
|
307
|
+
f"{prefix}_f1": f1,
|
|
308
|
+
f"{prefix}_auc": roc_auc,
|
|
309
|
+
f"{prefix}_pr_auc": pr_auc,
|
|
310
|
+
f"{prefix}_pr_auc_norm": pr_auc_norm,
|
|
311
|
+
f"{prefix}_pos_freq": pos_freq,
|
|
312
|
+
f"{prefix}_num_pos": num_pos
|
|
313
|
+
})
|
|
314
|
+
setattr(self, f"{prefix}_confusion_matrix", cm)
|
|
315
|
+
|
|
316
|
+
def _plot_roc_pr_curves(self, logits, targets):
|
|
317
|
+
plt.figure(figsize=(12, 5))
|
|
318
|
+
|
|
319
|
+
# ROC Curve
|
|
320
|
+
fpr, tpr = self.test_roc_curve
|
|
321
|
+
roc_auc = self.test_roc_auc
|
|
322
|
+
plt.subplot(1, 2, 1)
|
|
323
|
+
plt.plot(fpr, tpr, label=f"ROC AUC={roc_auc:.3f}")
|
|
324
|
+
plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
|
|
325
|
+
plt.xlabel("False Positive Rate")
|
|
326
|
+
plt.ylabel("True Positive Rate")
|
|
327
|
+
plt.ylim(0, 1.05)
|
|
328
|
+
plt.title(f"Test ROC Curve - {self.test_num_pos} positive class instances")
|
|
329
|
+
plt.legend()
|
|
330
|
+
|
|
331
|
+
# PR Curve
|
|
332
|
+
rc, pr = self.test_pr_curve
|
|
333
|
+
pr_auc = self.test_pr_auc
|
|
334
|
+
pos_freq = self.test_pos_freq
|
|
335
|
+
plt.subplot(1, 2, 2)
|
|
336
|
+
plt.plot(rc, pr, label=f"PR AUC={pr_auc:.3f}")
|
|
337
|
+
plt.axhline(pos_freq, linestyle='--', color="gray")
|
|
338
|
+
plt.xlabel("Recall")
|
|
339
|
+
plt.ylabel("Precision")
|
|
340
|
+
plt.ylim(0, 1.05)
|
|
341
|
+
plt.title(f"Test Precision-Recall Curve - {self.test_num_pos} positive class instances")
|
|
342
|
+
plt.legend()
|
|
343
|
+
|
|
344
|
+
plt.tight_layout()
|
|
345
|
+
plt.show()
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .base import BaseTorchModel
|
|
4
|
+
|
|
5
|
+
class MLPClassifier(BaseTorchModel):
|
|
6
|
+
def __init__(self, input_dim, num_classes=2, hidden_dims=[64, 64], dropout=0.2, use_batchnorm=True, **kwargs):
|
|
7
|
+
super().__init__(**kwargs)
|
|
8
|
+
layers = []
|
|
9
|
+
in_dim = input_dim
|
|
10
|
+
|
|
11
|
+
for h in hidden_dims:
|
|
12
|
+
layers.append(nn.Linear(in_dim, h))
|
|
13
|
+
if use_batchnorm:
|
|
14
|
+
layers.append(nn.BatchNorm1d(h))
|
|
15
|
+
layers.append(nn.ReLU())
|
|
16
|
+
if dropout > 0:
|
|
17
|
+
layers.append(nn.Dropout(dropout))
|
|
18
|
+
in_dim = h
|
|
19
|
+
|
|
20
|
+
output_size = 1 if num_classes == 2 else num_classes
|
|
21
|
+
|
|
22
|
+
layers.append(nn.Linear(in_dim, output_size))
|
|
23
|
+
self.model = nn.Sequential(*layers)
|
|
24
|
+
|
|
25
|
+
def forward(self, x):
|
|
26
|
+
return self.model(x)
|
|
@@ -10,8 +10,9 @@ class PositionalEncoding(nn.Module):
|
|
|
10
10
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
|
11
11
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
12
12
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
13
|
-
|
|
13
|
+
pe = pe.unsqueeze(0) # (1, max_len, d_model)
|
|
14
|
+
self.register_buffer("pe", pe)
|
|
14
15
|
|
|
15
16
|
def forward(self, x):
|
|
16
|
-
x = x + self.pe[:, :x.size(1)]
|
|
17
|
+
x = x + self.pe[:, :x.size(1)]
|
|
17
18
|
return x
|
|
@@ -8,7 +8,8 @@ class RNNClassifier(BaseTorchModel):
|
|
|
8
8
|
# Define LSTM layer
|
|
9
9
|
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
|
|
10
10
|
# Define fully connected output layer
|
|
11
|
-
|
|
11
|
+
output_size = 1 if num_classes == 2 else num_classes
|
|
12
|
+
self.fc = nn.Linear(hidden_dim, output_size)
|
|
12
13
|
|
|
13
14
|
def forward(self, x):
|
|
14
15
|
x = x.unsqueeze(1) # [B, 1, L] → for LSTM expecting batch_first
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
from sklearn.metrics import (
|
|
4
|
+
roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
class SklearnModelWrapper:
|
|
8
|
+
"""
|
|
9
|
+
Unified sklearn wrapper matching TorchClassifierWrapper interface.
|
|
10
|
+
"""
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
model,
|
|
14
|
+
label_col: str,
|
|
15
|
+
num_classes: int,
|
|
16
|
+
class_names=None,
|
|
17
|
+
focus_class: int=1,
|
|
18
|
+
enforce_eval_balance: bool=False,
|
|
19
|
+
target_eval_freq: float=0.3,
|
|
20
|
+
max_eval_positive=None
|
|
21
|
+
):
|
|
22
|
+
self.model = model
|
|
23
|
+
self.label_col = label_col
|
|
24
|
+
self.num_classes = num_classes
|
|
25
|
+
self.class_names = class_names
|
|
26
|
+
self.focus_class = self._resolve_focus_class(focus_class)
|
|
27
|
+
self.focus_class_name = focus_class
|
|
28
|
+
self.enforce_eval_balance = enforce_eval_balance
|
|
29
|
+
self.target_eval_freq = target_eval_freq
|
|
30
|
+
self.max_eval_positive = max_eval_positive
|
|
31
|
+
self.metrics = {}
|
|
32
|
+
|
|
33
|
+
def _resolve_focus_class(self, focus_class):
|
|
34
|
+
if isinstance(focus_class, int):
|
|
35
|
+
return focus_class
|
|
36
|
+
elif isinstance(focus_class, str):
|
|
37
|
+
if self.class_names is None:
|
|
38
|
+
raise ValueError("class_names must be provided if focus_class is a string.")
|
|
39
|
+
if focus_class not in self.class_names:
|
|
40
|
+
raise ValueError(f"focus_class '{focus_class}' not found in class_names {self.class_names}.")
|
|
41
|
+
return self.class_names.index(focus_class)
|
|
42
|
+
else:
|
|
43
|
+
raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
|
|
44
|
+
|
|
45
|
+
def fit(self, X, y):
|
|
46
|
+
self.model.fit(X, y)
|
|
47
|
+
|
|
48
|
+
def predict(self, X):
|
|
49
|
+
return self.model.predict(X)
|
|
50
|
+
|
|
51
|
+
def predict_proba(self, X):
|
|
52
|
+
return self.model.predict_proba(X)
|
|
53
|
+
|
|
54
|
+
def _subsample_for_fixed_positive_frequency(self, y_true):
|
|
55
|
+
pos_idx = np.where(y_true == self.focus_class)[0]
|
|
56
|
+
neg_idx = np.where(y_true != self.focus_class)[0]
|
|
57
|
+
|
|
58
|
+
max_neg = len(neg_idx)
|
|
59
|
+
max_pos = len(pos_idx)
|
|
60
|
+
max_possible_freq = max_pos / (max_pos + max_neg)
|
|
61
|
+
|
|
62
|
+
target_freq = min(self.target_eval_freq, max_possible_freq)
|
|
63
|
+
num_pos_target = min(int(target_freq * max_neg / (1 - target_freq)), max_pos)
|
|
64
|
+
num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
|
|
65
|
+
num_neg_target = min(num_neg_target, max_neg)
|
|
66
|
+
|
|
67
|
+
if self.max_eval_positive is not None:
|
|
68
|
+
num_pos_target = min(num_pos_target, self.max_eval_positive)
|
|
69
|
+
|
|
70
|
+
pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
|
|
71
|
+
neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
|
|
72
|
+
|
|
73
|
+
sampled_idx = np.concatenate([pos_sampled, neg_sampled])
|
|
74
|
+
np.random.shuffle(sampled_idx)
|
|
75
|
+
|
|
76
|
+
return sampled_idx
|
|
77
|
+
|
|
78
|
+
def evaluate(self, X, y, prefix="test"):
|
|
79
|
+
y_true = y
|
|
80
|
+
y_prob = self.predict_proba(X)
|
|
81
|
+
y_pred = self.predict(X)
|
|
82
|
+
|
|
83
|
+
if self.enforce_eval_balance:
|
|
84
|
+
sampled_idx = self._subsample_for_fixed_positive_frequency(y_true)
|
|
85
|
+
y_true = y_true[sampled_idx]
|
|
86
|
+
y_prob = y_prob[sampled_idx]
|
|
87
|
+
y_pred = y_pred[sampled_idx]
|
|
88
|
+
|
|
89
|
+
binary_focus = (y_true == self.focus_class).astype(int)
|
|
90
|
+
num_pos = binary_focus.sum()
|
|
91
|
+
|
|
92
|
+
is_binary = self.num_classes == 2
|
|
93
|
+
|
|
94
|
+
if is_binary:
|
|
95
|
+
if self.focus_class == 1:
|
|
96
|
+
focus_probs = y_prob[:, 1]
|
|
97
|
+
else:
|
|
98
|
+
focus_probs = y_prob[:, 0]
|
|
99
|
+
preds_focus = (y_pred == self.focus_class).astype(int)
|
|
100
|
+
else:
|
|
101
|
+
focus_probs = y_prob[:, self.focus_class]
|
|
102
|
+
preds_focus = (y_pred == self.focus_class).astype(int)
|
|
103
|
+
|
|
104
|
+
f1 = f1_score(binary_focus, preds_focus)
|
|
105
|
+
roc_auc = roc_auc_score(binary_focus, focus_probs)
|
|
106
|
+
pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
|
|
107
|
+
pr_auc = auc(rc, pr)
|
|
108
|
+
pos_freq = binary_focus.mean()
|
|
109
|
+
pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
|
|
110
|
+
fpr, tpr, _ = roc_curve(binary_focus, focus_probs)
|
|
111
|
+
cm = confusion_matrix(y_true, y_pred)
|
|
112
|
+
acc = np.mean(y_pred == y_true)
|
|
113
|
+
|
|
114
|
+
# store metrics as attributes for plotting later
|
|
115
|
+
setattr(self, f"{prefix}_f1", f1)
|
|
116
|
+
setattr(self, f"{prefix}_roc_curve", (fpr, tpr))
|
|
117
|
+
setattr(self, f"{prefix}_pr_curve", (rc, pr))
|
|
118
|
+
setattr(self, f"{prefix}_roc_auc", roc_auc)
|
|
119
|
+
setattr(self, f"{prefix}_pr_auc", pr_auc)
|
|
120
|
+
setattr(self, f"{prefix}_pos_freq", pos_freq)
|
|
121
|
+
setattr(self, f"{prefix}_num_pos", num_pos)
|
|
122
|
+
setattr(self, f"{prefix}_confusion_matrix", cm)
|
|
123
|
+
setattr(self, f"{prefix}_acc", acc)
|
|
124
|
+
|
|
125
|
+
# also store a metrics dict
|
|
126
|
+
self.metrics = {
|
|
127
|
+
f"{prefix}_acc": acc,
|
|
128
|
+
f"{prefix}_f1": f1,
|
|
129
|
+
f"{prefix}_auc": roc_auc,
|
|
130
|
+
f"{prefix}_pr_auc": pr_auc,
|
|
131
|
+
f"{prefix}_pr_auc_norm": pr_auc_norm,
|
|
132
|
+
f"{prefix}_pos_freq": pos_freq,
|
|
133
|
+
f"{prefix}_num_pos": num_pos
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
return self.metrics
|
|
137
|
+
|
|
138
|
+
def plot_roc_pr_curves(self, prefix="test"):
|
|
139
|
+
plt.figure(figsize=(12, 5))
|
|
140
|
+
|
|
141
|
+
fpr, tpr = getattr(self, f"{prefix}_roc_curve")
|
|
142
|
+
roc_auc = getattr(self, f"{prefix}_roc_auc")
|
|
143
|
+
plt.subplot(1, 2, 1)
|
|
144
|
+
plt.plot(fpr, tpr, label=f"ROC AUC={roc_auc:.3f}")
|
|
145
|
+
plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
|
|
146
|
+
plt.xlabel("False Positive Rate")
|
|
147
|
+
plt.ylabel("True Positive Rate")
|
|
148
|
+
plt.ylim(0, 1.05)
|
|
149
|
+
plt.title(f"ROC Curve - {getattr(self, f'{prefix}_num_pos')} positives")
|
|
150
|
+
plt.legend()
|
|
151
|
+
|
|
152
|
+
rc, pr = getattr(self, f"{prefix}_pr_curve")
|
|
153
|
+
pr_auc = getattr(self, f"{prefix}_pr_auc")
|
|
154
|
+
pos_freq = getattr(self, f"{prefix}_pos_freq")
|
|
155
|
+
plt.subplot(1, 2, 2)
|
|
156
|
+
plt.plot(rc, pr, label=f"PR AUC={pr_auc:.3f}")
|
|
157
|
+
plt.axhline(pos_freq, linestyle="--", color="gray")
|
|
158
|
+
plt.xlabel("Recall")
|
|
159
|
+
plt.ylabel("Precision")
|
|
160
|
+
plt.ylim(0, 1.05)
|
|
161
|
+
plt.title(f"PR Curve - {getattr(self, f'{prefix}_num_pos')} positives")
|
|
162
|
+
plt.legend()
|
|
163
|
+
|
|
164
|
+
plt.tight_layout()
|
|
165
|
+
plt.show()
|
|
166
|
+
|
|
167
|
+
def fit_from_datamodule(self, datamodule):
|
|
168
|
+
datamodule.setup()
|
|
169
|
+
X_tensor, y_tensor = datamodule.train_set.dataset.X_tensor, datamodule.train_set.dataset.y_tensor
|
|
170
|
+
indices = datamodule.train_set.indices
|
|
171
|
+
X_train = X_tensor[indices].numpy()
|
|
172
|
+
y_train = y_tensor[indices].numpy()
|
|
173
|
+
self.fit(X_train, y_train)
|
|
174
|
+
self.train_obs_names = datamodule.adata.obs_names[datamodule.train_set.indices].tolist()
|
|
175
|
+
self.val_obs_names = datamodule.adata.obs_names[datamodule.val_set.indices].tolist()
|
|
176
|
+
self.test_obs_names = datamodule.adata.obs_names[datamodule.test_set.indices].tolist()
|
|
177
|
+
|
|
178
|
+
def evaluate_from_datamodule(self, datamodule, split="test"):
|
|
179
|
+
datamodule.setup()
|
|
180
|
+
if split == "val":
|
|
181
|
+
subset = datamodule.val_set
|
|
182
|
+
elif split == "test":
|
|
183
|
+
subset = datamodule.test_set
|
|
184
|
+
else:
|
|
185
|
+
raise ValueError(f"Invalid split '{split}'")
|
|
186
|
+
|
|
187
|
+
X_tensor, y_tensor = subset.dataset.X_tensor, subset.dataset.y_tensor
|
|
188
|
+
indices = subset.indices
|
|
189
|
+
X_eval = X_tensor[indices].numpy()
|
|
190
|
+
y_eval = y_tensor[indices].numpy()
|
|
191
|
+
|
|
192
|
+
return self.evaluate(X_eval, y_eval, prefix=split)
|
|
193
|
+
|
|
194
|
+
def compute_shap(self, X, background=None, nsamples=100, target_class=None):
|
|
195
|
+
"""
|
|
196
|
+
Compute SHAP values on input X, optionally for a specified target class.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
X : array-like
|
|
201
|
+
Input features
|
|
202
|
+
background : array-like
|
|
203
|
+
SHAP background
|
|
204
|
+
nsamples : int
|
|
205
|
+
Number of samples for kernel approximation
|
|
206
|
+
target_class : int, optional
|
|
207
|
+
If None, uses model predicted class
|
|
208
|
+
"""
|
|
209
|
+
import shap
|
|
210
|
+
|
|
211
|
+
# choose explainer
|
|
212
|
+
if hasattr(self.model, "tree_") or hasattr(self.model, "estimators_"):
|
|
213
|
+
explainer = shap.TreeExplainer(self.model, data=background)
|
|
214
|
+
else:
|
|
215
|
+
if background is None:
|
|
216
|
+
background = shap.kmeans(X, 10)
|
|
217
|
+
explainer = shap.KernelExplainer(self.model.predict_proba, background)
|
|
218
|
+
|
|
219
|
+
# determine class
|
|
220
|
+
if target_class is None:
|
|
221
|
+
preds = self.model.predict(X)
|
|
222
|
+
target_class = preds
|
|
223
|
+
|
|
224
|
+
if isinstance(explainer, shap.TreeExplainer):
|
|
225
|
+
shap_values = explainer.shap_values(X)
|
|
226
|
+
else:
|
|
227
|
+
shap_values = explainer.shap_values(X, nsamples=nsamples)
|
|
228
|
+
|
|
229
|
+
if isinstance(shap_values, np.ndarray):
|
|
230
|
+
if shap_values.ndim == 3:
|
|
231
|
+
if isinstance(target_class, int):
|
|
232
|
+
return shap_values[:, :, target_class]
|
|
233
|
+
elif isinstance(target_class, np.ndarray):
|
|
234
|
+
# target_class is per-sample
|
|
235
|
+
if np.any(target_class >= shap_values.shape[2]):
|
|
236
|
+
raise ValueError(f"target_class values exceed {shap_values.shape[2]}")
|
|
237
|
+
selected = np.array([
|
|
238
|
+
shap_values[i, :, c]
|
|
239
|
+
for i, c in enumerate(target_class)
|
|
240
|
+
])
|
|
241
|
+
return selected
|
|
242
|
+
else:
|
|
243
|
+
# fallback to class 0
|
|
244
|
+
return shap_values[:, :, 0]
|
|
245
|
+
else:
|
|
246
|
+
# 2D shape (samples, features), no class dimension
|
|
247
|
+
return shap_values
|
|
248
|
+
|
|
249
|
+
def apply_shap_to_adata(self, dataloader, adata, background=None, adata_key="shap_values", target_class=None, normalize=True):
|
|
250
|
+
"""
|
|
251
|
+
Compute SHAP from a DataLoader and store in AnnData if provided.
|
|
252
|
+
"""
|
|
253
|
+
X_batches = []
|
|
254
|
+
|
|
255
|
+
for batch in dataloader:
|
|
256
|
+
X = batch[0].detach().cpu().numpy()
|
|
257
|
+
X_batches.append(X)
|
|
258
|
+
|
|
259
|
+
X_full = np.concatenate(X_batches, axis=0)
|
|
260
|
+
|
|
261
|
+
shap_values = self.compute_shap(X_full, background=background, target_class=target_class)
|
|
262
|
+
|
|
263
|
+
if adata is not None:
|
|
264
|
+
adata.obsm[adata_key] = shap_values
|
|
265
|
+
|
|
266
|
+
if normalize:
|
|
267
|
+
arr = shap_values
|
|
268
|
+
# row-wise normalization
|
|
269
|
+
row_max = np.max(np.abs(arr), axis=1, keepdims=True)
|
|
270
|
+
row_max[row_max == 0] = 1 # avoid divide by zero
|
|
271
|
+
normalized = arr / row_max
|
|
272
|
+
|
|
273
|
+
adata.obsm[f"{adata_key}_normalized"] = normalized
|