smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
smftools/plotting/classifiers.py
CHANGED
|
@@ -1,35 +1,48 @@
|
|
|
1
|
+
import os
|
|
1
2
|
|
|
2
|
-
import numpy as np
|
|
3
3
|
import matplotlib.pyplot as plt
|
|
4
|
+
import numpy as np
|
|
4
5
|
import torch
|
|
5
|
-
|
|
6
|
+
|
|
6
7
|
|
|
7
8
|
def plot_model_performance(metrics, save_path=None):
|
|
8
|
-
|
|
9
|
+
"""Plot ROC and precision-recall curves for model metrics.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
metrics: Dictionary of model metrics by reference.
|
|
13
|
+
save_path: Optional path to save plots.
|
|
14
|
+
"""
|
|
9
15
|
import os
|
|
16
|
+
|
|
10
17
|
for ref in metrics.keys():
|
|
11
18
|
plt.figure(figsize=(12, 5))
|
|
12
19
|
|
|
13
20
|
# ROC Curve
|
|
14
21
|
plt.subplot(1, 2, 1)
|
|
15
22
|
for model_name, vals in metrics[ref].items():
|
|
16
|
-
model_type = model_name.split(
|
|
23
|
+
model_type = model_name.split("_")[0]
|
|
17
24
|
data_type = model_name.split(f"{model_type}_")[1]
|
|
18
|
-
plt.plot(
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
plt.
|
|
25
|
+
plt.plot(
|
|
26
|
+
vals["fpr"], vals["tpr"], label=f"{model_type.upper()} - AUC: {vals['auc']:.4f}"
|
|
27
|
+
)
|
|
28
|
+
plt.xlabel("False Positive Rate")
|
|
29
|
+
plt.ylabel("True Positive Rate")
|
|
30
|
+
plt.title(f"{data_type} ROC Curve ({ref})")
|
|
22
31
|
plt.legend()
|
|
23
32
|
|
|
24
33
|
# PR Curve
|
|
25
34
|
plt.subplot(1, 2, 2)
|
|
26
35
|
for model_name, vals in metrics[ref].items():
|
|
27
|
-
model_type = model_name.split(
|
|
36
|
+
model_type = model_name.split("_")[0]
|
|
28
37
|
data_type = model_name.split(f"{model_type}_")[1]
|
|
29
|
-
plt.plot(
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
38
|
+
plt.plot(
|
|
39
|
+
vals["recall"],
|
|
40
|
+
vals["precision"],
|
|
41
|
+
label=f"{model_type.upper()} - F1: {vals['f1']:.4f}",
|
|
42
|
+
)
|
|
43
|
+
plt.xlabel("Recall")
|
|
44
|
+
plt.ylabel("Precision")
|
|
45
|
+
plt.title(f"{data_type} Precision-Recall Curve ({ref})")
|
|
33
46
|
plt.legend()
|
|
34
47
|
|
|
35
48
|
plt.tight_layout()
|
|
@@ -42,13 +55,14 @@ def plot_model_performance(metrics, save_path=None):
|
|
|
42
55
|
plt.savefig(out_file, dpi=300)
|
|
43
56
|
print(f"📁 Saved: {out_file}")
|
|
44
57
|
plt.show()
|
|
45
|
-
|
|
58
|
+
|
|
46
59
|
# Confusion Matrices
|
|
47
60
|
for model_name, vals in metrics[ref].items():
|
|
48
61
|
print(f"Confusion Matrix for {ref} - {model_name.upper()}:")
|
|
49
|
-
print(vals[
|
|
62
|
+
print(vals["confusion_matrix"])
|
|
50
63
|
print()
|
|
51
64
|
|
|
65
|
+
|
|
52
66
|
def plot_feature_importances_or_saliency(
|
|
53
67
|
models,
|
|
54
68
|
positions,
|
|
@@ -57,18 +71,31 @@ def plot_feature_importances_or_saliency(
|
|
|
57
71
|
adata=None,
|
|
58
72
|
layer_name=None,
|
|
59
73
|
save_path=None,
|
|
60
|
-
shaded_regions=None
|
|
74
|
+
shaded_regions=None,
|
|
61
75
|
):
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
76
|
+
"""Plot feature importances or saliency for trained models.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
models: Mapping of trained models.
|
|
80
|
+
positions: Mapping of positions per reference.
|
|
81
|
+
tensors: Mapping of input tensors per reference.
|
|
82
|
+
site_config: Site configuration mapping.
|
|
83
|
+
adata: Optional AnnData object.
|
|
84
|
+
layer_name: Optional layer name for plotting.
|
|
85
|
+
save_path: Optional path to save plots.
|
|
86
|
+
shaded_regions: Optional list of regions to highlight.
|
|
87
|
+
"""
|
|
65
88
|
import os
|
|
66
89
|
|
|
90
|
+
import numpy as np
|
|
91
|
+
|
|
67
92
|
# Select device for NN models
|
|
68
93
|
device = (
|
|
69
|
-
torch.device(
|
|
70
|
-
|
|
71
|
-
torch.device(
|
|
94
|
+
torch.device("cuda")
|
|
95
|
+
if torch.cuda.is_available()
|
|
96
|
+
else torch.device("mps")
|
|
97
|
+
if torch.backends.mps.is_available()
|
|
98
|
+
else torch.device("cpu")
|
|
72
99
|
)
|
|
73
100
|
|
|
74
101
|
for ref, model_dict in models.items():
|
|
@@ -90,7 +117,9 @@ def plot_feature_importances_or_saliency(
|
|
|
90
117
|
other_sites = set()
|
|
91
118
|
|
|
92
119
|
if adata is None:
|
|
93
|
-
print(
|
|
120
|
+
print(
|
|
121
|
+
"⚠️ AnnData object is required to classify site types. Skipping site type markers."
|
|
122
|
+
)
|
|
94
123
|
else:
|
|
95
124
|
gpc_col = f"{ref}_GpC_site"
|
|
96
125
|
cpg_col = f"{ref}_CpG_site"
|
|
@@ -146,20 +175,46 @@ def plot_feature_importances_or_saliency(
|
|
|
146
175
|
plt.figure(figsize=(12, 4))
|
|
147
176
|
for pos, imp in zip(positions_sorted, importances_sorted):
|
|
148
177
|
if pos in cpg_sites:
|
|
149
|
-
plt.plot(
|
|
150
|
-
|
|
178
|
+
plt.plot(
|
|
179
|
+
pos,
|
|
180
|
+
imp,
|
|
181
|
+
marker="*",
|
|
182
|
+
color="black",
|
|
183
|
+
markersize=10,
|
|
184
|
+
linestyle="None",
|
|
185
|
+
label="CpG site"
|
|
186
|
+
if "CpG site" not in plt.gca().get_legend_handles_labels()[1]
|
|
187
|
+
else "",
|
|
188
|
+
)
|
|
151
189
|
elif pos in gpc_sites:
|
|
152
|
-
plt.plot(
|
|
153
|
-
|
|
190
|
+
plt.plot(
|
|
191
|
+
pos,
|
|
192
|
+
imp,
|
|
193
|
+
marker="o",
|
|
194
|
+
color="blue",
|
|
195
|
+
markersize=6,
|
|
196
|
+
linestyle="None",
|
|
197
|
+
label="GpC site"
|
|
198
|
+
if "GpC site" not in plt.gca().get_legend_handles_labels()[1]
|
|
199
|
+
else "",
|
|
200
|
+
)
|
|
154
201
|
else:
|
|
155
|
-
plt.plot(
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
202
|
+
plt.plot(
|
|
203
|
+
pos,
|
|
204
|
+
imp,
|
|
205
|
+
marker=".",
|
|
206
|
+
color="gray",
|
|
207
|
+
linestyle="None",
|
|
208
|
+
label="Other"
|
|
209
|
+
if "Other" not in plt.gca().get_legend_handles_labels()[1]
|
|
210
|
+
else "",
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
plt.plot(positions_sorted, importances_sorted, linestyle="-", alpha=0.5, color="black")
|
|
159
214
|
|
|
160
215
|
if shaded_regions:
|
|
161
|
-
for
|
|
162
|
-
plt.axvspan(start, end, color=
|
|
216
|
+
for start, end in shaded_regions:
|
|
217
|
+
plt.axvspan(start, end, color="gray", alpha=0.3)
|
|
163
218
|
|
|
164
219
|
plt.xlabel("Genomic Position")
|
|
165
220
|
plt.ylabel(y_label)
|
|
@@ -170,31 +225,47 @@ def plot_feature_importances_or_saliency(
|
|
|
170
225
|
|
|
171
226
|
if save_path:
|
|
172
227
|
os.makedirs(save_path, exist_ok=True)
|
|
173
|
-
safe_name =
|
|
228
|
+
safe_name = (
|
|
229
|
+
plot_title.replace("=", "")
|
|
230
|
+
.replace("__", "_")
|
|
231
|
+
.replace(",", "_")
|
|
232
|
+
.replace(" ", "_")
|
|
233
|
+
)
|
|
174
234
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
175
235
|
plt.savefig(out_file, dpi=300)
|
|
176
236
|
print(f"📁 Saved: {out_file}")
|
|
177
237
|
|
|
178
238
|
plt.show()
|
|
179
239
|
|
|
240
|
+
|
|
180
241
|
def plot_model_curves_from_adata(
|
|
181
|
-
adata,
|
|
182
|
-
label_col=
|
|
183
|
-
model_names
|
|
184
|
-
suffix=
|
|
185
|
-
omit_training=True,
|
|
186
|
-
save_path=None,
|
|
187
|
-
ylim_roc=(0.0, 1.05),
|
|
188
|
-
ylim_pr=(0.0, 1.05)
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
242
|
+
adata,
|
|
243
|
+
label_col="activity_status",
|
|
244
|
+
model_names=["cnn", "mlp", "rf"],
|
|
245
|
+
suffix="GpC_site_CpG_site",
|
|
246
|
+
omit_training=True,
|
|
247
|
+
save_path=None,
|
|
248
|
+
ylim_roc=(0.0, 1.05),
|
|
249
|
+
ylim_pr=(0.0, 1.05),
|
|
250
|
+
):
|
|
251
|
+
"""Plot ROC and PR curves using AnnData model outputs.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
adata: AnnData containing model outputs.
|
|
255
|
+
label_col: Ground-truth label column.
|
|
256
|
+
model_names: Model name prefixes.
|
|
257
|
+
suffix: Prediction column suffix.
|
|
258
|
+
omit_training: Whether to omit training rows.
|
|
259
|
+
save_path: Optional path to save the plot.
|
|
260
|
+
ylim_roc: Y-axis limits for ROC curve.
|
|
261
|
+
ylim_pr: Y-axis limits for PR curve.
|
|
262
|
+
"""
|
|
263
|
+
from sklearn.metrics import auc, precision_recall_curve, roc_curve
|
|
193
264
|
|
|
194
265
|
if omit_training:
|
|
195
|
-
subset = adata[adata.obs[
|
|
266
|
+
subset = adata[~adata.obs["used_for_training"].astype(bool)]
|
|
196
267
|
|
|
197
|
-
label = subset.obs[label_col].map({
|
|
268
|
+
label = subset.obs[label_col].map({"Active": 1, "Silent": 0}).values
|
|
198
269
|
|
|
199
270
|
positive_ratio = np.sum(label.astype(int)) / len(label)
|
|
200
271
|
|
|
@@ -210,7 +281,7 @@ def plot_model_curves_from_adata(
|
|
|
210
281
|
roc_auc = auc(fpr, tpr)
|
|
211
282
|
plt.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
|
|
212
283
|
|
|
213
|
-
plt.plot([0, 1], [0, 1],
|
|
284
|
+
plt.plot([0, 1], [0, 1], "k--", alpha=0.5)
|
|
214
285
|
plt.xlabel("False Positive Rate")
|
|
215
286
|
plt.ylabel("True Positive Rate")
|
|
216
287
|
plt.title("ROC Curve")
|
|
@@ -230,13 +301,13 @@ def plot_model_curves_from_adata(
|
|
|
230
301
|
plt.xlabel("Recall")
|
|
231
302
|
plt.ylabel("Precision")
|
|
232
303
|
plt.ylim(*ylim_pr)
|
|
233
|
-
plt.axhline(y=positive_ratio, linestyle=
|
|
304
|
+
plt.axhline(y=positive_ratio, linestyle="--", color="gray", label="Random Baseline")
|
|
234
305
|
plt.title("Precision-Recall Curve")
|
|
235
306
|
plt.legend()
|
|
236
307
|
|
|
237
308
|
plt.tight_layout()
|
|
238
309
|
if save_path:
|
|
239
|
-
save_name =
|
|
310
|
+
save_name = "ROC_PR_curves"
|
|
240
311
|
os.makedirs(save_path, exist_ok=True)
|
|
241
312
|
safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
242
313
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
@@ -244,11 +315,12 @@ def plot_model_curves_from_adata(
|
|
|
244
315
|
print(f"📁 Saved: {out_file}")
|
|
245
316
|
plt.show()
|
|
246
317
|
|
|
318
|
+
|
|
247
319
|
def plot_model_curves_from_adata_with_frequency_grid(
|
|
248
320
|
adata,
|
|
249
|
-
label_col=
|
|
321
|
+
label_col="activity_status",
|
|
250
322
|
model_names=["cnn", "mlp", "rf"],
|
|
251
|
-
suffix=
|
|
323
|
+
suffix="GpC_site_CpG_site",
|
|
252
324
|
omit_training=True,
|
|
253
325
|
save_path=None,
|
|
254
326
|
ylim_roc=(0.0, 1.05),
|
|
@@ -256,22 +328,38 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
256
328
|
pos_sample_count=500,
|
|
257
329
|
pos_freq_list=[0.01, 0.05, 0.1],
|
|
258
330
|
show_f1_iso_curves=False,
|
|
259
|
-
f1_levels=None
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
331
|
+
f1_levels=None,
|
|
332
|
+
):
|
|
333
|
+
"""Plot ROC/PR curves with frequency grid overlays.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
adata: AnnData containing model outputs.
|
|
337
|
+
label_col: Ground-truth label column.
|
|
338
|
+
model_names: Model name prefixes.
|
|
339
|
+
suffix: Prediction column suffix.
|
|
340
|
+
omit_training: Whether to omit training rows.
|
|
341
|
+
save_path: Optional path to save the plot.
|
|
342
|
+
ylim_roc: Y-axis limits for ROC curve.
|
|
343
|
+
ylim_pr: Y-axis limits for PR curve.
|
|
344
|
+
pos_sample_count: Sample count for positive baseline.
|
|
345
|
+
pos_freq_list: List of positive class frequencies to plot.
|
|
346
|
+
show_f1_iso_curves: Whether to show F1 iso-curves.
|
|
347
|
+
f1_levels: F1 levels to plot if enabled.
|
|
348
|
+
"""
|
|
263
349
|
import os
|
|
264
|
-
|
|
265
|
-
|
|
350
|
+
|
|
351
|
+
import numpy as np
|
|
352
|
+
from sklearn.metrics import auc, precision_recall_curve, roc_curve
|
|
353
|
+
|
|
266
354
|
if f1_levels is None:
|
|
267
355
|
f1_levels = np.linspace(0.2, 0.9, 8)
|
|
268
|
-
|
|
356
|
+
|
|
269
357
|
if omit_training:
|
|
270
|
-
subset = adata[adata.obs[
|
|
358
|
+
subset = adata[~adata.obs["used_for_training"].astype(bool)]
|
|
271
359
|
else:
|
|
272
360
|
subset = adata
|
|
273
361
|
|
|
274
|
-
label = subset.obs[label_col].map({
|
|
362
|
+
label = subset.obs[label_col].map({"Active": 1, "Silent": 0}).values
|
|
275
363
|
subset = subset.copy()
|
|
276
364
|
subset.obs["__label__"] = label
|
|
277
365
|
|
|
@@ -280,7 +368,7 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
280
368
|
|
|
281
369
|
n_rows = len(pos_freq_list)
|
|
282
370
|
fig, axes = plt.subplots(n_rows, 2, figsize=(12, 5 * n_rows))
|
|
283
|
-
fig.suptitle(f
|
|
371
|
+
fig.suptitle(f"{suffix} Performance metrics")
|
|
284
372
|
|
|
285
373
|
for row_idx, pos_freq in enumerate(pos_freq_list):
|
|
286
374
|
desired_total = int(pos_sample_count / pos_freq)
|
|
@@ -308,14 +396,14 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
308
396
|
fpr, tpr, _ = roc_curve(y_true, probs)
|
|
309
397
|
roc_auc = auc(fpr, tpr)
|
|
310
398
|
ax_roc.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
|
|
311
|
-
ax_roc.plot([0, 1], [0, 1],
|
|
399
|
+
ax_roc.plot([0, 1], [0, 1], "k--", alpha=0.5)
|
|
312
400
|
ax_roc.set_xlabel("False Positive Rate")
|
|
313
401
|
ax_roc.set_ylabel("True Positive Rate")
|
|
314
402
|
ax_roc.set_ylim(*ylim_roc)
|
|
315
403
|
ax_roc.set_title(f"ROC Curve (Pos Freq: {pos_freq:.2%})")
|
|
316
404
|
ax_roc.legend()
|
|
317
|
-
ax_roc.spines[
|
|
318
|
-
ax_roc.spines[
|
|
405
|
+
ax_roc.spines["top"].set_visible(False)
|
|
406
|
+
ax_roc.spines["right"].set_visible(False)
|
|
319
407
|
|
|
320
408
|
# PR Curve
|
|
321
409
|
for model in model_names:
|
|
@@ -325,26 +413,28 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
325
413
|
precision, recall, _ = precision_recall_curve(y_true, probs)
|
|
326
414
|
pr_auc = auc(recall, precision)
|
|
327
415
|
ax_pr.plot(recall, precision, label=f"{model.upper()} (AUC={pr_auc:.4f})")
|
|
328
|
-
ax_pr.axhline(y=pos_freq, linestyle=
|
|
416
|
+
ax_pr.axhline(y=pos_freq, linestyle="--", color="gray", label="Random Baseline")
|
|
329
417
|
|
|
330
418
|
if show_f1_iso_curves:
|
|
331
419
|
recall_vals = np.linspace(0.01, 1, 500)
|
|
332
420
|
for f1 in f1_levels:
|
|
333
421
|
precision_vals = (f1 * recall_vals) / (2 * recall_vals - f1)
|
|
334
422
|
precision_vals[precision_vals < 0] = np.nan # Avoid plotting invalid values
|
|
335
|
-
ax_pr.plot(
|
|
423
|
+
ax_pr.plot(
|
|
424
|
+
recall_vals, precision_vals, color="gray", linestyle=":", linewidth=1, alpha=0.6
|
|
425
|
+
)
|
|
336
426
|
x_val = 0.9
|
|
337
427
|
y_val = (f1 * x_val) / (2 * x_val - f1)
|
|
338
428
|
if 0 < y_val < 1:
|
|
339
|
-
ax_pr.text(x_val, y_val, f"F1={f1:.1f}", fontsize=8, color=
|
|
429
|
+
ax_pr.text(x_val, y_val, f"F1={f1:.1f}", fontsize=8, color="gray")
|
|
340
430
|
|
|
341
431
|
ax_pr.set_xlabel("Recall")
|
|
342
432
|
ax_pr.set_ylabel("Precision")
|
|
343
433
|
ax_pr.set_ylim(*ylim_pr)
|
|
344
434
|
ax_pr.set_title(f"PR Curve (Pos Freq: {pos_freq:.2%})")
|
|
345
435
|
ax_pr.legend()
|
|
346
|
-
ax_pr.spines[
|
|
347
|
-
ax_pr.spines[
|
|
436
|
+
ax_pr.spines["top"].set_visible(False)
|
|
437
|
+
ax_pr.spines["right"].set_visible(False)
|
|
348
438
|
|
|
349
439
|
plt.tight_layout(rect=[0, 0, 1, 0.97])
|
|
350
440
|
if save_path:
|
|
@@ -352,4 +442,4 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
352
442
|
out_file = os.path.join(save_path, "ROC_PR_grid.png")
|
|
353
443
|
plt.savefig(out_file, dpi=300)
|
|
354
444
|
print(f"📁 Saved: {out_file}")
|
|
355
|
-
plt.show()
|
|
445
|
+
plt.show()
|