smftools 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +18 -2
  4. smftools/cli/hmm_adata.py +18 -1
  5. smftools/cli/latent_adata.py +522 -67
  6. smftools/cli/load_adata.py +2 -2
  7. smftools/cli/preprocess_adata.py +32 -93
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +23 -109
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +41 -5
  12. smftools/config/conversion.yaml +0 -10
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +49 -13
  15. smftools/config/experiment_config.py +96 -3
  16. smftools/constants.py +4 -0
  17. smftools/hmm/call_hmm_peaks.py +1 -1
  18. smftools/informatics/binarize_converted_base_identities.py +2 -89
  19. smftools/informatics/converted_BAM_to_adata.py +53 -13
  20. smftools/informatics/h5ad_functions.py +83 -0
  21. smftools/informatics/modkit_extract_to_adata.py +4 -0
  22. smftools/plotting/__init__.py +26 -12
  23. smftools/plotting/autocorrelation_plotting.py +22 -4
  24. smftools/plotting/chimeric_plotting.py +1893 -0
  25. smftools/plotting/classifiers.py +28 -14
  26. smftools/plotting/general_plotting.py +58 -3362
  27. smftools/plotting/hmm_plotting.py +1586 -2
  28. smftools/plotting/latent_plotting.py +804 -0
  29. smftools/plotting/plotting_utils.py +243 -0
  30. smftools/plotting/position_stats.py +16 -8
  31. smftools/plotting/preprocess_plotting.py +281 -0
  32. smftools/plotting/qc_plotting.py +8 -3
  33. smftools/plotting/spatial_plotting.py +1134 -0
  34. smftools/plotting/variant_plotting.py +1231 -0
  35. smftools/preprocessing/__init__.py +3 -0
  36. smftools/preprocessing/append_base_context.py +1 -1
  37. smftools/preprocessing/append_mismatch_frequency_sites.py +35 -6
  38. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  39. smftools/preprocessing/append_variant_call_layer.py +480 -0
  40. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  41. smftools/preprocessing/invert_adata.py +1 -0
  42. smftools/readwrite.py +109 -85
  43. smftools/tools/__init__.py +6 -0
  44. smftools/tools/calculate_knn.py +121 -0
  45. smftools/tools/calculate_nmf.py +18 -7
  46. smftools/tools/calculate_pca.py +180 -0
  47. smftools/tools/calculate_umap.py +70 -154
  48. smftools/tools/position_stats.py +4 -4
  49. smftools/tools/rolling_nn_distance.py +640 -3
  50. smftools/tools/sequence_alignment.py +140 -0
  51. smftools/tools/tensor_factorization.py +52 -4
  52. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/METADATA +3 -1
  53. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/RECORD +56 -42
  54. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  55. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  56. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1134 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from math import floor
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Mapping, Optional, Sequence
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import scipy.cluster.hierarchy as sch
11
+
12
+ from smftools.logging_utils import get_logger
13
+ from smftools.optional_imports import require
14
+ from smftools.plotting.plotting_utils import (
15
+ _fixed_tick_positions,
16
+ _layer_to_numpy,
17
+ _methylation_fraction_for_layer,
18
+ _select_labels,
19
+ clean_barplot,
20
+ make_row_colors,
21
+ )
22
+
23
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
24
+ colors = require("matplotlib.colors", extra="plotting", purpose="plot rendering")
25
+ grid_spec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
26
+ sns = require("seaborn", extra="plotting", purpose="plot styling")
27
+
28
+ logger = get_logger(__name__)
29
+
30
+
31
+ def plot_rolling_nn_and_layer(
32
+ subset,
33
+ obsm_key: str = "rolling_nn_dist",
34
+ layer_key: str = "nan0_0minus1",
35
+ meta_cols: tuple[str, ...] = ("Reference_strand", "Sample"),
36
+ col_cluster: bool = False,
37
+ fill_nn_with_colmax: bool = True,
38
+ fill_layer_value: float = 0.0,
39
+ drop_all_nan_windows: bool = True,
40
+ max_nan_fraction: float | None = None,
41
+ var_valid_fraction_col: str | None = None,
42
+ var_nan_fraction_col: str | None = None,
43
+ read_span_layer: str | None = "read_span_mask",
44
+ outside_read_color: str = "#bdbdbd",
45
+ figsize: tuple[float, float] = (14, 10),
46
+ right_panel_var_mask=None, # optional boolean mask over subset.var to reduce width
47
+ robust: bool = True,
48
+ title: str | None = None,
49
+ xtick_step: int | None = None,
50
+ xtick_rotation: int = 90,
51
+ xtick_fontsize: int = 8,
52
+ save_name: str | None = None,
53
+ ):
54
+ """
55
+ 1) Cluster rows by subset.obsm[obsm_key] (rolling NN distances)
56
+ 2) Plot two heatmaps side-by-side in the SAME row order, with mean barplots above:
57
+ - left: rolling NN distance matrix
58
+ - right: subset.layers[layer_key] matrix
59
+
60
+ Handles categorical/MultiIndex issues in metadata coloring.
61
+
62
+ Args:
63
+ subset: AnnData subset with rolling NN distances stored in ``obsm``.
64
+ obsm_key: Key in ``subset.obsm`` containing rolling NN distances.
65
+ layer_key: Layer name to plot alongside rolling NN distances.
66
+ meta_cols: Obs columns used for row color annotations.
67
+ col_cluster: Whether to cluster columns in the rolling NN clustermap.
68
+ fill_nn_with_colmax: Fill NaNs in rolling NN distances with per-column max values.
69
+ fill_layer_value: Fill NaNs in the layer heatmap with this value.
70
+ drop_all_nan_windows: Drop rolling windows that are all NaN.
71
+ max_nan_fraction: Maximum allowed NaN fraction per position (filtering columns).
72
+ var_valid_fraction_col: ``subset.var`` column with valid fractions (1 - NaN fraction).
73
+ var_nan_fraction_col: ``subset.var`` column with NaN fractions.
74
+ read_span_layer: Layer name with read span mask; 0 values are treated as outside read.
75
+ outside_read_color: Color used to show positions outside each read.
76
+ figsize: Figure size for the combined plot.
77
+ right_panel_var_mask: Optional boolean mask over ``subset.var`` for the right panel.
78
+ robust: Use robust color scaling in seaborn.
79
+ title: Optional figure title (suptitle).
80
+ xtick_step: Spacing between x-axis tick labels.
81
+ xtick_rotation: Rotation for x-axis tick labels.
82
+ xtick_fontsize: Font size for x-axis tick labels.
83
+ save_name: Optional output path for saving the plot.
84
+ """
85
+ if max_nan_fraction is not None and not (0 <= max_nan_fraction <= 1):
86
+ raise ValueError("max_nan_fraction must be between 0 and 1.")
87
+
88
+ logger.info("Plotting rolling NN distances with layer '%s'.", layer_key)
89
+
90
+ def _apply_xticks(ax, labels, step):
91
+ if labels is None or len(labels) == 0:
92
+ ax.set_xticks([])
93
+ return
94
+ if step is None or step <= 0:
95
+ step = max(1, len(labels) // 10)
96
+ ticks = np.arange(0, len(labels), step)
97
+ ax.set_xticks(ticks + 0.5)
98
+ ax.set_xticklabels(
99
+ [labels[i] for i in ticks],
100
+ rotation=xtick_rotation,
101
+ fontsize=xtick_fontsize,
102
+ )
103
+
104
+ def _format_labels(values):
105
+ values = np.asarray(values)
106
+ if np.issubdtype(values.dtype, np.number):
107
+ if np.all(np.isfinite(values)) and np.all(np.isclose(values, np.round(values))):
108
+ values = np.round(values).astype(int)
109
+ return [str(v) for v in values]
110
+
111
+ X = subset.obsm[obsm_key]
112
+ valid = ~np.all(np.isnan(X), axis=1)
113
+
114
+ X_df = pd.DataFrame(X[valid], index=subset.obs_names[valid])
115
+
116
+ if drop_all_nan_windows:
117
+ X_df = X_df.loc[:, ~X_df.isna().all(axis=0)]
118
+
119
+ X_df_filled = X_df.copy()
120
+ if fill_nn_with_colmax:
121
+ col_max = X_df_filled.max(axis=0, skipna=True)
122
+ X_df_filled = X_df_filled.fillna(col_max)
123
+
124
+ X_df_filled.index = X_df_filled.index.astype(str)
125
+
126
+ meta = subset.obs.loc[X_df.index, list(meta_cols)].copy()
127
+ meta.index = meta.index.astype(str)
128
+ row_colors = make_row_colors(meta)
129
+
130
+ g = sns.clustermap(
131
+ X_df_filled,
132
+ cmap="viridis",
133
+ col_cluster=col_cluster,
134
+ row_cluster=True,
135
+ row_colors=row_colors,
136
+ xticklabels=False,
137
+ yticklabels=False,
138
+ robust=robust,
139
+ )
140
+ row_order = g.dendrogram_row.reordered_ind
141
+ ordered_index = X_df_filled.index[row_order]
142
+ plt.close(g.fig)
143
+
144
+ X_ord = X_df_filled.loc[ordered_index]
145
+
146
+ L = subset.layers[layer_key]
147
+ L = L.toarray() if hasattr(L, "toarray") else np.asarray(L)
148
+
149
+ L_df = pd.DataFrame(L[valid], index=subset.obs_names[valid], columns=subset.var_names)
150
+ L_df.index = L_df.index.astype(str)
151
+
152
+ if right_panel_var_mask is not None:
153
+ if hasattr(right_panel_var_mask, "values"):
154
+ right_panel_var_mask = right_panel_var_mask.values
155
+ right_panel_var_mask = np.asarray(right_panel_var_mask, dtype=bool)
156
+
157
+ if max_nan_fraction is not None:
158
+ nan_fraction = None
159
+ if var_nan_fraction_col and var_nan_fraction_col in subset.var:
160
+ nan_fraction = pd.to_numeric(
161
+ subset.var[var_nan_fraction_col], errors="coerce"
162
+ ).to_numpy()
163
+ elif var_valid_fraction_col and var_valid_fraction_col in subset.var:
164
+ valid_fraction = pd.to_numeric(
165
+ subset.var[var_valid_fraction_col], errors="coerce"
166
+ ).to_numpy()
167
+ nan_fraction = 1 - valid_fraction
168
+ if nan_fraction is not None:
169
+ nan_mask = nan_fraction <= max_nan_fraction
170
+ if right_panel_var_mask is None:
171
+ right_panel_var_mask = nan_mask
172
+ else:
173
+ right_panel_var_mask = right_panel_var_mask & nan_mask
174
+
175
+ if right_panel_var_mask is not None:
176
+ if right_panel_var_mask.size != L_df.shape[1]:
177
+ raise ValueError("right_panel_var_mask must align with subset.var_names.")
178
+ L_df = L_df.loc[:, right_panel_var_mask]
179
+
180
+ read_span_mask = None
181
+ if read_span_layer and read_span_layer in subset.layers:
182
+ span = subset.layers[read_span_layer]
183
+ span = span.toarray() if hasattr(span, "toarray") else np.asarray(span)
184
+ span_df = pd.DataFrame(span[valid], index=subset.obs_names[valid], columns=subset.var_names)
185
+ span_df.index = span_df.index.astype(str)
186
+ if right_panel_var_mask is not None:
187
+ span_df = span_df.loc[:, right_panel_var_mask]
188
+ read_span_mask = span_df.loc[ordered_index].to_numpy() == 0
189
+
190
+ L_ord = L_df.loc[ordered_index]
191
+ L_plot = L_ord.fillna(fill_layer_value)
192
+ if read_span_mask is not None:
193
+ L_plot = L_plot.mask(read_span_mask)
194
+
195
+ fig = plt.figure(figsize=figsize)
196
+ gs = fig.add_gridspec(
197
+ 2,
198
+ 4,
199
+ width_ratios=[1, 0.05, 1, 0.05],
200
+ height_ratios=[1, 6],
201
+ wspace=0.2,
202
+ hspace=0.05,
203
+ )
204
+
205
+ ax1 = fig.add_subplot(gs[1, 0])
206
+ ax1_cbar = fig.add_subplot(gs[1, 1])
207
+ ax2 = fig.add_subplot(gs[1, 2])
208
+ ax2_cbar = fig.add_subplot(gs[1, 3])
209
+ ax1_bar = fig.add_subplot(gs[0, 0], sharex=ax1)
210
+ ax2_bar = fig.add_subplot(gs[0, 2], sharex=ax2)
211
+ fig.add_subplot(gs[0, 1]).axis("off")
212
+ fig.add_subplot(gs[0, 3]).axis("off")
213
+
214
+ mean_nn = np.nanmean(X_ord.to_numpy(), axis=0)
215
+ clean_barplot(
216
+ ax1_bar,
217
+ mean_nn,
218
+ obsm_key,
219
+ y_max=None,
220
+ y_label="Mean distance",
221
+ y_ticks=None,
222
+ )
223
+
224
+ sns.heatmap(
225
+ X_ord,
226
+ ax=ax1,
227
+ cmap="viridis",
228
+ xticklabels=False,
229
+ yticklabels=False,
230
+ robust=robust,
231
+ cbar_ax=ax1_cbar,
232
+ )
233
+ label_source = subset.uns.get(f"{obsm_key}_centers")
234
+ if label_source is None:
235
+ label_source = subset.uns.get(f"{obsm_key}_starts")
236
+ if label_source is not None:
237
+ label_source = np.asarray(label_source)
238
+ window_labels = _format_labels(label_source)
239
+ try:
240
+ col_idx = X_ord.columns.to_numpy()
241
+ if np.issubdtype(col_idx.dtype, np.number):
242
+ col_idx = col_idx.astype(int)
243
+ if col_idx.size and col_idx.max() < len(label_source):
244
+ window_labels = _format_labels(label_source[col_idx])
245
+ except Exception:
246
+ window_labels = _format_labels(label_source)
247
+ _apply_xticks(ax1, window_labels, xtick_step)
248
+
249
+ methylation_fraction = _methylation_fraction_for_layer(L_ord.to_numpy(), layer_key)
250
+ clean_barplot(
251
+ ax2_bar,
252
+ methylation_fraction,
253
+ layer_key,
254
+ y_max=1.0,
255
+ y_label="Methylation fraction",
256
+ y_ticks=[0.0, 0.5, 1.0],
257
+ )
258
+
259
+ layer_cmap = plt.get_cmap("coolwarm").copy()
260
+ if read_span_mask is not None:
261
+ layer_cmap.set_bad(outside_read_color)
262
+
263
+ sns.heatmap(
264
+ L_plot,
265
+ ax=ax2,
266
+ cmap=layer_cmap,
267
+ xticklabels=False,
268
+ yticklabels=False,
269
+ robust=robust,
270
+ cbar_ax=ax2_cbar,
271
+ )
272
+ _apply_xticks(ax2, [str(x) for x in L_plot.columns], xtick_step)
273
+
274
+ if title:
275
+ fig.suptitle(title)
276
+
277
+ if save_name is not None:
278
+ fname = os.path.join(save_name)
279
+ plt.savefig(fname, dpi=200, bbox_inches="tight")
280
+ logger.info("Saved rolling NN/layer plot to %s.", fname)
281
+ else:
282
+ plt.show()
283
+
284
+ return ordered_index
285
+
286
+
287
+ def plot_zero_hamming_span_and_layer(
288
+ subset,
289
+ span_layer_key: str,
290
+ layer_key: str = "nan0_0minus1",
291
+ meta_cols: tuple[str, ...] = ("Reference_strand", "Sample"),
292
+ col_cluster: bool = False,
293
+ fill_span_value: float = 0.0,
294
+ fill_layer_value: float = 0.0,
295
+ drop_all_nan_positions: bool = True,
296
+ max_nan_fraction: float | None = None,
297
+ var_valid_fraction_col: str | None = None,
298
+ var_nan_fraction_col: str | None = None,
299
+ read_span_layer: str | None = "read_span_mask",
300
+ outside_read_color: str = "#bdbdbd",
301
+ span_color: str = "#2ca25f",
302
+ figsize: tuple[float, float] = (14, 10),
303
+ robust: bool = True,
304
+ title: str | None = None,
305
+ xtick_step: int | None = None,
306
+ xtick_rotation: int = 90,
307
+ xtick_fontsize: int = 8,
308
+ save_name: str | None = None,
309
+ ):
310
+ """
311
+ Plot zero-Hamming span clustermap alongside a layer clustermap.
312
+
313
+ Args:
314
+ subset: AnnData subset with zero-Hamming span annotations stored in ``layers``.
315
+ span_layer_key: Layer name with the binary zero-Hamming span mask.
316
+ layer_key: Layer name to plot alongside the span mask.
317
+ meta_cols: Obs columns used for row color annotations.
318
+ col_cluster: Whether to cluster columns in the span mask clustermap.
319
+ fill_span_value: Value to fill NaNs in the span mask.
320
+ fill_layer_value: Value to fill NaNs in the layer heatmap.
321
+ drop_all_nan_positions: Drop positions that are all NaN in the span mask.
322
+ max_nan_fraction: Maximum allowed NaN fraction per position (filtering columns).
323
+ var_valid_fraction_col: ``subset.var`` column with valid fractions (1 - NaN fraction).
324
+ var_nan_fraction_col: ``subset.var`` column with NaN fractions.
325
+ read_span_layer: Layer name with read span mask; 0 values are treated as outside read.
326
+ outside_read_color: Color used to show positions outside each read.
327
+ span_color: Color for zero-Hamming span mask values.
328
+ figsize: Figure size for the combined plot.
329
+ robust: Use robust color scaling in seaborn.
330
+ title: Optional figure title (suptitle).
331
+ xtick_step: Spacing between x-axis tick labels.
332
+ xtick_rotation: Rotation for x-axis tick labels.
333
+ xtick_fontsize: Font size for x-axis tick labels.
334
+ save_name: Optional output path for saving the plot.
335
+ """
336
+ if max_nan_fraction is not None and not (0 <= max_nan_fraction <= 1):
337
+ raise ValueError("max_nan_fraction must be between 0 and 1.")
338
+
339
+ logger.info(
340
+ "Plotting zero-Hamming span mask '%s' with layer '%s'.",
341
+ span_layer_key,
342
+ layer_key,
343
+ )
344
+
345
+ def _apply_xticks(ax, labels, step):
346
+ if labels is None or len(labels) == 0:
347
+ ax.set_xticks([])
348
+ return
349
+ if step is None or step <= 0:
350
+ step = max(1, len(labels) // 10)
351
+ ticks = np.arange(0, len(labels), step)
352
+ ax.set_xticks(ticks + 0.5)
353
+ ax.set_xticklabels(
354
+ [labels[i] for i in ticks],
355
+ rotation=xtick_rotation,
356
+ fontsize=xtick_fontsize,
357
+ )
358
+
359
+ span = subset.layers[span_layer_key]
360
+ span = span.toarray() if hasattr(span, "toarray") else np.asarray(span)
361
+ span_df = pd.DataFrame(span, index=subset.obs_names, columns=subset.var_names)
362
+ span_df.index = span_df.index.astype(str)
363
+
364
+ if drop_all_nan_positions:
365
+ span_df = span_df.loc[:, ~span_df.isna().all(axis=0)]
366
+
367
+ nan_mask = None
368
+ if max_nan_fraction is not None:
369
+ nan_fraction = None
370
+ if var_nan_fraction_col and var_nan_fraction_col in subset.var:
371
+ nan_fraction = pd.to_numeric(
372
+ subset.var[var_nan_fraction_col], errors="coerce"
373
+ ).to_numpy()
374
+ elif var_valid_fraction_col and var_valid_fraction_col in subset.var:
375
+ valid_fraction = pd.to_numeric(
376
+ subset.var[var_valid_fraction_col], errors="coerce"
377
+ ).to_numpy()
378
+ nan_fraction = 1 - valid_fraction
379
+ if nan_fraction is not None:
380
+ nan_mask = nan_fraction <= max_nan_fraction
381
+ span_df = span_df.loc[:, nan_mask]
382
+
383
+ span_df_filled = span_df.fillna(fill_span_value)
384
+ span_df_filled.index = span_df_filled.index.astype(str)
385
+
386
+ meta = subset.obs.loc[span_df.index, list(meta_cols)].copy()
387
+ meta.index = meta.index.astype(str)
388
+ row_colors = make_row_colors(meta)
389
+
390
+ span_cmap = colors.ListedColormap(["white", span_color])
391
+ span_norm = colors.BoundaryNorm([-0.5, 0.5, 1.5], span_cmap.N)
392
+
393
+ g = sns.clustermap(
394
+ span_df_filled,
395
+ cmap=span_cmap,
396
+ norm=span_norm,
397
+ col_cluster=col_cluster,
398
+ row_cluster=True,
399
+ row_colors=row_colors,
400
+ xticklabels=False,
401
+ yticklabels=False,
402
+ robust=robust,
403
+ )
404
+ row_order = g.dendrogram_row.reordered_ind
405
+ ordered_index = span_df_filled.index[row_order]
406
+ plt.close(g.fig)
407
+
408
+ span_ord = span_df_filled.loc[ordered_index]
409
+
410
+ layer = subset.layers[layer_key]
411
+ layer = layer.toarray() if hasattr(layer, "toarray") else np.asarray(layer)
412
+ layer_df = pd.DataFrame(layer, index=subset.obs_names, columns=subset.var_names)
413
+ layer_df.index = layer_df.index.astype(str)
414
+
415
+ if max_nan_fraction is not None and nan_mask is not None:
416
+ layer_df = layer_df.loc[:, nan_mask]
417
+
418
+ read_span_mask = None
419
+ if read_span_layer and read_span_layer in subset.layers:
420
+ span_mask = subset.layers[read_span_layer]
421
+ span_mask = span_mask.toarray() if hasattr(span_mask, "toarray") else np.asarray(span_mask)
422
+ span_mask_df = pd.DataFrame(span_mask, index=subset.obs_names, columns=subset.var_names)
423
+ span_mask_df.index = span_mask_df.index.astype(str)
424
+ if max_nan_fraction is not None and nan_mask is not None:
425
+ span_mask_df = span_mask_df.loc[:, nan_mask]
426
+ read_span_mask = span_mask_df.loc[ordered_index].to_numpy() == 0
427
+
428
+ layer_ord = layer_df.loc[ordered_index]
429
+ layer_plot = layer_ord.fillna(fill_layer_value)
430
+ if read_span_mask is not None:
431
+ layer_plot = layer_plot.mask(read_span_mask)
432
+
433
+ fig = plt.figure(figsize=figsize)
434
+ gs = fig.add_gridspec(
435
+ 2,
436
+ 4,
437
+ width_ratios=[1, 0.05, 1, 0.05],
438
+ height_ratios=[1, 6],
439
+ wspace=0.2,
440
+ hspace=0.05,
441
+ )
442
+
443
+ ax1 = fig.add_subplot(gs[1, 0])
444
+ ax1_cbar = fig.add_subplot(gs[1, 1])
445
+ ax2 = fig.add_subplot(gs[1, 2])
446
+ ax2_cbar = fig.add_subplot(gs[1, 3])
447
+ ax1_bar = fig.add_subplot(gs[0, 0], sharex=ax1)
448
+ ax2_bar = fig.add_subplot(gs[0, 2], sharex=ax2)
449
+ fig.add_subplot(gs[0, 1]).axis("off")
450
+ fig.add_subplot(gs[0, 3]).axis("off")
451
+
452
+ mean_span = np.nanmean(span_ord.to_numpy(), axis=0)
453
+ clean_barplot(
454
+ ax1_bar,
455
+ mean_span,
456
+ span_layer_key,
457
+ y_max=1.0,
458
+ y_label="Span fraction",
459
+ y_ticks=[0.0, 0.5, 1.0],
460
+ )
461
+
462
+ methylation_fraction = _methylation_fraction_for_layer(layer_ord.to_numpy(), layer_key)
463
+ clean_barplot(
464
+ ax2_bar,
465
+ methylation_fraction,
466
+ layer_key,
467
+ y_max=1.0,
468
+ y_label="Methylation fraction",
469
+ y_ticks=[0.0, 0.5, 1.0],
470
+ )
471
+
472
+ sns.heatmap(
473
+ span_ord,
474
+ ax=ax1,
475
+ cmap=span_cmap,
476
+ norm=span_norm,
477
+ xticklabels=False,
478
+ yticklabels=False,
479
+ robust=robust,
480
+ cbar_ax=ax1_cbar,
481
+ )
482
+
483
+ layer_cmap = plt.get_cmap("coolwarm").copy()
484
+ if read_span_mask is not None:
485
+ layer_cmap.set_bad(outside_read_color)
486
+
487
+ sns.heatmap(
488
+ layer_plot,
489
+ ax=ax2,
490
+ cmap=layer_cmap,
491
+ xticklabels=False,
492
+ yticklabels=False,
493
+ robust=robust,
494
+ cbar_ax=ax2_cbar,
495
+ )
496
+
497
+ _apply_xticks(ax1, [str(x) for x in span_ord.columns], xtick_step)
498
+ _apply_xticks(ax2, [str(x) for x in layer_plot.columns], xtick_step)
499
+
500
+ if title:
501
+ fig.suptitle(title)
502
+
503
+ if save_name is not None:
504
+ fname = os.path.join(save_name)
505
+ plt.savefig(fname, dpi=200, bbox_inches="tight")
506
+ logger.info("Saved zero-Hamming span/layer plot to %s.", fname)
507
+ else:
508
+ plt.show()
509
+
510
+ return ordered_index
511
+
512
+
513
+ def _window_center_labels(var_names: Sequence, starts: np.ndarray, window: int) -> list[str]:
514
+ coords = np.asarray(var_names)
515
+ if coords.size == 0:
516
+ return []
517
+ try:
518
+ coords_numeric = coords.astype(float)
519
+ centers = np.array(
520
+ [floor(np.nanmean(coords_numeric[s : s + window])) for s in starts], dtype=float
521
+ )
522
+ return [str(c) for c in centers]
523
+ except Exception:
524
+ mid = np.clip(starts + (window // 2), 0, coords.size - 1)
525
+ return [str(coords[idx]) for idx in mid]
526
+
527
+
528
+ def plot_zero_hamming_pair_counts(
529
+ subset,
530
+ zero_pairs_uns_key: str,
531
+ meta_cols: tuple[str, ...] = ("Reference_strand", "Sample"),
532
+ col_cluster: bool = False,
533
+ figsize: tuple[float, float] = (14, 10),
534
+ robust: bool = True,
535
+ title: str | None = None,
536
+ xtick_step: int | None = None,
537
+ xtick_rotation: int = 90,
538
+ xtick_fontsize: int = 8,
539
+ save_name: str | None = None,
540
+ ):
541
+ """
542
+ Plot a heatmap of zero-Hamming pair counts per read across rolling windows.
543
+
544
+ Args:
545
+ subset: AnnData subset containing zero-pair window data in ``.uns``.
546
+ zero_pairs_uns_key: Key in ``subset.uns`` with zero-pair window data.
547
+ meta_cols: Obs columns used for row color annotations.
548
+ col_cluster: Whether to cluster columns in the heatmap.
549
+ figsize: Figure size for the plot.
550
+ robust: Use robust color scaling in seaborn.
551
+ title: Optional figure title (suptitle).
552
+ xtick_step: Spacing between x-axis tick labels.
553
+ xtick_rotation: Rotation for x-axis tick labels.
554
+ xtick_fontsize: Font size for x-axis tick labels.
555
+ save_name: Optional output path for saving the plot.
556
+ """
557
+ if zero_pairs_uns_key not in subset.uns:
558
+ raise KeyError(f"Missing zero-pair data in subset.uns[{zero_pairs_uns_key!r}].")
559
+
560
+ zero_pairs_by_window = subset.uns[zero_pairs_uns_key]
561
+ starts = np.asarray(subset.uns.get(f"{zero_pairs_uns_key}_starts", []))
562
+ window = int(subset.uns.get(f"{zero_pairs_uns_key}_window", 0))
563
+
564
+ n_windows = len(zero_pairs_by_window)
565
+ counts = np.zeros((subset.n_obs, n_windows), dtype=int)
566
+
567
+ for wi, pairs in enumerate(zero_pairs_by_window):
568
+ if pairs is None or len(pairs) == 0:
569
+ continue
570
+ pair_arr = np.asarray(pairs, dtype=int)
571
+ if pair_arr.size == 0:
572
+ continue
573
+ if pair_arr.ndim != 2 or pair_arr.shape[1] != 2:
574
+ raise ValueError("Zero-pair entries must be arrays of shape (n, 2).")
575
+ np.add.at(counts[:, wi], pair_arr[:, 0], 1)
576
+ np.add.at(counts[:, wi], pair_arr[:, 1], 1)
577
+
578
+ if starts.size == n_windows and window > 0:
579
+ labels = _window_center_labels(subset.var_names, starts, window)
580
+ else:
581
+ labels = [str(i) for i in range(n_windows)]
582
+
583
+ counts_df = pd.DataFrame(counts, index=subset.obs_names.astype(str), columns=labels)
584
+ meta = subset.obs.loc[counts_df.index, list(meta_cols)].copy()
585
+ meta.index = meta.index.astype(str)
586
+ row_colors = make_row_colors(meta)
587
+
588
+ def _apply_xticks(ax, labels, step):
589
+ if labels is None or len(labels) == 0:
590
+ ax.set_xticks([])
591
+ return
592
+ if step is None or step <= 0:
593
+ step = max(1, len(labels) // 10)
594
+ ticks = np.arange(0, len(labels), step)
595
+ ax.set_xticks(ticks + 0.5)
596
+ ax.set_xticklabels(
597
+ [labels[i] for i in ticks],
598
+ rotation=xtick_rotation,
599
+ fontsize=xtick_fontsize,
600
+ )
601
+
602
+ g = sns.clustermap(
603
+ counts_df,
604
+ cmap="viridis",
605
+ col_cluster=col_cluster,
606
+ row_cluster=True,
607
+ row_colors=row_colors,
608
+ xticklabels=False,
609
+ yticklabels=False,
610
+ figsize=figsize,
611
+ robust=robust,
612
+ )
613
+ _apply_xticks(g.ax_heatmap, labels, xtick_step)
614
+
615
+ if title:
616
+ g.fig.suptitle(title)
617
+
618
+ if save_name is not None:
619
+ fname = os.path.join(save_name)
620
+ g.fig.savefig(fname, dpi=200, bbox_inches="tight")
621
+ logger.info("Saved zero-Hamming pair count plot to %s.", fname)
622
+ else:
623
+ plt.show()
624
+
625
+ return g
626
+
627
+
628
+ def combined_raw_clustermap(
629
+ adata,
630
+ sample_col: str = "Sample_Names",
631
+ reference_col: str = "Reference_strand",
632
+ mod_target_bases: Sequence[str] = ("GpC", "CpG"),
633
+ layer_c: str = "nan0_0minus1",
634
+ layer_gpc: str = "nan0_0minus1",
635
+ layer_cpg: str = "nan0_0minus1",
636
+ layer_a: str = "nan0_0minus1",
637
+ cmap_c: str = "coolwarm",
638
+ cmap_gpc: str = "coolwarm",
639
+ cmap_cpg: str = "viridis",
640
+ cmap_a: str = "coolwarm",
641
+ min_quality: float | None = 20,
642
+ min_length: int | None = 200,
643
+ min_mapped_length_to_reference_length_ratio: float | None = 0,
644
+ min_position_valid_fraction: float | None = 0,
645
+ demux_types: Sequence[str] = ("single", "double", "already"),
646
+ sample_mapping: Optional[Mapping[str, str]] = None,
647
+ save_path: str | Path | None = None,
648
+ sort_by: str = "gpc", # 'gpc','cpg','c','gpc_cpg','a','none','obs:<col>'
649
+ bins: Optional[Dict[str, Any]] = None,
650
+ deaminase: bool = False,
651
+ min_signal: float = 0,
652
+ n_xticks_any_c: int = 10,
653
+ n_xticks_gpc: int = 10,
654
+ n_xticks_cpg: int = 10,
655
+ n_xticks_any_a: int = 10,
656
+ xtick_rotation: int = 90,
657
+ xtick_fontsize: int = 9,
658
+ index_col_suffix: str | None = None,
659
+ fill_nan_strategy: str = "value",
660
+ fill_nan_value: float = -1,
661
+ ):
662
+ """
663
+ Plot stacked heatmaps + per-position mean barplots for C, GpC, CpG, and optional A.
664
+
665
+ Key fixes vs old version:
666
+ - order computed ONCE per bin, applied to all matrices
667
+ - no hard-coded axes indices
668
+ - NaNs excluded from methylation denominators
669
+ - var_names not forced to int
670
+ - fixed count of x tick labels per block (controllable)
671
+ - optional NaN fill strategy for clustering/plotting (in-memory only)
672
+ - adata.uns updated once at end
673
+
674
+ Returns
675
+ -------
676
+ results : list[dict]
677
+ One entry per (sample, ref) plot with matrices + bin metadata.
678
+ """
679
+ logger.info("Plotting combined raw clustermaps.")
680
+ if fill_nan_strategy not in {"none", "value", "col_mean"}:
681
+ raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
682
+
683
+ def _mask_or_true(series_name: str, predicate):
684
+ """Return a mask from predicate or an all-True mask."""
685
+ if series_name not in adata.obs:
686
+ return pd.Series(True, index=adata.obs.index)
687
+ s = adata.obs[series_name]
688
+ try:
689
+ return predicate(s)
690
+ except Exception:
691
+ return pd.Series(True, index=adata.obs.index)
692
+
693
+ results: list[Dict[str, Any]] = []
694
+ save_path = Path(save_path) if save_path is not None else None
695
+ if save_path is not None:
696
+ save_path.mkdir(parents=True, exist_ok=True)
697
+
698
+ for col in (sample_col, reference_col):
699
+ if col not in adata.obs:
700
+ raise KeyError(f"{col} not in adata.obs")
701
+ if not isinstance(adata.obs[col].dtype, pd.CategoricalDtype):
702
+ adata.obs[col] = adata.obs[col].astype("category")
703
+
704
+ base_set = set(mod_target_bases)
705
+ include_any_c = any(b in {"C", "CpG", "GpC"} for b in base_set)
706
+ include_any_a = "A" in base_set
707
+
708
+ for ref in adata.obs[reference_col].cat.categories:
709
+ for sample in adata.obs[sample_col].cat.categories:
710
+ display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
711
+
712
+ qmask = _mask_or_true(
713
+ "read_quality",
714
+ (lambda s: s >= float(min_quality))
715
+ if (min_quality is not None)
716
+ else (lambda s: pd.Series(True, index=s.index)),
717
+ )
718
+ lm_mask = _mask_or_true(
719
+ "mapped_length",
720
+ (lambda s: s >= float(min_length))
721
+ if (min_length is not None)
722
+ else (lambda s: pd.Series(True, index=s.index)),
723
+ )
724
+ lrr_mask = _mask_or_true(
725
+ "mapped_length_to_reference_length_ratio",
726
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
727
+ if (min_mapped_length_to_reference_length_ratio is not None)
728
+ else (lambda s: pd.Series(True, index=s.index)),
729
+ )
730
+
731
+ demux_mask = _mask_or_true(
732
+ "demux_type",
733
+ (lambda s: s.astype("string").isin(list(demux_types)))
734
+ if (demux_types is not None)
735
+ else (lambda s: pd.Series(True, index=s.index)),
736
+ )
737
+
738
+ ref_mask = adata.obs[reference_col] == ref
739
+ sample_mask = adata.obs[sample_col] == sample
740
+
741
+ row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
742
+
743
+ if not bool(row_mask.any()):
744
+ logger.warning(
745
+ "No reads for %s - %s after read quality and length filtering.",
746
+ display_sample,
747
+ ref,
748
+ )
749
+ continue
750
+
751
+ try:
752
+ subset = adata[row_mask, :].copy()
753
+
754
+ if min_position_valid_fraction is not None:
755
+ valid_key = f"{ref}_valid_fraction"
756
+ if valid_key in subset.var:
757
+ v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
758
+ col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
759
+ if col_mask.any():
760
+ subset = subset[:, col_mask].copy()
761
+ else:
762
+ logger.warning(
763
+ "No positions left after valid_fraction filter for %s - %s.",
764
+ display_sample,
765
+ ref,
766
+ )
767
+ continue
768
+
769
+ if subset.shape[0] == 0:
770
+ logger.warning(
771
+ "No reads left after filtering for %s - %s.", display_sample, ref
772
+ )
773
+ continue
774
+
775
+ if bins is None:
776
+ bins_temp = {"All": (subset.obs[reference_col] == ref)}
777
+ else:
778
+ bins_temp = bins
779
+
780
+ any_c_sites = gpc_sites = cpg_sites = np.array([], dtype=int)
781
+ any_a_sites = np.array([], dtype=int)
782
+
783
+ num_any_c = num_gpc = num_cpg = num_any_a = 0
784
+
785
+ if include_any_c:
786
+ any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
787
+ gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
788
+ cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
789
+
790
+ num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
791
+
792
+ any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
793
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
794
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
795
+
796
+ if include_any_a:
797
+ any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
798
+ num_any_a = len(any_a_sites)
799
+ any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
800
+
801
+ stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
802
+ stacked_any_c_raw, stacked_gpc_raw, stacked_cpg_raw, stacked_any_a_raw = (
803
+ [],
804
+ [],
805
+ [],
806
+ [],
807
+ )
808
+ row_labels, bin_labels, bin_boundaries = [], [], []
809
+ percentages = {}
810
+ last_idx = 0
811
+ total_reads = subset.shape[0]
812
+
813
+ for bin_label, bin_filter in bins_temp.items():
814
+ subset_bin = subset[bin_filter].copy()
815
+ num_reads = subset_bin.shape[0]
816
+ if num_reads == 0:
817
+ percentages[bin_label] = 0.0
818
+ continue
819
+
820
+ percent_reads = (num_reads / total_reads) * 100
821
+ percentages[bin_label] = percent_reads
822
+
823
+ if sort_by.startswith("obs:"):
824
+ colname = sort_by.split("obs:")[1]
825
+ order = np.argsort(subset_bin.obs[colname].values)
826
+
827
+ elif sort_by == "gpc" and num_gpc > 0:
828
+ gpc_matrix = _layer_to_numpy(
829
+ subset_bin,
830
+ layer_gpc,
831
+ gpc_sites,
832
+ fill_nan_strategy=fill_nan_strategy,
833
+ fill_nan_value=fill_nan_value,
834
+ )
835
+ linkage = sch.linkage(gpc_matrix, method="ward")
836
+ order = sch.leaves_list(linkage)
837
+
838
+ elif sort_by == "cpg" and num_cpg > 0:
839
+ cpg_matrix = _layer_to_numpy(
840
+ subset_bin,
841
+ layer_cpg,
842
+ cpg_sites,
843
+ fill_nan_strategy=fill_nan_strategy,
844
+ fill_nan_value=fill_nan_value,
845
+ )
846
+ linkage = sch.linkage(cpg_matrix, method="ward")
847
+ order = sch.leaves_list(linkage)
848
+
849
+ elif sort_by == "c" and num_any_c > 0:
850
+ any_c_matrix = _layer_to_numpy(
851
+ subset_bin,
852
+ layer_c,
853
+ any_c_sites,
854
+ fill_nan_strategy=fill_nan_strategy,
855
+ fill_nan_value=fill_nan_value,
856
+ )
857
+ linkage = sch.linkage(any_c_matrix, method="ward")
858
+ order = sch.leaves_list(linkage)
859
+
860
+ elif sort_by == "gpc_cpg":
861
+ gpc_matrix = _layer_to_numpy(
862
+ subset_bin,
863
+ layer_gpc,
864
+ None,
865
+ fill_nan_strategy=fill_nan_strategy,
866
+ fill_nan_value=fill_nan_value,
867
+ )
868
+ linkage = sch.linkage(gpc_matrix, method="ward")
869
+ order = sch.leaves_list(linkage)
870
+
871
+ elif sort_by == "a" and num_any_a > 0:
872
+ any_a_matrix = _layer_to_numpy(
873
+ subset_bin,
874
+ layer_a,
875
+ any_a_sites,
876
+ fill_nan_strategy=fill_nan_strategy,
877
+ fill_nan_value=fill_nan_value,
878
+ )
879
+ linkage = sch.linkage(any_a_matrix, method="ward")
880
+ order = sch.leaves_list(linkage)
881
+
882
+ elif sort_by == "none":
883
+ order = np.arange(num_reads)
884
+
885
+ else:
886
+ order = np.arange(num_reads)
887
+
888
+ subset_bin = subset_bin[order]
889
+
890
+ if include_any_c and num_any_c > 0:
891
+ stacked_any_c.append(
892
+ _layer_to_numpy(
893
+ subset_bin,
894
+ layer_c,
895
+ any_c_sites,
896
+ fill_nan_strategy=fill_nan_strategy,
897
+ fill_nan_value=fill_nan_value,
898
+ )
899
+ )
900
+ stacked_any_c_raw.append(
901
+ _layer_to_numpy(
902
+ subset_bin,
903
+ layer_c,
904
+ any_c_sites,
905
+ fill_nan_strategy="none",
906
+ fill_nan_value=fill_nan_value,
907
+ )
908
+ )
909
+ if include_any_c and num_gpc > 0:
910
+ stacked_gpc.append(
911
+ _layer_to_numpy(
912
+ subset_bin,
913
+ layer_gpc,
914
+ gpc_sites,
915
+ fill_nan_strategy=fill_nan_strategy,
916
+ fill_nan_value=fill_nan_value,
917
+ )
918
+ )
919
+ stacked_gpc_raw.append(
920
+ _layer_to_numpy(
921
+ subset_bin,
922
+ layer_gpc,
923
+ gpc_sites,
924
+ fill_nan_strategy="none",
925
+ fill_nan_value=fill_nan_value,
926
+ )
927
+ )
928
+ if include_any_c and num_cpg > 0:
929
+ stacked_cpg.append(
930
+ _layer_to_numpy(
931
+ subset_bin,
932
+ layer_cpg,
933
+ cpg_sites,
934
+ fill_nan_strategy=fill_nan_strategy,
935
+ fill_nan_value=fill_nan_value,
936
+ )
937
+ )
938
+ stacked_cpg_raw.append(
939
+ _layer_to_numpy(
940
+ subset_bin,
941
+ layer_cpg,
942
+ cpg_sites,
943
+ fill_nan_strategy="none",
944
+ fill_nan_value=fill_nan_value,
945
+ )
946
+ )
947
+ if include_any_a and num_any_a > 0:
948
+ stacked_any_a.append(
949
+ _layer_to_numpy(
950
+ subset_bin,
951
+ layer_a,
952
+ any_a_sites,
953
+ fill_nan_strategy=fill_nan_strategy,
954
+ fill_nan_value=fill_nan_value,
955
+ )
956
+ )
957
+ stacked_any_a_raw.append(
958
+ _layer_to_numpy(
959
+ subset_bin,
960
+ layer_a,
961
+ any_a_sites,
962
+ fill_nan_strategy="none",
963
+ fill_nan_value=fill_nan_value,
964
+ )
965
+ )
966
+
967
+ row_labels.extend([bin_label] * num_reads)
968
+ bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
969
+ last_idx += num_reads
970
+ bin_boundaries.append(last_idx)
971
+
972
+ blocks = []
973
+
974
+ if include_any_c and stacked_any_c:
975
+ any_c_matrix = np.vstack(stacked_any_c)
976
+ any_c_matrix_raw = np.vstack(stacked_any_c_raw)
977
+ gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
978
+ gpc_matrix_raw = (
979
+ np.vstack(stacked_gpc_raw) if stacked_gpc_raw else np.empty((0, 0))
980
+ )
981
+ cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
982
+ cpg_matrix_raw = (
983
+ np.vstack(stacked_cpg_raw) if stacked_cpg_raw else np.empty((0, 0))
984
+ )
985
+
986
+ mean_any_c = (
987
+ _methylation_fraction_for_layer(any_c_matrix_raw, layer_c)
988
+ if any_c_matrix_raw.size
989
+ else None
990
+ )
991
+ mean_gpc = (
992
+ _methylation_fraction_for_layer(gpc_matrix_raw, layer_gpc)
993
+ if gpc_matrix_raw.size
994
+ else None
995
+ )
996
+ mean_cpg = (
997
+ _methylation_fraction_for_layer(cpg_matrix_raw, layer_cpg)
998
+ if cpg_matrix_raw.size
999
+ else None
1000
+ )
1001
+
1002
+ if any_c_matrix.size:
1003
+ blocks.append(
1004
+ dict(
1005
+ name="c",
1006
+ matrix=any_c_matrix,
1007
+ mean=mean_any_c,
1008
+ labels=any_c_labels,
1009
+ cmap=cmap_c,
1010
+ n_xticks=n_xticks_any_c,
1011
+ title="any C site Modification Signal",
1012
+ )
1013
+ )
1014
+ if gpc_matrix.size:
1015
+ blocks.append(
1016
+ dict(
1017
+ name="gpc",
1018
+ matrix=gpc_matrix,
1019
+ mean=mean_gpc,
1020
+ labels=gpc_labels,
1021
+ cmap=cmap_gpc,
1022
+ n_xticks=n_xticks_gpc,
1023
+ title="GpC Modification Signal",
1024
+ )
1025
+ )
1026
+ if cpg_matrix.size:
1027
+ blocks.append(
1028
+ dict(
1029
+ name="cpg",
1030
+ matrix=cpg_matrix,
1031
+ mean=mean_cpg,
1032
+ labels=cpg_labels,
1033
+ cmap=cmap_cpg,
1034
+ n_xticks=n_xticks_cpg,
1035
+ title="CpG Modification Signal",
1036
+ )
1037
+ )
1038
+
1039
+ if include_any_a and stacked_any_a:
1040
+ any_a_matrix = np.vstack(stacked_any_a)
1041
+ any_a_matrix_raw = np.vstack(stacked_any_a_raw)
1042
+ mean_any_a = (
1043
+ _methylation_fraction_for_layer(any_a_matrix_raw, layer_a)
1044
+ if any_a_matrix_raw.size
1045
+ else None
1046
+ )
1047
+ if any_a_matrix.size:
1048
+ blocks.append(
1049
+ dict(
1050
+ name="a",
1051
+ matrix=any_a_matrix,
1052
+ mean=mean_any_a,
1053
+ labels=any_a_labels,
1054
+ cmap=cmap_a,
1055
+ n_xticks=n_xticks_any_a,
1056
+ title="any A site Modification Signal",
1057
+ )
1058
+ )
1059
+
1060
+ if not blocks:
1061
+ logger.warning("No matrices to plot for %s - %s.", display_sample, ref)
1062
+ continue
1063
+
1064
+ gs_dim = len(blocks)
1065
+ fig = plt.figure(figsize=(5.5 * gs_dim, 11))
1066
+ gs = grid_spec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.02)
1067
+ fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
1068
+
1069
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
1070
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
1071
+
1072
+ for i, blk in enumerate(blocks):
1073
+ mat = blk["matrix"]
1074
+ mean = blk["mean"]
1075
+ labels = np.asarray(blk["labels"], dtype=str)
1076
+ n_xticks = blk["n_xticks"]
1077
+
1078
+ clean_barplot(axes_bar[i], mean, blk["title"])
1079
+
1080
+ sns.heatmap(
1081
+ mat, cmap=blk["cmap"], ax=axes_heat[i], yticklabels=False, cbar=False
1082
+ )
1083
+
1084
+ tick_pos = _fixed_tick_positions(len(labels), n_xticks)
1085
+ axes_heat[i].set_xticks(tick_pos)
1086
+ axes_heat[i].set_xticklabels(
1087
+ labels[tick_pos], rotation=xtick_rotation, fontsize=xtick_fontsize
1088
+ )
1089
+
1090
+ for boundary in bin_boundaries[:-1]:
1091
+ axes_heat[i].axhline(y=boundary, color="black", linewidth=2)
1092
+
1093
+ axes_heat[i].set_xlabel("Position", fontsize=9)
1094
+
1095
+ plt.tight_layout()
1096
+
1097
+ if save_path is not None:
1098
+ safe_name = (
1099
+ f"{ref}__{display_sample}".replace("=", "")
1100
+ .replace("__", "_")
1101
+ .replace(",", "_")
1102
+ .replace(" ", "_")
1103
+ )
1104
+ out_file = save_path / f"{safe_name}.png"
1105
+ fig.savefig(out_file, dpi=300)
1106
+ plt.close(fig)
1107
+ logger.info("Saved combined raw clustermap to %s.", out_file)
1108
+ else:
1109
+ plt.show()
1110
+
1111
+ rec = {
1112
+ "sample": str(sample),
1113
+ "ref": str(ref),
1114
+ "row_labels": row_labels,
1115
+ "bin_labels": bin_labels,
1116
+ "bin_boundaries": bin_boundaries,
1117
+ "percentages": percentages,
1118
+ }
1119
+ for blk in blocks:
1120
+ rec[f"{blk['name']}_matrix"] = blk["matrix"]
1121
+ rec[f"{blk['name']}_labels"] = list(map(str, blk["labels"]))
1122
+ results.append(rec)
1123
+
1124
+ logger.info("Summary for %s - %s:", display_sample, ref)
1125
+ for bin_label, percent in percentages.items():
1126
+ logger.info(" - %s: %.1f%%", bin_label, percent)
1127
+
1128
+ except Exception:
1129
+ import traceback
1130
+
1131
+ traceback.print_exc()
1132
+ continue
1133
+
1134
+ return results