smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +5 -1
- smftools/_version.py +1 -1
- smftools/informatics/__init__.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +29 -0
- smftools/informatics/basecall_pod5s.py +80 -0
- smftools/informatics/conversion_smf.py +63 -10
- smftools/informatics/direct_smf.py +66 -18
- smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
- smftools/informatics/helpers/__init__.py +16 -2
- smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
- smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
- smftools/informatics/helpers/bam_qc.py +66 -0
- smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
- smftools/informatics/helpers/canoncall.py +12 -3
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
- smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
- smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
- smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
- smftools/informatics/helpers/extract_base_identities.py +33 -46
- smftools/informatics/helpers/extract_mods.py +55 -23
- smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
- smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
- smftools/informatics/helpers/find_conversion_sites.py +33 -44
- smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
- smftools/informatics/helpers/modcall.py +13 -5
- smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
- smftools/informatics/helpers/ohe_batching.py +65 -41
- smftools/informatics/helpers/ohe_layers_decode.py +32 -0
- smftools/informatics/helpers/one_hot_decode.py +27 -0
- smftools/informatics/helpers/one_hot_encode.py +45 -9
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
- smftools/informatics/helpers/run_multiqc.py +28 -0
- smftools/informatics/helpers/split_and_index_BAM.py +3 -8
- smftools/informatics/load_adata.py +58 -3
- smftools/plotting/__init__.py +15 -0
- smftools/plotting/classifiers.py +355 -0
- smftools/plotting/general_plotting.py +205 -0
- smftools/plotting/position_stats.py +462 -0
- smftools/preprocessing/__init__.py +6 -7
- smftools/preprocessing/append_C_context.py +22 -9
- smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
- smftools/preprocessing/binarize_on_Youden.py +35 -32
- smftools/preprocessing/binary_layers_to_ohe.py +13 -3
- smftools/preprocessing/calculate_complexity.py +3 -2
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
- smftools/preprocessing/calculate_coverage.py +26 -25
- smftools/preprocessing/calculate_pairwise_differences.py +49 -0
- smftools/preprocessing/calculate_position_Youden.py +18 -7
- smftools/preprocessing/calculate_read_length_stats.py +39 -46
- smftools/preprocessing/clean_NaN.py +33 -25
- smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
- smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
- smftools/preprocessing/filter_reads_on_length.py +14 -4
- smftools/preprocessing/flag_duplicate_reads.py +149 -0
- smftools/preprocessing/invert_adata.py +18 -11
- smftools/preprocessing/load_sample_sheet.py +30 -16
- smftools/preprocessing/recipes.py +22 -20
- smftools/preprocessing/subsample_adata.py +58 -0
- smftools/readwrite.py +105 -13
- smftools/tools/__init__.py +49 -0
- smftools/tools/apply_hmm.py +202 -0
- smftools/tools/apply_hmm_batched.py +241 -0
- smftools/tools/archived/classify_methylated_features.py +66 -0
- smftools/tools/archived/classify_non_methylated_features.py +75 -0
- smftools/tools/archived/subset_adata_v1.py +32 -0
- smftools/tools/archived/subset_adata_v2.py +46 -0
- smftools/tools/calculate_distances.py +18 -0
- smftools/tools/calculate_umap.py +62 -0
- smftools/tools/call_hmm_peaks.py +105 -0
- smftools/tools/classifiers.py +787 -0
- smftools/tools/cluster_adata_on_methylation.py +105 -0
- smftools/tools/data/__init__.py +2 -0
- smftools/tools/data/anndata_data_module.py +90 -0
- smftools/tools/data/preprocessing.py +6 -0
- smftools/tools/display_hmm.py +18 -0
- smftools/tools/general_tools.py +69 -0
- smftools/tools/hmm_readwrite.py +16 -0
- smftools/tools/inference/__init__.py +1 -0
- smftools/tools/inference/lightning_inference.py +41 -0
- smftools/tools/models/__init__.py +9 -0
- smftools/tools/models/base.py +14 -0
- smftools/tools/models/cnn.py +34 -0
- smftools/tools/models/lightning_base.py +41 -0
- smftools/tools/models/mlp.py +17 -0
- smftools/tools/models/positional.py +17 -0
- smftools/tools/models/rnn.py +16 -0
- smftools/tools/models/sklearn_models.py +40 -0
- smftools/tools/models/transformer.py +133 -0
- smftools/tools/models/wrappers.py +20 -0
- smftools/tools/nucleosome_hmm_refinement.py +104 -0
- smftools/tools/position_stats.py +239 -0
- smftools/tools/read_stats.py +70 -0
- smftools/tools/subset_adata.py +19 -23
- smftools/tools/train_hmm.py +78 -0
- smftools/tools/training/__init__.py +1 -0
- smftools/tools/training/train_lightning_model.py +47 -0
- smftools/tools/utils/__init__.py +2 -0
- smftools/tools/utils/device.py +10 -0
- smftools/tools/utils/grl.py +14 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
- smftools-0.1.7.dist-info/RECORD +136 -0
- smftools/tools/apply_HMM.py +0 -1
- smftools/tools/read_HMM.py +0 -1
- smftools/tools/train_HMM.py +0 -43
- smftools-0.1.3.dist-info/RECORD +0 -84
- /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
- /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
- {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
|
|
2
|
+
import numpy as np
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import torch
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
def plot_model_performance(metrics, save_path=None):
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import os
|
|
10
|
+
for ref in metrics.keys():
|
|
11
|
+
plt.figure(figsize=(12, 5))
|
|
12
|
+
|
|
13
|
+
# ROC Curve
|
|
14
|
+
plt.subplot(1, 2, 1)
|
|
15
|
+
for model_name, vals in metrics[ref].items():
|
|
16
|
+
model_type = model_name.split('_')[0]
|
|
17
|
+
data_type = model_name.split(f"{model_type}_")[1]
|
|
18
|
+
plt.plot(vals['fpr'], vals['tpr'], label=f"{model_type.upper()} - AUC: {vals['auc']:.4f}")
|
|
19
|
+
plt.xlabel('False Positive Rate')
|
|
20
|
+
plt.ylabel('True Positive Rate')
|
|
21
|
+
plt.title(f'{data_type} ROC Curve ({ref})')
|
|
22
|
+
plt.legend()
|
|
23
|
+
|
|
24
|
+
# PR Curve
|
|
25
|
+
plt.subplot(1, 2, 2)
|
|
26
|
+
for model_name, vals in metrics[ref].items():
|
|
27
|
+
model_type = model_name.split('_')[0]
|
|
28
|
+
data_type = model_name.split(f"{model_type}_")[1]
|
|
29
|
+
plt.plot(vals['recall'], vals['precision'], label=f"{model_type.upper()} - F1: {vals['f1']:.4f}")
|
|
30
|
+
plt.xlabel('Recall')
|
|
31
|
+
plt.ylabel('Precision')
|
|
32
|
+
plt.title(f'{data_type} Precision-Recall Curve ({ref})')
|
|
33
|
+
plt.legend()
|
|
34
|
+
|
|
35
|
+
plt.tight_layout()
|
|
36
|
+
|
|
37
|
+
if save_path:
|
|
38
|
+
save_name = f"{ref}"
|
|
39
|
+
os.makedirs(save_path, exist_ok=True)
|
|
40
|
+
safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
41
|
+
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
42
|
+
plt.savefig(out_file, dpi=300)
|
|
43
|
+
print(f"📁 Saved: {out_file}")
|
|
44
|
+
plt.show()
|
|
45
|
+
|
|
46
|
+
# Confusion Matrices
|
|
47
|
+
for model_name, vals in metrics[ref].items():
|
|
48
|
+
print(f"Confusion Matrix for {ref} - {model_name.upper()}:")
|
|
49
|
+
print(vals['confusion_matrix'])
|
|
50
|
+
print()
|
|
51
|
+
|
|
52
|
+
def plot_feature_importances_or_saliency(
|
|
53
|
+
models,
|
|
54
|
+
positions,
|
|
55
|
+
tensors,
|
|
56
|
+
site_config,
|
|
57
|
+
adata=None,
|
|
58
|
+
layer_name=None,
|
|
59
|
+
save_path=None,
|
|
60
|
+
shaded_regions=None
|
|
61
|
+
):
|
|
62
|
+
import torch
|
|
63
|
+
import numpy as np
|
|
64
|
+
import matplotlib.pyplot as plt
|
|
65
|
+
import os
|
|
66
|
+
|
|
67
|
+
# Select device for NN models
|
|
68
|
+
device = (
|
|
69
|
+
torch.device('cuda') if torch.cuda.is_available() else
|
|
70
|
+
torch.device('mps') if torch.backends.mps.is_available() else
|
|
71
|
+
torch.device('cpu')
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
for ref, model_dict in models.items():
|
|
75
|
+
if layer_name:
|
|
76
|
+
suffix = layer_name
|
|
77
|
+
else:
|
|
78
|
+
suffix = "_".join(site_config[ref]) if ref in site_config else "full"
|
|
79
|
+
|
|
80
|
+
if ref not in positions or suffix not in positions[ref]:
|
|
81
|
+
print(f"Positions not found for {ref} with suffix {suffix}. Skipping {ref}.")
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
coords_index = positions[ref][suffix]
|
|
85
|
+
coords = coords_index.astype(int)
|
|
86
|
+
|
|
87
|
+
# Classify positions using adata.var columns
|
|
88
|
+
cpg_sites = set()
|
|
89
|
+
gpc_sites = set()
|
|
90
|
+
other_sites = set()
|
|
91
|
+
|
|
92
|
+
if adata is None:
|
|
93
|
+
print("⚠️ AnnData object is required to classify site types. Skipping site type markers.")
|
|
94
|
+
else:
|
|
95
|
+
gpc_col = f"{ref}_GpC_site"
|
|
96
|
+
cpg_col = f"{ref}_CpG_site"
|
|
97
|
+
for idx_str in coords_index:
|
|
98
|
+
try:
|
|
99
|
+
gpc = adata.var.at[idx_str, gpc_col] if gpc_col in adata.var.columns else False
|
|
100
|
+
cpg = adata.var.at[idx_str, cpg_col] if cpg_col in adata.var.columns else False
|
|
101
|
+
coord_int = int(idx_str)
|
|
102
|
+
if gpc and not cpg:
|
|
103
|
+
gpc_sites.add(coord_int)
|
|
104
|
+
elif cpg and not gpc:
|
|
105
|
+
cpg_sites.add(coord_int)
|
|
106
|
+
else:
|
|
107
|
+
other_sites.add(coord_int)
|
|
108
|
+
except KeyError:
|
|
109
|
+
print(f"⚠️ Index '{idx_str}' not found in adata.var. Skipping.")
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
for model_key, model in model_dict.items():
|
|
113
|
+
if not model_key.endswith(suffix):
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
if model_key.startswith("rf"):
|
|
117
|
+
if hasattr(model, "feature_importances_"):
|
|
118
|
+
importances = model.feature_importances_
|
|
119
|
+
else:
|
|
120
|
+
print(f"Random Forest model {model_key} has no feature_importances_. Skipping.")
|
|
121
|
+
continue
|
|
122
|
+
plot_title = f"RF Feature Importances for {ref} ({model_key})"
|
|
123
|
+
y_label = "Feature Importance"
|
|
124
|
+
else:
|
|
125
|
+
if tensors is None or ref not in tensors or suffix not in tensors[ref]:
|
|
126
|
+
print(f"No input data provided for NN saliency for {model_key}. Skipping.")
|
|
127
|
+
continue
|
|
128
|
+
input_tensor = tensors[ref][suffix]
|
|
129
|
+
model.eval()
|
|
130
|
+
input_tensor = input_tensor.to(device)
|
|
131
|
+
input_tensor.requires_grad_()
|
|
132
|
+
|
|
133
|
+
with torch.enable_grad():
|
|
134
|
+
logits = model(input_tensor)
|
|
135
|
+
score = logits[:, 1].sum()
|
|
136
|
+
score.backward()
|
|
137
|
+
saliency = input_tensor.grad.abs().mean(dim=0).cpu().numpy()
|
|
138
|
+
importances = saliency
|
|
139
|
+
plot_title = f"Feature Saliency for {ref} ({model_key})"
|
|
140
|
+
y_label = "Feature Saliency"
|
|
141
|
+
|
|
142
|
+
sorted_idx = np.argsort(coords)
|
|
143
|
+
positions_sorted = coords[sorted_idx]
|
|
144
|
+
importances_sorted = np.array(importances)[sorted_idx]
|
|
145
|
+
|
|
146
|
+
plt.figure(figsize=(12, 4))
|
|
147
|
+
for pos, imp in zip(positions_sorted, importances_sorted):
|
|
148
|
+
if pos in cpg_sites:
|
|
149
|
+
plt.plot(pos, imp, marker='*', color='black', markersize=10, linestyle='None',
|
|
150
|
+
label='CpG site' if 'CpG site' not in plt.gca().get_legend_handles_labels()[1] else "")
|
|
151
|
+
elif pos in gpc_sites:
|
|
152
|
+
plt.plot(pos, imp, marker='o', color='blue', markersize=6, linestyle='None',
|
|
153
|
+
label='GpC site' if 'GpC site' not in plt.gca().get_legend_handles_labels()[1] else "")
|
|
154
|
+
else:
|
|
155
|
+
plt.plot(pos, imp, marker='.', color='gray', linestyle='None',
|
|
156
|
+
label='Other' if 'Other' not in plt.gca().get_legend_handles_labels()[1] else "")
|
|
157
|
+
|
|
158
|
+
plt.plot(positions_sorted, importances_sorted, linestyle='-', alpha=0.5, color='black')
|
|
159
|
+
|
|
160
|
+
if shaded_regions:
|
|
161
|
+
for (start, end) in shaded_regions:
|
|
162
|
+
plt.axvspan(start, end, color='gray', alpha=0.3)
|
|
163
|
+
|
|
164
|
+
plt.xlabel("Genomic Position")
|
|
165
|
+
plt.ylabel(y_label)
|
|
166
|
+
plt.title(plot_title)
|
|
167
|
+
plt.grid(True)
|
|
168
|
+
plt.legend()
|
|
169
|
+
plt.tight_layout()
|
|
170
|
+
|
|
171
|
+
if save_path:
|
|
172
|
+
os.makedirs(save_path, exist_ok=True)
|
|
173
|
+
safe_name = plot_title.replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
|
|
174
|
+
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
175
|
+
plt.savefig(out_file, dpi=300)
|
|
176
|
+
print(f"📁 Saved: {out_file}")
|
|
177
|
+
|
|
178
|
+
plt.show()
|
|
179
|
+
|
|
180
|
+
def plot_model_curves_from_adata(
|
|
181
|
+
adata,
|
|
182
|
+
label_col='activity_status',
|
|
183
|
+
model_names = ["cnn", "mlp", "rf"],
|
|
184
|
+
suffix='GpC_site_CpG_site',
|
|
185
|
+
omit_training=True,
|
|
186
|
+
save_path=None,
|
|
187
|
+
ylim_roc=(0.0, 1.05),
|
|
188
|
+
ylim_pr=(0.0, 1.05)):
|
|
189
|
+
|
|
190
|
+
from sklearn.metrics import precision_recall_curve, roc_curve, auc
|
|
191
|
+
import matplotlib.pyplot as plt
|
|
192
|
+
import seaborn as sns
|
|
193
|
+
|
|
194
|
+
if omit_training:
|
|
195
|
+
subset = adata[adata.obs['used_for_training'].astype(bool) == False]
|
|
196
|
+
|
|
197
|
+
label = subset.obs[label_col].map({'Active': 1, 'Silent': 0}).values
|
|
198
|
+
|
|
199
|
+
positive_ratio = np.sum(label.astype(int)) / len(label)
|
|
200
|
+
|
|
201
|
+
plt.figure(figsize=(12, 5))
|
|
202
|
+
|
|
203
|
+
# ROC curve
|
|
204
|
+
plt.subplot(1, 2, 1)
|
|
205
|
+
for model in model_names:
|
|
206
|
+
prob_col = f"{model}_active_prob_{suffix}"
|
|
207
|
+
if prob_col in subset.obs.columns:
|
|
208
|
+
probs = subset.obs[prob_col].astype(float).values
|
|
209
|
+
fpr, tpr, _ = roc_curve(label, probs)
|
|
210
|
+
roc_auc = auc(fpr, tpr)
|
|
211
|
+
plt.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
|
|
212
|
+
|
|
213
|
+
plt.plot([0, 1], [0, 1], 'k--', alpha=0.5)
|
|
214
|
+
plt.xlabel("False Positive Rate")
|
|
215
|
+
plt.ylabel("True Positive Rate")
|
|
216
|
+
plt.title("ROC Curve")
|
|
217
|
+
plt.ylim(*ylim_roc)
|
|
218
|
+
plt.legend()
|
|
219
|
+
|
|
220
|
+
# PR curve
|
|
221
|
+
plt.subplot(1, 2, 2)
|
|
222
|
+
for model in model_names:
|
|
223
|
+
prob_col = f"{model}_active_prob_{suffix}"
|
|
224
|
+
if prob_col in subset.obs.columns:
|
|
225
|
+
probs = subset.obs[prob_col].astype(float).values
|
|
226
|
+
precision, recall, _ = precision_recall_curve(label, probs)
|
|
227
|
+
pr_auc = auc(recall, precision)
|
|
228
|
+
plt.plot(recall, precision, label=f"{model.upper()} (AUC={pr_auc:.4f})")
|
|
229
|
+
|
|
230
|
+
plt.xlabel("Recall")
|
|
231
|
+
plt.ylabel("Precision")
|
|
232
|
+
plt.ylim(*ylim_pr)
|
|
233
|
+
plt.axhline(y=positive_ratio, linestyle='--', color='gray', label='Random Baseline')
|
|
234
|
+
plt.title("Precision-Recall Curve")
|
|
235
|
+
plt.legend()
|
|
236
|
+
|
|
237
|
+
plt.tight_layout()
|
|
238
|
+
if save_path:
|
|
239
|
+
save_name = f"ROC_PR_curves"
|
|
240
|
+
os.makedirs(save_path, exist_ok=True)
|
|
241
|
+
safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
242
|
+
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
243
|
+
plt.savefig(out_file, dpi=300)
|
|
244
|
+
print(f"📁 Saved: {out_file}")
|
|
245
|
+
plt.show()
|
|
246
|
+
|
|
247
|
+
def plot_model_curves_from_adata_with_frequency_grid(
|
|
248
|
+
adata,
|
|
249
|
+
label_col='activity_status',
|
|
250
|
+
model_names=["cnn", "mlp", "rf"],
|
|
251
|
+
suffix='GpC_site_CpG_site',
|
|
252
|
+
omit_training=True,
|
|
253
|
+
save_path=None,
|
|
254
|
+
ylim_roc=(0.0, 1.05),
|
|
255
|
+
ylim_pr=(0.0, 1.05),
|
|
256
|
+
pos_sample_count=500,
|
|
257
|
+
pos_freq_list=[0.01, 0.05, 0.1],
|
|
258
|
+
show_f1_iso_curves=False,
|
|
259
|
+
f1_levels=None):
|
|
260
|
+
import numpy as np
|
|
261
|
+
import matplotlib.pyplot as plt
|
|
262
|
+
import seaborn as sns
|
|
263
|
+
import os
|
|
264
|
+
from sklearn.metrics import precision_recall_curve, roc_curve, auc
|
|
265
|
+
|
|
266
|
+
if f1_levels is None:
|
|
267
|
+
f1_levels = np.linspace(0.2, 0.9, 8)
|
|
268
|
+
|
|
269
|
+
if omit_training:
|
|
270
|
+
subset = adata[adata.obs['used_for_training'].astype(bool) == False]
|
|
271
|
+
else:
|
|
272
|
+
subset = adata
|
|
273
|
+
|
|
274
|
+
label = subset.obs[label_col].map({'Active': 1, 'Silent': 0}).values
|
|
275
|
+
subset = subset.copy()
|
|
276
|
+
subset.obs["__label__"] = label
|
|
277
|
+
|
|
278
|
+
pos_indices = np.where(label == 1)[0]
|
|
279
|
+
neg_indices = np.where(label == 0)[0]
|
|
280
|
+
|
|
281
|
+
n_rows = len(pos_freq_list)
|
|
282
|
+
fig, axes = plt.subplots(n_rows, 2, figsize=(12, 5 * n_rows))
|
|
283
|
+
fig.suptitle(f'{suffix} Performance metrics')
|
|
284
|
+
|
|
285
|
+
for row_idx, pos_freq in enumerate(pos_freq_list):
|
|
286
|
+
desired_total = int(pos_sample_count / pos_freq)
|
|
287
|
+
neg_sample_count = desired_total - pos_sample_count
|
|
288
|
+
|
|
289
|
+
if pos_sample_count > len(pos_indices) or neg_sample_count > len(neg_indices):
|
|
290
|
+
print(f"⚠️ Skipping frequency {pos_freq:.3f}: not enough samples.")
|
|
291
|
+
continue
|
|
292
|
+
|
|
293
|
+
sampled_pos = np.random.choice(pos_indices, size=pos_sample_count, replace=False)
|
|
294
|
+
sampled_neg = np.random.choice(neg_indices, size=neg_sample_count, replace=False)
|
|
295
|
+
sampled_indices = np.concatenate([sampled_pos, sampled_neg])
|
|
296
|
+
|
|
297
|
+
data_sampled = subset[sampled_indices]
|
|
298
|
+
y_true = data_sampled.obs["__label__"].values
|
|
299
|
+
|
|
300
|
+
ax_roc = axes[row_idx, 0] if n_rows > 1 else axes[0]
|
|
301
|
+
ax_pr = axes[row_idx, 1] if n_rows > 1 else axes[1]
|
|
302
|
+
|
|
303
|
+
# ROC Curve
|
|
304
|
+
for model in model_names:
|
|
305
|
+
prob_col = f"{model}_active_prob_{suffix}"
|
|
306
|
+
if prob_col in data_sampled.obs.columns:
|
|
307
|
+
probs = data_sampled.obs[prob_col].astype(float).values
|
|
308
|
+
fpr, tpr, _ = roc_curve(y_true, probs)
|
|
309
|
+
roc_auc = auc(fpr, tpr)
|
|
310
|
+
ax_roc.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
|
|
311
|
+
ax_roc.plot([0, 1], [0, 1], 'k--', alpha=0.5)
|
|
312
|
+
ax_roc.set_xlabel("False Positive Rate")
|
|
313
|
+
ax_roc.set_ylabel("True Positive Rate")
|
|
314
|
+
ax_roc.set_ylim(*ylim_roc)
|
|
315
|
+
ax_roc.set_title(f"ROC Curve (Pos Freq: {pos_freq:.2%})")
|
|
316
|
+
ax_roc.legend()
|
|
317
|
+
ax_roc.spines['top'].set_visible(False)
|
|
318
|
+
ax_roc.spines['right'].set_visible(False)
|
|
319
|
+
|
|
320
|
+
# PR Curve
|
|
321
|
+
for model in model_names:
|
|
322
|
+
prob_col = f"{model}_active_prob_{suffix}"
|
|
323
|
+
if prob_col in data_sampled.obs.columns:
|
|
324
|
+
probs = data_sampled.obs[prob_col].astype(float).values
|
|
325
|
+
precision, recall, _ = precision_recall_curve(y_true, probs)
|
|
326
|
+
pr_auc = auc(recall, precision)
|
|
327
|
+
ax_pr.plot(recall, precision, label=f"{model.upper()} (AUC={pr_auc:.4f})")
|
|
328
|
+
ax_pr.axhline(y=pos_freq, linestyle='--', color='gray', label='Random Baseline')
|
|
329
|
+
|
|
330
|
+
if show_f1_iso_curves:
|
|
331
|
+
recall_vals = np.linspace(0.01, 1, 500)
|
|
332
|
+
for f1 in f1_levels:
|
|
333
|
+
precision_vals = (f1 * recall_vals) / (2 * recall_vals - f1)
|
|
334
|
+
precision_vals[precision_vals < 0] = np.nan # Avoid plotting invalid values
|
|
335
|
+
ax_pr.plot(recall_vals, precision_vals, color='gray', linestyle=':', linewidth=1, alpha=0.6)
|
|
336
|
+
x_val = 0.9
|
|
337
|
+
y_val = (f1 * x_val) / (2 * x_val - f1)
|
|
338
|
+
if 0 < y_val < 1:
|
|
339
|
+
ax_pr.text(x_val, y_val, f"F1={f1:.1f}", fontsize=8, color='gray')
|
|
340
|
+
|
|
341
|
+
ax_pr.set_xlabel("Recall")
|
|
342
|
+
ax_pr.set_ylabel("Precision")
|
|
343
|
+
ax_pr.set_ylim(*ylim_pr)
|
|
344
|
+
ax_pr.set_title(f"PR Curve (Pos Freq: {pos_freq:.2%})")
|
|
345
|
+
ax_pr.legend()
|
|
346
|
+
ax_pr.spines['top'].set_visible(False)
|
|
347
|
+
ax_pr.spines['right'].set_visible(False)
|
|
348
|
+
|
|
349
|
+
plt.tight_layout(rect=[0, 0, 1, 0.97])
|
|
350
|
+
if save_path:
|
|
351
|
+
os.makedirs(save_path, exist_ok=True)
|
|
352
|
+
out_file = os.path.join(save_path, "ROC_PR_grid.png")
|
|
353
|
+
plt.savefig(out_file, dpi=300)
|
|
354
|
+
print(f"📁 Saved: {out_file}")
|
|
355
|
+
plt.show()
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import seaborn as sns
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
|
|
5
|
+
def clean_barplot(ax, mean_values, title):
|
|
6
|
+
x = np.arange(len(mean_values))
|
|
7
|
+
ax.bar(x, mean_values, color="gray", width=1.0, align='edge')
|
|
8
|
+
ax.set_xlim(0, len(mean_values))
|
|
9
|
+
ax.set_ylim(0, 1)
|
|
10
|
+
ax.set_yticks([0.0, 0.5, 1.0])
|
|
11
|
+
ax.set_ylabel("Mean")
|
|
12
|
+
ax.set_title(title, fontsize=12, pad=2)
|
|
13
|
+
|
|
14
|
+
# Hide all spines except left
|
|
15
|
+
for spine_name, spine in ax.spines.items():
|
|
16
|
+
spine.set_visible(spine_name == 'left')
|
|
17
|
+
|
|
18
|
+
ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def combined_hmm_raw_clustermap(
|
|
22
|
+
adata,
|
|
23
|
+
sample_col='Sample_Names',
|
|
24
|
+
hmm_feature_layer="hmm_combined",
|
|
25
|
+
layer_gpc="nan0_0minus1",
|
|
26
|
+
layer_cpg="nan0_0minus1",
|
|
27
|
+
cmap_hmm="tab10",
|
|
28
|
+
cmap_gpc="coolwarm",
|
|
29
|
+
cmap_cpg="viridis",
|
|
30
|
+
min_quality=20,
|
|
31
|
+
min_length=2700,
|
|
32
|
+
sample_mapping=None,
|
|
33
|
+
save_path=None,
|
|
34
|
+
normalize_hmm=False,
|
|
35
|
+
sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
|
|
36
|
+
bins=None
|
|
37
|
+
):
|
|
38
|
+
import scipy.cluster.hierarchy as sch
|
|
39
|
+
import pandas as pd
|
|
40
|
+
import numpy as np
|
|
41
|
+
import seaborn as sns
|
|
42
|
+
import matplotlib.pyplot as plt
|
|
43
|
+
import matplotlib.gridspec as gridspec
|
|
44
|
+
import os
|
|
45
|
+
|
|
46
|
+
results = []
|
|
47
|
+
|
|
48
|
+
for ref in adata.obs["Reference_strand"].cat.categories:
|
|
49
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
50
|
+
try:
|
|
51
|
+
subset = adata[
|
|
52
|
+
(adata.obs['Reference_strand'] == ref) &
|
|
53
|
+
(adata.obs[sample_col] == sample) &
|
|
54
|
+
(adata.obs['query_read_quality'] >= min_quality) &
|
|
55
|
+
(adata.obs['read_length'] >= min_length) &
|
|
56
|
+
(adata.obs['Raw_methylation_signal'] >= 20)
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
subset = subset[:, subset.var[f'position_in_{ref}'] == True]
|
|
60
|
+
|
|
61
|
+
if subset.shape[0] == 0:
|
|
62
|
+
print(f" ❌ No reads left after filtering for {sample} - {ref}")
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
if bins:
|
|
66
|
+
pass
|
|
67
|
+
else:
|
|
68
|
+
bins = {"All": (subset.obs['Reference_strand'] != None)}
|
|
69
|
+
|
|
70
|
+
# Get column positions (not var_names!) of site masks
|
|
71
|
+
gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
|
|
72
|
+
cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
|
|
73
|
+
|
|
74
|
+
# Use var_names for x-axis tick labels
|
|
75
|
+
gpc_labels = subset.var_names[gpc_sites].astype(int)
|
|
76
|
+
cpg_labels = subset.var_names[cpg_sites].astype(int)
|
|
77
|
+
|
|
78
|
+
stacked_hmm_feature, stacked_gpc, stacked_cpg = [], [], []
|
|
79
|
+
row_labels, bin_labels = [], []
|
|
80
|
+
bin_boundaries = []
|
|
81
|
+
|
|
82
|
+
total_reads = subset.shape[0]
|
|
83
|
+
percentages = {}
|
|
84
|
+
last_idx = 0
|
|
85
|
+
|
|
86
|
+
for bin_label, bin_filter in bins.items():
|
|
87
|
+
subset_bin = subset[bin_filter].copy()
|
|
88
|
+
num_reads = subset_bin.shape[0]
|
|
89
|
+
percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
|
|
90
|
+
percentages[bin_label] = percent_reads
|
|
91
|
+
|
|
92
|
+
if num_reads > 0:
|
|
93
|
+
# Determine sorting order
|
|
94
|
+
if sort_by.startswith("obs:"):
|
|
95
|
+
colname = sort_by.split("obs:")[1]
|
|
96
|
+
order = np.argsort(subset_bin.obs[colname].values)
|
|
97
|
+
elif sort_by == "gpc":
|
|
98
|
+
linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
|
|
99
|
+
order = sch.leaves_list(linkage)
|
|
100
|
+
elif sort_by == "cpg":
|
|
101
|
+
linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
|
|
102
|
+
order = sch.leaves_list(linkage)
|
|
103
|
+
elif sort_by == "gpc_cpg":
|
|
104
|
+
linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
|
|
105
|
+
order = sch.leaves_list(linkage)
|
|
106
|
+
elif sort_by == "none":
|
|
107
|
+
order = np.arange(num_reads)
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
110
|
+
|
|
111
|
+
stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
|
|
112
|
+
stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
|
|
113
|
+
stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
|
|
114
|
+
|
|
115
|
+
row_labels.extend([bin_label] * num_reads)
|
|
116
|
+
bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
117
|
+
last_idx += num_reads
|
|
118
|
+
bin_boundaries.append(last_idx)
|
|
119
|
+
|
|
120
|
+
if stacked_hmm_feature:
|
|
121
|
+
hmm_matrix = np.vstack(stacked_hmm_feature)
|
|
122
|
+
gpc_matrix = np.vstack(stacked_gpc)
|
|
123
|
+
cpg_matrix = np.vstack(stacked_cpg)
|
|
124
|
+
|
|
125
|
+
def normalized_mean(matrix):
|
|
126
|
+
mean = np.nanmean(matrix, axis=0)
|
|
127
|
+
normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
|
|
128
|
+
return normalized
|
|
129
|
+
|
|
130
|
+
def methylation_fraction(matrix):
|
|
131
|
+
methylated = (matrix == 1).sum(axis=0)
|
|
132
|
+
valid = (matrix != 0).sum(axis=0)
|
|
133
|
+
return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
|
|
134
|
+
|
|
135
|
+
if normalize_hmm:
|
|
136
|
+
mean_hmm = normalized_mean(hmm_matrix)
|
|
137
|
+
else:
|
|
138
|
+
mean_hmm = np.nanmean(hmm_matrix, axis=0)
|
|
139
|
+
mean_gpc = methylation_fraction(gpc_matrix)
|
|
140
|
+
mean_cpg = methylation_fraction(cpg_matrix)
|
|
141
|
+
|
|
142
|
+
fig = plt.figure(figsize=(18, 12))
|
|
143
|
+
gs = gridspec.GridSpec(2, 3, height_ratios=[1, 6], hspace=0.01)
|
|
144
|
+
fig.suptitle(f"{sample} - {ref}", fontsize=14, y=0.95)
|
|
145
|
+
|
|
146
|
+
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(3)]
|
|
147
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(3)]
|
|
148
|
+
|
|
149
|
+
clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
|
|
150
|
+
clean_barplot(axes_bar[1], mean_gpc, f"GpC Methylation")
|
|
151
|
+
clean_barplot(axes_bar[2], mean_cpg, f"CpG Methylation")
|
|
152
|
+
|
|
153
|
+
hmm_labels = subset.var_names.astype(int)
|
|
154
|
+
hmm_label_spacing = 150
|
|
155
|
+
sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
|
|
156
|
+
axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
|
|
157
|
+
axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
|
|
158
|
+
for boundary in bin_boundaries[:-1]:
|
|
159
|
+
axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
|
|
160
|
+
|
|
161
|
+
sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
|
|
162
|
+
axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
|
|
163
|
+
axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
|
|
164
|
+
for boundary in bin_boundaries[:-1]:
|
|
165
|
+
axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
|
|
166
|
+
|
|
167
|
+
sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
|
|
168
|
+
axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
|
|
169
|
+
for boundary in bin_boundaries[:-1]:
|
|
170
|
+
axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
|
|
171
|
+
|
|
172
|
+
plt.tight_layout()
|
|
173
|
+
|
|
174
|
+
if save_path:
|
|
175
|
+
save_name = f"{ref} — {sample}"
|
|
176
|
+
os.makedirs(save_path, exist_ok=True)
|
|
177
|
+
safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
178
|
+
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
179
|
+
plt.savefig(out_file, dpi=300)
|
|
180
|
+
print(f"📁 Saved: {out_file}")
|
|
181
|
+
|
|
182
|
+
plt.show()
|
|
183
|
+
|
|
184
|
+
print(f"📊 Summary for {sample} - {ref}:")
|
|
185
|
+
for bin_label, percent in percentages.items():
|
|
186
|
+
print(f" - {bin_label}: {percent:.1f}%")
|
|
187
|
+
|
|
188
|
+
results.append({
|
|
189
|
+
"sample": sample,
|
|
190
|
+
"ref": ref,
|
|
191
|
+
"hmm_matrix": hmm_matrix,
|
|
192
|
+
"gpc_matrix": gpc_matrix,
|
|
193
|
+
"cpg_matrix": cpg_matrix,
|
|
194
|
+
"row_labels": row_labels,
|
|
195
|
+
"bin_labels": bin_labels,
|
|
196
|
+
"bin_boundaries": bin_boundaries,
|
|
197
|
+
"percentages": percentages
|
|
198
|
+
})
|
|
199
|
+
|
|
200
|
+
adata.uns['clustermap_results'] = results
|
|
201
|
+
|
|
202
|
+
except Exception as e:
|
|
203
|
+
import traceback
|
|
204
|
+
traceback.print_exc()
|
|
205
|
+
continue
|