smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +7 -6
- smftools/_version.py +1 -1
- smftools/cli/cli_flows.py +94 -0
- smftools/cli/hmm_adata.py +338 -0
- smftools/cli/load_adata.py +577 -0
- smftools/cli/preprocess_adata.py +363 -0
- smftools/cli/spatial_adata.py +564 -0
- smftools/cli_entry.py +435 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +38 -0
- smftools/config/deaminase.yaml +61 -0
- smftools/config/default.yaml +264 -0
- smftools/config/direct.yaml +41 -0
- smftools/config/discover_input_files.py +115 -0
- smftools/config/experiment_config.py +1288 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +13 -9
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/fast5_to_pod5.py +43 -0
- smftools/informatics/archived/helpers/archived/__init__.py +71 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
- smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
- smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
- smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
- smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
- smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
- smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
- smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
- smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
- smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
- smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
- smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
- smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
- smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
- smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
- smftools/informatics/bam_functions.py +812 -0
- smftools/informatics/basecalling.py +67 -0
- smftools/informatics/bed_functions.py +366 -0
- smftools/informatics/binarize_converted_base_identities.py +172 -0
- smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
- smftools/informatics/fasta_functions.py +255 -0
- smftools/informatics/h5ad_functions.py +197 -0
- smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
- smftools/informatics/modkit_functions.py +129 -0
- smftools/informatics/ohe.py +160 -0
- smftools/informatics/pod5_functions.py +224 -0
- smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +609 -0
- smftools/plotting/general_plotting.py +1292 -140
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +15 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/binarize.py +17 -0
- smftools/preprocessing/binarize_on_Youden.py +2 -2
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_position_Youden.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +1021 -89
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
- smftools-0.2.3.dist-info/RECORD +173 -0
- smftools-0.2.3.dist-info/entry_points.txt +2 -0
- smftools/informatics/fast5_to_pod5.py +0 -21
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/__init__.py +0 -74
- smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
- smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
- smftools/informatics/helpers/bam_qc.py +0 -66
- smftools/informatics/helpers/bed_to_bigwig.py +0 -39
- smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
- smftools/informatics/helpers/index_fasta.py +0 -12
- smftools/informatics/helpers/make_dirs.py +0 -21
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/informatics/readwrite.py +0 -106
- smftools/informatics/subsample_fasta_from_bed.py +0 -47
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools/evaluation → cli}/__init__.py +0 -0
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
- /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
- /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
|
|
5
|
+
from sklearn.metrics import (
|
|
6
|
+
roc_auc_score, precision_recall_curve, auc, f1_score, confusion_matrix, roc_curve
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
class ModelEvaluator:
|
|
10
|
+
"""
|
|
11
|
+
A model evaluator for consolidating Sklearn and Lightning model evaluation metrics on testing data
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.results = []
|
|
15
|
+
self.pos_freq = None
|
|
16
|
+
self.num_pos = None
|
|
17
|
+
|
|
18
|
+
def add_model(self, name, model, is_torch=True):
|
|
19
|
+
"""
|
|
20
|
+
Add a trained model with its evaluation metrics.
|
|
21
|
+
"""
|
|
22
|
+
if is_torch:
|
|
23
|
+
entry = {
|
|
24
|
+
'name': name,
|
|
25
|
+
'f1': model.test_f1,
|
|
26
|
+
'auc': model.test_roc_auc,
|
|
27
|
+
'pr_auc': model.test_pr_auc,
|
|
28
|
+
'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
|
|
29
|
+
'pr_curve': model.test_pr_curve,
|
|
30
|
+
'roc_curve': model.test_roc_curve,
|
|
31
|
+
'num_pos': model.test_num_pos,
|
|
32
|
+
'pos_freq': model.test_pos_freq
|
|
33
|
+
}
|
|
34
|
+
else:
|
|
35
|
+
entry = {
|
|
36
|
+
'name': name,
|
|
37
|
+
'f1': model.test_f1,
|
|
38
|
+
'auc': model.test_roc_auc,
|
|
39
|
+
'pr_auc': model.test_pr_auc,
|
|
40
|
+
'pr_auc_norm': model.test_pr_auc / model.test_pos_freq if model.test_pos_freq > 0 else np.nan,
|
|
41
|
+
'pr_curve': model.test_pr_curve,
|
|
42
|
+
'roc_curve': model.test_roc_curve,
|
|
43
|
+
'num_pos': model.test_num_pos,
|
|
44
|
+
'pos_freq': model.test_pos_freq
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
self.results.append(entry)
|
|
48
|
+
|
|
49
|
+
if not self.pos_freq:
|
|
50
|
+
self.pos_freq = entry['pos_freq']
|
|
51
|
+
self.num_pos = entry['num_pos']
|
|
52
|
+
|
|
53
|
+
def get_metrics_dataframe(self):
|
|
54
|
+
"""
|
|
55
|
+
Return all metrics as pandas DataFrame.
|
|
56
|
+
"""
|
|
57
|
+
df = pd.DataFrame(self.results)
|
|
58
|
+
return df[['name', 'f1', 'auc', 'pr_auc', 'pr_auc_norm', 'num_pos', 'pos_freq']]
|
|
59
|
+
|
|
60
|
+
def plot_all_curves(self):
|
|
61
|
+
"""
|
|
62
|
+
Plot unified ROC and PR curves across all models.
|
|
63
|
+
"""
|
|
64
|
+
plt.figure(figsize=(12, 5))
|
|
65
|
+
|
|
66
|
+
# ROC
|
|
67
|
+
plt.subplot(1, 2, 1)
|
|
68
|
+
for res in self.results:
|
|
69
|
+
fpr, tpr = res['roc_curve']
|
|
70
|
+
plt.plot(fpr, tpr, label=f"{res['name']} (AUC={res['auc']:.3f})")
|
|
71
|
+
plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
|
|
72
|
+
plt.xlabel("False Positive Rate")
|
|
73
|
+
plt.ylabel("True Positive Rate")
|
|
74
|
+
plt.ylim(0,1.05)
|
|
75
|
+
plt.title(f"ROC Curves - {self.num_pos} positive instances")
|
|
76
|
+
plt.legend()
|
|
77
|
+
|
|
78
|
+
# PR
|
|
79
|
+
plt.subplot(1, 2, 2)
|
|
80
|
+
for res in self.results:
|
|
81
|
+
rc, pr = res['pr_curve']
|
|
82
|
+
plt.plot(rc, pr, label=f"{res['name']} (AUPRC={res['pr_auc']:.3f})")
|
|
83
|
+
plt.xlabel("Recall")
|
|
84
|
+
plt.ylabel("Precision")
|
|
85
|
+
plt.ylim(0,1.05)
|
|
86
|
+
plt.axhline(self.pos_freq, linestyle='--', color='grey')
|
|
87
|
+
plt.title(f"Precision-Recall Curves - {self.num_pos} positive instances")
|
|
88
|
+
plt.legend()
|
|
89
|
+
|
|
90
|
+
plt.tight_layout()
|
|
91
|
+
plt.show()
|
|
92
|
+
|
|
93
|
+
class PostInferenceModelEvaluator:
|
|
94
|
+
def __init__(self, adata, models, target_eval_freq=None, max_eval_positive=None):
|
|
95
|
+
"""
|
|
96
|
+
Initialize evaluator.
|
|
97
|
+
|
|
98
|
+
Parameters:
|
|
99
|
+
-----------
|
|
100
|
+
adata : AnnData
|
|
101
|
+
The annotated dataset where predictions are stored in obs/obsm.
|
|
102
|
+
models : dict
|
|
103
|
+
Dictionary of models: {model_name: model_instance}.
|
|
104
|
+
Supports TorchClassifierWrapper and SklearnModelWrapper.
|
|
105
|
+
"""
|
|
106
|
+
self.adata = adata
|
|
107
|
+
self.models = models
|
|
108
|
+
self.target_eval_freq = target_eval_freq
|
|
109
|
+
self.max_eval_positive = max_eval_positive
|
|
110
|
+
self.results = {}
|
|
111
|
+
|
|
112
|
+
def evaluate_all(self):
|
|
113
|
+
"""
|
|
114
|
+
Evaluate all models and store results.
|
|
115
|
+
"""
|
|
116
|
+
for name, model in self.models.items():
|
|
117
|
+
print(f"Evaluating {name}...")
|
|
118
|
+
label_col = model.label_col
|
|
119
|
+
full_prefix = f"{name}_{label_col}"
|
|
120
|
+
self.results[full_prefix] = self._evaluate_model(name, model)
|
|
121
|
+
|
|
122
|
+
def _evaluate_model(self, model_name, model):
|
|
123
|
+
"""
|
|
124
|
+
Evaluate one model and return metrics.
|
|
125
|
+
"""
|
|
126
|
+
label_col = model.label_col
|
|
127
|
+
num_classes = model.num_classes
|
|
128
|
+
class_names = model.class_names
|
|
129
|
+
focus_class = model.focus_class
|
|
130
|
+
|
|
131
|
+
full_prefix = f"{model_name}_{label_col}"
|
|
132
|
+
|
|
133
|
+
# Extract ground truth + predictions
|
|
134
|
+
y_true = self.adata.obs[label_col].cat.codes.to_numpy()
|
|
135
|
+
y_pred = self.adata.obs[f"{full_prefix}_pred"].to_numpy()
|
|
136
|
+
probs_all = self.adata.obsm[f"{full_prefix}_pred_prob_all"]
|
|
137
|
+
|
|
138
|
+
binary_focus = (y_true == focus_class).astype(int)
|
|
139
|
+
|
|
140
|
+
# OPTIONAL SUBSAMPLING
|
|
141
|
+
if self.target_eval_freq is not None:
|
|
142
|
+
indices = self._subsample_for_fixed_positive_frequency(
|
|
143
|
+
binary_focus, target_freq=self.target_eval_freq, max_positive=self.max_eval_positive
|
|
144
|
+
)
|
|
145
|
+
y_true = y_true[indices]
|
|
146
|
+
y_pred = y_pred[indices]
|
|
147
|
+
probs_all = probs_all[indices]
|
|
148
|
+
binary_focus = (y_true == focus_class).astype(int)
|
|
149
|
+
|
|
150
|
+
acc = np.mean(y_true == y_pred)
|
|
151
|
+
|
|
152
|
+
if num_classes == 2:
|
|
153
|
+
focus_probs = probs_all[:, focus_class]
|
|
154
|
+
f1 = f1_score(binary_focus, (y_pred == focus_class).astype(int))
|
|
155
|
+
roc_auc = roc_auc_score(binary_focus, focus_probs)
|
|
156
|
+
pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
|
|
157
|
+
fpr, tpr, _ = roc_curve(binary_focus, focus_probs)
|
|
158
|
+
pr_auc = auc(rc, pr)
|
|
159
|
+
pos_freq = binary_focus.mean()
|
|
160
|
+
pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
|
|
161
|
+
else:
|
|
162
|
+
f1 = f1_score(y_true, y_pred, average="macro")
|
|
163
|
+
roc_auc = roc_auc_score(y_true, probs_all, multi_class="ovr", average="macro")
|
|
164
|
+
focus_probs = probs_all[:, focus_class]
|
|
165
|
+
pr, rc, _ = precision_recall_curve(binary_focus, focus_probs)
|
|
166
|
+
fpr, tpr, _ = roc_curve(binary_focus, focus_probs)
|
|
167
|
+
pr_auc = auc(rc, pr)
|
|
168
|
+
pos_freq = binary_focus.mean()
|
|
169
|
+
pr_auc_norm = pr_auc / pos_freq if pos_freq > 0 else np.nan
|
|
170
|
+
|
|
171
|
+
cm = confusion_matrix(y_true, y_pred)
|
|
172
|
+
|
|
173
|
+
metrics = {
|
|
174
|
+
"accuracy": acc,
|
|
175
|
+
"f1": f1,
|
|
176
|
+
"roc_auc": roc_auc,
|
|
177
|
+
"pr_auc": pr_auc,
|
|
178
|
+
"pr_auc_norm": pr_auc_norm,
|
|
179
|
+
"pos_freq": pos_freq,
|
|
180
|
+
"confusion_matrix": cm,
|
|
181
|
+
"pr_rc_curve": (pr, rc),
|
|
182
|
+
"roc_curve": (tpr, fpr)
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
return metrics
|
|
186
|
+
|
|
187
|
+
def _subsample_for_fixed_positive_frequency(self, binary_labels, target_freq=0.3, max_positive=None):
|
|
188
|
+
pos_idx = np.where(binary_labels == 1)[0]
|
|
189
|
+
neg_idx = np.where(binary_labels == 0)[0]
|
|
190
|
+
|
|
191
|
+
max_pos = len(pos_idx)
|
|
192
|
+
max_neg = len(neg_idx)
|
|
193
|
+
|
|
194
|
+
max_possible_freq = max_pos / (max_pos + max_neg)
|
|
195
|
+
if target_freq > max_possible_freq:
|
|
196
|
+
target_freq = max_possible_freq
|
|
197
|
+
|
|
198
|
+
num_pos_target = int(target_freq * max_neg / (1 - target_freq))
|
|
199
|
+
num_pos_target = min(num_pos_target, max_pos)
|
|
200
|
+
if max_positive is not None:
|
|
201
|
+
num_pos_target = min(num_pos_target, max_positive)
|
|
202
|
+
|
|
203
|
+
num_neg_target = int(num_pos_target * (1 - target_freq) / target_freq)
|
|
204
|
+
num_neg_target = min(num_neg_target, max_neg)
|
|
205
|
+
|
|
206
|
+
pos_sampled = np.random.choice(pos_idx, size=num_pos_target, replace=False)
|
|
207
|
+
neg_sampled = np.random.choice(neg_idx, size=num_neg_target, replace=False)
|
|
208
|
+
sampled_idx = np.concatenate([pos_sampled, neg_sampled])
|
|
209
|
+
np.random.shuffle(sampled_idx)
|
|
210
|
+
return sampled_idx
|
|
211
|
+
|
|
212
|
+
def to_dataframe(self):
|
|
213
|
+
"""
|
|
214
|
+
Convert results to pandas DataFrame (excluding confusion matrices).
|
|
215
|
+
"""
|
|
216
|
+
records = []
|
|
217
|
+
for model_name, metrics in self.results.items():
|
|
218
|
+
row = {"model": model_name}
|
|
219
|
+
for k, v in metrics.items():
|
|
220
|
+
if k not in ["confusion_matrix", "pr_rc_curve", "roc_curve"]:
|
|
221
|
+
row[k] = v
|
|
222
|
+
records.append(row)
|
|
223
|
+
return pd.DataFrame(records)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
def annotate_split_column(adata, model, split_col="split"):
|
|
4
|
+
"""
|
|
5
|
+
Annotate adata.obs with train/val/test/new labels based on model's stored obs_names.
|
|
6
|
+
"""
|
|
7
|
+
# Get sets for fast lookup
|
|
8
|
+
train_set = set(model.train_obs_names)
|
|
9
|
+
val_set = set(model.val_obs_names)
|
|
10
|
+
test_set = set(model.test_obs_names)
|
|
11
|
+
|
|
12
|
+
# Create array for split labels
|
|
13
|
+
split_labels = []
|
|
14
|
+
for obs in adata.obs_names:
|
|
15
|
+
if obs in train_set:
|
|
16
|
+
split_labels.append("training")
|
|
17
|
+
elif obs in val_set:
|
|
18
|
+
split_labels.append("validation")
|
|
19
|
+
elif obs in test_set:
|
|
20
|
+
split_labels.append("testing")
|
|
21
|
+
else:
|
|
22
|
+
split_labels.append("new")
|
|
23
|
+
|
|
24
|
+
# Store in AnnData.obs
|
|
25
|
+
adata.obs[split_col] = pd.Categorical(split_labels, categories=["training", "validation", "testing", "new"])
|
|
26
|
+
|
|
27
|
+
print(f"Annotated {split_col} column with training/validation/testing/new status.")
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
from pytorch_lightning import Trainer
|
|
5
|
+
from .inference_utils import annotate_split_column
|
|
6
|
+
|
|
7
|
+
def run_lightning_inference(
|
|
8
|
+
adata,
|
|
9
|
+
model,
|
|
10
|
+
datamodule,
|
|
11
|
+
trainer,
|
|
12
|
+
prefix="model",
|
|
13
|
+
devices=1
|
|
14
|
+
):
|
|
15
|
+
"""
|
|
16
|
+
Run inference on AnnData using TorchClassifierWrapper + AnnDataModule (in inference mode).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# Device logic
|
|
20
|
+
if torch.cuda.is_available():
|
|
21
|
+
accelerator = "gpu"
|
|
22
|
+
elif torch.backends.mps.is_available():
|
|
23
|
+
accelerator = "mps"
|
|
24
|
+
devices = 1
|
|
25
|
+
else:
|
|
26
|
+
accelerator = "cpu"
|
|
27
|
+
devices = 1
|
|
28
|
+
|
|
29
|
+
label_col = model.label_col
|
|
30
|
+
num_classes = model.num_classes
|
|
31
|
+
class_labels = model.class_names
|
|
32
|
+
focus_class = model.focus_class
|
|
33
|
+
focus_class_name = model.focus_class_name
|
|
34
|
+
|
|
35
|
+
annotate_split_column(adata, model, split_col=f"{prefix}_training_split")
|
|
36
|
+
|
|
37
|
+
# Run predictions
|
|
38
|
+
outputs = trainer.predict(model, datamodule=datamodule)
|
|
39
|
+
|
|
40
|
+
preds_list, probs_list = zip(*outputs)
|
|
41
|
+
preds = torch.cat(preds_list, dim=0).cpu().numpy()
|
|
42
|
+
probs = torch.cat(probs_list, dim=0).cpu().numpy()
|
|
43
|
+
|
|
44
|
+
# Handle binary vs multiclass formats
|
|
45
|
+
if num_classes == 2:
|
|
46
|
+
# probs shape: (N,) from sigmoid
|
|
47
|
+
pred_class_idx = (probs >= 0.5).astype(int)
|
|
48
|
+
probs_all = np.vstack([1 - probs, probs]).T # shape (N, 2)
|
|
49
|
+
pred_class_probs = probs_all[np.arange(len(probs_all)), pred_class_idx]
|
|
50
|
+
else:
|
|
51
|
+
pred_class_idx = probs.argmax(axis=1)
|
|
52
|
+
probs_all = probs
|
|
53
|
+
pred_class_probs = probs_all[np.arange(len(probs_all)), pred_class_idx]
|
|
54
|
+
|
|
55
|
+
pred_class_labels = [class_labels[i] for i in pred_class_idx]
|
|
56
|
+
|
|
57
|
+
full_prefix = f"{prefix}_{label_col}"
|
|
58
|
+
|
|
59
|
+
adata.obs[f"{full_prefix}_pred"] = pred_class_idx
|
|
60
|
+
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
|
|
61
|
+
adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
|
|
62
|
+
|
|
63
|
+
for i, class_name in enumerate(class_labels):
|
|
64
|
+
adata.obs[f"{full_prefix}_prob_{class_name}"] = probs_all[:, i]
|
|
65
|
+
|
|
66
|
+
adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
|
|
67
|
+
|
|
68
|
+
print(f"Inference complete: stored under prefix '{full_prefix}'")
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
from .inference_utils import annotate_split_column
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def run_sklearn_inference(
|
|
7
|
+
adata,
|
|
8
|
+
model,
|
|
9
|
+
datamodule,
|
|
10
|
+
prefix="model"
|
|
11
|
+
):
|
|
12
|
+
"""
|
|
13
|
+
Run inference on AnnData using SklearnModelWrapper.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
label_col = model.label_col
|
|
17
|
+
num_classes = model.num_classes
|
|
18
|
+
class_labels = model.class_names
|
|
19
|
+
focus_class_name = model.focus_class_name
|
|
20
|
+
|
|
21
|
+
annotate_split_column(adata, model, split_col=f"{prefix}_training_split")
|
|
22
|
+
|
|
23
|
+
datamodule.setup()
|
|
24
|
+
|
|
25
|
+
X_infer = datamodule.to_numpy()
|
|
26
|
+
|
|
27
|
+
# Run predictions
|
|
28
|
+
preds = model.predict(X_infer)
|
|
29
|
+
probs = model.predict_proba(X_infer)
|
|
30
|
+
|
|
31
|
+
# Handle binary vs multiclass formats
|
|
32
|
+
if num_classes == 2:
|
|
33
|
+
# probs shape: (N, 2) from predict_proba
|
|
34
|
+
pred_class_idx = preds
|
|
35
|
+
probs_all = probs
|
|
36
|
+
pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
|
|
37
|
+
else:
|
|
38
|
+
pred_class_idx = preds
|
|
39
|
+
probs_all = probs
|
|
40
|
+
pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
|
|
41
|
+
|
|
42
|
+
pred_class_labels = [class_labels[i] for i in pred_class_idx]
|
|
43
|
+
|
|
44
|
+
full_prefix = f"{prefix}_{label_col}"
|
|
45
|
+
|
|
46
|
+
adata.obs[f"{full_prefix}_pred"] = pred_class_idx
|
|
47
|
+
adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
|
|
48
|
+
adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
|
|
49
|
+
|
|
50
|
+
for i, class_name in enumerate(class_labels):
|
|
51
|
+
adata.obs[f"{full_prefix}_prob_{class_name}"] = probs_all[:, i]
|
|
52
|
+
|
|
53
|
+
adata.obsm[f"{full_prefix}_pred_prob_all"] = probs_all
|
|
54
|
+
|
|
55
|
+
print(f"Inference complete: stored under prefix '{full_prefix}'")
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from ..data import AnnDataModule
|
|
2
|
+
from ..evaluation import PostInferenceModelEvaluator
|
|
3
|
+
from .lightning_inference import run_lightning_inference
|
|
4
|
+
from .sklearn_inference import run_sklearn_inference
|
|
5
|
+
|
|
6
|
+
def sliding_window_inference(
|
|
7
|
+
adata,
|
|
8
|
+
trained_results,
|
|
9
|
+
tensor_source='X',
|
|
10
|
+
tensor_key=None,
|
|
11
|
+
label_col='activity_status',
|
|
12
|
+
batch_size=64,
|
|
13
|
+
cleanup=False,
|
|
14
|
+
target_eval_freq=None,
|
|
15
|
+
max_eval_positive=None
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Apply trained sliding window models to an AnnData object (Lightning or Sklearn).
|
|
19
|
+
Evaluate model performance and return a df.
|
|
20
|
+
Optionally remove the appended inference columns from AnnData to clean up obs namespace.
|
|
21
|
+
"""
|
|
22
|
+
## Inference using trained models
|
|
23
|
+
for model_name, model_dict in trained_results.items():
|
|
24
|
+
for window_size, window_data in model_dict.items():
|
|
25
|
+
for center_varname, run in window_data.items():
|
|
26
|
+
print(f"\nEvaluating {model_name} window {window_size} around {center_varname}")
|
|
27
|
+
|
|
28
|
+
# Extract window start from varname
|
|
29
|
+
center_idx = adata.var_names.get_loc(center_varname)
|
|
30
|
+
window_start = center_idx - window_size // 2
|
|
31
|
+
|
|
32
|
+
# Build datamodule for window
|
|
33
|
+
datamodule = AnnDataModule(
|
|
34
|
+
adata,
|
|
35
|
+
tensor_source=tensor_source,
|
|
36
|
+
tensor_key=tensor_key,
|
|
37
|
+
label_col=label_col,
|
|
38
|
+
batch_size=batch_size,
|
|
39
|
+
window_start=window_start,
|
|
40
|
+
window_size=window_size,
|
|
41
|
+
inference_mode=True
|
|
42
|
+
)
|
|
43
|
+
datamodule.setup()
|
|
44
|
+
|
|
45
|
+
# Extract model + detect type
|
|
46
|
+
model = run['model']
|
|
47
|
+
|
|
48
|
+
# Lightning models
|
|
49
|
+
if hasattr(run, 'trainer') or 'trainer' in run:
|
|
50
|
+
trainer = run['trainer']
|
|
51
|
+
run_lightning_inference(
|
|
52
|
+
adata,
|
|
53
|
+
model=model,
|
|
54
|
+
datamodule=datamodule,
|
|
55
|
+
trainer=trainer,
|
|
56
|
+
prefix=f"{model_name}_w{window_size}_c{center_varname}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Sklearn models
|
|
60
|
+
else:
|
|
61
|
+
run_sklearn_inference(
|
|
62
|
+
adata,
|
|
63
|
+
model=model,
|
|
64
|
+
datamodule=datamodule,
|
|
65
|
+
prefix=f"{model_name}_w{window_size}_c{center_varname}"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
print("Inference complete across all models.")
|
|
69
|
+
|
|
70
|
+
## Post-inference model evaluation
|
|
71
|
+
model_wrappers = {}
|
|
72
|
+
|
|
73
|
+
for model_name, model_dict in trained_results.items():
|
|
74
|
+
for window_size, window_data in model_dict.items():
|
|
75
|
+
for center_varname, run in window_data.items():
|
|
76
|
+
# Reconstruct the prefix string you used in inference
|
|
77
|
+
prefix = f"{model_name}_w{window_size}_c{center_varname}"
|
|
78
|
+
# Use full key for uniqueness
|
|
79
|
+
key = prefix
|
|
80
|
+
model_wrappers[key] = run['model']
|
|
81
|
+
|
|
82
|
+
# Run evaluator
|
|
83
|
+
evaluator = PostInferenceModelEvaluator(adata, model_wrappers, target_eval_freq=target_eval_freq, max_eval_positive=max_eval_positive)
|
|
84
|
+
evaluator.evaluate_all()
|
|
85
|
+
|
|
86
|
+
# Get results
|
|
87
|
+
df = evaluator.to_dataframe()
|
|
88
|
+
|
|
89
|
+
df[['model_name', 'window_size', 'center']] = df['model'].str.extract(r'(\w+)_w(\d+)_c(\d+)_activity_status')
|
|
90
|
+
|
|
91
|
+
# Cast window_size and center to integers for plotting
|
|
92
|
+
df['window_size'] = df['window_size'].astype(int)
|
|
93
|
+
df['center'] = df['center'].astype(int)
|
|
94
|
+
|
|
95
|
+
## Optional cleanup:
|
|
96
|
+
if cleanup:
|
|
97
|
+
prefixes = [f"{model_name}_w{window_size}_c{center_varname}"
|
|
98
|
+
for model_name, model_dict in trained_results.items()
|
|
99
|
+
for window_size, window_data in model_dict.items()
|
|
100
|
+
for center_varname in window_data.keys()]
|
|
101
|
+
|
|
102
|
+
# Remove matching obs columns
|
|
103
|
+
for prefix in prefixes:
|
|
104
|
+
to_remove = [col for col in adata.obs.columns if col.startswith(prefix)]
|
|
105
|
+
adata.obs.drop(columns=to_remove, inplace=True)
|
|
106
|
+
|
|
107
|
+
# Remove obsm entries if any
|
|
108
|
+
obsm_key = f"{prefix}_pred_prob_all"
|
|
109
|
+
if obsm_key in adata.obsm:
|
|
110
|
+
del adata.obsm[obsm_key]
|
|
111
|
+
|
|
112
|
+
print(f"Cleaned up {len(prefixes)} model prefixes from AnnData.")
|
|
113
|
+
|
|
114
|
+
return df
|