smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
smftools/plotting/classifiers.py
CHANGED
|
@@ -1,35 +1,53 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
1
2
|
|
|
2
|
-
import numpy as np
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
|
-
import torch
|
|
5
3
|
import os
|
|
6
4
|
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from smftools.optional_imports import require
|
|
8
|
+
|
|
9
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="model plots")
|
|
10
|
+
torch = require("torch", extra="ml-base", purpose="model saliency plots")
|
|
11
|
+
|
|
12
|
+
|
|
7
13
|
def plot_model_performance(metrics, save_path=None):
|
|
8
|
-
|
|
14
|
+
"""Plot ROC and precision-recall curves for model metrics.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
metrics: Dictionary of model metrics by reference.
|
|
18
|
+
save_path: Optional path to save plots.
|
|
19
|
+
"""
|
|
9
20
|
import os
|
|
21
|
+
|
|
10
22
|
for ref in metrics.keys():
|
|
11
23
|
plt.figure(figsize=(12, 5))
|
|
12
24
|
|
|
13
25
|
# ROC Curve
|
|
14
26
|
plt.subplot(1, 2, 1)
|
|
15
27
|
for model_name, vals in metrics[ref].items():
|
|
16
|
-
model_type = model_name.split(
|
|
28
|
+
model_type = model_name.split("_")[0]
|
|
17
29
|
data_type = model_name.split(f"{model_type}_")[1]
|
|
18
|
-
plt.plot(
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
plt.
|
|
30
|
+
plt.plot(
|
|
31
|
+
vals["fpr"], vals["tpr"], label=f"{model_type.upper()} - AUC: {vals['auc']:.4f}"
|
|
32
|
+
)
|
|
33
|
+
plt.xlabel("False Positive Rate")
|
|
34
|
+
plt.ylabel("True Positive Rate")
|
|
35
|
+
plt.title(f"{data_type} ROC Curve ({ref})")
|
|
22
36
|
plt.legend()
|
|
23
37
|
|
|
24
38
|
# PR Curve
|
|
25
39
|
plt.subplot(1, 2, 2)
|
|
26
40
|
for model_name, vals in metrics[ref].items():
|
|
27
|
-
model_type = model_name.split(
|
|
41
|
+
model_type = model_name.split("_")[0]
|
|
28
42
|
data_type = model_name.split(f"{model_type}_")[1]
|
|
29
|
-
plt.plot(
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
43
|
+
plt.plot(
|
|
44
|
+
vals["recall"],
|
|
45
|
+
vals["precision"],
|
|
46
|
+
label=f"{model_type.upper()} - F1: {vals['f1']:.4f}",
|
|
47
|
+
)
|
|
48
|
+
plt.xlabel("Recall")
|
|
49
|
+
plt.ylabel("Precision")
|
|
50
|
+
plt.title(f"{data_type} Precision-Recall Curve ({ref})")
|
|
33
51
|
plt.legend()
|
|
34
52
|
|
|
35
53
|
plt.tight_layout()
|
|
@@ -42,13 +60,14 @@ def plot_model_performance(metrics, save_path=None):
|
|
|
42
60
|
plt.savefig(out_file, dpi=300)
|
|
43
61
|
print(f"📁 Saved: {out_file}")
|
|
44
62
|
plt.show()
|
|
45
|
-
|
|
63
|
+
|
|
46
64
|
# Confusion Matrices
|
|
47
65
|
for model_name, vals in metrics[ref].items():
|
|
48
66
|
print(f"Confusion Matrix for {ref} - {model_name.upper()}:")
|
|
49
|
-
print(vals[
|
|
67
|
+
print(vals["confusion_matrix"])
|
|
50
68
|
print()
|
|
51
69
|
|
|
70
|
+
|
|
52
71
|
def plot_feature_importances_or_saliency(
|
|
53
72
|
models,
|
|
54
73
|
positions,
|
|
@@ -57,18 +76,31 @@ def plot_feature_importances_or_saliency(
|
|
|
57
76
|
adata=None,
|
|
58
77
|
layer_name=None,
|
|
59
78
|
save_path=None,
|
|
60
|
-
shaded_regions=None
|
|
79
|
+
shaded_regions=None,
|
|
61
80
|
):
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
81
|
+
"""Plot feature importances or saliency for trained models.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
models: Mapping of trained models.
|
|
85
|
+
positions: Mapping of positions per reference.
|
|
86
|
+
tensors: Mapping of input tensors per reference.
|
|
87
|
+
site_config: Site configuration mapping.
|
|
88
|
+
adata: Optional AnnData object.
|
|
89
|
+
layer_name: Optional layer name for plotting.
|
|
90
|
+
save_path: Optional path to save plots.
|
|
91
|
+
shaded_regions: Optional list of regions to highlight.
|
|
92
|
+
"""
|
|
65
93
|
import os
|
|
66
94
|
|
|
95
|
+
import numpy as np
|
|
96
|
+
|
|
67
97
|
# Select device for NN models
|
|
68
98
|
device = (
|
|
69
|
-
torch.device(
|
|
70
|
-
|
|
71
|
-
torch.device(
|
|
99
|
+
torch.device("cuda")
|
|
100
|
+
if torch.cuda.is_available()
|
|
101
|
+
else torch.device("mps")
|
|
102
|
+
if torch.backends.mps.is_available()
|
|
103
|
+
else torch.device("cpu")
|
|
72
104
|
)
|
|
73
105
|
|
|
74
106
|
for ref, model_dict in models.items():
|
|
@@ -90,7 +122,9 @@ def plot_feature_importances_or_saliency(
|
|
|
90
122
|
other_sites = set()
|
|
91
123
|
|
|
92
124
|
if adata is None:
|
|
93
|
-
print(
|
|
125
|
+
print(
|
|
126
|
+
"⚠️ AnnData object is required to classify site types. Skipping site type markers."
|
|
127
|
+
)
|
|
94
128
|
else:
|
|
95
129
|
gpc_col = f"{ref}_GpC_site"
|
|
96
130
|
cpg_col = f"{ref}_CpG_site"
|
|
@@ -146,20 +180,46 @@ def plot_feature_importances_or_saliency(
|
|
|
146
180
|
plt.figure(figsize=(12, 4))
|
|
147
181
|
for pos, imp in zip(positions_sorted, importances_sorted):
|
|
148
182
|
if pos in cpg_sites:
|
|
149
|
-
plt.plot(
|
|
150
|
-
|
|
183
|
+
plt.plot(
|
|
184
|
+
pos,
|
|
185
|
+
imp,
|
|
186
|
+
marker="*",
|
|
187
|
+
color="black",
|
|
188
|
+
markersize=10,
|
|
189
|
+
linestyle="None",
|
|
190
|
+
label="CpG site"
|
|
191
|
+
if "CpG site" not in plt.gca().get_legend_handles_labels()[1]
|
|
192
|
+
else "",
|
|
193
|
+
)
|
|
151
194
|
elif pos in gpc_sites:
|
|
152
|
-
plt.plot(
|
|
153
|
-
|
|
195
|
+
plt.plot(
|
|
196
|
+
pos,
|
|
197
|
+
imp,
|
|
198
|
+
marker="o",
|
|
199
|
+
color="blue",
|
|
200
|
+
markersize=6,
|
|
201
|
+
linestyle="None",
|
|
202
|
+
label="GpC site"
|
|
203
|
+
if "GpC site" not in plt.gca().get_legend_handles_labels()[1]
|
|
204
|
+
else "",
|
|
205
|
+
)
|
|
154
206
|
else:
|
|
155
|
-
plt.plot(
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
207
|
+
plt.plot(
|
|
208
|
+
pos,
|
|
209
|
+
imp,
|
|
210
|
+
marker=".",
|
|
211
|
+
color="gray",
|
|
212
|
+
linestyle="None",
|
|
213
|
+
label="Other"
|
|
214
|
+
if "Other" not in plt.gca().get_legend_handles_labels()[1]
|
|
215
|
+
else "",
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
plt.plot(positions_sorted, importances_sorted, linestyle="-", alpha=0.5, color="black")
|
|
159
219
|
|
|
160
220
|
if shaded_regions:
|
|
161
|
-
for
|
|
162
|
-
plt.axvspan(start, end, color=
|
|
221
|
+
for start, end in shaded_regions:
|
|
222
|
+
plt.axvspan(start, end, color="gray", alpha=0.3)
|
|
163
223
|
|
|
164
224
|
plt.xlabel("Genomic Position")
|
|
165
225
|
plt.ylabel(y_label)
|
|
@@ -170,31 +230,50 @@ def plot_feature_importances_or_saliency(
|
|
|
170
230
|
|
|
171
231
|
if save_path:
|
|
172
232
|
os.makedirs(save_path, exist_ok=True)
|
|
173
|
-
safe_name =
|
|
233
|
+
safe_name = (
|
|
234
|
+
plot_title.replace("=", "")
|
|
235
|
+
.replace("__", "_")
|
|
236
|
+
.replace(",", "_")
|
|
237
|
+
.replace(" ", "_")
|
|
238
|
+
)
|
|
174
239
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
175
240
|
plt.savefig(out_file, dpi=300)
|
|
176
241
|
print(f"📁 Saved: {out_file}")
|
|
177
242
|
|
|
178
243
|
plt.show()
|
|
179
244
|
|
|
245
|
+
|
|
180
246
|
def plot_model_curves_from_adata(
|
|
181
|
-
adata,
|
|
182
|
-
label_col=
|
|
183
|
-
model_names
|
|
184
|
-
suffix=
|
|
185
|
-
omit_training=True,
|
|
186
|
-
save_path=None,
|
|
187
|
-
ylim_roc=(0.0, 1.05),
|
|
188
|
-
ylim_pr=(0.0, 1.05)
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
247
|
+
adata,
|
|
248
|
+
label_col="activity_status",
|
|
249
|
+
model_names=["cnn", "mlp", "rf"],
|
|
250
|
+
suffix="GpC_site_CpG_site",
|
|
251
|
+
omit_training=True,
|
|
252
|
+
save_path=None,
|
|
253
|
+
ylim_roc=(0.0, 1.05),
|
|
254
|
+
ylim_pr=(0.0, 1.05),
|
|
255
|
+
):
|
|
256
|
+
"""Plot ROC and PR curves using AnnData model outputs.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
adata: AnnData containing model outputs.
|
|
260
|
+
label_col: Ground-truth label column.
|
|
261
|
+
model_names: Model name prefixes.
|
|
262
|
+
suffix: Prediction column suffix.
|
|
263
|
+
omit_training: Whether to omit training rows.
|
|
264
|
+
save_path: Optional path to save the plot.
|
|
265
|
+
ylim_roc: Y-axis limits for ROC curve.
|
|
266
|
+
ylim_pr: Y-axis limits for PR curve.
|
|
267
|
+
"""
|
|
268
|
+
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model curves")
|
|
269
|
+
auc = sklearn_metrics.auc
|
|
270
|
+
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
271
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
193
272
|
|
|
194
273
|
if omit_training:
|
|
195
|
-
subset = adata[adata.obs[
|
|
274
|
+
subset = adata[~adata.obs["used_for_training"].astype(bool)]
|
|
196
275
|
|
|
197
|
-
label = subset.obs[label_col].map({
|
|
276
|
+
label = subset.obs[label_col].map({"Active": 1, "Silent": 0}).values
|
|
198
277
|
|
|
199
278
|
positive_ratio = np.sum(label.astype(int)) / len(label)
|
|
200
279
|
|
|
@@ -210,7 +289,7 @@ def plot_model_curves_from_adata(
|
|
|
210
289
|
roc_auc = auc(fpr, tpr)
|
|
211
290
|
plt.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
|
|
212
291
|
|
|
213
|
-
plt.plot([0, 1], [0, 1],
|
|
292
|
+
plt.plot([0, 1], [0, 1], "k--", alpha=0.5)
|
|
214
293
|
plt.xlabel("False Positive Rate")
|
|
215
294
|
plt.ylabel("True Positive Rate")
|
|
216
295
|
plt.title("ROC Curve")
|
|
@@ -230,13 +309,13 @@ def plot_model_curves_from_adata(
|
|
|
230
309
|
plt.xlabel("Recall")
|
|
231
310
|
plt.ylabel("Precision")
|
|
232
311
|
plt.ylim(*ylim_pr)
|
|
233
|
-
plt.axhline(y=positive_ratio, linestyle=
|
|
312
|
+
plt.axhline(y=positive_ratio, linestyle="--", color="gray", label="Random Baseline")
|
|
234
313
|
plt.title("Precision-Recall Curve")
|
|
235
314
|
plt.legend()
|
|
236
315
|
|
|
237
316
|
plt.tight_layout()
|
|
238
317
|
if save_path:
|
|
239
|
-
save_name =
|
|
318
|
+
save_name = "ROC_PR_curves"
|
|
240
319
|
os.makedirs(save_path, exist_ok=True)
|
|
241
320
|
safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
242
321
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
@@ -244,11 +323,12 @@ def plot_model_curves_from_adata(
|
|
|
244
323
|
print(f"📁 Saved: {out_file}")
|
|
245
324
|
plt.show()
|
|
246
325
|
|
|
326
|
+
|
|
247
327
|
def plot_model_curves_from_adata_with_frequency_grid(
|
|
248
328
|
adata,
|
|
249
|
-
label_col=
|
|
329
|
+
label_col="activity_status",
|
|
250
330
|
model_names=["cnn", "mlp", "rf"],
|
|
251
|
-
suffix=
|
|
331
|
+
suffix="GpC_site_CpG_site",
|
|
252
332
|
omit_training=True,
|
|
253
333
|
save_path=None,
|
|
254
334
|
ylim_roc=(0.0, 1.05),
|
|
@@ -256,22 +336,42 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
256
336
|
pos_sample_count=500,
|
|
257
337
|
pos_freq_list=[0.01, 0.05, 0.1],
|
|
258
338
|
show_f1_iso_curves=False,
|
|
259
|
-
f1_levels=None
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
339
|
+
f1_levels=None,
|
|
340
|
+
):
|
|
341
|
+
"""Plot ROC/PR curves with frequency grid overlays.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
adata: AnnData containing model outputs.
|
|
345
|
+
label_col: Ground-truth label column.
|
|
346
|
+
model_names: Model name prefixes.
|
|
347
|
+
suffix: Prediction column suffix.
|
|
348
|
+
omit_training: Whether to omit training rows.
|
|
349
|
+
save_path: Optional path to save the plot.
|
|
350
|
+
ylim_roc: Y-axis limits for ROC curve.
|
|
351
|
+
ylim_pr: Y-axis limits for PR curve.
|
|
352
|
+
pos_sample_count: Sample count for positive baseline.
|
|
353
|
+
pos_freq_list: List of positive class frequencies to plot.
|
|
354
|
+
show_f1_iso_curves: Whether to show F1 iso-curves.
|
|
355
|
+
f1_levels: F1 levels to plot if enabled.
|
|
356
|
+
"""
|
|
263
357
|
import os
|
|
264
|
-
|
|
265
|
-
|
|
358
|
+
|
|
359
|
+
import numpy as np
|
|
360
|
+
|
|
361
|
+
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model curves")
|
|
362
|
+
auc = sklearn_metrics.auc
|
|
363
|
+
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
364
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
365
|
+
|
|
266
366
|
if f1_levels is None:
|
|
267
367
|
f1_levels = np.linspace(0.2, 0.9, 8)
|
|
268
|
-
|
|
368
|
+
|
|
269
369
|
if omit_training:
|
|
270
|
-
subset = adata[adata.obs[
|
|
370
|
+
subset = adata[~adata.obs["used_for_training"].astype(bool)]
|
|
271
371
|
else:
|
|
272
372
|
subset = adata
|
|
273
373
|
|
|
274
|
-
label = subset.obs[label_col].map({
|
|
374
|
+
label = subset.obs[label_col].map({"Active": 1, "Silent": 0}).values
|
|
275
375
|
subset = subset.copy()
|
|
276
376
|
subset.obs["__label__"] = label
|
|
277
377
|
|
|
@@ -280,7 +380,7 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
280
380
|
|
|
281
381
|
n_rows = len(pos_freq_list)
|
|
282
382
|
fig, axes = plt.subplots(n_rows, 2, figsize=(12, 5 * n_rows))
|
|
283
|
-
fig.suptitle(f
|
|
383
|
+
fig.suptitle(f"{suffix} Performance metrics")
|
|
284
384
|
|
|
285
385
|
for row_idx, pos_freq in enumerate(pos_freq_list):
|
|
286
386
|
desired_total = int(pos_sample_count / pos_freq)
|
|
@@ -308,14 +408,14 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
308
408
|
fpr, tpr, _ = roc_curve(y_true, probs)
|
|
309
409
|
roc_auc = auc(fpr, tpr)
|
|
310
410
|
ax_roc.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
|
|
311
|
-
ax_roc.plot([0, 1], [0, 1],
|
|
411
|
+
ax_roc.plot([0, 1], [0, 1], "k--", alpha=0.5)
|
|
312
412
|
ax_roc.set_xlabel("False Positive Rate")
|
|
313
413
|
ax_roc.set_ylabel("True Positive Rate")
|
|
314
414
|
ax_roc.set_ylim(*ylim_roc)
|
|
315
415
|
ax_roc.set_title(f"ROC Curve (Pos Freq: {pos_freq:.2%})")
|
|
316
416
|
ax_roc.legend()
|
|
317
|
-
ax_roc.spines[
|
|
318
|
-
ax_roc.spines[
|
|
417
|
+
ax_roc.spines["top"].set_visible(False)
|
|
418
|
+
ax_roc.spines["right"].set_visible(False)
|
|
319
419
|
|
|
320
420
|
# PR Curve
|
|
321
421
|
for model in model_names:
|
|
@@ -325,26 +425,28 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
325
425
|
precision, recall, _ = precision_recall_curve(y_true, probs)
|
|
326
426
|
pr_auc = auc(recall, precision)
|
|
327
427
|
ax_pr.plot(recall, precision, label=f"{model.upper()} (AUC={pr_auc:.4f})")
|
|
328
|
-
ax_pr.axhline(y=pos_freq, linestyle=
|
|
428
|
+
ax_pr.axhline(y=pos_freq, linestyle="--", color="gray", label="Random Baseline")
|
|
329
429
|
|
|
330
430
|
if show_f1_iso_curves:
|
|
331
431
|
recall_vals = np.linspace(0.01, 1, 500)
|
|
332
432
|
for f1 in f1_levels:
|
|
333
433
|
precision_vals = (f1 * recall_vals) / (2 * recall_vals - f1)
|
|
334
434
|
precision_vals[precision_vals < 0] = np.nan # Avoid plotting invalid values
|
|
335
|
-
ax_pr.plot(
|
|
435
|
+
ax_pr.plot(
|
|
436
|
+
recall_vals, precision_vals, color="gray", linestyle=":", linewidth=1, alpha=0.6
|
|
437
|
+
)
|
|
336
438
|
x_val = 0.9
|
|
337
439
|
y_val = (f1 * x_val) / (2 * x_val - f1)
|
|
338
440
|
if 0 < y_val < 1:
|
|
339
|
-
ax_pr.text(x_val, y_val, f"F1={f1:.1f}", fontsize=8, color=
|
|
441
|
+
ax_pr.text(x_val, y_val, f"F1={f1:.1f}", fontsize=8, color="gray")
|
|
340
442
|
|
|
341
443
|
ax_pr.set_xlabel("Recall")
|
|
342
444
|
ax_pr.set_ylabel("Precision")
|
|
343
445
|
ax_pr.set_ylim(*ylim_pr)
|
|
344
446
|
ax_pr.set_title(f"PR Curve (Pos Freq: {pos_freq:.2%})")
|
|
345
447
|
ax_pr.legend()
|
|
346
|
-
ax_pr.spines[
|
|
347
|
-
ax_pr.spines[
|
|
448
|
+
ax_pr.spines["top"].set_visible(False)
|
|
449
|
+
ax_pr.spines["right"].set_visible(False)
|
|
348
450
|
|
|
349
451
|
plt.tight_layout(rect=[0, 0, 1, 0.97])
|
|
350
452
|
if save_path:
|
|
@@ -352,4 +454,4 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
352
454
|
out_file = os.path.join(save_path, "ROC_PR_grid.png")
|
|
353
455
|
plt.savefig(out_file, dpi=300)
|
|
354
456
|
print(f"📁 Saved: {out_file}")
|
|
355
|
-
plt.show()
|
|
457
|
+
plt.show()
|