smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. smftools/__init__.py +5 -1
  2. smftools/_version.py +1 -1
  3. smftools/informatics/__init__.py +2 -0
  4. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  5. smftools/informatics/basecall_pod5s.py +80 -0
  6. smftools/informatics/conversion_smf.py +63 -10
  7. smftools/informatics/direct_smf.py +66 -18
  8. smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
  9. smftools/informatics/helpers/__init__.py +16 -2
  10. smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
  11. smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
  12. smftools/informatics/helpers/bam_qc.py +66 -0
  13. smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
  14. smftools/informatics/helpers/canoncall.py +12 -3
  15. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
  16. smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
  17. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  18. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  19. smftools/informatics/helpers/extract_base_identities.py +33 -46
  20. smftools/informatics/helpers/extract_mods.py +55 -23
  21. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  22. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  23. smftools/informatics/helpers/find_conversion_sites.py +33 -44
  24. smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
  25. smftools/informatics/helpers/modcall.py +13 -5
  26. smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
  27. smftools/informatics/helpers/ohe_batching.py +65 -41
  28. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  29. smftools/informatics/helpers/one_hot_decode.py +27 -0
  30. smftools/informatics/helpers/one_hot_encode.py +45 -9
  31. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
  32. smftools/informatics/helpers/run_multiqc.py +28 -0
  33. smftools/informatics/helpers/split_and_index_BAM.py +3 -8
  34. smftools/informatics/load_adata.py +58 -3
  35. smftools/plotting/__init__.py +15 -0
  36. smftools/plotting/classifiers.py +355 -0
  37. smftools/plotting/general_plotting.py +205 -0
  38. smftools/plotting/position_stats.py +462 -0
  39. smftools/preprocessing/__init__.py +6 -7
  40. smftools/preprocessing/append_C_context.py +22 -9
  41. smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
  42. smftools/preprocessing/binarize_on_Youden.py +35 -32
  43. smftools/preprocessing/binary_layers_to_ohe.py +13 -3
  44. smftools/preprocessing/calculate_complexity.py +3 -2
  45. smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
  46. smftools/preprocessing/calculate_coverage.py +26 -25
  47. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  48. smftools/preprocessing/calculate_position_Youden.py +18 -7
  49. smftools/preprocessing/calculate_read_length_stats.py +39 -46
  50. smftools/preprocessing/clean_NaN.py +33 -25
  51. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  52. smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
  53. smftools/preprocessing/filter_reads_on_length.py +14 -4
  54. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  55. smftools/preprocessing/invert_adata.py +18 -11
  56. smftools/preprocessing/load_sample_sheet.py +30 -16
  57. smftools/preprocessing/recipes.py +22 -20
  58. smftools/preprocessing/subsample_adata.py +58 -0
  59. smftools/readwrite.py +105 -13
  60. smftools/tools/__init__.py +49 -0
  61. smftools/tools/apply_hmm.py +202 -0
  62. smftools/tools/apply_hmm_batched.py +241 -0
  63. smftools/tools/archived/classify_methylated_features.py +66 -0
  64. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  65. smftools/tools/archived/subset_adata_v1.py +32 -0
  66. smftools/tools/archived/subset_adata_v2.py +46 -0
  67. smftools/tools/calculate_distances.py +18 -0
  68. smftools/tools/calculate_umap.py +62 -0
  69. smftools/tools/call_hmm_peaks.py +105 -0
  70. smftools/tools/classifiers.py +787 -0
  71. smftools/tools/cluster_adata_on_methylation.py +105 -0
  72. smftools/tools/data/__init__.py +2 -0
  73. smftools/tools/data/anndata_data_module.py +90 -0
  74. smftools/tools/data/preprocessing.py +6 -0
  75. smftools/tools/display_hmm.py +18 -0
  76. smftools/tools/general_tools.py +69 -0
  77. smftools/tools/hmm_readwrite.py +16 -0
  78. smftools/tools/inference/__init__.py +1 -0
  79. smftools/tools/inference/lightning_inference.py +41 -0
  80. smftools/tools/models/__init__.py +9 -0
  81. smftools/tools/models/base.py +14 -0
  82. smftools/tools/models/cnn.py +34 -0
  83. smftools/tools/models/lightning_base.py +41 -0
  84. smftools/tools/models/mlp.py +17 -0
  85. smftools/tools/models/positional.py +17 -0
  86. smftools/tools/models/rnn.py +16 -0
  87. smftools/tools/models/sklearn_models.py +40 -0
  88. smftools/tools/models/transformer.py +133 -0
  89. smftools/tools/models/wrappers.py +20 -0
  90. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  91. smftools/tools/position_stats.py +239 -0
  92. smftools/tools/read_stats.py +70 -0
  93. smftools/tools/subset_adata.py +19 -23
  94. smftools/tools/train_hmm.py +78 -0
  95. smftools/tools/training/__init__.py +1 -0
  96. smftools/tools/training/train_lightning_model.py +47 -0
  97. smftools/tools/utils/__init__.py +2 -0
  98. smftools/tools/utils/device.py +10 -0
  99. smftools/tools/utils/grl.py +14 -0
  100. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
  101. smftools-0.1.7.dist-info/RECORD +136 -0
  102. smftools/tools/apply_HMM.py +0 -1
  103. smftools/tools/read_HMM.py +0 -1
  104. smftools/tools/train_HMM.py +0 -43
  105. smftools-0.1.3.dist-info/RECORD +0 -84
  106. /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
  107. /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
  108. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  109. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,355 @@
