smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,15 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
import pandas as pd
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
5
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
)
|
|
6
|
+
from smftools.optional_imports import require
|
|
7
|
+
|
|
8
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="evaluation plots")
|
|
9
|
+
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
|
|
10
|
+
|
|
11
|
+
auc = sklearn_metrics.auc
|
|
12
|
+
confusion_matrix = sklearn_metrics.confusion_matrix
|
|
13
|
+
f1_score = sklearn_metrics.f1_score
|
|
14
|
+
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
15
|
+
roc_auc_score = sklearn_metrics.roc_auc_score
|
|
16
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
17
|
+
|
|
8
18
|
|
|
9
19
|
class ModelEvaluator:
|
|
10
20
|
"""
|
|
11
21
|
A model evaluator for consolidating Sklearn and Lightning model evaluation metrics on testing data
|
|
12
22
|
"""
|
|
23
|
+
|
|
13
24
|
def __init__(self):
|
|
14
25
|
self.results = []
|
|
15
26
|
self.pos_freq = None
|
|
@@ -21,41 +32,45 @@ class ModelEvaluator:
|
|
|
21
32
|
"""
|
|
22
33
|
if is_torch:
|
|
23
34
|
entry = {
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
35
|
+
"name": name,
|
|
36
|
+
"f1": model.test_f1,
|
|
37
|
+
"auc": model.test_roc_auc,
|
|
38
|
+
"pr_auc": model.test_pr_auc,
|
|
39
|
+
"pr_auc_norm": model.test_pr_auc / model.test_pos_freq
|
|
40
|
+
if model.test_pos_freq > 0
|
|
41
|
+
else np.nan,
|
|
42
|
+
"pr_curve": model.test_pr_curve,
|
|
43
|
+
"roc_curve": model.test_roc_curve,
|
|
44
|
+
"num_pos": model.test_num_pos,
|
|
45
|
+
"pos_freq": model.test_pos_freq,
|
|
33
46
|
}
|
|
34
47
|
else:
|
|
35
48
|
entry = {
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
49
|
+
"name": name,
|
|
50
|
+
"f1": model.test_f1,
|
|
51
|
+
"auc": model.test_roc_auc,
|
|
52
|
+
"pr_auc": model.test_pr_auc,
|
|
53
|
+
"pr_auc_norm": model.test_pr_auc / model.test_pos_freq
|
|
54
|
+
if model.test_pos_freq > 0
|
|
55
|
+
else np.nan,
|
|
56
|
+
"pr_curve": model.test_pr_curve,
|
|
57
|
+
"roc_curve": model.test_roc_curve,
|
|
58
|
+
"num_pos": model.test_num_pos,
|
|
59
|
+
"pos_freq": model.test_pos_freq,
|
|
45
60
|
}
|
|
46
|
-
|
|
61
|
+
|
|
47
62
|
self.results.append(entry)
|
|
48
63
|
|
|
49
64
|
if not self.pos_freq:
|
|
50
|
-
self.pos_freq = entry[
|
|
51
|
-
self.num_pos = entry[
|
|
65
|
+
self.pos_freq = entry["pos_freq"]
|
|
66
|
+
self.num_pos = entry["num_pos"]
|
|
52
67
|
|
|
53
68
|
def get_metrics_dataframe(self):
|
|
54
69
|
"""
|
|
55
70
|
Return all metrics as pandas DataFrame.
|
|
56
71
|
"""
|
|
57
72
|
df = pd.DataFrame(self.results)
|
|
58
|
-
return df[[
|
|
73
|
+
return df[["name", "f1", "auc", "pr_auc", "pr_auc_norm", "num_pos", "pos_freq"]]
|
|
59
74
|
|
|
60
75
|
def plot_all_curves(self):
|
|
61
76
|
"""
|
|
@@ -66,30 +81,31 @@ class ModelEvaluator:
|
|
|
66
81
|
# ROC
|
|
67
82
|
plt.subplot(1, 2, 1)
|
|
68
83
|
for res in self.results:
|
|
69
|
-
fpr, tpr = res[
|
|
84
|
+
fpr, tpr = res["roc_curve"]
|
|
70
85
|
plt.plot(fpr, tpr, label=f"{res['name']} (AUC={res['auc']:.3f})")
|
|
71
86
|
plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
|
|
72
87
|
plt.xlabel("False Positive Rate")
|
|
73
88
|
plt.ylabel("True Positive Rate")
|
|
74
|
-
plt.ylim(0,1.05)
|
|
89
|
+
plt.ylim(0, 1.05)
|
|
75
90
|
plt.title(f"ROC Curves - {self.num_pos} positive instances")
|
|
76
91
|
plt.legend()
|
|
77
92
|
|
|
78
93
|
# PR
|
|
79
94
|
plt.subplot(1, 2, 2)
|
|
80
95
|
for res in self.results:
|
|
81
|
-
rc, pr = res[
|
|
96
|
+
rc, pr = res["pr_curve"]
|
|
82
97
|
plt.plot(rc, pr, label=f"{res['name']} (AUPRC={res['pr_auc']:.3f})")
|
|
83
98
|
plt.xlabel("Recall")
|
|
84
99
|
plt.ylabel("Precision")
|
|
85
|
-
plt.ylim(0,1.05)
|
|
86
|
-
plt.axhline(self.pos_freq, linestyle=
|
|
100
|
+
plt.ylim(0, 1.05)
|
|
101
|
+
plt.axhline(self.pos_freq, linestyle="--", color="grey")
|
|
87
102
|
plt.title(f"Precision-Recall Curves - {self.num_pos} positive instances")
|
|
88
103
|
plt.legend()
|
|
89
104
|
|
|
90
105
|
plt.tight_layout()
|
|
91
106
|
plt.show()
|
|
92
107
|
|
|
108
|
+
|
|
93
109
|
class PostInferenceModelEvaluator:
|
|
94
110
|
def __init__(self, adata, models, target_eval_freq=None, max_eval_positive=None):
|
|
95
111
|
"""
|
|
@@ -179,12 +195,14 @@ class PostInferenceModelEvaluator:
|
|
|
179
195
|
"pos_freq": pos_freq,
|
|
180
196
|
"confusion_matrix": cm,
|
|
181
197
|
"pr_rc_curve": (pr, rc),
|
|
182
|
-
"roc_curve": (tpr, fpr)
|
|
198
|
+
"roc_curve": (tpr, fpr),
|
|
183
199
|
}
|
|
184
200
|
|
|
185
201
|
return metrics
|
|
186
|
-
|
|
187
|
-
def _subsample_for_fixed_positive_frequency(
|
|
202
|
+
|
|
203
|
+
def _subsample_for_fixed_positive_frequency(
|
|
204
|
+
self, binary_labels, target_freq=0.3, max_positive=None
|
|
205
|
+
):
|
|
188
206
|
pos_idx = np.where(binary_labels == 1)[0]
|
|
189
207
|
neg_idx = np.where(binary_labels == 0)[0]
|
|
190
208
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from .lightning_inference import run_lightning_inference
|
|
4
|
+
from .sklearn_inference import run_sklearn_inference
|
|
2
5
|
from .sliding_window_inference import sliding_window_inference
|
|
3
|
-
from .sklearn_inference import run_sklearn_inference
|
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import pandas as pd
|
|
2
4
|
|
|
5
|
+
|
|
3
6
|
def annotate_split_column(adata, model, split_col="split"):
|
|
4
7
|
"""
|
|
5
8
|
Annotate adata.obs with train/val/test/new labels based on model's stored obs_names.
|
|
@@ -8,7 +11,7 @@ def annotate_split_column(adata, model, split_col="split"):
|
|
|
8
11
|
train_set = set(model.train_obs_names)
|
|
9
12
|
val_set = set(model.val_obs_names)
|
|
10
13
|
test_set = set(model.test_obs_names)
|
|
11
|
-
|
|
14
|
+
|
|
12
15
|
# Create array for split labels
|
|
13
16
|
split_labels = []
|
|
14
17
|
for obs in adata.obs_names:
|
|
@@ -20,8 +23,10 @@ def annotate_split_column(adata, model, split_col="split"):
|
|
|
20
23
|
split_labels.append("testing")
|
|
21
24
|
else:
|
|
22
25
|
split_labels.append("new")
|
|
23
|
-
|
|
26
|
+
|
|
24
27
|
# Store in AnnData.obs
|
|
25
|
-
adata.obs[split_col] = pd.Categorical(
|
|
26
|
-
|
|
28
|
+
adata.obs[split_col] = pd.Categorical(
|
|
29
|
+
split_labels, categories=["training", "validation", "testing", "new"]
|
|
30
|
+
)
|
|
31
|
+
|
|
27
32
|
print(f"Annotated {split_col} column with training/validation/testing/new status.")
|
|
@@ -1,17 +1,16 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
3
|
import numpy as np
|
|
4
|
-
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from smftools.optional_imports import require
|
|
7
|
+
|
|
5
8
|
from .inference_utils import annotate_split_column
|
|
6
9
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
trainer,
|
|
12
|
-
prefix="model",
|
|
13
|
-
devices=1
|
|
14
|
-
):
|
|
10
|
+
torch = require("torch", extra="ml-base", purpose="Lightning inference")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def run_lightning_inference(adata, model, datamodule, trainer, prefix="model", devices=1):
|
|
15
14
|
"""
|
|
16
15
|
Run inference on AnnData using TorchClassifierWrapper + AnnDataModule (in inference mode).
|
|
17
16
|
"""
|
|
@@ -57,7 +56,9 @@ def run_lightning_inference(
|
|
|
57
56
|
full_prefix = f"{prefix}_{label_col}"
|
|
58
57
|
|
|
59
58
|
adata.obs[f"{full_prefix}_pred"] = pred_class_idx
|
|
60
|
-
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
|
|
59
|
+
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
|
|
60
|
+
pred_class_labels, categories=class_labels
|
|
61
|
+
)
|
|
61
62
|
adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
|
|
62
63
|
|
|
63
64
|
for i, class_name in enumerate(class_labels):
|
|
@@ -65,4 +66,4 @@ def run_lightning_inference(
|
|
|
65
66
|
|
|
66
67
|
adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
|
|
67
68
|
|
|
68
|
-
print(f"Inference complete: stored under prefix '{full_prefix}'")
|
|
69
|
+
print(f"Inference complete: stored under prefix '{full_prefix}'")
|
|
@@ -1,14 +1,12 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
3
6
|
from .inference_utils import annotate_split_column
|
|
4
7
|
|
|
5
8
|
|
|
6
|
-
def run_sklearn_inference(
|
|
7
|
-
adata,
|
|
8
|
-
model,
|
|
9
|
-
datamodule,
|
|
10
|
-
prefix="model"
|
|
11
|
-
):
|
|
9
|
+
def run_sklearn_inference(adata, model, datamodule, prefix="model"):
|
|
12
10
|
"""
|
|
13
11
|
Run inference on AnnData using SklearnModelWrapper.
|
|
14
12
|
"""
|
|
@@ -44,7 +42,9 @@ def run_sklearn_inference(
|
|
|
44
42
|
full_prefix = f"{prefix}_{label_col}"
|
|
45
43
|
|
|
46
44
|
adata.obs[f"{full_prefix}_pred"] = pred_class_idx
|
|
47
|
-
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
|
|
45
|
+
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(
|
|
46
|
+
pred_class_labels, categories=class_labels
|
|
47
|
+
)
|
|
48
48
|
adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
|
|
49
49
|
|
|
50
50
|
for i, class_name in enumerate(class_labels):
|
|
@@ -1,18 +1,21 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from ..data import AnnDataModule
|
|
2
4
|
from ..evaluation import PostInferenceModelEvaluator
|
|
3
5
|
from .lightning_inference import run_lightning_inference
|
|
4
6
|
from .sklearn_inference import run_sklearn_inference
|
|
5
7
|
|
|
8
|
+
|
|
6
9
|
def sliding_window_inference(
|
|
7
|
-
adata,
|
|
8
|
-
trained_results,
|
|
9
|
-
tensor_source=
|
|
10
|
+
adata,
|
|
11
|
+
trained_results,
|
|
12
|
+
tensor_source="X",
|
|
10
13
|
tensor_key=None,
|
|
11
|
-
label_col=
|
|
14
|
+
label_col="activity_status",
|
|
12
15
|
batch_size=64,
|
|
13
16
|
cleanup=False,
|
|
14
|
-
target_eval_freq=None,
|
|
15
|
-
max_eval_positive=None
|
|
17
|
+
target_eval_freq=None,
|
|
18
|
+
max_eval_positive=None,
|
|
16
19
|
):
|
|
17
20
|
"""
|
|
18
21
|
Apply trained sliding window models to an AnnData object (Lightning or Sklearn).
|
|
@@ -24,11 +27,11 @@ def sliding_window_inference(
|
|
|
24
27
|
for window_size, window_data in model_dict.items():
|
|
25
28
|
for center_varname, run in window_data.items():
|
|
26
29
|
print(f"\nEvaluating {model_name} window {window_size} around {center_varname}")
|
|
27
|
-
|
|
30
|
+
|
|
28
31
|
# Extract window start from varname
|
|
29
32
|
center_idx = adata.var_names.get_loc(center_varname)
|
|
30
33
|
window_start = center_idx - window_size // 2
|
|
31
|
-
|
|
34
|
+
|
|
32
35
|
# Build datamodule for window
|
|
33
36
|
datamodule = AnnDataModule(
|
|
34
37
|
adata,
|
|
@@ -38,31 +41,31 @@ def sliding_window_inference(
|
|
|
38
41
|
batch_size=batch_size,
|
|
39
42
|
window_start=window_start,
|
|
40
43
|
window_size=window_size,
|
|
41
|
-
inference_mode=True
|
|
44
|
+
inference_mode=True,
|
|
42
45
|
)
|
|
43
46
|
datamodule.setup()
|
|
44
47
|
|
|
45
48
|
# Extract model + detect type
|
|
46
|
-
model = run[
|
|
49
|
+
model = run["model"]
|
|
47
50
|
|
|
48
51
|
# Lightning models
|
|
49
|
-
if hasattr(run,
|
|
50
|
-
trainer = run[
|
|
52
|
+
if hasattr(run, "trainer") or "trainer" in run:
|
|
53
|
+
trainer = run["trainer"]
|
|
51
54
|
run_lightning_inference(
|
|
52
55
|
adata,
|
|
53
56
|
model=model,
|
|
54
57
|
datamodule=datamodule,
|
|
55
58
|
trainer=trainer,
|
|
56
|
-
prefix=f"{model_name}_w{window_size}_c{center_varname}"
|
|
59
|
+
prefix=f"{model_name}_w{window_size}_c{center_varname}",
|
|
57
60
|
)
|
|
58
|
-
|
|
61
|
+
|
|
59
62
|
# Sklearn models
|
|
60
63
|
else:
|
|
61
64
|
run_sklearn_inference(
|
|
62
65
|
adata,
|
|
63
66
|
model=model,
|
|
64
67
|
datamodule=datamodule,
|
|
65
|
-
prefix=f"{model_name}_w{window_size}_c{center_varname}"
|
|
68
|
+
prefix=f"{model_name}_w{window_size}_c{center_varname}",
|
|
66
69
|
)
|
|
67
70
|
|
|
68
71
|
print("Inference complete across all models.")
|
|
@@ -77,27 +80,36 @@ def sliding_window_inference(
|
|
|
77
80
|
prefix = f"{model_name}_w{window_size}_c{center_varname}"
|
|
78
81
|
# Use full key for uniqueness
|
|
79
82
|
key = prefix
|
|
80
|
-
model_wrappers[key] = run[
|
|
83
|
+
model_wrappers[key] = run["model"]
|
|
81
84
|
|
|
82
85
|
# Run evaluator
|
|
83
|
-
evaluator = PostInferenceModelEvaluator(
|
|
86
|
+
evaluator = PostInferenceModelEvaluator(
|
|
87
|
+
adata,
|
|
88
|
+
model_wrappers,
|
|
89
|
+
target_eval_freq=target_eval_freq,
|
|
90
|
+
max_eval_positive=max_eval_positive,
|
|
91
|
+
)
|
|
84
92
|
evaluator.evaluate_all()
|
|
85
93
|
|
|
86
94
|
# Get results
|
|
87
95
|
df = evaluator.to_dataframe()
|
|
88
96
|
|
|
89
|
-
df[[
|
|
97
|
+
df[["model_name", "window_size", "center"]] = df["model"].str.extract(
|
|
98
|
+
r"(\w+)_w(\d+)_c(\d+)_activity_status"
|
|
99
|
+
)
|
|
90
100
|
|
|
91
101
|
# Cast window_size and center to integers for plotting
|
|
92
|
-
df[
|
|
93
|
-
df[
|
|
102
|
+
df["window_size"] = df["window_size"].astype(int)
|
|
103
|
+
df["center"] = df["center"].astype(int)
|
|
94
104
|
|
|
95
105
|
## Optional cleanup:
|
|
96
106
|
if cleanup:
|
|
97
|
-
prefixes = [
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
107
|
+
prefixes = [
|
|
108
|
+
f"{model_name}_w{window_size}_c{center_varname}"
|
|
109
|
+
for model_name, model_dict in trained_results.items()
|
|
110
|
+
for window_size, window_data in model_dict.items()
|
|
111
|
+
for center_varname in window_data.keys()
|
|
112
|
+
]
|
|
101
113
|
|
|
102
114
|
# Remove matching obs columns
|
|
103
115
|
for prefix in prefixes:
|
|
@@ -111,4 +123,4 @@ def sliding_window_inference(
|
|
|
111
123
|
|
|
112
124
|
print(f"Cleaned up {len(prefixes)} model prefixes from AnnData.")
|
|
113
125
|
|
|
114
|
-
return df
|
|
126
|
+
return df
|
|
@@ -1,9 +1,16 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from .base import BaseTorchModel
|
|
2
|
-
from .mlp import MLPClassifier
|
|
3
4
|
from .cnn import CNNClassifier
|
|
4
|
-
from .
|
|
5
|
-
from .
|
|
5
|
+
from .lightning_base import TorchClassifierWrapper
|
|
6
|
+
from .mlp import MLPClassifier
|
|
6
7
|
from .positional import PositionalEncoding
|
|
8
|
+
from .rnn import RNNClassifier
|
|
9
|
+
from .sklearn_models import SklearnModelWrapper
|
|
10
|
+
from .transformer import (
|
|
11
|
+
BaseTransformer,
|
|
12
|
+
DANNTransformerClassifier,
|
|
13
|
+
MaskedTransformerPretrainer,
|
|
14
|
+
TransformerClassifier,
|
|
15
|
+
)
|
|
7
16
|
from .wrappers import ScaledModel
|
|
8
|
-
from .lightning_base import TorchClassifierWrapper
|
|
9
|
-
from .sklearn_models import SklearnModelWrapper
|
|
@@ -1,17 +1,25 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
3
|
import numpy as np
|
|
4
|
+
|
|
5
|
+
from smftools.optional_imports import require
|
|
6
|
+
|
|
4
7
|
from ..utils.device import detect_device
|
|
5
8
|
|
|
9
|
+
torch = require("torch", extra="ml-base", purpose="ML base models")
|
|
10
|
+
nn = torch.nn
|
|
11
|
+
|
|
12
|
+
|
|
6
13
|
class BaseTorchModel(nn.Module):
|
|
7
14
|
"""
|
|
8
15
|
Minimal base class for torch models that:
|
|
9
16
|
- Stores device and dropout regularization
|
|
10
17
|
"""
|
|
18
|
+
|
|
11
19
|
def __init__(self, dropout_rate=0.0):
|
|
12
20
|
super().__init__()
|
|
13
|
-
self.device = detect_device()
|
|
14
|
-
self.dropout_rate = dropout_rate
|
|
21
|
+
self.device = detect_device() # detects available devices
|
|
22
|
+
self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
|
|
15
23
|
|
|
16
24
|
def compute_saliency(
|
|
17
25
|
self,
|
|
@@ -21,11 +29,11 @@ class BaseTorchModel(nn.Module):
|
|
|
21
29
|
smoothgrad=False,
|
|
22
30
|
smooth_samples=25,
|
|
23
31
|
smooth_noise=0.1,
|
|
24
|
-
signed=True
|
|
32
|
+
signed=True,
|
|
25
33
|
):
|
|
26
34
|
"""
|
|
27
35
|
Compute vanilla saliency or SmoothGrad saliency.
|
|
28
|
-
|
|
36
|
+
|
|
29
37
|
Arguments:
|
|
30
38
|
----------
|
|
31
39
|
x : torch.Tensor
|
|
@@ -43,7 +51,7 @@ class BaseTorchModel(nn.Module):
|
|
|
43
51
|
"""
|
|
44
52
|
self.eval()
|
|
45
53
|
x = x.clone().detach().requires_grad_(True)
|
|
46
|
-
|
|
54
|
+
|
|
47
55
|
if smoothgrad:
|
|
48
56
|
saliency_accum = torch.zeros_like(x)
|
|
49
57
|
for i in range(smooth_samples):
|
|
@@ -56,7 +64,7 @@ class BaseTorchModel(nn.Module):
|
|
|
56
64
|
if logits.shape[1] == 1:
|
|
57
65
|
scores = logits.squeeze(1)
|
|
58
66
|
else:
|
|
59
|
-
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
67
|
+
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
60
68
|
scores.sum().backward()
|
|
61
69
|
saliency_accum += x_noisy.grad.detach()
|
|
62
70
|
saliency = saliency_accum / smooth_samples
|
|
@@ -69,17 +77,17 @@ class BaseTorchModel(nn.Module):
|
|
|
69
77
|
scores = logits[torch.arange(x.shape[0]), target_class]
|
|
70
78
|
scores.sum().backward()
|
|
71
79
|
saliency = x.grad.detach()
|
|
72
|
-
|
|
80
|
+
|
|
73
81
|
if not signed:
|
|
74
82
|
saliency = saliency.abs()
|
|
75
|
-
|
|
83
|
+
|
|
76
84
|
if reduction == "sum" and x.ndim == 3:
|
|
77
85
|
return saliency.sum(dim=-1)
|
|
78
86
|
elif reduction == "mean" and x.ndim == 3:
|
|
79
87
|
return saliency.mean(dim=-1)
|
|
80
88
|
else:
|
|
81
89
|
return saliency
|
|
82
|
-
|
|
90
|
+
|
|
83
91
|
def compute_gradient_x_input(self, x, target_class=None):
|
|
84
92
|
"""
|
|
85
93
|
Computes gradient × input attribution.
|
|
@@ -118,22 +126,11 @@ class BaseTorchModel(nn.Module):
|
|
|
118
126
|
baseline = torch.zeros_like(x)
|
|
119
127
|
|
|
120
128
|
attributions, delta = ig.attribute(
|
|
121
|
-
x,
|
|
122
|
-
baselines=baseline,
|
|
123
|
-
target=target_class,
|
|
124
|
-
n_steps=steps,
|
|
125
|
-
return_convergence_delta=True
|
|
129
|
+
x, baselines=baseline, target=target_class, n_steps=steps, return_convergence_delta=True
|
|
126
130
|
)
|
|
127
131
|
return attributions, delta
|
|
128
132
|
|
|
129
|
-
def compute_deeplift(
|
|
130
|
-
self,
|
|
131
|
-
x,
|
|
132
|
-
baseline=None,
|
|
133
|
-
target_class=None,
|
|
134
|
-
reduction="sum",
|
|
135
|
-
signed=True
|
|
136
|
-
):
|
|
133
|
+
def compute_deeplift(self, x, baseline=None, target_class=None, reduction="sum", signed=True):
|
|
137
134
|
"""
|
|
138
135
|
Compute DeepLIFT scores using captum.
|
|
139
136
|
|
|
@@ -158,21 +155,15 @@ class BaseTorchModel(nn.Module):
|
|
|
158
155
|
|
|
159
156
|
if not signed:
|
|
160
157
|
attr = attr.abs()
|
|
161
|
-
|
|
158
|
+
|
|
162
159
|
if reduction == "sum" and x.ndim == 3:
|
|
163
160
|
return attr.sum(dim=-1)
|
|
164
161
|
elif reduction == "mean" and x.ndim == 3:
|
|
165
162
|
return attr.mean(dim=-1)
|
|
166
163
|
else:
|
|
167
164
|
return attr
|
|
168
|
-
|
|
169
|
-
def compute_occlusion(
|
|
170
|
-
self,
|
|
171
|
-
x,
|
|
172
|
-
target_class=None,
|
|
173
|
-
window_size=5,
|
|
174
|
-
baseline=None
|
|
175
|
-
):
|
|
165
|
+
|
|
166
|
+
def compute_occlusion(self, x, target_class=None, window_size=5, baseline=None):
|
|
176
167
|
"""
|
|
177
168
|
Computes per-sample occlusion attribution.
|
|
178
169
|
Supports 2D [B, S] or 3D [B, S, D] inputs.
|
|
@@ -208,9 +199,7 @@ class BaseTorchModel(nn.Module):
|
|
|
208
199
|
x_occluded[left:right, :] = baseline[left:right, :]
|
|
209
200
|
|
|
210
201
|
x_tensor = torch.tensor(
|
|
211
|
-
x_occluded,
|
|
212
|
-
device=self.device,
|
|
213
|
-
dtype=torch.float32
|
|
202
|
+
x_occluded, device=self.device, dtype=torch.float32
|
|
214
203
|
).unsqueeze(0)
|
|
215
204
|
|
|
216
205
|
logits = self.forward(x_tensor)
|
|
@@ -235,7 +224,7 @@ class BaseTorchModel(nn.Module):
|
|
|
235
224
|
device="cpu",
|
|
236
225
|
target_class=None,
|
|
237
226
|
normalize=True,
|
|
238
|
-
signed=True
|
|
227
|
+
signed=True,
|
|
239
228
|
):
|
|
240
229
|
"""
|
|
241
230
|
Apply a chosen attribution method to a dataloader and store results in adata.
|
|
@@ -252,7 +241,9 @@ class BaseTorchModel(nn.Module):
|
|
|
252
241
|
attr = model.compute_saliency(x, target_class=target_class, signed=signed)
|
|
253
242
|
|
|
254
243
|
elif method == "smoothgrad":
|
|
255
|
-
attr = model.compute_saliency(
|
|
244
|
+
attr = model.compute_saliency(
|
|
245
|
+
x, smoothgrad=True, target_class=target_class, signed=signed
|
|
246
|
+
)
|
|
256
247
|
|
|
257
248
|
elif method == "IG":
|
|
258
249
|
attributions, delta = model.compute_integrated_gradients(
|
|
@@ -261,15 +252,15 @@ class BaseTorchModel(nn.Module):
|
|
|
261
252
|
attr = attributions
|
|
262
253
|
|
|
263
254
|
elif method == "deeplift":
|
|
264
|
-
attr = model.compute_deeplift(
|
|
255
|
+
attr = model.compute_deeplift(
|
|
256
|
+
x, baseline=baseline, target_class=target_class, signed=signed
|
|
257
|
+
)
|
|
265
258
|
|
|
266
259
|
elif method == "gradxinput":
|
|
267
260
|
attr = model.compute_gradient_x_input(x, target_class=target_class)
|
|
268
261
|
|
|
269
262
|
elif method == "occlusion":
|
|
270
|
-
attr = model.compute_occlusion(
|
|
271
|
-
x, target_class=target_class, baseline=baseline
|
|
272
|
-
)
|
|
263
|
+
attr = model.compute_occlusion(x, target_class=target_class, baseline=baseline)
|
|
273
264
|
|
|
274
265
|
else:
|
|
275
266
|
raise ValueError(f"Unknown method {method}")
|
|
@@ -292,4 +283,4 @@ class BaseTorchModel(nn.Module):
|
|
|
292
283
|
return target_class
|
|
293
284
|
if logits.shape[1] == 1:
|
|
294
285
|
return (logits > 0).long().squeeze(1)
|
|
295
|
-
return logits.argmax(dim=1)
|
|
286
|
+
return logits.argmax(dim=1)
|