smftools 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +18 -2
  4. smftools/cli/hmm_adata.py +18 -1
  5. smftools/cli/latent_adata.py +522 -67
  6. smftools/cli/load_adata.py +2 -2
  7. smftools/cli/preprocess_adata.py +32 -93
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +23 -109
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +41 -5
  12. smftools/config/conversion.yaml +0 -10
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +49 -13
  15. smftools/config/experiment_config.py +96 -3
  16. smftools/constants.py +4 -0
  17. smftools/hmm/call_hmm_peaks.py +1 -1
  18. smftools/informatics/binarize_converted_base_identities.py +2 -89
  19. smftools/informatics/converted_BAM_to_adata.py +53 -13
  20. smftools/informatics/h5ad_functions.py +83 -0
  21. smftools/informatics/modkit_extract_to_adata.py +4 -0
  22. smftools/plotting/__init__.py +26 -12
  23. smftools/plotting/autocorrelation_plotting.py +22 -4
  24. smftools/plotting/chimeric_plotting.py +1893 -0
  25. smftools/plotting/classifiers.py +28 -14
  26. smftools/plotting/general_plotting.py +58 -3362
  27. smftools/plotting/hmm_plotting.py +1586 -2
  28. smftools/plotting/latent_plotting.py +804 -0
  29. smftools/plotting/plotting_utils.py +243 -0
  30. smftools/plotting/position_stats.py +16 -8
  31. smftools/plotting/preprocess_plotting.py +281 -0
  32. smftools/plotting/qc_plotting.py +8 -3
  33. smftools/plotting/spatial_plotting.py +1134 -0
  34. smftools/plotting/variant_plotting.py +1231 -0
  35. smftools/preprocessing/__init__.py +3 -0
  36. smftools/preprocessing/append_base_context.py +1 -1
  37. smftools/preprocessing/append_mismatch_frequency_sites.py +35 -6
  38. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  39. smftools/preprocessing/append_variant_call_layer.py +480 -0
  40. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  41. smftools/preprocessing/invert_adata.py +1 -0
  42. smftools/readwrite.py +109 -85
  43. smftools/tools/__init__.py +6 -0
  44. smftools/tools/calculate_knn.py +121 -0
  45. smftools/tools/calculate_nmf.py +18 -7
  46. smftools/tools/calculate_pca.py +180 -0
  47. smftools/tools/calculate_umap.py +70 -154
  48. smftools/tools/position_stats.py +4 -4
  49. smftools/tools/rolling_nn_distance.py +640 -3
  50. smftools/tools/sequence_alignment.py +140 -0
  51. smftools/tools/tensor_factorization.py +52 -4
  52. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/METADATA +3 -1
  53. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/RECORD +56 -42
  54. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  55. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  56. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import math
4
- from typing import Optional, Sequence, Tuple, Union
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union
5
7
 
6
8
  import numpy as np
7
9
  import pandas as pd
10
+ import scipy.cluster.hierarchy as sch
8
11
 
12
+ from smftools.logging_utils import get_logger
9
13
  from smftools.optional_imports import require
14
+ from smftools.plotting.plotting_utils import (
15
+ _layer_to_numpy,
16
+ _methylation_fraction_for_layer,
17
+ _select_labels,
18
+ clean_barplot,
19
+ normalized_mean,
20
+ )
21
+
22
+ gridspec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
23
+ patches = require("matplotlib.patches", extra="plotting", purpose="plot rendering")
24
+ sns = require("seaborn", extra="plotting", purpose="plot styling")
10
25
 
11
26
  plt = require("matplotlib.pyplot", extra="plotting", purpose="HMM plots")
12
27
  mpl_colors = require("matplotlib.colors", extra="plotting", purpose="HMM plots")
@@ -17,6 +32,153 @@ pdf_backend = require(
17
32
  )
18
33
  PdfPages = pdf_backend.PdfPages
19
34
 
35
+ logger = get_logger(__name__)
36
+
37
+
38
+ def _local_sites_to_global_indices(
39
+ adata,
40
+ subset,
41
+ local_sites: np.ndarray,
42
+ ) -> np.ndarray:
43
+ """Translate subset-local column indices into global ``adata.var`` indices."""
44
+ local_sites = np.asarray(local_sites, dtype=int)
45
+ if local_sites.size == 0:
46
+ return local_sites
47
+ subset_to_global = adata.var_names.get_indexer(subset.var_names)
48
+ global_sites = subset_to_global[local_sites]
49
+ if np.any(global_sites < 0):
50
+ missing = int(np.sum(global_sites < 0))
51
+ logger.warning(
52
+ "Could not map %d plotted positions back to full var index; skipping those points.",
53
+ missing,
54
+ )
55
+ global_sites = global_sites[global_sites >= 0]
56
+ return global_sites
57
+
58
+
59
+ def _overlay_variant_calls_on_panels(
60
+ adata,
61
+ reference: str,
62
+ ordered_obs_names: list,
63
+ panels_with_indices: list,
64
+ seq1_color: str = "white",
65
+ seq2_color: str = "black",
66
+ marker_size: float = 4.0,
67
+ ) -> bool:
68
+ """
69
+ Overlay variant call circles on heatmap panels using nearest-neighbor mapping.
70
+
71
+ This function maps variant call column indices to the nearest displayed column
72
+ in each panel, using var index space for mapping. This handles both regular
73
+ var_names and reindexed label coordinates.
74
+
75
+ Parameters
76
+ ----------
77
+ adata : AnnData
78
+ The AnnData containing variant call layers (should be full adata, not subset).
79
+ reference : str
80
+ Reference name used to auto-detect the variant call layer.
81
+ ordered_obs_names : list
82
+ Obs names in display order (rows of the heatmap).
83
+ panels_with_indices : list of (ax, site_indices)
84
+ Each entry is a matplotlib axes and the var indices for that panel's columns.
85
+ seq1_color, seq2_color : str
86
+ Colors for seq1 (value 1) and seq2 (value 2) variant calls.
87
+ marker_size : float
88
+ Size of the circle markers.
89
+
90
+ Returns
91
+ -------
92
+ bool
93
+ True if overlay was applied to at least one panel.
94
+ """
95
+ # Auto-detect variant call layer - find any layer ending with _variant_call
96
+ vc_layer_key = None
97
+ for key in adata.layers:
98
+ if key.endswith("_variant_call"):
99
+ vc_layer_key = key
100
+ break
101
+
102
+ if vc_layer_key is None:
103
+ return False
104
+
105
+ # Build row index mapping
106
+ obs_name_to_idx = {str(name): i for i, name in enumerate(adata.obs_names)}
107
+ common_obs = [str(name) for name in ordered_obs_names if str(name) in obs_name_to_idx]
108
+ if not common_obs:
109
+ return False
110
+
111
+ obs_idx = [obs_name_to_idx[name] for name in common_obs]
112
+ row_index_map = {str(name): i for i, name in enumerate(ordered_obs_names)}
113
+ heatmap_row_indices = np.array([row_index_map[name] for name in common_obs])
114
+
115
+ # Get variant call matrix for the ordered obs
116
+ vc_data = adata.layers[vc_layer_key]
117
+ if hasattr(vc_data, "toarray"):
118
+ vc_data = vc_data.toarray()
119
+ vc_matrix = np.asarray(vc_data)[obs_idx, :]
120
+
121
+ # Find columns with actual variant calls (value 1 or 2)
122
+ has_calls = np.isin(vc_matrix, [1, 2]).any(axis=0)
123
+ call_col_indices = np.where(has_calls)[0]
124
+
125
+ if len(call_col_indices) == 0:
126
+ return False
127
+
128
+ call_sub = vc_matrix[:, call_col_indices]
129
+
130
+ applied = False
131
+ for ax, site_indices in panels_with_indices:
132
+ site_indices = np.asarray(site_indices)
133
+ if site_indices.size == 0:
134
+ continue
135
+
136
+ # Use nearest-neighbor mapping in var index space
137
+ # site_indices are the var indices displayed in this panel (sorted by position)
138
+ # call_col_indices are var indices where calls exist
139
+ # Map each call to the nearest displayed site index
140
+
141
+ # Sort site indices for searchsorted
142
+ sorted_order = np.argsort(site_indices)
143
+ sorted_sites = site_indices[sorted_order]
144
+
145
+ # Find nearest site for each call
146
+ insert_idx = np.searchsorted(sorted_sites, call_col_indices)
147
+ insert_idx = np.clip(insert_idx, 0, len(sorted_sites) - 1)
148
+ left_idx = np.clip(insert_idx - 1, 0, len(sorted_sites) - 1)
149
+
150
+ dist_right = np.abs(sorted_sites[insert_idx].astype(float) - call_col_indices.astype(float))
151
+ dist_left = np.abs(sorted_sites[left_idx].astype(float) - call_col_indices.astype(float))
152
+ nearest_sorted = np.where(dist_left < dist_right, left_idx, insert_idx)
153
+
154
+ # Map back to original (unsorted) heatmap column positions
155
+ nearest_heatmap_col = sorted_order[nearest_sorted]
156
+
157
+ # Plot circles for each variant value
158
+ for call_val, color in [(1, seq1_color), (2, seq2_color)]:
159
+ local_rows, local_cols = np.where(call_sub == call_val)
160
+ if len(local_rows) == 0:
161
+ continue
162
+
163
+ plot_y = heatmap_row_indices[local_rows]
164
+ plot_x = nearest_heatmap_col[local_cols]
165
+
166
+ ax.scatter(
167
+ plot_x + 0.5,
168
+ plot_y + 0.5,
169
+ c=color,
170
+ s=marker_size,
171
+ marker="o",
172
+ edgecolors="gray",
173
+ linewidths=0.3,
174
+ zorder=3,
175
+ )
176
+ applied = True
177
+
178
+ if applied:
179
+ logger.info("Overlaid variant calls from layer '%s'.", vc_layer_key)
180
+ return applied
181
+
20
182
 
