smftools 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +39 -7
- smftools/_settings.py +2 -0
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +34 -6
- smftools/cli/hmm_adata.py +239 -33
- smftools/cli/latent_adata.py +318 -0
- smftools/cli/load_adata.py +167 -131
- smftools/cli/preprocess_adata.py +180 -53
- smftools/cli/spatial_adata.py +152 -100
- smftools/cli_entry.py +38 -1
- smftools/config/__init__.py +2 -0
- smftools/config/conversion.yaml +11 -1
- smftools/config/default.yaml +42 -2
- smftools/config/experiment_config.py +59 -1
- smftools/constants.py +65 -0
- smftools/datasets/__init__.py +2 -0
- smftools/hmm/HMM.py +97 -3
- smftools/hmm/__init__.py +24 -13
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +2 -0
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +5 -2
- smftools/hmm/display_hmm.py +4 -1
- smftools/hmm/hmm_readwrite.py +7 -2
- smftools/hmm/nucleosome_hmm_refinement.py +2 -0
- smftools/informatics/__init__.py +59 -34
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +2 -0
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1093 -176
- smftools/informatics/basecalling.py +2 -0
- smftools/informatics/bed_functions.py +271 -61
- smftools/informatics/binarize_converted_base_identities.py +3 -0
- smftools/informatics/complement_base_list.py +2 -0
- smftools/informatics/converted_BAM_to_adata.py +641 -176
- smftools/informatics/fasta_functions.py +94 -10
- smftools/informatics/h5ad_functions.py +123 -4
- smftools/informatics/modkit_extract_to_adata.py +1019 -431
- smftools/informatics/modkit_functions.py +2 -0
- smftools/informatics/ohe.py +2 -0
- smftools/informatics/pod5_functions.py +3 -2
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/machine_learning/__init__.py +22 -6
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +18 -4
- smftools/machine_learning/data/preprocessing.py +2 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +2 -0
- smftools/machine_learning/evaluation/evaluators.py +14 -9
- smftools/machine_learning/inference/__init__.py +2 -0
- smftools/machine_learning/inference/inference_utils.py +2 -0
- smftools/machine_learning/inference/lightning_inference.py +6 -1
- smftools/machine_learning/inference/sklearn_inference.py +2 -0
- smftools/machine_learning/inference/sliding_window_inference.py +2 -0
- smftools/machine_learning/models/__init__.py +2 -0
- smftools/machine_learning/models/base.py +7 -2
- smftools/machine_learning/models/cnn.py +7 -2
- smftools/machine_learning/models/lightning_base.py +16 -11
- smftools/machine_learning/models/mlp.py +5 -1
- smftools/machine_learning/models/positional.py +7 -2
- smftools/machine_learning/models/rnn.py +5 -1
- smftools/machine_learning/models/sklearn_models.py +14 -9
- smftools/machine_learning/models/transformer.py +7 -2
- smftools/machine_learning/models/wrappers.py +6 -2
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +13 -3
- smftools/machine_learning/training/train_sklearn_model.py +2 -0
- smftools/machine_learning/utils/__init__.py +2 -0
- smftools/machine_learning/utils/device.py +5 -1
- smftools/machine_learning/utils/grl.py +5 -1
- smftools/metadata.py +1 -1
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +41 -31
- smftools/plotting/autocorrelation_plotting.py +9 -5
- smftools/plotting/classifiers.py +16 -4
- smftools/plotting/general_plotting.py +2415 -629
- smftools/plotting/hmm_plotting.py +97 -9
- smftools/plotting/position_stats.py +15 -7
- smftools/plotting/qc_plotting.py +6 -1
- smftools/preprocessing/__init__.py +36 -37
- smftools/preprocessing/append_base_context.py +17 -17
- smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
- smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/archived/calculate_complexity.py +2 -0
- smftools/preprocessing/archived/mark_duplicates.py +2 -0
- smftools/preprocessing/archived/preprocessing.py +2 -0
- smftools/preprocessing/archived/remove_duplicates.py +2 -0
- smftools/preprocessing/binary_layers_to_ohe.py +2 -1
- smftools/preprocessing/calculate_complexity_II.py +4 -1
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_pairwise_differences.py +2 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
- smftools/preprocessing/calculate_position_Youden.py +9 -2
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
- smftools/preprocessing/flag_duplicate_reads.py +42 -54
- smftools/preprocessing/make_dirs.py +2 -1
- smftools/preprocessing/min_non_diagonal.py +2 -0
- smftools/preprocessing/recipes.py +2 -0
- smftools/readwrite.py +53 -17
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +30 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +2 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +2 -0
- smftools/tools/archived/subset_adata_v2.py +2 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +119 -0
- smftools/tools/calculate_umap.py +93 -8
- smftools/tools/cluster_adata_on_methylation.py +7 -1
- smftools/tools/position_stats.py +17 -27
- smftools/tools/rolling_nn_distance.py +235 -0
- smftools/tools/tensor_factorization.py +169 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/METADATA +69 -33
- smftools-0.3.1.dist-info/RECORD +189 -0
- smftools-0.2.5.dist-info/RECORD +0 -181
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,16 +1,37 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import ast
|
|
4
|
+
import json
|
|
3
5
|
import math
|
|
4
6
|
import os
|
|
5
7
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Tuple
|
|
7
9
|
|
|
8
|
-
import matplotlib.gridspec as gridspec
|
|
9
|
-
import matplotlib.pyplot as plt
|
|
10
10
|
import numpy as np
|
|
11
11
|
import pandas as pd
|
|
12
12
|
import scipy.cluster.hierarchy as sch
|
|
13
|
-
|
|
13
|
+
|
|
14
|
+
from smftools.logging_utils import get_logger
|
|
15
|
+
from smftools.optional_imports import require
|
|
16
|
+
|
|
17
|
+
colors = require("matplotlib.colors", extra="plotting", purpose="plot rendering")
|
|
18
|
+
gridspec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
|
|
19
|
+
patches = require("matplotlib.patches", extra="plotting", purpose="plot rendering")
|
|
20
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
|
|
21
|
+
sns = require("seaborn", extra="plotting", purpose="plot styling")
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
DNA_5COLOR_PALETTE = {
|
|
26
|
+
"A": "#00A000", # green
|
|
27
|
+
"C": "#0000FF", # blue
|
|
28
|
+
"G": "#FF7F00", # orange
|
|
29
|
+
"T": "#FF0000", # red
|
|
30
|
+
"OTHER": "#808080", # gray (N, PAD, unknown)
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
import anndata as ad
|
|
14
35
|
|
|
15
36
|
|
|
16
37
|
def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
|
|
@@ -68,7 +89,7 @@ def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix:
|
|
|
68
89
|
return labels[sites]
|
|
69
90
|
|
|
70
91
|
|
|
71
|
-
def normalized_mean(matrix: np.ndarray) -> np.ndarray:
|
|
92
|
+
def normalized_mean(matrix: np.ndarray, *, ignore_nan: bool = True) -> np.ndarray:
|
|
72
93
|
"""Compute normalized column means for a matrix.
|
|
73
94
|
|
|
74
95
|
Args:
|
|
@@ -77,19 +98,362 @@ def normalized_mean(matrix: np.ndarray) -> np.ndarray:
|
|
|
77
98
|
Returns:
|
|
78
99
|
1D array of normalized means.
|
|
79
100
|
"""
|
|
80
|
-
mean = np.nanmean(matrix, axis=0)
|
|
101
|
+
mean = np.nanmean(matrix, axis=0) if ignore_nan else np.mean(matrix, axis=0)
|
|
81
102
|
denom = (mean.max() - mean.min()) + 1e-9
|
|
82
103
|
return (mean - mean.min()) / denom
|
|
83
104
|
|
|
84
105
|
|
|
85
|
-
def
|
|
106
|
+
def plot_nmf_components(
|
|
107
|
+
adata: "ad.AnnData",
|
|
108
|
+
*,
|
|
109
|
+
output_dir: Path | str,
|
|
110
|
+
components_key: str = "H_nmf",
|
|
111
|
+
heatmap_name: str = "nmf_H_heatmap.png",
|
|
112
|
+
lineplot_name: str = "nmf_H_lineplot.png",
|
|
113
|
+
max_features: int = 2000,
|
|
114
|
+
) -> Dict[str, Path]:
|
|
115
|
+
"""Plot NMF component weights as a heatmap and per-component line plot.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
adata: AnnData object containing NMF results.
|
|
119
|
+
output_dir: Directory to write plots into.
|
|
120
|
+
components_key: Key in ``adata.varm`` storing the H matrix.
|
|
121
|
+
heatmap_name: Filename for the heatmap plot.
|
|
122
|
+
lineplot_name: Filename for the line plot.
|
|
123
|
+
max_features: Maximum number of features to plot (top-weighted by component).
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Dict[str, Path]: Paths to created plots (keys: ``heatmap`` and ``lineplot``).
|
|
127
|
+
"""
|
|
128
|
+
if components_key not in adata.varm:
|
|
129
|
+
logger.warning("NMF components key '%s' not found in adata.varm.", components_key)
|
|
130
|
+
return {}
|
|
131
|
+
|
|
132
|
+
output_path = Path(output_dir)
|
|
133
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
134
|
+
|
|
135
|
+
components = np.asarray(adata.varm[components_key])
|
|
136
|
+
if components.ndim != 2:
|
|
137
|
+
raise ValueError(f"NMF components must be 2D; got shape {components.shape}.")
|
|
138
|
+
|
|
139
|
+
feature_labels = (
|
|
140
|
+
np.asarray(adata.var_names).astype(str)
|
|
141
|
+
if adata.shape[1] == components.shape[0]
|
|
142
|
+
else np.array([str(i) for i in range(components.shape[0])])
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
nonzero_mask = np.any(components != 0, axis=1)
|
|
146
|
+
if not np.any(nonzero_mask):
|
|
147
|
+
logger.warning("NMF components are all zeros; skipping plot generation.")
|
|
148
|
+
return {}
|
|
149
|
+
|
|
150
|
+
components = components[nonzero_mask]
|
|
151
|
+
feature_labels = feature_labels[nonzero_mask]
|
|
152
|
+
|
|
153
|
+
if max_features and components.shape[0] > max_features:
|
|
154
|
+
scores = np.nanmax(components, axis=1)
|
|
155
|
+
top_idx = np.argsort(scores)[-max_features:]
|
|
156
|
+
top_idx = np.sort(top_idx)
|
|
157
|
+
components = components[top_idx]
|
|
158
|
+
feature_labels = feature_labels[top_idx]
|
|
159
|
+
logger.info(
|
|
160
|
+
"Downsampled NMF features from %s to %s for plotting.",
|
|
161
|
+
nonzero_mask.sum(),
|
|
162
|
+
components.shape[0],
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
n_features, n_components = components.shape
|
|
166
|
+
component_labels = [f"C{i + 1}" for i in range(n_components)]
|
|
167
|
+
|
|
168
|
+
heatmap_width = max(8, min(20, n_features / 60))
|
|
169
|
+
heatmap_height = max(2.5, 0.6 * n_components + 1.5)
|
|
170
|
+
fig, ax = plt.subplots(figsize=(heatmap_width, heatmap_height))
|
|
171
|
+
sns.heatmap(
|
|
172
|
+
components.T,
|
|
173
|
+
ax=ax,
|
|
174
|
+
cmap="viridis",
|
|
175
|
+
cbar_kws={"label": "Component weight"},
|
|
176
|
+
xticklabels=feature_labels if n_features <= 60 else False,
|
|
177
|
+
yticklabels=component_labels,
|
|
178
|
+
)
|
|
179
|
+
ax.set_xlabel("Feature")
|
|
180
|
+
ax.set_ylabel("NMF component")
|
|
181
|
+
fig.tight_layout()
|
|
182
|
+
heatmap_path = output_path / heatmap_name
|
|
183
|
+
fig.savefig(heatmap_path, dpi=200)
|
|
184
|
+
plt.close(fig)
|
|
185
|
+
|
|
186
|
+
fig, ax = plt.subplots(figsize=(max(8, min(20, n_features / 50)), 3.5))
|
|
187
|
+
x = np.arange(n_features)
|
|
188
|
+
for idx, label in enumerate(component_labels):
|
|
189
|
+
ax.plot(x, components[:, idx], label=label, linewidth=1.5)
|
|
190
|
+
ax.set_xlabel("Feature index")
|
|
191
|
+
ax.set_ylabel("Component weight")
|
|
192
|
+
if n_features <= 60:
|
|
193
|
+
ax.set_xticks(x)
|
|
194
|
+
ax.set_xticklabels(feature_labels, rotation=90, fontsize=8)
|
|
195
|
+
ax.legend(loc="upper right", frameon=False)
|
|
196
|
+
fig.tight_layout()
|
|
197
|
+
lineplot_path = output_path / lineplot_name
|
|
198
|
+
fig.savefig(lineplot_path, dpi=200)
|
|
199
|
+
plt.close(fig)
|
|
200
|
+
|
|
201
|
+
return {"heatmap": heatmap_path, "lineplot": lineplot_path}
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def plot_cp_sequence_components(
|
|
205
|
+
adata: "ad.AnnData",
|
|
206
|
+
*,
|
|
207
|
+
output_dir: Path | str,
|
|
208
|
+
components_key: str = "H_cp_sequence",
|
|
209
|
+
uns_key: str = "cp_sequence",
|
|
210
|
+
heatmap_name: str = "cp_sequence_position_heatmap.png",
|
|
211
|
+
lineplot_name: str = "cp_sequence_position_lineplot.png",
|
|
212
|
+
base_name: str = "cp_sequence_base_weights.png",
|
|
213
|
+
max_positions: int = 2000,
|
|
214
|
+
) -> Dict[str, Path]:
|
|
215
|
+
"""Plot CP decomposition position and base factors.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
adata: AnnData object containing CP decomposition results.
|
|
219
|
+
output_dir: Directory to write plots into.
|
|
220
|
+
components_key: Key in ``adata.varm`` storing position factors.
|
|
221
|
+
uns_key: Key in ``adata.uns`` storing base factors.
|
|
222
|
+
heatmap_name: Filename for position heatmap.
|
|
223
|
+
lineplot_name: Filename for position line plot.
|
|
224
|
+
base_name: Filename for base factor bar plot.
|
|
225
|
+
max_positions: Maximum number of positions to plot.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Dict[str, Path]: Paths to created plots.
|
|
229
|
+
"""
|
|
230
|
+
if components_key not in adata.varm:
|
|
231
|
+
logger.warning("CP components key '%s' not found in adata.varm.", components_key)
|
|
232
|
+
return {}
|
|
233
|
+
|
|
234
|
+
output_path = Path(output_dir)
|
|
235
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
236
|
+
|
|
237
|
+
components = np.asarray(adata.varm[components_key])
|
|
238
|
+
if components.ndim != 2:
|
|
239
|
+
raise ValueError(f"CP position factors must be 2D; got shape {components.shape}.")
|
|
240
|
+
|
|
241
|
+
feature_labels = (
|
|
242
|
+
np.asarray(adata.var_names).astype(str)
|
|
243
|
+
if adata.shape[1] == components.shape[0]
|
|
244
|
+
else np.array([str(i) for i in range(components.shape[0])])
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if max_positions and components.shape[0] > max_positions:
|
|
248
|
+
original_count = components.shape[0]
|
|
249
|
+
scores = np.nanmax(np.abs(components), axis=1)
|
|
250
|
+
top_idx = np.argsort(scores)[-max_positions:]
|
|
251
|
+
top_idx = np.sort(top_idx)
|
|
252
|
+
components = components[top_idx]
|
|
253
|
+
feature_labels = feature_labels[top_idx]
|
|
254
|
+
logger.info(
|
|
255
|
+
"Downsampled CP positions from %s to %s for plotting.",
|
|
256
|
+
original_count,
|
|
257
|
+
max_positions,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
n_positions, n_components = components.shape
|
|
261
|
+
component_labels = [f"C{i + 1}" for i in range(n_components)]
|
|
262
|
+
|
|
263
|
+
heatmap_width = max(8, min(20, n_positions / 60))
|
|
264
|
+
heatmap_height = max(2.5, 0.6 * n_components + 1.5)
|
|
265
|
+
fig, ax = plt.subplots(figsize=(heatmap_width, heatmap_height))
|
|
266
|
+
sns.heatmap(
|
|
267
|
+
components.T,
|
|
268
|
+
ax=ax,
|
|
269
|
+
cmap="viridis",
|
|
270
|
+
cbar_kws={"label": "Component weight"},
|
|
271
|
+
xticklabels=feature_labels if n_positions <= 60 else False,
|
|
272
|
+
yticklabels=component_labels,
|
|
273
|
+
)
|
|
274
|
+
ax.set_xlabel("Position")
|
|
275
|
+
ax.set_ylabel("CP component")
|
|
276
|
+
fig.tight_layout()
|
|
277
|
+
heatmap_path = output_path / heatmap_name
|
|
278
|
+
fig.savefig(heatmap_path, dpi=200)
|
|
279
|
+
plt.close(fig)
|
|
280
|
+
|
|
281
|
+
fig, ax = plt.subplots(figsize=(max(8, min(20, n_positions / 50)), 3.5))
|
|
282
|
+
x = np.arange(n_positions)
|
|
283
|
+
for idx, label in enumerate(component_labels):
|
|
284
|
+
ax.plot(x, components[:, idx], label=label, linewidth=1.5)
|
|
285
|
+
ax.set_xlabel("Position index")
|
|
286
|
+
ax.set_ylabel("Component weight")
|
|
287
|
+
if n_positions <= 60:
|
|
288
|
+
ax.set_xticks(x)
|
|
289
|
+
ax.set_xticklabels(feature_labels, rotation=90, fontsize=8)
|
|
290
|
+
ax.legend(loc="upper right", frameon=False)
|
|
291
|
+
fig.tight_layout()
|
|
292
|
+
lineplot_path = output_path / lineplot_name
|
|
293
|
+
fig.savefig(lineplot_path, dpi=200)
|
|
294
|
+
plt.close(fig)
|
|
295
|
+
|
|
296
|
+
outputs = {"heatmap": heatmap_path, "lineplot": lineplot_path}
|
|
297
|
+
if uns_key in adata.uns:
|
|
298
|
+
base_factors = adata.uns[uns_key].get("base_factors")
|
|
299
|
+
base_labels = adata.uns[uns_key].get("base_labels")
|
|
300
|
+
if base_factors is not None:
|
|
301
|
+
base_factors = np.asarray(base_factors)
|
|
302
|
+
if base_factors.ndim != 2 or base_factors.size == 0:
|
|
303
|
+
logger.warning(
|
|
304
|
+
"CP base factors must be 2D and non-empty; got shape %s.",
|
|
305
|
+
base_factors.shape,
|
|
306
|
+
)
|
|
307
|
+
else:
|
|
308
|
+
base_labels = base_labels or [f"B{i + 1}" for i in range(base_factors.shape[0])]
|
|
309
|
+
fig, ax = plt.subplots(figsize=(4.5, 3))
|
|
310
|
+
width = 0.8 / base_factors.shape[1]
|
|
311
|
+
x = np.arange(base_factors.shape[0])
|
|
312
|
+
for idx in range(base_factors.shape[1]):
|
|
313
|
+
ax.bar(
|
|
314
|
+
x + idx * width,
|
|
315
|
+
base_factors[:, idx],
|
|
316
|
+
width=width,
|
|
317
|
+
label=f"C{idx + 1}",
|
|
318
|
+
)
|
|
319
|
+
ax.set_xticks(x + width * (base_factors.shape[1] - 1) / 2)
|
|
320
|
+
ax.set_xticklabels(base_labels)
|
|
321
|
+
ax.set_ylabel("Base factor weight")
|
|
322
|
+
ax.legend(loc="upper right", frameon=False)
|
|
323
|
+
fig.tight_layout()
|
|
324
|
+
base_path = output_path / base_name
|
|
325
|
+
fig.savefig(base_path, dpi=200)
|
|
326
|
+
plt.close(fig)
|
|
327
|
+
outputs["base_factors"] = base_path
|
|
328
|
+
|
|
329
|
+
return outputs
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _resolve_feature_color(cmap: Any) -> Tuple[float, float, float, float]:
|
|
333
|
+
"""Resolve a representative feature color from a colormap or color spec."""
|
|
334
|
+
if isinstance(cmap, str):
|
|
335
|
+
try:
|
|
336
|
+
cmap_obj = plt.get_cmap(cmap)
|
|
337
|
+
return colors.to_rgba(cmap_obj(1.0))
|
|
338
|
+
except Exception:
|
|
339
|
+
return colors.to_rgba(cmap)
|
|
340
|
+
|
|
341
|
+
if isinstance(cmap, colors.Colormap):
|
|
342
|
+
if hasattr(cmap, "colors") and cmap.colors:
|
|
343
|
+
return colors.to_rgba(cmap.colors[-1])
|
|
344
|
+
return colors.to_rgba(cmap(1.0))
|
|
345
|
+
|
|
346
|
+
return colors.to_rgba("black")
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _build_hmm_feature_cmap(
|
|
350
|
+
cmap: Any,
|
|
351
|
+
*,
|
|
352
|
+
zero_color: str = "#f5f1e8",
|
|
353
|
+
nan_color: str = "#E6E6E6",
|
|
354
|
+
) -> colors.Colormap:
|
|
355
|
+
"""Build a two-color HMM colormap with explicit NaN/under handling."""
|
|
356
|
+
feature_color = _resolve_feature_color(cmap)
|
|
357
|
+
hmm_cmap = colors.LinearSegmentedColormap.from_list(
|
|
358
|
+
"hmm_feature_cmap",
|
|
359
|
+
[zero_color, feature_color],
|
|
360
|
+
)
|
|
361
|
+
hmm_cmap.set_bad(nan_color)
|
|
362
|
+
hmm_cmap.set_under(nan_color)
|
|
363
|
+
return hmm_cmap
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _map_length_matrix_to_subclasses(
|
|
367
|
+
length_matrix: np.ndarray,
|
|
368
|
+
feature_ranges: Sequence[Tuple[int, int, Any]],
|
|
369
|
+
) -> np.ndarray:
|
|
370
|
+
"""Map length values into subclass integer codes based on feature ranges."""
|
|
371
|
+
mapped = np.zeros_like(length_matrix, dtype=float)
|
|
372
|
+
finite_mask = np.isfinite(length_matrix)
|
|
373
|
+
for idx, (min_len, max_len, _color) in enumerate(feature_ranges, start=1):
|
|
374
|
+
mask = finite_mask & (length_matrix >= min_len) & (length_matrix <= max_len)
|
|
375
|
+
mapped[mask] = float(idx)
|
|
376
|
+
mapped[~finite_mask] = np.nan
|
|
377
|
+
return mapped
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def _build_length_feature_cmap(
|
|
381
|
+
feature_ranges: Sequence[Tuple[int, int, Any]],
|
|
382
|
+
*,
|
|
383
|
+
zero_color: str = "#f5f1e8",
|
|
384
|
+
nan_color: str = "#E6E6E6",
|
|
385
|
+
) -> Tuple[colors.Colormap, colors.BoundaryNorm]:
|
|
386
|
+
"""Build a discrete colormap and norm for length-based subclasses."""
|
|
387
|
+
color_list = [zero_color] + [color for _, _, color in feature_ranges]
|
|
388
|
+
cmap = colors.ListedColormap(color_list, name="hmm_length_feature_cmap")
|
|
389
|
+
cmap.set_bad(nan_color)
|
|
390
|
+
bounds = np.arange(-0.5, len(color_list) + 0.5, 1)
|
|
391
|
+
norm = colors.BoundaryNorm(bounds, cmap.N)
|
|
392
|
+
return cmap, norm
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def _layer_to_numpy(
|
|
396
|
+
subset,
|
|
397
|
+
layer_name: str,
|
|
398
|
+
sites: np.ndarray | None = None,
|
|
399
|
+
*,
|
|
400
|
+
fill_nan_strategy: str = "value",
|
|
401
|
+
fill_nan_value: float = -1,
|
|
402
|
+
) -> np.ndarray:
|
|
403
|
+
"""Return a (copied) numpy array for a layer with optional NaN filling."""
|
|
404
|
+
if sites is not None:
|
|
405
|
+
layer_data = subset[:, sites].layers[layer_name]
|
|
406
|
+
else:
|
|
407
|
+
layer_data = subset.layers[layer_name]
|
|
408
|
+
|
|
409
|
+
if hasattr(layer_data, "toarray"):
|
|
410
|
+
arr = layer_data.toarray()
|
|
411
|
+
else:
|
|
412
|
+
arr = np.asarray(layer_data)
|
|
413
|
+
|
|
414
|
+
arr = np.array(arr, copy=True)
|
|
415
|
+
|
|
416
|
+
if fill_nan_strategy == "none":
|
|
417
|
+
return arr
|
|
418
|
+
|
|
419
|
+
if fill_nan_strategy not in {"value", "col_mean"}:
|
|
420
|
+
raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
|
|
421
|
+
|
|
422
|
+
arr = arr.astype(float, copy=False)
|
|
423
|
+
|
|
424
|
+
if fill_nan_strategy == "value":
|
|
425
|
+
return np.where(np.isnan(arr), fill_nan_value, arr)
|
|
426
|
+
|
|
427
|
+
col_mean = np.nanmean(arr, axis=0)
|
|
428
|
+
if np.any(np.isnan(col_mean)):
|
|
429
|
+
col_mean = np.where(np.isnan(col_mean), fill_nan_value, col_mean)
|
|
430
|
+
return np.where(np.isnan(arr), col_mean, arr)
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def _infer_zero_is_valid(layer_name: str | None, matrix: np.ndarray) -> bool:
|
|
434
|
+
"""Infer whether zeros should count as valid (unmethylated) values."""
|
|
435
|
+
if layer_name and "nan0_0minus1" in layer_name:
|
|
436
|
+
return False
|
|
437
|
+
if np.isnan(matrix).any():
|
|
438
|
+
return True
|
|
439
|
+
if np.any(matrix < 0):
|
|
440
|
+
return False
|
|
441
|
+
return True
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def methylation_fraction(
|
|
445
|
+
matrix: np.ndarray, *, ignore_nan: bool = True, zero_is_valid: bool = False
|
|
446
|
+
) -> np.ndarray:
|
|
86
447
|
"""
|
|
87
448
|
Fraction methylated per column.
|
|
88
449
|
Methylated = 1
|
|
89
|
-
Valid = finite AND not 0
|
|
450
|
+
Valid = finite AND not 0 (unless zero_is_valid=True)
|
|
90
451
|
"""
|
|
91
452
|
matrix = np.asarray(matrix)
|
|
92
|
-
|
|
453
|
+
if not ignore_nan:
|
|
454
|
+
matrix = np.where(np.isnan(matrix), 0, matrix)
|
|
455
|
+
finite_mask = np.isfinite(matrix)
|
|
456
|
+
valid_mask = finite_mask if zero_is_valid else (finite_mask & (matrix != 0))
|
|
93
457
|
methyl_mask = (matrix == 1) & np.isfinite(matrix)
|
|
94
458
|
|
|
95
459
|
methylated = methyl_mask.sum(axis=0)
|
|
@@ -100,20 +464,53 @@ def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
|
|
|
100
464
|
)
|
|
101
465
|
|
|
102
466
|
|
|
103
|
-
def
|
|
467
|
+
def _methylation_fraction_for_layer(
|
|
468
|
+
matrix: np.ndarray,
|
|
469
|
+
layer_name: str | None,
|
|
470
|
+
*,
|
|
471
|
+
ignore_nan: bool = True,
|
|
472
|
+
zero_is_valid: bool | None = None,
|
|
473
|
+
) -> np.ndarray:
|
|
474
|
+
"""Compute methylation fractions with layer-aware zero handling."""
|
|
475
|
+
matrix = np.asarray(matrix)
|
|
476
|
+
if zero_is_valid is None:
|
|
477
|
+
zero_is_valid = _infer_zero_is_valid(layer_name, matrix)
|
|
478
|
+
return methylation_fraction(matrix, ignore_nan=ignore_nan, zero_is_valid=zero_is_valid)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def clean_barplot(
|
|
482
|
+
ax,
|
|
483
|
+
mean_values,
|
|
484
|
+
title,
|
|
485
|
+
*,
|
|
486
|
+
y_max: float | None = 1.0,
|
|
487
|
+
y_label: str = "Mean",
|
|
488
|
+
y_ticks: list[float] | None = None,
|
|
489
|
+
):
|
|
104
490
|
"""Format a barplot with consistent axes and labels.
|
|
105
491
|
|
|
106
492
|
Args:
|
|
107
493
|
ax: Matplotlib axes.
|
|
108
494
|
mean_values: Values to plot.
|
|
109
495
|
title: Plot title.
|
|
496
|
+
y_max: Optional y-axis max; inferred from data if not provided.
|
|
497
|
+
y_label: Y-axis label.
|
|
498
|
+
y_ticks: Optional y-axis ticks.
|
|
110
499
|
"""
|
|
111
500
|
x = np.arange(len(mean_values))
|
|
112
501
|
ax.bar(x, mean_values, color="gray", width=1.0, align="edge")
|
|
113
502
|
ax.set_xlim(0, len(mean_values))
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
503
|
+
if y_ticks is None and y_max == 1.0:
|
|
504
|
+
y_ticks = [0.0, 0.5, 1.0]
|
|
505
|
+
if y_max is None:
|
|
506
|
+
y_max = np.nanmax(mean_values) if len(mean_values) else 1.0
|
|
507
|
+
if not np.isfinite(y_max) or y_max <= 0:
|
|
508
|
+
y_max = 1.0
|
|
509
|
+
y_max *= 1.05
|
|
510
|
+
ax.set_ylim(0, y_max)
|
|
511
|
+
if y_ticks is not None:
|
|
512
|
+
ax.set_yticks(y_ticks)
|
|
513
|
+
ax.set_ylabel(y_label)
|
|
117
514
|
ax.set_title(title, fontsize=12, pad=2)
|
|
118
515
|
|
|
119
516
|
# Hide all spines except left
|
|
@@ -123,222 +520,6 @@ def clean_barplot(ax, mean_values, title):
|
|
|
123
520
|
ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
|
|
124
521
|
|
|
125
522
|
|
|
126
|
-
# def combined_hmm_raw_clustermap(
|
|
127
|
-
# adata,
|
|
128
|
-
# sample_col='Sample_Names',
|
|
129
|
-
# reference_col='Reference_strand',
|
|
130
|
-
# hmm_feature_layer="hmm_combined",
|
|
131
|
-
# layer_gpc="nan0_0minus1",
|
|
132
|
-
# layer_cpg="nan0_0minus1",
|
|
133
|
-
# layer_any_c="nan0_0minus1",
|
|
134
|
-
# cmap_hmm="tab10",
|
|
135
|
-
# cmap_gpc="coolwarm",
|
|
136
|
-
# cmap_cpg="viridis",
|
|
137
|
-
# cmap_any_c='coolwarm',
|
|
138
|
-
# min_quality=20,
|
|
139
|
-
# min_length=200,
|
|
140
|
-
# min_mapped_length_to_reference_length_ratio=0.8,
|
|
141
|
-
# min_position_valid_fraction=0.5,
|
|
142
|
-
# sample_mapping=None,
|
|
143
|
-
# save_path=None,
|
|
144
|
-
# normalize_hmm=False,
|
|
145
|
-
# sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
|
|
146
|
-
# bins=None,
|
|
147
|
-
# deaminase=False,
|
|
148
|
-
# min_signal=0
|
|
149
|
-
# ):
|
|
150
|
-
|
|
151
|
-
# results = []
|
|
152
|
-
# if deaminase:
|
|
153
|
-
# signal_type = 'deamination'
|
|
154
|
-
# else:
|
|
155
|
-
# signal_type = 'methylation'
|
|
156
|
-
|
|
157
|
-
# for ref in adata.obs[reference_col].cat.categories:
|
|
158
|
-
# for sample in adata.obs[sample_col].cat.categories:
|
|
159
|
-
# try:
|
|
160
|
-
# subset = adata[
|
|
161
|
-
# (adata.obs[reference_col] == ref) &
|
|
162
|
-
# (adata.obs[sample_col] == sample) &
|
|
163
|
-
# (adata.obs['read_quality'] >= min_quality) &
|
|
164
|
-
# (adata.obs['read_length'] >= min_length) &
|
|
165
|
-
# (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
|
|
166
|
-
# ]
|
|
167
|
-
|
|
168
|
-
# mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
|
|
169
|
-
# subset = subset[:, mask]
|
|
170
|
-
|
|
171
|
-
# if subset.shape[0] == 0:
|
|
172
|
-
# print(f" No reads left after filtering for {sample} - {ref}")
|
|
173
|
-
# continue
|
|
174
|
-
|
|
175
|
-
# if bins:
|
|
176
|
-
# print(f"Using defined bins to subset clustermap for {sample} - {ref}")
|
|
177
|
-
# bins_temp = bins
|
|
178
|
-
# else:
|
|
179
|
-
# print(f"Using all reads for clustermap for {sample} - {ref}")
|
|
180
|
-
# bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
|
|
181
|
-
|
|
182
|
-
# # Get column positions (not var_names!) of site masks
|
|
183
|
-
# gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
|
|
184
|
-
# cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
|
|
185
|
-
# any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
|
|
186
|
-
# num_gpc = len(gpc_sites)
|
|
187
|
-
# num_cpg = len(cpg_sites)
|
|
188
|
-
# num_c = len(any_c_sites)
|
|
189
|
-
# print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
|
|
190
|
-
|
|
191
|
-
# # Use var_names for x-axis tick labels
|
|
192
|
-
# gpc_labels = subset.var_names[gpc_sites].astype(int)
|
|
193
|
-
# cpg_labels = subset.var_names[cpg_sites].astype(int)
|
|
194
|
-
# any_c_labels = subset.var_names[any_c_sites].astype(int)
|
|
195
|
-
|
|
196
|
-
# stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
|
|
197
|
-
# row_labels, bin_labels = [], []
|
|
198
|
-
# bin_boundaries = []
|
|
199
|
-
|
|
200
|
-
# total_reads = subset.shape[0]
|
|
201
|
-
# percentages = {}
|
|
202
|
-
# last_idx = 0
|
|
203
|
-
|
|
204
|
-
# for bin_label, bin_filter in bins_temp.items():
|
|
205
|
-
# subset_bin = subset[bin_filter].copy()
|
|
206
|
-
# num_reads = subset_bin.shape[0]
|
|
207
|
-
# print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
|
|
208
|
-
# percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
|
|
209
|
-
# percentages[bin_label] = percent_reads
|
|
210
|
-
|
|
211
|
-
# if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
|
|
212
|
-
# # Determine sorting order
|
|
213
|
-
# if sort_by.startswith("obs:"):
|
|
214
|
-
# colname = sort_by.split("obs:")[1]
|
|
215
|
-
# order = np.argsort(subset_bin.obs[colname].values)
|
|
216
|
-
# elif sort_by == "gpc":
|
|
217
|
-
# linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
|
|
218
|
-
# order = sch.leaves_list(linkage)
|
|
219
|
-
# elif sort_by == "cpg":
|
|
220
|
-
# linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
|
|
221
|
-
# order = sch.leaves_list(linkage)
|
|
222
|
-
# elif sort_by == "gpc_cpg":
|
|
223
|
-
# linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
|
|
224
|
-
# order = sch.leaves_list(linkage)
|
|
225
|
-
# elif sort_by == "none":
|
|
226
|
-
# order = np.arange(num_reads)
|
|
227
|
-
# elif sort_by == "any_c":
|
|
228
|
-
# linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
|
|
229
|
-
# order = sch.leaves_list(linkage)
|
|
230
|
-
# else:
|
|
231
|
-
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
232
|
-
|
|
233
|
-
# stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
|
|
234
|
-
# stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
|
|
235
|
-
# stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
|
|
236
|
-
# stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
|
|
237
|
-
|
|
238
|
-
# row_labels.extend([bin_label] * num_reads)
|
|
239
|
-
# bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
240
|
-
# last_idx += num_reads
|
|
241
|
-
# bin_boundaries.append(last_idx)
|
|
242
|
-
|
|
243
|
-
# if stacked_hmm_feature:
|
|
244
|
-
# hmm_matrix = np.vstack(stacked_hmm_feature)
|
|
245
|
-
# gpc_matrix = np.vstack(stacked_gpc)
|
|
246
|
-
# cpg_matrix = np.vstack(stacked_cpg)
|
|
247
|
-
# any_c_matrix = np.vstack(stacked_any_c)
|
|
248
|
-
|
|
249
|
-
# if hmm_matrix.size > 0:
|
|
250
|
-
# def normalized_mean(matrix):
|
|
251
|
-
# mean = np.nanmean(matrix, axis=0)
|
|
252
|
-
# normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
|
|
253
|
-
# return normalized
|
|
254
|
-
|
|
255
|
-
# def methylation_fraction(matrix):
|
|
256
|
-
# methylated = (matrix == 1).sum(axis=0)
|
|
257
|
-
# valid = (matrix != 0).sum(axis=0)
|
|
258
|
-
# return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
|
|
259
|
-
|
|
260
|
-
# if normalize_hmm:
|
|
261
|
-
# mean_hmm = normalized_mean(hmm_matrix)
|
|
262
|
-
# else:
|
|
263
|
-
# mean_hmm = np.nanmean(hmm_matrix, axis=0)
|
|
264
|
-
# mean_gpc = methylation_fraction(gpc_matrix)
|
|
265
|
-
# mean_cpg = methylation_fraction(cpg_matrix)
|
|
266
|
-
# mean_any_c = methylation_fraction(any_c_matrix)
|
|
267
|
-
|
|
268
|
-
# fig = plt.figure(figsize=(18, 12))
|
|
269
|
-
# gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
|
|
270
|
-
# fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
|
|
271
|
-
|
|
272
|
-
# axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
|
|
273
|
-
# axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
|
|
274
|
-
|
|
275
|
-
# clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
|
|
276
|
-
# clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
|
|
277
|
-
# clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
|
|
278
|
-
# clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
|
|
279
|
-
|
|
280
|
-
# hmm_labels = subset.var_names.astype(int)
|
|
281
|
-
# hmm_label_spacing = 150
|
|
282
|
-
# sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
|
|
283
|
-
# axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
|
|
284
|
-
# axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
|
|
285
|
-
# for boundary in bin_boundaries[:-1]:
|
|
286
|
-
# axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
|
|
287
|
-
|
|
288
|
-
# sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
|
|
289
|
-
# axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
|
|
290
|
-
# axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
|
|
291
|
-
# for boundary in bin_boundaries[:-1]:
|
|
292
|
-
# axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
|
|
293
|
-
|
|
294
|
-
# sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
|
|
295
|
-
# axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
|
|
296
|
-
# for boundary in bin_boundaries[:-1]:
|
|
297
|
-
# axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
|
|
298
|
-
|
|
299
|
-
# sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
|
|
300
|
-
# axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
|
|
301
|
-
# axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
|
|
302
|
-
# for boundary in bin_boundaries[:-1]:
|
|
303
|
-
# axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
|
|
304
|
-
|
|
305
|
-
# plt.tight_layout()
|
|
306
|
-
|
|
307
|
-
# if save_path:
|
|
308
|
-
# save_name = f"{ref} — {sample}"
|
|
309
|
-
# os.makedirs(save_path, exist_ok=True)
|
|
310
|
-
# safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
311
|
-
# out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
312
|
-
# plt.savefig(out_file, dpi=300)
|
|
313
|
-
# print(f"Saved: {out_file}")
|
|
314
|
-
# plt.close()
|
|
315
|
-
# else:
|
|
316
|
-
# plt.show()
|
|
317
|
-
|
|
318
|
-
# print(f"Summary for {sample} - {ref}:")
|
|
319
|
-
# for bin_label, percent in percentages.items():
|
|
320
|
-
# print(f" - {bin_label}: {percent:.1f}%")
|
|
321
|
-
|
|
322
|
-
# results.append({
|
|
323
|
-
# "sample": sample,
|
|
324
|
-
# "ref": ref,
|
|
325
|
-
# "hmm_matrix": hmm_matrix,
|
|
326
|
-
# "gpc_matrix": gpc_matrix,
|
|
327
|
-
# "cpg_matrix": cpg_matrix,
|
|
328
|
-
# "row_labels": row_labels,
|
|
329
|
-
# "bin_labels": bin_labels,
|
|
330
|
-
# "bin_boundaries": bin_boundaries,
|
|
331
|
-
# "percentages": percentages
|
|
332
|
-
# })
|
|
333
|
-
|
|
334
|
-
# #adata.uns['clustermap_results'] = results
|
|
335
|
-
|
|
336
|
-
# except Exception as e:
|
|
337
|
-
# import traceback
|
|
338
|
-
# traceback.print_exc()
|
|
339
|
-
# continue
|
|
340
|
-
|
|
341
|
-
|
|
342
523
|
def combined_hmm_raw_clustermap(
|
|
343
524
|
adata,
|
|
344
525
|
sample_col: str = "Sample_Names",
|
|
@@ -372,6 +553,8 @@ def combined_hmm_raw_clustermap(
|
|
|
372
553
|
n_xticks_cpg: int = 8,
|
|
373
554
|
n_xticks_a: int = 8,
|
|
374
555
|
index_col_suffix: str | None = None,
|
|
556
|
+
fill_nan_strategy: str = "value",
|
|
557
|
+
fill_nan_value: float = -1,
|
|
375
558
|
):
|
|
376
559
|
"""
|
|
377
560
|
Makes a multi-panel clustermap per (sample, reference):
|
|
@@ -381,7 +564,11 @@ def combined_hmm_raw_clustermap(
|
|
|
381
564
|
|
|
382
565
|
sort_by options:
|
|
383
566
|
'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
|
|
567
|
+
|
|
568
|
+
NaN fill strategy is applied in-memory for clustering/plotting only.
|
|
384
569
|
"""
|
|
570
|
+
if fill_nan_strategy not in {"none", "value", "col_mean"}:
|
|
571
|
+
raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
|
|
385
572
|
|
|
386
573
|
def pick_xticks(labels: np.ndarray, n_ticks: int):
|
|
387
574
|
"""Pick tick indices/labels from an array."""
|
|
@@ -500,10 +687,15 @@ def combined_hmm_raw_clustermap(
|
|
|
500
687
|
|
|
501
688
|
# storage
|
|
502
689
|
stacked_hmm = []
|
|
690
|
+
stacked_hmm_raw = []
|
|
503
691
|
stacked_any_c = []
|
|
692
|
+
stacked_any_c_raw = []
|
|
504
693
|
stacked_gpc = []
|
|
694
|
+
stacked_gpc_raw = []
|
|
505
695
|
stacked_cpg = []
|
|
696
|
+
stacked_cpg_raw = []
|
|
506
697
|
stacked_any_a = []
|
|
698
|
+
stacked_any_a_raw = []
|
|
507
699
|
|
|
508
700
|
row_labels, bin_labels, bin_boundaries = [], [], []
|
|
509
701
|
total_reads = subset.n_obs
|
|
@@ -526,29 +718,69 @@ def combined_hmm_raw_clustermap(
|
|
|
526
718
|
order = np.argsort(sb.obs[colname].values)
|
|
527
719
|
|
|
528
720
|
elif sort_by == "gpc" and gpc_sites.size:
|
|
529
|
-
|
|
721
|
+
gpc_matrix = _layer_to_numpy(
|
|
722
|
+
sb,
|
|
723
|
+
layer_gpc,
|
|
724
|
+
gpc_sites,
|
|
725
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
726
|
+
fill_nan_value=fill_nan_value,
|
|
727
|
+
)
|
|
728
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
530
729
|
order = sch.leaves_list(linkage)
|
|
531
730
|
|
|
532
731
|
elif sort_by == "cpg" and cpg_sites.size:
|
|
533
|
-
|
|
732
|
+
cpg_matrix = _layer_to_numpy(
|
|
733
|
+
sb,
|
|
734
|
+
layer_cpg,
|
|
735
|
+
cpg_sites,
|
|
736
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
737
|
+
fill_nan_value=fill_nan_value,
|
|
738
|
+
)
|
|
739
|
+
linkage = sch.linkage(cpg_matrix, method="ward")
|
|
534
740
|
order = sch.leaves_list(linkage)
|
|
535
741
|
|
|
536
742
|
elif sort_by == "c" and any_c_sites.size:
|
|
537
|
-
|
|
743
|
+
any_c_matrix = _layer_to_numpy(
|
|
744
|
+
sb,
|
|
745
|
+
layer_c,
|
|
746
|
+
any_c_sites,
|
|
747
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
748
|
+
fill_nan_value=fill_nan_value,
|
|
749
|
+
)
|
|
750
|
+
linkage = sch.linkage(any_c_matrix, method="ward")
|
|
538
751
|
order = sch.leaves_list(linkage)
|
|
539
752
|
|
|
540
753
|
elif sort_by == "a" and any_a_sites.size:
|
|
541
|
-
|
|
754
|
+
any_a_matrix = _layer_to_numpy(
|
|
755
|
+
sb,
|
|
756
|
+
layer_a,
|
|
757
|
+
any_a_sites,
|
|
758
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
759
|
+
fill_nan_value=fill_nan_value,
|
|
760
|
+
)
|
|
761
|
+
linkage = sch.linkage(any_a_matrix, method="ward")
|
|
542
762
|
order = sch.leaves_list(linkage)
|
|
543
763
|
|
|
544
764
|
elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
|
|
545
|
-
|
|
765
|
+
gpc_matrix = _layer_to_numpy(
|
|
766
|
+
sb,
|
|
767
|
+
layer_gpc,
|
|
768
|
+
None,
|
|
769
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
770
|
+
fill_nan_value=fill_nan_value,
|
|
771
|
+
)
|
|
772
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
546
773
|
order = sch.leaves_list(linkage)
|
|
547
774
|
|
|
548
775
|
elif sort_by == "hmm" and hmm_sites.size:
|
|
549
|
-
|
|
550
|
-
sb
|
|
776
|
+
hmm_matrix = _layer_to_numpy(
|
|
777
|
+
sb,
|
|
778
|
+
hmm_feature_layer,
|
|
779
|
+
hmm_sites,
|
|
780
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
781
|
+
fill_nan_value=fill_nan_value,
|
|
551
782
|
)
|
|
783
|
+
linkage = sch.linkage(hmm_matrix, method="ward")
|
|
552
784
|
order = sch.leaves_list(linkage)
|
|
553
785
|
|
|
554
786
|
else:
|
|
@@ -557,15 +789,100 @@ def combined_hmm_raw_clustermap(
|
|
|
557
789
|
sb = sb[order]
|
|
558
790
|
|
|
559
791
|
# ---- collect matrices ----
|
|
560
|
-
stacked_hmm.append(
|
|
792
|
+
stacked_hmm.append(
|
|
793
|
+
_layer_to_numpy(
|
|
794
|
+
sb,
|
|
795
|
+
hmm_feature_layer,
|
|
796
|
+
None,
|
|
797
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
798
|
+
fill_nan_value=fill_nan_value,
|
|
799
|
+
)
|
|
800
|
+
)
|
|
801
|
+
stacked_hmm_raw.append(
|
|
802
|
+
_layer_to_numpy(
|
|
803
|
+
sb,
|
|
804
|
+
hmm_feature_layer,
|
|
805
|
+
None,
|
|
806
|
+
fill_nan_strategy="none",
|
|
807
|
+
fill_nan_value=fill_nan_value,
|
|
808
|
+
)
|
|
809
|
+
)
|
|
561
810
|
if any_c_sites.size:
|
|
562
|
-
stacked_any_c.append(
|
|
811
|
+
stacked_any_c.append(
|
|
812
|
+
_layer_to_numpy(
|
|
813
|
+
sb,
|
|
814
|
+
layer_c,
|
|
815
|
+
any_c_sites,
|
|
816
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
817
|
+
fill_nan_value=fill_nan_value,
|
|
818
|
+
)
|
|
819
|
+
)
|
|
820
|
+
stacked_any_c_raw.append(
|
|
821
|
+
_layer_to_numpy(
|
|
822
|
+
sb,
|
|
823
|
+
layer_c,
|
|
824
|
+
any_c_sites,
|
|
825
|
+
fill_nan_strategy="none",
|
|
826
|
+
fill_nan_value=fill_nan_value,
|
|
827
|
+
)
|
|
828
|
+
)
|
|
563
829
|
if gpc_sites.size:
|
|
564
|
-
stacked_gpc.append(
|
|
830
|
+
stacked_gpc.append(
|
|
831
|
+
_layer_to_numpy(
|
|
832
|
+
sb,
|
|
833
|
+
layer_gpc,
|
|
834
|
+
gpc_sites,
|
|
835
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
836
|
+
fill_nan_value=fill_nan_value,
|
|
837
|
+
)
|
|
838
|
+
)
|
|
839
|
+
stacked_gpc_raw.append(
|
|
840
|
+
_layer_to_numpy(
|
|
841
|
+
sb,
|
|
842
|
+
layer_gpc,
|
|
843
|
+
gpc_sites,
|
|
844
|
+
fill_nan_strategy="none",
|
|
845
|
+
fill_nan_value=fill_nan_value,
|
|
846
|
+
)
|
|
847
|
+
)
|
|
565
848
|
if cpg_sites.size:
|
|
566
|
-
stacked_cpg.append(
|
|
849
|
+
stacked_cpg.append(
|
|
850
|
+
_layer_to_numpy(
|
|
851
|
+
sb,
|
|
852
|
+
layer_cpg,
|
|
853
|
+
cpg_sites,
|
|
854
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
855
|
+
fill_nan_value=fill_nan_value,
|
|
856
|
+
)
|
|
857
|
+
)
|
|
858
|
+
stacked_cpg_raw.append(
|
|
859
|
+
_layer_to_numpy(
|
|
860
|
+
sb,
|
|
861
|
+
layer_cpg,
|
|
862
|
+
cpg_sites,
|
|
863
|
+
fill_nan_strategy="none",
|
|
864
|
+
fill_nan_value=fill_nan_value,
|
|
865
|
+
)
|
|
866
|
+
)
|
|
567
867
|
if any_a_sites.size:
|
|
568
|
-
stacked_any_a.append(
|
|
868
|
+
stacked_any_a.append(
|
|
869
|
+
_layer_to_numpy(
|
|
870
|
+
sb,
|
|
871
|
+
layer_a,
|
|
872
|
+
any_a_sites,
|
|
873
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
874
|
+
fill_nan_value=fill_nan_value,
|
|
875
|
+
)
|
|
876
|
+
)
|
|
877
|
+
stacked_any_a_raw.append(
|
|
878
|
+
_layer_to_numpy(
|
|
879
|
+
sb,
|
|
880
|
+
layer_a,
|
|
881
|
+
any_a_sites,
|
|
882
|
+
fill_nan_strategy="none",
|
|
883
|
+
fill_nan_value=fill_nan_value,
|
|
884
|
+
)
|
|
885
|
+
)
|
|
569
886
|
|
|
570
887
|
row_labels.extend([bin_label] * n)
|
|
571
888
|
bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
|
|
@@ -574,16 +891,21 @@ def combined_hmm_raw_clustermap(
|
|
|
574
891
|
|
|
575
892
|
# ---------------- stack ----------------
|
|
576
893
|
hmm_matrix = np.vstack(stacked_hmm)
|
|
894
|
+
hmm_matrix_raw = np.vstack(stacked_hmm_raw)
|
|
577
895
|
mean_hmm = (
|
|
578
|
-
normalized_mean(
|
|
896
|
+
normalized_mean(hmm_matrix_raw)
|
|
897
|
+
if normalize_hmm
|
|
898
|
+
else np.nanmean(hmm_matrix_raw, axis=0)
|
|
579
899
|
)
|
|
900
|
+
hmm_plot_matrix = hmm_matrix_raw
|
|
901
|
+
hmm_plot_cmap = _build_hmm_feature_cmap(cmap_hmm)
|
|
580
902
|
|
|
581
903
|
panels = [
|
|
582
904
|
(
|
|
583
905
|
f"HMM - {hmm_feature_layer}",
|
|
584
|
-
|
|
906
|
+
hmm_plot_matrix,
|
|
585
907
|
hmm_labels,
|
|
586
|
-
|
|
908
|
+
hmm_plot_cmap,
|
|
587
909
|
mean_hmm,
|
|
588
910
|
n_xticks_hmm,
|
|
589
911
|
),
|
|
@@ -591,26 +913,58 @@ def combined_hmm_raw_clustermap(
|
|
|
591
913
|
|
|
592
914
|
if stacked_any_c:
|
|
593
915
|
m = np.vstack(stacked_any_c)
|
|
916
|
+
m_raw = np.vstack(stacked_any_c_raw)
|
|
594
917
|
panels.append(
|
|
595
|
-
(
|
|
918
|
+
(
|
|
919
|
+
"C",
|
|
920
|
+
m,
|
|
921
|
+
any_c_labels,
|
|
922
|
+
cmap_c,
|
|
923
|
+
_methylation_fraction_for_layer(m_raw, layer_c),
|
|
924
|
+
n_xticks_any_c,
|
|
925
|
+
)
|
|
596
926
|
)
|
|
597
927
|
|
|
598
928
|
if stacked_gpc:
|
|
599
929
|
m = np.vstack(stacked_gpc)
|
|
930
|
+
m_raw = np.vstack(stacked_gpc_raw)
|
|
600
931
|
panels.append(
|
|
601
|
-
(
|
|
932
|
+
(
|
|
933
|
+
"GpC",
|
|
934
|
+
m,
|
|
935
|
+
gpc_labels,
|
|
936
|
+
cmap_gpc,
|
|
937
|
+
_methylation_fraction_for_layer(m_raw, layer_gpc),
|
|
938
|
+
n_xticks_gpc,
|
|
939
|
+
)
|
|
602
940
|
)
|
|
603
941
|
|
|
604
942
|
if stacked_cpg:
|
|
605
943
|
m = np.vstack(stacked_cpg)
|
|
944
|
+
m_raw = np.vstack(stacked_cpg_raw)
|
|
606
945
|
panels.append(
|
|
607
|
-
(
|
|
946
|
+
(
|
|
947
|
+
"CpG",
|
|
948
|
+
m,
|
|
949
|
+
cpg_labels,
|
|
950
|
+
cmap_cpg,
|
|
951
|
+
_methylation_fraction_for_layer(m_raw, layer_cpg),
|
|
952
|
+
n_xticks_cpg,
|
|
953
|
+
)
|
|
608
954
|
)
|
|
609
955
|
|
|
610
956
|
if stacked_any_a:
|
|
611
957
|
m = np.vstack(stacked_any_a)
|
|
958
|
+
m_raw = np.vstack(stacked_any_a_raw)
|
|
612
959
|
panels.append(
|
|
613
|
-
(
|
|
960
|
+
(
|
|
961
|
+
"A",
|
|
962
|
+
m,
|
|
963
|
+
any_a_labels,
|
|
964
|
+
cmap_a,
|
|
965
|
+
_methylation_fraction_for_layer(m_raw, layer_a),
|
|
966
|
+
n_xticks_a,
|
|
967
|
+
)
|
|
614
968
|
)
|
|
615
969
|
|
|
616
970
|
# ---------------- plotting ----------------
|
|
@@ -629,7 +983,15 @@ def combined_hmm_raw_clustermap(
|
|
|
629
983
|
clean_barplot(axes_bar[i], mean_vec, name)
|
|
630
984
|
|
|
631
985
|
# ---- heatmap ----
|
|
632
|
-
|
|
986
|
+
heatmap_kwargs = dict(
|
|
987
|
+
cmap=cmap,
|
|
988
|
+
ax=axes_heat[i],
|
|
989
|
+
yticklabels=False,
|
|
990
|
+
cbar=False,
|
|
991
|
+
)
|
|
992
|
+
if name.startswith("HMM -"):
|
|
993
|
+
heatmap_kwargs.update(vmin=0.0, vmax=1.0)
|
|
994
|
+
sns.heatmap(matrix, **heatmap_kwargs)
|
|
633
995
|
|
|
634
996
|
# ---- xticks ----
|
|
635
997
|
xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
|
|
@@ -658,271 +1020,6 @@ def combined_hmm_raw_clustermap(
|
|
|
658
1020
|
continue
|
|
659
1021
|
|
|
660
1022
|
|
|
661
|
-
# def combined_raw_clustermap(
|
|
662
|
-
# adata,
|
|
663
|
-
# sample_col='Sample_Names',
|
|
664
|
-
# reference_col='Reference_strand',
|
|
665
|
-
# mod_target_bases=['GpC', 'CpG'],
|
|
666
|
-
# layer_any_c="nan0_0minus1",
|
|
667
|
-
# layer_gpc="nan0_0minus1",
|
|
668
|
-
# layer_cpg="nan0_0minus1",
|
|
669
|
-
# layer_a="nan0_0minus1",
|
|
670
|
-
# cmap_any_c="coolwarm",
|
|
671
|
-
# cmap_gpc="coolwarm",
|
|
672
|
-
# cmap_cpg="viridis",
|
|
673
|
-
# cmap_a="coolwarm",
|
|
674
|
-
# min_quality=20,
|
|
675
|
-
# min_length=200,
|
|
676
|
-
# min_mapped_length_to_reference_length_ratio=0.8,
|
|
677
|
-
# min_position_valid_fraction=0.5,
|
|
678
|
-
# sample_mapping=None,
|
|
679
|
-
# save_path=None,
|
|
680
|
-
# sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', 'any_a', or 'obs:<column>'
|
|
681
|
-
# bins=None,
|
|
682
|
-
# deaminase=False,
|
|
683
|
-
# min_signal=0
|
|
684
|
-
# ):
|
|
685
|
-
|
|
686
|
-
# results = []
|
|
687
|
-
|
|
688
|
-
# for ref in adata.obs[reference_col].cat.categories:
|
|
689
|
-
# for sample in adata.obs[sample_col].cat.categories:
|
|
690
|
-
# try:
|
|
691
|
-
# subset = adata[
|
|
692
|
-
# (adata.obs[reference_col] == ref) &
|
|
693
|
-
# (adata.obs[sample_col] == sample) &
|
|
694
|
-
# (adata.obs['read_quality'] >= min_quality) &
|
|
695
|
-
# (adata.obs['mapped_length'] >= min_length) &
|
|
696
|
-
# (adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
|
|
697
|
-
# ]
|
|
698
|
-
|
|
699
|
-
# mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
|
|
700
|
-
# subset = subset[:, mask]
|
|
701
|
-
|
|
702
|
-
# if subset.shape[0] == 0:
|
|
703
|
-
# print(f" No reads left after filtering for {sample} - {ref}")
|
|
704
|
-
# continue
|
|
705
|
-
|
|
706
|
-
# if bins:
|
|
707
|
-
# print(f"Using defined bins to subset clustermap for {sample} - {ref}")
|
|
708
|
-
# bins_temp = bins
|
|
709
|
-
# else:
|
|
710
|
-
# print(f"Using all reads for clustermap for {sample} - {ref}")
|
|
711
|
-
# bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
|
|
712
|
-
|
|
713
|
-
# num_any_c = 0
|
|
714
|
-
# num_gpc = 0
|
|
715
|
-
# num_cpg = 0
|
|
716
|
-
# num_any_a = 0
|
|
717
|
-
|
|
718
|
-
# # Get column positions (not var_names!) of site masks
|
|
719
|
-
# if any(base in ["C", "CpG", "GpC"] for base in mod_target_bases):
|
|
720
|
-
# any_c_sites = np.where(subset.var[f"{ref}_C_site"].values)[0]
|
|
721
|
-
# gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
|
|
722
|
-
# cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
|
|
723
|
-
# num_any_c = len(any_c_sites)
|
|
724
|
-
# num_gpc = len(gpc_sites)
|
|
725
|
-
# num_cpg = len(cpg_sites)
|
|
726
|
-
# print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
|
|
727
|
-
|
|
728
|
-
# # Use var_names for x-axis tick labels
|
|
729
|
-
# gpc_labels = subset.var_names[gpc_sites].astype(int)
|
|
730
|
-
# cpg_labels = subset.var_names[cpg_sites].astype(int)
|
|
731
|
-
# any_c_labels = subset.var_names[any_c_sites].astype(int)
|
|
732
|
-
# stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
|
|
733
|
-
|
|
734
|
-
# if "A" in mod_target_bases:
|
|
735
|
-
# any_a_sites = np.where(subset.var[f"{ref}_A_site"].values)[0]
|
|
736
|
-
# num_any_a = len(any_a_sites)
|
|
737
|
-
# print(f"Found {num_any_a} any_A sites at {any_a_sites} for {sample} - {ref}")
|
|
738
|
-
# any_a_labels = subset.var_names[any_a_sites].astype(int)
|
|
739
|
-
# stacked_any_a = []
|
|
740
|
-
|
|
741
|
-
# row_labels, bin_labels = [], []
|
|
742
|
-
# bin_boundaries = []
|
|
743
|
-
|
|
744
|
-
# total_reads = subset.shape[0]
|
|
745
|
-
# percentages = {}
|
|
746
|
-
# last_idx = 0
|
|
747
|
-
|
|
748
|
-
# for bin_label, bin_filter in bins_temp.items():
|
|
749
|
-
# subset_bin = subset[bin_filter].copy()
|
|
750
|
-
# num_reads = subset_bin.shape[0]
|
|
751
|
-
# print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
|
|
752
|
-
# percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
|
|
753
|
-
# percentages[bin_label] = percent_reads
|
|
754
|
-
|
|
755
|
-
# if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
|
|
756
|
-
# # Determine sorting order
|
|
757
|
-
# if sort_by.startswith("obs:"):
|
|
758
|
-
# colname = sort_by.split("obs:")[1]
|
|
759
|
-
# order = np.argsort(subset_bin.obs[colname].values)
|
|
760
|
-
# elif sort_by == "gpc":
|
|
761
|
-
# linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
|
|
762
|
-
# order = sch.leaves_list(linkage)
|
|
763
|
-
# elif sort_by == "cpg":
|
|
764
|
-
# linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
|
|
765
|
-
# order = sch.leaves_list(linkage)
|
|
766
|
-
# elif sort_by == "any_c":
|
|
767
|
-
# linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
|
|
768
|
-
# order = sch.leaves_list(linkage)
|
|
769
|
-
# elif sort_by == "gpc_cpg":
|
|
770
|
-
# linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
|
|
771
|
-
# order = sch.leaves_list(linkage)
|
|
772
|
-
# elif sort_by == "none":
|
|
773
|
-
# order = np.arange(num_reads)
|
|
774
|
-
# elif sort_by == "any_a":
|
|
775
|
-
# linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
|
|
776
|
-
# order = sch.leaves_list(linkage)
|
|
777
|
-
# else:
|
|
778
|
-
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
779
|
-
|
|
780
|
-
# stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
|
|
781
|
-
# stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
|
|
782
|
-
# stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
|
|
783
|
-
|
|
784
|
-
# if num_reads > 0 and num_any_a > 0:
|
|
785
|
-
# # Determine sorting order
|
|
786
|
-
# if sort_by.startswith("obs:"):
|
|
787
|
-
# colname = sort_by.split("obs:")[1]
|
|
788
|
-
# order = np.argsort(subset_bin.obs[colname].values)
|
|
789
|
-
# elif sort_by == "gpc":
|
|
790
|
-
# linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
|
|
791
|
-
# order = sch.leaves_list(linkage)
|
|
792
|
-
# elif sort_by == "cpg":
|
|
793
|
-
# linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
|
|
794
|
-
# order = sch.leaves_list(linkage)
|
|
795
|
-
# elif sort_by == "any_c":
|
|
796
|
-
# linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
|
|
797
|
-
# order = sch.leaves_list(linkage)
|
|
798
|
-
# elif sort_by == "gpc_cpg":
|
|
799
|
-
# linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
|
|
800
|
-
# order = sch.leaves_list(linkage)
|
|
801
|
-
# elif sort_by == "none":
|
|
802
|
-
# order = np.arange(num_reads)
|
|
803
|
-
# elif sort_by == "any_a":
|
|
804
|
-
# linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
|
|
805
|
-
# order = sch.leaves_list(linkage)
|
|
806
|
-
# else:
|
|
807
|
-
# raise ValueError(f"Unsupported sort_by option: {sort_by}")
|
|
808
|
-
|
|
809
|
-
# stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
# row_labels.extend([bin_label] * num_reads)
|
|
813
|
-
# bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
814
|
-
# last_idx += num_reads
|
|
815
|
-
# bin_boundaries.append(last_idx)
|
|
816
|
-
|
|
817
|
-
# gs_dim = 0
|
|
818
|
-
|
|
819
|
-
# if stacked_any_c:
|
|
820
|
-
# any_c_matrix = np.vstack(stacked_any_c)
|
|
821
|
-
# gpc_matrix = np.vstack(stacked_gpc)
|
|
822
|
-
# cpg_matrix = np.vstack(stacked_cpg)
|
|
823
|
-
# if any_c_matrix.size > 0:
|
|
824
|
-
# mean_gpc = methylation_fraction(gpc_matrix)
|
|
825
|
-
# mean_cpg = methylation_fraction(cpg_matrix)
|
|
826
|
-
# mean_any_c = methylation_fraction(any_c_matrix)
|
|
827
|
-
# gs_dim += 3
|
|
828
|
-
|
|
829
|
-
# if stacked_any_a:
|
|
830
|
-
# any_a_matrix = np.vstack(stacked_any_a)
|
|
831
|
-
# if any_a_matrix.size > 0:
|
|
832
|
-
# mean_any_a = methylation_fraction(any_a_matrix)
|
|
833
|
-
# gs_dim += 1
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
# fig = plt.figure(figsize=(18, 12))
|
|
837
|
-
# gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
|
|
838
|
-
# fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
|
|
839
|
-
# axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
|
|
840
|
-
# axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
|
|
841
|
-
|
|
842
|
-
# current_ax = 0
|
|
843
|
-
|
|
844
|
-
# if stacked_any_c:
|
|
845
|
-
# if any_c_matrix.size > 0:
|
|
846
|
-
# clean_barplot(axes_bar[current_ax], mean_any_c, f"any C site Modification Signal")
|
|
847
|
-
# sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[current_ax], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
|
|
848
|
-
# axes_heat[current_ax].set_xticks(range(0, len(any_c_labels), 20))
|
|
849
|
-
# axes_heat[current_ax].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
|
|
850
|
-
# for boundary in bin_boundaries[:-1]:
|
|
851
|
-
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
852
|
-
# current_ax +=1
|
|
853
|
-
|
|
854
|
-
# clean_barplot(axes_bar[current_ax], mean_gpc, f"GpC Modification Signal")
|
|
855
|
-
# sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[current_ax], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
|
|
856
|
-
# axes_heat[current_ax].set_xticks(range(0, len(gpc_labels), 5))
|
|
857
|
-
# axes_heat[current_ax].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
|
|
858
|
-
# for boundary in bin_boundaries[:-1]:
|
|
859
|
-
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
860
|
-
# current_ax +=1
|
|
861
|
-
|
|
862
|
-
# clean_barplot(axes_bar[current_ax], mean_cpg, f"CpG Modification Signal")
|
|
863
|
-
# sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
|
|
864
|
-
# axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
|
|
865
|
-
# for boundary in bin_boundaries[:-1]:
|
|
866
|
-
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
867
|
-
# current_ax +=1
|
|
868
|
-
|
|
869
|
-
# results.append({
|
|
870
|
-
# "sample": sample,
|
|
871
|
-
# "ref": ref,
|
|
872
|
-
# "any_c_matrix": any_c_matrix,
|
|
873
|
-
# "gpc_matrix": gpc_matrix,
|
|
874
|
-
# "cpg_matrix": cpg_matrix,
|
|
875
|
-
# "row_labels": row_labels,
|
|
876
|
-
# "bin_labels": bin_labels,
|
|
877
|
-
# "bin_boundaries": bin_boundaries,
|
|
878
|
-
# "percentages": percentages
|
|
879
|
-
# })
|
|
880
|
-
|
|
881
|
-
# if stacked_any_a:
|
|
882
|
-
# if any_a_matrix.size > 0:
|
|
883
|
-
# clean_barplot(axes_bar[current_ax], mean_any_a, f"any A site Modification Signal")
|
|
884
|
-
# sns.heatmap(any_a_matrix, cmap=cmap_a, ax=axes_heat[current_ax], xticklabels=any_a_labels[::20], yticklabels=False, cbar=False)
|
|
885
|
-
# axes_heat[current_ax].set_xticks(range(0, len(any_a_labels), 20))
|
|
886
|
-
# axes_heat[current_ax].set_xticklabels(any_a_labels[::20], rotation=90, fontsize=10)
|
|
887
|
-
# for boundary in bin_boundaries[:-1]:
|
|
888
|
-
# axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
|
|
889
|
-
# current_ax +=1
|
|
890
|
-
|
|
891
|
-
# results.append({
|
|
892
|
-
# "sample": sample,
|
|
893
|
-
# "ref": ref,
|
|
894
|
-
# "any_a_matrix": any_a_matrix,
|
|
895
|
-
# "row_labels": row_labels,
|
|
896
|
-
# "bin_labels": bin_labels,
|
|
897
|
-
# "bin_boundaries": bin_boundaries,
|
|
898
|
-
# "percentages": percentages
|
|
899
|
-
# })
|
|
900
|
-
|
|
901
|
-
# plt.tight_layout()
|
|
902
|
-
|
|
903
|
-
# if save_path:
|
|
904
|
-
# save_name = f"{ref} — {sample}"
|
|
905
|
-
# os.makedirs(save_path, exist_ok=True)
|
|
906
|
-
# safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
907
|
-
# out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
908
|
-
# plt.savefig(out_file, dpi=300)
|
|
909
|
-
# print(f"Saved: {out_file}")
|
|
910
|
-
# plt.close()
|
|
911
|
-
# else:
|
|
912
|
-
# plt.show()
|
|
913
|
-
|
|
914
|
-
# print(f"Summary for {sample} - {ref}:")
|
|
915
|
-
# for bin_label, percent in percentages.items():
|
|
916
|
-
# print(f" - {bin_label}: {percent:.1f}%")
|
|
917
|
-
|
|
918
|
-
# adata.uns['clustermap_results'] = results
|
|
919
|
-
|
|
920
|
-
# except Exception as e:
|
|
921
|
-
# import traceback
|
|
922
|
-
# traceback.print_exc()
|
|
923
|
-
# continue
|
|
924
|
-
|
|
925
|
-
|
|
926
1023
|
def combined_raw_clustermap(
|
|
927
1024
|
adata,
|
|
928
1025
|
sample_col: str = "Sample_Names",
|
|
@@ -954,6 +1051,8 @@ def combined_raw_clustermap(
|
|
|
954
1051
|
xtick_rotation: int = 90,
|
|
955
1052
|
xtick_fontsize: int = 9,
|
|
956
1053
|
index_col_suffix: str | None = None,
|
|
1054
|
+
fill_nan_strategy: str = "value",
|
|
1055
|
+
fill_nan_value: float = -1,
|
|
957
1056
|
):
|
|
958
1057
|
"""
|
|
959
1058
|
Plot stacked heatmaps + per-position mean barplots for C, GpC, CpG, and optional A.
|
|
@@ -964,6 +1063,7 @@ def combined_raw_clustermap(
|
|
|
964
1063
|
- NaNs excluded from methylation denominators
|
|
965
1064
|
- var_names not forced to int
|
|
966
1065
|
- fixed count of x tick labels per block (controllable)
|
|
1066
|
+
- optional NaN fill strategy for clustering/plotting (in-memory only)
|
|
967
1067
|
- adata.uns updated once at end
|
|
968
1068
|
|
|
969
1069
|
Returns
|
|
@@ -971,6 +1071,8 @@ def combined_raw_clustermap(
|
|
|
971
1071
|
results : list[dict]
|
|
972
1072
|
One entry per (sample, ref) plot with matrices + bin metadata.
|
|
973
1073
|
"""
|
|
1074
|
+
if fill_nan_strategy not in {"none", "value", "col_mean"}:
|
|
1075
|
+
raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
|
|
974
1076
|
|
|
975
1077
|
# Helper: build a True mask if filter is inactive or column missing
|
|
976
1078
|
def _mask_or_true(series_name: str, predicate):
|
|
@@ -1093,6 +1195,12 @@ def combined_raw_clustermap(
|
|
|
1093
1195
|
any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
|
|
1094
1196
|
|
|
1095
1197
|
stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
|
|
1198
|
+
stacked_any_c_raw, stacked_gpc_raw, stacked_cpg_raw, stacked_any_a_raw = (
|
|
1199
|
+
[],
|
|
1200
|
+
[],
|
|
1201
|
+
[],
|
|
1202
|
+
[],
|
|
1203
|
+
)
|
|
1096
1204
|
row_labels, bin_labels, bin_boundaries = [], [], []
|
|
1097
1205
|
percentages = {}
|
|
1098
1206
|
last_idx = 0
|
|
@@ -1117,31 +1225,58 @@ def combined_raw_clustermap(
|
|
|
1117
1225
|
order = np.argsort(subset_bin.obs[colname].values)
|
|
1118
1226
|
|
|
1119
1227
|
elif sort_by == "gpc" and num_gpc > 0:
|
|
1120
|
-
|
|
1121
|
-
subset_bin
|
|
1228
|
+
gpc_matrix = _layer_to_numpy(
|
|
1229
|
+
subset_bin,
|
|
1230
|
+
layer_gpc,
|
|
1231
|
+
gpc_sites,
|
|
1232
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1233
|
+
fill_nan_value=fill_nan_value,
|
|
1122
1234
|
)
|
|
1235
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
1123
1236
|
order = sch.leaves_list(linkage)
|
|
1124
1237
|
|
|
1125
1238
|
elif sort_by == "cpg" and num_cpg > 0:
|
|
1126
|
-
|
|
1127
|
-
subset_bin
|
|
1239
|
+
cpg_matrix = _layer_to_numpy(
|
|
1240
|
+
subset_bin,
|
|
1241
|
+
layer_cpg,
|
|
1242
|
+
cpg_sites,
|
|
1243
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1244
|
+
fill_nan_value=fill_nan_value,
|
|
1128
1245
|
)
|
|
1246
|
+
linkage = sch.linkage(cpg_matrix, method="ward")
|
|
1129
1247
|
order = sch.leaves_list(linkage)
|
|
1130
1248
|
|
|
1131
1249
|
elif sort_by == "c" and num_any_c > 0:
|
|
1132
|
-
|
|
1133
|
-
subset_bin
|
|
1250
|
+
any_c_matrix = _layer_to_numpy(
|
|
1251
|
+
subset_bin,
|
|
1252
|
+
layer_c,
|
|
1253
|
+
any_c_sites,
|
|
1254
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1255
|
+
fill_nan_value=fill_nan_value,
|
|
1134
1256
|
)
|
|
1257
|
+
linkage = sch.linkage(any_c_matrix, method="ward")
|
|
1135
1258
|
order = sch.leaves_list(linkage)
|
|
1136
1259
|
|
|
1137
1260
|
elif sort_by == "gpc_cpg":
|
|
1138
|
-
|
|
1261
|
+
gpc_matrix = _layer_to_numpy(
|
|
1262
|
+
subset_bin,
|
|
1263
|
+
layer_gpc,
|
|
1264
|
+
None,
|
|
1265
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1266
|
+
fill_nan_value=fill_nan_value,
|
|
1267
|
+
)
|
|
1268
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
1139
1269
|
order = sch.leaves_list(linkage)
|
|
1140
1270
|
|
|
1141
1271
|
elif sort_by == "a" and num_any_a > 0:
|
|
1142
|
-
|
|
1143
|
-
subset_bin
|
|
1272
|
+
any_a_matrix = _layer_to_numpy(
|
|
1273
|
+
subset_bin,
|
|
1274
|
+
layer_a,
|
|
1275
|
+
any_a_sites,
|
|
1276
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1277
|
+
fill_nan_value=fill_nan_value,
|
|
1144
1278
|
)
|
|
1279
|
+
linkage = sch.linkage(any_a_matrix, method="ward")
|
|
1145
1280
|
order = sch.leaves_list(linkage)
|
|
1146
1281
|
|
|
1147
1282
|
elif sort_by == "none":
|
|
@@ -1154,13 +1289,81 @@ def combined_raw_clustermap(
|
|
|
1154
1289
|
|
|
1155
1290
|
# stack consistently
|
|
1156
1291
|
if include_any_c and num_any_c > 0:
|
|
1157
|
-
stacked_any_c.append(
|
|
1292
|
+
stacked_any_c.append(
|
|
1293
|
+
_layer_to_numpy(
|
|
1294
|
+
subset_bin,
|
|
1295
|
+
layer_c,
|
|
1296
|
+
any_c_sites,
|
|
1297
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1298
|
+
fill_nan_value=fill_nan_value,
|
|
1299
|
+
)
|
|
1300
|
+
)
|
|
1301
|
+
stacked_any_c_raw.append(
|
|
1302
|
+
_layer_to_numpy(
|
|
1303
|
+
subset_bin,
|
|
1304
|
+
layer_c,
|
|
1305
|
+
any_c_sites,
|
|
1306
|
+
fill_nan_strategy="none",
|
|
1307
|
+
fill_nan_value=fill_nan_value,
|
|
1308
|
+
)
|
|
1309
|
+
)
|
|
1158
1310
|
if include_any_c and num_gpc > 0:
|
|
1159
|
-
stacked_gpc.append(
|
|
1311
|
+
stacked_gpc.append(
|
|
1312
|
+
_layer_to_numpy(
|
|
1313
|
+
subset_bin,
|
|
1314
|
+
layer_gpc,
|
|
1315
|
+
gpc_sites,
|
|
1316
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1317
|
+
fill_nan_value=fill_nan_value,
|
|
1318
|
+
)
|
|
1319
|
+
)
|
|
1320
|
+
stacked_gpc_raw.append(
|
|
1321
|
+
_layer_to_numpy(
|
|
1322
|
+
subset_bin,
|
|
1323
|
+
layer_gpc,
|
|
1324
|
+
gpc_sites,
|
|
1325
|
+
fill_nan_strategy="none",
|
|
1326
|
+
fill_nan_value=fill_nan_value,
|
|
1327
|
+
)
|
|
1328
|
+
)
|
|
1160
1329
|
if include_any_c and num_cpg > 0:
|
|
1161
|
-
stacked_cpg.append(
|
|
1330
|
+
stacked_cpg.append(
|
|
1331
|
+
_layer_to_numpy(
|
|
1332
|
+
subset_bin,
|
|
1333
|
+
layer_cpg,
|
|
1334
|
+
cpg_sites,
|
|
1335
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1336
|
+
fill_nan_value=fill_nan_value,
|
|
1337
|
+
)
|
|
1338
|
+
)
|
|
1339
|
+
stacked_cpg_raw.append(
|
|
1340
|
+
_layer_to_numpy(
|
|
1341
|
+
subset_bin,
|
|
1342
|
+
layer_cpg,
|
|
1343
|
+
cpg_sites,
|
|
1344
|
+
fill_nan_strategy="none",
|
|
1345
|
+
fill_nan_value=fill_nan_value,
|
|
1346
|
+
)
|
|
1347
|
+
)
|
|
1162
1348
|
if include_any_a and num_any_a > 0:
|
|
1163
|
-
stacked_any_a.append(
|
|
1349
|
+
stacked_any_a.append(
|
|
1350
|
+
_layer_to_numpy(
|
|
1351
|
+
subset_bin,
|
|
1352
|
+
layer_a,
|
|
1353
|
+
any_a_sites,
|
|
1354
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1355
|
+
fill_nan_value=fill_nan_value,
|
|
1356
|
+
)
|
|
1357
|
+
)
|
|
1358
|
+
stacked_any_a_raw.append(
|
|
1359
|
+
_layer_to_numpy(
|
|
1360
|
+
subset_bin,
|
|
1361
|
+
layer_a,
|
|
1362
|
+
any_a_sites,
|
|
1363
|
+
fill_nan_strategy="none",
|
|
1364
|
+
fill_nan_value=fill_nan_value,
|
|
1365
|
+
)
|
|
1366
|
+
)
|
|
1164
1367
|
|
|
1165
1368
|
row_labels.extend([bin_label] * num_reads)
|
|
1166
1369
|
bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
@@ -1174,12 +1377,31 @@ def combined_raw_clustermap(
|
|
|
1174
1377
|
|
|
1175
1378
|
if include_any_c and stacked_any_c:
|
|
1176
1379
|
any_c_matrix = np.vstack(stacked_any_c)
|
|
1380
|
+
any_c_matrix_raw = np.vstack(stacked_any_c_raw)
|
|
1177
1381
|
gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
|
|
1382
|
+
gpc_matrix_raw = (
|
|
1383
|
+
np.vstack(stacked_gpc_raw) if stacked_gpc_raw else np.empty((0, 0))
|
|
1384
|
+
)
|
|
1178
1385
|
cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
|
|
1386
|
+
cpg_matrix_raw = (
|
|
1387
|
+
np.vstack(stacked_cpg_raw) if stacked_cpg_raw else np.empty((0, 0))
|
|
1388
|
+
)
|
|
1179
1389
|
|
|
1180
|
-
mean_any_c =
|
|
1181
|
-
|
|
1182
|
-
|
|
1390
|
+
mean_any_c = (
|
|
1391
|
+
_methylation_fraction_for_layer(any_c_matrix_raw, layer_c)
|
|
1392
|
+
if any_c_matrix_raw.size
|
|
1393
|
+
else None
|
|
1394
|
+
)
|
|
1395
|
+
mean_gpc = (
|
|
1396
|
+
_methylation_fraction_for_layer(gpc_matrix_raw, layer_gpc)
|
|
1397
|
+
if gpc_matrix_raw.size
|
|
1398
|
+
else None
|
|
1399
|
+
)
|
|
1400
|
+
mean_cpg = (
|
|
1401
|
+
_methylation_fraction_for_layer(cpg_matrix_raw, layer_cpg)
|
|
1402
|
+
if cpg_matrix_raw.size
|
|
1403
|
+
else None
|
|
1404
|
+
)
|
|
1183
1405
|
|
|
1184
1406
|
if any_c_matrix.size:
|
|
1185
1407
|
blocks.append(
|
|
@@ -1220,7 +1442,12 @@ def combined_raw_clustermap(
|
|
|
1220
1442
|
|
|
1221
1443
|
if include_any_a and stacked_any_a:
|
|
1222
1444
|
any_a_matrix = np.vstack(stacked_any_a)
|
|
1223
|
-
|
|
1445
|
+
any_a_matrix_raw = np.vstack(stacked_any_a_raw)
|
|
1446
|
+
mean_any_a = (
|
|
1447
|
+
_methylation_fraction_for_layer(any_a_matrix_raw, layer_a)
|
|
1448
|
+
if any_a_matrix_raw.size
|
|
1449
|
+
else None
|
|
1450
|
+
)
|
|
1224
1451
|
if any_a_matrix.size:
|
|
1225
1452
|
blocks.append(
|
|
1226
1453
|
dict(
|
|
@@ -1320,112 +1547,1530 @@ def combined_raw_clustermap(
|
|
|
1320
1547
|
return results
|
|
1321
1548
|
|
|
1322
1549
|
|
|
1323
|
-
def
|
|
1550
|
+
def combined_hmm_length_clustermap(
|
|
1324
1551
|
adata,
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1552
|
+
sample_col: str = "Sample_Names",
|
|
1553
|
+
reference_col: str = "Reference_strand",
|
|
1554
|
+
length_layer: str = "hmm_combined_lengths",
|
|
1555
|
+
layer_gpc: str = "nan0_0minus1",
|
|
1556
|
+
layer_cpg: str = "nan0_0minus1",
|
|
1557
|
+
layer_c: str = "nan0_0minus1",
|
|
1558
|
+
layer_a: str = "nan0_0minus1",
|
|
1559
|
+
cmap_lengths: Any = "Greens",
|
|
1560
|
+
cmap_gpc: str = "coolwarm",
|
|
1561
|
+
cmap_cpg: str = "viridis",
|
|
1562
|
+
cmap_c: str = "coolwarm",
|
|
1563
|
+
cmap_a: str = "coolwarm",
|
|
1564
|
+
min_quality: int = 20,
|
|
1565
|
+
min_length: int = 200,
|
|
1566
|
+
min_mapped_length_to_reference_length_ratio: float = 0.8,
|
|
1567
|
+
min_position_valid_fraction: float = 0.5,
|
|
1568
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
1569
|
+
sample_mapping: Optional[Mapping[str, str]] = None,
|
|
1570
|
+
save_path: str | Path | None = None,
|
|
1571
|
+
sort_by: str = "gpc",
|
|
1572
|
+
bins: Optional[Dict[str, Any]] = None,
|
|
1573
|
+
deaminase: bool = False,
|
|
1574
|
+
min_signal: float = 0.0,
|
|
1575
|
+
n_xticks_lengths: int = 10,
|
|
1576
|
+
n_xticks_any_c: int = 8,
|
|
1577
|
+
n_xticks_gpc: int = 8,
|
|
1578
|
+
n_xticks_cpg: int = 8,
|
|
1579
|
+
n_xticks_a: int = 8,
|
|
1580
|
+
index_col_suffix: str | None = None,
|
|
1581
|
+
fill_nan_strategy: str = "value",
|
|
1582
|
+
fill_nan_value: float = -1,
|
|
1583
|
+
length_feature_ranges: Optional[Sequence[Tuple[int, int, Any]]] = None,
|
|
1341
1584
|
):
|
|
1342
1585
|
"""
|
|
1343
|
-
|
|
1344
|
-
positional mean (mean across reads) for each layer listed.
|
|
1345
|
-
|
|
1346
|
-
Parameters
|
|
1347
|
-
----------
|
|
1348
|
-
adata : AnnData
|
|
1349
|
-
Input annotated data (expects obs columns sample_col and ref_col).
|
|
1350
|
-
layers : list[str] | None
|
|
1351
|
-
Which adata.layers to plot. If None, attempts to autodetect layers whose
|
|
1352
|
-
matrices look like "HMM" outputs (else will error). If None and layers
|
|
1353
|
-
cannot be found, user must pass a list.
|
|
1354
|
-
sample_col, ref_col : str
|
|
1355
|
-
obs columns used to group rows.
|
|
1356
|
-
samples, references : optional lists
|
|
1357
|
-
explicit ordering of samples / references. If None, categories in adata.obs are used.
|
|
1358
|
-
window : int
|
|
1359
|
-
rolling window size (odd recommended). If window <= 1, no smoothing applied.
|
|
1360
|
-
min_periods : int
|
|
1361
|
-
min periods param for pd.Series.rolling.
|
|
1362
|
-
center : bool
|
|
1363
|
-
center the rolling window.
|
|
1364
|
-
rows_per_page : int
|
|
1365
|
-
paginate rows per page into multiple figures if needed.
|
|
1366
|
-
figsize_per_cell : (w,h)
|
|
1367
|
-
per-subplot size in inches.
|
|
1368
|
-
dpi : int
|
|
1369
|
-
figure dpi when saving.
|
|
1370
|
-
output_dir : str | None
|
|
1371
|
-
directory to save pages; created if necessary. If None and save=True, uses cwd.
|
|
1372
|
-
save : bool
|
|
1373
|
-
whether to save PNG files.
|
|
1374
|
-
show_raw : bool
|
|
1375
|
-
draw unsmoothed mean as faint line under smoothed curve.
|
|
1376
|
-
cmap : str
|
|
1377
|
-
matplotlib colormap for layer lines.
|
|
1378
|
-
use_var_coords : bool
|
|
1379
|
-
if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
|
|
1586
|
+
Plot clustermaps for length-encoded HMM feature layers with optional subclass colors.
|
|
1380
1587
|
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
saved_files : list[str]
|
|
1384
|
-
list of saved filenames (may be empty if save=False).
|
|
1588
|
+
Length-based feature ranges map integer lengths into subclass colors for accessible
|
|
1589
|
+
and footprint layers. Raw methylation panels are included when available.
|
|
1385
1590
|
"""
|
|
1591
|
+
if fill_nan_strategy not in {"none", "value", "col_mean"}:
|
|
1592
|
+
raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
|
|
1386
1593
|
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
)
|
|
1594
|
+
def pick_xticks(labels: np.ndarray, n_ticks: int):
|
|
1595
|
+
"""Pick tick indices/labels from an array."""
|
|
1596
|
+
if labels.size == 0:
|
|
1597
|
+
return [], []
|
|
1598
|
+
idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
|
|
1599
|
+
idx = np.unique(idx)
|
|
1600
|
+
return idx.tolist(), labels[idx].tolist()
|
|
1392
1601
|
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1602
|
+
def _mask_or_true(series_name: str, predicate):
|
|
1603
|
+
"""Return a mask from predicate or an all-True mask."""
|
|
1604
|
+
if series_name not in adata.obs:
|
|
1605
|
+
return pd.Series(True, index=adata.obs.index)
|
|
1606
|
+
s = adata.obs[series_name]
|
|
1607
|
+
try:
|
|
1608
|
+
return predicate(s)
|
|
1609
|
+
except Exception:
|
|
1610
|
+
return pd.Series(True, index=adata.obs.index)
|
|
1401
1611
|
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
rseries = rseries.astype("category")
|
|
1406
|
-
refs_all = list(rseries.cat.categories)
|
|
1407
|
-
else:
|
|
1408
|
-
refs_all = list(references)
|
|
1612
|
+
results = []
|
|
1613
|
+
signal_type = "deamination" if deaminase else "methylation"
|
|
1614
|
+
feature_ranges = tuple(length_feature_ranges or ())
|
|
1409
1615
|
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1616
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
1617
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
1618
|
+
display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
|
|
1619
|
+
qmask = _mask_or_true(
|
|
1620
|
+
"read_quality",
|
|
1621
|
+
(lambda s: s >= float(min_quality))
|
|
1622
|
+
if (min_quality is not None)
|
|
1623
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1416
1624
|
)
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1625
|
+
lm_mask = _mask_or_true(
|
|
1626
|
+
"mapped_length",
|
|
1627
|
+
(lambda s: s >= float(min_length))
|
|
1628
|
+
if (min_length is not None)
|
|
1629
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1630
|
+
)
|
|
1631
|
+
lrr_mask = _mask_or_true(
|
|
1632
|
+
"mapped_length_to_reference_length_ratio",
|
|
1633
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
1634
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
1635
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1636
|
+
)
|
|
1637
|
+
|
|
1638
|
+
demux_mask = _mask_or_true(
|
|
1639
|
+
"demux_type",
|
|
1640
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
1641
|
+
if (demux_types is not None)
|
|
1642
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
1643
|
+
)
|
|
1644
|
+
|
|
1645
|
+
ref_mask = adata.obs[reference_col] == ref
|
|
1646
|
+
sample_mask = adata.obs[sample_col] == sample
|
|
1647
|
+
|
|
1648
|
+
row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
|
|
1649
|
+
|
|
1650
|
+
if not bool(row_mask.any()):
|
|
1651
|
+
print(
|
|
1652
|
+
f"No reads for {display_sample} - {ref} after read quality and length filtering"
|
|
1653
|
+
)
|
|
1654
|
+
continue
|
|
1655
|
+
|
|
1656
|
+
try:
|
|
1657
|
+
subset = adata[row_mask, :].copy()
|
|
1658
|
+
|
|
1659
|
+
if min_position_valid_fraction is not None:
|
|
1660
|
+
valid_key = f"{ref}_valid_fraction"
|
|
1661
|
+
if valid_key in subset.var:
|
|
1662
|
+
v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
|
|
1663
|
+
col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
|
|
1664
|
+
if col_mask.any():
|
|
1665
|
+
subset = subset[:, col_mask].copy()
|
|
1666
|
+
else:
|
|
1667
|
+
print(
|
|
1668
|
+
f"No positions left after valid_fraction filter for {display_sample} - {ref}"
|
|
1669
|
+
)
|
|
1670
|
+
continue
|
|
1671
|
+
|
|
1672
|
+
if subset.shape[0] == 0:
|
|
1673
|
+
print(f"No reads left after filtering for {display_sample} - {ref}")
|
|
1674
|
+
continue
|
|
1675
|
+
|
|
1676
|
+
if bins is None:
|
|
1677
|
+
bins_temp = {"All": np.ones(subset.n_obs, dtype=bool)}
|
|
1678
|
+
else:
|
|
1679
|
+
bins_temp = bins
|
|
1680
|
+
|
|
1681
|
+
def _sites(*keys):
|
|
1682
|
+
"""Return indices for the first matching site key."""
|
|
1683
|
+
for k in keys:
|
|
1684
|
+
if k in subset.var:
|
|
1685
|
+
return np.where(subset.var[k].values)[0]
|
|
1686
|
+
return np.array([], dtype=int)
|
|
1687
|
+
|
|
1688
|
+
gpc_sites = _sites(f"{ref}_GpC_site")
|
|
1689
|
+
cpg_sites = _sites(f"{ref}_CpG_site")
|
|
1690
|
+
any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
|
|
1691
|
+
any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
|
|
1692
|
+
|
|
1693
|
+
length_sites = np.arange(subset.n_vars, dtype=int)
|
|
1694
|
+
length_labels = _select_labels(subset, length_sites, ref, index_col_suffix)
|
|
1695
|
+
gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
|
|
1696
|
+
cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
|
|
1697
|
+
any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
|
|
1698
|
+
any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
|
|
1699
|
+
|
|
1700
|
+
stacked_lengths = []
|
|
1701
|
+
stacked_lengths_raw = []
|
|
1702
|
+
stacked_any_c = []
|
|
1703
|
+
stacked_any_c_raw = []
|
|
1704
|
+
stacked_gpc = []
|
|
1705
|
+
stacked_gpc_raw = []
|
|
1706
|
+
stacked_cpg = []
|
|
1707
|
+
stacked_cpg_raw = []
|
|
1708
|
+
stacked_any_a = []
|
|
1709
|
+
stacked_any_a_raw = []
|
|
1710
|
+
|
|
1711
|
+
row_labels, bin_labels, bin_boundaries = [], [], []
|
|
1712
|
+
total_reads = subset.n_obs
|
|
1713
|
+
percentages = {}
|
|
1714
|
+
last_idx = 0
|
|
1715
|
+
|
|
1716
|
+
for bin_label, bin_filter in bins_temp.items():
|
|
1717
|
+
sb = subset[bin_filter].copy()
|
|
1718
|
+
n = sb.n_obs
|
|
1719
|
+
if n == 0:
|
|
1720
|
+
continue
|
|
1721
|
+
|
|
1722
|
+
pct = (n / total_reads) * 100 if total_reads else 0
|
|
1723
|
+
percentages[bin_label] = pct
|
|
1724
|
+
|
|
1725
|
+
if sort_by.startswith("obs:"):
|
|
1726
|
+
colname = sort_by.split("obs:")[1]
|
|
1727
|
+
order = np.argsort(sb.obs[colname].values)
|
|
1728
|
+
elif sort_by == "gpc" and gpc_sites.size:
|
|
1729
|
+
gpc_matrix = _layer_to_numpy(
|
|
1730
|
+
sb,
|
|
1731
|
+
layer_gpc,
|
|
1732
|
+
gpc_sites,
|
|
1733
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1734
|
+
fill_nan_value=fill_nan_value,
|
|
1735
|
+
)
|
|
1736
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
1737
|
+
order = sch.leaves_list(linkage)
|
|
1738
|
+
elif sort_by == "cpg" and cpg_sites.size:
|
|
1739
|
+
cpg_matrix = _layer_to_numpy(
|
|
1740
|
+
sb,
|
|
1741
|
+
layer_cpg,
|
|
1742
|
+
cpg_sites,
|
|
1743
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1744
|
+
fill_nan_value=fill_nan_value,
|
|
1745
|
+
)
|
|
1746
|
+
linkage = sch.linkage(cpg_matrix, method="ward")
|
|
1747
|
+
order = sch.leaves_list(linkage)
|
|
1748
|
+
elif sort_by == "c" and any_c_sites.size:
|
|
1749
|
+
any_c_matrix = _layer_to_numpy(
|
|
1750
|
+
sb,
|
|
1751
|
+
layer_c,
|
|
1752
|
+
any_c_sites,
|
|
1753
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1754
|
+
fill_nan_value=fill_nan_value,
|
|
1755
|
+
)
|
|
1756
|
+
linkage = sch.linkage(any_c_matrix, method="ward")
|
|
1757
|
+
order = sch.leaves_list(linkage)
|
|
1758
|
+
elif sort_by == "a" and any_a_sites.size:
|
|
1759
|
+
any_a_matrix = _layer_to_numpy(
|
|
1760
|
+
sb,
|
|
1761
|
+
layer_a,
|
|
1762
|
+
any_a_sites,
|
|
1763
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1764
|
+
fill_nan_value=fill_nan_value,
|
|
1765
|
+
)
|
|
1766
|
+
linkage = sch.linkage(any_a_matrix, method="ward")
|
|
1767
|
+
order = sch.leaves_list(linkage)
|
|
1768
|
+
elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
|
|
1769
|
+
gpc_matrix = _layer_to_numpy(
|
|
1770
|
+
sb,
|
|
1771
|
+
layer_gpc,
|
|
1772
|
+
None,
|
|
1773
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1774
|
+
fill_nan_value=fill_nan_value,
|
|
1775
|
+
)
|
|
1776
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
1777
|
+
order = sch.leaves_list(linkage)
|
|
1778
|
+
elif sort_by == "hmm" and length_sites.size:
|
|
1779
|
+
length_matrix = _layer_to_numpy(
|
|
1780
|
+
sb,
|
|
1781
|
+
length_layer,
|
|
1782
|
+
length_sites,
|
|
1783
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1784
|
+
fill_nan_value=fill_nan_value,
|
|
1785
|
+
)
|
|
1786
|
+
linkage = sch.linkage(length_matrix, method="ward")
|
|
1787
|
+
order = sch.leaves_list(linkage)
|
|
1788
|
+
else:
|
|
1789
|
+
order = np.arange(n)
|
|
1790
|
+
|
|
1791
|
+
sb = sb[order]
|
|
1792
|
+
|
|
1793
|
+
stacked_lengths.append(
|
|
1794
|
+
_layer_to_numpy(
|
|
1795
|
+
sb,
|
|
1796
|
+
length_layer,
|
|
1797
|
+
None,
|
|
1798
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1799
|
+
fill_nan_value=fill_nan_value,
|
|
1800
|
+
)
|
|
1801
|
+
)
|
|
1802
|
+
stacked_lengths_raw.append(
|
|
1803
|
+
_layer_to_numpy(
|
|
1804
|
+
sb,
|
|
1805
|
+
length_layer,
|
|
1806
|
+
None,
|
|
1807
|
+
fill_nan_strategy="none",
|
|
1808
|
+
fill_nan_value=fill_nan_value,
|
|
1809
|
+
)
|
|
1810
|
+
)
|
|
1811
|
+
if any_c_sites.size:
|
|
1812
|
+
stacked_any_c.append(
|
|
1813
|
+
_layer_to_numpy(
|
|
1814
|
+
sb,
|
|
1815
|
+
layer_c,
|
|
1816
|
+
any_c_sites,
|
|
1817
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1818
|
+
fill_nan_value=fill_nan_value,
|
|
1819
|
+
)
|
|
1820
|
+
)
|
|
1821
|
+
stacked_any_c_raw.append(
|
|
1822
|
+
_layer_to_numpy(
|
|
1823
|
+
sb,
|
|
1824
|
+
layer_c,
|
|
1825
|
+
any_c_sites,
|
|
1826
|
+
fill_nan_strategy="none",
|
|
1827
|
+
fill_nan_value=fill_nan_value,
|
|
1828
|
+
)
|
|
1829
|
+
)
|
|
1830
|
+
if gpc_sites.size:
|
|
1831
|
+
stacked_gpc.append(
|
|
1832
|
+
_layer_to_numpy(
|
|
1833
|
+
sb,
|
|
1834
|
+
layer_gpc,
|
|
1835
|
+
gpc_sites,
|
|
1836
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1837
|
+
fill_nan_value=fill_nan_value,
|
|
1838
|
+
)
|
|
1839
|
+
)
|
|
1840
|
+
stacked_gpc_raw.append(
|
|
1841
|
+
_layer_to_numpy(
|
|
1842
|
+
sb,
|
|
1843
|
+
layer_gpc,
|
|
1844
|
+
gpc_sites,
|
|
1845
|
+
fill_nan_strategy="none",
|
|
1846
|
+
fill_nan_value=fill_nan_value,
|
|
1847
|
+
)
|
|
1848
|
+
)
|
|
1849
|
+
if cpg_sites.size:
|
|
1850
|
+
stacked_cpg.append(
|
|
1851
|
+
_layer_to_numpy(
|
|
1852
|
+
sb,
|
|
1853
|
+
layer_cpg,
|
|
1854
|
+
cpg_sites,
|
|
1855
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1856
|
+
fill_nan_value=fill_nan_value,
|
|
1857
|
+
)
|
|
1858
|
+
)
|
|
1859
|
+
stacked_cpg_raw.append(
|
|
1860
|
+
_layer_to_numpy(
|
|
1861
|
+
sb,
|
|
1862
|
+
layer_cpg,
|
|
1863
|
+
cpg_sites,
|
|
1864
|
+
fill_nan_strategy="none",
|
|
1865
|
+
fill_nan_value=fill_nan_value,
|
|
1866
|
+
)
|
|
1867
|
+
)
|
|
1868
|
+
if any_a_sites.size:
|
|
1869
|
+
stacked_any_a.append(
|
|
1870
|
+
_layer_to_numpy(
|
|
1871
|
+
sb,
|
|
1872
|
+
layer_a,
|
|
1873
|
+
any_a_sites,
|
|
1874
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
1875
|
+
fill_nan_value=fill_nan_value,
|
|
1876
|
+
)
|
|
1877
|
+
)
|
|
1878
|
+
stacked_any_a_raw.append(
|
|
1879
|
+
_layer_to_numpy(
|
|
1880
|
+
sb,
|
|
1881
|
+
layer_a,
|
|
1882
|
+
any_a_sites,
|
|
1883
|
+
fill_nan_strategy="none",
|
|
1884
|
+
fill_nan_value=fill_nan_value,
|
|
1885
|
+
)
|
|
1886
|
+
)
|
|
1887
|
+
|
|
1888
|
+
row_labels.extend([bin_label] * n)
|
|
1889
|
+
bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
|
|
1890
|
+
last_idx += n
|
|
1891
|
+
bin_boundaries.append(last_idx)
|
|
1892
|
+
|
|
1893
|
+
length_matrix = np.vstack(stacked_lengths)
|
|
1894
|
+
length_matrix_raw = np.vstack(stacked_lengths_raw)
|
|
1895
|
+
capped_lengths = np.where(length_matrix_raw > 1, 1.0, length_matrix_raw)
|
|
1896
|
+
mean_lengths = np.nanmean(capped_lengths, axis=0)
|
|
1897
|
+
length_plot_matrix = length_matrix_raw
|
|
1898
|
+
length_plot_cmap = cmap_lengths
|
|
1899
|
+
length_plot_norm = None
|
|
1900
|
+
|
|
1901
|
+
if feature_ranges:
|
|
1902
|
+
length_plot_matrix = _map_length_matrix_to_subclasses(
|
|
1903
|
+
length_matrix_raw, feature_ranges
|
|
1904
|
+
)
|
|
1905
|
+
length_plot_cmap, length_plot_norm = _build_length_feature_cmap(feature_ranges)
|
|
1906
|
+
|
|
1907
|
+
panels = [
|
|
1908
|
+
(
|
|
1909
|
+
f"HMM lengths - {length_layer}",
|
|
1910
|
+
length_plot_matrix,
|
|
1911
|
+
length_labels,
|
|
1912
|
+
length_plot_cmap,
|
|
1913
|
+
mean_lengths,
|
|
1914
|
+
n_xticks_lengths,
|
|
1915
|
+
length_plot_norm,
|
|
1916
|
+
),
|
|
1917
|
+
]
|
|
1918
|
+
|
|
1919
|
+
if stacked_any_c:
|
|
1920
|
+
m = np.vstack(stacked_any_c)
|
|
1921
|
+
m_raw = np.vstack(stacked_any_c_raw)
|
|
1922
|
+
panels.append(
|
|
1923
|
+
(
|
|
1924
|
+
"C",
|
|
1925
|
+
m,
|
|
1926
|
+
any_c_labels,
|
|
1927
|
+
cmap_c,
|
|
1928
|
+
_methylation_fraction_for_layer(m_raw, layer_c),
|
|
1929
|
+
n_xticks_any_c,
|
|
1930
|
+
None,
|
|
1931
|
+
)
|
|
1932
|
+
)
|
|
1933
|
+
|
|
1934
|
+
if stacked_gpc:
|
|
1935
|
+
m = np.vstack(stacked_gpc)
|
|
1936
|
+
m_raw = np.vstack(stacked_gpc_raw)
|
|
1937
|
+
panels.append(
|
|
1938
|
+
(
|
|
1939
|
+
"GpC",
|
|
1940
|
+
m,
|
|
1941
|
+
gpc_labels,
|
|
1942
|
+
cmap_gpc,
|
|
1943
|
+
_methylation_fraction_for_layer(m_raw, layer_gpc),
|
|
1944
|
+
n_xticks_gpc,
|
|
1945
|
+
None,
|
|
1946
|
+
)
|
|
1947
|
+
)
|
|
1948
|
+
|
|
1949
|
+
if stacked_cpg:
|
|
1950
|
+
m = np.vstack(stacked_cpg)
|
|
1951
|
+
m_raw = np.vstack(stacked_cpg_raw)
|
|
1952
|
+
panels.append(
|
|
1953
|
+
(
|
|
1954
|
+
"CpG",
|
|
1955
|
+
m,
|
|
1956
|
+
cpg_labels,
|
|
1957
|
+
cmap_cpg,
|
|
1958
|
+
_methylation_fraction_for_layer(m_raw, layer_cpg),
|
|
1959
|
+
n_xticks_cpg,
|
|
1960
|
+
None,
|
|
1961
|
+
)
|
|
1962
|
+
)
|
|
1963
|
+
|
|
1964
|
+
if stacked_any_a:
|
|
1965
|
+
m = np.vstack(stacked_any_a)
|
|
1966
|
+
m_raw = np.vstack(stacked_any_a_raw)
|
|
1967
|
+
panels.append(
|
|
1968
|
+
(
|
|
1969
|
+
"A",
|
|
1970
|
+
m,
|
|
1971
|
+
any_a_labels,
|
|
1972
|
+
cmap_a,
|
|
1973
|
+
_methylation_fraction_for_layer(m_raw, layer_a),
|
|
1974
|
+
n_xticks_a,
|
|
1975
|
+
None,
|
|
1976
|
+
)
|
|
1977
|
+
)
|
|
1978
|
+
|
|
1979
|
+
n_panels = len(panels)
|
|
1980
|
+
fig = plt.figure(figsize=(4.5 * n_panels, 10))
|
|
1981
|
+
gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
|
|
1982
|
+
fig.suptitle(
|
|
1983
|
+
f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
|
|
1984
|
+
)
|
|
1985
|
+
|
|
1986
|
+
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
|
|
1987
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
|
|
1988
|
+
|
|
1989
|
+
for i, (name, matrix, labels, cmap, mean_vec, n_ticks, norm) in enumerate(panels):
|
|
1990
|
+
clean_barplot(axes_bar[i], mean_vec, name)
|
|
1991
|
+
|
|
1992
|
+
heatmap_kwargs = dict(
|
|
1993
|
+
cmap=cmap,
|
|
1994
|
+
ax=axes_heat[i],
|
|
1995
|
+
yticklabels=False,
|
|
1996
|
+
cbar=False,
|
|
1997
|
+
)
|
|
1998
|
+
if norm is not None:
|
|
1999
|
+
heatmap_kwargs["norm"] = norm
|
|
2000
|
+
sns.heatmap(matrix, **heatmap_kwargs)
|
|
2001
|
+
|
|
2002
|
+
xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
|
|
2003
|
+
axes_heat[i].set_xticks(xtick_pos)
|
|
2004
|
+
axes_heat[i].set_xticklabels(xtick_labels, rotation=90, fontsize=8)
|
|
2005
|
+
|
|
2006
|
+
for boundary in bin_boundaries[:-1]:
|
|
2007
|
+
axes_heat[i].axhline(y=boundary, color="black", linewidth=1.2)
|
|
2008
|
+
|
|
2009
|
+
plt.tight_layout()
|
|
2010
|
+
|
|
2011
|
+
if save_path:
|
|
2012
|
+
save_path = Path(save_path)
|
|
2013
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
2014
|
+
safe_name = f"{ref}__{sample}".replace("/", "_")
|
|
2015
|
+
out_file = save_path / f"{safe_name}.png"
|
|
2016
|
+
plt.savefig(out_file, dpi=300)
|
|
2017
|
+
plt.close(fig)
|
|
2018
|
+
else:
|
|
2019
|
+
plt.show()
|
|
2020
|
+
|
|
2021
|
+
results.append((sample, ref))
|
|
2022
|
+
|
|
2023
|
+
except Exception:
|
|
2024
|
+
import traceback
|
|
2025
|
+
|
|
2026
|
+
traceback.print_exc()
|
|
2027
|
+
print(f"Failed {sample} - {ref} - {length_layer}")
|
|
2028
|
+
|
|
2029
|
+
return results
|
|
2030
|
+
|
|
2031
|
+
|
|
2032
|
+
def make_row_colors(meta: pd.DataFrame) -> pd.DataFrame:
|
|
2033
|
+
"""
|
|
2034
|
+
Convert metadata columns to RGB colors without invoking pandas Categorical.map
|
|
2035
|
+
(MultiIndex-safe, category-safe).
|
|
2036
|
+
"""
|
|
2037
|
+
row_colors = pd.DataFrame(index=meta.index)
|
|
2038
|
+
|
|
2039
|
+
for col in meta.columns:
|
|
2040
|
+
# Force plain python objects to avoid ExtensionArray/Categorical behavior
|
|
2041
|
+
s = meta[col].astype("object")
|
|
2042
|
+
|
|
2043
|
+
def _to_label(x):
|
|
2044
|
+
if x is None:
|
|
2045
|
+
return "NA"
|
|
2046
|
+
if isinstance(x, float) and np.isnan(x):
|
|
2047
|
+
return "NA"
|
|
2048
|
+
# If a MultiIndex object is stored in a cell (rare), bucket it
|
|
2049
|
+
if isinstance(x, pd.MultiIndex):
|
|
2050
|
+
return "MultiIndex"
|
|
2051
|
+
# Tuples are common when MultiIndex-ish things get stored as values
|
|
2052
|
+
if isinstance(x, tuple):
|
|
2053
|
+
return "|".join(map(str, x))
|
|
2054
|
+
return str(x)
|
|
2055
|
+
|
|
2056
|
+
labels = np.array([_to_label(x) for x in s.to_numpy()], dtype=object)
|
|
2057
|
+
uniq = pd.unique(labels)
|
|
2058
|
+
palette = dict(zip(uniq, sns.color_palette(n_colors=len(uniq))))
|
|
2059
|
+
|
|
2060
|
+
# Map via python loop -> no pandas map machinery
|
|
2061
|
+
colors = [palette.get(lbl, (0.7, 0.7, 0.7)) for lbl in labels]
|
|
2062
|
+
row_colors[col] = colors
|
|
2063
|
+
|
|
2064
|
+
return row_colors
|
|
2065
|
+
|
|
2066
|
+
|
|
2067
|
+
def plot_rolling_nn_and_layer(
|
|
2068
|
+
subset,
|
|
2069
|
+
obsm_key: str = "rolling_nn_dist",
|
|
2070
|
+
layer_key: str = "nan0_0minus1",
|
|
2071
|
+
meta_cols=("Reference_strand", "Sample"),
|
|
2072
|
+
col_cluster: bool = False,
|
|
2073
|
+
fill_nn_with_colmax: bool = True,
|
|
2074
|
+
fill_layer_value: float = 0.0,
|
|
2075
|
+
drop_all_nan_windows: bool = True,
|
|
2076
|
+
max_nan_fraction: float | None = None,
|
|
2077
|
+
var_valid_fraction_col: str | None = None,
|
|
2078
|
+
var_nan_fraction_col: str | None = None,
|
|
2079
|
+
figsize=(14, 10),
|
|
2080
|
+
right_panel_var_mask=None, # optional boolean mask over subset.var to reduce width
|
|
2081
|
+
robust=True,
|
|
2082
|
+
title: str | None = None,
|
|
2083
|
+
xtick_step: int | None = None,
|
|
2084
|
+
xtick_rotation: int = 90,
|
|
2085
|
+
xtick_fontsize: int = 8,
|
|
2086
|
+
save_name=None,
|
|
2087
|
+
):
|
|
2088
|
+
"""
|
|
2089
|
+
1) Cluster rows by subset.obsm[obsm_key] (rolling NN distances)
|
|
2090
|
+
2) Plot two heatmaps side-by-side in the SAME row order, with mean barplots above:
|
|
2091
|
+
- left: rolling NN distance matrix
|
|
2092
|
+
- right: subset.layers[layer_key] matrix
|
|
2093
|
+
|
|
2094
|
+
Handles categorical/MultiIndex issues in metadata coloring.
|
|
2095
|
+
|
|
2096
|
+
Args:
|
|
2097
|
+
subset: AnnData subset with rolling NN distances stored in ``obsm``.
|
|
2098
|
+
obsm_key: Key in ``subset.obsm`` containing rolling NN distances.
|
|
2099
|
+
layer_key: Layer name to plot alongside rolling NN distances.
|
|
2100
|
+
meta_cols: Obs columns used for row color annotations.
|
|
2101
|
+
col_cluster: Whether to cluster columns in the rolling NN clustermap.
|
|
2102
|
+
fill_nn_with_colmax: Fill NaNs in rolling NN distances with per-column max values.
|
|
2103
|
+
fill_layer_value: Fill NaNs in the layer heatmap with this value.
|
|
2104
|
+
drop_all_nan_windows: Drop rolling windows that are all NaN.
|
|
2105
|
+
max_nan_fraction: Maximum allowed NaN fraction per position (filtering columns).
|
|
2106
|
+
var_valid_fraction_col: ``subset.var`` column with valid fractions (1 - NaN fraction).
|
|
2107
|
+
var_nan_fraction_col: ``subset.var`` column with NaN fractions.
|
|
2108
|
+
figsize: Figure size for the combined plot.
|
|
2109
|
+
right_panel_var_mask: Optional boolean mask over ``subset.var`` for the right panel.
|
|
2110
|
+
robust: Use robust color scaling in seaborn.
|
|
2111
|
+
title: Optional figure title (suptitle).
|
|
2112
|
+
xtick_step: Spacing between x-axis tick labels.
|
|
2113
|
+
xtick_rotation: Rotation for x-axis tick labels.
|
|
2114
|
+
xtick_fontsize: Font size for x-axis tick labels.
|
|
2115
|
+
save_name: Optional output path for saving the plot.
|
|
2116
|
+
"""
|
|
2117
|
+
if max_nan_fraction is not None and not (0 <= max_nan_fraction <= 1):
|
|
2118
|
+
raise ValueError("max_nan_fraction must be between 0 and 1.")
|
|
2119
|
+
|
|
2120
|
+
def _apply_xticks(ax, labels, step):
|
|
2121
|
+
if labels is None or len(labels) == 0:
|
|
2122
|
+
ax.set_xticks([])
|
|
2123
|
+
return
|
|
2124
|
+
if step is None or step <= 0:
|
|
2125
|
+
step = max(1, len(labels) // 10)
|
|
2126
|
+
ticks = np.arange(0, len(labels), step)
|
|
2127
|
+
ax.set_xticks(ticks + 0.5)
|
|
2128
|
+
ax.set_xticklabels(
|
|
2129
|
+
[labels[i] for i in ticks],
|
|
2130
|
+
rotation=xtick_rotation,
|
|
2131
|
+
fontsize=xtick_fontsize,
|
|
2132
|
+
)
|
|
2133
|
+
|
|
2134
|
+
# --- rolling NN distances
|
|
2135
|
+
X = subset.obsm[obsm_key]
|
|
2136
|
+
valid = ~np.all(np.isnan(X), axis=1)
|
|
2137
|
+
|
|
2138
|
+
X_df = pd.DataFrame(X[valid], index=subset.obs_names[valid])
|
|
2139
|
+
|
|
2140
|
+
if drop_all_nan_windows:
|
|
2141
|
+
X_df = X_df.loc[:, ~X_df.isna().all(axis=0)]
|
|
2142
|
+
|
|
2143
|
+
X_df_filled = X_df.copy()
|
|
2144
|
+
if fill_nn_with_colmax:
|
|
2145
|
+
col_max = X_df_filled.max(axis=0, skipna=True)
|
|
2146
|
+
X_df_filled = X_df_filled.fillna(col_max)
|
|
2147
|
+
|
|
2148
|
+
# Ensure non-MultiIndex index for seaborn
|
|
2149
|
+
X_df_filled.index = X_df_filled.index.astype(str)
|
|
2150
|
+
|
|
2151
|
+
# --- row colors from metadata (MultiIndex-safe)
|
|
2152
|
+
meta = subset.obs.loc[X_df.index, list(meta_cols)].copy()
|
|
2153
|
+
meta.index = meta.index.astype(str)
|
|
2154
|
+
row_colors = make_row_colors(meta)
|
|
2155
|
+
|
|
2156
|
+
# --- get row order via clustermap
|
|
2157
|
+
g = sns.clustermap(
|
|
2158
|
+
X_df_filled,
|
|
2159
|
+
cmap="viridis",
|
|
2160
|
+
col_cluster=col_cluster,
|
|
2161
|
+
row_cluster=True,
|
|
2162
|
+
row_colors=row_colors,
|
|
2163
|
+
xticklabels=False,
|
|
2164
|
+
yticklabels=False,
|
|
2165
|
+
robust=robust,
|
|
2166
|
+
)
|
|
2167
|
+
row_order = g.dendrogram_row.reordered_ind
|
|
2168
|
+
ordered_index = X_df_filled.index[row_order]
|
|
2169
|
+
plt.close(g.fig)
|
|
2170
|
+
|
|
2171
|
+
# reorder rolling NN matrix
|
|
2172
|
+
X_ord = X_df_filled.loc[ordered_index]
|
|
2173
|
+
|
|
2174
|
+
# --- layer matrix
|
|
2175
|
+
L = subset.layers[layer_key]
|
|
2176
|
+
L = L.toarray() if hasattr(L, "toarray") else np.asarray(L)
|
|
2177
|
+
|
|
2178
|
+
L_df = pd.DataFrame(L[valid], index=subset.obs_names[valid], columns=subset.var_names)
|
|
2179
|
+
L_df.index = L_df.index.astype(str)
|
|
2180
|
+
|
|
2181
|
+
if right_panel_var_mask is not None:
|
|
2182
|
+
# right_panel_var_mask must be boolean array/Series aligned to subset.var_names
|
|
2183
|
+
if hasattr(right_panel_var_mask, "values"):
|
|
2184
|
+
right_panel_var_mask = right_panel_var_mask.values
|
|
2185
|
+
right_panel_var_mask = np.asarray(right_panel_var_mask, dtype=bool)
|
|
2186
|
+
|
|
2187
|
+
if max_nan_fraction is not None:
|
|
2188
|
+
nan_fraction = None
|
|
2189
|
+
if var_nan_fraction_col and var_nan_fraction_col in subset.var:
|
|
2190
|
+
nan_fraction = pd.to_numeric(
|
|
2191
|
+
subset.var[var_nan_fraction_col], errors="coerce"
|
|
2192
|
+
).to_numpy()
|
|
2193
|
+
elif var_valid_fraction_col and var_valid_fraction_col in subset.var:
|
|
2194
|
+
valid_fraction = pd.to_numeric(
|
|
2195
|
+
subset.var[var_valid_fraction_col], errors="coerce"
|
|
2196
|
+
).to_numpy()
|
|
2197
|
+
nan_fraction = 1 - valid_fraction
|
|
2198
|
+
if nan_fraction is not None:
|
|
2199
|
+
nan_mask = nan_fraction <= max_nan_fraction
|
|
2200
|
+
if right_panel_var_mask is None:
|
|
2201
|
+
right_panel_var_mask = nan_mask
|
|
2202
|
+
else:
|
|
2203
|
+
right_panel_var_mask = right_panel_var_mask & nan_mask
|
|
2204
|
+
|
|
2205
|
+
if right_panel_var_mask is not None:
|
|
2206
|
+
if right_panel_var_mask.size != L_df.shape[1]:
|
|
2207
|
+
raise ValueError("right_panel_var_mask must align with subset.var_names.")
|
|
2208
|
+
L_df = L_df.loc[:, right_panel_var_mask]
|
|
2209
|
+
|
|
2210
|
+
L_ord = L_df.loc[ordered_index]
|
|
2211
|
+
L_plot = L_ord.fillna(fill_layer_value)
|
|
2212
|
+
|
|
2213
|
+
# --- plot side-by-side with barplots above
|
|
2214
|
+
fig = plt.figure(figsize=figsize)
|
|
2215
|
+
gs = fig.add_gridspec(
|
|
2216
|
+
2,
|
|
2217
|
+
4,
|
|
2218
|
+
width_ratios=[1, 0.05, 1, 0.05],
|
|
2219
|
+
height_ratios=[1, 6],
|
|
2220
|
+
wspace=0.2,
|
|
2221
|
+
hspace=0.05,
|
|
2222
|
+
)
|
|
2223
|
+
|
|
2224
|
+
ax1 = fig.add_subplot(gs[1, 0])
|
|
2225
|
+
ax1_cbar = fig.add_subplot(gs[1, 1])
|
|
2226
|
+
ax2 = fig.add_subplot(gs[1, 2])
|
|
2227
|
+
ax2_cbar = fig.add_subplot(gs[1, 3])
|
|
2228
|
+
ax1_bar = fig.add_subplot(gs[0, 0], sharex=ax1)
|
|
2229
|
+
ax2_bar = fig.add_subplot(gs[0, 2], sharex=ax2)
|
|
2230
|
+
fig.add_subplot(gs[0, 1]).axis("off")
|
|
2231
|
+
fig.add_subplot(gs[0, 3]).axis("off")
|
|
2232
|
+
|
|
2233
|
+
mean_nn = np.nanmean(X_ord.to_numpy(), axis=0)
|
|
2234
|
+
clean_barplot(
|
|
2235
|
+
ax1_bar,
|
|
2236
|
+
mean_nn,
|
|
2237
|
+
obsm_key,
|
|
2238
|
+
y_max=None,
|
|
2239
|
+
y_label="Mean distance",
|
|
2240
|
+
y_ticks=None,
|
|
2241
|
+
)
|
|
2242
|
+
|
|
2243
|
+
sns.heatmap(
|
|
2244
|
+
X_ord,
|
|
2245
|
+
ax=ax1,
|
|
2246
|
+
cmap="viridis",
|
|
2247
|
+
xticklabels=False,
|
|
2248
|
+
yticklabels=False,
|
|
2249
|
+
robust=robust,
|
|
2250
|
+
cbar_ax=ax1_cbar,
|
|
2251
|
+
)
|
|
2252
|
+
starts = subset.uns.get(f"{obsm_key}_starts")
|
|
2253
|
+
if starts is not None:
|
|
2254
|
+
starts = np.asarray(starts)
|
|
2255
|
+
window_labels = [str(s) for s in starts]
|
|
2256
|
+
try:
|
|
2257
|
+
col_idx = X_ord.columns.to_numpy()
|
|
2258
|
+
if np.issubdtype(col_idx.dtype, np.number):
|
|
2259
|
+
col_idx = col_idx.astype(int)
|
|
2260
|
+
if col_idx.size and col_idx.max() < len(starts):
|
|
2261
|
+
window_labels = [str(s) for s in starts[col_idx]]
|
|
2262
|
+
except Exception:
|
|
2263
|
+
window_labels = [str(s) for s in starts]
|
|
2264
|
+
_apply_xticks(ax1, window_labels, xtick_step)
|
|
2265
|
+
|
|
2266
|
+
methylation_fraction = _methylation_fraction_for_layer(L_ord.to_numpy(), layer_key)
|
|
2267
|
+
clean_barplot(
|
|
2268
|
+
ax2_bar,
|
|
2269
|
+
methylation_fraction,
|
|
2270
|
+
layer_key,
|
|
2271
|
+
y_max=1.0,
|
|
2272
|
+
y_label="Methylation fraction",
|
|
2273
|
+
y_ticks=[0.0, 0.5, 1.0],
|
|
2274
|
+
)
|
|
2275
|
+
|
|
2276
|
+
sns.heatmap(
|
|
2277
|
+
L_plot,
|
|
2278
|
+
ax=ax2,
|
|
2279
|
+
cmap="coolwarm",
|
|
2280
|
+
xticklabels=False,
|
|
2281
|
+
yticklabels=False,
|
|
2282
|
+
robust=robust,
|
|
2283
|
+
cbar_ax=ax2_cbar,
|
|
2284
|
+
)
|
|
2285
|
+
_apply_xticks(ax2, [str(x) for x in L_plot.columns], xtick_step)
|
|
2286
|
+
|
|
2287
|
+
if title:
|
|
2288
|
+
fig.suptitle(title)
|
|
2289
|
+
|
|
2290
|
+
if save_name is not None:
|
|
2291
|
+
fname = os.path.join(save_name)
|
|
2292
|
+
plt.savefig(fname, dpi=200, bbox_inches="tight")
|
|
2293
|
+
|
|
2294
|
+
else:
|
|
2295
|
+
plt.show()
|
|
2296
|
+
|
|
2297
|
+
return ordered_index
|
|
2298
|
+
|
|
2299
|
+
|
|
2300
|
+
def plot_sequence_integer_encoding_clustermaps(
|
|
2301
|
+
adata,
|
|
2302
|
+
sample_col: str = "Sample_Names",
|
|
2303
|
+
reference_col: str = "Reference_strand",
|
|
2304
|
+
layer: str = "sequence_integer_encoding",
|
|
2305
|
+
mismatch_layer: str = "mismatch_integer_encoding",
|
|
2306
|
+
min_quality: float | None = 20,
|
|
2307
|
+
min_length: int | None = 200,
|
|
2308
|
+
min_mapped_length_to_reference_length_ratio: float | None = 0,
|
|
2309
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
2310
|
+
sort_by: str = "none", # "none", "hierarchical", "obs:<col>"
|
|
2311
|
+
cmap: str = "viridis",
|
|
2312
|
+
max_unknown_fraction: float | None = None,
|
|
2313
|
+
unknown_values: Sequence[int] = (4, 5),
|
|
2314
|
+
xtick_step: int | None = None,
|
|
2315
|
+
xtick_rotation: int = 90,
|
|
2316
|
+
xtick_fontsize: int = 9,
|
|
2317
|
+
max_reads: int | None = None,
|
|
2318
|
+
save_path: str | Path | None = None,
|
|
2319
|
+
use_dna_5color_palette: bool = True,
|
|
2320
|
+
show_numeric_colorbar: bool = False,
|
|
2321
|
+
show_position_axis: bool = False,
|
|
2322
|
+
position_axis_tick_target: int = 25,
|
|
2323
|
+
):
|
|
2324
|
+
"""Plot integer-encoded sequence clustermaps per sample/reference.
|
|
2325
|
+
|
|
2326
|
+
Args:
|
|
2327
|
+
adata: AnnData with a ``sequence_integer_encoding`` layer.
|
|
2328
|
+
sample_col: Column in ``adata.obs`` that identifies samples.
|
|
2329
|
+
reference_col: Column in ``adata.obs`` that identifies references.
|
|
2330
|
+
layer: Layer name containing integer-encoded sequences.
|
|
2331
|
+
mismatch_layer: Optional layer name containing mismatch integer encodings.
|
|
2332
|
+
min_quality: Optional minimum read quality filter.
|
|
2333
|
+
min_length: Optional minimum mapped length filter.
|
|
2334
|
+
min_mapped_length_to_reference_length_ratio: Optional min length ratio filter.
|
|
2335
|
+
demux_types: Allowed ``demux_type`` values, if present in ``adata.obs``.
|
|
2336
|
+
sort_by: Row sorting strategy: ``none``, ``hierarchical``, or ``obs:<col>``.
|
|
2337
|
+
cmap: Matplotlib colormap for the heatmap when ``use_dna_5color_palette`` is False.
|
|
2338
|
+
max_unknown_fraction: Optional maximum fraction of ``unknown_values`` allowed per
|
|
2339
|
+
position; positions above this threshold are excluded.
|
|
2340
|
+
unknown_values: Integer values to treat as unknown/padding.
|
|
2341
|
+
xtick_step: Spacing between x-axis tick labels (None = no labels).
|
|
2342
|
+
xtick_rotation: Rotation for x-axis tick labels.
|
|
2343
|
+
xtick_fontsize: Font size for x-axis tick labels.
|
|
2344
|
+
max_reads: Optional maximum number of reads to plot per sample/reference.
|
|
2345
|
+
save_path: Optional output directory for saving plots.
|
|
2346
|
+
use_dna_5color_palette: Whether to use a fixed A/C/G/T/Other palette.
|
|
2347
|
+
show_numeric_colorbar: If False, use a legend instead of a numeric colorbar.
|
|
2348
|
+
show_position_axis: Whether to draw a position axis with tick labels.
|
|
2349
|
+
position_axis_tick_target: Approximate number of ticks to show when auto-sizing.
|
|
2350
|
+
|
|
2351
|
+
Returns:
|
|
2352
|
+
List of dictionaries with per-plot metadata and output paths.
|
|
2353
|
+
"""
|
|
2354
|
+
|
|
2355
|
+
def _mask_or_true(series_name: str, predicate):
|
|
2356
|
+
if series_name not in adata.obs:
|
|
2357
|
+
return pd.Series(True, index=adata.obs.index)
|
|
2358
|
+
s = adata.obs[series_name]
|
|
2359
|
+
try:
|
|
2360
|
+
return predicate(s)
|
|
2361
|
+
except Exception:
|
|
2362
|
+
return pd.Series(True, index=adata.obs.index)
|
|
2363
|
+
|
|
2364
|
+
if layer not in adata.layers:
|
|
2365
|
+
raise KeyError(f"Layer '{layer}' not found in adata.layers")
|
|
2366
|
+
|
|
2367
|
+
if max_unknown_fraction is not None and not (0 <= max_unknown_fraction <= 1):
|
|
2368
|
+
raise ValueError("max_unknown_fraction must be between 0 and 1.")
|
|
2369
|
+
|
|
2370
|
+
if position_axis_tick_target < 1:
|
|
2371
|
+
raise ValueError("position_axis_tick_target must be at least 1.")
|
|
2372
|
+
|
|
2373
|
+
results: List[Dict[str, Any]] = []
|
|
2374
|
+
save_path = Path(save_path) if save_path is not None else None
|
|
2375
|
+
if save_path is not None:
|
|
2376
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
2377
|
+
|
|
2378
|
+
for col in (sample_col, reference_col):
|
|
2379
|
+
if col not in adata.obs:
|
|
2380
|
+
raise KeyError(f"{col} not in adata.obs")
|
|
2381
|
+
if not isinstance(adata.obs[col].dtype, pd.CategoricalDtype):
|
|
2382
|
+
adata.obs[col] = adata.obs[col].astype("category")
|
|
2383
|
+
|
|
2384
|
+
int_to_base = adata.uns.get("sequence_integer_decoding_map", {}) or {}
|
|
2385
|
+
if not int_to_base:
|
|
2386
|
+
encoding_map = adata.uns.get("sequence_integer_encoding_map", {}) or {}
|
|
2387
|
+
int_to_base = {int(v): str(k) for k, v in encoding_map.items()} if encoding_map else {}
|
|
2388
|
+
|
|
2389
|
+
coerced_int_to_base = {}
|
|
2390
|
+
for key, value in int_to_base.items():
|
|
2391
|
+
try:
|
|
2392
|
+
coerced_key = int(key)
|
|
2393
|
+
except Exception:
|
|
2394
|
+
continue
|
|
2395
|
+
coerced_int_to_base[coerced_key] = str(value)
|
|
2396
|
+
int_to_base = coerced_int_to_base
|
|
2397
|
+
|
|
2398
|
+
def normalize_base(base: str) -> str:
|
|
2399
|
+
return base if base in {"A", "C", "G", "T"} else "OTHER"
|
|
2400
|
+
|
|
2401
|
+
mismatch_int_to_base = {}
|
|
2402
|
+
if mismatch_layer in adata.layers:
|
|
2403
|
+
mismatch_encoding_map = adata.uns.get("mismatch_integer_encoding_map", {}) or {}
|
|
2404
|
+
mismatch_int_to_base = {
|
|
2405
|
+
int(v): str(k)
|
|
2406
|
+
for k, v in mismatch_encoding_map.items()
|
|
2407
|
+
if isinstance(v, (int, np.integer))
|
|
2408
|
+
}
|
|
2409
|
+
|
|
2410
|
+
def _resolve_xtick_step(n_positions: int) -> int | None:
|
|
2411
|
+
if xtick_step is not None:
|
|
2412
|
+
return xtick_step
|
|
2413
|
+
if not show_position_axis:
|
|
2414
|
+
return None
|
|
2415
|
+
return max(1, int(np.ceil(n_positions / position_axis_tick_target)))
|
|
2416
|
+
|
|
2417
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
2418
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
2419
|
+
qmask = _mask_or_true(
|
|
2420
|
+
"read_quality",
|
|
2421
|
+
(lambda s: s >= float(min_quality))
|
|
2422
|
+
if (min_quality is not None)
|
|
2423
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
2424
|
+
)
|
|
2425
|
+
lm_mask = _mask_or_true(
|
|
2426
|
+
"mapped_length",
|
|
2427
|
+
(lambda s: s >= float(min_length))
|
|
2428
|
+
if (min_length is not None)
|
|
2429
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
2430
|
+
)
|
|
2431
|
+
lrr_mask = _mask_or_true(
|
|
2432
|
+
"mapped_length_to_reference_length_ratio",
|
|
2433
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
2434
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
2435
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
2436
|
+
)
|
|
2437
|
+
demux_mask = _mask_or_true(
|
|
2438
|
+
"demux_type",
|
|
2439
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
2440
|
+
if (demux_types is not None)
|
|
2441
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
2442
|
+
)
|
|
2443
|
+
|
|
2444
|
+
row_mask = (
|
|
2445
|
+
(adata.obs[reference_col] == ref)
|
|
2446
|
+
& (adata.obs[sample_col] == sample)
|
|
2447
|
+
& qmask
|
|
2448
|
+
& lm_mask
|
|
2449
|
+
& lrr_mask
|
|
2450
|
+
& demux_mask
|
|
2451
|
+
)
|
|
2452
|
+
if not bool(row_mask.any()):
|
|
2453
|
+
continue
|
|
2454
|
+
|
|
2455
|
+
subset = adata[row_mask, :].copy()
|
|
2456
|
+
matrix = np.asarray(subset.layers[layer])
|
|
2457
|
+
mismatch_matrix = None
|
|
2458
|
+
if mismatch_layer in subset.layers:
|
|
2459
|
+
mismatch_matrix = np.asarray(subset.layers[mismatch_layer])
|
|
2460
|
+
|
|
2461
|
+
if max_unknown_fraction is not None:
|
|
2462
|
+
unknown_mask = np.isin(matrix, np.asarray(unknown_values))
|
|
2463
|
+
unknown_fraction = unknown_mask.mean(axis=0)
|
|
2464
|
+
keep_columns = unknown_fraction <= max_unknown_fraction
|
|
2465
|
+
if not np.any(keep_columns):
|
|
2466
|
+
continue
|
|
2467
|
+
matrix = matrix[:, keep_columns]
|
|
2468
|
+
subset = subset[:, keep_columns].copy()
|
|
2469
|
+
if mismatch_matrix is not None:
|
|
2470
|
+
mismatch_matrix = mismatch_matrix[:, keep_columns]
|
|
2471
|
+
|
|
2472
|
+
if max_reads is not None and matrix.shape[0] > max_reads:
|
|
2473
|
+
matrix = matrix[:max_reads]
|
|
2474
|
+
subset = subset[:max_reads, :].copy()
|
|
2475
|
+
if mismatch_matrix is not None:
|
|
2476
|
+
mismatch_matrix = mismatch_matrix[:max_reads]
|
|
2477
|
+
|
|
2478
|
+
if matrix.size == 0:
|
|
2479
|
+
continue
|
|
2480
|
+
|
|
2481
|
+
if use_dna_5color_palette and not int_to_base:
|
|
2482
|
+
uniq_vals = np.unique(matrix[~pd.isna(matrix)])
|
|
2483
|
+
guess = {}
|
|
2484
|
+
for val in uniq_vals:
|
|
2485
|
+
try:
|
|
2486
|
+
int_val = int(val)
|
|
2487
|
+
except Exception:
|
|
2488
|
+
continue
|
|
2489
|
+
guess[int_val] = {0: "A", 1: "C", 2: "G", 3: "T"}.get(int_val, "OTHER")
|
|
2490
|
+
int_to_base_local = guess
|
|
2491
|
+
else:
|
|
2492
|
+
int_to_base_local = int_to_base
|
|
2493
|
+
|
|
2494
|
+
order = None
|
|
2495
|
+
if sort_by.startswith("obs:"):
|
|
2496
|
+
colname = sort_by.split("obs:")[1]
|
|
2497
|
+
order = np.argsort(subset.obs[colname].values)
|
|
2498
|
+
elif sort_by == "hierarchical":
|
|
2499
|
+
linkage = sch.linkage(np.nan_to_num(matrix), method="ward")
|
|
2500
|
+
order = sch.leaves_list(linkage)
|
|
2501
|
+
elif sort_by != "none":
|
|
2502
|
+
raise ValueError("sort_by must be 'none', 'hierarchical', or 'obs:<col>'")
|
|
2503
|
+
|
|
2504
|
+
if order is not None:
|
|
2505
|
+
matrix = matrix[order]
|
|
2506
|
+
if mismatch_matrix is not None:
|
|
2507
|
+
mismatch_matrix = mismatch_matrix[order]
|
|
2508
|
+
|
|
2509
|
+
has_mismatch = mismatch_matrix is not None
|
|
2510
|
+
fig, axes = plt.subplots(
|
|
2511
|
+
ncols=2 if has_mismatch else 1,
|
|
2512
|
+
figsize=(18, 6) if has_mismatch else (12, 6),
|
|
2513
|
+
sharey=has_mismatch,
|
|
2514
|
+
)
|
|
2515
|
+
if not isinstance(axes, np.ndarray):
|
|
2516
|
+
axes = np.asarray([axes])
|
|
2517
|
+
ax = axes[0]
|
|
2518
|
+
|
|
2519
|
+
if use_dna_5color_palette and int_to_base_local:
|
|
2520
|
+
int_to_color = {
|
|
2521
|
+
int(int_val): DNA_5COLOR_PALETTE[normalize_base(str(base))]
|
|
2522
|
+
for int_val, base in int_to_base_local.items()
|
|
2523
|
+
}
|
|
2524
|
+
uniq_matrix = np.unique(matrix[~pd.isna(matrix)])
|
|
2525
|
+
for val in uniq_matrix:
|
|
2526
|
+
try:
|
|
2527
|
+
int_val = int(val)
|
|
2528
|
+
except Exception:
|
|
2529
|
+
continue
|
|
2530
|
+
if int_val not in int_to_color:
|
|
2531
|
+
int_to_color[int_val] = DNA_5COLOR_PALETTE["OTHER"]
|
|
2532
|
+
|
|
2533
|
+
ordered = sorted(int_to_color.items(), key=lambda x: x[0])
|
|
2534
|
+
colors_list = [color for _, color in ordered]
|
|
2535
|
+
bounds = [int_val - 0.5 for int_val, _ in ordered]
|
|
2536
|
+
bounds.append(ordered[-1][0] + 0.5)
|
|
2537
|
+
|
|
2538
|
+
cmap_obj = colors.ListedColormap(colors_list)
|
|
2539
|
+
norm = colors.BoundaryNorm(bounds, cmap_obj.N)
|
|
2540
|
+
|
|
2541
|
+
sns.heatmap(
|
|
2542
|
+
matrix,
|
|
2543
|
+
cmap=cmap_obj,
|
|
2544
|
+
norm=norm,
|
|
2545
|
+
ax=ax,
|
|
2546
|
+
yticklabels=False,
|
|
2547
|
+
cbar=show_numeric_colorbar,
|
|
2548
|
+
)
|
|
2549
|
+
|
|
2550
|
+
legend_handles = [
|
|
2551
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["A"], label="A"),
|
|
2552
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["C"], label="C"),
|
|
2553
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["G"], label="G"),
|
|
2554
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["T"], label="T"),
|
|
2555
|
+
patches.Patch(
|
|
2556
|
+
facecolor=DNA_5COLOR_PALETTE["OTHER"],
|
|
2557
|
+
label="Other (N / PAD / unknown)",
|
|
2558
|
+
),
|
|
2559
|
+
]
|
|
2560
|
+
ax.legend(
|
|
2561
|
+
handles=legend_handles,
|
|
2562
|
+
title="Base",
|
|
2563
|
+
loc="upper left",
|
|
2564
|
+
bbox_to_anchor=(1.02, 1.0),
|
|
2565
|
+
frameon=False,
|
|
2566
|
+
)
|
|
2567
|
+
else:
|
|
2568
|
+
sns.heatmap(matrix, cmap=cmap, ax=ax, yticklabels=False, cbar=True)
|
|
2569
|
+
|
|
2570
|
+
ax.set_title(layer)
|
|
2571
|
+
|
|
2572
|
+
resolved_step = _resolve_xtick_step(matrix.shape[1])
|
|
2573
|
+
if resolved_step is not None and resolved_step > 0:
|
|
2574
|
+
sites = np.arange(0, matrix.shape[1], resolved_step)
|
|
2575
|
+
ax.set_xticks(sites)
|
|
2576
|
+
ax.set_xticklabels(
|
|
2577
|
+
subset.var_names[sites].astype(str),
|
|
2578
|
+
rotation=xtick_rotation,
|
|
2579
|
+
fontsize=xtick_fontsize,
|
|
2580
|
+
)
|
|
2581
|
+
else:
|
|
2582
|
+
ax.set_xticks([])
|
|
2583
|
+
if show_position_axis or xtick_step is not None:
|
|
2584
|
+
ax.set_xlabel("Position")
|
|
2585
|
+
|
|
2586
|
+
if has_mismatch:
|
|
2587
|
+
mismatch_ax = axes[1]
|
|
2588
|
+
mismatch_int_to_base_local = mismatch_int_to_base or int_to_base_local
|
|
2589
|
+
if use_dna_5color_palette and mismatch_int_to_base_local:
|
|
2590
|
+
mismatch_int_to_color = {}
|
|
2591
|
+
for int_val, base in mismatch_int_to_base_local.items():
|
|
2592
|
+
base_upper = str(base).upper()
|
|
2593
|
+
if base_upper == "PAD":
|
|
2594
|
+
mismatch_int_to_color[int(int_val)] = "#D3D3D3"
|
|
2595
|
+
elif base_upper == "N":
|
|
2596
|
+
mismatch_int_to_color[int(int_val)] = "#808080"
|
|
2597
|
+
else:
|
|
2598
|
+
mismatch_int_to_color[int(int_val)] = DNA_5COLOR_PALETTE[
|
|
2599
|
+
normalize_base(base_upper)
|
|
2600
|
+
]
|
|
2601
|
+
|
|
2602
|
+
uniq_mismatch = np.unique(mismatch_matrix[~pd.isna(mismatch_matrix)])
|
|
2603
|
+
for val in uniq_mismatch:
|
|
2604
|
+
try:
|
|
2605
|
+
int_val = int(val)
|
|
2606
|
+
except Exception:
|
|
2607
|
+
continue
|
|
2608
|
+
if int_val not in mismatch_int_to_color:
|
|
2609
|
+
mismatch_int_to_color[int_val] = DNA_5COLOR_PALETTE["OTHER"]
|
|
2610
|
+
|
|
2611
|
+
ordered_mismatch = sorted(mismatch_int_to_color.items(), key=lambda x: x[0])
|
|
2612
|
+
mismatch_colors = [color for _, color in ordered_mismatch]
|
|
2613
|
+
mismatch_bounds = [int_val - 0.5 for int_val, _ in ordered_mismatch]
|
|
2614
|
+
mismatch_bounds.append(ordered_mismatch[-1][0] + 0.5)
|
|
2615
|
+
|
|
2616
|
+
mismatch_cmap = colors.ListedColormap(mismatch_colors)
|
|
2617
|
+
mismatch_norm = colors.BoundaryNorm(mismatch_bounds, mismatch_cmap.N)
|
|
2618
|
+
|
|
2619
|
+
sns.heatmap(
|
|
2620
|
+
mismatch_matrix,
|
|
2621
|
+
cmap=mismatch_cmap,
|
|
2622
|
+
norm=mismatch_norm,
|
|
2623
|
+
ax=mismatch_ax,
|
|
2624
|
+
yticklabels=False,
|
|
2625
|
+
cbar=show_numeric_colorbar,
|
|
2626
|
+
)
|
|
2627
|
+
|
|
2628
|
+
mismatch_legend_handles = [
|
|
2629
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["A"], label="A"),
|
|
2630
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["C"], label="C"),
|
|
2631
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["G"], label="G"),
|
|
2632
|
+
patches.Patch(facecolor=DNA_5COLOR_PALETTE["T"], label="T"),
|
|
2633
|
+
patches.Patch(facecolor="#808080", label="Match/N"),
|
|
2634
|
+
patches.Patch(facecolor="#D3D3D3", label="PAD"),
|
|
2635
|
+
]
|
|
2636
|
+
mismatch_ax.legend(
|
|
2637
|
+
handles=mismatch_legend_handles,
|
|
2638
|
+
title="Mismatch base",
|
|
2639
|
+
loc="upper left",
|
|
2640
|
+
bbox_to_anchor=(1.02, 1.0),
|
|
2641
|
+
frameon=False,
|
|
2642
|
+
)
|
|
2643
|
+
else:
|
|
2644
|
+
sns.heatmap(
|
|
2645
|
+
mismatch_matrix,
|
|
2646
|
+
cmap=cmap,
|
|
2647
|
+
ax=mismatch_ax,
|
|
2648
|
+
yticklabels=False,
|
|
2649
|
+
cbar=True,
|
|
2650
|
+
)
|
|
2651
|
+
|
|
2652
|
+
mismatch_ax.set_title(mismatch_layer)
|
|
2653
|
+
if resolved_step is not None and resolved_step > 0:
|
|
2654
|
+
sites = np.arange(0, mismatch_matrix.shape[1], resolved_step)
|
|
2655
|
+
mismatch_ax.set_xticks(sites)
|
|
2656
|
+
mismatch_ax.set_xticklabels(
|
|
2657
|
+
subset.var_names[sites].astype(str),
|
|
2658
|
+
rotation=xtick_rotation,
|
|
2659
|
+
fontsize=xtick_fontsize,
|
|
2660
|
+
)
|
|
2661
|
+
else:
|
|
2662
|
+
mismatch_ax.set_xticks([])
|
|
2663
|
+
if show_position_axis or xtick_step is not None:
|
|
2664
|
+
mismatch_ax.set_xlabel("Position")
|
|
2665
|
+
|
|
2666
|
+
fig.suptitle(f"{sample} - {ref}")
|
|
2667
|
+
fig.tight_layout(rect=(0, 0, 1, 0.95))
|
|
2668
|
+
|
|
2669
|
+
out_file = None
|
|
2670
|
+
if save_path is not None:
|
|
2671
|
+
safe_name = f"{ref}__{sample}__{layer}".replace("=", "").replace(",", "_")
|
|
2672
|
+
out_file = save_path / f"{safe_name}.png"
|
|
2673
|
+
fig.savefig(out_file, dpi=300, bbox_inches="tight")
|
|
2674
|
+
plt.close(fig)
|
|
2675
|
+
else:
|
|
2676
|
+
plt.show()
|
|
2677
|
+
|
|
2678
|
+
results.append(
|
|
2679
|
+
{
|
|
2680
|
+
"reference": str(ref),
|
|
2681
|
+
"sample": str(sample),
|
|
2682
|
+
"layer": layer,
|
|
2683
|
+
"n_positions": int(matrix.shape[1]),
|
|
2684
|
+
"mismatch_layer": mismatch_layer if has_mismatch else None,
|
|
2685
|
+
"mismatch_layer_present": bool(has_mismatch),
|
|
2686
|
+
"output_path": str(out_file) if out_file is not None else None,
|
|
2687
|
+
}
|
|
2688
|
+
)
|
|
2689
|
+
|
|
2690
|
+
return results
|
|
2691
|
+
|
|
2692
|
+
|
|
2693
|
+
def plot_read_span_quality_clustermaps(
|
|
2694
|
+
adata,
|
|
2695
|
+
sample_col: str = "Sample_Names",
|
|
2696
|
+
reference_col: str = "Reference_strand",
|
|
2697
|
+
quality_layer: str = "base_quality_scores",
|
|
2698
|
+
read_span_layer: str = "read_span_mask",
|
|
2699
|
+
quality_cmap: str = "viridis",
|
|
2700
|
+
read_span_color: str = "#2ca25f",
|
|
2701
|
+
max_nan_fraction: float | None = None,
|
|
2702
|
+
min_quality: float | None = None,
|
|
2703
|
+
min_length: int | None = None,
|
|
2704
|
+
min_mapped_length_to_reference_length_ratio: float | None = None,
|
|
2705
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
2706
|
+
max_reads: int | None = None,
|
|
2707
|
+
xtick_step: int | None = None,
|
|
2708
|
+
xtick_rotation: int = 90,
|
|
2709
|
+
xtick_fontsize: int = 9,
|
|
2710
|
+
show_position_axis: bool = False,
|
|
2711
|
+
position_axis_tick_target: int = 25,
|
|
2712
|
+
save_path: str | Path | None = None,
|
|
2713
|
+
) -> List[Dict[str, Any]]:
|
|
2714
|
+
"""Plot read-span mask and base quality clustermaps side by side.
|
|
2715
|
+
|
|
2716
|
+
Clustering is performed using the base-quality layer ordering, which is then
|
|
2717
|
+
applied to the read-span mask to keep the two panels aligned.
|
|
2718
|
+
|
|
2719
|
+
Args:
|
|
2720
|
+
adata: AnnData with read-span and base-quality layers.
|
|
2721
|
+
sample_col: Column in ``adata.obs`` that identifies samples.
|
|
2722
|
+
reference_col: Column in ``adata.obs`` that identifies references.
|
|
2723
|
+
quality_layer: Layer name containing base-quality scores.
|
|
2724
|
+
read_span_layer: Layer name containing read-span masks.
|
|
2725
|
+
quality_cmap: Colormap for base-quality scores.
|
|
2726
|
+
read_span_color: Color for read-span mask (1-values); 0-values are white.
|
|
2727
|
+
max_nan_fraction: Optional maximum fraction of NaNs allowed per position; positions
|
|
2728
|
+
above this threshold are excluded.
|
|
2729
|
+
min_quality: Optional minimum read quality filter.
|
|
2730
|
+
min_length: Optional minimum mapped length filter.
|
|
2731
|
+
min_mapped_length_to_reference_length_ratio: Optional min length ratio filter.
|
|
2732
|
+
demux_types: Allowed ``demux_type`` values, if present in ``adata.obs``.
|
|
2733
|
+
max_reads: Optional maximum number of reads to plot per sample/reference.
|
|
2734
|
+
xtick_step: Spacing between x-axis tick labels (None = no labels).
|
|
2735
|
+
xtick_rotation: Rotation for x-axis tick labels.
|
|
2736
|
+
xtick_fontsize: Font size for x-axis tick labels.
|
|
2737
|
+
show_position_axis: Whether to draw a position axis with tick labels.
|
|
2738
|
+
position_axis_tick_target: Approximate number of ticks to show when auto-sizing.
|
|
2739
|
+
save_path: Optional output directory for saving plots.
|
|
2740
|
+
|
|
2741
|
+
Returns:
|
|
2742
|
+
List of dictionaries with per-plot metadata and output paths.
|
|
2743
|
+
"""
|
|
2744
|
+
|
|
2745
|
+
def _mask_or_true(series_name: str, predicate):
|
|
2746
|
+
if series_name not in adata.obs:
|
|
2747
|
+
return pd.Series(True, index=adata.obs.index)
|
|
2748
|
+
s = adata.obs[series_name]
|
|
2749
|
+
try:
|
|
2750
|
+
return predicate(s)
|
|
2751
|
+
except Exception:
|
|
2752
|
+
return pd.Series(True, index=adata.obs.index)
|
|
2753
|
+
|
|
2754
|
+
def _resolve_xtick_step(n_positions: int) -> int | None:
|
|
2755
|
+
if xtick_step is not None:
|
|
2756
|
+
return xtick_step
|
|
2757
|
+
if not show_position_axis:
|
|
2758
|
+
return None
|
|
2759
|
+
return max(1, int(np.ceil(n_positions / position_axis_tick_target)))
|
|
2760
|
+
|
|
2761
|
+
def _fill_nan_with_col_means(matrix: np.ndarray) -> np.ndarray:
|
|
2762
|
+
filled = matrix.copy()
|
|
2763
|
+
col_means = np.nanmean(filled, axis=0)
|
|
2764
|
+
col_means = np.where(np.isnan(col_means), 0.0, col_means)
|
|
2765
|
+
nan_rows, nan_cols = np.where(np.isnan(filled))
|
|
2766
|
+
filled[nan_rows, nan_cols] = col_means[nan_cols]
|
|
2767
|
+
return filled
|
|
2768
|
+
|
|
2769
|
+
if quality_layer not in adata.layers:
|
|
2770
|
+
raise KeyError(f"Layer '{quality_layer}' not found in adata.layers")
|
|
2771
|
+
if read_span_layer not in adata.layers:
|
|
2772
|
+
raise KeyError(f"Layer '{read_span_layer}' not found in adata.layers")
|
|
2773
|
+
if max_nan_fraction is not None and not (0 <= max_nan_fraction <= 1):
|
|
2774
|
+
raise ValueError("max_nan_fraction must be between 0 and 1.")
|
|
2775
|
+
if position_axis_tick_target < 1:
|
|
2776
|
+
raise ValueError("position_axis_tick_target must be at least 1.")
|
|
2777
|
+
|
|
2778
|
+
results: List[Dict[str, Any]] = []
|
|
2779
|
+
save_path = Path(save_path) if save_path is not None else None
|
|
2780
|
+
if save_path is not None:
|
|
2781
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
2782
|
+
|
|
2783
|
+
for col in (sample_col, reference_col):
|
|
2784
|
+
if col not in adata.obs:
|
|
2785
|
+
raise KeyError(f"{col} not in adata.obs")
|
|
2786
|
+
if not isinstance(adata.obs[col].dtype, pd.CategoricalDtype):
|
|
2787
|
+
adata.obs[col] = adata.obs[col].astype("category")
|
|
2788
|
+
|
|
2789
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
2790
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
2791
|
+
qmask = _mask_or_true(
|
|
2792
|
+
"read_quality",
|
|
2793
|
+
(lambda s: s >= float(min_quality))
|
|
2794
|
+
if (min_quality is not None)
|
|
2795
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
2796
|
+
)
|
|
2797
|
+
lm_mask = _mask_or_true(
|
|
2798
|
+
"mapped_length",
|
|
2799
|
+
(lambda s: s >= float(min_length))
|
|
2800
|
+
if (min_length is not None)
|
|
2801
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
2802
|
+
)
|
|
2803
|
+
lrr_mask = _mask_or_true(
|
|
2804
|
+
"mapped_length_to_reference_length_ratio",
|
|
2805
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
2806
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
2807
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
2808
|
+
)
|
|
2809
|
+
demux_mask = _mask_or_true(
|
|
2810
|
+
"demux_type",
|
|
2811
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
2812
|
+
if (demux_types is not None)
|
|
2813
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
2814
|
+
)
|
|
2815
|
+
|
|
2816
|
+
row_mask = (
|
|
2817
|
+
(adata.obs[reference_col] == ref)
|
|
2818
|
+
& (adata.obs[sample_col] == sample)
|
|
2819
|
+
& qmask
|
|
2820
|
+
& lm_mask
|
|
2821
|
+
& lrr_mask
|
|
2822
|
+
& demux_mask
|
|
2823
|
+
)
|
|
2824
|
+
if not bool(row_mask.any()):
|
|
2825
|
+
continue
|
|
2826
|
+
|
|
2827
|
+
subset = adata[row_mask, :].copy()
|
|
2828
|
+
quality_matrix = np.asarray(subset.layers[quality_layer]).astype(float)
|
|
2829
|
+
quality_matrix[quality_matrix < 0] = np.nan
|
|
2830
|
+
read_span_matrix = np.asarray(subset.layers[read_span_layer]).astype(float)
|
|
2831
|
+
|
|
2832
|
+
if max_nan_fraction is not None:
|
|
2833
|
+
nan_mask = np.isnan(quality_matrix) | np.isnan(read_span_matrix)
|
|
2834
|
+
nan_fraction = nan_mask.mean(axis=0)
|
|
2835
|
+
keep_columns = nan_fraction <= max_nan_fraction
|
|
2836
|
+
if not np.any(keep_columns):
|
|
2837
|
+
continue
|
|
2838
|
+
quality_matrix = quality_matrix[:, keep_columns]
|
|
2839
|
+
read_span_matrix = read_span_matrix[:, keep_columns]
|
|
2840
|
+
subset = subset[:, keep_columns].copy()
|
|
2841
|
+
|
|
2842
|
+
if max_reads is not None and quality_matrix.shape[0] > max_reads:
|
|
2843
|
+
quality_matrix = quality_matrix[:max_reads]
|
|
2844
|
+
read_span_matrix = read_span_matrix[:max_reads]
|
|
2845
|
+
subset = subset[:max_reads, :].copy()
|
|
2846
|
+
|
|
2847
|
+
if quality_matrix.size == 0:
|
|
2848
|
+
continue
|
|
2849
|
+
|
|
2850
|
+
quality_filled = _fill_nan_with_col_means(quality_matrix)
|
|
2851
|
+
linkage = sch.linkage(quality_filled, method="ward")
|
|
2852
|
+
order = sch.leaves_list(linkage)
|
|
2853
|
+
|
|
2854
|
+
quality_matrix = quality_matrix[order]
|
|
2855
|
+
read_span_matrix = read_span_matrix[order]
|
|
2856
|
+
|
|
2857
|
+
fig, axes = plt.subplots(
|
|
2858
|
+
nrows=2,
|
|
2859
|
+
ncols=3,
|
|
2860
|
+
figsize=(18, 6),
|
|
2861
|
+
sharex="col",
|
|
2862
|
+
gridspec_kw={"height_ratios": [1, 4], "width_ratios": [1, 1, 0.05]},
|
|
2863
|
+
)
|
|
2864
|
+
span_bar_ax, quality_bar_ax, bar_spacer_ax = axes[0]
|
|
2865
|
+
span_ax, quality_ax, cbar_ax = axes[1]
|
|
2866
|
+
bar_spacer_ax.set_axis_off()
|
|
2867
|
+
|
|
2868
|
+
span_mean = np.nanmean(read_span_matrix, axis=0)
|
|
2869
|
+
quality_mean = np.nanmean(quality_matrix, axis=0)
|
|
2870
|
+
bar_positions = np.arange(read_span_matrix.shape[1]) + 0.5
|
|
2871
|
+
span_bar_ax.bar(
|
|
2872
|
+
bar_positions,
|
|
2873
|
+
span_mean,
|
|
2874
|
+
color=read_span_color,
|
|
2875
|
+
width=1.0,
|
|
2876
|
+
)
|
|
2877
|
+
span_bar_ax.set_title(f"{read_span_layer} mean")
|
|
2878
|
+
span_bar_ax.set_xlim(0, read_span_matrix.shape[1])
|
|
2879
|
+
span_bar_ax.tick_params(axis="x", labelbottom=False)
|
|
2880
|
+
|
|
2881
|
+
quality_bar_ax.bar(
|
|
2882
|
+
bar_positions,
|
|
2883
|
+
quality_mean,
|
|
2884
|
+
color="#4c72b0",
|
|
2885
|
+
width=1.0,
|
|
2886
|
+
)
|
|
2887
|
+
quality_bar_ax.set_title(f"{quality_layer} mean")
|
|
2888
|
+
quality_bar_ax.set_xlim(0, quality_matrix.shape[1])
|
|
2889
|
+
quality_bar_ax.tick_params(axis="x", labelbottom=False)
|
|
2890
|
+
|
|
2891
|
+
span_cmap = colors.ListedColormap(["white", read_span_color])
|
|
2892
|
+
span_norm = colors.BoundaryNorm([-0.5, 0.5, 1.5], span_cmap.N)
|
|
2893
|
+
sns.heatmap(
|
|
2894
|
+
read_span_matrix,
|
|
2895
|
+
cmap=span_cmap,
|
|
2896
|
+
norm=span_norm,
|
|
2897
|
+
ax=span_ax,
|
|
2898
|
+
yticklabels=False,
|
|
2899
|
+
cbar=False,
|
|
2900
|
+
)
|
|
2901
|
+
span_ax.set_title(read_span_layer)
|
|
2902
|
+
|
|
2903
|
+
sns.heatmap(
|
|
2904
|
+
quality_matrix,
|
|
2905
|
+
cmap=quality_cmap,
|
|
2906
|
+
ax=quality_ax,
|
|
2907
|
+
yticklabels=False,
|
|
2908
|
+
cbar=True,
|
|
2909
|
+
cbar_ax=cbar_ax,
|
|
2910
|
+
)
|
|
2911
|
+
quality_ax.set_title(quality_layer)
|
|
2912
|
+
|
|
2913
|
+
resolved_step = _resolve_xtick_step(quality_matrix.shape[1])
|
|
2914
|
+
for axis in (span_ax, quality_ax):
|
|
2915
|
+
if resolved_step is not None and resolved_step > 0:
|
|
2916
|
+
sites = np.arange(0, quality_matrix.shape[1], resolved_step)
|
|
2917
|
+
axis.set_xticks(sites)
|
|
2918
|
+
axis.set_xticklabels(
|
|
2919
|
+
subset.var_names[sites].astype(str),
|
|
2920
|
+
rotation=xtick_rotation,
|
|
2921
|
+
fontsize=xtick_fontsize,
|
|
2922
|
+
)
|
|
2923
|
+
else:
|
|
2924
|
+
axis.set_xticks([])
|
|
2925
|
+
if show_position_axis or xtick_step is not None:
|
|
2926
|
+
axis.set_xlabel("Position")
|
|
2927
|
+
|
|
2928
|
+
fig.suptitle(f"{sample} - {ref}")
|
|
2929
|
+
fig.tight_layout(rect=(0, 0, 1, 0.95))
|
|
2930
|
+
|
|
2931
|
+
out_file = None
|
|
2932
|
+
if save_path is not None:
|
|
2933
|
+
safe_name = f"{ref}__{sample}__read_span_quality".replace("=", "").replace(",", "_")
|
|
2934
|
+
out_file = save_path / f"{safe_name}.png"
|
|
2935
|
+
fig.savefig(out_file, dpi=300, bbox_inches="tight")
|
|
2936
|
+
plt.close(fig)
|
|
2937
|
+
else:
|
|
2938
|
+
plt.show()
|
|
2939
|
+
|
|
2940
|
+
results.append(
|
|
2941
|
+
{
|
|
2942
|
+
"reference": str(ref),
|
|
2943
|
+
"sample": str(sample),
|
|
2944
|
+
"quality_layer": quality_layer,
|
|
2945
|
+
"read_span_layer": read_span_layer,
|
|
2946
|
+
"n_positions": int(quality_matrix.shape[1]),
|
|
2947
|
+
"output_path": str(out_file) if out_file is not None else None,
|
|
2948
|
+
}
|
|
2949
|
+
)
|
|
2950
|
+
|
|
2951
|
+
return results
|
|
2952
|
+
|
|
2953
|
+
|
|
2954
|
+
def plot_hmm_layers_rolling_by_sample_ref(
|
|
2955
|
+
adata,
|
|
2956
|
+
layers: Optional[Sequence[str]] = None,
|
|
2957
|
+
sample_col: str = "Barcode",
|
|
2958
|
+
ref_col: str = "Reference_strand",
|
|
2959
|
+
samples: Optional[Sequence[str]] = None,
|
|
2960
|
+
references: Optional[Sequence[str]] = None,
|
|
2961
|
+
window: int = 51,
|
|
2962
|
+
min_periods: int = 1,
|
|
2963
|
+
center: bool = True,
|
|
2964
|
+
rows_per_page: int = 6,
|
|
2965
|
+
figsize_per_cell: Tuple[float, float] = (4.0, 2.5),
|
|
2966
|
+
dpi: int = 160,
|
|
2967
|
+
output_dir: Optional[str] = None,
|
|
2968
|
+
save: bool = True,
|
|
2969
|
+
show_raw: bool = False,
|
|
2970
|
+
cmap: str = "tab20",
|
|
2971
|
+
layer_colors: Optional[Mapping[str, Any]] = None,
|
|
2972
|
+
use_var_coords: bool = True,
|
|
2973
|
+
reindexed_var_suffix: str = "reindexed",
|
|
2974
|
+
):
|
|
2975
|
+
"""
|
|
2976
|
+
For each sample (row) and reference (col) plot the rolling average of the
|
|
2977
|
+
positional mean (mean across reads) for each layer listed.
|
|
2978
|
+
|
|
2979
|
+
Parameters
|
|
2980
|
+
----------
|
|
2981
|
+
adata : AnnData
|
|
2982
|
+
Input annotated data (expects obs columns sample_col and ref_col).
|
|
2983
|
+
layers : list[str] | None
|
|
2984
|
+
Which adata.layers to plot. If None, attempts to autodetect layers whose
|
|
2985
|
+
matrices look like "HMM" outputs (else will error). If None and layers
|
|
2986
|
+
cannot be found, user must pass a list.
|
|
2987
|
+
sample_col, ref_col : str
|
|
2988
|
+
obs columns used to group rows.
|
|
2989
|
+
samples, references : optional lists
|
|
2990
|
+
explicit ordering of samples / references. If None, categories in adata.obs are used.
|
|
2991
|
+
window : int
|
|
2992
|
+
rolling window size (odd recommended). If window <= 1, no smoothing applied.
|
|
2993
|
+
min_periods : int
|
|
2994
|
+
min periods param for pd.Series.rolling.
|
|
2995
|
+
center : bool
|
|
2996
|
+
center the rolling window.
|
|
2997
|
+
rows_per_page : int
|
|
2998
|
+
paginate rows per page into multiple figures if needed.
|
|
2999
|
+
figsize_per_cell : (w,h)
|
|
3000
|
+
per-subplot size in inches.
|
|
3001
|
+
dpi : int
|
|
3002
|
+
figure dpi when saving.
|
|
3003
|
+
output_dir : str | None
|
|
3004
|
+
directory to save pages; created if necessary. If None and save=True, uses cwd.
|
|
3005
|
+
save : bool
|
|
3006
|
+
whether to save PNG files.
|
|
3007
|
+
show_raw : bool
|
|
3008
|
+
draw unsmoothed mean as faint line under smoothed curve.
|
|
3009
|
+
cmap : str
|
|
3010
|
+
matplotlib colormap for layer lines.
|
|
3011
|
+
layer_colors : dict[str, Any] | None
|
|
3012
|
+
Optional mapping of layer name to explicit line colors.
|
|
3013
|
+
use_var_coords : bool
|
|
3014
|
+
if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
|
|
3015
|
+
reindexed_var_suffix : str
|
|
3016
|
+
Suffix for per-reference reindexed var columns (e.g., ``Reference_reindexed``) used when available.
|
|
3017
|
+
|
|
3018
|
+
Returns
|
|
3019
|
+
-------
|
|
3020
|
+
saved_files : list[str]
|
|
3021
|
+
list of saved filenames (may be empty if save=False).
|
|
3022
|
+
"""
|
|
3023
|
+
|
|
3024
|
+
# --- basic checks / defaults ---
|
|
3025
|
+
if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
|
|
3026
|
+
raise ValueError(
|
|
3027
|
+
f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
|
|
3028
|
+
)
|
|
3029
|
+
|
|
3030
|
+
# canonicalize samples / refs
|
|
3031
|
+
if samples is None:
|
|
3032
|
+
sseries = adata.obs[sample_col]
|
|
3033
|
+
if not pd.api.types.is_categorical_dtype(sseries):
|
|
3034
|
+
sseries = sseries.astype("category")
|
|
3035
|
+
samples_all = list(sseries.cat.categories)
|
|
3036
|
+
else:
|
|
3037
|
+
samples_all = list(samples)
|
|
3038
|
+
|
|
3039
|
+
if references is None:
|
|
3040
|
+
rseries = adata.obs[ref_col]
|
|
3041
|
+
if not pd.api.types.is_categorical_dtype(rseries):
|
|
3042
|
+
rseries = rseries.astype("category")
|
|
3043
|
+
refs_all = list(rseries.cat.categories)
|
|
3044
|
+
else:
|
|
3045
|
+
refs_all = list(references)
|
|
3046
|
+
|
|
3047
|
+
# choose layers: if not provided, try a sensible default: all layers
|
|
3048
|
+
if layers is None:
|
|
3049
|
+
layers = list(adata.layers.keys())
|
|
3050
|
+
if len(layers) == 0:
|
|
3051
|
+
raise ValueError(
|
|
3052
|
+
"No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
|
|
3053
|
+
)
|
|
3054
|
+
layers = list(layers)
|
|
3055
|
+
|
|
3056
|
+
# x coordinates (positions) + optional labels
|
|
3057
|
+
x_labels = None
|
|
3058
|
+
try:
|
|
3059
|
+
if use_var_coords:
|
|
3060
|
+
x_coords = np.array([int(v) for v in adata.var_names])
|
|
3061
|
+
else:
|
|
3062
|
+
raise Exception("user disabled var coords")
|
|
3063
|
+
except Exception:
|
|
3064
|
+
# fallback to 0..n_vars-1, but keep var_names as labels
|
|
3065
|
+
x_coords = np.arange(adata.shape[1], dtype=int)
|
|
3066
|
+
x_labels = adata.var_names.astype(str).tolist()
|
|
3067
|
+
|
|
3068
|
+
ref_reindexed_cols = {
|
|
3069
|
+
ref: f"{ref}_{reindexed_var_suffix}"
|
|
3070
|
+
for ref in refs_all
|
|
3071
|
+
if f"{ref}_{reindexed_var_suffix}" in adata.var
|
|
3072
|
+
}
|
|
3073
|
+
|
|
1429
3074
|
# make output dir
|
|
1430
3075
|
if save:
|
|
1431
3076
|
outdir = output_dir or os.getcwd()
|
|
@@ -1441,7 +3086,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1441
3086
|
# color cycle for layers
|
|
1442
3087
|
cmap_obj = plt.get_cmap(cmap)
|
|
1443
3088
|
n_layers = max(1, len(layers))
|
|
1444
|
-
|
|
3089
|
+
fallback_colors = [cmap_obj(i / max(1, n_layers - 1)) for i in range(n_layers)]
|
|
3090
|
+
layer_colors = layer_colors or {}
|
|
3091
|
+
colors = [layer_colors.get(layer, fallback_colors[idx]) for idx, layer in enumerate(layers)]
|
|
1445
3092
|
|
|
1446
3093
|
for page in range(total_pages):
|
|
1447
3094
|
start = page * rows_per_page
|
|
@@ -1486,6 +3133,14 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1486
3133
|
|
|
1487
3134
|
# for each layer, compute positional mean across reads (ignore NaNs)
|
|
1488
3135
|
plotted_any = False
|
|
3136
|
+
reindexed_col = ref_reindexed_cols.get(ref_name)
|
|
3137
|
+
if reindexed_col is not None:
|
|
3138
|
+
try:
|
|
3139
|
+
ref_coords = np.asarray(adata.var[reindexed_col], dtype=int)
|
|
3140
|
+
except Exception:
|
|
3141
|
+
ref_coords = x_coords
|
|
3142
|
+
else:
|
|
3143
|
+
ref_coords = x_coords
|
|
1489
3144
|
for li, layer in enumerate(layers):
|
|
1490
3145
|
if layer in sub.layers:
|
|
1491
3146
|
mat = sub.layers[layer]
|
|
@@ -1519,6 +3174,8 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1519
3174
|
if np.all(np.isnan(col_mean)):
|
|
1520
3175
|
continue
|
|
1521
3176
|
|
|
3177
|
+
valid_mask = np.isfinite(col_mean)
|
|
3178
|
+
|
|
1522
3179
|
# smooth via pandas rolling (centered)
|
|
1523
3180
|
if (window is None) or (window <= 1):
|
|
1524
3181
|
smoothed = col_mean
|
|
@@ -1529,10 +3186,11 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1529
3186
|
.mean()
|
|
1530
3187
|
.to_numpy()
|
|
1531
3188
|
)
|
|
3189
|
+
smoothed = np.where(valid_mask, smoothed, np.nan)
|
|
1532
3190
|
|
|
1533
3191
|
# x axis: x_coords (trim/pad to match length)
|
|
1534
3192
|
L = len(col_mean)
|
|
1535
|
-
x =
|
|
3193
|
+
x = ref_coords[:L]
|
|
1536
3194
|
|
|
1537
3195
|
# optionally plot raw faint line first
|
|
1538
3196
|
if show_raw:
|
|
@@ -1557,6 +3215,13 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1557
3215
|
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
|
|
1558
3216
|
if r_idx == nrows - 1:
|
|
1559
3217
|
ax.set_xlabel("position", fontsize=8)
|
|
3218
|
+
if x_labels is not None and reindexed_col is None:
|
|
3219
|
+
max_ticks = 8
|
|
3220
|
+
tick_step = max(1, int(math.ceil(len(x_labels) / max_ticks)))
|
|
3221
|
+
tick_positions = x_coords[::tick_step]
|
|
3222
|
+
tick_labels = x_labels[::tick_step]
|
|
3223
|
+
ax.set_xticks(tick_positions)
|
|
3224
|
+
ax.set_xticklabels(tick_labels, fontsize=7, rotation=45, ha="right")
|
|
1560
3225
|
|
|
1561
3226
|
# legend (only show in top-left plot to reduce clutter)
|
|
1562
3227
|
if (r_idx == 0 and c_idx == 0) and plotted_any:
|
|
@@ -1580,3 +3245,124 @@ def plot_hmm_layers_rolling_by_sample_ref(
|
|
|
1580
3245
|
plt.close(fig)
|
|
1581
3246
|
|
|
1582
3247
|
return saved_files
|
|
3248
|
+
|
|
3249
|
+
|
|
3250
|
+
def _resolve_embedding(adata: "ad.AnnData", basis: str) -> np.ndarray:
|
|
3251
|
+
key = basis if basis.startswith("X_") else f"X_{basis}"
|
|
3252
|
+
if key not in adata.obsm:
|
|
3253
|
+
raise KeyError(f"Embedding '{key}' not found in adata.obsm.")
|
|
3254
|
+
embedding = np.asarray(adata.obsm[key])
|
|
3255
|
+
if embedding.shape[1] < 2:
|
|
3256
|
+
raise ValueError(f"Embedding '{key}' must have at least two dimensions.")
|
|
3257
|
+
return embedding[:, :2]
|
|
3258
|
+
|
|
3259
|
+
|
|
3260
|
+
def plot_embedding(
|
|
3261
|
+
adata: "ad.AnnData",
|
|
3262
|
+
*,
|
|
3263
|
+
basis: str,
|
|
3264
|
+
color: str | Sequence[str],
|
|
3265
|
+
output_dir: Path | str,
|
|
3266
|
+
prefix: str | None = None,
|
|
3267
|
+
point_size: float = 12,
|
|
3268
|
+
alpha: float = 0.8,
|
|
3269
|
+
) -> Dict[str, Path]:
|
|
3270
|
+
"""Plot a 2D embedding with scanpy-style color options.
|
|
3271
|
+
|
|
3272
|
+
Args:
|
|
3273
|
+
adata: AnnData object with ``obsm['X_<basis>']``.
|
|
3274
|
+
basis: Embedding basis name (e.g., ``'umap'``, ``'pca'``).
|
|
3275
|
+
color: Obs column name or list of names to color by.
|
|
3276
|
+
output_dir: Directory to save plots.
|
|
3277
|
+
prefix: Optional filename prefix.
|
|
3278
|
+
point_size: Marker size for scatter plots.
|
|
3279
|
+
alpha: Marker transparency.
|
|
3280
|
+
|
|
3281
|
+
Returns:
|
|
3282
|
+
Dict[str, Path]: Mapping of color keys to saved plot paths.
|
|
3283
|
+
"""
|
|
3284
|
+
output_path = Path(output_dir)
|
|
3285
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
3286
|
+
embedding = _resolve_embedding(adata, basis)
|
|
3287
|
+
colors = [color] if isinstance(color, str) else list(color)
|
|
3288
|
+
saved: Dict[str, Path] = {}
|
|
3289
|
+
|
|
3290
|
+
for color_key in colors:
|
|
3291
|
+
if color_key not in adata.obs:
|
|
3292
|
+
logger.warning("Color key '%s' not found in adata.obs; skipping.", color_key)
|
|
3293
|
+
continue
|
|
3294
|
+
values = adata.obs[color_key]
|
|
3295
|
+
fig, ax = plt.subplots(figsize=(5.5, 4.5))
|
|
3296
|
+
|
|
3297
|
+
if pd.api.types.is_categorical_dtype(values) or values.dtype == object:
|
|
3298
|
+
categories = pd.Categorical(values)
|
|
3299
|
+
label_strings = categories.categories.astype(str)
|
|
3300
|
+
palette = sns.color_palette("tab20", n_colors=len(label_strings))
|
|
3301
|
+
color_map = dict(zip(label_strings, palette))
|
|
3302
|
+
codes = categories.codes
|
|
3303
|
+
mapped = np.empty(len(codes), dtype=object)
|
|
3304
|
+
valid = codes >= 0
|
|
3305
|
+
if np.any(valid):
|
|
3306
|
+
valid_codes = codes[valid]
|
|
3307
|
+
mapped_values = np.empty(len(valid_codes), dtype=object)
|
|
3308
|
+
for i, idx in enumerate(valid_codes):
|
|
3309
|
+
mapped_values[i] = palette[idx]
|
|
3310
|
+
mapped[valid] = mapped_values
|
|
3311
|
+
mapped[~valid] = "#bdbdbd"
|
|
3312
|
+
ax.scatter(
|
|
3313
|
+
embedding[:, 0],
|
|
3314
|
+
embedding[:, 1],
|
|
3315
|
+
c=list(mapped),
|
|
3316
|
+
s=point_size,
|
|
3317
|
+
alpha=alpha,
|
|
3318
|
+
linewidths=0,
|
|
3319
|
+
)
|
|
3320
|
+
handles = [
|
|
3321
|
+
patches.Patch(color=color_map[label], label=str(label)) for label in label_strings
|
|
3322
|
+
]
|
|
3323
|
+
ax.legend(handles=handles, loc="best", fontsize=8, frameon=False)
|
|
3324
|
+
else:
|
|
3325
|
+
scatter = ax.scatter(
|
|
3326
|
+
embedding[:, 0],
|
|
3327
|
+
embedding[:, 1],
|
|
3328
|
+
c=values.astype(float),
|
|
3329
|
+
cmap="viridis",
|
|
3330
|
+
s=point_size,
|
|
3331
|
+
alpha=alpha,
|
|
3332
|
+
linewidths=0,
|
|
3333
|
+
)
|
|
3334
|
+
fig.colorbar(scatter, ax=ax, label=color_key)
|
|
3335
|
+
|
|
3336
|
+
ax.set_xlabel(f"{basis.upper()} 1")
|
|
3337
|
+
ax.set_ylabel(f"{basis.upper()} 2")
|
|
3338
|
+
ax.set_title(f"{basis.upper()} colored by {color_key}")
|
|
3339
|
+
fig.tight_layout()
|
|
3340
|
+
|
|
3341
|
+
filename_prefix = prefix or basis
|
|
3342
|
+
safe_key = str(color_key).replace(" ", "_")
|
|
3343
|
+
output_file = output_path / f"{filename_prefix}_{safe_key}.png"
|
|
3344
|
+
fig.savefig(output_file, dpi=200)
|
|
3345
|
+
plt.close(fig)
|
|
3346
|
+
saved[color_key] = output_file
|
|
3347
|
+
|
|
3348
|
+
return saved
|
|
3349
|
+
|
|
3350
|
+
|
|
3351
|
+
def plot_umap(
|
|
3352
|
+
adata: "ad.AnnData",
|
|
3353
|
+
*,
|
|
3354
|
+
color: str | Sequence[str],
|
|
3355
|
+
output_dir: Path | str,
|
|
3356
|
+
) -> Dict[str, Path]:
|
|
3357
|
+
"""Plot UMAP embedding with scanpy-style color options."""
|
|
3358
|
+
return plot_embedding(adata, basis="umap", color=color, output_dir=output_dir, prefix="umap")
|
|
3359
|
+
|
|
3360
|
+
|
|
3361
|
+
def plot_pca(
|
|
3362
|
+
adata: "ad.AnnData",
|
|
3363
|
+
*,
|
|
3364
|
+
color: str | Sequence[str],
|
|
3365
|
+
output_dir: Path | str,
|
|
3366
|
+
) -> Dict[str, Path]:
|
|
3367
|
+
"""Plot PCA embedding with scanpy-style color options."""
|
|
3368
|
+
return plot_embedding(adata, basis="pca", color=color, output_dir=output_dir, prefix="pca")
|