smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,8 +1,15 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
from .base import BaseTorchModel
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
4
3
|
import numpy as np
|
|
5
4
|
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
7
|
+
from .base import BaseTorchModel
|
|
8
|
+
|
|
9
|
+
torch = require("torch", extra="ml-base", purpose="CNN models")
|
|
10
|
+
nn = torch.nn
|
|
11
|
+
|
|
12
|
+
|
|
6
13
|
class CNNClassifier(BaseTorchModel):
|
|
7
14
|
def __init__(
|
|
8
15
|
self,
|
|
@@ -15,7 +22,7 @@ class CNNClassifier(BaseTorchModel):
|
|
|
15
22
|
use_pooling=False,
|
|
16
23
|
dropout=0.2,
|
|
17
24
|
gradcam_layer_idx=-1,
|
|
18
|
-
**kwargs
|
|
25
|
+
**kwargs,
|
|
19
26
|
):
|
|
20
27
|
super().__init__(**kwargs)
|
|
21
28
|
self.name = "CNNClassifier"
|
|
@@ -30,7 +37,9 @@ class CNNClassifier(BaseTorchModel):
|
|
|
30
37
|
|
|
31
38
|
# Build conv layers
|
|
32
39
|
for out_channels, ksize in zip(conv_channels, kernel_sizes):
|
|
33
|
-
layers.append(
|
|
40
|
+
layers.append(
|
|
41
|
+
nn.Conv1d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2)
|
|
42
|
+
)
|
|
34
43
|
if use_batchnorm:
|
|
35
44
|
layers.append(nn.BatchNorm1d(out_channels))
|
|
36
45
|
layers.append(nn.ReLU())
|
|
@@ -76,7 +85,7 @@ class CNNClassifier(BaseTorchModel):
|
|
|
76
85
|
x = self.conv(x)
|
|
77
86
|
x = x.view(x.size(0), -1)
|
|
78
87
|
return self.fc(x)
|
|
79
|
-
|
|
88
|
+
|
|
80
89
|
def _register_gradcam_hooks(self):
|
|
81
90
|
def forward_hook(module, input, output):
|
|
82
91
|
self.gradcam_activations = output.detach()
|
|
@@ -97,15 +106,15 @@ class CNNClassifier(BaseTorchModel):
|
|
|
97
106
|
self.eval() # disable dropout etc.
|
|
98
107
|
|
|
99
108
|
output = self.forward(x) # shape (B, C) or (B, 1)
|
|
100
|
-
|
|
109
|
+
|
|
101
110
|
if class_idx is None:
|
|
102
111
|
class_idx = output.argmax(dim=1)
|
|
103
|
-
|
|
112
|
+
|
|
104
113
|
if output.shape[1] == 1:
|
|
105
114
|
target = output.view(-1) # shape (B,)
|
|
106
115
|
else:
|
|
107
116
|
target = output[torch.arange(output.shape[0]), class_idx]
|
|
108
|
-
|
|
117
|
+
|
|
109
118
|
target.sum().backward(retain_graph=True)
|
|
110
119
|
|
|
111
120
|
# restore training mode
|
|
@@ -114,16 +123,16 @@ class CNNClassifier(BaseTorchModel):
|
|
|
114
123
|
|
|
115
124
|
# get activations and gradients (set these via forward hook!)
|
|
116
125
|
activations = self.gradcam_activations # (B, C, L)
|
|
117
|
-
gradients = self.gradcam_gradients
|
|
126
|
+
gradients = self.gradcam_gradients # (B, C, L)
|
|
118
127
|
|
|
119
128
|
weights = gradients.mean(dim=2, keepdim=True) # (B, C, 1)
|
|
120
|
-
cam = (weights * activations).sum(dim=1)
|
|
129
|
+
cam = (weights * activations).sum(dim=1) # (B, L)
|
|
121
130
|
|
|
122
131
|
cam = torch.relu(cam)
|
|
123
132
|
cam = cam / (cam.max(dim=1, keepdim=True).values + 1e-6)
|
|
124
133
|
|
|
125
134
|
return cam
|
|
126
|
-
|
|
135
|
+
|
|
127
136
|
def apply_gradcam_to_adata(self, dataloader, adata, obsm_key="gradcam", device="cpu"):
|
|
128
137
|
self.to(device)
|
|
129
138
|
self.eval()
|
|
@@ -135,4 +144,4 @@ class CNNClassifier(BaseTorchModel):
|
|
|
135
144
|
cams.append(cam_batch.cpu().numpy())
|
|
136
145
|
|
|
137
146
|
cams = np.concatenate(cams, axis=0) # shape: [n_obs, input_len]
|
|
138
|
-
adata.obsm[obsm_key] = cams
|
|
147
|
+
adata.obsm[obsm_key] = cams
|
|
@@ -1,11 +1,22 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
|
-
from sklearn.metrics import (
|
|
5
|
-
roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
|
|
6
|
-
)
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
7
3
|
import numpy as np
|
|
8
4
|
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
7
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
|
|
8
|
+
pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning models")
|
|
9
|
+
torch = require("torch", extra="ml-base", purpose="Lightning models")
|
|
10
|
+
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
|
|
11
|
+
|
|
12
|
+
auc = sklearn_metrics.auc
|
|
13
|
+
confusion_matrix = sklearn_metrics.confusion_matrix
|
|
14
|
+
f1_score = sklearn_metrics.f1_score
|
|
15
|
+
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
16
|
+
roc_auc_score = sklearn_metrics.roc_auc_score
|
|
17
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
18
|
+
|
|
19
|
+
|
|
9
20
|
class TorchClassifierWrapper(pl.LightningModule):
|
|
10
21
|
"""
|
|
11
22
|
A Pytorch Lightning wrapper for PyTorch classifiers.
|
|
@@ -16,25 +27,26 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
16
27
|
- Can pass the index of the class label to use as the focus class when calculating precision/recall.
|
|
17
28
|
- Contains a prediction step to run inference with.
|
|
18
29
|
"""
|
|
30
|
+
|
|
19
31
|
def __init__(
|
|
20
32
|
self,
|
|
21
33
|
model: torch.nn.Module,
|
|
22
34
|
label_col: str,
|
|
23
35
|
num_classes: int,
|
|
24
|
-
class_names: list=None,
|
|
36
|
+
class_names: list = None,
|
|
25
37
|
optimizer_cls=torch.optim.AdamW,
|
|
26
38
|
optimizer_kwargs=None,
|
|
27
39
|
criterion_kwargs=None,
|
|
28
40
|
lr: float = 1e-3,
|
|
29
41
|
focus_class: int = 1, # used for binary or multiclass precision-recall
|
|
30
42
|
class_weights=None,
|
|
31
|
-
enforce_eval_balance: bool=False,
|
|
32
|
-
target_eval_freq: float=0.3,
|
|
33
|
-
max_eval_positive: int=None
|
|
43
|
+
enforce_eval_balance: bool = False,
|
|
44
|
+
target_eval_freq: float = 0.3,
|
|
45
|
+
max_eval_positive: int = None,
|
|
34
46
|
):
|
|
35
47
|
super().__init__()
|
|
36
48
|
self.model = model
|
|
37
|
-
self.save_hyperparameters(ignore=[
|
|
49
|
+
self.save_hyperparameters(ignore=["model"]) # logs all except actual model instance
|
|
38
50
|
self.optimizer_cls = optimizer_cls
|
|
39
51
|
self.optimizer_kwargs = optimizer_kwargs or {"weight_decay": 1e-4}
|
|
40
52
|
self.criterion = None
|
|
@@ -57,14 +69,17 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
57
69
|
if torch.is_tensor(class_weights[self.focus_class]):
|
|
58
70
|
self.criterion_kwargs["pos_weight"] = class_weights[self.focus_class]
|
|
59
71
|
else:
|
|
60
|
-
self.criterion_kwargs["pos_weight"] = torch.tensor(
|
|
72
|
+
self.criterion_kwargs["pos_weight"] = torch.tensor(
|
|
73
|
+
class_weights[self.focus_class], dtype=torch.float32, device=self.device
|
|
74
|
+
)
|
|
61
75
|
else:
|
|
62
76
|
# CrossEntropyLoss expects weight tensor of size C
|
|
63
77
|
if torch.is_tensor(class_weights):
|
|
64
78
|
self.criterion_kwargs["weight"] = class_weights
|
|
65
79
|
else:
|
|
66
|
-
self.criterion_kwargs["weight"] = torch.tensor(
|
|
67
|
-
|
|
80
|
+
self.criterion_kwargs["weight"] = torch.tensor(
|
|
81
|
+
class_weights, dtype=torch.float32
|
|
82
|
+
)
|
|
68
83
|
|
|
69
84
|
self._val_outputs = []
|
|
70
85
|
self._test_outputs = []
|
|
@@ -78,12 +93,20 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
78
93
|
|
|
79
94
|
def _init_criterion(self):
|
|
80
95
|
if self.num_classes == 2:
|
|
81
|
-
if "pos_weight" in self.criterion_kwargs and not torch.is_tensor(
|
|
82
|
-
self.criterion_kwargs["pos_weight"]
|
|
96
|
+
if "pos_weight" in self.criterion_kwargs and not torch.is_tensor(
|
|
97
|
+
self.criterion_kwargs["pos_weight"]
|
|
98
|
+
):
|
|
99
|
+
self.criterion_kwargs["pos_weight"] = torch.tensor(
|
|
100
|
+
self.criterion_kwargs["pos_weight"], dtype=torch.float32, device=self.device
|
|
101
|
+
)
|
|
83
102
|
self.criterion = torch.nn.BCEWithLogitsLoss(**self.criterion_kwargs)
|
|
84
103
|
else:
|
|
85
|
-
if "weight" in self.criterion_kwargs and not torch.is_tensor(
|
|
86
|
-
self.criterion_kwargs["weight"]
|
|
104
|
+
if "weight" in self.criterion_kwargs and not torch.is_tensor(
|
|
105
|
+
self.criterion_kwargs["weight"]
|
|
106
|
+
):
|
|
107
|
+
self.criterion_kwargs["weight"] = torch.tensor(
|
|
108
|
+
self.criterion_kwargs["weight"], dtype=torch.float32, device=self.device
|
|
109
|
+
)
|
|
87
110
|
self.criterion = torch.nn.CrossEntropyLoss(**self.criterion_kwargs)
|
|
88
111
|
|
|
89
112
|
def _resolve_focus_class(self, focus_class):
|
|
@@ -93,11 +116,13 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
93
116
|
if self.class_names is None:
|
|
94
117
|
raise ValueError("class_names must be provided if focus_class is a string.")
|
|
95
118
|
if focus_class not in self.class_names:
|
|
96
|
-
raise ValueError(
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"focus_class '{focus_class}' not found in class_names {self.class_names}."
|
|
121
|
+
)
|
|
97
122
|
return self.class_names.index(focus_class)
|
|
98
123
|
else:
|
|
99
124
|
raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
|
|
100
|
-
|
|
125
|
+
|
|
101
126
|
def set_training_indices(self, datamodule):
|
|
102
127
|
"""
|
|
103
128
|
Store obs_names for train/val/test subsets used during training.
|
|
@@ -140,7 +165,7 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
140
165
|
self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=False)
|
|
141
166
|
self._val_outputs.append((logits.detach(), y.detach()))
|
|
142
167
|
return loss
|
|
143
|
-
|
|
168
|
+
|
|
144
169
|
def test_step(self, batch, batch_idx):
|
|
145
170
|
"""
|
|
146
171
|
Test step for a batch through the Lightning Trainer.
|
|
@@ -189,7 +214,7 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
189
214
|
return self.criterion(logits.view(-1, 1), y)
|
|
190
215
|
else:
|
|
191
216
|
return self.criterion(logits, y)
|
|
192
|
-
|
|
217
|
+
|
|
193
218
|
def _get_probs(self, logits):
|
|
194
219
|
"""
|
|
195
220
|
A helper function for getting class probabilities for binary vs multiclass classifications.
|
|
@@ -207,8 +232,10 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
207
232
|
return (torch.sigmoid(logits.view(-1)) >= 0.5).long()
|
|
208
233
|
else:
|
|
209
234
|
return logits.argmax(dim=1)
|
|
210
|
-
|
|
211
|
-
def _subsample_for_fixed_positive_frequency(
|
|
235
|
+
|
|
236
|
+
def _subsample_for_fixed_positive_frequency(
|
|
237
|
+
self, y_true, probs, target_freq=0.3, max_positive=None
|
|
238
|
+
):
|
|
212
239
|
pos_idx = np.where(y_true == self.focus_class)[0]
|
|
213
240
|
neg_idx = np.where(y_true != self.focus_class)[0]
|
|
214
241
|
|
|
@@ -216,16 +243,20 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
216
243
|
max_positives_possible = len(pos_idx)
|
|
217
244
|
|
|
218
245
|
# maximum achievable positive class frequency
|
|
219
|
-
max_possible_freq = max_positives_possible / (
|
|
246
|
+
max_possible_freq = max_positives_possible / (
|
|
247
|
+
max_positives_possible + max_negatives_possible
|
|
248
|
+
)
|
|
220
249
|
|
|
221
250
|
if target_freq > max_possible_freq:
|
|
222
251
|
target_freq = max_possible_freq # clip if you ask for impossible freq
|
|
223
252
|
|
|
224
253
|
# now calculate positive count
|
|
225
|
-
num_pos_target = min(
|
|
254
|
+
num_pos_target = min(
|
|
255
|
+
int(target_freq * max_negatives_possible / (1 - target_freq)), max_positives_possible
|
|
256
|
+
)
|
|
226
257
|
num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
|
|
227
258
|
num_neg_target = min(num_neg_target, max_negatives_possible)
|
|
228
|
-
|
|
259
|
+
|
|
229
260
|
pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
|
|
230
261
|
neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
|
|
231
262
|
|
|
@@ -235,7 +266,7 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
235
266
|
actual_freq = len(pos_sampled) / len(sampled_idx)
|
|
236
267
|
|
|
237
268
|
return sampled_idx
|
|
238
|
-
|
|
269
|
+
|
|
239
270
|
def _log_classification_metrics(self, logits, targets, prefix="val"):
|
|
240
271
|
"""
|
|
241
272
|
A helper function for logging validation and testing split model evaluations.
|
|
@@ -252,9 +283,12 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
252
283
|
num_pos = binary_focus.sum()
|
|
253
284
|
|
|
254
285
|
# Subsample if you want to enforce a fixed proportion of the positive class
|
|
255
|
-
if prefix ==
|
|
286
|
+
if prefix == "test" and self.enforce_eval_balance:
|
|
256
287
|
sampled_idx = self._subsample_for_fixed_positive_frequency(
|
|
257
|
-
y_true,
|
|
288
|
+
y_true,
|
|
289
|
+
probs,
|
|
290
|
+
target_freq=self.target_eval_freq,
|
|
291
|
+
max_positive=self.max_eval_positive,
|
|
258
292
|
)
|
|
259
293
|
y_true = y_true[sampled_idx]
|
|
260
294
|
probs = probs[sampled_idx]
|
|
@@ -289,7 +323,7 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
289
323
|
cm = confusion_matrix(y_true, preds)
|
|
290
324
|
|
|
291
325
|
# Save attributes for later plotting
|
|
292
|
-
if prefix ==
|
|
326
|
+
if prefix == "test":
|
|
293
327
|
self.test_roc_curve = (fpr, tpr)
|
|
294
328
|
self.test_pr_curve = (rc, pr)
|
|
295
329
|
self.test_roc_auc = roc_auc
|
|
@@ -298,19 +332,21 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
298
332
|
self.test_num_pos = num_pos
|
|
299
333
|
self.test_acc = acc
|
|
300
334
|
self.test_f1 = f1
|
|
301
|
-
elif prefix ==
|
|
335
|
+
elif prefix == "val":
|
|
302
336
|
pass
|
|
303
337
|
|
|
304
338
|
# Logging
|
|
305
|
-
self.log_dict(
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
339
|
+
self.log_dict(
|
|
340
|
+
{
|
|
341
|
+
f"{prefix}_acc": acc,
|
|
342
|
+
f"{prefix}_f1": f1,
|
|
343
|
+
f"{prefix}_auc": roc_auc,
|
|
344
|
+
f"{prefix}_pr_auc": pr_auc,
|
|
345
|
+
f"{prefix}_pr_auc_norm": pr_auc_norm,
|
|
346
|
+
f"{prefix}_pos_freq": pos_freq,
|
|
347
|
+
f"{prefix}_num_pos": num_pos,
|
|
348
|
+
}
|
|
349
|
+
)
|
|
314
350
|
setattr(self, f"{prefix}_confusion_matrix", cm)
|
|
315
351
|
|
|
316
352
|
def _plot_roc_pr_curves(self, logits, targets):
|
|
@@ -334,7 +370,7 @@ class TorchClassifierWrapper(pl.LightningModule):
|
|
|
334
370
|
pos_freq = self.test_pos_freq
|
|
335
371
|
plt.subplot(1, 2, 2)
|
|
336
372
|
plt.plot(rc, pr, label=f"PR AUC={pr_auc:.3f}")
|
|
337
|
-
plt.axhline(pos_freq, linestyle=
|
|
373
|
+
plt.axhline(pos_freq, linestyle="--", color="gray")
|
|
338
374
|
plt.xlabel("Recall")
|
|
339
375
|
plt.ylabel("Precision")
|
|
340
376
|
plt.ylim(0, 1.05)
|
|
@@ -1,9 +1,22 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
4
|
+
|
|
3
5
|
from .base import BaseTorchModel
|
|
4
|
-
|
|
6
|
+
|
|
7
|
+
nn = require("torch.nn", extra="ml-base", purpose="MLP models")
|
|
8
|
+
|
|
9
|
+
|
|
5
10
|
class MLPClassifier(BaseTorchModel):
|
|
6
|
-
def __init__(
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
input_dim,
|
|
14
|
+
num_classes=2,
|
|
15
|
+
hidden_dims=[64, 64],
|
|
16
|
+
dropout=0.2,
|
|
17
|
+
use_batchnorm=True,
|
|
18
|
+
**kwargs,
|
|
19
|
+
):
|
|
7
20
|
super().__init__(**kwargs)
|
|
8
21
|
layers = []
|
|
9
22
|
in_dim = input_dim
|
|
@@ -23,4 +36,4 @@ class MLPClassifier(BaseTorchModel):
|
|
|
23
36
|
self.model = nn.Sequential(*layers)
|
|
24
37
|
|
|
25
38
|
def forward(self, x):
|
|
26
|
-
return self.model(x)
|
|
39
|
+
return self.model(x)
|
|
@@ -1,6 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
|
-
|
|
3
|
-
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
7
|
+
torch = require("torch", extra="ml-base", purpose="positional encoding")
|
|
8
|
+
nn = torch.nn
|
|
9
|
+
|
|
4
10
|
|
|
5
11
|
class PositionalEncoding(nn.Module):
|
|
6
12
|
def __init__(self, d_model, max_len=5000):
|
|
@@ -14,5 +20,5 @@ class PositionalEncoding(nn.Module):
|
|
|
14
20
|
self.register_buffer("pe", pe)
|
|
15
21
|
|
|
16
22
|
def forward(self, x):
|
|
17
|
-
x = x + self.pe[:, :x.size(1)]
|
|
18
|
-
return x
|
|
23
|
+
x = x + self.pe[:, : x.size(1)]
|
|
24
|
+
return x
|
|
@@ -1,7 +1,12 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
4
|
+
|
|
3
5
|
from .base import BaseTorchModel
|
|
4
6
|
|
|
7
|
+
nn = require("torch.nn", extra="ml-base", purpose="RNN models")
|
|
8
|
+
|
|
9
|
+
|
|
5
10
|
class RNNClassifier(BaseTorchModel):
|
|
6
11
|
def __init__(self, input_size, hidden_dim, num_classes, **kwargs):
|
|
7
12
|
super().__init__(**kwargs)
|
|
@@ -14,4 +19,4 @@ class RNNClassifier(BaseTorchModel):
|
|
|
14
19
|
def forward(self, x):
|
|
15
20
|
x = x.unsqueeze(1) # [B, 1, L] → for LSTM expecting batch_first
|
|
16
21
|
_, (h_n, _) = self.lstm(x) # h_n: [1, B, H]
|
|
17
|
-
return self.fc(h_n.squeeze(0)) # [B, H] → [B, num_classes]
|
|
22
|
+
return self.fc(h_n.squeeze(0)) # [B, H] → [B, num_classes]
|
|
@@ -1,23 +1,35 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
)
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
7
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
|
|
8
|
+
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
|
|
9
|
+
|
|
10
|
+
auc = sklearn_metrics.auc
|
|
11
|
+
confusion_matrix = sklearn_metrics.confusion_matrix
|
|
12
|
+
f1_score = sklearn_metrics.f1_score
|
|
13
|
+
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
14
|
+
roc_auc_score = sklearn_metrics.roc_auc_score
|
|
15
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
16
|
+
|
|
6
17
|
|
|
7
18
|
class SklearnModelWrapper:
|
|
8
19
|
"""
|
|
9
20
|
Unified sklearn wrapper matching TorchClassifierWrapper interface.
|
|
10
21
|
"""
|
|
22
|
+
|
|
11
23
|
def __init__(
|
|
12
|
-
self,
|
|
13
|
-
model,
|
|
24
|
+
self,
|
|
25
|
+
model,
|
|
14
26
|
label_col: str,
|
|
15
|
-
num_classes: int,
|
|
16
|
-
class_names=None,
|
|
17
|
-
focus_class: int=1,
|
|
18
|
-
enforce_eval_balance: bool=False,
|
|
19
|
-
target_eval_freq: float=0.3,
|
|
20
|
-
max_eval_positive=None
|
|
27
|
+
num_classes: int,
|
|
28
|
+
class_names=None,
|
|
29
|
+
focus_class: int = 1,
|
|
30
|
+
enforce_eval_balance: bool = False,
|
|
31
|
+
target_eval_freq: float = 0.3,
|
|
32
|
+
max_eval_positive=None,
|
|
21
33
|
):
|
|
22
34
|
self.model = model
|
|
23
35
|
self.label_col = label_col
|
|
@@ -37,7 +49,9 @@ class SklearnModelWrapper:
|
|
|
37
49
|
if self.class_names is None:
|
|
38
50
|
raise ValueError("class_names must be provided if focus_class is a string.")
|
|
39
51
|
if focus_class not in self.class_names:
|
|
40
|
-
raise ValueError(
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"focus_class '{focus_class}' not found in class_names {self.class_names}."
|
|
54
|
+
)
|
|
41
55
|
return self.class_names.index(focus_class)
|
|
42
56
|
else:
|
|
43
57
|
raise ValueError(f"focus_class must be int or str, got {type(focus_class)}")
|
|
@@ -130,7 +144,7 @@ class SklearnModelWrapper:
|
|
|
130
144
|
f"{prefix}_pr_auc": pr_auc,
|
|
131
145
|
f"{prefix}_pr_auc_norm": pr_auc_norm,
|
|
132
146
|
f"{prefix}_pos_freq": pos_freq,
|
|
133
|
-
f"{prefix}_num_pos": num_pos
|
|
147
|
+
f"{prefix}_num_pos": num_pos,
|
|
134
148
|
}
|
|
135
149
|
|
|
136
150
|
return self.metrics
|
|
@@ -166,7 +180,10 @@ class SklearnModelWrapper:
|
|
|
166
180
|
|
|
167
181
|
def fit_from_datamodule(self, datamodule):
|
|
168
182
|
datamodule.setup()
|
|
169
|
-
X_tensor, y_tensor =
|
|
183
|
+
X_tensor, y_tensor = (
|
|
184
|
+
datamodule.train_set.dataset.X_tensor,
|
|
185
|
+
datamodule.train_set.dataset.y_tensor,
|
|
186
|
+
)
|
|
170
187
|
indices = datamodule.train_set.indices
|
|
171
188
|
X_train = X_tensor[indices].numpy()
|
|
172
189
|
y_train = y_tensor[indices].numpy()
|
|
@@ -190,11 +207,11 @@ class SklearnModelWrapper:
|
|
|
190
207
|
y_eval = y_tensor[indices].numpy()
|
|
191
208
|
|
|
192
209
|
return self.evaluate(X_eval, y_eval, prefix=split)
|
|
193
|
-
|
|
210
|
+
|
|
194
211
|
def compute_shap(self, X, background=None, nsamples=100, target_class=None):
|
|
195
212
|
"""
|
|
196
213
|
Compute SHAP values on input X, optionally for a specified target class.
|
|
197
|
-
|
|
214
|
+
|
|
198
215
|
Parameters
|
|
199
216
|
----------
|
|
200
217
|
X : array-like
|
|
@@ -225,7 +242,7 @@ class SklearnModelWrapper:
|
|
|
225
242
|
shap_values = explainer.shap_values(X)
|
|
226
243
|
else:
|
|
227
244
|
shap_values = explainer.shap_values(X, nsamples=nsamples)
|
|
228
|
-
|
|
245
|
+
|
|
229
246
|
if isinstance(shap_values, np.ndarray):
|
|
230
247
|
if shap_values.ndim == 3:
|
|
231
248
|
if isinstance(target_class, int):
|
|
@@ -234,10 +251,7 @@ class SklearnModelWrapper:
|
|
|
234
251
|
# target_class is per-sample
|
|
235
252
|
if np.any(target_class >= shap_values.shape[2]):
|
|
236
253
|
raise ValueError(f"target_class values exceed {shap_values.shape[2]}")
|
|
237
|
-
selected = np.array([
|
|
238
|
-
shap_values[i, :, c]
|
|
239
|
-
for i, c in enumerate(target_class)
|
|
240
|
-
])
|
|
254
|
+
selected = np.array([shap_values[i, :, c] for i, c in enumerate(target_class)])
|
|
241
255
|
return selected
|
|
242
256
|
else:
|
|
243
257
|
# fallback to class 0
|
|
@@ -246,7 +260,15 @@ class SklearnModelWrapper:
|
|
|
246
260
|
# 2D shape (samples, features), no class dimension
|
|
247
261
|
return shap_values
|
|
248
262
|
|
|
249
|
-
def apply_shap_to_adata(
|
|
263
|
+
def apply_shap_to_adata(
|
|
264
|
+
self,
|
|
265
|
+
dataloader,
|
|
266
|
+
adata,
|
|
267
|
+
background=None,
|
|
268
|
+
adata_key="shap_values",
|
|
269
|
+
target_class=None,
|
|
270
|
+
normalize=True,
|
|
271
|
+
):
|
|
250
272
|
"""
|
|
251
273
|
Compute SHAP from a DataLoader and store in AnnData if provided.
|
|
252
274
|
"""
|
|
@@ -270,4 +292,4 @@ class SklearnModelWrapper:
|
|
|
270
292
|
row_max[row_max == 0] = 1 # avoid divide by zero
|
|
271
293
|
normalized = arr / row_max
|
|
272
294
|
|
|
273
|
-
adata.obsm[f"{adata_key}_normalized"] = normalized
|
|
295
|
+
adata.obsm[f"{adata_key}_normalized"] = normalized
|