smftools 0.1.7__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +9 -4
- smftools/_version.py +1 -1
- smftools/cli.py +184 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +33 -0
- smftools/config/deaminase.yaml +56 -0
- smftools/config/default.yaml +253 -0
- smftools/config/direct.yaml +17 -0
- smftools/config/experiment_config.py +1191 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +0 -2
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/fast5_to_pod5.py +4 -1
- smftools/informatics/helpers/__init__.py +3 -4
- smftools/informatics/helpers/align_and_sort_BAM.py +34 -7
- smftools/informatics/helpers/aligned_BAM_to_bed.py +35 -24
- smftools/informatics/helpers/binarize_converted_base_identities.py +116 -23
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +365 -42
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +165 -29
- smftools/informatics/helpers/discover_input_files.py +100 -0
- smftools/informatics/helpers/extract_base_identities.py +29 -3
- smftools/informatics/helpers/extract_read_features_from_bam.py +4 -2
- smftools/informatics/helpers/find_conversion_sites.py +5 -4
- smftools/informatics/helpers/modkit_extract_to_adata.py +6 -3
- smftools/informatics/helpers/plot_bed_histograms.py +269 -0
- smftools/informatics/helpers/separate_bam_by_bc.py +2 -2
- smftools/informatics/helpers/split_and_index_BAM.py +1 -5
- smftools/load_adata.py +1346 -0
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +611 -0
- smftools/plotting/general_plotting.py +566 -89
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +13 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +849 -43
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/METADATA +5 -1
- smftools-0.2.1.dist-info/RECORD +161 -0
- smftools-0.2.1.dist-info/entry_points.txt +2 -0
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/evaluation/__init__.py +0 -0
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from ..data import AnnDataModule
|
|
2
|
+
from ..evaluation import PostInferenceModelEvaluator
|
|
3
|
+
from .lightning_inference import run_lightning_inference
|
|
4
|
+
from .sklearn_inference import run_sklearn_inference
|
|
5
|
+
|
|
6
|
+
def sliding_window_inference(
|
|
7
|
+
adata,
|
|
8
|
+
trained_results,
|
|
9
|
+
tensor_source='X',
|
|
10
|
+
tensor_key=None,
|
|
11
|
+
label_col='activity_status',
|
|
12
|
+
batch_size=64,
|
|
13
|
+
cleanup=False,
|
|
14
|
+
target_eval_freq=None,
|
|
15
|
+
max_eval_positive=None
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Apply trained sliding window models to an AnnData object (Lightning or Sklearn).
|
|
19
|
+
Evaluate model performance and return a df.
|
|
20
|
+
Optionally remove the appended inference columns from AnnData to clean up obs namespace.
|
|
21
|
+
"""
|
|
22
|
+
## Inference using trained models
|
|
23
|
+
for model_name, model_dict in trained_results.items():
|
|
24
|
+
for window_size, window_data in model_dict.items():
|
|
25
|
+
for center_varname, run in window_data.items():
|
|
26
|
+
print(f"\nEvaluating {model_name} window {window_size} around {center_varname}")
|
|
27
|
+
|
|
28
|
+
# Extract window start from varname
|
|
29
|
+
center_idx = adata.var_names.get_loc(center_varname)
|
|
30
|
+
window_start = center_idx - window_size // 2
|
|
31
|
+
|
|
32
|
+
# Build datamodule for window
|
|
33
|
+
datamodule = AnnDataModule(
|
|
34
|
+
adata,
|
|
35
|
+
tensor_source=tensor_source,
|
|
36
|
+
tensor_key=tensor_key,
|
|
37
|
+
label_col=label_col,
|
|
38
|
+
batch_size=batch_size,
|
|
39
|
+
window_start=window_start,
|
|
40
|
+
window_size=window_size,
|
|
41
|
+
inference_mode=True
|
|
42
|
+
)
|
|
43
|
+
datamodule.setup()
|
|
44
|
+
|
|
45
|
+
# Extract model + detect type
|
|
46
|
+
model = run['model']
|
|
47
|
+
|
|
48
|
+
# Lightning models
|
|
49
|
+
if hasattr(run, 'trainer') or 'trainer' in run:
|
|
50
|
+
trainer = run['trainer']
|
|
51
|
+
run_lightning_inference(
|
|
52
|
+
adata,
|
|
53
|
+
model=model,
|
|
54
|
+
datamodule=datamodule,
|
|
55
|
+
trainer=trainer,
|
|
56
|
+
prefix=f"{model_name}_w{window_size}_c{center_varname}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Sklearn models
|
|
60
|
+
else:
|
|
61
|
+
run_sklearn_inference(
|
|
62
|
+
adata,
|
|
63
|
+
model=model,
|
|
64
|
+
datamodule=datamodule,
|
|
65
|
+
prefix=f"{model_name}_w{window_size}_c{center_varname}"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
print("Inference complete across all models.")
|
|
69
|
+
|
|
70
|
+
## Post-inference model evaluation
|
|
71
|
+
model_wrappers = {}
|
|
72
|
+
|
|
73
|
+
for model_name, model_dict in trained_results.items():
|
|
74
|
+
for window_size, window_data in model_dict.items():
|
|
75
|
+
for center_varname, run in window_data.items():
|
|
76
|
+
# Reconstruct the prefix string you used in inference
|
|
77
|
+
prefix = f"{model_name}_w{window_size}_c{center_varname}"
|
|
78
|
+
# Use full key for uniqueness
|
|
79
|
+
key = prefix
|
|
80
|
+
model_wrappers[key] = run['model']
|
|
81
|
+
|
|
82
|
+
# Run evaluator
|
|
83
|
+
evaluator = PostInferenceModelEvaluator(adata, model_wrappers, target_eval_freq=target_eval_freq, max_eval_positive=max_eval_positive)
|
|
84
|
+
evaluator.evaluate_all()
|
|
85
|
+
|
|
86
|
+
# Get results
|
|
87
|
+
df = evaluator.to_dataframe()
|
|
88
|
+
|
|
89
|
+
df[['model_name', 'window_size', 'center']] = df['model'].str.extract(r'(\w+)_w(\d+)_c(\d+)_activity_status')
|
|
90
|
+
|
|
91
|
+
# Cast window_size and center to integers for plotting
|
|
92
|
+
df['window_size'] = df['window_size'].astype(int)
|
|
93
|
+
df['center'] = df['center'].astype(int)
|
|
94
|
+
|
|
95
|
+
## Optional cleanup:
|
|
96
|
+
if cleanup:
|
|
97
|
+
prefixes = [f"{model_name}_w{window_size}_c{center_varname}"
|
|
98
|
+
for model_name, model_dict in trained_results.items()
|
|
99
|
+
for window_size, window_data in model_dict.items()
|
|
100
|
+
for center_varname in window_data.keys()]
|
|
101
|
+
|
|
102
|
+
# Remove matching obs columns
|
|
103
|
+
for prefix in prefixes:
|
|
104
|
+
to_remove = [col for col in adata.obs.columns if col.startswith(prefix)]
|
|
105
|
+
adata.obs.drop(columns=to_remove, inplace=True)
|
|
106
|
+
|
|
107
|
+
# Remove obsm entries if any
|
|
108
|
+
obsm_key = f"{prefix}_pred_prob_all"
|
|
109
|
+
if obsm_key in adata.obsm:
|
|
110
|
+
del adata.obsm[obsm_key]
|
|
111
|
+
|
|
112
|
+
print(f"Cleaned up {len(prefixes)} model prefixes from AnnData.")
|
|
113
|
+
|
|
114
|
+
return df
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
from ..utils.device import detect_device
|
|
5
|
+
|
|
6
|
+
class BaseTorchModel(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Minimal base class for torch models that:
|
|
9
|
+
- Stores device and dropout regularization
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, dropout_rate=0.0):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.device = detect_device() # detects available devices
|
|
14
|
+
self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
|
|
15
|
+
|
|
16
|
+
def compute_saliency(
|
|
17
|
+
self,
|
|
18
|
+
x,
|
|
19
|
+
target_class=None,
|
|
20
|
+
reduction="sum",
|
|
21
|
+
smoothgrad=False,
|
|
22
|
+
smooth_samples=25,
|
|
23
|
+
smooth_noise=0.1,
|
|
24
|
+
signed=True
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Compute vanilla saliency or SmoothGrad saliency.
|
|
28
|
+
|
|
29
|
+
Arguments:
|
|
30
|
+
----------
|
|
31
|
+
x : torch.Tensor
|
|
32
|
+
Input tensor [B, S, D].
|
|
33
|
+
target_class : int or list, optional
|
|
34
|
+
If None, uses model predicted class.
|
|
35
|
+
reduction : str
|
|
36
|
+
'sum' or 'mean' across channels
|
|
37
|
+
smoothgrad : bool
|
|
38
|
+
Whether to apply SmoothGrad.
|
|
39
|
+
smooth_samples : int
|
|
40
|
+
Number of noisy samples for SmoothGrad
|
|
41
|
+
smooth_noise : float
|
|
42
|
+
Standard deviation of noise added to input
|
|
43
|
+
"""
|
|
44
|
+
self.eval()
|
|
45
|
+
x = x.clone().detach().requires_grad_(True)
|
|
46
|
+
|
|
47
|
+
if smoothgrad:
|
|
48
|
+
saliency_accum = torch.zeros_like(x)
|
|
49
|
+
for i in range(smooth_samples):
|
|
50
|
+
noise = torch.normal(mean=0.0, std=smooth_noise, size=x.shape).to(x.device)
|
|
51
|
+
x_noisy = x + noise
|
|
52
|
+
x_noisy.requires_grad_(True)
|
|
53
|
+
x_noisy.retain_grad() # <<< fixes the issue
|
|
54
|
+
logits = self.forward(x_noisy)
|
|
55
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
56
|
+
if logits.shape[1] == 1:
|
|
57
|
+
scores = logits.squeeze(1)
|
|
58
|
+
else:
|
|
59
|
+
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
60
|
+
scores.sum().backward()
|
|
61
|
+
saliency_accum += x_noisy.grad.detach()
|
|
62
|
+
saliency = saliency_accum / smooth_samples
|
|
63
|
+
else:
|
|
64
|
+
logits = self.forward(x)
|
|
65
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
66
|
+
if logits.shape[1] == 1:
|
|
67
|
+
scores = logits.squeeze(1)
|
|
68
|
+
else:
|
|
69
|
+
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
70
|
+
scores.sum().backward()
|
|
71
|
+
saliency = x.grad.detach()
|
|
72
|
+
|
|
73
|
+
if not signed:
|
|
74
|
+
saliency = saliency.abs()
|
|
75
|
+
|
|
76
|
+
if reduction == "sum" and x.ndim == 3:
|
|
77
|
+
return saliency.sum(dim=-1)
|
|
78
|
+
elif reduction == "mean" and x.ndim == 3:
|
|
79
|
+
return saliency.mean(dim=-1)
|
|
80
|
+
else:
|
|
81
|
+
return saliency
|
|
82
|
+
|
|
83
|
+
def compute_gradient_x_input(self, x, target_class=None):
|
|
84
|
+
"""
|
|
85
|
+
Computes gradient × input attribution.
|
|
86
|
+
"""
|
|
87
|
+
x = x.clone().detach().requires_grad_(True)
|
|
88
|
+
logits = self.forward(x)
|
|
89
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
90
|
+
if logits.shape[1] == 1:
|
|
91
|
+
scores = logits.squeeze(1)
|
|
92
|
+
else:
|
|
93
|
+
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
94
|
+
scores.sum().backward()
|
|
95
|
+
grads = x.grad
|
|
96
|
+
return grads * x
|
|
97
|
+
|
|
98
|
+
def compute_integrated_gradients(self, x, target_class=None, steps=50, baseline=None):
|
|
99
|
+
"""
|
|
100
|
+
Compute Integrated Gradients for a batch of x.
|
|
101
|
+
If target=None, uses the predicted class.
|
|
102
|
+
Returns: [B, seq_len, channels] attribution tensor
|
|
103
|
+
"""
|
|
104
|
+
from captum.attr import IntegratedGradients
|
|
105
|
+
|
|
106
|
+
ig = IntegratedGradients(self)
|
|
107
|
+
self.eval()
|
|
108
|
+
x = x.requires_grad_(True)
|
|
109
|
+
|
|
110
|
+
with torch.no_grad():
|
|
111
|
+
logits = self.forward(x)
|
|
112
|
+
if logits.shape[1] == 1:
|
|
113
|
+
target_class = 0 # only one column exists, representing class 1 logit
|
|
114
|
+
else:
|
|
115
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
116
|
+
|
|
117
|
+
if baseline is None:
|
|
118
|
+
baseline = torch.zeros_like(x)
|
|
119
|
+
|
|
120
|
+
attributions, delta = ig.attribute(
|
|
121
|
+
x,
|
|
122
|
+
baselines=baseline,
|
|
123
|
+
target=target_class,
|
|
124
|
+
n_steps=steps,
|
|
125
|
+
return_convergence_delta=True
|
|
126
|
+
)
|
|
127
|
+
return attributions, delta
|
|
128
|
+
|
|
129
|
+
def compute_deeplift(
|
|
130
|
+
self,
|
|
131
|
+
x,
|
|
132
|
+
baseline=None,
|
|
133
|
+
target_class=None,
|
|
134
|
+
reduction="sum",
|
|
135
|
+
signed=True
|
|
136
|
+
):
|
|
137
|
+
"""
|
|
138
|
+
Compute DeepLIFT scores using captum.
|
|
139
|
+
|
|
140
|
+
baseline:
|
|
141
|
+
reference input for DeepLIFT.
|
|
142
|
+
"""
|
|
143
|
+
from captum.attr import DeepLift
|
|
144
|
+
|
|
145
|
+
self.eval()
|
|
146
|
+
deeplift = DeepLift(self)
|
|
147
|
+
|
|
148
|
+
logits = self.forward(x)
|
|
149
|
+
if logits.shape[1] == 1:
|
|
150
|
+
target_class = 0 # only one column exists, representing class 1 logit
|
|
151
|
+
else:
|
|
152
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
153
|
+
|
|
154
|
+
if baseline is None:
|
|
155
|
+
baseline = torch.zeros_like(x)
|
|
156
|
+
|
|
157
|
+
attr = deeplift.attribute(x, target=target_class, baselines=baseline)
|
|
158
|
+
|
|
159
|
+
if not signed:
|
|
160
|
+
attr = attr.abs()
|
|
161
|
+
|
|
162
|
+
if reduction == "sum" and x.ndim == 3:
|
|
163
|
+
return attr.sum(dim=-1)
|
|
164
|
+
elif reduction == "mean" and x.ndim == 3:
|
|
165
|
+
return attr.mean(dim=-1)
|
|
166
|
+
else:
|
|
167
|
+
return attr
|
|
168
|
+
|
|
169
|
+
def compute_occlusion(
|
|
170
|
+
self,
|
|
171
|
+
x,
|
|
172
|
+
target_class=None,
|
|
173
|
+
window_size=5,
|
|
174
|
+
baseline=None
|
|
175
|
+
):
|
|
176
|
+
"""
|
|
177
|
+
Computes per-sample occlusion attribution.
|
|
178
|
+
Supports 2D [B, S] or 3D [B, S, D] inputs.
|
|
179
|
+
Returns: [B, S] occlusion scores
|
|
180
|
+
"""
|
|
181
|
+
self.eval()
|
|
182
|
+
|
|
183
|
+
x_np = x.detach().cpu().numpy()
|
|
184
|
+
ndim = x_np.ndim
|
|
185
|
+
if ndim == 2:
|
|
186
|
+
B, S = x_np.shape
|
|
187
|
+
D = 1
|
|
188
|
+
elif ndim == 3:
|
|
189
|
+
B, S, D = x_np.shape
|
|
190
|
+
else:
|
|
191
|
+
raise ValueError(f"Unsupported input shape {x_np.shape}")
|
|
192
|
+
|
|
193
|
+
# if no baseline provided, fallback to mean
|
|
194
|
+
if baseline is None:
|
|
195
|
+
baseline = np.mean(x_np, axis=0)
|
|
196
|
+
|
|
197
|
+
occlusion_scores = np.zeros((B, S))
|
|
198
|
+
|
|
199
|
+
for b in range(B):
|
|
200
|
+
for i in range(S):
|
|
201
|
+
x_occluded = x_np[b].copy()
|
|
202
|
+
left = max(0, i - window_size // 2)
|
|
203
|
+
right = min(S, i + window_size // 2)
|
|
204
|
+
|
|
205
|
+
if ndim == 2:
|
|
206
|
+
x_occluded[left:right] = baseline[left:right]
|
|
207
|
+
else:
|
|
208
|
+
x_occluded[left:right, :] = baseline[left:right, :]
|
|
209
|
+
|
|
210
|
+
x_tensor = torch.tensor(
|
|
211
|
+
x_occluded,
|
|
212
|
+
device=self.device,
|
|
213
|
+
dtype=torch.float32
|
|
214
|
+
).unsqueeze(0)
|
|
215
|
+
|
|
216
|
+
logits = self.forward(x_tensor)
|
|
217
|
+
target_class = self._resolve_target_class(logits, target_class)
|
|
218
|
+
|
|
219
|
+
if logits.shape[1] == 1:
|
|
220
|
+
scores = logits.squeeze(1)
|
|
221
|
+
else:
|
|
222
|
+
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
223
|
+
|
|
224
|
+
occlusion_scores[b, i] = scores.mean().item()
|
|
225
|
+
|
|
226
|
+
return occlusion_scores
|
|
227
|
+
|
|
228
|
+
def apply_attributions_to_adata(
|
|
229
|
+
model,
|
|
230
|
+
dataloader,
|
|
231
|
+
adata,
|
|
232
|
+
method="saliency", # saliency, smoothgrad, IG, deeplift, gradxinput, occlusion
|
|
233
|
+
adata_key="attributions",
|
|
234
|
+
baseline=None,
|
|
235
|
+
device="cpu",
|
|
236
|
+
target_class=None,
|
|
237
|
+
normalize=True,
|
|
238
|
+
signed=True
|
|
239
|
+
):
|
|
240
|
+
"""
|
|
241
|
+
Apply a chosen attribution method to a dataloader and store results in adata.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
results = []
|
|
245
|
+
model.to(device)
|
|
246
|
+
model.eval()
|
|
247
|
+
|
|
248
|
+
for batch in dataloader:
|
|
249
|
+
x = batch[0].to(device)
|
|
250
|
+
|
|
251
|
+
if method == "saliency":
|
|
252
|
+
attr = model.compute_saliency(x, target_class=target_class, signed=signed)
|
|
253
|
+
|
|
254
|
+
elif method == "smoothgrad":
|
|
255
|
+
attr = model.compute_saliency(x, smoothgrad=True, target_class=target_class, signed=signed)
|
|
256
|
+
|
|
257
|
+
elif method == "IG":
|
|
258
|
+
attributions, delta = model.compute_integrated_gradients(
|
|
259
|
+
x, target_class=target_class, baseline=baseline
|
|
260
|
+
)
|
|
261
|
+
attr = attributions
|
|
262
|
+
|
|
263
|
+
elif method == "deeplift":
|
|
264
|
+
attr = model.compute_deeplift(x, baseline=baseline, target_class=target_class, signed=signed)
|
|
265
|
+
|
|
266
|
+
elif method == "gradxinput":
|
|
267
|
+
attr = model.compute_gradient_x_input(x, target_class=target_class)
|
|
268
|
+
|
|
269
|
+
elif method == "occlusion":
|
|
270
|
+
attr = model.compute_occlusion(
|
|
271
|
+
x, target_class=target_class, baseline=baseline
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
else:
|
|
275
|
+
raise ValueError(f"Unknown method {method}")
|
|
276
|
+
|
|
277
|
+
# ensure numpy
|
|
278
|
+
attr = attr.detach().cpu().numpy() if torch.is_tensor(attr) else attr
|
|
279
|
+
results.append(attr)
|
|
280
|
+
|
|
281
|
+
results_stacked = np.concatenate(results, axis=0)
|
|
282
|
+
adata.obsm[adata_key] = results_stacked
|
|
283
|
+
|
|
284
|
+
if normalize:
|
|
285
|
+
row_max = np.max(np.abs(results_stacked), axis=1, keepdims=True)
|
|
286
|
+
row_max[row_max == 0] = 1 # avoid divide by zero
|
|
287
|
+
results_normalized = results_stacked / row_max
|
|
288
|
+
adata.obsm[adata_key + "_normalized"] = results_normalized
|
|
289
|
+
|
|
290
|
+
def _resolve_target_class(self, logits, target_class):
|
|
291
|
+
if target_class is not None:
|
|
292
|
+
return target_class
|
|
293
|
+
if logits.shape[1] == 1:
|
|
294
|
+
return (logits > 0).long().squeeze(1)
|
|
295
|
+
return logits.argmax(dim=1)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .base import BaseTorchModel
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
class CNNClassifier(BaseTorchModel):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
input_size,
|
|
10
|
+
num_classes=2,
|
|
11
|
+
conv_channels=[16, 32],
|
|
12
|
+
kernel_sizes=[3, 3],
|
|
13
|
+
fc_dims=[64],
|
|
14
|
+
use_batchnorm=False,
|
|
15
|
+
use_pooling=False,
|
|
16
|
+
dropout=0.2,
|
|
17
|
+
gradcam_layer_idx=-1,
|
|
18
|
+
**kwargs
|
|
19
|
+
):
|
|
20
|
+
super().__init__(**kwargs)
|
|
21
|
+
self.name = "CNNClassifier"
|
|
22
|
+
|
|
23
|
+
# Normalize input
|
|
24
|
+
if isinstance(kernel_sizes, int):
|
|
25
|
+
kernel_sizes = [kernel_sizes] * len(conv_channels)
|
|
26
|
+
assert len(conv_channels) == len(kernel_sizes)
|
|
27
|
+
|
|
28
|
+
layers = []
|
|
29
|
+
in_channels = 1
|
|
30
|
+
|
|
31
|
+
# Build conv layers
|
|
32
|
+
for out_channels, ksize in zip(conv_channels, kernel_sizes):
|
|
33
|
+
layers.append(nn.Conv1d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2))
|
|
34
|
+
if use_batchnorm:
|
|
35
|
+
layers.append(nn.BatchNorm1d(out_channels))
|
|
36
|
+
layers.append(nn.ReLU())
|
|
37
|
+
if use_pooling:
|
|
38
|
+
layers.append(nn.MaxPool1d(kernel_size=2, stride=2))
|
|
39
|
+
if dropout > 0:
|
|
40
|
+
layers.append(nn.Dropout(dropout))
|
|
41
|
+
in_channels = out_channels
|
|
42
|
+
|
|
43
|
+
self.conv = nn.Sequential(*layers)
|
|
44
|
+
|
|
45
|
+
# Determine flattened size
|
|
46
|
+
with torch.no_grad():
|
|
47
|
+
dummy = torch.zeros(1, 1, input_size)
|
|
48
|
+
conv_out = self.conv(dummy)
|
|
49
|
+
flattened_size = conv_out.view(1, -1).shape[1]
|
|
50
|
+
|
|
51
|
+
# Build FC layers
|
|
52
|
+
fc_layers = []
|
|
53
|
+
in_dim = flattened_size
|
|
54
|
+
for dim in fc_dims:
|
|
55
|
+
fc_layers.append(nn.Linear(in_dim, dim))
|
|
56
|
+
fc_layers.append(nn.ReLU())
|
|
57
|
+
if dropout > 0:
|
|
58
|
+
fc_layers.append(nn.Dropout(dropout))
|
|
59
|
+
in_dim = dim
|
|
60
|
+
|
|
61
|
+
output_size = 1 if num_classes == 2 else num_classes
|
|
62
|
+
fc_layers.append(nn.Linear(in_dim, output_size))
|
|
63
|
+
|
|
64
|
+
self.fc = nn.Sequential(*fc_layers)
|
|
65
|
+
|
|
66
|
+
# Build gradcam hooks
|
|
67
|
+
self.gradcam_layer_idx = gradcam_layer_idx
|
|
68
|
+
self.gradcam_activations = None
|
|
69
|
+
self.gradcam_gradients = None
|
|
70
|
+
if not hasattr(self, "_hooks_registered"):
|
|
71
|
+
self._register_gradcam_hooks()
|
|
72
|
+
self._hooks_registered = True
|
|
73
|
+
|
|
74
|
+
def forward(self, x):
|
|
75
|
+
x = x.unsqueeze(1) # [B, 1, L]
|
|
76
|
+
x = self.conv(x)
|
|
77
|
+
x = x.view(x.size(0), -1)
|
|
78
|
+
return self.fc(x)
|
|
79
|
+
|
|
80
|
+
def _register_gradcam_hooks(self):
|
|
81
|
+
def forward_hook(module, input, output):
|
|
82
|
+
self.gradcam_activations = output.detach()
|
|
83
|
+
|
|
84
|
+
def backward_hook(module, grad_input, grad_output):
|
|
85
|
+
self.gradcam_gradients = grad_output[0].detach()
|
|
86
|
+
|
|
87
|
+
target_layer = list(self.conv.children())[self.gradcam_layer_idx]
|
|
88
|
+
target_layer.register_forward_hook(forward_hook)
|
|
89
|
+
target_layer.register_full_backward_hook(backward_hook)
|
|
90
|
+
|
|
91
|
+
def compute_gradcam(self, x, class_idx=None):
|
|
92
|
+
self.zero_grad()
|
|
93
|
+
|
|
94
|
+
x = x.detach().clone().requires_grad_().to(self.device)
|
|
95
|
+
|
|
96
|
+
was_training = self.training
|
|
97
|
+
self.eval() # disable dropout etc.
|
|
98
|
+
|
|
99
|
+
output = self.forward(x) # shape (B, C) or (B, 1)
|
|
100
|
+
|
|
101
|
+
if class_idx is None:
|
|
102
|
+
class_idx = output.argmax(dim=1)
|
|
103
|
+
|
|
104
|
+
if output.shape[1] == 1:
|
|
105
|
+
target = output.view(-1) # shape (B,)
|
|
106
|
+
else:
|
|
107
|
+
target = output[torch.arange(output.shape[0]), class_idx]
|
|
108
|
+
|
|
109
|
+
target.sum().backward(retain_graph=True)
|
|
110
|
+
|
|
111
|
+
# restore training mode
|
|
112
|
+
if was_training:
|
|
113
|
+
self.train()
|
|
114
|
+
|
|
115
|
+
# get activations and gradients (set these via forward hook!)
|
|
116
|
+
activations = self.gradcam_activations # (B, C, L)
|
|
117
|
+
gradients = self.gradcam_gradients # (B, C, L)
|
|
118
|
+
|
|
119
|
+
weights = gradients.mean(dim=2, keepdim=True) # (B, C, 1)
|
|
120
|
+
cam = (weights * activations).sum(dim=1) # (B, L)
|
|
121
|
+
|
|
122
|
+
cam = torch.relu(cam)
|
|
123
|
+
cam = cam / (cam.max(dim=1, keepdim=True).values + 1e-6)
|
|
124
|
+
|
|
125
|
+
return cam
|
|
126
|
+
|
|
127
|
+
def apply_gradcam_to_adata(self, dataloader, adata, obsm_key="gradcam", device="cpu"):
|
|
128
|
+
self.to(device)
|
|
129
|
+
self.eval()
|
|
130
|
+
cams = []
|
|
131
|
+
|
|
132
|
+
for batch in dataloader:
|
|
133
|
+
x = batch[0].to(device)
|
|
134
|
+
cam_batch = self.compute_gradcam(x)
|
|
135
|
+
cams.append(cam_batch.cpu().numpy())
|
|
136
|
+
|
|
137
|
+
cams = np.concatenate(cams, axis=0) # shape: [n_obs, input_len]
|
|
138
|
+
adata.obsm[obsm_key] = cams
|