1
+
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ import os
6
+
7
+ def plot_model_performance(metrics, save_path=None):
8
+ import matplotlib.pyplot as plt
9
+ import os
10
+ for ref in metrics.keys():
11
+ plt.figure(figsize=(12, 5))
12
+
13
+ # ROC Curve
14
+ plt.subplot(1, 2, 1)
15
+ for model_name, vals in metrics[ref].items():
16
+ model_type = model_name.split('_')[0]
17
+ data_type = model_name.split(f"{model_type}_")[1]
18
+ plt.plot(vals['fpr'], vals['tpr'], label=f"{model_type.upper()} - AUC: {vals['auc']:.4f}")
19
+ plt.xlabel('False Positive Rate')
20
+ plt.ylabel('True Positive Rate')
21
+ plt.title(f'{data_type} ROC Curve ({ref})')
22
+ plt.legend()
23
+
24
+ # PR Curve
25
+ plt.subplot(1, 2, 2)
26
+ for model_name, vals in metrics[ref].items():
27
+ model_type = model_name.split('_')[0]
28
+ data_type = model_name.split(f"{model_type}_")[1]
29
+ plt.plot(vals['recall'], vals['precision'], label=f"{model_type.upper()} - F1: {vals['f1']:.4f}")
30
+ plt.xlabel('Recall')
31
+ plt.ylabel('Precision')
32
+ plt.title(f'{data_type} Precision-Recall Curve ({ref})')
33
+ plt.legend()
34
+
35
+ plt.tight_layout()
36
+
37
+ if save_path:
38
+ save_name = f"{ref}"
39
+ os.makedirs(save_path, exist_ok=True)
40
+ safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
41
+ out_file = os.path.join(save_path, f"{safe_name}.png")
42
+ plt.savefig(out_file, dpi=300)
43
+ print(f"📁 Saved: {out_file}")
44
+ plt.show()
45
+
46
+ # Confusion Matrices
47
+ for model_name, vals in metrics[ref].items():
48
+ print(f"Confusion Matrix for {ref} - {model_name.upper()}:")
49
+ print(vals['confusion_matrix'])
50
+ print()
51
+
52
+ def plot_feature_importances_or_saliency(
53
+ models,
54
+ positions,
55
+ tensors,
56
+ site_config,
57
+ adata=None,
58
+ layer_name=None,
59
+ save_path=None,
60
+ shaded_regions=None
61
+ ):
62
+ import torch
63
+ import numpy as np
64
+ import matplotlib.pyplot as plt
65
+ import os
66
+
67
+ # Select device for NN models
68
+ device = (
69
+ torch.device('cuda') if torch.cuda.is_available() else
70
+ torch.device('mps') if torch.backends.mps.is_available() else
71
+ torch.device('cpu')
72
+ )
73
+
74
+ for ref, model_dict in models.items():
75
+ if layer_name:
76
+ suffix = layer_name
77
+ else:
78
+ suffix = "_".join(site_config[ref]) if ref in site_config else "full"
79
+
80
+ if ref not in positions or suffix not in positions[ref]:
81
+ print(f"Positions not found for {ref} with suffix {suffix}. Skipping {ref}.")
82
+ continue
83
+
84
+ coords_index = positions[ref][suffix]
85
+ coords = coords_index.astype(int)
86
+
87
+ # Classify positions using adata.var columns
88
+ cpg_sites = set()
89
+ gpc_sites = set()
90
+ other_sites = set()
91
+
92
+ if adata is None:
93
+ print("⚠️ AnnData object is required to classify site types. Skipping site type markers.")
94
+ else:
95
+ gpc_col = f"{ref}_GpC_site"
96
+ cpg_col = f"{ref}_CpG_site"
97
+ for idx_str in coords_index:
98
+ try:
99
+ gpc = adata.var.at[idx_str, gpc_col] if gpc_col in adata.var.columns else False
100
+ cpg = adata.var.at[idx_str, cpg_col] if cpg_col in adata.var.columns else False
101
+ coord_int = int(idx_str)
102
+ if gpc and not cpg:
103
+ gpc_sites.add(coord_int)
104
+ elif cpg and not gpc:
105
+ cpg_sites.add(coord_int)
106
+ else:
107
+ other_sites.add(coord_int)
108
+ except KeyError:
109
+ print(f"⚠️ Index '{idx_str}' not found in adata.var. Skipping.")
110
+ continue
111
+
112
+ for model_key, model in model_dict.items():
113
+ if not model_key.endswith(suffix):
114
+ continue
115
+
116
+ if model_key.startswith("rf"):
117
+ if hasattr(model, "feature_importances_"):
118
+ importances = model.feature_importances_
119
+ else:
120
+ print(f"Random Forest model {model_key} has no feature_importances_. Skipping.")
121
+ continue
122
+ plot_title = f"RF Feature Importances for {ref} ({model_key})"
123
+ y_label = "Feature Importance"
124
+ else:
125
+ if tensors is None or ref not in tensors or suffix not in tensors[ref]:
126
+ print(f"No input data provided for NN saliency for {model_key}. Skipping.")
127
+ continue
128
+ input_tensor = tensors[ref][suffix]
129
+ model.eval()
130
+ input_tensor = input_tensor.to(device)
131
+ input_tensor.requires_grad_()
132
+
133
+ with torch.enable_grad():
134
+ logits = model(input_tensor)
135
+ score = logits[:, 1].sum()
136
+ score.backward()
137
+ saliency = input_tensor.grad.abs().mean(dim=0).cpu().numpy()
138
+ importances = saliency
139
+ plot_title = f"Feature Saliency for {ref} ({model_key})"
140
+ y_label = "Feature Saliency"
141
+
142
+ sorted_idx = np.argsort(coords)
143
+ positions_sorted = coords[sorted_idx]
144
+ importances_sorted = np.array(importances)[sorted_idx]
145
+
146
+ plt.figure(figsize=(12, 4))
147
+ for pos, imp in zip(positions_sorted, importances_sorted):
148
+ if pos in cpg_sites:
149
+ plt.plot(pos, imp, marker='*', color='black', markersize=10, linestyle='None',
150
+ label='CpG site' if 'CpG site' not in plt.gca().get_legend_handles_labels()[1] else "")
151
+ elif pos in gpc_sites:
152
+ plt.plot(pos, imp, marker='o', color='blue', markersize=6, linestyle='None',
153
+ label='GpC site' if 'GpC site' not in plt.gca().get_legend_handles_labels()[1] else "")
154
+ else:
155
+ plt.plot(pos, imp, marker='.', color='gray', linestyle='None',
156
+ label='Other' if 'Other' not in plt.gca().get_legend_handles_labels()[1] else "")
157
+
158
+ plt.plot(positions_sorted, importances_sorted, linestyle='-', alpha=0.5, color='black')
159
+
160
+ if shaded_regions:
161
+ for (start, end) in shaded_regions:
162
+ plt.axvspan(start, end, color='gray', alpha=0.3)
163
+
164
+ plt.xlabel("Genomic Position")
165
+ plt.ylabel(y_label)
166
+ plt.title(plot_title)
167
+ plt.grid(True)
168
+ plt.legend()
169
+ plt.tight_layout()
170
+
171
+ if save_path:
172
+ os.makedirs(save_path, exist_ok=True)
173
+ safe_name = plot_title.replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
174
+ out_file = os.path.join(save_path, f"{safe_name}.png")
175
+ plt.savefig(out_file, dpi=300)
176
+ print(f"📁 Saved: {out_file}")
177
+
178
+ plt.show()
179
+
180
+ def plot_model_curves_from_adata(
181
+ adata,
182
+ label_col='activity_status',
183
+ model_names = ["cnn", "mlp", "rf"],
184
+ suffix='GpC_site_CpG_site',
185
+ omit_training=True,
186
+ save_path=None,
187
+ ylim_roc=(0.0, 1.05),
188
+ ylim_pr=(0.0, 1.05)):
189
+
190
+ from sklearn.metrics import precision_recall_curve, roc_curve, auc
191
+ import matplotlib.pyplot as plt
192
+ import seaborn as sns
193
+
194
+ if omit_training:
195
+ subset = adata[adata.obs['used_for_training'].astype(bool) == False]
196
+
197
+ label = subset.obs[label_col].map({'Active': 1, 'Silent': 0}).values
198
+
199
+ positive_ratio = np.sum(label.astype(int)) / len(label)
200
+
201
+ plt.figure(figsize=(12, 5))
202
+
203
+ # ROC curve
204
+ plt.subplot(1, 2, 1)
205
+ for model in model_names:
206
+ prob_col = f"{model}_active_prob_{suffix}"
207
+ if prob_col in subset.obs.columns:
208
+ probs = subset.obs[prob_col].astype(float).values
209
+ fpr, tpr, _ = roc_curve(label, probs)
210
+ roc_auc = auc(fpr, tpr)
211
+ plt.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
212
+
213
+ plt.plot([0, 1], [0, 1], 'k--', alpha=0.5)
214
+ plt.xlabel("False Positive Rate")
215
+ plt.ylabel("True Positive Rate")
216
+ plt.title("ROC Curve")
217
+ plt.ylim(*ylim_roc)
218
+ plt.legend()
219
+
220
+ # PR curve
221
+ plt.subplot(1, 2, 2)
222
+ for model in model_names:
223
+ prob_col = f"{model}_active_prob_{suffix}"
224
+ if prob_col in subset.obs.columns:
225
+ probs = subset.obs[prob_col].astype(float).values
226
+ precision, recall, _ = precision_recall_curve(label, probs)
227
+ pr_auc = auc(recall, precision)
228
+ plt.plot(recall, precision, label=f"{model.upper()} (AUC={pr_auc:.4f})")
229
+
230
+ plt.xlabel("Recall")
231
+ plt.ylabel("Precision")
232
+ plt.ylim(*ylim_pr)
233
+ plt.axhline(y=positive_ratio, linestyle='--', color='gray', label='Random Baseline')
234
+ plt.title("Precision-Recall Curve")
235
+ plt.legend()
236
+
237
+ plt.tight_layout()
238
+ if save_path:
239
+ save_name = f"ROC_PR_curves"
240
+ os.makedirs(save_path, exist_ok=True)
241
+ safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
242
+ out_file = os.path.join(save_path, f"{safe_name}.png")
243
+ plt.savefig(out_file, dpi=300)
244
+ print(f"📁 Saved: {out_file}")
245
+ plt.show()
246
+
247
+ def plot_model_curves_from_adata_with_frequency_grid(
248
+ adata,
249
+ label_col='activity_status',
250
+ model_names=["cnn", "mlp", "rf"],
251
+ suffix='GpC_site_CpG_site',
252
+ omit_training=True,
253
+ save_path=None,
254
+ ylim_roc=(0.0, 1.05),
255
+ ylim_pr=(0.0, 1.05),
256
+ pos_sample_count=500,
257
+ pos_freq_list=[0.01, 0.05, 0.1],
258
+ show_f1_iso_curves=False,
259
+ f1_levels=None):
260
+ import numpy as np
261
+ import matplotlib.pyplot as plt
262
+ import seaborn as sns
263
+ import os
264
+ from sklearn.metrics import precision_recall_curve, roc_curve, auc
265
+
266
+ if f1_levels is None:
267
+ f1_levels = np.linspace(0.2, 0.9, 8)
268
+
269
+ if omit_training:
270
+ subset = adata[adata.obs['used_for_training'].astype(bool) == False]
271
+ else:
272
+ subset = adata
273
+
274
+ label = subset.obs[label_col].map({'Active': 1, 'Silent': 0}).values
275
+ subset = subset.copy()
276
+ subset.obs["__label__"] = label
277
+
278
+ pos_indices = np.where(label == 1)[0]
279
+ neg_indices = np.where(label == 0)[0]
280
+
281
+ n_rows = len(pos_freq_list)
282
+ fig, axes = plt.subplots(n_rows, 2, figsize=(12, 5 * n_rows))
283
+ fig.suptitle(f'{suffix} Performance metrics')
284
+
285
+ for row_idx, pos_freq in enumerate(pos_freq_list):
286
+ desired_total = int(pos_sample_count / pos_freq)
287
+ neg_sample_count = desired_total - pos_sample_count
288
+
289
+ if pos_sample_count > len(pos_indices) or neg_sample_count > len(neg_indices):
290
+ print(f"⚠️ Skipping frequency {pos_freq:.3f}: not enough samples.")
291
+ continue
292
+
293
+ sampled_pos = np.random.choice(pos_indices, size=pos_sample_count, replace=False)
294
+ sampled_neg = np.random.choice(neg_indices, size=neg_sample_count, replace=False)
295
+ sampled_indices = np.concatenate([sampled_pos, sampled_neg])
296
+
297
+ data_sampled = subset[sampled_indices]
298
+ y_true = data_sampled.obs["__label__"].values
299
+
300
+ ax_roc = axes[row_idx, 0] if n_rows > 1 else axes[0]
301
+ ax_pr = axes[row_idx, 1] if n_rows > 1 else axes[1]
302
+
303
+ # ROC Curve
304
+ for model in model_names:
305
+ prob_col = f"{model}_active_prob_{suffix}"
306
+ if prob_col in data_sampled.obs.columns:
307
+ probs = data_sampled.obs[prob_col].astype(float).values
308
+ fpr, tpr, _ = roc_curve(y_true, probs)
309
+ roc_auc = auc(fpr, tpr)
310
+ ax_roc.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
311
+ ax_roc.plot([0, 1], [0, 1], 'k--', alpha=0.5)
312
+ ax_roc.set_xlabel("False Positive Rate")
313
+ ax_roc.set_ylabel("True Positive Rate")
314
+ ax_roc.set_ylim(*ylim_roc)
315
+ ax_roc.set_title(f"ROC Curve (Pos Freq: {pos_freq:.2%})")
316
+ ax_roc.legend()
317
+ ax_roc.spines['top'].set_visible(False)
318
+ ax_roc.spines['right'].set_visible(False)
319
+
320
+ # PR Curve
321
+ for model in model_names:
322
+ prob_col = f"{model}_active_prob_{suffix}"
323
+ if prob_col in data_sampled.obs.columns:
324
+ probs = data_sampled.obs[prob_col].astype(float).values
325
+ precision, recall, _ = precision_recall_curve(y_true, probs)
326
+ pr_auc = auc(recall, precision)
327
+ ax_pr.plot(recall, precision, label=f"{model.upper()} (AUC={pr_auc:.4f})")
328
+ ax_pr.axhline(y=pos_freq, linestyle='--', color='gray', label='Random Baseline')
329
+
330
+ if show_f1_iso_curves:
331
+ recall_vals = np.linspace(0.01, 1, 500)
332
+ for f1 in f1_levels:
333
+ precision_vals = (f1 * recall_vals) / (2 * recall_vals - f1)
334
+ precision_vals[precision_vals < 0] = np.nan # Avoid plotting invalid values
335
+ ax_pr.plot(recall_vals, precision_vals, color='gray', linestyle=':', linewidth=1, alpha=0.6)
336
+ x_val = 0.9
337
+ y_val = (f1 * x_val) / (2 * x_val - f1)
338
+ if 0 < y_val < 1:
339
+ ax_pr.text(x_val, y_val, f"F1={f1:.1f}", fontsize=8, color='gray')
340
+
341
+ ax_pr.set_xlabel("Recall")
342
+ ax_pr.set_ylabel("Precision")
343
+ ax_pr.set_ylim(*ylim_pr)
344
+ ax_pr.set_title(f"PR Curve (Pos Freq: {pos_freq:.2%})")
345
+ ax_pr.legend()
346
+ ax_pr.spines['top'].set_visible(False)
347
+ ax_pr.spines['right'].set_visible(False)
348
+
349
+ plt.tight_layout(rect=[0, 0, 1, 0.97])
350
+ if save_path:
351
+ os.makedirs(save_path, exist_ok=True)
352
+ out_file = os.path.join(save_path, "ROC_PR_grid.png")
353
+ plt.savefig(out_file, dpi=300)
354
+ print(f"📁 Saved: {out_file}")
355
+ plt.show()
@@ -0,0 +1,205 @@
1
+ import numpy as np
2
+ import seaborn as sns
3
+ import matplotlib.pyplot as plt
4
+
5
+ def clean_barplot(ax, mean_values, title):
6
+ x = np.arange(len(mean_values))
7
+ ax.bar(x, mean_values, color="gray", width=1.0, align='edge')
8
+ ax.set_xlim(0, len(mean_values))
9
+ ax.set_ylim(0, 1)
10
+ ax.set_yticks([0.0, 0.5, 1.0])
11
+ ax.set_ylabel("Mean")
12
+ ax.set_title(title, fontsize=12, pad=2)
13
+
14
+ # Hide all spines except left
15
+ for spine_name, spine in ax.spines.items():
16
+ spine.set_visible(spine_name == 'left')
17
+
18
+ ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
19
+
20
+
21
+ def combined_hmm_raw_clustermap(
22
+ adata,
23
+ sample_col='Sample_Names',
24
+ hmm_feature_layer="hmm_combined",
25
+ layer_gpc="nan0_0minus1",
26
+ layer_cpg="nan0_0minus1",
27
+ cmap_hmm="tab10",
28
+ cmap_gpc="coolwarm",
29
+ cmap_cpg="viridis",
30
+ min_quality=20,
31
+ min_length=2700,
32
+ sample_mapping=None,
33
+ save_path=None,
34
+ normalize_hmm=False,
35
+ sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
36
+ bins=None
37
+ ):
38
+ import scipy.cluster.hierarchy as sch
39
+ import pandas as pd
40
+ import numpy as np
41
+ import seaborn as sns
42
+ import matplotlib.pyplot as plt
43
+ import matplotlib.gridspec as gridspec
44
+ import os
45
+
46
+ results = []
47
+
48
+ for ref in adata.obs["Reference_strand"].cat.categories:
49
+ for sample in adata.obs[sample_col].cat.categories:
50
+ try:
51
+ subset = adata[
52
+ (adata.obs['Reference_strand'] == ref) &
53
+ (adata.obs[sample_col] == sample) &
54
+ (adata.obs['query_read_quality'] >= min_quality) &
55
+ (adata.obs['read_length'] >= min_length) &
56
+ (adata.obs['Raw_methylation_signal'] >= 20)
57
+ ]
58
+
59
+ subset = subset[:, subset.var[f'position_in_{ref}'] == True]
60
+
61
+ if subset.shape[0] == 0:
62
+ print(f" ❌ No reads left after filtering for {sample} - {ref}")
63
+ continue
64
+
65
+ if bins:
66
+ pass
67
+ else:
68
+ bins = {"All": (subset.obs['Reference_strand'] != None)}
69
+
70
+ # Get column positions (not var_names!) of site masks
71
+ gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
72
+ cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
73
+
74
+ # Use var_names for x-axis tick labels
75
+ gpc_labels = subset.var_names[gpc_sites].astype(int)
76
+ cpg_labels = subset.var_names[cpg_sites].astype(int)
77
+
78
+ stacked_hmm_feature, stacked_gpc, stacked_cpg = [], [], []
79
+ row_labels, bin_labels = [], []
80
+ bin_boundaries = []
81
+
82
+ total_reads = subset.shape[0]
83
+ percentages = {}
84
+ last_idx = 0
85
+
86
+ for bin_label, bin_filter in bins.items():
87
+ subset_bin = subset[bin_filter].copy()
88
+ num_reads = subset_bin.shape[0]
89
+ percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
90
+ percentages[bin_label] = percent_reads
91
+
92
+ if num_reads > 0:
93
+ # Determine sorting order
94
+ if sort_by.startswith("obs:"):
95
+ colname = sort_by.split("obs:")[1]
96
+ order = np.argsort(subset_bin.obs[colname].values)
97
+ elif sort_by == "gpc":
98
+ linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
99
+ order = sch.leaves_list(linkage)
100
+ elif sort_by == "cpg":
101
+ linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
102
+ order = sch.leaves_list(linkage)
103
+ elif sort_by == "gpc_cpg":
104
+ linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
105
+ order = sch.leaves_list(linkage)
106
+ elif sort_by == "none":
107
+ order = np.arange(num_reads)
108
+ else:
109
+ raise ValueError(f"Unsupported sort_by option: {sort_by}")
110
+
111
+ stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
112
+ stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
113
+ stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
114
+
115
+ row_labels.extend([bin_label] * num_reads)
116
+ bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
117
+ last_idx += num_reads
118
+ bin_boundaries.append(last_idx)
119
+
120
+ if stacked_hmm_feature:
121
+ hmm_matrix = np.vstack(stacked_hmm_feature)
122
+ gpc_matrix = np.vstack(stacked_gpc)
123
+ cpg_matrix = np.vstack(stacked_cpg)
124
+
125
+ def normalized_mean(matrix):
126
+ mean = np.nanmean(matrix, axis=0)
127
+ normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
128
+ return normalized
129
+
130
+ def methylation_fraction(matrix):
131
+ methylated = (matrix == 1).sum(axis=0)
132
+ valid = (matrix != 0).sum(axis=0)
133
+ return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
134
+
135
+ if normalize_hmm:
136
+ mean_hmm = normalized_mean(hmm_matrix)
137
+ else:
138
+ mean_hmm = np.nanmean(hmm_matrix, axis=0)
139
+ mean_gpc = methylation_fraction(gpc_matrix)
140
+ mean_cpg = methylation_fraction(cpg_matrix)
141
+
142
+ fig = plt.figure(figsize=(18, 12))
143
+ gs = gridspec.GridSpec(2, 3, height_ratios=[1, 6], hspace=0.01)
144
+ fig.suptitle(f"{sample} - {ref}", fontsize=14, y=0.95)
145
+
146
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(3)]
147
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(3)]
148
+
149
+ clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
150
+ clean_barplot(axes_bar[1], mean_gpc, f"GpC Methylation")
151
+ clean_barplot(axes_bar[2], mean_cpg, f"CpG Methylation")
152
+
153
+ hmm_labels = subset.var_names.astype(int)
154
+ hmm_label_spacing = 150
155
+ sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
156
+ axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
157
+ axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
158
+ for boundary in bin_boundaries[:-1]:
159
+ axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
160
+
161
+ sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
162
+ axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
163
+ axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
164
+ for boundary in bin_boundaries[:-1]:
165
+ axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
166
+
167
+ sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
168
+ axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
169
+ for boundary in bin_boundaries[:-1]:
170
+ axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
171
+
172
+ plt.tight_layout()
173
+
174
+ if save_path:
175
+ save_name = f"{ref} — {sample}"
176
+ os.makedirs(save_path, exist_ok=True)
177
+ safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
178
+ out_file = os.path.join(save_path, f"{safe_name}.png")
179
+ plt.savefig(out_file, dpi=300)
180
+ print(f"📁 Saved: {out_file}")
181
+
182
+ plt.show()
183
+
184
+ print(f"📊 Summary for {sample} - {ref}:")
185
+ for bin_label, percent in percentages.items():
186
+ print(f" - {bin_label}: {percent:.1f}%")
187
+
188
+ results.append({
189
+ "sample": sample,
190
+ "ref": ref,
191
+ "hmm_matrix": hmm_matrix,
192
+ "gpc_matrix": gpc_matrix,
193
+ "cpg_matrix": cpg_matrix,
194
+ "row_labels": row_labels,
195
+ "bin_labels": bin_labels,
196
+ "bin_boundaries": bin_boundaries,
197
+ "percentages": percentages
198
+ })
199
+
200
+ adata.uns['clustermap_results'] = results
201
+
202
+ except Exception as e:
203
+ import traceback
204
+ traceback.print_exc()
205
+ continue