smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +49 -7
  4. smftools/cli/hmm_adata.py +250 -32
  5. smftools/cli/latent_adata.py +773 -0
  6. smftools/cli/load_adata.py +78 -74
  7. smftools/cli/preprocess_adata.py +122 -58
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +74 -112
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +52 -4
  12. smftools/config/conversion.yaml +1 -1
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +85 -12
  15. smftools/config/experiment_config.py +146 -1
  16. smftools/constants.py +69 -0
  17. smftools/hmm/HMM.py +88 -0
  18. smftools/hmm/call_hmm_peaks.py +1 -1
  19. smftools/informatics/__init__.py +6 -0
  20. smftools/informatics/bam_functions.py +358 -8
  21. smftools/informatics/binarize_converted_base_identities.py +2 -89
  22. smftools/informatics/converted_BAM_to_adata.py +636 -175
  23. smftools/informatics/h5ad_functions.py +198 -2
  24. smftools/informatics/modkit_extract_to_adata.py +1007 -425
  25. smftools/informatics/sequence_encoding.py +72 -0
  26. smftools/logging_utils.py +21 -2
  27. smftools/metadata.py +1 -1
  28. smftools/plotting/__init__.py +26 -3
  29. smftools/plotting/autocorrelation_plotting.py +22 -4
  30. smftools/plotting/chimeric_plotting.py +1893 -0
  31. smftools/plotting/classifiers.py +28 -14
  32. smftools/plotting/general_plotting.py +62 -1583
  33. smftools/plotting/hmm_plotting.py +1670 -8
  34. smftools/plotting/latent_plotting.py +804 -0
  35. smftools/plotting/plotting_utils.py +243 -0
  36. smftools/plotting/position_stats.py +16 -8
  37. smftools/plotting/preprocess_plotting.py +281 -0
  38. smftools/plotting/qc_plotting.py +8 -3
  39. smftools/plotting/spatial_plotting.py +1134 -0
  40. smftools/plotting/variant_plotting.py +1231 -0
  41. smftools/preprocessing/__init__.py +4 -0
  42. smftools/preprocessing/append_base_context.py +18 -18
  43. smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
  44. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  45. smftools/preprocessing/append_variant_call_layer.py +480 -0
  46. smftools/preprocessing/calculate_consensus.py +1 -1
  47. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  48. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  49. smftools/preprocessing/invert_adata.py +1 -0
  50. smftools/readwrite.py +159 -99
  51. smftools/schema/anndata_schema_v1.yaml +15 -1
  52. smftools/tools/__init__.py +10 -0
  53. smftools/tools/calculate_knn.py +121 -0
  54. smftools/tools/calculate_leiden.py +57 -0
  55. smftools/tools/calculate_nmf.py +130 -0
  56. smftools/tools/calculate_pca.py +180 -0
  57. smftools/tools/calculate_umap.py +79 -80
  58. smftools/tools/position_stats.py +4 -4
  59. smftools/tools/rolling_nn_distance.py +872 -0
  60. smftools/tools/sequence_alignment.py +140 -0
  61. smftools/tools/tensor_factorization.py +217 -0
  62. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
  63. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
  64. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  65. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  66. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,1585 +1,64 @@
1
1
  from __future__ import annotations
2
2
 
