smftools 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/chimeric_adata.py +1563 -0
- smftools/cli/helpers.py +18 -2
- smftools/cli/hmm_adata.py +18 -1
- smftools/cli/latent_adata.py +522 -67
- smftools/cli/load_adata.py +2 -2
- smftools/cli/preprocess_adata.py +32 -93
- smftools/cli/recipes.py +26 -0
- smftools/cli/spatial_adata.py +23 -109
- smftools/cli/variant_adata.py +423 -0
- smftools/cli_entry.py +41 -5
- smftools/config/conversion.yaml +0 -10
- smftools/config/deaminase.yaml +3 -0
- smftools/config/default.yaml +49 -13
- smftools/config/experiment_config.py +96 -3
- smftools/constants.py +4 -0
- smftools/hmm/call_hmm_peaks.py +1 -1
- smftools/informatics/binarize_converted_base_identities.py +2 -89
- smftools/informatics/converted_BAM_to_adata.py +53 -13
- smftools/informatics/h5ad_functions.py +83 -0
- smftools/informatics/modkit_extract_to_adata.py +4 -0
- smftools/plotting/__init__.py +26 -12
- smftools/plotting/autocorrelation_plotting.py +22 -4
- smftools/plotting/chimeric_plotting.py +1893 -0
- smftools/plotting/classifiers.py +28 -14
- smftools/plotting/general_plotting.py +58 -3362
- smftools/plotting/hmm_plotting.py +1586 -2
- smftools/plotting/latent_plotting.py +804 -0
- smftools/plotting/plotting_utils.py +243 -0
- smftools/plotting/position_stats.py +16 -8
- smftools/plotting/preprocess_plotting.py +281 -0
- smftools/plotting/qc_plotting.py +8 -3
- smftools/plotting/spatial_plotting.py +1134 -0
- smftools/plotting/variant_plotting.py +1231 -0
- smftools/preprocessing/__init__.py +3 -0
- smftools/preprocessing/append_base_context.py +1 -1
- smftools/preprocessing/append_mismatch_frequency_sites.py +35 -6
- smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
- smftools/preprocessing/append_variant_call_layer.py +480 -0
- smftools/preprocessing/flag_duplicate_reads.py +4 -4
- smftools/preprocessing/invert_adata.py +1 -0
- smftools/readwrite.py +109 -85
- smftools/tools/__init__.py +6 -0
- smftools/tools/calculate_knn.py +121 -0
- smftools/tools/calculate_nmf.py +18 -7
- smftools/tools/calculate_pca.py +180 -0
- smftools/tools/calculate_umap.py +70 -154
- smftools/tools/position_stats.py +4 -4
- smftools/tools/rolling_nn_distance.py +640 -3
- smftools/tools/sequence_alignment.py +140 -0
- smftools/tools/tensor_factorization.py +52 -4
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/METADATA +3 -1
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/RECORD +56 -42
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1134 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from math import floor
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, Mapping, Optional, Sequence
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import scipy.cluster.hierarchy as sch
|
|
11
|
+
|
|
12
|
+
from smftools.logging_utils import get_logger
|
|
13
|
+
from smftools.optional_imports import require
|
|
14
|
+
from smftools.plotting.plotting_utils import (
|
|
15
|
+
_fixed_tick_positions,
|
|
16
|
+
_layer_to_numpy,
|
|
17
|
+
_methylation_fraction_for_layer,
|
|
18
|
+
_select_labels,
|
|
19
|
+
clean_barplot,
|
|
20
|
+
make_row_colors,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
|
|
24
|
+
colors = require("matplotlib.colors", extra="plotting", purpose="plot rendering")
|
|
25
|
+
grid_spec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
|
|
26
|
+
sns = require("seaborn", extra="plotting", purpose="plot styling")
|
|
27
|
+
|
|
28
|
+
logger = get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def plot_rolling_nn_and_layer(
|
|
32
|
+
subset,
|
|
33
|
+
obsm_key: str = "rolling_nn_dist",
|
|
34
|
+
layer_key: str = "nan0_0minus1",
|
|
35
|
+
meta_cols: tuple[str, ...] = ("Reference_strand", "Sample"),
|
|
36
|
+
col_cluster: bool = False,
|
|
37
|
+
fill_nn_with_colmax: bool = True,
|
|
38
|
+
fill_layer_value: float = 0.0,
|
|
39
|
+
drop_all_nan_windows: bool = True,
|
|
40
|
+
max_nan_fraction: float | None = None,
|
|
41
|
+
var_valid_fraction_col: str | None = None,
|
|
42
|
+
var_nan_fraction_col: str | None = None,
|
|
43
|
+
read_span_layer: str | None = "read_span_mask",
|
|
44
|
+
outside_read_color: str = "#bdbdbd",
|
|
45
|
+
figsize: tuple[float, float] = (14, 10),
|
|
46
|
+
right_panel_var_mask=None, # optional boolean mask over subset.var to reduce width
|
|
47
|
+
robust: bool = True,
|
|
48
|
+
title: str | None = None,
|
|
49
|
+
xtick_step: int | None = None,
|
|
50
|
+
xtick_rotation: int = 90,
|
|
51
|
+
xtick_fontsize: int = 8,
|
|
52
|
+
save_name: str | None = None,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
1) Cluster rows by subset.obsm[obsm_key] (rolling NN distances)
|
|
56
|
+
2) Plot two heatmaps side-by-side in the SAME row order, with mean barplots above:
|
|
57
|
+
- left: rolling NN distance matrix
|
|
58
|
+
- right: subset.layers[layer_key] matrix
|
|
59
|
+
|
|
60
|
+
Handles categorical/MultiIndex issues in metadata coloring.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
subset: AnnData subset with rolling NN distances stored in ``obsm``.
|
|
64
|
+
obsm_key: Key in ``subset.obsm`` containing rolling NN distances.
|
|
65
|
+
layer_key: Layer name to plot alongside rolling NN distances.
|
|
66
|
+
meta_cols: Obs columns used for row color annotations.
|
|
67
|
+
col_cluster: Whether to cluster columns in the rolling NN clustermap.
|
|
68
|
+
fill_nn_with_colmax: Fill NaNs in rolling NN distances with per-column max values.
|
|
69
|
+
fill_layer_value: Fill NaNs in the layer heatmap with this value.
|
|
70
|
+
drop_all_nan_windows: Drop rolling windows that are all NaN.
|
|
71
|
+
max_nan_fraction: Maximum allowed NaN fraction per position (filtering columns).
|
|
72
|
+
var_valid_fraction_col: ``subset.var`` column with valid fractions (1 - NaN fraction).
|
|
73
|
+
var_nan_fraction_col: ``subset.var`` column with NaN fractions.
|
|
74
|
+
read_span_layer: Layer name with read span mask; 0 values are treated as outside read.
|
|
75
|
+
outside_read_color: Color used to show positions outside each read.
|
|
76
|
+
figsize: Figure size for the combined plot.
|
|
77
|
+
right_panel_var_mask: Optional boolean mask over ``subset.var`` for the right panel.
|
|
78
|
+
robust: Use robust color scaling in seaborn.
|
|
79
|
+
title: Optional figure title (suptitle).
|
|
80
|
+
xtick_step: Spacing between x-axis tick labels.
|
|
81
|
+
xtick_rotation: Rotation for x-axis tick labels.
|
|
82
|
+
xtick_fontsize: Font size for x-axis tick labels.
|
|
83
|
+
save_name: Optional output path for saving the plot.
|
|
84
|
+
"""
|
|
85
|
+
if max_nan_fraction is not None and not (0 <= max_nan_fraction <= 1):
|
|
86
|
+
raise ValueError("max_nan_fraction must be between 0 and 1.")
|
|
87
|
+
|
|
88
|
+
logger.info("Plotting rolling NN distances with layer '%s'.", layer_key)
|
|
89
|
+
|
|
90
|
+
def _apply_xticks(ax, labels, step):
|
|
91
|
+
if labels is None or len(labels) == 0:
|
|
92
|
+
ax.set_xticks([])
|
|
93
|
+
return
|
|
94
|
+
if step is None or step <= 0:
|
|
95
|
+
step = max(1, len(labels) // 10)
|
|
96
|
+
ticks = np.arange(0, len(labels), step)
|
|
97
|
+
ax.set_xticks(ticks + 0.5)
|
|
98
|
+
ax.set_xticklabels(
|
|
99
|
+
[labels[i] for i in ticks],
|
|
100
|
+
rotation=xtick_rotation,
|
|
101
|
+
fontsize=xtick_fontsize,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def _format_labels(values):
|
|
105
|
+
values = np.asarray(values)
|
|
106
|
+
if np.issubdtype(values.dtype, np.number):
|
|
107
|
+
if np.all(np.isfinite(values)) and np.all(np.isclose(values, np.round(values))):
|
|
108
|
+
values = np.round(values).astype(int)
|
|
109
|
+
return [str(v) for v in values]
|
|
110
|
+
|
|
111
|
+
X = subset.obsm[obsm_key]
|
|
112
|
+
valid = ~np.all(np.isnan(X), axis=1)
|
|
113
|
+
|
|
114
|
+
X_df = pd.DataFrame(X[valid], index=subset.obs_names[valid])
|
|
115
|
+
|
|
116
|
+
if drop_all_nan_windows:
|
|
117
|
+
X_df = X_df.loc[:, ~X_df.isna().all(axis=0)]
|
|
118
|
+
|
|
119
|
+
X_df_filled = X_df.copy()
|
|
120
|
+
if fill_nn_with_colmax:
|
|
121
|
+
col_max = X_df_filled.max(axis=0, skipna=True)
|
|
122
|
+
X_df_filled = X_df_filled.fillna(col_max)
|
|
123
|
+
|
|
124
|
+
X_df_filled.index = X_df_filled.index.astype(str)
|
|
125
|
+
|
|
126
|
+
meta = subset.obs.loc[X_df.index, list(meta_cols)].copy()
|
|
127
|
+
meta.index = meta.index.astype(str)
|
|
128
|
+
row_colors = make_row_colors(meta)
|
|
129
|
+
|
|
130
|
+
g = sns.clustermap(
|
|
131
|
+
X_df_filled,
|
|
132
|
+
cmap="viridis",
|
|
133
|
+
col_cluster=col_cluster,
|
|
134
|
+
row_cluster=True,
|
|
135
|
+
row_colors=row_colors,
|
|
136
|
+
xticklabels=False,
|
|
137
|
+
yticklabels=False,
|
|
138
|
+
robust=robust,
|
|
139
|
+
)
|
|
140
|
+
row_order = g.dendrogram_row.reordered_ind
|
|
141
|
+
ordered_index = X_df_filled.index[row_order]
|
|
142
|
+
plt.close(g.fig)
|
|
143
|
+
|
|
144
|
+
X_ord = X_df_filled.loc[ordered_index]
|
|
145
|
+
|
|
146
|
+
L = subset.layers[layer_key]
|
|
147
|
+
L = L.toarray() if hasattr(L, "toarray") else np.asarray(L)
|
|
148
|
+
|
|
149
|
+
L_df = pd.DataFrame(L[valid], index=subset.obs_names[valid], columns=subset.var_names)
|
|
150
|
+
L_df.index = L_df.index.astype(str)
|
|
151
|
+
|
|
152
|
+
if right_panel_var_mask is not None:
|
|
153
|
+
if hasattr(right_panel_var_mask, "values"):
|
|
154
|
+
right_panel_var_mask = right_panel_var_mask.values
|
|
155
|
+
right_panel_var_mask = np.asarray(right_panel_var_mask, dtype=bool)
|
|
156
|
+
|
|
157
|
+
if max_nan_fraction is not None:
|
|
158
|
+
nan_fraction = None
|
|
159
|
+
if var_nan_fraction_col and var_nan_fraction_col in subset.var:
|
|
160
|
+
nan_fraction = pd.to_numeric(
|
|
161
|
+
subset.var[var_nan_fraction_col], errors="coerce"
|
|
162
|
+
).to_numpy()
|
|
163
|
+
elif var_valid_fraction_col and var_valid_fraction_col in subset.var:
|
|
164
|
+
valid_fraction = pd.to_numeric(
|
|
165
|
+
subset.var[var_valid_fraction_col], errors="coerce"
|
|
166
|
+
).to_numpy()
|
|
167
|
+
nan_fraction = 1 - valid_fraction
|
|
168
|
+
if nan_fraction is not None:
|
|
169
|
+
nan_mask = nan_fraction <= max_nan_fraction
|
|
170
|
+
if right_panel_var_mask is None:
|
|
171
|
+
right_panel_var_mask = nan_mask
|
|
172
|
+
else:
|
|
173
|
+
right_panel_var_mask = right_panel_var_mask & nan_mask
|
|
174
|
+
|
|
175
|
+
if right_panel_var_mask is not None:
|
|
176
|
+
if right_panel_var_mask.size != L_df.shape[1]:
|
|
177
|
+
raise ValueError("right_panel_var_mask must align with subset.var_names.")
|
|
178
|
+
L_df = L_df.loc[:, right_panel_var_mask]
|
|
179
|
+
|
|
180
|
+
read_span_mask = None
|
|
181
|
+
if read_span_layer and read_span_layer in subset.layers:
|
|
182
|
+
span = subset.layers[read_span_layer]
|
|
183
|
+
span = span.toarray() if hasattr(span, "toarray") else np.asarray(span)
|
|
184
|
+
span_df = pd.DataFrame(span[valid], index=subset.obs_names[valid], columns=subset.var_names)
|
|
185
|
+
span_df.index = span_df.index.astype(str)
|
|
186
|
+
if right_panel_var_mask is not None:
|
|
187
|
+
span_df = span_df.loc[:, right_panel_var_mask]
|
|
188
|
+
read_span_mask = span_df.loc[ordered_index].to_numpy() == 0
|
|
189
|
+
|
|
190
|
+
L_ord = L_df.loc[ordered_index]
|
|
191
|
+
L_plot = L_ord.fillna(fill_layer_value)
|
|
192
|
+
if read_span_mask is not None:
|
|
193
|
+
L_plot = L_plot.mask(read_span_mask)
|
|
194
|
+
|
|
195
|
+
fig = plt.figure(figsize=figsize)
|
|
196
|
+
gs = fig.add_gridspec(
|
|
197
|
+
2,
|
|
198
|
+
4,
|
|
199
|
+
width_ratios=[1, 0.05, 1, 0.05],
|
|
200
|
+
height_ratios=[1, 6],
|
|
201
|
+
wspace=0.2,
|
|
202
|
+
hspace=0.05,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
ax1 = fig.add_subplot(gs[1, 0])
|
|
206
|
+
ax1_cbar = fig.add_subplot(gs[1, 1])
|
|
207
|
+
ax2 = fig.add_subplot(gs[1, 2])
|
|
208
|
+
ax2_cbar = fig.add_subplot(gs[1, 3])
|
|
209
|
+
ax1_bar = fig.add_subplot(gs[0, 0], sharex=ax1)
|
|
210
|
+
ax2_bar = fig.add_subplot(gs[0, 2], sharex=ax2)
|
|
211
|
+
fig.add_subplot(gs[0, 1]).axis("off")
|
|
212
|
+
fig.add_subplot(gs[0, 3]).axis("off")
|
|
213
|
+
|
|
214
|
+
mean_nn = np.nanmean(X_ord.to_numpy(), axis=0)
|
|
215
|
+
clean_barplot(
|
|
216
|
+
ax1_bar,
|
|
217
|
+
mean_nn,
|
|
218
|
+
obsm_key,
|
|
219
|
+
y_max=None,
|
|
220
|
+
y_label="Mean distance",
|
|
221
|
+
y_ticks=None,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
sns.heatmap(
|
|
225
|
+
X_ord,
|
|
226
|
+
ax=ax1,
|
|
227
|
+
cmap="viridis",
|
|
228
|
+
xticklabels=False,
|
|
229
|
+
yticklabels=False,
|
|
230
|
+
robust=robust,
|
|
231
|
+
cbar_ax=ax1_cbar,
|
|
232
|
+
)
|
|
233
|
+
label_source = subset.uns.get(f"{obsm_key}_centers")
|
|
234
|
+
if label_source is None:
|
|
235
|
+
label_source = subset.uns.get(f"{obsm_key}_starts")
|
|
236
|
+
if label_source is not None:
|
|
237
|
+
label_source = np.asarray(label_source)
|
|
238
|
+
window_labels = _format_labels(label_source)
|
|
239
|
+
try:
|
|
240
|
+
col_idx = X_ord.columns.to_numpy()
|
|
241
|
+
if np.issubdtype(col_idx.dtype, np.number):
|
|
242
|
+
col_idx = col_idx.astype(int)
|
|
243
|
+
if col_idx.size and col_idx.max() < len(label_source):
|
|
244
|
+
window_labels = _format_labels(label_source[col_idx])
|
|
245
|
+
except Exception:
|
|
246
|
+
window_labels = _format_labels(label_source)
|
|
247
|
+
_apply_xticks(ax1, window_labels, xtick_step)
|
|
248
|
+
|
|
249
|
+
methylation_fraction = _methylation_fraction_for_layer(L_ord.to_numpy(), layer_key)
|
|
250
|
+
clean_barplot(
|
|
251
|
+
ax2_bar,
|
|
252
|
+
methylation_fraction,
|
|
253
|
+
layer_key,
|
|
254
|
+
y_max=1.0,
|
|
255
|
+
y_label="Methylation fraction",
|
|
256
|
+
y_ticks=[0.0, 0.5, 1.0],
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
layer_cmap = plt.get_cmap("coolwarm").copy()
|
|
260
|
+
if read_span_mask is not None:
|
|
261
|
+
layer_cmap.set_bad(outside_read_color)
|
|
262
|
+
|
|
263
|
+
sns.heatmap(
|
|
264
|
+
L_plot,
|
|
265
|
+
ax=ax2,
|
|
266
|
+
cmap=layer_cmap,
|
|
267
|
+
xticklabels=False,
|
|
268
|
+
yticklabels=False,
|
|
269
|
+
robust=robust,
|
|
270
|
+
cbar_ax=ax2_cbar,
|
|
271
|
+
)
|
|
272
|
+
_apply_xticks(ax2, [str(x) for x in L_plot.columns], xtick_step)
|
|
273
|
+
|
|
274
|
+
if title:
|
|
275
|
+
fig.suptitle(title)
|
|
276
|
+
|
|
277
|
+
if save_name is not None:
|
|
278
|
+
fname = os.path.join(save_name)
|
|
279
|
+
plt.savefig(fname, dpi=200, bbox_inches="tight")
|
|
280
|
+
logger.info("Saved rolling NN/layer plot to %s.", fname)
|
|
281
|
+
else:
|
|
282
|
+
plt.show()
|
|
283
|
+
|
|
284
|
+
return ordered_index
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def plot_zero_hamming_span_and_layer(
|
|
288
|
+
subset,
|
|
289
|
+
span_layer_key: str,
|
|
290
|
+
layer_key: str = "nan0_0minus1",
|
|
291
|
+
meta_cols: tuple[str, ...] = ("Reference_strand", "Sample"),
|
|
292
|
+
col_cluster: bool = False,
|
|
293
|
+
fill_span_value: float = 0.0,
|
|
294
|
+
fill_layer_value: float = 0.0,
|
|
295
|
+
drop_all_nan_positions: bool = True,
|
|
296
|
+
max_nan_fraction: float | None = None,
|
|
297
|
+
var_valid_fraction_col: str | None = None,
|
|
298
|
+
var_nan_fraction_col: str | None = None,
|
|
299
|
+
read_span_layer: str | None = "read_span_mask",
|
|
300
|
+
outside_read_color: str = "#bdbdbd",
|
|
301
|
+
span_color: str = "#2ca25f",
|
|
302
|
+
figsize: tuple[float, float] = (14, 10),
|
|
303
|
+
robust: bool = True,
|
|
304
|
+
title: str | None = None,
|
|
305
|
+
xtick_step: int | None = None,
|
|
306
|
+
xtick_rotation: int = 90,
|
|
307
|
+
xtick_fontsize: int = 8,
|
|
308
|
+
save_name: str | None = None,
|
|
309
|
+
):
|
|
310
|
+
"""
|
|
311
|
+
Plot zero-Hamming span clustermap alongside a layer clustermap.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
subset: AnnData subset with zero-Hamming span annotations stored in ``layers``.
|
|
315
|
+
span_layer_key: Layer name with the binary zero-Hamming span mask.
|
|
316
|
+
layer_key: Layer name to plot alongside the span mask.
|
|
317
|
+
meta_cols: Obs columns used for row color annotations.
|
|
318
|
+
col_cluster: Whether to cluster columns in the span mask clustermap.
|
|
319
|
+
fill_span_value: Value to fill NaNs in the span mask.
|
|
320
|
+
fill_layer_value: Value to fill NaNs in the layer heatmap.
|
|
321
|
+
drop_all_nan_positions: Drop positions that are all NaN in the span mask.
|
|
322
|
+
max_nan_fraction: Maximum allowed NaN fraction per position (filtering columns).
|
|
323
|
+
var_valid_fraction_col: ``subset.var`` column with valid fractions (1 - NaN fraction).
|
|
324
|
+
var_nan_fraction_col: ``subset.var`` column with NaN fractions.
|
|
325
|
+
read_span_layer: Layer name with read span mask; 0 values are treated as outside read.
|
|
326
|
+
outside_read_color: Color used to show positions outside each read.
|
|
327
|
+
span_color: Color for zero-Hamming span mask values.
|
|
328
|
+
figsize: Figure size for the combined plot.
|
|
329
|
+
robust: Use robust color scaling in seaborn.
|
|
330
|
+
title: Optional figure title (suptitle).
|
|
331
|
+
xtick_step: Spacing between x-axis tick labels.
|
|
332
|
+
xtick_rotation: Rotation for x-axis tick labels.
|
|
333
|
+
xtick_fontsize: Font size for x-axis tick labels.
|
|
334
|
+
save_name: Optional output path for saving the plot.
|
|
335
|
+
"""
|
|
336
|
+
if max_nan_fraction is not None and not (0 <= max_nan_fraction <= 1):
|
|
337
|
+
raise ValueError("max_nan_fraction must be between 0 and 1.")
|
|
338
|
+
|
|
339
|
+
logger.info(
|
|
340
|
+
"Plotting zero-Hamming span mask '%s' with layer '%s'.",
|
|
341
|
+
span_layer_key,
|
|
342
|
+
layer_key,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def _apply_xticks(ax, labels, step):
|
|
346
|
+
if labels is None or len(labels) == 0:
|
|
347
|
+
ax.set_xticks([])
|
|
348
|
+
return
|
|
349
|
+
if step is None or step <= 0:
|
|
350
|
+
step = max(1, len(labels) // 10)
|
|
351
|
+
ticks = np.arange(0, len(labels), step)
|
|
352
|
+
ax.set_xticks(ticks + 0.5)
|
|
353
|
+
ax.set_xticklabels(
|
|
354
|
+
[labels[i] for i in ticks],
|
|
355
|
+
rotation=xtick_rotation,
|
|
356
|
+
fontsize=xtick_fontsize,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
span = subset.layers[span_layer_key]
|
|
360
|
+
span = span.toarray() if hasattr(span, "toarray") else np.asarray(span)
|
|
361
|
+
span_df = pd.DataFrame(span, index=subset.obs_names, columns=subset.var_names)
|
|
362
|
+
span_df.index = span_df.index.astype(str)
|
|
363
|
+
|
|
364
|
+
if drop_all_nan_positions:
|
|
365
|
+
span_df = span_df.loc[:, ~span_df.isna().all(axis=0)]
|
|
366
|
+
|
|
367
|
+
nan_mask = None
|
|
368
|
+
if max_nan_fraction is not None:
|
|
369
|
+
nan_fraction = None
|
|
370
|
+
if var_nan_fraction_col and var_nan_fraction_col in subset.var:
|
|
371
|
+
nan_fraction = pd.to_numeric(
|
|
372
|
+
subset.var[var_nan_fraction_col], errors="coerce"
|
|
373
|
+
).to_numpy()
|
|
374
|
+
elif var_valid_fraction_col and var_valid_fraction_col in subset.var:
|
|
375
|
+
valid_fraction = pd.to_numeric(
|
|
376
|
+
subset.var[var_valid_fraction_col], errors="coerce"
|
|
377
|
+
).to_numpy()
|
|
378
|
+
nan_fraction = 1 - valid_fraction
|
|
379
|
+
if nan_fraction is not None:
|
|
380
|
+
nan_mask = nan_fraction <= max_nan_fraction
|
|
381
|
+
span_df = span_df.loc[:, nan_mask]
|
|
382
|
+
|
|
383
|
+
span_df_filled = span_df.fillna(fill_span_value)
|
|
384
|
+
span_df_filled.index = span_df_filled.index.astype(str)
|
|
385
|
+
|
|
386
|
+
meta = subset.obs.loc[span_df.index, list(meta_cols)].copy()
|
|
387
|
+
meta.index = meta.index.astype(str)
|
|
388
|
+
row_colors = make_row_colors(meta)
|
|
389
|
+
|
|
390
|
+
span_cmap = colors.ListedColormap(["white", span_color])
|
|
391
|
+
span_norm = colors.BoundaryNorm([-0.5, 0.5, 1.5], span_cmap.N)
|
|
392
|
+
|
|
393
|
+
g = sns.clustermap(
|
|
394
|
+
span_df_filled,
|
|
395
|
+
cmap=span_cmap,
|
|
396
|
+
norm=span_norm,
|
|
397
|
+
col_cluster=col_cluster,
|
|
398
|
+
row_cluster=True,
|
|
399
|
+
row_colors=row_colors,
|
|
400
|
+
xticklabels=False,
|
|
401
|
+
yticklabels=False,
|
|
402
|
+
robust=robust,
|
|
403
|
+
)
|
|
404
|
+
row_order = g.dendrogram_row.reordered_ind
|
|
405
|
+
ordered_index = span_df_filled.index[row_order]
|
|
406
|
+
plt.close(g.fig)
|
|
407
|
+
|
|
408
|
+
span_ord = span_df_filled.loc[ordered_index]
|
|
409
|
+
|
|
410
|
+
layer = subset.layers[layer_key]
|
|
411
|
+
layer = layer.toarray() if hasattr(layer, "toarray") else np.asarray(layer)
|
|
412
|
+
layer_df = pd.DataFrame(layer, index=subset.obs_names, columns=subset.var_names)
|
|
413
|
+
layer_df.index = layer_df.index.astype(str)
|
|
414
|
+
|
|
415
|
+
if max_nan_fraction is not None and nan_mask is not None:
|
|
416
|
+
layer_df = layer_df.loc[:, nan_mask]
|
|
417
|
+
|
|
418
|
+
read_span_mask = None
|
|
419
|
+
if read_span_layer and read_span_layer in subset.layers:
|
|
420
|
+
span_mask = subset.layers[read_span_layer]
|
|
421
|
+
span_mask = span_mask.toarray() if hasattr(span_mask, "toarray") else np.asarray(span_mask)
|
|
422
|
+
span_mask_df = pd.DataFrame(span_mask, index=subset.obs_names, columns=subset.var_names)
|
|
423
|
+
span_mask_df.index = span_mask_df.index.astype(str)
|
|
424
|
+
if max_nan_fraction is not None and nan_mask is not None:
|
|
425
|
+
span_mask_df = span_mask_df.loc[:, nan_mask]
|
|
426
|
+
read_span_mask = span_mask_df.loc[ordered_index].to_numpy() == 0
|
|
427
|
+
|
|
428
|
+
layer_ord = layer_df.loc[ordered_index]
|
|
429
|
+
layer_plot = layer_ord.fillna(fill_layer_value)
|
|
430
|
+
if read_span_mask is not None:
|
|
431
|
+
layer_plot = layer_plot.mask(read_span_mask)
|
|
432
|
+
|
|
433
|
+
fig = plt.figure(figsize=figsize)
|
|
434
|
+
gs = fig.add_gridspec(
|
|
435
|
+
2,
|
|
436
|
+
4,
|
|
437
|
+
width_ratios=[1, 0.05, 1, 0.05],
|
|
438
|
+
height_ratios=[1, 6],
|
|
439
|
+
wspace=0.2,
|
|
440
|
+
hspace=0.05,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
ax1 = fig.add_subplot(gs[1, 0])
|
|
444
|
+
ax1_cbar = fig.add_subplot(gs[1, 1])
|
|
445
|
+
ax2 = fig.add_subplot(gs[1, 2])
|
|
446
|
+
ax2_cbar = fig.add_subplot(gs[1, 3])
|
|
447
|
+
ax1_bar = fig.add_subplot(gs[0, 0], sharex=ax1)
|
|
448
|
+
ax2_bar = fig.add_subplot(gs[0, 2], sharex=ax2)
|
|
449
|
+
fig.add_subplot(gs[0, 1]).axis("off")
|
|
450
|
+
fig.add_subplot(gs[0, 3]).axis("off")
|
|
451
|
+
|
|
452
|
+
mean_span = np.nanmean(span_ord.to_numpy(), axis=0)
|
|
453
|
+
clean_barplot(
|
|
454
|
+
ax1_bar,
|
|
455
|
+
mean_span,
|
|
456
|
+
span_layer_key,
|
|
457
|
+
y_max=1.0,
|
|
458
|
+
y_label="Span fraction",
|
|
459
|
+
y_ticks=[0.0, 0.5, 1.0],
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
methylation_fraction = _methylation_fraction_for_layer(layer_ord.to_numpy(), layer_key)
|
|
463
|
+
clean_barplot(
|
|
464
|
+
ax2_bar,
|
|
465
|
+
methylation_fraction,
|
|
466
|
+
layer_key,
|
|
467
|
+
y_max=1.0,
|
|
468
|
+
y_label="Methylation fraction",
|
|
469
|
+
y_ticks=[0.0, 0.5, 1.0],
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
sns.heatmap(
|
|
473
|
+
span_ord,
|
|
474
|
+
ax=ax1,
|
|
475
|
+
cmap=span_cmap,
|
|
476
|
+
norm=span_norm,
|
|
477
|
+
xticklabels=False,
|
|
478
|
+
yticklabels=False,
|
|
479
|
+
robust=robust,
|
|
480
|
+
cbar_ax=ax1_cbar,
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
layer_cmap = plt.get_cmap("coolwarm").copy()
|
|
484
|
+
if read_span_mask is not None:
|
|
485
|
+
layer_cmap.set_bad(outside_read_color)
|
|
486
|
+
|
|
487
|
+
sns.heatmap(
|
|
488
|
+
layer_plot,
|
|
489
|
+
ax=ax2,
|
|
490
|
+
cmap=layer_cmap,
|
|
491
|
+
xticklabels=False,
|
|
492
|
+
yticklabels=False,
|
|
493
|
+
robust=robust,
|
|
494
|
+
cbar_ax=ax2_cbar,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
_apply_xticks(ax1, [str(x) for x in span_ord.columns], xtick_step)
|
|
498
|
+
_apply_xticks(ax2, [str(x) for x in layer_plot.columns], xtick_step)
|
|
499
|
+
|
|
500
|
+
if title:
|
|
501
|
+
fig.suptitle(title)
|
|
502
|
+
|
|
503
|
+
if save_name is not None:
|
|
504
|
+
fname = os.path.join(save_name)
|
|
505
|
+
plt.savefig(fname, dpi=200, bbox_inches="tight")
|
|
506
|
+
logger.info("Saved zero-Hamming span/layer plot to %s.", fname)
|
|
507
|
+
else:
|
|
508
|
+
plt.show()
|
|
509
|
+
|
|
510
|
+
return ordered_index
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def _window_center_labels(var_names: Sequence, starts: np.ndarray, window: int) -> list[str]:
|
|
514
|
+
coords = np.asarray(var_names)
|
|
515
|
+
if coords.size == 0:
|
|
516
|
+
return []
|
|
517
|
+
try:
|
|
518
|
+
coords_numeric = coords.astype(float)
|
|
519
|
+
centers = np.array(
|
|
520
|
+
[floor(np.nanmean(coords_numeric[s : s + window])) for s in starts], dtype=float
|
|
521
|
+
)
|
|
522
|
+
return [str(c) for c in centers]
|
|
523
|
+
except Exception:
|
|
524
|
+
mid = np.clip(starts + (window // 2), 0, coords.size - 1)
|
|
525
|
+
return [str(coords[idx]) for idx in mid]
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def plot_zero_hamming_pair_counts(
|
|
529
|
+
subset,
|
|
530
|
+
zero_pairs_uns_key: str,
|
|
531
|
+
meta_cols: tuple[str, ...] = ("Reference_strand", "Sample"),
|
|
532
|
+
col_cluster: bool = False,
|
|
533
|
+
figsize: tuple[float, float] = (14, 10),
|
|
534
|
+
robust: bool = True,
|
|
535
|
+
title: str | None = None,
|
|
536
|
+
xtick_step: int | None = None,
|
|
537
|
+
xtick_rotation: int = 90,
|
|
538
|
+
xtick_fontsize: int = 8,
|
|
539
|
+
save_name: str | None = None,
|
|
540
|
+
):
|
|
541
|
+
"""
|
|
542
|
+
Plot a heatmap of zero-Hamming pair counts per read across rolling windows.
|
|
543
|
+
|
|
544
|
+
Args:
|
|
545
|
+
subset: AnnData subset containing zero-pair window data in ``.uns``.
|
|
546
|
+
zero_pairs_uns_key: Key in ``subset.uns`` with zero-pair window data.
|
|
547
|
+
meta_cols: Obs columns used for row color annotations.
|
|
548
|
+
col_cluster: Whether to cluster columns in the heatmap.
|
|
549
|
+
figsize: Figure size for the plot.
|
|
550
|
+
robust: Use robust color scaling in seaborn.
|
|
551
|
+
title: Optional figure title (suptitle).
|
|
552
|
+
xtick_step: Spacing between x-axis tick labels.
|
|
553
|
+
xtick_rotation: Rotation for x-axis tick labels.
|
|
554
|
+
xtick_fontsize: Font size for x-axis tick labels.
|
|
555
|
+
save_name: Optional output path for saving the plot.
|
|
556
|
+
"""
|
|
557
|
+
if zero_pairs_uns_key not in subset.uns:
|
|
558
|
+
raise KeyError(f"Missing zero-pair data in subset.uns[{zero_pairs_uns_key!r}].")
|
|
559
|
+
|
|
560
|
+
zero_pairs_by_window = subset.uns[zero_pairs_uns_key]
|
|
561
|
+
starts = np.asarray(subset.uns.get(f"{zero_pairs_uns_key}_starts", []))
|
|
562
|
+
window = int(subset.uns.get(f"{zero_pairs_uns_key}_window", 0))
|
|
563
|
+
|
|
564
|
+
n_windows = len(zero_pairs_by_window)
|
|
565
|
+
counts = np.zeros((subset.n_obs, n_windows), dtype=int)
|
|
566
|
+
|
|
567
|
+
for wi, pairs in enumerate(zero_pairs_by_window):
|
|
568
|
+
if pairs is None or len(pairs) == 0:
|
|
569
|
+
continue
|
|
570
|
+
pair_arr = np.asarray(pairs, dtype=int)
|
|
571
|
+
if pair_arr.size == 0:
|
|
572
|
+
continue
|
|
573
|
+
if pair_arr.ndim != 2 or pair_arr.shape[1] != 2:
|
|
574
|
+
raise ValueError("Zero-pair entries must be arrays of shape (n, 2).")
|
|
575
|
+
np.add.at(counts[:, wi], pair_arr[:, 0], 1)
|
|
576
|
+
np.add.at(counts[:, wi], pair_arr[:, 1], 1)
|
|
577
|
+
|
|
578
|
+
if starts.size == n_windows and window > 0:
|
|
579
|
+
labels = _window_center_labels(subset.var_names, starts, window)
|
|
580
|
+
else:
|
|
581
|
+
labels = [str(i) for i in range(n_windows)]
|
|
582
|
+
|
|
583
|
+
counts_df = pd.DataFrame(counts, index=subset.obs_names.astype(str), columns=labels)
|
|
584
|
+
meta = subset.obs.loc[counts_df.index, list(meta_cols)].copy()
|
|
585
|
+
meta.index = meta.index.astype(str)
|
|
586
|
+
row_colors = make_row_colors(meta)
|
|
587
|
+
|
|
588
|
+
def _apply_xticks(ax, labels, step):
|
|
589
|
+
if labels is None or len(labels) == 0:
|
|
590
|
+
ax.set_xticks([])
|
|
591
|
+
return
|
|
592
|
+
if step is None or step <= 0:
|
|
593
|
+
step = max(1, len(labels) // 10)
|
|
594
|
+
ticks = np.arange(0, len(labels), step)
|
|
595
|
+
ax.set_xticks(ticks + 0.5)
|
|
596
|
+
ax.set_xticklabels(
|
|
597
|
+
[labels[i] for i in ticks],
|
|
598
|
+
rotation=xtick_rotation,
|
|
599
|
+
fontsize=xtick_fontsize,
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
g = sns.clustermap(
|
|
603
|
+
counts_df,
|
|
604
|
+
cmap="viridis",
|
|
605
|
+
col_cluster=col_cluster,
|
|
606
|
+
row_cluster=True,
|
|
607
|
+
row_colors=row_colors,
|
|
608
|
+
xticklabels=False,
|
|
609
|
+
yticklabels=False,
|
|
610
|
+
figsize=figsize,
|
|
611
|
+
robust=robust,
|
|
612
|
+
)
|
|
613
|
+
_apply_xticks(g.ax_heatmap, labels, xtick_step)
|
|
614
|
+
|
|
615
|
+
if title:
|
|
616
|
+
g.fig.suptitle(title)
|
|
617
|
+
|
|
618
|
+
if save_name is not None:
|
|
619
|
+
fname = os.path.join(save_name)
|
|
620
|
+
g.fig.savefig(fname, dpi=200, bbox_inches="tight")
|
|
621
|
+
logger.info("Saved zero-Hamming pair count plot to %s.", fname)
|
|
622
|
+
else:
|
|
623
|
+
plt.show()
|
|
624
|
+
|
|
625
|
+
return g
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
def combined_raw_clustermap(
|
|
629
|
+
adata,
|
|
630
|
+
sample_col: str = "Sample_Names",
|
|
631
|
+
reference_col: str = "Reference_strand",
|
|
632
|
+
mod_target_bases: Sequence[str] = ("GpC", "CpG"),
|
|
633
|
+
layer_c: str = "nan0_0minus1",
|
|
634
|
+
layer_gpc: str = "nan0_0minus1",
|
|
635
|
+
layer_cpg: str = "nan0_0minus1",
|
|
636
|
+
layer_a: str = "nan0_0minus1",
|
|
637
|
+
cmap_c: str = "coolwarm",
|
|
638
|
+
cmap_gpc: str = "coolwarm",
|
|
639
|
+
cmap_cpg: str = "viridis",
|
|
640
|
+
cmap_a: str = "coolwarm",
|
|
641
|
+
min_quality: float | None = 20,
|
|
642
|
+
min_length: int | None = 200,
|
|
643
|
+
min_mapped_length_to_reference_length_ratio: float | None = 0,
|
|
644
|
+
min_position_valid_fraction: float | None = 0,
|
|
645
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
646
|
+
sample_mapping: Optional[Mapping[str, str]] = None,
|
|
647
|
+
save_path: str | Path | None = None,
|
|
648
|
+
sort_by: str = "gpc", # 'gpc','cpg','c','gpc_cpg','a','none','obs:<col>'
|
|
649
|
+
bins: Optional[Dict[str, Any]] = None,
|
|
650
|
+
deaminase: bool = False,
|
|
651
|
+
min_signal: float = 0,
|
|
652
|
+
n_xticks_any_c: int = 10,
|
|
653
|
+
n_xticks_gpc: int = 10,
|
|
654
|
+
n_xticks_cpg: int = 10,
|
|
655
|
+
n_xticks_any_a: int = 10,
|
|
656
|
+
xtick_rotation: int = 90,
|
|
657
|
+
xtick_fontsize: int = 9,
|
|
658
|
+
index_col_suffix: str | None = None,
|
|
659
|
+
fill_nan_strategy: str = "value",
|
|
660
|
+
fill_nan_value: float = -1,
|
|
661
|
+
):
|
|
662
|
+
"""
|
|
663
|
+
Plot stacked heatmaps + per-position mean barplots for C, GpC, CpG, and optional A.
|
|
664
|
+
|
|
665
|
+
Key fixes vs old version:
|
|
666
|
+
- order computed ONCE per bin, applied to all matrices
|
|
667
|
+
- no hard-coded axes indices
|
|
668
|
+
- NaNs excluded from methylation denominators
|
|
669
|
+
- var_names not forced to int
|
|
670
|
+
- fixed count of x tick labels per block (controllable)
|
|
671
|
+
- optional NaN fill strategy for clustering/plotting (in-memory only)
|
|
672
|
+
- adata.uns updated once at end
|
|
673
|
+
|
|
674
|
+
Returns
|
|
675
|
+
-------
|
|
676
|
+
results : list[dict]
|
|
677
|
+
One entry per (sample, ref) plot with matrices + bin metadata.
|
|
678
|
+
"""
|
|
679
|
+
logger.info("Plotting combined raw clustermaps.")
|
|
680
|
+
if fill_nan_strategy not in {"none", "value", "col_mean"}:
|
|
681
|
+
raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
|
|
682
|
+
|
|
683
|
+
def _mask_or_true(series_name: str, predicate):
|
|
684
|
+
"""Return a mask from predicate or an all-True mask."""
|
|
685
|
+
if series_name not in adata.obs:
|
|
686
|
+
return pd.Series(True, index=adata.obs.index)
|
|
687
|
+
s = adata.obs[series_name]
|
|
688
|
+
try:
|
|
689
|
+
return predicate(s)
|
|
690
|
+
except Exception:
|
|
691
|
+
return pd.Series(True, index=adata.obs.index)
|
|
692
|
+
|
|
693
|
+
results: list[Dict[str, Any]] = []
|
|
694
|
+
save_path = Path(save_path) if save_path is not None else None
|
|
695
|
+
if save_path is not None:
|
|
696
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
697
|
+
|
|
698
|
+
for col in (sample_col, reference_col):
|
|
699
|
+
if col not in adata.obs:
|
|
700
|
+
raise KeyError(f"{col} not in adata.obs")
|
|
701
|
+
if not isinstance(adata.obs[col].dtype, pd.CategoricalDtype):
|
|
702
|
+
adata.obs[col] = adata.obs[col].astype("category")
|
|
703
|
+
|
|
704
|
+
base_set = set(mod_target_bases)
|
|
705
|
+
include_any_c = any(b in {"C", "CpG", "GpC"} for b in base_set)
|
|
706
|
+
include_any_a = "A" in base_set
|
|
707
|
+
|
|
708
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
709
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
710
|
+
display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
|
|
711
|
+
|
|
712
|
+
qmask = _mask_or_true(
|
|
713
|
+
"read_quality",
|
|
714
|
+
(lambda s: s >= float(min_quality))
|
|
715
|
+
if (min_quality is not None)
|
|
716
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
717
|
+
)
|
|
718
|
+
lm_mask = _mask_or_true(
|
|
719
|
+
"mapped_length",
|
|
720
|
+
(lambda s: s >= float(min_length))
|
|
721
|
+
if (min_length is not None)
|
|
722
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
723
|
+
)
|
|
724
|
+
lrr_mask = _mask_or_true(
|
|
725
|
+
"mapped_length_to_reference_length_ratio",
|
|
726
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
727
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
728
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
demux_mask = _mask_or_true(
|
|
732
|
+
"demux_type",
|
|
733
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
734
|
+
if (demux_types is not None)
|
|
735
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
ref_mask = adata.obs[reference_col] == ref
|
|
739
|
+
sample_mask = adata.obs[sample_col] == sample
|
|
740
|
+
|
|
741
|
+
row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
|
|
742
|
+
|
|
743
|
+
if not bool(row_mask.any()):
|
|
744
|
+
logger.warning(
|
|
745
|
+
"No reads for %s - %s after read quality and length filtering.",
|
|
746
|
+
display_sample,
|
|
747
|
+
ref,
|
|
748
|
+
)
|
|
749
|
+
continue
|
|
750
|
+
|
|
751
|
+
try:
|
|
752
|
+
subset = adata[row_mask, :].copy()
|
|
753
|
+
|
|
754
|
+
if min_position_valid_fraction is not None:
|
|
755
|
+
valid_key = f"{ref}_valid_fraction"
|
|
756
|
+
if valid_key in subset.var:
|
|
757
|
+
v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
|
|
758
|
+
col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
|
|
759
|
+
if col_mask.any():
|
|
760
|
+
subset = subset[:, col_mask].copy()
|
|
761
|
+
else:
|
|
762
|
+
logger.warning(
|
|
763
|
+
"No positions left after valid_fraction filter for %s - %s.",
|
|
764
|
+
display_sample,
|
|
765
|
+
ref,
|
|
766
|
+
)
|
|
767
|
+
continue
|
|
768
|
+
|
|
769
|
+
if subset.shape[0] == 0:
|
|
770
|
+
logger.warning(
|
|
771
|
+
"No reads left after filtering for %s - %s.", display_sample, ref
|
|
772
|
+
)
|
|
773
|
+
continue
|
|
774
|
+
|
|
775
|
+
if bins is None:
|
|
776
|
+
bins_temp = {"All": (subset.obs[reference_col] == ref)}
|
|
777
|
+
else:
|
|
778
|
+
bins_temp = bins
|
|
779
|
+
|
|
780
|
+
any_c_sites = gpc_sites = cpg_sites = np.array([], dtype=int)
|
|
781
|
+
any_a_sites = np.array([], dtype=int)
|
|
782
|
+
|
|
783
|
+
num_any_c = num_gpc = num_cpg = num_any_a = 0
|
|
784
|
+
|
|
785
|
+
if include_any_c:
|
|
786
|
+
any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
|
|
787
|
+
gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
|
|
788
|
+
cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
|
|
789
|
+
|
|
790
|
+
num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
|
|
791
|
+
|
|
792
|
+
any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
|
|
793
|
+
gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
|
|
794
|
+
cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
|
|
795
|
+
|
|
796
|
+
if include_any_a:
|
|
797
|
+
any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
|
|
798
|
+
num_any_a = len(any_a_sites)
|
|
799
|
+
any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
|
|
800
|
+
|
|
801
|
+
stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
|
|
802
|
+
stacked_any_c_raw, stacked_gpc_raw, stacked_cpg_raw, stacked_any_a_raw = (
|
|
803
|
+
[],
|
|
804
|
+
[],
|
|
805
|
+
[],
|
|
806
|
+
[],
|
|
807
|
+
)
|
|
808
|
+
row_labels, bin_labels, bin_boundaries = [], [], []
|
|
809
|
+
percentages = {}
|
|
810
|
+
last_idx = 0
|
|
811
|
+
total_reads = subset.shape[0]
|
|
812
|
+
|
|
813
|
+
for bin_label, bin_filter in bins_temp.items():
|
|
814
|
+
subset_bin = subset[bin_filter].copy()
|
|
815
|
+
num_reads = subset_bin.shape[0]
|
|
816
|
+
if num_reads == 0:
|
|
817
|
+
percentages[bin_label] = 0.0
|
|
818
|
+
continue
|
|
819
|
+
|
|
820
|
+
percent_reads = (num_reads / total_reads) * 100
|
|
821
|
+
percentages[bin_label] = percent_reads
|
|
822
|
+
|
|
823
|
+
if sort_by.startswith("obs:"):
|
|
824
|
+
colname = sort_by.split("obs:")[1]
|
|
825
|
+
order = np.argsort(subset_bin.obs[colname].values)
|
|
826
|
+
|
|
827
|
+
elif sort_by == "gpc" and num_gpc > 0:
|
|
828
|
+
gpc_matrix = _layer_to_numpy(
|
|
829
|
+
subset_bin,
|
|
830
|
+
layer_gpc,
|
|
831
|
+
gpc_sites,
|
|
832
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
833
|
+
fill_nan_value=fill_nan_value,
|
|
834
|
+
)
|
|
835
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
836
|
+
order = sch.leaves_list(linkage)
|
|
837
|
+
|
|
838
|
+
elif sort_by == "cpg" and num_cpg > 0:
|
|
839
|
+
cpg_matrix = _layer_to_numpy(
|
|
840
|
+
subset_bin,
|
|
841
|
+
layer_cpg,
|
|
842
|
+
cpg_sites,
|
|
843
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
844
|
+
fill_nan_value=fill_nan_value,
|
|
845
|
+
)
|
|
846
|
+
linkage = sch.linkage(cpg_matrix, method="ward")
|
|
847
|
+
order = sch.leaves_list(linkage)
|
|
848
|
+
|
|
849
|
+
elif sort_by == "c" and num_any_c > 0:
|
|
850
|
+
any_c_matrix = _layer_to_numpy(
|
|
851
|
+
subset_bin,
|
|
852
|
+
layer_c,
|
|
853
|
+
any_c_sites,
|
|
854
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
855
|
+
fill_nan_value=fill_nan_value,
|
|
856
|
+
)
|
|
857
|
+
linkage = sch.linkage(any_c_matrix, method="ward")
|
|
858
|
+
order = sch.leaves_list(linkage)
|
|
859
|
+
|
|
860
|
+
elif sort_by == "gpc_cpg":
|
|
861
|
+
gpc_matrix = _layer_to_numpy(
|
|
862
|
+
subset_bin,
|
|
863
|
+
layer_gpc,
|
|
864
|
+
None,
|
|
865
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
866
|
+
fill_nan_value=fill_nan_value,
|
|
867
|
+
)
|
|
868
|
+
linkage = sch.linkage(gpc_matrix, method="ward")
|
|
869
|
+
order = sch.leaves_list(linkage)
|
|
870
|
+
|
|
871
|
+
elif sort_by == "a" and num_any_a > 0:
|
|
872
|
+
any_a_matrix = _layer_to_numpy(
|
|
873
|
+
subset_bin,
|
|
874
|
+
layer_a,
|
|
875
|
+
any_a_sites,
|
|
876
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
877
|
+
fill_nan_value=fill_nan_value,
|
|
878
|
+
)
|
|
879
|
+
linkage = sch.linkage(any_a_matrix, method="ward")
|
|
880
|
+
order = sch.leaves_list(linkage)
|
|
881
|
+
|
|
882
|
+
elif sort_by == "none":
|
|
883
|
+
order = np.arange(num_reads)
|
|
884
|
+
|
|
885
|
+
else:
|
|
886
|
+
order = np.arange(num_reads)
|
|
887
|
+
|
|
888
|
+
subset_bin = subset_bin[order]
|
|
889
|
+
|
|
890
|
+
if include_any_c and num_any_c > 0:
|
|
891
|
+
stacked_any_c.append(
|
|
892
|
+
_layer_to_numpy(
|
|
893
|
+
subset_bin,
|
|
894
|
+
layer_c,
|
|
895
|
+
any_c_sites,
|
|
896
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
897
|
+
fill_nan_value=fill_nan_value,
|
|
898
|
+
)
|
|
899
|
+
)
|
|
900
|
+
stacked_any_c_raw.append(
|
|
901
|
+
_layer_to_numpy(
|
|
902
|
+
subset_bin,
|
|
903
|
+
layer_c,
|
|
904
|
+
any_c_sites,
|
|
905
|
+
fill_nan_strategy="none",
|
|
906
|
+
fill_nan_value=fill_nan_value,
|
|
907
|
+
)
|
|
908
|
+
)
|
|
909
|
+
if include_any_c and num_gpc > 0:
|
|
910
|
+
stacked_gpc.append(
|
|
911
|
+
_layer_to_numpy(
|
|
912
|
+
subset_bin,
|
|
913
|
+
layer_gpc,
|
|
914
|
+
gpc_sites,
|
|
915
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
916
|
+
fill_nan_value=fill_nan_value,
|
|
917
|
+
)
|
|
918
|
+
)
|
|
919
|
+
stacked_gpc_raw.append(
|
|
920
|
+
_layer_to_numpy(
|
|
921
|
+
subset_bin,
|
|
922
|
+
layer_gpc,
|
|
923
|
+
gpc_sites,
|
|
924
|
+
fill_nan_strategy="none",
|
|
925
|
+
fill_nan_value=fill_nan_value,
|
|
926
|
+
)
|
|
927
|
+
)
|
|
928
|
+
if include_any_c and num_cpg > 0:
|
|
929
|
+
stacked_cpg.append(
|
|
930
|
+
_layer_to_numpy(
|
|
931
|
+
subset_bin,
|
|
932
|
+
layer_cpg,
|
|
933
|
+
cpg_sites,
|
|
934
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
935
|
+
fill_nan_value=fill_nan_value,
|
|
936
|
+
)
|
|
937
|
+
)
|
|
938
|
+
stacked_cpg_raw.append(
|
|
939
|
+
_layer_to_numpy(
|
|
940
|
+
subset_bin,
|
|
941
|
+
layer_cpg,
|
|
942
|
+
cpg_sites,
|
|
943
|
+
fill_nan_strategy="none",
|
|
944
|
+
fill_nan_value=fill_nan_value,
|
|
945
|
+
)
|
|
946
|
+
)
|
|
947
|
+
if include_any_a and num_any_a > 0:
|
|
948
|
+
stacked_any_a.append(
|
|
949
|
+
_layer_to_numpy(
|
|
950
|
+
subset_bin,
|
|
951
|
+
layer_a,
|
|
952
|
+
any_a_sites,
|
|
953
|
+
fill_nan_strategy=fill_nan_strategy,
|
|
954
|
+
fill_nan_value=fill_nan_value,
|
|
955
|
+
)
|
|
956
|
+
)
|
|
957
|
+
stacked_any_a_raw.append(
|
|
958
|
+
_layer_to_numpy(
|
|
959
|
+
subset_bin,
|
|
960
|
+
layer_a,
|
|
961
|
+
any_a_sites,
|
|
962
|
+
fill_nan_strategy="none",
|
|
963
|
+
fill_nan_value=fill_nan_value,
|
|
964
|
+
)
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
row_labels.extend([bin_label] * num_reads)
|
|
968
|
+
bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
|
|
969
|
+
last_idx += num_reads
|
|
970
|
+
bin_boundaries.append(last_idx)
|
|
971
|
+
|
|
972
|
+
blocks = []
|
|
973
|
+
|
|
974
|
+
if include_any_c and stacked_any_c:
|
|
975
|
+
any_c_matrix = np.vstack(stacked_any_c)
|
|
976
|
+
any_c_matrix_raw = np.vstack(stacked_any_c_raw)
|
|
977
|
+
gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
|
|
978
|
+
gpc_matrix_raw = (
|
|
979
|
+
np.vstack(stacked_gpc_raw) if stacked_gpc_raw else np.empty((0, 0))
|
|
980
|
+
)
|
|
981
|
+
cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
|
|
982
|
+
cpg_matrix_raw = (
|
|
983
|
+
np.vstack(stacked_cpg_raw) if stacked_cpg_raw else np.empty((0, 0))
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
mean_any_c = (
|
|
987
|
+
_methylation_fraction_for_layer(any_c_matrix_raw, layer_c)
|
|
988
|
+
if any_c_matrix_raw.size
|
|
989
|
+
else None
|
|
990
|
+
)
|
|
991
|
+
mean_gpc = (
|
|
992
|
+
_methylation_fraction_for_layer(gpc_matrix_raw, layer_gpc)
|
|
993
|
+
if gpc_matrix_raw.size
|
|
994
|
+
else None
|
|
995
|
+
)
|
|
996
|
+
mean_cpg = (
|
|
997
|
+
_methylation_fraction_for_layer(cpg_matrix_raw, layer_cpg)
|
|
998
|
+
if cpg_matrix_raw.size
|
|
999
|
+
else None
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
if any_c_matrix.size:
|
|
1003
|
+
blocks.append(
|
|
1004
|
+
dict(
|
|
1005
|
+
name="c",
|
|
1006
|
+
matrix=any_c_matrix,
|
|
1007
|
+
mean=mean_any_c,
|
|
1008
|
+
labels=any_c_labels,
|
|
1009
|
+
cmap=cmap_c,
|
|
1010
|
+
n_xticks=n_xticks_any_c,
|
|
1011
|
+
title="any C site Modification Signal",
|
|
1012
|
+
)
|
|
1013
|
+
)
|
|
1014
|
+
if gpc_matrix.size:
|
|
1015
|
+
blocks.append(
|
|
1016
|
+
dict(
|
|
1017
|
+
name="gpc",
|
|
1018
|
+
matrix=gpc_matrix,
|
|
1019
|
+
mean=mean_gpc,
|
|
1020
|
+
labels=gpc_labels,
|
|
1021
|
+
cmap=cmap_gpc,
|
|
1022
|
+
n_xticks=n_xticks_gpc,
|
|
1023
|
+
title="GpC Modification Signal",
|
|
1024
|
+
)
|
|
1025
|
+
)
|
|
1026
|
+
if cpg_matrix.size:
|
|
1027
|
+
blocks.append(
|
|
1028
|
+
dict(
|
|
1029
|
+
name="cpg",
|
|
1030
|
+
matrix=cpg_matrix,
|
|
1031
|
+
mean=mean_cpg,
|
|
1032
|
+
labels=cpg_labels,
|
|
1033
|
+
cmap=cmap_cpg,
|
|
1034
|
+
n_xticks=n_xticks_cpg,
|
|
1035
|
+
title="CpG Modification Signal",
|
|
1036
|
+
)
|
|
1037
|
+
)
|
|
1038
|
+
|
|
1039
|
+
if include_any_a and stacked_any_a:
|
|
1040
|
+
any_a_matrix = np.vstack(stacked_any_a)
|
|
1041
|
+
any_a_matrix_raw = np.vstack(stacked_any_a_raw)
|
|
1042
|
+
mean_any_a = (
|
|
1043
|
+
_methylation_fraction_for_layer(any_a_matrix_raw, layer_a)
|
|
1044
|
+
if any_a_matrix_raw.size
|
|
1045
|
+
else None
|
|
1046
|
+
)
|
|
1047
|
+
if any_a_matrix.size:
|
|
1048
|
+
blocks.append(
|
|
1049
|
+
dict(
|
|
1050
|
+
name="a",
|
|
1051
|
+
matrix=any_a_matrix,
|
|
1052
|
+
mean=mean_any_a,
|
|
1053
|
+
labels=any_a_labels,
|
|
1054
|
+
cmap=cmap_a,
|
|
1055
|
+
n_xticks=n_xticks_any_a,
|
|
1056
|
+
title="any A site Modification Signal",
|
|
1057
|
+
)
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
if not blocks:
|
|
1061
|
+
logger.warning("No matrices to plot for %s - %s.", display_sample, ref)
|
|
1062
|
+
continue
|
|
1063
|
+
|
|
1064
|
+
gs_dim = len(blocks)
|
|
1065
|
+
fig = plt.figure(figsize=(5.5 * gs_dim, 11))
|
|
1066
|
+
gs = grid_spec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.02)
|
|
1067
|
+
fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
|
|
1068
|
+
|
|
1069
|
+
axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
|
|
1070
|
+
axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
|
|
1071
|
+
|
|
1072
|
+
for i, blk in enumerate(blocks):
|
|
1073
|
+
mat = blk["matrix"]
|
|
1074
|
+
mean = blk["mean"]
|
|
1075
|
+
labels = np.asarray(blk["labels"], dtype=str)
|
|
1076
|
+
n_xticks = blk["n_xticks"]
|
|
1077
|
+
|
|
1078
|
+
clean_barplot(axes_bar[i], mean, blk["title"])
|
|
1079
|
+
|
|
1080
|
+
sns.heatmap(
|
|
1081
|
+
mat, cmap=blk["cmap"], ax=axes_heat[i], yticklabels=False, cbar=False
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
tick_pos = _fixed_tick_positions(len(labels), n_xticks)
|
|
1085
|
+
axes_heat[i].set_xticks(tick_pos)
|
|
1086
|
+
axes_heat[i].set_xticklabels(
|
|
1087
|
+
labels[tick_pos], rotation=xtick_rotation, fontsize=xtick_fontsize
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
for boundary in bin_boundaries[:-1]:
|
|
1091
|
+
axes_heat[i].axhline(y=boundary, color="black", linewidth=2)
|
|
1092
|
+
|
|
1093
|
+
axes_heat[i].set_xlabel("Position", fontsize=9)
|
|
1094
|
+
|
|
1095
|
+
plt.tight_layout()
|
|
1096
|
+
|
|
1097
|
+
if save_path is not None:
|
|
1098
|
+
safe_name = (
|
|
1099
|
+
f"{ref}__{display_sample}".replace("=", "")
|
|
1100
|
+
.replace("__", "_")
|
|
1101
|
+
.replace(",", "_")
|
|
1102
|
+
.replace(" ", "_")
|
|
1103
|
+
)
|
|
1104
|
+
out_file = save_path / f"{safe_name}.png"
|
|
1105
|
+
fig.savefig(out_file, dpi=300)
|
|
1106
|
+
plt.close(fig)
|
|
1107
|
+
logger.info("Saved combined raw clustermap to %s.", out_file)
|
|
1108
|
+
else:
|
|
1109
|
+
plt.show()
|
|
1110
|
+
|
|
1111
|
+
rec = {
|
|
1112
|
+
"sample": str(sample),
|
|
1113
|
+
"ref": str(ref),
|
|
1114
|
+
"row_labels": row_labels,
|
|
1115
|
+
"bin_labels": bin_labels,
|
|
1116
|
+
"bin_boundaries": bin_boundaries,
|
|
1117
|
+
"percentages": percentages,
|
|
1118
|
+
}
|
|
1119
|
+
for blk in blocks:
|
|
1120
|
+
rec[f"{blk['name']}_matrix"] = blk["matrix"]
|
|
1121
|
+
rec[f"{blk['name']}_labels"] = list(map(str, blk["labels"]))
|
|
1122
|
+
results.append(rec)
|
|
1123
|
+
|
|
1124
|
+
logger.info("Summary for %s - %s:", display_sample, ref)
|
|
1125
|
+
for bin_label, percent in percentages.items():
|
|
1126
|
+
logger.info(" - %s: %.1f%%", bin_label, percent)
|
|
1127
|
+
|
|
1128
|
+
except Exception:
|
|
1129
|
+
import traceback
|
|
1130
|
+
|
|
1131
|
+
traceback.print_exc()
|
|
1132
|
+
continue
|
|
1133
|
+
|
|
1134
|
+
return results
|