smftools 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/chimeric_adata.py +1563 -0
- smftools/cli/helpers.py +18 -2
- smftools/cli/hmm_adata.py +18 -1
- smftools/cli/latent_adata.py +522 -67
- smftools/cli/load_adata.py +2 -2
- smftools/cli/preprocess_adata.py +32 -93
- smftools/cli/recipes.py +26 -0
- smftools/cli/spatial_adata.py +23 -109
- smftools/cli/variant_adata.py +423 -0
- smftools/cli_entry.py +41 -5
- smftools/config/conversion.yaml +0 -10
- smftools/config/deaminase.yaml +3 -0
- smftools/config/default.yaml +49 -13
- smftools/config/experiment_config.py +96 -3
- smftools/constants.py +4 -0
- smftools/hmm/call_hmm_peaks.py +1 -1
- smftools/informatics/binarize_converted_base_identities.py +2 -89
- smftools/informatics/converted_BAM_to_adata.py +53 -13
- smftools/informatics/h5ad_functions.py +83 -0
- smftools/informatics/modkit_extract_to_adata.py +4 -0
- smftools/plotting/__init__.py +26 -12
- smftools/plotting/autocorrelation_plotting.py +22 -4
- smftools/plotting/chimeric_plotting.py +1893 -0
- smftools/plotting/classifiers.py +28 -14
- smftools/plotting/general_plotting.py +58 -3362
- smftools/plotting/hmm_plotting.py +1586 -2
- smftools/plotting/latent_plotting.py +804 -0
- smftools/plotting/plotting_utils.py +243 -0
- smftools/plotting/position_stats.py +16 -8
- smftools/plotting/preprocess_plotting.py +281 -0
- smftools/plotting/qc_plotting.py +8 -3
- smftools/plotting/spatial_plotting.py +1134 -0
- smftools/plotting/variant_plotting.py +1231 -0
- smftools/preprocessing/__init__.py +3 -0
- smftools/preprocessing/append_base_context.py +1 -1
- smftools/preprocessing/append_mismatch_frequency_sites.py +35 -6
- smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
- smftools/preprocessing/append_variant_call_layer.py +480 -0
- smftools/preprocessing/flag_duplicate_reads.py +4 -4
- smftools/preprocessing/invert_adata.py +1 -0
- smftools/readwrite.py +109 -85
- smftools/tools/__init__.py +6 -0
- smftools/tools/calculate_knn.py +121 -0
- smftools/tools/calculate_nmf.py +18 -7
- smftools/tools/calculate_pca.py +180 -0
- smftools/tools/calculate_umap.py +70 -154
- smftools/tools/position_stats.py +4 -4
- smftools/tools/rolling_nn_distance.py +640 -3
- smftools/tools/sequence_alignment.py +140 -0
- smftools/tools/tensor_factorization.py +52 -4
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/METADATA +3 -1
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/RECORD +56 -42
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,12 +1,27 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
-
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union
|
|
5
7
|
|
|
6
8
|
import numpy as np
|
|
7
9
|
import pandas as pd
|
|
10
|
+
import scipy.cluster.hierarchy as sch
|
|
8
11
|
|
|
12
|
+
from smftools.logging_utils import get_logger
|
|
9
13
|
from smftools.optional_imports import require
|
|
14
|
+
from smftools.plotting.plotting_utils import (
|
|
15
|
+
_layer_to_numpy,
|
|
16
|
+
_methylation_fraction_for_layer,
|
|
17
|
+
_select_labels,
|
|
18
|
+
clean_barplot,
|
|
19
|
+
normalized_mean,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
gridspec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
|
|
23
|
+
patches = require("matplotlib.patches", extra="plotting", purpose="plot rendering")
|
|
24
|
+
sns = require("seaborn", extra="plotting", purpose="plot styling")
|
|
10
25
|
|
|
11
26
|
plt = require("matplotlib.pyplot", extra="plotting", purpose="HMM plots")
|
|
12
27
|
mpl_colors = require("matplotlib.colors", extra="plotting", purpose="HMM plots")
|
|
@@ -17,6 +32,153 @@ pdf_backend = require(
|
|
|
17
32
|
)
|
|
18
33
|
PdfPages = pdf_backend.PdfPages
|
|
19
34
|
|
|
35
|
+
logger = get_logger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _local_sites_to_global_indices(
|
|
39
|
+
adata,
|
|
40
|
+
subset,
|
|
41
|
+
local_sites: np.ndarray,
|
|
42
|
+
) -> np.ndarray:
|
|
43
|
+
"""Translate subset-local column indices into global ``adata.var`` indices."""
|
|
44
|
+
local_sites = np.asarray(local_sites, dtype=int)
|
|
45
|
+
if local_sites.size == 0:
|
|
46
|
+
return local_sites
|
|
47
|
+
subset_to_global = adata.var_names.get_indexer(subset.var_names)
|
|
48
|
+
global_sites = subset_to_global[local_sites]
|
|
49
|
+
if np.any(global_sites < 0):
|
|
50
|
+
missing = int(np.sum(global_sites < 0))
|
|
51
|
+
logger.warning(
|
|
52
|
+
"Could not map %d plotted positions back to full var index; skipping those points.",
|
|
53
|
+
missing,
|
|
54
|
+
)
|
|
55
|
+
global_sites = global_sites[global_sites >= 0]
|
|
56
|
+
return global_sites
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _overlay_variant_calls_on_panels(
|
|
60
|
+
adata,
|
|
61
|
+
reference: str,
|
|
62
|
+
ordered_obs_names: list,
|
|
63
|
+
panels_with_indices: list,
|
|
64
|
+
seq1_color: str = "white",
|
|
65
|
+
seq2_color: str = "black",
|
|
66
|
+
marker_size: float = 4.0,
|
|
67
|
+
) -> bool:
|
|
68
|
+
"""
|
|
69
|
+
Overlay variant call circles on heatmap panels using nearest-neighbor mapping.
|
|
70
|
+
|
|
71
|
+
This function maps variant call column indices to the nearest displayed column
|
|
72
|
+
in each panel, using var index space for mapping. This handles both regular
|
|
73
|
+
var_names and reindexed label coordinates.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
adata : AnnData
|
|
78
|
+
The AnnData containing variant call layers (should be full adata, not subset).
|
|
79
|
+
reference : str
|
|
80
|
+
Reference name used to auto-detect the variant call layer.
|
|
81
|
+
ordered_obs_names : list
|
|
82
|
+
Obs names in display order (rows of the heatmap).
|
|
83
|
+
panels_with_indices : list of (ax, site_indices)
|
|
84
|
+
Each entry is a matplotlib axes and the var indices for that panel's columns.
|
|
85
|
+
seq1_color, seq2_color : str
|
|
86
|
+
Colors for seq1 (value 1) and seq2 (value 2) variant calls.
|
|
87
|
+
marker_size : float
|
|
88
|
+
Size of the circle markers.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
bool
|
|
93
|
+
True if overlay was applied to at least one panel.
|
|
94
|
+
"""
|
|
95
|
+
# Auto-detect variant call layer - find any layer ending with _variant_call
|
|
96
|
+
vc_layer_key = None
|
|
97
|
+
for key in adata.layers:
|
|
98
|
+
if key.endswith("_variant_call"):
|
|
99
|
+
vc_layer_key = key
|
|
100
|
+
break
|
|
101
|
+
|
|
102
|
+
if vc_layer_key is None:
|
|
103
|
+
return False
|
|
104
|
+
|
|
105
|
+
# Build row index mapping
|
|
106
|
+
obs_name_to_idx = {str(name): i for i, name in enumerate(adata.obs_names)}
|
|
107
|
+
common_obs = [str(name) for name in ordered_obs_names if str(name) in obs_name_to_idx]
|
|
108
|
+
if not common_obs:
|
|
109
|
+
return False
|
|
110
|
+
|
|
111
|
+
obs_idx = [obs_name_to_idx[name] for name in common_obs]
|
|
112
|
+
row_index_map = {str(name): i for i, name in enumerate(ordered_obs_names)}
|
|
113
|
+
heatmap_row_indices = np.array([row_index_map[name] for name in common_obs])
|
|
114
|
+
|
|
115
|
+
# Get variant call matrix for the ordered obs
|
|
116
|
+
vc_data = adata.layers[vc_layer_key]
|
|
117
|
+
if hasattr(vc_data, "toarray"):
|
|
118
|
+
vc_data = vc_data.toarray()
|
|
119
|
+
vc_matrix = np.asarray(vc_data)[obs_idx, :]
|
|
120
|
+
|
|
121
|
+
# Find columns with actual variant calls (value 1 or 2)
|
|
122
|
+
has_calls = np.isin(vc_matrix, [1, 2]).any(axis=0)
|
|
123
|
+
call_col_indices = np.where(has_calls)[0]
|
|
124
|
+
|
|
125
|
+
if len(call_col_indices) == 0:
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
call_sub = vc_matrix[:, call_col_indices]
|
|
129
|
+
|
|
130
|
+
applied = False
|
|
131
|
+
for ax, site_indices in panels_with_indices:
|
|
132
|
+
site_indices = np.asarray(site_indices)
|
|
133
|
+
if site_indices.size == 0:
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
# Use nearest-neighbor mapping in var index space
|
|
137
|
+
# site_indices are the var indices displayed in this panel (sorted by position)
|
|
138
|
+
# call_col_indices are var indices where calls exist
|
|
139
|
+
# Map each call to the nearest displayed site index
|
|
140
|
+
|
|
141
|
+
# Sort site indices for searchsorted
|
|
142
|
+
sorted_order = np.argsort(site_indices)
|
|
143
|
+
sorted_sites = site_indices[sorted_order]
|
|
144
|
+
|
|
145
|
+
# Find nearest site for each call
|
|
146
|
+
insert_idx = np.searchsorted(sorted_sites, call_col_indices)
|
|
147
|
+
insert_idx = np.clip(insert_idx, 0, len(sorted_sites) - 1)
|
|
148
|
+
left_idx = np.clip(insert_idx - 1, 0, len(sorted_sites) - 1)
|
|
149
|
+
|
|
150
|
+
dist_right = np.abs(sorted_sites[insert_idx].astype(float) - call_col_indices.astype(float))
|
|
151
|
+
dist_left = np.abs(sorted_sites[left_idx].astype(float) - call_col_indices.astype(float))
|
|
152
|
+
nearest_sorted = np.where(dist_left < dist_right, left_idx, insert_idx)
|
|
153
|
+
|
|
154
|
+
# Map back to original (unsorted) heatmap column positions
|
|
155
|
+
nearest_heatmap_col = sorted_order[nearest_sorted]
|
|
156
|
+
|
|
157
|
+
# Plot circles for each variant value
|
|
158
|
+
for call_val, color in [(1, seq1_color), (2, seq2_color)]:
|
|
159
|
+
local_rows, local_cols = np.where(call_sub == call_val)
|
|
160
|
+
if len(local_rows) == 0:
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
plot_y = heatmap_row_indices[local_rows]
|
|
164
|
+
plot_x = nearest_heatmap_col[local_cols]
|
|
165
|
+
|
|
166
|
+
ax.scatter(
|
|
167
|
+
plot_x + 0.5,
|
|
168
|
+
plot_y + 0.5,
|
|
169
|
+
c=color,
|
|
170
|
+
s=marker_size,
|
|
171
|
+
marker="o",
|
|
172
|
+
edgecolors="gray",
|
|
173
|
+
linewidths=0.3,
|
|
174
|
+
zorder=3,
|
|
175
|
+
)
|
|
176
|
+
applied = True
|
|
177
|
+
|
|
178
|
+
if applied:
|
|
179
|
+
logger.info("Overlaid variant calls from layer '%s'.", vc_layer_key)
|
|
180
|
+
return applied
|
|
181
|
+
|
|
20
182
|
|
|
21
183
|
def plot_hmm_size_contours(
|
|
22
184
|
adata,
|
|
@@ -57,6 +219,7 @@ def plot_hmm_size_contours(
|
|
|
57
219
|
Other args are the same as prior function.
|
|
58
220
|
"""
|
|
59
221
|
feature_ranges = tuple(feature_ranges or ())
|
|
222
|
+
logger.info("Plotting HMM size contours%s.", f" -> {save_path}" if save_path else "")
|
|
60
223
|
|
|
61
224
|
def _resolve_length_color(length: int, fallback: str) -> Tuple[float, float, float, float]:
|
|
62
225
|
for min_len, max_len, color in feature_ranges:
|
|
@@ -365,6 +528,7 @@ def plot_hmm_size_contours(
|
|
|
365
528
|
fname = f"hmm_size_page_{p + 1:03d}.png"
|
|
366
529
|
out = os.path.join(save_path, fname)
|
|
367
530
|
fig.savefig(out, dpi=dpi, bbox_inches="tight")
|
|
531
|
+
logger.info("Saved HMM size contour page to %s.", out)
|
|
368
532
|
|
|
369
533
|
# multipage PDF if requested
|
|
370
534
|
if save_path is not None and save_pdf:
|
|
@@ -372,6 +536,1426 @@ def plot_hmm_size_contours(
|
|
|
372
536
|
with PdfPages(pdf_file) as pp:
|
|
373
537
|
for fig in figs:
|
|
374
538
|
pp.savefig(fig, bbox_inches="tight")
|
|
375
|
-
|
|
539
|
+
logger.info("Saved HMM size contour PDF to %s.", pdf_file)
|
|
376
540
|
|
|
377
541
|
return figs
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def _resolve_feature_color(cmap: Any) -> Tuple[float, float, float, float]:
|
|
545
|
+
"""Resolve a representative feature color from a colormap or color spec."""
|
|
546
|
+
if isinstance(cmap, str):
|
|
547
|
+
try:
|
|
548
|
+
cmap_obj = plt.get_cmap(cmap)
|
|
549
|
+
return mpl_colors.to_rgba(cmap_obj(1.0))
|
|
550
|
+
except Exception:
|
|
551
|
+
return mpl_colors.to_rgba(cmap)
|
|
552
|
+
|
|
553
|
+
if isinstance(cmap, mpl_colors.Colormap):
|
|
554
|
+
if hasattr(cmap, "colors") and cmap.colors:
|
|
555
|
+
return mpl_colors.to_rgba(cmap.colors[-1])
|
|
556
|
+
return mpl_colors.to_rgba(cmap(1.0))
|
|
557
|
+
|
|
558
|
+
return mpl_colors.to_rgba("black")
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
def _build_hmm_feature_cmap(
|
|
562
|
+
cmap: Any,
|
|
563
|
+
*,
|
|
564
|
+
zero_color: str = "#f5f1e8",
|
|
565
|
+
nan_color: str = "#E6E6E6",
|
|
566
|
+
) -> mpl_colors.Colormap:
|
|
567
|
+
"""Build a two-color HMM colormap with explicit NaN/under handling."""
|
|
568
|
+
feature_color = _resolve_feature_color(cmap)
|
|
569
|
+
hmm_cmap = mpl_colors.LinearSegmentedColormap.from_list(
|
|
570
|
+
"hmm_feature_cmap",
|
|
571
|
+
[zero_color, feature_color],
|
|
572
|
+
)
|
|
573
|
+
hmm_cmap.set_bad(nan_color)
|
|
574
|
+
hmm_cmap.set_under(nan_color)
|
|
575
|
+
return hmm_cmap
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def _map_length_matrix_to_subclasses(
|
|
579
|
+
length_matrix: np.ndarray,
|
|
580
|
+
feature_ranges: Sequence[Tuple[int, int, Any]],
|
|
581
|
+
) -> np.ndarray:
|
|
582
|
+
"""Map length values into subclass integer codes based on feature ranges."""
|
|
583
|
+
mapped = np.zeros_like(length_matrix, dtype=float)
|
|
584
|
+
finite_mask = np.isfinite(length_matrix)
|
|
585
|
+
for idx, (min_len, max_len, _color) in enumerate(feature_ranges, start=1):
|
|
586
|
+
mask = finite_mask & (length_matrix >= min_len) & (length_matrix <= max_len)
|
|
587
|
+
mapped[mask] = float(idx)
|
|
588
|
+
mapped[~finite_mask] = np.nan
|
|
589
|
+
return mapped
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def _build_length_feature_cmap(
|
|
593
|
+
feature_ranges: Sequence[Tuple[int, int, Any]],
|
|
594
|
+
*,
|
|
595
|
+
zero_color: str = "#f5f1e8",
|
|
596
|
+
nan_color: str = "#E6E6E6",
|
|
597
|
+
) -> Tuple[mpl_colors.Colormap, mpl_colors.BoundaryNorm]:
|
|
598
|
+
"""Build a discrete colormap and norm for length-based subclasses."""
|
|
599
|
+
color_list = [zero_color] + [color for _, _, color in feature_ranges]
|
|
600
|
+
cmap = mpl_colors.ListedColormap(color_list, name="hmm_length_feature_cmap")
|
|
601
|
+
cmap.set_bad(nan_color)
|
|
602
|
+
bounds = np.arange(-0.5, len(color_list) + 0.5, 1)
|
|
603
|
+
norm = mpl_colors.BoundaryNorm(bounds, cmap.N)
|
|
604
|
+
return cmap, norm
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
def combined_hmm_raw_clustermap(
|
|
608
|
+
adata,
|
|
609
|
+
sample_col: str = "Sample_Names",
|
|
610
|
+
reference_col: str = "Reference_strand",
|
|
611
|
+
hmm_feature_layer: str = "hmm_combined",
|
|
612
|
+
layer_gpc: str = "nan0_0minus1",
|
|
613
|
+
layer_cpg: str = "nan0_0minus1",
|
|
614
|
+
layer_c: str = "nan0_0minus1",
|
|
615
|
+
layer_a: str = "nan0_0minus1",
|
|
616
|
+
cmap_hmm: str = "tab10",
|
|
617
|
+
cmap_gpc: str = "coolwarm",
|
|
618
|
+
cmap_cpg: str = "viridis",
|
|
619
|
+
cmap_c: str = "coolwarm",
|
|
620
|
+
cmap_a: str = "coolwarm",
|
|
621
|
+
min_quality: int = 20,
|
|
622
|
+
min_length: int = 200,
|
|
623
|
+
min_mapped_length_to_reference_length_ratio: float = 0.8,
|
|
624
|
+
min_position_valid_fraction: float = 0.5,
|
|
625
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
626
|
+
sample_mapping: Optional[Mapping[str, str]] = None,
|
|
627
|
+
save_path: str | Path | None = None,
|
|
628
|
+
normalize_hmm: bool = False,
|
|
629
|
+
sort_by: str = "gpc",
|
|
630
|
+
bins: Optional[Dict[str, Any]] = None,
|
|
631
|
+
deaminase: bool = False,
|
|
632
|
+
min_signal: float = 0.0,
|
|
633
|
+
# ---- fixed tick label controls (counts, not spacing)
|
|
634
|
+
n_xticks_hmm: int = 10,
|
|
635
|
+
n_xticks_any_c: int = 8,
|
|
636
|
+
n_xticks_gpc: int = 8,
|
|
637
|
+
n_xticks_cpg: int = 8,
|
|
638
|
+
n_xticks_a: int = 8,
|
|
639
|
+
index_col_suffix: str | None = None,
|
|
640
|
+
fill_nan_strategy: str = "value",
|
|
641
|
+
fill_nan_value: float = -1,
|
|
642
|
+
overlay_variant_calls: bool = False,
|
|
643
|
+
variant_overlay_seq1_color: str = "white",
|
|
644
|
+
variant_overlay_seq2_color: str = "black",
|
|
645
|
+
variant_overlay_marker_size: float = 4.0,
|
|
646
|
+
):
|
|
647
|
+
"""
|
|
648
|
+
Makes a multi-panel clustermap per (sample, reference):
|
|
649
|
+
HMM panel (always) + optional raw panels for C, GpC, CpG, and A sites.
|
|
650
|
+
|
|
651
|
+
Panels are added only if the corresponding site mask exists AND has >0 sites.
|
|
652
|
+
|
|
653
|
+
sort_by options:
|
|
654
|
+
'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
|
|
655
|
+
|
|
656
|
+
NaN fill strategy is applied in-memory for clustering/plotting only.
|
|
657
|
+
"""
|
|
658
|
+
if fill_nan_strategy not in {"none", "value", "col_mean"}:
|
|
659
|
+
raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
|
|
660
|
+
|
|
661
|
+
def pick_xticks(labels: np.ndarray, n_ticks: int):
|
|
662
|
+
"""Pick tick indices/labels from an array."""
|
|
663
|
+
if labels.size == 0:
|
|
664
|
+
return [], []
|
|
665
|
+
idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
|
|
666
|
+
idx = np.unique(idx)
|
|
667
|
+
return idx.tolist(), labels[idx].tolist()
|
|
668
|
+
|
|
669
|
+
# Helper: build a True mask if filter is inactive or column missing
|
|
670
|
+
def _mask_or_true(series_name: str, predicate):
|
|
671
|
+
"""Return a mask from predicate or an all-True mask."""
|
|
672
|
+
if series_name not in adata.obs:
|
|
673
|
+
return pd.Series(True, index=adata.obs.index)
|
|
674
|
+
s = adata.obs[series_name]
|
|
675
|
+
try:
|
|
676
|
+
return predicate(s)
|
|
677
|
+
except Exception:
|
|
678
|
+
# Fallback: all True if bad dtype / predicate failure
|
|
679
|
+
return pd.Series(True, index=adata.obs.index)
|
|
680
|
+
|
|
681
|
+
results = []
|
|
682
|
+
signal_type = "deamination" if deaminase else "methylation"
|
|
683
|
+
|
|
684
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
685
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
686
|
+
# Optionally remap sample label for display
|
|
687
|
+
display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
|
|
688
|
+
# Row-level masks (obs)
|
|
689
|
+
qmask = _mask_or_true(
|
|
690
|
+
"read_quality",
|
|
691
|
+
(lambda s: s >= float(min_quality))
|
|
692
|
+
if (min_quality is not None)
|
|
693
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
694
|
+
)
|
|
695
|
+
lm_mask = _mask_or_true(
|
|
696
|
+
"mapped_length",
|
|
697
|
+
(lambda s: s >= float(min_length))
|
|
698
|
+
if (min_length is not None)
|
|
699
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
700
|
+
)
|
|
701
|
+
lrr_mask = _mask_or_true(
|
|
702
|
+
"mapped_length_to_reference_length_ratio",
|
|
703
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
704
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
705
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
demux_mask = _mask_or_true(
|
|
709
|
+
"demux_type",
|
|
710
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
711
|
+
if (demux_types is not None)
|
|
712
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
ref_mask = adata.obs[reference_col] == ref
|
|
716
|
+
sample_mask = adata.obs[sample_col] == sample
|
|
717
|
+
|
|
718
|
+
row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
|
|
719
|
+
|
|
720
|
+
if not bool(row_mask.any()):
|
|
721
|
+
print(
|
|
722
|
+
f"No reads for {display_sample} - {ref} after read quality and length filtering"
|
|
723
|
+
)
|
|
724
|
+
continue
|
|
725
|
+
|
|
726
|
+
try:
|
|
727
|
+
# ---- subset reads ----
|
|
728
|
+
subset = adata[row_mask, :].copy()
|
|
729
|
+
|
|
730
|
+
# Column-level mask (var)
|
|
731
|
+
if min_position_valid_fraction is not None:
|
|
732
|
+
valid_key = f"{ref}_valid_fraction"
|
|
733
|
+
if valid_key in subset.var:
|
|
734
|
+
v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
|
|
735
|
+
col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
|
|
736
|
+
if col_mask.any():
|
|
737
|
+
subset = subset[:, col_mask].copy()
|
|
738
|
+
else:
|
|
739
|
+
print(
|
|
740
|
+
f"No positions left after valid_fraction filter for {display_sample} - {ref}"
|
|
741
|
+
)
|
|
742
|
+
continue
|
|
743
|
+
|
|
744
|
+
if subset.shape[0] == 0:
|
|
745
|
+
print(f"No reads left after filtering for {display_sample} - {ref}")
|
|
746
|
+
continue
|
|
747
|
+
|
|
748
|
+
# ---- bins ----
|
|
749
|
+
if bins is None:
|
|
750
|
+
bins_temp = {"All": np.ones(subset.n_obs, dtype=bool)}
|
|
751
|
+
else:
|
|
752
|
+
bins_temp = bins
|
|
753
|
+
|
|
754
|
+
# ---- site masks (robust) ----
|
|
755
|
+
def _sites(*keys):
|
|
756
|
+
"""Return indices for the first matching site key."""
|
|
757
|
+
for k in keys:
|
|
758
|
+
if k in subset.var:
|
|
759
|
+
return np.where(subset.var[k].values)[0]
|
|
760
|
+
return np.array([], dtype=int)
|
|
761
|
+
|
|
762
|
+
gpc_sites = _sites(f"{ref}_GpC_site")
|
|
763
|
+
cpg_sites = _sites(f"{ref}_CpG_site")
|
|
764
|
+
any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
|
|
765
|
+
any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
|
|
766
|
+
|
|
767
|
+
# ---- labels via _select_labels ----
|
|
768
|
+
# HMM uses *all* columns
|
|
769
|
+
hmm_sites = np.arange(subset.n_vars, dtype=int)
|
|
770
|
+
hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
|
|
771
|
+
gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
|
|
772
|
+
cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
|
|
773
|
+
any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
|
|
774
|
+
any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
|
|
775
|
+
|
|
776
|
+
# storage
|
|
777
|
+
stacked_hmm = []
|
|
778
|
+
stacked_hmm_raw = []
|
|
779
|
+
stacked_any_c = []
|
|
780
|
+
stacked_any_c_raw = []
|
|
781
|
+
stacked_gpc = []
|
|
782
|
+
stacked_gpc_raw = []
|
|
783
|
+
stacked_cpg = []
|
|
784
|
+
stacked_cpg_raw = []
|
|
785
|
+
stacked_any_a = []
|
|
786
|
+
stacked_any_a_raw = []
|
|
787
|
+
ordered_obs_names = []
|
|
788
|
+
|
|
789
|
+
row_labels, bin_labels, bin_boundaries = [], [], []
|
|
790
|
+
total_reads = subset.n_obs
|
|
791
|
+
percentages = {}
|
|
792
|
+
last_idx = 0
|
|
793
|
+
|
|
794
|
+
# ---------------- process bins ----------------
|
|
795
|
+
for bin_label, bin_filter in bins_temp.items():
|
|
796
|
+
sb = subset[bin_filter].copy()
|
|
797
|
+
n = sb.n_obs
|
|
798
|
+
if n == 0:
|
|
799
|
+
continue
|
|
800
|
+
|
|
801
|
+
pct = (n / total_reads) * 100 if total_reads else 0
|
|
802
|
+
percentages[bin_label] = pct
|
|
803
|
+
|
|
804
|
+
# ---- sorting ----
|
|
805
|
+
if sort_by.startswith("obs:"):
|
|
806
|
+
colname = sort_by.split("obs:")[1]
|
|
807
|
+
order = np.argsort(sb.obs[colname].values)
|
|
808
|
+
|
|
809
|
+
elif sort_by == "gpc" and gpc_sites.size:
|
|
810
|
+
gpc_matrix = _layer_to_numpy(
|
|
811
|
+
sb,
|
|
812
|
+
layer_gpc,
|
|
813
|
+
gpc_sites,
|
|
814
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
815
|
+
fill_nan_value=fill_nan_value,
|
|
816
|
+
)
|
|
817
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
818
|
+
order = sch.leaves_list(linkage)
|
|
819
|
+
|
|
820
|
+
elif sort_by == "cpg" and cpg_sites.size:
|
|
821
|
+
cpg_matrix = _layer_to_numpy(
|
|
822
|
+
sb,
|
|
823
|
+
layer_cpg,
|
|
824
|
+
cpg_sites,
|
|
825
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
826
|
+
fill_nan_value=fill_nan_value,
|
|
827
|
+
)
|
|
828
|
+
linkage = sch.linkage(cpg_matrix, method="ward")
|
|
829
|
+
order = sch.leaves_list(linkage)
|
|
830
|
+
|
|
831
|
+
elif sort_by == "c" and any_c_sites.size:
|
|
832
|
+
any_c_matrix = _layer_to_numpy(
|
|
833
|
+
sb,
|
|
834
|
+
layer_c,
|
|
835
|
+
any_c_sites,
|
|
836
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
837
|
+
fill_nan_value=fill_nan_value,
|
|
838
|
+
)
|
|
839
|
+
linkage = sch.linkage(any_c_matrix, method="ward")
|
|
840
|
+
order = sch.leaves_list(linkage)
|
|
841
|
+
|
|
842
|
+
elif sort_by == "a" and any_a_sites.size:
|
|
843
|
+
any_a_matrix = _layer_to_numpy(
|
|
844
|
+
sb,
|
|
845
|
+
layer_a,
|
|
846
|
+
any_a_sites,
|
|
847
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
848
|
+
fill_nan_value=fill_nan_value,
|
|
849
|
+
)
|
|
850
|
+
linkage = sch.linkage(any_a_matrix, method="ward")
|
|
851
|
+
order = sch.leaves_list(linkage)
|
|
852
|
+
|
|
853
|
+
elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
|
|
854
|
+
gpc_matrix = _layer_to_numpy(
|
|
855
|
+
sb,
|
|
856
|
+
layer_gpc,
|
|
857
|
+
None,
|
|
858
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
859
|
+
fill_nan_value=fill_nan_value,
|
|
860
|
+
)
|
|
861
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
862
|
+
order = sch.leaves_list(linkage)
|
|
863
|
+
|
|
864
|
+
elif sort_by == "hmm" and hmm_sites.size:
|
|
865
|
+
hmm_matrix = _layer_to_numpy(
|
|
866
|
+
sb,
|
|
867
|
+
hmm_feature_layer,
|
|
868
|
+
hmm_sites,
|
|
869
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
870
|
+
fill_nan_value=fill_nan_value,
|
|
871
|
+
)
|
|
872
|
+
linkage = sch.linkage(hmm_matrix, method="ward")
|
|
873
|
+
order = sch.leaves_list(linkage)
|
|
874
|
+
|
|
875
|
+
else:
|
|
876
|
+
order = np.arange(n)
|
|
877
|
+
|
|
878
|
+
sb = sb[order]
|
|
879
|
+
ordered_obs_names.extend(sb.obs_names.tolist())
|
|
880
|
+
|
|
881
|
+
# ---- collect matrices ----
|
|
882
|
+
stacked_hmm.append(
|
|
883
|
+
_layer_to_numpy(
|
|
884
|
+
sb,
|
|
885
|
+
hmm_feature_layer,
|
|
886
|
+
None,
|
|
887
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
888
|
+
fill_nan_value=fill_nan_value,
|
|
889
|
+
)
|
|
890
|
+
)
|
|
891
|
+
stacked_hmm_raw.append(
|
|
892
|
+
_layer_to_numpy(
|
|
893
|
+
sb,
|
|
894
|
+
hmm_feature_layer,
|
|
895
|
+
None,
|
|
896
|
+
fill_nan_strategy="none",
|
|
897
|
+
fill_nan_value=fill_nan_value,
|
|
898
|
+
)
|
|
899
|
+
)
|
|
900
|
+
if any_c_sites.size:
|
|
901
|
+
stacked_any_c.append(
|
|
902
|
+
_layer_to_numpy(
|
|
903
|
+
sb,
|
|
904
|
+
layer_c,
|
|
905
|
+
any_c_sites,
|
|
906
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
907
|
+
fill_nan_value=fill_nan_value,
|
|
908
|
+
)
|
|
909
|
+
)
|
|
910
|
+
stacked_any_c_raw.append(
|
|
911
|
+
_layer_to_numpy(
|
|
912
|
+
sb,
|
|
913
|
+
layer_c,
|
|
914
|
+
any_c_sites,
|
|
915
|
+
fill_nan_strategy="none",
|
|
916
|
+
fill_nan_value=fill_nan_value,
|
|
917
|
+
)
|
|
918
|
+
)
|
|
919
|
+
if gpc_sites.size:
|
|
920
|
+
stacked_gpc.append(
|
|
921
|
+
_layer_to_numpy(
|
|
922
|
+
sb,
|
|
923
|
+
layer_gpc,
|
|
924
|
+
gpc_sites,
|
|
925
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
926
|
+
fill_nan_value=fill_nan_value,
|
|
927
|
+
)
|
|
928
|
+
)
|
|
929
|
+
stacked_gpc_raw.append(
|
|
930
|
+
_layer_to_numpy(
|
|
931
|
+
sb,
|
|
932
|
+
layer_gpc,
|
|
933
|
+
gpc_sites,
|
|
934
|
+
fill_nan_strategy="none",
|
|
935
|
+
fill_nan_value=fill_nan_value,
|
|
936
|
+
)
|
|
937
|
+
)
|
|
938
|
+
if cpg_sites.size:
|
|
939
|
+
stacked_cpg.append(
|
|
940
|
+
_layer_to_numpy(
|
|
941
|
+
sb,
|
|
942
|
+
layer_cpg,
|
|
943
|
+
cpg_sites,
|
|
944
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
945
|
+
fill_nan_value=fill_nan_value,
|
|
946
|
+
)
|
|
947
|
+
)
|
|
948
|
+
stacked_cpg_raw.append(
|
|
949
|
+
_layer_to_numpy(
|
|
950
|
+
sb,
|
|
951
|
+
layer_cpg,
|
|
952
|
+
cpg_sites,
|
|
953
|
+
fill_nan_strategy="none",
|
|
954
|
+
fill_nan_value=fill_nan_value,
|
|
955
|
+
)
|
|
956
|
+
)
|
|
957
|
+
if any_a_sites.size:
|
|
958
|
+
stacked_any_a.append(
|
|
959
|
+
_layer_to_numpy(
|
|
960
|
+
sb,
|
|
961
|
+
layer_a,
|
|
962
|
+
any_a_sites,
|
|
963
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
964
|
+
fill_nan_value=fill_nan_value,
|
|
965
|
+
)
|
|
966
|
+
)
|
|
967
|
+
stacked_any_a_raw.append(
|
|
968
|
+
_layer_to_numpy(
|
|
969
|
+
sb,
|
|
970
|
+
layer_a,
|
|
971
|
+
any_a_sites,
|
|
972
|
+
fill_nan_strategy="none",
|
|
973
|
+
fill_nan_value=fill_nan_value,
|
|
974
|
+
)
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
row_labels.extend([bin_label] * n)
|
|
978
|
+
bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
|
|
979
|
+
last_idx += n
|
|
980
|
+
bin_boundaries.append(last_idx)
|
|
981
|
+
|
|
982
|
+
# ---------------- stack ----------------
|
|
983
|
+
hmm_matrix = np.vstack(stacked_hmm)
|
|
984
|
+
hmm_matrix_raw = np.vstack(stacked_hmm_raw)
|
|
985
|
+
mean_hmm = (
|
|
986
|
+
normalized_mean(hmm_matrix_raw)
|
|
987
|
+
if normalize_hmm
|
|
988
|
+
else np.nanmean(hmm_matrix_raw, axis=0)
|
|
989
|
+
)
|
|
990
|
+
hmm_plot_matrix = hmm_matrix_raw
|
|
991
|
+
hmm_plot_cmap = _build_hmm_feature_cmap(cmap_hmm)
|
|
992
|
+
|
|
993
|
+
panels = [
|
|
994
|
+
(
|
|
995
|
+
f"HMM - {hmm_feature_layer}",
|
|
996
|
+
hmm_plot_matrix,
|
|
997
|
+
hmm_labels,
|
|
998
|
+
hmm_plot_cmap,
|
|
999
|
+
mean_hmm,
|
|
1000
|
+
n_xticks_hmm,
|
|
1001
|
+
),
|
|
1002
|
+
]
|
|
1003
|
+
|
|
1004
|
+
if stacked_any_c:
|
|
1005
|
+
m = np.vstack(stacked_any_c)
|
|
1006
|
+
m_raw = np.vstack(stacked_any_c_raw)
|
|
1007
|
+
panels.append(
|
|
1008
|
+
(
|
|
1009
|
+
"C",
|
|
1010
|
+
m,
|
|
1011
|
+
any_c_labels,
|
|
1012
|
+
cmap_c,
|
|
1013
|
+
_methylation_fraction_for_layer(m_raw, layer_c),
|
|
1014
|
+
n_xticks_any_c,
|
|
1015
|
+
)
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
if stacked_gpc:
|
|
1019
|
+
m = np.vstack(stacked_gpc)
|
|
1020
|
+
m_raw = np.vstack(stacked_gpc_raw)
|
|
1021
|
+
panels.append(
|
|
1022
|
+
(
|
|
1023
|
+
"GpC",
|
|
1024
|
+
m,
|
|
1025
|
+
gpc_labels,
|
|
1026
|
+
cmap_gpc,
|
|
1027
|
+
_methylation_fraction_for_layer(m_raw, layer_gpc),
|
|
1028
|
+
n_xticks_gpc,
|
|
1029
|
+
)
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
if stacked_cpg:
|
|
1033
|
+
m = np.vstack(stacked_cpg)
|
|
1034
|
+
m_raw = np.vstack(stacked_cpg_raw)
|
|
1035
|
+
panels.append(
|
|
1036
|
+
(
|
|
1037
|
+
"CpG",
|
|
1038
|
+
m,
|
|
1039
|
+
cpg_labels,
|
|
1040
|
+
cmap_cpg,
|
|
1041
|
+
_methylation_fraction_for_layer(m_raw, layer_cpg),
|
|
1042
|
+
n_xticks_cpg,
|
|
1043
|
+
)
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
if stacked_any_a:
|
|
1047
|
+
m = np.vstack(stacked_any_a)
|
|
1048
|
+
m_raw = np.vstack(stacked_any_a_raw)
|
|
1049
|
+
panels.append(
|
|
1050
|
+
(
|
|
1051
|
+
"A",
|
|
1052
|
+
m,
|
|
1053
|
+
any_a_labels,
|
|
1054
|
+
cmap_a,
|
|
1055
|
+
_methylation_fraction_for_layer(m_raw, layer_a),
|
|
1056
|
+
n_xticks_a,
|
|
1057
|
+
)
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
# ---------------- plotting ----------------
|
|
1061
|
+
n_panels = len(panels)
|
|
1062
|
+
fig = plt.figure(figsize=(4.5 * n_panels, 10))
|
|
1063
|
+
gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
|
|
1064
|
+
fig.suptitle(
|
|
1065
|
+
f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
|
|
1066
|
+
)
|
|
1067
|
+
|
|
1068
|
+
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
|
|
1069
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
|
|
1070
|
+
|
|
1071
|
+
for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
|
|
1072
|
+
# ---- your clean barplot ----
|
|
1073
|
+
clean_barplot(axes_bar[i], mean_vec, name)
|
|
1074
|
+
|
|
1075
|
+
# ---- heatmap ----
|
|
1076
|
+
heatmap_kwargs = dict(
|
|
1077
|
+
cmap=cmap,
|
|
1078
|
+
ax=axes_heat[i],
|
|
1079
|
+
yticklabels=False,
|
|
1080
|
+
cbar=False,
|
|
1081
|
+
)
|
|
1082
|
+
if name.startswith("HMM -"):
|
|
1083
|
+
heatmap_kwargs.update(vmin=0.0, vmax=1.0)
|
|
1084
|
+
sns.heatmap(matrix, **heatmap_kwargs)
|
|
1085
|
+
|
|
1086
|
+
# ---- xticks ----
|
|
1087
|
+
xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
|
|
1088
|
+
axes_heat[i].set_xticks(xtick_pos)
|
|
1089
|
+
axes_heat[i].set_xticklabels(xtick_labels, rotation=90, fontsize=8)
|
|
1090
|
+
|
|
1091
|
+
for boundary in bin_boundaries[:-1]:
|
|
1092
|
+
axes_heat[i].axhline(y=boundary, color="black", linewidth=1.2)
|
|
1093
|
+
|
|
1094
|
+
if overlay_variant_calls and ordered_obs_names:
|
|
1095
|
+
try:
|
|
1096
|
+
# Map panel sites from subset-local coordinates to full adata indices
|
|
1097
|
+
hmm_sites_global = _local_sites_to_global_indices(adata, subset, hmm_sites)
|
|
1098
|
+
any_c_sites_global = _local_sites_to_global_indices(
|
|
1099
|
+
adata, subset, any_c_sites
|
|
1100
|
+
)
|
|
1101
|
+
gpc_sites_global = _local_sites_to_global_indices(adata, subset, gpc_sites)
|
|
1102
|
+
cpg_sites_global = _local_sites_to_global_indices(adata, subset, cpg_sites)
|
|
1103
|
+
any_a_sites_global = _local_sites_to_global_indices(
|
|
1104
|
+
adata, subset, any_a_sites
|
|
1105
|
+
)
|
|
1106
|
+
|
|
1107
|
+
# Build panels_with_indices using site indices for each panel
|
|
1108
|
+
# Map panel names to their site index arrays
|
|
1109
|
+
name_to_sites = {
|
|
1110
|
+
f"HMM - {hmm_feature_layer}": hmm_sites_global,
|
|
1111
|
+
"C": any_c_sites_global,
|
|
1112
|
+
"GpC": gpc_sites_global,
|
|
1113
|
+
"CpG": cpg_sites_global,
|
|
1114
|
+
"A": any_a_sites_global,
|
|
1115
|
+
}
|
|
1116
|
+
panels_with_indices = []
|
|
1117
|
+
for idx, (name, *_rest) in enumerate(panels):
|
|
1118
|
+
sites = name_to_sites.get(name)
|
|
1119
|
+
if sites is not None and len(sites) > 0:
|
|
1120
|
+
panels_with_indices.append((axes_heat[idx], sites))
|
|
1121
|
+
if panels_with_indices:
|
|
1122
|
+
_overlay_variant_calls_on_panels(
|
|
1123
|
+
adata,
|
|
1124
|
+
ref,
|
|
1125
|
+
ordered_obs_names,
|
|
1126
|
+
panels_with_indices,
|
|
1127
|
+
seq1_color=variant_overlay_seq1_color,
|
|
1128
|
+
seq2_color=variant_overlay_seq2_color,
|
|
1129
|
+
marker_size=variant_overlay_marker_size,
|
|
1130
|
+
)
|
|
1131
|
+
except Exception as overlay_err:
|
|
1132
|
+
logger.warning(
|
|
1133
|
+
"Variant overlay failed for %s - %s: %s", sample, ref, overlay_err
|
|
1134
|
+
)
|
|
1135
|
+
|
|
1136
|
+
plt.tight_layout()
|
|
1137
|
+
|
|
1138
|
+
if save_path:
|
|
1139
|
+
save_path = Path(save_path)
|
|
1140
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
1141
|
+
safe_name = f"{ref}__{sample}".replace("/", "_")
|
|
1142
|
+
out_file = save_path / f"{safe_name}.png"
|
|
1143
|
+
plt.savefig(out_file, dpi=300)
|
|
1144
|
+
plt.close(fig)
|
|
1145
|
+
else:
|
|
1146
|
+
plt.show()
|
|
1147
|
+
|
|
1148
|
+
except Exception:
|
|
1149
|
+
import traceback
|
|
1150
|
+
|
|
1151
|
+
traceback.print_exc()
|
|
1152
|
+
continue
|
|
1153
|
+
|
|
1154
|
+
|
|
1155
|
+
def combined_hmm_length_clustermap(
|
|
1156
|
+
adata,
|
|
1157
|
+
sample_col: str = "Sample_Names",
|
|
1158
|
+
reference_col: str = "Reference_strand",
|
|
1159
|
+
length_layer: str = "hmm_combined_lengths",
|
|
1160
|
+
layer_gpc: str = "nan0_0minus1",
|
|
1161
|
+
layer_cpg: str = "nan0_0minus1",
|
|
1162
|
+
layer_c: str = "nan0_0minus1",
|
|
1163
|
+
layer_a: str = "nan0_0minus1",
|
|
1164
|
+
cmap_lengths: Any = "Greens",
|
|
1165
|
+
cmap_gpc: str = "coolwarm",
|
|
1166
|
+
cmap_cpg: str = "viridis",
|
|
1167
|
+
cmap_c: str = "coolwarm",
|
|
1168
|
+
cmap_a: str = "coolwarm",
|
|
1169
|
+
min_quality: int = 20,
|
|
1170
|
+
min_length: int = 200,
|
|
1171
|
+
min_mapped_length_to_reference_length_ratio: float = 0.8,
|
|
1172
|
+
min_position_valid_fraction: float = 0.5,
|
|
1173
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
1174
|
+
sample_mapping: Optional[Mapping[str, str]] = None,
|
|
1175
|
+
save_path: str | Path | None = None,
|
|
1176
|
+
sort_by: str = "gpc",
|
|
1177
|
+
bins: Optional[Dict[str, Any]] = None,
|
|
1178
|
+
deaminase: bool = False,
|
|
1179
|
+
min_signal: float = 0.0,
|
|
1180
|
+
n_xticks_lengths: int = 10,
|
|
1181
|
+
n_xticks_any_c: int = 8,
|
|
1182
|
+
n_xticks_gpc: int = 8,
|
|
1183
|
+
n_xticks_cpg: int = 8,
|
|
1184
|
+
n_xticks_a: int = 8,
|
|
1185
|
+
index_col_suffix: str | None = None,
|
|
1186
|
+
fill_nan_strategy: str = "value",
|
|
1187
|
+
fill_nan_value: float = -1,
|
|
1188
|
+
length_feature_ranges: Optional[Sequence[Tuple[int, int, Any]]] = None,
|
|
1189
|
+
overlay_variant_calls: bool = False,
|
|
1190
|
+
variant_overlay_seq1_color: str = "white",
|
|
1191
|
+
variant_overlay_seq2_color: str = "black",
|
|
1192
|
+
variant_overlay_marker_size: float = 4.0,
|
|
1193
|
+
):
|
|
1194
|
+
"""
|
|
1195
|
+
Plot clustermaps for length-encoded HMM feature layers with optional subclass colors.
|
|
1196
|
+
|
|
1197
|
+
Length-based feature ranges map integer lengths into subclass colors for accessible
|
|
1198
|
+
and footprint layers. Raw methylation panels are included when available.
|
|
1199
|
+
"""
|
|
1200
|
+
if fill_nan_strategy not in {"none", "value", "col_mean"}:
|
|
1201
|
+
raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
|
|
1202
|
+
|
|
1203
|
+
def pick_xticks(labels: np.ndarray, n_ticks: int):
|
|
1204
|
+
"""Pick tick indices/labels from an array."""
|
|
1205
|
+
if labels.size == 0:
|
|
1206
|
+
return [], []
|
|
1207
|
+
idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
|
|
1208
|
+
idx = np.unique(idx)
|
|
1209
|
+
return idx.tolist(), labels[idx].tolist()
|
|
1210
|
+
|
|
1211
|
+
def _mask_or_true(series_name: str, predicate):
|
|
1212
|
+
"""Return a mask from predicate or an all-True mask."""
|
|
1213
|
+
if series_name not in adata.obs:
|
|
1214
|
+
return pd.Series(True, index=adata.obs.index)
|
|
1215
|
+
s = adata.obs[series_name]
|
|
1216
|
+
try:
|
|
1217
|
+
return predicate(s)
|
|
1218
|
+
except Exception:
|
|
1219
|
+
return pd.Series(True, index=adata.obs.index)
|
|
1220
|
+
|
|
1221
|
+
results = []
|
|
1222
|
+
signal_type = "deamination" if deaminase else "methylation"
|
|
1223
|
+
feature_ranges = tuple(length_feature_ranges or ())
|
|
1224
|
+
|
|
1225
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
1226
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
1227
|
+
display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
|
|
1228
|
+
qmask = _mask_or_true(
|
|
1229
|
+
"read_quality",
|
|
1230
|
+
(lambda s: s >= float(min_quality))
|
|
1231
|
+
if (min_quality is not None)
|
|
1232
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1233
|
+
)
|
|
1234
|
+
lm_mask = _mask_or_true(
|
|
1235
|
+
"mapped_length",
|
|
1236
|
+
(lambda s: s >= float(min_length))
|
|
1237
|
+
if (min_length is not None)
|
|
1238
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1239
|
+
)
|
|
1240
|
+
lrr_mask = _mask_or_true(
|
|
1241
|
+
"mapped_length_to_reference_length_ratio",
|
|
1242
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
1243
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
1244
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1245
|
+
)
|
|
1246
|
+
|
|
1247
|
+
demux_mask = _mask_or_true(
|
|
1248
|
+
"demux_type",
|
|
1249
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
1250
|
+
if (demux_types is not None)
|
|
1251
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1252
|
+
)
|
|
1253
|
+
|
|
1254
|
+
ref_mask = adata.obs[reference_col] == ref
|
|
1255
|
+
sample_mask = adata.obs[sample_col] == sample
|
|
1256
|
+
|
|
1257
|
+
row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
|
|
1258
|
+
|
|
1259
|
+
if not bool(row_mask.any()):
|
|
1260
|
+
print(
|
|
1261
|
+
f"No reads for {display_sample} - {ref} after read quality and length filtering"
|
|
1262
|
+
)
|
|
1263
|
+
continue
|
|
1264
|
+
|
|
1265
|
+
try:
|
|
1266
|
+
subset = adata[row_mask, :].copy()
|
|
1267
|
+
|
|
1268
|
+
if min_position_valid_fraction is not None:
|
|
1269
|
+
valid_key = f"{ref}_valid_fraction"
|
|
1270
|
+
if valid_key in subset.var:
|
|
1271
|
+
v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
|
|
1272
|
+
col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
|
|
1273
|
+
if col_mask.any():
|
|
1274
|
+
subset = subset[:, col_mask].copy()
|
|
1275
|
+
else:
|
|
1276
|
+
print(
|
|
1277
|
+
f"No positions left after valid_fraction filter for {display_sample} - {ref}"
|
|
1278
|
+
)
|
|
1279
|
+
continue
|
|
1280
|
+
|
|
1281
|
+
if subset.shape[0] == 0:
|
|
1282
|
+
print(f"No reads left after filtering for {display_sample} - {ref}")
|
|
1283
|
+
continue
|
|
1284
|
+
|
|
1285
|
+
if bins is None:
|
|
1286
|
+
bins_temp = {"All": np.ones(subset.n_obs, dtype=bool)}
|
|
1287
|
+
else:
|
|
1288
|
+
bins_temp = bins
|
|
1289
|
+
|
|
1290
|
+
def _sites(*keys):
|
|
1291
|
+
"""Return indices for the first matching site key."""
|
|
1292
|
+
for k in keys:
|
|
1293
|
+
if k in subset.var:
|
|
1294
|
+
return np.where(subset.var[k].values)[0]
|
|
1295
|
+
return np.array([], dtype=int)
|
|
1296
|
+
|
|
1297
|
+
gpc_sites = _sites(f"{ref}_GpC_site")
|
|
1298
|
+
cpg_sites = _sites(f"{ref}_CpG_site")
|
|
1299
|
+
any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
|
|
1300
|
+
any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
|
|
1301
|
+
|
|
1302
|
+
length_sites = np.arange(subset.n_vars, dtype=int)
|
|
1303
|
+
length_labels = _select_labels(subset, length_sites, ref, index_col_suffix)
|
|
1304
|
+
gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
|
|
1305
|
+
cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
|
|
1306
|
+
any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
|
|
1307
|
+
any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
|
|
1308
|
+
|
|
1309
|
+
stacked_lengths = []
|
|
1310
|
+
stacked_lengths_raw = []
|
|
1311
|
+
stacked_any_c = []
|
|
1312
|
+
stacked_any_c_raw = []
|
|
1313
|
+
stacked_gpc = []
|
|
1314
|
+
stacked_gpc_raw = []
|
|
1315
|
+
stacked_cpg = []
|
|
1316
|
+
stacked_cpg_raw = []
|
|
1317
|
+
stacked_any_a = []
|
|
1318
|
+
stacked_any_a_raw = []
|
|
1319
|
+
ordered_obs_names = []
|
|
1320
|
+
|
|
1321
|
+
row_labels, bin_labels, bin_boundaries = [], [], []
|
|
1322
|
+
total_reads = subset.n_obs
|
|
1323
|
+
percentages = {}
|
|
1324
|
+
last_idx = 0
|
|
1325
|
+
|
|
1326
|
+
for bin_label, bin_filter in bins_temp.items():
|
|
1327
|
+
sb = subset[bin_filter].copy()
|
|
1328
|
+
n = sb.n_obs
|
|
1329
|
+
if n == 0:
|
|
1330
|
+
continue
|
|
1331
|
+
|
|
1332
|
+
pct = (n / total_reads) * 100 if total_reads else 0
|
|
1333
|
+
percentages[bin_label] = pct
|
|
1334
|
+
|
|
1335
|
+
if sort_by.startswith("obs:"):
|
|
1336
|
+
colname = sort_by.split("obs:")[1]
|
|
1337
|
+
order = np.argsort(sb.obs[colname].values)
|
|
1338
|
+
elif sort_by == "gpc" and gpc_sites.size:
|
|
1339
|
+
gpc_matrix = _layer_to_numpy(
|
|
1340
|
+
sb,
|
|
1341
|
+
layer_gpc,
|
|
1342
|
+
gpc_sites,
|
|
1343
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1344
|
+
fill_nan_value=fill_nan_value,
|
|
1345
|
+
)
|
|
1346
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
1347
|
+
order = sch.leaves_list(linkage)
|
|
1348
|
+
elif sort_by == "cpg" and cpg_sites.size:
|
|
1349
|
+
cpg_matrix = _layer_to_numpy(
|
|
1350
|
+
sb,
|
|
1351
|
+
layer_cpg,
|
|
1352
|
+
cpg_sites,
|
|
1353
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1354
|
+
fill_nan_value=fill_nan_value,
|
|
1355
|
+
)
|
|
1356
|
+
linkage = sch.linkage(cpg_matrix, method="ward")
|
|
1357
|
+
order = sch.leaves_list(linkage)
|
|
1358
|
+
elif sort_by == "c" and any_c_sites.size:
|
|
1359
|
+
any_c_matrix = _layer_to_numpy(
|
|
1360
|
+
sb,
|
|
1361
|
+
layer_c,
|
|
1362
|
+
any_c_sites,
|
|
1363
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1364
|
+
fill_nan_value=fill_nan_value,
|
|
1365
|
+
)
|
|
1366
|
+
linkage = sch.linkage(any_c_matrix, method="ward")
|
|
1367
|
+
order = sch.leaves_list(linkage)
|
|
1368
|
+
elif sort_by == "a" and any_a_sites.size:
|
|
1369
|
+
any_a_matrix = _layer_to_numpy(
|
|
1370
|
+
sb,
|
|
1371
|
+
layer_a,
|
|
1372
|
+
any_a_sites,
|
|
1373
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1374
|
+
fill_nan_value=fill_nan_value,
|
|
1375
|
+
)
|
|
1376
|
+
linkage = sch.linkage(any_a_matrix, method="ward")
|
|
1377
|
+
order = sch.leaves_list(linkage)
|
|
1378
|
+
elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
|
|
1379
|
+
gpc_matrix = _layer_to_numpy(
|
|
1380
|
+
sb,
|
|
1381
|
+
layer_gpc,
|
|
1382
|
+
None,
|
|
1383
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1384
|
+
fill_nan_value=fill_nan_value,
|
|
1385
|
+
)
|
|
1386
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
1387
|
+
order = sch.leaves_list(linkage)
|
|
1388
|
+
elif sort_by == "hmm" and length_sites.size:
|
|
1389
|
+
length_matrix = _layer_to_numpy(
|
|
1390
|
+
sb,
|
|
1391
|
+
length_layer,
|
|
1392
|
+
length_sites,
|
|
1393
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1394
|
+
fill_nan_value=fill_nan_value,
|
|
1395
|
+
)
|
|
1396
|
+
linkage = sch.linkage(length_matrix, method="ward")
|
|
1397
|
+
order = sch.leaves_list(linkage)
|
|
1398
|
+
else:
|
|
1399
|
+
order = np.arange(n)
|
|
1400
|
+
|
|
1401
|
+
sb = sb[order]
|
|
1402
|
+
ordered_obs_names.extend(sb.obs_names.tolist())
|
|
1403
|
+
|
|
1404
|
+
stacked_lengths.append(
|
|
1405
|
+
_layer_to_numpy(
|
|
1406
|
+
sb,
|
|
1407
|
+
length_layer,
|
|
1408
|
+
None,
|
|
1409
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1410
|
+
fill_nan_value=fill_nan_value,
|
|
1411
|
+
)
|
|
1412
|
+
)
|
|
1413
|
+
stacked_lengths_raw.append(
|
|
1414
|
+
_layer_to_numpy(
|
|
1415
|
+
sb,
|
|
1416
|
+
length_layer,
|
|
1417
|
+
None,
|
|
1418
|
+
fill_nan_strategy="none",
|
|
1419
|
+
fill_nan_value=fill_nan_value,
|
|
1420
|
+
)
|
|
1421
|
+
)
|
|
1422
|
+
if any_c_sites.size:
|
|
1423
|
+
stacked_any_c.append(
|
|
1424
|
+
_layer_to_numpy(
|
|
1425
|
+
sb,
|
|
1426
|
+
layer_c,
|
|
1427
|
+
any_c_sites,
|
|
1428
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1429
|
+
fill_nan_value=fill_nan_value,
|
|
1430
|
+
)
|
|
1431
|
+
)
|
|
1432
|
+
stacked_any_c_raw.append(
|
|
1433
|
+
_layer_to_numpy(
|
|
1434
|
+
sb,
|
|
1435
|
+
layer_c,
|
|
1436
|
+
any_c_sites,
|
|
1437
|
+
fill_nan_strategy="none",
|
|
1438
|
+
fill_nan_value=fill_nan_value,
|
|
1439
|
+
)
|
|
1440
|
+
)
|
|
1441
|
+
if gpc_sites.size:
|
|
1442
|
+
stacked_gpc.append(
|
|
1443
|
+
_layer_to_numpy(
|
|
1444
|
+
sb,
|
|
1445
|
+
layer_gpc,
|
|
1446
|
+
gpc_sites,
|
|
1447
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1448
|
+
fill_nan_value=fill_nan_value,
|
|
1449
|
+
)
|
|
1450
|
+
)
|
|
1451
|
+
stacked_gpc_raw.append(
|
|
1452
|
+
_layer_to_numpy(
|
|
1453
|
+
sb,
|
|
1454
|
+
layer_gpc,
|
|
1455
|
+
gpc_sites,
|
|
1456
|
+
fill_nan_strategy="none",
|
|
1457
|
+
fill_nan_value=fill_nan_value,
|
|
1458
|
+
)
|
|
1459
|
+
)
|
|
1460
|
+
if cpg_sites.size:
|
|
1461
|
+
stacked_cpg.append(
|
|
1462
|
+
_layer_to_numpy(
|
|
1463
|
+
sb,
|
|
1464
|
+
layer_cpg,
|
|
1465
|
+
cpg_sites,
|
|
1466
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1467
|
+
fill_nan_value=fill_nan_value,
|
|
1468
|
+
)
|
|
1469
|
+
)
|
|
1470
|
+
stacked_cpg_raw.append(
|
|
1471
|
+
_layer_to_numpy(
|
|
1472
|
+
sb,
|
|
1473
|
+
layer_cpg,
|
|
1474
|
+
cpg_sites,
|
|
1475
|
+
fill_nan_strategy="none",
|
|
1476
|
+
fill_nan_value=fill_nan_value,
|
|
1477
|
+
)
|
|
1478
|
+
)
|
|
1479
|
+
if any_a_sites.size:
|
|
1480
|
+
stacked_any_a.append(
|
|
1481
|
+
_layer_to_numpy(
|
|
1482
|
+
sb,
|
|
1483
|
+
layer_a,
|
|
1484
|
+
any_a_sites,
|
|
1485
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1486
|
+
fill_nan_value=fill_nan_value,
|
|
1487
|
+
)
|
|
1488
|
+
)
|
|
1489
|
+
stacked_any_a_raw.append(
|
|
1490
|
+
_layer_to_numpy(
|
|
1491
|
+
sb,
|
|
1492
|
+
layer_a,
|
|
1493
|
+
any_a_sites,
|
|
1494
|
+
fill_nan_strategy="none",
|
|
1495
|
+
fill_nan_value=fill_nan_value,
|
|
1496
|
+
)
|
|
1497
|
+
)
|
|
1498
|
+
|
|
1499
|
+
row_labels.extend([bin_label] * n)
|
|
1500
|
+
bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
|
|
1501
|
+
last_idx += n
|
|
1502
|
+
bin_boundaries.append(last_idx)
|
|
1503
|
+
|
|
1504
|
+
length_matrix = np.vstack(stacked_lengths)
|
|
1505
|
+
length_matrix_raw = np.vstack(stacked_lengths_raw)
|
|
1506
|
+
capped_lengths = np.where(length_matrix_raw > 1, 1.0, length_matrix_raw)
|
|
1507
|
+
mean_lengths = np.nanmean(capped_lengths, axis=0)
|
|
1508
|
+
length_plot_matrix = length_matrix_raw
|
|
1509
|
+
length_plot_cmap = cmap_lengths
|
|
1510
|
+
length_plot_norm = None
|
|
1511
|
+
|
|
1512
|
+
if feature_ranges:
|
|
1513
|
+
length_plot_matrix = _map_length_matrix_to_subclasses(
|
|
1514
|
+
length_matrix_raw, feature_ranges
|
|
1515
|
+
)
|
|
1516
|
+
length_plot_cmap, length_plot_norm = _build_length_feature_cmap(feature_ranges)
|
|
1517
|
+
|
|
1518
|
+
panels = [
|
|
1519
|
+
(
|
|
1520
|
+
f"HMM lengths - {length_layer}",
|
|
1521
|
+
length_plot_matrix,
|
|
1522
|
+
length_labels,
|
|
1523
|
+
length_plot_cmap,
|
|
1524
|
+
mean_lengths,
|
|
1525
|
+
n_xticks_lengths,
|
|
1526
|
+
length_plot_norm,
|
|
1527
|
+
),
|
|
1528
|
+
]
|
|
1529
|
+
|
|
1530
|
+
if stacked_any_c:
|
|
1531
|
+
m = np.vstack(stacked_any_c)
|
|
1532
|
+
m_raw = np.vstack(stacked_any_c_raw)
|
|
1533
|
+
panels.append(
|
|
1534
|
+
(
|
|
1535
|
+
"C",
|
|
1536
|
+
m,
|
|
1537
|
+
any_c_labels,
|
|
1538
|
+
cmap_c,
|
|
1539
|
+
_methylation_fraction_for_layer(m_raw, layer_c),
|
|
1540
|
+
n_xticks_any_c,
|
|
1541
|
+
None,
|
|
1542
|
+
)
|
|
1543
|
+
)
|
|
1544
|
+
|
|
1545
|
+
if stacked_gpc:
|
|
1546
|
+
m = np.vstack(stacked_gpc)
|
|
1547
|
+
m_raw = np.vstack(stacked_gpc_raw)
|
|
1548
|
+
panels.append(
|
|
1549
|
+
(
|
|
1550
|
+
"GpC",
|
|
1551
|
+
m,
|
|
1552
|
+
gpc_labels,
|
|
1553
|
+
cmap_gpc,
|
|
1554
|
+
_methylation_fraction_for_layer(m_raw, layer_gpc),
|
|
1555
|
+
n_xticks_gpc,
|
|
1556
|
+
None,
|
|
1557
|
+
)
|
|
1558
|
+
)
|
|
1559
|
+
|
|
1560
|
+
if stacked_cpg:
|
|
1561
|
+
m = np.vstack(stacked_cpg)
|
|
1562
|
+
m_raw = np.vstack(stacked_cpg_raw)
|
|
1563
|
+
panels.append(
|
|
1564
|
+
(
|
|
1565
|
+
"CpG",
|
|
1566
|
+
m,
|
|
1567
|
+
cpg_labels,
|
|
1568
|
+
cmap_cpg,
|
|
1569
|
+
_methylation_fraction_for_layer(m_raw, layer_cpg),
|
|
1570
|
+
n_xticks_cpg,
|
|
1571
|
+
None,
|
|
1572
|
+
)
|
|
1573
|
+
)
|
|
1574
|
+
|
|
1575
|
+
if stacked_any_a:
|
|
1576
|
+
m = np.vstack(stacked_any_a)
|
|
1577
|
+
m_raw = np.vstack(stacked_any_a_raw)
|
|
1578
|
+
panels.append(
|
|
1579
|
+
(
|
|
1580
|
+
"A",
|
|
1581
|
+
m,
|
|
1582
|
+
any_a_labels,
|
|
1583
|
+
cmap_a,
|
|
1584
|
+
_methylation_fraction_for_layer(m_raw, layer_a),
|
|
1585
|
+
n_xticks_a,
|
|
1586
|
+
None,
|
|
1587
|
+
)
|
|
1588
|
+
)
|
|
1589
|
+
|
|
1590
|
+
n_panels = len(panels)
|
|
1591
|
+
fig = plt.figure(figsize=(4.5 * n_panels, 10))
|
|
1592
|
+
gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
|
|
1593
|
+
fig.suptitle(
|
|
1594
|
+
f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
|
|
1595
|
+
)
|
|
1596
|
+
|
|
1597
|
+
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
|
|
1598
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
|
|
1599
|
+
|
|
1600
|
+
for i, (name, matrix, labels, cmap, mean_vec, n_ticks, norm) in enumerate(panels):
|
|
1601
|
+
clean_barplot(axes_bar[i], mean_vec, name)
|
|
1602
|
+
|
|
1603
|
+
heatmap_kwargs = dict(
|
|
1604
|
+
cmap=cmap,
|
|
1605
|
+
ax=axes_heat[i],
|
|
1606
|
+
yticklabels=False,
|
|
1607
|
+
cbar=False,
|
|
1608
|
+
)
|
|
1609
|
+
if norm is not None:
|
|
1610
|
+
heatmap_kwargs["norm"] = norm
|
|
1611
|
+
sns.heatmap(matrix, **heatmap_kwargs)
|
|
1612
|
+
|
|
1613
|
+
xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
|
|
1614
|
+
axes_heat[i].set_xticks(xtick_pos)
|
|
1615
|
+
axes_heat[i].set_xticklabels(xtick_labels, rotation=90, fontsize=8)
|
|
1616
|
+
|
|
1617
|
+
for boundary in bin_boundaries[:-1]:
|
|
1618
|
+
axes_heat[i].axhline(y=boundary, color="black", linewidth=1.2)
|
|
1619
|
+
|
|
1620
|
+
if overlay_variant_calls and ordered_obs_names:
|
|
1621
|
+
try:
|
|
1622
|
+
# Map panel sites from subset-local coordinates to full adata indices
|
|
1623
|
+
length_sites_global = _local_sites_to_global_indices(
|
|
1624
|
+
adata, subset, length_sites
|
|
1625
|
+
)
|
|
1626
|
+
any_c_sites_global = _local_sites_to_global_indices(
|
|
1627
|
+
adata, subset, any_c_sites
|
|
1628
|
+
)
|
|
1629
|
+
gpc_sites_global = _local_sites_to_global_indices(adata, subset, gpc_sites)
|
|
1630
|
+
cpg_sites_global = _local_sites_to_global_indices(adata, subset, cpg_sites)
|
|
1631
|
+
any_a_sites_global = _local_sites_to_global_indices(
|
|
1632
|
+
adata, subset, any_a_sites
|
|
1633
|
+
)
|
|
1634
|
+
|
|
1635
|
+
# Build panels_with_indices using site indices for each panel
|
|
1636
|
+
name_to_sites = {
|
|
1637
|
+
f"HMM lengths - {length_layer}": length_sites_global,
|
|
1638
|
+
"C": any_c_sites_global,
|
|
1639
|
+
"GpC": gpc_sites_global,
|
|
1640
|
+
"CpG": cpg_sites_global,
|
|
1641
|
+
"A": any_a_sites_global,
|
|
1642
|
+
}
|
|
1643
|
+
panels_with_indices = []
|
|
1644
|
+
for idx, (name, *_rest) in enumerate(panels):
|
|
1645
|
+
sites = name_to_sites.get(name)
|
|
1646
|
+
if sites is not None and len(sites) > 0:
|
|
1647
|
+
panels_with_indices.append((axes_heat[idx], sites))
|
|
1648
|
+
if panels_with_indices:
|
|
1649
|
+
_overlay_variant_calls_on_panels(
|
|
1650
|
+
adata,
|
|
1651
|
+
ref,
|
|
1652
|
+
ordered_obs_names,
|
|
1653
|
+
panels_with_indices,
|
|
1654
|
+
seq1_color=variant_overlay_seq1_color,
|
|
1655
|
+
seq2_color=variant_overlay_seq2_color,
|
|
1656
|
+
marker_size=variant_overlay_marker_size,
|
|
1657
|
+
)
|
|
1658
|
+
except Exception as overlay_err:
|
|
1659
|
+
logger.warning(
|
|
1660
|
+
"Variant overlay failed for %s - %s: %s", sample, ref, overlay_err
|
|
1661
|
+
)
|
|
1662
|
+
|
|
1663
|
+
plt.tight_layout()
|
|
1664
|
+
|
|
1665
|
+
if save_path:
|
|
1666
|
+
save_path = Path(save_path)
|
|
1667
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
1668
|
+
safe_name = f"{ref}__{sample}".replace("/", "_")
|
|
1669
|
+
out_file = save_path / f"{safe_name}.png"
|
|
1670
|
+
plt.savefig(out_file, dpi=300)
|
|
1671
|
+
plt.close(fig)
|
|
1672
|
+
else:
|
|
1673
|
+
plt.show()
|
|
1674
|
+
|
|
1675
|
+
results.append((sample, ref))
|
|
1676
|
+
|
|
1677
|
+
except Exception:
|
|
1678
|
+
import traceback
|
|
1679
|
+
|
|
1680
|
+
traceback.print_exc()
|
|
1681
|
+
print(f"Failed {sample} - {ref} - {length_layer}")
|
|
1682
|
+
|
|
1683
|
+
return results
|
|
1684
|
+
|
|
1685
|
+
|
|
1686
|
+
def plot_hmm_layers_rolling_by_sample_ref(
|
|
1687
|
+
adata,
|
|
1688
|
+
layers: Optional[Sequence[str]] = None,
|
|
1689
|
+
sample_col: str = "Barcode",
|
|
1690
|
+
ref_col: str = "Reference_strand",
|
|
1691
|
+
samples: Optional[Sequence[str]] = None,
|
|
1692
|
+
references: Optional[Sequence[str]] = None,
|
|
1693
|
+
window: int = 51,
|
|
1694
|
+
min_periods: int = 1,
|
|
1695
|
+
center: bool = True,
|
|
1696
|
+
rows_per_page: int = 6,
|
|
1697
|
+
figsize_per_cell: Tuple[float, float] = (4.0, 2.5),
|
|
1698
|
+
dpi: int = 160,
|
|
1699
|
+
output_dir: Optional[str] = None,
|
|
1700
|
+
save: bool = True,
|
|
1701
|
+
show_raw: bool = False,
|
|
1702
|
+
cmap: str = "tab20",
|
|
1703
|
+
layer_colors: Optional[Mapping[str, Any]] = None,
|
|
1704
|
+
use_var_coords: bool = True,
|
|
1705
|
+
reindexed_var_suffix: str = "reindexed",
|
|
1706
|
+
):
|
|
1707
|
+
"""
|
|
1708
|
+
For each sample (row) and reference (col) plot the rolling average of the
|
|
1709
|
+
positional mean (mean across reads) for each layer listed.
|
|
1710
|
+
|
|
1711
|
+
Parameters
|
|
1712
|
+
----------
|
|
1713
|
+
adata : AnnData
|
|
1714
|
+
Input annotated data (expects obs columns sample_col and ref_col).
|
|
1715
|
+
layers : list[str] | None
|
|
1716
|
+
Which adata.layers to plot. If None, attempts to autodetect layers whose
|
|
1717
|
+
matrices look like "HMM" outputs (else will error). If None and layers
|
|
1718
|
+
cannot be found, user must pass a list.
|
|
1719
|
+
sample_col, ref_col : str
|
|
1720
|
+
obs columns used to group rows.
|
|
1721
|
+
samples, references : optional lists
|
|
1722
|
+
explicit ordering of samples / references. If None, categories in adata.obs are used.
|
|
1723
|
+
window : int
|
|
1724
|
+
rolling window size (odd recommended). If window <= 1, no smoothing applied.
|
|
1725
|
+
min_periods : int
|
|
1726
|
+
min periods param for pd.Series.rolling.
|
|
1727
|
+
center : bool
|
|
1728
|
+
center the rolling window.
|
|
1729
|
+
rows_per_page : int
|
|
1730
|
+
paginate rows per page into multiple figures if needed.
|
|
1731
|
+
figsize_per_cell : (w,h)
|
|
1732
|
+
per-subplot size in inches.
|
|
1733
|
+
dpi : int
|
|
1734
|
+
figure dpi when saving.
|
|
1735
|
+
output_dir : str | None
|
|
1736
|
+
directory to save pages; created if necessary. If None and save=True, uses cwd.
|
|
1737
|
+
save : bool
|
|
1738
|
+
whether to save PNG files.
|
|
1739
|
+
show_raw : bool
|
|
1740
|
+
draw unsmoothed mean as faint line under smoothed curve.
|
|
1741
|
+
cmap : str
|
|
1742
|
+
matplotlib colormap for layer lines.
|
|
1743
|
+
layer_colors : dict[str, Any] | None
|
|
1744
|
+
Optional mapping of layer name to explicit line colors.
|
|
1745
|
+
use_var_coords : bool
|
|
1746
|
+
if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
|
|
1747
|
+
reindexed_var_suffix : str
|
|
1748
|
+
Suffix for per-reference reindexed var columns (e.g., ``Reference_reindexed``) used when available.
|
|
1749
|
+
|
|
1750
|
+
Returns
|
|
1751
|
+
-------
|
|
1752
|
+
saved_files : list[str]
|
|
1753
|
+
list of saved filenames (may be empty if save=False).
|
|
1754
|
+
"""
|
|
1755
|
+
logger.info("Plotting rolling HMM layers by sample/ref.")
|
|
1756
|
+
|
|
1757
|
+
if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
|
|
1758
|
+
raise ValueError(
|
|
1759
|
+
f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
|
|
1760
|
+
)
|
|
1761
|
+
|
|
1762
|
+
if samples is None:
|
|
1763
|
+
sseries = adata.obs[sample_col]
|
|
1764
|
+
if not isinstance(sseries.dtype, pd.CategoricalDtype):
|
|
1765
|
+
sseries = sseries.astype("category")
|
|
1766
|
+
samples_all = list(sseries.cat.categories)
|
|
1767
|
+
else:
|
|
1768
|
+
samples_all = list(samples)
|
|
1769
|
+
|
|
1770
|
+
if references is None:
|
|
1771
|
+
rseries = adata.obs[ref_col]
|
|
1772
|
+
if not isinstance(rseries.dtype, pd.CategoricalDtype):
|
|
1773
|
+
rseries = rseries.astype("category")
|
|
1774
|
+
refs_all = list(rseries.cat.categories)
|
|
1775
|
+
else:
|
|
1776
|
+
refs_all = list(references)
|
|
1777
|
+
|
|
1778
|
+
if layers is None:
|
|
1779
|
+
layers = list(adata.layers.keys())
|
|
1780
|
+
if len(layers) == 0:
|
|
1781
|
+
raise ValueError(
|
|
1782
|
+
"No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
|
|
1783
|
+
)
|
|
1784
|
+
layers = list(layers)
|
|
1785
|
+
|
|
1786
|
+
x_labels = None
|
|
1787
|
+
try:
|
|
1788
|
+
if use_var_coords:
|
|
1789
|
+
x_coords = np.array([int(v) for v in adata.var_names])
|
|
1790
|
+
else:
|
|
1791
|
+
raise Exception("user disabled var coords")
|
|
1792
|
+
except Exception:
|
|
1793
|
+
x_coords = np.arange(adata.shape[1], dtype=int)
|
|
1794
|
+
x_labels = adata.var_names.astype(str).tolist()
|
|
1795
|
+
|
|
1796
|
+
ref_reindexed_cols = {
|
|
1797
|
+
ref: f"{ref}_{reindexed_var_suffix}"
|
|
1798
|
+
for ref in refs_all
|
|
1799
|
+
if f"{ref}_{reindexed_var_suffix}" in adata.var
|
|
1800
|
+
}
|
|
1801
|
+
|
|
1802
|
+
if save:
|
|
1803
|
+
outdir = output_dir or os.getcwd()
|
|
1804
|
+
os.makedirs(outdir, exist_ok=True)
|
|
1805
|
+
else:
|
|
1806
|
+
outdir = None
|
|
1807
|
+
|
|
1808
|
+
n_samples = len(samples_all)
|
|
1809
|
+
n_refs = len(refs_all)
|
|
1810
|
+
total_pages = math.ceil(n_samples / rows_per_page)
|
|
1811
|
+
saved_files = []
|
|
1812
|
+
|
|
1813
|
+
cmap_obj = plt.get_cmap(cmap)
|
|
1814
|
+
n_layers = max(1, len(layers))
|
|
1815
|
+
fallback_colors = [cmap_obj(i / max(1, n_layers - 1)) for i in range(n_layers)]
|
|
1816
|
+
layer_colors = layer_colors or {}
|
|
1817
|
+
colors = [layer_colors.get(layer, fallback_colors[idx]) for idx, layer in enumerate(layers)]
|
|
1818
|
+
|
|
1819
|
+
for page in range(total_pages):
|
|
1820
|
+
start = page * rows_per_page
|
|
1821
|
+
end = min(start + rows_per_page, n_samples)
|
|
1822
|
+
chunk = samples_all[start:end]
|
|
1823
|
+
nrows = len(chunk)
|
|
1824
|
+
ncols = n_refs
|
|
1825
|
+
|
|
1826
|
+
fig_w = figsize_per_cell[0] * ncols
|
|
1827
|
+
fig_h = figsize_per_cell[1] * nrows
|
|
1828
|
+
fig, axes = plt.subplots(
|
|
1829
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
1830
|
+
)
|
|
1831
|
+
|
|
1832
|
+
for r_idx, sample_name in enumerate(chunk):
|
|
1833
|
+
for c_idx, ref_name in enumerate(refs_all):
|
|
1834
|
+
ax = axes[r_idx][c_idx]
|
|
1835
|
+
|
|
1836
|
+
mask = (adata.obs[sample_col].values == sample_name) & (
|
|
1837
|
+
adata.obs[ref_col].values == ref_name
|
|
1838
|
+
)
|
|
1839
|
+
sub = adata[mask]
|
|
1840
|
+
if sub.n_obs == 0:
|
|
1841
|
+
ax.text(
|
|
1842
|
+
0.5,
|
|
1843
|
+
0.5,
|
|
1844
|
+
"No reads",
|
|
1845
|
+
ha="center",
|
|
1846
|
+
va="center",
|
|
1847
|
+
transform=ax.transAxes,
|
|
1848
|
+
color="gray",
|
|
1849
|
+
)
|
|
1850
|
+
ax.set_xticks([])
|
|
1851
|
+
ax.set_yticks([])
|
|
1852
|
+
if r_idx == 0:
|
|
1853
|
+
ax.set_title(str(ref_name), fontsize=9)
|
|
1854
|
+
if c_idx == 0:
|
|
1855
|
+
total_reads = int((adata.obs[sample_col] == sample_name).sum())
|
|
1856
|
+
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
|
|
1857
|
+
continue
|
|
1858
|
+
|
|
1859
|
+
plotted_any = False
|
|
1860
|
+
reindexed_col = ref_reindexed_cols.get(ref_name)
|
|
1861
|
+
if reindexed_col is not None:
|
|
1862
|
+
try:
|
|
1863
|
+
ref_coords = np.asarray(adata.var[reindexed_col], dtype=int)
|
|
1864
|
+
except Exception:
|
|
1865
|
+
ref_coords = x_coords
|
|
1866
|
+
else:
|
|
1867
|
+
ref_coords = x_coords
|
|
1868
|
+
for li, layer in enumerate(layers):
|
|
1869
|
+
if layer in sub.layers:
|
|
1870
|
+
mat = sub.layers[layer]
|
|
1871
|
+
else:
|
|
1872
|
+
if layer == layers[0] and getattr(sub, "X", None) is not None:
|
|
1873
|
+
mat = sub.X
|
|
1874
|
+
else:
|
|
1875
|
+
continue
|
|
1876
|
+
|
|
1877
|
+
if hasattr(mat, "toarray"):
|
|
1878
|
+
try:
|
|
1879
|
+
arr = mat.toarray()
|
|
1880
|
+
except Exception:
|
|
1881
|
+
arr = np.asarray(mat)
|
|
1882
|
+
else:
|
|
1883
|
+
arr = np.asarray(mat)
|
|
1884
|
+
|
|
1885
|
+
if arr.size == 0 or arr.shape[1] == 0:
|
|
1886
|
+
continue
|
|
1887
|
+
|
|
1888
|
+
arr = arr.astype(float)
|
|
1889
|
+
with np.errstate(all="ignore"):
|
|
1890
|
+
col_mean = np.nanmean(arr, axis=0)
|
|
1891
|
+
|
|
1892
|
+
if np.all(np.isnan(col_mean)):
|
|
1893
|
+
continue
|
|
1894
|
+
|
|
1895
|
+
valid_mask = np.isfinite(col_mean)
|
|
1896
|
+
|
|
1897
|
+
if (window is None) or (window <= 1):
|
|
1898
|
+
smoothed = col_mean
|
|
1899
|
+
else:
|
|
1900
|
+
ser = pd.Series(col_mean)
|
|
1901
|
+
smoothed = (
|
|
1902
|
+
ser.rolling(window=window, min_periods=min_periods, center=center)
|
|
1903
|
+
.mean()
|
|
1904
|
+
.to_numpy()
|
|
1905
|
+
)
|
|
1906
|
+
smoothed = np.where(valid_mask, smoothed, np.nan)
|
|
1907
|
+
|
|
1908
|
+
L = len(col_mean)
|
|
1909
|
+
x = ref_coords[:L]
|
|
1910
|
+
|
|
1911
|
+
if show_raw:
|
|
1912
|
+
ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
|
|
1913
|
+
|
|
1914
|
+
ax.plot(
|
|
1915
|
+
x,
|
|
1916
|
+
smoothed[:L],
|
|
1917
|
+
label=layer,
|
|
1918
|
+
color=colors[li],
|
|
1919
|
+
linewidth=1.2,
|
|
1920
|
+
alpha=0.95,
|
|
1921
|
+
zorder=2,
|
|
1922
|
+
)
|
|
1923
|
+
plotted_any = True
|
|
1924
|
+
|
|
1925
|
+
if r_idx == 0:
|
|
1926
|
+
ax.set_title(str(ref_name), fontsize=9)
|
|
1927
|
+
if c_idx == 0:
|
|
1928
|
+
total_reads = int((adata.obs[sample_col] == sample_name).sum())
|
|
1929
|
+
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
|
|
1930
|
+
if r_idx == nrows - 1:
|
|
1931
|
+
ax.set_xlabel("position", fontsize=8)
|
|
1932
|
+
if x_labels is not None and reindexed_col is None:
|
|
1933
|
+
max_ticks = 8
|
|
1934
|
+
tick_step = max(1, int(math.ceil(len(x_labels) / max_ticks)))
|
|
1935
|
+
tick_positions = x_coords[::tick_step]
|
|
1936
|
+
tick_labels = x_labels[::tick_step]
|
|
1937
|
+
ax.set_xticks(tick_positions)
|
|
1938
|
+
ax.set_xticklabels(tick_labels, fontsize=7, rotation=45, ha="right")
|
|
1939
|
+
|
|
1940
|
+
if (r_idx == 0 and c_idx == 0) and plotted_any:
|
|
1941
|
+
ax.legend(fontsize=7, loc="upper right")
|
|
1942
|
+
|
|
1943
|
+
ax.grid(True, alpha=0.2)
|
|
1944
|
+
|
|
1945
|
+
fig.suptitle(
|
|
1946
|
+
f"Rolling mean of layer positional means (window={window}) — page {page + 1}/{total_pages}",
|
|
1947
|
+
fontsize=11,
|
|
1948
|
+
y=0.995,
|
|
1949
|
+
)
|
|
1950
|
+
fig.tight_layout(rect=[0, 0, 1, 0.97])
|
|
1951
|
+
|
|
1952
|
+
if save:
|
|
1953
|
+
fname = os.path.join(outdir, f"hmm_layers_rolling_page{page + 1}.png")
|
|
1954
|
+
plt.savefig(fname, bbox_inches="tight", dpi=dpi)
|
|
1955
|
+
saved_files.append(fname)
|
|
1956
|
+
logger.info("Saved HMM layers rolling plot to %s.", fname)
|
|
1957
|
+
else:
|
|
1958
|
+
plt.show()
|
|
1959
|
+
plt.close(fig)
|
|
1960
|
+
|
|
1961
|
+
return saved_files
|