smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,10 +1,16 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import pytorch_lightning as pl
|
|
3
1
|
import matplotlib.pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pytorch_lightning as pl
|
|
4
|
+
import torch
|
|
4
5
|
from sklearn.metrics import (
|
|
5
|
-
|
|
6
|
+
auc,
|
|
7
|
+
confusion_matrix,
|
|
8
|
+
f1_score,
|
|
9
|
+
precision_recall_curve,
|
|
10
|
+
roc_auc_score,
|
|
11
|
+
roc_curve,
|
|
6
12
|
)
|
|
7
|
-
|
|
13
|
+
|
|
8
14
|
|
|
9
15
|
class TorchClassifierWrapper(pl.LightningModule):
|
|
10
16
|
"""
|
|
@@ -16,25 +22,26 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
16
22
|
- Can pass the index of the class label to use as the focus class when calculating precision/recall.
|
|
17
23
|
- Contains a prediction step to run inference with.
|
|
18
24
|
"""
|
|
25
|
+
|
|
19
26
|
def __init__(
|
|
20
27
|
self,
|
|
21
28
|
model: torch.nn.Module,
|
|
22
29
|
label_col: str,
|
|
23
30
|
num_classes: int,
|
|
24
|
-
class_names: list=None,
|
|
31
|
+
class_names: list = None,
|
|
25
32
|
optimizer_cls=torch.optim.AdamW,
|
|
26
33
|
optimizer_kwargs=None,
|
|
27
34
|
criterion_kwargs=None,
|
|
28
35
|
lr: float = 1e-3,
|
|
29
36
|
focus_class: int = 1, # used for binary or multiclass precision-recall
|
|
30
37
|
class_weights=None,
|
|
31
|
-
enforce_eval_balance: bool=False,
|
|
32
|
-
target_eval_freq: float=0.3,
|
|
33
|
-
max_eval_positive: int=None
|
|
38
|
+
enforce_eval_balance: bool = False,
|
|
39
|
+
target_eval_freq: float = 0.3,
|
|
40
|
+
max_eval_positive: int = None,
|
|
34
41
|
):
|
|
35
42
|
super().__init__()
|
|
36
43
|
self.model = model
|
|
37
|
-
self.save_hyperparameters(ignore=[
|
|
44
|
+
self.save_hyperparameters(ignore=["model"]) # logs all except actual model instance
|
|
38
45
|
self.optimizer_cls = optimizer_cls
|
|
39
46
|
self.optimizer_kwargs = optimizer_kwargs or {"weight_decay": 1e-4}
|
|
40
47
|
self.criterion = None
|
|
@@ -57,14 +64,17 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
57
64
|
if torch.is_tensor(class_weights[self.focus_class]):
|
|
58
65
|
self.criterion_kwargs["pos_weight"] = class_weights[self.focus_class]
|
|
59
66
|
else:
|
|
60
|
-
self.criterion_kwargs["pos_weight"] = torch.tensor(
|
|
67
|
+
self.criterion_kwargs["pos_weight"] = torch.tensor(
|
|
68
|
+
class_weights[self.focus_class], dtype=torch.float32, device=self.device
|
|
69
|
+
)
|
|
61
70
|
else:
|
|
62
71
|
# CrossEntropyLoss expects weight tensor of size C
|
|
63
72
|
if torch.is_tensor(class_weights):
|
|
64
73
|
self.criterion_kwargs["weight"] = class_weights
|
|
65
74
|
else:
|
|
66
|
-
self.criterion_kwargs["weight"] = torch.tensor(
|
|
67
|
-
|
|
75
|
+
self.criterion_kwargs["weight"] = torch.tensor(
|
|
76
|
+
class_weights, dtype=torch.float32
|
|
77
|
+
)
|
|
68
78
|
|
|
69
79
|
self._val_outputs = []
|
|
70
80
|
self._test_outputs = []
|
|
@@ -78,12 +88,20 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
78
88
|
|
|
79
89
|
def _init_criterion(self):
|
|
80
90
|
if self.num_classes == 2:
|
|
81
|
-
if "pos_weight" in self.criterion_kwargs and not torch.is_tensor(
|
|
82
|
-
self.criterion_kwargs["pos_weight"]
|
|
91
|
+
if "pos_weight" in self.criterion_kwargs and not torch.is_tensor(
|
|
92
|
+
self.criterion_kwargs["pos_weight"]
|
|
93
|
+
):
|
|
94
|
+
self.criterion_kwargs["pos_weight"] = torch.tensor(
|
|
95
|
+
self.criterion_kwargs["pos_weight"], dtype=torch.float32, device=self.device
|
|
96
|
+
)
|
|
83
97
|
self.criterion = torch.nn.BCEWithLogitsLoss(**self.criterion_kwargs)
|
|
84
98
|
else:
|
|
85
|
-
if "weight" in self.criterion_kwargs and not torch.is_tensor(
|
|
86
|
-
self.criterion_kwargs["weight"]
|
|
99
|
+
if "weight" in self.criterion_kwargs and not torch.is_tensor(
|
|
100
|
+
self.criterion_kwargs["weight"]
|
|
101
|
+
):
|
|
102
|
+
self.criterion_kwargs["weight"] = torch.tensor(
|
|
103
|
+
self.criterion_kwargs["weight"], dtype=torch.float32, device=self.device
|
|
104
|
+
)
|
|
87
105
|
self.criterion = torch.nn.CrossEntropyLoss(**self.criterion_kwargs)
|
|
88
106
|
|
|
89
107
|
def _resolve_focus_class(self, focus_class):
|
|
@@ -93,11 +111,13 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
93
111
|
if self.class_names is None:
|
|
94
112
|
raise ValueError("class_names must be provided if focus_class is a string.")
|
|
95
113
|
if focus_class not in self.class_names:
|
|
96
|
-
raise ValueError(
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"focus_class '{focus_class}' not found in class_names {self.class_names}."
|
|
116
|
+
)
|
|
97
117
|
return self.class_names.index(focus_class)
|
|
98
118
|
else:
|
|
99
119
|
raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
|
|
100
|
-
|
|
120
|
+
|
|
101
121
|
def set_training_indices(self, datamodule):
|
|
102
122
|
"""
|
|
103
123
|
Store obs_names for train/val/test subsets used during training.
|
|
@@ -140,7 +160,7 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
140
160
|
self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=False)
|
|
141
161
|
self._val_outputs.append((logits.detach(), y.detach()))
|
|
142
162
|
return loss
|
|
143
|
-
|
|
163
|
+
|
|
144
164
|
def test_step(self, batch, batch_idx):
|
|
145
165
|
"""
|
|
146
166
|
Test step for a batch through the Lightning Trainer.
|
|
@@ -189,7 +209,7 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
189
209
|
return self.criterion(logits.view(-1, 1), y)
|
|
190
210
|
else:
|
|
191
211
|
return self.criterion(logits, y)
|
|
192
|
-
|
|
212
|
+
|
|
193
213
|
def _get_probs(self, logits):
|
|
194
214
|
"""
|
|
195
215
|
A helper function for getting class probabilities for binary vs multiclass classifications.
|
|
@@ -207,8 +227,10 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
207
227
|
return (torch.sigmoid(logits.view(-1)) >= 0.5).long()
|
|
208
228
|
else:
|
|
209
229
|
return logits.argmax(dim=1)
|
|
210
|
-
|
|
211
|
-
def _subsample_for_fixed_positive_frequency(
|
|
230
|
+
|
|
231
|
+
def _subsample_for_fixed_positive_frequency(
|
|
232
|
+
self, y_true, probs, target_freq=0.3, max_positive=None
|
|
233
|
+
):
|
|
212
234
|
pos_idx = np.where(y_true == self.focus_class)[0]
|
|
213
235
|
neg_idx = np.where(y_true != self.focus_class)[0]
|
|
214
236
|
|
|
@@ -216,16 +238,20 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
216
238
|
max_positives_possible = len(pos_idx)
|
|
217
239
|
|
|
218
240
|
# maximum achievable positive class frequency
|
|
219
|
-
max_possible_freq = max_positives_possible / (
|
|
241
|
+
max_possible_freq = max_positives_possible / (
|
|
242
|
+
max_positives_possible + max_negatives_possible
|
|
243
|
+
)
|
|
220
244
|
|
|
221
245
|
if target_freq > max_possible_freq:
|
|
222
246
|
target_freq = max_possible_freq # clip if you ask for impossible freq
|
|
223
247
|
|
|
224
248
|
# now calculate positive count
|
|
225
|
-
num_pos_target = min(
|
|
249
|
+
num_pos_target = min(
|
|
250
|
+
int(target_freq * max_negatives_possible / (1 - target_freq)), max_positives_possible
|
|
251
|
+
)
|
|
226
252
|
num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
|
|
227
253
|
num_neg_target = min(num_neg_target, max_negatives_possible)
|
|
228
|
-
|
|
254
|
+
|
|
229
255
|
pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
|
|
230
256
|
neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
|
|
231
257
|
|
|
@@ -235,7 +261,7 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
235
261
|
actual_freq = len(pos_sampled) / len(sampled_idx)
|
|
236
262
|
|
|
237
263
|
return sampled_idx
|
|
238
|
-
|
|
264
|
+
|
|
239
265
|
def _log_classification_metrics(self, logits, targets, prefix="val"):
|
|
240
266
|
"""
|
|
241
267
|
A helper function for logging validation and testing split model evaluations.
|
|
@@ -252,9 +278,12 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
252
278
|
num_pos = binary_focus.sum()
|
|
253
279
|
|
|
254
280
|
# Subsample if you want to enforce a fixed proportion of the positive class
|
|
255
|
-
if prefix ==
|
|
281
|
+
if prefix == "test" and self.enforce_eval_balance:
|
|
256
282
|
sampled_idx = self._subsample_for_fixed_positive_frequency(
|
|
257
|
-
y_true,
|
|
283
|
+
y_true,
|
|
284
|
+
probs,
|
|
285
|
+
target_freq=self.target_eval_freq,
|
|
286
|
+
max_positive=self.max_eval_positive,
|
|
258
287
|
)
|
|
259
288
|
y_true = y_true[sampled_idx]
|
|
260
289
|
probs = probs[sampled_idx]
|
|
@@ -289,7 +318,7 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
289
318
|
cm = confusion_matrix(y_true, preds)
|
|
290
319
|
|
|
291
320
|
# Save attributes for later plotting
|
|
292
|
-
if prefix ==
|
|
321
|
+
if prefix == "test":
|
|
293
322
|
self.test_roc_curve = (fpr, tpr)
|
|
294
323
|
self.test_pr_curve = (rc, pr)
|
|
295
324
|
self.test_roc_auc = roc_auc
|
|
@@ -298,19 +327,21 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
298
327
|
self.test_num_pos = num_pos
|
|
299
328
|
self.test_acc = acc
|
|
300
329
|
self.test_f1 = f1
|
|
301
|
-
elif prefix ==
|
|
330
|
+
elif prefix == "val":
|
|
302
331
|
pass
|
|
303
332
|
|
|
304
333
|
# Logging
|
|
305
|
-
self.log_dict(
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
334
|
+
self.log_dict(
|
|
335
|
+
{
|
|
336
|
+
f"{prefix}_acc": acc,
|
|
337
|
+
f"{prefix}_f1": f1,
|
|
338
|
+
f"{prefix}_auc": roc_auc,
|
|
339
|
+
f"{prefix}_pr_auc": pr_auc,
|
|
340
|
+
f"{prefix}_pr_auc_norm": pr_auc_norm,
|
|
341
|
+
f"{prefix}_pos_freq": pos_freq,
|
|
342
|
+
f"{prefix}_num_pos": num_pos,
|
|
343
|
+
}
|
|
344
|
+
)
|
|
314
345
|
setattr(self, f"{prefix}_confusion_matrix", cm)
|
|
315
346
|
|
|
316
347
|
def _plot_roc_pr_curves(self, logits, targets):
|
|
@@ -334,7 +365,7 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
334
365
|
pos_freq = self.test_pos_freq
|
|
335
366
|
plt.subplot(1, 2, 2)
|
|
336
367
|
plt.plot(rc, pr, label=f"PR AUC={pr_auc:.3f}")
|
|
337
|
-
plt.axhline(pos_freq, linestyle=
|
|
368
|
+
plt.axhline(pos_freq, linestyle="--", color="gray")
|
|
338
369
|
plt.xlabel("Recall")
|
|
339
370
|
plt.ylabel("Precision")
|
|
340
371
|
plt.ylim(0, 1.05)
|
|
@@ -1,9 +1,18 @@
|
|
|
1
|
-
import torch
|
|
2
1
|
import torch.nn as nn
|
|
2
|
+
|
|
3
3
|
from .base import BaseTorchModel
|
|
4
|
-
|
|
4
|
+
|
|
5
|
+
|
|
5
6
|
class MLPClassifier(BaseTorchModel):
|
|
6
|
-
def __init__(
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
input_dim,
|
|
10
|
+
num_classes=2,
|
|
11
|
+
hidden_dims=[64, 64],
|
|
12
|
+
dropout=0.2,
|
|
13
|
+
use_batchnorm=True,
|
|
14
|
+
**kwargs,
|
|
15
|
+
):
|
|
7
16
|
super().__init__(**kwargs)
|
|
8
17
|
layers = []
|
|
9
18
|
in_dim = input_dim
|
|
@@ -23,4 +32,4 @@ class MLPClassifier(BaseTorchModel):
|
|
|
23
32
|
self.model = nn.Sequential(*layers)
|
|
24
33
|
|
|
25
34
|
def forward(self, x):
|
|
26
|
-
return self.model(x)
|
|
35
|
+
return self.model(x)
|
|
@@ -2,6 +2,7 @@ import numpy as np
|
|
|
2
2
|
import torch
|
|
3
3
|
import torch.nn as nn
|
|
4
4
|
|
|
5
|
+
|
|
5
6
|
class PositionalEncoding(nn.Module):
|
|
6
7
|
def __init__(self, d_model, max_len=5000):
|
|
7
8
|
super().__init__()
|
|
@@ -14,5 +15,5 @@ class PositionalEncoding(nn.Module):
|
|
|
14
15
|
self.register_buffer("pe", pe)
|
|
15
16
|
|
|
16
17
|
def forward(self, x):
|
|
17
|
-
x = x + self.pe[:, :x.size(1)]
|
|
18
|
-
return x
|
|
18
|
+
x = x + self.pe[:, : x.size(1)]
|
|
19
|
+
return x
|
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
import torch
|
|
2
1
|
import torch.nn as nn
|
|
2
|
+
|
|
3
3
|
from .base import BaseTorchModel
|
|
4
4
|
|
|
5
|
+
|
|
5
6
|
class RNNClassifier(BaseTorchModel):
|
|
6
7
|
def __init__(self, input_size, hidden_dim, num_classes, **kwargs):
|
|
7
8
|
super().__init__(**kwargs)
|
|
@@ -14,4 +15,4 @@ class RNNClassifier(BaseTorchModel):
|
|
|
14
15
|
def forward(self, x):
|
|
15
16
|
x = x.unsqueeze(1) # [B, 1, L] → for LSTM expecting batch_first
|
|
16
17
|
_, (h_n, _) = self.lstm(x) # h_n: [1, B, H]
|
|
17
|
-
return self.fc(h_n.squeeze(0)) # [B, H] → [B, num_classes]
|
|
18
|
+
return self.fc(h_n.squeeze(0)) # [B, H] → [B, num_classes]
|
|
@@ -1,23 +1,30 @@
|
|
|
1
|
-
import numpy as np
|
|
2
1
|
import matplotlib.pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
3
|
from sklearn.metrics import (
|
|
4
|
-
|
|
4
|
+
auc,
|
|
5
|
+
confusion_matrix,
|
|
6
|
+
f1_score,
|
|
7
|
+
precision_recall_curve,
|
|
8
|
+
roc_auc_score,
|
|
9
|
+
roc_curve,
|
|
5
10
|
)
|
|
6
11
|
|
|
12
|
+
|
|
7
13
|
class SklearnModelWrapper:
|
|
8
14
|
"""
|
|
9
15
|
Unified sklearn wrapper matching TorchClassifierWrapper interface.
|
|
10
16
|
"""
|
|
17
|
+
|
|
11
18
|
def __init__(
|
|
12
|
-
self,
|
|
13
|
-
model,
|
|
19
|
+
self,
|
|
20
|
+
model,
|
|
14
21
|
label_col: str,
|
|
15
|
-
num_classes: int,
|
|
16
|
-
class_names=None,
|
|
17
|
-
focus_class: int=1,
|
|
18
|
-
enforce_eval_balance: bool=False,
|
|
19
|
-
target_eval_freq: float=0.3,
|
|
20
|
-
max_eval_positive=None
|
|
22
|
+
num_classes: int,
|
|
23
|
+
class_names=None,
|
|
24
|
+
focus_class: int = 1,
|
|
25
|
+
enforce_eval_balance: bool = False,
|
|
26
|
+
target_eval_freq: float = 0.3,
|
|
27
|
+
max_eval_positive=None,
|
|
21
28
|
):
|
|
22
29
|
self.model = model
|
|
23
30
|
self.label_col = label_col
|
|
@@ -37,7 +44,9 @@ class SklearnModelWrapper:
|
|
|
37
44
|
if self.class_names is None:
|
|
38
45
|
raise ValueError("class_names must be provided if focus_class is a string.")
|
|
39
46
|
if focus_class not in self.class_names:
|
|
40
|
-
raise ValueError(
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"focus_class '{focus_class}' not found in class_names {self.class_names}."
|
|
49
|
+
)
|
|
41
50
|
return self.class_names.index(focus_class)
|
|
42
51
|
else:
|
|
43
52
|
raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
|
|
@@ -130,7 +139,7 @@ class SklearnModelWrapper:
|
|
|
130
139
|
f"{prefix}_pr_auc": pr_auc,
|
|
131
140
|
f"{prefix}_pr_auc_norm": pr_auc_norm,
|
|
132
141
|
f"{prefix}_pos_freq": pos_freq,
|
|
133
|
-
f"{prefix}_num_pos": num_pos
|
|
142
|
+
f"{prefix}_num_pos": num_pos,
|
|
134
143
|
}
|
|
135
144
|
|
|
136
145
|
return self.metrics
|
|
@@ -166,7 +175,10 @@ class SklearnModelWrapper:
|
|
|
166
175
|
|
|
167
176
|
def fit_from_datamodule(self, datamodule):
|
|
168
177
|
datamodule.setup()
|
|
169
|
-
X_tensor, y_tensor =
|
|
178
|
+
X_tensor, y_tensor = (
|
|
179
|
+
datamodule.train_set.dataset.X_tensor,
|
|
180
|
+
datamodule.train_set.dataset.y_tensor,
|
|
181
|
+
)
|
|
170
182
|
indices = datamodule.train_set.indices
|
|
171
183
|
X_train = X_tensor[indices].numpy()
|
|
172
184
|
y_train = y_tensor[indices].numpy()
|
|
@@ -190,11 +202,11 @@ class SklearnModelWrapper:
|
|
|
190
202
|
y_eval = y_tensor[indices].numpy()
|
|
191
203
|
|
|
192
204
|
return self.evaluate(X_eval, y_eval, prefix=split)
|
|
193
|
-
|
|
205
|
+
|
|
194
206
|
def compute_shap(self, X, background=None, nsamples=100, target_class=None):
|
|
195
207
|
"""
|
|
196
208
|
Compute SHAP values on input X, optionally for a specified target class.
|
|
197
|
-
|
|
209
|
+
|
|
198
210
|
Parameters
|
|
199
211
|
----------
|
|
200
212
|
X : array-like
|
|
@@ -225,7 +237,7 @@ class SklearnModelWrapper:
|
|
|
225
237
|
shap_values = explainer.shap_values(X)
|
|
226
238
|
else:
|
|
227
239
|
shap_values = explainer.shap_values(X, nsamples=nsamples)
|
|
228
|
-
|
|
240
|
+
|
|
229
241
|
if isinstance(shap_values, np.ndarray):
|
|
230
242
|
if shap_values.ndim == 3:
|
|
231
243
|
if isinstance(target_class, int):
|
|
@@ -234,10 +246,7 @@ class SklearnModelWrapper:
|
|
|
234
246
|
# target_class is per-sample
|
|
235
247
|
if np.any(target_class >= shap_values.shape[2]):
|
|
236
248
|
raise ValueError(f"target_class values exceed {shap_values.shape[2]}")
|
|
237
|
-
selected = np.array([
|
|
238
|
-
shap_values[i, :, c]
|
|
239
|
-
for i, c in enumerate(target_class)
|
|
240
|
-
])
|
|
249
|
+
selected = np.array([shap_values[i, :, c] for i, c in enumerate(target_class)])
|
|
241
250
|
return selected
|
|
242
251
|
else:
|
|
243
252
|
# fallback to class 0
|
|
@@ -246,7 +255,15 @@ class SklearnModelWrapper:
|
|
|
246
255
|
# 2D shape (samples, features), no class dimension
|
|
247
256
|
return shap_values
|
|
248
257
|
|
|
249
|
-
def apply_shap_to_adata(
|
|
258
|
+
def apply_shap_to_adata(
|
|
259
|
+
self,
|
|
260
|
+
dataloader,
|
|
261
|
+
adata,
|
|
262
|
+
background=None,
|
|
263
|
+
adata_key="shap_values",
|
|
264
|
+
target_class=None,
|
|
265
|
+
normalize=True,
|
|
266
|
+
):
|
|
250
267
|
"""
|
|
251
268
|
Compute SHAP from a DataLoader and store in AnnData if provided.
|
|
252
269
|
"""
|
|
@@ -270,4 +287,4 @@ class SklearnModelWrapper:
|
|
|
270
287
|
row_max[row_max == 0] = 1 # avoid divide by zero
|
|
271
288
|
normalized = arr / row_max
|
|
272
289
|
|
|
273
|
-
adata.obsm[f"{adata_key}_normalized"] = normalized
|
|
290
|
+
adata.obsm[f"{adata_key}_normalized"] = normalized
|