smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
2
|
|
|
3
|
+
|
|
3
4
|
def annotate_split_column(adata, model, split_col="split"):
|
|
4
5
|
"""
|
|
5
6
|
Annotate adata.obs with train/val/test/new labels based on model's stored obs_names.
|
|
@@ -8,7 +9,7 @@ def annotate_split_column(adata, model, split_col="split"):
|
|
|
8
9
|
train_set = set(model.train_obs_names)
|
|
9
10
|
val_set = set(model.val_obs_names)
|
|
10
11
|
test_set = set(model.test_obs_names)
|
|
11
|
-
|
|
12
|
+
|
|
12
13
|
# Create array for split labels
|
|
13
14
|
split_labels = []
|
|
14
15
|
for obs in adata.obs_names:
|
|
@@ -20,8 +21,10 @@ def annotate_split_column(adata, model, split_col="split"):
|
|
|
20
21
|
split_labels.append("testing")
|
|
21
22
|
else:
|
|
22
23
|
split_labels.append("new")
|
|
23
|
-
|
|
24
|
+
|
|
24
25
|
# Store in AnnData.obs
|
|
25
|
-
adata.obs[split_col] = pd.Categorical(
|
|
26
|
-
|
|
26
|
+
adata.obs[split_col] = pd.Categorical(
|
|
27
|
+
split_labels, categories=["training", "validation", "testing", "new"]
|
|
28
|
+
)
|
|
29
|
+
|
|
27
30
|
print(f"Annotated {split_col} column with training/validation/testing/new status.")
|
|
@@ -1,17 +1,11 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import pandas as pd
|
|
3
1
|
import numpy as np
|
|
4
|
-
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
5
|
from .inference_utils import annotate_split_column
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
model,
|
|
10
|
-
datamodule,
|
|
11
|
-
trainer,
|
|
12
|
-
prefix="model",
|
|
13
|
-
devices=1
|
|
14
|
-
):
|
|
7
|
+
|
|
8
|
+
def run_lightning_inference(adata, model, datamodule, trainer, prefix="model", devices=1):
|
|
15
9
|
"""
|
|
16
10
|
Run inference on AnnData using TorchClassifierWrapper + AnnDataModule (in inference mode).
|
|
17
11
|
"""
|
|
@@ -57,7 +51,9 @@ def run_lightning_inference(
|
|
|
57
51
|
full_prefix = f"{prefix}_{label_col}"
|
|
58
52
|
|
|
59
53
|
adata.obs[f"{full_prefix}_pred"] = pred_class_idx
|
|
60
|
-
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
|
|
54
|
+
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
|
|
55
|
+
pred_class_labels, categories=class_labels
|
|
56
|
+
)
|
|
61
57
|
adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
|
|
62
58
|
|
|
63
59
|
for i, class_name in enumerate(class_labels):
|
|
@@ -65,4 +61,4 @@ def run_lightning_inference(
|
|
|
65
61
|
|
|
66
62
|
adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
|
|
67
63
|
|
|
68
|
-
print(f"Inference complete: stored under prefix '{full_prefix}'")
|
|
64
|
+
print(f"Inference complete: stored under prefix '{full_prefix}'")
|
|
@@ -1,14 +1,10 @@
|
|
|
1
|
-
import pandas as pd
|
|
2
1
|
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
3
4
|
from .inference_utils import annotate_split_column
|
|
4
5
|
|
|
5
6
|
|
|
6
|
-
def run_sklearn_inference(
|
|
7
|
-
adata,
|
|
8
|
-
model,
|
|
9
|
-
datamodule,
|
|
10
|
-
prefix="model"
|
|
11
|
-
):
|
|
7
|
+
def run_sklearn_inference(adata, model, datamodule, prefix="model"):
|
|
12
8
|
"""
|
|
13
9
|
Run inference on AnnData using SklearnModelWrapper.
|
|
14
10
|
"""
|
|
@@ -44,7 +40,9 @@ def run_sklearn_inference(
|
|
|
44
40
|
full_prefix = f"{prefix}_{label_col}"
|
|
45
41
|
|
|
46
42
|
adata.obs[f"{full_prefix}_pred"] = pred_class_idx
|
|
47
|
-
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
|
|
43
|
+
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
|
|
44
|
+
pred_class_labels, categories=class_labels
|
|
45
|
+
)
|
|
48
46
|
adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
|
|
49
47
|
|
|
50
48
|
for i, class_name in enumerate(class_labels):
|
|
@@ -3,16 +3,17 @@ from ..evaluation import PostInferenceModelEvaluator
|
|
|
3
3
|
from .lightning_inference import run_lightning_inference
|
|
4
4
|
from .sklearn_inference import run_sklearn_inference
|
|
5
5
|
|
|
6
|
+
|
|
6
7
|
def sliding_window_inference(
|
|
7
|
-
adata,
|
|
8
|
-
trained_results,
|
|
9
|
-
tensor_source=
|
|
8
|
+
adata,
|
|
9
|
+
trained_results,
|
|
10
|
+
tensor_source="X",
|
|
10
11
|
tensor_key=None,
|
|
11
|
-
label_col=
|
|
12
|
+
label_col="activity_status",
|
|
12
13
|
batch_size=64,
|
|
13
14
|
cleanup=False,
|
|
14
|
-
target_eval_freq=None,
|
|
15
|
-
max_eval_positive=None
|
|
15
|
+
target_eval_freq=None,
|
|
16
|
+
max_eval_positive=None,
|
|
16
17
|
):
|
|
17
18
|
"""
|
|
18
19
|
Apply trained sliding window models to an AnnData object (Lightning or Sklearn).
|
|
@@ -24,11 +25,11 @@ def sliding_window_inference(
|
|
|
24
25
|
for window_size, window_data in model_dict.items():
|
|
25
26
|
for center_varname, run in window_data.items():
|
|
26
27
|
print(f"\nEvaluating {model_name} window {window_size} around {center_varname}")
|
|
27
|
-
|
|
28
|
+
|
|
28
29
|
# Extract window start from varname
|
|
29
30
|
center_idx = adata.var_names.get_loc(center_varname)
|
|
30
31
|
window_start = center_idx - window_size // 2
|
|
31
|
-
|
|
32
|
+
|
|
32
33
|
# Build datamodule for window
|
|
33
34
|
datamodule = AnnDataModule(
|
|
34
35
|
adata,
|
|
@@ -38,31 +39,31 @@ def sliding_window_inference(
|
|
|
38
39
|
batch_size=batch_size,
|
|
39
40
|
window_start=window_start,
|
|
40
41
|
window_size=window_size,
|
|
41
|
-
inference_mode=True
|
|
42
|
+
inference_mode=True,
|
|
42
43
|
)
|
|
43
44
|
datamodule.setup()
|
|
44
45
|
|
|
45
46
|
# Extract model + detect type
|
|
46
|
-
model = run[
|
|
47
|
+
model = run["model"]
|
|
47
48
|
|
|
48
49
|
# Lightning models
|
|
49
|
-
if hasattr(run,
|
|
50
|
-
trainer = run[
|
|
50
|
+
if hasattr(run, "trainer") or "trainer" in run:
|
|
51
|
+
trainer = run["trainer"]
|
|
51
52
|
run_lightning_inference(
|
|
52
53
|
adata,
|
|
53
54
|
model=model,
|
|
54
55
|
datamodule=datamodule,
|
|
55
56
|
trainer=trainer,
|
|
56
|
-
prefix=f"{model_name}_w{window_size}_c{center_varname}"
|
|
57
|
+
prefix=f"{model_name}_w{window_size}_c{center_varname}",
|
|
57
58
|
)
|
|
58
|
-
|
|
59
|
+
|
|
59
60
|
# Sklearn models
|
|
60
61
|
else:
|
|
61
62
|
run_sklearn_inference(
|
|
62
63
|
adata,
|
|
63
64
|
model=model,
|
|
64
65
|
datamodule=datamodule,
|
|
65
|
-
prefix=f"{model_name}_w{window_size}_c{center_varname}"
|
|
66
|
+
prefix=f"{model_name}_w{window_size}_c{center_varname}",
|
|
66
67
|
)
|
|
67
68
|
|
|
68
69
|
print("Inference complete across all models.")
|
|
@@ -77,27 +78,36 @@ def sliding_window_inference(
|
|
|
77
78
|
prefix = f"{model_name}_w{window_size}_c{center_varname}"
|
|
78
79
|
# Use full key for uniqueness
|
|
79
80
|
key = prefix
|
|
80
|
-
model_wrappers[key] = run[
|
|
81
|
+
model_wrappers[key] = run["model"]
|
|
81
82
|
|
|
82
83
|
# Run evaluator
|
|
83
|
-
evaluator = PostInferenceModelEvaluator(
|
|
84
|
+
evaluator = PostInferenceModelEvaluator(
|
|
85
|
+
adata,
|
|
86
|
+
model_wrappers,
|
|
87
|
+
target_eval_freq=target_eval_freq,
|
|
88
|
+
max_eval_positive=max_eval_positive,
|
|
89
|
+
)
|
|
84
90
|
evaluator.evaluate_all()
|
|
85
91
|
|
|
86
92
|
# Get results
|
|
87
93
|
df = evaluator.to_dataframe()
|
|
88
94
|
|
|
89
|
-
df[[
|
|
95
|
+
df[["model_name", "window_size", "center"]] = df["model"].str.extract(
|
|
96
|
+
r"(\w+)_w(\d+)_c(\d+)_activity_status"
|
|
97
|
+
)
|
|
90
98
|
|
|
91
99
|
# Cast window_size and center to integers for plotting
|
|
92
|
-
df[
|
|
93
|
-
df[
|
|
100
|
+
df["window_size"] = df["window_size"].astype(int)
|
|
101
|
+
df["center"] = df["center"].astype(int)
|
|
94
102
|
|
|
95
103
|
## Optional cleanup:
|
|
96
104
|
if cleanup:
|
|
97
|
-
prefixes = [
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
105
|
+
prefixes = [
|
|
106
|
+
f"{model_name}_w{window_size}_c{center_varname}"
|
|
107
|
+
for model_name, model_dict in trained_results.items()
|
|
108
|
+
for window_size, window_data in model_dict.items()
|
|
109
|
+
for center_varname in window_data.keys()
|
|
110
|
+
]
|
|
101
111
|
|
|
102
112
|
# Remove matching obs columns
|
|
103
113
|
for prefix in prefixes:
|
|
@@ -111,4 +121,4 @@ def sliding_window_inference(
|
|
|
111
121
|
|
|
112
122
|
print(f"Cleaned up {len(prefixes)} model prefixes from AnnData.")
|
|
113
123
|
|
|
114
|
-
return df
|
|
124
|
+
return df
|
|
@@ -1,9 +1,14 @@
|
|
|
1
1
|
from .base import BaseTorchModel
|
|
2
|
-
from .mlp import MLPClassifier
|
|
3
2
|
from .cnn import CNNClassifier
|
|
4
|
-
from .
|
|
5
|
-
from .
|
|
3
|
+
from .lightning_base import TorchClassifierWrapper
|
|
4
|
+
from .mlp import MLPClassifier
|
|
6
5
|
from .positional import PositionalEncoding
|
|
6
|
+
from .rnn import RNNClassifier
|
|
7
|
+
from .sklearn_models import SklearnModelWrapper
|
|
8
|
+
from .transformer import (
|
|
9
|
+
BaseTransformer,
|
|
10
|
+
DANNTransformerClassifier,
|
|
11
|
+
MaskedTransformerPretrainer,
|
|
12
|
+
TransformerClassifier,
|
|
13
|
+
)
|
|
7
14
|
from .wrappers import ScaledModel
|
|
8
|
-
from .lightning_base import TorchClassifierWrapper
|
|
9
|
-
from .sklearn_models import SklearnModelWrapper
|
|
@@ -1,17 +1,20 @@
|
|
|
1
|
+
import numpy as np
|
|
1
2
|
import torch
|
|
2
3
|
import torch.nn as nn
|
|
3
|
-
|
|
4
|
+
|
|
4
5
|
from ..utils.device import detect_device
|
|
5
6
|
|
|
7
|
+
|
|
6
8
|
class BaseTorchModel(nn.Module):
|
|
7
9
|
"""
|
|
8
10
|
Minimal base class for torch models that:
|
|
9
11
|
- Stores device and dropout regularization
|
|
10
12
|
"""
|
|
13
|
+
|
|
11
14
|
def __init__(self, dropout_rate=0.0):
|
|
12
15
|
super().__init__()
|
|
13
|
-
self.device = detect_device()
|
|
14
|
-
self.dropout_rate = dropout_rate
|
|
16
|
+
self.device = detect_device() # detects available devices
|
|
17
|
+
self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
|
|
15
18
|
|
|
16
19
|
def compute_saliency(
|
|
17
20
|
self,
|
|
@@ -21,11 +24,11 @@ class BaseTorchModel(nn.Module):
|
|
|
21
24
|
smoothgrad=False,
|
|
22
25
|
smooth_samples=25,
|
|
23
26
|
smooth_noise=0.1,
|
|
24
|
-
signed=True
|
|
27
|
+
signed=True,
|
|
25
28
|
):
|
|
26
29
|
"""
|
|
27
30
|
Compute vanilla saliency or SmoothGrad saliency.
|
|
28
|
-
|
|
31
|
+
|
|
29
32
|
Arguments:
|
|
30
33
|
----------
|
|
31
34
|
x : torch.Tensor
|
|
@@ -43,7 +46,7 @@ class BaseTorchModel(nn.Module):
|
|
|
43
46
|
"""
|
|
44
47
|
self.eval()
|
|
45
48
|
x = x.clone().detach().requires_grad_(True)
|
|
46
|
-
|
|
49
|
+
|
|
47
50
|
if smoothgrad:
|
|
48
51
|
saliency_accum = torch.zeros_like(x)
|
|
49
52
|
for i in range(smooth_samples):
|
|
@@ -56,7 +59,7 @@ class BaseTorchModel(nn.Module):
|
|
|
56
59
|
if logits.shape[1] == 1:
|
|
57
60
|
scores = logits.squeeze(1)
|
|
58
61
|
else:
|
|
59
|
-
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
62
|
+
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
60
63
|
scores.sum().backward()
|
|
61
64
|
saliency_accum += x_noisy.grad.detach()
|
|
62
65
|
saliency = saliency_accum / smooth_samples
|
|
@@ -69,17 +72,17 @@ class BaseTorchModel(nn.Module):
|
|
|
69
72
|
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
70
73
|
scores.sum().backward()
|
|
71
74
|
saliency = x.grad.detach()
|
|
72
|
-
|
|
75
|
+
|
|
73
76
|
if not signed:
|
|
74
77
|
saliency = saliency.abs()
|
|
75
|
-
|
|
78
|
+
|
|
76
79
|
if reduction == "sum" and x.ndim == 3:
|
|
77
80
|
return saliency.sum(dim=-1)
|
|
78
81
|
elif reduction == "mean" and x.ndim == 3:
|
|
79
82
|
return saliency.mean(dim=-1)
|
|
80
83
|
else:
|
|
81
84
|
return saliency
|
|
82
|
-
|
|
85
|
+
|
|
83
86
|
def compute_gradient_x_input(self, x, target_class=None):
|
|
84
87
|
"""
|
|
85
88
|
Computes gradient × input attribution.
|
|
@@ -118,22 +121,11 @@ class BaseTorchModel(nn.Module):
|
|
|
118
121
|
baseline = torch.zeros_like(x)
|
|
119
122
|
|
|
120
123
|
attributions, delta = ig.attribute(
|
|
121
|
-
x,
|
|
122
|
-
baselines=baseline,
|
|
123
|
-
target=target_class,
|
|
124
|
-
n_steps=steps,
|
|
125
|
-
return_convergence_delta=True
|
|
124
|
+
x, baselines=baseline, target=target_class, n_steps=steps, return_convergence_delta=True
|
|
126
125
|
)
|
|
127
126
|
return attributions, delta
|
|
128
127
|
|
|
129
|
-
def compute_deeplift(
|
|
130
|
-
self,
|
|
131
|
-
x,
|
|
132
|
-
baseline=None,
|
|
133
|
-
target_class=None,
|
|
134
|
-
reduction="sum",
|
|
135
|
-
signed=True
|
|
136
|
-
):
|
|
128
|
+
def compute_deeplift(self, x, baseline=None, target_class=None, reduction="sum", signed=True):
|
|
137
129
|
"""
|
|
138
130
|
Compute DeepLIFT scores using captum.
|
|
139
131
|
|
|
@@ -158,21 +150,15 @@ class BaseTorchModel(nn.Module):
|
|
|
158
150
|
|
|
159
151
|
if not signed:
|
|
160
152
|
attr = attr.abs()
|
|
161
|
-
|
|
153
|
+
|
|
162
154
|
if reduction == "sum" and x.ndim == 3:
|
|
163
155
|
return attr.sum(dim=-1)
|
|
164
156
|
elif reduction == "mean" and x.ndim == 3:
|
|
165
157
|
return attr.mean(dim=-1)
|
|
166
158
|
else:
|
|
167
159
|
return attr
|
|
168
|
-
|
|
169
|
-
def compute_occlusion(
|
|
170
|
-
self,
|
|
171
|
-
x,
|
|
172
|
-
target_class=None,
|
|
173
|
-
window_size=5,
|
|
174
|
-
baseline=None
|
|
175
|
-
):
|
|
160
|
+
|
|
161
|
+
def compute_occlusion(self, x, target_class=None, window_size=5, baseline=None):
|
|
176
162
|
"""
|
|
177
163
|
Computes per-sample occlusion attribution.
|
|
178
164
|
Supports 2D [B, S] or 3D [B, S, D] inputs.
|
|
@@ -208,9 +194,7 @@ class BaseTorchModel(nn.Module):
|
|
|
208
194
|
x_occluded[left:right, :] = baseline[left:right, :]
|
|
209
195
|
|
|
210
196
|
x_tensor = torch.tensor(
|
|
211
|
-
x_occluded,
|
|
212
|
-
device=self.device,
|
|
213
|
-
dtype=torch.float32
|
|
197
|
+
x_occluded, device=self.device, dtype=torch.float32
|
|
214
198
|
).unsqueeze(0)
|
|
215
199
|
|
|
216
200
|
logits = self.forward(x_tensor)
|
|
@@ -235,7 +219,7 @@ class BaseTorchModel(nn.Module):
|
|
|
235
219
|
device="cpu",
|
|
236
220
|
target_class=None,
|
|
237
221
|
normalize=True,
|
|
238
|
-
signed=True
|
|
222
|
+
signed=True,
|
|
239
223
|
):
|
|
240
224
|
"""
|
|
241
225
|
Apply a chosen attribution method to a dataloader and store results in adata.
|
|
@@ -252,7 +236,9 @@ class BaseTorchModel(nn.Module):
|
|
|
252
236
|
attr = model.compute_saliency(x, target_class=target_class, signed=signed)
|
|
253
237
|
|
|
254
238
|
elif method == "smoothgrad":
|
|
255
|
-
attr = model.compute_saliency(
|
|
239
|
+
attr = model.compute_saliency(
|
|
240
|
+
x, smoothgrad=True, target_class=target_class, signed=signed
|
|
241
|
+
)
|
|
256
242
|
|
|
257
243
|
elif method == "IG":
|
|
258
244
|
attributions, delta = model.compute_integrated_gradients(
|
|
@@ -261,15 +247,15 @@ class BaseTorchModel(nn.Module):
|
|
|
261
247
|
attr = attributions
|
|
262
248
|
|
|
263
249
|
elif method == "deeplift":
|
|
264
|
-
attr = model.compute_deeplift(
|
|
250
|
+
attr = model.compute_deeplift(
|
|
251
|
+
x, baseline=baseline, target_class=target_class, signed=signed
|
|
252
|
+
)
|
|
265
253
|
|
|
266
254
|
elif method == "gradxinput":
|
|
267
255
|
attr = model.compute_gradient_x_input(x, target_class=target_class)
|
|
268
256
|
|
|
269
257
|
elif method == "occlusion":
|
|
270
|
-
attr = model.compute_occlusion(
|
|
271
|
-
x, target_class=target_class, baseline=baseline
|
|
272
|
-
)
|
|
258
|
+
attr = model.compute_occlusion(x, target_class=target_class, baseline=baseline)
|
|
273
259
|
|
|
274
260
|
else:
|
|
275
261
|
raise ValueError(f"Unknown method {method}")
|
|
@@ -292,4 +278,4 @@ class BaseTorchModel(nn.Module):
|
|
|
292
278
|
return target_class
|
|
293
279
|
if logits.shape[1] == 1:
|
|
294
280
|
return (logits > 0).long().squeeze(1)
|
|
295
|
-
return logits.argmax(dim=1)
|
|
281
|
+
return logits.argmax(dim=1)
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
import numpy as np
|
|
1
2
|
import torch
|
|
2
3
|
import torch.nn as nn
|
|
4
|
+
|
|
3
5
|
from .base import BaseTorchModel
|
|
4
|
-
|
|
6
|
+
|
|
5
7
|
|
|
6
8
|
class CNNClassifier(BaseTorchModel):
|
|
7
9
|
def __init__(
|
|
@@ -15,7 +17,7 @@ class CNNClassifier(BaseTorchModel):
|
|
|
15
17
|
use_pooling=False,
|
|
16
18
|
dropout=0.2,
|
|
17
19
|
gradcam_layer_idx=-1,
|
|
18
|
-
**kwargs
|
|
20
|
+
**kwargs,
|
|
19
21
|
):
|
|
20
22
|
super().__init__(**kwargs)
|
|
21
23
|
self.name = "CNNClassifier"
|
|
@@ -30,7 +32,9 @@ class CNNClassifier(BaseTorchModel):
|
|
|
30
32
|
|
|
31
33
|
# Build conv layers
|
|
32
34
|
for out_channels, ksize in zip(conv_channels, kernel_sizes):
|
|
33
|
-
layers.append(
|
|
35
|
+
layers.append(
|
|
36
|
+
nn.Conv1d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2)
|
|
37
|
+
)
|
|
34
38
|
if use_batchnorm:
|
|
35
39
|
layers.append(nn.BatchNorm1d(out_channels))
|
|
36
40
|
layers.append(nn.ReLU())
|
|
@@ -76,7 +80,7 @@ class CNNClassifier(BaseTorchModel):
|
|
|
76
80
|
x = self.conv(x)
|
|
77
81
|
x = x.view(x.size(0), -1)
|
|
78
82
|
return self.fc(x)
|
|
79
|
-
|
|
83
|
+
|
|
80
84
|
def _register_gradcam_hooks(self):
|
|
81
85
|
def forward_hook(module, input, output):
|
|
82
86
|
self.gradcam_activations = output.detach()
|
|
@@ -97,15 +101,15 @@ class CNNClassifier(BaseTorchModel):
|
|
|
97
101
|
self.eval() # disable dropout etc.
|
|
98
102
|
|
|
99
103
|
output = self.forward(x) # shape (B, C) or (B, 1)
|
|
100
|
-
|
|
104
|
+
|
|
101
105
|
if class_idx is None:
|
|
102
106
|
class_idx = output.argmax(dim=1)
|
|
103
|
-
|
|
107
|
+
|
|
104
108
|
if output.shape[1] == 1:
|
|
105
109
|
target = output.view(-1) # shape (B,)
|
|
106
110
|
else:
|
|
107
111
|
target = output[torch.arange(output.shape[0]), class_idx]
|
|
108
|
-
|
|
112
|
+
|
|
109
113
|
target.sum().backward(retain_graph=True)
|
|
110
114
|
|
|
111
115
|
# restore training mode
|
|
@@ -114,16 +118,16 @@ class CNNClassifier(BaseTorchModel):
|
|
|
114
118
|
|
|
115
119
|
# get activations and gradients (set these via forward hook!)
|
|
116
120
|
activations = self.gradcam_activations # (B, C, L)
|
|
117
|
-
gradients = self.gradcam_gradients
|
|
121
|
+
gradients = self.gradcam_gradients # (B, C, L)
|
|
118
122
|
|
|
119
123
|
weights = gradients.mean(dim=2, keepdim=True) # (B, C, 1)
|
|
120
|
-
cam = (weights * activations).sum(dim=1)
|
|
124
|
+
cam = (weights * activations).sum(dim=1) # (B, L)
|
|
121
125
|
|
|
122
126
|
cam = torch.relu(cam)
|
|
123
127
|
cam = cam / (cam.max(dim=1, keepdim=True).values + 1e-6)
|
|
124
128
|
|
|
125
129
|
return cam
|
|
126
|
-
|
|
130
|
+
|
|
127
131
|
def apply_gradcam_to_adata(self, dataloader, adata, obsm_key="gradcam", device="cpu"):
|
|
128
132
|
self.to(device)
|
|
129
133
|
self.eval()
|
|
@@ -135,4 +139,4 @@ class CNNClassifier(BaseTorchModel):
|
|
|
135
139
|
cams.append(cam_batch.cpu().numpy())
|
|
136
140
|
|
|
137
141
|
cams = np.concatenate(cams, axis=0) # shape: [n_obs, input_len]
|
|
138
|
-
adata.obsm[obsm_key] = cams
|
|
142
|
+
adata.obsm[obsm_key] = cams
|