smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +7 -6
- smftools/_version.py +1 -1
- smftools/cli/cli_flows.py +94 -0
- smftools/cli/hmm_adata.py +338 -0
- smftools/cli/load_adata.py +577 -0
- smftools/cli/preprocess_adata.py +363 -0
- smftools/cli/spatial_adata.py +564 -0
- smftools/cli_entry.py +435 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +38 -0
- smftools/config/deaminase.yaml +61 -0
- smftools/config/default.yaml +264 -0
- smftools/config/direct.yaml +41 -0
- smftools/config/discover_input_files.py +115 -0
- smftools/config/experiment_config.py +1288 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +13 -9
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/fast5_to_pod5.py +43 -0
- smftools/informatics/archived/helpers/archived/__init__.py +71 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
- smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
- smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
- smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
- smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
- smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
- smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
- smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
- smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
- smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
- smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
- smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
- smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
- smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
- smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
- smftools/informatics/bam_functions.py +812 -0
- smftools/informatics/basecalling.py +67 -0
- smftools/informatics/bed_functions.py +366 -0
- smftools/informatics/binarize_converted_base_identities.py +172 -0
- smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
- smftools/informatics/fasta_functions.py +255 -0
- smftools/informatics/h5ad_functions.py +197 -0
- smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
- smftools/informatics/modkit_functions.py +129 -0
- smftools/informatics/ohe.py +160 -0
- smftools/informatics/pod5_functions.py +224 -0
- smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +609 -0
- smftools/plotting/general_plotting.py +1292 -140
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +15 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/binarize.py +17 -0
- smftools/preprocessing/binarize_on_Youden.py +2 -2
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_position_Youden.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +1021 -89
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
- smftools-0.2.3.dist-info/RECORD +173 -0
- smftools-0.2.3.dist-info/entry_points.txt +2 -0
- smftools/informatics/fast5_to_pod5.py +0 -21
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/__init__.py +0 -74
- smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
- smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
- smftools/informatics/helpers/bam_qc.py +0 -66
- smftools/informatics/helpers/bed_to_bigwig.py +0 -39
- smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
- smftools/informatics/helpers/index_fasta.py +0 -12
- smftools/informatics/helpers/make_dirs.py +0 -21
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/informatics/readwrite.py +0 -106
- smftools/informatics/subsample_fasta_from_bed.py +0 -47
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools/evaluation → cli}/__init__.py +0 -0
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
- /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
- /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
from ..utils.device import detect_device
|
|
5
|
+
|
|
6
|
+
class BaseTorchModel(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Minimal base class for torch models that:
|
|
9
|
+
- Stores device and dropout regularization
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, dropout_rate=0.0):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.device = detect_device() # detects available devices
|
|
14
|
+
self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
|
|
15
|
+
|
|
16
|
+
def compute_saliency(
|
|
17
|
+
self,
|
|
18
|
+
x,
|
|
19
|
+
target_class=None,
|
|
20
|
+
reduction="sum",
|
|
21
|
+
smoothgrad=False,
|
|
22
|
+
smooth_samples=25,
|
|
23
|
+
smooth_noise=0.1,
|
|
24
|
+
signed=True
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Compute vanilla saliency or SmoothGrad saliency.
|
|
28
|
+
|
|
29
|
+
Arguments:
|
|
30
|
+
----------
|
|
31
|
+
x : torch.Tensor
|
|
32
|
+
Input tensor [B, S, D].
|
|
33
|
+
target_class : int or list, optional
|
|
34
|
+
If None, uses model predicted class.
|
|
35
|
+
reduction : str
|
|
36
|
+
'sum' or 'mean' across channels
|
|
37
|
+
smoothgrad : bool
|
|
38
|
+
Whether to apply SmoothGrad.
|
|
39
|
+
smooth_samples : int
|
|
40
|
+
Number of noisy samples for SmoothGrad
|
|
41
|
+
smooth_noise : float
|
|
42
|
+
Standard deviation of noise added to input
|
|
43
|
+
"""
|
|
44
|
+
self.eval()
|
|
45
|
+
x = x.clone().detach().requires_grad_(True)
|
|
46
|
+
|
|
47
|
+
if smoothgrad:
|
|
48
|
+
saliency_accum = torch.zeros_like(x)
|
|
49
|
+
for i in range(smooth_samples):
|
|
50
|
+
noise = torch.normal(mean=0.0, std=smooth_noise, size=x.shape).to(x.device)
|
|
51
|
+
x_noisy = x + noise
|
|
52
|
+
x_noisy.requires_grad_(True)
|
|
53
|
+
x_noisy.retain_grad() # <<< fixes the issue
|
|
54
|
+
logits = self.forward(x_noisy)
|
|
55
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
56
|
+
if logits.shape[1] == 1:
|
|
57
|
+
scores = logits.squeeze(1)
|
|
58
|
+
else:
|
|
59
|
+
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
60
|
+
scores.sum().backward()
|
|
61
|
+
saliency_accum += x_noisy.grad.detach()
|
|
62
|
+
saliency = saliency_accum / smooth_samples
|
|
63
|
+
else:
|
|
64
|
+
logits = self.forward(x)
|
|
65
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
66
|
+
if logits.shape[1] == 1:
|
|
67
|
+
scores = logits.squeeze(1)
|
|
68
|
+
else:
|
|
69
|
+
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
70
|
+
scores.sum().backward()
|
|
71
|
+
saliency = x.grad.detach()
|
|
72
|
+
|
|
73
|
+
if not signed:
|
|
74
|
+
saliency = saliency.abs()
|
|
75
|
+
|
|
76
|
+
if reduction == "sum" and x.ndim == 3:
|
|
77
|
+
return saliency.sum(dim=-1)
|
|
78
|
+
elif reduction == "mean" and x.ndim == 3:
|
|
79
|
+
return saliency.mean(dim=-1)
|
|
80
|
+
else:
|
|
81
|
+
return saliency
|
|
82
|
+
|
|
83
|
+
def compute_gradient_x_input(self, x, target_class=None):
|
|
84
|
+
"""
|
|
85
|
+
Computes gradient × input attribution.
|
|
86
|
+
"""
|
|
87
|
+
x = x.clone().detach().requires_grad_(True)
|
|
88
|
+
logits = self.forward(x)
|
|
89
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
90
|
+
if logits.shape[1] == 1:
|
|
91
|
+
scores = logits.squeeze(1)
|
|
92
|
+
else:
|
|
93
|
+
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
94
|
+
scores.sum().backward()
|
|
95
|
+
grads = x.grad
|
|
96
|
+
return grads * x
|
|
97
|
+
|
|
98
|
+
def compute_integrated_gradients(self, x, target_class=None, steps=50, baseline=None):
|
|
99
|
+
"""
|
|
100
|
+
Compute Integrated Gradients for a batch of x.
|
|
101
|
+
If target=None, uses the predicted class.
|
|
102
|
+
Returns: [B, seq_len, channels] attribution tensor
|
|
103
|
+
"""
|
|
104
|
+
from captum.attr import IntegratedGradients
|
|
105
|
+
|
|
106
|
+
ig = IntegratedGradients(self)
|
|
107
|
+
self.eval()
|
|
108
|
+
x = x.requires_grad_(True)
|
|
109
|
+
|
|
110
|
+
with torch.no_grad():
|
|
111
|
+
logits = self.forward(x)
|
|
112
|
+
if logits.shape[1] == 1:
|
|
113
|
+
target_class = 0 # only one column exists, representing class 1 logit
|
|
114
|
+
else:
|
|
115
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
116
|
+
|
|
117
|
+
if baseline is None:
|
|
118
|
+
baseline = torch.zeros_like(x)
|
|
119
|
+
|
|
120
|
+
attributions, delta = ig.attribute(
|
|
121
|
+
x,
|
|
122
|
+
baselines=baseline,
|
|
123
|
+
target=target_class,
|
|
124
|
+
n_steps=steps,
|
|
125
|
+
return_convergence_delta=True
|
|
126
|
+
)
|
|
127
|
+
return attributions, delta
|
|
128
|
+
|
|
129
|
+
def compute_deeplift(
|
|
130
|
+
self,
|
|
131
|
+
x,
|
|
132
|
+
baseline=None,
|
|
133
|
+
target_class=None,
|
|
134
|
+
reduction="sum",
|
|
135
|
+
signed=True
|
|
136
|
+
):
|
|
137
|
+
"""
|
|
138
|
+
Compute DeepLIFT scores using captum.
|
|
139
|
+
|
|
140
|
+
baseline:
|
|
141
|
+
reference input for DeepLIFT.
|
|
142
|
+
"""
|
|
143
|
+
from captum.attr import DeepLift
|
|
144
|
+
|
|
145
|
+
self.eval()
|
|
146
|
+
deeplift = DeepLift(self)
|
|
147
|
+
|
|
148
|
+
logits = self.forward(x)
|
|
149
|
+
if logits.shape[1] == 1:
|
|
150
|
+
target_class = 0 # only one column exists, representing class 1 logit
|
|
151
|
+
else:
|
|
152
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
153
|
+
|
|
154
|
+
if baseline is None:
|
|
155
|
+
baseline = torch.zeros_like(x)
|
|
156
|
+
|
|
157
|
+
attr = deeplift.attribute(x, target=target_class, baselines=baseline)
|
|
158
|
+
|
|
159
|
+
if not signed:
|
|
160
|
+
attr = attr.abs()
|
|
161
|
+
|
|
162
|
+
if reduction == "sum" and x.ndim == 3:
|
|
163
|
+
return attr.sum(dim=-1)
|
|
164
|
+
elif reduction == "mean" and x.ndim == 3:
|
|
165
|
+
return attr.mean(dim=-1)
|
|
166
|
+
else:
|
|
167
|
+
return attr
|
|
168
|
+
|
|
169
|
+
def compute_occlusion(
|
|
170
|
+
self,
|
|
171
|
+
x,
|
|
172
|
+
target_class=None,
|
|
173
|
+
window_size=5,
|
|
174
|
+
baseline=None
|
|
175
|
+
):
|
|
176
|
+
"""
|
|
177
|
+
Computes per-sample occlusion attribution.
|
|
178
|
+
Supports 2D [B, S] or 3D [B, S, D] inputs.
|
|
179
|
+
Returns: [B, S] occlusion scores
|
|
180
|
+
"""
|
|
181
|
+
self.eval()
|
|
182
|
+
|
|
183
|
+
x_np = x.detach().cpu().numpy()
|
|
184
|
+
ndim = x_np.ndim
|
|
185
|
+
if ndim == 2:
|
|
186
|
+
B, S = x_np.shape
|
|
187
|
+
D = 1
|
|
188
|
+
elif ndim == 3:
|
|
189
|
+
B, S, D = x_np.shape
|
|
190
|
+
else:
|
|
191
|
+
raise ValueError(f"Unsupported input shape {x_np.shape}")
|
|
192
|
+
|
|
193
|
+
# if no baseline provided, fallback to mean
|
|
194
|
+
if baseline is None:
|
|
195
|
+
baseline = np.mean(x_np, axis=0)
|
|
196
|
+
|
|
197
|
+
occlusion_scores = np.zeros((B, S))
|
|
198
|
+
|
|
199
|
+
for b in range(B):
|
|
200
|
+
for i in range(S):
|
|
201
|
+
x_occluded = x_np[b].copy()
|
|
202
|
+
left = max(0, i - window_size // 2)
|
|
203
|
+
right = min(S, i + window_size // 2)
|
|
204
|
+
|
|
205
|
+
if ndim == 2:
|
|
206
|
+
x_occluded[left:right] = baseline[left:right]
|
|
207
|
+
else:
|
|
208
|
+
x_occluded[left:right, :] = baseline[left:right, :]
|
|
209
|
+
|
|
210
|
+
x_tensor = torch.tensor(
|
|
211
|
+
x_occluded,
|
|
212
|
+
device=self.device,
|
|
213
|
+
dtype=torch.float32
|
|
214
|
+
).unsqueeze(0)
|
|
215
|
+
|
|
216
|
+
logits = self.forward(x_tensor)
|
|
217
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
218
|
+
|
|
219
|
+
if logits.shape[1] == 1:
|
|
220
|
+
scores = logits.squeeze(1)
|
|
221
|
+
else:
|
|
222
|
+
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
223
|
+
|
|
224
|
+
occlusion_scores[b, i] = scores.mean().item()
|
|
225
|
+
|
|
226
|
+
return occlusion_scores
|
|
227
|
+
|
|
228
|
+
def apply_attributions_to_adata(
|
|
229
|
+
model,
|
|
230
|
+
dataloader,
|
|
231
|
+
adata,
|
|
232
|
+
method="saliency", # saliency, smoothgrad, IG, deeplift, gradxinput, occlusion
|
|
233
|
+
adata_key="attributions",
|
|
234
|
+
baseline=None,
|
|
235
|
+
device="cpu",
|
|
236
|
+
target_class=None,
|
|
237
|
+
normalize=True,
|
|
238
|
+
signed=True
|
|
239
|
+
):
|
|
240
|
+
"""
|
|
241
|
+
Apply a chosen attribution method to a dataloader and store results in adata.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
results = []
|
|
245
|
+
model.to(device)
|
|
246
|
+
model.eval()
|
|
247
|
+
|
|
248
|
+
for batch in dataloader:
|
|
249
|
+
x = batch[0].to(device)
|
|
250
|
+
|
|
251
|
+
if method == "saliency":
|
|
252
|
+
attr = model.compute_saliency(x, target_class=target_class, signed=signed)
|
|
253
|
+
|
|
254
|
+
elif method == "smoothgrad":
|
|
255
|
+
attr = model.compute_saliency(x, smoothgrad=True, target_class=target_class, signed=signed)
|
|
256
|
+
|
|
257
|
+
elif method == "IG":
|
|
258
|
+
attributions, delta = model.compute_integrated_gradients(
|
|
259
|
+
x, target_class=target_class, baseline=baseline
|
|
260
|
+
)
|
|
261
|
+
attr = attributions
|
|
262
|
+
|
|
263
|
+
elif method == "deeplift":
|
|
264
|
+
attr = model.compute_deeplift(x, baseline=baseline, target_class=target_class, signed=signed)
|
|
265
|
+
|
|
266
|
+
elif method == "gradxinput":
|
|
267
|
+
attr = model.compute_gradient_x_input(x, target_class=target_class)
|
|
268
|
+
|
|
269
|
+
elif method == "occlusion":
|
|
270
|
+
attr = model.compute_occlusion(
|
|
271
|
+
x, target_class=target_class, baseline=baseline
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
else:
|
|
275
|
+
raise ValueError(f"Unknown method {method}")
|
|
276
|
+
|
|
277
|
+
# ensure numpy
|
|
278
|
+
attr = attr.detach().cpu().numpy() if torch.is_tensor(attr) else attr
|
|
279
|
+
results.append(attr)
|
|
280
|
+
|
|
281
|
+
results_stacked = np.concatenate(results, axis=0)
|
|
282
|
+
adata.obsm[adata_key] = results_stacked
|
|
283
|
+
|
|
284
|
+
if normalize:
|
|
285
|
+
row_max = np.max(np.abs(results_stacked), axis=1, keepdims=True)
|
|
286
|
+
row_max[row_max == 0] = 1 # avoid divide by zero
|
|
287
|
+
results_normalized = results_stacked / row_max
|
|
288
|
+
adata.obsm[adata_key + "_normalized"] = results_normalized
|
|
289
|
+
|
|
290
|
+
def _resolve_target_class(self, logits, target_class):
|
|
291
|
+
if target_class is not None:
|
|
292
|
+
return target_class
|
|
293
|
+
if logits.shape[1] == 1:
|
|
294
|
+
return (logits > 0).long().squeeze(1)
|
|
295
|
+
return logits.argmax(dim=1)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .base import BaseTorchModel
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
class CNNClassifier(BaseTorchModel):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
input_size,
|
|
10
|
+
num_classes=2,
|
|
11
|
+
conv_channels=[16, 32],
|
|
12
|
+
kernel_sizes=[3, 3],
|
|
13
|
+
fc_dims=[64],
|
|
14
|
+
use_batchnorm=False,
|
|
15
|
+
use_pooling=False,
|
|
16
|
+
dropout=0.2,
|
|
17
|
+
gradcam_layer_idx=-1,
|
|
18
|
+
**kwargs
|
|
19
|
+
):
|
|
20
|
+
super().__init__(**kwargs)
|
|
21
|
+
self.name = "CNNClassifier"
|
|
22
|
+
|
|
23
|
+
# Normalize input
|
|
24
|
+
if isinstance(kernel_sizes, int):
|
|
25
|
+
kernel_sizes = [kernel_sizes] * len(conv_channels)
|
|
26
|
+
assert len(conv_channels) == len(kernel_sizes)
|
|
27
|
+
|
|
28
|
+
layers = []
|
|
29
|
+
in_channels = 1
|
|
30
|
+
|
|
31
|
+
# Build conv layers
|
|
32
|
+
for out_channels, ksize in zip(conv_channels, kernel_sizes):
|
|
33
|
+
layers.append(nn.Conv1d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2))
|
|
34
|
+
if use_batchnorm:
|
|
35
|
+
layers.append(nn.BatchNorm1d(out_channels))
|
|
36
|
+
layers.append(nn.ReLU())
|
|
37
|
+
if use_pooling:
|
|
38
|
+
layers.append(nn.MaxPool1d(kernel_size=2, stride=2))
|
|
39
|
+
if dropout > 0:
|
|
40
|
+
layers.append(nn.Dropout(dropout))
|
|
41
|
+
in_channels = out_channels
|
|
42
|
+
|
|
43
|
+
self.conv = nn.Sequential(*layers)
|
|
44
|
+
|
|
45
|
+
# Determine flattened size
|
|
46
|
+
with torch.no_grad():
|
|
47
|
+
dummy = torch.zeros(1, 1, input_size)
|
|
48
|
+
conv_out = self.conv(dummy)
|
|
49
|
+
flattened_size = conv_out.view(1, -1).shape[1]
|
|
50
|
+
|
|
51
|
+
# Build FC layers
|
|
52
|
+
fc_layers = []
|
|
53
|
+
in_dim = flattened_size
|
|
54
|
+
for dim in fc_dims:
|
|
55
|
+
fc_layers.append(nn.Linear(in_dim, dim))
|
|
56
|
+
fc_layers.append(nn.ReLU())
|
|
57
|
+
if dropout > 0:
|
|
58
|
+
fc_layers.append(nn.Dropout(dropout))
|
|
59
|
+
in_dim = dim
|
|
60
|
+
|
|
61
|
+
output_size = 1 if num_classes == 2 else num_classes
|
|
62
|
+
fc_layers.append(nn.Linear(in_dim, output_size))
|
|
63
|
+
|
|
64
|
+
self.fc = nn.Sequential(*fc_layers)
|
|
65
|
+
|
|
66
|
+
# Build gradcam hooks
|
|
67
|
+
self.gradcam_layer_idx = gradcam_layer_idx
|
|
68
|
+
self.gradcam_activations = None
|
|
69
|
+
self.gradcam_gradients = None
|
|
70
|
+
if not hasattr(self, "_hooks_registered"):
|
|
71
|
+
self._register_gradcam_hooks()
|
|
72
|
+
self._hooks_registered = True
|
|
73
|
+
|
|
74
|
+
def forward(self, x):
|
|
75
|
+
x = x.unsqueeze(1) # [B, 1, L]
|
|
76
|
+
x = self.conv(x)
|
|
77
|
+
x = x.view(x.size(0), -1)
|
|
78
|
+
return self.fc(x)
|
|
79
|
+
|
|
80
|
+
def _register_gradcam_hooks(self):
|
|
81
|
+
def forward_hook(module, input, output):
|
|
82
|
+
self.gradcam_activations = output.detach()
|
|
83
|
+
|
|
84
|
+
def backward_hook(module, grad_input, grad_output):
|
|
85
|
+
self.gradcam_gradients = grad_output[0].detach()
|
|
86
|
+
|
|
87
|
+
target_layer = list(self.conv.children())[self.gradcam_layer_idx]
|
|
88
|
+
target_layer.register_forward_hook(forward_hook)
|
|
89
|
+
target_layer.register_full_backward_hook(backward_hook)
|
|
90
|
+
|
|
91
|
+
def compute_gradcam(self, x, class_idx=None):
|
|
92
|
+
self.zero_grad()
|
|
93
|
+
|
|
94
|
+
x = x.detach().clone().requires_grad_().to(self.device)
|
|
95
|
+
|
|
96
|
+
was_training = self.training
|
|
97
|
+
self.eval() # disable dropout etc.
|
|
98
|
+
|
|
99
|
+
output = self.forward(x) # shape (B, C) or (B, 1)
|
|
100
|
+
|
|
101
|
+
if class_idx is None:
|
|
102
|
+
class_idx = output.argmax(dim=1)
|
|
103
|
+
|
|
104
|
+
if output.shape[1] == 1:
|
|
105
|
+
target = output.view(-1) # shape (B,)
|
|
106
|
+
else:
|
|
107
|
+
target = output[torch.arange(output.shape[0]), class_idx]
|
|
108
|
+
|
|
109
|
+
target.sum().backward(retain_graph=True)
|
|
110
|
+
|
|
111
|
+
# restore training mode
|
|
112
|
+
if was_training:
|
|
113
|
+
self.train()
|
|
114
|
+
|
|
115
|
+
# get activations and gradients (set these via forward hook!)
|
|
116
|
+
activations = self.gradcam_activations # (B, C, L)
|
|
117
|
+
gradients = self.gradcam_gradients # (B, C, L)
|
|
118
|
+
|
|
119
|
+
weights = gradients.mean(dim=2, keepdim=True) # (B, C, 1)
|
|
120
|
+
cam = (weights * activations).sum(dim=1) # (B, L)
|
|
121
|
+
|
|
122
|
+
cam = torch.relu(cam)
|
|
123
|
+
cam = cam / (cam.max(dim=1, keepdim=True).values + 1e-6)
|
|
124
|
+
|
|
125
|
+
return cam
|
|
126
|
+
|
|
127
|
+
def apply_gradcam_to_adata(self, dataloader, adata, obsm_key="gradcam", device="cpu"):
|
|
128
|
+
self.to(device)
|
|
129
|
+
self.eval()
|
|
130
|
+
cams = []
|
|
131
|
+
|
|
132
|
+
for batch in dataloader:
|
|
133
|
+
x = batch[0].to(device)
|
|
134
|
+
cam_batch = self.compute_gradcam(x)
|
|
135
|
+
cams.append(cam_batch.cpu().numpy())
|
|
136
|
+
|
|
137
|
+
cams = np.concatenate(cams, axis=0) # shape: [n_obs, input_len]
|
|
138
|
+
adata.obsm[obsm_key] = cams
|