3
- import math
4
- import os
5
- from pathlib import Path
6
- from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
7
-
8
- import numpy as np
9
- import pandas as pd
10
- import scipy.cluster.hierarchy as sch
11
-
12
- from smftools.optional_imports import require
13
-
14
- gridspec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
15
- plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
16
- sns = require("seaborn", extra="plotting", purpose="plot styling")
17
-
18
-
19
- def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
20
- """
21
- Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
22
- Always includes 0 and n_positions-1 when possible.
23
- """
24
- n_ticks = int(max(2, n_ticks))
25
- if n_positions <= n_ticks:
26
- return np.arange(n_positions)
27
-
28
- # linspace gives fixed count
29
- pos = np.linspace(0, n_positions - 1, n_ticks)
30
- return np.unique(np.round(pos).astype(int))
31
-
32
-
33
- def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix: str | None):
34
- """
35
- Select tick labels for the heatmap axis.
36
-
37
- Parameters
38
- ----------
39
- subset : AnnData view
40
- The per-bin subset of the AnnData.
41
- sites : np.ndarray[int]
42
- Indices of the subset.var positions to annotate.
43
- reference : str
44
- Reference name (e.g., '6B6_top').
45
- index_col_suffix : None or str
46
- If None → use subset.var_names
47
- Else → use subset.var[f"{reference}_{index_col_suffix}"]
48
-
49
- Returns
50
- -------
51
- np.ndarray[str]
52
- The labels to use for tick positions.
53
- """
54
- if sites.size == 0:
55
- return np.array([])
56
-
57
- # Default behavior: use var_names
58
- if index_col_suffix is None:
59
- return subset.var_names[sites].astype(str)
60
-
61
- # Otherwise: use a computed column adata.var[f"{reference}_{suffix}"]
62
- colname = f"{reference}_{index_col_suffix}"
63
-
64
- if colname not in subset.var:
65
- raise KeyError(
66
- f"index_col_suffix='{index_col_suffix}' requires var column '{colname}', "
67
- f"but it is not present in adata.var."
68
- )
69
-
70
- labels = subset.var[colname].astype(str).values
71
- return labels[sites]
72
-
73
-
74
- def normalized_mean(matrix: np.ndarray) -> np.ndarray:
75
- """Compute normalized column means for a matrix.
76
-
77
- Args:
78
- matrix: Input matrix.
79
-
80
- Returns:
81
- 1D array of normalized means.
82
- """
83
- mean = np.nanmean(matrix, axis=0)
84
- denom = (mean.max() - mean.min()) + 1e-9
85
- return (mean - mean.min()) / denom
86
-
87
-
88
- def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
89
- """
90
- Fraction methylated per column.
91
- Methylated = 1
92
- Valid = finite AND not 0
93
- """
94
- matrix = np.asarray(matrix)
95
- valid_mask = np.isfinite(matrix) & (matrix != 0)
96
- methyl_mask = (matrix == 1) & np.isfinite(matrix)
97
-
98
- methylated = methyl_mask.sum(axis=0)
99
- valid = valid_mask.sum(axis=0)
100
-
101
- return np.divide(
102
- methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0
103
- )
104
-
105
-
106
- def clean_barplot(ax, mean_values, title):
107
- """Format a barplot with consistent axes and labels.
108
-
109
- Args:
110
- ax: Matplotlib axes.
111
- mean_values: Values to plot.
112
- title: Plot title.
113
- """
114
- x = np.arange(len(mean_values))
115
- ax.bar(x, mean_values, color="gray", width=1.0, align="edge")
116
- ax.set_xlim(0, len(mean_values))
117
- ax.set_ylim(0, 1)
118
- ax.set_yticks([0.0, 0.5, 1.0])
119
- ax.set_ylabel("Mean")
120
- ax.set_title(title, fontsize=12, pad=2)
121
-
122
- # Hide all spines except left
123
- for spine_name, spine in ax.spines.items():
124
- spine.set_visible(spine_name == "left")
125
-
126
- ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
127
-
128
-
129
- # def combined_hmm_raw_clustermap(
130
- # adata,
131
- # sample_col='Sample_Names',
132
- # reference_col='Reference_strand',
133
- # hmm_feature_layer="hmm_combined",
134
- # layer_gpc="nan0_0minus1",
135
- # layer_cpg="nan0_0minus1",
136
- # layer_any_c="nan0_0minus1",
137
- # cmap_hmm="tab10",
138
- # cmap_gpc="coolwarm",
139
- # cmap_cpg="viridis",
140
- # cmap_any_c='coolwarm',
141
- # min_quality=20,
142
- # min_length=200,
143
- # min_mapped_length_to_reference_length_ratio=0.8,
144
- # min_position_valid_fraction=0.5,
145
- # sample_mapping=None,
146
- # save_path=None,
147
- # normalize_hmm=False,
148
- # sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
149
- # bins=None,
150
- # deaminase=False,
151
- # min_signal=0
152
- # ):
153
-
154
- # results = []
155
- # if deaminase:
156
- # signal_type = 'deamination'
157
- # else:
158
- # signal_type = 'methylation'
159
-
160
- # for ref in adata.obs[reference_col].cat.categories:
161
- # for sample in adata.obs[sample_col].cat.categories:
162
- # try:
163
- # subset = adata[
164
- # (adata.obs[reference_col] == ref) &
165
- # (adata.obs[sample_col] == sample) &
166
- # (adata.obs['read_quality'] >= min_quality) &
167
- # (adata.obs['read_length'] >= min_length) &
168
- # (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
169
- # ]
170
-
171
- # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
172
- # subset = subset[:, mask]
173
-
174
- # if subset.shape[0] == 0:
175
- # print(f" No reads left after filtering for {sample} - {ref}")
176
- # continue
177
-
178
- # if bins:
179
- # print(f"Using defined bins to subset clustermap for {sample} - {ref}")
180
- # bins_temp = bins
181
- # else:
182
- # print(f"Using all reads for clustermap for {sample} - {ref}")
183
- # bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
184
-
185
- # # Get column positions (not var_names!) of site masks
186
- # gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
187
- # cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
188
- # any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
189
- # num_gpc = len(gpc_sites)
190
- # num_cpg = len(cpg_sites)
191
- # num_c = len(any_c_sites)
192
- # print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
193
-
194
- # # Use var_names for x-axis tick labels
195
- # gpc_labels = subset.var_names[gpc_sites].astype(int)
196
- # cpg_labels = subset.var_names[cpg_sites].astype(int)
197
- # any_c_labels = subset.var_names[any_c_sites].astype(int)
198
-
199
- # stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
200
- # row_labels, bin_labels = [], []
201
- # bin_boundaries = []
202
-
203
- # total_reads = subset.shape[0]
204
- # percentages = {}
205
- # last_idx = 0
206
-
207
- # for bin_label, bin_filter in bins_temp.items():
208
- # subset_bin = subset[bin_filter].copy()
209
- # num_reads = subset_bin.shape[0]
210
- # print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
211
- # percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
212
- # percentages[bin_label] = percent_reads
213
-
214
- # if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
215
- # # Determine sorting order
216
- # if sort_by.startswith("obs:"):
217
- # colname = sort_by.split("obs:")[1]
218
- # order = np.argsort(subset_bin.obs[colname].values)
219
- # elif sort_by == "gpc":
220
- # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
221
- # order = sch.leaves_list(linkage)
222
- # elif sort_by == "cpg":
223
- # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
224
- # order = sch.leaves_list(linkage)
225
- # elif sort_by == "gpc_cpg":
226
- # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
227
- # order = sch.leaves_list(linkage)
228
- # elif sort_by == "none":
229
- # order = np.arange(num_reads)
230
- # elif sort_by == "any_c":
231
- # linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
232
- # order = sch.leaves_list(linkage)
233
- # else:
234
- # raise ValueError(f"Unsupported sort_by option: {sort_by}")
235
-
236
- # stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
237
- # stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
238
- # stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
239
- # stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
240
-
241
- # row_labels.extend([bin_label] * num_reads)
242
- # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
243
- # last_idx += num_reads
244
- # bin_boundaries.append(last_idx)
245
-
246
- # if stacked_hmm_feature:
247
- # hmm_matrix = np.vstack(stacked_hmm_feature)
248
- # gpc_matrix = np.vstack(stacked_gpc)
249
- # cpg_matrix = np.vstack(stacked_cpg)
250
- # any_c_matrix = np.vstack(stacked_any_c)
251
-
252
- # if hmm_matrix.size > 0:
253
- # def normalized_mean(matrix):
254
- # mean = np.nanmean(matrix, axis=0)
255
- # normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
256
- # return normalized
257
-
258
- # def methylation_fraction(matrix):
259
- # methylated = (matrix == 1).sum(axis=0)
260
- # valid = (matrix != 0).sum(axis=0)
261
- # return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
262
-
263
- # if normalize_hmm:
264
- # mean_hmm = normalized_mean(hmm_matrix)
265
- # else:
266
- # mean_hmm = np.nanmean(hmm_matrix, axis=0)
267
- # mean_gpc = methylation_fraction(gpc_matrix)
268
- # mean_cpg = methylation_fraction(cpg_matrix)
269
- # mean_any_c = methylation_fraction(any_c_matrix)
270
-
271
- # fig = plt.figure(figsize=(18, 12))
272
- # gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
273
- # fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
274
-
275
- # axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
276
- # axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
277
-
278
- # clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
279
- # clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
280
- # clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
281
- # clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
282
-
283
- # hmm_labels = subset.var_names.astype(int)
284
- # hmm_label_spacing = 150
285
- # sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
286
- # axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
287
- # axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
288
- # for boundary in bin_boundaries[:-1]:
289
- # axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
290
-
291
- # sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
292
- # axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
293
- # axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
294
- # for boundary in bin_boundaries[:-1]:
295
- # axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
296
-
297
- # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
298
- # axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
299
- # for boundary in bin_boundaries[:-1]:
300
- # axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
301
-
302
- # sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
303
- # axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
304
- # axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
305
- # for boundary in bin_boundaries[:-1]:
306
- # axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
307
-
308
- # plt.tight_layout()
309
-
310
- # if save_path:
311
- # save_name = f"{ref} — {sample}"
312
- # os.makedirs(save_path, exist_ok=True)
313
- # safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
314
- # out_file = os.path.join(save_path, f"{safe_name}.png")
315
- # plt.savefig(out_file, dpi=300)
316
- # print(f"Saved: {out_file}")
317
- # plt.close()
318
- # else:
319
- # plt.show()
320
-
321
- # print(f"Summary for {sample} - {ref}:")
322
- # for bin_label, percent in percentages.items():
323
- # print(f" - {bin_label}: {percent:.1f}%")
324
-
325
- # results.append({
326
- # "sample": sample,
327
- # "ref": ref,
328
- # "hmm_matrix": hmm_matrix,
329
- # "gpc_matrix": gpc_matrix,
330
- # "cpg_matrix": cpg_matrix,
331
- # "row_labels": row_labels,
332
- # "bin_labels": bin_labels,
333
- # "bin_boundaries": bin_boundaries,
334
- # "percentages": percentages
335
- # })
336
-
337
- # #adata.uns['clustermap_results'] = results
338
-
339
- # except Exception as e:
340
- # import traceback
341
- # traceback.print_exc()
342
- # continue
343
-
344
-
345
- def combined_hmm_raw_clustermap(
346
- adata,
347
- sample_col: str = "Sample_Names",
348
- reference_col: str = "Reference_strand",
349
- hmm_feature_layer: str = "hmm_combined",
350
- layer_gpc: str = "nan0_0minus1",
351
- layer_cpg: str = "nan0_0minus1",
352
- layer_c: str = "nan0_0minus1",
353
- layer_a: str = "nan0_0minus1",
354
- cmap_hmm: str = "tab10",
355
- cmap_gpc: str = "coolwarm",
356
- cmap_cpg: str = "viridis",
357
- cmap_c: str = "coolwarm",
358
- cmap_a: str = "coolwarm",
359
- min_quality: int = 20,
360
- min_length: int = 200,
361
- min_mapped_length_to_reference_length_ratio: float = 0.8,
362
- min_position_valid_fraction: float = 0.5,
363
- demux_types: Sequence[str] = ("single", "double", "already"),
364
- sample_mapping: Optional[Mapping[str, str]] = None,
365
- save_path: str | Path | None = None,
366
- normalize_hmm: bool = False,
367
- sort_by: str = "gpc",
368
- bins: Optional[Dict[str, Any]] = None,
369
- deaminase: bool = False,
370
- min_signal: float = 0.0,
371
- # ---- fixed tick label controls (counts, not spacing)
372
- n_xticks_hmm: int = 10,
373
- n_xticks_any_c: int = 8,
374
- n_xticks_gpc: int = 8,
375
- n_xticks_cpg: int = 8,
376
- n_xticks_a: int = 8,
377
- index_col_suffix: str | None = None,
378
- ):
379
- """
380
- Makes a multi-panel clustermap per (sample, reference):
381
- HMM panel (always) + optional raw panels for C, GpC, CpG, and A sites.
382
-
383
- Panels are added only if the corresponding site mask exists AND has >0 sites.
384
-
385
- sort_by options:
386
- 'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
387
- """
388
-
389
- def pick_xticks(labels: np.ndarray, n_ticks: int):
390
- """Pick tick indices/labels from an array."""
391
- if labels.size == 0:
392
- return [], []
393
- idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
394
- idx = np.unique(idx)
395
- return idx.tolist(), labels[idx].tolist()
396
-
397
- # Helper: build a True mask if filter is inactive or column missing
398
- def _mask_or_true(series_name: str, predicate):
399
- """Return a mask from predicate or an all-True mask."""
400
- if series_name not in adata.obs:
401
- return pd.Series(True, index=adata.obs.index)
402
- s = adata.obs[series_name]
403
- try:
404
- return predicate(s)
405
- except Exception:
406
- # Fallback: all True if bad dtype / predicate failure
407
- return pd.Series(True, index=adata.obs.index)
408
-
409
- results = []
410
- signal_type = "deamination" if deaminase else "methylation"
411
-
412
- for ref in adata.obs[reference_col].cat.categories:
413
- for sample in adata.obs[sample_col].cat.categories:
414
- # Optionally remap sample label for display
415
- display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
416
- # Row-level masks (obs)
417
- qmask = _mask_or_true(
418
- "read_quality",
419
- (lambda s: s >= float(min_quality))
420
- if (min_quality is not None)
421
- else (lambda s: pd.Series(True, index=s.index)),
422
- )
423
- lm_mask = _mask_or_true(
424
- "mapped_length",
425
- (lambda s: s >= float(min_length))
426
- if (min_length is not None)
427
- else (lambda s: pd.Series(True, index=s.index)),
428
- )
429
- lrr_mask = _mask_or_true(
430
- "mapped_length_to_reference_length_ratio",
431
- (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
432
- if (min_mapped_length_to_reference_length_ratio is not None)
433
- else (lambda s: pd.Series(True, index=s.index)),
434
- )
435
-
436
- demux_mask = _mask_or_true(
437
- "demux_type",
438
- (lambda s: s.astype("string").isin(list(demux_types)))
439
- if (demux_types is not None)
440
- else (lambda s: pd.Series(True, index=s.index)),
441
- )
442
-
443
- ref_mask = adata.obs[reference_col] == ref
444
- sample_mask = adata.obs[sample_col] == sample
445
-
446
- row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
447
-
448
- if not bool(row_mask.any()):
449
- print(
450
- f"No reads for {display_sample} - {ref} after read quality and length filtering"
451
- )
452
- continue
453
-
454
- try:
455
- # ---- subset reads ----
456
- subset = adata[row_mask, :].copy()
457
-
458
- # Column-level mask (var)
459
- if min_position_valid_fraction is not None:
460
- valid_key = f"{ref}_valid_fraction"
461
- if valid_key in subset.var:
462
- v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
463
- col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
464
- if col_mask.any():
465
- subset = subset[:, col_mask].copy()
466
- else:
467
- print(
468
- f"No positions left after valid_fraction filter for {display_sample} - {ref}"
469
- )
470
- continue
471
-
472
- if subset.shape[0] == 0:
473
- print(f"No reads left after filtering for {display_sample} - {ref}")
474
- continue
475
-
476
- # ---- bins ----
477
- if bins is None:
478
- bins_temp = {"All": np.ones(subset.n_obs, dtype=bool)}
479
- else:
480
- bins_temp = bins
481
-
482
- # ---- site masks (robust) ----
483
- def _sites(*keys):
484
- """Return indices for the first matching site key."""
485
- for k in keys:
486
- if k in subset.var:
487
- return np.where(subset.var[k].values)[0]
488
- return np.array([], dtype=int)
489
-
490
- gpc_sites = _sites(f"{ref}_GpC_site")
491
- cpg_sites = _sites(f"{ref}_CpG_site")
492
- any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
493
- any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
494
-
495
- # ---- labels via _select_labels ----
496
- # HMM uses *all* columns
497
- hmm_sites = np.arange(subset.n_vars, dtype=int)
498
- hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
499
- gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
500
- cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
501
- any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
502
- any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
503
-
504
- # storage
505
- stacked_hmm = []
506
- stacked_any_c = []
507
- stacked_gpc = []
508
- stacked_cpg = []
509
- stacked_any_a = []
510
-
511
- row_labels, bin_labels, bin_boundaries = [], [], []
512
- total_reads = subset.n_obs
513
- percentages = {}
514
- last_idx = 0
515
-
516
- # ---------------- process bins ----------------
517
- for bin_label, bin_filter in bins_temp.items():
518
- sb = subset[bin_filter].copy()
519
- n = sb.n_obs
520
- if n == 0:
521
- continue
522
-
523
- pct = (n / total_reads) * 100 if total_reads else 0
524
- percentages[bin_label] = pct
525
-
526
- # ---- sorting ----
527
- if sort_by.startswith("obs:"):
528
- colname = sort_by.split("obs:")[1]
529
- order = np.argsort(sb.obs[colname].values)
530
-
531
- elif sort_by == "gpc" and gpc_sites.size:
532
- linkage = sch.linkage(sb[:, gpc_sites].layers[layer_gpc], method="ward")
533
- order = sch.leaves_list(linkage)
534
-
535
- elif sort_by == "cpg" and cpg_sites.size:
536
- linkage = sch.linkage(sb[:, cpg_sites].layers[layer_cpg], method="ward")
537
- order = sch.leaves_list(linkage)
538
-
539
- elif sort_by == "c" and any_c_sites.size:
540
- linkage = sch.linkage(sb[:, any_c_sites].layers[layer_c], method="ward")
541
- order = sch.leaves_list(linkage)
542
-
543
- elif sort_by == "a" and any_a_sites.size:
544
- linkage = sch.linkage(sb[:, any_a_sites].layers[layer_a], method="ward")
545
- order = sch.leaves_list(linkage)
546
-
547
- elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
548
- linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
549
- order = sch.leaves_list(linkage)
550
-
551
- elif sort_by == "hmm" and hmm_sites.size:
552
- linkage = sch.linkage(
553
- sb[:, hmm_sites].layers[hmm_feature_layer], method="ward"
554
- )
555
- order = sch.leaves_list(linkage)
556
-
557
- else:
558
- order = np.arange(n)
559
-
560
- sb = sb[order]
561
-
562
- # ---- collect matrices ----
563
- stacked_hmm.append(sb.layers[hmm_feature_layer])
564
- if any_c_sites.size:
565
- stacked_any_c.append(sb[:, any_c_sites].layers[layer_c])
566
- if gpc_sites.size:
567
- stacked_gpc.append(sb[:, gpc_sites].layers[layer_gpc])
568
- if cpg_sites.size:
569
- stacked_cpg.append(sb[:, cpg_sites].layers[layer_cpg])
570
- if any_a_sites.size:
571
- stacked_any_a.append(sb[:, any_a_sites].layers[layer_a])
572
-
573
- row_labels.extend([bin_label] * n)
574
- bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
575
- last_idx += n
576
- bin_boundaries.append(last_idx)
577
-
578
- # ---------------- stack ----------------
579
- hmm_matrix = np.vstack(stacked_hmm)
580
- mean_hmm = (
581
- normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
582
- )
583
-
584
- panels = [
585
- (
586
- f"HMM - {hmm_feature_layer}",
587
- hmm_matrix,
588
- hmm_labels,
589
- cmap_hmm,
590
- mean_hmm,
591
- n_xticks_hmm,
592
- ),
593
- ]
594
-
595
- if stacked_any_c:
596
- m = np.vstack(stacked_any_c)
597
- panels.append(
598
- ("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c)
599
- )
600
-
601
- if stacked_gpc:
602
- m = np.vstack(stacked_gpc)
603
- panels.append(
604
- ("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc)
605
- )
606
-
607
- if stacked_cpg:
608
- m = np.vstack(stacked_cpg)
609
- panels.append(
610
- ("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg)
611
- )
612
-
613
- if stacked_any_a:
614
- m = np.vstack(stacked_any_a)
615
- panels.append(
616
- ("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a)
617
- )
618
-
619
- # ---------------- plotting ----------------
620
- n_panels = len(panels)
621
- fig = plt.figure(figsize=(4.5 * n_panels, 10))
622
- gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
623
- fig.suptitle(
624
- f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
625
- )
626
-
627
- axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
628
- axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
629
-
630
- for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
631
- # ---- your clean barplot ----
632
- clean_barplot(axes_bar[i], mean_vec, name)
633
-
634
- # ---- heatmap ----
635
- sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i], yticklabels=False, cbar=False)
636
-
637
- # ---- xticks ----
638
- xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
639
- axes_heat[i].set_xticks(xtick_pos)
640
- axes_heat[i].set_xticklabels(xtick_labels, rotation=90, fontsize=8)
641
-
642
- for boundary in bin_boundaries[:-1]:
643
- axes_heat[i].axhline(y=boundary, color="black", linewidth=1.2)
644
-
645
- plt.tight_layout()
646
-
647
- if save_path:
648
- save_path = Path(save_path)
649
- save_path.mkdir(parents=True, exist_ok=True)
650
- safe_name = f"{ref}__{sample}".replace("/", "_")
651
- out_file = save_path / f"{safe_name}.png"
652
- plt.savefig(out_file, dpi=300)
653
- plt.close(fig)
654
- else:
655
- plt.show()
656
-
657
- except Exception:
658
- import traceback
659
-
660
- traceback.print_exc()
661
- continue
662
-
663
-
664
- # def combined_raw_clustermap(
665
- # adata,
666
- # sample_col='Sample_Names',
667
- # reference_col='Reference_strand',
668
- # mod_target_bases=['GpC', 'CpG'],
669
- # layer_any_c="nan0_0minus1",
670
- # layer_gpc="nan0_0minus1",
671
- # layer_cpg="nan0_0minus1",
672
- # layer_a="nan0_0minus1",
673
- # cmap_any_c="coolwarm",
674
- # cmap_gpc="coolwarm",
675
- # cmap_cpg="viridis",
676
- # cmap_a="coolwarm",
677
- # min_quality=20,
678
- # min_length=200,
679
- # min_mapped_length_to_reference_length_ratio=0.8,
680
- # min_position_valid_fraction=0.5,
681
- # sample_mapping=None,
682
- # save_path=None,
683
- # sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', 'any_a', or 'obs:<column>'
684
- # bins=None,
685
- # deaminase=False,
686
- # min_signal=0
687
- # ):
688
-
689
- # results = []
690
-
691
- # for ref in adata.obs[reference_col].cat.categories:
692
- # for sample in adata.obs[sample_col].cat.categories:
693
- # try:
694
- # subset = adata[
695
- # (adata.obs[reference_col] == ref) &
696
- # (adata.obs[sample_col] == sample) &
697
- # (adata.obs['read_quality'] >= min_quality) &
698
- # (adata.obs['mapped_length'] >= min_length) &
699
- # (adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
700
- # ]
701
-
702
- # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
703
- # subset = subset[:, mask]
704
-
705
- # if subset.shape[0] == 0:
706
- # print(f" No reads left after filtering for {sample} - {ref}")
707
- # continue
708
-
709
- # if bins:
710
- # print(f"Using defined bins to subset clustermap for {sample} - {ref}")
711
- # bins_temp = bins
712
- # else:
713
- # print(f"Using all reads for clustermap for {sample} - {ref}")
714
- # bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
715
-
716
- # num_any_c = 0
717
- # num_gpc = 0
718
- # num_cpg = 0
719
- # num_any_a = 0
720
-
721
- # # Get column positions (not var_names!) of site masks
722
- # if any(base in ["C", "CpG", "GpC"] for base in mod_target_bases):
723
- # any_c_sites = np.where(subset.var[f"{ref}_C_site"].values)[0]
724
- # gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
725
- # cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
726
- # num_any_c = len(any_c_sites)
727
- # num_gpc = len(gpc_sites)
728
- # num_cpg = len(cpg_sites)
729
- # print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
730
-
731
- # # Use var_names for x-axis tick labels
732
- # gpc_labels = subset.var_names[gpc_sites].astype(int)
733
- # cpg_labels = subset.var_names[cpg_sites].astype(int)
734
- # any_c_labels = subset.var_names[any_c_sites].astype(int)
735
- # stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
736
-
737
- # if "A" in mod_target_bases:
738
- # any_a_sites = np.where(subset.var[f"{ref}_A_site"].values)[0]
739
- # num_any_a = len(any_a_sites)
740
- # print(f"Found {num_any_a} any_A sites at {any_a_sites} for {sample} - {ref}")
741
- # any_a_labels = subset.var_names[any_a_sites].astype(int)
742
- # stacked_any_a = []
743
-
744
- # row_labels, bin_labels = [], []
745
- # bin_boundaries = []
746
-
747
- # total_reads = subset.shape[0]
748
- # percentages = {}
749
- # last_idx = 0
750
-
751
- # for bin_label, bin_filter in bins_temp.items():
752
- # subset_bin = subset[bin_filter].copy()
753
- # num_reads = subset_bin.shape[0]
754
- # print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
755
- # percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
756
- # percentages[bin_label] = percent_reads
757
-
758
- # if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
759
- # # Determine sorting order
760
- # if sort_by.startswith("obs:"):
761
- # colname = sort_by.split("obs:")[1]
762
- # order = np.argsort(subset_bin.obs[colname].values)
763
- # elif sort_by == "gpc":
764
- # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
765
- # order = sch.leaves_list(linkage)
766
- # elif sort_by == "cpg":
767
- # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
768
- # order = sch.leaves_list(linkage)
769
- # elif sort_by == "any_c":
770
- # linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
771
- # order = sch.leaves_list(linkage)
772
- # elif sort_by == "gpc_cpg":
773
- # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
774
- # order = sch.leaves_list(linkage)
775
- # elif sort_by == "none":
776
- # order = np.arange(num_reads)
777
- # elif sort_by == "any_a":
778
- # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
779
- # order = sch.leaves_list(linkage)
780
- # else:
781
- # raise ValueError(f"Unsupported sort_by option: {sort_by}")
782
-
783
- # stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
784
- # stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
785
- # stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
786
-
787
- # if num_reads > 0 and num_any_a > 0:
788
- # # Determine sorting order
789
- # if sort_by.startswith("obs:"):
790
- # colname = sort_by.split("obs:")[1]
791
- # order = np.argsort(subset_bin.obs[colname].values)
792
- # elif sort_by == "gpc":
793
- # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
794
- # order = sch.leaves_list(linkage)
795
- # elif sort_by == "cpg":
796
- # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
797
- # order = sch.leaves_list(linkage)
798
- # elif sort_by == "any_c":
799
- # linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
800
- # order = sch.leaves_list(linkage)
801
- # elif sort_by == "gpc_cpg":
802
- # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
803
- # order = sch.leaves_list(linkage)
804
- # elif sort_by == "none":
805
- # order = np.arange(num_reads)
806
- # elif sort_by == "any_a":
807
- # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
808
- # order = sch.leaves_list(linkage)
809
- # else:
810
- # raise ValueError(f"Unsupported sort_by option: {sort_by}")
811
-
812
- # stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
813
-
814
-
815
- # row_labels.extend([bin_label] * num_reads)
816
- # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
817
- # last_idx += num_reads
818
- # bin_boundaries.append(last_idx)
819
-
820
- # gs_dim = 0
821
-
822
- # if stacked_any_c:
823
- # any_c_matrix = np.vstack(stacked_any_c)
824
- # gpc_matrix = np.vstack(stacked_gpc)
825
- # cpg_matrix = np.vstack(stacked_cpg)
826
- # if any_c_matrix.size > 0:
827
- # mean_gpc = methylation_fraction(gpc_matrix)
828
- # mean_cpg = methylation_fraction(cpg_matrix)
829
- # mean_any_c = methylation_fraction(any_c_matrix)
830
- # gs_dim += 3
831
-
832
- # if stacked_any_a:
833
- # any_a_matrix = np.vstack(stacked_any_a)
834
- # if any_a_matrix.size > 0:
835
- # mean_any_a = methylation_fraction(any_a_matrix)
836
- # gs_dim += 1
837
-
838
-
839
- # fig = plt.figure(figsize=(18, 12))
840
- # gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
841
- # fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
842
- # axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
843
- # axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
844
-
845
- # current_ax = 0
846
-
847
- # if stacked_any_c:
848
- # if any_c_matrix.size > 0:
849
- # clean_barplot(axes_bar[current_ax], mean_any_c, f"any C site Modification Signal")
850
- # sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[current_ax], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
851
- # axes_heat[current_ax].set_xticks(range(0, len(any_c_labels), 20))
852
- # axes_heat[current_ax].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
853
- # for boundary in bin_boundaries[:-1]:
854
- # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
855
- # current_ax +=1
856
-
857
- # clean_barplot(axes_bar[current_ax], mean_gpc, f"GpC Modification Signal")
858
- # sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[current_ax], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
859
- # axes_heat[current_ax].set_xticks(range(0, len(gpc_labels), 5))
860
- # axes_heat[current_ax].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
861
- # for boundary in bin_boundaries[:-1]:
862
- # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
863
- # current_ax +=1
864
-
865
- # clean_barplot(axes_bar[current_ax], mean_cpg, f"CpG Modification Signal")
866
- # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
867
- # axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
868
- # for boundary in bin_boundaries[:-1]:
869
- # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
870
- # current_ax +=1
871
-
872
- # results.append({
873
- # "sample": sample,
874
- # "ref": ref,
875
- # "any_c_matrix": any_c_matrix,
876
- # "gpc_matrix": gpc_matrix,
877
- # "cpg_matrix": cpg_matrix,
878
- # "row_labels": row_labels,
879
- # "bin_labels": bin_labels,
880
- # "bin_boundaries": bin_boundaries,
881
- # "percentages": percentages
882
- # })
883
-
884
- # if stacked_any_a:
885
- # if any_a_matrix.size > 0:
886
- # clean_barplot(axes_bar[current_ax], mean_any_a, f"any A site Modification Signal")
887
- # sns.heatmap(any_a_matrix, cmap=cmap_a, ax=axes_heat[current_ax], xticklabels=any_a_labels[::20], yticklabels=False, cbar=False)
888
- # axes_heat[current_ax].set_xticks(range(0, len(any_a_labels), 20))
889
- # axes_heat[current_ax].set_xticklabels(any_a_labels[::20], rotation=90, fontsize=10)
890
- # for boundary in bin_boundaries[:-1]:
891
- # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
892
- # current_ax +=1
893
-
894
- # results.append({
895
- # "sample": sample,
896
- # "ref": ref,
897
- # "any_a_matrix": any_a_matrix,
898
- # "row_labels": row_labels,
899
- # "bin_labels": bin_labels,
900
- # "bin_boundaries": bin_boundaries,
901
- # "percentages": percentages
902
- # })
903
-
904
- # plt.tight_layout()
905
-
906
- # if save_path:
907
- # save_name = f"{ref} — {sample}"
908
- # os.makedirs(save_path, exist_ok=True)
909
- # safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
910
- # out_file = os.path.join(save_path, f"{safe_name}.png")
911
- # plt.savefig(out_file, dpi=300)
912
- # print(f"Saved: {out_file}")
913
- # plt.close()
914
- # else:
915
- # plt.show()
916
-
917
- # print(f"Summary for {sample} - {ref}:")
918
- # for bin_label, percent in percentages.items():
919
- # print(f" - {bin_label}: {percent:.1f}%")
920
-
921
- # adata.uns['clustermap_results'] = results
922
-
923
- # except Exception as e:
924
- # import traceback
925
- # traceback.print_exc()
926
- # continue
927
-
928
-
929
- def combined_raw_clustermap(
930
- adata,
931
- sample_col: str = "Sample_Names",
932
- reference_col: str = "Reference_strand",
933
- mod_target_bases: Sequence[str] = ("GpC", "CpG"),
934
- layer_c: str = "nan0_0minus1",
935
- layer_gpc: str = "nan0_0minus1",
936
- layer_cpg: str = "nan0_0minus1",
937
- layer_a: str = "nan0_0minus1",
938
- cmap_c: str = "coolwarm",
939
- cmap_gpc: str = "coolwarm",
940
- cmap_cpg: str = "viridis",
941
- cmap_a: str = "coolwarm",
942
- min_quality: float | None = 20,
943
- min_length: int | None = 200,
944
- min_mapped_length_to_reference_length_ratio: float | None = 0,
945
- min_position_valid_fraction: float | None = 0,
946
- demux_types: Sequence[str] = ("single", "double", "already"),
947
- sample_mapping: Optional[Mapping[str, str]] = None,
948
- save_path: str | Path | None = None,
949
- sort_by: str = "gpc", # 'gpc','cpg','c','gpc_cpg','a','none','obs:<col>'
950
- bins: Optional[Dict[str, Any]] = None,
951
- deaminase: bool = False,
952
- min_signal: float = 0,
953
- n_xticks_any_c: int = 10,
954
- n_xticks_gpc: int = 10,
955
- n_xticks_cpg: int = 10,
956
- n_xticks_any_a: int = 10,
957
- xtick_rotation: int = 90,
958
- xtick_fontsize: int = 9,
959
- index_col_suffix: str | None = None,
960
- ):
961
- """
962
- Plot stacked heatmaps + per-position mean barplots for C, GpC, CpG, and optional A.
963
-
964
- Key fixes vs old version:
965
- - order computed ONCE per bin, applied to all matrices
966
- - no hard-coded axes indices
967
- - NaNs excluded from methylation denominators
968
- - var_names not forced to int
969
- - fixed count of x tick labels per block (controllable)
970
- - adata.uns updated once at end
971
-
972
- Returns
973
- -------
974
- results : list[dict]
975
- One entry per (sample, ref) plot with matrices + bin metadata.
976
- """
977
-
978
- # Helper: build a True mask if filter is inactive or column missing
979
- def _mask_or_true(series_name: str, predicate):
980
- """Return a mask from predicate or an all-True mask."""
981
- if series_name not in adata.obs:
982
- return pd.Series(True, index=adata.obs.index)
983
- s = adata.obs[series_name]
984
- try:
985
- return predicate(s)
986
- except Exception:
987
- # Fallback: all True if bad dtype / predicate failure
988
- return pd.Series(True, index=adata.obs.index)
989
-
990
- results: List[Dict[str, Any]] = []
991
- save_path = Path(save_path) if save_path is not None else None
992
- if save_path is not None:
993
- save_path.mkdir(parents=True, exist_ok=True)
994
-
995
- # Ensure categorical
996
- for col in (sample_col, reference_col):
997
- if col not in adata.obs:
998
- raise KeyError(f"{col} not in adata.obs")
999
- if not pd.api.types.is_categorical_dtype(adata.obs[col]):
1000
- adata.obs[col] = adata.obs[col].astype("category")
1001
-
1002
- base_set = set(mod_target_bases)
1003
- include_any_c = any(b in {"C", "CpG", "GpC"} for b in base_set)
1004
- include_any_a = "A" in base_set
1005
-
1006
- for ref in adata.obs[reference_col].cat.categories:
1007
- for sample in adata.obs[sample_col].cat.categories:
1008
- # Optionally remap sample label for display
1009
- display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
1010
-
1011
- # Row-level masks (obs)
1012
- qmask = _mask_or_true(
1013
- "read_quality",
1014
- (lambda s: s >= float(min_quality))
1015
- if (min_quality is not None)
1016
- else (lambda s: pd.Series(True, index=s.index)),
1017
- )
1018
- lm_mask = _mask_or_true(
1019
- "mapped_length",
1020
- (lambda s: s >= float(min_length))
1021
- if (min_length is not None)
1022
- else (lambda s: pd.Series(True, index=s.index)),
1023
- )
1024
- lrr_mask = _mask_or_true(
1025
- "mapped_length_to_reference_length_ratio",
1026
- (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
1027
- if (min_mapped_length_to_reference_length_ratio is not None)
1028
- else (lambda s: pd.Series(True, index=s.index)),
1029
- )
1030
-
1031
- demux_mask = _mask_or_true(
1032
- "demux_type",
1033
- (lambda s: s.astype("string").isin(list(demux_types)))
1034
- if (demux_types is not None)
1035
- else (lambda s: pd.Series(True, index=s.index)),
1036
- )
1037
-
1038
- ref_mask = adata.obs[reference_col] == ref
1039
- sample_mask = adata.obs[sample_col] == sample
1040
-
1041
- row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
1042
-
1043
- if not bool(row_mask.any()):
1044
- print(
1045
- f"No reads for {display_sample} - {ref} after read quality and length filtering"
1046
- )
1047
- continue
1048
-
1049
- try:
1050
- subset = adata[row_mask, :].copy()
1051
-
1052
- # Column-level mask (var)
1053
- if min_position_valid_fraction is not None:
1054
- valid_key = f"{ref}_valid_fraction"
1055
- if valid_key in subset.var:
1056
- v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
1057
- col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
1058
- if col_mask.any():
1059
- subset = subset[:, col_mask].copy()
1060
- else:
1061
- print(
1062
- f"No positions left after valid_fraction filter for {display_sample} - {ref}"
1063
- )
1064
- continue
1065
-
1066
- if subset.shape[0] == 0:
1067
- print(f"No reads left after filtering for {display_sample} - {ref}")
1068
- continue
1069
-
1070
- # bins mode
1071
- if bins is None:
1072
- bins_temp = {"All": (subset.obs[reference_col] == ref)}
1073
- else:
1074
- bins_temp = bins
1075
-
1076
- # find sites (positions)
1077
- any_c_sites = gpc_sites = cpg_sites = np.array([], dtype=int)
1078
- any_a_sites = np.array([], dtype=int)
1079
-
1080
- num_any_c = num_gpc = num_cpg = num_any_a = 0
1081
-
1082
- if include_any_c:
1083
- any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
1084
- gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
1085
- cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
1086
-
1087
- num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
1088
-
1089
- any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
1090
- gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
1091
- cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
1092
-
1093
- if include_any_a:
1094
- any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
1095
- num_any_a = len(any_a_sites)
1096
- any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
1097
-
1098
- stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
1099
- row_labels, bin_labels, bin_boundaries = [], [], []
1100
- percentages = {}
1101
- last_idx = 0
1102
- total_reads = subset.shape[0]
1103
-
1104
- # ----------------------------
1105
- # per-bin stacking
1106
- # ----------------------------
1107
- for bin_label, bin_filter in bins_temp.items():
1108
- subset_bin = subset[bin_filter].copy()
1109
- num_reads = subset_bin.shape[0]
1110
- if num_reads == 0:
1111
- percentages[bin_label] = 0.0
1112
- continue
1113
-
1114
- percent_reads = (num_reads / total_reads) * 100
1115
- percentages[bin_label] = percent_reads
1116
-
1117
- # compute order ONCE
1118
- if sort_by.startswith("obs:"):
1119
- colname = sort_by.split("obs:")[1]
1120
- order = np.argsort(subset_bin.obs[colname].values)
1121
-
1122
- elif sort_by == "gpc" and num_gpc > 0:
1123
- linkage = sch.linkage(
1124
- subset_bin[:, gpc_sites].layers[layer_gpc], method="ward"
1125
- )
1126
- order = sch.leaves_list(linkage)
1127
-
1128
- elif sort_by == "cpg" and num_cpg > 0:
1129
- linkage = sch.linkage(
1130
- subset_bin[:, cpg_sites].layers[layer_cpg], method="ward"
1131
- )
1132
- order = sch.leaves_list(linkage)
1133
-
1134
- elif sort_by == "c" and num_any_c > 0:
1135
- linkage = sch.linkage(
1136
- subset_bin[:, any_c_sites].layers[layer_c], method="ward"
1137
- )
1138
- order = sch.leaves_list(linkage)
1139
-
1140
- elif sort_by == "gpc_cpg":
1141
- linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
1142
- order = sch.leaves_list(linkage)
1143
-
1144
- elif sort_by == "a" and num_any_a > 0:
1145
- linkage = sch.linkage(
1146
- subset_bin[:, any_a_sites].layers[layer_a], method="ward"
1147
- )
1148
- order = sch.leaves_list(linkage)
1149
-
1150
- elif sort_by == "none":
1151
- order = np.arange(num_reads)
1152
-
1153
- else:
1154
- order = np.arange(num_reads)
1155
-
1156
- subset_bin = subset_bin[order]
1157
-
1158
- # stack consistently
1159
- if include_any_c and num_any_c > 0:
1160
- stacked_any_c.append(subset_bin[:, any_c_sites].layers[layer_c])
1161
- if include_any_c and num_gpc > 0:
1162
- stacked_gpc.append(subset_bin[:, gpc_sites].layers[layer_gpc])
1163
- if include_any_c and num_cpg > 0:
1164
- stacked_cpg.append(subset_bin[:, cpg_sites].layers[layer_cpg])
1165
- if include_any_a and num_any_a > 0:
1166
- stacked_any_a.append(subset_bin[:, any_a_sites].layers[layer_a])
1167
-
1168
- row_labels.extend([bin_label] * num_reads)
1169
- bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
1170
- last_idx += num_reads
1171
- bin_boundaries.append(last_idx)
1172
-
1173
- # ----------------------------
1174
- # build matrices + means
1175
- # ----------------------------
1176
- blocks = [] # list of dicts describing what to plot in order
1177
-
1178
- if include_any_c and stacked_any_c:
1179
- any_c_matrix = np.vstack(stacked_any_c)
1180
- gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
1181
- cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
1182
-
1183
- mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
1184
- mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
1185
- mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
1186
-
1187
- if any_c_matrix.size:
1188
- blocks.append(
1189
- dict(
1190
- name="c",
1191
- matrix=any_c_matrix,
1192
- mean=mean_any_c,
1193
- labels=any_c_labels,
1194
- cmap=cmap_c,
1195
- n_xticks=n_xticks_any_c,
1196
- title="any C site Modification Signal",
1197
- )
1198
- )
1199
- if gpc_matrix.size:
1200
- blocks.append(
1201
- dict(
1202
- name="gpc",
1203
- matrix=gpc_matrix,
1204
- mean=mean_gpc,
1205
- labels=gpc_labels,
1206
- cmap=cmap_gpc,
1207
- n_xticks=n_xticks_gpc,
1208
- title="GpC Modification Signal",
1209
- )
1210
- )
1211
- if cpg_matrix.size:
1212
- blocks.append(
1213
- dict(
1214
- name="cpg",
1215
- matrix=cpg_matrix,
1216
- mean=mean_cpg,
1217
- labels=cpg_labels,
1218
- cmap=cmap_cpg,
1219
- n_xticks=n_xticks_cpg,
1220
- title="CpG Modification Signal",
1221
- )
1222
- )
1223
-
1224
- if include_any_a and stacked_any_a:
1225
- any_a_matrix = np.vstack(stacked_any_a)
1226
- mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
1227
- if any_a_matrix.size:
1228
- blocks.append(
1229
- dict(
1230
- name="a",
1231
- matrix=any_a_matrix,
1232
- mean=mean_any_a,
1233
- labels=any_a_labels,
1234
- cmap=cmap_a,
1235
- n_xticks=n_xticks_any_a,
1236
- title="any A site Modification Signal",
1237
- )
1238
- )
1239
-
1240
- if not blocks:
1241
- print(f"No matrices to plot for {display_sample} - {ref}")
1242
- continue
1243
-
1244
- gs_dim = len(blocks)
1245
- fig = plt.figure(figsize=(5.5 * gs_dim, 11))
1246
- gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.02)
1247
- fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
1248
-
1249
- axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
1250
- axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
1251
-
1252
- # ----------------------------
1253
- # plot blocks
1254
- # ----------------------------
1255
- for i, blk in enumerate(blocks):
1256
- mat = blk["matrix"]
1257
- mean = blk["mean"]
1258
- labels = np.asarray(blk["labels"], dtype=str)
1259
- n_xticks = blk["n_xticks"]
1260
-
1261
- # barplot
1262
- clean_barplot(axes_bar[i], mean, blk["title"])
1263
-
1264
- # heatmap
1265
- sns.heatmap(
1266
- mat, cmap=blk["cmap"], ax=axes_heat[i], yticklabels=False, cbar=False
1267
- )
1268
-
1269
- # fixed tick labels
1270
- tick_pos = _fixed_tick_positions(len(labels), n_xticks)
1271
- axes_heat[i].set_xticks(tick_pos)
1272
- axes_heat[i].set_xticklabels(
1273
- labels[tick_pos], rotation=xtick_rotation, fontsize=xtick_fontsize
1274
- )
1275
-
1276
- # bin separators
1277
- for boundary in bin_boundaries[:-1]:
1278
- axes_heat[i].axhline(y=boundary, color="black", linewidth=2)
1279
-
1280
- axes_heat[i].set_xlabel("Position", fontsize=9)
1281
-
1282
- plt.tight_layout()
1283
-
1284
- # save or show
1285
- if save_path is not None:
1286
- safe_name = (
1287
- f"{ref}__{display_sample}".replace("=", "")
1288
- .replace("__", "_")
1289
- .replace(",", "_")
1290
- .replace(" ", "_")
1291
- )
1292
- out_file = save_path / f"{safe_name}.png"
1293
- fig.savefig(out_file, dpi=300)
1294
- plt.close(fig)
1295
- print(f"Saved: {out_file}")
1296
- else:
1297
- plt.show()
1298
-
1299
- # record results
1300
- rec = {
1301
- "sample": str(sample),
1302
- "ref": str(ref),
1303
- "row_labels": row_labels,
1304
- "bin_labels": bin_labels,
1305
- "bin_boundaries": bin_boundaries,
1306
- "percentages": percentages,
1307
- }
1308
- for blk in blocks:
1309
- rec[f"{blk['name']}_matrix"] = blk["matrix"]
1310
- rec[f"{blk['name']}_labels"] = list(map(str, blk["labels"]))
1311
- results.append(rec)
1312
-
1313
- print(f"Summary for {display_sample} - {ref}:")
1314
- for bin_label, percent in percentages.items():
1315
- print(f" - {bin_label}: {percent:.1f}%")
1316
-
1317
- except Exception:
1318
- import traceback
1319
-
1320
- traceback.print_exc()
1321
- continue
1322
-
1323
- return results
1324
-
1325
-
1326
- def plot_hmm_layers_rolling_by_sample_ref(
1327
- adata,
1328
- layers: Optional[Sequence[str]] = None,
1329
- sample_col: str = "Barcode",
1330
- ref_col: str = "Reference_strand",
1331
- samples: Optional[Sequence[str]] = None,
1332
- references: Optional[Sequence[str]] = None,
1333
- window: int = 51,
1334
- min_periods: int = 1,
1335
- center: bool = True,
1336
- rows_per_page: int = 6,
1337
- figsize_per_cell: Tuple[float, float] = (4.0, 2.5),
1338
- dpi: int = 160,
1339
- output_dir: Optional[str] = None,
1340
- save: bool = True,
1341
- show_raw: bool = False,
1342
- cmap: str = "tab20",
1343
- use_var_coords: bool = True,
1344
- ):
1345
- """
1346
- For each sample (row) and reference (col) plot the rolling average of the
1347
- positional mean (mean across reads) for each layer listed.
1348
-
1349
- Parameters
1350
- ----------
1351
- adata : AnnData
1352
- Input annotated data (expects obs columns sample_col and ref_col).
1353
- layers : list[str] | None
1354
- Which adata.layers to plot. If None, attempts to autodetect layers whose
1355
- matrices look like "HMM" outputs (else will error). If None and layers
1356
- cannot be found, user must pass a list.
1357
- sample_col, ref_col : str
1358
- obs columns used to group rows.
1359
- samples, references : optional lists
1360
- explicit ordering of samples / references. If None, categories in adata.obs are used.
1361
- window : int
1362
- rolling window size (odd recommended). If window <= 1, no smoothing applied.
1363
- min_periods : int
1364
- min periods param for pd.Series.rolling.
1365
- center : bool
1366
- center the rolling window.
1367
- rows_per_page : int
1368
- paginate rows per page into multiple figures if needed.
1369
- figsize_per_cell : (w,h)
1370
- per-subplot size in inches.
1371
- dpi : int
1372
- figure dpi when saving.
1373
- output_dir : str | None
1374
- directory to save pages; created if necessary. If None and save=True, uses cwd.
1375
- save : bool
1376
- whether to save PNG files.
1377
- show_raw : bool
1378
- draw unsmoothed mean as faint line under smoothed curve.
1379
- cmap : str
1380
- matplotlib colormap for layer lines.
1381
- use_var_coords : bool
1382
- if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
1383
-
1384
- Returns
1385
- -------
1386
- saved_files : list[str]
1387
- list of saved filenames (may be empty if save=False).
1388
- """
1389
-
1390
- # --- basic checks / defaults ---
1391
- if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
1392
- raise ValueError(
1393
- f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
1394
- )
1395
-
1396
- # canonicalize samples / refs
1397
- if samples is None:
1398
- sseries = adata.obs[sample_col]
1399
- if not pd.api.types.is_categorical_dtype(sseries):
1400
- sseries = sseries.astype("category")
1401
- samples_all = list(sseries.cat.categories)
1402
- else:
1403
- samples_all = list(samples)
1404
-
1405
- if references is None:
1406
- rseries = adata.obs[ref_col]
1407
- if not pd.api.types.is_categorical_dtype(rseries):
1408
- rseries = rseries.astype("category")
1409
- refs_all = list(rseries.cat.categories)
1410
- else:
1411
- refs_all = list(references)
1412
-
1413
- # choose layers: if not provided, try a sensible default: all layers
1414
- if layers is None:
1415
- layers = list(adata.layers.keys())
1416
- if len(layers) == 0:
1417
- raise ValueError(
1418
- "No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
1419
- )
1420
- layers = list(layers)
1421
-
1422
- # x coordinates (positions)
1423
- try:
1424
- if use_var_coords:
1425
- x_coords = np.array([int(v) for v in adata.var_names])
1426
- else:
1427
- raise Exception("user disabled var coords")
1428
- except Exception:
1429
- # fallback to 0..n_vars-1
1430
- x_coords = np.arange(adata.shape[1], dtype=int)
1431
-
1432
- # make output dir
1433
- if save:
1434
- outdir = output_dir or os.getcwd()
1435
- os.makedirs(outdir, exist_ok=True)
1436
- else:
1437
- outdir = None
1438
-
1439
- n_samples = len(samples_all)
1440
- n_refs = len(refs_all)
1441
- total_pages = math.ceil(n_samples / rows_per_page)
1442
- saved_files = []
1443
-
1444
- # color cycle for layers
1445
- cmap_obj = plt.get_cmap(cmap)
1446
- n_layers = max(1, len(layers))
1447
- colors = [cmap_obj(i / max(1, n_layers - 1)) for i in range(n_layers)]
1448
-
1449
- for page in range(total_pages):
1450
- start = page * rows_per_page
1451
- end = min(start + rows_per_page, n_samples)
1452
- chunk = samples_all[start:end]
1453
- nrows = len(chunk)
1454
- ncols = n_refs
1455
-
1456
- fig_w = figsize_per_cell[0] * ncols
1457
- fig_h = figsize_per_cell[1] * nrows
1458
- fig, axes = plt.subplots(
1459
- nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
1460
- )
1461
-
1462
- for r_idx, sample_name in enumerate(chunk):
1463
- for c_idx, ref_name in enumerate(refs_all):
1464
- ax = axes[r_idx][c_idx]
1465
-
1466
- # subset adata
1467
- mask = (adata.obs[sample_col].values == sample_name) & (
1468
- adata.obs[ref_col].values == ref_name
1469
- )
1470
- sub = adata[mask]
1471
- if sub.n_obs == 0:
1472
- ax.text(
1473
- 0.5,
1474
- 0.5,
1475
- "No reads",
1476
- ha="center",
1477
- va="center",
1478
- transform=ax.transAxes,
1479
- color="gray",
1480
- )
1481
- ax.set_xticks([])
1482
- ax.set_yticks([])
1483
- if r_idx == 0:
1484
- ax.set_title(str(ref_name), fontsize=9)
1485
- if c_idx == 0:
1486
- total_reads = int((adata.obs[sample_col] == sample_name).sum())
1487
- ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
1488
- continue
1489
-
1490
- # for each layer, compute positional mean across reads (ignore NaNs)
1491
- plotted_any = False
1492
- for li, layer in enumerate(layers):
1493
- if layer in sub.layers:
1494
- mat = sub.layers[layer]
1495
- else:
1496
- # fallback: try .X only for the first layer if layer not present
1497
- if layer == layers[0] and getattr(sub, "X", None) is not None:
1498
- mat = sub.X
1499
- else:
1500
- # layer not present for this subset
1501
- continue
1502
-
1503
- # convert matrix to numpy 2D
1504
- if hasattr(mat, "toarray"):
1505
- try:
1506
- arr = mat.toarray()
1507
- except Exception:
1508
- arr = np.asarray(mat)
1509
- else:
1510
- arr = np.asarray(mat)
1511
-
1512
- if arr.size == 0 or arr.shape[1] == 0:
1513
- continue
1514
-
1515
- # compute column-wise mean ignoring NaNs
1516
- # if arr is boolean or int, convert to float to support NaN
1517
- arr = arr.astype(float)
1518
- with np.errstate(all="ignore"):
1519
- col_mean = np.nanmean(arr, axis=0)
1520
-
1521
- # If all-NaN, skip
1522
- if np.all(np.isnan(col_mean)):
1523
- continue
1524
-
1525
- # smooth via pandas rolling (centered)
1526
- if (window is None) or (window <= 1):
1527
- smoothed = col_mean
1528
- else:
1529
- ser = pd.Series(col_mean)
1530
- smoothed = (
1531
- ser.rolling(window=window, min_periods=min_periods, center=center)
1532
- .mean()
1533
- .to_numpy()
1534
- )
1535
-
1536
- # x axis: x_coords (trim/pad to match length)
1537
- L = len(col_mean)
1538
- x = x_coords[:L]
1539
-
1540
- # optionally plot raw faint line first
1541
- if show_raw:
1542
- ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
1543
-
1544
- ax.plot(
1545
- x,
1546
- smoothed[:L],
1547
- label=layer,
1548
- color=colors[li],
1549
- linewidth=1.2,
1550
- alpha=0.95,
1551
- zorder=2,
1552
- )
1553
- plotted_any = True
1554
-
1555
- # labels / titles
1556
- if r_idx == 0:
1557
- ax.set_title(str(ref_name), fontsize=9)
1558
- if c_idx == 0:
1559
- total_reads = int((adata.obs[sample_col] == sample_name).sum())
1560
- ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
1561
- if r_idx == nrows - 1:
1562
- ax.set_xlabel("position", fontsize=8)
1563
-
1564
- # legend (only show in top-left plot to reduce clutter)
1565
- if (r_idx == 0 and c_idx == 0) and plotted_any:
1566
- ax.legend(fontsize=7, loc="upper right")
1567
-
1568
- ax.grid(True, alpha=0.2)
1569
-
1570
- fig.suptitle(
1571
- f"Rolling mean of layer positional means (window={window}) — page {page + 1}/{total_pages}",
1572
- fontsize=11,
1573
- y=0.995,
1574
- )
1575
- fig.tight_layout(rect=[0, 0, 1, 0.97])
1576
-
1577
- if save:
1578
- fname = os.path.join(outdir, f"hmm_layers_rolling_page{page + 1}.png")
1579
- plt.savefig(fname, bbox_inches="tight", dpi=dpi)
1580
- saved_files.append(fname)
1581
- else:
1582
- plt.show()
1583
- plt.close(fig)
1584
-
1585
- return saved_files
3
+ from smftools.logging_utils import get_logger
4
+ from smftools.plotting.chimeric_plotting import (
5
+ plot_delta_hamming_summary,
6
+ plot_rolling_nn_and_layer,
7
+ plot_rolling_nn_and_two_layers,
8
+ plot_segment_length_histogram,
9
+ plot_span_length_distributions,
10
+ plot_zero_hamming_pair_counts,
11
+ plot_zero_hamming_span_and_layer,
12
+ )
13
+ from smftools.plotting.hmm_plotting import (
14
+ combined_hmm_length_clustermap,
15
+ combined_hmm_raw_clustermap,
16
+ plot_hmm_layers_rolling_by_sample_ref,
17
+ )
18
+ from smftools.plotting.latent_plotting import (
19
+ plot_cp_sequence_components,
20
+ plot_embedding,
21
+ plot_embedding_grid,
22
+ plot_nmf_components,
23
+ plot_pca,
24
+ plot_pca_components,
25
+ plot_pca_explained_variance,
26
+ plot_pca_grid,
27
+ plot_umap,
28
+ plot_umap_grid,
29
+ )
30
+ from smftools.plotting.preprocess_plotting import (
31
+ plot_read_span_quality_clustermaps,
32
+ )
33
+ from smftools.plotting.spatial_plotting import (
34
+ combined_raw_clustermap,
35
+ )
36
+ from smftools.plotting.variant_plotting import (
37
+ plot_sequence_integer_encoding_clustermaps,
38
+ )
39
+
40
+ logger = get_logger(__name__)
41
+
42
+ __all__ = [
43
+ "combined_hmm_length_clustermap",
44
+ "combined_hmm_raw_clustermap",
45
+ "combined_raw_clustermap",
46
+ "plot_rolling_nn_and_layer",
47
+ "plot_rolling_nn_and_two_layers",
48
+ "plot_segment_length_histogram",
49
+ "plot_zero_hamming_pair_counts",
50
+ "plot_zero_hamming_span_and_layer",
51
+ "plot_hmm_layers_rolling_by_sample_ref",
52
+ "plot_nmf_components",
53
+ "plot_pca_components",
54
+ "plot_cp_sequence_components",
55
+ "plot_embedding",
56
+ "plot_embedding_grid",
57
+ "plot_read_span_quality_clustermaps",
58
+ "plot_pca",
59
+ "plot_pca_grid",
60
+ "plot_pca_explained_variance",
61
+ "plot_sequence_integer_encoding_clustermaps",
62
+ "plot_umap",
63
+ "plot_umap_grid",
64
+ ]