21
183
  def plot_hmm_size_contours(
22
184
  adata,
@@ -57,6 +219,7 @@ def plot_hmm_size_contours(
57
219
  Other args are the same as prior function.
58
220
  """
59
221
  feature_ranges = tuple(feature_ranges or ())
222
+ logger.info("Plotting HMM size contours%s.", f" -> {save_path}" if save_path else "")
60
223
 
61
224
  def _resolve_length_color(length: int, fallback: str) -> Tuple[float, float, float, float]:
62
225
  for min_len, max_len, color in feature_ranges:
@@ -365,6 +528,7 @@ def plot_hmm_size_contours(
365
528
  fname = f"hmm_size_page_{p + 1:03d}.png"
366
529
  out = os.path.join(save_path, fname)
367
530
  fig.savefig(out, dpi=dpi, bbox_inches="tight")
531
+ logger.info("Saved HMM size contour page to %s.", out)
368
532
 
369
533
  # multipage PDF if requested
370
534
  if save_path is not None and save_pdf:
@@ -372,6 +536,1426 @@ def plot_hmm_size_contours(
372
536
  with PdfPages(pdf_file) as pp:
373
537
  for fig in figs:
374
538
  pp.savefig(fig, bbox_inches="tight")
375
- print(f"Saved multipage PDF: {pdf_file}")
539
+ logger.info("Saved HMM size contour PDF to %s.", pdf_file)
376
540
 
377
541
  return figs
542
+
543
+
544
+ def _resolve_feature_color(cmap: Any) -> Tuple[float, float, float, float]:
545
+ """Resolve a representative feature color from a colormap or color spec."""
546
+ if isinstance(cmap, str):
547
+ try:
548
+ cmap_obj = plt.get_cmap(cmap)
549
+ return mpl_colors.to_rgba(cmap_obj(1.0))
550
+ except Exception:
551
+ return mpl_colors.to_rgba(cmap)
552
+
553
+ if isinstance(cmap, mpl_colors.Colormap):
554
+ if hasattr(cmap, "colors") and cmap.colors:
555
+ return mpl_colors.to_rgba(cmap.colors[-1])
556
+ return mpl_colors.to_rgba(cmap(1.0))
557
+
558
+ return mpl_colors.to_rgba("black")
559
+
560
+
561
+ def _build_hmm_feature_cmap(
562
+ cmap: Any,
563
+ *,
564
+ zero_color: str = "#f5f1e8",
565
+ nan_color: str = "#E6E6E6",
566
+ ) -> mpl_colors.Colormap:
567
+ """Build a two-color HMM colormap with explicit NaN/under handling."""
568
+ feature_color = _resolve_feature_color(cmap)
569
+ hmm_cmap = mpl_colors.LinearSegmentedColormap.from_list(
570
+ "hmm_feature_cmap",
571
+ [zero_color, feature_color],
572
+ )
573
+ hmm_cmap.set_bad(nan_color)
574
+ hmm_cmap.set_under(nan_color)
575
+ return hmm_cmap
576
+
577
+
578
+ def _map_length_matrix_to_subclasses(
579
+ length_matrix: np.ndarray,
580
+ feature_ranges: Sequence[Tuple[int, int, Any]],
581
+ ) -> np.ndarray:
582
+ """Map length values into subclass integer codes based on feature ranges."""
583
+ mapped = np.zeros_like(length_matrix, dtype=float)
584
+ finite_mask = np.isfinite(length_matrix)
585
+ for idx, (min_len, max_len, _color) in enumerate(feature_ranges, start=1):
586
+ mask = finite_mask & (length_matrix >= min_len) & (length_matrix <= max_len)
587
+ mapped[mask] = float(idx)
588
+ mapped[~finite_mask] = np.nan
589
+ return mapped
590
+
591
+
592
+ def _build_length_feature_cmap(
593
+ feature_ranges: Sequence[Tuple[int, int, Any]],
594
+ *,
595
+ zero_color: str = "#f5f1e8",
596
+ nan_color: str = "#E6E6E6",
597
+ ) -> Tuple[mpl_colors.Colormap, mpl_colors.BoundaryNorm]:
598
+ """Build a discrete colormap and norm for length-based subclasses."""
599
+ color_list = [zero_color] + [color for _, _, color in feature_ranges]
600
+ cmap = mpl_colors.ListedColormap(color_list, name="hmm_length_feature_cmap")
601
+ cmap.set_bad(nan_color)
602
+ bounds = np.arange(-0.5, len(color_list) + 0.5, 1)
603
+ norm = mpl_colors.BoundaryNorm(bounds, cmap.N)
604
+ return cmap, norm
605
+
606
+
607
+ def combined_hmm_raw_clustermap(
608
+ adata,
609
+ sample_col: str = "Sample_Names",
610
+ reference_col: str = "Reference_strand",
611
+ hmm_feature_layer: str = "hmm_combined",
612
+ layer_gpc: str = "nan0_0minus1",
613
+ layer_cpg: str = "nan0_0minus1",
614
+ layer_c: str = "nan0_0minus1",
615
+ layer_a: str = "nan0_0minus1",
616
+ cmap_hmm: str = "tab10",
617
+ cmap_gpc: str = "coolwarm",
618
+ cmap_cpg: str = "viridis",
619
+ cmap_c: str = "coolwarm",
620
+ cmap_a: str = "coolwarm",
621
+ min_quality: int = 20,
622
+ min_length: int = 200,
623
+ min_mapped_length_to_reference_length_ratio: float = 0.8,
624
+ min_position_valid_fraction: float = 0.5,
625
+ demux_types: Sequence[str] = ("single", "double", "already"),
626
+ sample_mapping: Optional[Mapping[str, str]] = None,
627
+ save_path: str | Path | None = None,
628
+ normalize_hmm: bool = False,
629
+ sort_by: str = "gpc",
630
+ bins: Optional[Dict[str, Any]] = None,
631
+ deaminase: bool = False,
632
+ min_signal: float = 0.0,
633
+ # ---- fixed tick label controls (counts, not spacing)
634
+ n_xticks_hmm: int = 10,
635
+ n_xticks_any_c: int = 8,
636
+ n_xticks_gpc: int = 8,
637
+ n_xticks_cpg: int = 8,
638
+ n_xticks_a: int = 8,
639
+ index_col_suffix: str | None = None,
640
+ fill_nan_strategy: str = "value",
641
+ fill_nan_value: float = -1,
642
+ overlay_variant_calls: bool = False,
643
+ variant_overlay_seq1_color: str = "white",
644
+ variant_overlay_seq2_color: str = "black",
645
+ variant_overlay_marker_size: float = 4.0,
646
+ ):
647
+ """
648
+ Makes a multi-panel clustermap per (sample, reference):
649
+ HMM panel (always) + optional raw panels for C, GpC, CpG, and A sites.
650
+
651
+ Panels are added only if the corresponding site mask exists AND has >0 sites.
652
+
653
+ sort_by options:
654
+ 'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
655
+
656
+ NaN fill strategy is applied in-memory for clustering/plotting only.
657
+ """
658
+ if fill_nan_strategy not in {"none", "value", "col_mean"}:
659
+ raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
660
+
661
+ def pick_xticks(labels: np.ndarray, n_ticks: int):
662
+ """Pick tick indices/labels from an array."""
663
+ if labels.size == 0:
664
+ return [], []
665
+ idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
666
+ idx = np.unique(idx)
667
+ return idx.tolist(), labels[idx].tolist()
668
+
669
+ # Helper: build a True mask if filter is inactive or column missing
670
+ def _mask_or_true(series_name: str, predicate):
671
+ """Return a mask from predicate or an all-True mask."""
672
+ if series_name not in adata.obs:
673
+ return pd.Series(True, index=adata.obs.index)
674
+ s = adata.obs[series_name]
675
+ try:
676
+ return predicate(s)
677
+ except Exception:
678
+ # Fallback: all True if bad dtype / predicate failure
679
+ return pd.Series(True, index=adata.obs.index)
680
+
681
+ results = []
682
+ signal_type = "deamination" if deaminase else "methylation"
683
+
684
+ for ref in adata.obs[reference_col].cat.categories:
685
+ for sample in adata.obs[sample_col].cat.categories:
686
+ # Optionally remap sample label for display
687
+ display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
688
+ # Row-level masks (obs)
689
+ qmask = _mask_or_true(
690
+ "read_quality",
691
+ (lambda s: s >= float(min_quality))
692
+ if (min_quality is not None)
693
+ else (lambda s: pd.Series(True, index=s.index)),
694
+ )
695
+ lm_mask = _mask_or_true(
696
+ "mapped_length",
697
+ (lambda s: s >= float(min_length))
698
+ if (min_length is not None)
699
+ else (lambda s: pd.Series(True, index=s.index)),
700
+ )
701
+ lrr_mask = _mask_or_true(
702
+ "mapped_length_to_reference_length_ratio",
703
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
704
+ if (min_mapped_length_to_reference_length_ratio is not None)
705
+ else (lambda s: pd.Series(True, index=s.index)),
706
+ )
707
+
708
+ demux_mask = _mask_or_true(
709
+ "demux_type",
710
+ (lambda s: s.astype("string").isin(list(demux_types)))
711
+ if (demux_types is not None)
712
+ else (lambda s: pd.Series(True, index=s.index)),
713
+ )
714
+
715
+ ref_mask = adata.obs[reference_col] == ref
716
+ sample_mask = adata.obs[sample_col] == sample
717
+
718
+ row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
719
+
720
+ if not bool(row_mask.any()):
721
+ print(
722
+ f"No reads for {display_sample} - {ref} after read quality and length filtering"
723
+ )
724
+ continue
725
+
726
+ try:
727
+ # ---- subset reads ----
728
+ subset = adata[row_mask, :].copy()
729
+
730
+ # Column-level mask (var)
731
+ if min_position_valid_fraction is not None:
732
+ valid_key = f"{ref}_valid_fraction"
733
+ if valid_key in subset.var:
734
+ v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
735
+ col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
736
+ if col_mask.any():
737
+ subset = subset[:, col_mask].copy()
738
+ else:
739
+ print(
740
+ f"No positions left after valid_fraction filter for {display_sample} - {ref}"
741
+ )
742
+ continue
743
+
744
+ if subset.shape[0] == 0:
745
+ print(f"No reads left after filtering for {display_sample} - {ref}")
746
+ continue
747
+
748
+ # ---- bins ----
749
+ if bins is None:
750
+ bins_temp = {"All": np.ones(subset.n_obs, dtype=bool)}
751
+ else:
752
+ bins_temp = bins
753
+
754
+ # ---- site masks (robust) ----
755
+ def _sites(*keys):
756
+ """Return indices for the first matching site key."""
757
+ for k in keys:
758
+ if k in subset.var:
759
+ return np.where(subset.var[k].values)[0]
760
+ return np.array([], dtype=int)
761
+
762
+ gpc_sites = _sites(f"{ref}_GpC_site")
763
+ cpg_sites = _sites(f"{ref}_CpG_site")
764
+ any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
765
+ any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
766
+
767
+ # ---- labels via _select_labels ----
768
+ # HMM uses *all* columns
769
+ hmm_sites = np.arange(subset.n_vars, dtype=int)
770
+ hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
771
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
772
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
773
+ any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
774
+ any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
775
+
776
+ # storage
777
+ stacked_hmm = []
778
+ stacked_hmm_raw = []
779
+ stacked_any_c = []
780
+ stacked_any_c_raw = []
781
+ stacked_gpc = []
782
+ stacked_gpc_raw = []
783
+ stacked_cpg = []
784
+ stacked_cpg_raw = []
785
+ stacked_any_a = []
786
+ stacked_any_a_raw = []
787
+ ordered_obs_names = []
788
+
789
+ row_labels, bin_labels, bin_boundaries = [], [], []
790
+ total_reads = subset.n_obs
791
+ percentages = {}
792
+ last_idx = 0
793
+
794
+ # ---------------- process bins ----------------
795
+ for bin_label, bin_filter in bins_temp.items():
796
+ sb = subset[bin_filter].copy()
797
+ n = sb.n_obs
798
+ if n == 0:
799
+ continue
800
+
801
+ pct = (n / total_reads) * 100 if total_reads else 0
802
+ percentages[bin_label] = pct
803
+
804
+ # ---- sorting ----
805
+ if sort_by.startswith("obs:"):
806
+ colname = sort_by.split("obs:")[1]
807
+ order = np.argsort(sb.obs[colname].values)
808
+
809
+ elif sort_by == "gpc" and gpc_sites.size:
810
+ gpc_matrix = _layer_to_numpy(
811
+ sb,
812
+ layer_gpc,
813
+ gpc_sites,
814
+ fill_nan_strategy=fill_nan_strategy,
815
+ fill_nan_value=fill_nan_value,
816
+ )
817
+ linkage = sch.linkage(gpc_matrix, method="ward")
818
+ order = sch.leaves_list(linkage)
819
+
820
+ elif sort_by == "cpg" and cpg_sites.size:
821
+ cpg_matrix = _layer_to_numpy(
822
+ sb,
823
+ layer_cpg,
824
+ cpg_sites,
825
+ fill_nan_strategy=fill_nan_strategy,
826
+ fill_nan_value=fill_nan_value,
827
+ )
828
+ linkage = sch.linkage(cpg_matrix, method="ward")
829
+ order = sch.leaves_list(linkage)
830
+
831
+ elif sort_by == "c" and any_c_sites.size:
832
+ any_c_matrix = _layer_to_numpy(
833
+ sb,
834
+ layer_c,
835
+ any_c_sites,
836
+ fill_nan_strategy=fill_nan_strategy,
837
+ fill_nan_value=fill_nan_value,
838
+ )
839
+ linkage = sch.linkage(any_c_matrix, method="ward")
840
+ order = sch.leaves_list(linkage)
841
+
842
+ elif sort_by == "a" and any_a_sites.size:
843
+ any_a_matrix = _layer_to_numpy(
844
+ sb,
845
+ layer_a,
846
+ any_a_sites,
847
+ fill_nan_strategy=fill_nan_strategy,
848
+ fill_nan_value=fill_nan_value,
849
+ )
850
+ linkage = sch.linkage(any_a_matrix, method="ward")
851
+ order = sch.leaves_list(linkage)
852
+
853
+ elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
854
+ gpc_matrix = _layer_to_numpy(
855
+ sb,
856
+ layer_gpc,
857
+ None,
858
+ fill_nan_strategy=fill_nan_strategy,
859
+ fill_nan_value=fill_nan_value,
860
+ )
861
+ linkage = sch.linkage(gpc_matrix, method="ward")
862
+ order = sch.leaves_list(linkage)
863
+
864
+ elif sort_by == "hmm" and hmm_sites.size:
865
+ hmm_matrix = _layer_to_numpy(
866
+ sb,
867
+ hmm_feature_layer,
868
+ hmm_sites,
869
+ fill_nan_strategy=fill_nan_strategy,
870
+ fill_nan_value=fill_nan_value,
871
+ )
872
+ linkage = sch.linkage(hmm_matrix, method="ward")
873
+ order = sch.leaves_list(linkage)
874
+
875
+ else:
876
+ order = np.arange(n)
877
+
878
+ sb = sb[order]
879
+ ordered_obs_names.extend(sb.obs_names.tolist())
880
+
881
+ # ---- collect matrices ----
882
+ stacked_hmm.append(
883
+ _layer_to_numpy(
884
+ sb,
885
+ hmm_feature_layer,
886
+ None,
887
+ fill_nan_strategy=fill_nan_strategy,
888
+ fill_nan_value=fill_nan_value,
889
+ )
890
+ )
891
+ stacked_hmm_raw.append(
892
+ _layer_to_numpy(
893
+ sb,
894
+ hmm_feature_layer,
895
+ None,
896
+ fill_nan_strategy="none",
897
+ fill_nan_value=fill_nan_value,
898
+ )
899
+ )
900
+ if any_c_sites.size:
901
+ stacked_any_c.append(
902
+ _layer_to_numpy(
903
+ sb,
904
+ layer_c,
905
+ any_c_sites,
906
+ fill_nan_strategy=fill_nan_strategy,
907
+ fill_nan_value=fill_nan_value,
908
+ )
909
+ )
910
+ stacked_any_c_raw.append(
911
+ _layer_to_numpy(
912
+ sb,
913
+ layer_c,
914
+ any_c_sites,
915
+ fill_nan_strategy="none",
916
+ fill_nan_value=fill_nan_value,
917
+ )
918
+ )
919
+ if gpc_sites.size:
920
+ stacked_gpc.append(
921
+ _layer_to_numpy(
922
+ sb,
923
+ layer_gpc,
924
+ gpc_sites,
925
+ fill_nan_strategy=fill_nan_strategy,
926
+ fill_nan_value=fill_nan_value,
927
+ )
928
+ )
929
+ stacked_gpc_raw.append(
930
+ _layer_to_numpy(
931
+ sb,
932
+ layer_gpc,
933
+ gpc_sites,
934
+ fill_nan_strategy="none",
935
+ fill_nan_value=fill_nan_value,
936
+ )
937
+ )
938
+ if cpg_sites.size:
939
+ stacked_cpg.append(
940
+ _layer_to_numpy(
941
+ sb,
942
+ layer_cpg,
943
+ cpg_sites,
944
+ fill_nan_strategy=fill_nan_strategy,
945
+ fill_nan_value=fill_nan_value,
946
+ )
947
+ )
948
+ stacked_cpg_raw.append(
949
+ _layer_to_numpy(
950
+ sb,
951
+ layer_cpg,
952
+ cpg_sites,
953
+ fill_nan_strategy="none",
954
+ fill_nan_value=fill_nan_value,
955
+ )
956
+ )
957
+ if any_a_sites.size:
958
+ stacked_any_a.append(
959
+ _layer_to_numpy(
960
+ sb,
961
+ layer_a,
962
+ any_a_sites,
963
+ fill_nan_strategy=fill_nan_strategy,
964
+ fill_nan_value=fill_nan_value,
965
+ )
966
+ )
967
+ stacked_any_a_raw.append(
968
+ _layer_to_numpy(
969
+ sb,
970
+ layer_a,
971
+ any_a_sites,
972
+ fill_nan_strategy="none",
973
+ fill_nan_value=fill_nan_value,
974
+ )
975
+ )
976
+
977
+ row_labels.extend([bin_label] * n)
978
+ bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
979
+ last_idx += n
980
+ bin_boundaries.append(last_idx)
981
+
982
+ # ---------------- stack ----------------
983
+ hmm_matrix = np.vstack(stacked_hmm)
984
+ hmm_matrix_raw = np.vstack(stacked_hmm_raw)
985
+ mean_hmm = (
986
+ normalized_mean(hmm_matrix_raw)
987
+ if normalize_hmm
988
+ else np.nanmean(hmm_matrix_raw, axis=0)
989
+ )
990
+ hmm_plot_matrix = hmm_matrix_raw
991
+ hmm_plot_cmap = _build_hmm_feature_cmap(cmap_hmm)
992
+
993
+ panels = [
994
+ (
995
+ f"HMM - {hmm_feature_layer}",
996
+ hmm_plot_matrix,
997
+ hmm_labels,
998
+ hmm_plot_cmap,
999
+ mean_hmm,
1000
+ n_xticks_hmm,
1001
+ ),
1002
+ ]
1003
+
1004
+ if stacked_any_c:
1005
+ m = np.vstack(stacked_any_c)
1006
+ m_raw = np.vstack(stacked_any_c_raw)
1007
+ panels.append(
1008
+ (
1009
+ "C",
1010
+ m,
1011
+ any_c_labels,
1012
+ cmap_c,
1013
+ _methylation_fraction_for_layer(m_raw, layer_c),
1014
+ n_xticks_any_c,
1015
+ )
1016
+ )
1017
+
1018
+ if stacked_gpc:
1019
+ m = np.vstack(stacked_gpc)
1020
+ m_raw = np.vstack(stacked_gpc_raw)
1021
+ panels.append(
1022
+ (
1023
+ "GpC",
1024
+ m,
1025
+ gpc_labels,
1026
+ cmap_gpc,
1027
+ _methylation_fraction_for_layer(m_raw, layer_gpc),
1028
+ n_xticks_gpc,
1029
+ )
1030
+ )
1031
+
1032
+ if stacked_cpg:
1033
+ m = np.vstack(stacked_cpg)
1034
+ m_raw = np.vstack(stacked_cpg_raw)
1035
+ panels.append(
1036
+ (
1037
+ "CpG",
1038
+ m,
1039
+ cpg_labels,
1040
+ cmap_cpg,
1041
+ _methylation_fraction_for_layer(m_raw, layer_cpg),
1042
+ n_xticks_cpg,
1043
+ )
1044
+ )
1045
+
1046
+ if stacked_any_a:
1047
+ m = np.vstack(stacked_any_a)
1048
+ m_raw = np.vstack(stacked_any_a_raw)
1049
+ panels.append(
1050
+ (
1051
+ "A",
1052
+ m,
1053
+ any_a_labels,
1054
+ cmap_a,
1055
+ _methylation_fraction_for_layer(m_raw, layer_a),
1056
+ n_xticks_a,
1057
+ )
1058
+ )
1059
+
1060
+ # ---------------- plotting ----------------
1061
+ n_panels = len(panels)
1062
+ fig = plt.figure(figsize=(4.5 * n_panels, 10))
1063
+ gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
1064
+ fig.suptitle(
1065
+ f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
1066
+ )
1067
+
1068
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
1069
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
1070
+
1071
+ for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
1072
+ # ---- your clean barplot ----
1073
+ clean_barplot(axes_bar[i], mean_vec, name)
1074
+
1075
+ # ---- heatmap ----
1076
+ heatmap_kwargs = dict(
1077
+ cmap=cmap,
1078
+ ax=axes_heat[i],
1079
+ yticklabels=False,
1080
+ cbar=False,
1081
+ )
1082
+ if name.startswith("HMM -"):
1083
+ heatmap_kwargs.update(vmin=0.0, vmax=1.0)
1084
+ sns.heatmap(matrix, **heatmap_kwargs)
1085
+
1086
+ # ---- xticks ----
1087
+ xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
1088
+ axes_heat[i].set_xticks(xtick_pos)
1089
+ axes_heat[i].set_xticklabels(xtick_labels, rotation=90, fontsize=8)
1090
+
1091
+ for boundary in bin_boundaries[:-1]:
1092
+ axes_heat[i].axhline(y=boundary, color="black", linewidth=1.2)
1093
+
1094
+ if overlay_variant_calls and ordered_obs_names:
1095
+ try:
1096
+ # Map panel sites from subset-local coordinates to full adata indices
1097
+ hmm_sites_global = _local_sites_to_global_indices(adata, subset, hmm_sites)
1098
+ any_c_sites_global = _local_sites_to_global_indices(
1099
+ adata, subset, any_c_sites
1100
+ )
1101
+ gpc_sites_global = _local_sites_to_global_indices(adata, subset, gpc_sites)
1102
+ cpg_sites_global = _local_sites_to_global_indices(adata, subset, cpg_sites)
1103
+ any_a_sites_global = _local_sites_to_global_indices(
1104
+ adata, subset, any_a_sites
1105
+ )
1106
+
1107
+ # Build panels_with_indices using site indices for each panel
1108
+ # Map panel names to their site index arrays
1109
+ name_to_sites = {
1110
+ f"HMM - {hmm_feature_layer}": hmm_sites_global,
1111
+ "C": any_c_sites_global,
1112
+ "GpC": gpc_sites_global,
1113
+ "CpG": cpg_sites_global,
1114
+ "A": any_a_sites_global,
1115
+ }
1116
+ panels_with_indices = []
1117
+ for idx, (name, *_rest) in enumerate(panels):
1118
+ sites = name_to_sites.get(name)
1119
+ if sites is not None and len(sites) > 0:
1120
+ panels_with_indices.append((axes_heat[idx], sites))
1121
+ if panels_with_indices:
1122
+ _overlay_variant_calls_on_panels(
1123
+ adata,
1124
+ ref,
1125
+ ordered_obs_names,
1126
+ panels_with_indices,
1127
+ seq1_color=variant_overlay_seq1_color,
1128
+ seq2_color=variant_overlay_seq2_color,
1129
+ marker_size=variant_overlay_marker_size,
1130
+ )
1131
+ except Exception as overlay_err:
1132
+ logger.warning(
1133
+ "Variant overlay failed for %s - %s: %s", sample, ref, overlay_err
1134
+ )
1135
+
1136
+ plt.tight_layout()
1137
+
1138
+ if save_path:
1139
+ save_path = Path(save_path)
1140
+ save_path.mkdir(parents=True, exist_ok=True)
1141
+ safe_name = f"{ref}__{sample}".replace("/", "_")
1142
+ out_file = save_path / f"{safe_name}.png"
1143
+ plt.savefig(out_file, dpi=300)
1144
+ plt.close(fig)
1145
+ else:
1146
+ plt.show()
1147
+
1148
+ except Exception:
1149
+ import traceback
1150
+
1151
+ traceback.print_exc()
1152
+ continue
1153
+
1154
+
1155
+ def combined_hmm_length_clustermap(
1156
+ adata,
1157
+ sample_col: str = "Sample_Names",
1158
+ reference_col: str = "Reference_strand",
1159
+ length_layer: str = "hmm_combined_lengths",
1160
+ layer_gpc: str = "nan0_0minus1",
1161
+ layer_cpg: str = "nan0_0minus1",
1162
+ layer_c: str = "nan0_0minus1",
1163
+ layer_a: str = "nan0_0minus1",
1164
+ cmap_lengths: Any = "Greens",
1165
+ cmap_gpc: str = "coolwarm",
1166
+ cmap_cpg: str = "viridis",
1167
+ cmap_c: str = "coolwarm",
1168
+ cmap_a: str = "coolwarm",
1169
+ min_quality: int = 20,
1170
+ min_length: int = 200,
1171
+ min_mapped_length_to_reference_length_ratio: float = 0.8,
1172
+ min_position_valid_fraction: float = 0.5,
1173
+ demux_types: Sequence[str] = ("single", "double", "already"),
1174
+ sample_mapping: Optional[Mapping[str, str]] = None,
1175
+ save_path: str | Path | None = None,
1176
+ sort_by: str = "gpc",
1177
+ bins: Optional[Dict[str, Any]] = None,
1178
+ deaminase: bool = False,
1179
+ min_signal: float = 0.0,
1180
+ n_xticks_lengths: int = 10,
1181
+ n_xticks_any_c: int = 8,
1182
+ n_xticks_gpc: int = 8,
1183
+ n_xticks_cpg: int = 8,
1184
+ n_xticks_a: int = 8,
1185
+ index_col_suffix: str | None = None,
1186
+ fill_nan_strategy: str = "value",
1187
+ fill_nan_value: float = -1,
1188
+ length_feature_ranges: Optional[Sequence[Tuple[int, int, Any]]] = None,
1189
+ overlay_variant_calls: bool = False,
1190
+ variant_overlay_seq1_color: str = "white",
1191
+ variant_overlay_seq2_color: str = "black",
1192
+ variant_overlay_marker_size: float = 4.0,
1193
+ ):
1194
+ """
1195
+ Plot clustermaps for length-encoded HMM feature layers with optional subclass colors.
1196
+
1197
+ Length-based feature ranges map integer lengths into subclass colors for accessible
1198
+ and footprint layers. Raw methylation panels are included when available.
1199
+ """
1200
+ if fill_nan_strategy not in {"none", "value", "col_mean"}:
1201
+ raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
1202
+
1203
+ def pick_xticks(labels: np.ndarray, n_ticks: int):
1204
+ """Pick tick indices/labels from an array."""
1205
+ if labels.size == 0:
1206
+ return [], []
1207
+ idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
1208
+ idx = np.unique(idx)
1209
+ return idx.tolist(), labels[idx].tolist()
1210
+
1211
+ def _mask_or_true(series_name: str, predicate):
1212
+ """Return a mask from predicate or an all-True mask."""
1213
+ if series_name not in adata.obs:
1214
+ return pd.Series(True, index=adata.obs.index)
1215
+ s = adata.obs[series_name]
1216
+ try:
1217
+ return predicate(s)
1218
+ except Exception:
1219
+ return pd.Series(True, index=adata.obs.index)
1220
+
1221
+ results = []
1222
+ signal_type = "deamination" if deaminase else "methylation"
1223
+ feature_ranges = tuple(length_feature_ranges or ())
1224
+
1225
+ for ref in adata.obs[reference_col].cat.categories:
1226
+ for sample in adata.obs[sample_col].cat.categories:
1227
+ display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
1228
+ qmask = _mask_or_true(
1229
+ "read_quality",
1230
+ (lambda s: s >= float(min_quality))
1231
+ if (min_quality is not None)
1232
+ else (lambda s: pd.Series(True, index=s.index)),
1233
+ )
1234
+ lm_mask = _mask_or_true(
1235
+ "mapped_length",
1236
+ (lambda s: s >= float(min_length))
1237
+ if (min_length is not None)
1238
+ else (lambda s: pd.Series(True, index=s.index)),
1239
+ )
1240
+ lrr_mask = _mask_or_true(
1241
+ "mapped_length_to_reference_length_ratio",
1242
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
1243
+ if (min_mapped_length_to_reference_length_ratio is not None)
1244
+ else (lambda s: pd.Series(True, index=s.index)),
1245
+ )
1246
+
1247
+ demux_mask = _mask_or_true(
1248
+ "demux_type",
1249
+ (lambda s: s.astype("string").isin(list(demux_types)))
1250
+ if (demux_types is not None)
1251
+ else (lambda s: pd.Series(True, index=s.index)),
1252
+ )
1253
+
1254
+ ref_mask = adata.obs[reference_col] == ref
1255
+ sample_mask = adata.obs[sample_col] == sample
1256
+
1257
+ row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
1258
+
1259
+ if not bool(row_mask.any()):
1260
+ print(
1261
+ f"No reads for {display_sample} - {ref} after read quality and length filtering"
1262
+ )
1263
+ continue
1264
+
1265
+ try:
1266
+ subset = adata[row_mask, :].copy()
1267
+
1268
+ if min_position_valid_fraction is not None:
1269
+ valid_key = f"{ref}_valid_fraction"
1270
+ if valid_key in subset.var:
1271
+ v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
1272
+ col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
1273
+ if col_mask.any():
1274
+ subset = subset[:, col_mask].copy()
1275
+ else:
1276
+ print(
1277
+ f"No positions left after valid_fraction filter for {display_sample} - {ref}"
1278
+ )
1279
+ continue
1280
+
1281
+ if subset.shape[0] == 0:
1282
+ print(f"No reads left after filtering for {display_sample} - {ref}")
1283
+ continue
1284
+
1285
+ if bins is None:
1286
+ bins_temp = {"All": np.ones(subset.n_obs, dtype=bool)}
1287
+ else:
1288
+ bins_temp = bins
1289
+
1290
+ def _sites(*keys):
1291
+ """Return indices for the first matching site key."""
1292
+ for k in keys:
1293
+ if k in subset.var:
1294
+ return np.where(subset.var[k].values)[0]
1295
+ return np.array([], dtype=int)
1296
+
1297
+ gpc_sites = _sites(f"{ref}_GpC_site")
1298
+ cpg_sites = _sites(f"{ref}_CpG_site")
1299
+ any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
1300
+ any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
1301
+
1302
+ length_sites = np.arange(subset.n_vars, dtype=int)
1303
+ length_labels = _select_labels(subset, length_sites, ref, index_col_suffix)
1304
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
1305
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
1306
+ any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
1307
+ any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
1308
+
1309
+ stacked_lengths = []
1310
+ stacked_lengths_raw = []
1311
+ stacked_any_c = []
1312
+ stacked_any_c_raw = []
1313
+ stacked_gpc = []
1314
+ stacked_gpc_raw = []
1315
+ stacked_cpg = []
1316
+ stacked_cpg_raw = []
1317
+ stacked_any_a = []
1318
+ stacked_any_a_raw = []
1319
+ ordered_obs_names = []
1320
+
1321
+ row_labels, bin_labels, bin_boundaries = [], [], []
1322
+ total_reads = subset.n_obs
1323
+ percentages = {}
1324
+ last_idx = 0
1325
+
1326
+ for bin_label, bin_filter in bins_temp.items():
1327
+ sb = subset[bin_filter].copy()
1328
+ n = sb.n_obs
1329
+ if n == 0:
1330
+ continue
1331
+
1332
+ pct = (n / total_reads) * 100 if total_reads else 0
1333
+ percentages[bin_label] = pct
1334
+
1335
+ if sort_by.startswith("obs:"):
1336
+ colname = sort_by.split("obs:")[1]
1337
+ order = np.argsort(sb.obs[colname].values)
1338
+ elif sort_by == "gpc" and gpc_sites.size:
1339
+ gpc_matrix = _layer_to_numpy(
1340
+ sb,
1341
+ layer_gpc,
1342
+ gpc_sites,
1343
+ fill_nan_strategy=fill_nan_strategy,
1344
+ fill_nan_value=fill_nan_value,
1345
+ )
1346
+ linkage = sch.linkage(gpc_matrix, method="ward")
1347
+ order = sch.leaves_list(linkage)
1348
+ elif sort_by == "cpg" and cpg_sites.size:
1349
+ cpg_matrix = _layer_to_numpy(
1350
+ sb,
1351
+ layer_cpg,
1352
+ cpg_sites,
1353
+ fill_nan_strategy=fill_nan_strategy,
1354
+ fill_nan_value=fill_nan_value,
1355
+ )
1356
+ linkage = sch.linkage(cpg_matrix, method="ward")
1357
+ order = sch.leaves_list(linkage)
1358
+ elif sort_by == "c" and any_c_sites.size:
1359
+ any_c_matrix = _layer_to_numpy(
1360
+ sb,
1361
+ layer_c,
1362
+ any_c_sites,
1363
+ fill_nan_strategy=fill_nan_strategy,
1364
+ fill_nan_value=fill_nan_value,
1365
+ )
1366
+ linkage = sch.linkage(any_c_matrix, method="ward")
1367
+ order = sch.leaves_list(linkage)
1368
+ elif sort_by == "a" and any_a_sites.size:
1369
+ any_a_matrix = _layer_to_numpy(
1370
+ sb,
1371
+ layer_a,
1372
+ any_a_sites,
1373
+ fill_nan_strategy=fill_nan_strategy,
1374
+ fill_nan_value=fill_nan_value,
1375
+ )
1376
+ linkage = sch.linkage(any_a_matrix, method="ward")
1377
+ order = sch.leaves_list(linkage)
1378
+ elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
1379
+ gpc_matrix = _layer_to_numpy(
1380
+ sb,
1381
+ layer_gpc,
1382
+ None,
1383
+ fill_nan_strategy=fill_nan_strategy,
1384
+ fill_nan_value=fill_nan_value,
1385
+ )
1386
+ linkage = sch.linkage(gpc_matrix, method="ward")
1387
+ order = sch.leaves_list(linkage)
1388
+ elif sort_by == "hmm" and length_sites.size:
1389
+ length_matrix = _layer_to_numpy(
1390
+ sb,
1391
+ length_layer,
1392
+ length_sites,
1393
+ fill_nan_strategy=fill_nan_strategy,
1394
+ fill_nan_value=fill_nan_value,
1395
+ )
1396
+ linkage = sch.linkage(length_matrix, method="ward")
1397
+ order = sch.leaves_list(linkage)
1398
+ else:
1399
+ order = np.arange(n)
1400
+
1401
+ sb = sb[order]
1402
+ ordered_obs_names.extend(sb.obs_names.tolist())
1403
+
1404
+ stacked_lengths.append(
1405
+ _layer_to_numpy(
1406
+ sb,
1407
+ length_layer,
1408
+ None,
1409
+ fill_nan_strategy=fill_nan_strategy,
1410
+ fill_nan_value=fill_nan_value,
1411
+ )
1412
+ )
1413
+ stacked_lengths_raw.append(
1414
+ _layer_to_numpy(
1415
+ sb,
1416
+ length_layer,
1417
+ None,
1418
+ fill_nan_strategy="none",
1419
+ fill_nan_value=fill_nan_value,
1420
+ )
1421
+ )
1422
+ if any_c_sites.size:
1423
+ stacked_any_c.append(
1424
+ _layer_to_numpy(
1425
+ sb,
1426
+ layer_c,
1427
+ any_c_sites,
1428
+ fill_nan_strategy=fill_nan_strategy,
1429
+ fill_nan_value=fill_nan_value,
1430
+ )
1431
+ )
1432
+ stacked_any_c_raw.append(
1433
+ _layer_to_numpy(
1434
+ sb,
1435
+ layer_c,
1436
+ any_c_sites,
1437
+ fill_nan_strategy="none",
1438
+ fill_nan_value=fill_nan_value,
1439
+ )
1440
+ )
1441
+ if gpc_sites.size:
1442
+ stacked_gpc.append(
1443
+ _layer_to_numpy(
1444
+ sb,
1445
+ layer_gpc,
1446
+ gpc_sites,
1447
+ fill_nan_strategy=fill_nan_strategy,
1448
+ fill_nan_value=fill_nan_value,
1449
+ )
1450
+ )
1451
+ stacked_gpc_raw.append(
1452
+ _layer_to_numpy(
1453
+ sb,
1454
+ layer_gpc,
1455
+ gpc_sites,
1456
+ fill_nan_strategy="none",
1457
+ fill_nan_value=fill_nan_value,
1458
+ )
1459
+ )
1460
+ if cpg_sites.size:
1461
+ stacked_cpg.append(
1462
+ _layer_to_numpy(
1463
+ sb,
1464
+ layer_cpg,
1465
+ cpg_sites,
1466
+ fill_nan_strategy=fill_nan_strategy,
1467
+ fill_nan_value=fill_nan_value,
1468
+ )
1469
+ )
1470
+ stacked_cpg_raw.append(
1471
+ _layer_to_numpy(
1472
+ sb,
1473
+ layer_cpg,
1474
+ cpg_sites,
1475
+ fill_nan_strategy="none",
1476
+ fill_nan_value=fill_nan_value,
1477
+ )
1478
+ )
1479
+ if any_a_sites.size:
1480
+ stacked_any_a.append(
1481
+ _layer_to_numpy(
1482
+ sb,
1483
+ layer_a,
1484
+ any_a_sites,
1485
+ fill_nan_strategy=fill_nan_strategy,
1486
+ fill_nan_value=fill_nan_value,
1487
+ )
1488
+ )
1489
+ stacked_any_a_raw.append(
1490
+ _layer_to_numpy(
1491
+ sb,
1492
+ layer_a,
1493
+ any_a_sites,
1494
+ fill_nan_strategy="none",
1495
+ fill_nan_value=fill_nan_value,
1496
+ )
1497
+ )
1498
+
1499
+ row_labels.extend([bin_label] * n)
1500
+ bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
1501
+ last_idx += n
1502
+ bin_boundaries.append(last_idx)
1503
+
1504
+ length_matrix = np.vstack(stacked_lengths)
1505
+ length_matrix_raw = np.vstack(stacked_lengths_raw)
1506
+ capped_lengths = np.where(length_matrix_raw > 1, 1.0, length_matrix_raw)
1507
+ mean_lengths = np.nanmean(capped_lengths, axis=0)
1508
+ length_plot_matrix = length_matrix_raw
1509
+ length_plot_cmap = cmap_lengths
1510
+ length_plot_norm = None
1511
+
1512
+ if feature_ranges:
1513
+ length_plot_matrix = _map_length_matrix_to_subclasses(
1514
+ length_matrix_raw, feature_ranges
1515
+ )
1516
+ length_plot_cmap, length_plot_norm = _build_length_feature_cmap(feature_ranges)
1517
+
1518
+ panels = [
1519
+ (
1520
+ f"HMM lengths - {length_layer}",
1521
+ length_plot_matrix,
1522
+ length_labels,
1523
+ length_plot_cmap,
1524
+ mean_lengths,
1525
+ n_xticks_lengths,
1526
+ length_plot_norm,
1527
+ ),
1528
+ ]
1529
+
1530
+ if stacked_any_c:
1531
+ m = np.vstack(stacked_any_c)
1532
+ m_raw = np.vstack(stacked_any_c_raw)
1533
+ panels.append(
1534
+ (
1535
+ "C",
1536
+ m,
1537
+ any_c_labels,
1538
+ cmap_c,
1539
+ _methylation_fraction_for_layer(m_raw, layer_c),
1540
+ n_xticks_any_c,
1541
+ None,
1542
+ )
1543
+ )
1544
+
1545
+ if stacked_gpc:
1546
+ m = np.vstack(stacked_gpc)
1547
+ m_raw = np.vstack(stacked_gpc_raw)
1548
+ panels.append(
1549
+ (
1550
+ "GpC",
1551
+ m,
1552
+ gpc_labels,
1553
+ cmap_gpc,
1554
+ _methylation_fraction_for_layer(m_raw, layer_gpc),
1555
+ n_xticks_gpc,
1556
+ None,
1557
+ )
1558
+ )
1559
+
1560
+ if stacked_cpg:
1561
+ m = np.vstack(stacked_cpg)
1562
+ m_raw = np.vstack(stacked_cpg_raw)
1563
+ panels.append(
1564
+ (
1565
+ "CpG",
1566
+ m,
1567
+ cpg_labels,
1568
+ cmap_cpg,
1569
+ _methylation_fraction_for_layer(m_raw, layer_cpg),
1570
+ n_xticks_cpg,
1571
+ None,
1572
+ )
1573
+ )
1574
+
1575
+ if stacked_any_a:
1576
+ m = np.vstack(stacked_any_a)
1577
+ m_raw = np.vstack(stacked_any_a_raw)
1578
+ panels.append(
1579
+ (
1580
+ "A",
1581
+ m,
1582
+ any_a_labels,
1583
+ cmap_a,
1584
+ _methylation_fraction_for_layer(m_raw, layer_a),
1585
+ n_xticks_a,
1586
+ None,
1587
+ )
1588
+ )
1589
+
1590
+ n_panels = len(panels)
1591
+ fig = plt.figure(figsize=(4.5 * n_panels, 10))
1592
+ gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
1593
+ fig.suptitle(
1594
+ f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
1595
+ )
1596
+
1597
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
1598
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
1599
+
1600
+ for i, (name, matrix, labels, cmap, mean_vec, n_ticks, norm) in enumerate(panels):
1601
+ clean_barplot(axes_bar[i], mean_vec, name)
1602
+
1603
+ heatmap_kwargs = dict(
1604
+ cmap=cmap,
1605
+ ax=axes_heat[i],
1606
+ yticklabels=False,
1607
+ cbar=False,
1608
+ )
1609
+ if norm is not None:
1610
+ heatmap_kwargs["norm"] = norm
1611
+ sns.heatmap(matrix, **heatmap_kwargs)
1612
+
1613
+ xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
1614
+ axes_heat[i].set_xticks(xtick_pos)
1615
+ axes_heat[i].set_xticklabels(xtick_labels, rotation=90, fontsize=8)
1616
+
1617
+ for boundary in bin_boundaries[:-1]:
1618
+ axes_heat[i].axhline(y=boundary, color="black", linewidth=1.2)
1619
+
1620
+ if overlay_variant_calls and ordered_obs_names:
1621
+ try:
1622
+ # Map panel sites from subset-local coordinates to full adata indices
1623
+ length_sites_global = _local_sites_to_global_indices(
1624
+ adata, subset, length_sites
1625
+ )
1626
+ any_c_sites_global = _local_sites_to_global_indices(
1627
+ adata, subset, any_c_sites
1628
+ )
1629
+ gpc_sites_global = _local_sites_to_global_indices(adata, subset, gpc_sites)
1630
+ cpg_sites_global = _local_sites_to_global_indices(adata, subset, cpg_sites)
1631
+ any_a_sites_global = _local_sites_to_global_indices(
1632
+ adata, subset, any_a_sites
1633
+ )
1634
+
1635
+ # Build panels_with_indices using site indices for each panel
1636
+ name_to_sites = {
1637
+ f"HMM lengths - {length_layer}": length_sites_global,
1638
+ "C": any_c_sites_global,
1639
+ "GpC": gpc_sites_global,
1640
+ "CpG": cpg_sites_global,
1641
+ "A": any_a_sites_global,
1642
+ }
1643
+ panels_with_indices = []
1644
+ for idx, (name, *_rest) in enumerate(panels):
1645
+ sites = name_to_sites.get(name)
1646
+ if sites is not None and len(sites) > 0:
1647
+ panels_with_indices.append((axes_heat[idx], sites))
1648
+ if panels_with_indices:
1649
+ _overlay_variant_calls_on_panels(
1650
+ adata,
1651
+ ref,
1652
+ ordered_obs_names,
1653
+ panels_with_indices,
1654
+ seq1_color=variant_overlay_seq1_color,
1655
+ seq2_color=variant_overlay_seq2_color,
1656
+ marker_size=variant_overlay_marker_size,
1657
+ )
1658
+ except Exception as overlay_err:
1659
+ logger.warning(
1660
+ "Variant overlay failed for %s - %s: %s", sample, ref, overlay_err
1661
+ )
1662
+
1663
+ plt.tight_layout()
1664
+
1665
+ if save_path:
1666
+ save_path = Path(save_path)
1667
+ save_path.mkdir(parents=True, exist_ok=True)
1668
+ safe_name = f"{ref}__{sample}".replace("/", "_")
1669
+ out_file = save_path / f"{safe_name}.png"
1670
+ plt.savefig(out_file, dpi=300)
1671
+ plt.close(fig)
1672
+ else:
1673
+ plt.show()
1674
+
1675
+ results.append((sample, ref))
1676
+
1677
+ except Exception:
1678
+ import traceback
1679
+
1680
+ traceback.print_exc()
1681
+ print(f"Failed {sample} - {ref} - {length_layer}")
1682
+
1683
+ return results
1684
+
1685
+
1686
+ def plot_hmm_layers_rolling_by_sample_ref(
1687
+ adata,
1688
+ layers: Optional[Sequence[str]] = None,
1689
+ sample_col: str = "Barcode",
1690
+ ref_col: str = "Reference_strand",
1691
+ samples: Optional[Sequence[str]] = None,
1692
+ references: Optional[Sequence[str]] = None,
1693
+ window: int = 51,
1694
+ min_periods: int = 1,
1695
+ center: bool = True,
1696
+ rows_per_page: int = 6,
1697
+ figsize_per_cell: Tuple[float, float] = (4.0, 2.5),
1698
+ dpi: int = 160,
1699
+ output_dir: Optional[str] = None,
1700
+ save: bool = True,
1701
+ show_raw: bool = False,
1702
+ cmap: str = "tab20",
1703
+ layer_colors: Optional[Mapping[str, Any]] = None,
1704
+ use_var_coords: bool = True,
1705
+ reindexed_var_suffix: str = "reindexed",
1706
+ ):
1707
+ """
1708
+ For each sample (row) and reference (col) plot the rolling average of the
1709
+ positional mean (mean across reads) for each layer listed.
1710
+
1711
+ Parameters
1712
+ ----------
1713
+ adata : AnnData
1714
+ Input annotated data (expects obs columns sample_col and ref_col).
1715
+ layers : list[str] | None
1716
+ Which adata.layers to plot. If None, attempts to autodetect layers whose
1717
+ matrices look like "HMM" outputs (else will error). If None and layers
1718
+ cannot be found, user must pass a list.
1719
+ sample_col, ref_col : str
1720
+ obs columns used to group rows.
1721
+ samples, references : optional lists
1722
+ explicit ordering of samples / references. If None, categories in adata.obs are used.
1723
+ window : int
1724
+ rolling window size (odd recommended). If window <= 1, no smoothing applied.
1725
+ min_periods : int
1726
+ min periods param for pd.Series.rolling.
1727
+ center : bool
1728
+ center the rolling window.
1729
+ rows_per_page : int
1730
+ paginate rows per page into multiple figures if needed.
1731
+ figsize_per_cell : (w,h)
1732
+ per-subplot size in inches.
1733
+ dpi : int
1734
+ figure dpi when saving.
1735
+ output_dir : str | None
1736
+ directory to save pages; created if necessary. If None and save=True, uses cwd.
1737
+ save : bool
1738
+ whether to save PNG files.
1739
+ show_raw : bool
1740
+ draw unsmoothed mean as faint line under smoothed curve.
1741
+ cmap : str
1742
+ matplotlib colormap for layer lines.
1743
+ layer_colors : dict[str, Any] | None
1744
+ Optional mapping of layer name to explicit line colors.
1745
+ use_var_coords : bool
1746
+ if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
1747
+ reindexed_var_suffix : str
1748
+ Suffix for per-reference reindexed var columns (e.g., ``Reference_reindexed``) used when available.
1749
+
1750
+ Returns
1751
+ -------
1752
+ saved_files : list[str]
1753
+ list of saved filenames (may be empty if save=False).
1754
+ """
1755
+ logger.info("Plotting rolling HMM layers by sample/ref.")
1756
+
1757
+ if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
1758
+ raise ValueError(
1759
+ f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
1760
+ )
1761
+
1762
+ if samples is None:
1763
+ sseries = adata.obs[sample_col]
1764
+ if not isinstance(sseries.dtype, pd.CategoricalDtype):
1765
+ sseries = sseries.astype("category")
1766
+ samples_all = list(sseries.cat.categories)
1767
+ else:
1768
+ samples_all = list(samples)
1769
+
1770
+ if references is None:
1771
+ rseries = adata.obs[ref_col]
1772
+ if not isinstance(rseries.dtype, pd.CategoricalDtype):
1773
+ rseries = rseries.astype("category")
1774
+ refs_all = list(rseries.cat.categories)
1775
+ else:
1776
+ refs_all = list(references)
1777
+
1778
+ if layers is None:
1779
+ layers = list(adata.layers.keys())
1780
+ if len(layers) == 0:
1781
+ raise ValueError(
1782
+ "No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
1783
+ )
1784
+ layers = list(layers)
1785
+
1786
+ x_labels = None
1787
+ try:
1788
+ if use_var_coords:
1789
+ x_coords = np.array([int(v) for v in adata.var_names])
1790
+ else:
1791
+ raise Exception("user disabled var coords")
1792
+ except Exception:
1793
+ x_coords = np.arange(adata.shape[1], dtype=int)
1794
+ x_labels = adata.var_names.astype(str).tolist()
1795
+
1796
+ ref_reindexed_cols = {
1797
+ ref: f"{ref}_{reindexed_var_suffix}"
1798
+ for ref in refs_all
1799
+ if f"{ref}_{reindexed_var_suffix}" in adata.var
1800
+ }
1801
+
1802
+ if save:
1803
+ outdir = output_dir or os.getcwd()
1804
+ os.makedirs(outdir, exist_ok=True)
1805
+ else:
1806
+ outdir = None
1807
+
1808
+ n_samples = len(samples_all)
1809
+ n_refs = len(refs_all)
1810
+ total_pages = math.ceil(n_samples / rows_per_page)
1811
+ saved_files = []
1812
+
1813
+ cmap_obj = plt.get_cmap(cmap)
1814
+ n_layers = max(1, len(layers))
1815
+ fallback_colors = [cmap_obj(i / max(1, n_layers - 1)) for i in range(n_layers)]
1816
+ layer_colors = layer_colors or {}
1817
+ colors = [layer_colors.get(layer, fallback_colors[idx]) for idx, layer in enumerate(layers)]
1818
+
1819
+ for page in range(total_pages):
1820
+ start = page * rows_per_page
1821
+ end = min(start + rows_per_page, n_samples)
1822
+ chunk = samples_all[start:end]
1823
+ nrows = len(chunk)
1824
+ ncols = n_refs
1825
+
1826
+ fig_w = figsize_per_cell[0] * ncols
1827
+ fig_h = figsize_per_cell[1] * nrows
1828
+ fig, axes = plt.subplots(
1829
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
1830
+ )
1831
+
1832
+ for r_idx, sample_name in enumerate(chunk):
1833
+ for c_idx, ref_name in enumerate(refs_all):
1834
+ ax = axes[r_idx][c_idx]
1835
+
1836
+ mask = (adata.obs[sample_col].values == sample_name) & (
1837
+ adata.obs[ref_col].values == ref_name
1838
+ )
1839
+ sub = adata[mask]
1840
+ if sub.n_obs == 0:
1841
+ ax.text(
1842
+ 0.5,
1843
+ 0.5,
1844
+ "No reads",
1845
+ ha="center",
1846
+ va="center",
1847
+ transform=ax.transAxes,
1848
+ color="gray",
1849
+ )
1850
+ ax.set_xticks([])
1851
+ ax.set_yticks([])
1852
+ if r_idx == 0:
1853
+ ax.set_title(str(ref_name), fontsize=9)
1854
+ if c_idx == 0:
1855
+ total_reads = int((adata.obs[sample_col] == sample_name).sum())
1856
+ ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
1857
+ continue
1858
+
1859
+ plotted_any = False
1860
+ reindexed_col = ref_reindexed_cols.get(ref_name)
1861
+ if reindexed_col is not None:
1862
+ try:
1863
+ ref_coords = np.asarray(adata.var[reindexed_col], dtype=int)
1864
+ except Exception:
1865
+ ref_coords = x_coords
1866
+ else:
1867
+ ref_coords = x_coords
1868
+ for li, layer in enumerate(layers):
1869
+ if layer in sub.layers:
1870
+ mat = sub.layers[layer]
1871
+ else:
1872
+ if layer == layers[0] and getattr(sub, "X", None) is not None:
1873
+ mat = sub.X
1874
+ else:
1875
+ continue
1876
+
1877
+ if hasattr(mat, "toarray"):
1878
+ try:
1879
+ arr = mat.toarray()
1880
+ except Exception:
1881
+ arr = np.asarray(mat)
1882
+ else:
1883
+ arr = np.asarray(mat)
1884
+
1885
+ if arr.size == 0 or arr.shape[1] == 0:
1886
+ continue
1887
+
1888
+ arr = arr.astype(float)
1889
+ with np.errstate(all="ignore"):
1890
+ col_mean = np.nanmean(arr, axis=0)
1891
+
1892
+ if np.all(np.isnan(col_mean)):
1893
+ continue
1894
+
1895
+ valid_mask = np.isfinite(col_mean)
1896
+
1897
+ if (window is None) or (window <= 1):
1898
+ smoothed = col_mean
1899
+ else:
1900
+ ser = pd.Series(col_mean)
1901
+ smoothed = (
1902
+ ser.rolling(window=window, min_periods=min_periods, center=center)
1903
+ .mean()
1904
+ .to_numpy()
1905
+ )
1906
+ smoothed = np.where(valid_mask, smoothed, np.nan)
1907
+
1908
+ L = len(col_mean)
1909
+ x = ref_coords[:L]
1910
+
1911
+ if show_raw:
1912
+ ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
1913
+
1914
+ ax.plot(
1915
+ x,
1916
+ smoothed[:L],
1917
+ label=layer,
1918
+ color=colors[li],
1919
+ linewidth=1.2,
1920
+ alpha=0.95,
1921
+ zorder=2,
1922
+ )
1923
+ plotted_any = True
1924
+
1925
+ if r_idx == 0:
1926
+ ax.set_title(str(ref_name), fontsize=9)
1927
+ if c_idx == 0:
1928
+ total_reads = int((adata.obs[sample_col] == sample_name).sum())
1929
+ ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
1930
+ if r_idx == nrows - 1:
1931
+ ax.set_xlabel("position", fontsize=8)
1932
+ if x_labels is not None and reindexed_col is None:
1933
+ max_ticks = 8
1934
+ tick_step = max(1, int(math.ceil(len(x_labels) / max_ticks)))
1935
+ tick_positions = x_coords[::tick_step]
1936
+ tick_labels = x_labels[::tick_step]
1937
+ ax.set_xticks(tick_positions)
1938
+ ax.set_xticklabels(tick_labels, fontsize=7, rotation=45, ha="right")
1939
+
1940
+ if (r_idx == 0 and c_idx == 0) and plotted_any:
1941
+ ax.legend(fontsize=7, loc="upper right")
1942
+
1943
+ ax.grid(True, alpha=0.2)
1944
+
1945
+ fig.suptitle(
1946
+ f"Rolling mean of layer positional means (window={window}) — page {page + 1}/{total_pages}",
1947
+ fontsize=11,
1948
+ y=0.995,
1949
+ )
1950
+ fig.tight_layout(rect=[0, 0, 1, 0.97])
1951
+
1952
+ if save:
1953
+ fname = os.path.join(outdir, f"hmm_layers_rolling_page{page + 1}.png")
1954
+ plt.savefig(fname, bbox_inches="tight", dpi=dpi)
1955
+ saved_files.append(fname)
1956
+ logger.info("Saved HMM layers rolling plot to %s.", fname)
1957
+ else:
1958
+ plt.show()
1959
+ plt.close(fig)
1960
+
1961
+ return saved_files